import argparse
import os
import warnings
from pathlib import Path
from typing import List, Union

import numpy as np
import torch
import torch.distributed as dist
import torchvision.models.optical_flow
import torchvision.prototype.models.depth.stereo
import utils
import visualization

from parsing import make_dataset, make_eval_transform, make_train_transform, VALID_DATASETS
from torch import nn
from torchvision.transforms.functional import get_dimensions, InterpolationMode, resize
from utils.metrics import AVAILABLE_METRICS
from utils.norm import freeze_batch_norm


def make_stereo_flow(flow: Union[torch.Tensor, List[torch.Tensor]], model_out_channels: int) -> torch.Tensor:
    """Helper function to make stereo flow from a given model output"""
    if isinstance(flow, list):
        return [make_stereo_flow(flow_i, model_out_channels) for flow_i in flow]

    B, C, H, W = flow.shape
    # we need to add zero flow if the model outputs 2 channels
    if C == 1 and model_out_channels == 2:
        zero_flow = torch.zeros_like(flow)
        # by convention the flow is X-Y axis, so we need the Y flow last
        flow = torch.cat([flow, zero_flow], dim=1)
    return flow


def make_lr_schedule(args: argparse.Namespace, optimizer: torch.optim.Optimizer) -> np.ndarray:
    """Helper function to return a learning rate scheduler for CRE-stereo"""
    if args.decay_after_steps < args.warmup_steps:
        raise ValueError(f"decay_after_steps: {args.function} must be greater than warmup_steps: {args.warmup_steps}")

    warmup_steps = args.warmup_steps if args.warmup_steps else 0
    flat_lr_steps = args.decay_after_steps - warmup_steps if args.decay_after_steps else 0
    decay_lr_steps = args.total_iterations - flat_lr_steps

    max_lr = args.lr
    min_lr = args.min_lr

    schedulers = []
    milestones = []

    if warmup_steps > 0:
        if args.lr_warmup_method == "linear":
            warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=args.lr_warmup_factor, total_iters=warmup_steps
            )
        elif args.lr_warmup_method == "constant":
            warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
                optimizer, factor=args.lr_warmup_factor, total_iters=warmup_steps
            )
        else:
            raise ValueError(f"Unknown lr warmup method {args.lr_warmup_method}")
        schedulers.append(warmup_lr_scheduler)
        milestones.append(warmup_steps)

    if flat_lr_steps > 0:
        flat_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=max_lr, total_iters=flat_lr_steps)
        schedulers.append(flat_lr_scheduler)
        milestones.append(flat_lr_steps + warmup_steps)

    if decay_lr_steps > 0:
        if args.lr_decay_method == "cosine":
            decay_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=decay_lr_steps, eta_min=min_lr
            )
        elif args.lr_decay_method == "linear":
            decay_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=max_lr, end_factor=min_lr, total_iters=decay_lr_steps
            )
        elif args.lr_decay_method == "exponential":
            decay_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
                optimizer, gamma=args.lr_decay_gamma, last_epoch=-1
            )
        else:
            raise ValueError(f"Unknown lr decay method {args.lr_decay_method}")
        schedulers.append(decay_lr_scheduler)

    scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers, milestones=milestones)
    return scheduler


def shuffle_dataset(dataset):
    """Shuffle the dataset"""
    perm = torch.randperm(len(dataset))
    return torch.utils.data.Subset(dataset, perm)


def resize_dataset_to_n_steps(
    dataset: torch.utils.data.Dataset, dataset_steps: int, samples_per_step: int, args: argparse.Namespace
) -> torch.utils.data.Dataset:
    original_size = len(dataset)
    if args.steps_is_epochs:
        samples_per_step = original_size
    target_size = dataset_steps * samples_per_step

    dataset_copies = []
    n_expands, remainder = divmod(target_size, original_size)
    for idx in range(n_expands):
        dataset_copies.append(dataset)

    if remainder > 0:
        dataset_copies.append(torch.utils.data.Subset(dataset, list(range(remainder))))

    if args.dataset_shuffle:
        dataset_copies = [shuffle_dataset(dataset_copy) for dataset_copy in dataset_copies]

    dataset = torch.utils.data.ConcatDataset(dataset_copies)
    return dataset


