Initial commit

This commit is contained in:
kdusek
2025-12-09 12:13:01 +01:00
commit 8e654ed209
13332 changed files with 2695056 additions and 0 deletions

View File

@@ -0,0 +1,15 @@
from ._version import __version__
from .httpxclientmanager import HttpxThrottleCache
__all__ = ["HttpxThrottleCache", "__version__"]
EDGAR_CACHE_RULES = {
r".*\.sec\.gov": {
"/submissions.*": 600,
r"/include/ticker\.txt.*": 600,
r"/files/company_tickers\.json.*": 600,
".*index/.*": 1800,
"/Archives/edgar/data": True, # cache forever
}
}

View File

@@ -0,0 +1,6 @@
import importlib.metadata
try:
__version__ = importlib.metadata.version(__package__ or __name__)
except importlib.metadata.PackageNotFoundError: # pragma: no cover
__version__ = "0.0.0"

View File

@@ -0,0 +1,108 @@
import logging
import re
from typing import Optional, Union
import hishel
import httpcore
logger = logging.getLogger(__name__)
def get_rules(
request_host: str, cache_rules: dict[str, dict[str, Union[bool, int]]]
) -> Optional[dict[str, Union[bool, int]]]:
for site_pattern, rules in cache_rules.items():
if re.match(site_pattern, request_host):
logger.info("matched %s, using value %s: %s", site_pattern, request_host, rules)
return rules
logger.debug("No patterns matched %s", request_host)
def match_request(target, cache_rules_for_site: dict[str, Union[bool, int]]):
for pat, v in cache_rules_for_site.items():
if re.match(pat, target):
logger.info("%s matched %s, using value %s", target, pat, v)
return v
def get_rule_for_request(
request_host: str, target: str, cache_rules: dict[str, dict[str, Union[bool, int]]]
) -> Optional[Union[bool, int]]:
cache_rules_for_site = get_rules(request_host=request_host, cache_rules=cache_rules)
if cache_rules_for_site:
is_cacheable = match_request(target=target, cache_rules_for_site=cache_rules_for_site)
return is_cacheable
return None
def get_cache_controller(key_generator, cache_rules: dict[str, dict[str, Union[bool, int]]], **kwargs):
class EdgarController(hishel.Controller):
def is_cachable(self, request: httpcore.Request, response: httpcore.Response) -> bool:
if response.status not in self._cacheable_status_codes:
return False
cache_period = get_rule_for_request(
request_host=request.url.host.decode(), target=request.url.target.decode(), cache_rules=cache_rules
)
if cache_period: # True or an Int>0
return True
elif cache_period is False or cache_period == 0: # Explicitly not cacheable
return False
else:
# Fall through default caching policy
super_is_cachable = super().is_cachable(request, response)
logger.debug("%s is cacheable %s", request.url, super_is_cachable)
return super_is_cachable
def construct_response_from_cache(
self, request: httpcore.Request, response: httpcore.Response, original_request: httpcore.Request
) -> Union[httpcore.Request, httpcore.Response, None]:
if (
response.status not in self._cacheable_status_codes
): # pragma: no cover - would only occur if the cache was loaded then rules changed
return None
cache_period = get_rule_for_request(
request_host=request.url.host.decode(), target=request.url.target.decode(), cache_rules=cache_rules
)
if cache_period is True:
# Cache forever, never recheck
logger.debug("Cache hit for %s", request.url)
return response
elif (
cache_period is False or cache_period == 0
): # pragma: no cover - would only occur if the cache was loaded then rules changed
return None
elif cache_period: # int
max_age = cache_period
age_seconds = hishel._controller.get_age(response, self._clock)
if age_seconds > max_age:
logger.debug(
"Request needs to be validated before using %s (age=%d, max_age=%d)",
request.url,
age_seconds,
max_age,
)
self._make_request_conditional(request=request, response=response)
return request
else:
logger.debug("Cache hit for %s (age=%d, max_age=%d)", request.url, age_seconds, max_age)
return response
else:
logger.debug("No rules applied to %s, using default", request.url)
return super().construct_response_from_cache(request, response, original_request)
controller = EdgarController(
cacheable_methods=["GET", "POST"], cacheable_status_codes=[200], key_generator=key_generator, **kwargs
)
return controller

