// Copyright 2025 RnD Center "ELVEES", JSC

#ifndef MAT_MUL_WITH_DMA_H
#define MAT_MUL_WITH_DMA_H
#endif

#include "elcore50-matrix-lib/common.h"
#include "elcore50-matrix-lib/elcore50.h"

/// Структура для запуска потайлового алгоритма матричного умножения формата
/// float16
typedef struct MatMulFl16Config {
  uint16_t* buf_A[2];   ///< указатель на буферы тайловой обработки первой матрицы
  uint16_t* buf_B[2];   ///< указатель на буферы тайловой обработки второй матрицы
  uint16_t* buf_C[2];   ///< указатель на буферы тайловой обработки
                        ///< результирующей матрицы
  float* buf_init_vec;  ///< указатель на буфер с инициализирующим вектором
  int buf_row0;         ///< к-во строк буфера первой матрицы
  int buf_row1col0;     ///< к-во строк буфера второй матрицы и столбцов первой
  int buf_col1;         ///< к-во столбцов буфера второй матрицы

  VDMAChain chain_A;         ///< dma цепочки для тайлов первой матрицы
  VDMAChain chain_B;         ///< dma цепочки для тайлов второй матрицы
  VDMAChain chain_ld_C;      ///< dma цепочки для тайлов загрузки результата
  VDMAChain chain_st_C;      ///< dma цепочки для тайлов выгрузки результата
  VDMAChain chain_init_vec;  ///< dma цепочки для тайлов инициализирующего вектора
  int len_chain_A;           ///< к-во dma цепочек для тайлов первой матрицы
  int len_chain_B;           ///< к-во dma цепочек для тайлов второй матрицы
  int len_chain_ld_C;        ///< к-во dma цепочек для загрузки тайлов результирующей
                             ///< матрицы
  int len_chain_st_C;        ///< к-во dma цепочек для выгрузки тайлов результирующей
                             ///< матрицы
} MatMulFl16Config;

/// Подбор оптимальных размеров тайлов для умножения матриц формата float16
void size_selector_mat_mul_fl16(int M,            ///< [in]  к-во строк первой матрицы
                                int K,            ///< [in]  к-во столбцов первой матрицы
                                int N,            ///< [in]  к-во столбцов второй матрицы
                                int& buf_M,       ///< [out] к-во строк тайла первой матрицы
                                int& buf_K,       ///< [in]  к-во столбцов тайла первой матрицы
                                int& buf_N,       ///< [in]  к-во столбцов тайла второй матрицы
                                int count_buf_A,  ///< [in]  к-во используемых буферов (двубуферная
                                                  ///< или однобуферная схема)
                                int count_buf_B,  ///< [in]  к-во используемых буферов
                                int count_buf_C,  ///< [in]  к-во используемых буферов
                                int& offsetA,     ///< [out] смещение тайла первой матрицы
                                int& offsetB      ///< [out] смещение тайла второй матрицы
);

/// Освобождение данных структуры запуска
void destroy_dma_chain_mat_mul_fl16(MatMulFl16Config* config  ///< [in] структура для запуска потайловой обработки
);

/// Инициализация и заполнение конфигурационной структуры для алгоритма
/// матричного умножения формата float16
void init_dma_chain_matmul_fl16(uint16_t* src0,            ///< [in]  указатель на первую матрицу
                                int row0,                  ///< [in]  к-во строк первой матрицы
                                int row1col0,              ///< [in]  к-во столбцов первой матрицы
                                uint16_t* src1,            ///< [in]  указатель на вторую матрицу
                                int col1,                  ///< [in]  к-во столбцов второй матрицы
                                uint16_t* dst,             ///< [in]  указатель на результирующую матрицу
                                int& offset_A,             ///< [out] смещение тайла первой матрицы
                                int& offset_B,             ///< [out] смещение тайла второй матрицы
                                MatMulFl16Config* config,  ///< [out] структура для запуска потайловой обработки
                                float* init_vector,        ///< [in]  указатель на вектор начальных значений
                                int offset_src0,           ///< [in]  смещение первой матрицы
                                int offset_dst,            ///< [in]  смещение результата
                                uint16_t* start_adr        ///< [in]  начальный адрес локальной памяти
);

/// Запуск потайловой обработки алгоритма матричного умножения формата float16
void run_matmul_fl16(uint16_t* src0,            ///< [in]  указатель на первую матрицу
                     int row0,                  ///< [in]  к-во строк первой матрицы
                     int row1col0,              ///< [in]  к-во столбцов первой матрицы
                     uint16_t* src1,            ///< [in]  указатель на вторую матрицу
                     int col1,                  ///< [in]  к-во столбцов второй матрицы
                     uint16_t* dst,             ///< [in]  указатель на результирующую матрицу
                     int offset_A,              ///< [in]  смещение тайла первой матрицы
                     int offset_B,              ///< [in]  смещение тайла второй матрицы
                     MatMulFl16Config* config,  ///< [in]  структура для запуска потайловой обработки
                     Store_version st_ver,      ///< [in]  версия постобработки
                     float* init_vector         ///< [in]  указатель на вектор начальных значений
);

