#ifndef MetalConvParams_h #define MetalConvParams_h #include namespace at::native::metal { struct Conv2DParams final { Conv2DParams() = default; Conv2DParams( c10::IntArrayRef inputSizes, c10::IntArrayRef weightSizes, c10::IntArrayRef padding, c10::IntArrayRef stride, c10::IntArrayRef dilation, int64_t groups); std::vector output_sizes() const { return {N, OC, OH, OW}; } bool isDepthwise() const { // Currently, only channel multiplier of 1 is supported // i.e. inputFeatureChannels == outputFeatureChannels return G > 1 && IC == 1 && OC == G && OC == C; } int64_t N; // batch size int64_t C; // channels int64_t H; // input height int64_t W; // input width int64_t OC; // output channels int64_t IC; // input channels int64_t KH; // kernel height int64_t KW; // kernel width int64_t SY; // stride y (height) int64_t SX; // stride x (width) int64_t PY; // padding y (height) int64_t PX; // padding x (width) int64_t DY; // dilation y (height) int64_t DX; // dilation x (width) int64_t G; // groups int64_t OW; // output width int64_t OH; // output height }; } // namespace at::native::metal #endif /* MetalConvParams_h */