Files
edgartools/venv/lib/python3.10/site-packages/pyrate_limiter/abstracts/bucket.py
2025-12-09 12:13:01 +01:00

305 lines
9.5 KiB
Python

""" Implement this class to create
a workable bucket for Limiter to use
"""
import asyncio
import logging
from abc import ABC
from abc import abstractmethod
from collections import defaultdict
from inspect import isawaitable
from inspect import iscoroutine
from threading import Thread
from typing import Awaitable
from typing import Dict
from typing import List
from typing import Optional
from typing import Type
from typing import Union
from .clock import AbstractClock
from .rate import Rate
from .rate import RateItem
logger = logging.getLogger("pyrate_limiter")
class AbstractBucket(ABC):
"""Base bucket interface
Assumption: len(rates) always > 0
TODO: allow empty rates
"""
rates: List[Rate]
failing_rate: Optional[Rate] = None
@abstractmethod
def put(self, item: RateItem) -> Union[bool, Awaitable[bool]]:
"""Put an item (typically the current time) in the bucket
return true if successful, otherwise false
"""
@abstractmethod
def leak(
self,
current_timestamp: Optional[int] = None,
) -> Union[int, Awaitable[int]]:
"""leaking bucket - removing items that are outdated"""
@abstractmethod
def flush(self) -> Union[None, Awaitable[None]]:
"""Flush the whole bucket
- Must remove `failing-rate` after flushing
"""
@abstractmethod
def count(self) -> Union[int, Awaitable[int]]:
"""Count number of items in the bucket"""
@abstractmethod
def peek(self, index: int) -> Union[Optional[RateItem], Awaitable[Optional[RateItem]]]:
"""Peek at the rate-item at a specific index in latest-to-earliest order
NOTE: The reason we cannot peek from the start of the queue(earliest-to-latest) is
we can't really tell how many outdated items are still in the queue
"""
def waiting(self, item: RateItem) -> Union[int, Awaitable[int]]:
"""Calculate time until bucket become availabe to consume an item again"""
if self.failing_rate is None:
return 0
assert item.weight > 0, "Item's weight must > 0"
if item.weight > self.failing_rate.limit:
return -1
bound_item = self.peek(self.failing_rate.limit - item.weight)
if bound_item is None:
# NOTE: No waiting, bucket is immediately ready
return 0
def _calc_waiting(inner_bound_item: RateItem) -> int:
assert self.failing_rate is not None # NOTE: silence mypy
lower_time_bound = item.timestamp - self.failing_rate.interval
upper_time_bound = inner_bound_item.timestamp
return upper_time_bound - lower_time_bound
async def _calc_waiting_async() -> int:
nonlocal bound_item
while isawaitable(bound_item):
bound_item = await bound_item
if bound_item is None:
# NOTE: No waiting, bucket is immediately ready
return 0
assert isinstance(bound_item, RateItem)
return _calc_waiting(bound_item)
if isawaitable(bound_item):
return _calc_waiting_async()
assert isinstance(bound_item, RateItem)
return _calc_waiting(bound_item)
def limiter_lock(self) -> Optional[object]: # type: ignore
"""An additional lock to be used by Limiter in-front of the thread lock.
Intended for multiprocessing environments where a thread lock is insufficient.
"""
return None
class Leaker(Thread):
"""Responsible for scheduling buckets' leaking at the background either
through a daemon task(for sync buckets) or a task using asyncio.Task
"""
daemon = True
name = "PyrateLimiter's Leaker"
sync_buckets: Optional[Dict[int, AbstractBucket]] = None
async_buckets: Optional[Dict[int, AbstractBucket]] = None
clocks: Optional[Dict[int, AbstractClock]] = None
leak_interval: int = 10_000
aio_leak_task: Optional[asyncio.Task] = None
def __init__(self, leak_interval: int):
self.sync_buckets = defaultdict()
self.async_buckets = defaultdict()
self.clocks = defaultdict()
self.leak_interval = leak_interval
super().__init__()
def register(self, bucket: AbstractBucket, clock: AbstractClock):
"""Register a new bucket with its associated clock"""
assert self.sync_buckets is not None
assert self.clocks is not None
assert self.async_buckets is not None
try_leak = bucket.leak(0)
bucket_id = id(bucket)
if iscoroutine(try_leak):
try_leak.close()
self.async_buckets[bucket_id] = bucket
else:
self.sync_buckets[bucket_id] = bucket
self.clocks[bucket_id] = clock
def deregister(self, bucket_id: int) -> bool:
"""Deregister a bucket"""
if self.sync_buckets and bucket_id in self.sync_buckets:
del self.sync_buckets[bucket_id]
assert self.clocks
del self.clocks[bucket_id]
return True
if self.async_buckets and bucket_id in self.async_buckets:
del self.async_buckets[bucket_id]
assert self.clocks
del self.clocks[bucket_id]
if not self.async_buckets and self.aio_leak_task:
self.aio_leak_task.cancel()
self.aio_leak_task = None
return True
return False
async def _leak(self, buckets: Dict[int, AbstractBucket]) -> None:
assert self.clocks
while buckets:
try:
for bucket_id, bucket in list(buckets.items()):
clock = self.clocks[bucket_id]
now = clock.now()
while isawaitable(now):
now = await now
assert isinstance(now, int)
leak = bucket.leak(now)
while isawaitable(leak):
leak = await leak
assert isinstance(leak, int)
await asyncio.sleep(self.leak_interval / 1000)
except RuntimeError as e:
logger.info("Leak task stopped due to event loop shutdown. %s", e)
return
def leak_async(self):
if self.async_buckets and not self.aio_leak_task:
self.aio_leak_task = asyncio.create_task(self._leak(self.async_buckets))
def run(self) -> None:
""" Override the original method of Thread
Not meant to be called directly
"""
assert self.sync_buckets
asyncio.run(self._leak(self.sync_buckets))
def start(self) -> None:
""" Override the original method of Thread
Call to run leaking sync buckets
"""
if self.sync_buckets and not self.is_alive():
super().start()
class BucketFactory(ABC):
"""Asbtract BucketFactory class.
It is reserved for user to implement/override this class with
his own bucket-routing/creating logic
"""
_leaker: Optional[Leaker] = None
_leak_interval: int = 10_000
@property
def leak_interval(self) -> int:
"""Retrieve leak-interval from inner Leaker task"""
if not self._leaker:
return self._leak_interval
return self._leaker.leak_interval
@leak_interval.setter
def leak_interval(self, value: int):
"""Set leak-interval for inner Leaker task"""
if self._leaker:
self._leaker.leak_interval = value
self._leak_interval = value
@abstractmethod
def wrap_item(
self,
name: str,
weight: int = 1,
) -> Union[RateItem, Awaitable[RateItem]]:
"""Add the current timestamp to the receiving item using any clock backend
- Turn it into a RateItem
- Can return either a coroutine or a RateItem instance
"""
@abstractmethod
def get(self, item: RateItem) -> Union[AbstractBucket, Awaitable[AbstractBucket]]:
"""Get the corresponding bucket to this item"""
def create(
self,
clock: AbstractClock,
bucket_class: Type[AbstractBucket],
*args,
**kwargs,
) -> AbstractBucket:
"""Creating a bucket dynamically"""
bucket = bucket_class(*args, **kwargs)
self.schedule_leak(bucket, clock)
return bucket
def schedule_leak(self, new_bucket: AbstractBucket, associated_clock: AbstractClock) -> None:
"""Schedule all the buckets' leak, reset bucket's failing rate"""
assert new_bucket.rates, "Bucket rates are not set"
if not self._leaker:
self._leaker = Leaker(self.leak_interval)
self._leaker.register(new_bucket, associated_clock)
self._leaker.start()
self._leaker.leak_async()
def get_buckets(self) -> List[AbstractBucket]:
"""Iterator over all buckets in the factory
"""
if not self._leaker:
return []
buckets = []
if self._leaker.sync_buckets:
for _, bucket in self._leaker.sync_buckets.items():
buckets.append(bucket)
if self._leaker.async_buckets:
for _, bucket in self._leaker.async_buckets.items():
buckets.append(bucket)
return buckets
def dispose(self, bucket: Union[int, AbstractBucket]) -> bool:
"""Delete a bucket from the factory"""
if isinstance(bucket, AbstractBucket):
bucket = id(bucket)
assert isinstance(bucket, int), "not valid bucket id"
if not self._leaker:
return False
return self._leaker.deregister(bucket)