/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 * All rights reserved.
 *
 * This source code is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree.
 */

#define FBGEMM_EXPORTS
#include <algorithm>
#include <iterator>
#include <numeric>
#include <type_traits>

#include "fbgemm/QuantUtils.h"

#include <cpuinfo.h>

#include "fbgemm/Fbgemm.h"

#include "fbgemm/Types.h"

namespace fbgemm {

using namespace std;

// Use fp16_min as the small scale cutoff because we don't want to use scales in
// fp16 subnormal range. This is to be consistent with Glow and FakeLowP
// implementation for NNPI.
constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f;

float TensorQuantizationParams::Min() const {
  return Dequantize(0, *this);
}

float TensorQuantizationParams::Max() const {
  return Dequantize((1 << precision) - 1, *this);
}

TensorQuantizationParams ChooseQuantizationParams(
    float min,
    float max,
    int32_t qmin,
    int32_t qmax,
    bool preserve_sparsity,
    bool force_scale_power_of_two) {
  if (min < 0 && max > 0 && preserve_sparsity) {
    int symmetric_qmin = -((qmax - qmin) / 2 + 1);
    int symmetric_qmax = (qmax - qmin) / 2;
    double max_scale =
        std::max(fabs(min / symmetric_qmin), fabs(max / symmetric_qmax));
    min = max_scale * symmetric_qmin;
    max = max_scale * symmetric_qmax;
  }

  // We extend the [min, max] interval to ensure that it contains 0.
  // Otherwise, we would not meet the requirement that 0 be an exactly
  // representable value.
  min = std::min(min, 0.f);
  max = std::max(max, 0.f);

  // Use double precision for intermediate computation but use single precision
  // in final number to reflect the actual number used during quantization.
  float scale = (static_cast<double>(max) - min) / (qmax - qmin);
  // If scale is 0 or too small so its reciprocal is infinity, we arbitrary
  // adjust the scale to 0.1 . We want to avoid scale's reciprocal being
  // infinity because some of fbgemm code pre-computes scale's reciprocal to do
  // multiplication instead of division in the time critical part of code.
  if (scale == 0.0f || isinf(1.0f / scale)) {
    scale = 0.1;
  }
  assert(scale > 0);

  if (force_scale_power_of_two) {
    if (scale < 1) {
      scale = 1.0 / (1 << static_cast<int>(floor(log2(1.0 / scale))));
    } else {
      scale = 1 << static_cast<int>(ceil(log2(scale)));
    }
  }

  // Cut off small scale
  if (scale < SMALL_SCALE_THRESHOLD) {
    float org_scale = scale;
    scale = SMALL_SCALE_THRESHOLD;
    // Adjust the min and max based on the new scale
    if (min == 0.0f) {
      max = SMALL_SCALE_THRESHOLD * (qmax - qmin);
    } else if (max == 0.0f) {
      min = -SMALL_SCALE_THRESHOLD * (qmax - qmin);
    } else {
      float amplifier = SMALL_SCALE_THRESHOLD / org_scale;
      min *= amplifier;
      max *= amplifier;
    }
  }

  // Zero-point computation.
  // First the initial floating-point computation. The zero-point can be
  // determined from solving an affine equation for any known pair
  // (real value, corresponding quantized value).
  // We know two such pairs: (rmin, qmin) and (rmax, qmax).
  // The arithmetic error on the zero point computed from either pair
  // will be roughly machine_epsilon * (sum of absolute values of terms)
  // so we want to use the variant that adds the smaller terms.
  double zero_point_from_min = qmin - min / static_cast<double>(scale);
  double zero_point_from_max = qmax - max / static_cast<double>(scale);
  double zero_point_from_min_error =
      std::abs(qmin) + std::abs(min / static_cast<double>(scale));
  double zero_point_from_max_error =
      std::abs(qmax) + std::abs(max / static_cast<double>(scale));
  double initial_zero_point =
      zero_point_from_min_error < zero_point_from_max_error
      ? zero_point_from_min
      : zero_point_from_max;

  // Note: preserve_sparsity here means symmetric quantization.
  // for symmetric quantization, we force zero_point
  // to be a middle value between qmin and qmax.
  // If either min or max is 0, then we just use 0 as zero_point.
  if (min < 0 && max > 0 && preserve_sparsity) {
    initial_zero_point = static_cast<double>(qmin + qmax) / 2;
  }

  // Now we need to nudge the zero point to be an integer
  // (our zero points are integer, and this is motivated by the requirement
  // to be able to represent the real value "0" exactly as a quantized value,
  // which is required in multiple places, for example in Im2col with zero
  // padding).
  int32_t nudged_zero_point = 0;
  if (initial_zero_point < qmin) {
    nudged_zero_point = qmin;
  } else if (initial_zero_point > qmax) {
    nudged_zero_point = qmax;
  } else {
    nudged_zero_point = nearbyint(initial_zero_point);
  }

  TensorQuantizationParams result;
  result.scale = scale;
  result.zero_point = nudged_zero_point;
  return result;
}

void ChooseRequantizationMultiplier(
    float real_multiplier,
    int32_t* quantized_multiplier,
    int* right_shift,
    int requantization_multiplier_precision) {
  assert(real_multiplier != 0.f);

  // Assuming requantization_multiplier_precision_ = 31,
  // the default right shift is 31 when the real multiplier is already
  // in interval [1/2, 1).
  // Multiplying a 32-bit signed integer with all 31 bits except the sign bit
  // is used followed by 31-bit right shift implements multiplying with a real
  // number in [1/2, 1).
  // We want to utilize all 31 bits except the sign bit in the 32-bit signed
  // integer to get the best accuracy.
  int s = 31;

  // We want to bring the real multiplier into the interval [1/2, 1).
  // We can do so by multiplying it by two, and recording how many times
  // we multiplied by two so that we can compensate that by a right
  // shift by the same amount.
  if (real_multiplier > 0.f) {
    while (real_multiplier < 0.5f) {
      real_multiplier *= 2.f;
      s++;
    }
    while (real_multiplier > 1.f) {
      real_multiplier /= 2.f;
      s--;
    }
  }
  // Now that the real multiplier is in [1/2, 1), we convert it
  // into a fixed-point number.
  int64_t q = nearbyint(
      real_multiplier * (1ll << (requantization_multiplier_precision - 1)));
  assert(q <= (1ll << (requantization_multiplier_precision - 1)));
  // Handle the special case when the real multiplier was so close to 1
  // that its fixed-point approximation was undistinguishable from 1.
  // We handle this by dividing it by two, and remembering to decrement
  // the right shift amount.
  if (q == (1ll << (requantization_multiplier_precision - 1))) {
    q /= 2;
    s--;
  }
  assert(s >= 0);
  assert(q >= 0);
  assert(q <= numeric_limits<int32_t>::max());
  *quantized_multiplier = static_cast<int32_t>(q);
  *right_shift = s;
  assert(s < 64);
}

////////////////////////////////////////////////////////////////////////////////
// Utility functions

#define FBGEMM_SPECIALIZED_QUANTIZE(T, LEGACY)                      \
  template <>                                                       \
  FBGEMM_API void Quantize<T, LEGACY>(                              \
      const float* src,                                             \
      T* dst,                                                       \
      const int64_t len,                                            \
      const TensorQuantizationParams& qparams,                      \
      int thread_id,                                                \
      int num_threads) {                                            \
    int64_t i_begin, i_end;                                         \
    fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); \
    for (int64_t i = i_begin; i < i_end; ++i) {                     \
      dst[i] = Quantize<T, LEGACY>(src[i], qparams);                \
    }                                                               \
  }

