import asyncio
import logging
import os
import pathlib
import sys
from datetime import datetime, timezone
from typing import List, Optional
from uuid import UUID

from fastapi import HTTPException
from pusher import Pusher
from sqlalchemy.orm.exc import StaleDataError
from sqlmodel import Session

from keep.api.arq_pool import get_pool
from keep.api.bl.enrichments_bl import EnrichmentsBl
from keep.api.core.db import (
    add_alerts_to_incident,
    add_audit,
    create_incident_from_dto,
    delete_incident_by_id,
    enrich_alerts_with_incidents,
    get_all_alerts_by_fingerprints,
    get_incident_by_id,
    get_incident_unique_fingerprint_count,
    is_all_alerts_resolved,
    is_first_incident_alert_resolved,
    is_last_incident_alert_resolved,
    remove_alerts_to_incident_by_incident_id,
    update_incident_from_dto_by_id,
    update_incident_severity,
)
from keep.api.core.elastic import ElasticClient
from keep.api.core.incidents import get_last_incidents_by_cel
from keep.api.models.action_type import ActionType
from keep.api.models.db.incident import Incident, IncidentSeverity, IncidentStatus
from keep.api.models.db.rule import ResolveOn
from keep.api.models.incident import IncidentDto, IncidentDtoIn, IncidentSorting
from keep.api.utils.enrichment_helpers import convert_db_alerts_to_dto_alerts
from keep.api.utils.pagination import IncidentsPaginatedResultsDto
from keep.identitymanager.authenticatedentity import AuthenticatedEntity
from keep.workflowmanager.workflowmanager import WorkflowManager

MIN_INCIDENT_ALERTS_FOR_SUMMARY_GENERATION = int(
    os.environ.get("MIN_INCIDENT_ALERTS_FOR_SUMMARY_GENERATION", 5)
)

ee_enabled = os.environ.get("EE_ENABLED", "false") == "true"
if ee_enabled:
    path_with_ee = (
        str(pathlib.Path(__file__).parent.resolve()) + "/../../../ee/experimental"
    )
    sys.path.insert(0, path_with_ee)
else:
    ALGORITHM_VERBOSE_NAME = NotImplemented


