Initial commit

This commit is contained in:
kdusek
2025-12-09 12:13:01 +01:00
commit 8e654ed209
13332 changed files with 2695056 additions and 0 deletions

View File

@@ -0,0 +1,17 @@
import httpx
from ._async import *
from ._controller import *
from ._exceptions import *
from ._headers import *
from ._serializers import *
from ._sync import *
from ._lfu_cache import *
def install_cache() -> None: # pragma: no cover
httpx.AsyncClient = AsyncCacheClient # type: ignore
httpx.Client = CacheClient # type: ignore
__version__ = "0.1.3"

View File

@@ -0,0 +1,5 @@
from ._client import * # noqa: F403
from ._mock import * # noqa: F403
from ._pool import * # noqa: F403
from ._storages import * # noqa: F403
from ._transports import * # noqa: F403

View File

@@ -0,0 +1,30 @@
import typing as tp
import httpx
from hishel._async._transports import AsyncCacheTransport
__all__ = ("AsyncCacheClient",)
class AsyncCacheClient(httpx.AsyncClient):
def __init__(self, *args: tp.Any, **kwargs: tp.Any):
self._storage = kwargs.pop("storage") if "storage" in kwargs else None
self._controller = kwargs.pop("controller") if "controller" in kwargs else None
super().__init__(*args, **kwargs)
def _init_transport(self, *args, **kwargs) -> AsyncCacheTransport: # type: ignore
_transport = super()._init_transport(*args, **kwargs)
return AsyncCacheTransport(
transport=_transport,
storage=self._storage,
controller=self._controller,
)
def _init_proxy_transport(self, *args, **kwargs) -> AsyncCacheTransport: # type: ignore
_transport = super()._init_proxy_transport(*args, **kwargs) # pragma: no cover
return AsyncCacheTransport( # pragma: no cover
transport=_transport,
storage=self._storage,
controller=self._controller,
)

View File

@@ -0,0 +1,43 @@
import typing as tp
from types import TracebackType
import httpcore
import httpx
from httpcore._async.interfaces import AsyncRequestInterface
if tp.TYPE_CHECKING: # pragma: no cover
from typing_extensions import Self
__all__ = ("MockAsyncConnectionPool", "MockAsyncTransport")
class MockAsyncConnectionPool(AsyncRequestInterface):
async def handle_async_request(self, request: httpcore.Request) -> httpcore.Response:
assert isinstance(request.stream, tp.AsyncIterable)
data = b"".join([chunk async for chunk in request.stream]) # noqa: F841
return self.mocked_responses.pop(0)
def add_responses(self, responses: tp.List[httpcore.Response]) -> None:
if not hasattr(self, "mocked_responses"):
self.mocked_responses = []
self.mocked_responses.extend(responses)
async def __aenter__(self) -> "Self":
return self
async def __aexit__(
self,
exc_type: tp.Optional[tp.Type[BaseException]] = None,
exc_value: tp.Optional[BaseException] = None,
traceback: tp.Optional[TracebackType] = None,
) -> None: ...
class MockAsyncTransport(httpx.AsyncBaseTransport):
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
return self.mocked_responses.pop(0)
def add_responses(self, responses: tp.List[httpx.Response]) -> None:
if not hasattr(self, "mocked_responses"):
self.mocked_responses = []
self.mocked_responses.extend(responses)

View File

@@ -0,0 +1,201 @@
from __future__ import annotations
import types
import typing as tp
from httpcore._async.interfaces import AsyncRequestInterface
from httpcore._exceptions import ConnectError
from httpcore._models import Request, Response
from .._controller import Controller, allowed_stale
from .._headers import parse_cache_control
from .._serializers import JSONSerializer, Metadata
from .._utils import extract_header_values_decoded
from ._storages import AsyncBaseStorage, AsyncFileStorage
T = tp.TypeVar("T")
__all__ = ("AsyncCacheConnectionPool",)
async def fake_stream(content: bytes) -> tp.AsyncIterable[bytes]:
yield content
def generate_504() -> Response:
return Response(status=504)
class AsyncCacheConnectionPool(AsyncRequestInterface):
"""An HTTP Core Connection Pool that supports HTTP caching.
:param pool: `Connection Pool` that our class wraps in order to add an HTTP Cache layer on top of
:type pool: AsyncRequestInterface
:param storage: Storage that handles how the responses should be saved., defaults to None
:type storage: tp.Optional[AsyncBaseStorage], optional
:param controller: Controller that manages the cache behavior at the specification level, defaults to None
:type controller: tp.Optional[Controller], optional
"""
def __init__(
self,
pool: AsyncRequestInterface,
storage: tp.Optional[AsyncBaseStorage] = None,
controller: tp.Optional[Controller] = None,
) -> None:
self._pool = pool
self._storage = storage if storage is not None else AsyncFileStorage(serializer=JSONSerializer())
if not isinstance(self._storage, AsyncBaseStorage): # pragma: no cover
raise TypeError(f"Expected subclass of `AsyncBaseStorage` but got `{storage.__class__.__name__}`")
self._controller = controller if controller is not None else Controller()
async def handle_async_request(self, request: Request) -> Response:
"""
Handles HTTP requests while also implementing HTTP caching.
:param request: An HTTP request
:type request: httpcore.Request
:return: An HTTP response
:rtype: httpcore.Response
"""
if request.extensions.get("cache_disabled", False):
request.headers.extend([(b"cache-control", b"no-cache"), (b"cache-control", b"max-age=0")])
if request.method.upper() not in [b"GET", b"HEAD"]:
# If the HTTP method is, for example, POST,
# we must also use the request data to generate the hash.
assert isinstance(request.stream, tp.AsyncIterable)
body_for_key = b"".join([chunk async for chunk in request.stream])
request.stream = fake_stream(body_for_key)
else:
body_for_key = b""
key = self._controller._key_generator(request, body_for_key)
stored_data = await self._storage.retrieve(key)
request_cache_control = parse_cache_control(extract_header_values_decoded(request.headers, b"Cache-Control"))
if request_cache_control.only_if_cached and not stored_data:
return generate_504()
if stored_data:
# Try using the stored response if it was discovered.
stored_response, stored_request, metadata = stored_data
# Immediately read the stored response to avoid issues when trying to access the response body.
stored_response.read()
res = self._controller.construct_response_from_cache(
request=request,
response=stored_response,
original_request=stored_request,
)
if isinstance(res, Response):
# Simply use the response if the controller determines it is ready for use.
return await self._create_hishel_response(
key=key,
response=stored_response,
request=request,
metadata=metadata,
cached=True,
revalidated=False,
)
if request_cache_control.only_if_cached:
return generate_504()
if isinstance(res, Request):
# Controller has determined that the response needs to be re-validated.
try:
revalidation_response = await self._pool.handle_async_request(res)
except ConnectError:
# If there is a connection error, we can use the stale response if allowed.
if self._controller._allow_stale and allowed_stale(response=stored_response):
return await self._create_hishel_response(
key=key,
response=stored_response,
request=request,
metadata=metadata,
cached=True,
revalidated=False,
)
raise # pragma: no cover
# Merge headers with the stale response.
final_response = self._controller.handle_validation_response(
old_response=stored_response, new_response=revalidation_response
)
await final_response.aread()
# RFC 9111: 4.3.3. Handling a Validation Response
# A 304 (Not Modified) response status code indicates that the stored response can be updated and
# reused. A full response (i.e., one containing content) indicates that none of the stored responses
# nominated in the conditional request are suitable. Instead, the cache MUST use the full response to
# satisfy the request. The cache MAY store such a full response, subject to its constraints.
if revalidation_response.status != 304 and self._controller.is_cachable(
request=request, response=final_response
):
await self._storage.store(key, response=final_response, request=request)
return await self._create_hishel_response(
key=key,
response=final_response,
request=request,
cached=revalidation_response.status == 304,
revalidated=True,
metadata=metadata,
)
regular_response = await self._pool.handle_async_request(request)
await regular_response.aread()
if self._controller.is_cachable(request=request, response=regular_response):
await self._storage.store(key, response=regular_response, request=request)
return await self._create_hishel_response(
key=key, response=regular_response, request=request, cached=False, revalidated=False
)
async def _create_hishel_response(
self,
key: str,
response: Response,
request: Request,
cached: bool,
revalidated: bool,
metadata: Metadata | None = None,
) -> Response:
if cached:
assert metadata
metadata["number_of_uses"] += 1
await self._storage.update_metadata(key=key, request=request, response=response, metadata=metadata)
response.extensions["from_cache"] = True # type: ignore[index]
response.extensions["cache_metadata"] = metadata # type: ignore[index]
else:
response.extensions["from_cache"] = False # type: ignore[index]
response.extensions["revalidated"] = revalidated # type: ignore[index]
return response
async def aclose(self) -> None:
await self._storage.aclose()
if hasattr(self._pool, "aclose"): # pragma: no cover
await self._pool.aclose()
async def __aenter__(self: T) -> T:
return self
async def __aexit__(
self,
exc_type: tp.Optional[tp.Type[BaseException]] = None,
exc_value: tp.Optional[BaseException] = None,
traceback: tp.Optional[types.TracebackType] = None,
) -> None:
await self.aclose()

View File

