Source code for quart.wrappers.request

from __future__ import annotations

import asyncio
import io
from cgi import FieldStorage, parse_header
from typing import Any, AnyStr, Awaitable, Callable, Generator, List, Optional, Union
from urllib.parse import parse_qs

from werkzeug.datastructures import CombinedMultiDict, Headers, MultiDict

from .base import BaseRequestWebsocket, JSONMixin
from ..datastructures import FileStorage
from ..json import dumps, loads

SERVER_PUSH_HEADERS_TO_COPY = {
    "accept",
    "accept-encoding",
    "accept-language",
    "cache-control",
    "user-agent",
}


class Body:
    """A request body container.

    The request body can either be iterated over and consumed in parts
    (without building up memory usage) or awaited.

    .. code-block:: python

        async for data in body:
            ...
        # or simply
        complete = await body

    Note: It is not possible to iterate over the data and then await
    it.
    """

    def __init__(
        self, expected_content_length: Optional[int], max_content_length: Optional[int]
    ) -> None:
        self._data = bytearray()
        self._complete: asyncio.Event = asyncio.Event()
        self._has_data: asyncio.Event = asyncio.Event()
        self._max_content_length = max_content_length
        # Exceptions must be raised within application (not ASGI)
        # calls, this is achieved by having the ASGI methods set this
        # to an exception on error.
        self._must_raise: Optional[Exception] = None
        if (
            expected_content_length is not None
            and max_content_length is not None
            and expected_content_length > max_content_length
        ):
            from ..exceptions import RequestEntityTooLarge  # noqa Avoiding circular import

            self._must_raise = RequestEntityTooLarge()

    def __aiter__(self) -> "Body":
        return self

    async def __anext__(self) -> bytes:
        if self._must_raise is not None:
            raise self._must_raise

        # if we got all of the data in the first shot, then self._complete is
        # set and self._has_data will not get set again, so skip the await
        # if we already have completed everything
        if not self._complete.is_set():
            await self._has_data.wait()

        if self._complete.is_set() and len(self._data) == 0:
            raise StopAsyncIteration()

        data = bytes(self._data)
        self._data.clear()
        self._has_data.clear()
        return data

    def __await__(self) -> Generator[Any, None, Any]:
        # Must check the _must_raise before and after waiting on the
        # completion event as it may change whilst waiting and the
        # event may not be set if there is already an issue.

        if self._must_raise is not None:
            raise self._must_raise

        yield from self._complete.wait().__await__()

        if self._must_raise is not None:
            raise self._must_raise
        return bytes(self._data)

    def append(self, data: bytes) -> None:
        if data == b"" or self._must_raise is not None:
            return
        self._data.extend(data)
        self._has_data.set()
        if self._max_content_length is not None and len(self._data) > self._max_content_length:
            from ..exceptions import RequestEntityTooLarge  # noqa Avoiding circular import

            self._must_raise = RequestEntityTooLarge()
            self.set_complete()

    def set_complete(self) -> None:
        self._complete.set()
        self._has_data.set()

    def set_result(self, data: bytes) -> None:
        """Convienience method, mainly for testing."""
        self.append(data)
        self.set_complete()


