Initial commit

This commit is contained in:
kdusek
2025-12-09 12:13:01 +01:00
commit 8e654ed209
13332 changed files with 2695056 additions and 0 deletions

View File

@@ -0,0 +1,5 @@
from ._client import * # noqa: F403
from ._mock import * # noqa: F403
from ._pool import * # noqa: F403
from ._storages import * # noqa: F403
from ._transports import * # noqa: F403

View File

@@ -0,0 +1,30 @@
import typing as tp
import httpx
from hishel._async._transports import AsyncCacheTransport
__all__ = ("AsyncCacheClient",)
class AsyncCacheClient(httpx.AsyncClient):
def __init__(self, *args: tp.Any, **kwargs: tp.Any):
self._storage = kwargs.pop("storage") if "storage" in kwargs else None
self._controller = kwargs.pop("controller") if "controller" in kwargs else None
super().__init__(*args, **kwargs)
def _init_transport(self, *args, **kwargs) -> AsyncCacheTransport: # type: ignore
_transport = super()._init_transport(*args, **kwargs)
return AsyncCacheTransport(
transport=_transport,
storage=self._storage,
controller=self._controller,
)
def _init_proxy_transport(self, *args, **kwargs) -> AsyncCacheTransport: # type: ignore
_transport = super()._init_proxy_transport(*args, **kwargs) # pragma: no cover
return AsyncCacheTransport( # pragma: no cover
transport=_transport,
storage=self._storage,
controller=self._controller,
)

View File

@@ -0,0 +1,43 @@
import typing as tp
from types import TracebackType
import httpcore
import httpx
from httpcore._async.interfaces import AsyncRequestInterface
if tp.TYPE_CHECKING: # pragma: no cover
from typing_extensions import Self
__all__ = ("MockAsyncConnectionPool", "MockAsyncTransport")
class MockAsyncConnectionPool(AsyncRequestInterface):
async def handle_async_request(self, request: httpcore.Request) -> httpcore.Response:
assert isinstance(request.stream, tp.AsyncIterable)
data = b"".join([chunk async for chunk in request.stream]) # noqa: F841
return self.mocked_responses.pop(0)
def add_responses(self, responses: tp.List[httpcore.Response]) -> None:
if not hasattr(self, "mocked_responses"):
self.mocked_responses = []
self.mocked_responses.extend(responses)
async def __aenter__(self) -> "Self":
return self
async def __aexit__(
self,
exc_type: tp.Optional[tp.Type[BaseException]] = None,
exc_value: tp.Optional[BaseException] = None,
traceback: tp.Optional[TracebackType] = None,
) -> None: ...
class MockAsyncTransport(httpx.AsyncBaseTransport):
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
return self.mocked_responses.pop(0)
def add_responses(self, responses: tp.List[httpx.Response]) -> None:
if not hasattr(self, "mocked_responses"):
self.mocked_responses = []
self.mocked_responses.extend(responses)

View File