@@ -0,0 +1,772 @@
from __future__ import annotations
import datetime
import logging
import os
import time
import typing as t
import typing as tp
import warnings
from copy import deepcopy
from pathlib import Path
try:
import boto3
from .._s3 import AsyncS3Manager
except ImportError: # pragma: no cover
boto3 = None # type: ignore
try:
import anysqlite
except ImportError: # pragma: no cover
anysqlite = None # type: ignore
from httpcore import Request, Response
if t.TYPE_CHECKING: # pragma: no cover
from typing_extensions import TypeAlias
from hishel._serializers import BaseSerializer, clone_model
from .._files import AsyncFileManager
from .._serializers import JSONSerializer, Metadata
from .._synchronization import AsyncLock
from .._utils import float_seconds_to_int_milliseconds
logger = logging.getLogger("hishel.storages")
__all__ = (
"AsyncBaseStorage",
"AsyncFileStorage",
"AsyncRedisStorage",
"AsyncSQLiteStorage",
"AsyncInMemoryStorage",
"AsyncS3Storage",
)
StoredResponse: TypeAlias = tp.Tuple[Response, Request, Metadata]
RemoveTypes = tp.Union[str, Response]
try:
import redis.asyncio as redis
except ImportError: # pragma: no cover
redis = None # type: ignore
class AsyncBaseStorage:
def __init__(
self,
serializer: tp.Optional[BaseSerializer] = None,
ttl: tp.Optional[tp.Union[int, float]] = None,
) -> None:
self._serializer = serializer or JSONSerializer()
self._ttl = ttl
async def store(self, key: str, response: Response, request: Request, metadata: Metadata | None = None) -> None:
raise NotImplementedError()
async def remove(self, key: RemoveTypes) -> None:
raise NotImplementedError()
async def update_metadata(self, key: str, response: Response, request: Request, metadata: Metadata) -> None:
raise NotImplementedError()
async def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
raise NotImplementedError()
async def aclose(self) -> None:
raise NotImplementedError()
class AsyncFileStorage(AsyncBaseStorage):
"""
A simple file storage.
:param serializer: Serializer capable of serializing and de-serializing http responses, defaults to None
:type serializer: tp.Optional[BaseSerializer], optional
:param base_path: A storage base path where the responses should be saved, defaults to None
:type base_path: tp.Optional[Path], optional
:param ttl: Specifies the maximum number of seconds that the response can be cached, defaults to None
:type ttl: tp.Optional[tp.Union[int, float]], optional
:param check_ttl_every: How often in seconds to check staleness of **all** cache files.
Makes sense only with set `ttl`, defaults to 60
:type check_ttl_every: tp.Union[int, float]
"""
def __init__(
self,
serializer: tp.Optional[BaseSerializer] = None,
base_path: tp.Optional[Path] = None,
ttl: tp.Optional[tp.Union[int, float]] = None,
check_ttl_every: tp.Union[int, float] = 60,
) -> None:
super().__init__(serializer, ttl)
self._base_path = Path(base_path) if base_path is not None else Path(".cache/hishel")
self._gitignore_file = self._base_path / ".gitignore"
if not self._base_path.is_dir():
self._base_path.mkdir(parents=True)
if not self._gitignore_file.is_file():
with open(self._gitignore_file, "w", encoding="utf-8") as f:
f.write("# Automatically created by Hishel\n*")
self._file_manager = AsyncFileManager(is_binary=self._serializer.is_binary)
self._lock = AsyncLock()
self._check_ttl_every = check_ttl_every
self._last_cleaned = time.monotonic()
async def store(self, key: str, response: Response, request: Request, metadata: Metadata | None = None) -> None:
"""
Stores the response in the cache.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additional information about the stored response
:type metadata: Optional[Metadata]
"""
metadata = metadata or Metadata(
cache_key=key, created_at=datetime.datetime.now(datetime.timezone.utc), number_of_uses=0
)
response_path = self._base_path / key
async with self._lock:
await self._file_manager.write_to(
str(response_path),
self._serializer.dumps(response=response, request=request, metadata=metadata),
)
await self._remove_expired_caches(response_path)
async def remove(self, key: RemoveTypes) -> None:
"""
Removes the response from the cache.
:param key: Hashed value of concatenated HTTP method and URI or an HTTP response
:type key: Union[str, Response]
"""
if isinstance(key, Response): # pragma: no cover
key = t.cast(str, key.extensions["cache_metadata"]["cache_key"])
response_path = self._base_path / key
async with self._lock:
if response_path.exists():
response_path.unlink(missing_ok=True)
async def update_metadata(self, key: str, response: Response, request: Request, metadata: Metadata) -> None:
"""
Updates the metadata of the stored response.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additional information about the stored response
:type metadata: Metadata
"""
response_path = self._base_path / key
async with self._lock:
if response_path.exists():
atime = response_path.stat().st_atime
old_mtime = response_path.stat().st_mtime
await self._file_manager.write_to(
str(response_path),
self._serializer.dumps(response=response, request=request, metadata=metadata),
)
# Restore the old atime and mtime (we use mtime to check the cache expiration time)
os.utime(response_path, (atime, old_mtime))
return
return await self.store(key, response, request, metadata) # pragma: no cover
async def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
"""
Retreives the response from the cache using his key.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:return: An HTTP response and his HTTP request.
:rtype: tp.Optional[StoredResponse]
"""
response_path = self._base_path / key
await self._remove_expired_caches(response_path)
async with self._lock:
if response_path.exists():
read_data = await self._file_manager.read_from(str(response_path))
if len(read_data) != 0:
return self._serializer.loads(read_data)
return None
async def aclose(self) -> None: # pragma: no cover
return
async def _remove_expired_caches(self, response_path: Path) -> None:
if self._ttl is None:
return
if time.monotonic() - self._last_cleaned < self._check_ttl_every:
if response_path.is_file():
age = time.time() - response_path.stat().st_mtime
if age > self._ttl:
response_path.unlink(missing_ok=True)
return
self._last_cleaned = time.monotonic()
async with self._lock:
with os.scandir(self._base_path) as entries:
for entry in entries:
try:
if entry.is_file():
age = time.time() - entry.stat().st_mtime
if age > self._ttl:
os.unlink(entry.path)
except FileNotFoundError: # pragma: no cover
pass
class AsyncSQLiteStorage(AsyncBaseStorage):
"""
A simple sqlite3 storage.
:param serializer: Serializer capable of serializing and de-serializing http responses, defaults to None
:type serializer: tp.Optional[BaseSerializer], optional
:param connection: A connection for sqlite, defaults to None
:type connection: tp.Optional[anysqlite.Connection], optional
:param ttl: Specifies the maximum number of seconds that the response can be cached, defaults to None
:type ttl: tp.Optional[tp.Union[int, float]], optional
"""
def __init__(
self,
serializer: tp.Optional[BaseSerializer] = None,
connection: tp.Optional[anysqlite.Connection] = None,
ttl: tp.Optional[tp.Union[int, float]] = None,
) -> None:
if anysqlite is None: # pragma: no cover
raise RuntimeError(
f"The `{type(self).__name__}` was used, but the required packages were not found. "
"Check that you have `Hishel` installed with the `sqlite` extension as shown.\n"
"```pip install hishel[sqlite]```"
)
super().__init__(serializer, ttl)
self._connection: tp.Optional[anysqlite.Connection] = connection or None
self._setup_lock = AsyncLock()
self._setup_completed: bool = False
self._lock = AsyncLock()
async def _setup(self) -> None:
async with self._setup_lock:
if not self._setup_completed:
if not self._connection: # pragma: no cover
self._connection = await anysqlite.connect(".hishel.sqlite", check_same_thread=False)
await self._connection.execute(
"CREATE TABLE IF NOT EXISTS cache(key TEXT, data BLOB, date_created REAL)"
)
await self._connection.commit()
self._setup_completed = True
async def store(self, key: str, response: Response, request: Request, metadata: Metadata | None = None) -> None:
"""
Stores the response in the cache.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additioal information about the stored response
:type metadata: Optional[Metadata]
"""
await self._setup()
assert self._connection
metadata = metadata or Metadata(
cache_key=key, created_at=datetime.datetime.now(datetime.timezone.utc), number_of_uses=0
)
async with self._lock:
await self._connection.execute("DELETE FROM cache WHERE key = ?", [key])
serialized_response = self._serializer.dumps(response=response, request=request, metadata=metadata)
await self._connection.execute(
"INSERT INTO cache(key, data, date_created) VALUES(?, ?, ?)", [key, serialized_response, time.time()]
)
await self._connection.commit()
await self._remove_expired_caches()
async def remove(self, key: RemoveTypes) -> None:
"""
Removes the response from the cache.
:param key: Hashed value of concatenated HTTP method and URI or an HTTP response
:type key: Union[str, Response]
"""
await self._setup()
assert self._connection
if isinstance(key, Response): # pragma: no cover
key = t.cast(str, key.extensions["cache_metadata"]["cache_key"])
async with self._lock:
await self._connection.execute("DELETE FROM cache WHERE key = ?", [key])
await self._connection.commit()
async def update_metadata(self, key: str, response: Response, request: Request, metadata: Metadata) -> None:
"""
Updates the metadata of the stored response.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additional information about the stored response
:type metadata: Metadata
"""
await self._setup()
assert self._connection
async with self._lock:
cursor = await self._connection.execute("SELECT data FROM cache WHERE key = ?", [key])
row = await cursor.fetchone()
if row is not None:
serialized_response = self._serializer.dumps(response=response, request=request, metadata=metadata)
await self._connection.execute("UPDATE cache SET data = ? WHERE key = ?", [serialized_response, key])
await self._connection.commit()
return
return await self.store(key, response, request, metadata) # pragma: no cover
async def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
"""
Retreives the response from the cache using his key.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:return: An HTTP response and its HTTP request.
:rtype: tp.Optional[StoredResponse]
"""
await self._setup()
assert self._connection
await self._remove_expired_caches()
async with self._lock:
cursor = await self._connection.execute("SELECT data FROM cache WHERE key = ?", [key])
row = await cursor.fetchone()
if row is None:
return None
cached_response = row[0]
return self._serializer.loads(cached_response)
async def aclose(self) -> None: # pragma: no cover
if self._connection is not None:
await self._connection.close()
async def _remove_expired_caches(self) -> None:
assert self._connection
if self._ttl is None:
return
async with self._lock:
await self._connection.execute("DELETE FROM cache WHERE date_created + ? < ?", [self._ttl, time.time()])
await self._connection.commit()
class AsyncRedisStorage(AsyncBaseStorage):
"""
A simple redis storage.
:param serializer: Serializer capable of serializing and de-serializing http responses, defaults to None
:type serializer: tp.Optional[BaseSerializer], optional
:param client: A client for redis, defaults to None
:type client: tp.Optional["redis.Redis"], optional
:param ttl: Specifies the maximum number of seconds that the response can be cached, defaults to None
:type ttl: tp.Optional[tp.Union[int, float]], optional
"""
def __init__(
self,
serializer: tp.Optional[BaseSerializer] = None,
client: tp.Optional[redis.Redis] = None, # type: ignore
ttl: tp.Optional[tp.Union[int, float]] = None,
) -> None:
if redis is None: # pragma: no cover
raise RuntimeError(
f"The `{type(self).__name__}` was used, but the required packages were not found. "
"Check that you have `Hishel` installed with the `redis` extension as shown.\n"
"```pip install hishel[redis]```"
)
super().__init__(serializer, ttl)
if client is None:
self._client = redis.Redis() # type: ignore
else: # pragma: no cover
self._client = client
async def store(self, key: str, response: Response, request: Request, metadata: Metadata | None = None) -> None:
"""
Stores the response in the cache.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additioal information about the stored response
:type metadata: Optional[Metadata]
"""
metadata = metadata or Metadata(
cache_key=key, created_at=datetime.datetime.now(datetime.timezone.utc), number_of_uses=0
)
if self._ttl is not None:
px = float_seconds_to_int_milliseconds(self._ttl)
else:
px = None
await self._client.set(
key, self._serializer.dumps(response=response, request=request, metadata=metadata), px=px
)
async def remove(self, key: RemoveTypes) -> None:
"""
Removes the response from the cache.
:param key: Hashed value of concatenated HTTP method and URI or an HTTP response
:type key: Union[str, Response]
"""
if isinstance(key, Response): # pragma: no cover
key = t.cast(str, key.extensions["cache_metadata"]["cache_key"])
await self._client.delete(key)
async def update_metadata(self, key: str, response: Response, request: Request, metadata: Metadata) -> None:
"""
Updates the metadata of the stored response.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additional information about the stored response
:type metadata: Metadata
"""
ttl_in_milliseconds = await self._client.pttl(key)
# -2: if the key does not exist in Redis
# -1: if the key exists in Redis but has no expiration
if ttl_in_milliseconds == -2 or ttl_in_milliseconds == -1: # pragma: no cover
await self.store(key, response, request, metadata)
else:
await self._client.set(
key,
self._serializer.dumps(response=response, request=request, metadata=metadata),
px=ttl_in_milliseconds,
)
async def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
"""
Retreives the response from the cache using his key.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:return: An HTTP response and its HTTP request.
:rtype: tp.Optional[StoredResponse]
"""
cached_response = await self._client.get(key)
if cached_response is None:
return None
return self._serializer.loads(cached_response)
async def aclose(self) -> None: # pragma: no cover
await self._client.aclose()
class AsyncInMemoryStorage(AsyncBaseStorage):
"""
A simple in-memory storage.
:param serializer: Serializer capable of serializing and de-serializing http responses, defaults to None
:type serializer: tp.Optional[BaseSerializer], optional
:param ttl: Specifies the maximum number of seconds that the response can be cached, defaults to None
:type ttl: tp.Optional[tp.Union[int, float]], optional
:param capacity: The maximum number of responses that can be cached, defaults to 128
:type capacity: int, optional
"""
def __init__(
self,
serializer: tp.Optional[BaseSerializer] = None,
ttl: tp.Optional[tp.Union[int, float]] = None,
capacity: int = 128,
) -> None:
super().__init__(serializer, ttl)
if serializer is not None: # pragma: no cover
warnings.warn("The serializer is not used in the in-memory storage.", RuntimeWarning)
from hishel import LFUCache
self._cache: LFUCache[str, tp.Tuple[StoredResponse, float]] = LFUCache(capacity=capacity)
self._lock = AsyncLock()
async def store(self, key: str, response: Response, request: Request, metadata: Metadata | None = None) -> None:
"""
Stores the response in the cache.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additioal information about the stored response
:type metadata: Optional[Metadata]
"""
metadata = metadata or Metadata(
cache_key=key, created_at=datetime.datetime.now(datetime.timezone.utc), number_of_uses=0
)
async with self._lock:
response_clone = clone_model(response)
request_clone = clone_model(request)
stored_response: StoredResponse = (deepcopy(response_clone), deepcopy(request_clone), metadata)
self._cache.put(key, (stored_response, time.monotonic()))
await self._remove_expired_caches()
async def remove(self, key: RemoveTypes) -> None:
"""
Removes the response from the cache.
:param key: Hashed value of concatenated HTTP method and URI or an HTTP response
:type key: Union[str, Response]
"""
if isinstance(key, Response): # pragma: no cover
key = t.cast(str, key.extensions["cache_metadata"]["cache_key"])
async with self._lock:
self._cache.remove_key(key)
async def update_metadata(self, key: str, response: Response, request: Request, metadata: Metadata) -> None:
"""
Updates the metadata of the stored response.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additional information about the stored response
:type metadata: Metadata
"""
async with self._lock:
try:
stored_response, created_at = self._cache.get(key)
stored_response = (stored_response[0], stored_response[1], metadata)
self._cache.put(key, (stored_response, created_at))
return
except KeyError: # pragma: no cover
pass
await self.store(key, response, request, metadata) # pragma: no cover
async def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
"""
Retreives the response from the cache using his key.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:return: An HTTP response and its HTTP request.
:rtype: tp.Optional[StoredResponse]
"""
await self._remove_expired_caches()
async with self._lock:
try:
stored_response, _ = self._cache.get(key)
except KeyError:
return None
return stored_response
async def aclose(self) -> None: # pragma: no cover
return
async def _remove_expired_caches(self) -> None:
if self._ttl is None:
return
async with self._lock:
keys_to_remove = set()
for key in self._cache:
created_at = self._cache.get(key)[1]
if time.monotonic() - created_at > self._ttl:
keys_to_remove.add(key)
for key in keys_to_remove:
self._cache.remove_key(key)
class AsyncS3Storage(AsyncBaseStorage): # pragma: no cover
"""
AWS S3 storage.
:param bucket_name: The name of the bucket to store the responses in
:type bucket_name: str
:param serializer: Serializer capable of serializing and de-serializing http responses, defaults to None
:type serializer: tp.Optional[BaseSerializer], optional
:param ttl: Specifies the maximum number of seconds that the response can be cached, defaults to None
:type ttl: tp.Optional[tp.Union[int, float]], optional
:param check_ttl_every: How often in seconds to check staleness of **all** cache files.
Makes sense only with set `ttl`, defaults to 60
:type check_ttl_every: tp.Union[int, float]
:param client: A client for S3, defaults to None
:type client: tp.Optional[tp.Any], optional
:param path_prefix: A path prefix to use for S3 object keys, defaults to "hishel-"
:type path_prefix: str, optional
"""
def __init__(
self,
bucket_name: str,
serializer: tp.Optional[BaseSerializer] = None,
ttl: tp.Optional[tp.Union[int, float]] = None,
check_ttl_every: tp.Union[int, float] = 60,
client: tp.Optional[tp.Any] = None,
path_prefix: str = "hishel-",
) -> None:
super().__init__(serializer, ttl)
if boto3 is None: # pragma: no cover
raise RuntimeError(
f"The `{type(self).__name__}` was used, but the required packages were not found. "
"Check that you have `Hishel` installed with the `s3` extension as shown.\n"
"```pip install hishel[s3]```"
)
self._bucket_name = bucket_name
client = client or boto3.client("s3")
self._s3_manager = AsyncS3Manager(
client=client,
bucket_name=bucket_name,
is_binary=self._serializer.is_binary,
check_ttl_every=check_ttl_every,
path_prefix=path_prefix,
)
self._lock = AsyncLock()
async def store(self, key: str, response: Response, request: Request, metadata: Metadata | None = None) -> None:
"""
Stores the response in the cache.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additioal information about the stored response
:type metadata: Optional[Metadata]`
"""
metadata = metadata or Metadata(
cache_key=key, created_at=datetime.datetime.now(datetime.timezone.utc), number_of_uses=0
)
async with self._lock:
serialized = self._serializer.dumps(response=response, request=request, metadata=metadata)
await self._s3_manager.write_to(path=key, data=serialized)
await self._remove_expired_caches(key)
async def remove(self, key: RemoveTypes) -> None:
"""
Removes the response from the cache.
:param key: Hashed value of concatenated HTTP method and URI or an HTTP response
:type key: Union[str, Response]
"""
if isinstance(key, Response): # pragma: no cover
key = t.cast(str, key.extensions["cache_metadata"]["cache_key"])
async with self._lock:
await self._s3_manager.remove_entry(key)
async def update_metadata(self, key: str, response: Response, request: Request, metadata: Metadata) -> None:
"""
Updates the metadata of the stored response.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additional information about the stored response
:type metadata: Metadata
"""
async with self._lock:
serialized = self._serializer.dumps(response=response, request=request, metadata=metadata)
await self._s3_manager.write_to(path=key, data=serialized, only_metadata=True)
async def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
"""
Retreives the response from the cache using his key.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:return: An HTTP response and its HTTP request.
:rtype: tp.Optional[StoredResponse]
"""
await self._remove_expired_caches(key)
async with self._lock:
try:
return self._serializer.loads(await self._s3_manager.read_from(path=key))
except Exception:
return None
async def aclose(self) -> None: # pragma: no cover
return
async def _remove_expired_caches(self, key: str) -> None:
if self._ttl is None:
return
async with self._lock:
converted_ttl = float_seconds_to_int_milliseconds(self._ttl)
await self._s3_manager.remove_expired(ttl=converted_ttl, key=key)

View File

