// Copyright 2025 RnD Center "ELVEES", JSC

#include <elcore50-matrix-lib/winograd_fl16.hpp>

extern "C" void float16_to_float32(uint16_t* src, float* dst, int size);
extern "C" void float32_to_float16(float* src, uint16_t* dst, int size);

static void weight_dyxc_to_yxcd(const float* input, float* output, int kernelY, int kernelX, int srcC, int dstC);

static void ref_convolution_wo_bias(const float* src, int batch, int srcH, int srcW, int srcC, int kernelY,
                                    int kernelX, int dilationY, int dilationX, int strideY, int strideX, int padY,
                                    int padX, int padH, int padW, int group, const float* weight, float* dst,
                                    int dstC);

#define TEST_COUNT 40

int main() {
  disable_l2_cache();

  int test_result = 0;

  int array_srcH[TEST_COUNT] = {10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 66, 18, 34, 10,
                                34, 10, 10, 17, 9,  24, 10, 9,  41, 8,  10, 9,  9,  8,  10, 73, 9,  8,  30, 6};

  int array_srcW[TEST_COUNT] = {10, 10, 10, 10, 10, 10, 10, 10, 66, 10, 10, 66, 10, 10, 10, 66, 10, 18, 34, 34,
                                10, 34, 18, 42, 10, 10, 9,  25, 9,  9,  9,  41, 9,  9,  8,  8,  72, 8,  30, 6};

  int array_srcC[TEST_COUNT] = {32, 128, 256, 64, 32, 32, 32, 64, 64, 64,  32, 32, 32, 32, 64, 64,  32, 128, 64,  192,
                                64, 192, 64,  32, 64, 64, 64, 64, 64, 128, 64, 32, 64, 64, 64, 192, 64, 64,  128, 64};

  int array_dstC[TEST_COUNT] = {32, 32,  32, 32, 128, 64, 64, 32,  32, 32, 128, 64, 64, 32, 64, 64, 32, 128, 128, 256,
                                64, 256, 32, 64, 32,  64, 32, 128, 32, 32, 32,  32, 32, 32, 32, 32, 32, 32,  32,  64};

  int array_padX[TEST_COUNT] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 3};
  int array_padW[TEST_COUNT] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                                0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1};
  int array_padY[TEST_COUNT] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                                0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 2};
  int array_padH[TEST_COUNT] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                                0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 2};

  printf(
      "| func_name        | input (b, h, w, ch)             | bias | store mode | opt "
      "mul/tic |    opt tic | status |\n");
  printf(
      "|                  | ker(X,Y),str(X,Y),pads(X,W,Y,H) |      |            |       "
      "      |            |        |\n");
  printf(
      "|                  | output (b, h, w, ch)            |      |            |       "
      "      |            |        |\n");
  printf(
      "------------------------------------------------------------------------"
      "--------------------------------------\n");
  for (int winograd_ver_id = 0; winograd_ver_id < 2; ++winograd_ver_id) {
    for (int test_id = 0; test_id < TEST_COUNT; ++test_id) {
      int padX = array_padX[test_id];
      int padW = array_padW[test_id];
      int padY = array_padY[test_id];
      int padH = array_padH[test_id];

      int srcH = array_srcH[test_id];
      int srcW = array_srcW[test_id];
      int srcC = array_srcC[test_id];
      int dstC = array_dstC[test_id];
      printf("| winograd f16 %s | (%1d %4d %4d %4d)              |    %d |", winograd_ver_id ? "4x4" : "2x2", 1, srcH,
             srcW, srcC, 0);
      printf("      none  | ");

      int dstH = srcH - 2 + padY + padH;
      int dstW = srcW - 2 + padX + padW;

      int kernelY = 3;
      int kernelX = 3;

      float* src = (float*)memalign(64, srcH * srcW * srcC * sizeof(float));
      float* weight_dyxc = (float*)memalign(64, kernelY * kernelX * dstC * srcC * sizeof(float));
      float* weight_yxcd = (float*)memalign(64, kernelY * kernelX * dstC * srcC * sizeof(float));

      for (int i = 0; i < srcH * srcW * srcC; ++i) {
        src[i] = rand() % 3;
      }

      for (int i = 0; i < kernelY * kernelX * dstC * srcC; ++i) {
        weight_dyxc[i] = (rand() % 3) * 0.1;
      }

      weight_dyxc_to_yxcd(weight_dyxc, weight_yxcd, kernelY, kernelX, srcC, dstC);

      float* dst_ref = (float*)memalign(64, dstC * dstH * dstW * sizeof(float));
      float* dst_opt = (float*)memalign(64, dstC * dstH * dstW * sizeof(float));

      uint16_t* src_fl16 = (uint16_t*)memalign(64, srcH * srcW * srcC * sizeof(uint16_t));
      uint16_t* dst_fl16 = (uint16_t*)memalign(64, dstC * dstH * dstW * sizeof(uint16_t));
      float32_to_float16(src, src_fl16, srcH * srcW * srcC);

      Winograd_version ver = winograd_ver_id ? VER_4x4 : VER_2x2;
      WinogradConfig config(ver);
      Padding pads = {padY, padX, padH, padW};
      init_dma_chain_winograd(src_fl16, 1, srcC, srcH, srcW, weight_yxcd, dst_fl16, dstC, &config, &pads);

      int tic[2], instr[2];

      count_tics(tic, instr);
      run_winograd(src_fl16, 1, srcC, srcH, srcW, dst_fl16, dstC, &config, &pads);
      count_tics(&tic[1], &instr[1]);

      printf("   %7.3f  | ", (float)dstH * dstW * srcC * 3 * 3 * dstC / (tic[1] - tic[0]));
      printf("%10d | ", tic[1] - tic[0]);
      ref_convolution_wo_bias(src, 1, srcH, srcW, srcC, kernelY, kernelX, 1, 1, 1, 1, padY, padX, padH, padW, 1,
                              weight_dyxc, dst_ref, dstC);

      float16_to_float32(dst_fl16, dst_opt, dstC * dstH * dstW);
      destroy_dma_chain_winograd(&config);
      free(src_fl16);
      free(dst_fl16);

      int error = 0;
      for (int i = 0; i < dstC * dstH * dstW; ++i) {
        if ((dst_ref[i] != 0.0 && (abs(dst_ref[i] - dst_opt[i]) / abs(dst_ref[i])) > 0.05) ||
            (dst_ref[i] == 0.0 && abs(dst_opt[i]) > 0.1)) {
          printf("\n%d %f %f\n", i, dst_ref[i], dst_opt[i]);
          error++;
          break;
        }
      }

      free(src);
      free(weight_dyxc);
      free(weight_yxcd);
      free(dst_ref);
      free(dst_opt);

      if (!error)
        printf("passed |\n");
      else
        printf("failed |\n");

      test_result += error;

      printf(
          "|                  | (%1d %2d) (%1d %2d) (%1d %1d %1d %1d)         |      |            "
          "|             |            |       "
          " |\n",
          3, 3, 1, 1, padX, padW, padY, padH);
      printf(
          "|                  | (%1d %4d %4d %4d)              |      |            |      "
          "       |            |        |\n",
          1, dstH, dstW, dstC);
      printf(
          "----------------------------------------------------------------------"
          "----------------------------------------\n");
    }
  }

  enable_l2_cache(L2_CACHE_SIZE);

  return test_result;
}

