import time
import timeit

import numpy as np

import torch
import torch._dynamo.config


# to satisfy linter complaining about undefined variable
foo = None

args = [f"x{i}" for i in range(100)]
fn_str = f"""\
def foo({", ".join(args)}):
    n = {" + ".join(arg + ".shape[0]" for arg in args)}
    return x0 + n
"""

exec(fn_str, globals())
torch._dynamo.config.recompile_limit = 16


def bench(name, fn):
    torch._dynamo.reset()
    inps = [[torch.randn(i) for _ in range(100)] for i in range(10, 101, 10)]

    def run_fn():
        for inp in inps:
            fn(*inp)

    start = time.perf_counter()
    for _ in range(3):
        run_fn()
    end = time.perf_counter()

    results = timeit.repeat(lambda: run_fn(), number=1000, repeat=10)
    print(f"{name} {np.median(results) * 1000:.1f}us (warmup={end - start:.1f}s)")


def main():
    bench("compiled", torch.compile(foo, dynamic=False))  # type: ignore[F821]


if __name__ == "__main__":
    main()
