diff --git a/src/ahttpx/__init__.py b/src/ahttpx/__init__.py index aafb928..478a50b 100644 --- a/src/ahttpx/__init__.py +++ b/src/ahttpx/__init__.py @@ -7,7 +7,7 @@ from ._pool import * # Connection, ConnectionPool, Transport from ._quickstart import * # get, post, put, patch, delete from ._response import * # Response -from ._request import * # Request +from ._request import * # Method, Request from ._streams import * # ByteStream, DuplexStream, FileStream, Stream from ._server import * # serve_http, run from ._urlencode import * # quote, unquote, urldecode, urlencode @@ -34,6 +34,7 @@ "HTTPParser", "HTTPStream", "JSON", + "Method", "MultiPart", "NetworkBackend", "NetworkStream", @@ -57,9 +58,3 @@ "urldecode", "urlencode", ] - - -__locals = locals() -for __name in __all__: - if not __name.startswith('__'): - setattr(__locals[__name], "__module__", "httpx") diff --git a/src/ahttpx/_client.py b/src/ahttpx/_client.py index 6326ac5..45c46b9 100644 --- a/src/ahttpx/_client.py +++ b/src/ahttpx/_client.py @@ -4,7 +4,7 @@ from ._content import Content from ._headers import Headers from ._pool import ConnectionPool, Transport -from ._request import Request +from ._request import Method, Request from ._response import Response from ._streams import Stream from ._urls import URL @@ -33,7 +33,7 @@ def __init__( def build_request( self, - method: str, + method: Method | str, url: URL | str, headers: Headers | typing.Mapping[str, str] | None = None, content: Content | Stream | bytes | None = None, @@ -47,7 +47,7 @@ def build_request( async def request( self, - method: str, + method: Method | str, url: URL | str, headers: Headers | typing.Mapping[str, str] | None = None, content: Content | Stream | bytes | None = None, @@ -59,7 +59,7 @@ async def request( async def stream( self, - method: str, + method: Method | str, url: URL | str, headers: Headers | typing.Mapping[str, str] | None = None, content: Content | Stream | bytes | None = None, diff --git a/src/ahttpx/_pool.py b/src/ahttpx/_pool.py index ff12246..a322c84 100644 --- a/src/ahttpx/_pool.py +++ b/src/ahttpx/_pool.py @@ -7,7 +7,7 @@ from ._network import Lock, NetworkBackend, Semaphore from ._parsers import HTTPParser, HTTPStream from ._response import Response -from ._request import Request +from ._request import Method, Request from ._streams import Stream from ._urls import URL @@ -29,7 +29,7 @@ async def close(self): async def request( self, - method: str, + method: Method | str, url: URL | str, headers: Headers | dict[str, str] | None = None, content: Content | Stream | bytes | None = None, @@ -41,7 +41,7 @@ async def request( async def stream( self, - method: str, + method: Method | str, url: URL | str, headers: Headers | dict[str, str] | None = None, content: Content | Stream | bytes | None = None, @@ -141,7 +141,7 @@ async def __aexit__( class Connection(Transport): def __init__(self, stream: Stream, origin: URL | str): self._stream = stream - self._origin = URL(origin) + self._origin = URL(origin) if not isinstance(origin, URL) else origin self._keepalive_duration = 5.0 self._idle_expiry = time.monotonic() + self._keepalive_duration self._request_lock = Lock() @@ -183,7 +183,7 @@ async def close(self) -> None: # Top-level API for working directly with a connection. async def request( self, - method: str, + method: Method | str, url: URL | str, headers: Headers | typing.Mapping[str, str] | None = None, content: Content | Stream | bytes | None = None, @@ -196,7 +196,7 @@ async def request( async def stream( self, - method: str, + method: Method | str, url: URL | str, headers: Headers | typing.Mapping[str, str] | None = None, content: Content | Stream | bytes | None = None, @@ -207,7 +207,7 @@ async def stream( # Send the request... async def _send_head(self, request: Request) -> None: - method = request.method.encode('ascii') + method = bytes(request.method) target = request.url.target.encode('ascii') protocol = b'HTTP/1.1' await self._parser.send_method_line(method, target, protocol) diff --git a/src/ahttpx/_request.py b/src/ahttpx/_request.py index 78b8228..b5f1345 100644 --- a/src/ahttpx/_request.py +++ b/src/ahttpx/_request.py @@ -9,17 +9,38 @@ __all__ = ["Request"] +class Method: + def __init__(self, method: str, standard=True): + if standard: + method = method.upper() + if method not in ("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE"): + raise ValueError("Non-standard method {method!r}") + self._method = method + + def __eq__(self, other) -> bool: + return str(self) == str(other) + + def __bytes__(self) -> bytes: + return self._method.encode('ascii') + + def __str__(self) -> str: + return self._method + + def __repr__(self): + return "" + + class Request: def __init__( self, - method: str, + method: Method | str, url: URL | str, headers: Headers | typing.Mapping[str, str] | None = None, content: Content | Stream | bytes | None = None, ): - self.method = method - self.url = URL(url) - self.headers = Headers(headers) + self.method = Method(method) if not isinstance(method, Method) else method + self.url = URL(url) if not isinstance(url, URL) else url + self.headers = Headers(headers) if not isinstance(headers, Headers) else headers self.stream: Stream = ByteStream(b"") # https://datatracker.ietf.org/doc/html/rfc2616#section-14.23 diff --git a/src/ahttpx/_response.py b/src/ahttpx/_response.py index db1de83..d695488 100644 --- a/src/ahttpx/_response.py +++ b/src/ahttpx/_response.py @@ -85,7 +85,7 @@ def __init__( content: Content | Stream | bytes | None = None, ): self.status_code = status_code - self.headers = Headers(headers) + self.headers = Headers(headers) if not isinstance(headers, Headers) else headers self.stream: Stream = ByteStream(b"") if content is not None: diff --git a/src/httpx/__init__.py b/src/httpx/__init__.py index aafb928..478a50b 100644 --- a/src/httpx/__init__.py +++ b/src/httpx/__init__.py @@ -7,7 +7,7 @@ from ._pool import * # Connection, ConnectionPool, Transport from ._quickstart import * # get, post, put, patch, delete from ._response import * # Response -from ._request import * # Request +from ._request import * # Method, Request from ._streams import * # ByteStream, DuplexStream, FileStream, Stream from ._server import * # serve_http, run from ._urlencode import * # quote, unquote, urldecode, urlencode @@ -34,6 +34,7 @@ "HTTPParser", "HTTPStream", "JSON", + "Method", "MultiPart", "NetworkBackend", "NetworkStream", @@ -57,9 +58,3 @@ "urldecode", "urlencode", ] - - -__locals = locals() -for __name in __all__: - if not __name.startswith('__'): - setattr(__locals[__name], "__module__", "httpx") diff --git a/src/httpx/_client.py b/src/httpx/_client.py index 2dd54fd..aaf8ad9 100644 --- a/src/httpx/_client.py +++ b/src/httpx/_client.py @@ -4,7 +4,7 @@ from ._content import Content from ._headers import Headers from ._pool import ConnectionPool, Transport -from ._request import Request +from ._request import Method, Request from ._response import Response from ._streams import Stream from ._urls import URL @@ -33,7 +33,7 @@ def __init__( def build_request( self, - method: str, + method: Method | str, url: URL | str, headers: Headers | typing.Mapping[str, str] | None = None, content: Content | Stream | bytes | None = None, @@ -47,7 +47,7 @@ def build_request( def request( self, - method: str, + method: Method | str, url: URL | str, headers: Headers | typing.Mapping[str, str] | None = None, content: Content | Stream | bytes | None = None, @@ -59,7 +59,7 @@ def request( def stream( self, - method: str, + method: Method | str, url: URL | str, headers: Headers | typing.Mapping[str, str] | None = None, content: Content | Stream | bytes | None = None, diff --git a/src/httpx/_pool.py b/src/httpx/_pool.py index 71cb942..f537be8 100644 --- a/src/httpx/_pool.py +++ b/src/httpx/_pool.py @@ -7,7 +7,7 @@ from ._network import Lock, NetworkBackend, Semaphore from ._parsers import HTTPParser, HTTPStream from ._response import Response -from ._request import Request +from ._request import Method, Request from ._streams import Stream from ._urls import URL @@ -29,7 +29,7 @@ def close(self): def request( self, - method: str, + method: Method | str, url: URL | str, headers: Headers | dict[str, str] | None = None, content: Content | Stream | bytes | None = None, @@ -41,7 +41,7 @@ def request( def stream( self, - method: str, + method: Method | str, url: URL | str, headers: Headers | dict[str, str] | None = None, content: Content | Stream | bytes | None = None, @@ -141,7 +141,7 @@ def __exit__( class Connection(Transport): def __init__(self, stream: Stream, origin: URL | str): self._stream = stream - self._origin = URL(origin) + self._origin = URL(origin) if not isinstance(origin, URL) else origin self._keepalive_duration = 5.0 self._idle_expiry = time.monotonic() + self._keepalive_duration self._request_lock = Lock() @@ -183,7 +183,7 @@ def close(self) -> None: # Top-level API for working directly with a connection. def request( self, - method: str, + method: Method | str, url: URL | str, headers: Headers | typing.Mapping[str, str] | None = None, content: Content | Stream | bytes | None = None, @@ -196,7 +196,7 @@ def request( def stream( self, - method: str, + method: Method | str, url: URL | str, headers: Headers | typing.Mapping[str, str] | None = None, content: Content | Stream | bytes | None = None, @@ -207,7 +207,7 @@ def stream( # Send the request... def _send_head(self, request: Request) -> None: - method = request.method.encode('ascii') + method = bytes(request.method) target = request.url.target.encode('ascii') protocol = b'HTTP/1.1' self._parser.send_method_line(method, target, protocol) diff --git a/src/httpx/_request.py b/src/httpx/_request.py index 1b739b1..3cf030c 100644 --- a/src/httpx/_request.py +++ b/src/httpx/_request.py @@ -9,17 +9,38 @@ __all__ = ["Request"] +class Method: + def __init__(self, method: str, standard=True): + if standard: + method = method.upper() + if method not in ("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE"): + raise ValueError("Non-standard method {method!r}") + self._method = method + + def __eq__(self, other) -> bool: + return str(self) == str(other) + + def __bytes__(self) -> bytes: + return self._method.encode('ascii') + + def __str__(self) -> str: + return self._method + + def __repr__(self): + return "" + + class Request: def __init__( self, - method: str, + method: Method | str, url: URL | str, headers: Headers | typing.Mapping[str, str] | None = None, content: Content | Stream | bytes | None = None, ): - self.method = method - self.url = URL(url) - self.headers = Headers(headers) + self.method = Method(method) if not isinstance(method, Method) else method + self.url = URL(url) if not isinstance(url, URL) else url + self.headers = Headers(headers) if not isinstance(headers, Headers) else headers self.stream: Stream = ByteStream(b"") # https://datatracker.ietf.org/doc/html/rfc2616#section-14.23 diff --git a/src/httpx/_response.py b/src/httpx/_response.py index abfec81..0dc43ff 100644 --- a/src/httpx/_response.py +++ b/src/httpx/_response.py @@ -85,7 +85,7 @@ def __init__( content: Content | Stream | bytes | None = None, ): self.status_code = status_code - self.headers = Headers(headers) + self.headers = Headers(headers) if not isinstance(headers, Headers) else headers self.stream: Stream = ByteStream(b"") if content is not None: diff --git a/tests/test_ahttpx/test_client.py b/tests/test_ahttpx/test_client.py index f7be2c2..a059733 100644 --- a/tests/test_ahttpx/test_client.py +++ b/tests/test_ahttpx/test_client.py @@ -6,7 +6,7 @@ async def echo(request): await request.read() response = ahttpx.Response(200, content=ahttpx.JSON({ - 'method': request.method, + 'method': str(request.method), 'query-params': dict(request.url.params.items()), 'content-type': request.headers.get('Content-Type'), 'json': json.loads(request.body) if request.body else None, diff --git a/tests/test_ahttpx/test_quickstart.py b/tests/test_ahttpx/test_quickstart.py index ef3963c..e996482 100644 --- a/tests/test_ahttpx/test_quickstart.py +++ b/tests/test_ahttpx/test_quickstart.py @@ -6,7 +6,7 @@ async def echo(request): await request.read() response = ahttpx.Response(200, content=ahttpx.JSON({ - 'method': request.method, + 'method': str(request.method), 'query-params': dict(request.url.params.items()), 'content-type': request.headers.get('Content-Type'), 'json': json.loads(request.body) if request.body else None, diff --git a/tests/test_httpx/test_client.py b/tests/test_httpx/test_client.py index 6aa76f5..b558319 100644 --- a/tests/test_httpx/test_client.py +++ b/tests/test_httpx/test_client.py @@ -6,7 +6,7 @@ def echo(request): request.read() response = httpx.Response(200, content=httpx.JSON({ - 'method': request.method, + 'method': str(request.method), 'query-params': dict(request.url.params.items()), 'content-type': request.headers.get('Content-Type'), 'json': json.loads(request.body) if request.body else None, diff --git a/tests/test_httpx/test_quickstart.py b/tests/test_httpx/test_quickstart.py index 55c34b1..6843049 100644 --- a/tests/test_httpx/test_quickstart.py +++ b/tests/test_httpx/test_quickstart.py @@ -6,7 +6,7 @@ def echo(request): request.read() response = httpx.Response(200, content=httpx.JSON({ - 'method': request.method, + 'method': str(request.method), 'query-params': dict(request.url.params.items()), 'content-type': request.headers.get('Content-Type'), 'json': json.loads(request.body) if request.body else None,