@@ -0,0 +1,201 @@
from __future__ import annotations
import types
import typing as tp
from httpcore._async.interfaces import AsyncRequestInterface
from httpcore._exceptions import ConnectError
from httpcore._models import Request, Response
from .._controller import Controller, allowed_stale
from .._headers import parse_cache_control
from .._serializers import JSONSerializer, Metadata
from .._utils import extract_header_values_decoded
from ._storages import AsyncBaseStorage, AsyncFileStorage
T = tp.TypeVar("T")
__all__ = ("AsyncCacheConnectionPool",)
async def fake_stream(content: bytes) -> tp.AsyncIterable[bytes]:
yield content
def generate_504() -> Response:
return Response(status=504)
class AsyncCacheConnectionPool(AsyncRequestInterface):
"""An HTTP Core Connection Pool that supports HTTP caching.
:param pool: `Connection Pool` that our class wraps in order to add an HTTP Cache layer on top of
:type pool: AsyncRequestInterface
:param storage: Storage that handles how the responses should be saved., defaults to None
:type storage: tp.Optional[AsyncBaseStorage], optional
:param controller: Controller that manages the cache behavior at the specification level, defaults to None
:type controller: tp.Optional[Controller], optional
"""
def __init__(
self,
pool: AsyncRequestInterface,
storage: tp.Optional[AsyncBaseStorage] = None,
controller: tp.Optional[Controller] = None,
) -> None:
self._pool = pool
self._storage = storage if storage is not None else AsyncFileStorage(serializer=JSONSerializer())
if not isinstance(self._storage, AsyncBaseStorage): # pragma: no cover
raise TypeError(f"Expected subclass of `AsyncBaseStorage` but got `{storage.__class__.__name__}`")
self._controller = controller if controller is not None else Controller()
async def handle_async_request(self, request: Request) -> Response:
"""
Handles HTTP requests while also implementing HTTP caching.
:param request: An HTTP request
:type request: httpcore.Request
:return: An HTTP response
:rtype: httpcore.Response
"""
if request.extensions.get("cache_disabled", False):
request.headers.extend([(b"cache-control", b"no-cache"), (b"cache-control", b"max-age=0")])
if request.method.upper() not in [b"GET", b"HEAD"]:
# If the HTTP method is, for example, POST,
# we must also use the request data to generate the hash.
assert isinstance(request.stream, tp.AsyncIterable)
body_for_key = b"".join([chunk async for chunk in request.stream])
request.stream = fake_stream(body_for_key)
else:
body_for_key = b""
key = self._controller._key_generator(request, body_for_key)
stored_data = await self._storage.retrieve(key)
request_cache_control = parse_cache_control(extract_header_values_decoded(request.headers, b"Cache-Control"))
if request_cache_control.only_if_cached and not stored_data:
return generate_504()
if stored_data:
# Try using the stored response if it was discovered.
stored_response, stored_request, metadata = stored_data
# Immediately read the stored response to avoid issues when trying to access the response body.
stored_response.read()
res = self._controller.construct_response_from_cache(
request=request,
response=stored_response,
original_request=stored_request,
)
if isinstance(res, Response):
# Simply use the response if the controller determines it is ready for use.
return await self._create_hishel_response(
key=key,
response=stored_response,
request=request,
metadata=metadata,
cached=True,
revalidated=False,
)
if request_cache_control.only_if_cached:
return generate_504()
if isinstance(res, Request):
# Controller has determined that the response needs to be re-validated.
try:
revalidation_response = await self._pool.handle_async_request(res)
except ConnectError:
# If there is a connection error, we can use the stale response if allowed.
if self._controller._allow_stale and allowed_stale(response=stored_response):
return await self._create_hishel_response(
key=key,
response=stored_response,
request=request,
metadata=metadata,
cached=True,
revalidated=False,
)
raise # pragma: no cover
# Merge headers with the stale response.
final_response = self._controller.handle_validation_response(
old_response=stored_response, new_response=revalidation_response
)
await final_response.aread()
# RFC 9111: 4.3.3. Handling a Validation Response
# A 304 (Not Modified) response status code indicates that the stored response can be updated and
# reused. A full response (i.e., one containing content) indicates that none of the stored responses
# nominated in the conditional request are suitable. Instead, the cache MUST use the full response to
# satisfy the request. The cache MAY store such a full response, subject to its constraints.
if revalidation_response.status != 304 and self._controller.is_cachable(
request=request, response=final_response
):
await self._storage.store(key, response=final_response, request=request)
return await self._create_hishel_response(
key=key,
response=final_response,
request=request,
cached=revalidation_response.status == 304,
revalidated=True,
metadata=metadata,
)
regular_response = await self._pool.handle_async_request(request)
await regular_response.aread()
if self._controller.is_cachable(request=request, response=regular_response):
await self._storage.store(key, response=regular_response, request=request)
return await self._create_hishel_response(
key=key, response=regular_response, request=request, cached=False, revalidated=False
)
async def _create_hishel_response(
self,
key: str,
response: Response,
request: Request,
cached: bool,
revalidated: bool,
metadata: Metadata | None = None,
) -> Response:
if cached:
assert metadata
metadata["number_of_uses"] += 1
await self._storage.update_metadata(key=key, request=request, response=response, metadata=metadata)
response.extensions["from_cache"] = True # type: ignore[index]
response.extensions["cache_metadata"] = metadata # type: ignore[index]
else:
response.extensions["from_cache"] = False # type: ignore[index]
response.extensions["revalidated"] = revalidated # type: ignore[index]
return response
async def aclose(self) -> None:
await self._storage.aclose()
if hasattr(self._pool, "aclose"): # pragma: no cover
await self._pool.aclose()
async def __aenter__(self: T) -> T:
return self
async def __aexit__(
self,
exc_type: tp.Optional[tp.Type[BaseException]] = None,
exc_value: tp.Optional[BaseException] = None,
traceback: tp.Optional[types.TracebackType] = None,
) -> None:
await self.aclose()

View File