FBGEMM_SPECIALIZED_QUANTIZE(uint16_t, true)
FBGEMM_SPECIALIZED_QUANTIZE(int16_t, true)
FBGEMM_SPECIALIZED_QUANTIZE(int32_t, true)
FBGEMM_SPECIALIZED_QUANTIZE(uint16_t, false)
FBGEMM_SPECIALIZED_QUANTIZE(int16_t, false)
FBGEMM_SPECIALIZED_QUANTIZE(int32_t, false)

#define FBGEMM_SPECIALIZED_QUANTIZE_AVX2(T, LEGACY)                     \
  template <>                                                           \
  FBGEMM_API void Quantize<T, LEGACY>(                                  \
      const float* src,                                                 \
      T* dst,                                                           \
      int64_t len,                                                      \
      const TensorQuantizationParams& qparams,                          \
      int thread_id,                                                    \
      int num_threads) {                                                \
    bool avx2_support = cpuinfo_initialize() && fbgemmHasAvx2Support(); \
    bool fma_support = cpuinfo_has_x86_fma3();                          \
    int64_t i_begin, i_end;                                             \
    fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end);     \
    if (avx2_support && fma_support && qparams.precision == 8) {        \
      /* fast path  */                                                  \
      QuantizeAvx2<T, LEGACY>(                                          \
          &src[i_begin], &dst[i_begin], i_end - i_begin, qparams);      \
    } else {                                                            \
      for (int64_t i = i_begin; i < i_end; ++i) {                       \
        dst[i] = Quantize<T, LEGACY>(src[i], qparams);                  \
      }                                                                 \
    }                                                                   \
  }

