# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
from enum import Enum

import numpy as np

import paddle

try:
    import tensorrt as trt
except Exception as e:
    pass
from paddle import pir
from paddle.base.log_helper import get_logger
from paddle.pir.core import _PADDLE_PIR_DTYPE_2_NUMPY_DTYPE

_logger = get_logger(
    __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)


class RefitRole(Enum):
    SHIFT = "SHIFT"
    SCALE = "SCALE"
    CONSTANT = "CONSTANT"
    BIAS = "BIAS"
    KERNEL = "KERNEL"


def map_dtype(pd_dtype):
    version_list = get_trt_version_list()
    if pd_dtype == "FLOAT32":
        return trt.float32
    elif pd_dtype == "FLOAT16":
        return trt.float16
    elif pd_dtype == "INT32":
        return trt.int32
    elif pd_dtype == "INT8":
        return trt.int8
    elif pd_dtype == "BOOL":
        return trt.bool
    # trt version<10.0 not support int64,so convert int64 to int32
    elif pd_dtype == "INT64":
        return trt.int64 if version_list[0] >= 10 else trt.int32
    # Add other dtype mappings as needed
    else:
        raise TypeError(f"Unsupported dtype: {pd_dtype}")


def support_constant_folding_pass(program):
    for op in program.global_block().ops:
        if op.name() == "pd_op.while" or op.name() == "pd_op.if":
            return False
    return True


def all_ops_into_trt(program):
    for op in program.global_block().ops:
        if (
            op.name() == "pd_op.fetch"
            or op.name() == "pd_op.data"
            or op.name() == "pd_op.tensorrt_engine"
            or op.name() == "cinn_op.group"
            or op.name().split('.')[0] == "builtin"
        ):
            continue
        if op.has_attr("__l_trt__") is False:
            return False
        if op.attrs()["__l_trt__"] is False:
            return False
    _logger.info("All ops convert to trt.")
    return True


def run_pir_pass(program, disable_passes=[], scope=None, precision_mode=None):
    def _add_pass_(pm, passes, disable_passes):
        for pass_item in passes:
            for pass_name, pass_attr in pass_item.items():
                if pass_name in disable_passes:
                    continue
                pm.add_pass(pass_name, pass_attr)

    pm = pir.PassManager(opt_level=4)
    pm.enable_print_statistics()
    if scope is None:
        scope = paddle.static.global_scope()
    place = paddle.CUDAPlace(0)

    # run marker pass
    passes = [
        {'trt_op_marker_pass': {}},
    ]
    if precision_mode is not None and precision_mode.value == "INT8":
        passes.append(
            {
                'delete_quant_dequant_linear_op_pass': {
                    "__param_scope__": scope,
                }
            }
        )
        passes.append(
            {
                'trt_delete_weight_dequant_linear_op_pass': {
                    "__param_scope__": scope,
                }
            }
        )
    _add_pass_(pm, passes, disable_passes)
    pm.run(program)

    # run other passes
    pm.clear()
    passes = []
    if support_constant_folding_pass(program):
        # only run constant_folding_pass when all ops into trt
        passes.append(
            {
                'constant_folding_pass': {
                    "__place__": place,
                    "__param_scope__": scope,
                }
            }
        )
        passes.append(
            {
                'dead_code_elimination_pass': {
                    "__place__": place,
                    "__param_scope__": scope,
                }
            }
        )
        passes.append({'conv2d_add_fuse_pass': {}})
    passes.append({'trt_op_marker_pass': {}})  # for op that created by pass
    _add_pass_(pm, passes, disable_passes)
    pm.run(program)

    return program


def run_trt_partition(program):
    pm = pir.PassManager(opt_level=4)
    pm.enable_print_statistics()
    pm.add_pass("trt_sub_graph_extract_pass", {})
    pm.run(program)
    return program


def forbid_op_lower_trt(program, disabled_ops):
    if isinstance(disabled_ops, str):
        disabled_ops = [disabled_ops]
    for op in program.global_block().ops:
        if op.name() in disabled_ops:
            op.set_bool_attr("__l_trt__", False)


def enforce_op_lower_trt(program, op_name):
    for op in program.global_block().ops:
        if op.name() == op_name:
            op.set_bool_attr("__l_trt__", True)


