Initial commit
This commit is contained in:
@@ -0,0 +1,312 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, AsyncGenerator, Callable, Generator, Literal, Mapping, Optional, Sequence, Union
|
||||
|
||||
import hishel
|
||||
import httpx
|
||||
from httpx._types import ProxyTypes
|
||||
from pyrate_limiter import Duration, Limiter
|
||||
|
||||
from .controller import get_cache_controller
|
||||
from .filecache.transport import CachingTransport
|
||||
from .key_generator import file_key_generator
|
||||
from .ratelimiter import AsyncRateLimitingTransport, RateLimitingTransport, create_rate_limiter
|
||||
from .serializer import JSONByteSerializer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
# enable http2 if h2 is installed
|
||||
import h2 # type: ignore # noqa
|
||||
|
||||
logger.debug("HTTP2 available")
|
||||
|
||||
HTTP2 = True # pragma: no cover
|
||||
except ImportError:
|
||||
logger.debug("HTTP2 not available")
|
||||
HTTP2 = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class HttpxThrottleCache:
|
||||
"""
|
||||
Implements a rate limited, optional-cached HTTPX wrapper that returns client() (httpx.Client) or async_http_client() (httpx.AsyncClient).
|
||||
|
||||
Rate Limiting is across all connections, whether via client & async_htp_client, using pyrate_limiter. For multiprocessing, use pyrate_limiters
|
||||
MultiprocessBucket or SqliteBucket w/ a file lock.
|
||||
|
||||
Caching is implemented via Hishel, which allows a variety of configurations, including AWS storage.
|
||||
|
||||
This function is used for all synchronous requests.
|
||||
"""
|
||||
|
||||
httpx_params: dict[str, Any] = field(
|
||||
default_factory=lambda: {"default_encoding": "utf-8", "http2": HTTP2, "verify": True}
|
||||
)
|
||||
|
||||
cache_rules: dict[str, dict[str, Union[bool, int]]] = field(default_factory=lambda: {})
|
||||
rate_limiter_enabled: bool = True
|
||||
cache_mode: Literal[False, "Disabled", "Hishel-S3", "Hishel-File", "FileCache"] = "Hishel-File"
|
||||
request_per_sec_limit: int = 10
|
||||
max_delay: Duration = field(default_factory=lambda: Duration.DAY)
|
||||
_client: Optional[httpx.Client] = None
|
||||
|
||||
rate_limiter: Optional[Limiter] = None
|
||||
s3_bucket: Optional[str] = None
|
||||
s3_client: Optional[Any] = None
|
||||
user_agent: Optional[str] = None
|
||||
user_agent_factory: Optional[Callable] = None
|
||||
|
||||
cache_dir: Optional[Union[Path, str]] = None
|
||||
|
||||
lock = threading.Lock()
|
||||
|
||||
proxy: Optional[ProxyTypes] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.cache_dir = Path(self.cache_dir) if isinstance(self.cache_dir, str) else self.cache_dir
|
||||
# self.lock = threading.Lock()
|
||||
|
||||
if self.rate_limiter_enabled and self.rate_limiter is None:
|
||||
self.rate_limiter = create_rate_limiter(
|
||||
requests_per_second=self.request_per_sec_limit, max_delay=self.max_delay
|
||||
)
|
||||
|
||||
if (self.cache_mode != "Disabled" or self.cache_mode is False) and not self.cache_rules:
|
||||
logger.info("Cache is enabled, but no cache_rules provided. Will use default caching.")
|
||||
|
||||
if self.cache_mode == "Disabled" or self.cache_mode is False:
|
||||
pass
|
||||
elif self.cache_mode == "Hishel-S3":
|
||||
if self.s3_bucket is None:
|
||||
raise ValueError("s3_bucket must be provided if using Hishel-S3 storage")
|
||||
else: # Hishel-File or FileCache
|
||||
if self.cache_dir is None:
|
||||
raise ValueError(f"cache_dir must be provided if using a file based cache: {self.cache_mode}")
|
||||
else:
|
||||
if not self.cache_dir.exists():
|
||||
self.cache_dir.mkdir()
|
||||
|
||||
logger.debug(
|
||||
"Initialized cache with cache_mode=%s, cache_dir=%s, rate_limiter_enabled=%s",
|
||||
self.cache_mode,
|
||||
self.cache_dir,
|
||||
self.rate_limiter_enabled,
|
||||
)
|
||||
|
||||
if os.environ.get("HTTPS_PROXY") is not None:
|
||||
self.proxy = os.environ.get("HTTPS_PROXY")
|
||||
|
||||
def _populate_user_agent(self, params: dict):
|
||||
if self.user_agent_factory is not None:
|
||||
user_agent = self.user_agent_factory()
|
||||
else:
|
||||
user_agent = self.user_agent
|
||||
|
||||
if user_agent is not None:
|
||||
if "headers" not in params:
|
||||
params["headers"] = {}
|
||||
|
||||
params["headers"]["User-Agent"] = user_agent
|
||||
return params
|
||||
|
||||
def populate_user_agent(self, params: dict):
|
||||
"""Provided so clients can inspect the params that would be passed to HTTPX"""
|
||||
return self._populate_user_agent(params)
|
||||
|
||||
def get_batch(self, *, urls: Sequence[str] | Mapping[str, Path], _client_mocker=None):
|
||||
"""
|
||||
Fetch a batch of URLs concurrently and either return their content in-memory
|
||||
or stream them directly to files.
|
||||
|
||||
Uses background thread with an asyncio event loop.
|
||||
|
||||
Args:
|
||||
urls (Sequence[str] | Mapping[str, Path]):
|
||||
|
||||
Returns:
|
||||
list:
|
||||
- If given mappings, then returns a list of Paths. Else, a list of the Content
|
||||
|
||||
Raises:
|
||||
RuntimeError: If any URL responds with a status code other than 200 or 304.
|
||||
"""
|
||||
|
||||
import aiofiles
|
||||
|
||||
async def _run():
|
||||
async with self.async_http_client() as client:
|
||||
if _client_mocker:
|
||||
# For testing
|
||||
_client_mocker(client)
|
||||
|
||||
async def task(url: str, path: Optional[Path]):
|
||||
async with client.stream("GET", url) as r:
|
||||
if r.status_code in (200, 304):
|
||||
if path:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
async with aiofiles.open(path, "wb") as f:
|
||||
async for chunk in r.aiter_bytes():
|
||||
await f.write(chunk)
|
||||
return path
|
||||
else:
|
||||
return await r.aread()
|
||||
else:
|
||||
raise RuntimeError(f"URL status code is not 200 or 304: {url=}")
|
||||
|
||||
if isinstance(urls, Mapping):
|
||||
return await asyncio.gather(*(task(u, p) for u, p in urls.items()))
|
||||
else:
|
||||
return await asyncio.gather(*(task(u, None) for u in urls))
|
||||
|
||||
with ThreadPoolExecutor(1) as pool:
|
||||
return pool.submit(lambda: asyncio.run(_run())).result()
|
||||
|
||||
def _get_httpx_transport_params(self, params: dict[str, Any]):
|
||||
http2 = params.get("http2", False)
|
||||
proxy = self.proxy
|
||||
|
||||
return {"http2": http2, "proxy": proxy}
|
||||
|
||||
@contextmanager
|
||||
def http_client(self, bypass_cache: bool = False, **kwargs) -> Generator[httpx.Client, None, None]:
|
||||
"""Provides and reuses a client. Does not close"""
|
||||
if self._client is None:
|
||||
with self.lock:
|
||||
# Locking: not super critical, since worst case might be extra httpx clients created,
|
||||
# but future proofing against TOCTOU races in free-threading world
|
||||
if self._client is None:
|
||||
logger.debug("Creating new HTTPX Client")
|
||||
params = self.httpx_params.copy()
|
||||
|
||||
self._populate_user_agent(params)
|
||||
|
||||
params.update(**kwargs)
|
||||
params["transport"] = self._get_transport(
|
||||
bypass_cache=bypass_cache, httpx_transport_params=self._get_httpx_transport_params(params)
|
||||
)
|
||||
self._client = httpx.Client(**params)
|
||||
|
||||
yield self._client
|
||||
|
||||
def close(self):
|
||||
if self._client is not None:
|
||||
self._client.close()
|
||||
self._client = None
|
||||
|
||||
def update_rate_limiter(self, requests_per_second: int, max_delay: Duration = Duration.DAY):
|
||||
self.rate_limiter = create_rate_limiter(requests_per_second=requests_per_second, max_delay=Duration.DAY)
|
||||
|
||||
self.close()
|
||||
|
||||
def _client_factory_async(self, bypass_cache: bool, **kwargs) -> httpx.AsyncClient:
|
||||
params = self.httpx_params.copy()
|
||||
params.update(**kwargs)
|
||||
self._populate_user_agent(params)
|
||||
params["transport"] = self._get_async_transport(
|
||||
bypass_cache=bypass_cache, httpx_transport_params=self._get_httpx_transport_params(params)
|
||||
)
|
||||
|
||||
return httpx.AsyncClient(**params)
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_http_client(
|
||||
self, client: Optional[httpx.AsyncClient] = None, bypass_cache: bool = False, **kwargs
|
||||
) -> AsyncGenerator[httpx.AsyncClient, None]:
|
||||
"""
|
||||
Async callers should create a single client for a group of tasks, rather than creating a single client per task.
|
||||
|
||||
If a null client is passed, then this is a no-op and the client isn't closed. This (passing a client) occurs when a higher level async task creates the client to be used by child calls.
|
||||
"""
|
||||
|
||||
if client is not None:
|
||||
yield client # type: ignore # Caller is responsible for closing
|
||||
return
|
||||
|
||||
async with self._client_factory_async(bypass_cache=bypass_cache, **kwargs) as client:
|
||||
yield client
|
||||
|
||||
def _get_transport(self, bypass_cache: bool, httpx_transport_params: dict[str, Any]) -> httpx.BaseTransport:
|
||||
"""
|
||||
Constructs the Transport Chain:
|
||||
|
||||
Caching Transport (if enabled) => Rate Limiting Transport (if enabled) => httpx.HTTPTransport
|
||||
"""
|
||||
if self.rate_limiter_enabled:
|
||||
assert self.rate_limiter is not None
|
||||
next_transport = RateLimitingTransport(self.rate_limiter, **httpx_transport_params)
|
||||
else:
|
||||
next_transport = httpx.HTTPTransport(**httpx_transport_params)
|
||||
|
||||
if bypass_cache or self.cache_mode == "Disabled" or self.cache_mode is False:
|
||||
logger.info("Cache is DISABLED, rate limiting only")
|
||||
return next_transport
|
||||
elif self.cache_mode == "FileCache":
|
||||
assert self.cache_dir is not None
|
||||
return CachingTransport(cache_dir=self.cache_dir, transport=next_transport, cache_rules=self.cache_rules)
|
||||
else:
|
||||
# either Hishel-S3 or Hishel-File
|
||||
assert self.cache_mode == "Hishel-File" or self.cache_mode == "Hishel-S3"
|
||||
controller = get_cache_controller(key_generator=file_key_generator, cache_rules=self.cache_rules)
|
||||
|
||||
if self.cache_mode == "Hishel-S3":
|
||||
assert self.s3_bucket is not None
|
||||
storage = hishel.S3Storage(
|
||||
client=self.s3_client, bucket_name=self.s3_bucket, serializer=JSONByteSerializer()
|
||||
)
|
||||
else:
|
||||
assert self.cache_mode == "Hishel-File"
|
||||
assert self.cache_dir is not None
|
||||
storage = hishel.FileStorage(base_path=Path(self.cache_dir), serializer=JSONByteSerializer())
|
||||
|
||||
return hishel.CacheTransport(transport=next_transport, storage=storage, controller=controller)
|
||||
|
||||
def _get_async_transport(
|
||||
self, bypass_cache: bool, httpx_transport_params: dict[str, Any]
|
||||
) -> httpx.AsyncBaseTransport:
|
||||
"""
|
||||
Constructs the Transport Chain:
|
||||
|
||||
Caching Transport (if enabled) => Rate Limiting Transport (if enabled) => httpx.HTTPTransport
|
||||
"""
|
||||
|
||||
if self.rate_limiter_enabled:
|
||||
assert self.rate_limiter is not None
|
||||
next_transport = AsyncRateLimitingTransport(self.rate_limiter, **httpx_transport_params)
|
||||
else:
|
||||
next_transport = httpx.AsyncHTTPTransport(**httpx_transport_params)
|
||||
|
||||
if bypass_cache or self.cache_mode == "Disabled" or self.cache_mode is False:
|
||||
logger.info("Cache is DISABLED, rate limiting only")
|
||||
return next_transport
|
||||
elif self.cache_mode == "FileCache":
|
||||
assert self.cache_dir is not None
|
||||
return CachingTransport(cache_dir=self.cache_dir, transport=next_transport, cache_rules=self.cache_rules) # pyright: ignore[reportArgumentType]
|
||||
else:
|
||||
# either Hishel-S3 or Hishel-File
|
||||
assert self.cache_mode == "Hishel-File" or self.cache_mode == "Hishel-S3"
|
||||
controller = get_cache_controller(key_generator=file_key_generator, cache_rules=self.cache_rules)
|
||||
|
||||
if self.cache_mode == "Hishel-S3":
|
||||
assert self.s3_bucket is not None
|
||||
storage = hishel.AsyncS3Storage(
|
||||
client=self.s3_client, bucket_name=self.s3_bucket, serializer=JSONByteSerializer()
|
||||
)
|
||||
else:
|
||||
assert self.cache_mode == "Hishel-File"
|
||||
assert self.cache_dir is not None
|
||||
storage = hishel.AsyncFileStorage(base_path=Path(self.cache_dir), serializer=JSONByteSerializer())
|
||||
|
||||
return hishel.AsyncCacheTransport(transport=next_transport, storage=storage, controller=controller)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.close()
|
||||
Reference in New Issue
Block a user