from __future__ import annotations
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from email.utils import parsedate_to_datetime
from hashlib import md5
from inspect import isasyncgen, isgenerator
from io import BytesIO
from os import PathLike
from types import TracebackType
from typing import (
AnyStr,
AsyncGenerator,
AsyncIterable,
AsyncIterator,
Iterable,
Optional,
Tuple,
Union,
)
from wsgiref.handlers import format_date_time
from aiofiles import open as async_open
from aiofiles.base import AiofilesContextManager
from aiofiles.threadpool import AsyncFileIO
from werkzeug.datastructures import ( # type: ignore
ContentRange,
ContentSecurityPolicy,
Headers,
HeaderSet,
Range,
ResponseCacheControl,
)
from werkzeug.http import ( # type: ignore
dump_cookie,
dump_csp_header,
dump_header,
parse_cache_control_header,
parse_content_range_header,
parse_csp_header,
parse_set_header,
)
from .base import _BaseRequestResponse, JSONMixin
from ..utils import file_path_to_path, run_sync_iterable
sentinel = object()
class ResponseBody(ABC):
"""Base class wrapper for response body data.
This ensures that the following is possible (as Quart assumes so
when returning the body to the ASGI server
async with wrapper as response:
async for data in response:
send(data)
"""
@abstractmethod
async def __aenter__(self) -> AsyncIterable:
pass
@abstractmethod
async def __aexit__(self, exc_type: type, exc_value: BaseException, tb: TracebackType) -> None:
pass
@abstractmethod
async def convert_to_sequence(self) -> bytes:
pass
def _raise_if_invalid_range(begin: int, end: int, size: int) -> None:
if begin >= end or abs(begin) > size or end > size:
from ..exceptions import RequestRangeNotSatisfiable
raise RequestRangeNotSatisfiable()
class DataBody(ResponseBody):
def __init__(self, data: bytes) -> None:
self.data = data
self.begin = 0
self.end = len(self.data)
async def __aenter__(self) -> "DataBody":
return self
async def __aexit__(self, exc_type: type, exc_value: BaseException, tb: TracebackType) -> None:
pass
def __aiter__(self) -> AsyncIterator:
async def _aiter() -> AsyncGenerator[bytes, None]:
yield self.data[self.begin : self.end]
return _aiter()
async def convert_to_sequence(self) -> bytes:
return self.data[self.begin : self.end]
async def make_conditional(
self, begin: int, end: Optional[int], max_partial_size: Optional[int] = None
) -> int:
self.begin = begin
self.end = len(self.data) if end is None else end
if max_partial_size is not None:
self.end = min(self.begin + max_partial_size, self.end)
_raise_if_invalid_range(self.begin, self.end, len(self.data))
return len(self.data)
class IterableBody(ResponseBody):
def __init__(self, iterable: Union[AsyncGenerator[bytes, None], Iterable]) -> None:
self.iter: AsyncGenerator[bytes, None]
if isasyncgen(iterable):
self.iter = iterable # type: ignore
elif isgenerator(iterable):
self.iter = run_sync_iterable(iterable) # type: ignore
else:
async def _aiter() -> AsyncGenerator[bytes, None]:
for data in iterable: # type: ignore
yield data
self.iter = _aiter()
async def __aenter__(self) -> "IterableBody":
return self
async def __aexit__(self, exc_type: type, exc_value: BaseException, tb: TracebackType) -> None:
await self.iter.aclose()
def __aiter__(self) -> AsyncIterator:
return self.iter
async def convert_to_sequence(self) -> bytes:
result = bytearray()
async for data in self.iter:
result.extend(data)
return bytes(result)
class FileBody(ResponseBody):
"""Provides an async file accessor with range setting.
The :attr:`Response.response` attribute must be async-iterable and
yield bytes, which this wrapper does for a file. In addition it
allows a range to be set on the file, thereby supporting
conditional requests.
Attributes:
buffer_size: Size in bytes to load per iteration.
"""
buffer_size = 8192
def __init__(
self, file_path: Union[str, PathLike], *, buffer_size: Optional[int] = None
) -> None:
self.file_path = file_path_to_path(file_path)
self.size = self.file_path.stat().st_size
self.begin = 0
self.end = self.size
if buffer_size is not None:
self.buffer_size = buffer_size
self.file: Optional[AsyncFileIO] = None
self.file_manager: Optional[AiofilesContextManager] = None
async def __aenter__(self) -> "FileBody":
self.file_manager = async_open(self.file_path, mode="rb")
self.file = await self.file_manager.__aenter__()
await self.file.seek(self.begin)
return self
async def __aexit__(self, exc_type: type, exc_value: BaseException, tb: TracebackType) -> None:
await self.file_manager.__aexit__(exc_type, exc_value, tb)
def __aiter__(self) -> "FileBody":
return self
async def __anext__(self) -> bytes:
current = await self.file.tell()
if current >= self.end:
raise StopAsyncIteration()
read_size = min(self.buffer_size, self.end - current)
chunk = await self.file.read(read_size)
if chunk:
return chunk
else:
raise StopAsyncIteration()
async def convert_to_sequence(self) -> bytes:
result = bytearray()
async with self as response:
async for data in response:
result.extend(data)
return bytes(result)
async def make_conditional(
self, begin: int, end: Optional[int], max_partial_size: Optional[int] = None
) -> int:
self.begin = begin
self.end = self.size if end is None else end
if max_partial_size is not None:
self.end = min(self.begin + max_partial_size, self.end)
_raise_if_invalid_range(self.begin, self.end, self.size)
return self.size
class IOBody(ResponseBody):
"""Provides an async file accessor with range setting.
The :attr:`Response.response` attribute must be async-iterable and
yield bytes, which this wrapper does for a file. In addition it
allows a range to be set on the file, thereby supporting
conditional requests.
Attributes:
buffer_size: Size in bytes to load per iteration.
"""
buffer_size = 8192
def __init__(self, io_stream: BytesIO, *, buffer_size: Optional[int] = None) -> None:
self.io_stream = io_stream
self.size = io_stream.getbuffer().nbytes
self.begin = 0
self.end = self.size
if buffer_size is not None:
self.buffer_size = buffer_size
async def __aenter__(self) -> "IOBody":
self.io_stream.seek(self.begin)
return self
async def __aexit__(self, exc_type: type, exc_value: BaseException, tb: TracebackType) -> None:
return None
def __aiter__(self) -> "IOBody":
return self
async def __anext__(self) -> bytes:
current = self.io_stream.tell()
if current >= self.end:
raise StopAsyncIteration()
read_size = min(self.buffer_size, self.end - current)
chunk = self.io_stream.read(read_size)
if chunk:
return chunk
else:
raise StopAsyncIteration()
async def convert_to_sequence(self) -> bytes:
result = bytearray()
async with self as response:
async for data in response:
result.extend(data)
return bytes(result)
async def make_conditional(
self, begin: int, end: Optional[int], max_partial_size: Optional[int] = None
) -> int:
self.begin = begin
self.end = self.size if end is None else end
if max_partial_size is not None:
self.end = min(self.begin + max_partial_size, self.end)
_raise_if_invalid_range(self.begin, self.end, self.size)
return self.size
class Response(_BaseRequestResponse, JSONMixin):
"""This class represents a response.
It can be subclassed and the subclassed used in preference by
replacing the :attr:`~quart.Quart.response_class` with your
subclass.
Attributes:
automatically_set_content_length: If False the content length
header must be provided.
default_status: The status code to use if not provided.
default_mimetype: The mimetype to use if not provided.
implicit_sequence_conversion: Implicitly convert the response
to a iterable in the get_data method, to allow multiple
iterations.
"""
automatically_set_content_length = True
default_status = 200
default_mimetype = "text/html"
data_body_class = DataBody
file_body_class = FileBody
implicit_sequence_conversion = True
io_body_class = IOBody
iterable_body_class = IterableBody
max_cookie_size = 4093
def __init__(
self,
response: Union[ResponseBody, AnyStr, Iterable],
status: Optional[int] = None,
headers: Optional[Union[dict, Headers]] = None,
mimetype: Optional[str] = None,
content_type: Optional[str] = None,
) -> None:
"""Create a response object.
The response itself can be a chunk of data or a
iterable/generator of data chunks.
The Content-Type can either be specified as a mimetype or
content_type header or omitted to use the
:attr:`default_mimetype`.
Arguments:
response: The response data or iterable over the data.
status: Status code of the response.
headers: Headers to attach to the response.
mimetype: Mimetype of the response.
content_type: Content-Type header value.
Attributes:
response: An iterable of the response bytes-data.
"""
super().__init__(headers)
self.timeout: Union[int, None, object] = sentinel
if status is None:
status = self.default_status
try:
self.status_code = int(status)
except ValueError as error:
raise ValueError("Quart does not support non-integer status values") from error
if content_type is None:
if mimetype is None and "content-type" not in self.headers:
mimetype = self.default_mimetype
if mimetype is not None:
self.mimetype = mimetype
if content_type is not None:
self.headers["Content-Type"] = content_type
self.response: ResponseBody
if isinstance(response, ResponseBody):
self.response = response
elif isinstance(response, (str, bytes)):
self.set_data(response) # type: ignore
else:
self.response = self.iterable_body_class(response)
async def get_data(self, raw: bool = True) -> AnyStr:
"""Return the body data."""
if self.implicit_sequence_conversion:
self.response = self.data_body_class(await self.response.convert_to_sequence())
result = b"" if raw else ""
async with self.response as body: # type: ignore
async for data in body:
if raw:
result += data
else:
result += data.decode(self.charset)
return result # type: ignore
def set_data(self, data: AnyStr) -> None:
"""Set the response data.
This will encode using the :attr:`charset`.
"""
if isinstance(data, str):
bytes_data = data.encode(self.charset)
else:
bytes_data = data
self.response = self.data_body_class(bytes_data)
if self.automatically_set_content_length:
self.content_length = len(bytes_data)
async def make_conditional(
self, request_range: Optional[Range], max_partial_size: Optional[int] = None
) -> None:
"""Make the response conditional to the
Arguments:
request_range: The range as requested by the request.
max_partial_size: The maximum length the server is willing
to serve in a single response. Defaults to unlimited.
"""
self.accept_ranges = "bytes" # Advertise this ability
if request_range is None or len(request_range.ranges) == 0: # Not a conditional request
return
if request_range.units != "bytes" or len(request_range.ranges) > 1:
from ..exceptions import RequestRangeNotSatisfiable
raise RequestRangeNotSatisfiable()
begin, end = request_range.ranges[0]
try:
complete_length = await self.response.make_conditional( # type: ignore
begin, end, max_partial_size
)
except AttributeError:
self.response = self.data_body_class(await self.response.convert_to_sequence())
return await self.make_conditional(request_range, max_partial_size)
else:
self.content_length = self.response.end - self.response.begin # type: ignore
if self.content_length != complete_length:
self.content_range = ContentRange(
request_range.units,
self.response.begin, # type: ignore
self.response.end - 1, # type: ignore
complete_length,
)
self.status_code = 206
async def freeze(self) -> None:
"""Freeze this object ready for pickling."""
self.set_data((await self.get_data()))
def set_cookie(
self,
key: str,
value: AnyStr = "", # type: ignore
max_age: Optional[Union[int, timedelta]] = None,
expires: Optional[Union[int, float, datetime]] = None,
path: str = "/",
domain: Optional[str] = None,
secure: bool = False,
httponly: bool = False,
samesite: str = None,
) -> None:
"""Set a cookie in the response headers.
The arguments are the standard cookie morsels and this is a
wrapper around the stdlib SimpleCookie code.
"""
if isinstance(value, bytes):
value = value.decode() # type: ignore
self.headers.add(
"Set-Cookie",
dump_cookie( # type: ignore
key,
value=value,
max_age=max_age,
expires=expires,
path=path,
domain=domain,
secure=secure,
httponly=httponly,
charset=self.charset,
max_size=self.max_cookie_size,
samesite=samesite,
),
)
def delete_cookie(
self,
key: str,
path: str = "/",
domain: Optional[str] = None,
secure: bool = False,
httponly: bool = False,
samesite: str = None,
) -> None:
"""Delete a cookie (set to expire immediately)."""
self.set_cookie(
key,
expires=0,
max_age=0,
path=path,
domain=domain,
secure=secure,
httponly=httponly,
samesite=samesite,
)
async def add_etag(self, overwrite: bool = False, weak: bool = False) -> None:
if overwrite or "etag" not in self.headers:
self.set_etag(md5((await self.get_data())).hexdigest(), weak)
def get_etag(self) -> Tuple[Optional[str], Optional[bool]]:
etag = self.headers.get("ETag")
if etag is None:
return None, None
else:
weak = False
if etag.upper().startswith("W/"):
etag = etag[2:]
return etag.strip('"'), weak
def set_etag(self, etag: str, weak: bool = False) -> None:
if weak:
self.headers["ETag"] = f'W/"{etag}"'
else:
self.headers["ETag"] = f'"{etag}"'
@property
def access_control_allow_credentials(self) -> bool:
"""Whether credentials can be shared by the browser to
JavaScript code. As part of the preflight request it indicates
whether credentials can be used on the cross origin request.
"""
return "Access-Control-Allow-Credentials" in self.headers
@access_control_allow_credentials.setter
def access_control_allow_credentials(self, value: bool) -> None:
if value is True:
self.headers["Access-Control-Allow-Credentials"] = "true"
else:
self.headers.pop("Access-Control-Allow-Credentials", None) # type: ignore
@property
def access_control_allow_headers(self) -> Optional[HeaderSet]:
if "Access-Control-Allow-Headers" in self.headers:
return parse_set_header(self.headers["Access-Control-Allow-Headers"])
return None
@access_control_allow_headers.setter
def access_control_allow_headers(self, value: HeaderSet) -> None:
self.headers["Access-Control-Allow-Headers"] = dump_header(value)
@property
def access_control_allow_methods(self) -> Optional[HeaderSet]:
if "Access-Control-Allow-Methods" in self.headers:
return parse_set_header(self.headers["Access-Control-Allow-Methods"])
return None
@access_control_allow_methods.setter
def access_control_allow_methods(self, value: HeaderSet) -> None:
self.headers["Access-Control-Allow-Methods"] = dump_header(value)
@property
def access_control_allow_origin(self) -> Optional[str]:
return self.headers.get("Access-Control-Allow-Origin")
@access_control_allow_origin.setter
def access_control_allow_origin(self, value: str) -> None:
self.headers["Access-Control-Allow-Origin"] = value
@property
def access_control_expose_headers(self) -> Optional[HeaderSet]:
if "Access-Control-Expose-Headers" in self.headers:
return parse_set_header(self.headers["Access-Control-Expose-Headers"])
return None
@access_control_expose_headers.setter
def access_control_expose_headers(self, value: HeaderSet) -> None:
self.headers["Access-Control-Expose-Headers"] = dump_header(value)
@property
def access_control_max_age(self) -> Optional[int]:
if "Access-Control-Max-Age" in self.headers:
return int(self.headers["Access-Control-Max-Age"])
return None
@access_control_max_age.setter
def access_control_max_age(self, value: int) -> None:
self.headers["Access-Control-Max-Age"] = str(value)
@property
def accept_ranges(self) -> Optional[str]:
return self.headers.get("Accept-Ranges")
@accept_ranges.setter
def accept_ranges(self, value: str) -> None:
self.headers["Accept-Ranges"] = value
@property
def age(self) -> Optional[int]:
try:
value = int(self.headers.get("Age", ""))
except (TypeError, ValueError):
return None
return value if value > 0 else None
@age.setter
def age(self, value: Union[int, timedelta]) -> None:
if isinstance(value, timedelta):
self.headers["Age"] = str(value.total_seconds())
else:
self.headers["Age"] = str(value)
@property
def allow(self) -> HeaderSet:
def on_update(header_set: HeaderSet) -> None:
self.allow = header_set
return parse_set_header(self.headers.get("Allow"), on_update=on_update)
@allow.setter
def allow(self, value: HeaderSet) -> None:
self._set_or_pop_header("Allow", value.to_header())
@property
def cache_control(self) -> ResponseCacheControl:
def on_update(cache_control: ResponseCacheControl) -> None:
self.cache_control = cache_control
return parse_cache_control_header(
self.headers.get("Cache-Control"), on_update, ResponseCacheControl
)
@cache_control.setter
def cache_control(self, value: ResponseCacheControl) -> None:
self._set_or_pop_header("Cache-Control", value.to_header())
@property
def content_encoding(self) -> Optional[str]:
return self.headers.get("Content-Encoding")
@content_encoding.setter
def content_encoding(self, value: str) -> None:
self.headers["Content-Encoding"] = value
@property
def content_language(self) -> HeaderSet:
def on_update(header_set: HeaderSet) -> None:
self.content_language = header_set
return parse_set_header(self.headers.get("Content-Language"), on_update=on_update)
@content_language.setter
def content_language(self, value: HeaderSet) -> None:
self._set_or_pop_header("Content-Language", value.to_header())
@property
def content_length(self) -> Optional[int]:
try:
return int(self.headers.get("Content-Length"))
except (ValueError, TypeError):
return None
@content_length.setter
def content_length(self, value: int) -> None:
self.headers["Content-Length"] = str(value)
@property
def content_location(self) -> Optional[str]:
return self.headers.get("Content-Location")
@content_location.setter
def content_location(self, value: str) -> None:
self.headers["Content-Location"] = value
@property
def content_md5(self) -> Optional[str]:
return self.headers.get("Content-MD5")
@content_md5.setter
def content_md5(self, value: str) -> None:
self.headers["Content-MD5"] = value
@property
def content_range(self) -> ContentRange:
def on_update(cache_range: ContentRange) -> None:
self.content_range = cache_range
return parse_content_range_header(self.headers.get("Content-Range"), on_update)
@content_range.setter
def content_range(self, value: ContentRange) -> None:
self._set_or_pop_header("Content-Range", value.to_header())
@property
def content_security_policy(self) -> ContentSecurityPolicy:
def on_update(content_security_policy: ContentSecurityPolicy) -> None:
self.content_security_policy = content_security_policy
return parse_csp_header(self.headers.get("Content-Security-Policy"), on_update)
@content_security_policy.setter
def content_security_policy(self, value: ContentSecurityPolicy) -> None:
self._set_or_pop_header("Content-Security-Policy", dump_csp_header(value))
@property
def content_security_policy_report_only(self) -> ContentSecurityPolicy:
def on_update(content_security_policy: ContentSecurityPolicy) -> None:
self.content_security_policy_report_only = content_security_policy
return parse_csp_header(self.headers.get("Content-Security-Policy-Report-Only"), on_update)
@content_security_policy_report_only.setter
def content_security_policy_report_only(self, value: ContentSecurityPolicy) -> None:
self._set_or_pop_header("Content-Security-Policy-Report-Only", value.to_header())
@property
def content_type(self) -> Optional[str]:
return self.headers.get("Content-Type")
@content_type.setter
def content_type(self, value: str) -> None:
self.headers["Content-Type"] = value
@property
def date(self) -> Optional[datetime]:
try:
return parsedate_to_datetime(self.headers.get("Date", ""))
except TypeError: # Not a date format
return None
@date.setter
def date(self, value: datetime) -> None:
self.headers["Date"] = format_date_time(value.timestamp())
@property
def expires(self) -> Optional[datetime]:
try:
return parsedate_to_datetime(self.headers.get("Expires", ""))
except TypeError: # Not a date format
return None
@expires.setter
def expires(self, value: datetime) -> None:
self.headers["Expires"] = format_date_time(value.timestamp())
@property
def last_modified(self) -> Optional[datetime]:
try:
return parsedate_to_datetime(self.headers.get("Last-Modified", ""))
except TypeError: # Not a date format
return None
@last_modified.setter
def last_modified(self, value: datetime) -> None:
self.headers["Last-Modified"] = format_date_time(value.timestamp())
@property
def location(self) -> Optional[str]:
return self.headers.get("Location")
@location.setter
def location(self, value: str) -> None:
self.headers["Location"] = value
@property
def referrer(self) -> Optional[str]:
return self.headers.get("Referer")
@referrer.setter
def referrer(self, value: str) -> None:
self.headers["Referer"] = value
@property
def retry_after(self) -> Optional[datetime]:
value = self.headers.get("Retry-After", "")
if value.isdigit():
return datetime.utcnow() + timedelta(seconds=int(value))
else:
try:
return parsedate_to_datetime(value)
except TypeError:
return None
@retry_after.setter
def retry_after(self, value: Union[datetime, int]) -> None:
if isinstance(value, datetime):
self.headers["Retry-After"] = format_date_time(value.timestamp())
else:
self.headers["Retry-After"] = str(value)
@property
def vary(self) -> HeaderSet:
def on_update(header_set: HeaderSet) -> None:
self.vary = header_set
return parse_set_header(self.headers.get("Vary"), on_update=on_update)
@vary.setter
def vary(self, value: HeaderSet) -> None:
self._set_or_pop_header("Vary", value.to_header())
async def _load_json_data(self) -> str:
"""Return the data after decoding."""
return await self.get_data(raw=False)
def _set_or_pop_header(self, key: str, value: str) -> None:
if value == "":
self.headers.pop(key, None) # type: ignore
else:
self.headers[key] = value
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.status_code})"