// Copyright 2025 RnD Center "ELVEES", JSC

#ifndef WINOGRAD_FL16_H
#define WINOGRAD_FL16_H

#include "elcore50-matrix-lib/common.h"
#include "elcore50-matrix-lib/dmainit.h"
#include "elcore50-matrix-lib/mat_mul_with_dma_fl16.hpp"
#include "elcore50-matrix-lib/tile_segmentation.hpp"

/// Преобразование входного тензора
extern "C" void transform_input(uint16_t* src,  ///< [in]  указатель на данные входного тензора
                                int srcC,       ///< [in]  количество каналов тензора
                                int srcH,       ///< [in]  высота тензора
                                int srcW,       ///< [in]  ширина тензора
                                uint16_t* dst,  ///< [out] указатель на преобразованные данные
                                int size        ///< [in]  размер блока преобразованных данных
);

/// Преобразование входного тензора с окном 4x4
extern "C" void transform_input4x4(uint16_t* src,  ///< [in]  указатель на данные входного тензора
                                   int srcC,       ///< [in]  количество каналов тензора
                                   int srcH,       ///< [in]  высота тензора
                                   int srcW,       ///< [in]  ширина тензора
                                   uint16_t* dst,  ///< [out] указатель на преобразованные данные
                                   int size        ///< [in]  размер блока преобразованных данных
);

/// Преобразование для получения выходного тензора
extern "C" void transform_output(uint16_t* src,  ///< [in]  указатель на преобразованные данные
                                 int sizeD,      ///< [in]  размер блока преобразованных данных
                                 uint16_t* dst,  ///< [out] указатель на данные выходного тензора
                                 int dstC,       ///< [in]  количество каналов тензора
                                 int dstH,       ///< [in]  высота тензора
                                 int dstW        ///< [in]  ширина тензора
);

/// Преобразование для получения выходного тензора с окном 4x4
extern "C" void transform_output4x4(uint16_t* src,  ///< [in]  указатель на преобразованные данные
                                    int sizeD,      ///< [in]  размер блока преобразованных данных
                                    uint16_t* dst,  ///< [out] указатель на данные выходного тензора
                                    int dstC,       ///< [in]  количество каналов тензора
                                    int dstH,       ///< [in]  высота тензора
                                    int dstW        ///< [in]  ширина тензора
);

/// Версии свертки по методу Винограда
enum Winograd_version {
  VER_2x2,  ///< версия с окном преобразования 2 на 2 элемента
  VER_4x4   ///< версия с окном преобразования 4 на 4 элемента
};

/// Подбор оптимальных размеров тайлов для функции transform_input
void tile_size_selector_transform_input(
    int srcC,               ///< [in]  количество каналов тензора
    int srcH,               ///< [in]  высота тензора
    int srcW,               ///< [in]  ширина тензора
    TileSegConfig* config,  ///< [out] структура для запуска потайлового преобразования
    Winograd_version ver    ///< [in]  версия выходного окна алгоритма
);

/// Заполнение конфигурационной структуры для запуска потайлового преобразования
/// входного тензора
void init_chain_dma_transform_input(uint16_t* src,          ///< [in]  указатель на данные входного тензора
                                    int srcC,               ///< [in]  количество каналов тензора
                                    int srcH,               ///< [in]  высота тензора
                                    int srcW,               ///< [in]  ширина тензора
                                    uint16_t* dst,          ///< [in]  указатель на преобразованные данные
                                    int size,               ///< [in]  размер блока преобразованных данных
                                    TileSegConfig* config,  ///< [out] структура для запуска потайлового преобразования
                                    Winograd_version ver    ///< [in]  версия выходного окна алгоритма
);

/// Запуск потайлового преобразования входного тензора
void run_dma_transform_input(uint16_t* src,          ///< [in]  указатель на данные входного тензора
                             int srcC,               ///< [in]  количество каналов тензора
                             int srcH,               ///< [in]  высота тензора
                             int srcW,               ///< [in]  ширина тензора
                             uint16_t* dst,          ///< [out] указатель на преобразованные данные
                             int size,               ///< [in]  размер блока преобразованных данных
                             TileSegConfig* config,  ///< [in]  структура для запуска потайлового преобразования
                             Winograd_version ver    ///< [in]  версия выходного окна алгоритма
);

/// Подбор оптимальных размеров тайлов для функции transform_output
void tile_size_selector_transform_output(
    int dstC,               ///< [in]  количество каналов тензора
    int dstH,               ///< [in]  высота тензора
    int dstW,               ///< [in]  ширина тензора
    TileSegConfig* config,  ///< [out] структура для запуска потайлового преобразования
    Winograd_version ver    ///< [in]  версия выходного окна алгоритма
);

