#!/usr/bin/env python3

# NB: the following functions are used in Meta-internal workflows
# (github_first_try_merge/my_handler.py) and thus have functionality limitations
# (no `git` command access, no network access besides the strict allow list):
#
# find_matching_merge_rule
# read_merge_rules
#
# Also any signature changes of these functions, as well as changes to the `GitHubPR`
# class, will likely require corresponding changes for the internal workflows.

import base64
import json
import os
import re
import time
import urllib.parse
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass
from functools import cache
from pathlib import Path
from re import Pattern
from typing import Any, Callable, cast, NamedTuple, Optional
from warnings import warn

import yaml
from github_utils import (
    gh_close_pr,
    gh_fetch_json_list,
    gh_fetch_merge_base,
    gh_fetch_url,
    gh_graphql,
    gh_post_commit_comment,
    gh_post_pr_comment,
    gh_update_pr_state,
    GitHubComment,
)
from gitutils import (
    are_ghstack_branches_in_sync,
    get_git_remote_name,
    get_git_repo_dir,
    GitRepo,
    patterns_to_regex,
    retries_decorator,
)
from label_utils import (
    gh_add_labels,
    gh_remove_label,
    has_required_labels,
    LABEL_ERR_MSG,
)
from trymerge_explainer import get_revert_message, TryMergeExplainer


# labels
MERGE_IN_PROGRESS_LABEL = "merging"
MERGE_COMPLETE_LABEL = "merged"


class JobCheckState(NamedTuple):
    name: str
    url: str
    status: Optional[str]
    classification: Optional[str]
    job_id: Optional[int]
    title: Optional[str]
    summary: Optional[str]


JobNameToStateDict = dict[str, JobCheckState]


class WorkflowCheckState:
    def __init__(self, name: str, url: str, run_id: int, status: Optional[str]):
        self.name: str = name
        self.url: str = url
        self.run_id: int = run_id
        self.status: Optional[str] = status
        self.jobs: JobNameToStateDict = {}


GH_PR_REVIEWS_FRAGMENT = """
fragment PRReviews on PullRequestReviewConnection {
  nodes {
    author {
      login
    }
    bodyText
    createdAt
    authorAssociation
    editor {
      login
    }
    databaseId
    url
    state
  }
  pageInfo {
    startCursor
    hasPreviousPage
  }
}
"""

GH_CHECKSUITES_FRAGMENT = """
fragment PRCheckSuites on CheckSuiteConnection {
  edges {
    node {
      app {
        name
        databaseId
      }
      workflowRun {
        workflow {
          name
          databaseId
        }
        databaseId
        url
      }
      checkRuns(first: 50) {
        nodes {
          name
          conclusion
          detailsUrl
          databaseId
          title
          summary
        }
        pageInfo {
          endCursor
          hasNextPage
        }
      }
      conclusion
    }
    cursor
  }
  pageInfo {
    hasNextPage
  }
}
"""

GH_COMMIT_AUTHORS_FRAGMENT = """
fragment CommitAuthors on PullRequestCommitConnection {
  nodes {
    commit {
      authors(first: 2) {
        nodes {
          user {
            login
          }
          email
          name
        }
      }
      oid
    }
  }
  pageInfo {
    endCursor
    hasNextPage
  }
}
"""

GH_GET_PR_INFO_QUERY = (
    GH_PR_REVIEWS_FRAGMENT
    + GH_CHECKSUITES_FRAGMENT
    + GH_COMMIT_AUTHORS_FRAGMENT
    + """
query ($owner: String!, $name: String!, $number: Int!) {
  repository(owner: $owner, name: $name) {
    pullRequest(number: $number) {
      closed
      isCrossRepository
      author {
        login
      }
      title
      body
      headRefName
      headRepository {
        nameWithOwner
      }
      baseRefName
      baseRefOid
      baseRepository {
        nameWithOwner
        isPrivate
        defaultBranchRef {
          name
        }
      }
      mergeCommit {
        oid
      }
      commits_with_authors: commits(first: 100) {
        ...CommitAuthors
        totalCount
      }
      commits(last: 1) {
        nodes {
          commit {
            checkSuites(first: 10) {
              ...PRCheckSuites
            }
            status {
              contexts {
                context
                state
                targetUrl
              }
            }
            oid
          }
        }
      }
      changedFiles
      files(first: 100) {
        nodes {
          path
        }
        pageInfo {
          endCursor
          hasNextPage
        }
      }
      reviews(last: 100) {
        ...PRReviews
      }
      comments(last: 5) {
        nodes {
          bodyText
          createdAt
          author {
            login
          }
          authorAssociation
          editor {
            login
          }
          databaseId
          url
        }
        pageInfo {
          startCursor
          hasPreviousPage
        }
      }
      labels(first: 100) {
        edges {
          node {
            name
          }
        }
      }
    }
  }
}
"""
)

GH_GET_PR_NEXT_FILES_QUERY = """
query ($owner: String!, $name: String!, $number: Int!, $cursor: String!) {
  repository(name: $name, owner: $owner) {
    pullRequest(number: $number) {
      files(first: 100, after: $cursor) {
        nodes {
          path
        }
        pageInfo {
          endCursor
          hasNextPage
        }
      }
    }
  }
}
"""

GH_GET_PR_NEXT_CHECKSUITES = (
    GH_CHECKSUITES_FRAGMENT
    + """
query ($owner: String!, $name: String!, $number: Int!, $cursor: String!) {
  repository(name: $name, owner: $owner) {
    pullRequest(number: $number) {
      commits(last: 1) {
        nodes {
          commit {
            oid
            checkSuites(first: 10, after: $cursor) {
              ...PRCheckSuites
            }
          }
        }
      }
    }
  }
}
"""
)

GH_GET_PR_NEXT_CHECK_RUNS = """
query ($owner: String!, $name: String!, $number: Int!, $cs_cursor: String, $cr_cursor: String!) {
  repository(name: $name, owner: $owner) {
    pullRequest(number: $number) {
      commits(last: 1) {
        nodes {
          commit {
            oid
            checkSuites(first: 1, after: $cs_cursor) {
              nodes {
                checkRuns(first: 100, after: $cr_cursor) {
                  nodes {
                    name
                    conclusion
                    detailsUrl
                    databaseId
                    title
                    summary
                  }
                  pageInfo {
                    endCursor
                    hasNextPage
                  }
                }
              }
            }
          }
        }
      }
    }
  }
}
"""

GH_GET_PR_PREV_COMMENTS = """
query ($owner: String!, $name: String!, $number: Int!, $cursor: String!) {
  repository(name: $name, owner: $owner) {
    pullRequest(number: $number) {
      comments(last: 100, before: $cursor) {
        nodes {
          bodyText
          createdAt
          author {
            login
          }
          authorAssociation
          editor {
            login
          }
          databaseId
          url
        }
        pageInfo {
          startCursor
          hasPreviousPage
        }
      }
    }
  }
}
"""

# This query needs read-org permission
GH_GET_TEAM_MEMBERS_QUERY = """
query($org: String!, $name: String!, $cursor: String) {
  organization(login: $org) {
    team(slug: $name) {
      members(first: 100, after: $cursor) {
        nodes {
          login
        }
        pageInfo {
          hasNextPage
          endCursor
        }
      }
    }
  }
}
"""

GH_GET_PR_NEXT_AUTHORS_QUERY = (
    GH_COMMIT_AUTHORS_FRAGMENT
    + """
query ($owner: String!, $name: String!, $number: Int!, $cursor: String) {
  repository(name: $name, owner: $owner) {
    pullRequest(number: $number) {
      commits_with_authors: commits(first: 100, after: $cursor) {
        ...CommitAuthors
      }
    }
  }
}
"""
)

GH_GET_PR_PREV_REVIEWS_QUERY = (
    GH_PR_REVIEWS_FRAGMENT
    + """
query ($owner: String!, $name: String!, $number: Int!, $cursor: String!) {
  repository(name: $name, owner: $owner) {
    pullRequest(number: $number) {
      reviews(last: 100, before: $cursor) {
        ...PRReviews
      }
    }
  }
}
"""
)

GH_GET_REPO_SUBMODULES = """
query ($owner: String!, $name: String!) {
  repository(owner: $owner, name: $name) {
    submodules(first: 100) {
      nodes {
        path
      }
      pageInfo {
        endCursor
        hasNextPage
      }
    }
  }
}
"""

RE_GHSTACK_HEAD_REF = re.compile(r"^(gh/[^/]+/[0-9]+/)head$")
RE_GHSTACK_DESC = re.compile(r"Stack.*:\r?\n(\* [^\r\n]+\r?\n)+", re.MULTILINE)
RE_PULL_REQUEST_RESOLVED = re.compile(
    r"Pull Request resolved: "
    r"https://github.com/(?P<owner>[^/]+)/(?P<repo>[^/]+)/pull/(?P<number>[0-9]+)",
    re.MULTILINE,
)
RE_PR_CC_LINE = re.compile(r"^cc:? @\w+.*\r?\n?$", re.MULTILINE)
RE_DIFF_REV = re.compile(r"^Differential Revision:.+?(D[0-9]+)", re.MULTILINE)
CIFLOW_LABEL = re.compile(r"^ciflow/.+")
CIFLOW_TRUNK_LABEL = re.compile(r"^ciflow/trunk")
MERGE_RULE_PATH = Path(".github") / "merge_rules.yaml"
REMOTE_MAIN_BRANCH = "origin/main"
DRCI_CHECKRUN_NAME = "Dr.CI"
INTERNAL_CHANGES_CHECKRUN_NAME = "Meta Internal-Only Changes Check"
HAS_NO_CONNECTED_DIFF_TITLE = (
    "There is no internal Diff connected, this can be merged now"
)
# This could be set to -1 to ignore all flaky and broken trunk failures. On the
# other hand, using a large value like 10 here might be useful in sev situation
IGNORABLE_FAILED_CHECKS_THESHOLD = 10