extern "C" void mat_mul_fl16_16x128_fl16_aliquantM(uint16_t* in1, int row, int col0row1, uint16_t* in2, int col1,
                                                   uint16_t* out, int offsetA, int offsetB, int real_col, int flag,
                                                   int* tics, int* instr);
extern "C" void mat_mul_fl16_16x128_fl16_aliquantM_relu(uint16_t* in1, int row, int col0row1, uint16_t* in2, int col1,
                                                        uint16_t* out, int offsetA, int offsetB, int real_col,
                                                        int flag, int* tics, int* instr);
extern "C" void mat_mul_fl16_16x128_fl16_aliquantM_relu6(uint16_t* in1, int row, int col0row1, uint16_t* in2, int col1,
                                                         uint16_t* out, int offsetA, int offsetB, int real_col,
                                                         int flag, int* tics, int* instr);

extern "C" void mat_mul_fl16_16x128_fl16_aliquantM_init_vec(uint16_t* in1, int row, int col0row1, uint16_t* in2,
                                                            int col1, uint16_t* out, int offsetA, int offsetB,
                                                            int real_col, int flag, int* tics, int* instr,
                                                            float* init_vector);
extern "C" void mat_mul_fl16_16x128_fl16_aliquantM_relu_init_vec(uint16_t* in1, int row, int col0row1, uint16_t* in2,
                                                                 int col1, uint16_t* out, int offsetA, int offsetB,
                                                                 int real_col, int flag, int* tics, int* instr,
                                                                 float* init_vector);
extern "C" void mat_mul_fl16_16x128_fl16_aliquantM_relu6_init_vec(uint16_t* in1, int row, int col0row1, uint16_t* in2,
                                                                  int col1, uint16_t* out, int offsetA, int offsetB,
                                                                  int real_col, int flag, int* tics, int* instr,
                                                                  float* init_vector);

extern "C" void mat_mul_fl16_16x128_fl16_general(uint16_t* in1, int row, int col0row1, uint16_t* in2, int col1,
                                                 uint16_t* out, int offsetA, int offsetB, int real_col, int flag,
                                                 int* tics, int* instr);
extern "C" void mat_mul_fl16_16x128_fl16_general_relu(uint16_t* in1, int row, int col0row1, uint16_t* in2, int col1,
                                                      uint16_t* out, int offsetA, int offsetB, int real_col, int flag,
                                                      int* tics, int* instr);
extern "C" void mat_mul_fl16_16x128_fl16_general_relu6(uint16_t* in1, int row, int col0row1, uint16_t* in2, int col1,
                                                       uint16_t* out, int offsetA, int offsetB, int real_col, int flag,
                                                       int* tics, int* instr);

extern "C" void mat_mul_fl16_16x128_fl16_general_no_ld_res(uint16_t* in1, int row, int col0row1, uint16_t* in2,
                                                           int col1, uint16_t* out, int offsetA, int offsetB,
                                                           int real_col, int flag, int* tics, int* instr);
extern "C" void mat_mul_fl16_16x128_fl16_general_no_ld_res_relu(uint16_t* in1, int row, int col0row1, uint16_t* in2,
                                                                int col1, uint16_t* out, int offsetA, int offsetB,
                                                                int real_col, int flag, int* tics, int* instr);
extern "C" void mat_mul_fl16_16x128_fl16_general_no_ld_res_relu6(uint16_t* in1, int row, int col0row1, uint16_t* in2,
                                                                 int col1, uint16_t* out, int offsetA, int offsetB,
                                                                 int real_col, int flag, int* tics, int* instr);

extern "C" void mat_mul_fl16_16x128_fl16_general_no_ld_res_init_vec(uint16_t* in1, int row, int col0row1,
                                                                    uint16_t* in2, int col1, uint16_t* out,
                                                                    int offsetA, int offsetB, int real_col, int flag,
                                                                    int* tics, int* instr, float* init_vector);
extern "C" void mat_mul_fl16_16x128_fl16_general_no_ld_res_relu_init_vec(uint16_t* in1, int row, int col0row1,
                                                                         uint16_t* in2, int col1, uint16_t* out,
                                                                         int offsetA, int offsetB, int real_col,
                                                                         int flag, int* tics, int* instr,
                                                                         float* init_vector);