/// Заполнение конфигурационной структуры для запуска потайлового преобразования
/// выходного тензора
void init_chain_dma_transform_output(
    uint16_t* src,          ///< [in]  указатель на преобразованные данные
    int sizeD,              ///< [in]  размер блока преобразованных данных
    uint16_t* dst,          ///< [in]  указатель на данные выходного тензора
    int dstC,               ///< [in]  количество каналов тензора
    int dstH,               ///< [in]  высота тензора
    int dstW,               ///< [in]  ширина тензора
    TileSegConfig* config,  ///< [out] структура для запуска потайлового преобразования
    Winograd_version ver    ///< [in]  версия выходного окна алгоритма
);

/// Запуск потайлового преобразования выходного тензора
void run_dma_transform_output(uint16_t* src,          ///< [in]  указатель на преобразованные данные
                              int sizeD,              ///< [in]  размер блока преобразованных данных
                              uint16_t* dst,          ///< [out] указатель на данные выходного тензора
                              int dstC,               ///< [in]  количество каналов тензора
                              int dstH,               ///< [in]  высота тензора
                              int dstW,               ///< [in]  ширина тензора
                              TileSegConfig* config,  ///< [in]  структура для запуска потайлового преобразования
                              Winograd_version ver    ///< [in]  версия выходного окна алгоритма
);

/// Структура для запуска потайлового алгоритма Винограда
typedef struct WinogradConfig {
  WinogradConfig(Winograd_version version);
  Winograd_version ver;  ///< версия выходного окна алгоритма
  float* bufW_fl32;      ///< указатель на буфер с весами
  uint16_t* bufW_fl16;   ///< указатель на буфер с весами в fl16
  uint16_t* bufS_fl16;   ///< указатель на буфер с преобразованным входом
  uint16_t* bufD_fl16;   ///< указатель на буфер с преобразованным выходом

  TileSegConfig config_transform_input;   ///< конфиг для преобразования входа
  TileSegConfig config_transform_output;  ///< конфиг для преобразования выхода

  int offset_A[36];  ///< смещения первой матрицы
  int offset_B[36];  ///< смещения второй матрицы
  int offset_C[36];  ///< смещения выходной матрицы

  float* init_vector;                   ///< вектор начальных значений
  MatMulFl16Config mat_mul_config[36];  ///< структуры для запуска потайлового
                                        ///< алгоритма матричного умножения
} WinogradConfig;

/// Преобразования весов
void transform_filter(float* src,           ///< [in]  указатель на данные весов
                      int size,             ///< [in]  размер блока преобразованных данных
                      float* dst,           ///< [out] указатель на преобразованные данные
                      int srcC,             ///< [in]  количество каналов входного тензора
                      int dstC,             ///< [in]  количество каналов выходного тензора
                      Winograd_version ver  ///< [in]  версия выходного окна алгоритма
);

/// Инициализация и заполнение конфигурационной структуры для алгоритма
/// Винограда
void init_dma_chain_winograd(uint16_t* src_fl16,     ///< [in]  указатель на данные входного тензора
                             int batch,              ///< [in]  количество батчей входного тензора
                             int srcC,               ///< [in]  количество каналов входного тензора
                             int srcH,               ///< [in]  высота входного тензора
                             int srcW,               ///< [in]  ширина входного тензора
                             float* weight,          ///< [in]  указатель на данные весов
                             uint16_t* dst_fl16,     ///< [in]  указатель на данные выходного тензора
                             int dstC,               ///< [in]  количество каналов выходного тензора
                             WinogradConfig* config  ///< [out] структура для запуска потайловой обработки
);

// Освобождение данных структуры запуска алгоритма Винограда
void destroy_dma_chain_winograd(WinogradConfig* config  ///< [in] структура для запуска потайловой обработки
);

/// Запуск потайлового алгоритма Винограда.
/// Ограничения: srcC, dstC - кратны 32; srcH, srcW - больше 4 и кратны 2.
void run_winograd(uint16_t* src_fl16,     ///< [in]  указатель на данные входного тензора
                  int batch,              ///< [in]  количество батчей входного тензора
                  int srcC,               ///< [in]  количество каналов входного тензора
                  int srcH,               ///< [in]  высота входного тензора
                  int srcW,               ///< [in]  ширина входного тензора
                  uint16_t* dst_fl16,     ///< [in]  указатель на данные выходного тензора
                  int dstC,               ///< [in]  количество каналов выходного тензора
                  WinogradConfig* config  ///< [out] структура для запуска потайловой обработки
);

#endif