def gh_get_pr_info(org: str, proj: str, pr_no: int) -> Any:
    rc = gh_graphql(GH_GET_PR_INFO_QUERY, name=proj, owner=org, number=pr_no)
    return rc["data"]["repository"]["pullRequest"]


@cache
def gh_get_team_members(org: str, name: str) -> list[str]:
    rc: list[str] = []
    team_members: dict[str, Any] = {
        "pageInfo": {"hasNextPage": "true", "endCursor": None}
    }
    while bool(team_members["pageInfo"]["hasNextPage"]):
        query = gh_graphql(
            GH_GET_TEAM_MEMBERS_QUERY,
            org=org,
            name=name,
            cursor=team_members["pageInfo"]["endCursor"],
        )
        team = query["data"]["organization"]["team"]
        if team is None:
            warn(f"Requested non-existing team {org}/{name}")
            return []
        team_members = team["members"]
        rc += [member["login"] for member in team_members["nodes"]]
    return rc


def get_check_run_name_prefix(workflow_run: Any) -> str:
    if workflow_run is None:
        return ""
    else:
        return f"{workflow_run['workflow']['name']} / "


def is_passing_status(status: Optional[str]) -> bool:
    return status is not None and status.upper() in ["SUCCESS", "SKIPPED", "NEUTRAL"]


def add_workflow_conclusions(
    checksuites: Any,
    get_next_checkruns_page: Callable[[list[dict[str, dict[str, Any]]], int, Any], Any],
    get_next_checksuites: Callable[[Any], Any],
) -> JobNameToStateDict:
    # graphql seems to favor the most recent workflow run, so in theory we
    # shouldn't need to account for reruns, but do it just in case

    # workflow -> job -> job info
    workflows: dict[str, WorkflowCheckState] = {}

    # for the jobs that don't have a workflow
    no_workflow_obj: WorkflowCheckState = WorkflowCheckState("", "", 0, None)

    def add_conclusions(edges: Any) -> None:
        for edge_idx, edge in enumerate(edges):
            node = edge["node"]
            workflow_run = node["workflowRun"]
            checkruns = node["checkRuns"]

            workflow_obj: WorkflowCheckState = no_workflow_obj

            if workflow_run is not None:
                # This is the usual workflow run ID we see on GitHub
                workflow_run_id = workflow_run["databaseId"]
                # While this is the metadata name and ID of the workflow itself
                workflow_name = workflow_run["workflow"]["name"]
                workflow_id = workflow_run["workflow"]["databaseId"]

                workflow_conclusion = node["conclusion"]
                # Do not override existing status with cancelled
                if workflow_conclusion == "CANCELLED" and workflow_name in workflows:
                    continue

                # Only keep the latest workflow run for each workflow, heuristically,
                # it's the run with largest run ID
                if (
                    workflow_id not in workflows
                    or workflows[workflow_id].run_id < workflow_run_id
                ):
                    workflows[workflow_id] = WorkflowCheckState(
                        name=workflow_name,
                        status=workflow_conclusion,
                        url=workflow_run["url"],
                        run_id=workflow_run_id,
                    )
                workflow_obj = workflows[workflow_id]

            while checkruns is not None:
                for checkrun_node in checkruns["nodes"]:
                    if not isinstance(checkrun_node, dict):
                        warn(f"Expected dictionary, but got {type(checkrun_node)}")
                        continue
                    checkrun_name = f"{get_check_run_name_prefix(workflow_run)}{checkrun_node['name']}"
                    existing_checkrun = workflow_obj.jobs.get(checkrun_name)
                    if existing_checkrun is None or not is_passing_status(
                        existing_checkrun.status
                    ):
                        workflow_obj.jobs[checkrun_name] = JobCheckState(
                            checkrun_name,
                            checkrun_node["detailsUrl"],
                            checkrun_node["conclusion"],
                            classification=None,
                            job_id=checkrun_node["databaseId"],
                            title=checkrun_node["title"],
                            summary=checkrun_node["summary"],
                        )

                if bool(checkruns["pageInfo"]["hasNextPage"]):
                    checkruns = get_next_checkruns_page(edges, edge_idx, checkruns)
                else:
                    checkruns = None

    all_edges = checksuites["edges"].copy()
    while bool(checksuites["pageInfo"]["hasNextPage"]):
        checksuites = get_next_checksuites(checksuites)
        all_edges.extend(checksuites["edges"])

    add_conclusions(all_edges)

    # Flatten the dictionaries.  If there exists jobs in the workflow run, put
    # the jobs in but don't put the workflow in.  We care more about the jobs in
    # the workflow that ran than the container workflow.
    res: JobNameToStateDict = {}
    for workflow in workflows.values():
        if len(workflow.jobs) > 0:
            for job_name, job in workflow.jobs.items():
                res[job_name] = job
        else:
            res[workflow.name] = JobCheckState(
                workflow.name,
                workflow.url,
                workflow.status,
                classification=None,
                job_id=None,
                title=None,
                summary=None,
            )
    for job_name, job in no_workflow_obj.jobs.items():
        res[job_name] = job
    return res


def parse_args() -> Any:
    from argparse import ArgumentParser

    parser = ArgumentParser("Merge PR into default branch")
    parser.add_argument("--dry-run", action="store_true")
    parser.add_argument("--revert", action="store_true")
    parser.add_argument("--force", action="store_true")
    parser.add_argument("--ignore-current", action="store_true")
    parser.add_argument("--check-mergeability", action="store_true")
    parser.add_argument("--comment-id", type=int)
    parser.add_argument("--reason", type=str)
    parser.add_argument("pr_num", type=int)
    return parser.parse_args()


def can_skip_internal_checks(pr: "GitHubPR", comment_id: Optional[int] = None) -> bool:
    if comment_id is None:
        return False
    comment = pr.get_comment_by_id(comment_id)
    if comment.editor_login is not None:
        return False
    return comment.author_login == "facebook-github-bot"


def _revlist_to_prs(
    repo: GitRepo,
    pr: "GitHubPR",
    rev_list: Iterable[str],
    should_skip: Optional[Callable[[int, "GitHubPR"], bool]] = None,
) -> list[tuple["GitHubPR", str]]:
    rc: list[tuple[GitHubPR, str]] = []
    for idx, rev in enumerate(rev_list):
        msg = repo.commit_message(rev)
        m = RE_PULL_REQUEST_RESOLVED.search(msg)
        if m is None:
            raise RuntimeError(
                f"Could not find PR-resolved string in {msg} of ghstacked PR {pr.pr_num}"
            )
        if pr.org != m.group("owner") or pr.project != m.group("repo"):
            raise RuntimeError(
                f"PR {m.group('number')} resolved to wrong owner/repo pair"
            )
        pr_num = int(m.group("number"))
        candidate = GitHubPR(pr.org, pr.project, pr_num) if pr_num != pr.pr_num else pr
        if should_skip is not None and should_skip(idx, candidate):
            continue
        rc.append((candidate, rev))
    return rc


def get_ghstack_prs(
    repo: GitRepo, pr: "GitHubPR", open_only: bool = True
) -> list[tuple["GitHubPR", str]]:
    """
    Get the PRs in the stack that are below this PR (inclusive).  Throws error if any of the open PRs are out of sync.
    @:param open_only: Only return open PRs
    """
    # For ghstack, cherry-pick commits based from origin
    orig_ref = f"{repo.remote}/{pr.get_ghstack_orig_ref()}"
    rev_list = repo.revlist(f"{pr.default_branch()}..{orig_ref}")

    def skip_func(idx: int, candidate: "GitHubPR") -> bool:
        if not open_only or not candidate.is_closed():
            return False
        print(
            f"Skipping {idx + 1} of {len(rev_list)} PR (#{candidate.pr_num}) as its already been merged"
        )
        return True

    assert pr.is_ghstack_pr()
    entire_stack = _revlist_to_prs(repo, pr, reversed(rev_list), skip_func)

    for stacked_pr, rev in entire_stack:
        if stacked_pr.is_closed():
            continue
        base_ref = stacked_pr.base_ref()
        if base_ref == pr.default_branch():
            base_ref = repo.get_merge_base(
                f"{repo.remote}/{base_ref}", f"{repo.remote}/{stacked_pr.head_ref()}"
            )
        if not are_ghstack_branches_in_sync(repo, stacked_pr.head_ref(), base_ref):
            raise RuntimeError(
                f"PR {stacked_pr.pr_num} is out of sync with the corresponding revision {rev} on "
                + f"branch {stacked_pr.get_ghstack_orig_ref()} that would be merged into {stacked_pr.default_branch()}.  "
                + "This usually happens because there is a non ghstack change in the PR.  "
                + f"Please sync them and try again (ex. make the changes on {orig_ref} and run ghstack)."
            )
    return entire_stack


