// Copyright 2025 RnD Center "ELVEES", JSC

#include <elcore50-matrix-lib/mat_mul_with_dma_fl32.hpp>

#include "convolve_tests_helper.h"

int main() {
  disable_l2_cache();

  int result = 0;
  int test_result = 0;

  int array_row[] = {4, 512, 256, 512, 512, 512, 512, 49, 12544, 3136, 784, 784, 196, 196, 49, 49};
  int array_row1col0[] = {4, 512, 256, 256, 512, 512, 1024, 2048, 16, 24, 192, 32, 384, 64, 160, 960};
  int array_col1[] = {128, 512, 256, 256, 256, 128, 256, 512, 96, 144, 32, 192, 64, 384, 960, 160};

  printf(
      "| func_name     | size (r0, c1r1, c1) | ref mul/tic | opt mul/tic | "
      "status |\n");
  printf(
      "------------------------------------------------------------------------"
      "-------\n");

  for (int i = 0; i < 16; ++i) {
    printf("| mat_mul_fl32  |");

    int row = array_row[i];
    int row1col0 = array_row1col0[i];
    int col1 = array_col1[i];
    printf("   %5d %5d %5d |", row, row1col0, col1);

    int tic[2], instr[2], block_tic[6], block_instr[6], func_tic, func_instr;

    float* src0_fl32 = (float*)memalign(64, row * row1col0 * sizeof(float));
    float* src1_fl32 = (float*)memalign(64, row1col0 * col1 * sizeof(float));

    for (int i = 0; i < row * row1col0; ++i) {
      src0_fl32[i] = rand() % 5 - 2;
    }

    for (int i = 0; i < row1col0 * col1; ++i) {
      src1_fl32[i] = rand() % 5 - 2;
    }

    float* dst = (float*)memalign(64, row * col1 * sizeof(float) * COEF);
    float* dst1 = (float*)memalign(64, row * col1 * sizeof(float) * COEF);
    memset(dst, 0, row * col1 * sizeof(float) * COEF);
    memset(dst1, 0, row * col1 * sizeof(float) * COEF);

    flush_all_caches();
    count_tics(tic, instr);
    mm_v0_vliw_1_sub_matrix_pre_load_real_out_offset(src0_fl32, row, row1col0, src1_fl32, col1, dst, 0, 0, 0, col1,
                                                     block_tic, block_instr, 0, 0);
    count_tics(&tic[1], &instr[1]);

    func_tic = tic[1] - tic[0];
    func_instr = instr[1] - instr[0];
    printf("     %7.3f |", (float)row * row1col0 * col1 / func_tic);

    result = 0;
    flush_all_caches();

    int offset_A = 0;
    int offset_B = 0;

    MatMulFl32Config config;

    init_dma_chain_matmul_fl32(src0_fl32, row, row1col0, src1_fl32, col1, dst1, offset_A, offset_B, &config);

    flush_all_caches();
    count_tics(tic, instr);
    run_matmul_fl32(src0_fl32, row, row1col0, src1_fl32, col1, dst1, offset_A, offset_B, &config);
    count_tics(&tic[1], &instr[1]);

    destroy_dma_chain_mat_mul_fl32(&config);

    func_tic = tic[1] - tic[0];
    func_instr = instr[1] - instr[0];
    printf("     %7.3f |", (float)row * row1col0 * col1 / func_tic);

    result += memcmp(dst, dst1, row * col1 * sizeof(float) * COEF);
    if (!result)
      printf(" passed |\n");
    else
      printf(" failed |\n");

    test_result += result;

    free(src0_fl32);
    free(src1_fl32);
    free(dst);
    free(dst1);
  }

  printf(
      "------------------------------------------------------------------------"
      "-------\n");

  enable_l2_cache(L2_CACHE_SIZE);
  return test_result;
}
