"""test_check_labels.py"""

from typing import Any
from unittest import main, mock, TestCase

from check_labels import (
    add_label_err_comment,
    delete_all_label_err_comments,
    main as check_labels_main,
)
from github_utils import GitHubComment
from label_utils import BOT_AUTHORS, LABEL_ERR_MSG_TITLE
from test_trymerge import mock_gh_get_info, mocked_gh_graphql
from trymerge import GitHubPR


def mock_parse_args() -> object:
    class Object:
        def __init__(self) -> None:
            self.pr_num = 76123
            self.exit_non_zero = False

    return Object()


def mock_add_label_err_comment(pr: "GitHubPR") -> None:
    pass


def mock_delete_all_label_err_comments(pr: "GitHubPR") -> None:
    pass


def mock_get_comments() -> list[GitHubComment]:
    return [
        # Case 1 - a non label err comment
        GitHubComment(
            body_text="mock_body_text",
            created_at="",
            author_login="",
            author_association="",
            editor_login=None,
            database_id=1,
            url="",
        ),
        # Case 2 - a label err comment
        GitHubComment(
            body_text=" #" + LABEL_ERR_MSG_TITLE.replace("`", ""),
            created_at="",
            author_login=BOT_AUTHORS[1],
            author_association="",
            editor_login=None,
            database_id=2,
            url="",
        ),
    ]


class TestCheckLabels(TestCase):
    @mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
    @mock.patch("trymerge.GitHubPR.get_comments", return_value=[mock_get_comments()[0]])
    @mock.patch("check_labels.gh_post_pr_comment")
    def test_correctly_add_label_err_comment(
        self, mock_gh_post_pr_comment: Any, mock_get_comments: Any, mock_gh_grphql: Any
    ) -> None:
        "Test add label err comment when similar comments don't exist."
        pr = GitHubPR("pytorch", "pytorch", 75095)
        add_label_err_comment(pr)
        mock_gh_post_pr_comment.assert_called_once()

    @mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
    @mock.patch("trymerge.GitHubPR.get_comments", return_value=[mock_get_comments()[1]])
    @mock.patch("check_labels.gh_post_pr_comment")
    def test_not_add_label_err_comment(
        self, mock_gh_post_pr_comment: Any, mock_get_comments: Any, mock_gh_grphql: Any
    ) -> None:
        "Test not add label err comment when similar comments exist."
        pr = GitHubPR("pytorch", "pytorch", 75095)
        add_label_err_comment(pr)
        mock_gh_post_pr_comment.assert_not_called()

    @mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
    @mock.patch("trymerge.GitHubPR.get_comments", return_value=mock_get_comments())
    @mock.patch("check_labels.gh_delete_comment")
    def test_correctly_delete_all_label_err_comments(
        self, mock_gh_delete_comment: Any, mock_get_comments: Any, mock_gh_grphql: Any
    ) -> None:
        "Test only delete label err comment."
        pr = GitHubPR("pytorch", "pytorch", 75095)
        delete_all_label_err_comments(pr)
        mock_gh_delete_comment.assert_called_once_with("pytorch", "pytorch", 2)

    @mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())
    @mock.patch("check_labels.parse_args", return_value=mock_parse_args())
    @mock.patch("check_labels.has_required_labels", return_value=False)
    @mock.patch(
        "check_labels.delete_all_label_err_comments",
        side_effect=mock_delete_all_label_err_comments,
    )
    @mock.patch(
        "check_labels.add_label_err_comment", side_effect=mock_add_label_err_comment
    )
    def test_ci_comments_and_exit0_without_required_labels(
        self,
        mock_add_label_err_comment: Any,
        mock_delete_all_label_err_comments: Any,
        mock_has_required_labels: Any,
        mock_parse_args: Any,
        mock_gh_get_info: Any,
    ) -> None:
        with self.assertRaises(SystemExit) as sys_exit:
            check_labels_main()
        self.assertEqual(str(sys_exit.exception), "0")
        mock_add_label_err_comment.assert_called_once()
        mock_delete_all_label_err_comments.assert_not_called()

    @mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())
    @mock.patch("check_labels.parse_args", return_value=mock_parse_args())
    @mock.patch("check_labels.has_required_labels", return_value=True)
    @mock.patch(
        "check_labels.delete_all_label_err_comments",
        side_effect=mock_delete_all_label_err_comments,
    )
    @mock.patch(
        "check_labels.add_label_err_comment", side_effect=mock_add_label_err_comment
    )
    def test_ci_exit0_with_required_labels(
        self,
        mock_add_label_err_comment: Any,
        mock_delete_all_label_err_comments: Any,
        mock_has_required_labels: Any,
        mock_parse_args: Any,
        mock_gh_get_info: Any,
    ) -> None:
        with self.assertRaises(SystemExit) as sys_exit:
            check_labels_main()
        self.assertEqual(str(sys_exit.exception), "0")
        mock_add_label_err_comment.assert_not_called()
        mock_delete_all_label_err_comments.assert_called_once()


if __name__ == "__main__":
    main()
