#pragma once #include #include #include namespace torch::jit::tensorexpr { struct TensorInfo { std::vector dims; c10::ScalarType dtype; }; std::optional getTensorInfo(const BufHandle& b); int64_t normalizeAndCheckIndex(int64_t idx, int64_t list_size); // Convert boolean to integer, if needed. ExprHandle boolToInteger(const ExprHandle& x); ExprHandle promoteToDtype(ExprHandle e, ScalarType dt); void promoteInputs( std::vector& inputs, const int typeConstraints = kAllTypes); ExprHandle promoteIntegerToDefaultType(const ExprHandle& e); ExprHandle promoteHalfToFloat(const ExprHandle& e); ExprHandle demoteOutput( const ExprHandle& e, const std::optional type); std::vector broadcastShapes( std::vector> shapes); std::vector broadcastShapes( const std::vector& a, const std::vector& b); std::vector valueShape(const ArgValue& v); ExprHandle tensorOrConstant( const ArgValue& v, const std::vector& axes); ExprHandle scalarOrConstant(const ArgValue& v); ExprHandle broadcast(const BufHandle& b, const std::vector& axes); ExprHandle constant(const ArgValue& v); ExprHandle clamp( const ExprHandle& cmin, const ExprHandle& cmax, const ExprHandle& input); Tensor computeChunk( const std::vector& inputs, const std::vector& outputShape, const std::vector& outputStrides, const std::optional& outputType, at::Device device); Tensor computeTranspose( const std::vector& inputs, const std::vector& outputShape, const std::vector& outputStrides, const std::optional& outputType, at::Device device); Tensor computeExpand( const std::vector& inputs, const std::vector& outputShape, const std::vector& outputStrides, const std::optional& outputType, at::Device device); Tensor computeReshape( const std::vector& inputs, const std::vector& outputShape, const std::vector& outputStrides, const std::optional& outputType, at::Device device); Tensor computeFlatten( const std::vector& inputs, const std::vector& outputShape, const std::vector& outputStrides, const std::optional& outputType, at::Device device); Tensor computeCatWoConditionals( const std::vector& inputs, const std::vector& outputShape); Tensor computeCat( const std::vector& inputs, const std::vector& outputShape, const std::vector& outputStrides, const std::optional& outputType, at::Device device); Tensor computeEmbedding( const std::vector& inputs, const std::vector& outputShape, const std::vector& outputStrides, const std::optional& outputType, at::Device device); } // namespace torch::jit::tensorexpr