# Delete old branches
import os
import re
from datetime import datetime
from functools import lru_cache
from pathlib import Path
from typing import Any, Callable

from github_utils import gh_fetch_json_dict, gh_graphql
from gitutils import GitRepo


SEC_IN_DAY = 24 * 60 * 60
CLOSED_PR_RETENTION = 30 * SEC_IN_DAY
NO_PR_RETENTION = 1.5 * 365 * SEC_IN_DAY
PR_WINDOW = 90 * SEC_IN_DAY  # Set to None to look at all PRs (may take a lot of tokens)
REPO_OWNER = "pytorch"
REPO_NAME = "pytorch"
ESTIMATED_TOKENS = [0]

TOKEN = os.environ["GITHUB_TOKEN"]
if not TOKEN:
    raise Exception("GITHUB_TOKEN is not set")  # noqa: TRY002

REPO_ROOT = Path(__file__).parents[2]

# Query for all PRs instead of just closed/merged because it's faster
GRAPHQL_ALL_PRS_BY_UPDATED_AT = """
query ($owner: String!, $repo: String!, $cursor: String) {
  repository(owner: $owner, name: $repo) {
    pullRequests(
      first: 100
      after: $cursor
      orderBy: {field: UPDATED_AT, direction: DESC}
    ) {
      totalCount
      pageInfo {
        hasNextPage
        endCursor
      }
      nodes {
        headRefName
        number
        updatedAt
        state
      }
    }
  }
}
"""

GRAPHQL_OPEN_PRS = """
query ($owner: String!, $repo: String!, $cursor: String) {
  repository(owner: $owner, name: $repo) {
    pullRequests(
      first: 100
      after: $cursor
      states: [OPEN]
    ) {
      totalCount
      pageInfo {
        hasNextPage
        endCursor
      }
      nodes {
        headRefName
        number
        updatedAt
        state
      }
    }
  }
}
"""

GRAPHQL_NO_DELETE_BRANCH_LABEL = """
query ($owner: String!, $repo: String!, $cursor: String) {
  repository(owner: $owner, name: $repo) {
    label(name: "no-delete-branch") {
      pullRequests(first: 100, after: $cursor) {
        totalCount
        pageInfo {
          hasNextPage
          endCursor
        }
        nodes {
          headRefName
          number
          updatedAt
          state
        }
      }
    }
  }
}
"""


def is_protected(branch: str) -> bool:
    try:
        ESTIMATED_TOKENS[0] += 1
        res = gh_fetch_json_dict(
            f"https://api.github.com/repos/{REPO_OWNER}/{REPO_NAME}/branches/{branch}"
        )
        return bool(res["protected"])
    except Exception as e:
        print(f"[{branch}] Failed to fetch branch protections: {e}")
        return True


def convert_gh_timestamp(date: str) -> float:
    return datetime.strptime(date, "%Y-%m-%dT%H:%M:%SZ").timestamp()


def get_branches(repo: GitRepo) -> dict[str, Any]:
    # Query locally for branches, group by branch base name (e.g. gh/blah/base -> gh/blah), and get the most recent branch
    git_response = repo._run_git(
        "for-each-ref",
        "--sort=creatordate",
        "--format=%(refname) %(committerdate:iso-strict)",
        "refs/remotes/origin",
    )
    branches_by_base_name: dict[str, Any] = {}
    for line in git_response.splitlines():
        branch, date = line.split(" ")
        re_branch = re.match(r"refs/remotes/origin/(.*)", branch)
        assert re_branch
        branch = branch_base_name = re_branch.group(1)
        if x := re.match(r"(gh\/.+)\/(head|base|orig)", branch):
            branch_base_name = x.group(1)
        date = datetime.fromisoformat(date).timestamp()
        if branch_base_name not in branches_by_base_name:
            branches_by_base_name[branch_base_name] = [date, [branch]]
        else:
            branches_by_base_name[branch_base_name][1].append(branch)
            if date > branches_by_base_name[branch_base_name][0]:
                branches_by_base_name[branch_base_name][0] = date
    return branches_by_base_name


def paginate_graphql(
    query: str,
    kwargs: dict[str, Any],
    termination_func: Callable[[list[dict[str, Any]]], bool],
    get_data: Callable[[dict[str, Any]], list[dict[str, Any]]],
    get_page_info: Callable[[dict[str, Any]], dict[str, Any]],
) -> list[Any]:
    hasNextPage = True
    endCursor = None
    data: list[dict[str, Any]] = []
    while hasNextPage:
        ESTIMATED_TOKENS[0] += 1
        res = gh_graphql(query, cursor=endCursor, **kwargs)
        data.extend(get_data(res))
        hasNextPage = get_page_info(res)["hasNextPage"]
        endCursor = get_page_info(res)["endCursor"]
        if termination_func(data):
            break
    return data


def get_recent_prs() -> dict[str, Any]:
    now = datetime.now().timestamp()

    # Grab all PRs updated in last CLOSED_PR_RETENTION days
    pr_infos: list[dict[str, Any]] = paginate_graphql(
        GRAPHQL_ALL_PRS_BY_UPDATED_AT,
        {"owner": "pytorch", "repo": "pytorch"},
        lambda data: (
            PR_WINDOW is not None
            and (now - convert_gh_timestamp(data[-1]["updatedAt"]) > PR_WINDOW)
        ),
        lambda res: res["data"]["repository"]["pullRequests"]["nodes"],
        lambda res: res["data"]["repository"]["pullRequests"]["pageInfo"],
    )

    # Get the most recent PR for each branch base (group gh together)
    prs_by_branch_base = {}
    for pr in pr_infos:
        pr["updatedAt"] = convert_gh_timestamp(pr["updatedAt"])
        branch_base_name = pr["headRefName"]
        if x := re.match(r"(gh\/.+)\/(head|base|orig)", branch_base_name):
            branch_base_name = x.group(1)
        if branch_base_name not in prs_by_branch_base:
            prs_by_branch_base[branch_base_name] = pr
        else:
            if pr["updatedAt"] > prs_by_branch_base[branch_base_name]["updatedAt"]:
                prs_by_branch_base[branch_base_name] = pr
    return prs_by_branch_base


