# mypy: allow-untyped-defs
"""
This module implements nonuniform observers used to collect statistics about
the values observed during calibration (PTQ) or training (QAT).
"""

import itertools

import torch
from torch.ao.quantization.experimental.apot_utils import apot_to_float, float_to_apot
from torch.ao.quantization.observer import ObserverBase


# TODO: Consider adding NonUniformQuantizationObserverBase class
# when more than one non-uniform method is implemented


class APoTObserver(ObserverBase):
    b: int
    k: int
    n: int
    min_val: torch.Tensor
    max_val: torch.Tensor

    def __init__(self, b, k, dtype=torch.quint8) -> None:
        super().__init__(dtype)
        self.b = b
        self.k = k

        self.min_val = torch.tensor([])
        self.max_val = torch.tensor([])

    # min_val and max_val are optional args to override
    # the min_val and max_val observed by forward
    def calculate_qparams(self, signed):  # type:ignore[override]
        return self._calculate_qparams(signed, self.min_val, self.max_val)

    r""" Calculates nonuniform quantization parameters according to APoT paper:
    https://arxiv.org/pdf/1909.13144.pdf.
    Arg:
        signed: specifies whether to include signed values in quantization level calculations
        min_val: optional arg that can override min_val internal attribute
        max_val: optional arg that can override max_val internal attribute
    Returns:
        alpha: alpha quantization parameter, max of abs value of observed values
        gamma: gamma quantization parameter, defined to ensure that alpha is the maximum of the range
        quantization_levels: non-uniform quantization levels (fp representation)
        level_indices: int representation of quantization_levels indices
    """

    def _calculate_qparams(self, signed: bool, min_val=None, max_val=None):
        if min_val is not None:
            self.min_val = min_val
        if max_val is not None:
            self.max_val = max_val

        # compute alpha
        alpha = torch.max(-self.min_val, self.max_val)

        # check for valid inputs of b, k
        assert self.k and self.k != 0
        assert self.b % self.k == 0

        # compute n and store as member variable
        self.n = self.b // self.k

        # store a tensor of subtensors (all levels)
        p_all = []

        # create levels
        for i in range(0, self.n):
            p_curr = torch.tensor([0])

            for j in range(0, (2**self.k - 2) + 1):
                curr_ele = 2 ** (-(i + j * self.n))
                p_append = torch.tensor([curr_ele])
                p_curr = torch.cat((p_curr, p_append))
                # introduce signed numbers
                if signed:
                    p_curr = torch.cat((p_curr, torch.tensor([-curr_ele])))

            if signed:
                # sort tensor in reverse order before adding to list if signed
                sorted, _indices = torch.sort(p_curr, descending=True)
                p_all.append(sorted)
            else:
                p_all.append(p_curr)

        # gamma calculation:
        # loop through all tensors
        # if signed, add element at index 0 for each tensor
        # else, add element at index 1 for each tensor
        # gamma defined to ensure alpha is at max of range
        p_sum = 0.0
        for tens in p_all:
            if signed:
                p_sum += float(tens[0])
            else:
                p_sum += float(tens[1])

        # assign gamma
        gamma = alpha / p_sum

        # calculate cartesian product
        cartesian_product = list(itertools.product(*p_all))

        quantization_levels_list = []

        # calculate sum of each row
        for row in cartesian_product:
            sum = 0.0
            for ele in row:
                sum += ele
            quantization_levels_list.append(sum)

        quantization_levels_gamma = [
            float(gamma) * ele for ele in quantization_levels_list
        ]
        quantization_levels = torch.tensor(quantization_levels_gamma)
        level_indices = torch.tensor([])
        quantization_levels, level_indices = quantization_levels.sort()

        return (alpha, gamma, quantization_levels, level_indices)

    r"""Records the running minimum and maximum of ``x``.
        Args:
            x_orig: Tensor to be observed for min and max val"""

    def forward(self, x_orig):
        if x_orig.numel() == 0:
            return x_orig
        x = x_orig.detach()
        min_val, max_val = torch.aminmax(x)
        if self.min_val.numel():
            min_val = torch.min(min_val, self.min_val)
        if self.max_val.numel():
            max_val = torch.max(max_val, self.max_val)
        self.min_val = min_val
        self.max_val = max_val
        return x_orig

    r"""Displays visualization of APoT quantization levels
        Args:
            observer: APoTObserver to calculate qparams
            signed: bool to indicate if qparams should be signed/unsigned
    """

    def quant_levels_visualization(self, signed=False):
        # matplotlib is optional dep
        import matplotlib.pyplot as plt

        alpha, _gamma, quantization_levels, level_indices = self.calculate_qparams(
            signed
        )

        xs = [float(x) / 1000.0 for x in range(1000)]
        ys = [
            apot_to_float(
                float_to_apot(x, quantization_levels, level_indices, alpha),
                quantization_levels,
                level_indices,
            ).item()
            for x in xs
        ]

        plt.figure(figsize=(15, 10))

        plt.plot(xs, ys)
        plt.title("APoT Quantization Plot")
        plt.xlabel("Full Precision")
        plt.ylabel("Quantized")
        plt.show()
