305 lines
9.5 KiB
Python
305 lines
9.5 KiB
Python
""" Implement this class to create
|
|
a workable bucket for Limiter to use
|
|
"""
|
|
import asyncio
|
|
import logging
|
|
from abc import ABC
|
|
from abc import abstractmethod
|
|
from collections import defaultdict
|
|
from inspect import isawaitable
|
|
from inspect import iscoroutine
|
|
from threading import Thread
|
|
from typing import Awaitable
|
|
from typing import Dict
|
|
from typing import List
|
|
from typing import Optional
|
|
from typing import Type
|
|
from typing import Union
|
|
|
|
from .clock import AbstractClock
|
|
from .rate import Rate
|
|
from .rate import RateItem
|
|
|
|
logger = logging.getLogger("pyrate_limiter")
|
|
|
|
|
|
class AbstractBucket(ABC):
|
|
"""Base bucket interface
|
|
Assumption: len(rates) always > 0
|
|
TODO: allow empty rates
|
|
"""
|
|
|
|
rates: List[Rate]
|
|
failing_rate: Optional[Rate] = None
|
|
|
|
@abstractmethod
|
|
def put(self, item: RateItem) -> Union[bool, Awaitable[bool]]:
|
|
"""Put an item (typically the current time) in the bucket
|
|
return true if successful, otherwise false
|
|
"""
|
|
|
|
@abstractmethod
|
|
def leak(
|
|
self,
|
|
current_timestamp: Optional[int] = None,
|
|
) -> Union[int, Awaitable[int]]:
|
|
"""leaking bucket - removing items that are outdated"""
|
|
|
|
@abstractmethod
|
|
def flush(self) -> Union[None, Awaitable[None]]:
|
|
"""Flush the whole bucket
|
|
- Must remove `failing-rate` after flushing
|
|
"""
|
|
|
|
@abstractmethod
|
|
def count(self) -> Union[int, Awaitable[int]]:
|
|
"""Count number of items in the bucket"""
|
|
|
|
@abstractmethod
|
|
def peek(self, index: int) -> Union[Optional[RateItem], Awaitable[Optional[RateItem]]]:
|
|
"""Peek at the rate-item at a specific index in latest-to-earliest order
|
|
NOTE: The reason we cannot peek from the start of the queue(earliest-to-latest) is
|
|
we can't really tell how many outdated items are still in the queue
|
|
"""
|
|
|
|
def waiting(self, item: RateItem) -> Union[int, Awaitable[int]]:
|
|
"""Calculate time until bucket become availabe to consume an item again"""
|
|
if self.failing_rate is None:
|
|
return 0
|
|
|
|
assert item.weight > 0, "Item's weight must > 0"
|
|
|
|
if item.weight > self.failing_rate.limit:
|
|
return -1
|
|
|
|
bound_item = self.peek(self.failing_rate.limit - item.weight)
|
|
|
|
if bound_item is None:
|
|
# NOTE: No waiting, bucket is immediately ready
|
|
return 0
|
|
|
|
def _calc_waiting(inner_bound_item: RateItem) -> int:
|
|
assert self.failing_rate is not None # NOTE: silence mypy
|
|
lower_time_bound = item.timestamp - self.failing_rate.interval
|
|
upper_time_bound = inner_bound_item.timestamp
|
|
return upper_time_bound - lower_time_bound
|
|
|
|
async def _calc_waiting_async() -> int:
|
|
nonlocal bound_item
|
|
|
|
while isawaitable(bound_item):
|
|
bound_item = await bound_item
|
|
|
|
if bound_item is None:
|
|
# NOTE: No waiting, bucket is immediately ready
|
|
return 0
|
|
|
|
assert isinstance(bound_item, RateItem)
|
|
return _calc_waiting(bound_item)
|
|
|
|
if isawaitable(bound_item):
|
|
return _calc_waiting_async()
|
|
|
|
assert isinstance(bound_item, RateItem)
|
|
return _calc_waiting(bound_item)
|
|
|
|
def limiter_lock(self) -> Optional[object]: # type: ignore
|
|
"""An additional lock to be used by Limiter in-front of the thread lock.
|
|
Intended for multiprocessing environments where a thread lock is insufficient.
|
|
"""
|
|
return None
|
|
|
|
|
|
class Leaker(Thread):
|
|
"""Responsible for scheduling buckets' leaking at the background either
|
|
through a daemon task(for sync buckets) or a task using asyncio.Task
|
|
"""
|
|
|
|
daemon = True
|
|
name = "PyrateLimiter's Leaker"
|
|
sync_buckets: Optional[Dict[int, AbstractBucket]] = None
|
|
async_buckets: Optional[Dict[int, AbstractBucket]] = None
|
|
clocks: Optional[Dict[int, AbstractClock]] = None
|
|
leak_interval: int = 10_000
|
|
aio_leak_task: Optional[asyncio.Task] = None
|
|
|
|
def __init__(self, leak_interval: int):
|
|
self.sync_buckets = defaultdict()
|
|
self.async_buckets = defaultdict()
|
|
self.clocks = defaultdict()
|
|
self.leak_interval = leak_interval
|
|
super().__init__()
|
|
|
|
def register(self, bucket: AbstractBucket, clock: AbstractClock):
|
|
"""Register a new bucket with its associated clock"""
|
|
assert self.sync_buckets is not None
|
|
assert self.clocks is not None
|
|
assert self.async_buckets is not None
|
|
|
|
try_leak = bucket.leak(0)
|
|
bucket_id = id(bucket)
|
|
|
|
if iscoroutine(try_leak):
|
|
try_leak.close()
|
|
self.async_buckets[bucket_id] = bucket
|
|
else:
|
|
self.sync_buckets[bucket_id] = bucket
|
|
|
|
self.clocks[bucket_id] = clock
|
|
|
|
def deregister(self, bucket_id: int) -> bool:
|
|
"""Deregister a bucket"""
|
|
if self.sync_buckets and bucket_id in self.sync_buckets:
|
|
del self.sync_buckets[bucket_id]
|
|
assert self.clocks
|
|
del self.clocks[bucket_id]
|
|
return True
|
|
|
|
if self.async_buckets and bucket_id in self.async_buckets:
|
|
del self.async_buckets[bucket_id]
|
|
assert self.clocks
|
|
del self.clocks[bucket_id]
|
|
|
|
if not self.async_buckets and self.aio_leak_task:
|
|
self.aio_leak_task.cancel()
|
|
self.aio_leak_task = None
|
|
|
|
return True
|
|
|
|
return False
|
|
|
|
async def _leak(self, buckets: Dict[int, AbstractBucket]) -> None:
|
|
assert self.clocks
|
|
|
|
while buckets:
|
|
try:
|
|
for bucket_id, bucket in list(buckets.items()):
|
|
clock = self.clocks[bucket_id]
|
|
now = clock.now()
|
|
|
|
while isawaitable(now):
|
|
now = await now
|
|
|
|
assert isinstance(now, int)
|
|
leak = bucket.leak(now)
|
|
|
|
while isawaitable(leak):
|
|
leak = await leak
|
|
|
|
assert isinstance(leak, int)
|
|
|
|
await asyncio.sleep(self.leak_interval / 1000)
|
|
except RuntimeError as e:
|
|
logger.info("Leak task stopped due to event loop shutdown. %s", e)
|
|
return
|
|
|
|
def leak_async(self):
|
|
if self.async_buckets and not self.aio_leak_task:
|
|
self.aio_leak_task = asyncio.create_task(self._leak(self.async_buckets))
|
|
|
|
def run(self) -> None:
|
|
""" Override the original method of Thread
|
|
Not meant to be called directly
|
|
"""
|
|
assert self.sync_buckets
|
|
asyncio.run(self._leak(self.sync_buckets))
|
|
|
|
def start(self) -> None:
|
|
""" Override the original method of Thread
|
|
Call to run leaking sync buckets
|
|
"""
|
|
if self.sync_buckets and not self.is_alive():
|
|
super().start()
|
|
|
|
|
|
class BucketFactory(ABC):
|
|
"""Asbtract BucketFactory class.
|
|
It is reserved for user to implement/override this class with
|
|
his own bucket-routing/creating logic
|
|
"""
|
|
|
|
_leaker: Optional[Leaker] = None
|
|
_leak_interval: int = 10_000
|
|
|
|
@property
|
|
def leak_interval(self) -> int:
|
|
"""Retrieve leak-interval from inner Leaker task"""
|
|
if not self._leaker:
|
|
return self._leak_interval
|
|
return self._leaker.leak_interval
|
|
|
|
@leak_interval.setter
|
|
def leak_interval(self, value: int):
|
|
"""Set leak-interval for inner Leaker task"""
|
|
if self._leaker:
|
|
self._leaker.leak_interval = value
|
|
self._leak_interval = value
|
|
|
|
@abstractmethod
|
|
def wrap_item(
|
|
self,
|
|
name: str,
|
|
weight: int = 1,
|
|
) -> Union[RateItem, Awaitable[RateItem]]:
|
|
"""Add the current timestamp to the receiving item using any clock backend
|
|
- Turn it into a RateItem
|
|
- Can return either a coroutine or a RateItem instance
|
|
"""
|
|
|
|
@abstractmethod
|
|
def get(self, item: RateItem) -> Union[AbstractBucket, Awaitable[AbstractBucket]]:
|
|
"""Get the corresponding bucket to this item"""
|
|
|
|
def create(
|
|
self,
|
|
clock: AbstractClock,
|
|
bucket_class: Type[AbstractBucket],
|
|
*args,
|
|
**kwargs,
|
|
) -> AbstractBucket:
|
|
"""Creating a bucket dynamically"""
|
|
bucket = bucket_class(*args, **kwargs)
|
|
self.schedule_leak(bucket, clock)
|
|
return bucket
|
|
|
|
def schedule_leak(self, new_bucket: AbstractBucket, associated_clock: AbstractClock) -> None:
|
|
"""Schedule all the buckets' leak, reset bucket's failing rate"""
|
|
assert new_bucket.rates, "Bucket rates are not set"
|
|
|
|
if not self._leaker:
|
|
self._leaker = Leaker(self.leak_interval)
|
|
|
|
self._leaker.register(new_bucket, associated_clock)
|
|
self._leaker.start()
|
|
self._leaker.leak_async()
|
|
|
|
def get_buckets(self) -> List[AbstractBucket]:
|
|
"""Iterator over all buckets in the factory
|
|
"""
|
|
if not self._leaker:
|
|
return []
|
|
|
|
buckets = []
|
|
|
|
if self._leaker.sync_buckets:
|
|
for _, bucket in self._leaker.sync_buckets.items():
|
|
buckets.append(bucket)
|
|
|
|
if self._leaker.async_buckets:
|
|
for _, bucket in self._leaker.async_buckets.items():
|
|
buckets.append(bucket)
|
|
|
|
return buckets
|
|
|
|
def dispose(self, bucket: Union[int, AbstractBucket]) -> bool:
|
|
"""Delete a bucket from the factory"""
|
|
if isinstance(bucket, AbstractBucket):
|
|
bucket = id(bucket)
|
|
|
|
assert isinstance(bucket, int), "not valid bucket id"
|
|
|
|
if not self._leaker:
|
|
return False
|
|
|
|
return self._leaker.deregister(bucket)
|