# -*- coding: utf-8 -*-

import pytest

from urllib3 import HTTPConnectionPool
from urllib3.util.retry import Retry
from dummyserver.testcase import SocketDummyServerTestCase, consume_socket

# Retry failed tests
pytestmark = pytest.mark.flaky


class TestChunkedTransfer(SocketDummyServerTestCase):
    def start_chunked_handler(self):
        self.buffer = b""

        def socket_handler(listener):
            sock = listener.accept()[0]

            while not self.buffer.endswith(b"\r\n0\r\n\r\n"):
                self.buffer += sock.recv(65536)

            sock.send(
                b"HTTP/1.1 200 OK\r\n"
                b"Content-type: text/plain\r\n"
                b"Content-Length: 0\r\n"
                b"\r\n"
            )
            sock.close()

        self._start_server(socket_handler)

    def test_chunks(self):
        self.start_chunked_handler()
        chunks = ["foo", "bar", "", "bazzzzzzzzzzzzzzzzzzzzzz"]
        with HTTPConnectionPool(self.host, self.port, retries=False) as pool:
            pool.urlopen("GET", "/", chunks, headers=dict(DNT="1"), chunked=True)

            assert b"Transfer-Encoding" in self.buffer
            body = self.buffer.split(b"\r\n\r\n", 1)[1]
            lines = body.split(b"\r\n")
            # Empty chunks should have been skipped, as this could not be distinguished
            # from terminating the transmission
            for i, chunk in enumerate([c for c in chunks if c]):
                assert lines[i * 2] == hex(len(chunk))[2:].encode("utf-8")
                assert lines[i * 2 + 1] == chunk.encode("utf-8")

    def _test_body(self, data):
        self.start_chunked_handler()
        with HTTPConnectionPool(self.host, self.port, retries=False) as pool:
            pool.urlopen("GET", "/", data, chunked=True)
            header, body = self.buffer.split(b"\r\n\r\n", 1)

            assert b"Transfer-Encoding: chunked" in header.split(b"\r\n")
            if data:
                bdata = data if isinstance(data, bytes) else data.encode("utf-8")
                assert b"\r\n" + bdata + b"\r\n" in body
                assert body.endswith(b"\r\n0\r\n\r\n")

                len_str = body.split(b"\r\n", 1)[0]
                stated_len = int(len_str, 16)
                assert stated_len == len(bdata)
            else:
                assert body == b"0\r\n\r\n"

    def test_bytestring_body(self):
        self._test_body(b"thisshouldbeonechunk\r\nasdf")

    def test_unicode_body(self):
        self._test_body(u"thisshouldbeonechunk\r\näöüß")

    def test_empty_body(self):
        self._test_body(None)

    def test_empty_string_body(self):
        self._test_body("")

    def test_empty_iterable_body(self):
        self._test_body([])

    def test_removes_duplicate_host_header(self):
        self.start_chunked_handler()
        chunks = ["foo", "bar", "", "bazzzzzzzzzzzzzzzzzzzzzz"]
        with HTTPConnectionPool(self.host, self.port, retries=False) as pool:
            pool.urlopen("GET", "/", chunks, headers={"Host": "test.org"}, chunked=True)

            header_block = self.buffer.split(b"\r\n\r\n", 1)[0].lower()
            header_lines = header_block.split(b"\r\n")[1:]

            host_headers = [x for x in header_lines if x.startswith(b"host")]
            assert len(host_headers) == 1

    def test_provides_default_host_header(self):
        self.start_chunked_handler()
        chunks = ["foo", "bar", "", "bazzzzzzzzzzzzzzzzzzzzzz"]
        with HTTPConnectionPool(self.host, self.port, retries=False) as pool:
            pool.urlopen("GET", "/", chunks, chunked=True)

            header_block = self.buffer.split(b"\r\n\r\n", 1)[0].lower()
            header_lines = header_block.split(b"\r\n")[1:]

            host_headers = [x for x in header_lines if x.startswith(b"host")]
            assert len(host_headers) == 1

    def test_preserve_chunked_on_retry(self):
        self.chunked_requests = 0

        def socket_handler(listener):
            for _ in range(2):
                sock = listener.accept()[0]
                request = consume_socket(sock)
                if b"Transfer-Encoding: chunked" in request.split(b"\r\n"):
                    self.chunked_requests += 1

                sock.send(
                    b"HTTP/1.1 429 Too Many Requests\r\n"
                    b"Content-Type: text/plain\r\n"
                    b"Retry-After: 1\r\n"
                    b"\r\n"
                )
                sock.close()

        self._start_server(socket_handler)
        with HTTPConnectionPool(self.host, self.port) as pool:
            retries = Retry(total=1)
            pool.urlopen(
                "GET", "/", chunked=True, preload_content=False, retries=retries
            )
        assert self.chunked_requests == 2

    def test_preserve_chunked_on_redirect(self):
        self.chunked_requests = 0

        def socket_handler(listener):
            for i in range(2):
                sock = listener.accept()[0]
                request = consume_socket(sock)
                if b"Transfer-Encoding: chunked" in request.split(b"\r\n"):
                    self.chunked_requests += 1

                if i == 0:
                    sock.send(
                        b"HTTP/1.1 301 Moved Permanently\r\n"
                        b"Location: /redirect\r\n\r\n"
                    )
                else:
                    sock.send(b"HTTP/1.1 200 OK\r\n\r\n")
                sock.close()

        self._start_server(socket_handler)
        with HTTPConnectionPool(self.host, self.port) as pool:
            retries = Retry(redirect=1)
            pool.urlopen(
                "GET", "/", chunked=True, preload_content=False, retries=retries
            )
        assert self.chunked_requests == 2

    def test_preserve_chunked_on_broken_connection(self):
        self.chunked_requests = 0

        def socket_handler(listener):
            for i in range(2):
                sock = listener.accept()[0]
                request = consume_socket(sock)
                if b"Transfer-Encoding: chunked" in request.split(b"\r\n"):
                    self.chunked_requests += 1

                if i == 0:
                    # Bad HTTP version will trigger a connection close
                    sock.send(b"HTTP/0.5 200 OK\r\n\r\n")
                else:
                    sock.send(b"HTTP/1.1 200 OK\r\n\r\n")
                sock.close()

        self._start_server(socket_handler)
        with HTTPConnectionPool(self.host, self.port) as pool:
            retries = Retry(read=1)
            pool.urlopen(
                "GET", "/", chunked=True, preload_content=False, retries=retries
            )
        assert self.chunked_requests == 2