@@ -0,0 +1,277 @@
from __future__ import annotations
import types
import typing as tp
import httpcore
import httpx
from httpx import AsyncByteStream, Request, Response
from httpx._exceptions import ConnectError
from hishel._utils import extract_header_values_decoded, normalized_url
from .._controller import Controller, allowed_stale
from .._headers import parse_cache_control
from .._serializers import JSONSerializer, Metadata
from ._storages import AsyncBaseStorage, AsyncFileStorage
if tp.TYPE_CHECKING: # pragma: no cover
from typing_extensions import Self
__all__ = ("AsyncCacheTransport",)
async def fake_stream(content: bytes) -> tp.AsyncIterable[bytes]:
yield content
def generate_504() -> Response:
return Response(status_code=504)
class AsyncCacheStream(AsyncByteStream):
def __init__(self, httpcore_stream: tp.AsyncIterable[bytes]):
self._httpcore_stream = httpcore_stream
async def __aiter__(self) -> tp.AsyncIterator[bytes]:
async for part in self._httpcore_stream:
yield part
async def aclose(self) -> None:
if hasattr(self._httpcore_stream, "aclose"):
await self._httpcore_stream.aclose()
class AsyncCacheTransport(httpx.AsyncBaseTransport):
"""
An HTTPX Transport that supports HTTP caching.
:param transport: `Transport` that our class wraps in order to add an HTTP Cache layer on top of
:type transport: httpx.AsyncBaseTransport
:param storage: Storage that handles how the responses should be saved., defaults to None
:type storage: tp.Optional[AsyncBaseStorage], optional
:param controller: Controller that manages the cache behavior at the specification level, defaults to None
:type controller: tp.Optional[Controller], optional
"""
def __init__(
self,
transport: httpx.AsyncBaseTransport,
storage: tp.Optional[AsyncBaseStorage] = None,
controller: tp.Optional[Controller] = None,
) -> None:
self._transport = transport
self._storage = storage if storage is not None else AsyncFileStorage(serializer=JSONSerializer())
if not isinstance(self._storage, AsyncBaseStorage): # pragma: no cover
raise TypeError(f"Expected subclass of `AsyncBaseStorage` but got `{storage.__class__.__name__}`")
self._controller = controller if controller is not None else Controller()
async def handle_async_request(self, request: Request) -> Response:
"""
Handles HTTP requests while also implementing HTTP caching.
:param request: An HTTP request
:type request: httpx.Request
:return: An HTTP response
:rtype: httpx.Response
"""
if request.extensions.get("cache_disabled", False):
request.headers.update(
[
("Cache-Control", "no-store"),
("Cache-Control", "no-cache"),
*[("cache-control", value) for value in request.headers.get_list("cache-control")],
]
)
if request.method not in ["GET", "HEAD"]:
# If the HTTP method is, for example, POST,
# we must also use the request data to generate the hash.
body_for_key = await request.aread()
request.stream = AsyncCacheStream(fake_stream(body_for_key))
else:
body_for_key = b""
# Construct the HTTPCore request because Controllers and Storages work with HTTPCore requests.
httpcore_request = httpcore.Request(
method=request.method,
url=httpcore.URL(
scheme=request.url.raw_scheme,
host=request.url.raw_host,
port=request.url.port,
target=request.url.raw_path,
),
headers=request.headers.raw,
content=request.stream,
extensions=request.extensions,
)
key = self._controller._key_generator(httpcore_request, body_for_key)
stored_data = await self._storage.retrieve(key)
request_cache_control = parse_cache_control(
extract_header_values_decoded(request.headers.raw, b"Cache-Control")
)
if request_cache_control.only_if_cached and not stored_data:
return generate_504()
if stored_data:
# Try using the stored response if it was discovered.
stored_response, stored_request, metadata = stored_data
# Immediately read the stored response to avoid issues when trying to access the response body.
stored_response.read()
res = self._controller.construct_response_from_cache(
request=httpcore_request,
response=stored_response,
original_request=stored_request,
)
if isinstance(res, httpcore.Response):
# Simply use the response if the controller determines it is ready for use.
return await self._create_hishel_response(
key=key,
response=res,
request=httpcore_request,
cached=True,
revalidated=False,
metadata=metadata,
)
if request_cache_control.only_if_cached:
return generate_504()
if isinstance(res, httpcore.Request):
# Controller has determined that the response needs to be re-validated.
assert isinstance(res.stream, tp.AsyncIterable)
revalidation_request = Request(
method=res.method.decode(),
url=normalized_url(res.url),
headers=res.headers,
stream=AsyncCacheStream(res.stream),
extensions=res.extensions,
)
try:
revalidation_response = await self._transport.handle_async_request(revalidation_request)
except ConnectError:
# If there is a connection error, we can use the stale response if allowed.
if self._controller._allow_stale and allowed_stale(response=stored_response):
return await self._create_hishel_response(
key=key,
response=stored_response,
request=httpcore_request,
cached=True,
revalidated=False,
metadata=metadata,
)
raise # pragma: no cover
assert isinstance(revalidation_response.stream, tp.AsyncIterable)
httpcore_revalidation_response = httpcore.Response(
status=revalidation_response.status_code,
headers=revalidation_response.headers.raw,
content=AsyncCacheStream(revalidation_response.stream),
extensions=revalidation_response.extensions,
)
# Merge headers with the stale response.
final_httpcore_response = self._controller.handle_validation_response(
old_response=stored_response,
new_response=httpcore_revalidation_response,
)
await final_httpcore_response.aread()
await revalidation_response.aclose()
assert isinstance(final_httpcore_response.stream, tp.AsyncIterable)
# RFC 9111: 4.3.3. Handling a Validation Response
# A 304 (Not Modified) response status code indicates that the stored response can be updated and
# reused. A full response (i.e., one containing content) indicates that none of the stored responses
# nominated in the conditional request are suitable. Instead, the cache MUST use the full response to
# satisfy the request. The cache MAY store such a full response, subject to its constraints.
if revalidation_response.status_code != 304 and self._controller.is_cachable(
request=httpcore_request, response=final_httpcore_response
):
await self._storage.store(key, response=final_httpcore_response, request=httpcore_request)
return await self._create_hishel_response(
key=key,
response=final_httpcore_response,
request=httpcore_request,
cached=revalidation_response.status_code == 304,
revalidated=True,
metadata=metadata,
)
regular_response = await self._transport.handle_async_request(request)
assert isinstance(regular_response.stream, tp.AsyncIterable)
httpcore_regular_response = httpcore.Response(
status=regular_response.status_code,
headers=regular_response.headers.raw,
content=AsyncCacheStream(regular_response.stream),
extensions=regular_response.extensions,
)
await httpcore_regular_response.aread()
await httpcore_regular_response.aclose()
if self._controller.is_cachable(request=httpcore_request, response=httpcore_regular_response):
await self._storage.store(
key,
response=httpcore_regular_response,
request=httpcore_request,
)
return await self._create_hishel_response(
key=key,
response=httpcore_regular_response,
request=httpcore_request,
cached=False,
revalidated=False,
)
async def _create_hishel_response(
self,
key: str,
response: httpcore.Response,
request: httpcore.Request,
cached: bool,
revalidated: bool,
metadata: Metadata | None = None,
) -> Response:
if cached:
assert metadata
metadata["number_of_uses"] += 1
await self._storage.update_metadata(key=key, request=request, response=response, metadata=metadata)
response.extensions["from_cache"] = True # type: ignore[index]
response.extensions["cache_metadata"] = metadata # type: ignore[index]
else:
response.extensions["from_cache"] = False # type: ignore[index]
response.extensions["revalidated"] = revalidated # type: ignore[index]
return Response(
status_code=response.status,
headers=response.headers,
stream=AsyncCacheStream(fake_stream(response.content)),
extensions=response.extensions,
)
async def aclose(self) -> None:
await self._storage.aclose()
await self._transport.aclose()
async def __aenter__(self) -> Self:
return self
async def __aexit__(
self,
exc_type: tp.Optional[tp.Type[BaseException]] = None,
exc_value: tp.Optional[BaseException] = None,
traceback: tp.Optional[types.TracebackType] = None,
) -> None:
await self.aclose()

View File

