// Copyright 2025 RnD Center "ELVEES", JSC

#include "tests_tile_segmentation.hpp"

template <class T>
void data_generator(T* dst, size_t size) {
  for (int i = 0; i < size; ++i) {
    dst[i] = rand() % INT16_MAX - (INT16_MAX / 2);
    if (!dst[i]) ++dst[i];
  }
}

template <typename Type, class create_func, class ref_func, class run_calc_ptr>
bool test_recip(Type* src, Type* rfrac_ref, Type* rexp_ref, Type* rfrac_opt, Type* rexp_opt, create_func create_vector,
                ref_func reference, run_calc_ptr run_calc, int size, int* localmem) {
  data_generator(src, size);

  FLUSH_ALL_CACHES();
  uint32_t tic_count[2], instruction_count[2];
  count_tics(tic_count, instruction_count);
  reference(src, rfrac_ref, rexp_ref, size);
  count_tics(&tic_count[1], &instruction_count[1]);

  std::cout << "Ref func result (size = " << size << "): tic = " << tic_count[1] - tic_count[0]
            << " instr = " << instruction_count[1] - instruction_count[0] << std::endl;

  TileSegConfig config;
  CreateTileSegConfigRecip16(src, rfrac_opt, rexp_opt, size, &config, localmem);
  FLUSH_ALL_CACHES();

  count_tics(tic_count, instruction_count);
  run_calc(&config);
  count_tics(&tic_count[1], &instruction_count[1]);

  std::cout << "Opt func result (size = " << size << "): tic = " << tic_count[1] - tic_count[0]
            << " instr = " << instruction_count[1] - instruction_count[0] << std::endl;

  int ret = 0;
  for (int i = 0; i < size; ++i) {
    double vect0 = rfrac_ref[i] * pow(2, (double)rexp_ref[i]);
    double vect1 = rfrac_opt[i] * pow(2, (double)rexp_opt[i]);
    if (vect0 == 0.) std::cout << vect0 << " " << vect1 << std::endl;
    if ((fabs((double)(vect0 - vect1) / vect0) > EPS)) {
      std::cout << vect0 << " " << vect1 << std::endl;
      ret += 1;
      if (ret > 10) break;
    }
  }

  return ret;
}

int main() {
  disable_l2_cache();

  void* src0 = memalign(64, SIZE * sizeof(int64_t));
  void* rfrac_ref = memalign(64, SIZE * sizeof(int64_t));
  void* rexp_ref = memalign(64, SIZE * sizeof(int64_t));
  void* rfrac_opt = memalign(64, SIZE * sizeof(int64_t));
  void* rexp_opt = memalign(64, SIZE * sizeof(int64_t));
  int ret = 0;
  int test_status = 0;

  for (int i = 1; i <= SIZE; i *= 2) {
    std::cout << "recip_s16" << std::endl;
#ifdef USE_REF_VER
    ret = test_recip(static_cast<int16_t*>(src0), static_cast<int16_t*>(rfrac_ref), static_cast<int16_t*>(rexp_ref),
                     static_cast<int16_t*>(rfrac_opt), static_cast<int16_t*>(rexp_opt), create_vector_s16, ref_recip16,
                     RunCalculationRecip16, i, &__local_mem);
#else
    ret = test_recip(static_cast<int16_t*>(src0), static_cast<int16_t*>(rfrac_ref), static_cast<int16_t*>(rexp_ref),
                     static_cast<int16_t*>(rfrac_opt), static_cast<int16_t*>(rexp_opt), create_vector_s16, recip16,
                     RunCalculationRecip16, i, &__local_mem);
#endif

    test_status |= ret;
    if (ret) std::cout << "recip16 error!\n";

    if (!test_status)
      std::cout << "Test passed" << std::endl;
    else
      std::cout << "Test failed" << std::endl;
  }

  free(src0);
  free(rfrac_ref);
  free(rexp_ref);
  free(rfrac_opt);
  free(rexp_opt);

  enable_l2_cache(L2_CACHE_SIZE);

  return test_status;
}