#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
FBGEMM_SPECIALIZED_QUANTIZE_AVX2(int8_t, true)
FBGEMM_SPECIALIZED_QUANTIZE_AVX2(uint8_t, true)
FBGEMM_SPECIALIZED_QUANTIZE_AVX2(int8_t, false)
FBGEMM_SPECIALIZED_QUANTIZE_AVX2(uint8_t, false)
#else
FBGEMM_SPECIALIZED_QUANTIZE(int8_t, true)
FBGEMM_SPECIALIZED_QUANTIZE(uint8_t, true)
FBGEMM_SPECIALIZED_QUANTIZE(int8_t, false)
FBGEMM_SPECIALIZED_QUANTIZE(uint8_t, false)
#endif

#undef FBGEMM_SPECIALIZED_QUANTIZE
#undef FBGEMM_SPECIALIZED_QUANTIZE_AVX2

#define FBGEMM_SPECIALIZED_FUSED_QUANTIZE_DEQUANTIZE(T)             \
  template <>                                                       \
  FBGEMM_API void FusedQuantizeDequantize<T>(                       \
      const float* src,                                             \
      float* dst,                                                   \
      int64_t len,                                                  \
      const TensorQuantizationParams& qparams,                      \
      int thread_id,                                                \
      int num_threads,                                              \
      [[maybe_unused]] float noise_ratio) {                         \
    int64_t i_begin, i_end;                                         \
    fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); \
    for (int64_t i = i_begin; i < i_end; ++i) {                     \
      dst[i] = FusedQuantizeDequantize<T>(src[i], qparams);         \
    }                                                               \
  }

#define FBGEMM_SPECIALIZED_FUSED_QUANTIZE_DEQUANTIZE_AVX2(T)            \
  template <>                                                           \
  FBGEMM_API void FusedQuantizeDequantize<T>(                           \
      const float* src,                                                 \
      float* dst,                                                       \
      int64_t len,                                                      \
      const TensorQuantizationParams& qparams,                          \
      int thread_id,                                                    \
      int num_threads,                                                  \
      float noise_ratio) {                                              \
    bool avx2_support = cpuinfo_initialize() && fbgemmHasAvx2Support(); \
    bool fma_support = cpuinfo_has_x86_fma3();                          \
    int64_t i_begin, i_end;                                             \
    fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end);     \
    if (avx2_support && fma_support && qparams.precision == 8) {        \
      /* fast path  */                                                  \
      FusedQuantizeDequantizeAvx2<T>(                                   \
          &src[i_begin], &dst[i_begin], i_end - i_begin, qparams);      \
    } else if (noise_ratio <= 0.0f) {                                   \
      for (int64_t i = i_begin; i < i_end; ++i) {                       \
        dst[i] = FusedQuantizeDequantize<T>(src[i], qparams);           \
      }                                                                 \
    } else {                                                            \
      throw std::runtime_error("Failed to initialize cpuinfo!");        \
    }                                                                   \
  }