View File

@@ -0,0 +1,363 @@
"""An alternative cache using:
- Flat files
"""
import calendar
import json
import logging
import os
import time
from pathlib import Path
from typing import Iterator, Optional, Union
from urllib.parse import quote, unquote
import aiofiles
import httpx
from filelock import AsyncFileLock, FileLock
from httpx._types import SyncByteStream # protocol type
from ..controller import get_rule_for_request
logger = logging.getLogger(__name__)
class AlreadyLockedError(Exception):
pass
class DualFileStream(httpx._types.SyncByteStream, httpx.AsyncByteStream):
def __init__(
self,
path: Path,
chunk_size: int = 1024 * 1024,
on_close: Optional[callable] = None,
async_on_close: Optional[callable] = None,
):
self.path, self.chunk_size = Path(path), chunk_size
self.on_close, self.async_on_close = on_close, async_on_close
def __iter__(self):
with open(self.path, "rb") as f:
while True:
b = f.read(self.chunk_size)
if not b:
break
yield b
def close(self) -> None:
if self.on_close: # pragma: no cover
self.on_close()
async def __aiter__(self):
async with aiofiles.open(self.path, "rb") as f:
while True:
b = await f.read(self.chunk_size)
if not b:
break
yield b
async def aclose(self) -> None:
if self.async_on_close: # pragma: no cover
await self.async_on_close()
class FileCache:
def __init__(self, cache_dir: Union[str, Path], locking: bool = True):
self.cache_dir = Path(cache_dir)
logger.info("cache_dir=%s", self.cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.locking = locking
def _meta_path(self, p: Path) -> Path:
return p.with_suffix(p.suffix + ".meta")
def _load_meta(self, p: Path) -> dict:
try:
return json.loads(self._meta_path(p).read_text())
except FileNotFoundError: # pragma: no cover
return {}
def to_path(self, host: str, path: str, query: str) -> Path:
site = host.lower().rstrip(".")
(self.cache_dir / site).mkdir(parents=True, exist_ok=True)
name = unquote(path).strip("/").replace("/", "-") or "index"
if query:
name += "-" + unquote(query).replace("&", "-").replace("=", "-")
return self.cache_dir / site / quote(name, safe="._-~")
def get_if_fresh(
self, host: str, path: str, query: str, cache_rules: dict[str, dict[str, Union[bool, int]]]
) -> tuple[bool, Optional[Path]]:
cached = get_rule_for_request(request_host=host, target=path, cache_rules=cache_rules)
if not cached:
logger.info("No cache policy for %s://%s, not retrieving from cache", host, path)
return False, None
p = self.to_path(host=host, path=path, query=query)
if not p.exists():
logger.info("Cache file doesn't exist: %s for %s", path, p)
return False, None
meta = self._load_meta(p)
fetched = meta.get("fetched")
if not fetched:
return False, p # pragma: no cover
if cached is True:
logger.info("Cache policy allows unlimited cache, returning %s", p)
return True, p
age = round(time.time() - fetched)
if age < 0: # pragma: no cover
raise ValueError(f"Age is less than 0, impossible {age=}, file {path=}")
logger.info("file is %s seconds old, policy allows caching for up to %s", age, cached)
return (age <= cached, p)
class _TeeCore:
def __init__(self, resp: httpx.Response, path: Path, locking: bool, last_modified: str, access_date: str):
assert path is not None
self.resp = resp
self.path = path
self.tmp = path.with_name(path.name + ".tmp")
self.lock = FileLock(str(path) + ".lock") if locking else None
self.fh = None
if last_modified:
self.mtime = calendar.timegm(time.strptime(last_modified, "%a, %d %b %Y %H:%M:%S GMT"))
else:
self.mtime = None
if access_date:
self.atime = calendar.timegm(time.strptime(access_date, "%a, %d %b %Y %H:%M:%S GMT"))
else:
self.atime = None # pragma: no cover
def acquire(self):
self.lock and self.lock.acquire() # pyright: ignore[reportUnusedExpression]
def open_tmp(self):
self.fh = open(self.tmp, "wb")
def write(self, chunk: bytes):
self.fh.write(chunk) # pyright: ignore[reportOptionalMemberAccess]
def finalize(self):
try:
if self.fh and not self.fh.closed:
self.fh.flush()
os.fsync(self.fh.fileno())
self.fh.close()
os.replace(self.tmp, self.path)
try:
meta_path = self.path.with_suffix(self.path.suffix + ".meta")
headers = {
"content-type": self.resp.headers.get("content-type"),
"content-encoding": self.resp.headers.get("content-encoding"),
}
meta_path.write_text(json.dumps({"fetched": self.atime, "origin_lm": self.mtime, "headers": headers}))
except FileNotFoundError: # pragma: no cover
pass
finally:
if self.lock and getattr(self.lock, "is_locked", False):
self.lock.release()
class _TeeToDisk(SyncByteStream):
def __init__(self, resp: httpx.Response, path: Path, locking: bool, last_modified: str, access_date: str) -> None:
self.core = _TeeCore(resp, path, locking, last_modified, access_date)
def __iter__(self) -> Iterator[bytes]:
self.core.acquire()
try:
self.core.open_tmp()
for chunk in self.core.resp.iter_raw():
self.core.write(chunk)
yield chunk
finally:
self.core.finalize()
def close(self) -> None:
try:
self.core.resp.close()
finally:
self.core.finalize()
class _AsyncTeeToDisk(httpx.AsyncByteStream):
def __init__(self, resp, path, locking, last_modified, access_date):
self.resp = resp
self.path = path
self.tmp = path.with_name(path.name + ".tmp")
self.lock = AsyncFileLock(str(path) + ".lock") if locking else None
if last_modified:
self.mtime = calendar.timegm(time.strptime(last_modified, "%a, %d %b %Y %H:%M:%S GMT"))
else:
self.mtime = None
if access_date:
self.atime = calendar.timegm(time.strptime(access_date, "%a, %d %b %Y %H:%M:%S GMT"))
else:
self.atime = None # pragma: no cover
async def __aiter__(self):
if self.lock:
await self.lock.acquire()
try:
async with aiofiles.open(self.tmp, "wb") as f:
async for chunk in self.resp.aiter_raw():
await f.write(chunk)
yield chunk
os.replace(self.tmp, self.path)
async with aiofiles.open(self.path.with_suffix(self.path.suffix + ".meta"), "w") as m:
headers = {
"content-type": self.resp.headers.get("content-type"),
"content-encoding": self.resp.headers.get("content-encoding"),
}
await m.write(json.dumps({"fetched": self.atime, "origin_lm": self.mtime, "headers": headers}))
finally:
if self.lock:
await self.lock.release()
async def aclose(self):
try:
await self.resp.aclose()
finally:
if self.lock:
await self.lock.release()
class CachingTransport(httpx.BaseTransport, httpx.AsyncBaseTransport):
cache_rules: dict[str, dict[str, Union[bool, int]]]
streaming_cutoff: int = 8 * 1024 * 1024
def __init__(
self,
cache_dir: Union[str, Path],
cache_rules: dict[str, dict[str, Union[bool, int]]],
transport: Optional[httpx.BaseTransport] = None,
):
self._cache = FileCache(cache_dir=cache_dir, locking=True)
self.transport = transport or httpx.HTTPTransport()
self.cache_rules = cache_rules
def _cache_hit_response(self, req, path: Path, status_code: int = 200):
"""
TODO: More carefully consider async here. read_text, read_bytes both are blocking.
Large files are streamed async, so the only blocking events here are for reading small(ish) files
"""
meta = json.loads(path.with_suffix(path.suffix + ".meta").read_text())
date = time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime(meta["fetched"]))
last_modified = time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime(meta["origin_lm"]))
ct = meta.get("headers", {}).get("content-type", "application/octet-stream")
ce = meta.get("headers", {}).get("content-encoding")
size = os.path.getsize(path)
headers = [
("x-cache", "HIT"),
("content-length", str(size)),
("Date", date),
("Last-Modified", last_modified),
]
if ce:
headers.append(("content-encoding", ce))
if ct:
headers.append(("content-type", ct))
if size < self.streaming_cutoff:
# If the file is small, just read it and return it
return httpx.Response(
status_code=status_code,
headers=headers,
content=path.read_bytes(),
request=req,
)
else:
# If the file is large, stream it
return httpx.Response(
status_code=status_code,
headers=headers,
stream=DualFileStream(path),
request=req,
)
def _cache_miss_response(self, req, net, path, tee_factory):
if net.status_code != 200:
return net
miss_headers = [
(k, v)
for k, v in net.headers.items()
if k.lower() not in ("transfer-encoding",) # "content-encoding", "content-length", "transfer-encoding")
]
miss_headers.append(("x-cache", "MISS"))
return httpx.Response(
status_code=net.status_code,
headers=miss_headers,
stream=tee_factory(
net, path, self._cache.locking, net.headers.get("Last-Modified"), net.headers.get("Date")
),
request=req,
extensions={**net.extensions, "decode_content": False},
)
def return_if_fresh(self, request):
host = request.url.host
path = request.url.path
query = request.url.query.decode() if request.url.query else ""
fresh, path = self._cache.get_if_fresh(host, path, query, self.cache_rules)
if path:
if fresh:
return self._cache_hit_response(request, path), path
else:
lm = json.loads(path.with_suffix(path.suffix + ".meta").read_text()).get("origin_lm")
if lm:
request.headers["If-Modified-Since"] = time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime(lm))
return None, path
else:
return None, None
def handle_request(self, request: httpx.Request) -> httpx.Response:
if request.method != "GET":
return self.transport.handle_request(request)
response, path = self.return_if_fresh(request)
if response:
return response
net = self.transport.handle_request(request)
if net.status_code == 304:
logger.info("304 for %s", request)
assert path is not None # must be true
return self._cache_hit_response(request, path, status_code=304)
host = request.url.host
path = request.url.path
query = request.url.query.decode() if request.url.query else ""
path = self._cache.to_path(host, path, query)
return self._cache_miss_response(request, net, path, _TeeToDisk)
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
if request.method != "GET":
return await self.transport.handle_async_request(request) # type: ignore[attr-defined]
response, path = self.return_if_fresh(request)
if response:
return response
net = await self.transport.handle_async_request(request)
if net.status_code == 304:
assert path is not None # must be true
logger.info("304 for %s", request)
return self._cache_hit_response(request, path, status_code=304)
path = self._cache.to_path(request.url.host, request.url.path, request.url.query.decode())
return self._cache_miss_response(request, net, path, _AsyncTeeToDisk)

