import operator_benchmark as op_bench

import torch


"""Microbenchmarks for where operator."""


configs_short = op_bench.config_list(
    attr_names=["cond_shape", "input_shape", "other_shape"],
    attrs=[
        [(8, 16, 1), (1,), (1,)],
        [(8, 16, 1), (16, 1), (8, 16, 1)],
        [(8, 16, 1), (8, 1, 1), (1,)],
    ],
    cross_product_configs={"device": ["cpu"], "dtype": [torch.float]},
    tags=["short"],
)


configs_long = op_bench.cross_product_configs(
    cond_shape=[(64, 16, 1), (64, 16, 8), (1024, 64, 16, 128)],
    input_shape=[(1,), (16, 1), (64, 16, 1)],
    other_shape=[(1,), (16, 1), (64, 16, 1)],
    device=["cpu", "cuda"],
    dtype=[torch.float],
    tags=["long"],
)


class WhereBenchmark(op_bench.TorchBenchmarkBase):
    def init(self, cond_shape, input_shape, other_shape, dtype, device):
        def _create_tensor(shape):
            return torch.randn(*shape, dtype=dtype, device=device)

        self.inputs = {
            "condition": _create_tensor(cond_shape) > 0,
            "input": _create_tensor(input_shape),
            "other": _create_tensor(other_shape),
        }
        self.set_module_name("where")

    def forward(self, condition, input, other):
        return torch.where(condition, input, other)


op_bench.generate_pt_test(configs_short + configs_long, WhereBenchmark)


if __name__ == "__main__":
    op_bench.benchmark_runner.main()