def get_train_dataset(dataset_root: str, args: argparse.Namespace) -> torch.utils.data.Dataset:
    datasets = []
    for dataset_name in args.train_datasets:
        transform = make_train_transform(args)
        dataset = make_dataset(dataset_name, dataset_root, transform)
        datasets.append(dataset)

    if len(datasets) == 0:
        raise ValueError("No datasets specified for training")

    samples_per_step = args.world_size * args.batch_size

    for idx, (dataset, steps_per_dataset) in enumerate(zip(datasets, args.dataset_steps)):
        datasets[idx] = resize_dataset_to_n_steps(dataset, steps_per_dataset, samples_per_step, args)

    dataset = torch.utils.data.ConcatDataset(datasets)
    if args.dataset_order_shuffle:
        dataset = shuffle_dataset(dataset)

    print(f"Training dataset: {len(dataset)} samples")
    return dataset


@torch.inference_mode()
def _evaluate(
    model,
    args,
    val_loader,
    *,
    padder_mode,
    print_freq=10,
    writer=None,
    step=None,
    iterations=None,
    batch_size=None,
    header=None,
):
    """Helper function to compute various metrics (epe, etc.) for a model on a given dataset."""
    model.eval()
    header = header or "Test:"
    device = torch.device(args.device)
    metric_logger = utils.MetricLogger(delimiter="  ")

    iterations = iterations or args.recurrent_updates

    logger = utils.MetricLogger()
    for meter_name in args.metrics:
        logger.add_meter(meter_name, fmt="{global_avg:.4f}")
    if "fl-all" not in args.metrics:
        logger.add_meter("fl-all", fmt="{global_avg:.4f}")

    num_processed_samples = 0
    with torch.cuda.amp.autocast(enabled=args.mixed_precision, dtype=torch.float16):
        for blob in metric_logger.log_every(val_loader, print_freq, header):
            image_left, image_right, disp_gt, valid_disp_mask = (x.to(device) for x in blob)
            padder = utils.InputPadder(image_left.shape, mode=padder_mode)
            image_left, image_right = padder.pad(image_left, image_right)

            disp_predictions = model(image_left, image_right, flow_init=None, num_iters=iterations)
            disp_pred = disp_predictions[-1][:, :1, :, :]
            disp_pred = padder.unpad(disp_pred)

            metrics, _ = utils.compute_metrics(disp_pred, disp_gt, valid_disp_mask, metrics=logger.meters.keys())
            num_processed_samples += image_left.shape[0]
            for name in metrics:
                logger.meters[name].update(metrics[name], n=1)

    num_processed_samples = utils.reduce_across_processes(num_processed_samples)

    print("Num_processed_samples: ", num_processed_samples)
    if (
        hasattr(val_loader.dataset, "__len__")
        and len(val_loader.dataset) != num_processed_samples
        and torch.distributed.get_rank() == 0
    ):
        warnings.warn(
            f"Number of processed samples {num_processed_samples} is different"
            f"from the dataset size {len(val_loader.dataset)}. This may happen if"
            "the dataset is not divisible by the batch size. Try lowering the batch size or GPU number for more accurate results."
        )

    if writer is not None and args.rank == 0:
        for meter_name, meter_value in logger.meters.items():
            scalar_name = f"{meter_name} {header}"
            writer.add_scalar(scalar_name, meter_value.avg, step)

    logger.synchronize_between_processes()
    print(header, logger)