View File

@@ -0,0 +1,312 @@
import asyncio
import logging
import os
import threading
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager, contextmanager
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, AsyncGenerator, Callable, Generator, Literal, Mapping, Optional, Sequence, Union
import hishel
import httpx
from httpx._types import ProxyTypes
from pyrate_limiter import Duration, Limiter
from .controller import get_cache_controller
from .filecache.transport import CachingTransport
from .key_generator import file_key_generator
from .ratelimiter import AsyncRateLimitingTransport, RateLimitingTransport, create_rate_limiter
from .serializer import JSONByteSerializer
logger = logging.getLogger(__name__)
try:
# enable http2 if h2 is installed
import h2 # type: ignore # noqa
logger.debug("HTTP2 available")
HTTP2 = True # pragma: no cover
except ImportError:
logger.debug("HTTP2 not available")
HTTP2 = False
@dataclass
class HttpxThrottleCache:
"""
Implements a rate limited, optional-cached HTTPX wrapper that returns client() (httpx.Client) or async_http_client() (httpx.AsyncClient).
Rate Limiting is across all connections, whether via client & async_htp_client, using pyrate_limiter. For multiprocessing, use pyrate_limiters
MultiprocessBucket or SqliteBucket w/ a file lock.
Caching is implemented via Hishel, which allows a variety of configurations, including AWS storage.
This function is used for all synchronous requests.
"""
httpx_params: dict[str, Any] = field(
default_factory=lambda: {"default_encoding": "utf-8", "http2": HTTP2, "verify": True}
)
cache_rules: dict[str, dict[str, Union[bool, int]]] = field(default_factory=lambda: {})
rate_limiter_enabled: bool = True
cache_mode: Literal[False, "Disabled", "Hishel-S3", "Hishel-File", "FileCache"] = "Hishel-File"
request_per_sec_limit: int = 10
max_delay: Duration = field(default_factory=lambda: Duration.DAY)
_client: Optional[httpx.Client] = None
rate_limiter: Optional[Limiter] = None
s3_bucket: Optional[str] = None
s3_client: Optional[Any] = None
user_agent: Optional[str] = None
user_agent_factory: Optional[Callable] = None
cache_dir: Optional[Union[Path, str]] = None
lock = threading.Lock()
proxy: Optional[ProxyTypes] = None
def __post_init__(self):
self.cache_dir = Path(self.cache_dir) if isinstance(self.cache_dir, str) else self.cache_dir
# self.lock = threading.Lock()
if self.rate_limiter_enabled and self.rate_limiter is None:
self.rate_limiter = create_rate_limiter(
requests_per_second=self.request_per_sec_limit, max_delay=self.max_delay
)
if (self.cache_mode != "Disabled" or self.cache_mode is False) and not self.cache_rules:
logger.info("Cache is enabled, but no cache_rules provided. Will use default caching.")
if self.cache_mode == "Disabled" or self.cache_mode is False:
pass
elif self.cache_mode == "Hishel-S3":
if self.s3_bucket is None:
raise ValueError("s3_bucket must be provided if using Hishel-S3 storage")
else: # Hishel-File or FileCache
if self.cache_dir is None:
raise ValueError(f"cache_dir must be provided if using a file based cache: {self.cache_mode}")
else:
if not self.cache_dir.exists():
self.cache_dir.mkdir()
logger.debug(
"Initialized cache with cache_mode=%s, cache_dir=%s, rate_limiter_enabled=%s",
self.cache_mode,
self.cache_dir,
self.rate_limiter_enabled,
)
if os.environ.get("HTTPS_PROXY") is not None:
self.proxy = os.environ.get("HTTPS_PROXY")
def _populate_user_agent(self, params: dict):
if self.user_agent_factory is not None:
user_agent = self.user_agent_factory()
else:
user_agent = self.user_agent
if user_agent is not None:
if "headers" not in params:
params["headers"] = {}
params["headers"]["User-Agent"] = user_agent
return params
def populate_user_agent(self, params: dict):
"""Provided so clients can inspect the params that would be passed to HTTPX"""
return self._populate_user_agent(params)
def get_batch(self, *, urls: Sequence[str] | Mapping[str, Path], _client_mocker=None):
"""
Fetch a batch of URLs concurrently and either return their content in-memory
or stream them directly to files.
Uses background thread with an asyncio event loop.
Args:
urls (Sequence[str] | Mapping[str, Path]):
Returns:
list:
- If given mappings, then returns a list of Paths. Else, a list of the Content
Raises:
RuntimeError: If any URL responds with a status code other than 200 or 304.
"""
import aiofiles
async def _run():
async with self.async_http_client() as client:
if _client_mocker:
# For testing
_client_mocker(client)
async def task(url: str, path: Optional[Path]):
async with client.stream("GET", url) as r:
if r.status_code in (200, 304):
if path:
path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(path, "wb") as f:
async for chunk in r.aiter_bytes():
await f.write(chunk)
return path
else:
return await r.aread()
else:
raise RuntimeError(f"URL status code is not 200 or 304: {url=}")
if isinstance(urls, Mapping):
return await asyncio.gather(*(task(u, p) for u, p in urls.items()))
else:
return await asyncio.gather(*(task(u, None) for u in urls))
with ThreadPoolExecutor(1) as pool:
return pool.submit(lambda: asyncio.run(_run())).result()
def _get_httpx_transport_params(self, params: dict[str, Any]):
http2 = params.get("http2", False)
proxy = self.proxy
return {"http2": http2, "proxy": proxy}
@contextmanager
def http_client(self, bypass_cache: bool = False, **kwargs) -> Generator[httpx.Client, None, None]:
"""Provides and reuses a client. Does not close"""
if self._client is None:
with self.lock:
# Locking: not super critical, since worst case might be extra httpx clients created,
# but future proofing against TOCTOU races in free-threading world
if self._client is None:
logger.debug("Creating new HTTPX Client")
params = self.httpx_params.copy()
self._populate_user_agent(params)
params.update(**kwargs)
params["transport"] = self._get_transport(
bypass_cache=bypass_cache, httpx_transport_params=self._get_httpx_transport_params(params)
)
self._client = httpx.Client(**params)
yield self._client
def close(self):
if self._client is not None:
self._client.close()
self._client = None
def update_rate_limiter(self, requests_per_second: int, max_delay: Duration = Duration.DAY):
self.rate_limiter = create_rate_limiter(requests_per_second=requests_per_second, max_delay=Duration.DAY)
self.close()
def _client_factory_async(self, bypass_cache: bool, **kwargs) -> httpx.AsyncClient:
params = self.httpx_params.copy()
params.update(**kwargs)
self._populate_user_agent(params)
params["transport"] = self._get_async_transport(
bypass_cache=bypass_cache, httpx_transport_params=self._get_httpx_transport_params(params)
)
return httpx.AsyncClient(**params)
@asynccontextmanager
async def async_http_client(
self, client: Optional[httpx.AsyncClient] = None, bypass_cache: bool = False, **kwargs
) -> AsyncGenerator[httpx.AsyncClient, None]:
"""
Async callers should create a single client for a group of tasks, rather than creating a single client per task.
If a null client is passed, then this is a no-op and the client isn't closed. This (passing a client) occurs when a higher level async task creates the client to be used by child calls.
"""
if client is not None:
yield client # type: ignore # Caller is responsible for closing
return
async with self._client_factory_async(bypass_cache=bypass_cache, **kwargs) as client:
yield client
def _get_transport(self, bypass_cache: bool, httpx_transport_params: dict[str, Any]) -> httpx.BaseTransport:
"""
Constructs the Transport Chain:
Caching Transport (if enabled) => Rate Limiting Transport (if enabled) => httpx.HTTPTransport
"""
if self.rate_limiter_enabled:
assert self.rate_limiter is not None
next_transport = RateLimitingTransport(self.rate_limiter, **httpx_transport_params)
else:
next_transport = httpx.HTTPTransport(**httpx_transport_params)
if bypass_cache or self.cache_mode == "Disabled" or self.cache_mode is False:
logger.info("Cache is DISABLED, rate limiting only")
return next_transport
elif self.cache_mode == "FileCache":
assert self.cache_dir is not None
return CachingTransport(cache_dir=self.cache_dir, transport=next_transport, cache_rules=self.cache_rules)
else:
# either Hishel-S3 or Hishel-File
assert self.cache_mode == "Hishel-File" or self.cache_mode == "Hishel-S3"
controller = get_cache_controller(key_generator=file_key_generator, cache_rules=self.cache_rules)
if self.cache_mode == "Hishel-S3":
assert self.s3_bucket is not None
storage = hishel.S3Storage(
client=self.s3_client, bucket_name=self.s3_bucket, serializer=JSONByteSerializer()
)
else:
assert self.cache_mode == "Hishel-File"
assert self.cache_dir is not None
storage = hishel.FileStorage(base_path=Path(self.cache_dir), serializer=JSONByteSerializer())
return hishel.CacheTransport(transport=next_transport, storage=storage, controller=controller)
def _get_async_transport(
self, bypass_cache: bool, httpx_transport_params: dict[str, Any]
) -> httpx.AsyncBaseTransport:
"""
Constructs the Transport Chain:
Caching Transport (if enabled) => Rate Limiting Transport (if enabled) => httpx.HTTPTransport
"""
if self.rate_limiter_enabled:
assert self.rate_limiter is not None
next_transport = AsyncRateLimitingTransport(self.rate_limiter, **httpx_transport_params)
else:
next_transport = httpx.AsyncHTTPTransport(**httpx_transport_params)
if bypass_cache or self.cache_mode == "Disabled" or self.cache_mode is False:
logger.info("Cache is DISABLED, rate limiting only")
return next_transport
elif self.cache_mode == "FileCache":
assert self.cache_dir is not None
return CachingTransport(cache_dir=self.cache_dir, transport=next_transport, cache_rules=self.cache_rules) # pyright: ignore[reportArgumentType]
else:
# either Hishel-S3 or Hishel-File
assert self.cache_mode == "Hishel-File" or self.cache_mode == "Hishel-S3"
controller = get_cache_controller(key_generator=file_key_generator, cache_rules=self.cache_rules)
if self.cache_mode == "Hishel-S3":
assert self.s3_bucket is not None
storage = hishel.AsyncS3Storage(
client=self.s3_client, bucket_name=self.s3_bucket, serializer=JSONByteSerializer()
)
else:
assert self.cache_mode == "Hishel-File"
assert self.cache_dir is not None
storage = hishel.AsyncFileStorage(base_path=Path(self.cache_dir), serializer=JSONByteSerializer())
return hishel.AsyncCacheTransport(transport=next_transport, storage=storage, controller=controller)
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.close()

