import os
import warnings

import torch
import torchvision
import torchvision.prototype.models.depth.stereo
import utils
from torch.nn import functional as F
from train import make_eval_loader

from utils.metrics import AVAILABLE_METRICS
from visualization import make_prediction_image_side_to_side


def get_args_parser(add_help=True):
    import argparse

    parser = argparse.ArgumentParser(description="PyTorch Stereo Matching Evaluation", add_help=add_help)
    parser.add_argument("--dataset", type=str, default="middlebury2014-train", help="dataset to use")
    parser.add_argument("--dataset-root", type=str, default="", help="root of the dataset")

    parser.add_argument("--checkpoint", type=str, default="", help="path to weights")
    parser.add_argument("--weights", type=str, default=None, help="torchvision API weight")
    parser.add_argument(
        "--model",
        type=str,
        default="crestereo_base",
        help="which model to use if not speciffying a training checkpoint",
    )
    parser.add_argument("--img-folder", type=str, default="images")

    parser.add_argument("--batch-size", type=int, default=1, help="batch size")
    parser.add_argument("--workers", type=int, default=0, help="number of workers")

    parser.add_argument("--eval-size", type=int, nargs="+", default=[384, 512], help="resize size")
    parser.add_argument(
        "--norm-mean", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="mean for image normalization"
    )
    parser.add_argument(
        "--norm-std", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="std for image normalization"
    )
    parser.add_argument(
        "--use-grayscale", action="store_true", help="use grayscale images instead of RGB", default=False
    )
    parser.add_argument("--max-disparity", type=float, default=None, help="maximum disparity")
    parser.add_argument(
        "--interpolation-strategy",
        type=str,
        default="bilinear",
        help="interpolation strategy",
        choices=["bilinear", "bicubic", "mixed"],
    )

    parser.add_argument("--n_iterations", nargs="+", type=int, default=[10], help="number of recurent iterations")
    parser.add_argument("--n_cascades", nargs="+", type=int, default=[1], help="number of cascades")
    parser.add_argument(
        "--metrics",
        type=str,
        nargs="+",
        default=["mae", "rmse", "1px", "3px", "5px", "relepe"],
        help="metrics to log",
        choices=AVAILABLE_METRICS,
    )
    parser.add_argument("--mixed-precision", action="store_true", help="use mixed precision training")

    parser.add_argument("--world-size", type=int, default=1, help="number of distributed processes")
    parser.add_argument("--dist-url", type=str, default="env://", help="url used to set up distributed training")
    parser.add_argument("--device", type=str, default="cuda", help="device to use for training")

    parser.add_argument("--save-images", action="store_true", help="save images of the predictions")
    parser.add_argument("--padder-type", type=str, default="kitti", help="padder type", choices=["kitti", "sintel"])

    return parser


