from __future__ import annotations

import re
from importlib import import_module
from typing import TYPE_CHECKING, Any

from polars._dependencies import _PYARROW_AVAILABLE, import_optional
from polars._utils.various import parse_version
from polars.convert import from_arrow
from polars.exceptions import ModuleUpgradeRequiredError

if TYPE_CHECKING:
    from collections.abc import Coroutine

    from polars import DataFrame
    from polars._typing import SchemaDict


def _run_async(co: Coroutine[Any, Any, Any]) -> Any:
    """Run asynchronous code as if it was synchronous."""
    import asyncio

    try:
        running_loop = asyncio.get_running_loop()
    except RuntimeError:
        # no running loop; can use asyncio "as-is"
        return asyncio.run(co)
    else:
        # inside running loop; use vendored `nest_asyncio` (for now)
        import polars._utils.nest_asyncio

        polars._utils.nest_asyncio.apply()  # type: ignore[attr-defined]
        return running_loop.run_until_complete(co)


def _read_sql_connectorx(
    query: str | list[str],
    connection_uri: str,
    partition_on: str | None = None,
    partition_range: tuple[int, int] | None = None,
    partition_num: int | None = None,
    protocol: str | None = None,
    schema_overrides: SchemaDict | None = None,
    pre_execution_query: str | list[str] | None = None,
) -> DataFrame:
    cx = import_optional("connectorx")

    if parse_version(cx.__version__) < (0, 4, 2):
        if pre_execution_query:
            msg = "'pre_execution_query' is only supported in connectorx version 0.4.2 or later"
            raise ValueError(msg)
        return_type = "arrow2"
        pre_execution_args = {}
    else:
        return_type = "arrow"
        pre_execution_args = {"pre_execution_query": pre_execution_query}

    try:
        tbl = cx.read_sql(
            conn=connection_uri,
            query=query,
            return_type=return_type,
            partition_on=partition_on,
            partition_range=partition_range,
            partition_num=partition_num,
            protocol=protocol,
            **pre_execution_args,
        )
    except BaseException as err:
        # basic sanitisation of /user:pass/ credentials exposed in connectorx errs
        errmsg = re.sub("://[^:]+:[^:]+@", "://***:***@", str(err))
        raise type(err)(errmsg) from err

    return from_arrow(tbl, schema_overrides=schema_overrides)  # type: ignore[return-value]


def _read_sql_adbc(
    query: str,
    connection_uri: str,
    schema_overrides: SchemaDict | None,
    execute_options: dict[str, Any] | None = None,
) -> DataFrame:
    module_name = _get_adbc_module_name_from_uri(connection_uri)
    # import the driver first, to ensure a good error message if not installed
    _import_optional_adbc_driver(module_name, dbapi_submodule=False)
    adbc_driver_manager = import_optional("adbc_driver_manager")
    adbc_str_version = getattr(adbc_driver_manager, "__version__", "0.0")
    adbc_version = parse_version(adbc_str_version)

    # adbc_driver_manager must be >= 1.7.0 to support passing Python sequences into
    # parameterised queries (via execute_options) without PyArrow installed
    adbc_version_no_pyarrow_required = "1.7.0"
    has_required_adbc_version = adbc_version >= parse_version(
        adbc_version_no_pyarrow_required
    )

    if (
        execute_options is not None
        and not _PYARROW_AVAILABLE
        and not has_required_adbc_version
    ):
        msg = (
            "pyarrow is required for adbc-driver-manager < "
            f"{adbc_version_no_pyarrow_required} when using parameterized queries (via "
            f"`execute_options`), found {adbc_str_version}.\nEither upgrade "
            "`adbc-driver-manager` (suggested) or install `pyarrow`"
        )
        raise ModuleUpgradeRequiredError(msg)

    # From adbc_driver_manager version 1.6.0 Cursor.fetch_arrow() was introduced,
    # returning an object implementing the Arrow PyCapsule interface. This should be
    # used regardless of whether PyArrow is available.
    fetch_method_name = (
        "fetch_arrow" if adbc_version >= (1, 6, 0) else "fetch_arrow_table"
    )

    with _open_adbc_connection(connection_uri) as conn, conn.cursor() as cursor:
        cursor.execute(query, **(execute_options or {}))
        tbl = getattr(cursor, fetch_method_name)()
        return from_arrow(tbl, schema_overrides=schema_overrides)  # type: ignore[return-value]


def _get_adbc_driver_name_from_uri(connection_uri: str) -> str:
    driver_name = connection_uri.split(":", 1)[0].lower()
    # map uri prefix to ADBC name when not 1:1
    driver_suffix_map: dict[str, str] = {"postgres": "postgresql"}
    return driver_suffix_map.get(driver_name, driver_name)


def _get_adbc_module_name_from_uri(connection_uri: str) -> str:
    driver_name = _get_adbc_driver_name_from_uri(connection_uri)
    return f"adbc_driver_{driver_name}"


def _import_optional_adbc_driver(
    module_name: str,
    *,
    dbapi_submodule: bool = True,
) -> Any:
    # Always import top level module first. This will surface a better error for users
    # if the module does not exist. It doesn't negatively impact performance given the
    # dbapi submodule would also load it.
    adbc_driver = import_optional(
        module_name,
        err_prefix="ADBC",
        err_suffix="driver not detected",
        install_message=(
            "If ADBC supports this database, please run: pip install "
            f"{module_name.replace('_', '-')}"
        ),
    )
    if not dbapi_submodule:
        return adbc_driver
    # Importing the dbapi without pyarrow before adbc_driver_manager 1.6.0
    # raises ImportError: PyArrow is required for the DBAPI-compatible interface
    # Use importlib.import_module because Polars' import_optional clobbers this error
    try:
        adbc_driver_dbapi = import_module(f"{module_name}.dbapi")
    except ImportError as e:
        if "PyArrow is required for the DBAPI-compatible interface" in (str(e)):
            adbc_driver_manager = import_optional("adbc_driver_manager")
            adbc_str_version = getattr(adbc_driver_manager, "__version__", "0.0")

            msg = (
                "pyarrow is required for adbc-driver-manager < 1.6.0, found "
                f"{adbc_str_version}.\nEither upgrade `adbc-driver-manager` (suggested) or "
                "install `pyarrow`"
            )
            raise ModuleUpgradeRequiredError(msg) from None
        # if the error message was something different, re-raise it
        raise
    else:
        return adbc_driver_dbapi


def _open_adbc_connection(connection_uri: str) -> Any:
    driver_name = _get_adbc_driver_name_from_uri(connection_uri)
    module_name = _get_adbc_module_name_from_uri(connection_uri)
    adbc_driver = _import_optional_adbc_driver(module_name)

    # some backends require the driver name to be stripped from the URI
    if driver_name in ("duckdb", "snowflake", "sqlite"):
        connection_uri = re.sub(f"^{driver_name}:/{{,3}}", "", connection_uri)

    return adbc_driver.connect(connection_uri)


def _is_adbc_snowflake_conn(conn: Any) -> bool:
    import adbc_driver_manager

    # If PyArrow is available, prefer using the built in method
    if _PYARROW_AVAILABLE:
        return "snowflake" in conn.adbc_get_info()["vendor_name"].lower()
    # Otherwise, use a workaround checking a Snowflake specific ADBC option
    try:
        adbc_driver_snowflake = import_optional("adbc_driver_snowflake")

        return (
            "snowflake"
            in conn.adbc_database.get_option(
                adbc_driver_snowflake.DatabaseOptions.HOST.value
            ).lower()
        )
    except (ImportError, adbc_driver_manager.Error):
        return False