View File

@@ -0,0 +1,22 @@
from typing import Optional
import httpcore
def file_key_generator(request: httpcore.Request, body: Optional[bytes]) -> str:
"""Generates a stable, readable key for a given request.
Args:
request (httpcore.Request): _description_
body (bytes): _description_
Returns:
str: Persistent key for the request
"""
host = request.url.host.decode()
path_b, _, query_b = request.url.target.partition(b"?")
path = path_b.decode()
query = query_b.decode()
url_p = path.replace("/", "__") + (f"__{query.replace('&', '__').replace('=', '__')}" if query else "")
key = f"{host}_{url_p}"
return key

View File

@@ -0,0 +1,57 @@
"""
To control rate limit across multiple processes, see https://pyratelimiter.readthedocs.io/en/latest/#backends
"""
import logging
import httpx
from pyrate_limiter import Duration, InMemoryBucket, Limiter, Rate
logger = logging.getLogger(__name__)
def create_rate_limiter(requests_per_second: int, max_delay=Duration.DAY) -> Limiter:
rate = Rate(requests_per_second, Duration.SECOND)
rate_limits = [rate]
base_bucket = InMemoryBucket(rate_limits)
bucket = base_bucket
limiter = Limiter(bucket, max_delay=max_delay, raise_when_fail=False, retry_until_max_delay=True)
return limiter
class RateLimitingTransport(httpx.HTTPTransport):
def __init__(self, limiter: Limiter, **kwargs):
super().__init__(**kwargs)
self.limiter = limiter
def handle_request(self, request: httpx.Request, **kwargs) -> httpx.Response:
# using a constant string for item name means that the same
# rate is applied to all requests.
if self.limiter:
while not self.limiter.try_acquire(__name__):
logger.debug("Lock acquisition timed out, retrying") # pragma: no cover
logger.debug("Acquired lock")
logger.info("Making HTTP Request %s", request)
return super().handle_request(request, **kwargs)
class AsyncRateLimitingTransport(httpx.AsyncHTTPTransport):
def __init__(self, limiter: Limiter, **kwargs):
super().__init__(**kwargs)
self.limiter = limiter
async def handle_async_request(self, request: httpx.Request, **kwargs) -> httpx.Response:
if self.limiter:
while not await self.limiter.try_acquire_async(__name__):
logger.debug("Lock acquisition timed out, retrying") # pragma: no cover
logger.debug("Acquired lock")
logger.info("Making HTTP Request %s", request)
return await super().handle_async_request(request, **kwargs)

