Initial commit
This commit is contained in:
@@ -0,0 +1,304 @@
|
||||
""" Implement this class to create
|
||||
a workable bucket for Limiter to use
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from inspect import isawaitable
|
||||
from inspect import iscoroutine
|
||||
from threading import Thread
|
||||
from typing import Awaitable
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Type
|
||||
from typing import Union
|
||||
|
||||
from .clock import AbstractClock
|
||||
from .rate import Rate
|
||||
from .rate import RateItem
|
||||
|
||||
logger = logging.getLogger("pyrate_limiter")
|
||||
|
||||
|
||||
class AbstractBucket(ABC):
|
||||
"""Base bucket interface
|
||||
Assumption: len(rates) always > 0
|
||||
TODO: allow empty rates
|
||||
"""
|
||||
|
||||
rates: List[Rate]
|
||||
failing_rate: Optional[Rate] = None
|
||||
|
||||
@abstractmethod
|
||||
def put(self, item: RateItem) -> Union[bool, Awaitable[bool]]:
|
||||
"""Put an item (typically the current time) in the bucket
|
||||
return true if successful, otherwise false
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def leak(
|
||||
self,
|
||||
current_timestamp: Optional[int] = None,
|
||||
) -> Union[int, Awaitable[int]]:
|
||||
"""leaking bucket - removing items that are outdated"""
|
||||
|
||||
@abstractmethod
|
||||
def flush(self) -> Union[None, Awaitable[None]]:
|
||||
"""Flush the whole bucket
|
||||
- Must remove `failing-rate` after flushing
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def count(self) -> Union[int, Awaitable[int]]:
|
||||
"""Count number of items in the bucket"""
|
||||
|
||||
@abstractmethod
|
||||
def peek(self, index: int) -> Union[Optional[RateItem], Awaitable[Optional[RateItem]]]:
|
||||
"""Peek at the rate-item at a specific index in latest-to-earliest order
|
||||
NOTE: The reason we cannot peek from the start of the queue(earliest-to-latest) is
|
||||
we can't really tell how many outdated items are still in the queue
|
||||
"""
|
||||
|
||||
def waiting(self, item: RateItem) -> Union[int, Awaitable[int]]:
|
||||
"""Calculate time until bucket become availabe to consume an item again"""
|
||||
if self.failing_rate is None:
|
||||
return 0
|
||||
|
||||
assert item.weight > 0, "Item's weight must > 0"
|
||||
|
||||
if item.weight > self.failing_rate.limit:
|
||||
return -1
|
||||
|
||||
bound_item = self.peek(self.failing_rate.limit - item.weight)
|
||||
|
||||
if bound_item is None:
|
||||
# NOTE: No waiting, bucket is immediately ready
|
||||
return 0
|
||||
|
||||
def _calc_waiting(inner_bound_item: RateItem) -> int:
|
||||
assert self.failing_rate is not None # NOTE: silence mypy
|
||||
lower_time_bound = item.timestamp - self.failing_rate.interval
|
||||
upper_time_bound = inner_bound_item.timestamp
|
||||
return upper_time_bound - lower_time_bound
|
||||
|
||||
async def _calc_waiting_async() -> int:
|
||||
nonlocal bound_item
|
||||
|
||||
while isawaitable(bound_item):
|
||||
bound_item = await bound_item
|
||||
|
||||
if bound_item is None:
|
||||
# NOTE: No waiting, bucket is immediately ready
|
||||
return 0
|
||||
|
||||
assert isinstance(bound_item, RateItem)
|
||||
return _calc_waiting(bound_item)
|
||||
|
||||
if isawaitable(bound_item):
|
||||
return _calc_waiting_async()
|
||||
|
||||
assert isinstance(bound_item, RateItem)
|
||||
return _calc_waiting(bound_item)
|
||||
|
||||
def limiter_lock(self) -> Optional[object]: # type: ignore
|
||||
"""An additional lock to be used by Limiter in-front of the thread lock.
|
||||
Intended for multiprocessing environments where a thread lock is insufficient.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
class Leaker(Thread):
|
||||
"""Responsible for scheduling buckets' leaking at the background either
|
||||
through a daemon task(for sync buckets) or a task using asyncio.Task
|
||||
"""
|
||||
|
||||
daemon = True
|
||||
name = "PyrateLimiter's Leaker"
|
||||
sync_buckets: Optional[Dict[int, AbstractBucket]] = None
|
||||
async_buckets: Optional[Dict[int, AbstractBucket]] = None
|
||||
clocks: Optional[Dict[int, AbstractClock]] = None
|
||||
leak_interval: int = 10_000
|
||||
aio_leak_task: Optional[asyncio.Task] = None
|
||||
|
||||
def __init__(self, leak_interval: int):
|
||||
self.sync_buckets = defaultdict()
|
||||
self.async_buckets = defaultdict()
|
||||
self.clocks = defaultdict()
|
||||
self.leak_interval = leak_interval
|
||||
super().__init__()
|
||||
|
||||
def register(self, bucket: AbstractBucket, clock: AbstractClock):
|
||||
"""Register a new bucket with its associated clock"""
|
||||
assert self.sync_buckets is not None
|
||||
assert self.clocks is not None
|
||||
assert self.async_buckets is not None
|
||||
|
||||
try_leak = bucket.leak(0)
|
||||
bucket_id = id(bucket)
|
||||
|
||||
if iscoroutine(try_leak):
|
||||
try_leak.close()
|
||||
self.async_buckets[bucket_id] = bucket
|
||||
else:
|
||||
self.sync_buckets[bucket_id] = bucket
|
||||
|
||||
self.clocks[bucket_id] = clock
|
||||
|
||||
def deregister(self, bucket_id: int) -> bool:
|
||||
"""Deregister a bucket"""
|
||||
if self.sync_buckets and bucket_id in self.sync_buckets:
|
||||
del self.sync_buckets[bucket_id]
|
||||
assert self.clocks
|
||||
del self.clocks[bucket_id]
|
||||
return True
|
||||
|
||||
if self.async_buckets and bucket_id in self.async_buckets:
|
||||
del self.async_buckets[bucket_id]
|
||||
assert self.clocks
|
||||
del self.clocks[bucket_id]
|
||||
|
||||
if not self.async_buckets and self.aio_leak_task:
|
||||
self.aio_leak_task.cancel()
|
||||
self.aio_leak_task = None
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _leak(self, buckets: Dict[int, AbstractBucket]) -> None:
|
||||
assert self.clocks
|
||||
|
||||
while buckets:
|
||||
try:
|
||||
for bucket_id, bucket in list(buckets.items()):
|
||||
clock = self.clocks[bucket_id]
|
||||
now = clock.now()
|
||||
|
||||
while isawaitable(now):
|
||||
now = await now
|
||||
|
||||
assert isinstance(now, int)
|
||||
leak = bucket.leak(now)
|
||||
|
||||
while isawaitable(leak):
|
||||
leak = await leak
|
||||
|
||||
assert isinstance(leak, int)
|
||||
|
||||
await asyncio.sleep(self.leak_interval / 1000)
|
||||
except RuntimeError as e:
|
||||
logger.info("Leak task stopped due to event loop shutdown. %s", e)
|
||||
return
|
||||
|
||||
def leak_async(self):
|
||||
if self.async_buckets and not self.aio_leak_task:
|
||||
self.aio_leak_task = asyncio.create_task(self._leak(self.async_buckets))
|
||||
|
||||
def run(self) -> None:
|
||||
""" Override the original method of Thread
|
||||
Not meant to be called directly
|
||||
"""
|
||||
assert self.sync_buckets
|
||||
asyncio.run(self._leak(self.sync_buckets))
|
||||
|
||||
def start(self) -> None:
|
||||
""" Override the original method of Thread
|
||||
Call to run leaking sync buckets
|
||||
"""
|
||||
if self.sync_buckets and not self.is_alive():
|
||||
super().start()
|
||||
|
||||
|
||||
class BucketFactory(ABC):
|
||||
"""Asbtract BucketFactory class.
|
||||
It is reserved for user to implement/override this class with
|
||||
his own bucket-routing/creating logic
|
||||
"""
|
||||
|
||||
_leaker: Optional[Leaker] = None
|
||||
_leak_interval: int = 10_000
|
||||
|
||||
@property
|
||||
def leak_interval(self) -> int:
|
||||
"""Retrieve leak-interval from inner Leaker task"""
|
||||
if not self._leaker:
|
||||
return self._leak_interval
|
||||
return self._leaker.leak_interval
|
||||
|
||||
@leak_interval.setter
|
||||
def leak_interval(self, value: int):
|
||||
"""Set leak-interval for inner Leaker task"""
|
||||
if self._leaker:
|
||||
self._leaker.leak_interval = value
|
||||
self._leak_interval = value
|
||||
|
||||
@abstractmethod
|
||||
def wrap_item(
|
||||
self,
|
||||
name: str,
|
||||
weight: int = 1,
|
||||
) -> Union[RateItem, Awaitable[RateItem]]:
|
||||
"""Add the current timestamp to the receiving item using any clock backend
|
||||
- Turn it into a RateItem
|
||||
- Can return either a coroutine or a RateItem instance
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, item: RateItem) -> Union[AbstractBucket, Awaitable[AbstractBucket]]:
|
||||
"""Get the corresponding bucket to this item"""
|
||||
|
||||
def create(
|
||||
self,
|
||||
clock: AbstractClock,
|
||||
bucket_class: Type[AbstractBucket],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> AbstractBucket:
|
||||
"""Creating a bucket dynamically"""
|
||||
bucket = bucket_class(*args, **kwargs)
|
||||
self.schedule_leak(bucket, clock)
|
||||
return bucket
|
||||
|
||||
def schedule_leak(self, new_bucket: AbstractBucket, associated_clock: AbstractClock) -> None:
|
||||
"""Schedule all the buckets' leak, reset bucket's failing rate"""
|
||||
assert new_bucket.rates, "Bucket rates are not set"
|
||||
|
||||
if not self._leaker:
|
||||
self._leaker = Leaker(self.leak_interval)
|
||||
|
||||
self._leaker.register(new_bucket, associated_clock)
|
||||
self._leaker.start()
|
||||
self._leaker.leak_async()
|
||||
|
||||
def get_buckets(self) -> List[AbstractBucket]:
|
||||
"""Iterator over all buckets in the factory
|
||||
"""
|
||||
if not self._leaker:
|
||||
return []
|
||||
|
||||
buckets = []
|
||||
|
||||
if self._leaker.sync_buckets:
|
||||
for _, bucket in self._leaker.sync_buckets.items():
|
||||
buckets.append(bucket)
|
||||
|
||||
if self._leaker.async_buckets:
|
||||
for _, bucket in self._leaker.async_buckets.items():
|
||||
buckets.append(bucket)
|
||||
|
||||
return buckets
|
||||
|
||||
def dispose(self, bucket: Union[int, AbstractBucket]) -> bool:
|
||||
"""Delete a bucket from the factory"""
|
||||
if isinstance(bucket, AbstractBucket):
|
||||
bucket = id(bucket)
|
||||
|
||||
assert isinstance(bucket, int), "not valid bucket id"
|
||||
|
||||
if not self._leaker:
|
||||
return False
|
||||
|
||||
return self._leaker.deregister(bucket)
|
||||
Reference in New Issue
Block a user