#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
FBGEMM_SPECIALIZED_FUSED_QUANTIZE_DEQUANTIZE_AVX2(int8_t)
FBGEMM_SPECIALIZED_FUSED_QUANTIZE_DEQUANTIZE_AVX2(uint8_t)
#else
FBGEMM_SPECIALIZED_FUSED_QUANTIZE_DEQUANTIZE(int8_t)
FBGEMM_SPECIALIZED_FUSED_QUANTIZE_DEQUANTIZE(uint8_t)
#endif
#undef FBGEMM_SPECIALIZED_FUSED_QUANTIZE_DEQUANTIZE
#undef FBGEMM_SPECIALIZED_FUSED_QUANTIZE_DEQUANTIZE_AVX2

#define FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKCX(T)                       \
  template <>                                                            \
  FBGEMM_API void QuantizeGroupwise<T, layout_t::KCX>(                   \
      const float* src,                                                  \
      int N,                                                             \
      int C,                                                             \
      int X,                                                             \
      int G,                                                             \
      const float* scales,                                               \
      const std::int32_t* zero_points,                                   \
      T* dst) {                                                          \
    assert(C % G == 0);                                                  \
    int C_per_G = C / G;                                                 \
    for (int64_t i = 0; i < N; ++i) {                                    \
      for (int64_t g = 0; g < G; ++g) {                                  \
        float scale = scales[g];                                         \
        int32_t zero_point = zero_points[g];                             \
        for (int64_t c = 0; c < C / G; ++c) {                            \
          for (int64_t x = 0; x < X; ++x) {                              \
            const int64_t idx = (i * C + g * C_per_G + c) * X + x;       \
            dst[idx] =                                                   \
                Quantize<T>(src[idx], zero_point, scale, 8 * sizeof(T)); \
          }                                                              \
        }                                                                \
      }                                                                  \
    }                                                                    \
  }

FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKCX(int8_t)
FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKCX(int32_t)
#undef FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKCX

template <>
FBGEMM_API void QuantizeGroupwise<uint8_t, layout_t::KCX>(
    const float* src,
    int K,
    int C,
    int X,
    int G,
    const float* scales,
    const std::int32_t* zero_points,
    uint8_t* dst) {
  assert(C % G == 0);
  int C_per_G = C / G;
  fbgemm::TensorQuantizationParams qparams;
  qparams.precision = 8 * sizeof(uint8_t);
  bool takeFastPath =
      cpuinfo_initialize() && fbgemmHasAvx2Support() && cpuinfo_has_x86_fma3();

  for (int64_t i = 0; i < K; ++i) {
    for (int64_t g = 0; g < G; ++g) {
      qparams.scale = scales[g];
      qparams.zero_point = zero_points[g];
      if (takeFastPath) {
#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
        const int64_t offset = (i * C + g * C_per_G) * X;
        QuantizeAvx2(
            src + offset,
            dst + offset,
            static_cast<int64_t>(C_per_G) * X,
            qparams);
#endif
      } else {
        for (int64_t c = 0; c < C / G; ++c) {
          for (int64_t x = 0; x < X; ++x) {
            const int64_t idx = (i * C + g * C_per_G + c) * X + x;
            dst[idx] = Quantize<uint8_t>(
                src[idx], qparams.zero_point, qparams.scale, qparams.precision);
          }
        }
      }
    }
  }
}

#define FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC(T)                       \
  template <>                                                            \
  FBGEMM_API void QuantizeGroupwise<T, layout_t::KXC>(                   \
      const float* src,                                                  \
      int K,                                                             \
      int C,                                                             \
      int X,                                                             \
      int G,                                                             \
      const float* scales,                                               \
      const std::int32_t* zero_points,                                   \
      T* dst) {                                                          \
    assert(C % G == 0);                                                  \
    int C_per_G = C / G;                                                 \
    for (int64_t i = 0; i < K; ++i) {                                    \
      for (int64_t x = 0; x < X; ++x) {                                  \
        for (int64_t g = 0; g < G; ++g) {                                \
          float scale = scales[g];                                       \
          int32_t zero_point = zero_points[g];                           \
          for (int64_t c = 0; c < C / G; ++c) {                          \
            const int64_t idx = (i * X + x) * C + g * C_per_G + c;       \
            dst[idx] =                                                   \
                Quantize<T>(src[idx], zero_point, scale, 8 * sizeof(T)); \
          }                                                              \
        }                                                                \
      }                                                                  \
    }                                                                    \
  }
FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC(int8_t)
FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC(uint8_t)
FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC(int32_t)
#undef FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC

////////////////////////////////////////////////////////////////////////////////
// Requantization (pure fixed-point)

int64_t SaturatingRoundingMulWithShift(int32_t a, int32_t b, int right_shift) {
  int64_t a_64(a);
  int64_t b_64(b);
  int64_t ab_64 = a_64 * b_64;

  int64_t nudge = 1ll << (right_shift - 1);
  return (ab_64 + nudge) >> right_shift;
}

#define FBGEMM_SPECIALIZED_REQUANTIZE(T)                            \
  template <>                                                       \
  FBGEMM_API void Requantize<T>(                                    \
      const int32_t* src,                                           \
      T* dst,                                                       \
      const int64_t len,                                            \
      const RequantizationParams& params,                           \
      int thread_id,                                                \
      int num_threads) {                                            \
    int64_t i_begin, i_end;                                         \
    fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); \
    for (int64_t i = i_begin; i < i_end; ++i) {                     \
      dst[i] = Requantize<T>(src[i], params);                       \
    }                                                               \
  }
FBGEMM_SPECIALIZED_REQUANTIZE(uint16_t)
FBGEMM_SPECIALIZED_REQUANTIZE(int32_t)
#undef FBGEMM_SPECIALIZED_REQUANTIZE

template <>
FBGEMM_API void Requantize<uint8_t>(
    const int32_t* src,
    uint8_t* dst,
    const int64_t len,
    const RequantizationParams& params,
    int thread_id,
    int num_threads) {
  int64_t i_begin, i_end;
  fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end);
  if (params.target_qparams.precision == 8 && cpuinfo_initialize() &&
      fbgemmHasAvx2Support()) {
#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
    RequantizeAvx2(&src[i_begin], &dst[i_begin], i_end - i_begin, params);
#endif
  } else {
    for (int64_t i = i_begin; i < i_end; ++i) {
      dst[i] = Requantize<uint8_t>(src[i], params);
    }
  }
}

template <typename T>
FBGEMM_API void RequantizeFixedPoint(
    const std::int32_t* src,
    T* dst,
    int64_t len,
    const RequantizationParams& params,
    int thread_id,
    int num_threads) {
  int64_t i_begin, i_end;
  fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end);
  if (std::is_same<T, uint8_t>::value && params.target_qparams.precision == 8 &&
      cpuinfo_initialize() && fbgemmHasAvx2Support()) {
#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
    RequantizeFixedPointAvx2(
        &src[i_begin], &dst[i_begin], i_end - i_begin, params);
#endif
  } else {
    for (int64_t i = i_begin; i < i_end; ++i) {
      dst[i] = RequantizeFixedPoint<T>(src[i], params);
    }
  }
}

#define FBGEMM_SPECIALIZED_REQUANTIZE(T)                            \
  template <>                                                       \
  FBGEMM_API void RequantizeFixedPoint<T>(                          \
      const int32_t* src,                                           \
      T* dst,                                                       \
      const int64_t len,                                            \
      const RequantizationParams& params,                           \
      int thread_id,                                                \
      int num_threads) {                                            \
    int64_t i_begin, i_end;                                         \
    fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); \
    for (int64_t i = i_begin; i < i_end; ++i) {                     \
      dst[i] = RequantizeFixedPoint<T>(src[i], params);             \
    }                                                               \
  }
FBGEMM_SPECIALIZED_REQUANTIZE(uint16_t)
FBGEMM_SPECIALIZED_REQUANTIZE(int32_t)
#undef FBGEMM_SPECIALIZED_REQUANTIZE