def make_eval_loader(dataset_name: str, args: argparse.Namespace) -> torch.utils.data.DataLoader:
    if args.weights:
        weights = torchvision.models.get_weight(args.weights)
        trans = weights.transforms()

        def preprocessing(image_left, image_right, disp, valid_disp_mask):
            C_o, H_o, W_o = get_dimensions(image_left)
            image_left, image_right = trans(image_left, image_right)

            C_t, H_t, W_t = get_dimensions(image_left)
            scale_factor = W_t / W_o

            if disp is not None and not isinstance(disp, torch.Tensor):
                disp = torch.from_numpy(disp)
                if W_t != W_o:
                    disp = resize(disp, (H_t, W_t), mode=InterpolationMode.BILINEAR) * scale_factor
            if valid_disp_mask is not None and not isinstance(valid_disp_mask, torch.Tensor):
                valid_disp_mask = torch.from_numpy(valid_disp_mask)
                if W_t != W_o:
                    valid_disp_mask = resize(valid_disp_mask, (H_t, W_t), mode=InterpolationMode.NEAREST)
            return image_left, image_right, disp, valid_disp_mask

    else:
        preprocessing = make_eval_transform(args)

    val_dataset = make_dataset(dataset_name, args.dataset_root, transforms=preprocessing)
    if args.distributed:
        sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=False)
    else:
        sampler = torch.utils.data.SequentialSampler(val_dataset)

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        sampler=sampler,
        batch_size=args.batch_size,
        pin_memory=True,
        num_workers=args.workers,
    )

    return val_loader


def evaluate(model, loaders, args, writer=None, step=None):
    for loader_name, loader in loaders.items():
        _evaluate(
            model,
            args,
            loader,
            iterations=args.recurrent_updates,
            padder_mode=args.padder_type,
            header=f"{loader_name} evaluation",
            batch_size=args.batch_size,
            writer=writer,
            step=step,
        )


