// Copyright 2025 RnD Center "ELVEES", JSC

#include <elcore50-matrix-lib/convolution_fl16.hpp>
#include <elcore50-matrix-lib/winograd_fl16.hpp>

#include "convolve_tests_helper.h"

extern "C" void float16_to_float32(uint16_t* src, float* dst, int size);
extern "C" void float32_to_float16(float* src, uint16_t* dst, int size);

static void weight_dyxc_to_yxcd(const float* input, float* output, int kernelY, int kernelX, int srcC, int dstC);

#define TEST_COUNT 20

int main() {
  disable_l2_cache();

  int test_result = 0;

  int tic[2], instr[2];
  Store_version st_ver = STORE_NONE;

  int array_srcH[TEST_COUNT] = {18, 34, 34, 58, 58, 18, 34, 58, 18, 34, 114, 58, 154, 74, 10, 10, 34, 34, 34, 66};

  int array_srcW[TEST_COUNT] = {18, 34, 34, 58, 58, 18, 34, 58, 18, 34, 114, 58, 154, 74, 34, 34, 34, 34, 34, 66};

  int array_srcC[TEST_COUNT] = {128, 128, 128, 128, 128, 256, 256, 256, 512, 512,
                                64,  64,  32,  96,  192, 256, 256, 512, 512, 512};

  int array_dstC[TEST_COUNT] = {32,  128, 32, 256, 32,  256, 512, 256, 512, 512,
                                128, 64,  32, 192, 256, 256, 256, 256, 512, 512};

  printf(
      "| func_name        | input (b, h, w, ch)    | bias | store mode | "
      "conv2d mul/tic | conv_tic/win_tic | "
      "status |\n");
  printf(
      "|                  | ker (X, Y), str (X, Y) |      |            |       "
      "         |                  | "
      "       |\n");
  printf(
      "|                  | output (b, h, w, ch)   |      |            |       "
      "         |                  | "
      "       |\n");
  printf(
      "------------------------------------------------------------------------"
      "----------------------------"
      "----------\n");

  for (int winograd_ver_id = 0; winograd_ver_id < 2; ++winograd_ver_id) {
    for (int test_id = 0; test_id < TEST_COUNT; ++test_id) {
      int srcH = array_srcH[test_id];
      int srcW = array_srcW[test_id];
      int srcC = array_srcC[test_id];
      int dstC = array_dstC[test_id];

      printf("| win%s vs conv   | (%1d %4d %4d %4d)     |    %d |", winograd_ver_id ? "4x4" : "2x2", 1, srcH, srcW,
             srcC, 0);
      printf("      none  | ");

      int dstH = srcH - 2;
      int dstW = srcW - 2;

      int kernelY = 3;
      int kernelX = 3;

      float* src = (float*)memalign(64, srcH * srcW * srcC * sizeof(float));
      float* weight_dyxc = (float*)memalign(64, kernelY * kernelX * dstC * srcC * sizeof(float));
      float* weight_yxcd = (float*)memalign(64, kernelY * kernelX * dstC * srcC * sizeof(float));

      for (int i = 0; i < srcH * srcW * srcC; ++i) {
        src[i] = rand() % 3;
      }

      for (int i = 0; i < kernelY * kernelX * dstC * srcC; ++i) {
        weight_dyxc[i] = (rand() % 3) * 0.1;
      }

      weight_dyxc_to_yxcd(weight_dyxc, weight_yxcd, kernelY, kernelX, srcC, dstC);

      Weight_fl16 input_weight = {NULL, kernelX, kernelY, srcC, dstC, 1, 1, 1, 1};
      input_weight.data = (uint16_t*)memalign(64, kernelY * kernelX * dstC * srcC * sizeof(uint16_t));
      float32_to_float16(weight_yxcd, input_weight.data, kernelY * kernelX * dstC * srcC);

      float* dst_conv2d = (float*)memalign(64, dstC * dstH * dstW * sizeof(float));
      float* dst_winograd = (float*)memalign(64, dstC * dstH * dstW * sizeof(float));

      uint16_t* src_fl16 = (uint16_t*)memalign(64, srcH * srcW * srcC * sizeof(uint16_t));
      uint16_t* dst_conv2d_fl16 = (uint16_t*)memalign(64, dstC * dstH * dstW * sizeof(uint16_t));
      uint16_t* dst_winograd_fl16 = (uint16_t*)memalign(64, dstC * dstH * dstW * sizeof(uint16_t));
      float32_to_float16(src, src_fl16, srcH * srcW * srcC);

      /* Conv2d */
      Tensor_fl16 input_fl16 = {src_fl16, 1, srcH, srcW, srcC};
      Tensor_fl16 output_opt_fl16 = {dst_conv2d_fl16, 1, dstH, dstW, dstC};

      ConvFl16Config conv_config;
      init_dma_chain_conv_fl16(&input_fl16, 1, 0, 0, 0, 0, &input_weight, NULL, &output_opt_fl16, &conv_config, 0,
                               (uint16_t*)__local_mem, 524288);
      count_tics(tic, instr);
      run_conv_fl16(&input_fl16, 1, &input_weight, NULL, &output_opt_fl16, &conv_config, st_ver, 0);
      count_tics(&tic[1], &instr[1]);

      destroy_dma_chain_conv_fl16(&conv_config);
      flush_all_caches();

      float16_to_float32(dst_conv2d_fl16, dst_conv2d, dstC * dstH * dstW);
      int conv_tic = tic[1] - tic[0];
      printf("    %9.3f  | ", (float)dstH * dstW * srcC * 3 * 3 * dstC / (tic[1] - tic[0]));

      Winograd_version ver = winograd_ver_id ? VER_4x4 : VER_2x2;
      WinogradConfig config(ver);
      init_dma_chain_winograd(src_fl16, 1, srcC, srcH, srcW, weight_yxcd, dst_winograd_fl16, dstC, &config);

      count_tics(tic, instr);
      run_winograd(src_fl16, 1, srcC, srcH, srcW, dst_winograd_fl16, dstC, &config);
      count_tics(&tic[1], &instr[1]);

      printf("     %10.3f  | ", (float)conv_tic / (tic[1] - tic[0]));

      float16_to_float32(dst_winograd_fl16, dst_winograd, dstC * dstH * dstW);
      destroy_dma_chain_winograd(&config);

      int error = 0;

      for (int i = 0; i < dstC * dstH * dstW; ++i) {
        if (abs((dst_conv2d[i]) - (dst_winograd[i])) / (abs((dst_conv2d[i]) ? abs(dst_conv2d[i]) : 1)) > 0.01) {
          error++;
          break;
        }
      }

      free(src_fl16);
      free(dst_conv2d_fl16);
      free(dst_winograd_fl16);
      free(src);
      free(weight_dyxc);
      free(weight_yxcd);
      free(dst_conv2d);
      free(dst_winograd);
      free(input_weight.data);

      if (!error)
        printf("passed |\n");
      else
        printf("failed |\n");

      test_result += error;

      printf(
          "|                  | (%1d %2d) (%1d %2d)          |      |            "
          "|                |                 "
          " |        |\n",
          3, 3, 1, 1);
      printf(
          "|                  | (%1d %4d %4d %4d)     |      |            |      "
          "          |                  |     "
          "   |\n",
          1, dstH, dstW, dstC);
      printf(
          "----------------------------------------------------------------------"
          "-----------------------------------"
          "-----\n");
    }
  }
  enable_l2_cache(L2_CACHE_SIZE);

  return test_result;
}

__attribute__((optnone)) static void weight_dyxc_to_yxcd(const float* input, float* output, int kernelY, int kernelX,
                                                         int srcC, int dstC) {
  int indx_in = 0;
  int indx_out = 0;
  for (int h = 0; h < kernelY; ++h) {
    for (int w = 0; w < kernelX; ++w) {
      for (int ch = 0; ch < srcC; ++ch) {
        for (int b = 0; b < dstC; ++b) {
          indx_out = h * dstC * kernelX * srcC + w * srcC * dstC + ch * dstC + b;
          indx_in = b * kernelY * kernelX * srcC + h * kernelX * srcC + w * srcC + ch;
          output[indx_out] = input[indx_in];
        }
      }
    }
  }
}
