"""AsyncIO support for zmq

Requires asyncio and Python 3.
"""

# Copyright (c) PyZMQ Developers.
# Distributed under the terms of the Modified BSD License.

import asyncio
import selectors
import sys
import warnings
from asyncio import SelectorEventLoop, Future
from typing import Union
from weakref import WeakKeyDictionary

import zmq as _zmq
from zmq import _future

# registry of asyncio loop : selector thread
_selectors: WeakKeyDictionary = WeakKeyDictionary()


class ProactorSelectorThreadWarning(RuntimeWarning):
    """Warning class for notifying about the extra thread spawned by tornado

    We automatically support proactor via tornado's AddThreadSelectorEventLoop"""


def _get_selector_windows(
    asyncio_loop,
) -> asyncio.AbstractEventLoop:
    """Get selector-compatible loop

    Returns an object with ``add_reader`` family of methods,
    either the loop itself or a SelectorThread instance.

    Workaround Windows proactor removal of
    *reader methods, which we need for zmq sockets.
    """

    if asyncio_loop in _selectors:
        return _selectors[asyncio_loop]

    # detect add_reader instead of checking for proactor?
    if hasattr(asyncio, "ProactorEventLoop") and isinstance(
        asyncio_loop, asyncio.ProactorEventLoop  # type: ignore
    ):
        try:
            from tornado.platform.asyncio import AddThreadSelectorEventLoop
        except ImportError:
            raise RuntimeError(
                "Proactor event loop does not implement add_reader family of methods required for zmq."
                " zmq will work with proactor if tornado >= 6.1 can be found."
                " Use `asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())`"
                " or install 'tornado>=6.1' to avoid this error."
            )

        warnings.warn(
            "Proactor event loop does not implement add_reader family of methods required for zmq."
            " Registering an additional selector thread for add_reader support via tornado."
            " Use `asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())`"
            " to avoid this warning.",
            RuntimeWarning,
            # stacklevel 5 matches most likely zmq.asyncio.Context().socket()
            stacklevel=5,
        )

        selector_loop = _selectors[asyncio_loop] = AddThreadSelectorEventLoop(asyncio_loop)  # type: ignore

        # patch loop.close to also close the selector thread
        loop_close = asyncio_loop.close

        def _close_selector_and_loop():
            # restore original before calling selector.close,
            # which in turn calls eventloop.close!
            asyncio_loop.close = loop_close
            _selectors.pop(asyncio_loop, None)
            selector_loop.close()

        asyncio_loop.close = _close_selector_and_loop
        return selector_loop
    else:
        return asyncio_loop


def _get_selector_noop(loop) -> asyncio.AbstractEventLoop:
    """no-op on non-Windows"""
    return loop


if sys.platform == "win32":
    _get_selector = _get_selector_windows
else:
    _get_selector = _get_selector_noop


class _AsyncIO(object):
    _Future = Future
    _WRITE = selectors.EVENT_WRITE
    _READ = selectors.EVENT_READ

    def _default_loop(self):
        return asyncio.get_event_loop()


class Poller(_AsyncIO, _future._AsyncPoller):
    """Poller returning asyncio.Future for poll results."""

    def _watch_raw_socket(self, loop, socket, evt, f):
        """Schedule callback for a raw socket"""
        selector = _get_selector(loop)
        if evt & self._READ:
            selector.add_reader(socket, lambda *args: f())
        if evt & self._WRITE:
            selector.add_writer(socket, lambda *args: f())

    def _unwatch_raw_sockets(self, loop, *sockets):
        """Unschedule callback for a raw socket"""
        selector = _get_selector(loop)
        for socket in sockets:
            selector.remove_reader(socket)
            selector.remove_writer(socket)


class Socket(_AsyncIO, _future._AsyncSocket):
    """Socket returning asyncio Futures for send/recv/poll methods."""

    _poller_class = Poller

    __selector = None

    @property
    def _selector(self):
        if self.__selector is None:
            self.__selector = _get_selector(self.io_loop)
        return self.__selector

    def _init_io_state(self):
        """initialize the ioloop event handler"""
        self._selector.add_reader(self._fd, lambda: self._handle_events(0, 0))

    def _clear_io_state(self):
        """clear any ioloop event handler

        called once at close
        """
        if not self.io_loop.is_closed():
            self._selector.remove_reader(self._fd)


Poller._socket_class = Socket


class Context(_zmq.Context):
    """Context for creating asyncio-compatible Sockets"""

    _socket_class = Socket

    # avoid sharing instance with base Context class
    _instance = None


class ZMQEventLoop(SelectorEventLoop):
    """DEPRECATED: AsyncIO eventloop using zmq_poll.

    pyzmq sockets should work with any asyncio event loop as of pyzmq 17.
    """

    def __init__(self, selector=None):
        _deprecated()
        return super(ZMQEventLoop, self).__init__(selector)


_loop = None


def _deprecated():
    if _deprecated.called:  # type: ignore
        return
    _deprecated.called = True  # type: ignore

    warnings.warn(
        "ZMQEventLoop and zmq.asyncio.install are deprecated in pyzmq 17. Special eventloop integration is no longer needed.",
        DeprecationWarning,
        stacklevel=3,
    )


_deprecated.called = False  # type: ignore


def install():
    """DEPRECATED: No longer needed in pyzmq 17"""
    _deprecated()


__all__ = [
    "Context",
    "Socket",
    "Poller",
    "ZMQEventLoop",
    "install",
]