class GitHubPR:
    def __init__(self, org: str, project: str, pr_num: int) -> None:
        assert isinstance(pr_num, int)
        self.org = org
        self.project = project
        self.pr_num = pr_num
        self.info = gh_get_pr_info(org, project, pr_num)
        self.changed_files: Optional[list[str]] = None
        self.labels: Optional[list[str]] = None
        self.conclusions: Optional[JobNameToStateDict] = None
        self.comments: Optional[list[GitHubComment]] = None
        self._authors: Optional[list[tuple[str, str]]] = None
        self._reviews: Optional[list[tuple[str, str]]] = None
        self.merge_base: Optional[str] = None
        self.submodules: Optional[list[str]] = None

    def is_closed(self) -> bool:
        return bool(self.info["closed"])

    def is_cross_repo(self) -> bool:
        return bool(self.info["isCrossRepository"])

    def base_ref(self) -> str:
        return cast(str, self.info["baseRefName"])

    def default_branch(self) -> str:
        return cast(str, self.info["baseRepository"]["defaultBranchRef"]["name"])

    def head_ref(self) -> str:
        return cast(str, self.info["headRefName"])

    def is_ghstack_pr(self) -> bool:
        return RE_GHSTACK_HEAD_REF.match(self.head_ref()) is not None

    def get_ghstack_orig_ref(self) -> str:
        assert self.is_ghstack_pr()
        return re.sub(r"/head$", "/orig", self.head_ref())

    def is_base_repo_private(self) -> bool:
        return bool(self.info["baseRepository"]["isPrivate"])

    def get_changed_files_count(self) -> int:
        return int(self.info["changedFiles"])

    def last_commit(self) -> Any:
        return self.info["commits"]["nodes"][-1]["commit"]

    def get_merge_base(self) -> str:
        if self.merge_base:
            return self.merge_base

        last_commit_oid = self.last_commit()["oid"]
        # NB: We could use self.base_ref() here for regular PR, however, that doesn't
        # work for ghstack where the base is the custom branch, i.e. gh/USER/ID/base,
        # so let's just use main instead
        self.merge_base = gh_fetch_merge_base(
            self.org, self.project, last_commit_oid, self.default_branch()
        )

        # Fallback to baseRefOid if the API call fails, i.e. rate limit. Note that baseRefOid
        # points to the base ref associated with the PR or, in other words, the head of main
        # when the PR is created or rebased. This is not necessarily the merge base commit,
        # but it could serve as a fallback in most cases and it's readily available as part
        # of the PR info
        if not self.merge_base:
            self.merge_base = cast(str, self.info["baseRefOid"])

        return self.merge_base

    def get_changed_files(self) -> list[str]:
        if self.changed_files is None:
            info = self.info
            unique_changed_files = set()
            # Do not try to fetch more than 10K files
            for _ in range(100):
                unique_changed_files.update([x["path"] for x in info["files"]["nodes"]])
                if not info["files"]["pageInfo"]["hasNextPage"]:
                    break
                rc = gh_graphql(
                    GH_GET_PR_NEXT_FILES_QUERY,
                    name=self.project,
                    owner=self.org,
                    number=self.pr_num,
                    cursor=info["files"]["pageInfo"]["endCursor"],
                )
                info = rc["data"]["repository"]["pullRequest"]
            self.changed_files = list(unique_changed_files)

        if len(self.changed_files) != self.get_changed_files_count():
            raise RuntimeError("Changed file count mismatch")
        return self.changed_files

    def get_submodules(self) -> list[str]:
        if self.submodules is None:
            rc = gh_graphql(GH_GET_REPO_SUBMODULES, name=self.project, owner=self.org)
            info = rc["data"]["repository"]["submodules"]
            self.submodules = [s["path"] for s in info["nodes"]]
        return self.submodules

    def get_changed_submodules(self) -> list[str]:
        submodules = self.get_submodules()
        return [f for f in self.get_changed_files() if f in submodules]

    def has_invalid_submodule_updates(self) -> bool:
        """Submodule updates in PR are invalid if submodule keyword
        is not mentioned in neither the title nor body/description
        nor in any of the labels.
        """
        return (
            len(self.get_changed_submodules()) > 0
            and "submodule" not in self.get_title().lower()
            and "submodule" not in self.get_body().lower()
            and all("submodule" not in label for label in self.get_labels())
        )

    def _get_reviews(self) -> list[tuple[str, str]]:
        if self._reviews is None:
            self._reviews = []
            info = self.info
            for _ in range(100):
                nodes = info["reviews"]["nodes"]
                self._reviews = [
                    (node["author"]["login"], node["state"]) for node in nodes
                ] + self._reviews
                if not info["reviews"]["pageInfo"]["hasPreviousPage"]:
                    break
                rc = gh_graphql(
                    GH_GET_PR_PREV_REVIEWS_QUERY,
                    name=self.project,
                    owner=self.org,
                    number=self.pr_num,
                    cursor=info["reviews"]["pageInfo"]["startCursor"],
                )
                info = rc["data"]["repository"]["pullRequest"]
        reviews = {}
        for author, state in self._reviews:
            if state != "COMMENTED":
                reviews[author] = state
        return list(reviews.items())

    def get_approved_by(self) -> list[str]:
        return [login for (login, state) in self._get_reviews() if state == "APPROVED"]

    def get_commit_count(self) -> int:
        return int(self.info["commits_with_authors"]["totalCount"])

    def get_pr_creator_login(self) -> str:
        return cast(str, self.info["author"]["login"])

    def _fetch_authors(self) -> list[tuple[str, str]]:
        if self._authors is not None:
            return self._authors
        authors: list[tuple[str, str]] = []

        def add_authors(info: dict[str, Any]) -> None:
            for node in info["commits_with_authors"]["nodes"]:
                for author_node in node["commit"]["authors"]["nodes"]:
                    user_node = author_node["user"]
                    author = f"{author_node['name']} <{author_node['email']}>"
                    if user_node is None:
                        # If author is not github user, user node will be null
                        authors.append(("", author))
                    else:
                        authors.append((cast(str, user_node["login"]), author))

        info = self.info
        for _ in range(100):
            add_authors(info)
            if not info["commits_with_authors"]["pageInfo"]["hasNextPage"]:
                break
            rc = gh_graphql(
                GH_GET_PR_NEXT_AUTHORS_QUERY,
                name=self.project,
                owner=self.org,
                number=self.pr_num,
                cursor=info["commits_with_authors"]["pageInfo"]["endCursor"],
            )
            info = rc["data"]["repository"]["pullRequest"]
        self._authors = authors
        return authors

    def get_committer_login(self, num: int = 0) -> str:
        return self._fetch_authors()[num][0]

    def get_committer_author(self, num: int = 0) -> str:
        return self._fetch_authors()[num][1]

    def get_labels(self) -> list[str]:
        if self.labels is not None:
            return self.labels
        labels = (
            [node["node"]["name"] for node in self.info["labels"]["edges"]]
            if "labels" in self.info
            else []
        )
        self.labels = labels
        return self.labels

    def get_checkrun_conclusions(self) -> JobNameToStateDict:
        """Returns dict of checkrun -> [conclusion, url]"""
        if self.conclusions is not None:
            return self.conclusions
        orig_last_commit = self.last_commit()

        def get_pr_next_check_runs(
            edges: list[dict[str, dict[str, Any]]], edge_idx: int, checkruns: Any
        ) -> Any:
            rc = gh_graphql(
                GH_GET_PR_NEXT_CHECK_RUNS,
                name=self.project,
                owner=self.org,
                number=self.pr_num,
                cs_cursor=edges[edge_idx - 1]["cursor"] if edge_idx > 0 else None,
                cr_cursor=checkruns["pageInfo"]["endCursor"],
            )
            last_commit = rc["data"]["repository"]["pullRequest"]["commits"]["nodes"][
                -1
            ]["commit"]
            checkruns = last_commit["checkSuites"]["nodes"][-1]["checkRuns"]
            return checkruns

        def get_pr_next_checksuites(checksuites: Any) -> Any:
            rc = gh_graphql(
                GH_GET_PR_NEXT_CHECKSUITES,
                name=self.project,
                owner=self.org,
                number=self.pr_num,
                cursor=checksuites["edges"][-1]["cursor"],
            )
            info = rc["data"]["repository"]["pullRequest"]
            last_commit = info["commits"]["nodes"][-1]["commit"]
            if last_commit["oid"] != orig_last_commit["oid"]:
                raise RuntimeError("Last commit changed on PR")
            return last_commit["checkSuites"]

        checksuites = orig_last_commit["checkSuites"]

        self.conclusions = add_workflow_conclusions(
            checksuites, get_pr_next_check_runs, get_pr_next_checksuites
        )

        # Append old style statuses(like ones populated by CircleCI or EasyCLA) to conclusions
        if orig_last_commit["status"] and orig_last_commit["status"]["contexts"]:
            for status in orig_last_commit["status"]["contexts"]:
                name = status["context"]
                self.conclusions[name] = JobCheckState(
                    name,
                    status["targetUrl"],
                    status["state"],
                    classification=None,
                    job_id=None,
                    title=None,
                    summary=None,
                )

        return self.conclusions

    def get_authors(self) -> dict[str, str]:
        rc = {}
        for idx in range(len(self._fetch_authors())):
            rc[self.get_committer_login(idx)] = self.get_committer_author(idx)

        return rc

    def get_author(self) -> str:
        authors = self.get_authors()
        if len(authors) == 1:
            return next(iter(authors.values()))
        creator = self.get_pr_creator_login()
        # If PR creator is not among authors
        # Assume it was authored by first commit author
        if creator not in authors:
            return self.get_committer_author(0)
        return authors[creator]

    def get_title(self) -> str:
        return cast(str, self.info["title"])

    def get_body(self) -> str:
        return cast(str, self.info["body"])

    def get_merge_commit(self) -> Optional[str]:
        mc = self.info["mergeCommit"]
        return mc["oid"] if mc is not None else None

    def get_pr_url(self) -> str:
        return f"https://github.com/{self.org}/{self.project}/pull/{self.pr_num}"

    @staticmethod
    def _comment_from_node(node: Any) -> GitHubComment:
        editor = node["editor"]
        return GitHubComment(
            body_text=node["bodyText"],
            created_at=node["createdAt"] if "createdAt" in node else "",
            author_login=node["author"]["login"],
            author_association=node["authorAssociation"],
            editor_login=editor["login"] if editor else None,
            database_id=node["databaseId"],
            url=node["url"],
        )

    def get_comments(self) -> list[GitHubComment]:
        if self.comments is not None:
            return self.comments
        self.comments = []
        info = self.info["comments"]
        # Do not try to fetch more than 10K comments
        for _ in range(100):
            self.comments = [
                self._comment_from_node(node) for node in info["nodes"]
            ] + self.comments
            if not info["pageInfo"]["hasPreviousPage"]:
                break
            rc = gh_graphql(
                GH_GET_PR_PREV_COMMENTS,
                name=self.project,
                owner=self.org,
                number=self.pr_num,
                cursor=info["pageInfo"]["startCursor"],
            )
            info = rc["data"]["repository"]["pullRequest"]["comments"]
        return self.comments

    def get_last_comment(self) -> GitHubComment:
        return self._comment_from_node(self.info["comments"]["nodes"][-1])

    def get_comment_by_id(self, database_id: int) -> GitHubComment:
        if self.comments is None:
            # Fastpath - try searching in partial prefetched comments
            for node in self.info["comments"]["nodes"]:
                comment = self._comment_from_node(node)
                if comment.database_id == database_id:
                    return comment

        for comment in self.get_comments():
            if comment.database_id == database_id:
                return comment

        # The comment could have actually been a review left on the PR (the message written alongside the review).
        # (This is generally done to trigger the merge right when a comment is left)
        # Check those review comments to see if one of those was the comment in question.
        for node in self.info["reviews"]["nodes"]:
            # These review comments contain all the fields regular comments need
            comment = self._comment_from_node(node)
            if comment.database_id == database_id:
                return comment

        raise RuntimeError(f"Comment with id {database_id} not found")

    def get_diff_revision(self) -> Optional[str]:
        rc = RE_DIFF_REV.search(self.get_body())
        return rc.group(1) if rc is not None else None

    def has_internal_changes(self) -> bool:
        checkrun_name = INTERNAL_CHANGES_CHECKRUN_NAME
        if self.get_diff_revision() is None:
            return False
        checks = self.get_checkrun_conclusions()
        if checks is None or checkrun_name not in checks:
            return False
        return checks[checkrun_name].status != "SUCCESS"

    def has_no_connected_diff(self) -> bool:
        checkrun_name = INTERNAL_CHANGES_CHECKRUN_NAME
        checks = self.get_checkrun_conclusions()
        if checks is None or checkrun_name not in checks:
            return False
        return checks[checkrun_name].title == HAS_NO_CONNECTED_DIFF_TITLE

    def merge_ghstack_into(
        self,
        repo: GitRepo,
        skip_mandatory_checks: bool,
        comment_id: Optional[int] = None,
        skip_all_rule_checks: bool = False,
    ) -> list["GitHubPR"]:
        assert self.is_ghstack_pr()
        ghstack_prs = get_ghstack_prs(
            repo, self, open_only=False
        )  # raises error if out of sync
        pr_dependencies = []
        for pr, rev in ghstack_prs:
            if pr.is_closed():
                pr_dependencies.append(pr)
                continue

            commit_msg = pr.gen_commit_message(
                filter_ghstack=True, ghstack_deps=pr_dependencies
            )
            if pr.pr_num != self.pr_num and not skip_all_rule_checks:
                # Raises exception if matching rule is not found
                find_matching_merge_rule(
                    pr,
                    repo,
                    skip_mandatory_checks=skip_mandatory_checks,
                    skip_internal_checks=can_skip_internal_checks(self, comment_id),
                )
            repo.cherry_pick(rev)
            repo.amend_commit_message(commit_msg)
            pr_dependencies.append(pr)
        return [x for x, _ in ghstack_prs if not x.is_closed()]

    def gen_commit_message(
        self,
        filter_ghstack: bool = False,
        ghstack_deps: Optional[list["GitHubPR"]] = None,
    ) -> str:
        """Fetches title and body from PR description
        adds reviewed by, pull request resolved and optionally
        filters out ghstack info"""
        # Adding the url here makes it clickable within the Github UI
        approved_by_urls = ", ".join(
            prefix_with_github_url(login) for login in self.get_approved_by()
        )
        # Remove "cc: " line from the message body
        msg_body = re.sub(RE_PR_CC_LINE, "", self.get_body())
        if filter_ghstack:
            msg_body = re.sub(RE_GHSTACK_DESC, "", msg_body)
        msg = self.get_title() + f" (#{self.pr_num})\n\n"
        msg += msg_body

        msg += f"\nPull Request resolved: {self.get_pr_url()}\n"
        msg += f"Approved by: {approved_by_urls}\n"
        if ghstack_deps:
            msg += f"ghstack dependencies: {', '.join([f'#{pr.pr_num}' for pr in ghstack_deps])}\n"

        # Mention PR co-authors, which should be at the end of the message
        # And separated from the body by two newlines
        first_coauthor = True
        for author_login, author_name in self.get_authors().items():
            if author_login != self.get_pr_creator_login():
                if first_coauthor:
                    msg, first_coauthor = (msg + "\n", False)
                msg += f"\nCo-authored-by: {author_name}"

        return msg

    def add_numbered_label(self, label_base: str, dry_run: bool) -> None:
        labels = self.get_labels() if self.labels is not None else []
        full_label = label_base
        count = 0
        for label in labels:
            if label_base in label:
                count += 1
                full_label = f"{label_base}X{count}"
        self.add_label(full_label, dry_run)

    def add_label(self, label: str, dry_run: bool) -> None:
        gh_add_labels(self.org, self.project, self.pr_num, [label], dry_run)

    def merge_into(
        self,
        repo: GitRepo,
        *,
        skip_mandatory_checks: bool = False,
        dry_run: bool = False,
        comment_id: Optional[int] = None,
        ignore_current_checks: Optional[list[str]] = None,
    ) -> None:
        # Raises exception if matching rule is not found
        (
            merge_rule,
            pending_checks,
            failed_checks,
            ignorable_checks,
        ) = find_matching_merge_rule(
            self,
            repo,
            skip_mandatory_checks=skip_mandatory_checks,
            skip_internal_checks=can_skip_internal_checks(self, comment_id),
            ignore_current_checks=ignore_current_checks,
        )
        additional_merged_prs = self.merge_changes(
            repo, skip_mandatory_checks, comment_id
        )

        repo.push(self.default_branch(), dry_run)
        if not dry_run:
            self.add_numbered_label(MERGE_COMPLETE_LABEL, dry_run)
            for pr in additional_merged_prs:
                pr.add_numbered_label(MERGE_COMPLETE_LABEL, dry_run)

        # When the merge process reaches this part, we can assume that the commit
        # has been successfully pushed to trunk
        merge_commit_sha = repo.rev_parse(name=self.default_branch())

        if comment_id and self.pr_num:
            # Finally, upload the record to s3. The list of pending and failed
            # checks are at the time of the merge
            save_merge_record(
                comment_id=comment_id,
                pr_num=self.pr_num,
                owner=self.org,
                project=self.project,
                author=self.get_author(),
                pending_checks=pending_checks,
                failed_checks=failed_checks,
                ignore_current_checks=ignorable_checks.get("IGNORE_CURRENT_CHECK", []),
                broken_trunk_checks=ignorable_checks.get("BROKEN_TRUNK", []),
                flaky_checks=ignorable_checks.get("FLAKY", []),
                unstable_checks=ignorable_checks.get("UNSTABLE", []),
                last_commit_sha=self.last_commit().get("oid", ""),
                merge_base_sha=self.get_merge_base(),
                merge_commit_sha=merge_commit_sha,
                is_failed=False,
                skip_mandatory_checks=skip_mandatory_checks,
                ignore_current=bool(ignore_current_checks),
            )
        else:
            print("Missing comment ID or PR number, couldn't upload to s3")

        # Usually Github will see that the commit has "resolves <pr_num>" in the
        # commit message and close the PR, but sometimes it doesn't, leading to
        # confusion.  When it doesn't, we close it manually.
        time.sleep(60)  # Give Github some time to close the PR
        manually_close_merged_pr(
            pr=self,
            additional_merged_prs=additional_merged_prs,
            merge_commit_sha=merge_commit_sha,
            dry_run=dry_run,
        )

    def merge_changes(
        self,
        repo: GitRepo,
        skip_mandatory_checks: bool = False,
        comment_id: Optional[int] = None,
        branch: Optional[str] = None,
        skip_all_rule_checks: bool = False,
    ) -> list["GitHubPR"]:
        """
        :param skip_all_rule_checks: If true, skips all rule checks, useful for dry-running merge locally
        """
        branch_to_merge_into = self.default_branch() if branch is None else branch
        if repo.current_branch() != branch_to_merge_into:
            repo.checkout(branch_to_merge_into)
        if not self.is_ghstack_pr():
            msg = self.gen_commit_message()
            pr_branch_name = f"__pull-request-{self.pr_num}__init__"
            repo.fetch(self.last_commit()["oid"], pr_branch_name)
            repo._run_git("merge", "--squash", pr_branch_name)
            repo._run_git("commit", f'--author="{self.get_author()}"', "-m", msg)

            # Did the PR change since we started the merge?
            pulled_sha = repo.show_ref(pr_branch_name)
            latest_pr_status = GitHubPR(self.org, self.project, self.pr_num)
            if pulled_sha != latest_pr_status.last_commit()["oid"]:
                raise RuntimeError(
                    "PR has been updated since CI checks last passed. Please rerun the merge command."
                )
            return []
        else:
            return self.merge_ghstack_into(
                repo,
                skip_mandatory_checks,
                comment_id=comment_id,
                skip_all_rule_checks=skip_all_rule_checks,
            )


