#pragma once #ifdef TORCH_ENABLE_LLVM #include #include #include #include C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wsuggest-override") #include C10_DIAGNOSTIC_POP() #include #include #include #include #include namespace torch { namespace jit { namespace tensorexpr { inline std::string formatError(llvm::Error&& err, const char* msg) { static constexpr const char* defaultErrorMsg = "Unexpected failure in LLVM JIT"; std::string errorMsg(msg ? msg : defaultErrorMsg); llvm::raw_string_ostream ss(errorMsg); ss << ": " << err; return ss.str(); } template T assertSuccess(llvm::Expected valOrErr, const char* msg = nullptr) { TORCH_INTERNAL_ASSERT(valOrErr, formatError(valOrErr.takeError(), msg)); return std::move(*valOrErr); } inline void assertSuccess(llvm::Error err, const char* msg = nullptr) { TORCH_INTERNAL_ASSERT(!err, formatError(std::move(err), msg)); } } // namespace tensorexpr } // namespace jit } // namespace torch namespace llvm { namespace orc { class PytorchLLVMJITImpl; class TORCH_API PytorchLLVMJIT { public: PytorchLLVMJIT( std::optional triple, std::optional cpu, std::optional attrs); ~PytorchLLVMJIT(); void addModule(std::unique_ptr M, std::unique_ptr C); JITSymbol findSymbol(const std::string Name); bool hasSymbol(const std::string& Name); TargetMachine& getTargetMachine(); const DataLayout& getDataLayout(); private: // Use the PImpl idiom here to hide the no-rtti parts of the JIT structure. std::unique_ptr impl_; }; } // end namespace orc } // end namespace llvm #endif // ENABLE LLVM