483 lines
17 KiB
Python
483 lines
17 KiB
Python
"""Limiter class implementation
|
|
"""
|
|
import asyncio
|
|
import logging
|
|
from contextlib import contextmanager
|
|
from functools import wraps
|
|
from inspect import isawaitable
|
|
from threading import local
|
|
from threading import RLock
|
|
from time import sleep
|
|
from typing import Any
|
|
from typing import Awaitable
|
|
from typing import Callable
|
|
from typing import Iterable
|
|
from typing import List
|
|
from typing import Optional
|
|
from typing import Tuple
|
|
from typing import Union
|
|
|
|
from .abstracts import AbstractBucket
|
|
from .abstracts import AbstractClock
|
|
from .abstracts import BucketFactory
|
|
from .abstracts import Duration
|
|
from .abstracts import Rate
|
|
from .abstracts import RateItem
|
|
from .buckets import InMemoryBucket
|
|
from .clocks import TimeClock
|
|
from .exceptions import BucketFullException
|
|
from .exceptions import LimiterDelayException
|
|
|
|
logger = logging.getLogger("pyrate_limiter")
|
|
|
|
ItemMapping = Callable[[Any], Tuple[str, int]]
|
|
DecoratorWrapper = Callable[[Callable[[Any], Any]], Callable[[Any], Any]]
|
|
|
|
|
|
class SingleBucketFactory(BucketFactory):
|
|
"""Single-bucket factory for quick use with Limiter"""
|
|
|
|
bucket: AbstractBucket
|
|
clock: AbstractClock
|
|
|
|
def __init__(self, bucket: AbstractBucket, clock: AbstractClock):
|
|
self.clock = clock
|
|
self.bucket = bucket
|
|
self.schedule_leak(bucket, clock)
|
|
|
|
def wrap_item(self, name: str, weight: int = 1):
|
|
now = self.clock.now()
|
|
|
|
async def wrap_async():
|
|
return RateItem(name, await now, weight=weight)
|
|
|
|
def wrap_sync():
|
|
return RateItem(name, now, weight=weight)
|
|
|
|
return wrap_async() if isawaitable(now) else wrap_sync()
|
|
|
|
def get(self, _: RateItem) -> AbstractBucket:
|
|
return self.bucket
|
|
|
|
|
|
@contextmanager
|
|
def combined_lock(locks: Iterable, timeout_sec: Optional[float] = None):
|
|
"""Acquires and releases multiple locks. Intended to be used in multiprocessing for a cross-process
|
|
lock combined with in process thread RLocks"""
|
|
acquired = []
|
|
try:
|
|
for lock in locks:
|
|
if timeout_sec is not None:
|
|
if not lock.acquire(timeout=timeout_sec):
|
|
raise TimeoutError("Timeout while acquiring combined lock.")
|
|
else:
|
|
lock.acquire()
|
|
acquired.append(lock)
|
|
yield
|
|
finally:
|
|
for lock in reversed(acquired):
|
|
lock.release()
|
|
|
|
|
|
class Limiter:
|
|
"""This class responsibility is to sum up all underlying logic
|
|
and make working with async/sync functions easily
|
|
"""
|
|
|
|
bucket_factory: BucketFactory
|
|
raise_when_fail: bool
|
|
retry_until_max_delay: bool
|
|
max_delay: Optional[int] = None
|
|
lock: Union[RLock, Iterable]
|
|
buffer_ms: int
|
|
|
|
# async_lock is thread local, created on first use
|
|
_thread_local: local
|
|
|
|
def __init__(
|
|
self,
|
|
argument: Union[BucketFactory, AbstractBucket, Rate, List[Rate]],
|
|
clock: AbstractClock = TimeClock(),
|
|
raise_when_fail: bool = True,
|
|
max_delay: Optional[Union[int, Duration]] = None,
|
|
retry_until_max_delay: bool = False,
|
|
buffer_ms: int = 50
|
|
):
|
|
"""Init Limiter using either a single bucket / multiple-bucket factory
|
|
/ single rate / rate list.
|
|
|
|
Parameters:
|
|
argument (Union[BucketFactory, AbstractBucket, Rate, List[Rate]]): The bucket or rate configuration.
|
|
clock (AbstractClock, optional): The clock instance to use for rate limiting. Defaults to TimeClock().
|
|
raise_when_fail (bool, optional): Whether to raise an exception when rate limiting fails. Defaults to True.
|
|
max_delay (Optional[Union[int, Duration]], optional): The maximum delay allowed for rate limiting.
|
|
Defaults to None.
|
|
retry_until_max_delay (bool, optional): If True, retry operations until the maximum delay is reached.
|
|
Useful for ensuring operations eventually succeed within the allowed delay window. Defaults to False.
|
|
"""
|
|
self.bucket_factory = self._init_bucket_factory(argument, clock=clock)
|
|
self.raise_when_fail = raise_when_fail
|
|
self.retry_until_max_delay = retry_until_max_delay
|
|
self.buffer_ms = buffer_ms
|
|
if max_delay is not None:
|
|
if isinstance(max_delay, Duration):
|
|
max_delay = int(max_delay)
|
|
|
|
assert max_delay >= 0, "Max-delay must not be negative"
|
|
|
|
self.max_delay = max_delay
|
|
self.lock = RLock()
|
|
self._thread_local = local()
|
|
|
|
if isinstance(argument, AbstractBucket):
|
|
limiter_lock = argument.limiter_lock()
|
|
if limiter_lock is not None:
|
|
self.lock = (limiter_lock, self.lock)
|
|
|
|
def buckets(self) -> List[AbstractBucket]:
|
|
"""Get list of active buckets
|
|
"""
|
|
return self.bucket_factory.get_buckets()
|
|
|
|
def dispose(self, bucket: Union[int, AbstractBucket]) -> bool:
|
|
"""Dispose/Remove a specific bucket,
|
|
using bucket-id or bucket object as param
|
|
"""
|
|
return self.bucket_factory.dispose(bucket)
|
|
|
|
def _init_bucket_factory(
|
|
self,
|
|
argument: Union[BucketFactory, AbstractBucket, Rate, List[Rate]],
|
|
clock: AbstractClock,
|
|
) -> BucketFactory:
|
|
if isinstance(argument, Rate):
|
|
argument = [argument]
|
|
|
|
if isinstance(argument, list):
|
|
assert len(argument) > 0, "Rates must not be empty"
|
|
assert isinstance(argument[0], Rate), "Not valid rates list"
|
|
rates = argument
|
|
logger.info("Initializing default bucket(InMemoryBucket) with rates: %s", rates)
|
|
argument = InMemoryBucket(rates)
|
|
|
|
if isinstance(argument, AbstractBucket):
|
|
argument = SingleBucketFactory(argument, clock)
|
|
|
|
assert isinstance(argument, BucketFactory), "Not a valid bucket/bucket-factory"
|
|
|
|
return argument
|
|
|
|
def _raise_bucket_full_if_necessary(
|
|
self,
|
|
bucket: AbstractBucket,
|
|
item: RateItem,
|
|
):
|
|
if self.raise_when_fail:
|
|
assert bucket.failing_rate is not None # NOTE: silence mypy
|
|
raise BucketFullException(item, bucket.failing_rate)
|
|
|
|
def _raise_delay_exception_if_necessary(
|
|
self,
|
|
bucket: AbstractBucket,
|
|
item: RateItem,
|
|
delay: int,
|
|
):
|
|
if self.raise_when_fail:
|
|
assert bucket.failing_rate is not None # NOTE: silence mypy
|
|
assert isinstance(self.max_delay, int)
|
|
raise LimiterDelayException(
|
|
item,
|
|
bucket.failing_rate,
|
|
delay,
|
|
self.max_delay,
|
|
)
|
|
|
|
def delay_or_raise(
|
|
self,
|
|
bucket: AbstractBucket,
|
|
item: RateItem,
|
|
) -> Union[bool, Awaitable[bool]]:
|
|
"""On `try_acquire` failed, handle delay or raise error immediately"""
|
|
assert bucket.failing_rate is not None
|
|
|
|
if self.max_delay is None:
|
|
self._raise_bucket_full_if_necessary(bucket, item)
|
|
return False
|
|
|
|
delay = bucket.waiting(item)
|
|
|
|
def _handle_reacquire(re_acquire: bool) -> bool:
|
|
if not re_acquire:
|
|
logger.error("""Failed to re-acquire after the expected delay. If it failed,
|
|
either clock or bucket is unstable.
|
|
If asyncio, use try_acquire_async(). If multiprocessing,
|
|
use retry_until_max_delay=True.""")
|
|
self._raise_bucket_full_if_necessary(bucket, item)
|
|
|
|
return re_acquire
|
|
|
|
if isawaitable(delay):
|
|
|
|
async def _handle_async():
|
|
nonlocal delay
|
|
delay = await delay
|
|
assert isinstance(delay, int), "Delay not integer"
|
|
|
|
total_delay = 0
|
|
delay += self.buffer_ms
|
|
|
|
while True:
|
|
total_delay += delay
|
|
|
|
if self.retry_until_max_delay:
|
|
if self.max_delay is not None and total_delay > self.max_delay:
|
|
logger.error("Total delay exceeded max_delay: total_delay=%s, max_delay=%s",
|
|
total_delay, self.max_delay)
|
|
self._raise_delay_exception_if_necessary(bucket, item, total_delay)
|
|
return False
|
|
else:
|
|
if self.max_delay is not None and delay > self.max_delay:
|
|
logger.error(
|
|
"Required delay too large: actual=%s, expected=%s",
|
|
delay,
|
|
self.max_delay,
|
|
)
|
|
self._raise_delay_exception_if_necessary(bucket, item, delay)
|
|
return False
|
|
|
|
await asyncio.sleep(delay / 1000)
|
|
item.timestamp += delay
|
|
re_acquire = bucket.put(item)
|
|
|
|
if isawaitable(re_acquire):
|
|
re_acquire = await re_acquire
|
|
|
|
if not self.retry_until_max_delay:
|
|
return _handle_reacquire(re_acquire)
|
|
elif re_acquire:
|
|
return True
|
|
|
|
return _handle_async()
|
|
|
|
assert isinstance(delay, int)
|
|
|
|
if delay < 0:
|
|
logger.error(
|
|
"Cannot fit item into bucket: item=%s, rate=%s, bucket=%s",
|
|
item,
|
|
bucket.failing_rate,
|
|
bucket,
|
|
)
|
|
self._raise_bucket_full_if_necessary(bucket, item)
|
|
return False
|
|
|
|
total_delay = 0
|
|
|
|
while True:
|
|
logger.debug("delay=%d, total_delay=%s", delay, total_delay)
|
|
delay = bucket.waiting(item)
|
|
assert isinstance(delay, int)
|
|
|
|
delay += self.buffer_ms
|
|
total_delay += delay
|
|
|
|
if self.max_delay is not None and total_delay > self.max_delay:
|
|
logger.error(
|
|
"Required delay too large: actual=%s, expected=%s",
|
|
delay,
|
|
self.max_delay,
|
|
)
|
|
|
|
if self.retry_until_max_delay:
|
|
self._raise_delay_exception_if_necessary(bucket, item, total_delay)
|
|
else:
|
|
self._raise_delay_exception_if_necessary(bucket, item, delay)
|
|
|
|
return False
|
|
|
|
sleep(delay / 1000)
|
|
item.timestamp += delay
|
|
re_acquire = bucket.put(item)
|
|
# NOTE: if delay is not Awaitable, then `bucket.put` is not Awaitable
|
|
assert isinstance(re_acquire, bool)
|
|
|
|
if not self.retry_until_max_delay:
|
|
return _handle_reacquire(re_acquire)
|
|
elif re_acquire:
|
|
return True
|
|
|
|
def handle_bucket_put(
|
|
self,
|
|
bucket: AbstractBucket,
|
|
item: RateItem,
|
|
) -> Union[bool, Awaitable[bool]]:
|
|
"""Putting item into bucket"""
|
|
|
|
def _handle_result(is_success: bool):
|
|
if not is_success:
|
|
return self.delay_or_raise(bucket, item)
|
|
|
|
return True
|
|
|
|
acquire = bucket.put(item)
|
|
|
|
if isawaitable(acquire):
|
|
|
|
async def _put_async():
|
|
nonlocal acquire
|
|
acquire = await acquire
|
|
result = _handle_result(acquire)
|
|
|
|
while isawaitable(result):
|
|
result = await result
|
|
|
|
return result
|
|
|
|
return _put_async()
|
|
|
|
return _handle_result(acquire) # type: ignore
|
|
|
|
def _get_async_lock(self):
|
|
"""Must be called before first try_acquire_async for each thread"""
|
|
try:
|
|
return self._thread_local.async_lock
|
|
except AttributeError:
|
|
lock = asyncio.Lock()
|
|
self._thread_local.async_lock = lock
|
|
return lock
|
|
|
|
async def try_acquire_async(self, name: str, weight: int = 1) -> bool:
|
|
"""
|
|
async version of try_acquire.
|
|
|
|
This uses a top level, thread-local async lock to ensure that the async loop doesn't block
|
|
|
|
This does not make the entire code async: use an async bucket for that.
|
|
"""
|
|
async with self._get_async_lock():
|
|
acquired = self.try_acquire(name=name, weight=weight)
|
|
|
|
if isawaitable(acquired):
|
|
return await acquired
|
|
else:
|
|
logger.warning("async call made without an async bucket.")
|
|
return acquired
|
|
|
|
def try_acquire(self, name: str, weight: int = 1) -> Union[bool, Awaitable[bool]]:
|
|
"""Try acquiring an item with name & weight
|
|
Return true on success, false on failure
|
|
"""
|
|
with self.lock if not isinstance(self.lock, Iterable) else combined_lock(self.lock):
|
|
assert weight >= 0, "item's weight must be >= 0"
|
|
|
|
if weight == 0:
|
|
# NOTE: if item is weightless, just let it go through
|
|
# NOTE: this might change in the future
|
|
return True
|
|
|
|
item = self.bucket_factory.wrap_item(name, weight)
|
|
|
|
if isawaitable(item):
|
|
|
|
async def _handle_async():
|
|
nonlocal item
|
|
item = await item
|
|
bucket = self.bucket_factory.get(item)
|
|
if isawaitable(bucket):
|
|
bucket = await bucket
|
|
assert isinstance(bucket, AbstractBucket), f"Invalid bucket: item: {name}"
|
|
result = self.handle_bucket_put(bucket, item)
|
|
|
|
while isawaitable(result):
|
|
result = await result
|
|
|
|
return result
|
|
|
|
return _handle_async()
|
|
|
|
assert isinstance(item, RateItem) # NOTE: this is to silence mypy warning
|
|
bucket = self.bucket_factory.get(item)
|
|
if isawaitable(bucket):
|
|
|
|
async def _handle_async_bucket():
|
|
nonlocal bucket
|
|
bucket = await bucket
|
|
assert isinstance(bucket, AbstractBucket), f"Invalid bucket: item: {name}"
|
|
result = self.handle_bucket_put(bucket, item)
|
|
|
|
while isawaitable(result):
|
|
result = await result
|
|
|
|
return result
|
|
|
|
return _handle_async_bucket()
|
|
|
|
assert isinstance(bucket, AbstractBucket), f"Invalid bucket: item: {name}"
|
|
result = self.handle_bucket_put(bucket, item)
|
|
|
|
if isawaitable(result):
|
|
|
|
async def _handle_async_result():
|
|
nonlocal result
|
|
while isawaitable(result):
|
|
result = await result
|
|
|
|
return result
|
|
|
|
return _handle_async_result()
|
|
|
|
return result
|
|
|
|
def as_decorator(self) -> Callable[[ItemMapping], DecoratorWrapper]:
|
|
"""Use limiter decorator
|
|
Use with both sync & async function
|
|
"""
|
|
def with_mapping_func(mapping: ItemMapping) -> DecoratorWrapper:
|
|
def decorator_wrapper(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
|
|
"""Actual function wrapper"""
|
|
|
|
if asyncio.iscoroutinefunction(func):
|
|
|
|
@wraps(func)
|
|
async def wrapper_async(*args, **kwargs):
|
|
(name, weight) = mapping(*args, **kwargs)
|
|
assert isinstance(name, str), "Mapping name is expected but not found"
|
|
assert isinstance(weight, int), "Mapping weight is expected but not found"
|
|
accquire_ok = self.try_acquire_async(name, weight)
|
|
|
|
if isawaitable(accquire_ok):
|
|
await accquire_ok
|
|
|
|
return await func(*args, **kwargs)
|
|
|
|
return wrapper_async
|
|
else:
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
(name, weight) = mapping(*args, **kwargs)
|
|
assert isinstance(name, str), "Mapping name is expected but not found"
|
|
assert isinstance(weight, int), "Mapping weight is expected but not found"
|
|
accquire_ok = self.try_acquire(name, weight)
|
|
|
|
if not isawaitable(accquire_ok):
|
|
return func(*args, **kwargs)
|
|
|
|
async def _handle_accquire_async():
|
|
nonlocal accquire_ok
|
|
accquire_ok = await accquire_ok
|
|
|
|
result = func(*args, **kwargs)
|
|
|
|
if isawaitable(result):
|
|
return await result
|
|
|
|
return result
|
|
|
|
return _handle_accquire_async()
|
|
|
|
return wrapper
|
|
|
|
return decorator_wrapper
|
|
|
|
return with_mapping_func
|