from __future__ import annotations

from functools import partial
from socket import AddressFamily, SocketKind
from ssl import SSLContext
from typing import Any, NoReturn

import attr
import pytest

import trio
import trio.testing
from trio.abc import Stream
from trio.socket import AF_INET, IPPROTO_TCP, SOCK_STREAM

from .._highlevel_socket import SocketListener
from .._highlevel_ssl_helpers import (
    open_ssl_over_tcp_listeners,
    open_ssl_over_tcp_stream,
    serve_ssl_over_tcp,
)
from .._ssl import SSLListener

# using noqa because linters don't understand how pytest fixtures work.
from .test_ssl import SERVER_CTX, client_ctx  # noqa: F401


async def echo_handler(stream: Stream) -> None:
    async with stream:
        try:
            while True:
                data = await stream.receive_some(10000)
                if not data:
                    break
                await stream.send_all(data)
        except trio.BrokenResourceError:
            pass


# Resolver that always returns the given sockaddr, no matter what host/port
# you ask for.
@attr.s
class FakeHostnameResolver(trio.abc.HostnameResolver):
    sockaddr: tuple[str, int] | tuple[str, int, int, int] = attr.ib()

    async def getaddrinfo(
        self,
        host: bytes | str | None,
        port: bytes | str | int | None,
        family: int = 0,
        type: int = 0,
        proto: int = 0,
        flags: int = 0,
    ) -> list[
        tuple[
            AddressFamily,
            SocketKind,
            int,
            str,
            tuple[str, int] | tuple[str, int, int, int],
        ]
    ]:
        return [(AF_INET, SOCK_STREAM, IPPROTO_TCP, "", self.sockaddr)]

    async def getnameinfo(self, *args: Any) -> NoReturn:  # pragma: no cover
        raise NotImplementedError


# This uses serve_ssl_over_tcp, which uses open_ssl_over_tcp_listeners...
# using noqa because linters don't understand how pytest fixtures work.
async def test_open_ssl_over_tcp_stream_and_everything_else(
    client_ctx: SSLContext,  # noqa: F811 # linters doesn't understand fixture
) -> None:
    async with trio.open_nursery() as nursery:
        # TODO: this function wraps an SSLListener around a SocketListener, this is illegal
        # according to current type hints, and probably for good reason. But there should
        # maybe be a different wrapper class/function that could be used instead?
        res: list[SSLListener[SocketListener]] = (  # type: ignore[type-var]
            await nursery.start(
                partial(
                    serve_ssl_over_tcp, echo_handler, 0, SERVER_CTX, host="127.0.0.1"
                )
            )
        )
        (listener,) = res
        async with listener:
            # listener.transport_listener is of type Listener[Stream]
            tp_listener: SocketListener = listener.transport_listener  # type: ignore[assignment]

            sockaddr = tp_listener.socket.getsockname()
            hostname_resolver = FakeHostnameResolver(sockaddr)
            trio.socket.set_custom_hostname_resolver(hostname_resolver)

            # We don't have the right trust set up
            # (checks that ssl_context=None is doing some validation)
            stream = await open_ssl_over_tcp_stream("trio-test-1.example.org", 80)
            async with stream:
                with pytest.raises(trio.BrokenResourceError):
                    await stream.do_handshake()

            # We have the trust but not the hostname
            # (checks custom ssl_context + hostname checking)
            stream = await open_ssl_over_tcp_stream(
                "xyzzy.example.org", 80, ssl_context=client_ctx
            )
            async with stream:
                with pytest.raises(trio.BrokenResourceError):
                    await stream.do_handshake()

            # This one should work!
            stream = await open_ssl_over_tcp_stream(
                "trio-test-1.example.org", 80, ssl_context=client_ctx
            )
            async with stream:
                assert isinstance(stream, trio.SSLStream)
                assert stream.server_hostname == "trio-test-1.example.org"
                await stream.send_all(b"x")
                assert await stream.receive_some(1) == b"x"

            # Check https_compatible settings are being passed through
            assert not stream._https_compatible
            stream = await open_ssl_over_tcp_stream(
                "trio-test-1.example.org",
                80,
                ssl_context=client_ctx,
                https_compatible=True,
                # also, smoke test happy_eyeballs_delay
                happy_eyeballs_delay=1,
            )
            async with stream:
                assert stream._https_compatible

            # Stop the echo server
            nursery.cancel_scope.cancel()


async def test_open_ssl_over_tcp_listeners() -> None:
    (listener,) = await open_ssl_over_tcp_listeners(0, SERVER_CTX, host="127.0.0.1")
    async with listener:
        assert isinstance(listener, trio.SSLListener)
        tl = listener.transport_listener
        assert isinstance(tl, trio.SocketListener)
        assert tl.socket.getsockname()[0] == "127.0.0.1"

        assert not listener._https_compatible

    (listener,) = await open_ssl_over_tcp_listeners(
        0, SERVER_CTX, host="127.0.0.1", https_compatible=True
    )
    async with listener:
        assert listener._https_compatible