def run(model, optimizer, scheduler, train_loader, val_loaders, logger, writer, scaler, args):
    device = torch.device(args.device)
    # wrap the loader in a logger
    loader = iter(logger.log_every(train_loader))
    # output channels
    model_out_channels = model.module.output_channels if args.distributed else model.output_channels

    torch.set_num_threads(args.threads)

    sequence_criterion = utils.SequenceLoss(
        gamma=args.gamma,
        max_flow=args.max_disparity,
        exclude_large_flows=args.flow_loss_exclude_large,
    ).to(device)

    if args.consistency_weight:
        consistency_criterion = utils.FlowSequenceConsistencyLoss(
            args.gamma,
            resize_factor=0.25,
            rescale_factor=0.25,
            rescale_mode="bilinear",
        ).to(device)
    else:
        consistency_criterion = None

    if args.psnr_weight:
        psnr_criterion = utils.PSNRLoss().to(device)
    else:
        psnr_criterion = None

    if args.smoothness_weight:
        smoothness_criterion = utils.SmoothnessLoss().to(device)
    else:
        smoothness_criterion = None

    if args.photometric_weight:
        photometric_criterion = utils.FlowPhotoMetricLoss(
            ssim_weight=args.photometric_ssim_weight,
            max_displacement_ratio=args.photometric_max_displacement_ratio,
            ssim_use_padding=False,
        ).to(device)
    else:
        photometric_criterion = None

    for step in range(args.start_step + 1, args.total_iterations + 1):
        data_blob = next(loader)
        optimizer.zero_grad()

        # unpack the data blob
        image_left, image_right, disp_mask, valid_disp_mask = (x.to(device) for x in data_blob)
        with torch.cuda.amp.autocast(enabled=args.mixed_precision, dtype=torch.float16):
            disp_predictions = model(image_left, image_right, flow_init=None, num_iters=args.recurrent_updates)
            # different models have different outputs, make sure we get the right ones for this task
            disp_predictions = make_stereo_flow(disp_predictions, model_out_channels)
            # should the architecture or training loop require it, we have to adjust the disparity mask
            # target to possibly look like an optical flow mask
            disp_mask = make_stereo_flow(disp_mask, model_out_channels)
            # sequence loss on top of the model outputs

        loss = sequence_criterion(disp_predictions, disp_mask, valid_disp_mask) * args.flow_loss_weight

        if args.consistency_weight > 0:
            loss_consistency = consistency_criterion(disp_predictions)
            loss += loss_consistency * args.consistency_weight

        if args.psnr_weight > 0:
            loss_psnr = 0.0
            for pred in disp_predictions:
                # predictions might have 2 channels
                loss_psnr += psnr_criterion(
                    pred * valid_disp_mask.unsqueeze(1),
                    disp_mask * valid_disp_mask.unsqueeze(1),
                ).mean()  # mean the psnr loss over the batch
            loss += loss_psnr / len(disp_predictions) * args.psnr_weight

        if args.photometric_weight > 0:
            loss_photometric = 0.0
            for pred in disp_predictions:
                # predictions might have 1 channel, therefore we need to inpute 0s for the second channel
                if model_out_channels == 1:
                    pred = torch.cat([pred, torch.zeros_like(pred)], dim=1)

                loss_photometric += photometric_criterion(
                    image_left, image_right, pred, valid_disp_mask
                )  # photometric loss already comes out meaned over the batch
            loss += loss_photometric / len(disp_predictions) * args.photometric_weight

        if args.smoothness_weight > 0:
            loss_smoothness = 0.0
            for pred in disp_predictions:
                # predictions might have 2 channels
                loss_smoothness += smoothness_criterion(
                    image_left, pred[:, :1, :, :]
                ).mean()  # mean the smoothness loss over the batch
            loss += loss_smoothness / len(disp_predictions) * args.smoothness_weight

        with torch.no_grad():
            metrics, _ = utils.compute_metrics(
                disp_predictions[-1][:, :1, :, :],  # predictions might have 2 channels
                disp_mask[:, :1, :, :],  # so does the ground truth
                valid_disp_mask,
                args.metrics,
            )

        metrics.pop("fl-all", None)
        logger.update(loss=loss, **metrics)

        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            if args.clip_grad_norm:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad_norm)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            if args.clip_grad_norm:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad_norm)
            optimizer.step()

        scheduler.step()

        if not dist.is_initialized() or dist.get_rank() == 0:
            if writer is not None and step % args.tensorboard_log_frequency == 0:
                # log the loss and metrics to tensorboard

                writer.add_scalar("loss", loss, step)
                for name, value in logger.meters.items():
                    writer.add_scalar(name, value.avg, step)
                # log the images to tensorboard
                pred_grid = visualization.make_training_sample_grid(
                    image_left, image_right, disp_mask, valid_disp_mask, disp_predictions
                )
                writer.add_image("predictions", pred_grid, step, dataformats="HWC")

                # second thing we want to see is how relevant the iterative refinement is
                pred_sequence_grid = visualization.make_disparity_sequence_grid(disp_predictions, disp_mask)
                writer.add_image("sequence", pred_sequence_grid, step, dataformats="HWC")

        if step % args.save_frequency == 0:
            if not args.distributed or args.rank == 0:
                model_without_ddp = (
                    model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
                )
                checkpoint = {
                    "model": model_without_ddp.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "scheduler": scheduler.state_dict(),
                    "step": step,
                    "args": args,
                }
                os.makedirs(args.checkpoint_dir, exist_ok=True)
                torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}_{step}.pth")
                torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}.pth")

        if step % args.valid_frequency == 0:
            evaluate(model, val_loaders, args, writer, step)
            model.train()
            if args.freeze_batch_norm:
                if isinstance(model, nn.parallel.DistributedDataParallel):
                    freeze_batch_norm(model.module)
                else:
                    freeze_batch_norm(model)

    # one final save at the end
    if not args.distributed or args.rank == 0:
        model_without_ddp = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
        checkpoint = {
            "model": model_without_ddp.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
            "step": step,
            "args": args,
        }
        os.makedirs(args.checkpoint_dir, exist_ok=True)
        torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}_{step}.pth")
        torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}.pth")