extern "C" void mat_mul_fl16_16x128_fl16_general_no_ld_res_relu6_init_vec(uint16_t* in1, int row, int col0row1,
                                                                          uint16_t* in2, int col1, uint16_t* out,
                                                                          int offsetA, int offsetB, int real_col,
                                                                          int flag, int* tics, int* instr,
                                                                          float* init_vector);

extern "C" void mat_mul_fl16_32x64_fl16_aliquantM(uint16_t* in1, int row, int col0row1, uint16_t* in2, int col1,
                                                  uint16_t* out, int offsetA, int offsetB, int real_col, int flag,
                                                  int* tics, int* instr);
extern "C" void mat_mul_fl16_32x64_fl16_aliquantM_relu(uint16_t* in1, int row, int col0row1, uint16_t* in2, int col1,
                                                       uint16_t* out, int offsetA, int offsetB, int real_col, int flag,
                                                       int* tics, int* instr);
extern "C" void mat_mul_fl16_32x64_fl16_aliquantM_relu6(uint16_t* in1, int row, int col0row1, uint16_t* in2, int col1,
                                                        uint16_t* out, int offsetA, int offsetB, int real_col,
                                                        int flag, int* tics, int* instr);

extern "C" void mat_mul_fl16_32x64_fl16_aliquantM_init_vec(uint16_t* in1, int row, int col0row1, uint16_t* in2,
                                                           int col1, uint16_t* out, int offsetA, int offsetB,
                                                           int real_col, int flag, int* tics, int* instr,
                                                           float* init_vector);
extern "C" void mat_mul_fl16_32x64_fl16_aliquantM_relu_init_vec(uint16_t* in1, int row, int col0row1, uint16_t* in2,
                                                                int col1, uint16_t* out, int offsetA, int offsetB,
                                                                int real_col, int flag, int* tics, int* instr,
                                                                float* init_vector);
extern "C" void mat_mul_fl16_32x64_fl16_aliquantM_relu6_init_vec(uint16_t* in1, int row, int col0row1, uint16_t* in2,
                                                                 int col1, uint16_t* out, int offsetA, int offsetB,
                                                                 int real_col, int flag, int* tics, int* instr,
                                                                 float* init_vector);

extern "C" void mat_mul_fl16_32x64_fl16_general(uint16_t* in1, int row, int col0row1, uint16_t* in2, int col1,
                                                uint16_t* out, int offsetA, int offsetB, int real_col, int flag,
                                                int* tics, int* instr);
extern "C" void mat_mul_fl16_32x64_fl16_general_relu(uint16_t* in1, int row, int col0row1, uint16_t* in2, int col1,
                                                     uint16_t* out, int offsetA, int offsetB, int real_col, int flag,
                                                     int* tics, int* instr);
extern "C" void mat_mul_fl16_32x64_fl16_general_relu6(uint16_t* in1, int row, int col0row1, uint16_t* in2, int col1,
                                                      uint16_t* out, int offsetA, int offsetB, int real_col, int flag,
                                                      int* tics, int* instr);

extern "C" void mat_mul_fl16_32x64_fl16_general_no_ld_res(uint16_t* in1, int row, int col0row1, uint16_t* in2,
                                                          int col1, uint16_t* out, int offsetA, int offsetB,
                                                          int real_col, int flag, int* tics, int* instr);
extern "C" void mat_mul_fl16_32x64_fl16_general_no_ld_res_relu(uint16_t* in1, int row, int col0row1, uint16_t* in2,
                                                               int col1, uint16_t* out, int offsetA, int offsetB,
                                                               int real_col, int flag, int* tics, int* instr);
extern "C" void mat_mul_fl16_32x64_fl16_general_no_ld_res_relu6(uint16_t* in1, int row, int col0row1, uint16_t* in2,
                                                                int col1, uint16_t* out, int offsetA, int offsetB,
                                                                int real_col, int flag, int* tics, int* instr);

extern "C" void mat_mul_fl16_32x64_fl16_general_no_ld_res_init_vec(uint16_t* in1, int row, int col0row1, uint16_t* in2,
                                                                   int col1, uint16_t* out, int offsetA, int offsetB,
                                                                   int real_col, int flag, int* tics, int* instr,
                                                                   float* init_vector);
extern "C" void mat_mul_fl16_32x64_fl16_general_no_ld_res_relu_init_vec(uint16_t* in1, int row, int col0row1,
                                                                        uint16_t* in2, int col1, uint16_t* out,
                                                                        int offsetA, int offsetB, int real_col,
                                                                        int flag, int* tics, int* instr,
                                                                        float* init_vector);
extern "C" void mat_mul_fl16_32x64_fl16_general_no_ld_res_relu6_init_vec(uint16_t* in1, int row, int col0row1,
                                                                         uint16_t* in2, int col1, uint16_t* out,
                                                                         int offsetA, int offsetB, int real_col,
                                                                         int flag, int* tics, int* instr,
                                                                         float* init_vector);