def predict_program(program, feed_data, fetch_var_list, scope=None):
    with (
        paddle.pir_utils.IrGuard(),
        paddle.static.program_guard(program),
    ):
        place = paddle.CUDAPlace(0)
        executor = paddle.static.Executor(place)
        output = executor.run(
            program,
            feed=feed_data,
            fetch_list=fetch_var_list,
            scope=scope,
        )
        return output


def warmup_shape_infer(program, feeds, scope=None):
    paddle.framework.set_flags({"FLAGS_enable_collect_shape": True})
    with paddle.pir_utils.IrGuard(), paddle.static.program_guard(program):
        executor = paddle.static.Executor()
        # Run the program with input_data
        for i in range(len(feeds)):
            executor.run(program, feed=feeds[i], scope=scope)

        exe_program, _, _ = (
            executor._executor_cache.get_pir_program_and_executor(
                program,
                feed=feeds[-1],
                fetch_list=None,
                feed_var_name='feed',
                fetch_var_name='fetch',
                place=paddle.framework._current_expected_place_(),
                scope=scope,
                plan=None,
            )
        )
    paddle.framework.set_flags({"FLAGS_enable_collect_shape": False})

    return exe_program


def get_trt_version_list():
    version = trt.__version__
    return list(map(int, version.split('.')))


# Adding marker labels to builtin ops facilitates convert processing, but they ultimately do not enter the TensorRT subgraph.
def mark_builtin_op(program):
    for op in program.global_block().ops:
        if op.name() == "builtin.split":
            defining_op = op.operands()[0].source().get_defining_op()
            if defining_op is not None:
                if (
                    defining_op.has_attr("__l_trt__")
                    and defining_op.attrs()["__l_trt__"]
                ):
                    op.set_bool_attr("__l_trt__", True)
        if op.name() == "builtin.combine":
            defining_op = op.results()[0].all_used_ops()[0]
            if defining_op is not None:
                if (
                    defining_op.has_attr("__l_trt__")
                    and defining_op.attrs()["__l_trt__"]
                ):
                    op.set_bool_attr("__l_trt__", True)


class TensorRTConfigManager:
    _instance = None

    def __new__(cls, trt_config=None):
        if not cls._instance:
            cls._instance = super().__new__(cls)
            cls._instance.trt_config = trt_config
        else:
            if trt_config is not None:
                cls._instance.trt_config = trt_config
        return cls._instance

    def _init(self, trt_config=None):
        self.trt_config = trt_config

    def get_precision_mode(self):
        if self.trt_config and self.trt_config.precision_mode:
            return self.trt_config.precision_mode
        return None

    def get_force_fp32_ops(self):
        if self.trt_config and self.trt_config.ops_run_float:
            return self.trt_config.ops_run_float
        return []

    def get_refit_params_path(self):
        if self.trt_config and self.trt_config.refit_params_path:
            return self.trt_config.refit_params_path
        return None


class TensorRTConstantManager:
    _instance = None

    def __new__(cls, trt_config=None):
        if not cls._instance:
            cls._instance = super().__new__(cls)
            cls._instance.constant_dict = {}
        return cls._instance

    def set_constant_value(self, name, tensor_data, value):
        out_dtype = np.dtype(_PADDLE_PIR_DTYPE_2_NUMPY_DTYPE[value.dtype])
        if out_dtype == np.dtype("float64"):
            out_dtype = np.dtype("float32")
        if out_dtype == np.dtype("int64"):
            out_dtype = np.dtype("int32")
        constant_array = np.array(tensor_data, dtype=out_dtype)
        self.constant_dict.update({name: constant_array})

    def get_constant_value(self, name):
        return self.constant_dict[name]


class RefitManager:
    _instance = None

    def __new__(cls, trt_config=None):
        if not cls._instance:
            cls._instance = super().__new__(cls)
            cls._instance.trt_weights_dict = {}
            cls._instance.refit_param_names2trt_names = {}
        return cls._instance

    def set_trt_weight_tensor(self, name, trt_weights):
        self.trt_weights_dict[name] = trt_weights

    def get_trt_weight_tensor(self, name):
        return self.trt_weights_dict[name]

    def set_mapping(self, param_name, layer_name, role):
        if isinstance(role, RefitRole):
            role = role.value
        if param_name not in self.refit_param_names2trt_names:
            self.refit_param_names2trt_names[param_name] = {}
        self.refit_param_names2trt_names[param_name][role] = layer_name

    def get_mapping(self, param_name, role=None):
        if param_name in self.refit_param_names2trt_names:
            if role is None:
                return self.refit_param_names2trt_names[param_name]
            if isinstance(role, RefitRole):
                role = role.value
            if role in self.refit_param_names2trt_names[param_name]:
                return self.refit_param_names2trt_names[param_name][role]
            else:
                return None
        else:
            return None

    def get_all_mappings(self):
        return self.refit_param_names2trt_names


