#pragma once #include #include #include #include #include #include namespace torch::jit::tensorexpr { class TORCH_API Tensor { public: Tensor(BufPtr buf, const std::vector& args, const ExprPtr& body) : buf_(std::move(buf)) { stmt_ = constructStmt(args, body, {}, {}); } Tensor(BufHandle buf, const std::vector& args, ExprHandle body) : Tensor(buf.node(), VarHandleVectorToVarVector(args), body.node()) {} Tensor( BufPtr buf, const std::vector& args, const std::vector& reduce_dims, const std::vector& reduce_args, const ExprPtr& body) : buf_(std::move(buf)) { stmt_ = constructStmt(args, body, reduce_dims, reduce_args); } Tensor( BufHandle buf, const std::vector& args, const std::vector& reduce_dims, const std::vector& reduce_args, ExprHandle body) : Tensor( buf.node(), VarHandleVectorToVarVector(args), ExprHandleVectorToExprVector(reduce_dims), VarHandleVectorToVarVector(reduce_args), body.node()) {} Tensor(BufPtr buf, StmtPtr stmt) : buf_(std::move(buf)), stmt_(std::move(stmt)) {} BufPtr buf() const { return buf_; } StmtPtr stmt() const { return stmt_; } template inline ExprHandle load(const std::vector& args) const; template inline ExprHandle load(const Ts&... ts) const; private: StmtPtr constructStmt( const std::vector& args, const ExprPtr& body, const std::vector& reduce_dims, const std::vector& reduce_args) const; BufPtr buf_; StmtPtr stmt_; }; TORCH_API Tensor Compute( const std::string& func_name, const std::vector& dims, const std::optional>& strides, const std::function& body_func); TORCH_API Tensor Compute( const std::string& func_name, const std::vector& dims, const std::function& body_func); TORCH_API Tensor Compute( const std::string& func_name, const std::vector& dims, const std::optional>& strides, const std::function& body_func); TORCH_API Tensor Compute( const std::string& func_name, const std::vector& dims, const std::function& body_func); TORCH_API Tensor Compute( const std::string& func_name, const std::vector& dims, const std::optional>& strides, const std::function< ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>& body_func); TORCH_API Tensor Compute( const std::string& func_name, const std::vector& dims, const std::function< ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>& body_func); TORCH_API Tensor Compute( const std::string& func_name, const std::vector& dims, const std::optional>& strides, const std::function& body_func); TORCH_API Tensor Compute( const std::string& func_name, const std::vector& dims, const std::function& body_func); TORCH_API Tensor Compute( const std::string& func_name, const std::vector& dims, const std::optional>& strides, const std::function&)>& body_func); TORCH_API Tensor Compute( const std::string& func_name, const std::vector& dims, const std::function&)>& body_func); inline std::vector create_index_vars( const std::vector& dims) { std::vector vars; vars.reserve(dims.size()); for (const ExprHandle& dim : dims) { vars.emplace_back(alloc( "i", dim.dtype().scalar_type() == ScalarType::Long ? kLong : kInt)); } return vars; } // Handle reductions over a Reducer and a body_func which produces values. template Tensor Reduce( const std::string& func_name, const std::vector& dims, const std::optional>& strides, const Reducer& reducer, const InitFunc& init_func, const BodyFunc& body_func, const std::vector& reduce_dims) { std::vector vars = create_index_vars(dims); std::vector reduce_vars = create_index_vars(reduce_dims); // If reduce_vars is empty, then it's not a reduction, but rather a simple // copy if (reduce_vars.empty()) { ExprHandle body = Reducer::getReduceBody(body_func, vars); BufHandle func_result = Buf::make(func_name, dims, body.dtype(), std::nullopt, strides); return Tensor(std::move(func_result), vars, std::move(body)); } std::vector all_vars; all_vars.insert(all_vars.end(), vars.begin(), vars.end()); all_vars.insert(all_vars.end(), reduce_vars.begin(), reduce_vars.end()); ExprHandle body = Reducer::getReduceBody(body_func, all_vars); std::vector output_args(vars.begin(), vars.end()); ExprHandle init_expr = Cast::make(body.dtype(), init_func(vars)); BufHandle func_result = Buf::make(func_name, dims, body.dtype(), init_expr); ExprHandle reduce_op = reducer(func_result, body, output_args, reduce_vars); if (body.dtype() == kBFloat16) { ExprHandle init_expr_acc = Cast::make(kFloat, init_func(vars)); BufHandle func_result_acc = Buf::make(func_name + "_acc", dims, kFloat, init_expr_acc); reduce_op = reducer( func_result, std::move(func_result_acc), body, output_args, reduce_vars); } Tensor t = Tensor( std::move(func_result), vars, reduce_dims, reduce_vars, std::move(reduce_op)); return t; } template Tensor Reduce( const std::string& func_name, const std::vector& dims, const Reducer& reducer, const InitFunc& init_func, const BodyFunc& body_func, const std::vector& reduce_dims) { return Reduce( func_name, dims, std::nullopt, reducer, init_func, body_func, reduce_dims); } template Tensor Reduce( const std::string& func_name, const std::vector& dims, const std::optional>& strides, const Reducer& reducer, const BodyFunc& body_func, const std::vector& reduce_dims) { return Reduce( func_name, dims, strides, reducer, [&](ParameterList& p [[maybe_unused]]) { return ExprHandle(reducer.initializer()); }, body_func, reduce_dims); } template Tensor Reduce( const std::string& func_name, const std::vector& dims, const Reducer& reducer, const BodyFunc& body_func, const std::vector& reduce_dims) { return Reduce( func_name, dims, std::nullopt, reducer, body_func, reduce_dims); } // Overload which allows inline lambda functions for the body_func. template Tensor Reduce( const std::string& func_name, const std::vector& dims, const std::optional>& strides, const Reducer& reducer, const BodyFunc&& body_func, const std::vector& reduce_dims) { return Reduce(func_name, dims, strides, reducer, body_func, reduce_dims); } template Tensor Reduce( const std::string& func_name, const std::vector& dims, const Reducer& reducer, const BodyFunc&& body_func, const std::vector& reduce_dims) { return Reduce(func_name, dims, std::nullopt, reducer, body_func, reduce_dims); } TORCH_API Tensor Reduce( const std::string& name, const std::vector& dims, const std::optional>& strides, const Reducer& reducer, const BufHandle& buffer, const std::vector& reduce_dims); TORCH_API Tensor Reduce( const std::string& name, const std::vector& dims, const Reducer& reducer, const BufHandle& buffer, const std::vector& reduce_dims); // Overload for the common case of all dimensions of a previously Computed // Tensor. TORCH_API Tensor Reduce( const std::string& func_name, const std::vector& dims, const std::optional>& strides, const Reducer& reducer, const Tensor& tensor, const std::vector& reduce_dims); TORCH_API Tensor Reduce( const std::string& func_name, const std::vector& dims, const Reducer& reducer, const Tensor& tensor, const std::vector& reduce_dims); template inline ExprHandle Tensor::load(const Ts&... ts) const { std::vector params({ExprHandle(ts)...}); return Load::make(BufHandle(this->buf()), params); } template inline ExprHandle Tensor::load(const std::vector& args) const { std::vector params(args.begin(), args.end()); return Load::make(BufHandle(this->buf()), params); } template inline ExprHandle BufHandle::load(const Ts&... ts) const { std::vector params({ExprHandle(ts)...}); return ExprHandle(alloc(node(), ExprHandleVectorToExprVector(params))); } template inline ExprHandle BufHandle::load(const std::vector& args) const { std::vector params(args.begin(), args.end()); return ExprHandle(alloc(node(), ExprHandleVectorToExprVector(params))); } inline ExprHandle BufHandle::load(const std::vector& args) const { return this->template load(args); } } // namespace torch::jit::tensorexpr