__attribute__((optnone)) static void weight_dyxc_to_yxcd(const float* input, float* output, int kernelY, int kernelX,
                                                         int srcC, int dstC) {
  int indx_in = 0;
  int indx_out = 0;
  for (int h = 0; h < kernelY; ++h) {
    for (int w = 0; w < kernelX; ++w) {
      for (int ch = 0; ch < srcC; ++ch) {
        for (int b = 0; b < dstC; ++b) {
          indx_out = h * dstC * kernelX * srcC + w * srcC * dstC + ch * dstC + b;
          indx_in = b * kernelY * kernelX * srcC + h * kernelX * srcC + w * srcC + ch;
          output[indx_out] = input[indx_in];
        }
      }
    }
  }
}

static void ref_convolution_wo_bias(const float* src, int batch, int srcH, int srcW, int srcC, int kernelY,
                                    int kernelX, int dilationY, int dilationX, int strideY, int strideX, int padY,
                                    int padX, int padH, int padW, int group, const float* weight, float* dst,
                                    int dstC) {
  int dstH = (srcH + padY + padH - (dilationY * (kernelY - 1) + 1)) / strideY + 1;
  int dstW = (srcW + padX + padW - (dilationX * (kernelX - 1) + 1)) / strideX + 1;

  for (int b = 0; b < batch; ++b) {
    for (int dc = 0; dc < dstC; ++dc) {
      for (int dh = 0; dh < dstH; ++dh) {
        for (int dw = 0; dw < dstW; ++dw) {
          float sum = 0;
          for (int sc = 0; sc < srcC; ++sc) {
            for (int ky = 0; ky < kernelY; ky++) {
              for (int kx = 0; kx < kernelX; kx++) {
                int sy = dh * strideY + ky * dilationY - padY;
                int sx = dw * strideX + kx * dilationX - padX;
                if (sy >= 0 && sy < srcH && sx >= 0 && sx < srcW) {
                  sum += src[((b * srcH + sy) * srcW + sx) * srcC + sc] *
                         weight[((dc * kernelY + ky) * kernelX + kx) * srcC + sc];
                }
              }
            }
          }
          dst[((b * dstH + dh) * dstW + dw) * dstC + dc] = sum;  // дополнительная обработка relu и пр.
        }
      }
    }
  }
}
