/** \brief Fusing linear patterns as single at::linear for easier pattern * matching in later passes */ #pragma once #include namespace torch::jit { /** \brief Match the at::linear pattern and fuse it into a single at::linear * This pass fuse the addmm or matmul + add generated by JIT back to linear * This pass can be deleted once the JIT can emit the aten::linear in the future */ TORCH_API void FuseLinear(std::shared_ptr& graph); /** Swap functional linear CallFunctions to aten::linear */ TORCH_API void SwapFunctionalLinear(std::shared_ptr& graph); /** Swap all functional linear CallFunctions in module */ TORCH_API void SwapFunctionalLinear(Module& module); } // namespace torch::jit