# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import abc
from typing import cast, Any, Optional, TypeVar

from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions
from .. import CredentialUnavailableError
from .._internal.managed_identity_client import ManagedIdentityClient
from .._internal.get_token_mixin import GetTokenMixin

T = TypeVar("T", bound="ManagedIdentityBase")


class ManagedIdentityBase(GetTokenMixin):
    """Base class for internal credentials using ManagedIdentityClient"""

    def __init__(self, **kwargs: Any) -> None:
        super(ManagedIdentityBase, self).__init__()
        self._client = self.get_client(**kwargs)

    @abc.abstractmethod
    def get_client(self, **kwargs: Any) -> Optional[ManagedIdentityClient]:
        pass

    @abc.abstractmethod
    def get_unavailable_message(self, desc: str = "") -> str:
        pass

    def __enter__(self: T) -> T:
        if self._client:
            self._client.__enter__()
        return self

    def __exit__(self, *args):
        if self._client:
            self._client.__exit__(*args)

    def close(self) -> None:
        self.__exit__()

    def get_token(
        self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs: Any
    ) -> AccessToken:
        if not self._client:
            raise CredentialUnavailableError(message=self.get_unavailable_message())
        return super(ManagedIdentityBase, self).get_token(*scopes, claims=claims, tenant_id=tenant_id, **kwargs)

    def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo:
        if not self._client:
            raise CredentialUnavailableError(message=self.get_unavailable_message())
        return super().get_token_info(*scopes, options=options)

    def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]:
        # casting because mypy can't determine that these methods are called
        # only by get_token, which raises when self._client is None
        return cast(ManagedIdentityClient, self._client).get_cached_token(*scopes, **kwargs)

    def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo:
        return cast(ManagedIdentityClient, self._client).request_token(*scopes, **kwargs)
