from contextlib import nullcontext
import math
from typing import Optional, Tuple
# from megatron.model import LayerNorm
from easydict import EasyDict as adict
import torch
from torch.nn import functional as F
from torch import nn
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
# from optimus import flash_attn_func
# from megatron.core import tensor_parallel
# from megatron.core import parallel_state as mpu
# from megatron.core.utils import make_viewless_tensor, divide
# from megatron.model.fused_rms_norm import RMSNorm
# from megatron.model.transformer import (
#     FlashSelfAttention,
#     NoopTransformerLayer,
#     _cfg_to_kwargs,
# )
# from megatron.model.enums import AttnMaskType, AttnType
# from megatron.model.fused_softmax import FusedScaleMaskSoftmax
# from megatron.model.utils import attention_mask_func

# from megatron.model.module import MegatronModule

# try:
#     from einops import rearrange
# except ImportError:
#     rearrange = None

# from flash_attn import flash_attn_varlen_func as flash_attn_unpadded_func

# try:
#     # flash attention 2.x
#     from flash_attn import flash_attn_varlen_func as flash_attn_unpadded_func
# except ImportError:
#     try:
#         # flash attention 1.x
#         from flash_attn.flash_attn_interface import flash_attn_unpadded_func
#     except ImportError:
#         flash_attn_unpadded_func = None

# try:
#     from flash_attn.flash_attn_interface import flash_attn_unpadded_relative_attention_bias_func
# except ImportError:
#     flash_attn_unpadded_relative_attention_bias_func = None

# try:
#     from flash_attn.flash_attn_interface import mask_flash_attn_unpadded_func
# except ImportError:
#     mask_flash_attn_unpadded_func = None


class LayerNormfp32(torch.nn.LayerNorm):
    """Subclass torch's LayerNorm to handle fp16."""

    def forward(self, x: torch.Tensor):
        orig_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(orig_type)


def get_abs_pos(abs_pos, tgt_size):
    # abs_pos: L, C
    # tgt_size: M
    # return: M, C

    # print(tgt_size)
    # print(abs_pos.shape)
    # exit()
    dim = abs_pos.size(-1)
    # print(dim)
    abs_pos_new = abs_pos.squeeze(0)
    cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:]



    src_size = int(math.sqrt(abs_pos_new.shape[0] - 1))
    tgt_size = int(math.sqrt(tgt_size))
    dtype = abs_pos.dtype

    if src_size != tgt_size:
        old_pos_embed = old_pos_embed.view(1, src_size, src_size, dim).permute(0, 3, 1,
                                                                                    2).contiguous()
        old_pos_embed = old_pos_embed.to(torch.float32)
        new_pos_embed = F.interpolate(
            old_pos_embed,
            size=(tgt_size, tgt_size),
            mode='bicubic',
            antialias=True,
            align_corners=False,
        ).to(dtype)
        new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
        new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim)
        vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0)
        vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim)
        return vision_pos_embed
    else:
        return abs_pos

@torch.jit.script
def quick_gelu(x):
    return x * torch.sigmoid(1.702 * x)



class CLIPVisionEmbeddings(nn.Module):
    def __init__(self, hidden_size=1024, image_size=224, patch_size=14, num_channels=3):
        super().__init__()
        self.embed_dim = hidden_size
        self.image_size = image_size
        self.patch_size = patch_size

        self.class_embedding = torch.nn.Parameter(torch.randn(self.embed_dim))

        self.patch_embedding = torch.nn.Conv2d(
            in_channels=num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            bias=False,
        )

        self.num_patches = (self.image_size // self.patch_size) ** 2
        self.num_positions = self.num_patches + 1
        self.position_embedding = torch.nn.Embedding(self.num_positions, self.embed_dim)
        self.register_buffer(
            "position_ids", torch.arange(self.num_positions).expand((1, -1))
        )

    def forward(self, pixel_values, patch_embeds):
        batch_size = pixel_values.shape[0]
        # patch_embeds = self.patch_embedding(
        #     pixel_values
        # )  # shape = [*, width, grid, grid]


        if patch_embeds is not None:
            patch_embeds = patch_embeds
            # print(patch_embeds.shape)
        else:
            patch_embeds = self.patch_embedding(pixel_values)  
            # print(111111)
        # shape = [*, width, grid, grid]
        # patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)


        class_embeds = self.class_embedding.expand(batch_size, 1, -1)
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)

        # x = torch.cat([cls_token, x], dim=1)
        embeddings = embeddings + get_abs_pos(self.position_embedding(self.position_ids), embeddings.size(1))
        # embeddings = embeddings + self.position_embedding(self.position_ids)
        return embeddings


class NoTPFeedForward(nn.Module):
    def __init__(
            self,
            cfg,
            dim: int,
            hidden_dim: int,
    ):
        super().__init__()

        self.fc1 = torch.nn.Linear(dim, hidden_dim, bias=True)
        self.fc2 = torch.nn.Linear(hidden_dim, dim, bias=True)

    def forward(self, x):
        output = self.fc2(quick_gelu(self.fc1(x)))
        return output


