#pragma once #include #include #include #include #include namespace torch::jit { class TEWrapper { public: TEWrapper() = default; void call(const std::vector& args); template bool checkInput(const at::Tensor& t) { #ifdef TORCH_ENABLE_LLVM return t.is_contiguous() && t.dtype().Match(); #else return false; #endif } #ifdef TORCH_ENABLE_LLVM void update(std::unique_ptr&& cg_); #endif private: #ifdef TORCH_ENABLE_LLVM std::unique_ptr cg; #endif }; std::shared_ptr createDiv(); std::shared_ptr createLogit(); std::shared_ptr createRelu(); std::shared_ptr createTanh(); std::shared_ptr createSigmoid(); std::shared_ptr createSignedLog1p(); std::shared_ptr createClamp(); std::shared_ptr createClampNanToNum(); } // namespace torch::jit