Files
2025-12-09 12:13:01 +01:00

483 lines
17 KiB
Python

"""Limiter class implementation
"""
import asyncio
import logging
from contextlib import contextmanager
from functools import wraps
from inspect import isawaitable
from threading import local
from threading import RLock
from time import sleep
from typing import Any
from typing import Awaitable
from typing import Callable
from typing import Iterable
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
from .abstracts import AbstractBucket
from .abstracts import AbstractClock
from .abstracts import BucketFactory
from .abstracts import Duration
from .abstracts import Rate
from .abstracts import RateItem
from .buckets import InMemoryBucket
from .clocks import TimeClock
from .exceptions import BucketFullException
from .exceptions import LimiterDelayException
logger = logging.getLogger("pyrate_limiter")
ItemMapping = Callable[[Any], Tuple[str, int]]
DecoratorWrapper = Callable[[Callable[[Any], Any]], Callable[[Any], Any]]
class SingleBucketFactory(BucketFactory):
"""Single-bucket factory for quick use with Limiter"""
bucket: AbstractBucket
clock: AbstractClock
def __init__(self, bucket: AbstractBucket, clock: AbstractClock):
self.clock = clock
self.bucket = bucket
self.schedule_leak(bucket, clock)
def wrap_item(self, name: str, weight: int = 1):
now = self.clock.now()
async def wrap_async():
return RateItem(name, await now, weight=weight)
def wrap_sync():
return RateItem(name, now, weight=weight)
return wrap_async() if isawaitable(now) else wrap_sync()
def get(self, _: RateItem) -> AbstractBucket:
return self.bucket
@contextmanager
def combined_lock(locks: Iterable, timeout_sec: Optional[float] = None):
"""Acquires and releases multiple locks. Intended to be used in multiprocessing for a cross-process
lock combined with in process thread RLocks"""
acquired = []
try:
for lock in locks:
if timeout_sec is not None:
if not lock.acquire(timeout=timeout_sec):
raise TimeoutError("Timeout while acquiring combined lock.")
else:
lock.acquire()
acquired.append(lock)
yield
finally:
for lock in reversed(acquired):
lock.release()
class Limiter:
"""This class responsibility is to sum up all underlying logic
and make working with async/sync functions easily
"""
bucket_factory: BucketFactory
raise_when_fail: bool
retry_until_max_delay: bool
max_delay: Optional[int] = None
lock: Union[RLock, Iterable]
buffer_ms: int
# async_lock is thread local, created on first use
_thread_local: local
def __init__(
self,
argument: Union[BucketFactory, AbstractBucket, Rate, List[Rate]],
clock: AbstractClock = TimeClock(),
raise_when_fail: bool = True,
max_delay: Optional[Union[int, Duration]] = None,
retry_until_max_delay: bool = False,
buffer_ms: int = 50
):
"""Init Limiter using either a single bucket / multiple-bucket factory
/ single rate / rate list.
Parameters:
argument (Union[BucketFactory, AbstractBucket, Rate, List[Rate]]): The bucket or rate configuration.
clock (AbstractClock, optional): The clock instance to use for rate limiting. Defaults to TimeClock().
raise_when_fail (bool, optional): Whether to raise an exception when rate limiting fails. Defaults to True.
max_delay (Optional[Union[int, Duration]], optional): The maximum delay allowed for rate limiting.
Defaults to None.
retry_until_max_delay (bool, optional): If True, retry operations until the maximum delay is reached.
Useful for ensuring operations eventually succeed within the allowed delay window. Defaults to False.
"""
self.bucket_factory = self._init_bucket_factory(argument, clock=clock)
self.raise_when_fail = raise_when_fail
self.retry_until_max_delay = retry_until_max_delay
self.buffer_ms = buffer_ms
if max_delay is not None:
if isinstance(max_delay, Duration):
max_delay = int(max_delay)
assert max_delay >= 0, "Max-delay must not be negative"
self.max_delay = max_delay
self.lock = RLock()
self._thread_local = local()
if isinstance(argument, AbstractBucket):
limiter_lock = argument.limiter_lock()
if limiter_lock is not None:
self.lock = (limiter_lock, self.lock)
def buckets(self) -> List[AbstractBucket]:
"""Get list of active buckets
"""
return self.bucket_factory.get_buckets()
def dispose(self, bucket: Union[int, AbstractBucket]) -> bool:
"""Dispose/Remove a specific bucket,
using bucket-id or bucket object as param
"""
return self.bucket_factory.dispose(bucket)
def _init_bucket_factory(
self,
argument: Union[BucketFactory, AbstractBucket, Rate, List[Rate]],
clock: AbstractClock,
) -> BucketFactory:
if isinstance(argument, Rate):
argument = [argument]
if isinstance(argument, list):
assert len(argument) > 0, "Rates must not be empty"
assert isinstance(argument[0], Rate), "Not valid rates list"
rates = argument
logger.info("Initializing default bucket(InMemoryBucket) with rates: %s", rates)
argument = InMemoryBucket(rates)
if isinstance(argument, AbstractBucket):
argument = SingleBucketFactory(argument, clock)
assert isinstance(argument, BucketFactory), "Not a valid bucket/bucket-factory"
return argument
def _raise_bucket_full_if_necessary(
self,
bucket: AbstractBucket,
item: RateItem,
):
if self.raise_when_fail:
assert bucket.failing_rate is not None # NOTE: silence mypy
raise BucketFullException(item, bucket.failing_rate)
def _raise_delay_exception_if_necessary(
self,
bucket: AbstractBucket,
item: RateItem,
delay: int,
):
if self.raise_when_fail:
assert bucket.failing_rate is not None # NOTE: silence mypy
assert isinstance(self.max_delay, int)
raise LimiterDelayException(
item,
bucket.failing_rate,
delay,
self.max_delay,
)
def delay_or_raise(
self,
bucket: AbstractBucket,
item: RateItem,
) -> Union[bool, Awaitable[bool]]:
"""On `try_acquire` failed, handle delay or raise error immediately"""
assert bucket.failing_rate is not None
if self.max_delay is None:
self._raise_bucket_full_if_necessary(bucket, item)
return False
delay = bucket.waiting(item)
def _handle_reacquire(re_acquire: bool) -> bool:
if not re_acquire:
logger.error("""Failed to re-acquire after the expected delay. If it failed,
either clock or bucket is unstable.
If asyncio, use try_acquire_async(). If multiprocessing,
use retry_until_max_delay=True.""")
self._raise_bucket_full_if_necessary(bucket, item)
return re_acquire
if isawaitable(delay):
async def _handle_async():
nonlocal delay
delay = await delay
assert isinstance(delay, int), "Delay not integer"
total_delay = 0
delay += self.buffer_ms
while True:
total_delay += delay
if self.retry_until_max_delay:
if self.max_delay is not None and total_delay > self.max_delay:
logger.error("Total delay exceeded max_delay: total_delay=%s, max_delay=%s",
total_delay, self.max_delay)
self._raise_delay_exception_if_necessary(bucket, item, total_delay)
return False
else:
if self.max_delay is not None and delay > self.max_delay:
logger.error(
"Required delay too large: actual=%s, expected=%s",
delay,
self.max_delay,
)
self._raise_delay_exception_if_necessary(bucket, item, delay)
return False
await asyncio.sleep(delay / 1000)
item.timestamp += delay
re_acquire = bucket.put(item)
if isawaitable(re_acquire):
re_acquire = await re_acquire
if not self.retry_until_max_delay:
return _handle_reacquire(re_acquire)
elif re_acquire:
return True
return _handle_async()
assert isinstance(delay, int)
if delay < 0:
logger.error(
"Cannot fit item into bucket: item=%s, rate=%s, bucket=%s",
item,
bucket.failing_rate,
bucket,
)
self._raise_bucket_full_if_necessary(bucket, item)
return False
total_delay = 0
while True:
logger.debug("delay=%d, total_delay=%s", delay, total_delay)
delay = bucket.waiting(item)
assert isinstance(delay, int)
delay += self.buffer_ms
total_delay += delay
if self.max_delay is not None and total_delay > self.max_delay:
logger.error(
"Required delay too large: actual=%s, expected=%s",
delay,
self.max_delay,
)
if self.retry_until_max_delay:
self._raise_delay_exception_if_necessary(bucket, item, total_delay)
else:
self._raise_delay_exception_if_necessary(bucket, item, delay)
return False
sleep(delay / 1000)
item.timestamp += delay
re_acquire = bucket.put(item)
# NOTE: if delay is not Awaitable, then `bucket.put` is not Awaitable
assert isinstance(re_acquire, bool)
if not self.retry_until_max_delay:
return _handle_reacquire(re_acquire)
elif re_acquire:
return True
def handle_bucket_put(
self,
bucket: AbstractBucket,
item: RateItem,
) -> Union[bool, Awaitable[bool]]:
"""Putting item into bucket"""
def _handle_result(is_success: bool):
if not is_success:
return self.delay_or_raise(bucket, item)
return True
acquire = bucket.put(item)
if isawaitable(acquire):
async def _put_async():
nonlocal acquire
acquire = await acquire
result = _handle_result(acquire)
while isawaitable(result):
result = await result
return result
return _put_async()
return _handle_result(acquire) # type: ignore
def _get_async_lock(self):
"""Must be called before first try_acquire_async for each thread"""
try:
return self._thread_local.async_lock
except AttributeError:
lock = asyncio.Lock()
self._thread_local.async_lock = lock
return lock
async def try_acquire_async(self, name: str, weight: int = 1) -> bool:
"""
async version of try_acquire.
This uses a top level, thread-local async lock to ensure that the async loop doesn't block
This does not make the entire code async: use an async bucket for that.
"""
async with self._get_async_lock():
acquired = self.try_acquire(name=name, weight=weight)
if isawaitable(acquired):
return await acquired
else:
logger.warning("async call made without an async bucket.")
return acquired
def try_acquire(self, name: str, weight: int = 1) -> Union[bool, Awaitable[bool]]:
"""Try acquiring an item with name & weight
Return true on success, false on failure
"""
with self.lock if not isinstance(self.lock, Iterable) else combined_lock(self.lock):
assert weight >= 0, "item's weight must be >= 0"
if weight == 0:
# NOTE: if item is weightless, just let it go through
# NOTE: this might change in the future
return True
item = self.bucket_factory.wrap_item(name, weight)
if isawaitable(item):
async def _handle_async():
nonlocal item
item = await item
bucket = self.bucket_factory.get(item)
if isawaitable(bucket):
bucket = await bucket
assert isinstance(bucket, AbstractBucket), f"Invalid bucket: item: {name}"
result = self.handle_bucket_put(bucket, item)
while isawaitable(result):
result = await result
return result
return _handle_async()
assert isinstance(item, RateItem) # NOTE: this is to silence mypy warning
bucket = self.bucket_factory.get(item)
if isawaitable(bucket):
async def _handle_async_bucket():
nonlocal bucket
bucket = await bucket
assert isinstance(bucket, AbstractBucket), f"Invalid bucket: item: {name}"
result = self.handle_bucket_put(bucket, item)
while isawaitable(result):
result = await result
return result
return _handle_async_bucket()
assert isinstance(bucket, AbstractBucket), f"Invalid bucket: item: {name}"
result = self.handle_bucket_put(bucket, item)
if isawaitable(result):
async def _handle_async_result():
nonlocal result
while isawaitable(result):
result = await result
return result
return _handle_async_result()
return result
def as_decorator(self) -> Callable[[ItemMapping], DecoratorWrapper]:
"""Use limiter decorator
Use with both sync & async function
"""
def with_mapping_func(mapping: ItemMapping) -> DecoratorWrapper:
def decorator_wrapper(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
"""Actual function wrapper"""
if asyncio.iscoroutinefunction(func):
@wraps(func)
async def wrapper_async(*args, **kwargs):
(name, weight) = mapping(*args, **kwargs)
assert isinstance(name, str), "Mapping name is expected but not found"
assert isinstance(weight, int), "Mapping weight is expected but not found"
accquire_ok = self.try_acquire_async(name, weight)
if isawaitable(accquire_ok):
await accquire_ok
return await func(*args, **kwargs)
return wrapper_async
else:
@wraps(func)
def wrapper(*args, **kwargs):
(name, weight) = mapping(*args, **kwargs)
assert isinstance(name, str), "Mapping name is expected but not found"
assert isinstance(weight, int), "Mapping weight is expected but not found"
accquire_ok = self.try_acquire(name, weight)
if not isawaitable(accquire_ok):
return func(*args, **kwargs)
async def _handle_accquire_async():
nonlocal accquire_ok
accquire_ok = await accquire_ok
result = func(*args, **kwargs)
if isawaitable(result):
return await result
return result
return _handle_accquire_async()
return wrapper
return decorator_wrapper
return with_mapping_func