# from optimus.flash_attn_interface import flash_attn_qkvpacked_func


# class NoTPAttention(nn.Module):
#     def __init__(self, cfg):
#         super().__init__()
#         self.num_heads = cfg.num_attention_heads
#         self.n_local_heads = cfg.num_attention_heads
#         self.head_dim = cfg.hidden_size // cfg.num_attention_heads
#         self.max_seq_len = cfg.seq_length
#         self.use_flash_attention = cfg.use_flash_attn

#         self.qkv_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size * 3, bias=True)
#         self.out_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=True)

#         # self.core_attention = CoreAttention(cfg, AttnType.self_attn)

#         self.attn_drop = cfg.attention_dropout

#     def forward(
#             self,
#             x: torch.Tensor,
#     ):
#         bsz, seqlen, _ = x.shape
#         xqkv = self.qkv_proj(x)
#         xqkv = xqkv.view(bsz, seqlen, 3, self.num_heads, self.head_dim)

#         if self.use_flash_attention:
#             output = flash_attn_qkvpacked_func(xqkv)
#             output = output.view(bsz, seqlen, -1)
#         else:
#             xq, xk, xv = torch.split(xqkv, 1, dim=2)
#             xq = xq.squeeze(2)
#             xk = xk.squeeze(2)
#             xv = xv.squeeze(2)
#             # xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]

#             # （B, num_head, S, head_size)
#             xq = xq.permute(0, 2, 1, 3)
#             xk = xk.permute(0, 2, 1, 3)
#             xv = xv.permute(0, 2, 1, 3)

#             output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None)
#             utput = output.permute(0, 2, 1, 3).view(bsz, seqlen, -1)
#         output = self.out_proj(output)
#         return output


# from optimus.flash_attn_interface import flash_attn_qkvpacked_func


class NoTPAttention(torch.nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.num_heads = cfg.num_attention_heads
        self.n_local_heads = cfg.num_attention_heads
        self.head_dim = cfg.hidden_size // cfg.num_attention_heads
        self.max_seq_len = cfg.seq_length
        self.use_flash_attention = cfg.use_flash_attn

        self.qkv_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size * 3, bias=True)
        self.out_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=True)

        # self.core_attention = CoreAttention(cfg, AttnType.self_attn)

        self.attn_drop = cfg.attention_dropout

    def forward(
            self,
            x: torch.Tensor,
    ):
        bsz, seqlen, _ = x.shape
        xqkv = self.qkv_proj(x)
        xqkv = xqkv.view(bsz, seqlen, 3, self.num_heads, self.head_dim)

        if self.use_flash_attention:
            output = flash_attn_qkvpacked_func(xqkv)
            output = output.view(bsz, seqlen, -1)
            # xq, xk, xv = torch.split(xqkv, 1, dim=2)
            # xq = xq.squeeze(2)
            # xk = xk.squeeze(2)
            # xv = xv.squeeze(2)
            # # xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]

            # # （B, num_head, S, head_size)
            # xq = xq.permute(0, 2, 1, 3)
            # xk = xk.permute(0, 2, 1, 3)
            # xv = xv.permute(0, 2, 1, 3)
            # # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
            # output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None)
            # output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1)
                # output = output.permute(0, 2, 1, 3).contiguous().view(bsz, seqlen, -1)
        else:
            # output = flash_attn_qkvpacked_func(xqkv)
            xq, xk, xv = torch.split(xqkv, 1, dim=2)
            xq = xq.squeeze(2)
            xk = xk.squeeze(2)
            xv = xv.squeeze(2)
            # xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]

            # （B, num_head, S, head_size)
            xq = xq.permute(0, 2, 1, 3)
            xk = xk.permute(0, 2, 1, 3)
            xv = xv.permute(0, 2, 1, 3)
            # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
            output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None)
            output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1)
        output = self.out_proj(output)
        return output

class NoTPTransformerBlock(nn.Module):
    def __init__(self, cfg, layer_id: int, multiple_of=256):
        super().__init__()

        self.n_heads = cfg.num_attention_heads
        self.dim = cfg.hidden_size
        self.head_dim = cfg.hidden_size // cfg.num_attention_heads
        self.self_attn = NoTPAttention(cfg)
        self.mlp = NoTPFeedForward(
            cfg, dim=cfg.hidden_size, hidden_dim=cfg.ffn_hidden_size
        )
        self.layer_id = layer_id
        self.layer_norm1 = torch.nn.LayerNorm(
            cfg.hidden_size, eps=cfg.layernorm_epsilon
        )
        self.layer_norm2 = torch.nn.LayerNorm(
            cfg.hidden_size, eps=cfg.layernorm_epsilon
        )

    def forward(self, x: torch.Tensor):
        residual = self.self_attn.forward(self.layer_norm1(x))
        h = x + residual
        out = h + self.mlp.forward(self.layer_norm2(h))
        return out