class MergeRuleFailedError(RuntimeError):
    def __init__(self, message: str, rule: Optional["MergeRule"] = None) -> None:
        super().__init__(message)
        self.rule = rule


class MandatoryChecksMissingError(MergeRuleFailedError):
    pass


class PostCommentError(Exception):
    pass


@dataclass
class MergeRule:
    name: str
    patterns: list[str]
    approved_by: list[str]
    mandatory_checks_name: Optional[list[str]]
    ignore_flaky_failures: bool = True


def gen_new_issue_link(
    org: str, project: str, labels: list[str], template: str = "bug-report.yml"
) -> str:
    labels_str = ",".join(labels)
    return (
        f"https://github.com/{org}/{project}/issues/new?"
        f"labels={urllib.parse.quote(labels_str)}&"
        f"template={urllib.parse.quote(template)}"
    )


def read_merge_rules(
    repo: Optional[GitRepo], org: str, project: str
) -> list[MergeRule]:
    """Returns the list of all merge rules for the repo or project.

    NB: this function is used in Meta-internal workflows, see the comment
    at the top of this file for details.
    """
    repo_relative_rules_path = MERGE_RULE_PATH
    if repo is None:
        json_data = gh_fetch_url(
            f"https://api.github.com/repos/{org}/{project}/contents/{repo_relative_rules_path}",
            headers={"Accept": "application/vnd.github.v3+json"},
            reader=json.load,
        )
        content = base64.b64decode(json_data["content"])
        return [MergeRule(**x) for x in yaml.safe_load(content)]
    else:
        rules_path = Path(repo.repo_dir) / repo_relative_rules_path
        if not rules_path.exists():
            print(f"{rules_path} does not exist, returning empty rules")
            return []
        with open(rules_path) as fp:
            rc = yaml.safe_load(fp)
        return [MergeRule(**x) for x in rc]