@@ -0,0 +1,772 @@
from __future__ import annotations
import datetime
import logging
import os
import time
import typing as t
import typing as tp
import warnings
from copy import deepcopy
from pathlib import Path
try:
import boto3
from .._s3 import AsyncS3Manager
except ImportError: # pragma: no cover
boto3 = None # type: ignore
try:
import anysqlite
except ImportError: # pragma: no cover
anysqlite = None # type: ignore
from httpcore import Request, Response
if t.TYPE_CHECKING: # pragma: no cover
from typing_extensions import TypeAlias
from hishel._serializers import BaseSerializer, clone_model
from .._files import AsyncFileManager
from .._serializers import JSONSerializer, Metadata
from .._synchronization import AsyncLock
from .._utils import float_seconds_to_int_milliseconds
logger = logging.getLogger("hishel.storages")
__all__ = (
"AsyncBaseStorage",
"AsyncFileStorage",
"AsyncRedisStorage",
"AsyncSQLiteStorage",
"AsyncInMemoryStorage",
"AsyncS3Storage",
)
StoredResponse: TypeAlias = tp.Tuple[Response, Request, Metadata]
RemoveTypes = tp.Union[str, Response]
try:
import redis.asyncio as redis
except ImportError: # pragma: no cover
redis = None # type: ignore
class AsyncBaseStorage:
def __init__(
self,
serializer: tp.Optional[BaseSerializer] = None,
ttl: tp.Optional[tp.Union[int, float]] = None,
) -> None:
self._serializer = serializer or JSONSerializer()
self._ttl = ttl
async def store(self, key: str, response: Response, request: Request, metadata: Metadata | None = None) -> None:
raise NotImplementedError()
async def remove(self, key: RemoveTypes) -> None:
raise NotImplementedError()
async def update_metadata(self, key: str, response: Response, request: Request, metadata: Metadata) -> None:
raise NotImplementedError()
async def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
raise NotImplementedError()
async def aclose(self) -> None:
raise NotImplementedError()
class AsyncFileStorage(AsyncBaseStorage):
"""
A simple file storage.
:param serializer: Serializer capable of serializing and de-serializing http responses, defaults to None
:type serializer: tp.Optional[BaseSerializer], optional
:param base_path: A storage base path where the responses should be saved, defaults to None
:type base_path: tp.Optional[Path], optional
:param ttl: Specifies the maximum number of seconds that the response can be cached, defaults to None
:type ttl: tp.Optional[tp.Union[int, float]], optional
:param check_ttl_every: How often in seconds to check staleness of **all** cache files.
Makes sense only with set `ttl`, defaults to 60
:type check_ttl_every: tp.Union[int, float]
"""
def __init__(
self,
serializer: tp.Optional[BaseSerializer] = None,
base_path: tp.Optional[Path] = None,
ttl: tp.Optional[tp.Union[int, float]] = None,
check_ttl_every: tp.Union[int, float] = 60,
) -> None:
super().__init__(serializer, ttl)
self._base_path = Path(base_path) if base_path is not None else Path(".cache/hishel")
self._gitignore_file = self._base_path / ".gitignore"
if not self._base_path.is_dir():
self._base_path.mkdir(parents=True)
if not self._gitignore_file.is_file():
with open(self._gitignore_file, "w", encoding="utf-8") as f:
f.write("# Automatically created by Hishel\n*")
self._file_manager = AsyncFileManager(is_binary=self._serializer.is_binary)
self._lock = AsyncLock()
self._check_ttl_every = check_ttl_every
self._last_cleaned = time.monotonic()
async def store(self, key: str, response: Response, request: Request, metadata: Metadata | None = None) -> None:
"""
Stores the response in the cache.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additional information about the stored response
:type metadata: Optional[Metadata]
"""
metadata = metadata or Metadata(
cache_key=key, created_at=datetime.datetime.now(datetime.timezone.utc), number_of_uses=0
)
response_path = self._base_path / key
async with self._lock:
await self._file_manager.write_to(
str(response_path),
self._serializer.dumps(response=response, request=request, metadata=metadata),
)
await self._remove_expired_caches(response_path)
async def remove(self, key: RemoveTypes) -> None:
"""
Removes the response from the cache.
:param key: Hashed value of concatenated HTTP method and URI or an HTTP response
:type key: Union[str, Response]
"""
if isinstance(key, Response): # pragma: no cover
key = t.cast(str, key.extensions["cache_metadata"]["cache_key"])
response_path = self._base_path / key
async with self._lock:
if response_path.exists():
response_path.unlink(missing_ok=True)
async def update_metadata(self, key: str, response: Response, request: Request, metadata: Metadata) -> None:
"""
Updates the metadata of the stored response.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additional information about the stored response
:type metadata: Metadata
"""
response_path = self._base_path / key
async with self._lock:
if response_path.exists():
atime = response_path.stat().st_atime
old_mtime = response_path.stat().st_mtime
await self._file_manager.write_to(
str(response_path),
self._serializer.dumps(response=response, request=request, metadata=metadata),
)
# Restore the old atime and mtime (we use mtime to check the cache expiration time)
os.utime(response_path, (atime, old_mtime))
return
return await self.store(key, response, request, metadata) # pragma: no cover
async def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
"""
Retreives the response from the cache using his key.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:return: An HTTP response and his HTTP request.
:rtype: tp.Optional[StoredResponse]
"""
response_path = self._base_path / key
await self._remove_expired_caches(response_path)
async with self._lock:
if response_path.exists():
read_data = await self._file_manager.read_from(str(response_path))
if len(read_data) != 0:
return self._serializer.loads(read_data)
return None
async def aclose(self) -> None: # pragma: no cover
return
async def _remove_expired_caches(self, response_path: Path) -> None:
if self._ttl is None:
return
if time.monotonic() - self._last_cleaned < self._check_ttl_every:
if response_path.is_file():
age = time.time() - response_path.stat().st_mtime
if age > self._ttl:
response_path.unlink(missing_ok=True)
return
self._last_cleaned = time.monotonic()
async with self._lock:
with os.scandir(self._base_path) as entries:
for entry in entries:
try:
if entry.is_file():
age = time.time() - entry.stat().st_mtime
if age > self._ttl:
os.unlink(entry.path)
except FileNotFoundError: # pragma: no cover
pass
class AsyncSQLiteStorage(AsyncBaseStorage):
"""
A simple sqlite3 storage.
:param serializer: Serializer capable of serializing and de-serializing http responses, defaults to None
:type serializer: tp.Optional[BaseSerializer], optional
:param connection: A connection for sqlite, defaults to None
:type connection: tp.Optional[anysqlite.Connection], optional
:param ttl: Specifies the maximum number of seconds that the response can be cached, defaults to None
:type ttl: tp.Optional[tp.Union[int, float]], optional
"""
def __init__(
self,
serializer: tp.Optional[BaseSerializer] = None,
connection: tp.Optional[anysqlite.Connection] = None,
ttl: tp.Optional[tp.Union[int, float]] = None,
) -> None:
if anysqlite is None: # pragma: no cover
raise RuntimeError(
f"The `{type(self).__name__}` was used, but the required packages were not found. "
"Check that you have `Hishel` installed with the `sqlite` extension as shown.\n"
"```pip install hishel[sqlite]```"
)
super().__init__(serializer, ttl)
self._connection: tp.Optional[anysqlite.Connection] = connection or None
self._setup_lock = AsyncLock()
self._setup_completed: bool = False
self._lock = AsyncLock()
async def _setup(self) -> None:
async with self._setup_lock:
if not self._setup_completed:
if not self._connection: # pragma: no cover
self._connection = await anysqlite.connect(".hishel.sqlite", check_same_thread=False)
await self._connection.execute(
"CREATE TABLE IF NOT EXISTS cache(key TEXT, data BLOB, date_created REAL)"
)
await self._connection.commit()
self._setup_completed = True
async def store(self, key: str, response: Response, request: Request, metadata: Metadata | None = None) -> None:
"""
Stores the response in the cache.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additioal information about the stored response
:type metadata: Optional[Metadata]
"""
await self._setup()
assert self._connection
metadata = metadata or Metadata(
cache_key=key, created_at=datetime.datetime.now(datetime.timezone.utc), number_of_uses=0
)
async with self._lock:
await self._connection.execute("DELETE FROM cache WHERE key = ?", [key])
serialized_response = self._serializer.dumps(response=response, request=request, metadata=metadata)
await self._connection.execute(
"INSERT INTO cache(key, data, date_created) VALUES(?, ?, ?)", [key, serialized_response, time.time()]
)
await self._connection.commit()
await self._remove_expired_caches()
async def remove(self, key: RemoveTypes) -> None:
"""
Removes the response from the cache.
:param key: Hashed value of concatenated HTTP method and URI or an HTTP response
:type key: Union[str, Response]
"""
await self._setup()
assert self._connection
if isinstance(key, Response): # pragma: no cover
key = t.cast(str, key.extensions["cache_metadata"]["cache_key"])
async with self._lock:
await self._connection.execute("DELETE FROM cache WHERE key = ?", [key])
await self._connection.commit()
async def update_metadata(self, key: str, response: Response, request: Request, metadata: Metadata) -> None:
"""
Updates the metadata of the stored response.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additional information about the stored response
:type metadata: Metadata
"""
await self._setup()
assert self._connection
async with self._lock:
cursor = await self._connection.execute("SELECT data FROM cache WHERE key = ?", [key])
row = await cursor.fetchone()
if row is not None:
serialized_response = self._serializer.dumps(response=response, request=request, metadata=metadata)
await self._connection.execute("UPDATE cache SET data = ? WHERE key = ?", [serialized_response, key])
await self._connection.commit()
return
return await self.store(key, response, request, metadata) # pragma: no cover
async def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
"""
Retreives the response from the cache using his key.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:return: An HTTP response and its HTTP request.
:rtype: tp.Optional[StoredResponse]
"""
await self._setup()
assert self._connection
await self._remove_expired_caches()
async with self._lock:
cursor = await self._connection.execute("SELECT data FROM cache WHERE key = ?", [key])
row = await cursor.fetchone()
if row is None:
return None
cached_response = row[0]
return self._serializer.loads(cached_response)
async def aclose(self) -> None: # pragma: no cover
if self._connection is not None:
await self._connection.close()
async def _remove_expired_caches(self) -> None:
assert self._connection
if self._ttl is None:
return
async with self._lock:
await self._connection.execute("DELETE FROM cache WHERE date_created + ? < ?", [self._ttl, time.time()])
await self._connection.commit()
class AsyncRedisStorage(AsyncBaseStorage):
"""
A simple redis storage.
:param serializer: Serializer capable of serializing and de-serializing http responses, defaults to None
:type serializer: tp.Optional[BaseSerializer], optional
:param client: A client for redis, defaults to None
:type client: tp.Optional["redis.Redis"], optional
:param ttl: Specifies the maximum number of seconds that the response can be cached, defaults to None
:type ttl: tp.Optional[tp.Union[int, float]], optional
"""
def __init__(
self,
serializer: tp.Optional[BaseSerializer] = None,
client: tp.Optional[redis.Redis] = None, # type: ignore
ttl: tp.Optional[tp.Union[int, float]] = None,
) -> None:
if redis is None: # pragma: no cover
raise RuntimeError(
f"The `{type(self).__name__}` was used, but the required packages were not found. "
"Check that you have `Hishel` installed with the `redis` extension as shown.\n"
"```pip install hishel[redis]```"
)
super().__init__(serializer, ttl)
if client is None:
self._client = redis.Redis() # type: ignore
else: # pragma: no cover
self._client = client
async def store(self, key: str, response: Response, request: Request, metadata: Metadata | None = None) -> None:
"""
Stores the response in the cache.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additioal information about the stored response
:type metadata: Optional[Metadata]
"""
metadata = metadata or Metadata(
cache_key=key, created_at=datetime.datetime.now(datetime.timezone.utc), number_of_uses=0
)
if self._ttl is not None:
px = float_seconds_to_int_milliseconds(self._ttl)
else:
px = None
await self._client.set(
key, self._serializer.dumps(response=response, request=request, metadata=metadata), px=px
)
async def remove(self, key: RemoveTypes) -> None:
"""
Removes the response from the cache.
:param key: Hashed value of concatenated HTTP method and URI or an HTTP response
:type key: Union[str, Response]
"""
if isinstance(key, Response): # pragma: no cover
key = t.cast(str, key.extensions["cache_metadata"]["cache_key"])
await self._client.delete(key)
async def update_metadata(self, key: str, response: Response, request: Request, metadata: Metadata) -> None:
"""
Updates the metadata of the stored response.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additional information about the stored response
:type metadata: Metadata
"""
ttl_in_milliseconds = await self._client.pttl(key)
# -2: if the key does not exist in Redis
# -1: if the key exists in Redis but has no expiration
if ttl_in_milliseconds == -2 or ttl_in_milliseconds == -1: # pragma: no cover
await self.store(key, response, request, metadata)
else:
await self._client.set(
key,
self._serializer.dumps(response=response, request=request, metadata=metadata),
px=ttl_in_milliseconds,
)
async def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
"""
Retreives the response from the cache using his key.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:return: An HTTP response and its HTTP request.
:rtype: tp.Optional[StoredResponse]
"""
cached_response = await self._client.get(key)
if cached_response is None:
return None
return self._serializer.loads(cached_response)
async def aclose(self) -> None: # pragma: no cover
await self._client.aclose()
class AsyncInMemoryStorage(AsyncBaseStorage):
"""
A simple in-memory storage.
:param serializer: Serializer capable of serializing and de-serializing http responses, defaults to None
:type serializer: tp.Optional[BaseSerializer], optional
:param ttl: Specifies the maximum number of seconds that the response can be cached, defaults to None
:type ttl: tp.Optional[tp.Union[int, float]], optional
:param capacity: The maximum number of responses that can be cached, defaults to 128
:type capacity: int, optional
"""
def __init__(
self,
serializer: tp.Optional[BaseSerializer] = None,
ttl: tp.Optional[tp.Union[int, float]] = None,
capacity: int = 128,
) -> None:
super().__init__(serializer, ttl)
if serializer is not None: # pragma: no cover
warnings.warn("The serializer is not used in the in-memory storage.", RuntimeWarning)
from hishel import LFUCache
self._cache: LFUCache[str, tp.Tuple[StoredResponse, float]] = LFUCache(capacity=capacity)
self._lock = AsyncLock()
async def store(self, key: str, response: Response, request: Request, metadata: Metadata | None = None) -> None:
"""
Stores the response in the cache.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additioal information about the stored response
:type metadata: Optional[Metadata]
"""
metadata = metadata or Metadata(
cache_key=key, created_at=datetime.datetime.now(datetime.timezone.utc), number_of_uses=0
)
async with self._lock:
response_clone = clone_model(response)
request_clone = clone_model(request)
stored_response: StoredResponse = (deepcopy(response_clone), deepcopy(request_clone), metadata)
self._cache.put(key, (stored_response, time.monotonic()))
await self._remove_expired_caches()
async def remove(self, key: RemoveTypes) -> None:
"""
Removes the response from the cache.
:param key: Hashed value of concatenated HTTP method and URI or an HTTP response
:type key: Union[str, Response]
"""
if isinstance(key, Response): # pragma: no cover
key = t.cast(str, key.extensions["cache_metadata"]["cache_key"])
async with self._lock:
self._cache.remove_key(key)
async def update_metadata(self, key: str, response: Response, request: Request, metadata: Metadata) -> None:
"""
Updates the metadata of the stored response.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additional information about the stored response
:type metadata: Metadata
"""
async with self._lock:
try:
stored_response, created_at = self._cache.get(key)
stored_response = (stored_response[0], stored_response[1], metadata)
self._cache.put(key, (stored_response, created_at))
return
except KeyError: # pragma: no cover
pass
await self.store(key, response, request, metadata) # pragma: no cover
async def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
"""
Retreives the response from the cache using his key.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:return: An HTTP response and its HTTP request.
:rtype: tp.Optional[StoredResponse]
"""
await self._remove_expired_caches()
async with self._lock:
try:
stored_response, _ = self._cache.get(key)
except KeyError:
return None
return stored_response
async def aclose(self) -> None: # pragma: no cover
return
async def _remove_expired_caches(self) -> None:
if self._ttl is None:
return
async with self._lock:
keys_to_remove = set()
for key in self._cache:
created_at = self._cache.get(key)[1]
if time.monotonic() - created_at > self._ttl:
keys_to_remove.add(key)
for key in keys_to_remove:
self._cache.remove_key(key)
class AsyncS3Storage(AsyncBaseStorage): # pragma: no cover
"""
AWS S3 storage.
:param bucket_name: The name of the bucket to store the responses in
:type bucket_name: str
:param serializer: Serializer capable of serializing and de-serializing http responses, defaults to None
:type serializer: tp.Optional[BaseSerializer], optional
:param ttl: Specifies the maximum number of seconds that the response can be cached, defaults to None
:type ttl: tp.Optional[tp.Union[int, float]], optional
:param check_ttl_every: How often in seconds to check staleness of **all** cache files.
Makes sense only with set `ttl`, defaults to 60
:type check_ttl_every: tp.Union[int, float]
:param client: A client for S3, defaults to None
:type client: tp.Optional[tp.Any], optional
:param path_prefix: A path prefix to use for S3 object keys, defaults to "hishel-"
:type path_prefix: str, optional
"""
def __init__(
self,
bucket_name: str,
serializer: tp.Optional[BaseSerializer] = None,
ttl: tp.Optional[tp.Union[int, float]] = None,
check_ttl_every: tp.Union[int, float] = 60,
client: tp.Optional[tp.Any] = None,
path_prefix: str = "hishel-",
) -> None:
super().__init__(serializer, ttl)
if boto3 is None: # pragma: no cover
raise RuntimeError(
f"The `{type(self).__name__}` was used, but the required packages were not found. "
"Check that you have `Hishel` installed with the `s3` extension as shown.\n"
"```pip install hishel[s3]```"
)
self._bucket_name = bucket_name
client = client or boto3.client("s3")
self._s3_manager = AsyncS3Manager(
client=client,
bucket_name=bucket_name,
is_binary=self._serializer.is_binary,
check_ttl_every=check_ttl_every,
path_prefix=path_prefix,
)
self._lock = AsyncLock()
async def store(self, key: str, response: Response, request: Request, metadata: Metadata | None = None) -> None:
"""
Stores the response in the cache.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additioal information about the stored response
:type metadata: Optional[Metadata]`
"""
metadata = metadata or Metadata(
cache_key=key, created_at=datetime.datetime.now(datetime.timezone.utc), number_of_uses=0
)
async with self._lock:
serialized = self._serializer.dumps(response=response, request=request, metadata=metadata)
await self._s3_manager.write_to(path=key, data=serialized)
await self._remove_expired_caches(key)
async def remove(self, key: RemoveTypes) -> None:
"""
Removes the response from the cache.
:param key: Hashed value of concatenated HTTP method and URI or an HTTP response
:type key: Union[str, Response]
"""
if isinstance(key, Response): # pragma: no cover
key = t.cast(str, key.extensions["cache_metadata"]["cache_key"])
async with self._lock:
await self._s3_manager.remove_entry(key)
async def update_metadata(self, key: str, response: Response, request: Request, metadata: Metadata) -> None:
"""
Updates the metadata of the stored response.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additional information about the stored response
:type metadata: Metadata
"""
async with self._lock:
serialized = self._serializer.dumps(response=response, request=request, metadata=metadata)
await self._s3_manager.write_to(path=key, data=serialized, only_metadata=True)
async def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
"""
Retreives the response from the cache using his key.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:return: An HTTP response and its HTTP request.
:rtype: tp.Optional[StoredResponse]
"""
await self._remove_expired_caches(key)
async with self._lock:
try:
return self._serializer.loads(await self._s3_manager.read_from(path=key))
except Exception:
return None
async def aclose(self) -> None: # pragma: no cover
return
async def _remove_expired_caches(self, key: str) -> None:
if self._ttl is None:
return
async with self._lock:
converted_ttl = float_seconds_to_int_milliseconds(self._ttl)
await self._s3_manager.remove_expired(ttl=converted_ttl, key=key)