template <>
FBGEMM_API void RequantizeFixedPoint<uint8_t>(
    const int32_t* src,
    uint8_t* dst,
    const int64_t len,
    const RequantizationParams& params,
    int thread_id,
    int num_threads) {
  int64_t i_begin, i_end;
  fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end);

  if (params.target_qparams.precision == 8 && cpuinfo_initialize() &&
      fbgemmHasAvx2Support()) {
#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
    RequantizeFixedPointAvx2(
        &src[i_begin], &dst[i_begin], i_end - i_begin, params);
#endif
  } else {
    for (int64_t i = i_begin; i < i_end; ++i) {
      dst[i] = RequantizeFixedPoint<uint8_t>(src[i], params);
    }
  }
}

template <typename InputType>
void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef(
    int bit_rate,
    const InputType* input,
    size_t input_rows,
    int input_columns,
    std::uint8_t* output) {
  if (input_rows == 0 || input_columns == 0) {
    return;
  }

  static_assert(
      std::is_same<InputType, float>() || std::is_same<InputType, float16>(),
      "Only float and float16 types are allowed.");
  int num_elem_per_byte = 8 / bit_rate;
  const int output_columns =
      (static_cast<int64_t>(input_columns) + num_elem_per_byte - 1) /
          num_elem_per_byte +
      2 * sizeof(float16);
  std::vector<float> input_row_float(input_columns);
  for (size_t row = 0; row < input_rows; ++row) {
    const InputType* input_row = input + row * input_columns;
    std::uint8_t* output_row = output + row * output_columns;
    float16* output_row_scale_bias = reinterpret_cast<float16*>(
        output_row +
        (static_cast<int64_t>(input_columns) + num_elem_per_byte - 1) /
            num_elem_per_byte);

    // NOTE: this can be optimized, however we don't care much about performance
    // for reference implementation.
    for (int col = 0; col < input_columns; ++col) {
      if (std::is_same<InputType, float>()) {
        input_row_float[col] = input_row[col];
      } else {
        input_row_float[col] = cpu_half2float(input_row[col]);
      }
    }

    float minimum_element =
        *std::min_element(input_row_float.begin(), input_row_float.end());
    float maximum_element =
        *std::max_element(input_row_float.begin(), input_row_float.end());
    // Truncate since bias will be represented by fp16. Keep higher precision
    // max untouched.
    float16 minimum_element_fp16 = cpu_float2half_rn(minimum_element);
    minimum_element = cpu_half2float(minimum_element_fp16);
    const float range = maximum_element - minimum_element;

    float scale = range == 0 ? 1.0f : range / ((1 << bit_rate) - 1);
    float16 scale_fp16 = cpu_float2half_rn(scale);
    scale = cpu_half2float(scale_fp16);
    if (scale == 0) {
      // Corner case handling when maximum_element == minimum_element
      // Any scale would work because X - minimum_element will be 0 for all X
      scale = 1.0f;
    }
    float inverse_scale = 1.0f / scale;
    if (std::isinf(inverse_scale)) {
      scale = 1.0f;
      inverse_scale = 1.0f;
    }

    output_row_scale_bias[0] = cpu_float2half_rn(scale);
    output_row_scale_bias[1] = minimum_element_fp16;
    for (int col = 0; col < input_columns; ++col) {
      float X = input_row_float[col];
      std::uint8_t quantized = std::max(
          0,
          std::min<int>(
              std::lrintf((X - minimum_element) * inverse_scale),
              (1 << bit_rate) - 1));
      if (col % num_elem_per_byte == 0) {
        output_row[col / num_elem_per_byte] = quantized;
      } else {
        output_row[col / num_elem_per_byte] |=
            (quantized << ((col % num_elem_per_byte) * bit_rate));
      }
    }
  } // for each row
}