def find_matching_merge_rule(
    pr: GitHubPR,
    repo: Optional[GitRepo] = None,
    skip_mandatory_checks: bool = False,
    skip_internal_checks: bool = False,
    ignore_current_checks: Optional[list[str]] = None,
) -> tuple[
    MergeRule,
    list[tuple[str, Optional[str], Optional[int]]],
    list[tuple[str, Optional[str], Optional[int]]],
    dict[str, list[Any]],
]:
    """
    Returns merge rule matching to this pr together with the list of associated pending
    and failing jobs OR raises an exception.

    NB: this function is used in Meta-internal workflows, see the comment at the top of
    this file for details.
    """
    changed_files = pr.get_changed_files()
    approved_by = set(pr.get_approved_by())

    issue_link = gen_new_issue_link(
        org=pr.org,
        project=pr.project,
        labels=["module: ci"],
    )
    reject_reason = f"No rule found to match PR. Please [report]{issue_link} this issue to DevX team."

    rules = read_merge_rules(repo, pr.org, pr.project)
    if not rules:
        reject_reason = f"Rejecting the merge as no rules are defined for the repository in {MERGE_RULE_PATH}"
        raise RuntimeError(reject_reason)

    checks = pr.get_checkrun_conclusions()
    checks = get_classifications(
        pr.pr_num,
        pr.project,
        checks,
        ignore_current_checks=ignore_current_checks,
    )

    # This keeps the list of all approvers that could stamp the change
    all_rule_approvers = {}

    # PRs can fail multiple merge rules, but it only needs to pass one rule to be approved.
    # If it fails all rules, we need to find the rule that it came closest to passing and report
    # that to the dev.
    #
    # reject_reason_score ranks rules by relevancy. The higher the score, the more relevant the
    # rule & rejection reason, and we only care about the most relevant rule/reason
    #
    # reject_reason_score intrepretation:
    # Score 0 to 10K - how many files rule matched
    # Score 10K - matched all files, but no overlapping approvers
    # Score 20K - matched all files and approvers, but mandatory checks are pending
    # Score 30k - Matched all files and approvers, but mandatory checks failed
    reject_reason_score = 0
    for rule in rules:
        rule_name = rule.name
        patterns_re = patterns_to_regex(rule.patterns)
        non_matching_files = []

        # Does this rule apply to all the files?
        for fname in changed_files:
            if not patterns_re.match(fname):
                non_matching_files.append(fname)
        if len(non_matching_files) > 0:
            num_matching_files = len(changed_files) - len(non_matching_files)
            if num_matching_files > reject_reason_score:
                reject_reason_score = num_matching_files
                reject_reason = "\n".join(
                    (
                        f"Not all files match rule `{rule_name}`.",
                        f"{num_matching_files} files matched, but there are still non-matching files:",
                        f"{','.join(non_matching_files[:5])}{', ...' if len(non_matching_files) > 5 else ''}",
                    )
                )
            continue

        # If rule needs approvers but PR has not been reviewed, skip it
        if len(rule.approved_by) > 0 and len(approved_by) == 0:
            if reject_reason_score < 10000:
                reject_reason_score = 10000
                reject_reason = f"PR #{pr.pr_num} has not been reviewed yet"
            continue

        # Does the PR have the required approvals for this rule?
        rule_approvers = set()
        for approver in rule.approved_by:
            if "/" in approver:
                org, name = approver.split("/")
                rule_approvers.update(gh_get_team_members(org, name))
            else:
                rule_approvers.add(approver)
        approvers_intersection = approved_by.intersection(rule_approvers)
        # If rule requires approvers but they aren't the ones that reviewed PR
        if len(approvers_intersection) == 0 and len(rule_approvers) > 0:
            # Less than or equal is intentionally used here to gather all potential
            # approvers
            if reject_reason_score <= 10000:
                reject_reason_score = 10000

                all_rule_approvers[rule.name] = rule.approved_by
                # Prepare the reject reason
                all_rule_approvers_msg = [
                    f"- {name} ({', '.join(approved_by[:5])}{', ...' if len(approved_by) > 5 else ''})"
                    for name, approved_by in all_rule_approvers.items()
                ]

                reject_reason = "Approvers from one of the following sets are needed:\n"
                reject_reason += "\n".join(all_rule_approvers_msg)

            continue

        # Does the PR pass the checks required by this rule?
        mandatory_checks = (
            rule.mandatory_checks_name if rule.mandatory_checks_name is not None else []
        )
        required_checks = list(
            filter(
                lambda x: ("EasyCLA" in x)
                or ("Facebook CLA Check" in x)
                or not skip_mandatory_checks,
                mandatory_checks,
            )
        )
        pending_checks, failed_checks, _ = categorize_checks(
            checks,
            required_checks,
            ok_failed_checks_threshold=IGNORABLE_FAILED_CHECKS_THESHOLD
            if rule.ignore_flaky_failures
            else 0,
        )

        # categorize_checks assumes all tests are required if required_checks is empty.
        # this is a workaround as we want to keep that behavior for categorize_checks
        # generally.
        if not required_checks:
            pending_checks = []
            failed_checks = []

        hud_link = f"https://hud.pytorch.org/{pr.org}/{pr.project}/commit/{pr.last_commit()['oid']}"
        if len(failed_checks) > 0:
            if reject_reason_score < 30000:
                reject_reason_score = 30000
                reject_reason = "\n".join(
                    (
                        f"{len(failed_checks)} mandatory check(s) failed.  The first few are:",
                        *checks_to_markdown_bullets(failed_checks),
                        "",
                        f"Dig deeper by [viewing the failures on hud]({hud_link})",
                    )
                )
            continue
        elif len(pending_checks) > 0:
            if reject_reason_score < 20000:
                reject_reason_score = 20000
                reject_reason = "\n".join(
                    (
                        f"{len(pending_checks)} mandatory check(s) are pending/not yet run.  The first few are:",
                        *checks_to_markdown_bullets(pending_checks),
                        "",
                        f"Dig deeper by [viewing the pending checks on hud]({hud_link})",
                    )
                )
            continue

        if not skip_internal_checks and pr.has_internal_changes():
            raise RuntimeError(
                "This PR has internal changes and must be landed via Phabricator! Please try reimporting/rexporting the PR!"
            )

        # Categorize all checks when skip_mandatory_checks (force merge) is set. Do it here
        # where the list of checks is readily available. These records will be saved into
        # s3 merge records
        (
            pending_mandatory_checks,
            failed_mandatory_checks,
            ignorable_checks,
        ) = categorize_checks(
            checks,
            [],
            ok_failed_checks_threshold=IGNORABLE_FAILED_CHECKS_THESHOLD,
        )
        return (
            rule,
            pending_mandatory_checks,
            failed_mandatory_checks,
            ignorable_checks,
        )

    if reject_reason_score == 20000:
        raise MandatoryChecksMissingError(reject_reason, rule)
    raise MergeRuleFailedError(reject_reason, rule)


def checks_to_str(checks: list[tuple[str, Optional[str]]]) -> str:
    return ", ".join(f"[{c[0]}]({c[1]})" if c[1] is not None else c[0] for c in checks)


def checks_to_markdown_bullets(
    checks: list[tuple[str, Optional[str], Optional[int]]],
) -> list[str]:
    return [
        f"- [{c[0]}]({c[1]})" if c[1] is not None else f"- {c[0]}" for c in checks[:5]
    ]


def post_starting_merge_comment(
    repo: GitRepo,
    pr: GitHubPR,
    explainer: TryMergeExplainer,
    dry_run: bool,
    ignore_current_checks_info: Optional[
        list[tuple[str, Optional[str], Optional[int]]]
    ] = None,
) -> None:
    """Post the initial merge starting message on the PR. Also post a short
    message on all PRs in the stack."""
    gh_post_pr_comment(
        pr.org,
        pr.project,
        pr.pr_num,
        explainer.get_merge_message(ignore_current_checks_info),
        dry_run=dry_run,
    )
    if pr.is_ghstack_pr():
        for additional_prs, _ in get_ghstack_prs(repo, pr):
            if additional_prs.pr_num != pr.pr_num:
                gh_post_pr_comment(
                    additional_prs.org,
                    additional_prs.project,
                    additional_prs.pr_num,
                    f"Starting merge as part of PR stack under #{pr.pr_num}",
                    dry_run=dry_run,
                )