class IncidentBl:

    def __init__(
        self,
        tenant_id: str,
        session: Session,
        pusher_client: Optional[Pusher] = None,
        user: str = None,
    ):
        self.tenant_id = tenant_id
        self.user = user
        self.session = session
        self.pusher_client = pusher_client
        self.logger = logging.getLogger(__name__)
        self.ee_enabled = os.environ.get("EE_ENABLED", "false").lower() == "true"
        self.redis = os.environ.get("REDIS", "false") == "true"

    def create_incident(
        self,
        incident_dto: [IncidentDtoIn | IncidentDto],
        generated_from_ai: bool = False,
    ) -> IncidentDto:
        """
        Creates a new incident.

        Args:
            incident_dto (IncidentDtoIn | IncidentDto): The data transfer object containing the details of the incident to be created.
            generated_from_ai (bool, optional): Indicates if the incident was generated by Keep's AI. Defaults to False.

        Returns:
            IncidentDto: The newly created incident object, containing details of the incident.
        """
        self.logger.info(
            "Creating incident",
            extra={"incident_dto": incident_dto.dict(), "tenant_id": self.tenant_id},
        )
        incident = create_incident_from_dto(
            self.tenant_id,
            incident_dto,
            generated_from_ai=generated_from_ai,
            session=self.session,
        )
        self.logger.info(
            "Incident created",
            extra={"incident_id": incident.id, "tenant_id": self.tenant_id},
        )
        new_incident_dto = IncidentDto.from_db_incident(incident)
        self.logger.info(
            "Incident DTO created",
            extra={"incident_id": new_incident_dto.id, "tenant_id": self.tenant_id},
        )
        self.update_client_on_incident_change()
        self.logger.info(
            "Client updated on incident change",
            extra={"incident_id": new_incident_dto.id, "tenant_id": self.tenant_id},
        )
        self.send_workflow_event(new_incident_dto, "created")
        self.logger.info(
            "Workflows run on incident",
            extra={"incident_id": new_incident_dto.id, "tenant_id": self.tenant_id},
        )
        return new_incident_dto

    def sync_add_alerts_to_incident(self, *args, **kwargs) -> None:
        """
        Synchronous wrapper for the async add_alerts_to_incident method.
        """
        asyncio.run(self.add_alerts_to_incident(*args, **kwargs))

    async def add_alerts_to_incident(
        self,
        incident_id: UUID,
        alert_fingerprints: List[str],
        is_created_by_ai: bool = False,
        override_count: bool = False,
    ) -> None:
        self.logger.info(
            "Adding alerts to incident",
            extra={
                "incident_id": incident_id,
                "alert_fingerprints": alert_fingerprints,
            },
        )
        incident = get_incident_by_id(
            tenant_id=self.tenant_id, incident_id=incident_id, session=self.session
        )
        if not incident:
            raise HTTPException(status_code=404, detail="Incident not found")

        add_alerts_to_incident(
            self.tenant_id,
            incident,
            alert_fingerprints,
            is_created_by_ai,
            session=self.session,
            override_count=override_count,
        )
        self.logger.info(
            "Alerts added to incident",
            extra={
                "incident_id": incident_id,
                "alert_fingerprints": alert_fingerprints,
            },
        )
        self.__postprocess_alerts_change(incident, alert_fingerprints)
        await self.__generate_summary(incident_id, incident)
        self.logger.info(
            "Summary generated",
            extra={
                "incident_id": incident_id,
                "alert_fingerprints": alert_fingerprints,
            },
        )

    def __update_elastic(self, alert_fingerprints: List[str]):
        try:
            elastic_client = ElasticClient(self.tenant_id)
            if elastic_client.enabled:
                db_alerts = get_all_alerts_by_fingerprints(
                    tenant_id=self.tenant_id,
                    fingerprints=alert_fingerprints,
                    session=self.session,
                )
                db_alerts = enrich_alerts_with_incidents(
                    self.tenant_id, db_alerts, session=self.session
                )
                enriched_alerts_dto = convert_db_alerts_to_dto_alerts(
                    db_alerts, with_incidents=True
                )
                elastic_client.index_alerts(alerts=enriched_alerts_dto)
        except Exception:
            self.logger.exception("Failed to push alert to elasticsearch")
            raise

    def update_client_on_incident_change(self, incident_id: Optional[UUID] = None):
        if self.pusher_client is not None:
            self.logger.info(
                "Pushing incident change to client",
                extra={"incident_id": incident_id, "tenant_id": self.tenant_id},
            )
            try:
                self.pusher_client.trigger(
                    f"private-{self.tenant_id}",
                    "incident-change",
                    {"incident_id": str(incident_id) if incident_id else None},
                )
                self.logger.info(
                    "Incident change pushed to client",
                    extra={"incident_id": incident_id, "tenant_id": self.tenant_id},
                )
            except Exception:
                self.logger.exception(
                    "Failed to push incident change to client",
                    extra={"incident_id": incident_id, "tenant_id": self.tenant_id},
                )

    def send_workflow_event(self, incident_dto: IncidentDto, action: str) -> None:
        try:
            workflow_manager = WorkflowManager.get_instance()
            workflow_manager.insert_incident(self.tenant_id, incident_dto, action)
        except Exception:
            self.logger.exception(
                "Failed to run workflows based on incident",
                extra={"incident_id": incident_dto.id, "tenant_id": self.tenant_id},
            )

    async def __generate_summary(self, incident_id: UUID, incident: Incident):
        try:
            fingerprints_count = get_incident_unique_fingerprint_count(
                self.tenant_id, incident_id
            )
            if (
                ee_enabled
                and self.redis
                and fingerprints_count > MIN_INCIDENT_ALERTS_FOR_SUMMARY_GENERATION
                and not incident.user_summary
            ):
                pool = await get_pool()
                job = await pool.enqueue_job(
                    "process_summary_generation",
                    tenant_id=self.tenant_id,
                    incident_id=incident_id,
                )
                self.logger.info(
                    f"Summary generation for incident {incident_id} scheduled, job: {job}",
                    extra={
                        "tenant_id": self.tenant_id,
                        "incident_id": incident_id,
                    },
                )
        except Exception:
            self.logger.exception(
                "Failed to generate summary for incident",
                extra={"incident_id": incident_id, "tenant_id": self.tenant_id},
            )

    def delete_alerts_from_incident(
        self, incident_id: UUID, alert_fingerprints: List[str]
    ) -> None:
        self.logger.info(
            "Fetching incident",
            extra={
                "incident_id": incident_id,
                "tenant_id": self.tenant_id,
            },
        )
        incident = get_incident_by_id(tenant_id=self.tenant_id, incident_id=incident_id)
        if not incident:
            raise HTTPException(status_code=404, detail="Incident not found")

        remove_alerts_to_incident_by_incident_id(
            self.tenant_id, incident_id, alert_fingerprints
        )
        self.__postprocess_alerts_change(incident, alert_fingerprints)

    def delete_incident(self, incident_id: UUID) -> None:
        self.logger.info(
            "Fetching incident",
            extra={
                "incident_id": incident_id,
                "tenant_id": self.tenant_id,
            },
        )

        incident = get_incident_by_id(tenant_id=self.tenant_id, incident_id=incident_id)
        if not incident:
            raise HTTPException(status_code=404, detail="Incident not found")

        incident_dto = IncidentDto.from_db_incident(incident)

        deleted = delete_incident_by_id(
            tenant_id=self.tenant_id, incident_id=incident_id
        )
        if not deleted:
            raise HTTPException(status_code=404, detail="Incident not found")

        self.update_client_on_incident_change()
        self.send_workflow_event(incident_dto, "deleted")

    def bulk_delete_incidents(self, incident_ids: List[UUID]) -> None:
        for incident_id in incident_ids:
            self.delete_incident(incident_id)

    def update_incident(
        self,
        incident_id: UUID,
        updated_incident_dto: IncidentDtoIn,
        generated_by_ai: bool,
    ) -> IncidentDto:
        self.logger.info(
            "Fetching incident",
            extra={
                "incident_id": incident_id,
                "tenant_id": self.tenant_id,
            },
        )
        incident = update_incident_from_dto_by_id(
            self.tenant_id, incident_id, updated_incident_dto, generated_by_ai
        )
        return self.__postprocess_incident_change(incident)

    def __postprocess_alerts_change(self, incident, alert_fingerprints):

        self.__update_elastic(alert_fingerprints)
        self.logger.info(
            "Alerts pushed to elastic",
            extra={
                "incident_id": incident.id,
                "alert_fingerprints": alert_fingerprints,
            },
        )
        self.update_client_on_incident_change(incident.id)
        self.logger.info(
            "Client updated on incident change",
            extra={
                "incident_id": incident.id,
                "alert_fingerprints": alert_fingerprints,
            },
        )
        incident_dto = IncidentDto.from_db_incident(incident)
        self.send_workflow_event(incident_dto, "updated")
        self.logger.info(
            "Workflows run on incident",
            extra={
                "incident_id": incident.id,
                "alert_fingerprints": alert_fingerprints,
            },
        )

    def update_severity(
        self,
        incident_id: UUID,
        severity: IncidentSeverity,
        comment: Optional[str] = None,
    ) -> IncidentDto:
        self.logger.info(
            "Fetching incident",
            extra={
                "incident_id": incident_id,
                "tenant_id": self.tenant_id,
            },
        )
        incident = update_incident_severity(
            self.tenant_id,
            incident_id,
            severity,
        )

        if comment:
            add_audit(
                self.tenant_id,
                str(incident_id),
                self.user,
                ActionType.INCIDENT_COMMENT,
                comment,
            )

        return self.__postprocess_incident_change(incident)

    def __postprocess_incident_change(self, incident):
        if not incident:
            raise HTTPException(status_code=404, detail="Incident not found")

        new_incident_dto = IncidentDto.from_db_incident(incident)

        self.update_client_on_incident_change(incident.id)
        self.logger.info(
            "Client updated on incident change",
            extra={"incident_id": incident.id},
        )
        self.send_workflow_event(new_incident_dto, "updated")
        self.logger.info(
            "Workflows run on incident",
            extra={"incident_id": incident.id},
        )
        return new_incident_dto

    @staticmethod
    def query_incidents(
        tenant_id: str,
        limit: int = 25,
        offset: int = 0,
        timeframe: int = None,
        upper_timestamp: datetime = None,
        lower_timestamp: datetime = None,
        is_candidate: bool = False,
        sorting: Optional[IncidentSorting] = IncidentSorting.creation_time,
        with_alerts: bool = False,
        is_predicted: bool = None,
        cel: str = None,
        allowed_incident_ids: Optional[List[str]] = None,
    ):
        incidents, total_count = get_last_incidents_by_cel(
            tenant_id=tenant_id,
            limit=limit,
            offset=offset,
            timeframe=timeframe,
            upper_timestamp=upper_timestamp,
            lower_timestamp=lower_timestamp,
            is_candidate=is_candidate,
            sorting=sorting,
            with_alerts=with_alerts,
            is_predicted=is_predicted,
            cel=cel,
            allowed_incident_ids=allowed_incident_ids,
        )
        incidents_dto = []
        for incident in incidents:
            incidents_dto.append(IncidentDto.from_db_incident(incident))

        return IncidentsPaginatedResultsDto(
            limit=limit, offset=offset, count=total_count, items=incidents_dto
        )

    def resolve_incident_if_require(
        self, incident: Incident, max_retries=3
    ) -> Incident:

        should_resolve = False

        if incident.resolve_on == ResolveOn.ALL.value and is_all_alerts_resolved(
            incident=incident, session=self.session
        ):
            should_resolve = True

        elif (
            incident.resolve_on == ResolveOn.FIRST.value
            and is_first_incident_alert_resolved(incident, session=self.session)
        ):
            should_resolve = True

        elif (
            incident.resolve_on == ResolveOn.LAST.value
            and is_last_incident_alert_resolved(incident, session=self.session)
        ):
            should_resolve = True

        incident_id = incident.id

        if should_resolve:
            for attempt in range(max_retries):
                try:
                    incident.status = IncidentStatus.RESOLVED.value
                    self.session.add(incident)
                    self.session.commit()
                    break
                except StaleDataError as ex:
                    if "expected to update" in ex.args[0]:
                        self.logger.info(
                            f"Phantom read detected while updating incident `{incident_id}`, retry #{attempt}"
                        )
                        self.session.rollback()
                        continue

        return incident

    def change_status(
        self,
        incident_id: UUID | str,
        new_status: IncidentStatus,
        change_by: AuthenticatedEntity,
    ) -> IncidentDto:

        self.logger.info(
            "Fetching incident",
            extra={
                "incident_id": incident_id,
                "tenant_id": self.tenant_id,
            },
        )

        with_alerts = new_status in [
            IncidentStatus.RESOLVED,
            IncidentStatus.ACKNOWLEDGED,
        ]
        incident = get_incident_by_id(
            self.tenant_id, incident_id, with_alerts=with_alerts, session=self.session
        )

        if not incident:
            raise HTTPException(status_code=404, detail="Incident not found")

        if new_status in [IncidentStatus.RESOLVED, IncidentStatus.ACKNOWLEDGED]:
            enrichments = {"status": new_status.value}
            fingerprints = [alert.fingerprint for alert in incident.alerts]
            enrichments_bl = EnrichmentsBl(self.tenant_id, db=self.session)
            (
                action_type,
                action_description,
                should_run_workflow,
                should_check_incidents_resolution,
            ) = enrichments_bl.get_enrichment_metadata(enrichments, change_by)
            enrichments_bl.batch_enrich(
                fingerprints,
                enrichments,
                action_type,
                change_by.email,
                action_description,
                dispose_on_new_alert=True,
            )

        if new_status == IncidentStatus.RESOLVED:
            end_time = datetime.now(tz=timezone.utc)
            incident.end_time = end_time

        if incident.assignee != change_by.email:
            incident.assignee = change_by.email
            add_audit(
                self.tenant_id,
                str(incident_id),
                change_by.email,
                ActionType.INCIDENT_ASSIGN,
                f"Incident self-assigned to {change_by.email}",
                session=self.session,
                commit=False,
            )

        add_audit(
            self.tenant_id,
            str(incident_id),
            change_by.email,
            ActionType.INCIDENT_STATUS_CHANGE,
            f"Incident status changed from {incident.status} to {new_status.value}",
            session=self.session,
            commit=False,
        )
        incident.status = new_status.value
        self.session.add(incident)
        self.session.commit()

        return self.__postprocess_incident_change(incident)
