from __future__ import annotations
import asyncio
import warnings
from functools import partial
from typing import Any, AnyStr, Callable, cast, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
from urllib.parse import urlparse
from werkzeug.datastructures import Headers
from .debug import traceback_response
from .wrappers import Request, Response, sentinel, Websocket # noqa: F401
if TYPE_CHECKING:
from .app import Quart # noqa: F401
class ASGIHTTPConnection:
def __init__(self, app: "Quart", scope: dict) -> None:
self.app = app
self.scope = scope
async def __call__(self, receive: Callable, send: Callable) -> None:
request = self._create_request_from_scope(send)
receiver_task = asyncio.ensure_future(self.handle_messages(request, receive))
handler_task = asyncio.ensure_future(self.handle_request(request, send))
done, pending = await asyncio.wait(
[handler_task, receiver_task], return_when=asyncio.FIRST_COMPLETED
)
await _cancel_tasks(pending)
_raise_exceptions(done)
async def handle_messages(self, request: Request, receive: Callable) -> None:
while True:
message = await receive()
if message["type"] == "http.request":
request.body.append(message.get("body", b""))
if not message.get("more_body", False):
request.body.set_complete()
elif message["type"] == "http.disconnect":
return
def _create_request_from_scope(self, send: Callable) -> Request:
headers = Headers()
headers["Remote-Addr"] = (self.scope.get("client") or ["<local>"])[0]
for name, value in self.scope["headers"]:
headers.add(name.decode("latin1").title(), value.decode("latin1"))
if self.scope["http_version"] < "1.1":
headers.setdefault("Host", self.app.config["SERVER_NAME"] or "")
path = self.scope["path"]
path = path if path[0] == "/" else urlparse(path).path
return self.app.request_class(
self.scope["method"],
self.scope["scheme"],
path,
self.scope["query_string"],
headers,
self.scope.get("root_path", ""),
self.scope["http_version"],
max_content_length=self.app.config["MAX_CONTENT_LENGTH"],
body_timeout=self.app.config["BODY_TIMEOUT"],
send_push_promise=partial(self._send_push_promise, send),
scope=self.scope,
)
async def handle_request(self, request: Request, send: Callable) -> None:
try:
response = await self.app.handle_request(request)
except Exception:
response = await traceback_response()
if response.timeout != sentinel:
timeout = cast(Optional[float], response.timeout)
else:
timeout = self.app.config["RESPONSE_TIMEOUT"]
try:
await asyncio.wait_for(self._send_response(send, response), timeout=timeout)
except asyncio.TimeoutError:
pass
async def _send_response(self, send: Callable, response: Response) -> None:
await send(
{
"type": "http.response.start",
"status": response.status_code,
"headers": _encode_headers(response.headers),
}
)
async with response.response as body:
async for data in body:
await send({"type": "http.response.body", "body": data, "more_body": True})
await send({"type": "http.response.body", "body": b"", "more_body": False})
async def _send_push_promise(self, send: Callable, path: str, headers: Headers) -> None:
if "http.response.push" in self.scope.get("extensions", {}):
await send(
{"type": "http.response.push", "path": path, "headers": _encode_headers(headers)}
)
class ASGIWebsocketConnection:
def __init__(self, app: "Quart", scope: dict) -> None:
self.app = app
self.scope = scope
self.queue: asyncio.Queue = asyncio.Queue()
self._accepted = False
async def __call__(self, receive: Callable, send: Callable) -> None:
websocket = self._create_websocket_from_scope(send)
receiver_task = asyncio.ensure_future(self.handle_messages(receive))
handler_task = asyncio.ensure_future(self.handle_websocket(websocket, send))
done, pending = await asyncio.wait(
[handler_task, receiver_task], return_when=asyncio.FIRST_COMPLETED
)
await _cancel_tasks(pending)
_raise_exceptions(done)
async def handle_messages(self, receive: Callable) -> None:
while True:
event = await receive()
if event["type"] == "websocket.receive":
await self.queue.put(event.get("bytes") or event["text"])
elif event["type"] == "websocket.disconnect":
return
def _create_websocket_from_scope(self, send: Callable) -> Websocket:
headers = Headers()
headers["Remote-Addr"] = (self.scope.get("client") or ["<local>"])[0]
for name, value in self.scope["headers"]:
headers.add(name.decode("latin1").title(), value.decode("latin1"))
path = self.scope["path"]
path = path if path[0] == "/" else urlparse(path).path
return self.app.websocket_class(
path,
self.scope["query_string"],
self.scope["scheme"],
headers,
self.scope.get("root_path", ""),
self.scope.get("http_version", "1.1"),
self.scope.get("subprotocols", []),
self.queue.get,
partial(self.send_data, send),
partial(self.accept_connection, send),
)
async def handle_websocket(self, websocket: Websocket, send: Callable) -> None:
response = await self.app.handle_websocket(websocket)
if response is not None and not self._accepted:
if "websocket.http.response" in self.scope.get("extensions", {}):
headers = [
(key.lower().encode(), value.encode())
for key, value in response.headers.items()
]
await send(
{
"type": "websocket.http.response.start",
"status": response.status_code,
"headers": headers,
}
)
async with response.response as body:
async for data in body:
await send(
{
"type": "websocket.http.response.body",
"body": data,
"more_body": True,
}
)
await send(
{"type": "websocket.http.response.body", "body": b"", "more_body": False}
)
else:
await send({"type": "websocket.close", "code": 1000})
elif self._accepted:
await send({"type": "websocket.close", "code": 1000})
async def send_data(self, send: Callable, data: AnyStr) -> None:
if isinstance(data, str):
await send({"type": "websocket.send", "text": data})
else:
await send({"type": "websocket.send", "bytes": data})
async def accept_connection(
self, send: Callable, headers: Headers, subprotocol: Optional[str]
) -> None:
if not self._accepted:
message: Dict[str, Any] = {"subprotocol": subprotocol, "type": "websocket.accept"}
spec_version = _convert_version(self.scope.get("asgi", {}).get("spec_version", "2.0"))
if spec_version > [2, 0]:
message["headers"] = _encode_headers(headers)
elif headers:
warnings.warn("The ASGI Server does not support accept headers, headers not sent")
await send(message)
self._accepted = True
class ASGILifespan:
def __init__(self, app: "Quart", scope: dict) -> None:
self.app = app
async def __call__(self, receive: Callable, send: Callable) -> None:
while True:
event = await receive()
if event["type"] == "lifespan.startup":
try:
await self.app.startup()
except Exception as error:
await send({"type": "lifespan.startup.failed", "message": str(error)})
else:
await send({"type": "lifespan.startup.complete"})
elif event["type"] == "lifespan.shutdown":
try:
await self.app.shutdown()
except Exception as error:
await send({"type": "lifespan.shutdown.failed", "message": str(error)})
else:
await send({"type": "lifespan.shutdown.complete"})
break
async def _cancel_tasks(tasks: Set[asyncio.Future]) -> None:
# Cancel any pending, and wait for the cancellation to
# complete i.e. finish any remaining work.
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
_raise_exceptions(tasks)
def _raise_exceptions(tasks: Set[asyncio.Future]) -> None:
# Raise any unexcepted exceptions
for task in tasks:
if not task.cancelled() and task.exception() is not None:
raise task.exception()
def _encode_headers(headers: Headers) -> List[Tuple[bytes, bytes]]:
return [(key.lower().encode(), value.encode()) for key, value in headers.items()]
def _convert_version(raw: str) -> List[int]:
return list(map(int, raw.split(".")))