def manually_close_merged_pr(
    pr: GitHubPR,
    additional_merged_prs: list[GitHubPR],
    merge_commit_sha: str,
    dry_run: bool,
) -> None:
    def _comment_and_close(pr: GitHubPR, comment: str) -> None:
        pr = GitHubPR(pr.org, pr.project, pr.pr_num)  # Refresh the PR
        if not pr.is_closed():
            gh_post_pr_comment(pr.org, pr.project, pr.pr_num, comment, dry_run)
            gh_close_pr(pr.org, pr.project, pr.pr_num, dry_run)

    message = (
        f"This PR (#{pr.pr_num}) was merged in {merge_commit_sha} but it is still open, likely due to a Github bug, "
        "so mergebot is closing it manually.  If you think this is a mistake, please feel free to reopen and contact Dev Infra."
    )
    _comment_and_close(pr, message)
    for additional_pr in additional_merged_prs:
        message = (
            f"This PR (#{additional_pr.pr_num}) was merged as part of PR #{pr.pr_num} in the stack under {merge_commit_sha} "
            "but it is still open, likely due to a Github bug, so mergebot is closing it manually. "
            "If you think this is a mistake, please feel free to reopen and contact Dev Infra."
        )
        _comment_and_close(additional_pr, message)

    print(f"PR {pr.pr_num} and all additional PRs in the stack have been closed.")


@retries_decorator()
def save_merge_record(
    comment_id: int,
    pr_num: int,
    owner: str,
    project: str,
    author: str,
    pending_checks: list[tuple[str, Optional[str], Optional[int]]],
    failed_checks: list[tuple[str, Optional[str], Optional[int]]],
    ignore_current_checks: list[tuple[str, Optional[str], Optional[int]]],
    broken_trunk_checks: list[tuple[str, Optional[str], Optional[int]]],
    flaky_checks: list[tuple[str, Optional[str], Optional[int]]],
    unstable_checks: list[tuple[str, Optional[str], Optional[int]]],
    last_commit_sha: str,
    merge_base_sha: str,
    merge_commit_sha: str = "",
    is_failed: bool = False,
    skip_mandatory_checks: bool = False,
    ignore_current: bool = False,
    error: str = "",
) -> None:
    """
    This saves the merge records as a json, which can later be uploaded to s3
    """

    # Prepare the record to be written into s3
    data = [
        {
            "comment_id": comment_id,
            "pr_num": pr_num,
            "owner": owner,
            "project": project,
            "author": author,
            "pending_checks": pending_checks,
            "failed_checks": failed_checks,
            "ignore_current_checks": ignore_current_checks,
            "broken_trunk_checks": broken_trunk_checks,
            "flaky_checks": flaky_checks,
            "unstable_checks": unstable_checks,
            "last_commit_sha": last_commit_sha,
            "merge_base_sha": merge_base_sha,
            "merge_commit_sha": merge_commit_sha,
            "is_failed": is_failed,
            "skip_mandatory_checks": skip_mandatory_checks,
            "ignore_current": ignore_current,
            "error": error,
            # This is a unique identifier for the record for deduping purposes
            # in Rockset.  Any unique string would work.  This will not be used
            # after we migrate off Rockset
            "_id": f"{project}-{pr_num}-{comment_id}-{os.environ.get('GITHUB_RUN_ID')}",
        }
    ]
    repo_root = Path(__file__).resolve().parent.parent.parent

    with open(repo_root / "merge_record.json", "w") as f:
        json.dump(data, f)


@retries_decorator()
def get_drci_classifications(pr_num: int, project: str = "pytorch") -> Any:
    """
    Query HUD API to find similar failures to decide if they are flaky
    """
    # NB: This doesn't work internally atm because this requires making an
    # external API call to HUD
    failures = gh_fetch_url(
        f"https://hud.pytorch.org/api/drci/drci?prNumber={pr_num}",
        data=f"repo={project}",
        headers={
            "Authorization": os.getenv("DRCI_BOT_KEY", ""),
            "Accept": "application/vnd.github.v3+json",
        },
        method="POST",
        reader=json.load,
    )

    return failures.get(str(pr_num), {}) if failures else {}


REMOVE_JOB_NAME_SUFFIX_REGEX = re.compile(r", [0-9]+, [0-9]+, .+\)$")


def remove_job_name_suffix(name: str, replacement: str = ")") -> str:
    return re.sub(REMOVE_JOB_NAME_SUFFIX_REGEX, replacement, name)


def is_broken_trunk(
    check: JobCheckState,
    drci_classifications: Any,
) -> bool:
    if not check or not drci_classifications:
        return False

    name = check.name
    job_id = check.job_id

    # Consult the list of broken trunk failures from Dr.CI
    return any(
        (name == broken_trunk["name"]) or (job_id and job_id == broken_trunk["id"])
        for broken_trunk in drci_classifications.get("BROKEN_TRUNK", [])
    )


def is_unstable(
    check: JobCheckState,
    drci_classifications: Any,
) -> bool:
    if not check or not drci_classifications:
        return False

    name = check.name
    job_id = check.job_id

    # The job name has the unstable keyword. This is the original way to mark a job
    # as unstable on HUD, Dr.CI, and trymerge
    if "unstable" in name:
        return True

    # Consult the list of unstable failures from Dr.CI
    return any(
        (name == unstable["name"] or (job_id and job_id == unstable["id"]))
        for unstable in drci_classifications.get("UNSTABLE", [])
    )


def is_flaky(
    check: JobCheckState,
    drci_classifications: Any,
) -> bool:
    if not check or not drci_classifications:
        return False

    name = check.name
    job_id = check.job_id

    # Consult the list of flaky failures from Dr.CI
    return any(
        (name == flaky["name"] or (job_id and job_id == flaky["id"]))
        for flaky in drci_classifications.get("FLAKY", [])
    )


def is_invalid_cancel(
    name: str,
    conclusion: Optional[str],
    drci_classifications: Any,
) -> bool:
    """
    After https://github.com/pytorch/test-infra/pull/4579, invalid cancelled
    signals have been removed from HUD and Dr.CI. The same needs to be done
    here for consistency
    """
    if (
        not name
        or not drci_classifications
        or not conclusion
        or conclusion.upper() != "CANCELLED"
    ):
        return False

    # If a job is cancelled and not listed as a failure by Dr.CI, it's an
    # invalid signal and can be ignored
    return all(
        name != failure["name"] for failure in drci_classifications.get("FAILED", [])
    )


def get_classifications(
    pr_num: int,
    project: str,
    checks: dict[str, JobCheckState],
    ignore_current_checks: Optional[list[str]],
) -> dict[str, JobCheckState]:
    # Get the failure classification from Dr.CI, which is the source of truth
    # going forward. It's preferable to try calling Dr.CI API directly first
    # to get the latest results as well as update Dr.CI PR comment
    drci_classifications = get_drci_classifications(pr_num=pr_num, project=project)

    def get_readable_drci_results(drci_classifications: Any) -> str:
        try:
            s = f"From Dr.CI API ({pr_num}):\n"
            for classification, jobs in drci_classifications.items():
                s += f"  {classification}: \n"
                for job in jobs:
                    s += f"    {job['id']} {job['name']}\n"
            return s
        except Exception:
            return f"From Dr.CI API: {json.dumps(drci_classifications)}"

    print(get_readable_drci_results(drci_classifications))

    # NB: if the latest results from Dr.CI is not available, i.e. when calling from
    # SandCastle, we fallback to any results we can find on Dr.CI check run summary
    if (
        not drci_classifications
        and DRCI_CHECKRUN_NAME in checks
        and checks[DRCI_CHECKRUN_NAME]
        and checks[DRCI_CHECKRUN_NAME].summary
    ):
        drci_summary = checks[DRCI_CHECKRUN_NAME].summary
        try:
            print(f"From Dr.CI checkrun summary: {drci_summary}")
            drci_classifications = json.loads(str(drci_summary))
        except json.JSONDecodeError:
            warn("Invalid Dr.CI checkrun summary")
            drci_classifications = {}

    checks_with_classifications = checks.copy()
    for name, check in checks.items():
        if check.status == "SUCCESS" or check.status == "NEUTRAL":
            continue

        if is_unstable(check, drci_classifications):
            checks_with_classifications[name] = JobCheckState(
                check.name,
                check.url,
                check.status,
                "UNSTABLE",
                check.job_id,
                check.title,
                check.summary,
            )
            continue

        # NB: It's important to note that when it comes to ghstack and broken trunk classification,
        # Dr.CI uses the base of the whole stack
        if is_broken_trunk(check, drci_classifications):
            checks_with_classifications[name] = JobCheckState(
                check.name,
                check.url,
                check.status,
                "BROKEN_TRUNK",
                check.job_id,
                check.title,
                check.summary,
            )
            continue

        elif is_flaky(check, drci_classifications):
            checks_with_classifications[name] = JobCheckState(
                check.name,
                check.url,
                check.status,
                "FLAKY",
                check.job_id,
                check.title,
                check.summary,
            )
            continue

        elif is_invalid_cancel(name, check.status, drci_classifications):
            # NB: Create a new category here for invalid cancelled signals because
            # there are usually many of them when they happen. So, they shouldn't
            # be counted toward ignorable failures threshold
            checks_with_classifications[name] = JobCheckState(
                check.name,
                check.url,
                check.status,
                "INVALID_CANCEL",
                check.job_id,
                check.title,
                check.summary,
            )
            continue

        if ignore_current_checks is not None and name in ignore_current_checks:
            checks_with_classifications[name] = JobCheckState(
                check.name,
                check.url,
                check.status,
                "IGNORE_CURRENT_CHECK",
                check.job_id,
                check.title,
                check.summary,
            )

    return checks_with_classifications


def filter_checks_with_lambda(
    checks: JobNameToStateDict, status_filter: Callable[[Optional[str]], bool]
) -> list[JobCheckState]:
    return [check for check in checks.values() if status_filter(check.status)]


def get_pr_commit_sha(repo: GitRepo, pr: GitHubPR) -> str:
    commit_sha = pr.get_merge_commit()
    if commit_sha is not None:
        return commit_sha
    commits = repo.commits_resolving_gh_pr(pr.pr_num)
    if len(commits) == 0:
        raise PostCommentError("Can't find any commits resolving PR")
    return commits[0]


