# Copyright (c) 2024, Tri Dao.

import pytest
import torch
import torch.nn.functional as F
from einops import rearrange, repeat

from flash_attn.ops.triton.layer_norm import (
    layer_norm_fn,
    layer_norm_ref,
    rms_norm_ref,
    layer_norm_linear_fn,
)


is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8


@pytest.mark.parametrize("has_weight1", [False, True])
# @pytest.mark.parametrize("has_weight1", [True])
@pytest.mark.parametrize("has_x1", [False, True])
# @pytest.mark.parametrize("has_x1", [False])
@pytest.mark.parametrize("has_rowscale", [False, True])
# @pytest.mark.parametrize("has_rowscale", [False])
@pytest.mark.parametrize("dropout_p", [0.0, 0.27])
# @pytest.mark.parametrize("dropout_p", [0.0])
@pytest.mark.parametrize("prenorm", [True, False])
# @pytest.mark.parametrize("prenorm", [False])
@pytest.mark.parametrize("is_rms_norm", [False, True])
# @pytest.mark.parametrize("is_rms_norm", [True])
@pytest.mark.parametrize("has_residual", [True, False])
# @pytest.mark.parametrize("has_residual", [False])
@pytest.mark.parametrize(
    "weight_dtype", [torch.float32, torch.float16] + ([torch.bfloat16] if is_sm8x else [])
)
# @pytest.mark.parametrize("weight_dtype", [torch.float32])
@pytest.mark.parametrize(
    "input_dtype,residual_dtype",
    [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
    + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
)
# @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.float16, torch.float16)])
@pytest.mark.parametrize("hidden_size", [192, 2048, 2560, 3000, 4096])
# @pytest.mark.parametrize("hidden_size", [256])
def test_layer_norm(
    hidden_size,
    input_dtype,
    residual_dtype,
    weight_dtype,
    has_residual,
    is_rms_norm,
    prenorm,
    dropout_p,
    has_rowscale,
    has_x1,
    has_weight1,
):
    if has_rowscale and has_x1:
        pytest.skip("Not supported")
    device = "cuda"
    if any(x == torch.bfloat16 for x in [input_dtype, residual_dtype, weight_dtype]):
        atol = 5e-2
    elif any(x == torch.float16 for x in [input_dtype, residual_dtype, weight_dtype]):
        atol = 1e-2
    else:
        atol = 1e-4
    # set seed
    torch.random.manual_seed(0)
    batch_size = 8
    seqlen = 512
    layer_norm_ref_fn = layer_norm_ref if not is_rms_norm else rms_norm_ref
    allclose = (
        # Sometimes x0_pt.grad is NaN
        lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max()
        <= 2 * (x_pt[~x_pt.isnan()] - x_ref[~x_pt.isnan()]).abs().max() + atol
        or (
            # Sometimes x_pt and x_ref are the same (e.g. bfloat16) so we want to perturb is a bit
            # by multiply and divide by 0.3
            (x_pt[~x_pt.isnan()] - x_ref[~x_pt.isnan()]).abs().max() == 0.0
            and (x - x_ref).abs().max()
            <= 2 * (x_pt[~x_pt.isnan()] * 0.3 / 0.3 - x_ref[~x_pt.isnan()]).abs().max() + atol
        )
    )
    x0 = torch.randn(
        batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
    )
    x0_pt = x0.detach().clone().requires_grad_()
    x0_ref = x0.detach().clone().requires_grad_()
    if has_residual:
        res = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
        res_pt = res.detach().clone().requires_grad_()
        res_ref = res.detach().clone().requires_grad_()
    else:
        res, res_pt, res_ref = None, None, None
    weight = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
    if not is_rms_norm:
        bias = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
    else:
        bias = None
    weight_pt = weight.detach().clone().requires_grad_()
    weight_ref = weight.detach().clone().requires_grad_()
    bias_pt = bias.detach().clone().requires_grad_() if bias is not None else None
    bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
    if has_x1:
        x1 = torch.randn_like(x0, dtype=input_dtype, requires_grad=True)
        x1_pt = x1.detach().clone().requires_grad_()
        x1_ref = x1.detach().clone().requires_grad_()
    else:
        x1, x1_pt, x1_ref = None, None, None
    if has_weight1:
        weight1 = torch.randn(
            hidden_size, device=device, dtype=weight_dtype, requires_grad=True
        )
        weight1_pt = weight1.detach().clone().requires_grad_()
        weight1_ref = weight1.detach().clone().requires_grad_()
        if not is_rms_norm:
            bias1 = torch.randn(
                hidden_size, device=device, dtype=weight_dtype, requires_grad=True
            )
        else:
            bias1 = None
        bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None
        bias1_ref = bias1.detach().clone().requires_grad_() if bias1 is not None else None
    else:
        weight1, weight1_pt, weight1_ref = None, None, None
        bias1, bias1_pt, bias1_ref = None, None, None

    rowscale = (
        torch.randn(batch_size, seqlen, dtype=input_dtype, device=device)
        if has_rowscale
        else None
    )

    residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
    out, *rest = layer_norm_fn(
        x0,
        weight,
        bias,
        residual=res,
        x1=x1,
        weight1=weight1,
        bias1=bias1,
        eps=1e-6,
        dropout_p=dropout_p,
        rowscale=rowscale,
        prenorm=prenorm,
        residual_in_fp32=residual_in_fp32,
        is_rms_norm=is_rms_norm,
        return_dropout_mask=True,
    )
    dropout_mask = rest[-2] if dropout_p > 0.0 else None
    dropout_mask1 = rest[-1] if dropout_p > 0.0 and x1 is not None else None
    out_pt = layer_norm_ref_fn(
        x0_pt,
        weight_pt,
        bias_pt,
        residual=res_pt,
        x1=x1_pt,
        weight1=weight1_pt,
        bias1=bias1_pt,
        eps=1e-6,
        dropout_p=dropout_p,
        rowscale=rowscale,
        prenorm=prenorm,
        dropout_mask=dropout_mask,
        dropout_mask1=dropout_mask1,
    )
    out_ref = layer_norm_ref_fn(
        x0_ref,
        weight_ref,
        bias_ref,
        residual=res_ref,
        x1=x1_ref,
        weight1=weight1_ref,
        bias1=bias1_ref,
        eps=1e-6,
        dropout_p=dropout_p,
        rowscale=rowscale,
        prenorm=prenorm,
        dropout_mask=dropout_mask,
        dropout_mask1=dropout_mask1,
        upcast=True,
    )
    if not has_weight1:
        if prenorm:
            residual = rest[0]
            out_pt, residual_pt = out_pt
            out_ref, residual_ref = out_ref
        out1, out1_pt, out1_ref = None, None, None
    else:
        out1 = rest.pop(0)
        if prenorm:
            residual = rest[0]
            out_pt, out1_pt, residual_pt = out_pt
            out_ref, out1_ref, residual_ref = out_ref
        else:
            out_pt, out1_pt = out_pt
            out_ref, out1_ref = out_ref
    assert out.dtype == input_dtype
    if prenorm:
        assert residual.dtype == residual_dtype
        assert allclose(residual, residual_pt, residual_ref)
    assert allclose(out, out_pt, out_ref)
    if out1 is not None:
        assert out1.dtype == input_dtype
        assert allclose(out1, out1_pt, out1_ref)
    if dropout_mask is not None:
        dropout_fraction = 1.0 - dropout_mask.float().mean()
        assert abs(dropout_fraction - dropout_p) < 0.01
    if dropout_mask1 is not None:
        dropout_fraction = 1.0 - dropout_mask1.float().mean()
        assert abs(dropout_fraction - dropout_p) < 0.01
        assert not torch.equal(dropout_mask, dropout_mask1)

    g = torch.randn_like(out) / batch_size
    if has_weight1:
        out = out * F.gelu(out1)
        out_pt = out_pt * F.gelu(out1_pt)
        out_ref = out_ref * F.gelu(out1_ref)
    if not prenorm:
        out.backward(g)
        out_pt.backward(g)
        out_ref.backward(g)
    else:
        (out * F.sigmoid(residual)).backward(g)
        (out_pt * F.sigmoid(residual_pt)).backward(g)
        (out_ref * F.sigmoid(residual_ref.to(dtype=residual_dtype))).backward(g)
    assert allclose(x0.grad, x0_pt.grad, x0_ref.grad)
    if has_residual:
        assert allclose(res.grad, res_pt.grad, res_ref.grad)
    if has_x1:
        assert allclose(x1.grad, x1_pt.grad, x1_ref.grad)
    assert allclose(weight.grad, weight_pt.grad, weight_ref.grad)
    if bias is not None:
        assert allclose(bias.grad, bias_pt.grad, bias_ref.grad)
    if has_weight1:
        assert allclose(weight1.grad, weight1_pt.grad, weight1_ref.grad)
        if bias1 is not None:
            assert allclose(bias1.grad, bias1_pt.grad, bias1_ref.grad)


@pytest.mark.parametrize("prenorm", [True, False])
# @pytest.mark.parametrize("prenorm", [True])
@pytest.mark.parametrize("is_rms_norm", [False, True])
# @pytest.mark.parametrize("is_rms_norm", [True])
@pytest.mark.parametrize("has_residual", [True, False])
# @pytest.mark.parametrize("has_residual", [False])
@pytest.mark.parametrize("weight_dtype", [torch.float32])
@pytest.mark.parametrize(
    "input_dtype,residual_dtype",
    [(torch.float16, torch.float16), (torch.float16, torch.float32)]
    + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
)
# @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.bfloat16, torch.float32)])
@pytest.mark.parametrize("hidden_size", [192, 2048, 2560, 3000])
# @pytest.mark.parametrize("hidden_size", [256])
def test_layer_norm_linear(
    hidden_size, input_dtype, residual_dtype, weight_dtype, has_residual, is_rms_norm, prenorm
):
    device = "cuda"
    if any(x == torch.bfloat16 for x in [input_dtype, residual_dtype, weight_dtype]):
        atol = 5e-2
    elif any(x == torch.float16 for x in [input_dtype, residual_dtype, weight_dtype]):
        atol = 1e-2
    else:
        atol = 1e-4
    # set seed
    torch.random.manual_seed(0)
    batch_size = 4
    seqlen = 512
    # batch_size = 1
    # seqlen = 1
    layer_norm_ref_fn = layer_norm_ref if not is_rms_norm else rms_norm_ref
    allclose = (
        lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max()
        <= 2 * (x_pt - x_ref).abs().max() + atol
    )
    x0 = torch.randn(
        batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
    )
    x0_pt = x0.detach().clone().requires_grad_()
    x0_ref = x0.detach().clone().requires_grad_()
    if has_residual:
        res = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
        res_pt = res.detach().clone().requires_grad_()
        res_ref = res.detach().clone().requires_grad_()
    else:
        res, res_pt, res_ref = None, None, None
    norm_weight = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
    if not is_rms_norm:
        norm_bias = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
    else:
        norm_bias = None
    norm_weight_pt = norm_weight.detach().clone().requires_grad_()
    norm_weight_ref = norm_weight.detach().clone().requires_grad_()
    norm_bias_pt = norm_bias.detach().clone().requires_grad_() if norm_bias is not None else None
    norm_bias_ref = norm_bias.detach().clone().requires_grad_() if norm_bias is not None else None
    linear_weight = torch.empty(
        2 * hidden_size, hidden_size, device=device, dtype=weight_dtype, requires_grad=True
    )
    torch.nn.init.xavier_uniform_(linear_weight)
    if not is_rms_norm:
        linear_bias = torch.randn(
            2 * hidden_size, device=device, dtype=weight_dtype, requires_grad=True
        )
    else:
        linear_bias = None
    linear_weight_pt = linear_weight.detach().clone().requires_grad_()
    linear_weight_ref = linear_weight.detach().clone().requires_grad_()
    linear_bias_pt = (
        linear_bias.detach().clone().requires_grad_() if linear_bias is not None else None
    )
    linear_bias_ref = (
        linear_bias.detach().clone().requires_grad_() if linear_bias is not None else None
    )

    residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
    with torch.autocast(device_type="cuda", dtype=input_dtype):
        out, *rest = layer_norm_linear_fn(
            x0,
            norm_weight,
            norm_bias,
            linear_weight,
            linear_bias,
            residual=res,
            eps=1e-6,
            prenorm=prenorm,
            residual_in_fp32=residual_in_fp32,
            is_rms_norm=is_rms_norm,
        )
    out_pt, *rest_pt = layer_norm_ref_fn(
        x0_pt, norm_weight_pt, norm_bias_pt, residual=res_pt, eps=1e-6, prenorm=prenorm
    )
    with torch.autocast(device_type="cuda", dtype=input_dtype):
        out_pt = F.linear(out_pt, linear_weight_pt, linear_bias_pt)
    out_ref, *rest_ref = layer_norm_ref_fn(
        x0_ref,
        norm_weight_ref,
        norm_bias_ref,
        residual=res_ref,
        eps=1e-6,
        prenorm=prenorm,
        upcast=True,
    )
    out_ref = F.linear(out_ref.to(linear_weight_ref.dtype), linear_weight_ref, linear_bias_ref)
    if prenorm:
        residual = rest[0]
        residual_pt = rest_pt[0]
        residual_ref = rest_ref[0]
    assert out.dtype == input_dtype
    if prenorm:
        assert residual.dtype == residual_dtype
        assert allclose(residual, residual_pt, residual_ref)
    assert allclose(out, out_pt, out_ref)

    g = torch.randn_like(out) / batch_size
    out.backward(g)
    out_pt.backward(g)
    out_ref.backward(g)
    assert allclose(x0.grad, x0_pt.grad, x0_ref.grad)
    if has_residual:
        assert allclose(res.grad, res_pt.grad, res_ref.grad)
    assert allclose(norm_weight.grad, norm_weight_pt.grad, norm_weight_ref.grad)
    if norm_bias is not None:
        assert allclose(norm_bias.grad, norm_bias_pt.grad, norm_bias_ref.grad)
    assert allclose(linear_weight.grad, linear_weight_pt.grad, linear_weight_ref.grad)
    if linear_bias is not None:
        assert allclose(linear_bias.grad, linear_bias_pt.grad, linear_bias_ref.grad)
