# Owner(s): ["module: inductor"]
import unittest

from torch._inductor.codegen.cpp import CppOverrides, CppVecOverrides
from torch._inductor.codegen.halide import HalideOverrides
from torch._inductor.codegen.mps import MetalOverrides
from torch._inductor.codegen.triton import TritonKernelOverrides
from torch._inductor.ops_handler import list_ops, OP_NAMES, OpsHandler
from torch._inductor.test_case import TestCase


class TestOpCompleteness(TestCase):
    def verify_ops_handler_completeness(self, handler):
        for op in OP_NAMES:
            self.assertIsNot(
                getattr(handler, op),
                getattr(OpsHandler, op),
                msg=f"{handler} must implement {op}",
            )
        extra_ops = list_ops(handler) - OP_NAMES
        if extra_ops:
            raise AssertionError(
                f"{handler} has an extra ops: {extra_ops}, add them to OpHandler class or prefix with `_`"
            )

    def test_triton_overrides(self):
        self.verify_ops_handler_completeness(TritonKernelOverrides)

    def test_cpp_overrides(self):
        self.verify_ops_handler_completeness(CppOverrides)

    def test_cpp_vec_overrides(self):
        self.verify_ops_handler_completeness(CppVecOverrides)

    def test_halide_overrides(self):
        self.verify_ops_handler_completeness(HalideOverrides)

    @unittest.skip("MPS backend not yet finished")
    def test_metal_overrides(self):
        self.verify_ops_handler_completeness(MetalOverrides)


if __name__ == "__main__":
    from torch._inductor.test_case import run_tests

    run_tests()