template <typename InputType>
void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(
    int bit_rate,
    const InputType* input,
    size_t input_rows,
    int input_columns,
    std::uint8_t* output) {
  // Currenlty we can only dequantize if the number of input columns
  // is a multiple of number of elements_per_byte

  int num_elem_per_byte = 8 / bit_rate;
  if (input_columns % num_elem_per_byte != 0) {
    throw std::runtime_error("Unsupported number of columns");
  }

  if (cpuinfo_initialize() && fbgemmHasAvx2Support()) {
#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
    switch (bit_rate) {
      case 2:
        FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2<InputType, 2>(
            input, input_rows, input_columns, output);
        break;
      case 4:
        FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2<InputType, 4>(
            input, input_rows, input_columns, output);
        break;
      case 8:
        FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2<InputType, 8>(
            input, input_rows, input_columns, output);
        break;
      default:
        FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef<InputType>(
            bit_rate, input, input_rows, input_columns, output);
    }
#endif
  } else {
    FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef<InputType>(
        bit_rate, input, input_rows, input_columns, output);
  }
}

template <typename InputType>
void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef(
    const InputType* input,
    size_t input_rows,
    int input_columns,
    std::uint8_t* output) {
  constexpr float kEpsilon = 1e-8f;

  if (input_rows == 0 || input_columns == 0) {
    return;
  }

  const int64_t output_columns =
      static_cast<int64_t>(input_columns) + 2 * sizeof(float);
  std::vector<float> input_row_float(input_columns);
  for (size_t row = 0; row < input_rows; ++row) {
    const InputType* input_row = input + row * input_columns;
    std::uint8_t* output_row = output + row * output_columns;
    float* output_row_scale_bias =
        reinterpret_cast<float*>(output_row + input_columns);

    for (int col = 0; col < input_columns; ++col) {
      if (std::is_same<InputType, float>()) {
        input_row_float[col] = input_row[col];
      } else {
        input_row_float[col] = cpu_half2float(input_row[col]);
      }
    }

    float minimum_element =
        *std::min_element(input_row_float.begin(), input_row_float.end());
    float maximum_element =
        *std::max_element(input_row_float.begin(), input_row_float.end());
    float range = maximum_element - minimum_element;

    output_row_scale_bias[0] = range / 255.0f;
    output_row_scale_bias[1] = minimum_element;
    const auto inverse_scale = 255.0f / (range + kEpsilon);
    for (int64_t col = 0; col < input_columns; ++col) {
      output_row[col] =
          std::lrintf((input_row_float[col] - minimum_element) * inverse_scale);
    }
  } // for each row
}

template <typename InputType>
void FloatOrHalfToFused8BitRowwiseQuantizedSBFloat(
    const InputType* input,
    size_t input_rows,
    int input_columns,
    std::uint8_t* output) {
  if (cpuinfo_initialize() && fbgemmHasAvx2Support()) {
#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
    FloatOrHalfToFused8BitRowwiseQuantizedSBFloatAvx2<InputType>(
        input, input_rows, input_columns, output);
#endif
  } else {
    FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef<InputType>(
        input, input_rows, input_columns, output);
  }
}

template <typename OutputType>
void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef(
    int bit_rate,
    const uint8_t* input,
    size_t input_rows,
    int input_columns,
    OutputType* output) {
  static_assert(
      std::is_same<OutputType, float>() || std::is_same<OutputType, float16>(),
      "Only float and float16 types are allowed.");
  int num_elem_per_byte = 8 / bit_rate;
  const int64_t output_columns =
      static_cast<int64_t>(input_columns - 2 * sizeof(float16)) *
      num_elem_per_byte;

  for (size_t row = 0; row < input_rows; ++row) {
    const std::uint8_t* input_row = input + row * input_columns;
    const float16* input_row_scale_bias = reinterpret_cast<const float16*>(
        input_row +
        (output_columns + num_elem_per_byte - 1) / num_elem_per_byte);
    float scale = cpu_half2float(input_row_scale_bias[0]);
    float bias = cpu_half2float(input_row_scale_bias[1]);
    OutputType* output_row = output + row * output_columns;

    for (int64_t col = 0; col < output_columns; ++col) {
      std::uint8_t quantized = input_row[col / num_elem_per_byte];
      quantized >>= (col % num_elem_per_byte) * bit_rate;
      quantized &= (1 << bit_rate) - 1;
      float output_value = scale * quantized + bias;
      if (std::is_same<OutputType, float>()) {
        output_row[col] = output_value;
      } else {
        output_row[col] = cpu_float2half_rn(output_value);
      }
    }
  }
}