def cascade_inference(model, image_left, image_right, iterations, cascades):
    # check that image size is divisible by 16 * (2 ** (cascades - 1))
    for image in [image_left, image_right]:
        if image.shape[-2] % ((2 ** (cascades - 1))) != 0:
            raise ValueError(
                f"image height is not divisible by {16 * (2 ** (cascades - 1))}. Image shape: {image.shape[-2]}"
            )

        if image.shape[-1] % ((2 ** (cascades - 1))) != 0:
            raise ValueError(
                f"image width is not divisible by {16 * (2 ** (cascades - 1))}. Image shape: {image.shape[-2]}"
            )

    left_image_pyramid = [image_left]
    right_image_pyramid = [image_right]
    for idx in range(0, cascades - 1):
        ds_factor = int(2 ** (idx + 1))
        ds_shape = (image_left.shape[-2] // ds_factor, image_left.shape[-1] // ds_factor)
        left_image_pyramid += F.interpolate(image_left, size=ds_shape, mode="bilinear", align_corners=True).unsqueeze(0)
        right_image_pyramid += F.interpolate(image_right, size=ds_shape, mode="bilinear", align_corners=True).unsqueeze(
            0
        )

    flow_init = None
    for left_image, right_image in zip(reversed(left_image_pyramid), reversed(right_image_pyramid)):
        flow_pred = model(left_image, right_image, flow_init, num_iters=iterations)
        # flow pred is a list
        flow_init = flow_pred[-1]

    return flow_init


@torch.inference_mode()
def _evaluate(
    model,
    args,
    val_loader,
    *,
    padder_mode,
    print_freq=10,
    writer=None,
    step=None,
    iterations=10,
    cascades=1,
    batch_size=None,
    header=None,
    save_images=False,
    save_path="",
):
    """Helper function to compute various metrics (epe, etc.) for a model on a given dataset.
    We process as many samples as possible with ddp.
    """
    model.eval()
    header = header or "Test:"
    device = torch.device(args.device)
    metric_logger = utils.MetricLogger(delimiter="  ")

    iterations = iterations or args.recurrent_updates

    logger = utils.MetricLogger()
    for meter_name in args.metrics:
        logger.add_meter(meter_name, fmt="{global_avg:.4f}")
    if "fl-all" not in args.metrics:
        logger.add_meter("fl-all", fmt="{global_avg:.4f}")

    num_processed_samples = 0
    with torch.cuda.amp.autocast(enabled=args.mixed_precision, dtype=torch.float16):
        batch_idx = 0
        for blob in metric_logger.log_every(val_loader, print_freq, header):
            image_left, image_right, disp_gt, valid_disp_mask = (x.to(device) for x in blob)
            padder = utils.InputPadder(image_left.shape, mode=padder_mode)
            image_left, image_right = padder.pad(image_left, image_right)

            disp_pred = cascade_inference(model, image_left, image_right, iterations, cascades)
            disp_pred = disp_pred[:, :1, :, :]
            disp_pred = padder.unpad(disp_pred)

            if save_images:
                if args.distributed:
                    rank_prefix = args.rank
                else:
                    rank_prefix = 0
                make_prediction_image_side_to_side(
                    disp_pred, disp_gt, valid_disp_mask, save_path, prefix=f"batch_{rank_prefix}_{batch_idx}"
                )

            metrics, _ = utils.compute_metrics(disp_pred, disp_gt, valid_disp_mask, metrics=logger.meters.keys())
            num_processed_samples += image_left.shape[0]
            for name in metrics:
                logger.meters[name].update(metrics[name], n=1)

            batch_idx += 1

    num_processed_samples = utils.reduce_across_processes(num_processed_samples) / args.world_size

    print("Num_processed_samples: ", num_processed_samples)
    if (
        hasattr(val_loader.dataset, "__len__")
        and len(val_loader.dataset) != num_processed_samples
        and torch.distributed.get_rank() == 0
    ):
        warnings.warn(
            f"Number of processed samples {num_processed_samples} is different"
            f"from the dataset size {len(val_loader.dataset)}. This may happen if"
            "the dataset is not divisible by the batch size. Try lowering the batch size for more accurate results."
        )

    if writer is not None and args.rank == 0:
        for meter_name, meter_value in logger.meters.items():
            scalar_name = f"{meter_name} {header}"
            writer.add_scalar(scalar_name, meter_value.avg, step)

    logger.synchronize_between_processes()
    print(header, logger)

    logger_metrics = {k: v.global_avg for k, v in logger.meters.items()}
    return logger_metrics


def evaluate(model, loader, args, writer=None, step=None):
    os.makedirs(args.img_folder, exist_ok=True)
    checkpoint_name = os.path.basename(args.checkpoint) or args.weights
    image_checkpoint_folder = os.path.join(args.img_folder, checkpoint_name)

    metrics = {}
    base_image_folder = os.path.join(image_checkpoint_folder, args.dataset)
    os.makedirs(base_image_folder, exist_ok=True)

    for n_cascades in args.n_cascades:
        for n_iters in args.n_iterations:

            config = f"{n_cascades}c_{n_iters}i"
            config_image_folder = os.path.join(base_image_folder, config)
            os.makedirs(config_image_folder, exist_ok=True)

            metrics[config] = _evaluate(
                model,
                args,
                loader,
                padder_mode=args.padder_type,
                header=f"{args.dataset} evaluation@ size:{args.eval_size} n_cascades:{n_cascades} n_iters:{n_iters}",
                batch_size=args.batch_size,
                writer=writer,
                step=step,
                iterations=n_iters,
                cascades=n_cascades,
                save_path=config_image_folder,
                save_images=args.save_images,
            )

    metric_log = []
    metric_log_dict = {}
    # print the final results
    for config in metrics:
        config_tokens = config.split("_")
        config_iters = config_tokens[1][:-1]
        config_cascades = config_tokens[0][:-1]

        metric_log_dict[config_cascades] = metric_log_dict.get(config_cascades, {})
        metric_log_dict[config_cascades][config_iters] = metrics[config]

        evaluation_str = f"{args.dataset} evaluation@ size:{args.eval_size} n_cascades:{config_cascades} recurrent_updates:{config_iters}"
        metrics_str = f"Metrics: {metrics[config]}"
        metric_log.extend([evaluation_str, metrics_str])

        print(evaluation_str)
        print(metrics_str)

    eval_log_name = f"{checkpoint_name.replace('.pth', '')}_eval.log"
    print("Saving eval log to: ", eval_log_name)
    with open(eval_log_name, "w") as f:
        f.write(f"Dataset: {args.dataset} @size: {args.eval_size}:\n")
        # write the dict line by line for each key, and each value in the keys
        for config_cascades in metric_log_dict:
            f.write("{\n")
            f.write(f"\t{config_cascades}: {{\n")
            for config_iters in metric_log_dict[config_cascades]:
                # convert every metric to 4 decimal places
                metrics = metric_log_dict[config_cascades][config_iters]
                metrics = {k: float(f"{v:.3f}") for k, v in metrics.items()}
                f.write(f"\t\t{config_iters}: {metrics}\n")
            f.write("\t},\n")
            f.write("}\n")


def load_checkpoint(args):
    utils.setup_ddp(args)

    if not args.weights:
        checkpoint = torch.load(args.checkpoint, map_location=torch.device("cpu"), weights_only=True)
        if "model" in checkpoint:
            experiment_args = checkpoint["args"]
            model = torchvision.prototype.models.depth.stereo.__dict__[experiment_args.model](weights=None)
            model.load_state_dict(checkpoint["model"])
        else:
            model = torchvision.prototype.models.depth.stereo.__dict__[args.model](weights=None)
            model.load_state_dict(checkpoint)

        # set the appropriate devices
        if args.distributed and args.device == "cpu":
            raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun")
        device = torch.device(args.device)
    else:
        model = torchvision.prototype.models.depth.stereo.__dict__[args.model](weights=args.weights)

    # convert to DDP if need be
    if args.distributed:
        model = model.to(args.device)
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
    else:
        model.to(device)

    return model


def main(args):
    model = load_checkpoint(args)
    loader = make_eval_loader(args.dataset, args)
    evaluate(model, loader, args)


if __name__ == "__main__":
    args = get_args_parser().parse_args()
    main(args)