def validate_revert(
    repo: GitRepo, pr: GitHubPR, *, comment_id: Optional[int] = None
) -> tuple[str, str]:
    comment = (
        pr.get_last_comment()
        if comment_id is None
        else pr.get_comment_by_id(comment_id)
    )
    if comment.editor_login is not None:
        raise PostCommentError("Don't want to revert based on edited command")
    author_association = comment.author_association
    author_login = comment.author_login
    allowed_reverters = ["COLLABORATOR", "MEMBER", "OWNER"]
    # For some reason, one can not be a member of private repo, only CONTRIBUTOR
    if pr.is_base_repo_private():
        allowed_reverters.append("CONTRIBUTOR")
    if author_association not in allowed_reverters:
        raise PostCommentError(
            f"Will not revert as @{author_login} is not one of "
            f"[{', '.join(allowed_reverters)}], but instead is {author_association}."
        )

    # Raises exception if matching rule is not found, but ignores all status checks
    find_matching_merge_rule(
        pr, repo, skip_mandatory_checks=True, skip_internal_checks=True
    )
    commit_sha = get_pr_commit_sha(repo, pr)
    return (author_login, commit_sha)


def get_ghstack_dependent_prs(
    repo: GitRepo, pr: GitHubPR, only_closed: bool = True
) -> list[tuple[str, GitHubPR]]:
    """
    Get the PRs in the stack that are above this PR (inclusive).
    Throws error if stack have branched or original branches are gone
    """
    assert pr.is_ghstack_pr()
    orig_ref = f"{repo.remote}/{pr.get_ghstack_orig_ref()}"
    rev_list = repo.revlist(f"{pr.default_branch()}..{orig_ref}")
    if len(rev_list) == 0:
        raise RuntimeError(
            f"PR {pr.pr_num} does not have any revisions associated with it"
        )
    skip_len = len(rev_list) - 1
    for branch in repo.branches_containing_ref(orig_ref):
        candidate = repo.revlist(f"{pr.default_branch()}..{branch}")
        # Pick longest candidate
        if len(candidate) > len(rev_list):
            candidate, rev_list = rev_list, candidate
        # Validate that candidate always ends rev-list
        if rev_list[-len(candidate) :] != candidate:
            raise RuntimeError(
                f"Branch {branch} revlist {', '.join(candidate)} is not a subset of {', '.join(rev_list)}"
            )
    # Remove commits original PR depends on
    if skip_len > 0:
        rev_list = rev_list[:-skip_len]
    rc: list[tuple[str, GitHubPR]] = []
    for pr_, sha in _revlist_to_prs(repo, pr, rev_list):
        if not pr_.is_closed():
            if not only_closed:
                rc.append(("", pr_))
            continue
        commit_sha = get_pr_commit_sha(repo, pr_)
        rc.append((commit_sha, pr_))
    return rc


def do_revert_prs(
    repo: GitRepo,
    shas_and_prs: list[tuple[str, GitHubPR]],
    *,
    author_login: str,
    extra_msg: str = "",
    skip_internal_checks: bool = False,
    dry_run: bool = False,
) -> None:
    # Prepare and push revert commits
    for commit_sha, pr in shas_and_prs:
        revert_msg = f"\nReverted {pr.get_pr_url()} on behalf of {prefix_with_github_url(author_login)}"
        revert_msg += extra_msg
        repo.checkout(pr.default_branch())
        repo.revert(commit_sha)
        msg = repo.commit_message("HEAD")
        msg = re.sub(RE_PULL_REQUEST_RESOLVED, "", msg)
        msg += revert_msg
        repo.amend_commit_message(msg)
    repo.push(shas_and_prs[0][1].default_branch(), dry_run)

    # Comment/reopen PRs
    for commit_sha, pr in shas_and_prs:
        revert_message = (
            f"@{pr.get_pr_creator_login()} your PR has been successfully reverted."
        )
        if (
            pr.has_internal_changes()
            and not pr.has_no_connected_diff()
            and not skip_internal_checks
        ):
            revert_message += "\n:warning: This PR might contain internal changes"
            revert_message += "\ncc: @pytorch/pytorch-dev-infra"
        gh_post_pr_comment(
            pr.org, pr.project, pr.pr_num, revert_message, dry_run=dry_run
        )

        pr.add_numbered_label("reverted", dry_run)
        pr.add_label("ci-no-td", dry_run)
        if not dry_run:
            gh_post_commit_comment(pr.org, pr.project, commit_sha, revert_msg)
            gh_update_pr_state(pr.org, pr.project, pr.pr_num)


def try_revert(
    repo: GitRepo,
    pr: GitHubPR,
    *,
    dry_run: bool = False,
    comment_id: Optional[int] = None,
    reason: Optional[str] = None,
) -> None:
    try:
        author_login, commit_sha = validate_revert(repo, pr, comment_id=comment_id)
    except PostCommentError as e:
        gh_post_pr_comment(pr.org, pr.project, pr.pr_num, str(e), dry_run=dry_run)
        return

    extra_msg = f" due to {reason}" if reason is not None else ""
    extra_msg += (
        f" ([comment]({pr.get_comment_by_id(comment_id).url}))\n"
        if comment_id is not None
        else "\n"
    )
    shas_and_prs = [(commit_sha, pr)]
    if pr.is_ghstack_pr():
        try:
            shas_and_prs = get_ghstack_dependent_prs(repo, pr)
            prs_to_revert = " ".join([t[1].get_pr_url() for t in shas_and_prs])
            print(f"About to stack of PRs: {prs_to_revert}")
        except Exception as e:
            print(
                f"Failed to fetch dependent PRs: {str(e)}, fall over to single revert"
            )

    do_revert_prs(
        repo,
        shas_and_prs,
        author_login=author_login,
        extra_msg=extra_msg,
        dry_run=dry_run,
        skip_internal_checks=can_skip_internal_checks(pr, comment_id),
    )


def prefix_with_github_url(suffix_str: str) -> str:
    return f"https://github.com/{suffix_str}"


def check_for_sev(org: str, project: str, skip_mandatory_checks: bool) -> None:
    if skip_mandatory_checks:
        return
    response = cast(
        dict[str, Any],
        gh_fetch_json_list(
            "https://api.github.com/search/issues",
            # Having two label: queries is an AND operation
            params={
                "q": f'repo:{org}/{project} is:open is:issue label:"ci: sev" label:"merge blocking"'
            },
        ),
    )
    if response["total_count"] != 0:
        raise RuntimeError(
            "Not merging any PRs at the moment because there is a "
            + "merge blocking https://github.com/pytorch/pytorch/labels/ci:%20sev issue open at: \n"
            + f"{response['items'][0]['html_url']}"
        )
    return