def main(args):
    args.total_iterations = sum(args.dataset_steps)

    # initialize DDP setting
    utils.setup_ddp(args)
    print(args)

    args.test_only = args.train_datasets is None

    # set the appropriate devices
    if args.distributed and args.device == "cpu":
        raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun")
    device = torch.device(args.device)

    # select model architecture
    model = torchvision.prototype.models.depth.stereo.__dict__[args.model](weights=args.weights)

    # convert to DDP if need be
    if args.distributed:
        model = model.to(args.gpu)
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module
    else:
        model.to(device)
        model_without_ddp = model

    os.makedirs(args.checkpoint_dir, exist_ok=True)

    val_loaders = {name: make_eval_loader(name, args) for name in args.test_datasets}

    # EVAL ONLY configurations
    if args.test_only:
        evaluate(model, val_loaders, args)
        return

    # Sanity check for the parameter count
    print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

    # Compose the training dataset
    train_dataset = get_train_dataset(args.dataset_root, args)

    # initialize the optimizer
    if args.optimizer == "adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer == "sgd":
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9)
    else:
        raise ValueError(f"Unknown optimizer {args.optimizer}. Please choose between adam and sgd")

    # initialize the learning rate schedule
    scheduler = make_lr_schedule(args, optimizer)

    # load them from checkpoint if needed
    args.start_step = 0
    if args.resume_path is not None:
        checkpoint = torch.load(args.resume_path, map_location="cpu", weights_only=True)
        if "model" in checkpoint:
            # this means the user requested to resume from a training checkpoint
            model_without_ddp.load_state_dict(checkpoint["model"])
            # this means the user wants to continue training from where it was left off
            if args.resume_schedule:
                optimizer.load_state_dict(checkpoint["optimizer"])
                scheduler.load_state_dict(checkpoint["scheduler"])
                args.start_step = checkpoint["step"] + 1
                # modify starting point of the dat
                sample_start_step = args.start_step * args.batch_size * args.world_size
                train_dataset = train_dataset[sample_start_step:]

        else:
            # this means the user wants to finetune on top of a model state dict
            # and that no other changes are required
            model_without_ddp.load_state_dict(checkpoint)

    torch.backends.cudnn.benchmark = True

    # enable training mode
    model.train()
    if args.freeze_batch_norm:
        freeze_batch_norm(model_without_ddp)

    # put dataloader on top of the dataset
    # make sure to disable shuffling since the dataset is already shuffled
    # in order to guarantee quasi randomness whilst retaining a deterministic
    # dataset consumption order
    if args.distributed:
        # the train dataset is preshuffled in order to respect the iteration order
        sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=False, drop_last=True)
    else:
        # the train dataset is already shuffled, so we can use a simple SequentialSampler
        sampler = torch.utils.data.SequentialSampler(train_dataset)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        sampler=sampler,
        batch_size=args.batch_size,
        pin_memory=True,
        num_workers=args.workers,
    )

    # initialize the logger
    if args.tensorboard_summaries:
        from torch.utils.tensorboard import SummaryWriter

        tensorboard_path = Path(args.checkpoint_dir) / "tensorboard"
        os.makedirs(tensorboard_path, exist_ok=True)

        tensorboard_run = tensorboard_path / f"{args.name}"
        writer = SummaryWriter(tensorboard_run)
    else:
        writer = None

    logger = utils.MetricLogger(delimiter="  ")

    scaler = torch.cuda.amp.GradScaler() if args.mixed_precision else None
    # run the training loop
    # this will perform optimization, respectively logging and saving checkpoints
    # when need be
    run(
        model=model,
        optimizer=optimizer,
        scheduler=scheduler,
        train_loader=train_loader,
        val_loaders=val_loaders,
        logger=logger,
        writer=writer,
        scaler=scaler,
        args=args,
    )


