#pragma once #include #include namespace torch::jit::tensorexpr { // An API to compute 2D depthwise convolutions with bias. TORCH_API Tensor conv2d_depthwise( BufHandle input, BufHandle weight, BufHandle bias, int stride, int pad, int groups); // An API to compute 2D depthwise convolutions without bias. TORCH_API Tensor conv2d_depthwise( BufHandle input, BufHandle weight, int stride, int pad, int groups); TORCH_API Tensor conv2d_depthwise( BufHandle input, BufHandle weight, BufHandle bias, ExprHandle N, ExprHandle C, ExprHandle H, ExprHandle W, ExprHandle K, ExprHandle CperG, ExprHandle R, ExprHandle S, ExprHandle stride, ExprHandle pad, ExprHandle groups); TORCH_API Tensor conv2d_depthwise( BufHandle input, BufHandle weight, ExprHandle N, ExprHandle C, ExprHandle H, ExprHandle W, ExprHandle K, ExprHandle CperG, ExprHandle R, ExprHandle S, ExprHandle stride, ExprHandle pad, ExprHandle groups); bool conv2dIsSupported( const TensorInfo& input, const TensorInfo& weight, const TensorInfo& bias, const std::vector& stride, const std::vector& pad, const std::vector& dilation, int64_t groups); bool mkldnnPrepackedConvIsSupported( const TensorInfo& input, const TensorInfo& weight, const std::vector& stride, const std::vector& pad, const std::vector& dilation, int64_t groups); Tensor computeConv2d( const std::vector& inputs, const std::vector& outputShape, const std::vector& outputStrides, const std::optional& outputType, at::Device device); Tensor computeConv1d( const std::vector& inputs, const std::vector& outputShape, const std::vector& outputStrides, const std::optional& outputType, at::Device device); Tensor computePrepackedConv2dClampRun( const std::vector& inputs, const std::vector& outputShape, const std::vector& outputStrides, const std::optional& outputType, at::Device device); Tensor computePrepackedLinearClampRun( const std::vector& inputs, const std::vector& outputShape, const std::vector& outputStrides, const std::optional& outputType, at::Device device); Tensor computeMkldnnPrepackedConvRun( const std::vector& inputs, const std::vector& outputShape, const std::vector& outputStrides, const std::optional& outputType, at::Device device); } // namespace torch::jit::tensorexpr