@@ -0,0 +1,582 @@
import logging
import typing as tp
from httpcore import Request, Response
from hishel._headers import Vary, parse_cache_control
from ._utils import (
BaseClock,
Clock,
extract_header_values,
extract_header_values_decoded,
generate_key,
get_safe_url,
header_presents,
parse_date,
)
logger = logging.getLogger("hishel.controller")
HEURISTICALLY_CACHEABLE_STATUS_CODES = (200, 203, 204, 206, 300, 301, 308, 404, 405, 410, 414, 501)
HTTP_METHODS = ["GET", "HEAD", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"]
__all__ = ("Controller", "HEURISTICALLY_CACHEABLE_STATUS_CODES")
def get_updated_headers(
stored_response_headers: tp.List[tp.Tuple[bytes, bytes]],
new_response_headers: tp.List[tp.Tuple[bytes, bytes]],
) -> tp.List[tp.Tuple[bytes, bytes]]:
updated_headers = []
checked = set()
for key, value in stored_response_headers:
if key not in checked and key.lower() != b"content-length":
checked.add(key)
values = extract_header_values(new_response_headers, key)
if values:
updated_headers.extend([(key, value) for value in values])
else:
values = extract_header_values(stored_response_headers, key)
updated_headers.extend([(key, value) for value in values])
for key, value in new_response_headers:
if key not in checked and key.lower() != b"content-length":
values = extract_header_values(new_response_headers, key)
updated_headers.extend([(key, value) for value in values])
return updated_headers
def get_freshness_lifetime(response: Response) -> tp.Optional[int]:
response_cache_control = parse_cache_control(extract_header_values_decoded(response.headers, b"Cache-Control"))
if response_cache_control.max_age is not None:
return response_cache_control.max_age
if header_presents(response.headers, b"expires"):
expires = extract_header_values_decoded(response.headers, b"expires", single=True)[0]
expires_timestamp = parse_date(expires)
if expires_timestamp is None:
return None
date = extract_header_values_decoded(response.headers, b"date", single=True)[0]
date_timestamp = parse_date(date)
if date_timestamp is None:
return None
return expires_timestamp - date_timestamp
return None
def get_heuristic_freshness(response: Response, clock: "BaseClock") -> int:
last_modified = extract_header_values_decoded(response.headers, b"last-modified", single=True)
if last_modified:
last_modified_timestamp = parse_date(last_modified[0])
if last_modified_timestamp is not None:
now = clock.now()
ONE_WEEK = 604_800
return min(ONE_WEEK, int((now - last_modified_timestamp) * 0.1))
ONE_DAY = 86_400
return ONE_DAY
def get_age(response: Response, clock: "BaseClock") -> int:
if not header_presents(response.headers, b"date"):
# If the response does not have a date header, then it is impossible to calculate the age.
# Instead of raising an exception, we return infinity to be sure that the response is not considered fresh.
return float("inf") # type: ignore
date = parse_date(extract_header_values_decoded(response.headers, b"date")[0])
if date is None:
return float("inf") # type: ignore
now = clock.now()
apparent_age = max(0, now - date)
return int(apparent_age)
def allowed_stale(response: Response) -> bool:
response_cache_control = parse_cache_control(extract_header_values_decoded(response.headers, b"Cache-Control"))
if response_cache_control.no_cache:
return False
if response_cache_control.must_revalidate:
return False
return True
class Controller:
def __init__(
self,
cacheable_methods: tp.Optional[tp.List[str]] = None,
cacheable_status_codes: tp.Optional[tp.List[int]] = None,
cache_private: bool = True,
allow_heuristics: bool = False,
clock: tp.Optional[BaseClock] = None,
allow_stale: bool = False,
always_revalidate: bool = False,
force_cache: bool = False,
key_generator: tp.Optional[tp.Callable[[Request, tp.Optional[bytes]], str]] = None,
):
self._cacheable_methods = []
if cacheable_methods is None:
self._cacheable_methods.append("GET")
else:
for method in cacheable_methods:
if method.upper() not in HTTP_METHODS:
raise RuntimeError(
f"Hishel does not support the HTTP method `{method}`.\n"
f"Please use the methods from this list: {HTTP_METHODS}"
)
self._cacheable_methods.append(method.upper())
self._cacheable_status_codes = cacheable_status_codes if cacheable_status_codes else [200, 301, 308]
self._cache_private = cache_private
self._clock = clock if clock else Clock()
self._allow_heuristics = allow_heuristics
self._allow_stale = allow_stale
self._always_revalidate = always_revalidate
self._force_cache = force_cache
self._key_generator = key_generator or generate_key
def is_cachable(self, request: Request, response: Response) -> bool:
"""
Determines whether the response may be cached.
The only thing this method does is determine whether the
response associated with this request can be cached for later use.
`https://www.rfc-editor.org/rfc/rfc9111.html#name-storing-responses-in-caches`
lists the steps that this method simply follows.
"""
method = request.method.decode("ascii")
force_cache = request.extensions.get("force_cache", None)
if response.status not in self._cacheable_status_codes:
logger.debug(
(
f"Considering the resource located at {get_safe_url(request.url)} "
f"as not cachable since its status code ({response.status})"
" is not in the list of cacheable status codes."
)
)
return False
if response.status in (301, 308):
logger.debug(
(
f"Considering the resource located at {get_safe_url(request.url)} "
"as cachable since its status code is a permanent redirect."
)
)
return True
# the request method is understood by the cache
if method not in self._cacheable_methods:
logger.debug(
(
f"Considering the resource located at {get_safe_url(request.url)} "
f"as not cachable since the request method ({method}) is not in the list of cacheable methods."
)
)
return False
if force_cache if force_cache is not None else self._force_cache:
logger.debug(
(
f"Considering the resource located at {get_safe_url(request.url)} "
"as cachable since the request is forced to use the cache."
)
)
return True
response_cache_control = parse_cache_control(extract_header_values_decoded(response.headers, b"cache-control"))
request_cache_control = parse_cache_control(extract_header_values_decoded(request.headers, b"cache-control"))
# the response status code is final
if response.status // 100 == 1:
logger.debug(
(
f"Considering the resource located at {get_safe_url(request.url)} "
"as not cachable since its status code is informational."
)
)
return False
# the no-store cache directive is not present (see Section 5.2.2.5)
if request_cache_control.no_store:
logger.debug(
(
f"Considering the resource located at {get_safe_url(request.url)} "
"as not cachable since the request contains the no-store directive."
)
)
return False
# note that the must-understand cache directive overrides
# no-store in certain circumstances; see Section 5.2.2.3.
if response_cache_control.no_store:
if response_cache_control.must_understand:
logger.debug(
(
f"Skipping the no-store directive for the resource located at {get_safe_url(request.url)} "
"since the response contains the must-understand directive."
)
)
else:
logger.debug(
(
f"Considering the resource located at {get_safe_url(request.url)} "
"as not cachable since the response contains the no-store directive."
)
)
return False
# a shared cache must not store a response with private directive
# Note that we do not implement special handling for the qualified form,
# which would only forbid storing specified headers.
if not self._cache_private and response_cache_control.private:
logger.debug(
(
f"Considering the resource located at {get_safe_url(request.url)} "
"as not cachable since the response contains the private directive."
)
)
return False
expires_presents = header_presents(response.headers, b"expires")
# the response contains at least one of the following:
# - a public response directive (see Section 5.2.2.9);
# - a private response directive, if the cache is not shared (see Section 5.2.2.7);
# - an Expires header field (see Section 5.3);
# - a max-age response directive (see Section 5.2.2.1);
# - if the cache is shared: an s-maxage response directive (see Section 5.2.2.10);
# - a cache extension that allows it to be cached (see Section 5.2.3); or
# - a status code that is defined as heuristically cacheable (see Section 4.2.2).
if self._allow_heuristics and response.status in HEURISTICALLY_CACHEABLE_STATUS_CODES:
logger.debug(
(
f"Considering the resource located at {get_safe_url(request.url)} "
"as cachable since its status code is heuristically cacheable."
)
)
return True
if not any(
[
response_cache_control.public,
response_cache_control.private,
expires_presents,
response_cache_control.max_age is not None,
]
):
logger.debug(
(
f"Considering the resource located at {get_safe_url(request.url)} "
"as not cachable since it does not contain any of the required cache directives."
)
)
return False
logger.debug(
(
f"Considering the resource located at {get_safe_url(request.url)} "
"as cachable since it meets the criteria for being stored in the cache."
)
)
# response is a cachable!
return True
def _make_request_conditional(self, request: Request, response: Response) -> None:
"""
Adds the precondition headers needed for response validation.
This method will use the "Last-Modified" or "Etag" headers
if they are provided in order to create precondition headers.
See also (https://www.rfc-editor.org/rfc/rfc9111.html#name-sending-a-validation-reques)
"""
if header_presents(response.headers, b"last-modified"):
last_modified = extract_header_values(response.headers, b"last-modified", single=True)[0]
logger.debug(
(
f"Adding the 'If-Modified-Since' header with the value of '{last_modified.decode('ascii')}' "
f"to the request for the resource located at {get_safe_url(request.url)}."
)
)
else:
last_modified = None
if header_presents(response.headers, b"etag"):
etag = extract_header_values(response.headers, b"etag", single=True)[0]
logger.debug(
(
f"Adding the 'If-None-Match' header with the value of '{etag.decode('ascii')}' "
f"to the request for the resource located at {get_safe_url(request.url)}."
)
)
else:
etag = None
precondition_headers: tp.List[tp.Tuple[bytes, bytes]] = []
if last_modified:
precondition_headers.append((b"If-Modified-Since", last_modified))
if etag:
precondition_headers.append((b"If-None-Match", etag))
request.headers.extend(precondition_headers)
def _validate_vary(self, request: Request, response: Response, original_request: Request) -> bool:
"""
Determines whether the "vary" headers in the request and response headers are identical.
See also (https://www.rfc-editor.org/rfc/rfc9111.html#name-calculating-cache-keys-with).
"""
vary_headers = extract_header_values_decoded(response.headers, b"vary")
vary = Vary.from_value(vary_values=vary_headers)
for vary_header in vary._values:
if vary_header == "*":
return False # pragma: no cover
if extract_header_values(request.headers, vary_header) != extract_header_values(
original_request.headers, vary_header
):
return False
return True
def construct_response_from_cache(
self, request: Request, response: Response, original_request: Request
) -> tp.Union[Response, Request, None]:
"""
Specifies whether the response should be used, skipped, or validated by the cache.
This method makes a decision regarding what to do with
the stored response when it is retrieved from storage.
It might be ready for use or it might need to be revalidated.
This method mirrors the relevant section from RFC 9111,
see (https://www.rfc-editor.org/rfc/rfc9111.html#name-constructing-responses-from).
Returns:
Response: This response is applicable to the request.
Request: This response can be used for this request, but it must first be revalidated.
None: It is not possible to use this response for this request.
"""
# Use of responses with status codes 301 and 308 is always
# legal as long as they don't adhere to any caching rules.
if response.status in (301, 308):
logger.debug(
(
f"Considering the resource located at {get_safe_url(request.url)} "
"as valid for cache use since its status code is a permanent redirect."
)
)
return response
response_cache_control = parse_cache_control(extract_header_values_decoded(response.headers, b"Cache-Control"))
request_cache_control = parse_cache_control(extract_header_values_decoded(request.headers, b"Cache-Control"))
# request header fields nominated by the stored
# response (if any) match those presented (see Section 4.1)
if not self._validate_vary(request=request, response=response, original_request=original_request):
# If the vary headers does not match, then do not use the response
logger.debug(
(
f"Considering the resource located at {get_safe_url(request.url)} "
"as invalid for cache use since the vary headers do not match."
)
)
return None # pragma: no cover
# !!! this should be after the "vary" header validation.
force_cache = request.extensions.get("force_cache", None)
if force_cache if force_cache is not None else self._force_cache:
logger.debug(
(
f"Considering the resource located at {get_safe_url(request.url)} "
"as valid for cache use since the request is forced to use the cache."
)
)
return response
# the stored response does not contain the
# no-cache directive (Section 5.2.2.4), unless
# it is successfully validated (Section 4.3)
if (
self._always_revalidate
or response_cache_control.no_cache
or response_cache_control.must_revalidate
or request_cache_control.no_cache
):
if self._always_revalidate:
log_text = (
f"Considering the resource located at {get_safe_url(request.url)} "
"as needing revalidation since the cache is set to always revalidate."
)
elif response_cache_control.no_cache:
log_text = (
f"Considering the resource located at {get_safe_url(request.url)} "
"as needing revalidation since the response contains the no-cache directive."
)
elif response_cache_control.must_revalidate:
log_text = (
f"Considering the resource located at {get_safe_url(request.url)} "
"as needing revalidation since the response contains the must-revalidate directive."
)
elif request_cache_control.no_cache:
log_text = (
f"Considering the resource located at {get_safe_url(request.url)} "
"as needing revalidation since the request contains the no-cache directive."
)
else:
assert False, "Unreachable code " # pragma: no cover
logger.debug(log_text)
self._make_request_conditional(request=request, response=response)
return request
freshness_lifetime = get_freshness_lifetime(response)
if freshness_lifetime is None:
logger.debug(
(
"Could not determine the freshness lifetime of "
f"the resource located at {get_safe_url(request.url)}, "
"trying to use heuristics to calculate it."
)
)
if self._allow_heuristics and response.status in HEURISTICALLY_CACHEABLE_STATUS_CODES:
freshness_lifetime = get_heuristic_freshness(response=response, clock=self._clock)
logger.debug(
(
f"Successfully calculated the freshness lifetime of the resource located at "
f"{get_safe_url(request.url)} using heuristics."
)
)
else:
logger.debug(
(
"Could not calculate the freshness lifetime of "
f"the resource located at {get_safe_url(request.url)}. "
"Making a conditional request to revalidate the response."
)
)
# If Freshness cannot be calculated, then send the request
self._make_request_conditional(request=request, response=response)
return request
age = get_age(response, self._clock)
is_fresh = freshness_lifetime > age
# The min-fresh request directive indicates that the client
# prefers a response whose freshness lifetime is no less than
# its current age plus the specified time in seconds.
# That is, the client wants a response that will still
# be fresh for at least the specified number of seconds.
if request_cache_control.min_fresh is not None:
if freshness_lifetime < (age + request_cache_control.min_fresh):
logger.debug(
(
f"Considering the resource located at {get_safe_url(request.url)} "
"as invalid for cache use since the time left for "
"freshness is less than the min-fresh directive."
)
)
return None
# The max-stale request directive indicates that the
# client will accept a response that has exceeded its freshness lifetime.
# If a value is present, then the client is willing to accept a response
# that has exceeded its freshness lifetime by no more than the specified
# number of seconds. If no value is assigned to max-stale, then
# the client will accept a stale response of any age.
if not is_fresh and request_cache_control.max_stale is not None:
exceeded_freshness_lifetime = age - freshness_lifetime
if request_cache_control.max_stale < exceeded_freshness_lifetime:
logger.debug(
(
f"Considering the resource located at {get_safe_url(request.url)} "
"as invalid for cache use since the freshness lifetime has been exceeded more than max-stale."
)
)
return None
else:
logger.debug(
(
f"Considering the resource located at {get_safe_url(request.url)} "
"as valid for cache use since the freshness lifetime has been exceeded less than max-stale."
)
)
return response
# The max-age request directive indicates that
# the client prefers a response whose age is
# less than or equal to the specified number of seconds.
# Unless the max-stale request directive is also present,
# the client does not wish to receive a stale response.
if request_cache_control.max_age is not None:
if request_cache_control.max_age < age:
logger.debug(
(
f"Considering the resource located at {get_safe_url(request.url)} "
"as invalid for cache use since the age of the response exceeds the max-age directive."
)
)
return None
# the stored response is one of the following:
# fresh (see Section 4.2), or
# allowed to be served stale (see Section 4.2.4), or
# successfully validated (see Section 4.3).
if is_fresh:
logger.debug(
(
f"Considering the resource located at {get_safe_url(request.url)} "
"as valid for cache use since it is fresh."
)
)
return response
else:
logger.debug(
(
f"Considering the resource located at {get_safe_url(request.url)} "
"as needing revalidation since it is not fresh."
)
)
# Otherwise, make a conditional request
self._make_request_conditional(request=request, response=response)
return request
def handle_validation_response(self, old_response: Response, new_response: Response) -> Response:
"""
Handles incoming validation response.
This method takes care of what to do with the incoming
validation response; if it is a 304 response, it updates
the headers with the new response and returns it.
This method mirrors the relevant section from RFC 9111,
see (https://www.rfc-editor.org/rfc/rfc9111.html#name-handling-a-validation-respo).
"""
if new_response.status == 304:
headers = get_updated_headers(
stored_response_headers=old_response.headers,
new_response_headers=new_response.headers,
)
old_response.headers = headers
return old_response
else:
return new_response

View File

@@ -0,0 +1,10 @@
__all__ = ("CacheControlError", "ParseError", "ValidationError")
class CacheControlError(Exception): ...
class ParseError(CacheControlError): ...
class ValidationError(CacheControlError): ...

View File

@@ -0,0 +1,54 @@
import typing as tp
import anyio
class AsyncBaseFileManager:
def __init__(self, is_binary: bool) -> None:
self.is_binary = is_binary
async def write_to(self, path: str, data: tp.Union[bytes, str], is_binary: tp.Optional[bool] = None) -> None:
raise NotImplementedError()
async def read_from(self, path: str, is_binary: tp.Optional[bool] = None) -> tp.Union[bytes, str]:
raise NotImplementedError()
class AsyncFileManager(AsyncBaseFileManager):
async def write_to(self, path: str, data: tp.Union[bytes, str], is_binary: tp.Optional[bool] = None) -> None:
is_binary = self.is_binary if is_binary is None else is_binary
mode = "wb" if is_binary else "wt"
async with await anyio.open_file(path, mode) as f: # type: ignore[call-overload]
await f.write(data)
async def read_from(self, path: str, is_binary: tp.Optional[bool] = None) -> tp.Union[bytes, str]:
is_binary = self.is_binary if is_binary is None else is_binary
mode = "rb" if is_binary else "rt"
async with await anyio.open_file(path, mode) as f: # type: ignore[call-overload]
return tp.cast(tp.Union[bytes, str], await f.read())
class BaseFileManager:
def __init__(self, is_binary: bool) -> None:
self.is_binary = is_binary
def write_to(self, path: str, data: tp.Union[bytes, str], is_binary: tp.Optional[bool] = None) -> None:
raise NotImplementedError()
def read_from(self, path: str, is_binary: tp.Optional[bool] = None) -> tp.Union[bytes, str]:
raise NotImplementedError()
class FileManager(BaseFileManager):
def write_to(self, path: str, data: tp.Union[bytes, str], is_binary: tp.Optional[bool] = None) -> None:
is_binary = self.is_binary if is_binary is None else is_binary
mode = "wb" if is_binary else "wt"
with open(path, mode) as f:
f.write(data)
def read_from(self, path: str, is_binary: tp.Optional[bool] = None) -> tp.Union[bytes, str]:
is_binary = self.is_binary if is_binary is None else is_binary
mode = "rb" if is_binary else "rt"
with open(path, mode) as f:
return tp.cast(tp.Union[bytes, str], f.read())

View File

@@ -0,0 +1,215 @@
import string
from typing import Any, Dict, List, Optional, Union
from ._exceptions import ParseError, ValidationError
## Grammar
HTAB = "\t"
SP = " "
obs_text = "".join(chr(i) for i in range(0x80, 0xFF + 1)) # 0x80-0xFF
tchar = "!#$%&'*+-.^_`|~0123456789" + string.ascii_letters
qdtext = "".join(
[
HTAB,
SP,
"\x21",
"".join(chr(i) for i in range(0x23, 0x5B + 1)), # 0x23-0x5b
"".join(chr(i) for i in range(0x5D, 0x7E + 1)), # 0x5D-0x7E
obs_text,
]
)
TIME_FIELDS = [
"max_age",
"max_stale",
"min_fresh",
"s_maxage",
]
BOOLEAN_FIELDS = [
"immutable",
"must_revalidate",
"must_understand",
"no_store",
"no_transform",
"only_if_cached",
"public",
"proxy_revalidate",
]
LIST_FIELDS = ["no_cache", "private"]
__all__ = (
"CacheControl",
"Vary",
)
def strip_ows_around(text: str) -> str:
return text.strip(" ").strip("\t")
def normalize_directive(text: str) -> str:
return text.replace("-", "_")
def parse_cache_control(cache_control_values: List[str]) -> "CacheControl":
directives = {}
for cache_control_value in cache_control_values:
if "no-cache=" in cache_control_value or "private=" in cache_control_value:
cache_control_splited = [cache_control_value]
else:
cache_control_splited = cache_control_value.split(",")
for directive in cache_control_splited:
key: str = ""
value: Optional[str] = None
dquote = False
if not directive:
raise ParseError("The directive should not be left blank.")
directive = strip_ows_around(directive)
if not directive:
raise ParseError("The directive should not contain only whitespaces.")
for i, key_char in enumerate(directive):
if key_char == "=":
value = directive[i + 1 :]
if not value:
raise ParseError("The directive value cannot be left blank.")
if value[0] == '"':
dquote = True
if dquote and value[-1] != '"':
raise ParseError("Invalid quotes around the value.")
if not dquote:
for value_char in value:
if value_char not in tchar:
raise ParseError(
f"The character '{value_char!r}' is not permitted for the unquoted values."
)
else:
for value_char in value[1:-1]:
if value_char not in qdtext:
raise ParseError(
f"The character '{value_char!r}' is not permitted for the quoted values."
)
break
if key_char not in tchar:
raise ParseError(f"The character '{key_char!r}' is not permitted in the directive name.")
key += key_char
directives[key] = value
validated_data = CacheControl.validate(directives)
return CacheControl(**validated_data)
class Vary:
def __init__(self, values: List[str]) -> None:
self._values = values
@classmethod
def from_value(cls, vary_values: List[str]) -> "Vary":
values = []
for vary_value in vary_values:
for field_name in vary_value.split(","):
field_name = field_name.strip()
values.append(field_name)
return Vary(values)
class CacheControl:
def __init__(
self,
immutable: bool = False, # [RFC8246]
max_age: Optional[int] = None, # [RFC9111, Section 5.2.1.1, 5.2.2.1]
max_stale: Optional[int] = None, # [RFC9111, Section 5.2.1.2]
min_fresh: Optional[int] = None, # [RFC9111, Section 5.2.1.3]
must_revalidate: bool = False, # [RFC9111, Section 5.2.2.2]
must_understand: bool = False, # [RFC9111, Section 5.2.2.3]
no_cache: Union[bool, List[str]] = False, # [RFC9111, Section 5.2.1.4, 5.2.2.4]
no_store: bool = False, # [RFC9111, Section 5.2.1.5, 5.2.2.5]
no_transform: bool = False, # [RFC9111, Section 5.2.1.6, 5.2.2.6]
only_if_cached: bool = False, # [RFC9111, Section 5.2.1.7]
private: Union[bool, List[str]] = False, # [RFC9111, Section 5.2.2.7]
proxy_revalidate: bool = False, # [RFC9111, Section 5.2.2.8]
public: bool = False, # [RFC9111, Section 5.2.2.9]
s_maxage: Optional[int] = None, # [RFC9111, Section 5.2.2.10]
) -> None:
self.immutable = immutable
self.max_age = max_age
self.max_stale = max_stale
self.min_fresh = min_fresh
self.must_revalidate = must_revalidate
self.must_understand = must_understand
self.no_cache = no_cache
self.no_store = no_store
self.no_transform = no_transform
self.only_if_cached = only_if_cached
self.private = private
self.proxy_revalidate = proxy_revalidate
self.public = public
self.s_maxage = s_maxage
@classmethod
def validate(cls, directives: Dict[str, Any]) -> Dict[str, Any]:
validated_data: Dict[str, Any] = {}
for key, value in directives.items():
key = normalize_directive(key)
if key in TIME_FIELDS:
if value is None:
raise ValidationError(f"The directive '{key}' necessitates a value.")
if value[0] == '"' or value[-1] == '"':
raise ValidationError(f"The argument '{key}' should be an integer, but a quote was found.")
try:
validated_data[key] = int(value)
except Exception:
raise ValidationError(f"The argument '{key}' should be an integer, but got '{value!r}'.")
elif key in BOOLEAN_FIELDS:
if value is not None:
raise ValidationError(f"The directive '{key}' should have no value, but it does.")
validated_data[key] = True
elif key in LIST_FIELDS:
if value is None:
validated_data[key] = True
else:
values = []
for list_value in value[1:-1].split(","):
if not list_value:
raise ValidationError("The list value must not be empty.")
list_value = strip_ows_around(list_value)
values.append(list_value)
validated_data[key] = values
return validated_data
def __repr__(self) -> str:
fields = ""
for key in TIME_FIELDS:
key = key.replace("-", "_")
value = getattr(self, key)
if value:
fields += f"{key}={value}, "
for key in BOOLEAN_FIELDS:
key = key.replace("-", "_")
value = getattr(self, key)
if value:
fields += f"{key}, "
fields = fields[:-2]
return f"<{type(self).__name__} {fields}>"

View File

@@ -0,0 +1,71 @@
from collections import OrderedDict
from typing import DefaultDict, Dict, Generic, Iterator, Tuple, TypeVar
K = TypeVar("K")
V = TypeVar("V")
__all__ = ["LFUCache"]
class LFUCache(Generic[K, V]):
def __init__(self, capacity: int):
if capacity <= 0:
raise ValueError("Capacity must be positive")
self.capacity = capacity
self.cache: Dict[K, Tuple[V, int]] = {} # To store key-value pairs
self.freq_count: DefaultDict[int, OrderedDict[K, V]] = DefaultDict(
OrderedDict
) # To store frequency of each key
self.min_freq = 0 # To keep track of the minimum frequency
def get(self, key: K) -> V:
if key in self.cache:
value, freq = self.cache[key]
# Update frequency and move the key to the next frequency bucket
self.freq_count[freq].pop(key)
if not self.freq_count[freq]: # If the current frequency has no keys, update min_freq
del self.freq_count[freq]
if freq == self.min_freq:
self.min_freq += 1
freq += 1
self.freq_count[freq][key] = value
self.cache[key] = (value, freq)
return value
raise KeyError(f"Key {key} not found")
def put(self, key: K, value: V) -> None:
if key in self.cache:
_, freq = self.cache[key]
# Update frequency and move the key to the next frequency bucket
self.freq_count[freq].pop(key)
if not self.freq_count[freq]:
del self.freq_count[freq]
if freq == self.min_freq:
self.min_freq += 1
freq += 1
self.freq_count[freq][key] = value
self.cache[key] = (value, freq)
else:
# Check if cache is full, evict the least frequently used item
if len(self.cache) == self.capacity:
evicted_key, _ = self.freq_count[self.min_freq].popitem(last=False)
del self.cache[evicted_key]
# Add the new key-value pair with frequency 1
self.cache[key] = (value, 1)
self.freq_count[1][key] = value
self.min_freq = 1
def remove_key(self, key: K) -> None:
if key in self.cache:
_, freq = self.cache[key]
self.freq_count[freq].pop(key)
if not self.freq_count[freq]: # If the current frequency has no keys, update min_freq
del self.freq_count[freq]
if freq == self.min_freq:
self.min_freq += 1
del self.cache[key]
def __iter__(self) -> Iterator[K]:
yield from self.cache

View File

@@ -0,0 +1,122 @@
import time
import typing as tp
from anyio import to_thread
from botocore.exceptions import ClientError
def get_timestamp_in_ms() -> float:
return time.time() * 1000
class S3Manager:
def __init__(
self,
client: tp.Any,
bucket_name: str,
check_ttl_every: tp.Union[int, float],
is_binary: bool = False,
path_prefix: str = "hishel-",
):
self._client = client
self._bucket_name = bucket_name
self._is_binary = is_binary
self._last_cleaned = time.monotonic()
self._check_ttl_every = check_ttl_every
self._path_prefix = path_prefix
def write_to(self, path: str, data: tp.Union[bytes, str], only_metadata: bool = False) -> None:
path = self._path_prefix + path
if isinstance(data, str):
data = data.encode("utf-8")
created_at = None
if only_metadata:
try:
response = self._client.get_object(
Bucket=self._bucket_name,
Key=path,
)
created_at = response["Metadata"]["created_at"]
except Exception:
pass
self._client.put_object(
Bucket=self._bucket_name,
Key=path,
Body=data,
Metadata={"created_at": created_at or str(get_timestamp_in_ms())},
)
def read_from(self, path: str) -> tp.Union[bytes, str]:
path = self._path_prefix + path
response = self._client.get_object(
Bucket=self._bucket_name,
Key=path,
)
content = response["Body"].read()
if self._is_binary: # pragma: no cover
return tp.cast(bytes, content)
return tp.cast(str, content.decode("utf-8"))
def remove_expired(self, ttl: int, key: str) -> None:
path = self._path_prefix + key
if time.monotonic() - self._last_cleaned < self._check_ttl_every:
try:
response = self._client.get_object(Bucket=self._bucket_name, Key=path)
if get_timestamp_in_ms() - float(response["Metadata"]["created_at"]) > ttl:
self._client.delete_object(Bucket=self._bucket_name, Key=path)
return
except ClientError as e:
if e.response["Error"]["Code"] == "NoSuchKey":
return
raise e
self._last_cleaned = time.monotonic()
for obj in self._client.list_objects(Bucket=self._bucket_name).get("Contents", []):
if not obj["Key"].startswith(self._path_prefix): # pragma: no cover
continue
try:
metadata_obj = self._client.head_object(Bucket=self._bucket_name, Key=obj["Key"]).get("Metadata", {})
except ClientError as e:
if e.response["Error"]["Code"] == "404":
continue
if not metadata_obj or "created_at" not in metadata_obj:
continue
if get_timestamp_in_ms() - float(metadata_obj["created_at"]) > ttl:
self._client.delete_object(Bucket=self._bucket_name, Key=obj["Key"])
def remove_entry(self, key: str) -> None:
path = self._path_prefix + key
self._client.delete_object(Bucket=self._bucket_name, Key=path)
class AsyncS3Manager: # pragma: no cover
def __init__(
self,
client: tp.Any,
bucket_name: str,
check_ttl_every: tp.Union[int, float],
is_binary: bool = False,
path_prefix: str = "hishel-",
):
self._sync_manager = S3Manager(client, bucket_name, check_ttl_every, is_binary, path_prefix)
async def write_to(self, path: str, data: tp.Union[bytes, str], only_metadata: bool = False) -> None:
return await to_thread.run_sync(self._sync_manager.write_to, path, data, only_metadata)
async def read_from(self, path: str) -> tp.Union[bytes, str]:
return await to_thread.run_sync(self._sync_manager.read_from, path)
async def remove_expired(self, ttl: int, key: str) -> None:
return await to_thread.run_sync(self._sync_manager.remove_expired, ttl, key)
async def remove_entry(self, key: str) -> None:
return await to_thread.run_sync(self._sync_manager.remove_entry, key)

View File

@@ -0,0 +1,329 @@
import base64
import json
import pickle
import typing as tp
from datetime import datetime
from httpcore import Request, Response
from hishel._utils import normalized_url
try:
import yaml
except ImportError: # pragma: no cover
yaml = None # type: ignore
HEADERS_ENCODING = "iso-8859-1"
KNOWN_RESPONSE_EXTENSIONS = ("http_version", "reason_phrase")
KNOWN_REQUEST_EXTENSIONS = ("timeout", "sni_hostname")
__all__ = ("PickleSerializer", "JSONSerializer", "YAMLSerializer", "BaseSerializer", "clone_model")
T = tp.TypeVar("T", Request, Response)
def clone_model(model: T) -> T:
if isinstance(model, Response):
return Response(
status=model.status,
headers=model.headers,
content=model.content,
extensions={key: value for key, value in model.extensions.items() if key in KNOWN_RESPONSE_EXTENSIONS},
) # type: ignore
else:
return Request(
method=model.method,
url=normalized_url(model.url),
headers=model.headers,
extensions={key: value for key, value in model.extensions.items() if key in KNOWN_REQUEST_EXTENSIONS},
) # type: ignore
class Metadata(tp.TypedDict):
number_of_uses: int
created_at: datetime
cache_key: str
class BaseSerializer:
def dumps(self, response: Response, request: Request, metadata: Metadata) -> tp.Union[str, bytes]:
raise NotImplementedError()
def loads(self, data: tp.Union[str, bytes]) -> tp.Tuple[Response, Request, Metadata]:
raise NotImplementedError()
@property
def is_binary(self) -> bool:
raise NotImplementedError()
class PickleSerializer(BaseSerializer):
"""
A simple pickle-based serializer.
"""
def dumps(self, response: Response, request: Request, metadata: Metadata) -> tp.Union[str, bytes]:
"""
Dumps the HTTP response and its HTTP request.
:param response: An HTTP response
:type response: Response
:param request: An HTTP request
:type request: Request
:param metadata: Additional information about the stored response
:type metadata: Metadata
:return: Serialized response
:rtype: tp.Union[str, bytes]
"""
clone_response = clone_model(response)
clone_request = clone_model(request)
return pickle.dumps((clone_response, clone_request, metadata))
def loads(self, data: tp.Union[str, bytes]) -> tp.Tuple[Response, Request, Metadata]:
"""
Loads the HTTP response and its HTTP request from serialized data.
:param data: Serialized data
:type data: tp.Union[str, bytes]
:return: HTTP response and its HTTP request
:rtype: tp.Tuple[Response, Request, Metadata]
"""
assert isinstance(data, bytes)
return tp.cast(tp.Tuple[Response, Request, Metadata], pickle.loads(data))
@property
def is_binary(self) -> bool: # pragma: no cover
return True
class JSONSerializer(BaseSerializer):
"""A simple json-based serializer."""
def dumps(self, response: Response, request: Request, metadata: Metadata) -> tp.Union[str, bytes]:
"""
Dumps the HTTP response and its HTTP request.
:param response: An HTTP response
:type response: Response
:param request: An HTTP request
:type request: Request
:param metadata: Additional information about the stored response
:type metadata: Metadata
:return: Serialized response
:rtype: tp.Union[str, bytes]
"""
response_dict = {
"status": response.status,
"headers": [
(key.decode(HEADERS_ENCODING), value.decode(HEADERS_ENCODING)) for key, value in response.headers
],
"content": base64.b64encode(response.content).decode("ascii"),
"extensions": {
key: value.decode("ascii")
for key, value in response.extensions.items()
if key in KNOWN_RESPONSE_EXTENSIONS
},
}
request_dict = {
"method": request.method.decode("ascii"),
"url": normalized_url(request.url),
"headers": [
(key.decode(HEADERS_ENCODING), value.decode(HEADERS_ENCODING)) for key, value in request.headers
],
"extensions": {key: value for key, value in request.extensions.items() if key in KNOWN_REQUEST_EXTENSIONS},
}
metadata_dict = {
"cache_key": metadata["cache_key"],
"number_of_uses": metadata["number_of_uses"],
"created_at": metadata["created_at"].strftime("%a, %d %b %Y %H:%M:%S GMT"),
}
full_json = {
"response": response_dict,
"request": request_dict,
"metadata": metadata_dict,
}
return json.dumps(full_json, indent=4)
def loads(self, data: tp.Union[str, bytes]) -> tp.Tuple[Response, Request, Metadata]:
"""
Loads the HTTP response and its HTTP request from serialized data.
:param data: Serialized data
:type data: tp.Union[str, bytes]
:return: HTTP response and its HTTP request
:rtype: tp.Tuple[Response, Request, Metadata]
"""
full_json = json.loads(data)
response_dict = full_json["response"]
request_dict = full_json["request"]
metadata_dict = full_json["metadata"]
metadata_dict["created_at"] = datetime.strptime(
metadata_dict["created_at"],
"%a, %d %b %Y %H:%M:%S GMT",
)
response = Response(
status=response_dict["status"],
headers=[
(key.encode(HEADERS_ENCODING), value.encode(HEADERS_ENCODING))
for key, value in response_dict["headers"]
],
content=base64.b64decode(response_dict["content"].encode("ascii")),
extensions={
key: value.encode("ascii")
for key, value in response_dict["extensions"].items()
if key in KNOWN_RESPONSE_EXTENSIONS
},
)
request = Request(
method=request_dict["method"],
url=request_dict["url"],
headers=[
(key.encode(HEADERS_ENCODING), value.encode(HEADERS_ENCODING)) for key, value in request_dict["headers"]
],
extensions={
key: value for key, value in request_dict["extensions"].items() if key in KNOWN_REQUEST_EXTENSIONS
},
)
metadata = Metadata(
cache_key=metadata_dict["cache_key"],
created_at=metadata_dict["created_at"],
number_of_uses=metadata_dict["number_of_uses"],
)
return response, request, metadata
@property
def is_binary(self) -> bool:
return False
class YAMLSerializer(BaseSerializer):
"""A simple yaml-based serializer."""
def dumps(self, response: Response, request: Request, metadata: Metadata) -> tp.Union[str, bytes]:
"""
Dumps the HTTP response and its HTTP request.
:param response: An HTTP response
:type response: Response
:param request: An HTTP request
:type request: Request
:param metadata: Additional information about the stored response
:type metadata: Metadata
:return: Serialized response
:rtype: tp.Union[str, bytes]
"""
if yaml is None: # pragma: no cover
raise RuntimeError(
f"The `{type(self).__name__}` was used, but the required packages were not found. "
"Check that you have `Hishel` installed with the `yaml` extension as shown.\n"
"```pip install hishel[yaml]```"
)
response_dict = {
"status": response.status,
"headers": [
(key.decode(HEADERS_ENCODING), value.decode(HEADERS_ENCODING)) for key, value in response.headers
],
"content": base64.b64encode(response.content).decode("ascii"),
"extensions": {
key: value.decode("ascii")
for key, value in response.extensions.items()
if key in KNOWN_RESPONSE_EXTENSIONS
},
}
request_dict = {
"method": request.method.decode("ascii"),
"url": normalized_url(request.url),
"headers": [
(key.decode(HEADERS_ENCODING), value.decode(HEADERS_ENCODING)) for key, value in request.headers
],
"extensions": {key: value for key, value in request.extensions.items() if key in KNOWN_REQUEST_EXTENSIONS},
}
metadata_dict = {
"cache_key": metadata["cache_key"],
"number_of_uses": metadata["number_of_uses"],
"created_at": metadata["created_at"].strftime("%a, %d %b %Y %H:%M:%S GMT"),
}
full_json = {
"response": response_dict,
"request": request_dict,
"metadata": metadata_dict,
}
return yaml.safe_dump(full_json, sort_keys=False)
def loads(self, data: tp.Union[str, bytes]) -> tp.Tuple[Response, Request, Metadata]:
"""
Loads the HTTP response and its HTTP request from serialized data.
:param data: Serialized data
:type data: tp.Union[str, bytes]
:raises RuntimeError: When used without the `yaml` extension installed
:return: HTTP response and its HTTP request
:rtype: tp.Tuple[Response, Request, Metadata]
"""
if yaml is None: # pragma: no cover
raise RuntimeError(
f"The `{type(self).__name__}` was used, but the required packages were not found. "
"Check that you have `Hishel` installed with the `yaml` extension as shown.\n"
"```pip install hishel[yaml]```"
)
full_json = yaml.safe_load(data)
response_dict = full_json["response"]
request_dict = full_json["request"]
metadata_dict = full_json["metadata"]
metadata_dict["created_at"] = datetime.strptime(
metadata_dict["created_at"],
"%a, %d %b %Y %H:%M:%S GMT",
)
response = Response(
status=response_dict["status"],
headers=[
(key.encode(HEADERS_ENCODING), value.encode(HEADERS_ENCODING))
for key, value in response_dict["headers"]
],
content=base64.b64decode(response_dict["content"].encode("ascii")),
extensions={
key: value.encode("ascii")
for key, value in response_dict["extensions"].items()
if key in KNOWN_RESPONSE_EXTENSIONS
},
)
request = Request(
method=request_dict["method"],
url=request_dict["url"],
headers=[
(key.encode(HEADERS_ENCODING), value.encode(HEADERS_ENCODING)) for key, value in request_dict["headers"]
],
extensions={
key: value for key, value in request_dict["extensions"].items() if key in KNOWN_REQUEST_EXTENSIONS
},
)
metadata = Metadata(
cache_key=metadata_dict["cache_key"],
created_at=metadata_dict["created_at"],
number_of_uses=metadata_dict["number_of_uses"],
)
return response, request, metadata
@property
def is_binary(self) -> bool: # pragma: no cover
return False

View File

@@ -0,0 +1,5 @@
from ._client import * # noqa: F403
from ._mock import * # noqa: F403
from ._pool import * # noqa: F403
from ._storages import * # noqa: F403
from ._transports import * # noqa: F403

View File

@@ -0,0 +1,30 @@
import typing as tp
import httpx
from hishel._sync._transports import CacheTransport
__all__ = ("CacheClient",)
class CacheClient(httpx.Client):
def __init__(self, *args: tp.Any, **kwargs: tp.Any):
self._storage = kwargs.pop("storage") if "storage" in kwargs else None
self._controller = kwargs.pop("controller") if "controller" in kwargs else None
super().__init__(*args, **kwargs)
def _init_transport(self, *args, **kwargs) -> CacheTransport: # type: ignore
_transport = super()._init_transport(*args, **kwargs)
return CacheTransport(
transport=_transport,
storage=self._storage,
controller=self._controller,
)
def _init_proxy_transport(self, *args, **kwargs) -> CacheTransport: # type: ignore
_transport = super()._init_proxy_transport(*args, **kwargs) # pragma: no cover
return CacheTransport( # pragma: no cover
transport=_transport,
storage=self._storage,
controller=self._controller,
)

View File

@@ -0,0 +1,43 @@
import typing as tp
from types import TracebackType
import httpcore
import httpx
from httpcore._sync.interfaces import RequestInterface
if tp.TYPE_CHECKING: # pragma: no cover
from typing_extensions import Self
__all__ = ("MockConnectionPool", "MockTransport")
class MockConnectionPool(RequestInterface):
def handle_request(self, request: httpcore.Request) -> httpcore.Response:
assert isinstance(request.stream, tp.Iterable)
data = b"".join([chunk for chunk in request.stream]) # noqa: F841
return self.mocked_responses.pop(0)
def add_responses(self, responses: tp.List[httpcore.Response]) -> None:
if not hasattr(self, "mocked_responses"):
self.mocked_responses = []
self.mocked_responses.extend(responses)
def __enter__(self) -> "Self":
return self
def __exit__(
self,
exc_type: tp.Optional[tp.Type[BaseException]] = None,
exc_value: tp.Optional[BaseException] = None,
traceback: tp.Optional[TracebackType] = None,
) -> None: ...
class MockTransport(httpx.BaseTransport):
def handle_request(self, request: httpx.Request) -> httpx.Response:
return self.mocked_responses.pop(0)
def add_responses(self, responses: tp.List[httpx.Response]) -> None:
if not hasattr(self, "mocked_responses"):
self.mocked_responses = []
self.mocked_responses.extend(responses)

View File

@@ -0,0 +1,201 @@
from __future__ import annotations
import types
import typing as tp
from httpcore._sync.interfaces import RequestInterface
from httpcore._exceptions import ConnectError
from httpcore._models import Request, Response
from .._controller import Controller, allowed_stale
from .._headers import parse_cache_control
from .._serializers import JSONSerializer, Metadata
from .._utils import extract_header_values_decoded
from ._storages import BaseStorage, FileStorage
T = tp.TypeVar("T")
__all__ = ("CacheConnectionPool",)
def fake_stream(content: bytes) -> tp.Iterable[bytes]:
yield content
def generate_504() -> Response:
return Response(status=504)
class CacheConnectionPool(RequestInterface):
"""An HTTP Core Connection Pool that supports HTTP caching.
:param pool: `Connection Pool` that our class wraps in order to add an HTTP Cache layer on top of
:type pool: RequestInterface
:param storage: Storage that handles how the responses should be saved., defaults to None
:type storage: tp.Optional[BaseStorage], optional
:param controller: Controller that manages the cache behavior at the specification level, defaults to None
:type controller: tp.Optional[Controller], optional
"""
def __init__(
self,
pool: RequestInterface,
storage: tp.Optional[BaseStorage] = None,
controller: tp.Optional[Controller] = None,
) -> None:
self._pool = pool
self._storage = storage if storage is not None else FileStorage(serializer=JSONSerializer())
if not isinstance(self._storage, BaseStorage): # pragma: no cover
raise TypeError(f"Expected subclass of `BaseStorage` but got `{storage.__class__.__name__}`")
self._controller = controller if controller is not None else Controller()
def handle_request(self, request: Request) -> Response:
"""
Handles HTTP requests while also implementing HTTP caching.
:param request: An HTTP request
:type request: httpcore.Request
:return: An HTTP response
:rtype: httpcore.Response
"""
if request.extensions.get("cache_disabled", False):
request.headers.extend([(b"cache-control", b"no-cache"), (b"cache-control", b"max-age=0")])
if request.method.upper() not in [b"GET", b"HEAD"]:
# If the HTTP method is, for example, POST,
# we must also use the request data to generate the hash.
assert isinstance(request.stream, tp.Iterable)
body_for_key = b"".join([chunk for chunk in request.stream])
request.stream = fake_stream(body_for_key)
else:
body_for_key = b""
key = self._controller._key_generator(request, body_for_key)
stored_data = self._storage.retrieve(key)
request_cache_control = parse_cache_control(extract_header_values_decoded(request.headers, b"Cache-Control"))
if request_cache_control.only_if_cached and not stored_data:
return generate_504()
if stored_data:
# Try using the stored response if it was discovered.
stored_response, stored_request, metadata = stored_data
# Immediately read the stored response to avoid issues when trying to access the response body.
stored_response.read()
res = self._controller.construct_response_from_cache(
request=request,
response=stored_response,
original_request=stored_request,
)
if isinstance(res, Response):
# Simply use the response if the controller determines it is ready for use.
return self._create_hishel_response(
key=key,
response=stored_response,
request=request,
metadata=metadata,
cached=True,
revalidated=False,
)
if request_cache_control.only_if_cached:
return generate_504()
if isinstance(res, Request):
# Controller has determined that the response needs to be re-validated.
try:
revalidation_response = self._pool.handle_request(res)
except ConnectError:
# If there is a connection error, we can use the stale response if allowed.
if self._controller._allow_stale and allowed_stale(response=stored_response):
return self._create_hishel_response(
key=key,
response=stored_response,
request=request,
metadata=metadata,
cached=True,
revalidated=False,
)
raise # pragma: no cover
# Merge headers with the stale response.
final_response = self._controller.handle_validation_response(
old_response=stored_response, new_response=revalidation_response
)
final_response.read()
# RFC 9111: 4.3.3. Handling a Validation Response
# A 304 (Not Modified) response status code indicates that the stored response can be updated and
# reused. A full response (i.e., one containing content) indicates that none of the stored responses
# nominated in the conditional request are suitable. Instead, the cache MUST use the full response to
# satisfy the request. The cache MAY store such a full response, subject to its constraints.
if revalidation_response.status != 304 and self._controller.is_cachable(
request=request, response=final_response
):
self._storage.store(key, response=final_response, request=request)
return self._create_hishel_response(
key=key,
response=final_response,
request=request,
cached=revalidation_response.status == 304,
revalidated=True,
metadata=metadata,
)
regular_response = self._pool.handle_request(request)
regular_response.read()
if self._controller.is_cachable(request=request, response=regular_response):
self._storage.store(key, response=regular_response, request=request)
return self._create_hishel_response(
key=key, response=regular_response, request=request, cached=False, revalidated=False
)
def _create_hishel_response(
self,
key: str,
response: Response,
request: Request,
cached: bool,
revalidated: bool,
metadata: Metadata | None = None,
) -> Response:
if cached:
assert metadata
metadata["number_of_uses"] += 1
self._storage.update_metadata(key=key, request=request, response=response, metadata=metadata)
response.extensions["from_cache"] = True # type: ignore[index]
response.extensions["cache_metadata"] = metadata # type: ignore[index]
else:
response.extensions["from_cache"] = False # type: ignore[index]
response.extensions["revalidated"] = revalidated # type: ignore[index]
return response
def close(self) -> None:
self._storage.close()
if hasattr(self._pool, "close"): # pragma: no cover
self._pool.close()
def __enter__(self: T) -> T:
return self
def __exit__(
self,
exc_type: tp.Optional[tp.Type[BaseException]] = None,
exc_value: tp.Optional[BaseException] = None,
traceback: tp.Optional[types.TracebackType] = None,
) -> None:
self.close()

View File

@@ -0,0 +1,772 @@
from __future__ import annotations
import datetime
import logging
import os
import time
import typing as t
import typing as tp
import warnings
from copy import deepcopy
from pathlib import Path
try:
import boto3
from .._s3 import S3Manager
except ImportError: # pragma: no cover
boto3 = None # type: ignore
try:
import sqlite3
except ImportError: # pragma: no cover
sqlite3 = None # type: ignore
from httpcore import Request, Response
if t.TYPE_CHECKING: # pragma: no cover
from typing_extensions import TypeAlias
from hishel._serializers import BaseSerializer, clone_model
from .._files import FileManager
from .._serializers import JSONSerializer, Metadata
from .._synchronization import Lock
from .._utils import float_seconds_to_int_milliseconds
logger = logging.getLogger("hishel.storages")
__all__ = (
"BaseStorage",
"FileStorage",
"RedisStorage",
"SQLiteStorage",
"InMemoryStorage",
"S3Storage",
)
StoredResponse: TypeAlias = tp.Tuple[Response, Request, Metadata]
RemoveTypes = tp.Union[str, Response]
try:
import redis
except ImportError: # pragma: no cover
redis = None # type: ignore
class BaseStorage:
def __init__(
self,
serializer: tp.Optional[BaseSerializer] = None,
ttl: tp.Optional[tp.Union[int, float]] = None,
) -> None:
self._serializer = serializer or JSONSerializer()
self._ttl = ttl
def store(self, key: str, response: Response, request: Request, metadata: Metadata | None = None) -> None:
raise NotImplementedError()
def remove(self, key: RemoveTypes) -> None:
raise NotImplementedError()
def update_metadata(self, key: str, response: Response, request: Request, metadata: Metadata) -> None:
raise NotImplementedError()
def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
raise NotImplementedError()
def close(self) -> None:
raise NotImplementedError()
class FileStorage(BaseStorage):
"""
A simple file storage.
:param serializer: Serializer capable of serializing and de-serializing http responses, defaults to None
:type serializer: tp.Optional[BaseSerializer], optional
:param base_path: A storage base path where the responses should be saved, defaults to None
:type base_path: tp.Optional[Path], optional
:param ttl: Specifies the maximum number of seconds that the response can be cached, defaults to None
:type ttl: tp.Optional[tp.Union[int, float]], optional
:param check_ttl_every: How often in seconds to check staleness of **all** cache files.
Makes sense only with set `ttl`, defaults to 60
:type check_ttl_every: tp.Union[int, float]
"""
def __init__(
self,
serializer: tp.Optional[BaseSerializer] = None,
base_path: tp.Optional[Path] = None,
ttl: tp.Optional[tp.Union[int, float]] = None,
check_ttl_every: tp.Union[int, float] = 60,
) -> None:
super().__init__(serializer, ttl)
self._base_path = Path(base_path) if base_path is not None else Path(".cache/hishel")
self._gitignore_file = self._base_path / ".gitignore"
if not self._base_path.is_dir():
self._base_path.mkdir(parents=True)
if not self._gitignore_file.is_file():
with open(self._gitignore_file, "w", encoding="utf-8") as f:
f.write("# Automatically created by Hishel\n*")
self._file_manager = FileManager(is_binary=self._serializer.is_binary)
self._lock = Lock()
self._check_ttl_every = check_ttl_every
self._last_cleaned = time.monotonic()
def store(self, key: str, response: Response, request: Request, metadata: Metadata | None = None) -> None:
"""
Stores the response in the cache.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additional information about the stored response
:type metadata: Optional[Metadata]
"""
metadata = metadata or Metadata(
cache_key=key, created_at=datetime.datetime.now(datetime.timezone.utc), number_of_uses=0
)
response_path = self._base_path / key
with self._lock:
self._file_manager.write_to(
str(response_path),
self._serializer.dumps(response=response, request=request, metadata=metadata),
)
self._remove_expired_caches(response_path)
def remove(self, key: RemoveTypes) -> None:
"""
Removes the response from the cache.
:param key: Hashed value of concatenated HTTP method and URI or an HTTP response
:type key: Union[str, Response]
"""
if isinstance(key, Response): # pragma: no cover
key = t.cast(str, key.extensions["cache_metadata"]["cache_key"])
response_path = self._base_path / key
with self._lock:
if response_path.exists():
response_path.unlink(missing_ok=True)
def update_metadata(self, key: str, response: Response, request: Request, metadata: Metadata) -> None:
"""
Updates the metadata of the stored response.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additional information about the stored response
:type metadata: Metadata
"""
response_path = self._base_path / key
with self._lock:
if response_path.exists():
atime = response_path.stat().st_atime
old_mtime = response_path.stat().st_mtime
self._file_manager.write_to(
str(response_path),
self._serializer.dumps(response=response, request=request, metadata=metadata),
)
# Restore the old atime and mtime (we use mtime to check the cache expiration time)
os.utime(response_path, (atime, old_mtime))
return
return self.store(key, response, request, metadata) # pragma: no cover
def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
"""
Retreives the response from the cache using his key.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:return: An HTTP response and his HTTP request.
:rtype: tp.Optional[StoredResponse]
"""
response_path = self._base_path / key
self._remove_expired_caches(response_path)
with self._lock:
if response_path.exists():
read_data = self._file_manager.read_from(str(response_path))
if len(read_data) != 0:
return self._serializer.loads(read_data)
return None
def close(self) -> None: # pragma: no cover
return
def _remove_expired_caches(self, response_path: Path) -> None:
if self._ttl is None:
return
if time.monotonic() - self._last_cleaned < self._check_ttl_every:
if response_path.is_file():
age = time.time() - response_path.stat().st_mtime
if age > self._ttl:
response_path.unlink(missing_ok=True)
return
self._last_cleaned = time.monotonic()
with self._lock:
with os.scandir(self._base_path) as entries:
for entry in entries:
try:
if entry.is_file():
age = time.time() - entry.stat().st_mtime
if age > self._ttl:
os.unlink(entry.path)
except FileNotFoundError: # pragma: no cover
pass
class SQLiteStorage(BaseStorage):
"""
A simple sqlite3 storage.
:param serializer: Serializer capable of serializing and de-serializing http responses, defaults to None
:type serializer: tp.Optional[BaseSerializer], optional
:param connection: A connection for sqlite, defaults to None
:type connection: tp.Optional[sqlite3.Connection], optional
:param ttl: Specifies the maximum number of seconds that the response can be cached, defaults to None
:type ttl: tp.Optional[tp.Union[int, float]], optional
"""
def __init__(
self,
serializer: tp.Optional[BaseSerializer] = None,
connection: tp.Optional[sqlite3.Connection] = None,
ttl: tp.Optional[tp.Union[int, float]] = None,
) -> None:
if sqlite3 is None: # pragma: no cover
raise RuntimeError(
f"The `{type(self).__name__}` was used, but the required packages were not found. "
"Check that you have `Hishel` installed with the `sqlite` extension as shown.\n"
"```pip install hishel[sqlite]```"
)
super().__init__(serializer, ttl)
self._connection: tp.Optional[sqlite3.Connection] = connection or None
self._setup_lock = Lock()
self._setup_completed: bool = False
self._lock = Lock()
def _setup(self) -> None:
with self._setup_lock:
if not self._setup_completed:
if not self._connection: # pragma: no cover
self._connection = sqlite3.connect(".hishel.sqlite", check_same_thread=False)
self._connection.execute(
"CREATE TABLE IF NOT EXISTS cache(key TEXT, data BLOB, date_created REAL)"
)
self._connection.commit()
self._setup_completed = True
def store(self, key: str, response: Response, request: Request, metadata: Metadata | None = None) -> None:
"""
Stores the response in the cache.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additioal information about the stored response
:type metadata: Optional[Metadata]
"""
self._setup()
assert self._connection
metadata = metadata or Metadata(
cache_key=key, created_at=datetime.datetime.now(datetime.timezone.utc), number_of_uses=0
)
with self._lock:
self._connection.execute("DELETE FROM cache WHERE key = ?", [key])
serialized_response = self._serializer.dumps(response=response, request=request, metadata=metadata)
self._connection.execute(
"INSERT INTO cache(key, data, date_created) VALUES(?, ?, ?)", [key, serialized_response, time.time()]
)
self._connection.commit()
self._remove_expired_caches()
def remove(self, key: RemoveTypes) -> None:
"""
Removes the response from the cache.
:param key: Hashed value of concatenated HTTP method and URI or an HTTP response
:type key: Union[str, Response]
"""
self._setup()
assert self._connection
if isinstance(key, Response): # pragma: no cover
key = t.cast(str, key.extensions["cache_metadata"]["cache_key"])
with self._lock:
self._connection.execute("DELETE FROM cache WHERE key = ?", [key])
self._connection.commit()
def update_metadata(self, key: str, response: Response, request: Request, metadata: Metadata) -> None:
"""
Updates the metadata of the stored response.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additional information about the stored response
:type metadata: Metadata
"""
self._setup()
assert self._connection
with self._lock:
cursor = self._connection.execute("SELECT data FROM cache WHERE key = ?", [key])
row = cursor.fetchone()
if row is not None:
serialized_response = self._serializer.dumps(response=response, request=request, metadata=metadata)
self._connection.execute("UPDATE cache SET data = ? WHERE key = ?", [serialized_response, key])
self._connection.commit()
return
return self.store(key, response, request, metadata) # pragma: no cover
def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
"""
Retreives the response from the cache using his key.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:return: An HTTP response and its HTTP request.
:rtype: tp.Optional[StoredResponse]
"""
self._setup()
assert self._connection
self._remove_expired_caches()
with self._lock:
cursor = self._connection.execute("SELECT data FROM cache WHERE key = ?", [key])
row = cursor.fetchone()
if row is None:
return None
cached_response = row[0]
return self._serializer.loads(cached_response)
def close(self) -> None: # pragma: no cover
if self._connection is not None:
self._connection.close()
def _remove_expired_caches(self) -> None:
assert self._connection
if self._ttl is None:
return
with self._lock:
self._connection.execute("DELETE FROM cache WHERE date_created + ? < ?", [self._ttl, time.time()])
self._connection.commit()
class RedisStorage(BaseStorage):
"""
A simple redis storage.
:param serializer: Serializer capable of serializing and de-serializing http responses, defaults to None
:type serializer: tp.Optional[BaseSerializer], optional
:param client: A client for redis, defaults to None
:type client: tp.Optional["redis.Redis"], optional
:param ttl: Specifies the maximum number of seconds that the response can be cached, defaults to None
:type ttl: tp.Optional[tp.Union[int, float]], optional
"""
def __init__(
self,
serializer: tp.Optional[BaseSerializer] = None,
client: tp.Optional[redis.Redis] = None, # type: ignore
ttl: tp.Optional[tp.Union[int, float]] = None,
) -> None:
if redis is None: # pragma: no cover
raise RuntimeError(
f"The `{type(self).__name__}` was used, but the required packages were not found. "
"Check that you have `Hishel` installed with the `redis` extension as shown.\n"
"```pip install hishel[redis]```"
)
super().__init__(serializer, ttl)
if client is None:
self._client = redis.Redis() # type: ignore
else: # pragma: no cover
self._client = client
def store(self, key: str, response: Response, request: Request, metadata: Metadata | None = None) -> None:
"""
Stores the response in the cache.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additioal information about the stored response
:type metadata: Optional[Metadata]
"""
metadata = metadata or Metadata(
cache_key=key, created_at=datetime.datetime.now(datetime.timezone.utc), number_of_uses=0
)
if self._ttl is not None:
px = float_seconds_to_int_milliseconds(self._ttl)
else:
px = None
self._client.set(
key, self._serializer.dumps(response=response, request=request, metadata=metadata), px=px
)
def remove(self, key: RemoveTypes) -> None:
"""
Removes the response from the cache.
:param key: Hashed value of concatenated HTTP method and URI or an HTTP response
:type key: Union[str, Response]
"""
if isinstance(key, Response): # pragma: no cover
key = t.cast(str, key.extensions["cache_metadata"]["cache_key"])
self._client.delete(key)
def update_metadata(self, key: str, response: Response, request: Request, metadata: Metadata) -> None:
"""
Updates the metadata of the stored response.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additional information about the stored response
:type metadata: Metadata
"""
ttl_in_milliseconds = self._client.pttl(key)
# -2: if the key does not exist in Redis
# -1: if the key exists in Redis but has no expiration
if ttl_in_milliseconds == -2 or ttl_in_milliseconds == -1: # pragma: no cover
self.store(key, response, request, metadata)
else:
self._client.set(
key,
self._serializer.dumps(response=response, request=request, metadata=metadata),
px=ttl_in_milliseconds,
)
def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
"""
Retreives the response from the cache using his key.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:return: An HTTP response and its HTTP request.
:rtype: tp.Optional[StoredResponse]
"""
cached_response = self._client.get(key)
if cached_response is None:
return None
return self._serializer.loads(cached_response)
def close(self) -> None: # pragma: no cover
self._client.close()
class InMemoryStorage(BaseStorage):
"""
A simple in-memory storage.
:param serializer: Serializer capable of serializing and de-serializing http responses, defaults to None
:type serializer: tp.Optional[BaseSerializer], optional
:param ttl: Specifies the maximum number of seconds that the response can be cached, defaults to None
:type ttl: tp.Optional[tp.Union[int, float]], optional
:param capacity: The maximum number of responses that can be cached, defaults to 128
:type capacity: int, optional
"""
def __init__(
self,
serializer: tp.Optional[BaseSerializer] = None,
ttl: tp.Optional[tp.Union[int, float]] = None,
capacity: int = 128,
) -> None:
super().__init__(serializer, ttl)
if serializer is not None: # pragma: no cover
warnings.warn("The serializer is not used in the in-memory storage.", RuntimeWarning)
from hishel import LFUCache
self._cache: LFUCache[str, tp.Tuple[StoredResponse, float]] = LFUCache(capacity=capacity)
self._lock = Lock()
def store(self, key: str, response: Response, request: Request, metadata: Metadata | None = None) -> None:
"""
Stores the response in the cache.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additioal information about the stored response
:type metadata: Optional[Metadata]
"""
metadata = metadata or Metadata(
cache_key=key, created_at=datetime.datetime.now(datetime.timezone.utc), number_of_uses=0
)
with self._lock:
response_clone = clone_model(response)
request_clone = clone_model(request)
stored_response: StoredResponse = (deepcopy(response_clone), deepcopy(request_clone), metadata)
self._cache.put(key, (stored_response, time.monotonic()))
self._remove_expired_caches()
def remove(self, key: RemoveTypes) -> None:
"""
Removes the response from the cache.
:param key: Hashed value of concatenated HTTP method and URI or an HTTP response
:type key: Union[str, Response]
"""
if isinstance(key, Response): # pragma: no cover
key = t.cast(str, key.extensions["cache_metadata"]["cache_key"])
with self._lock:
self._cache.remove_key(key)
def update_metadata(self, key: str, response: Response, request: Request, metadata: Metadata) -> None:
"""
Updates the metadata of the stored response.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additional information about the stored response
:type metadata: Metadata
"""
with self._lock:
try:
stored_response, created_at = self._cache.get(key)
stored_response = (stored_response[0], stored_response[1], metadata)
self._cache.put(key, (stored_response, created_at))
return
except KeyError: # pragma: no cover
pass
self.store(key, response, request, metadata) # pragma: no cover
def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
"""
Retreives the response from the cache using his key.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:return: An HTTP response and its HTTP request.
:rtype: tp.Optional[StoredResponse]
"""
self._remove_expired_caches()
with self._lock:
try:
stored_response, _ = self._cache.get(key)
except KeyError:
return None
return stored_response
def close(self) -> None: # pragma: no cover
return
def _remove_expired_caches(self) -> None:
if self._ttl is None:
return
with self._lock:
keys_to_remove = set()
for key in self._cache:
created_at = self._cache.get(key)[1]
if time.monotonic() - created_at > self._ttl:
keys_to_remove.add(key)
for key in keys_to_remove:
self._cache.remove_key(key)
class S3Storage(BaseStorage): # pragma: no cover
"""
AWS S3 storage.
:param bucket_name: The name of the bucket to store the responses in
:type bucket_name: str
:param serializer: Serializer capable of serializing and de-serializing http responses, defaults to None
:type serializer: tp.Optional[BaseSerializer], optional
:param ttl: Specifies the maximum number of seconds that the response can be cached, defaults to None
:type ttl: tp.Optional[tp.Union[int, float]], optional
:param check_ttl_every: How often in seconds to check staleness of **all** cache files.
Makes sense only with set `ttl`, defaults to 60
:type check_ttl_every: tp.Union[int, float]
:param client: A client for S3, defaults to None
:type client: tp.Optional[tp.Any], optional
:param path_prefix: A path prefix to use for S3 object keys, defaults to "hishel-"
:type path_prefix: str, optional
"""
def __init__(
self,
bucket_name: str,
serializer: tp.Optional[BaseSerializer] = None,
ttl: tp.Optional[tp.Union[int, float]] = None,
check_ttl_every: tp.Union[int, float] = 60,
client: tp.Optional[tp.Any] = None,
path_prefix: str = "hishel-",
) -> None:
super().__init__(serializer, ttl)
if boto3 is None: # pragma: no cover
raise RuntimeError(
f"The `{type(self).__name__}` was used, but the required packages were not found. "
"Check that you have `Hishel` installed with the `s3` extension as shown.\n"
"```pip install hishel[s3]```"
)
self._bucket_name = bucket_name
client = client or boto3.client("s3")
self._s3_manager = S3Manager(
client=client,
bucket_name=bucket_name,
is_binary=self._serializer.is_binary,
check_ttl_every=check_ttl_every,
path_prefix=path_prefix,
)
self._lock = Lock()
def store(self, key: str, response: Response, request: Request, metadata: Metadata | None = None) -> None:
"""
Stores the response in the cache.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additioal information about the stored response
:type metadata: Optional[Metadata]`
"""
metadata = metadata or Metadata(
cache_key=key, created_at=datetime.datetime.now(datetime.timezone.utc), number_of_uses=0
)
with self._lock:
serialized = self._serializer.dumps(response=response, request=request, metadata=metadata)
self._s3_manager.write_to(path=key, data=serialized)
self._remove_expired_caches(key)
def remove(self, key: RemoveTypes) -> None:
"""
Removes the response from the cache.
:param key: Hashed value of concatenated HTTP method and URI or an HTTP response
:type key: Union[str, Response]
"""
if isinstance(key, Response): # pragma: no cover
key = t.cast(str, key.extensions["cache_metadata"]["cache_key"])
with self._lock:
self._s3_manager.remove_entry(key)
def update_metadata(self, key: str, response: Response, request: Request, metadata: Metadata) -> None:
"""
Updates the metadata of the stored response.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:param response: An HTTP response
:type response: httpcore.Response
:param request: An HTTP request
:type request: httpcore.Request
:param metadata: Additional information about the stored response
:type metadata: Metadata
"""
with self._lock:
serialized = self._serializer.dumps(response=response, request=request, metadata=metadata)
self._s3_manager.write_to(path=key, data=serialized, only_metadata=True)
def retrieve(self, key: str) -> tp.Optional[StoredResponse]:
"""
Retreives the response from the cache using his key.
:param key: Hashed value of concatenated HTTP method and URI
:type key: str
:return: An HTTP response and its HTTP request.
:rtype: tp.Optional[StoredResponse]
"""
self._remove_expired_caches(key)
with self._lock:
try:
return self._serializer.loads(self._s3_manager.read_from(path=key))
except Exception:
return None
def close(self) -> None: # pragma: no cover
return
def _remove_expired_caches(self, key: str) -> None:
if self._ttl is None:
return
with self._lock:
converted_ttl = float_seconds_to_int_milliseconds(self._ttl)
self._s3_manager.remove_expired(ttl=converted_ttl, key=key)

View File

@@ -0,0 +1,277 @@
from __future__ import annotations
import types
import typing as tp
import httpcore
import httpx
from httpx import SyncByteStream, Request, Response
from httpx._exceptions import ConnectError
from hishel._utils import extract_header_values_decoded, normalized_url
from .._controller import Controller, allowed_stale
from .._headers import parse_cache_control
from .._serializers import JSONSerializer, Metadata
from ._storages import BaseStorage, FileStorage
if tp.TYPE_CHECKING: # pragma: no cover
from typing_extensions import Self
__all__ = ("CacheTransport",)
def fake_stream(content: bytes) -> tp.Iterable[bytes]:
yield content
def generate_504() -> Response:
return Response(status_code=504)
class CacheStream(SyncByteStream):
def __init__(self, httpcore_stream: tp.Iterable[bytes]):
self._httpcore_stream = httpcore_stream
def __iter__(self) -> tp.Iterator[bytes]:
for part in self._httpcore_stream:
yield part
def close(self) -> None:
if hasattr(self._httpcore_stream, "close"):
self._httpcore_stream.close()
class CacheTransport(httpx.BaseTransport):
"""
An HTTPX Transport that supports HTTP caching.
:param transport: `Transport` that our class wraps in order to add an HTTP Cache layer on top of
:type transport: httpx.BaseTransport
:param storage: Storage that handles how the responses should be saved., defaults to None
:type storage: tp.Optional[BaseStorage], optional
:param controller: Controller that manages the cache behavior at the specification level, defaults to None
:type controller: tp.Optional[Controller], optional
"""
def __init__(
self,
transport: httpx.BaseTransport,
storage: tp.Optional[BaseStorage] = None,
controller: tp.Optional[Controller] = None,
) -> None:
self._transport = transport
self._storage = storage if storage is not None else FileStorage(serializer=JSONSerializer())
if not isinstance(self._storage, BaseStorage): # pragma: no cover
raise TypeError(f"Expected subclass of `BaseStorage` but got `{storage.__class__.__name__}`")
self._controller = controller if controller is not None else Controller()
def handle_request(self, request: Request) -> Response:
"""
Handles HTTP requests while also implementing HTTP caching.
:param request: An HTTP request
:type request: httpx.Request
:return: An HTTP response
:rtype: httpx.Response
"""
if request.extensions.get("cache_disabled", False):
request.headers.update(
[
("Cache-Control", "no-store"),
("Cache-Control", "no-cache"),
*[("cache-control", value) for value in request.headers.get_list("cache-control")],
]
)
if request.method not in ["GET", "HEAD"]:
# If the HTTP method is, for example, POST,
# we must also use the request data to generate the hash.
body_for_key = request.read()
request.stream = CacheStream(fake_stream(body_for_key))
else:
body_for_key = b""
# Construct the HTTPCore request because Controllers and Storages work with HTTPCore requests.
httpcore_request = httpcore.Request(
method=request.method,
url=httpcore.URL(
scheme=request.url.raw_scheme,
host=request.url.raw_host,
port=request.url.port,
target=request.url.raw_path,
),
headers=request.headers.raw,
content=request.stream,
extensions=request.extensions,
)
key = self._controller._key_generator(httpcore_request, body_for_key)
stored_data = self._storage.retrieve(key)
request_cache_control = parse_cache_control(
extract_header_values_decoded(request.headers.raw, b"Cache-Control")
)
if request_cache_control.only_if_cached and not stored_data:
return generate_504()
if stored_data:
# Try using the stored response if it was discovered.
stored_response, stored_request, metadata = stored_data
# Immediately read the stored response to avoid issues when trying to access the response body.
stored_response.read()
res = self._controller.construct_response_from_cache(
request=httpcore_request,
response=stored_response,
original_request=stored_request,
)
if isinstance(res, httpcore.Response):
# Simply use the response if the controller determines it is ready for use.
return self._create_hishel_response(
key=key,
response=res,
request=httpcore_request,
cached=True,
revalidated=False,
metadata=metadata,
)
if request_cache_control.only_if_cached:
return generate_504()
if isinstance(res, httpcore.Request):
# Controller has determined that the response needs to be re-validated.
assert isinstance(res.stream, tp.Iterable)
revalidation_request = Request(
method=res.method.decode(),
url=normalized_url(res.url),
headers=res.headers,
stream=CacheStream(res.stream),
extensions=res.extensions,
)
try:
revalidation_response = self._transport.handle_request(revalidation_request)
except ConnectError:
# If there is a connection error, we can use the stale response if allowed.
if self._controller._allow_stale and allowed_stale(response=stored_response):
return self._create_hishel_response(
key=key,
response=stored_response,
request=httpcore_request,
cached=True,
revalidated=False,
metadata=metadata,
)
raise # pragma: no cover
assert isinstance(revalidation_response.stream, tp.Iterable)
httpcore_revalidation_response = httpcore.Response(
status=revalidation_response.status_code,
headers=revalidation_response.headers.raw,
content=CacheStream(revalidation_response.stream),
extensions=revalidation_response.extensions,
)
# Merge headers with the stale response.
final_httpcore_response = self._controller.handle_validation_response(
old_response=stored_response,
new_response=httpcore_revalidation_response,
)
final_httpcore_response.read()
revalidation_response.close()
assert isinstance(final_httpcore_response.stream, tp.Iterable)
# RFC 9111: 4.3.3. Handling a Validation Response
# A 304 (Not Modified) response status code indicates that the stored response can be updated and
# reused. A full response (i.e., one containing content) indicates that none of the stored responses
# nominated in the conditional request are suitable. Instead, the cache MUST use the full response to
# satisfy the request. The cache MAY store such a full response, subject to its constraints.
if revalidation_response.status_code != 304 and self._controller.is_cachable(
request=httpcore_request, response=final_httpcore_response
):
self._storage.store(key, response=final_httpcore_response, request=httpcore_request)
return self._create_hishel_response(
key=key,
response=final_httpcore_response,
request=httpcore_request,
cached=revalidation_response.status_code == 304,
revalidated=True,
metadata=metadata,
)
regular_response = self._transport.handle_request(request)
assert isinstance(regular_response.stream, tp.Iterable)
httpcore_regular_response = httpcore.Response(
status=regular_response.status_code,
headers=regular_response.headers.raw,
content=CacheStream(regular_response.stream),
extensions=regular_response.extensions,
)
httpcore_regular_response.read()
httpcore_regular_response.close()
if self._controller.is_cachable(request=httpcore_request, response=httpcore_regular_response):
self._storage.store(
key,
response=httpcore_regular_response,
request=httpcore_request,
)
return self._create_hishel_response(
key=key,
response=httpcore_regular_response,
request=httpcore_request,
cached=False,
revalidated=False,
)
def _create_hishel_response(
self,
key: str,
response: httpcore.Response,
request: httpcore.Request,
cached: bool,
revalidated: bool,
metadata: Metadata | None = None,
) -> Response:
if cached:
assert metadata
metadata["number_of_uses"] += 1
self._storage.update_metadata(key=key, request=request, response=response, metadata=metadata)
response.extensions["from_cache"] = True # type: ignore[index]
response.extensions["cache_metadata"] = metadata # type: ignore[index]
else:
response.extensions["from_cache"] = False # type: ignore[index]
response.extensions["revalidated"] = revalidated # type: ignore[index]
return Response(
status_code=response.status,
headers=response.headers,
stream=CacheStream(fake_stream(response.content)),
extensions=response.extensions,
)
def close(self) -> None:
self._storage.close()
self._transport.close()
def __enter__(self) -> Self:
return self
def __exit__(
self,
exc_type: tp.Optional[tp.Type[BaseException]] = None,
exc_value: tp.Optional[BaseException] = None,
traceback: tp.Optional[types.TracebackType] = None,
) -> None:
self.close()

View File

@@ -0,0 +1,37 @@
import types
import typing as tp
from threading import Lock as T_LOCK
import anyio
class AsyncLock:
def __init__(self) -> None:
self._lock = anyio.Lock()
async def __aenter__(self) -> None:
await self._lock.acquire()
async def __aexit__(
self,
exc_type: tp.Optional[tp.Type[BaseException]] = None,
exc_value: tp.Optional[BaseException] = None,
traceback: tp.Optional[types.TracebackType] = None,
) -> None:
self._lock.release()
class Lock:
def __init__(self) -> None:
self._lock = T_LOCK()
def __enter__(self) -> None:
self._lock.acquire()
def __exit__(
self,
exc_type: tp.Optional[tp.Type[BaseException]] = None,
exc_value: tp.Optional[BaseException] = None,
traceback: tp.Optional[types.TracebackType] = None,
) -> None:
self._lock.release()

View File

@@ -0,0 +1,118 @@
import calendar
import hashlib
import time
import typing as tp
from email.utils import parsedate_tz
import anyio
import httpcore
import httpx
HEADERS_ENCODING = "iso-8859-1"
class BaseClock:
def now(self) -> int:
raise NotImplementedError()
class Clock(BaseClock):
def now(self) -> int:
return int(time.time())
def normalized_url(url: tp.Union[httpcore.URL, str, bytes]) -> str:
if isinstance(url, str): # pragma: no cover
return url
if isinstance(url, bytes): # pragma: no cover
return url.decode("ascii")
if isinstance(url, httpcore.URL):
port = f":{url.port}" if url.port is not None else ""
return f"{url.scheme.decode('ascii')}://{url.host.decode('ascii')}{port}{url.target.decode('ascii')}"
assert False, "Invalid type for `normalized_url`" # pragma: no cover
def get_safe_url(url: httpcore.URL) -> str:
httpx_url = httpx.URL(bytes(url).decode("ascii"))
schema = httpx_url.scheme
host = httpx_url.host
path = httpx_url.path
return f"{schema}://{host}{path}"
def generate_key(request: httpcore.Request, body: bytes = b"") -> str:
encoded_url = normalized_url(request.url).encode("ascii")
key_parts = [request.method, encoded_url, body]
# FIPs mode disables blake2 algorithm, use sha256 instead when not found.
blake2b_hasher = None
sha256_hasher = hashlib.sha256(usedforsecurity=False)
try:
blake2b_hasher = hashlib.blake2b(digest_size=16, usedforsecurity=False)
except (ValueError, TypeError, AttributeError):
pass
hexdigest: str
if blake2b_hasher:
for part in key_parts:
blake2b_hasher.update(part)
hexdigest = blake2b_hasher.hexdigest()
else:
for part in key_parts:
sha256_hasher.update(part)
hexdigest = sha256_hasher.hexdigest()
return hexdigest
def extract_header_values(
headers: tp.List[tp.Tuple[bytes, bytes]],
header_key: tp.Union[bytes, str],
single: bool = False,
) -> tp.List[bytes]:
if isinstance(header_key, str):
header_key = header_key.encode(HEADERS_ENCODING)
extracted_headers = []
for key, value in headers:
if key.lower() == header_key.lower():
extracted_headers.append(value)
if single:
break
return extracted_headers
def extract_header_values_decoded(
headers: tp.List[tp.Tuple[bytes, bytes]], header_key: bytes, single: bool = False
) -> tp.List[str]:
values = extract_header_values(headers=headers, header_key=header_key, single=single)
return [value.decode(HEADERS_ENCODING) for value in values]
def header_presents(headers: tp.List[tp.Tuple[bytes, bytes]], header_key: bytes) -> bool:
return bool(extract_header_values(headers, header_key, single=True))
def parse_date(date: str) -> tp.Optional[int]:
expires = parsedate_tz(date)
if expires is None:
return None
timestamp = calendar.timegm(expires[:6])
return timestamp
async def asleep(seconds: tp.Union[int, float]) -> None:
await anyio.sleep(seconds)
def sleep(seconds: tp.Union[int, float]) -> None:
time.sleep(seconds)
def float_seconds_to_int_milliseconds(seconds: float) -> int:
return int(seconds * 1000)