def has_label(labels: list[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool:
    return len(list(filter(pattern.match, labels))) > 0


def categorize_checks(
    check_runs: JobNameToStateDict,
    required_checks: list[str],
    ok_failed_checks_threshold: Optional[int] = None,
) -> tuple[
    list[tuple[str, Optional[str], Optional[int]]],
    list[tuple[str, Optional[str], Optional[int]]],
    dict[str, list[Any]],
]:
    """
    Categories all jobs into the list of pending and failing jobs. All known flaky
    failures and broken trunk are ignored by defaults when ok_failed_checks_threshold
    is not set (unlimited)
    """
    pending_checks: list[tuple[str, Optional[str], Optional[int]]] = []
    failed_checks: list[tuple[str, Optional[str], Optional[int]]] = []

    # failed_checks_categorization is used to keep track of all ignorable failures when saving the merge record on s3
    failed_checks_categorization: dict[str, list[Any]] = defaultdict(list)

    # If required_checks is not set or empty, consider all names are relevant
    relevant_checknames = [
        name
        for name in check_runs.keys()
        if not required_checks or any(x in name for x in required_checks)
    ]

    for checkname in required_checks:
        if all(checkname not in x for x in check_runs.keys()):
            pending_checks.append((checkname, None, None))

    for checkname in relevant_checknames:
        status = check_runs[checkname].status
        url = check_runs[checkname].url
        classification = check_runs[checkname].classification
        job_id = check_runs[checkname].job_id

        if status is None and classification != "UNSTABLE":
            # NB: No need to wait if the job classification is unstable as it would be
            # ignored anyway. This is useful to not need to wait for scarce resources
            # like ROCm, which is also frequently in unstable mode
            pending_checks.append((checkname, url, job_id))
        elif classification == "INVALID_CANCEL":
            continue
        elif not is_passing_status(check_runs[checkname].status):
            target = (
                failed_checks_categorization[classification]
                if classification
                in ("IGNORE_CURRENT_CHECK", "BROKEN_TRUNK", "FLAKY", "UNSTABLE")
                else failed_checks
            )
            target.append((checkname, url, job_id))

    flaky_or_broken_trunk = (
        failed_checks_categorization["BROKEN_TRUNK"]
        + failed_checks_categorization["FLAKY"]
    )

    if flaky_or_broken_trunk:
        warn(
            f"The following {len(flaky_or_broken_trunk)} checks failed but were likely due flakiness or broken trunk: "
            + ", ".join([x[0] for x in flaky_or_broken_trunk])
            + (
                f" but this is greater than the threshold of {ok_failed_checks_threshold} so merge will fail"
                if ok_failed_checks_threshold is not None
                and len(flaky_or_broken_trunk) > ok_failed_checks_threshold
                else ""
            )
        )

    if (
        ok_failed_checks_threshold is not None
        and len(flaky_or_broken_trunk) > ok_failed_checks_threshold
    ):
        failed_checks = failed_checks + flaky_or_broken_trunk

    # The list of failed_checks_categorization is returned so that it can be saved into the s3 merge record
    return (pending_checks, failed_checks, failed_checks_categorization)


def merge(
    pr: GitHubPR,
    repo: GitRepo,
    dry_run: bool = False,
    skip_mandatory_checks: bool = False,
    comment_id: Optional[int] = None,
    timeout_minutes: int = 400,
    stale_pr_days: int = 3,
    ignore_current: bool = False,
) -> None:
    initial_commit_sha = pr.last_commit()["oid"]
    pr_link = f"https://github.com/{pr.org}/{pr.project}/pull/{pr.pr_num}"
    print(f"Attempting merge of {initial_commit_sha} ({pr_link})")

    if MERGE_IN_PROGRESS_LABEL not in pr.get_labels():
        gh_add_labels(pr.org, pr.project, pr.pr_num, [MERGE_IN_PROGRESS_LABEL], dry_run)

    explainer = TryMergeExplainer(
        skip_mandatory_checks,
        pr.get_labels(),
        pr.pr_num,
        pr.org,
        pr.project,
        ignore_current,
    )

    # probably a bad name, but this is a list of current checks that should be
    # ignored and is toggled by the --ignore-current flag
    ignore_current_checks_info = []

    if pr.is_ghstack_pr():
        get_ghstack_prs(repo, pr)  # raises error if out of sync

    check_for_sev(pr.org, pr.project, skip_mandatory_checks)

    if skip_mandatory_checks:
        post_starting_merge_comment(repo, pr, explainer, dry_run)
        return pr.merge_into(
            repo,
            dry_run=dry_run,
            skip_mandatory_checks=skip_mandatory_checks,
            comment_id=comment_id,
        )

    # Check for approvals
    find_matching_merge_rule(pr, repo, skip_mandatory_checks=True)

    if not has_required_labels(pr):
        raise RuntimeError(LABEL_ERR_MSG.lstrip(" #"))

    if ignore_current:
        checks = pr.get_checkrun_conclusions()
        _, failing, _ = categorize_checks(
            checks,
            list(checks.keys()),
            ok_failed_checks_threshold=IGNORABLE_FAILED_CHECKS_THESHOLD,
        )
        ignore_current_checks_info = failing

    post_starting_merge_comment(
        repo,
        pr,
        explainer,
        dry_run,
        ignore_current_checks_info=ignore_current_checks_info,
    )

    start_time = time.time()
    last_exception = ""
    elapsed_time = 0.0
    ignore_current_checks = [
        x[0] for x in ignore_current_checks_info
    ]  # convert to List[str] for convenience
    while elapsed_time < timeout_minutes * 60:
        check_for_sev(pr.org, pr.project, skip_mandatory_checks)
        current_time = time.time()
        elapsed_time = current_time - start_time
        print(
            f"Attempting merge of https://github.com/{pr.org}/{pr.project}/pull/{pr.pr_num} ({elapsed_time / 60} minutes elapsed)"
        )
        pr = GitHubPR(pr.org, pr.project, pr.pr_num)
        if initial_commit_sha != pr.last_commit()["oid"]:
            raise RuntimeError(
                "New commits were pushed while merging. Please rerun the merge command."
            )
        try:
            required_checks = []
            failed_rule_message = None
            ignore_flaky_failures = True
            try:
                find_matching_merge_rule(
                    pr, repo, ignore_current_checks=ignore_current_checks
                )
            except MandatoryChecksMissingError as ex:
                if ex.rule is not None:
                    ignore_flaky_failures = ex.rule.ignore_flaky_failures
                    if ex.rule.mandatory_checks_name is not None:
                        required_checks = ex.rule.mandatory_checks_name
                failed_rule_message = ex

            checks = pr.get_checkrun_conclusions()
            checks = get_classifications(
                pr.pr_num,
                pr.project,
                checks,
                ignore_current_checks=ignore_current_checks,
            )
            pending, failing, _ = categorize_checks(
                checks,
                required_checks
                + [x for x in checks.keys() if x not in required_checks],
                ok_failed_checks_threshold=IGNORABLE_FAILED_CHECKS_THESHOLD
                if ignore_flaky_failures
                else 0,
            )
            # HACK until GitHub will be better about surfacing those
            startup_failures = filter_checks_with_lambda(
                checks, lambda status: status == "STARTUP_FAILURE"
            )
            if len(startup_failures) > 0:
                raise RuntimeError(
                    f"{len(startup_failures)} STARTUP failures reported, please check workflows syntax! "
                    + ", ".join(f"[{x.name}]({x.url})" for x in startup_failures[:5])
                )
            # END of HACK

            if len(failing) > 0:
                raise RuntimeError(
                    f"{len(failing)} jobs have failed, first few of them are: "
                    + ", ".join(f"[{x[0]}]({x[1]})" for x in failing[:5])
                )
            if len(pending) > 0:
                if failed_rule_message is not None:
                    raise failed_rule_message
                else:
                    raise MandatoryChecksMissingError(
                        f"Still waiting for {len(pending)} jobs to finish, "
                        + f"first few of them are: {', '.join(x[0] for x in pending[:5])}"
                    )

            return pr.merge_into(
                repo,
                dry_run=dry_run,
                skip_mandatory_checks=skip_mandatory_checks,
                comment_id=comment_id,
                ignore_current_checks=ignore_current_checks,
            )
        except MandatoryChecksMissingError as ex:
            last_exception = str(ex)
            print(
                f"Merge of https://github.com/{pr.org}/{pr.project}/pull/{pr.pr_num} failed due to: {ex}. Retrying in 5 min"
            )
            time.sleep(5 * 60)
    # Finally report timeout back
    msg = f"Merged timed out after {timeout_minutes} minutes. Please contact the pytorch_dev_infra team."
    msg += f"The last exception was: {last_exception}"
    gh_add_labels(pr.org, pr.project, pr.pr_num, ["land-failed"], dry_run)
    raise RuntimeError(msg)


def main() -> None:
    args = parse_args()
    repo = GitRepo(get_git_repo_dir(), get_git_remote_name())
    org, project = repo.gh_owner_and_name()
    pr = GitHubPR(org, project, args.pr_num)

    def handle_exception(e: Exception, title: str = "Merge failed") -> None:
        exception = f"**Reason**: {e}"

        failing_rule = None
        if isinstance(e, MergeRuleFailedError):
            failing_rule = e.rule.name if e.rule else None

        internal_debugging = ""
        run_url = os.getenv("GH_RUN_URL")
        if run_url is not None:
            # Hide this behind a collapsed bullet since it's not helpful to most devs
            internal_debugging = "\n".join(
                line
                for line in (
                    "<details><summary>Details for Dev Infra team</summary>",
                    f'Raised by <a href="{run_url}">workflow job</a>\n',
                    f"Failing merge rule: {failing_rule}" if failing_rule else "",
                    "</details>",
                )
                if line
            )  # ignore empty lines during the join

        msg = "\n".join((f"## {title}", f"{exception}", "", f"{internal_debugging}"))

        gh_post_pr_comment(org, project, args.pr_num, msg, dry_run=args.dry_run)
        import traceback

        traceback.print_exc()

    if args.revert:
        try:
            gh_post_pr_comment(
                org,
                project,
                args.pr_num,
                get_revert_message(org, project, pr.pr_num),
                args.dry_run,
            )
            try_revert(
                repo,
                pr,
                dry_run=args.dry_run,
                comment_id=args.comment_id,
                reason=args.reason,
            )
        except Exception as e:
            handle_exception(e, f"Reverting PR {args.pr_num} failed")
        return

    if pr.is_closed():
        gh_post_pr_comment(
            org,
            project,
            args.pr_num,
            f"Can't merge closed PR #{args.pr_num}",
            dry_run=args.dry_run,
        )
        return

    if pr.is_cross_repo() and pr.is_ghstack_pr():
        gh_post_pr_comment(
            org,
            project,
            args.pr_num,
            "Cross-repo ghstack merges are not supported",
            dry_run=args.dry_run,
        )
        return
    if not pr.is_ghstack_pr() and pr.base_ref() != pr.default_branch():
        gh_post_pr_comment(
            org,
            project,
            args.pr_num,
            f"PR targets {pr.base_ref()} rather than {pr.default_branch()}, refusing merge request",
            dry_run=args.dry_run,
        )
        return

    if args.check_mergeability:
        if pr.is_ghstack_pr():
            get_ghstack_prs(repo, pr)  # raises error if out of sync
        pr.merge_changes(
            repo,
            skip_mandatory_checks=True,
            skip_all_rule_checks=True,
        )
        return

    if not args.force and pr.has_invalid_submodule_updates():
        message = (
            f"This PR updates submodules {', '.join(pr.get_changed_submodules())}\n"
        )
        message += '\nIf those updates are intentional, please add "submodule" keyword to PR title/description.'
        gh_post_pr_comment(org, project, args.pr_num, message, dry_run=args.dry_run)
        return
    try:
        merge(
            pr,
            repo,
            dry_run=args.dry_run,
            skip_mandatory_checks=args.force,
            comment_id=args.comment_id,
            ignore_current=args.ignore_current,
        )
    except Exception as e:
        handle_exception(e)

        if args.comment_id and args.pr_num:
            # Finally, upload the record to s3, we don't have access to the
            # list of pending and failed checks here, but they are not really
            # needed at the moment
            save_merge_record(
                comment_id=args.comment_id,
                pr_num=args.pr_num,
                owner=org,
                project=project,
                author=pr.get_author(),
                pending_checks=[],
                failed_checks=[],
                ignore_current_checks=[],
                broken_trunk_checks=[],
                flaky_checks=[],
                unstable_checks=[],
                last_commit_sha=pr.last_commit().get("oid", ""),
                merge_base_sha=pr.get_merge_base(),
                is_failed=True,
                skip_mandatory_checks=args.force,
                ignore_current=args.ignore_current,
                error=str(e),
            )
        else:
            print("Missing comment ID or PR number, couldn't upload to s3")
    finally:
        if not args.check_mergeability:
            gh_remove_label(
                org, project, args.pr_num, MERGE_IN_PROGRESS_LABEL, args.dry_run
            )


if __name__ == "__main__":
    main()