class NoTPTransformer(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.cfg = cfg
        # self.recompute_list = self.cfg.get("recompute_list", [])
        self.num_layers = cfg.num_layers  # _get_num_layers(cfg)

        self.layers = torch.nn.ModuleList()
        for layer_id in range(self.num_layers):
            self.layers.append(
                NoTPTransformerBlock(
                    cfg,
                    layer_id + 1,
                )
            )

    def forward(
            self,
            hidden_states,
    ):

        for lid, layer in enumerate(self.layers):
            # if lid in self.recompute_list:
            #     def custom(layer_id):
            #         def custom_forward(*args, **kwargs):
            #             x_ = self.layers[layer_id](*args, **kwargs)
            #             return x_

            #         return custom_forward

            #     assert hidden_states.requires_grad == True, logger.warning(
            #         "When using recalculation, the input must have grad fn"
            #     )
            #     hidden_states = tensor_parallel.checkpoint(
            #         custom(lid),
            #         False,
            #         hidden_states.contiguous()
            #     )
            # else:
            hidden_states = layer(hidden_states)

        return hidden_states


# from megatron.core.tensor_parallel.layers import non_tensor_paralleled, local_dp_reduce, local_dp_scatter

class VitModel(nn.Module):
    def __init__(
            self,
            cfg,
            freeze_embed=False,
            freeze_pre_norm=False
    ) -> None:
        super().__init__()

        self.embeddings = CLIPVisionEmbeddings(hidden_size=cfg.hidden_size, image_size=cfg.image_size, patch_size=cfg.patch_size)

        if freeze_embed:
            for name, param in self.embeddings.named_parameters():
                param.requires_grad = False

        self.transformer = NoTPTransformer(cfg=cfg)

        if cfg.get("fp32norm", False):
            logger.info("Load fp32 layernorm for ViT.")
            self.pre_layrnorm = LayerNormfp32(
                cfg.hidden_size,
                eps=cfg.get("pre_layernorm_epsilon", 1e-5),
            )
        else:
            self.pre_layrnorm = torch.nn.LayerNorm(
                cfg.hidden_size,
                eps=cfg.get("pre_layernorm_epsilon", 1e-5),
            )

        # self.pre_layrnorm = RMSNorm(
        #     cfg.hidden_size,
        #     eps=cfg.get("pre_layernorm_epsilon", 1e-5),
        #     sequence_parallel=False,
        #     use_fp32=True,
        #     use_optimus=True,
        # )

        if freeze_pre_norm:
            for name, param in self.pre_layrnorm.named_parameters():
                param.requires_grad = False

        for p in self.parameters():
            p.micro_dp = True

    def set_input_tensor(self, input_tensor):
        if not isinstance(input_tensor, list):
            input_tensor = [input_tensor]
        self.transformer.set_input_tensor(input_tensor[0])

    def __str__(self) -> str:
        return "open_clip"

    def forward(
            self,
            x,
            patch_embeds
    ):
        x = self.embeddings(x, patch_embeds)
        hidden_states = self.pre_layrnorm(x)

        # hidden_states, dis = local_dp_scatter(hidden_states)
        output = self.transformer(hidden_states)

        # output = local_dp_reduce(output, dis)

        return output


vit_model_cfg = adict(
    num_layers=24,
    hidden_size=1024,
    num_heads = 16,
    num_attention_heads=16,
    ffn_hidden_size=4096,
    seq_length=256,
    max_position_embeddings=256,
    use_flash_attn=False,
    understand_projector_stride=2,
    hidden_dropout = 0.0,
    attention_dropout = 0.0,
    no_persist_layer_norm = False,
    layernorm_epsilon = 1e-5,
    pre_layernorm_epsilon = 1e-5,
    image_size = 224,
    patch_size = 14,
    recompute_list = []
)

def build_clip_l():
    return VitModel(
        cfg=vit_model_cfg,
        freeze_embed=False,
        freeze_pre_norm=False,
    )


if __name__ == '__main__':

    
    from mmgpt.model.vision_encoder.sam_b import build_sam_vit_b



    vit_model_cfg = adict(
        num_layers=24,
        hidden_size=1024,
        num_attention_heads=16,
        ffn_hidden_size=4096,
        seq_length=256,
        max_position_embeddings=256,
        use_flash_attn=False,
        understand_projector_stride=2,
        hidden_dropout = 0.0,
        attention_dropout = 0.0,
        no_persist_layer_norm = False,
        layernorm_epsilon = 1e-5,
        pre_layernorm_epsilon = 1e-5,
        image_size = 224,
        patch_size = 14,
        recompute_list = []
    )

    sam_model = build_sam_vit_b()


    vision_model = VitModel(
        cfg=vit_model_cfg,
        freeze_embed=False,
        freeze_pre_norm=False,
    )

    # model = VitModel(1344)
    # x = torch.zeros(2, 3, 224, 224)
    x = torch.zeros(2, 3, 1024, 1024)

    
    with torch.no_grad():
        # y = vision_model(x)
        patch_embed = sam_model(x)
        print(patch_embed.shape)
        y = vision_model(x, patch_embed)
        print(y.shape)

        image_feature = torch.add(y[:, 1:], patch_embed.flatten(2).permute(0, 2, 1))

        print(image_feature.shape)
