// Copyright 2018-2025 RnD Center "ELVEES", JSC

/*! \file
 *  \brief Тестирование функции mat_trans_scalar
 *  \author Фролов Андрей
 */

#include "tests.h"

int main() {
  int failed_count = 0;
  int rows[TEST_COUNT] = {8, 16, 32, 64, 128};
  int columns[TEST_COUNT] = {8, 16, 32, 64, 128};

  int print = 0;

  print_table_header();

#ifndef LOCAL_MEM
  void* src0 = memalign(64, columns[TEST_COUNT - 1] * rows[TEST_COUNT - 1] * sizeof(int64_t));
  void* dst_ref = memalign(64, columns[TEST_COUNT - 1] * rows[TEST_COUNT - 1] * sizeof(int64_t));
  void* dst_opt = memalign(64, columns[TEST_COUNT - 1] * rows[TEST_COUNT - 1] * sizeof(int64_t));
#else
#ifdef BARE_METAL
  void* src0 = &__local_mem;
#else
  disable_l2_cache();
  void* src0 = &xyram_data;
#endif
  void* dst_ref = src0 + columns[TEST_COUNT - 1] * rows[TEST_COUNT - 1] * sizeof(int64_t);
  void* dst_opt = dst_ref + columns[TEST_COUNT - 1] * rows[TEST_COUNT - 1] * sizeof(int64_t);
#endif

  create_vector_s16((int16_t*)src0, columns[TEST_COUNT - 1] * rows[TEST_COUNT - 1], 0);
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = columns[i] * rows[i] * sizeof(int16_t) * 1;
    int32_t ti_tics = (int)(1.0 / 2 * rows[i] * columns[i] + 35);
    printf("| mat_trans_scalar_short |  %6d %6d |", rows[i], columns[i]);
    failed_count += test_mat_trans_scalar((int16_t*)src0, (int16_t*)dst_ref, (int16_t*)dst_opt, rows[i], columns[i],
                                          print, input_bytes, ti_tics);
  }

  create_vector_s32((int32_t*)src0, columns[TEST_COUNT - 1] * rows[TEST_COUNT - 1], 0);
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = columns[i] * rows[i] * sizeof(int32_t) * 1;
    int32_t ti_tics = 0;
    printf("| mat_trans_scalar_int   |  %6d %6d |", rows[i], columns[i]);
    failed_count += test_mat_trans_scalar_s32((int32_t*)src0, (int32_t*)dst_ref, (int32_t*)dst_opt, rows[i],
                                              columns[i], print, input_bytes, ti_tics);
  }

  create_vector_float((float*)src0, columns[TEST_COUNT - 1] * rows[TEST_COUNT - 1], 0);
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = columns[i] * rows[i] * sizeof(float) * 1;
    int32_t ti_tics = (int)(1.0 / 2 * rows[i] * columns[i] + 4.0 * rows[i] + 29);
    printf("| mat_trans_scalar_fl    |  %6d %6d |", rows[i], columns[i]);
    failed_count += test_mat_trans_scalar_fl((float*)src0, (float*)dst_ref, (float*)dst_opt, rows[i], columns[i],
                                             print, input_bytes, ti_tics);
  }

  create_vector_double((double*)src0, columns[TEST_COUNT - 1] * rows[TEST_COUNT - 1], 0);
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = columns[i] * rows[i] * sizeof(double) * 1;
    int32_t ti_tics = rows[i] * columns[i] + 6 * rows[i] + 28;
    printf("| mat_trans_scalar_db    |  %6d %6d |", rows[i], columns[i]);
    failed_count += test_mat_trans_scalar_db((double*)src0, (double*)dst_ref, (double*)dst_opt, rows[i], columns[i],
                                             print, input_bytes, ti_tics);
  }

#ifndef LOCAL_MEM
  free(src0);
  free(dst_ref);
  free(dst_opt);
#else
#ifndef BARE_METAL
  enable_l2_cache(L2_CACHE_SIZE);
#endif
#endif

  return failed_count;
}

int test_mat_trans_scalar(int16_t* src0, int16_t* dst_ref, int16_t* dst_opt, int rows, int columns, int print,
                          int32_t input_bytes, int32_t ti_tics) {
  int ret = 0;

  uint32_t tic_count[2], instruction_count[2];
  uint32_t ref_tic_count[2], ref_instruction_count[2];

  count_tics(ref_tic_count, ref_instruction_count);
  ref_mat_trans(src0, rows, columns, dst_ref);
  count_tics(&ref_tic_count[1], &ref_instruction_count[1]);

  count_tics(tic_count, instruction_count);
  mat_trans_scalar(src0, rows, columns, dst_opt);
  count_tics(&tic_count[1], &instruction_count[1]);

  if (print) {
    printf("\n%d X %d\n", rows, columns);
    printf("mat1:\n");
    print_matrix_s16(src0, rows, columns);
    printf("ref_res %d X %d\n:", columns, rows);
    print_matrix_s16(dst_ref, columns, rows);
    printf("dsp_res %d X %d\n:", columns, rows);
    print_matrix_s16(dst_opt, columns, rows);
  }

  ret = compare_s16(dst_ref, dst_opt, columns * rows);

  print_performance(ref_tic_count, ref_instruction_count, tic_count, instruction_count, input_bytes, ti_tics);

  if (ret == 0)
    printf(" passed |\n");
  else
    printf(" failed |\n");

  return ret;
}

