from __future__ import annotations
import asyncio
import io
from cgi import FieldStorage, parse_header
from typing import Any, AnyStr, Awaitable, Callable, Generator, List, Optional, Union
from urllib.parse import parse_qs
from werkzeug.datastructures import CombinedMultiDict, Headers, MultiDict
from .base import BaseRequestWebsocket, JSONMixin
from ..datastructures import FileStorage
from ..json import dumps, loads
SERVER_PUSH_HEADERS_TO_COPY = {
"accept",
"accept-encoding",
"accept-language",
"cache-control",
"user-agent",
}
class Body:
"""A request body container.
The request body can either be iterated over and consumed in parts
(without building up memory usage) or awaited.
.. code-block:: python
async for data in body:
...
# or simply
complete = await body
Note: It is not possible to iterate over the data and then await
it.
"""
def __init__(
self, expected_content_length: Optional[int], max_content_length: Optional[int]
) -> None:
self._data = bytearray()
self._complete: asyncio.Event = asyncio.Event()
self._has_data: asyncio.Event = asyncio.Event()
self._max_content_length = max_content_length
# Exceptions must be raised within application (not ASGI)
# calls, this is achieved by having the ASGI methods set this
# to an exception on error.
self._must_raise: Optional[Exception] = None
if (
expected_content_length is not None
and max_content_length is not None
and expected_content_length > max_content_length
):
from ..exceptions import RequestEntityTooLarge # noqa Avoiding circular import
self._must_raise = RequestEntityTooLarge()
def __aiter__(self) -> "Body":
return self
async def __anext__(self) -> bytes:
if self._must_raise is not None:
raise self._must_raise
# if we got all of the data in the first shot, then self._complete is
# set and self._has_data will not get set again, so skip the await
# if we already have completed everything
if not self._complete.is_set():
await self._has_data.wait()
if self._complete.is_set() and len(self._data) == 0:
raise StopAsyncIteration()
data = bytes(self._data)
self._data.clear()
self._has_data.clear()
return data
def __await__(self) -> Generator[Any, None, Any]:
# Must check the _must_raise before and after waiting on the
# completion event as it may change whilst waiting and the
# event may not be set if there is already an issue.
if self._must_raise is not None:
raise self._must_raise
yield from self._complete.wait().__await__()
if self._must_raise is not None:
raise self._must_raise
return bytes(self._data)
def append(self, data: bytes) -> None:
if data == b"" or self._must_raise is not None:
return
self._data.extend(data)
self._has_data.set()
if self._max_content_length is not None and len(self._data) > self._max_content_length:
from ..exceptions import RequestEntityTooLarge # noqa Avoiding circular import
self._must_raise = RequestEntityTooLarge()
self.set_complete()
def set_complete(self) -> None:
self._complete.set()
self._has_data.set()
def set_result(self, data: bytes) -> None:
"""Convienience method, mainly for testing."""
self.append(data)
self.set_complete()
class Request(BaseRequestWebsocket, JSONMixin):
"""This class represents a request.
It can be subclassed and the subclassed used in preference by
replacing the :attr:`~quart.Quart.request_class` with your
subclass.
Attributes:
body_class: The class to store the body data within.
"""
body_class = Body
def __init__(
self,
method: str,
scheme: str,
path: str,
query_string: bytes,
headers: Headers,
root_path: str,
http_version: str,
*,
max_content_length: Optional[int] = None,
body_timeout: Optional[int] = None,
send_push_promise: Callable[[str, Headers], Awaitable[None]],
scope: Optional[dict] = None,
) -> None:
"""Create a request object.
Arguments:
method: The HTTP verb.
scheme: The scheme used for the request.
path: The full unquoted path of the request.
query_string: The raw bytes for the query string part.
headers: The request headers.
root_path: The root path that should be prepended to all
routes.
http_version: The HTTP version of the request.
body: An awaitable future for the body data i.e.
``data = await body``
max_content_length: The maximum length in bytes of the
body (None implies no limit in Quart).
body_timeout: The maximum time (seconds) to wait for the
body before timing out.
send_push_promise: An awaitable to send a push promise based
off of this request (HTTP/2 feature).
scope: Underlying ASGI scope dictionary.
"""
super().__init__(method, scheme, path, query_string, headers, root_path, http_version)
self.body_timeout = body_timeout
self.body = self.body_class(self.content_length, max_content_length)
self._form: Optional[MultiDict] = None
self._files: Optional[MultiDict] = None
self._send_push_promise = send_push_promise
self.scope = scope
async def get_data(self, raw: bool = True) -> AnyStr:
"""The request body data."""
try:
body_future = asyncio.ensure_future(self.body)
raw_data = await asyncio.wait_for(body_future, timeout=self.body_timeout)
except asyncio.TimeoutError:
body_future.cancel()
try:
await body_future
except asyncio.CancelledError:
pass
from ..exceptions import RequestTimeout # noqa Avoiding circular import
raise RequestTimeout()
if raw:
return raw_data
else:
return raw_data.decode(self.charset)
@property
async def data(self) -> bytes:
return await self.get_data()
@property
async def values(self) -> CombinedMultiDict:
form = await self.form
return CombinedMultiDict([self.args, form])
@property
async def form(self) -> MultiDict:
"""The parsed form encoded data.
Note file data is present in the :attr:`files`.
"""
if self._form is None:
await self._load_form_data()
return self._form
@property
async def files(self) -> MultiDict:
"""The parsed files.
This will return an empty multidict unless the request
mimetype was ``enctype="multipart/form-data"`` and the method
POST, PUT, or PATCH.
"""
if self._files is None:
await self._load_form_data()
return self._files
async def _load_form_data(self) -> None:
raw_data: bytes = await self.get_data(raw=True)
self._form = MultiDict()
self._files = MultiDict()
content_header = self.content_type
if content_header is None:
return
content_type, parameters = parse_header(content_header)
if content_type == "application/x-www-form-urlencoded":
try:
data = raw_data.decode(parameters.get("charset", "utf-8"))
except UnicodeDecodeError:
from ..exceptions import BadRequest # noqa Avoiding circular import
raise BadRequest()
for key, values in parse_qs(data, keep_blank_values=True).items():
for value in values:
self._form.add(key, value)
elif content_type == "multipart/form-data":
field_storage = FieldStorage(
io.BytesIO(raw_data),
headers={name.lower(): value for name, value in self.headers.items()},
environ={"REQUEST_METHOD": "POST"},
limit=len(raw_data),
)
for key in field_storage: # type: ignore
field_storage_key = field_storage[key]
if isinstance(field_storage_key, list):
for item in field_storage_key:
self._load_field_storage(key, item)
else:
self._load_field_storage(key, field_storage_key)
def _load_field_storage(self, key: str, field_storage: FieldStorage) -> None:
if isinstance(field_storage, FieldStorage) and field_storage.filename is not None:
self._files.add(
key,
FileStorage(
io.BytesIO(field_storage.file.read()),
field_storage.filename,
field_storage.name, # type: ignore
field_storage.type,
field_storage.headers, # type: ignore
),
)
else:
self._form.add(key, field_storage.value)
@property
def content_encoding(self) -> Optional[str]:
return self.headers.get("Content-Encoding")
@property
def content_length(self) -> Optional[int]:
if "Content-Length" in self.headers:
return int(self.headers["Content-Length"])
else:
return None
@property
def content_md5(self) -> Optional[str]:
return self.headers.get("Content-md5")
@property
def content_type(self) -> Optional[str]:
return self.headers.get("Content-Type")
async def _load_json_data(self) -> str:
"""Return the data after decoding."""
return await self.get_data(raw=False)
async def send_push_promise(self, path: str) -> None:
headers = Headers()
for name in SERVER_PUSH_HEADERS_TO_COPY:
for value in self.headers.getlist(name):
headers.add(name, value)
await self._send_push_promise(path, headers)
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.method}, {self.path})"
class Websocket(BaseRequestWebsocket):
def __init__(
self,
path: str,
query_string: bytes,
scheme: str,
headers: Headers,
root_path: str,
http_version: str,
subprotocols: List[str],
receive: Callable,
send: Callable,
accept: Callable,
) -> None:
"""Create a request object.
Arguments:
path: The full unquoted path of the request.
query_string: The raw bytes for the query string part.
scheme: The scheme used for the request.
headers: The request headers.
root_path: The root path that should be prepended to all
routes.
http_version: The HTTP version of the request.
subprotocols: The subprotocols requested.
receive: Returns an awaitable of the current data
accept: Idempotent callable to accept the websocket connection.
"""
super().__init__("GET", scheme, path, query_string, headers, root_path, http_version)
self._accept = accept
self._receive = receive
self._send = send
self._subprotocols = subprotocols
@property
def requested_subprotocols(self) -> List[str]:
return self._subprotocols
async def receive(self) -> AnyStr:
await self.accept()
return await self._receive()
async def send(self, data: AnyStr) -> None:
# Must allow for the event loop to act if the user has say
# setup a tight loop sending data over a websocket (as in the
# example). So yield via the sleep.
await asyncio.sleep(0)
await self.accept()
await self._send(data)
async def receive_json(self) -> Any:
data = await self.receive()
return loads(data)
async def send_json(self, data: Any) -> None:
raw = dumps(data)
await self.send(raw)
async def accept(
self, headers: Optional[Union[dict, Headers]] = None, subprotocol: Optional[str] = None
) -> None:
"""Manually chose to accept the websocket connection.
Arguments:
headers: Additional headers to send with the acceptance
response.
subprotocol: The chosen subprotocol, optional.
"""
if headers is None:
headers_ = Headers()
else:
headers_ = Headers(headers)
await self._accept(headers_, subprotocol)
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.path})"