Initial commit
This commit is contained in:
@@ -0,0 +1,15 @@
|
||||
from ._version import __version__
|
||||
from .httpxclientmanager import HttpxThrottleCache
|
||||
|
||||
__all__ = ["HttpxThrottleCache", "__version__"]
|
||||
|
||||
|
||||
EDGAR_CACHE_RULES = {
|
||||
r".*\.sec\.gov": {
|
||||
"/submissions.*": 600,
|
||||
r"/include/ticker\.txt.*": 600,
|
||||
r"/files/company_tickers\.json.*": 600,
|
||||
".*index/.*": 1800,
|
||||
"/Archives/edgar/data": True, # cache forever
|
||||
}
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,6 @@
|
||||
import importlib.metadata
|
||||
|
||||
try:
|
||||
__version__ = importlib.metadata.version(__package__ or __name__)
|
||||
except importlib.metadata.PackageNotFoundError: # pragma: no cover
|
||||
__version__ = "0.0.0"
|
||||
@@ -0,0 +1,108 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, Union
|
||||
|
||||
import hishel
|
||||
import httpcore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_rules(
|
||||
request_host: str, cache_rules: dict[str, dict[str, Union[bool, int]]]
|
||||
) -> Optional[dict[str, Union[bool, int]]]:
|
||||
for site_pattern, rules in cache_rules.items():
|
||||
if re.match(site_pattern, request_host):
|
||||
logger.info("matched %s, using value %s: %s", site_pattern, request_host, rules)
|
||||
|
||||
return rules
|
||||
|
||||
logger.debug("No patterns matched %s", request_host)
|
||||
|
||||
|
||||
def match_request(target, cache_rules_for_site: dict[str, Union[bool, int]]):
|
||||
for pat, v in cache_rules_for_site.items():
|
||||
if re.match(pat, target):
|
||||
logger.info("%s matched %s, using value %s", target, pat, v)
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def get_rule_for_request(
|
||||
request_host: str, target: str, cache_rules: dict[str, dict[str, Union[bool, int]]]
|
||||
) -> Optional[Union[bool, int]]:
|
||||
cache_rules_for_site = get_rules(request_host=request_host, cache_rules=cache_rules)
|
||||
|
||||
if cache_rules_for_site:
|
||||
is_cacheable = match_request(target=target, cache_rules_for_site=cache_rules_for_site)
|
||||
return is_cacheable
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_cache_controller(key_generator, cache_rules: dict[str, dict[str, Union[bool, int]]], **kwargs):
|
||||
class EdgarController(hishel.Controller):
|
||||
def is_cachable(self, request: httpcore.Request, response: httpcore.Response) -> bool:
|
||||
if response.status not in self._cacheable_status_codes:
|
||||
return False
|
||||
|
||||
cache_period = get_rule_for_request(
|
||||
request_host=request.url.host.decode(), target=request.url.target.decode(), cache_rules=cache_rules
|
||||
)
|
||||
|
||||
if cache_period: # True or an Int>0
|
||||
return True
|
||||
elif cache_period is False or cache_period == 0: # Explicitly not cacheable
|
||||
return False
|
||||
else:
|
||||
# Fall through default caching policy
|
||||
super_is_cachable = super().is_cachable(request, response)
|
||||
logger.debug("%s is cacheable %s", request.url, super_is_cachable)
|
||||
return super_is_cachable
|
||||
|
||||
def construct_response_from_cache(
|
||||
self, request: httpcore.Request, response: httpcore.Response, original_request: httpcore.Request
|
||||
) -> Union[httpcore.Request, httpcore.Response, None]:
|
||||
if (
|
||||
response.status not in self._cacheable_status_codes
|
||||
): # pragma: no cover - would only occur if the cache was loaded then rules changed
|
||||
return None
|
||||
|
||||
cache_period = get_rule_for_request(
|
||||
request_host=request.url.host.decode(), target=request.url.target.decode(), cache_rules=cache_rules
|
||||
)
|
||||
|
||||
if cache_period is True:
|
||||
# Cache forever, never recheck
|
||||
logger.debug("Cache hit for %s", request.url)
|
||||
return response
|
||||
elif (
|
||||
cache_period is False or cache_period == 0
|
||||
): # pragma: no cover - would only occur if the cache was loaded then rules changed
|
||||
return None
|
||||
elif cache_period: # int
|
||||
max_age = cache_period
|
||||
|
||||
age_seconds = hishel._controller.get_age(response, self._clock)
|
||||
|
||||
if age_seconds > max_age:
|
||||
logger.debug(
|
||||
"Request needs to be validated before using %s (age=%d, max_age=%d)",
|
||||
request.url,
|
||||
age_seconds,
|
||||
max_age,
|
||||
)
|
||||
self._make_request_conditional(request=request, response=response)
|
||||
return request
|
||||
else:
|
||||
logger.debug("Cache hit for %s (age=%d, max_age=%d)", request.url, age_seconds, max_age)
|
||||
return response
|
||||
else:
|
||||
logger.debug("No rules applied to %s, using default", request.url)
|
||||
return super().construct_response_from_cache(request, response, original_request)
|
||||
|
||||
controller = EdgarController(
|
||||
cacheable_methods=["GET", "POST"], cacheable_status_codes=[200], key_generator=key_generator, **kwargs
|
||||
)
|
||||
|
||||
return controller
|
||||
Binary file not shown.
@@ -0,0 +1,363 @@
|
||||
"""An alternative cache using:
|
||||
- Flat files
|
||||
|
||||
"""
|
||||
|
||||
import calendar
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Optional, Union
|
||||
from urllib.parse import quote, unquote
|
||||
|
||||
import aiofiles
|
||||
import httpx
|
||||
from filelock import AsyncFileLock, FileLock
|
||||
from httpx._types import SyncByteStream # protocol type
|
||||
|
||||
from ..controller import get_rule_for_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AlreadyLockedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DualFileStream(httpx._types.SyncByteStream, httpx.AsyncByteStream):
|
||||
def __init__(
|
||||
self,
|
||||
path: Path,
|
||||
chunk_size: int = 1024 * 1024,
|
||||
on_close: Optional[callable] = None,
|
||||
async_on_close: Optional[callable] = None,
|
||||
):
|
||||
self.path, self.chunk_size = Path(path), chunk_size
|
||||
self.on_close, self.async_on_close = on_close, async_on_close
|
||||
|
||||
def __iter__(self):
|
||||
with open(self.path, "rb") as f:
|
||||
while True:
|
||||
b = f.read(self.chunk_size)
|
||||
if not b:
|
||||
break
|
||||
yield b
|
||||
|
||||
def close(self) -> None:
|
||||
if self.on_close: # pragma: no cover
|
||||
self.on_close()
|
||||
|
||||
async def __aiter__(self):
|
||||
async with aiofiles.open(self.path, "rb") as f:
|
||||
while True:
|
||||
b = await f.read(self.chunk_size)
|
||||
if not b:
|
||||
break
|
||||
yield b
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self.async_on_close: # pragma: no cover
|
||||
await self.async_on_close()
|
||||
|
||||
|
||||
class FileCache:
|
||||
def __init__(self, cache_dir: Union[str, Path], locking: bool = True):
|
||||
self.cache_dir = Path(cache_dir)
|
||||
logger.info("cache_dir=%s", self.cache_dir)
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.locking = locking
|
||||
|
||||
def _meta_path(self, p: Path) -> Path:
|
||||
return p.with_suffix(p.suffix + ".meta")
|
||||
|
||||
def _load_meta(self, p: Path) -> dict:
|
||||
try:
|
||||
return json.loads(self._meta_path(p).read_text())
|
||||
except FileNotFoundError: # pragma: no cover
|
||||
return {}
|
||||
|
||||
def to_path(self, host: str, path: str, query: str) -> Path:
|
||||
site = host.lower().rstrip(".")
|
||||
(self.cache_dir / site).mkdir(parents=True, exist_ok=True)
|
||||
name = unquote(path).strip("/").replace("/", "-") or "index"
|
||||
if query:
|
||||
name += "-" + unquote(query).replace("&", "-").replace("=", "-")
|
||||
return self.cache_dir / site / quote(name, safe="._-~")
|
||||
|
||||
def get_if_fresh(
|
||||
self, host: str, path: str, query: str, cache_rules: dict[str, dict[str, Union[bool, int]]]
|
||||
) -> tuple[bool, Optional[Path]]:
|
||||
cached = get_rule_for_request(request_host=host, target=path, cache_rules=cache_rules)
|
||||
|
||||
if not cached:
|
||||
logger.info("No cache policy for %s://%s, not retrieving from cache", host, path)
|
||||
return False, None
|
||||
|
||||
p = self.to_path(host=host, path=path, query=query)
|
||||
if not p.exists():
|
||||
logger.info("Cache file doesn't exist: %s for %s", path, p)
|
||||
return False, None
|
||||
|
||||
meta = self._load_meta(p)
|
||||
fetched = meta.get("fetched")
|
||||
if not fetched:
|
||||
return False, p # pragma: no cover
|
||||
|
||||
if cached is True:
|
||||
logger.info("Cache policy allows unlimited cache, returning %s", p)
|
||||
return True, p
|
||||
|
||||
age = round(time.time() - fetched)
|
||||
if age < 0: # pragma: no cover
|
||||
raise ValueError(f"Age is less than 0, impossible {age=}, file {path=}")
|
||||
logger.info("file is %s seconds old, policy allows caching for up to %s", age, cached)
|
||||
return (age <= cached, p)
|
||||
|
||||
|
||||
class _TeeCore:
|
||||
def __init__(self, resp: httpx.Response, path: Path, locking: bool, last_modified: str, access_date: str):
|
||||
assert path is not None
|
||||
|
||||
self.resp = resp
|
||||
self.path = path
|
||||
self.tmp = path.with_name(path.name + ".tmp")
|
||||
self.lock = FileLock(str(path) + ".lock") if locking else None
|
||||
self.fh = None
|
||||
if last_modified:
|
||||
self.mtime = calendar.timegm(time.strptime(last_modified, "%a, %d %b %Y %H:%M:%S GMT"))
|
||||
else:
|
||||
self.mtime = None
|
||||
|
||||
if access_date:
|
||||
self.atime = calendar.timegm(time.strptime(access_date, "%a, %d %b %Y %H:%M:%S GMT"))
|
||||
else:
|
||||
self.atime = None # pragma: no cover
|
||||
|
||||
def acquire(self):
|
||||
self.lock and self.lock.acquire() # pyright: ignore[reportUnusedExpression]
|
||||
|
||||
def open_tmp(self):
|
||||
self.fh = open(self.tmp, "wb")
|
||||
|
||||
def write(self, chunk: bytes):
|
||||
self.fh.write(chunk) # pyright: ignore[reportOptionalMemberAccess]
|
||||
|
||||
def finalize(self):
|
||||
try:
|
||||
if self.fh and not self.fh.closed:
|
||||
self.fh.flush()
|
||||
os.fsync(self.fh.fileno())
|
||||
self.fh.close()
|
||||
os.replace(self.tmp, self.path)
|
||||
try:
|
||||
meta_path = self.path.with_suffix(self.path.suffix + ".meta")
|
||||
headers = {
|
||||
"content-type": self.resp.headers.get("content-type"),
|
||||
"content-encoding": self.resp.headers.get("content-encoding"),
|
||||
}
|
||||
|
||||
meta_path.write_text(json.dumps({"fetched": self.atime, "origin_lm": self.mtime, "headers": headers}))
|
||||
except FileNotFoundError: # pragma: no cover
|
||||
pass
|
||||
finally:
|
||||
if self.lock and getattr(self.lock, "is_locked", False):
|
||||
self.lock.release()
|
||||
|
||||
|
||||
class _TeeToDisk(SyncByteStream):
|
||||
def __init__(self, resp: httpx.Response, path: Path, locking: bool, last_modified: str, access_date: str) -> None:
|
||||
self.core = _TeeCore(resp, path, locking, last_modified, access_date)
|
||||
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
self.core.acquire()
|
||||
try:
|
||||
self.core.open_tmp()
|
||||
for chunk in self.core.resp.iter_raw():
|
||||
self.core.write(chunk)
|
||||
yield chunk
|
||||
finally:
|
||||
self.core.finalize()
|
||||
|
||||
def close(self) -> None:
|
||||
try:
|
||||
self.core.resp.close()
|
||||
finally:
|
||||
self.core.finalize()
|
||||
|
||||
|
||||
class _AsyncTeeToDisk(httpx.AsyncByteStream):
|
||||
def __init__(self, resp, path, locking, last_modified, access_date):
|
||||
self.resp = resp
|
||||
self.path = path
|
||||
self.tmp = path.with_name(path.name + ".tmp")
|
||||
self.lock = AsyncFileLock(str(path) + ".lock") if locking else None
|
||||
if last_modified:
|
||||
self.mtime = calendar.timegm(time.strptime(last_modified, "%a, %d %b %Y %H:%M:%S GMT"))
|
||||
else:
|
||||
self.mtime = None
|
||||
|
||||
if access_date:
|
||||
self.atime = calendar.timegm(time.strptime(access_date, "%a, %d %b %Y %H:%M:%S GMT"))
|
||||
else:
|
||||
self.atime = None # pragma: no cover
|
||||
|
||||
async def __aiter__(self):
|
||||
if self.lock:
|
||||
await self.lock.acquire()
|
||||
try:
|
||||
async with aiofiles.open(self.tmp, "wb") as f:
|
||||
async for chunk in self.resp.aiter_raw():
|
||||
await f.write(chunk)
|
||||
yield chunk
|
||||
os.replace(self.tmp, self.path)
|
||||
async with aiofiles.open(self.path.with_suffix(self.path.suffix + ".meta"), "w") as m:
|
||||
headers = {
|
||||
"content-type": self.resp.headers.get("content-type"),
|
||||
"content-encoding": self.resp.headers.get("content-encoding"),
|
||||
}
|
||||
await m.write(json.dumps({"fetched": self.atime, "origin_lm": self.mtime, "headers": headers}))
|
||||
finally:
|
||||
if self.lock:
|
||||
await self.lock.release()
|
||||
|
||||
async def aclose(self):
|
||||
try:
|
||||
await self.resp.aclose()
|
||||
finally:
|
||||
if self.lock:
|
||||
await self.lock.release()
|
||||
|
||||
|
||||
class CachingTransport(httpx.BaseTransport, httpx.AsyncBaseTransport):
|
||||
cache_rules: dict[str, dict[str, Union[bool, int]]]
|
||||
streaming_cutoff: int = 8 * 1024 * 1024
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_dir: Union[str, Path],
|
||||
cache_rules: dict[str, dict[str, Union[bool, int]]],
|
||||
transport: Optional[httpx.BaseTransport] = None,
|
||||
):
|
||||
self._cache = FileCache(cache_dir=cache_dir, locking=True)
|
||||
self.transport = transport or httpx.HTTPTransport()
|
||||
self.cache_rules = cache_rules
|
||||
|
||||
def _cache_hit_response(self, req, path: Path, status_code: int = 200):
|
||||
"""
|
||||
TODO: More carefully consider async here. read_text, read_bytes both are blocking.
|
||||
|
||||
Large files are streamed async, so the only blocking events here are for reading small(ish) files
|
||||
"""
|
||||
meta = json.loads(path.with_suffix(path.suffix + ".meta").read_text())
|
||||
date = time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime(meta["fetched"]))
|
||||
last_modified = time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime(meta["origin_lm"]))
|
||||
|
||||
ct = meta.get("headers", {}).get("content-type", "application/octet-stream")
|
||||
ce = meta.get("headers", {}).get("content-encoding")
|
||||
size = os.path.getsize(path)
|
||||
|
||||
headers = [
|
||||
("x-cache", "HIT"),
|
||||
("content-length", str(size)),
|
||||
("Date", date),
|
||||
("Last-Modified", last_modified),
|
||||
]
|
||||
if ce:
|
||||
headers.append(("content-encoding", ce))
|
||||
if ct:
|
||||
headers.append(("content-type", ct))
|
||||
|
||||
if size < self.streaming_cutoff:
|
||||
# If the file is small, just read it and return it
|
||||
return httpx.Response(
|
||||
status_code=status_code,
|
||||
headers=headers,
|
||||
content=path.read_bytes(),
|
||||
request=req,
|
||||
)
|
||||
else:
|
||||
# If the file is large, stream it
|
||||
return httpx.Response(
|
||||
status_code=status_code,
|
||||
headers=headers,
|
||||
stream=DualFileStream(path),
|
||||
request=req,
|
||||
)
|
||||
|
||||
def _cache_miss_response(self, req, net, path, tee_factory):
|
||||
if net.status_code != 200:
|
||||
return net
|
||||
|
||||
miss_headers = [
|
||||
(k, v)
|
||||
for k, v in net.headers.items()
|
||||
if k.lower() not in ("transfer-encoding",) # "content-encoding", "content-length", "transfer-encoding")
|
||||
]
|
||||
miss_headers.append(("x-cache", "MISS"))
|
||||
return httpx.Response(
|
||||
status_code=net.status_code,
|
||||
headers=miss_headers,
|
||||
stream=tee_factory(
|
||||
net, path, self._cache.locking, net.headers.get("Last-Modified"), net.headers.get("Date")
|
||||
),
|
||||
request=req,
|
||||
extensions={**net.extensions, "decode_content": False},
|
||||
)
|
||||
|
||||
def return_if_fresh(self, request):
|
||||
host = request.url.host
|
||||
path = request.url.path
|
||||
query = request.url.query.decode() if request.url.query else ""
|
||||
|
||||
fresh, path = self._cache.get_if_fresh(host, path, query, self.cache_rules)
|
||||
|
||||
if path:
|
||||
if fresh:
|
||||
return self._cache_hit_response(request, path), path
|
||||
else:
|
||||
lm = json.loads(path.with_suffix(path.suffix + ".meta").read_text()).get("origin_lm")
|
||||
if lm:
|
||||
request.headers["If-Modified-Since"] = time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime(lm))
|
||||
return None, path
|
||||
else:
|
||||
return None, None
|
||||
|
||||
def handle_request(self, request: httpx.Request) -> httpx.Response:
|
||||
if request.method != "GET":
|
||||
return self.transport.handle_request(request)
|
||||
|
||||
response, path = self.return_if_fresh(request)
|
||||
if response:
|
||||
return response
|
||||
|
||||
net = self.transport.handle_request(request)
|
||||
if net.status_code == 304:
|
||||
logger.info("304 for %s", request)
|
||||
assert path is not None # must be true
|
||||
return self._cache_hit_response(request, path, status_code=304)
|
||||
|
||||
host = request.url.host
|
||||
path = request.url.path
|
||||
query = request.url.query.decode() if request.url.query else ""
|
||||
|
||||
path = self._cache.to_path(host, path, query)
|
||||
return self._cache_miss_response(request, net, path, _TeeToDisk)
|
||||
|
||||
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
|
||||
if request.method != "GET":
|
||||
return await self.transport.handle_async_request(request) # type: ignore[attr-defined]
|
||||
|
||||
response, path = self.return_if_fresh(request)
|
||||
if response:
|
||||
return response
|
||||
|
||||
net = await self.transport.handle_async_request(request)
|
||||
if net.status_code == 304:
|
||||
assert path is not None # must be true
|
||||
logger.info("304 for %s", request)
|
||||
return self._cache_hit_response(request, path, status_code=304)
|
||||
|
||||
path = self._cache.to_path(request.url.host, request.url.path, request.url.query.decode())
|
||||
return self._cache_miss_response(request, net, path, _AsyncTeeToDisk)
|
||||
@@ -0,0 +1,312 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, AsyncGenerator, Callable, Generator, Literal, Mapping, Optional, Sequence, Union
|
||||
|
||||
import hishel
|
||||
import httpx
|
||||
from httpx._types import ProxyTypes
|
||||
from pyrate_limiter import Duration, Limiter
|
||||
|
||||
from .controller import get_cache_controller
|
||||
from .filecache.transport import CachingTransport
|
||||
from .key_generator import file_key_generator
|
||||
from .ratelimiter import AsyncRateLimitingTransport, RateLimitingTransport, create_rate_limiter
|
||||
from .serializer import JSONByteSerializer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
# enable http2 if h2 is installed
|
||||
import h2 # type: ignore # noqa
|
||||
|
||||
logger.debug("HTTP2 available")
|
||||
|
||||
HTTP2 = True # pragma: no cover
|
||||
except ImportError:
|
||||
logger.debug("HTTP2 not available")
|
||||
HTTP2 = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class HttpxThrottleCache:
|
||||
"""
|
||||
Implements a rate limited, optional-cached HTTPX wrapper that returns client() (httpx.Client) or async_http_client() (httpx.AsyncClient).
|
||||
|
||||
Rate Limiting is across all connections, whether via client & async_htp_client, using pyrate_limiter. For multiprocessing, use pyrate_limiters
|
||||
MultiprocessBucket or SqliteBucket w/ a file lock.
|
||||
|
||||
Caching is implemented via Hishel, which allows a variety of configurations, including AWS storage.
|
||||
|
||||
This function is used for all synchronous requests.
|
||||
"""
|
||||
|
||||
httpx_params: dict[str, Any] = field(
|
||||
default_factory=lambda: {"default_encoding": "utf-8", "http2": HTTP2, "verify": True}
|
||||
)
|
||||
|
||||
cache_rules: dict[str, dict[str, Union[bool, int]]] = field(default_factory=lambda: {})
|
||||
rate_limiter_enabled: bool = True
|
||||
cache_mode: Literal[False, "Disabled", "Hishel-S3", "Hishel-File", "FileCache"] = "Hishel-File"
|
||||
request_per_sec_limit: int = 10
|
||||
max_delay: Duration = field(default_factory=lambda: Duration.DAY)
|
||||
_client: Optional[httpx.Client] = None
|
||||
|
||||
rate_limiter: Optional[Limiter] = None
|
||||
s3_bucket: Optional[str] = None
|
||||
s3_client: Optional[Any] = None
|
||||
user_agent: Optional[str] = None
|
||||
user_agent_factory: Optional[Callable] = None
|
||||
|
||||
cache_dir: Optional[Union[Path, str]] = None
|
||||
|
||||
lock = threading.Lock()
|
||||
|
||||
proxy: Optional[ProxyTypes] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.cache_dir = Path(self.cache_dir) if isinstance(self.cache_dir, str) else self.cache_dir
|
||||
# self.lock = threading.Lock()
|
||||
|
||||
if self.rate_limiter_enabled and self.rate_limiter is None:
|
||||
self.rate_limiter = create_rate_limiter(
|
||||
requests_per_second=self.request_per_sec_limit, max_delay=self.max_delay
|
||||
)
|
||||
|
||||
if (self.cache_mode != "Disabled" or self.cache_mode is False) and not self.cache_rules:
|
||||
logger.info("Cache is enabled, but no cache_rules provided. Will use default caching.")
|
||||
|
||||
if self.cache_mode == "Disabled" or self.cache_mode is False:
|
||||
pass
|
||||
elif self.cache_mode == "Hishel-S3":
|
||||
if self.s3_bucket is None:
|
||||
raise ValueError("s3_bucket must be provided if using Hishel-S3 storage")
|
||||
else: # Hishel-File or FileCache
|
||||
if self.cache_dir is None:
|
||||
raise ValueError(f"cache_dir must be provided if using a file based cache: {self.cache_mode}")
|
||||
else:
|
||||
if not self.cache_dir.exists():
|
||||
self.cache_dir.mkdir()
|
||||
|
||||
logger.debug(
|
||||
"Initialized cache with cache_mode=%s, cache_dir=%s, rate_limiter_enabled=%s",
|
||||
self.cache_mode,
|
||||
self.cache_dir,
|
||||
self.rate_limiter_enabled,
|
||||
)
|
||||
|
||||
if os.environ.get("HTTPS_PROXY") is not None:
|
||||
self.proxy = os.environ.get("HTTPS_PROXY")
|
||||
|
||||
def _populate_user_agent(self, params: dict):
|
||||
if self.user_agent_factory is not None:
|
||||
user_agent = self.user_agent_factory()
|
||||
else:
|
||||
user_agent = self.user_agent
|
||||
|
||||
if user_agent is not None:
|
||||
if "headers" not in params:
|
||||
params["headers"] = {}
|
||||
|
||||
params["headers"]["User-Agent"] = user_agent
|
||||
return params
|
||||
|
||||
def populate_user_agent(self, params: dict):
|
||||
"""Provided so clients can inspect the params that would be passed to HTTPX"""
|
||||
return self._populate_user_agent(params)
|
||||
|
||||
def get_batch(self, *, urls: Sequence[str] | Mapping[str, Path], _client_mocker=None):
|
||||
"""
|
||||
Fetch a batch of URLs concurrently and either return their content in-memory
|
||||
or stream them directly to files.
|
||||
|
||||
Uses background thread with an asyncio event loop.
|
||||
|
||||
Args:
|
||||
urls (Sequence[str] | Mapping[str, Path]):
|
||||
|
||||
Returns:
|
||||
list:
|
||||
- If given mappings, then returns a list of Paths. Else, a list of the Content
|
||||
|
||||
Raises:
|
||||
RuntimeError: If any URL responds with a status code other than 200 or 304.
|
||||
"""
|
||||
|
||||
import aiofiles
|
||||
|
||||
async def _run():
|
||||
async with self.async_http_client() as client:
|
||||
if _client_mocker:
|
||||
# For testing
|
||||
_client_mocker(client)
|
||||
|
||||
async def task(url: str, path: Optional[Path]):
|
||||
async with client.stream("GET", url) as r:
|
||||
if r.status_code in (200, 304):
|
||||
if path:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
async with aiofiles.open(path, "wb") as f:
|
||||
async for chunk in r.aiter_bytes():
|
||||
await f.write(chunk)
|
||||
return path
|
||||
else:
|
||||
return await r.aread()
|
||||
else:
|
||||
raise RuntimeError(f"URL status code is not 200 or 304: {url=}")
|
||||
|
||||
if isinstance(urls, Mapping):
|
||||
return await asyncio.gather(*(task(u, p) for u, p in urls.items()))
|
||||
else:
|
||||
return await asyncio.gather(*(task(u, None) for u in urls))
|
||||
|
||||
with ThreadPoolExecutor(1) as pool:
|
||||
return pool.submit(lambda: asyncio.run(_run())).result()
|
||||
|
||||
def _get_httpx_transport_params(self, params: dict[str, Any]):
|
||||
http2 = params.get("http2", False)
|
||||
proxy = self.proxy
|
||||
|
||||
return {"http2": http2, "proxy": proxy}
|
||||
|
||||
@contextmanager
|
||||
def http_client(self, bypass_cache: bool = False, **kwargs) -> Generator[httpx.Client, None, None]:
|
||||
"""Provides and reuses a client. Does not close"""
|
||||
if self._client is None:
|
||||
with self.lock:
|
||||
# Locking: not super critical, since worst case might be extra httpx clients created,
|
||||
# but future proofing against TOCTOU races in free-threading world
|
||||
if self._client is None:
|
||||
logger.debug("Creating new HTTPX Client")
|
||||
params = self.httpx_params.copy()
|
||||
|
||||
self._populate_user_agent(params)
|
||||
|
||||
params.update(**kwargs)
|
||||
params["transport"] = self._get_transport(
|
||||
bypass_cache=bypass_cache, httpx_transport_params=self._get_httpx_transport_params(params)
|
||||
)
|
||||
self._client = httpx.Client(**params)
|
||||
|
||||
yield self._client
|
||||
|
||||
def close(self):
|
||||
if self._client is not None:
|
||||
self._client.close()
|
||||
self._client = None
|
||||
|
||||
def update_rate_limiter(self, requests_per_second: int, max_delay: Duration = Duration.DAY):
|
||||
self.rate_limiter = create_rate_limiter(requests_per_second=requests_per_second, max_delay=Duration.DAY)
|
||||
|
||||
self.close()
|
||||
|
||||
def _client_factory_async(self, bypass_cache: bool, **kwargs) -> httpx.AsyncClient:
|
||||
params = self.httpx_params.copy()
|
||||
params.update(**kwargs)
|
||||
self._populate_user_agent(params)
|
||||
params["transport"] = self._get_async_transport(
|
||||
bypass_cache=bypass_cache, httpx_transport_params=self._get_httpx_transport_params(params)
|
||||
)
|
||||
|
||||
return httpx.AsyncClient(**params)
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_http_client(
|
||||
self, client: Optional[httpx.AsyncClient] = None, bypass_cache: bool = False, **kwargs
|
||||
) -> AsyncGenerator[httpx.AsyncClient, None]:
|
||||
"""
|
||||
Async callers should create a single client for a group of tasks, rather than creating a single client per task.
|
||||
|
||||
If a null client is passed, then this is a no-op and the client isn't closed. This (passing a client) occurs when a higher level async task creates the client to be used by child calls.
|
||||
"""
|
||||
|
||||
if client is not None:
|
||||
yield client # type: ignore # Caller is responsible for closing
|
||||
return
|
||||
|
||||
async with self._client_factory_async(bypass_cache=bypass_cache, **kwargs) as client:
|
||||
yield client
|
||||
|
||||
def _get_transport(self, bypass_cache: bool, httpx_transport_params: dict[str, Any]) -> httpx.BaseTransport:
|
||||
"""
|
||||
Constructs the Transport Chain:
|
||||
|
||||
Caching Transport (if enabled) => Rate Limiting Transport (if enabled) => httpx.HTTPTransport
|
||||
"""
|
||||
if self.rate_limiter_enabled:
|
||||
assert self.rate_limiter is not None
|
||||
next_transport = RateLimitingTransport(self.rate_limiter, **httpx_transport_params)
|
||||
else:
|
||||
next_transport = httpx.HTTPTransport(**httpx_transport_params)
|
||||
|
||||
if bypass_cache or self.cache_mode == "Disabled" or self.cache_mode is False:
|
||||
logger.info("Cache is DISABLED, rate limiting only")
|
||||
return next_transport
|
||||
elif self.cache_mode == "FileCache":
|
||||
assert self.cache_dir is not None
|
||||
return CachingTransport(cache_dir=self.cache_dir, transport=next_transport, cache_rules=self.cache_rules)
|
||||
else:
|
||||
# either Hishel-S3 or Hishel-File
|
||||
assert self.cache_mode == "Hishel-File" or self.cache_mode == "Hishel-S3"
|
||||
controller = get_cache_controller(key_generator=file_key_generator, cache_rules=self.cache_rules)
|
||||
|
||||
if self.cache_mode == "Hishel-S3":
|
||||
assert self.s3_bucket is not None
|
||||
storage = hishel.S3Storage(
|
||||
client=self.s3_client, bucket_name=self.s3_bucket, serializer=JSONByteSerializer()
|
||||
)
|
||||
else:
|
||||
assert self.cache_mode == "Hishel-File"
|
||||
assert self.cache_dir is not None
|
||||
storage = hishel.FileStorage(base_path=Path(self.cache_dir), serializer=JSONByteSerializer())
|
||||
|
||||
return hishel.CacheTransport(transport=next_transport, storage=storage, controller=controller)
|
||||
|
||||
def _get_async_transport(
|
||||
self, bypass_cache: bool, httpx_transport_params: dict[str, Any]
|
||||
) -> httpx.AsyncBaseTransport:
|
||||
"""
|
||||
Constructs the Transport Chain:
|
||||
|
||||
Caching Transport (if enabled) => Rate Limiting Transport (if enabled) => httpx.HTTPTransport
|
||||
"""
|
||||
|
||||
if self.rate_limiter_enabled:
|
||||
assert self.rate_limiter is not None
|
||||
next_transport = AsyncRateLimitingTransport(self.rate_limiter, **httpx_transport_params)
|
||||
else:
|
||||
next_transport = httpx.AsyncHTTPTransport(**httpx_transport_params)
|
||||
|
||||
if bypass_cache or self.cache_mode == "Disabled" or self.cache_mode is False:
|
||||
logger.info("Cache is DISABLED, rate limiting only")
|
||||
return next_transport
|
||||
elif self.cache_mode == "FileCache":
|
||||
assert self.cache_dir is not None
|
||||
return CachingTransport(cache_dir=self.cache_dir, transport=next_transport, cache_rules=self.cache_rules) # pyright: ignore[reportArgumentType]
|
||||
else:
|
||||
# either Hishel-S3 or Hishel-File
|
||||
assert self.cache_mode == "Hishel-File" or self.cache_mode == "Hishel-S3"
|
||||
controller = get_cache_controller(key_generator=file_key_generator, cache_rules=self.cache_rules)
|
||||
|
||||
if self.cache_mode == "Hishel-S3":
|
||||
assert self.s3_bucket is not None
|
||||
storage = hishel.AsyncS3Storage(
|
||||
client=self.s3_client, bucket_name=self.s3_bucket, serializer=JSONByteSerializer()
|
||||
)
|
||||
else:
|
||||
assert self.cache_mode == "Hishel-File"
|
||||
assert self.cache_dir is not None
|
||||
storage = hishel.AsyncFileStorage(base_path=Path(self.cache_dir), serializer=JSONByteSerializer())
|
||||
|
||||
return hishel.AsyncCacheTransport(transport=next_transport, storage=storage, controller=controller)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.close()
|
||||
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
import httpcore
|
||||
|
||||
|
||||
def file_key_generator(request: httpcore.Request, body: Optional[bytes]) -> str:
|
||||
"""Generates a stable, readable key for a given request.
|
||||
|
||||
Args:
|
||||
request (httpcore.Request): _description_
|
||||
body (bytes): _description_
|
||||
|
||||
Returns:
|
||||
str: Persistent key for the request
|
||||
"""
|
||||
host = request.url.host.decode()
|
||||
path_b, _, query_b = request.url.target.partition(b"?")
|
||||
path = path_b.decode()
|
||||
query = query_b.decode()
|
||||
url_p = path.replace("/", "__") + (f"__{query.replace('&', '__').replace('=', '__')}" if query else "")
|
||||
key = f"{host}_{url_p}"
|
||||
return key
|
||||
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
To control rate limit across multiple processes, see https://pyratelimiter.readthedocs.io/en/latest/#backends
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
from pyrate_limiter import Duration, InMemoryBucket, Limiter, Rate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_rate_limiter(requests_per_second: int, max_delay=Duration.DAY) -> Limiter:
|
||||
rate = Rate(requests_per_second, Duration.SECOND)
|
||||
rate_limits = [rate]
|
||||
|
||||
base_bucket = InMemoryBucket(rate_limits)
|
||||
|
||||
bucket = base_bucket
|
||||
|
||||
limiter = Limiter(bucket, max_delay=max_delay, raise_when_fail=False, retry_until_max_delay=True)
|
||||
|
||||
return limiter
|
||||
|
||||
|
||||
class RateLimitingTransport(httpx.HTTPTransport):
|
||||
def __init__(self, limiter: Limiter, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.limiter = limiter
|
||||
|
||||
def handle_request(self, request: httpx.Request, **kwargs) -> httpx.Response:
|
||||
# using a constant string for item name means that the same
|
||||
# rate is applied to all requests.
|
||||
if self.limiter:
|
||||
while not self.limiter.try_acquire(__name__):
|
||||
logger.debug("Lock acquisition timed out, retrying") # pragma: no cover
|
||||
|
||||
logger.debug("Acquired lock")
|
||||
|
||||
logger.info("Making HTTP Request %s", request)
|
||||
return super().handle_request(request, **kwargs)
|
||||
|
||||
|
||||
class AsyncRateLimitingTransport(httpx.AsyncHTTPTransport):
|
||||
def __init__(self, limiter: Limiter, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.limiter = limiter
|
||||
|
||||
async def handle_async_request(self, request: httpx.Request, **kwargs) -> httpx.Response:
|
||||
if self.limiter:
|
||||
while not await self.limiter.try_acquire_async(__name__):
|
||||
logger.debug("Lock acquisition timed out, retrying") # pragma: no cover
|
||||
|
||||
logger.debug("Acquired lock")
|
||||
|
||||
logger.info("Making HTTP Request %s", request)
|
||||
return await super().handle_async_request(request, **kwargs)
|
||||
@@ -0,0 +1,125 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Tuple, Union
|
||||
|
||||
from hishel._serializers import (
|
||||
HEADERS_ENCODING,
|
||||
KNOWN_REQUEST_EXTENSIONS,
|
||||
KNOWN_RESPONSE_EXTENSIONS,
|
||||
BaseSerializer,
|
||||
Metadata,
|
||||
normalized_url,
|
||||
)
|
||||
from httpcore import Request, Response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JSONByteSerializer(BaseSerializer):
|
||||
"""JSONByteSerializer stores HTTP metadata as compact UTF-8 JSON followed by raw binary body bytes,
|
||||
separated by a single null byte. This avoids base64 encoding, significantly reducing size and
|
||||
improving performance for large responses.."""
|
||||
|
||||
def dumps(self, response: Response, request: Request, metadata: Metadata) -> Union[str, bytes]:
|
||||
"""
|
||||
Dumps the HTTP response and its HTTP request.
|
||||
:param response: An HTTP response
|
||||
:type response: Response
|
||||
:param request: An HTTP request
|
||||
:type request: Request
|
||||
:param metadata: Additional information about the stored response
|
||||
:type metadata: Metadata
|
||||
:return: Serialized response
|
||||
:rtype: Union[str, bytes]
|
||||
"""
|
||||
response_dict = {
|
||||
"status": response.status,
|
||||
"headers": [
|
||||
(key.decode(HEADERS_ENCODING), value.decode(HEADERS_ENCODING)) for key, value in response.headers
|
||||
],
|
||||
"extensions": {
|
||||
key: value.decode("ascii")
|
||||
for key, value in response.extensions.items()
|
||||
if key in KNOWN_RESPONSE_EXTENSIONS
|
||||
},
|
||||
}
|
||||
|
||||
request_dict = {
|
||||
"method": request.method.decode("ascii"),
|
||||
"url": normalized_url(request.url),
|
||||
"headers": [
|
||||
(key.decode(HEADERS_ENCODING), value.decode(HEADERS_ENCODING)) for key, value in request.headers
|
||||
],
|
||||
"extensions": {key: value for key, value in request.extensions.items() if key in KNOWN_REQUEST_EXTENSIONS},
|
||||
}
|
||||
|
||||
metadata_dict = {
|
||||
"cache_key": metadata["cache_key"],
|
||||
"number_of_uses": metadata["number_of_uses"],
|
||||
"created_at": metadata["created_at"].strftime("%a, %d %b %Y %H:%M:%S GMT"),
|
||||
}
|
||||
|
||||
full_json = {
|
||||
"response": response_dict,
|
||||
"request": request_dict,
|
||||
"metadata": metadata_dict,
|
||||
}
|
||||
|
||||
return json.dumps(full_json, separators=(",", ":")).encode("utf-8") + b"\0" + response.content
|
||||
|
||||
def loads(self, data: Union[str, bytes]) -> Tuple[Response, Request, Metadata]:
|
||||
"""
|
||||
Loads the HTTP response and its HTTP request from serialized data.
|
||||
:param data: Serialized data
|
||||
:type data: Union[str, bytes]
|
||||
:return: HTTP response and its HTTP request
|
||||
:rtype: Tuple[Response, Request, Metadata]
|
||||
"""
|
||||
data_b: bytes = data.encode("utf-8") if isinstance(data, str) else data
|
||||
full_json, body = data_b.split(b"\0", 1)
|
||||
full_json = json.loads(full_json.decode("utf-8"))
|
||||
response_dict = full_json["response"]
|
||||
request_dict = full_json["request"]
|
||||
metadata_dict = full_json["metadata"]
|
||||
metadata_dict["created_at"] = datetime.strptime(
|
||||
metadata_dict["created_at"],
|
||||
"%a, %d %b %Y %H:%M:%S GMT",
|
||||
)
|
||||
|
||||
response = Response(
|
||||
status=response_dict["status"],
|
||||
headers=[
|
||||
(key.encode(HEADERS_ENCODING), value.encode(HEADERS_ENCODING))
|
||||
for key, value in response_dict["headers"]
|
||||
],
|
||||
content=body,
|
||||
extensions={
|
||||
key: value.encode("ascii")
|
||||
for key, value in response_dict["extensions"].items()
|
||||
if key in KNOWN_RESPONSE_EXTENSIONS
|
||||
},
|
||||
)
|
||||
|
||||
request = Request(
|
||||
method=request_dict["method"],
|
||||
url=request_dict["url"],
|
||||
headers=[
|
||||
(key.encode(HEADERS_ENCODING), value.encode(HEADERS_ENCODING)) for key, value in request_dict["headers"]
|
||||
],
|
||||
extensions={
|
||||
key: value for key, value in request_dict["extensions"].items() if key in KNOWN_REQUEST_EXTENSIONS
|
||||
},
|
||||
)
|
||||
|
||||
metadata = Metadata(
|
||||
cache_key=metadata_dict["cache_key"],
|
||||
created_at=metadata_dict["created_at"],
|
||||
number_of_uses=metadata_dict["number_of_uses"],
|
||||
)
|
||||
|
||||
return response, request, metadata
|
||||
|
||||
@property
|
||||
def is_binary(self) -> bool: # pragma: no cover
|
||||
return True
|
||||
Reference in New Issue
Block a user