"""
Policy Resolver - Resolves final guardrail list from policies.

Handles:
- Inheritance chain resolution (inherit with add/remove)
- Applying add/remove guardrails
- Evaluating model conditions
- Combining guardrails from multiple matching policies
"""

from typing import Dict, List, Optional, Set, Tuple

from litellm._logging import verbose_proxy_logger
from litellm.types.proxy.policy_engine import (GuardrailPipeline, Policy,
                                               PolicyMatchContext,
                                               ResolvedPolicy)


class PolicyResolver:
    """
    Resolves the final list of guardrails from policies.

    Handles:
    - Inheritance chains with add/remove operations
    - Model-based conditions
    """

    @staticmethod
    def resolve_inheritance_chain(
        policy_name: str,
        policies: Dict[str, Policy],
        visited: Optional[Set[str]] = None,
    ) -> List[str]:
        """
        Get the inheritance chain for a policy (from root to policy).

        Args:
            policy_name: Name of the policy
            policies: Dictionary of all policies
            visited: Set of visited policies (for cycle detection)

        Returns:
            List of policy names from root ancestor to the given policy
        """
        if visited is None:
            visited = set()

        if policy_name in visited:
            verbose_proxy_logger.warning(
                f"Circular inheritance detected for policy '{policy_name}'"
            )
            return []

        policy = policies.get(policy_name)
        if policy is None:
            return []

        visited.add(policy_name)

        if policy.inherit:
            parent_chain = PolicyResolver.resolve_inheritance_chain(
                policy_name=policy.inherit, policies=policies, visited=visited
            )
            return parent_chain + [policy_name]

        return [policy_name]

    @staticmethod
    def resolve_policy_guardrails(
        policy_name: str,
        policies: Dict[str, Policy],
        context: Optional[PolicyMatchContext] = None,
    ) -> ResolvedPolicy:
        """
        Resolve the final guardrails for a single policy, including inheritance.

        This method:
        1. Resolves the inheritance chain
        2. Applies add/remove from each policy in the chain
        3. Evaluates model conditions (if context provided)

        Args:
            policy_name: Name of the policy to resolve
            policies: Dictionary of all policies
            context: Optional request context for evaluating conditions

        Returns:
            ResolvedPolicy with final guardrails list
        """
        from litellm.proxy.policy_engine.condition_evaluator import \
            ConditionEvaluator

        inheritance_chain = PolicyResolver.resolve_inheritance_chain(
            policy_name=policy_name, policies=policies
        )

        # Start with empty set of guardrails
        guardrails: Set[str] = set()

        # Apply each policy in the chain (from root to leaf)
        for chain_policy_name in inheritance_chain:
            policy = policies.get(chain_policy_name)
            if policy is None:
                continue

            # Check if policy condition matches (if context provided)
            if context is not None and policy.condition is not None:
                if not ConditionEvaluator.evaluate(
                    condition=policy.condition,
                    context=context,
                ):
                    verbose_proxy_logger.debug(
                        f"Policy '{chain_policy_name}' condition did not match, skipping guardrails"
                    )
                    continue

            # Add guardrails from guardrails.add
            for guardrail in policy.guardrails.get_add():
                guardrails.add(guardrail)

            # Remove guardrails from guardrails.remove
            for guardrail in policy.guardrails.get_remove():
                guardrails.discard(guardrail)

        return ResolvedPolicy(
            policy_name=policy_name,
            guardrails=list(guardrails),
            inheritance_chain=inheritance_chain,
        )

    @staticmethod
    def resolve_guardrails_for_context(
        context: PolicyMatchContext,
        policies: Optional[Dict[str, Policy]] = None,
        policy_names: Optional[List[str]] = None,
    ) -> List[str]:
        """
        Resolve the final list of guardrails for a request context.

        This:
        1. Finds all policies that match the context via policy_attachments (or policy_names if provided)
        2. Resolves each policy's guardrails (including inheritance)
        3. Evaluates model conditions
        4. Combines all guardrails (union)

        Args:
            context: The request context
            policies: Dictionary of all policies (if None, uses global registry)
            policy_names: If provided, use this list instead of attachment matching

        Returns:
            List of guardrail names to apply
        """
        from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher
        from litellm.proxy.policy_engine.policy_registry import \
            get_policy_registry

        if policies is None:
            registry = get_policy_registry()
            if not registry.is_initialized():
                return []
            policies = registry.get_all_policies()

        # Use provided policy names or get matching policies via attachments
        matching_policy_names = (
            policy_names
            if policy_names is not None
            else PolicyMatcher.get_matching_policies(context=context)
        )

        if not matching_policy_names:
            verbose_proxy_logger.debug(
                f"No policies match context: team_alias={context.team_alias}, "
                f"key_alias={context.key_alias}, model={context.model}"
            )
            return []

        # Resolve each matching policy and combine guardrails
        all_guardrails: Set[str] = set()

        for policy_name in matching_policy_names:
            resolved = PolicyResolver.resolve_policy_guardrails(
                policy_name=policy_name,
                policies=policies,
                context=context,
            )
            all_guardrails.update(resolved.guardrails)
            verbose_proxy_logger.debug(
                f"Policy '{policy_name}' contributes guardrails: {resolved.guardrails}"
            )

        result = list(all_guardrails)
        verbose_proxy_logger.debug(
            f"Final guardrails for context: {result}"
        )

        return result

    @staticmethod
    def resolve_pipelines_for_context(
        context: PolicyMatchContext,
        policies: Optional[Dict[str, Policy]] = None,
        policy_names: Optional[List[str]] = None,
    ) -> List[Tuple[str, GuardrailPipeline]]:
        """
        Resolve pipelines from matching policies for a request context.

        Returns (policy_name, pipeline) tuples for policies that have pipelines.
        Guardrails managed by pipelines should be excluded from the flat
        guardrails list to avoid double execution.

        Args:
            context: The request context
            policies: Dictionary of all policies (if None, uses global registry)
            policy_names: If provided, use this list instead of attachment matching

        Returns:
            List of (policy_name, GuardrailPipeline) tuples
        """
        from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher
        from litellm.proxy.policy_engine.policy_registry import \
            get_policy_registry

        if policies is None:
            registry = get_policy_registry()
            if not registry.is_initialized():
                return []
            policies = registry.get_all_policies()

        matching_policy_names = (
            policy_names
            if policy_names is not None
            else PolicyMatcher.get_matching_policies(context=context)
        )
        if not matching_policy_names:
            return []

        pipelines: List[Tuple[str, GuardrailPipeline]] = []
        for policy_name in matching_policy_names:
            policy = policies.get(policy_name)
            if policy is None:
                continue
            if policy.pipeline is not None:
                pipelines.append((policy_name, policy.pipeline))
                verbose_proxy_logger.debug(
                    f"Policy '{policy_name}' has pipeline with "
                    f"{len(policy.pipeline.steps)} steps"
                )

        return pipelines

    @staticmethod
    def get_pipeline_managed_guardrails(
        pipelines: List[Tuple[str, GuardrailPipeline]],
    ) -> Set[str]:
        """
        Get the set of guardrail names managed by pipelines.

        These guardrails should be excluded from normal independent execution.
        """
        managed: Set[str] = set()
        for _policy_name, pipeline in pipelines:
            for step in pipeline.steps:
                managed.add(step.guardrail)
        return managed

    @staticmethod
    def get_all_resolved_policies(
        policies: Optional[Dict[str, Policy]] = None,
        context: Optional[PolicyMatchContext] = None,
    ) -> Dict[str, ResolvedPolicy]:
        """
        Resolve all policies and return their final guardrails.

        Useful for debugging and displaying policy configurations.

        Args:
            policies: Dictionary of all policies (if None, uses global registry)
            context: Optional context for evaluating conditions

        Returns:
            Dictionary mapping policy names to ResolvedPolicy objects
        """
        from litellm.proxy.policy_engine.policy_registry import \
            get_policy_registry

        if policies is None:
            registry = get_policy_registry()
            if not registry.is_initialized():
                return {}
            policies = registry.get_all_policies()

        resolved: Dict[str, ResolvedPolicy] = {}

        for policy_name in policies:
            resolved[policy_name] = PolicyResolver.resolve_policy_guardrails(
                policy_name=policy_name,
                policies=policies,
                context=context,
            )

        return resolved