# In TensorRT FP16 inference, this function sets the precision of specific
# operators to FP32, ensuring numerical accuracy for these operations.
def support_fp32_mix_precision(op_type, layer, trt_config=None):
    trt_manager = TensorRTConfigManager()
    force_fp32_ops = trt_manager.get_force_fp32_ops()
    if op_type in force_fp32_ops:
        layer.reset_precision()
        layer.precision = trt.DataType.FLOAT


def weight_to_tensor(network, paddle_value, trt_tensor, use_op_name=None):
    # the following op needn't cast trt.Weight to ITensor, because the layer need weight as input
    forbid_cast_op = [
        "pd_op.depthwise_conv2d",
        "pd_op.conv2d",
        "pd_op.conv2d_transpose",
        "pd_op.conv3d",
        "pd_op.conv3d_transpose",
        "pd_op.batch_norm",
        "pd_op.batch_norm_",
        "pd_op.layer_norm",
        "pd_op.depthwise_conv2d_transpose",
        "pd_op.fused_conv2d_add_act",
        "pd_op.affine_channel",
        "pd_op.prelu",
        "pd_op.fused_bias_dropout_residual_layer_norm",
        "pd_op.deformable_conv",
    ]
    if use_op_name in forbid_cast_op:
        return trt_tensor
    if isinstance(trt_tensor, trt.Weights):
        input_shape = paddle_value.shape
        constant_layer = network.add_constant(input_shape, trt_tensor)
        return constant_layer.get_output(0)
    return trt_tensor


def zero_dims_to_one_dims(network, trt_tensor):
    if trt_tensor is None:
        return None
    if type(trt_tensor) == trt.Weights:
        return trt_tensor
    if len(trt_tensor.shape) != 0:
        return trt_tensor
    shuffle_layer = network.add_shuffle(trt_tensor)
    shuffle_layer.reshape_dims = (1,)
    return shuffle_layer.get_output(0)


# We use a special rule to judge whether a paddle value is a shape tensor.
# The rule is consistent with the rule in C++ source code(collect_shape_manager.cc).
# We use the rule for getting min/max/opt value shape from collect_shape_manager.
# We don't use trt_tensor.is_shape_tensor, because sometimes, the trt_tensor that corresponding to paddle value is not a shape tensor
# when it is a output in this trt graph, but it is a shape tensor when it is a input in next trt graph.
def is_shape_tensor(value):
    dims = value.shape
    total_elements = 1
    if (
        dims.count(-1) > 1
    ):  # we can only deal with the situation that is has one dynamic dims
        return False
    for dim in dims:
        total_elements *= abs(dim)  # add abs for dynamic shape -1
    is_int_dtype = value.dtype == paddle.int32 or value.dtype == paddle.int64
    return total_elements <= 8 and total_elements >= 1 and is_int_dtype


def get_cache_path(cache_path):
    if cache_path is not None:
        cache_path = cache_path
    else:
        home_path = os.path.expanduser("~")
        cache_path = os.path.join(home_path, ".pp_trt_cache")

    if not os.path.exists(cache_path):
        os.makedirs(cache_path)
    return cache_path


def remove_duplicate_value(value_list):
    ret_list = []
    ret_list_id = []
    for value in value_list:
        if value.id not in ret_list_id:
            ret_list.append(value)
            ret_list_id.append(value.id)
    return ret_list


def set_dynamic_range(paddle_op, trt_inputs):
    if paddle_op.has_attr("inputs_index"):
        inputs_index = paddle_op.attrs()["inputs_index"]
        inputs_scale = paddle_op.attrs()["inputs_scale"]
        for i, index in enumerate(inputs_index):
            scale = inputs_scale[i]
            trt_inputs[index].set_dynamic_range(-scale, scale)


def get_trt_version():
    return trt.__version__
