#pragma once #include #include #include #include #include #include #include #include #include #include #include #include namespace torch::jit::tensorexpr { // A class that analyzes the given program relevant for Block backend. class BlockAnalysis : public IRVisitor { public: bool is_buf_store_target(const BufPtr& buf) const { return store_targets_.count(buf) > 0; } const std::unordered_set& loads() const { return loads_; } const std::unordered_set& stores() const { return store_targets_; } int64_t block_size() const { return block_size_; } bool areBufsInMap(const std::unordered_set& bufs) const; BufPtr getMultiDimBuf(const BufPtr& buf) const; std::string getInputName(const BufPtr& buf) const; std::string getFlatInputName(const BufPtr& buf) const { return getInputName(buf) + "_flat"; } std::unordered_map getBufferMap() const { return map_input_to_tensor_bufs_; } private: void visit(const StorePtr& v) override; void visit(const LoadPtr& v) override; void visit(const ForPtr& v) override; std::unordered_map map_input_to_tensor_bufs_; std::unordered_set store_targets_; std::unordered_set loads_; int64_t block_size_ = 32; }; // A class that overrides the underlying IRPrinter to produce Block. class BlockPrinter : public IRPrinter { public: BlockPrinter(std::ostream* os, BlockAnalysis* block_analysis) : IRPrinter(*os), block_analysis_(block_analysis) {} using IRPrinter::name_manager; using IRPrinter::visit; private: BlockAnalysis* block_analysis_; std::unordered_map dim_values_map; std::vector dim_names = {"N", "H", "W", "C"}; std::vector flat_dim_names = {"N", "NH", "NHW", "NHWC"}; void PrintTensorInfo(const std::unordered_set& bufs); void PrintArguments(const std::unordered_set& bufs); void PrintBufferInfo(const std::unordered_set& bufs); void PrintDistribution(const std::unordered_set& bufs); void PrintLoop(const std::unordered_set& bufs, bool block_idx = true); void PrintReshapeInfo( const std::unordered_set& bufs, bool reverse = false); void PrintDMAs(const std::unordered_set& bufs); void PrintAdjustBuffers(const std::unordered_set& bufs); void visit(const ForPtr& v) override; void visit(const LoadPtr& v) override; void visit(const StorePtr& v) override; void visit(const BlockPtr& v) override; void visit(const AddPtr& v) override; void visit(const MulPtr& v) override; }; class TORCH_API BlockCodeGen : public CodeGen { public: template /* implicit */ BlockCodeGen(StmtPtr stmt, Ts... ts) : CodeGen( stmt, std::vector({BufferArg(ts)...}), at::Device(at::kCPU)) { Initialize(); } BlockCodeGen( StmtPtr stmt, const std::vector& buffer_args, at::Device device = at::Device(at::kCPU), const std::string& kernel_func_name = "func") : CodeGen(std::move(stmt), buffer_args, device, kernel_func_name) { Initialize(); } ~BlockCodeGen() override; void call(const std::vector& args) override; void call_raw(const std::vector& args) override; void Initialize(); std::string getCodeText(const std::string& attr = "") override { return oss_.str(); } private: UniqueNameManager* name_manager() { if (!printer_) { throw std::runtime_error("Null IRPrinter is not expected"); } return printer_->name_manager(); } std::ostream& os() { return printer_->os(); } std::ostringstream oss_; std::unique_ptr printer_; std::unique_ptr block_analysis_; std::string GetUniqueFuncName(const std::string& func_prefix); }; } // namespace torch::jit::tensorexpr