class Request(BaseRequestWebsocket, JSONMixin):
    """This class represents a request.

    It can be subclassed and the subclassed used in preference by
    replacing the :attr:`~quart.Quart.request_class` with your
    subclass.

    Attributes:
        body_class: The class to store the body data within.
    """

    body_class = Body

    def __init__(
        self,
        method: str,
        scheme: str,
        path: str,
        query_string: bytes,
        headers: Headers,
        root_path: str,
        http_version: str,
        *,
        max_content_length: Optional[int] = None,
        body_timeout: Optional[int] = None,
        send_push_promise: Callable[[str, Headers], Awaitable[None]],
        scope: Optional[dict] = None,
    ) -> None:
        """Create a request object.

        Arguments:
            method: The HTTP verb.
            scheme: The scheme used for the request.
            path: The full unquoted path of the request.
            query_string: The raw bytes for the query string part.
            headers: The request headers.
            root_path: The root path that should be prepended to all
                routes.
            http_version: The HTTP version of the request.
            body: An awaitable future for the body data i.e.
                ``data = await body``
            max_content_length: The maximum length in bytes of the
                body (None implies no limit in Quart).
            body_timeout: The maximum time (seconds) to wait for the
                body before timing out.
            send_push_promise: An awaitable to send a push promise based
                off of this request (HTTP/2 feature).
            scope: Underlying ASGI scope dictionary.
        """
        super().__init__(method, scheme, path, query_string, headers, root_path, http_version)
        self.body_timeout = body_timeout
        self.body = self.body_class(self.content_length, max_content_length)
        self._form: Optional[MultiDict] = None
        self._files: Optional[MultiDict] = None
        self._send_push_promise = send_push_promise
        self.scope = scope

    async def get_data(self, raw: bool = True) -> AnyStr:
        """The request body data."""
        try:
            body_future = asyncio.ensure_future(self.body)
            raw_data = await asyncio.wait_for(body_future, timeout=self.body_timeout)
        except asyncio.TimeoutError:
            body_future.cancel()
            try:
                await body_future
            except asyncio.CancelledError:
                pass

            from ..exceptions import RequestTimeout  # noqa Avoiding circular import

            raise RequestTimeout()

        if raw:
            return raw_data
        else:
            return raw_data.decode(self.charset)

    @property
    async def data(self) -> bytes:
        return await self.get_data()

    @property
    async def values(self) -> CombinedMultiDict:
        form = await self.form
        return CombinedMultiDict([self.args, form])

    @property
    async def form(self) -> MultiDict:
        """The parsed form encoded data.

        Note file data is present in the :attr:`files`.
        """
        if self._form is None:
            await self._load_form_data()
        return self._form

    @property
    async def files(self) -> MultiDict:
        """The parsed files.

        This will return an empty multidict unless the request
        mimetype was ``enctype="multipart/form-data"`` and the method
        POST, PUT, or PATCH.
        """
        if self._files is None:
            await self._load_form_data()
        return self._files

    async def _load_form_data(self) -> None:
        raw_data: bytes = await self.get_data(raw=True)
        self._form = MultiDict()
        self._files = MultiDict()
        content_header = self.content_type
        if content_header is None:
            return
        content_type, parameters = parse_header(content_header)
        if content_type == "application/x-www-form-urlencoded":
            try:
                data = raw_data.decode(parameters.get("charset", "utf-8"))
            except UnicodeDecodeError:
                from ..exceptions import BadRequest  # noqa Avoiding circular import

                raise BadRequest()
            for key, values in parse_qs(data, keep_blank_values=True).items():
                for value in values:
                    self._form.add(key, value)
        elif content_type == "multipart/form-data":
            field_storage = FieldStorage(
                io.BytesIO(raw_data),
                headers={name.lower(): value for name, value in self.headers.items()},
                environ={"REQUEST_METHOD": "POST"},
                limit=len(raw_data),
            )
            for key in field_storage:  # type: ignore
                field_storage_key = field_storage[key]
                if isinstance(field_storage_key, list):
                    for item in field_storage_key:
                        self._load_field_storage(key, item)
                else:
                    self._load_field_storage(key, field_storage_key)

    def _load_field_storage(self, key: str, field_storage: FieldStorage) -> None:
        if isinstance(field_storage, FieldStorage) and field_storage.filename is not None:
            self._files.add(
                key,
                FileStorage(
                    io.BytesIO(field_storage.file.read()),
                    field_storage.filename,
                    field_storage.name,  # type: ignore
                    field_storage.type,
                    field_storage.headers,  # type: ignore
                ),
            )
        else:
            self._form.add(key, field_storage.value)

    @property
    def content_encoding(self) -> Optional[str]:
        return self.headers.get("Content-Encoding")

    @property
    def content_length(self) -> Optional[int]:
        if "Content-Length" in self.headers:
            return int(self.headers["Content-Length"])
        else:
            return None

    @property
    def content_md5(self) -> Optional[str]:
        return self.headers.get("Content-md5")

    @property
    def content_type(self) -> Optional[str]:
        return self.headers.get("Content-Type")

    async def _load_json_data(self) -> str:
        """Return the data after decoding."""
        return await self.get_data(raw=False)

    async def send_push_promise(self, path: str) -> None:
        headers = Headers()
        for name in SERVER_PUSH_HEADERS_TO_COPY:
            for value in self.headers.getlist(name):
                headers.add(name, value)
        await self._send_push_promise(path, headers)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}({self.method}, {self.path})"


class Websocket(BaseRequestWebsocket):
    def __init__(
        self,
        path: str,
        query_string: bytes,
        scheme: str,
        headers: Headers,
        root_path: str,
        http_version: str,
        subprotocols: List[str],
        receive: Callable,
        send: Callable,
        accept: Callable,
    ) -> None:
        """Create a request object.

        Arguments:
            path: The full unquoted path of the request.
            query_string: The raw bytes for the query string part.
            scheme: The scheme used for the request.
            headers: The request headers.
            root_path: The root path that should be prepended to all
                routes.
            http_version: The HTTP version of the request.
            subprotocols: The subprotocols requested.
            receive: Returns an awaitable of the current data

            accept: Idempotent callable to accept the websocket connection.
        """
        super().__init__("GET", scheme, path, query_string, headers, root_path, http_version)
        self._accept = accept
        self._receive = receive
        self._send = send
        self._subprotocols = subprotocols

    @property
    def requested_subprotocols(self) -> List[str]:
        return self._subprotocols

    async def receive(self) -> AnyStr:
        await self.accept()
        return await self._receive()

    async def send(self, data: AnyStr) -> None:
        # Must allow for the event loop to act if the user has say
        # setup a tight loop sending data over a websocket (as in the
        # example). So yield via the sleep.
        await asyncio.sleep(0)
        await self.accept()
        await self._send(data)

    async def receive_json(self) -> Any:
        data = await self.receive()
        return loads(data)

    async def send_json(self, data: Any) -> None:
        raw = dumps(data)
        await self.send(raw)

    async def accept(
        self, headers: Optional[Union[dict, Headers]] = None, subprotocol: Optional[str] = None
    ) -> None:
        """Manually chose to accept the websocket connection.

        Arguments:
            headers: Additional headers to send with the acceptance
                response.
            subprotocol: The chosen subprotocol, optional.
        """
        if headers is None:
            headers_ = Headers()
        else:
            headers_ = Headers(headers)
        await self._accept(headers_, subprotocol)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}({self.path})"