@lru_cache(maxsize=1)
def get_open_prs() -> list[dict[str, Any]]:
    return paginate_graphql(
        GRAPHQL_OPEN_PRS,
        {"owner": "pytorch", "repo": "pytorch"},
        lambda data: False,
        lambda res: res["data"]["repository"]["pullRequests"]["nodes"],
        lambda res: res["data"]["repository"]["pullRequests"]["pageInfo"],
    )


def get_branches_with_magic_label_or_open_pr() -> set[str]:
    pr_infos: list[dict[str, Any]] = paginate_graphql(
        GRAPHQL_NO_DELETE_BRANCH_LABEL,
        {"owner": "pytorch", "repo": "pytorch"},
        lambda data: False,
        lambda res: res["data"]["repository"]["label"]["pullRequests"]["nodes"],
        lambda res: res["data"]["repository"]["label"]["pullRequests"]["pageInfo"],
    )

    pr_infos.extend(get_open_prs())

    # Get the most recent PR for each branch base (group gh together)
    branch_bases = set()
    for pr in pr_infos:
        branch_base_name = pr["headRefName"]
        if x := re.match(r"(gh\/.+)\/(head|base|orig)", branch_base_name):
            branch_base_name = x.group(1)
        branch_bases.add(branch_base_name)
    return branch_bases


def delete_branch(repo: GitRepo, branch: str) -> None:
    repo._run_git("push", "origin", "-d", branch)


def delete_branches() -> None:
    now = datetime.now().timestamp()
    git_repo = GitRepo(str(REPO_ROOT), "origin", debug=True)
    branches = get_branches(git_repo)
    prs_by_branch = get_recent_prs()
    keep_branches = get_branches_with_magic_label_or_open_pr()

    delete = []
    # Do not delete if:
    # * associated PR is open, closed but updated recently, or contains the magic string
    # * no associated PR and branch was updated in last 1.5 years
    # * is protected
    # Setting different values of PR_WINDOW will change how branches with closed
    # PRs are treated depending on how old the branch is.  The default value of
    # 90 will allow branches with closed PRs to be deleted if the PR hasn't been
    # updated in 90 days and the branch hasn't been updated in 1.5 years
    for base_branch, (date, sub_branches) in branches.items():
        print(f"[{base_branch}] Updated {(now - date) / SEC_IN_DAY} days ago")
        if base_branch in keep_branches:
            print(f"[{base_branch}] Has magic label or open PR, skipping")
            continue
        pr = prs_by_branch.get(base_branch)
        if pr:
            print(
                f"[{base_branch}] Has PR {pr['number']}: {pr['state']}, updated {(now - pr['updatedAt']) / SEC_IN_DAY} days ago"
            )
            if (
                now - pr["updatedAt"] < CLOSED_PR_RETENTION
                or (now - date) < CLOSED_PR_RETENTION
            ):
                continue
        elif now - date < NO_PR_RETENTION:
            continue
        print(f"[{base_branch}] Checking for branch protections")
        if any(is_protected(sub_branch) for sub_branch in sub_branches):
            print(f"[{base_branch}] Is protected")
            continue
        for sub_branch in sub_branches:
            print(f"[{base_branch}] Deleting {sub_branch}")
            delete.append(sub_branch)
        if ESTIMATED_TOKENS[0] > 400:
            print("Estimated tokens exceeded, exiting")
            break

    print(f"To delete ({len(delete)}):")
    for branch in delete:
        print(f"About to delete branch {branch}")
        delete_branch(git_repo, branch)


def delete_old_ciflow_tags() -> None:
    # Deletes ciflow tags if they are associated with a closed PR or a specific
    # commit.  Lightweight tags don't have information about the date they were
    # created, so we can't check how old they are.  The script just assumes that
    # ciflow tags should be deleted regardless of creation date.
    git_repo = GitRepo(str(REPO_ROOT), "origin", debug=True)

    def delete_tag(tag: str) -> None:
        print(f"Deleting tag {tag}")
        ESTIMATED_TOKENS[0] += 1
        delete_branch(git_repo, f"refs/tags/{tag}")

    tags = git_repo._run_git("tag").splitlines()
    open_pr_numbers = [x["number"] for x in get_open_prs()]

    for tag in tags:
        try:
            if ESTIMATED_TOKENS[0] > 400:
                print("Estimated tokens exceeded, exiting")
                break
            if not tag.startswith("ciflow/"):
                continue
            re_match_pr = re.match(r"^ciflow\/.*\/(\d{5,6})$", tag)
            re_match_sha = re.match(r"^ciflow\/.*\/([0-9a-f]{40})$", tag)
            if re_match_pr:
                pr_number = int(re_match_pr.group(1))
                if pr_number in open_pr_numbers:
                    continue
                delete_tag(tag)
            elif re_match_sha:
                delete_tag(tag)
        except Exception as e:
            print(f"Failed to check tag {tag}: {e}")


if __name__ == "__main__":
    delete_branches()
    delete_old_ciflow_tags()
