import asyncio import logging import os import threading from concurrent.futures import ThreadPoolExecutor from contextlib import asynccontextmanager, contextmanager from dataclasses import dataclass, field from pathlib import Path from typing import Any, AsyncGenerator, Callable, Generator, Literal, Mapping, Optional, Sequence, Union import hishel import httpx from httpx._types import ProxyTypes from pyrate_limiter import Duration, Limiter from .controller import get_cache_controller from .filecache.transport import CachingTransport from .key_generator import file_key_generator from .ratelimiter import AsyncRateLimitingTransport, RateLimitingTransport, create_rate_limiter from .serializer import JSONByteSerializer logger = logging.getLogger(__name__) try: # enable http2 if h2 is installed import h2 # type: ignore # noqa logger.debug("HTTP2 available") HTTP2 = True # pragma: no cover except ImportError: logger.debug("HTTP2 not available") HTTP2 = False @dataclass class HttpxThrottleCache: """ Implements a rate limited, optional-cached HTTPX wrapper that returns client() (httpx.Client) or async_http_client() (httpx.AsyncClient). Rate Limiting is across all connections, whether via client & async_htp_client, using pyrate_limiter. For multiprocessing, use pyrate_limiters MultiprocessBucket or SqliteBucket w/ a file lock. Caching is implemented via Hishel, which allows a variety of configurations, including AWS storage. This function is used for all synchronous requests. """ httpx_params: dict[str, Any] = field( default_factory=lambda: {"default_encoding": "utf-8", "http2": HTTP2, "verify": True} ) cache_rules: dict[str, dict[str, Union[bool, int]]] = field(default_factory=lambda: {}) rate_limiter_enabled: bool = True cache_mode: Literal[False, "Disabled", "Hishel-S3", "Hishel-File", "FileCache"] = "Hishel-File" request_per_sec_limit: int = 10 max_delay: Duration = field(default_factory=lambda: Duration.DAY) _client: Optional[httpx.Client] = None rate_limiter: Optional[Limiter] = None s3_bucket: Optional[str] = None s3_client: Optional[Any] = None user_agent: Optional[str] = None user_agent_factory: Optional[Callable] = None cache_dir: Optional[Union[Path, str]] = None lock = threading.Lock() proxy: Optional[ProxyTypes] = None def __post_init__(self): self.cache_dir = Path(self.cache_dir) if isinstance(self.cache_dir, str) else self.cache_dir # self.lock = threading.Lock() if self.rate_limiter_enabled and self.rate_limiter is None: self.rate_limiter = create_rate_limiter( requests_per_second=self.request_per_sec_limit, max_delay=self.max_delay ) if (self.cache_mode != "Disabled" or self.cache_mode is False) and not self.cache_rules: logger.info("Cache is enabled, but no cache_rules provided. Will use default caching.") if self.cache_mode == "Disabled" or self.cache_mode is False: pass elif self.cache_mode == "Hishel-S3": if self.s3_bucket is None: raise ValueError("s3_bucket must be provided if using Hishel-S3 storage") else: # Hishel-File or FileCache if self.cache_dir is None: raise ValueError(f"cache_dir must be provided if using a file based cache: {self.cache_mode}") else: if not self.cache_dir.exists(): self.cache_dir.mkdir() logger.debug( "Initialized cache with cache_mode=%s, cache_dir=%s, rate_limiter_enabled=%s", self.cache_mode, self.cache_dir, self.rate_limiter_enabled, ) if os.environ.get("HTTPS_PROXY") is not None: self.proxy = os.environ.get("HTTPS_PROXY") def _populate_user_agent(self, params: dict): if self.user_agent_factory is not None: user_agent = self.user_agent_factory() else: user_agent = self.user_agent if user_agent is not None: if "headers" not in params: params["headers"] = {} params["headers"]["User-Agent"] = user_agent return params def populate_user_agent(self, params: dict): """Provided so clients can inspect the params that would be passed to HTTPX""" return self._populate_user_agent(params) def get_batch(self, *, urls: Sequence[str] | Mapping[str, Path], _client_mocker=None): """ Fetch a batch of URLs concurrently and either return their content in-memory or stream them directly to files. Uses background thread with an asyncio event loop. Args: urls (Sequence[str] | Mapping[str, Path]): Returns: list: - If given mappings, then returns a list of Paths. Else, a list of the Content Raises: RuntimeError: If any URL responds with a status code other than 200 or 304. """ import aiofiles async def _run(): async with self.async_http_client() as client: if _client_mocker: # For testing _client_mocker(client) async def task(url: str, path: Optional[Path]): async with client.stream("GET", url) as r: if r.status_code in (200, 304): if path: path.parent.mkdir(parents=True, exist_ok=True) async with aiofiles.open(path, "wb") as f: async for chunk in r.aiter_bytes(): await f.write(chunk) return path else: return await r.aread() else: raise RuntimeError(f"URL status code is not 200 or 304: {url=}") if isinstance(urls, Mapping): return await asyncio.gather(*(task(u, p) for u, p in urls.items())) else: return await asyncio.gather(*(task(u, None) for u in urls)) with ThreadPoolExecutor(1) as pool: return pool.submit(lambda: asyncio.run(_run())).result() def _get_httpx_transport_params(self, params: dict[str, Any]): http2 = params.get("http2", False) proxy = self.proxy return {"http2": http2, "proxy": proxy} @contextmanager def http_client(self, bypass_cache: bool = False, **kwargs) -> Generator[httpx.Client, None, None]: """Provides and reuses a client. Does not close""" if self._client is None: with self.lock: # Locking: not super critical, since worst case might be extra httpx clients created, # but future proofing against TOCTOU races in free-threading world if self._client is None: logger.debug("Creating new HTTPX Client") params = self.httpx_params.copy() self._populate_user_agent(params) params.update(**kwargs) params["transport"] = self._get_transport( bypass_cache=bypass_cache, httpx_transport_params=self._get_httpx_transport_params(params) ) self._client = httpx.Client(**params) yield self._client def close(self): if self._client is not None: self._client.close() self._client = None def update_rate_limiter(self, requests_per_second: int, max_delay: Duration = Duration.DAY): self.rate_limiter = create_rate_limiter(requests_per_second=requests_per_second, max_delay=Duration.DAY) self.close() def _client_factory_async(self, bypass_cache: bool, **kwargs) -> httpx.AsyncClient: params = self.httpx_params.copy() params.update(**kwargs) self._populate_user_agent(params) params["transport"] = self._get_async_transport( bypass_cache=bypass_cache, httpx_transport_params=self._get_httpx_transport_params(params) ) return httpx.AsyncClient(**params) @asynccontextmanager async def async_http_client( self, client: Optional[httpx.AsyncClient] = None, bypass_cache: bool = False, **kwargs ) -> AsyncGenerator[httpx.AsyncClient, None]: """ Async callers should create a single client for a group of tasks, rather than creating a single client per task. If a null client is passed, then this is a no-op and the client isn't closed. This (passing a client) occurs when a higher level async task creates the client to be used by child calls. """ if client is not None: yield client # type: ignore # Caller is responsible for closing return async with self._client_factory_async(bypass_cache=bypass_cache, **kwargs) as client: yield client def _get_transport(self, bypass_cache: bool, httpx_transport_params: dict[str, Any]) -> httpx.BaseTransport: """ Constructs the Transport Chain: Caching Transport (if enabled) => Rate Limiting Transport (if enabled) => httpx.HTTPTransport """ if self.rate_limiter_enabled: assert self.rate_limiter is not None next_transport = RateLimitingTransport(self.rate_limiter, **httpx_transport_params) else: next_transport = httpx.HTTPTransport(**httpx_transport_params) if bypass_cache or self.cache_mode == "Disabled" or self.cache_mode is False: logger.info("Cache is DISABLED, rate limiting only") return next_transport elif self.cache_mode == "FileCache": assert self.cache_dir is not None return CachingTransport(cache_dir=self.cache_dir, transport=next_transport, cache_rules=self.cache_rules) else: # either Hishel-S3 or Hishel-File assert self.cache_mode == "Hishel-File" or self.cache_mode == "Hishel-S3" controller = get_cache_controller(key_generator=file_key_generator, cache_rules=self.cache_rules) if self.cache_mode == "Hishel-S3": assert self.s3_bucket is not None storage = hishel.S3Storage( client=self.s3_client, bucket_name=self.s3_bucket, serializer=JSONByteSerializer() ) else: assert self.cache_mode == "Hishel-File" assert self.cache_dir is not None storage = hishel.FileStorage(base_path=Path(self.cache_dir), serializer=JSONByteSerializer()) return hishel.CacheTransport(transport=next_transport, storage=storage, controller=controller) def _get_async_transport( self, bypass_cache: bool, httpx_transport_params: dict[str, Any] ) -> httpx.AsyncBaseTransport: """ Constructs the Transport Chain: Caching Transport (if enabled) => Rate Limiting Transport (if enabled) => httpx.HTTPTransport """ if self.rate_limiter_enabled: assert self.rate_limiter is not None next_transport = AsyncRateLimitingTransport(self.rate_limiter, **httpx_transport_params) else: next_transport = httpx.AsyncHTTPTransport(**httpx_transport_params) if bypass_cache or self.cache_mode == "Disabled" or self.cache_mode is False: logger.info("Cache is DISABLED, rate limiting only") return next_transport elif self.cache_mode == "FileCache": assert self.cache_dir is not None return CachingTransport(cache_dir=self.cache_dir, transport=next_transport, cache_rules=self.cache_rules) # pyright: ignore[reportArgumentType] else: # either Hishel-S3 or Hishel-File assert self.cache_mode == "Hishel-File" or self.cache_mode == "Hishel-S3" controller = get_cache_controller(key_generator=file_key_generator, cache_rules=self.cache_rules) if self.cache_mode == "Hishel-S3": assert self.s3_bucket is not None storage = hishel.AsyncS3Storage( client=self.s3_client, bucket_name=self.s3_bucket, serializer=JSONByteSerializer() ) else: assert self.cache_mode == "Hishel-File" assert self.cache_dir is not None storage = hishel.AsyncFileStorage(base_path=Path(self.cache_dir), serializer=JSONByteSerializer()) return hishel.AsyncCacheTransport(transport=next_transport, storage=storage, controller=controller) def __enter__(self): return self def __exit__(self, type, value, traceback): self.close()