int test_mat_trans_scalar_s32(int32_t* src0, int32_t* dst_ref, int32_t* dst_opt, int rows, int columns, int print,
                              int32_t input_bytes, int32_t ti_tics) {
  int ret = 0;

  uint32_t tic_count[2], instruction_count[2];
  uint32_t ref_tic_count[2], ref_instruction_count[2];

  count_tics(ref_tic_count, ref_instruction_count);
  ref_mat_trans_s32(src0, rows, columns, dst_ref);
  count_tics(&ref_tic_count[1], &ref_instruction_count[1]);

  count_tics(tic_count, instruction_count);
  mat_trans_scalar_s32(src0, rows, columns, dst_opt);
  count_tics(&tic_count[1], &instruction_count[1]);

  if (print) {
    printf("\n%d X %d\n", rows, columns);
    printf("mat1:\n");
    print_matrix_s32(src0, rows, columns);
    printf("ref_res %d X %d\n:", columns, rows);
    print_matrix_s32(dst_ref, columns, rows);
    printf("dsp_res %d X %d\n:", columns, rows);
    print_matrix_s32(dst_opt, columns, rows);
  }

  ret = compare_s32(dst_ref, dst_opt, columns * rows);

  print_performance(ref_tic_count, ref_instruction_count, tic_count, instruction_count, input_bytes, ti_tics);

  if (ret == 0)
    printf(" passed |\n");
  else
    printf(" failed |\n");

  return ret;
}

int test_mat_trans_scalar_fl(float* src0, float* dst_ref, float* dst_opt, int rows, int columns, int print,
                             int32_t input_bytes, int32_t ti_tics) {
  int ret = 0;

  uint32_t tic_count[2], instruction_count[2];
  uint32_t ref_tic_count[2], ref_instruction_count[2];

  count_tics(ref_tic_count, ref_instruction_count);
  ref_mat_trans_fl(src0, rows, columns, dst_ref);
  count_tics(&ref_tic_count[1], &ref_instruction_count[1]);

  count_tics(tic_count, instruction_count);
  mat_trans_scalar_fl(src0, rows, columns, dst_opt);
  count_tics(&tic_count[1], &instruction_count[1]);

  if (print) {
    printf("\n%d X %d\n", rows, columns);
    printf("mat1:\n");
    print_matrix_float(src0, rows, columns);
    printf("ref_res %d X %d\n:", columns, rows);
    print_matrix_float(dst_ref, columns, rows);
    printf("dsp_res %d X %d\n:", columns, rows);
    print_matrix_float(dst_opt, columns, rows);
  }

  ret = compare_float(dst_ref, dst_opt, columns * rows);

  print_performance(ref_tic_count, ref_instruction_count, tic_count, instruction_count, input_bytes, ti_tics);

  if (ret == 0)
    printf(" passed |\n");
  else
    printf(" failed |\n");

  return ret;
}

int test_mat_trans_scalar_db(double* src0, double* dst_ref, double* dst_opt, int rows, int columns, int print,
                             int32_t input_bytes, int32_t ti_tics) {
  int ret = 0;

  uint32_t tic_count[2], instruction_count[2];
  uint32_t ref_tic_count[2], ref_instruction_count[2];

  count_tics(ref_tic_count, ref_instruction_count);
  ref_mat_trans_db(src0, rows, columns, dst_ref);
  count_tics(&ref_tic_count[1], &ref_instruction_count[1]);

  count_tics(tic_count, instruction_count);
  mat_trans_scalar_db(src0, rows, columns, dst_opt);
  count_tics(&tic_count[1], &instruction_count[1]);

  if (print) {
    printf("\n%d X %d\n", rows, columns);
    printf("mat1:\n");
    print_matrix_double(src0, rows, columns);
    printf("ref_res %d X %d\n:", columns, rows);
    print_matrix_double(dst_ref, columns, rows);
    printf("dsp_res %d X %d\n:", columns, rows);
    print_matrix_double(dst_opt, columns, rows);
  }

  ret = compare_double(dst_ref, dst_opt, columns * rows);

  print_performance(ref_tic_count, ref_instruction_count, tic_count, instruction_count, input_bytes, ti_tics);

  if (ret == 0)
    printf(" passed |\n");
  else
    printf(" failed |\n");

  return ret;
}