View File

@@ -0,0 +1,277 @@
from __future__ import annotations
import types
import typing as tp
import httpcore
import httpx
from httpx import AsyncByteStream, Request, Response
from httpx._exceptions import ConnectError
from hishel._utils import extract_header_values_decoded, normalized_url
from .._controller import Controller, allowed_stale
from .._headers import parse_cache_control
from .._serializers import JSONSerializer, Metadata
from ._storages import AsyncBaseStorage, AsyncFileStorage
if tp.TYPE_CHECKING: # pragma: no cover
from typing_extensions import Self
__all__ = ("AsyncCacheTransport",)
async def fake_stream(content: bytes) -> tp.AsyncIterable[bytes]:
yield content
def generate_504() -> Response:
return Response(status_code=504)
class AsyncCacheStream(AsyncByteStream):
def __init__(self, httpcore_stream: tp.AsyncIterable[bytes]):
self._httpcore_stream = httpcore_stream
async def __aiter__(self) -> tp.AsyncIterator[bytes]:
async for part in self._httpcore_stream:
yield part
async def aclose(self) -> None:
if hasattr(self._httpcore_stream, "aclose"):
await self._httpcore_stream.aclose()
class AsyncCacheTransport(httpx.AsyncBaseTransport):
"""
An HTTPX Transport that supports HTTP caching.
:param transport: `Transport` that our class wraps in order to add an HTTP Cache layer on top of
:type transport: httpx.AsyncBaseTransport
:param storage: Storage that handles how the responses should be saved., defaults to None
:type storage: tp.Optional[AsyncBaseStorage], optional
:param controller: Controller that manages the cache behavior at the specification level, defaults to None
:type controller: tp.Optional[Controller], optional
"""
def __init__(
self,
transport: httpx.AsyncBaseTransport,
storage: tp.Optional[AsyncBaseStorage] = None,
controller: tp.Optional[Controller] = None,
) -> None:
self._transport = transport
self._storage = storage if storage is not None else AsyncFileStorage(serializer=JSONSerializer())
if not isinstance(self._storage, AsyncBaseStorage): # pragma: no cover
raise TypeError(f"Expected subclass of `AsyncBaseStorage` but got `{storage.__class__.__name__}`")
self._controller = controller if controller is not None else Controller()
async def handle_async_request(self, request: Request) -> Response:
"""
Handles HTTP requests while also implementing HTTP caching.
:param request: An HTTP request
:type request: httpx.Request
:return: An HTTP response
:rtype: httpx.Response
"""
if request.extensions.get("cache_disabled", False):
request.headers.update(
[
("Cache-Control", "no-store"),
("Cache-Control", "no-cache"),
*[("cache-control", value) for value in request.headers.get_list("cache-control")],
]
)
if request.method not in ["GET", "HEAD"]:
# If the HTTP method is, for example, POST,
# we must also use the request data to generate the hash.
body_for_key = await request.aread()
request.stream = AsyncCacheStream(fake_stream(body_for_key))
else:
body_for_key = b""
# Construct the HTTPCore request because Controllers and Storages work with HTTPCore requests.
httpcore_request = httpcore.Request(
method=request.method,
url=httpcore.URL(
scheme=request.url.raw_scheme,
host=request.url.raw_host,
port=request.url.port,
target=request.url.raw_path,
),
headers=request.headers.raw,
content=request.stream,
extensions=request.extensions,
)
key = self._controller._key_generator(httpcore_request, body_for_key)
stored_data = await self._storage.retrieve(key)
request_cache_control = parse_cache_control(
extract_header_values_decoded(request.headers.raw, b"Cache-Control")
)
if request_cache_control.only_if_cached and not stored_data:
return generate_504()
if stored_data:
# Try using the stored response if it was discovered.
stored_response, stored_request, metadata = stored_data
# Immediately read the stored response to avoid issues when trying to access the response body.
stored_response.read()
res = self._controller.construct_response_from_cache(
request=httpcore_request,
response=stored_response,
original_request=stored_request,
)
if isinstance(res, httpcore.Response):
# Simply use the response if the controller determines it is ready for use.
return await self._create_hishel_response(
key=key,
response=res,
request=httpcore_request,
cached=True,
revalidated=False,
metadata=metadata,
)
if request_cache_control.only_if_cached:
return generate_504()
if isinstance(res, httpcore.Request):
# Controller has determined that the response needs to be re-validated.
assert isinstance(res.stream, tp.AsyncIterable)
revalidation_request = Request(
method=res.method.decode(),
url=normalized_url(res.url),
headers=res.headers,
stream=AsyncCacheStream(res.stream),
extensions=res.extensions,
)
try:
revalidation_response = await self._transport.handle_async_request(revalidation_request)
except ConnectError:
# If there is a connection error, we can use the stale response if allowed.
if self._controller._allow_stale and allowed_stale(response=stored_response):
return await self._create_hishel_response(
key=key,
response=stored_response,
request=httpcore_request,
cached=True,
revalidated=False,
metadata=metadata,
)
raise # pragma: no cover
assert isinstance(revalidation_response.stream, tp.AsyncIterable)
httpcore_revalidation_response = httpcore.Response(
status=revalidation_response.status_code,
headers=revalidation_response.headers.raw,
content=AsyncCacheStream(revalidation_response.stream),
extensions=revalidation_response.extensions,
)
# Merge headers with the stale response.
final_httpcore_response = self._controller.handle_validation_response(
old_response=stored_response,
new_response=httpcore_revalidation_response,
)
await final_httpcore_response.aread()
await revalidation_response.aclose()
assert isinstance(final_httpcore_response.stream, tp.AsyncIterable)
# RFC 9111: 4.3.3. Handling a Validation Response
# A 304 (Not Modified) response status code indicates that the stored response can be updated and
# reused. A full response (i.e., one containing content) indicates that none of the stored responses
# nominated in the conditional request are suitable. Instead, the cache MUST use the full response to
# satisfy the request. The cache MAY store such a full response, subject to its constraints.
if revalidation_response.status_code != 304 and self._controller.is_cachable(
request=httpcore_request, response=final_httpcore_response
):
await self._storage.store(key, response=final_httpcore_response, request=httpcore_request)
return await self._create_hishel_response(
key=key,
response=final_httpcore_response,
request=httpcore_request,
cached=revalidation_response.status_code == 304,
revalidated=True,
metadata=metadata,
)
regular_response = await self._transport.handle_async_request(request)
assert isinstance(regular_response.stream, tp.AsyncIterable)
httpcore_regular_response = httpcore.Response(
status=regular_response.status_code,
headers=regular_response.headers.raw,
content=AsyncCacheStream(regular_response.stream),
extensions=regular_response.extensions,
)
await httpcore_regular_response.aread()
await httpcore_regular_response.aclose()
if self._controller.is_cachable(request=httpcore_request, response=httpcore_regular_response):
await self._storage.store(
key,
response=httpcore_regular_response,
request=httpcore_request,
)
return await self._create_hishel_response(
key=key,
response=httpcore_regular_response,
request=httpcore_request,
cached=False,
revalidated=False,
)
async def _create_hishel_response(
self,
key: str,
response: httpcore.Response,
request: httpcore.Request,
cached: bool,
revalidated: bool,
metadata: Metadata | None = None,
) -> Response:
if cached:
assert metadata
metadata["number_of_uses"] += 1
await self._storage.update_metadata(key=key, request=request, response=response, metadata=metadata)
response.extensions["from_cache"] = True # type: ignore[index]
response.extensions["cache_metadata"] = metadata # type: ignore[index]
else:
response.extensions["from_cache"] = False # type: ignore[index]
response.extensions["revalidated"] = revalidated # type: ignore[index]
return Response(
status_code=response.status,
headers=response.headers,
stream=AsyncCacheStream(fake_stream(response.content)),
extensions=response.extensions,
)
async def aclose(self) -> None:
await self._storage.aclose()
await self._transport.aclose()
async def __aenter__(self) -> Self:
return self
async def __aexit__(
self,
exc_type: tp.Optional[tp.Type[BaseException]] = None,
exc_value: tp.Optional[BaseException] = None,
traceback: tp.Optional[types.TracebackType] = None,
) -> None:
await self.aclose()