def get_args_parser(add_help=True):
    import argparse

    parser = argparse.ArgumentParser(description="PyTorch Stereo Matching Training", add_help=add_help)
    # checkpointing
    parser.add_argument("--name", default="crestereo", help="name of the experiment")
    parser.add_argument("--resume", type=str, default=None, help="from which checkpoint to resume")
    parser.add_argument("--checkpoint-dir", type=str, default="checkpoints", help="path to the checkpoint directory")

    # dataset
    parser.add_argument("--dataset-root", type=str, default="", help="path to the dataset root directory")
    parser.add_argument(
        "--train-datasets",
        type=str,
        nargs="+",
        default=["crestereo"],
        help="dataset(s) to train on",
        choices=list(VALID_DATASETS.keys()),
    )
    parser.add_argument(
        "--dataset-steps", type=int, nargs="+", default=[300_000], help="number of steps for each dataset"
    )
    parser.add_argument(
        "--steps-is-epochs", action="store_true", help="if set, dataset-steps are interpreted as epochs"
    )
    parser.add_argument(
        "--test-datasets",
        type=str,
        nargs="+",
        default=["middlebury2014-train"],
        help="dataset(s) to test on",
        choices=["middlebury2014-train"],
    )
    parser.add_argument("--dataset-shuffle", type=bool, help="shuffle the dataset", default=True)
    parser.add_argument("--dataset-order-shuffle", type=bool, help="shuffle the dataset order", default=True)
    parser.add_argument("--batch-size", type=int, default=2, help="batch size per GPU")
    parser.add_argument("--workers", type=int, default=4, help="number of workers per GPU")
    parser.add_argument(
        "--threads",
        type=int,
        default=16,
        help="number of CPU threads per GPU. This can be changed around to speed-up transforms if needed. This can lead to worker thread contention so use with care.",
    )

    # model architecture
    parser.add_argument(
        "--model",
        type=str,
        default="crestereo_base",
        help="model architecture",
        choices=["crestereo_base", "raft_stereo"],
    )
    parser.add_argument("--recurrent-updates", type=int, default=10, help="number of recurrent updates")
    parser.add_argument("--freeze-batch-norm", action="store_true", help="freeze batch norm parameters")

    # loss parameters
    parser.add_argument("--gamma", type=float, default=0.8, help="gamma parameter for the flow sequence loss")
    parser.add_argument("--flow-loss-weight", type=float, default=1.0, help="weight for the flow loss")
    parser.add_argument(
        "--flow-loss-exclude-large",
        action="store_true",
        help="exclude large flow values from the loss. A large value is defined as a value greater than the ground truth flow norm",
        default=False,
    )
    parser.add_argument("--consistency-weight", type=float, default=0.0, help="consistency loss weight")
    parser.add_argument(
        "--consistency-resize-factor",
        type=float,
        default=0.25,
        help="consistency loss resize factor to account for the fact that the flow is computed on a downsampled image",
    )
    parser.add_argument("--psnr-weight", type=float, default=0.0, help="psnr loss weight")
    parser.add_argument("--smoothness-weight", type=float, default=0.0, help="smoothness loss weight")
    parser.add_argument("--photometric-weight", type=float, default=0.0, help="photometric loss weight")
    parser.add_argument(
        "--photometric-max-displacement-ratio",
        type=float,
        default=0.15,
        help="Only pixels with a displacement smaller than this ratio of the image width will be considered for the photometric loss",
    )
    parser.add_argument("--photometric-ssim-weight", type=float, default=0.85, help="photometric ssim loss weight")

    # transforms parameters
    parser.add_argument("--gpu-transforms", action="store_true", help="use GPU transforms")
    parser.add_argument(
        "--eval-size", type=int, nargs="+", default=[384, 512], help="size of the images for evaluation"
    )
    parser.add_argument("--resize-size", type=int, nargs=2, default=None, help="resize size")
    parser.add_argument("--crop-size", type=int, nargs=2, default=[384, 512], help="crop size")
    parser.add_argument("--scale-range", type=float, nargs=2, default=[0.6, 1.0], help="random scale range")
    parser.add_argument("--rescale-prob", type=float, default=1.0, help="probability of resizing the image")
    parser.add_argument(
        "--scaling-type", type=str, default="linear", help="scaling type", choices=["exponential", "linear"]
    )
    parser.add_argument("--flip-prob", type=float, default=0.5, help="probability of flipping the image")
    parser.add_argument(
        "--norm-mean", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="mean for image normalization"
    )
    parser.add_argument(
        "--norm-std", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="std for image normalization"
    )
    parser.add_argument(
        "--use-grayscale", action="store_true", help="use grayscale images instead of RGB", default=False
    )
    parser.add_argument("--max-disparity", type=float, default=None, help="maximum disparity")
    parser.add_argument(
        "--interpolation-strategy",
        type=str,
        default="bilinear",
        help="interpolation strategy",
        choices=["bilinear", "bicubic", "mixed"],
    )
    parser.add_argument("--spatial-shift-prob", type=float, default=1.0, help="probability of shifting the image")
    parser.add_argument(
        "--spatial-shift-max-angle", type=float, default=0.1, help="maximum angle for the spatial shift"
    )
    parser.add_argument(
        "--spatial-shift-max-displacement", type=float, default=2.0, help="maximum displacement for the spatial shift"
    )
    parser.add_argument("--gamma-range", type=float, nargs="+", default=[0.8, 1.2], help="range for gamma correction")
    parser.add_argument(
        "--brightness-range", type=float, nargs="+", default=[0.8, 1.2], help="range for brightness correction"
    )
    parser.add_argument(
        "--contrast-range", type=float, nargs="+", default=[0.8, 1.2], help="range for contrast correction"
    )
    parser.add_argument(
        "--saturation-range", type=float, nargs="+", default=0.0, help="range for saturation correction"
    )
    parser.add_argument("--hue-range", type=float, nargs="+", default=0.0, help="range for hue correction")
    parser.add_argument(
        "--asymmetric-jitter-prob",
        type=float,
        default=1.0,
        help="probability of using asymmetric jitter instead of symmetric jitter",
    )
    parser.add_argument("--occlusion-prob", type=float, default=0.5, help="probability of occluding the rightimage")
    parser.add_argument(
        "--occlusion-px-range", type=int, nargs="+", default=[50, 100], help="range for the number of occluded pixels"
    )
    parser.add_argument("--erase-prob", type=float, default=0.0, help="probability of erasing in both images")
    parser.add_argument(
        "--erase-px-range", type=int, nargs="+", default=[50, 100], help="range for the number of erased pixels"
    )
    parser.add_argument(
        "--erase-num-repeats", type=int, default=1, help="number of times to repeat the erase operation"
    )

    # optimizer parameters
    parser.add_argument("--optimizer", type=str, default="adam", help="optimizer", choices=["adam", "sgd"])
    parser.add_argument("--lr", type=float, default=4e-4, help="learning rate")
    parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay")
    parser.add_argument("--clip-grad-norm", type=float, default=0.0, help="clip grad norm")

    # lr_scheduler parameters
    parser.add_argument("--min-lr", type=float, default=2e-5, help="minimum learning rate")
    parser.add_argument("--warmup-steps", type=int, default=6_000, help="number of warmup steps")
    parser.add_argument(
        "--decay-after-steps", type=int, default=180_000, help="number of steps after which to start decay the lr"
    )
    parser.add_argument(
        "--lr-warmup-method", type=str, default="linear", help="warmup method", choices=["linear", "cosine"]
    )
    parser.add_argument("--lr-warmup-factor", type=float, default=0.02, help="warmup factor for the learning rate")
    parser.add_argument(
        "--lr-decay-method",
        type=str,
        default="linear",
        help="decay method",
        choices=["linear", "cosine", "exponential"],
    )
    parser.add_argument("--lr-decay-gamma", type=float, default=0.8, help="decay factor for the learning rate")

    # deterministic behaviour
    parser.add_argument("--seed", type=int, default=42, help="seed for random number generators")

    # mixed precision training
    parser.add_argument("--mixed-precision", action="store_true", help="use mixed precision training")

    # logging
    parser.add_argument("--tensorboard-summaries", action="store_true", help="log to tensorboard")
    parser.add_argument("--tensorboard-log-frequency", type=int, default=100, help="log frequency")
    parser.add_argument("--save-frequency", type=int, default=1_000, help="save frequency")
    parser.add_argument("--valid-frequency", type=int, default=1_000, help="validation frequency")
    parser.add_argument(
        "--metrics",
        type=str,
        nargs="+",
        default=["mae", "rmse", "1px", "3px", "5px", "relepe"],
        help="metrics to log",
        choices=AVAILABLE_METRICS,
    )

    # distributed parameters
    parser.add_argument("--world-size", type=int, default=8, help="number of distributed processes")
    parser.add_argument("--dist-url", type=str, default="env://", help="url used to set up distributed training")
    parser.add_argument("--device", type=str, default="cuda", help="device to use for training")

    # weights API
    parser.add_argument("--weights", type=str, default=None, help="weights API url")
    parser.add_argument(
        "--resume-path", type=str, default=None, help="a path from which to resume or start fine-tuning"
    )
    parser.add_argument("--resume-schedule", action="store_true", help="resume optimizer state")

    # padder parameters
    parser.add_argument("--padder-type", type=str, default="kitti", help="padder type", choices=["kitti", "sintel"])
    return parser


if __name__ == "__main__":
    args = get_args_parser().parse_args()
    main(args)