View File

@@ -0,0 +1,125 @@
import json
import logging
from datetime import datetime
from typing import Tuple, Union
from hishel._serializers import (
HEADERS_ENCODING,
KNOWN_REQUEST_EXTENSIONS,
KNOWN_RESPONSE_EXTENSIONS,
BaseSerializer,
Metadata,
normalized_url,
)
from httpcore import Request, Response
logger = logging.getLogger(__name__)
class JSONByteSerializer(BaseSerializer):
"""JSONByteSerializer stores HTTP metadata as compact UTF-8 JSON followed by raw binary body bytes,
separated by a single null byte. This avoids base64 encoding, significantly reducing size and
improving performance for large responses.."""
def dumps(self, response: Response, request: Request, metadata: Metadata) -> Union[str, bytes]:
"""
Dumps the HTTP response and its HTTP request.
:param response: An HTTP response
:type response: Response
:param request: An HTTP request
:type request: Request
:param metadata: Additional information about the stored response
:type metadata: Metadata
:return: Serialized response
:rtype: Union[str, bytes]
"""
response_dict = {
"status": response.status,
"headers": [
(key.decode(HEADERS_ENCODING), value.decode(HEADERS_ENCODING)) for key, value in response.headers
],
"extensions": {
key: value.decode("ascii")
for key, value in response.extensions.items()
if key in KNOWN_RESPONSE_EXTENSIONS
},
}
request_dict = {
"method": request.method.decode("ascii"),
"url": normalized_url(request.url),
"headers": [
(key.decode(HEADERS_ENCODING), value.decode(HEADERS_ENCODING)) for key, value in request.headers
],
"extensions": {key: value for key, value in request.extensions.items() if key in KNOWN_REQUEST_EXTENSIONS},
}
metadata_dict = {
"cache_key": metadata["cache_key"],
"number_of_uses": metadata["number_of_uses"],
"created_at": metadata["created_at"].strftime("%a, %d %b %Y %H:%M:%S GMT"),
}
full_json = {
"response": response_dict,
"request": request_dict,
"metadata": metadata_dict,
}
return json.dumps(full_json, separators=(",", ":")).encode("utf-8") + b"\0" + response.content
def loads(self, data: Union[str, bytes]) -> Tuple[Response, Request, Metadata]:
"""
Loads the HTTP response and its HTTP request from serialized data.
:param data: Serialized data
:type data: Union[str, bytes]
:return: HTTP response and its HTTP request
:rtype: Tuple[Response, Request, Metadata]
"""
data_b: bytes = data.encode("utf-8") if isinstance(data, str) else data
full_json, body = data_b.split(b"\0", 1)
full_json = json.loads(full_json.decode("utf-8"))
response_dict = full_json["response"]
request_dict = full_json["request"]
metadata_dict = full_json["metadata"]
metadata_dict["created_at"] = datetime.strptime(
metadata_dict["created_at"],
"%a, %d %b %Y %H:%M:%S GMT",
)
response = Response(
status=response_dict["status"],
headers=[
(key.encode(HEADERS_ENCODING), value.encode(HEADERS_ENCODING))
for key, value in response_dict["headers"]
],
content=body,
extensions={
key: value.encode("ascii")
for key, value in response_dict["extensions"].items()
if key in KNOWN_RESPONSE_EXTENSIONS
},
)
request = Request(
method=request_dict["method"],
url=request_dict["url"],
headers=[
(key.encode(HEADERS_ENCODING), value.encode(HEADERS_ENCODING)) for key, value in request_dict["headers"]
],
extensions={
key: value for key, value in request_dict["extensions"].items() if key in KNOWN_REQUEST_EXTENSIONS
},
)
metadata = Metadata(
cache_key=metadata_dict["cache_key"],
created_at=metadata_dict["created_at"],
number_of_uses=metadata_dict["number_of_uses"],
)
return response, request, metadata
@property
def is_binary(self) -> bool: # pragma: no cover
return True