Source code for quart.routing

from __future__ import annotations

import warnings
from typing import Iterable, Optional

from werkzeug.routing import Map, MapAdapter, Rule

from .wrappers.base import BaseRequestWebsocket


class QuartRule(Rule):
    def __init__(
        self,
        string: str,
        defaults: Optional[dict] = None,
        subdomain: Optional[str] = None,
        methods: Optional[Iterable[str]] = None,
        endpoint: Optional[str] = None,
        strict_slashes: Optional[bool] = None,
        merge_slashes: Optional[bool] = None,
        host: Optional[str] = None,
        websocket: bool = False,
        provide_automatic_options: bool = False,
    ) -> None:
        super().__init__(  # type: ignore
            string,
            defaults=defaults,
            subdomain=subdomain,
            methods=methods,
            endpoint=endpoint,
            strict_slashes=strict_slashes,
            merge_slashes=merge_slashes,
            host=host,
            websocket=websocket,
        )
        self.provide_automatic_options = provide_automatic_options


class QuartMap(Map):
    def bind_to_request(
        self, request: BaseRequestWebsocket, subdomain: Optional[str], server_name: Optional[str]
    ) -> MapAdapter:
        host: str
        if server_name is None:
            host = request.host.lower()
        else:
            host = server_name.lower()

        host = _normalise_host(request.scheme, host)

        if subdomain is None and not self.host_matching:
            request_host_parts = _normalise_host(request.scheme, request.host).split(".")
            config_host_parts = host.split(".")
            offset = -len(config_host_parts)

            if request_host_parts[offset:] != config_host_parts:
                warnings.warn(
                    f"Current server name '{request.host}' doesn't match configured"
                    f" server name '{host}'",
                    stacklevel=2,
                )
                subdomain = "<invalid>"
            else:
                subdomain = ".".join(filter(None, request_host_parts[:offset]))

        return super().bind(
            host,
            request.root_path,
            subdomain,
            request.scheme,
            request.method,
            request.path,
            request.query_string.decode(),
        )


def _normalise_host(scheme: str, host: str) -> str:
    # It is not common to write port 80 or 443 for a hostname,
    # so strip it if present.
    if scheme in {"http", "ws"} and host.endswith(":80"):
        return host[:-3]
    elif scheme in {"https", "wss"} and host.endswith(":443"):
        return host[:-4]
    else:
        return host