template <typename OutputType>
void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf(
    int bit_rate,
    const uint8_t* input,
    size_t input_rows,
    int input_columns,
    OutputType* output) {
  if (cpuinfo_initialize() && fbgemmHasAvx2Support()) {
#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
    switch (bit_rate) {
      case 2:
        FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2<OutputType, 2>(
            input, input_rows, input_columns, output);
        break;
      case 4:
        FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2<OutputType, 4>(
            input, input_rows, input_columns, output);
        break;
      case 8:
        FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2<OutputType, 8>(
            input, input_rows, input_columns, output);
        break;
      default:
        FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<OutputType>(
            bit_rate, input, input_rows, input_columns, output);
    }
#endif
  } else {
    FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<OutputType>(
        bit_rate, input, input_rows, input_columns, output);
  }
}

template <typename OutputType>
void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef(
    const std::uint8_t* input,
    size_t input_rows,
    int input_columns,
    OutputType* output) {
  int output_columns = input_columns - 2 * sizeof(float);

  for (size_t row = 0; row < input_rows; ++row) {
    const std::uint8_t* input_row = input + row * input_columns;
    const float* input_row_scale_bias =
        reinterpret_cast<const float*>(input_row + output_columns);
    OutputType* output_row = output + row * output_columns;

    for (int col = 0; col < output_columns; ++col) {
      float output_value =
          input_row[col] * input_row_scale_bias[0] + input_row_scale_bias[1];
      if (std::is_same<OutputType, float>()) {
        output_row[col] = output_value;
      } else {
        output_row[col] = cpu_float2half_rn(output_value);
      }
    }
  }
}

template <typename OutputType>
void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf(
    const std::uint8_t* input,
    size_t input_rows,
    int input_columns,
    OutputType* output) {
  if (cpuinfo_initialize() && fbgemmHasAvx2Support()) {
#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
    Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2<OutputType>(
        input, input_rows, input_columns, output);
#endif
  } else {
    Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef<OutputType>(
        input, input_rows, input_columns, output);
  }
}

#define INSTANTIATE_QuantizationFunctions(type)                                \
  template FBGEMM_API void                                                     \
  FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef<type>(                       \
      int bit_rate,                                                            \
      const type* input,                                                       \
      size_t input_rows,                                                       \
      int input_columns,                                                       \
      std::uint8_t* output);                                                   \
  template FBGEMM_API void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<type>( \
      int bit_rate,                                                            \
      const type* input,                                                       \
      size_t input_rows,                                                       \
      int input_columns,                                                       \
      std::uint8_t* output);                                                   \
  template FBGEMM_API void                                                     \
  FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<type>(                       \
      int bit_rate,                                                            \
      const uint8_t* input,                                                    \
      size_t input_rows,                                                       \
      int input_columns,                                                       \
      type* output);                                                           \
  template FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf<type>( \
      int bit_rate,                                                            \
      const uint8_t* input,                                                    \
      size_t input_rows,                                                       \
      int input_columns,                                                       \
      type* output);                                                           \
  template FBGEMM_API void                                                     \
  FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef<type>(                      \
      const type* input,                                                       \
      size_t input_rows,                                                       \
      int input_columns,                                                       \
      std::uint8_t* output);                                                   \
  template FBGEMM_API void                                                     \
  FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<type>(                         \
      const type* input,                                                       \
      size_t input_rows,                                                       \
      int input_columns,                                                       \
      std::uint8_t* output);                                                   \
  template FBGEMM_API void                                                     \
  Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef<type>(                      \
      const uint8_t* input,                                                    \
      size_t input_rows,                                                       \
      int input_columns,                                                       \
      type* output);                                                           \
  template FBGEMM_API void                                                     \
  Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<type>(                         \
      const uint8_t* input,                                                    \
      size_t input_rows,                                                       \
      int input_columns,                                                       \
      type* output);

// clang-format off
INSTANTIATE_QuantizationFunctions(float)
INSTANTIATE_QuantizationFunctions(float16)
// clang-format on

#undef INSTANTIATE_QuantizationFunctions

} // namespace fbgemm
