#include <c10/util/Exception.h>
#include <c10/util/Logging.h>
#include <fmt/format.h>
#include <fmt/ranges.h>
#include <nlohmann/json.hpp>
#include <algorithm>
#include <array>
#include <iostream>

#include <torch/csrc/utils/generated_serialization_types.h>
#include <torch/nativert/graph/GraphSignature.h>

namespace torch::nativert {

namespace {

bool isSymbolicOutput(torch::_export::Argument::Tag t) {
  switch (t) {
    case torch::_export::Argument::Tag::AS_TENSOR:
    case torch::_export::Argument::Tag::AS_TENSORS:
    case torch::_export::Argument::Tag::AS_OPTIONAL_TENSORS:
    case torch::_export::Argument::Tag::AS_SYM_BOOL:
    case torch::_export::Argument::Tag::AS_SYM_BOOLS:
    case torch::_export::Argument::Tag::AS_SYM_INT:
    case torch::_export::Argument::Tag::AS_SYM_INTS:
    case torch::_export::Argument::Tag::AS_SYM_FLOAT:
    case torch::_export::Argument::Tag::AS_SYM_FLOATS:
    case torch::_export::Argument::Tag::AS_CUSTOM_OBJ:
      return true;
    default:
      return false;
  }
}

std::pair<std::string, std::string> getSpecDetails(
    const torch::_export::InputSpec& inputSpec) {
  // Retrieve the argument name and spec tag name
  switch (inputSpec.tag()) {
    case torch::_export::InputSpec::Tag::PARAMETER:
      return std::make_pair(
          inputSpec.get_parameter().get_arg().get_name(), "PARAMETER");
      break;
    case torch::_export::InputSpec::Tag::BUFFER:
      return std::make_pair(
          inputSpec.get_buffer().get_arg().get_name(), "BUFFER");
      break;
    case torch::_export::InputSpec::Tag::TENSOR_CONSTANT:
      return std::make_pair(
          inputSpec.get_tensor_constant().get_arg().get_name(),
          "TENSOR_CONSTANT");
      break;
    case torch::_export::InputSpec::Tag::CUSTOM_OBJ:
      return std::make_pair(
          inputSpec.get_custom_obj().get_arg().get_name(), "CUSTOM_OBJ");
      break;
    case torch::_export::InputSpec::Tag::USER_INPUT:
      if (inputSpec.get_user_input().get_arg().tag() ==
          torch::_export::Argument::Tag::AS_TENSOR) {
        return std::make_pair(
            inputSpec.get_user_input().get_arg().get_as_tensor().get_name(),
            "USER_INPUT");
      } else if (
          inputSpec.get_user_input().get_arg().tag() ==
          torch::_export::Argument::Tag::AS_CUSTOM_OBJ) {
        return std::make_pair(
            inputSpec.get_user_input().get_arg().get_as_custom_obj().get_name(),
            "USER_INPUT");
      } else {
        TORCH_CHECK(false, "Unsupported USER_INPUT argument type.");
      }
      break;
    case torch::_export::InputSpec::Tag::CONSTANT_INPUT:
      return std::make_pair(
          inputSpec.get_constant_input().get_name(), "CONSTANT_INPUT");
      break;
    case torch::_export::InputSpec::Tag::TOKEN:
      TORCH_CHECK(false, "Token inputs not implemented yet.");
    default:
      TORCH_CHECK(false, "Unknown InputSpec tag encountered.");
  }
}

void checkInputOrders(
    const std::vector<torch::_export::InputSpec>& inputSpecs) {
  // Map each tag to its index in the expected order
  static constexpr std::
      array<std::pair<torch::_export::InputSpec::Tag, uint32_t>, 5>
          tagOrderArray = {
              {{torch::_export::InputSpec::Tag::TOKEN, 0},
               {torch::_export::InputSpec::Tag::PARAMETER, 1},
               {torch::_export::InputSpec::Tag::BUFFER, 2},
               {torch::_export::InputSpec::Tag::TENSOR_CONSTANT, 3},
               {torch::_export::InputSpec::Tag::CUSTOM_OBJ, 4}}};
  uint32_t currentOrderIndex = 0;
  bool seenNonPersistentBuffer = false;
  for (const auto& inputSpec : inputSpecs) {
    if (inputSpec.tag() == torch::_export::InputSpec::Tag::USER_INPUT ||
        inputSpec.tag() == torch::_export::InputSpec::Tag::CONSTANT_INPUT) {
      continue;
    }
    auto it = std::find_if(
        tagOrderArray.begin(),
        tagOrderArray.end(),
        [&inputSpec](const auto& pair) {
          return pair.first == inputSpec.tag();
        });
    TORCH_CHECK(
        it != tagOrderArray.end(), "Unknown InputSpec tag encountered.");
    uint32_t tagIndex = it->second;
    if (tagIndex < currentOrderIndex) {
      auto [argName, tagName] = getSpecDetails(inputSpec);
      TORCH_CHECK(
          false,
          fmt::format(
              "Input arg {} with InputSpec {} is out of order!",
              argName,
              tagName));
    }
    currentOrderIndex = tagIndex;
    // Additional check for buffers
    if (inputSpec.tag() == torch::_export::InputSpec::Tag::BUFFER) {
      if (!inputSpec.get_buffer().get_persistent()) {
        seenNonPersistentBuffer = true;
      } else {
        TORCH_CHECK(
            !seenNonPersistentBuffer,
            "Persistent buffers must come before non-persistent buffers.");
      }
    }
  }
}

void checkInputNames(
    const c10::FastSet<std::string>& sigNames,
    const c10::FastSet<std::string>& graphNames) {
  if (sigNames == graphNames) {
    return;
  }

  std::string errorMsg = fmt::format(
      "Error: Value name difference detected between graph signature and graph nodes:\n"
      "Signature value names:\n[{}]\n"
      "Graph node names:\n[{}]",
      fmt::join(sigNames, ", "),
      fmt::join(graphNames, ", "));
  TORCH_CHECK(false, errorMsg);
}

void checkOutputNames(
    const c10::FastSet<std::optional<std::string>>& sigNames,
    const c10::FastSet<std::string>& graphNames) {
  std::vector<std::string> validNames;
  for (const auto& nameOpt : sigNames) {
    if (nameOpt.has_value()) {
      validNames.push_back(*nameOpt);
    }
  }

  for (const auto& name : validNames) {
    if (graphNames.find(name) == graphNames.end()) {
      std::string errorMsg = fmt::format(
          "Error: Value name difference detected between graph signature and graph nodes:\n"
          "Signature value names:\n[{}]\n"
          "Graph node names:\n[{}]",
          fmt::join(validNames, ", "),
          fmt::join(graphNames, ", "));
      TORCH_CHECK(false, errorMsg);
    }
  }
}

void replaceInMap(
    c10::FastMap<std::string, std::string>& map,
    std::string_view old,
    std::string_view replacement) {
  auto it = map.find(std::string{old});
  if (it == map.end()) {
    return;
  }
  std::string value = std::move(it->second);
  map.erase(it);
  map.emplace(replacement, std::move(value));
}

} // namespace

GraphSignature::GraphSignature(const torch::_export::GraphSignature& storage) {
  checkInputOrders(storage.get_input_specs());

  for (const torch::_export::InputSpec& inputSpec : storage.get_input_specs()) {
    switch (inputSpec.tag()) {
      case torch::_export::InputSpec::Tag::USER_INPUT: {
        const auto& userInputArg = inputSpec.get_user_input().get_arg();
        if (userInputArg.tag() == torch::_export::Argument::Tag::AS_TENSOR) {
          userInputs_.emplace_back(userInputArg.get_as_tensor().get_name());
        } else if (
            userInputArg.tag() ==
            torch::_export::Argument::Tag::AS_CUSTOM_OBJ) {
          userInputs_.emplace_back(userInputArg.get_as_custom_obj().get_name());
        } else {
          // TODO: handle other types
          TORCH_CHECK(false, "Non tensor inputs not implemented yet.");
        }
        break;
      }
      case torch::_export::InputSpec::Tag::PARAMETER: {
        numParameters_++;
        const auto& inputName = inputSpec.get_parameter().get_arg().get_name();
        const auto& weightName = inputSpec.get_parameter().get_parameter_name();
        inputsToWeights_.emplace_back(inputName, weightName);
        break;
      }
      case torch::_export::InputSpec::Tag::BUFFER: {
        const bool isPersistent = inputSpec.get_buffer().get_persistent();
        const auto& inputName = inputSpec.get_buffer().get_arg().get_name();
        const auto& weightName = inputSpec.get_buffer().get_buffer_name();
        if (isPersistent) {
          numPersistentBuffers_++;
        } else {
          numNonPersistentBuffers_++;
        }
        inputsToWeights_.emplace_back(inputName, weightName);
        break;
      }
      case torch::_export::InputSpec::Tag::TENSOR_CONSTANT: {
        numTensorConstants_++;
        const auto& inputName =
            inputSpec.get_tensor_constant().get_arg().get_name();
        const auto& weightName =
            inputSpec.get_tensor_constant().get_tensor_constant_name();
        inputsToWeights_.emplace_back(inputName, weightName);
        break;
      }
      case torch::_export::InputSpec::Tag::CUSTOM_OBJ: {
        numCustomObjs_++;
        const auto& inputName = inputSpec.get_custom_obj().get_arg().get_name();
        const auto& customObjName =
            inputSpec.get_custom_obj().get_custom_obj_name();
        inputsToCustomObjs_.emplace_back(inputName, customObjName);
        break;
      }
      case torch::_export::InputSpec::Tag::CONSTANT_INPUT: {
        break;
      }
      case torch::_export::InputSpec::Tag::TOKEN: {
        TORCH_CHECK(false, "Token inputs not implemented yet.");
      }
      default:
        TORCH_CHECK(false, "Unknown InputSpec tag encountered.");
        break;
    }
  }

  for (const torch::_export::OutputSpec& outputSpec :
       storage.get_output_specs()) {
    switch (outputSpec.tag()) {
      case torch::_export::OutputSpec::Tag::LOSS_OUTPUT:
        lossOutput_ = outputSpec.get_loss_output().get_arg().get_name();
        break;
      case torch::_export::OutputSpec::Tag::USER_OUTPUT:
        if (isSymbolicOutput(outputSpec.get_user_output().get_arg().tag())) {
          switch (outputSpec.get_user_output().get_arg().tag()) {
            case torch::_export::Argument::Tag::AS_TENSOR: {
              userOutputs_.emplace_back(outputSpec.get_user_output()
                                            .get_arg()
                                            .get_as_tensor()
                                            .get_name());
              break;
            }
            case torch::_export::Argument::Tag::AS_SYM_INT: {
              userOutputs_.emplace_back(outputSpec.get_user_output()
                                            .get_arg()
                                            .get_as_sym_int()
                                            .get_as_name());
              break;
            }
            default: {
              TORCH_CHECK(
                  false, "Unsupported symbolic user output type encountered.");
            }
          }
        } else {
          // for constant outputs, we don't have a name
          userOutputs_.emplace_back(std::nullopt);
        }
        break;
      case torch::_export::OutputSpec::Tag::BUFFER_MUTATION:
        buffersToMutate_.emplace(
            outputSpec.get_buffer_mutation().get_arg().get_name(),
            outputSpec.get_buffer_mutation().get_buffer_name());
        break;
      case torch::_export::OutputSpec::Tag::GRADIENT_TO_PARAMETER:
        gradientsToParameters_.emplace(
            outputSpec.get_gradient_to_parameter().get_arg().get_name(),
            outputSpec.get_gradient_to_parameter().get_parameter_name());
        break;
      case torch::_export::OutputSpec::Tag::GRADIENT_TO_USER_INPUT:
        gradientsToUserInputs_.emplace(
            outputSpec.get_gradient_to_user_input().get_arg().get_name(),
            outputSpec.get_gradient_to_user_input().get_user_input_name());
        break;
      case torch::_export::OutputSpec::Tag::USER_INPUT_MUTATION:
        userInputsToMutate_.emplace(
            outputSpec.get_user_input_mutation().get_arg().get_name(),
            outputSpec.get_user_input_mutation().get_user_input_name());
        break;
      case torch::_export::OutputSpec::Tag::TOKEN: {
        TORCH_CHECK(false, "Token outputs not implemented yet.");
      }
      default:
        TORCH_CHECK(false, "Unknown OutputSpec tag encountered.");
    }
  }

  if (FLAGS_caffe2_log_level > 2) {
    std::cout << *this << "\n";
  }
}

c10::FastSet<std::string> GraphSignature::inputNames() const {
  c10::FastSet<std::string> ret;
  size_t numInputs = userInputs().size() + inputsToWeights().size() +
      inputsToCustomObjs().size();
  ret.reserve(numInputs);
  for (const auto& name : userInputs()) {
    ret.insert(name);
  }
  for (const auto& [inputName, _] : inputsToWeights()) {
    ret.insert(inputName);
  }
  for (const auto& [inputName, _] : inputsToCustomObjs()) {
    ret.insert(inputName);
  }
  return ret;
}

c10::FastSet<std::optional<std::string>> GraphSignature::outputNames() const {
  c10::FastSet<std::optional<std::string>> ret;
  size_t numOutputs = userOutputs().size() + buffersToMutate().size() +
      userInputsToMutate().size() +
      (hasBackward() ? gradientsToParameters().size() +
               gradientsToUserInputs().size() + (lossOutput().empty() ? 0 : 1)
                     : 0);
  ret.reserve(numOutputs);
  for (const auto& name : userOutputs()) {
    ret.insert(name);
  }
  for (const auto& [outputName, _] : buffersToMutate()) {
    ret.insert(outputName);
  }
  for (const auto& [outputName, _] : userInputsToMutate()) {
    ret.insert(outputName);
  }
  if (hasBackward()) {
    if (!gradientsToParameters().empty()) {
      for (const auto& [outputName, _] : gradientsToParameters()) {
        ret.insert(outputName);
      }
    }
    if (!gradientsToUserInputs().empty()) {
      for (const auto& [outputName, _] : gradientsToUserInputs()) {
        ret.insert(outputName);
      }
    }
    if (!lossOutput().empty()) {
      ret.insert(lossOutput());
    }
  }
  return ret;
}

void GraphSignature::lint(
    const c10::FastSet<std::string>& graphInputs,
    const c10::FastSet<std::string>& graphOutputs) const {
  checkInputNames(inputNames(), graphInputs);
  checkOutputNames(outputNames(), graphOutputs);
}

void GraphSignature::replaceAllUses(
    std::string_view old,
    std::string_view replacement) {
  if (old == replacement) {
    return;
  }
  for (auto& name : userOutputs_) {
    if (name == old) {
      name = replacement;
    }
  }
  replaceInMap(buffersToMutate_, old, replacement);
  if (hasBackward()) {
    replaceInMap(gradientsToParameters_, old, replacement);
    replaceInMap(gradientsToUserInputs_, old, replacement);
    if (old == lossOutput_) {
      lossOutput_ = replacement;
    }
  }
}

std::ostream& operator<<(std::ostream& out, const GraphSignature& sig) {
  if (!sig.inputsToParameters().empty()) {
    out << "inputsToParameters: {\n";
    for (const auto& [inputName, paramName] : sig.inputsToParameters()) {
      out << "\t" << inputName << " : " << paramName << "\n";
    }
    out << "}\n";
  }
  if (!sig.inputsToBuffers().empty()) {
    out << "inputsToBuffers: {\n";
    for (const auto& [inputName, bufferName] : sig.inputsToBuffers()) {
      out << "\t" << inputName << " : " << bufferName << "\n";
    }
    out << "}\n";
  }
  if (!sig.inputsToTensorConstants().empty()) {
    out << "inputsToTensorConstants: {\n";
    for (const auto& [inputName, tensorConstantName] :
         sig.inputsToTensorConstants()) {
      out << "\t" << inputName << " : " << tensorConstantName << "\n";
    }
    out << "}\n";
  }
  if (!sig.inputsToCustomObjs().empty()) {
    out << "inputsToCustomObjs: {\n";
    for (const auto& [inputName, customObjName] : sig.inputsToCustomObjs()) {
      out << "\t" << inputName << " : " << customObjName << "\n";
    }
    out << "}\n";
  }
  if (!sig.userOutputs().empty()) {
    out << "userOutputs: {\n";
    for (const auto& outputName : sig.userOutputs()) {
      out << "\t" << outputName.value_or("Constant") << "\n";
    }
    out << "}\n";
  }
  if (!sig.buffersToMutate().empty()) {
    out << "buffersToMutate: {\n";
    for (const auto& [outputName, mutatedBufferName] : sig.buffersToMutate()) {
      out << "\t" << outputName << " : " << mutatedBufferName << "\n";
    }
    out << "}\n";
  }
  if (!sig.userInputsToMutate().empty()) {
    out << "userInputsToMutate: {\n";
    for (const auto& [outputName, mutatedUserInputName] :
         sig.userInputsToMutate()) {
      out << "\t" << outputName << " : " << mutatedUserInputName << "\n";
    }
    out << "}\n";
  }
  if (sig.hasBackward()) {
    if (!sig.gradientsToParameters().empty()) {
      out << "gradientsToParameters: {\n";
      for (const auto& [outputName, paramName] : sig.gradientsToParameters()) {
        out << "\t" << outputName << " : " << paramName << "\n";
      }
      out << "}\n";
    }
    if (!sig.gradientsToUserInputs().empty()) {
      out << "gradientsToUserInputs: {\n";
      for (const auto& [outputName, userInputName] :
           sig.gradientsToUserInputs()) {
        out << "\t" << outputName << " : " << userInputName << "\n";
      }
      out << "}\n";
    }
    out << "lossOutput: " << sig.lossOutput() << "\n";
  }
  return out;
}

} // namespace torch::nativert