extern "C" void mat_mul_fl16_64x32_fl16_aliquantM(uint16_t* in1, int row, int col0row1, uint16_t* in2, int col1,
                                                  uint16_t* out, int offsetA, int offsetB, int real_col, int flag,
                                                  int* tics, int* instr);
extern "C" void mat_mul_fl16_64x32_fl16_aliquantM_relu(uint16_t* in1, int row, int col0row1, uint16_t* in2, int col1,
                                                       uint16_t* out, int offsetA, int offsetB, int real_col, int flag,
                                                       int* tics, int* instr);
extern "C" void mat_mul_fl16_64x32_fl16_aliquantM_relu6(uint16_t* in1, int row, int col0row1, uint16_t* in2, int col1,
                                                        uint16_t* out, int offsetA, int offsetB, int real_col,
                                                        int flag, int* tics, int* instr);

extern "C" void mat_mul_fl16_64x32_fl16_aliquantM_init_vec(uint16_t* in1, int row, int col0row1, uint16_t* in2,
                                                           int col1, uint16_t* out, int offsetA, int offsetB,
                                                           int real_col, int flag, int* tics, int* instr,
                                                           float* init_vector);
extern "C" void mat_mul_fl16_64x32_fl16_aliquantM_relu_init_vec(uint16_t* in1, int row, int col0row1, uint16_t* in2,
                                                                int col1, uint16_t* out, int offsetA, int offsetB,
                                                                int real_col, int flag, int* tics, int* instr,
                                                                float* init_vector);
extern "C" void mat_mul_fl16_64x32_fl16_aliquantM_relu6_init_vec(uint16_t* in1, int row, int col0row1, uint16_t* in2,
                                                                 int col1, uint16_t* out, int offsetA, int offsetB,
                                                                 int real_col, int flag, int* tics, int* instr,
                                                                 float* init_vector);

extern "C" void mat_mul_fl16_64x32_fl16_general(uint16_t* in1, int row, int col0row1, uint16_t* in2, int col1,
                                                uint16_t* out, int offsetA, int offsetB, int real_col, int flag,
                                                int* tics, int* instr);
extern "C" void mat_mul_fl16_64x32_fl16_general_relu(uint16_t* in1, int row, int col0row1, uint16_t* in2, int col1,
                                                     uint16_t* out, int offsetA, int offsetB, int real_col, int flag,
                                                     int* tics, int* instr);
extern "C" void mat_mul_fl16_64x32_fl16_general_relu6(uint16_t* in1, int row, int col0row1, uint16_t* in2, int col1,
                                                      uint16_t* out, int offsetA, int offsetB, int real_col, int flag,
                                                      int* tics, int* instr);

extern "C" void mat_mul_fl16_64x32_fl16_general_no_ld_res(uint16_t* in1, int row, int col0row1, uint16_t* in2,
                                                          int col1, uint16_t* out, int offsetA, int offsetB,
                                                          int real_col, int flag, int* tics, int* instr);
extern "C" void mat_mul_fl16_64x32_fl16_general_no_ld_res_relu(uint16_t* in1, int row, int col0row1, uint16_t* in2,
                                                               int col1, uint16_t* out, int offsetA, int offsetB,
                                                               int real_col, int flag, int* tics, int* instr);
extern "C" void mat_mul_fl16_64x32_fl16_general_no_ld_res_relu6(uint16_t* in1, int row, int col0row1, uint16_t* in2,
                                                                int col1, uint16_t* out, int offsetA, int offsetB,
                                                                int real_col, int flag, int* tics, int* instr);

extern "C" void mat_mul_fl16_64x32_fl16_general_no_ld_res_init_vec(uint16_t* in1, int row, int col0row1, uint16_t* in2,
                                                                   int col1, uint16_t* out, int offsetA, int offsetB,
                                                                   int real_col, int flag, int* tics, int* instr,
                                                                   float* init_vector);
extern "C" void mat_mul_fl16_64x32_fl16_general_no_ld_res_relu_init_vec(uint16_t* in1, int row, int col0row1,
                                                                        uint16_t* in2, int col1, uint16_t* out,
                                                                        int offsetA, int offsetB, int real_col,
                                                                        int flag, int* tics, int* instr,
                                                                        float* init_vector);
extern "C" void mat_mul_fl16_64x32_fl16_general_no_ld_res_relu6_init_vec(uint16_t* in1, int row, int col0row1,
                                                                         uint16_t* in2, int col1, uint16_t* out,
                                                                         int offsetA, int offsetB, int real_col,
                                                                         int flag, int* tics, int* instr,
                                                                         float* init_vector);
