Initial commit
This commit is contained in:
@@ -0,0 +1,7 @@
|
||||
# flake8: noqa
|
||||
from .abstracts import *
|
||||
from .buckets import *
|
||||
from .clocks import *
|
||||
from .exceptions import *
|
||||
from .limiter import *
|
||||
from .utils import *
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,4 @@
|
||||
from .bucket import * # noqa
|
||||
from .clock import * # noqa
|
||||
from .rate import * # noqa
|
||||
from .wrappers import * # noqa
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,304 @@
|
||||
""" Implement this class to create
|
||||
a workable bucket for Limiter to use
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from inspect import isawaitable
|
||||
from inspect import iscoroutine
|
||||
from threading import Thread
|
||||
from typing import Awaitable
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Type
|
||||
from typing import Union
|
||||
|
||||
from .clock import AbstractClock
|
||||
from .rate import Rate
|
||||
from .rate import RateItem
|
||||
|
||||
logger = logging.getLogger("pyrate_limiter")
|
||||
|
||||
|
||||
class AbstractBucket(ABC):
|
||||
"""Base bucket interface
|
||||
Assumption: len(rates) always > 0
|
||||
TODO: allow empty rates
|
||||
"""
|
||||
|
||||
rates: List[Rate]
|
||||
failing_rate: Optional[Rate] = None
|
||||
|
||||
@abstractmethod
|
||||
def put(self, item: RateItem) -> Union[bool, Awaitable[bool]]:
|
||||
"""Put an item (typically the current time) in the bucket
|
||||
return true if successful, otherwise false
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def leak(
|
||||
self,
|
||||
current_timestamp: Optional[int] = None,
|
||||
) -> Union[int, Awaitable[int]]:
|
||||
"""leaking bucket - removing items that are outdated"""
|
||||
|
||||
@abstractmethod
|
||||
def flush(self) -> Union[None, Awaitable[None]]:
|
||||
"""Flush the whole bucket
|
||||
- Must remove `failing-rate` after flushing
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def count(self) -> Union[int, Awaitable[int]]:
|
||||
"""Count number of items in the bucket"""
|
||||
|
||||
@abstractmethod
|
||||
def peek(self, index: int) -> Union[Optional[RateItem], Awaitable[Optional[RateItem]]]:
|
||||
"""Peek at the rate-item at a specific index in latest-to-earliest order
|
||||
NOTE: The reason we cannot peek from the start of the queue(earliest-to-latest) is
|
||||
we can't really tell how many outdated items are still in the queue
|
||||
"""
|
||||
|
||||
def waiting(self, item: RateItem) -> Union[int, Awaitable[int]]:
|
||||
"""Calculate time until bucket become availabe to consume an item again"""
|
||||
if self.failing_rate is None:
|
||||
return 0
|
||||
|
||||
assert item.weight > 0, "Item's weight must > 0"
|
||||
|
||||
if item.weight > self.failing_rate.limit:
|
||||
return -1
|
||||
|
||||
bound_item = self.peek(self.failing_rate.limit - item.weight)
|
||||
|
||||
if bound_item is None:
|
||||
# NOTE: No waiting, bucket is immediately ready
|
||||
return 0
|
||||
|
||||
def _calc_waiting(inner_bound_item: RateItem) -> int:
|
||||
assert self.failing_rate is not None # NOTE: silence mypy
|
||||
lower_time_bound = item.timestamp - self.failing_rate.interval
|
||||
upper_time_bound = inner_bound_item.timestamp
|
||||
return upper_time_bound - lower_time_bound
|
||||
|
||||
async def _calc_waiting_async() -> int:
|
||||
nonlocal bound_item
|
||||
|
||||
while isawaitable(bound_item):
|
||||
bound_item = await bound_item
|
||||
|
||||
if bound_item is None:
|
||||
# NOTE: No waiting, bucket is immediately ready
|
||||
return 0
|
||||
|
||||
assert isinstance(bound_item, RateItem)
|
||||
return _calc_waiting(bound_item)
|
||||
|
||||
if isawaitable(bound_item):
|
||||
return _calc_waiting_async()
|
||||
|
||||
assert isinstance(bound_item, RateItem)
|
||||
return _calc_waiting(bound_item)
|
||||
|
||||
def limiter_lock(self) -> Optional[object]: # type: ignore
|
||||
"""An additional lock to be used by Limiter in-front of the thread lock.
|
||||
Intended for multiprocessing environments where a thread lock is insufficient.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
class Leaker(Thread):
|
||||
"""Responsible for scheduling buckets' leaking at the background either
|
||||
through a daemon task(for sync buckets) or a task using asyncio.Task
|
||||
"""
|
||||
|
||||
daemon = True
|
||||
name = "PyrateLimiter's Leaker"
|
||||
sync_buckets: Optional[Dict[int, AbstractBucket]] = None
|
||||
async_buckets: Optional[Dict[int, AbstractBucket]] = None
|
||||
clocks: Optional[Dict[int, AbstractClock]] = None
|
||||
leak_interval: int = 10_000
|
||||
aio_leak_task: Optional[asyncio.Task] = None
|
||||
|
||||
def __init__(self, leak_interval: int):
|
||||
self.sync_buckets = defaultdict()
|
||||
self.async_buckets = defaultdict()
|
||||
self.clocks = defaultdict()
|
||||
self.leak_interval = leak_interval
|
||||
super().__init__()
|
||||
|
||||
def register(self, bucket: AbstractBucket, clock: AbstractClock):
|
||||
"""Register a new bucket with its associated clock"""
|
||||
assert self.sync_buckets is not None
|
||||
assert self.clocks is not None
|
||||
assert self.async_buckets is not None
|
||||
|
||||
try_leak = bucket.leak(0)
|
||||
bucket_id = id(bucket)
|
||||
|
||||
if iscoroutine(try_leak):
|
||||
try_leak.close()
|
||||
self.async_buckets[bucket_id] = bucket
|
||||
else:
|
||||
self.sync_buckets[bucket_id] = bucket
|
||||
|
||||
self.clocks[bucket_id] = clock
|
||||
|
||||
def deregister(self, bucket_id: int) -> bool:
|
||||
"""Deregister a bucket"""
|
||||
if self.sync_buckets and bucket_id in self.sync_buckets:
|
||||
del self.sync_buckets[bucket_id]
|
||||
assert self.clocks
|
||||
del self.clocks[bucket_id]
|
||||
return True
|
||||
|
||||
if self.async_buckets and bucket_id in self.async_buckets:
|
||||
del self.async_buckets[bucket_id]
|
||||
assert self.clocks
|
||||
del self.clocks[bucket_id]
|
||||
|
||||
if not self.async_buckets and self.aio_leak_task:
|
||||
self.aio_leak_task.cancel()
|
||||
self.aio_leak_task = None
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _leak(self, buckets: Dict[int, AbstractBucket]) -> None:
|
||||
assert self.clocks
|
||||
|
||||
while buckets:
|
||||
try:
|
||||
for bucket_id, bucket in list(buckets.items()):
|
||||
clock = self.clocks[bucket_id]
|
||||
now = clock.now()
|
||||
|
||||
while isawaitable(now):
|
||||
now = await now
|
||||
|
||||
assert isinstance(now, int)
|
||||
leak = bucket.leak(now)
|
||||
|
||||
while isawaitable(leak):
|
||||
leak = await leak
|
||||
|
||||
assert isinstance(leak, int)
|
||||
|
||||
await asyncio.sleep(self.leak_interval / 1000)
|
||||
except RuntimeError as e:
|
||||
logger.info("Leak task stopped due to event loop shutdown. %s", e)
|
||||
return
|
||||
|
||||
def leak_async(self):
|
||||
if self.async_buckets and not self.aio_leak_task:
|
||||
self.aio_leak_task = asyncio.create_task(self._leak(self.async_buckets))
|
||||
|
||||
def run(self) -> None:
|
||||
""" Override the original method of Thread
|
||||
Not meant to be called directly
|
||||
"""
|
||||
assert self.sync_buckets
|
||||
asyncio.run(self._leak(self.sync_buckets))
|
||||
|
||||
def start(self) -> None:
|
||||
""" Override the original method of Thread
|
||||
Call to run leaking sync buckets
|
||||
"""
|
||||
if self.sync_buckets and not self.is_alive():
|
||||
super().start()
|
||||
|
||||
|
||||
class BucketFactory(ABC):
|
||||
"""Asbtract BucketFactory class.
|
||||
It is reserved for user to implement/override this class with
|
||||
his own bucket-routing/creating logic
|
||||
"""
|
||||
|
||||
_leaker: Optional[Leaker] = None
|
||||
_leak_interval: int = 10_000
|
||||
|
||||
@property
|
||||
def leak_interval(self) -> int:
|
||||
"""Retrieve leak-interval from inner Leaker task"""
|
||||
if not self._leaker:
|
||||
return self._leak_interval
|
||||
return self._leaker.leak_interval
|
||||
|
||||
@leak_interval.setter
|
||||
def leak_interval(self, value: int):
|
||||
"""Set leak-interval for inner Leaker task"""
|
||||
if self._leaker:
|
||||
self._leaker.leak_interval = value
|
||||
self._leak_interval = value
|
||||
|
||||
@abstractmethod
|
||||
def wrap_item(
|
||||
self,
|
||||
name: str,
|
||||
weight: int = 1,
|
||||
) -> Union[RateItem, Awaitable[RateItem]]:
|
||||
"""Add the current timestamp to the receiving item using any clock backend
|
||||
- Turn it into a RateItem
|
||||
- Can return either a coroutine or a RateItem instance
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, item: RateItem) -> Union[AbstractBucket, Awaitable[AbstractBucket]]:
|
||||
"""Get the corresponding bucket to this item"""
|
||||
|
||||
def create(
|
||||
self,
|
||||
clock: AbstractClock,
|
||||
bucket_class: Type[AbstractBucket],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> AbstractBucket:
|
||||
"""Creating a bucket dynamically"""
|
||||
bucket = bucket_class(*args, **kwargs)
|
||||
self.schedule_leak(bucket, clock)
|
||||
return bucket
|
||||
|
||||
def schedule_leak(self, new_bucket: AbstractBucket, associated_clock: AbstractClock) -> None:
|
||||
"""Schedule all the buckets' leak, reset bucket's failing rate"""
|
||||
assert new_bucket.rates, "Bucket rates are not set"
|
||||
|
||||
if not self._leaker:
|
||||
self._leaker = Leaker(self.leak_interval)
|
||||
|
||||
self._leaker.register(new_bucket, associated_clock)
|
||||
self._leaker.start()
|
||||
self._leaker.leak_async()
|
||||
|
||||
def get_buckets(self) -> List[AbstractBucket]:
|
||||
"""Iterator over all buckets in the factory
|
||||
"""
|
||||
if not self._leaker:
|
||||
return []
|
||||
|
||||
buckets = []
|
||||
|
||||
if self._leaker.sync_buckets:
|
||||
for _, bucket in self._leaker.sync_buckets.items():
|
||||
buckets.append(bucket)
|
||||
|
||||
if self._leaker.async_buckets:
|
||||
for _, bucket in self._leaker.async_buckets.items():
|
||||
buckets.append(bucket)
|
||||
|
||||
return buckets
|
||||
|
||||
def dispose(self, bucket: Union[int, AbstractBucket]) -> bool:
|
||||
"""Delete a bucket from the factory"""
|
||||
if isinstance(bucket, AbstractBucket):
|
||||
bucket = id(bucket)
|
||||
|
||||
assert isinstance(bucket, int), "not valid bucket id"
|
||||
|
||||
if not self._leaker:
|
||||
return False
|
||||
|
||||
return self._leaker.deregister(bucket)
|
||||
@@ -0,0 +1,12 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Awaitable
|
||||
from typing import Union
|
||||
|
||||
|
||||
class AbstractClock(ABC):
|
||||
"""Clock that return timestamp for `now`"""
|
||||
|
||||
@abstractmethod
|
||||
def now(self) -> Union[int, Awaitable[int]]:
|
||||
"""Get time as of now, in miliseconds"""
|
||||
@@ -0,0 +1,96 @@
|
||||
"""Unit classes that deals with rate, item & duration
|
||||
"""
|
||||
from enum import Enum
|
||||
from typing import Union
|
||||
|
||||
|
||||
class Duration(Enum):
|
||||
"""Interval helper class"""
|
||||
|
||||
SECOND = 1000
|
||||
MINUTE = 1000 * 60
|
||||
HOUR = 1000 * 60 * 60
|
||||
DAY = 1000 * 60 * 60 * 24
|
||||
WEEK = 1000 * 60 * 60 * 24 * 7
|
||||
|
||||
def __mul__(self, mutiplier: float) -> int:
|
||||
return int(self.value * mutiplier)
|
||||
|
||||
def __rmul__(self, multiplier: float) -> int:
|
||||
return self.__mul__(multiplier)
|
||||
|
||||
def __add__(self, another_duration: Union["Duration", int]) -> int:
|
||||
return self.value + int(another_duration)
|
||||
|
||||
def __radd__(self, another_duration: Union["Duration", int]) -> int:
|
||||
return self.__add__(another_duration)
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self.value
|
||||
|
||||
def __eq__(self, duration: object) -> bool:
|
||||
if not isinstance(duration, (Duration, int)):
|
||||
return NotImplemented
|
||||
|
||||
return self.value == int(duration)
|
||||
|
||||
@staticmethod
|
||||
def readable(value: int) -> str:
|
||||
notes = [
|
||||
(Duration.WEEK, "w"),
|
||||
(Duration.DAY, "d"),
|
||||
(Duration.HOUR, "h"),
|
||||
(Duration.MINUTE, "m"),
|
||||
(Duration.SECOND, "s"),
|
||||
]
|
||||
|
||||
for note, shorten in notes:
|
||||
if value >= note.value:
|
||||
decimal_value = value / note.value
|
||||
return f"{decimal_value:0.1f}{shorten}" # noqa: E231
|
||||
|
||||
return f"{value}ms"
|
||||
|
||||
|
||||
class RateItem:
|
||||
"""RateItem is a wrapper for bucket to work with"""
|
||||
|
||||
name: str
|
||||
weight: int
|
||||
timestamp: int
|
||||
|
||||
def __init__(self, name: str, timestamp: int, weight: int = 1):
|
||||
self.name = name
|
||||
self.timestamp = timestamp
|
||||
self.weight = weight
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"RateItem(name={self.name}, weight={self.weight}, timestamp={self.timestamp})"
|
||||
|
||||
|
||||
class Rate:
|
||||
"""Rate definition.
|
||||
|
||||
Args:
|
||||
limit: Number of requests allowed within ``interval``
|
||||
interval: Time interval, in miliseconds
|
||||
"""
|
||||
|
||||
limit: int
|
||||
interval: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
limit: int,
|
||||
interval: Union[int, Duration],
|
||||
):
|
||||
self.limit = limit
|
||||
self.interval = int(interval)
|
||||
assert self.interval
|
||||
assert self.limit
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"limit={self.limit}/{Duration.readable(self.interval)}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"limit={self.limit}/{self.interval}"
|
||||
@@ -0,0 +1,77 @@
|
||||
""" Wrappers over different abstract types
|
||||
"""
|
||||
from inspect import isawaitable
|
||||
from typing import Optional
|
||||
|
||||
from .bucket import AbstractBucket
|
||||
from .rate import RateItem
|
||||
|
||||
|
||||
class BucketAsyncWrapper(AbstractBucket):
|
||||
"""BucketAsyncWrapper is a wrapping over any bucket
|
||||
that turns a async/synchronous bucket into an async one
|
||||
"""
|
||||
|
||||
def __init__(self, bucket: AbstractBucket):
|
||||
assert isinstance(bucket, AbstractBucket)
|
||||
self.bucket = bucket
|
||||
|
||||
async def put(self, item: RateItem):
|
||||
result = self.bucket.put(item)
|
||||
|
||||
while isawaitable(result):
|
||||
result = await result
|
||||
|
||||
return result
|
||||
|
||||
async def count(self):
|
||||
result = self.bucket.count()
|
||||
|
||||
while isawaitable(result):
|
||||
result = await result
|
||||
|
||||
return result
|
||||
|
||||
async def leak(self, current_timestamp: Optional[int] = None) -> int:
|
||||
result = self.bucket.leak(current_timestamp)
|
||||
|
||||
while isawaitable(result):
|
||||
result = await result
|
||||
|
||||
assert isinstance(result, int)
|
||||
return result
|
||||
|
||||
async def flush(self) -> None:
|
||||
result = self.bucket.flush()
|
||||
|
||||
while isawaitable(result):
|
||||
# TODO: AbstractBucket.flush() may not have correct type annotation?
|
||||
result = await result # type: ignore
|
||||
|
||||
return None
|
||||
|
||||
async def peek(self, index: int) -> Optional[RateItem]:
|
||||
item = self.bucket.peek(index)
|
||||
|
||||
while isawaitable(item):
|
||||
item = await item
|
||||
|
||||
assert item is None or isinstance(item, RateItem)
|
||||
return item
|
||||
|
||||
async def waiting(self, item: RateItem) -> int:
|
||||
wait = super().waiting(item)
|
||||
|
||||
if isawaitable(wait):
|
||||
wait = await wait
|
||||
|
||||
assert isinstance(wait, int)
|
||||
return wait
|
||||
|
||||
@property
|
||||
def failing_rate(self):
|
||||
return self.bucket.failing_rate
|
||||
|
||||
@property
|
||||
def rates(self):
|
||||
return self.bucket.rates
|
||||
@@ -0,0 +1,10 @@
|
||||
# flake8: noqa
|
||||
"""Conrete bucket implementations
|
||||
"""
|
||||
from .in_memory_bucket import InMemoryBucket
|
||||
from .mp_bucket import MultiprocessBucket
|
||||
from .postgres import PostgresBucket
|
||||
from .postgres import Queries as PgQueries
|
||||
from .redis_bucket import RedisBucket
|
||||
from .sqlite_bucket import Queries as SQLiteQueries
|
||||
from .sqlite_bucket import SQLiteBucket
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,92 @@
|
||||
"""Naive bucket implementation using built-in list
|
||||
"""
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from ..abstracts import AbstractBucket
|
||||
from ..abstracts import Rate
|
||||
from ..abstracts import RateItem
|
||||
from ..utils import binary_search
|
||||
|
||||
|
||||
class InMemoryBucket(AbstractBucket):
|
||||
"""Simple In-memory Bucket using native list
|
||||
Clock can be either `time.time` or `time.monotonic`
|
||||
When leak, clock is required
|
||||
Pros: fast, safe, and precise
|
||||
Cons: since it resides in local memory, the data is not persistent, nor scalable
|
||||
Usecase: small applications, simple logic
|
||||
"""
|
||||
|
||||
items: List[RateItem]
|
||||
failing_rate: Optional[Rate]
|
||||
|
||||
def __init__(self, rates: List[Rate]):
|
||||
self.rates = sorted(rates, key=lambda r: r.interval)
|
||||
self.items = []
|
||||
|
||||
def put(self, item: RateItem) -> bool:
|
||||
if item.weight == 0:
|
||||
return True
|
||||
|
||||
current_length = len(self.items)
|
||||
after_length = item.weight + current_length
|
||||
|
||||
for rate in self.rates:
|
||||
if after_length < rate.limit:
|
||||
break
|
||||
|
||||
lower_bound_value = item.timestamp - rate.interval
|
||||
lower_bound_idx = binary_search(self.items, lower_bound_value)
|
||||
|
||||
if lower_bound_idx >= 0:
|
||||
count_existing_items = len(self.items) - lower_bound_idx
|
||||
space_available = rate.limit - count_existing_items
|
||||
else:
|
||||
space_available = rate.limit
|
||||
|
||||
if space_available < item.weight:
|
||||
self.failing_rate = rate
|
||||
return False
|
||||
|
||||
self.failing_rate = None
|
||||
|
||||
if item.weight > 1:
|
||||
self.items.extend([item for _ in range(item.weight)])
|
||||
else:
|
||||
self.items.append(item)
|
||||
|
||||
return True
|
||||
|
||||
def leak(self, current_timestamp: Optional[int] = None) -> int:
|
||||
assert current_timestamp is not None
|
||||
if self.items:
|
||||
max_interval = self.rates[-1].interval
|
||||
lower_bound = current_timestamp - max_interval
|
||||
|
||||
if lower_bound > self.items[-1].timestamp:
|
||||
remove_count = len(self.items)
|
||||
del self.items[:]
|
||||
return remove_count
|
||||
|
||||
if lower_bound < self.items[0].timestamp:
|
||||
return 0
|
||||
|
||||
idx = binary_search(self.items, lower_bound)
|
||||
del self.items[:idx]
|
||||
return idx
|
||||
|
||||
return 0
|
||||
|
||||
def flush(self) -> None:
|
||||
self.failing_rate = None
|
||||
del self.items[:]
|
||||
|
||||
def count(self) -> int:
|
||||
return len(self.items)
|
||||
|
||||
def peek(self, index: int) -> Optional[RateItem]:
|
||||
if not self.items:
|
||||
return None
|
||||
|
||||
return self.items[-1 - index] if abs(index) < self.count() else None
|
||||
@@ -0,0 +1,53 @@
|
||||
"""multiprocessing In-memory Bucket using a multiprocessing.Manager.ListProxy
|
||||
and a multiprocessing.Lock.
|
||||
"""
|
||||
from multiprocessing import Manager
|
||||
from multiprocessing import RLock
|
||||
from multiprocessing.managers import ListProxy
|
||||
from multiprocessing.synchronize import RLock as LockType
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from ..abstracts import Rate
|
||||
from ..abstracts import RateItem
|
||||
from pyrate_limiter.buckets import InMemoryBucket
|
||||
|
||||
|
||||
class MultiprocessBucket(InMemoryBucket):
|
||||
|
||||
items: List[RateItem] # ListProxy
|
||||
mp_lock: LockType
|
||||
|
||||
def __init__(self, rates: List[Rate], items: List[RateItem], mp_lock: LockType):
|
||||
|
||||
if not isinstance(items, ListProxy):
|
||||
raise ValueError("items must be a ListProxy")
|
||||
|
||||
self.rates = sorted(rates, key=lambda r: r.interval)
|
||||
self.items = items
|
||||
self.mp_lock = mp_lock
|
||||
|
||||
def put(self, item: RateItem) -> bool:
|
||||
with self.mp_lock:
|
||||
return super().put(item)
|
||||
|
||||
def leak(self, current_timestamp: Optional[int] = None) -> int:
|
||||
with self.mp_lock:
|
||||
return super().leak(current_timestamp)
|
||||
|
||||
def limiter_lock(self):
|
||||
return self.mp_lock
|
||||
|
||||
@classmethod
|
||||
def init(
|
||||
cls,
|
||||
rates: List[Rate],
|
||||
):
|
||||
"""
|
||||
Creates a single ListProxy so that this bucket can be shared across multiple processes.
|
||||
"""
|
||||
shared_items: List[RateItem] = Manager().list() # type: ignore[assignment]
|
||||
|
||||
mp_lock: LockType = RLock()
|
||||
|
||||
return cls(rates=rates, items=shared_items, mp_lock=mp_lock)
|
||||
@@ -0,0 +1,166 @@
|
||||
"""A bucket using PostgreSQL as backend
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Awaitable
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from ..abstracts import AbstractBucket
|
||||
from ..abstracts import Rate
|
||||
from ..abstracts import RateItem
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from psycopg_pool import ConnectionPool # type: ignore[import-untyped]
|
||||
|
||||
|
||||
class Queries:
|
||||
CREATE_BUCKET_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS {table} (
|
||||
name VARCHAR,
|
||||
weight SMALLINT,
|
||||
item_timestamp TIMESTAMP
|
||||
)
|
||||
"""
|
||||
CREATE_INDEX_ON_TIMESTAMP = """
|
||||
CREATE INDEX IF NOT EXISTS {index} ON {table} (item_timestamp)
|
||||
"""
|
||||
COUNT = """
|
||||
SELECT COUNT(*) FROM {table}
|
||||
"""
|
||||
PUT = """
|
||||
INSERT INTO {table} (name, weight, item_timestamp) VALUES (%s, %s, TO_TIMESTAMP(%s))
|
||||
"""
|
||||
FLUSH = """
|
||||
DELETE FROM {table}
|
||||
"""
|
||||
PEEK = """
|
||||
SELECT name, weight, (extract(EPOCH FROM item_timestamp) * 1000) as item_timestamp
|
||||
FROM {table}
|
||||
ORDER BY item_timestamp DESC
|
||||
LIMIT 1
|
||||
OFFSET {offset}
|
||||
"""
|
||||
LEAK = """
|
||||
DELETE FROM {table} WHERE item_timestamp < TO_TIMESTAMP({timestamp})
|
||||
"""
|
||||
LEAK_COUNT = """
|
||||
SELECT COUNT(*) FROM {table} WHERE item_timestamp < TO_TIMESTAMP({timestamp})
|
||||
"""
|
||||
|
||||
|
||||
class PostgresBucket(AbstractBucket):
|
||||
table: str
|
||||
pool: ConnectionPool
|
||||
|
||||
def __init__(self, pool: ConnectionPool, table: str, rates: List[Rate]):
|
||||
self.table = table.lower()
|
||||
self.pool = pool
|
||||
assert rates
|
||||
self.rates = rates
|
||||
self._full_tbl = f'ratelimit___{self.table}'
|
||||
self._create_table()
|
||||
|
||||
@contextmanager
|
||||
def _get_conn(self):
|
||||
with self.pool.connection() as conn:
|
||||
yield conn
|
||||
|
||||
def _create_table(self):
|
||||
with self._get_conn() as conn:
|
||||
conn.execute(Queries.CREATE_BUCKET_TABLE.format(table=self._full_tbl))
|
||||
index_name = f'timestampIndex_{self.table}'
|
||||
conn.execute(Queries.CREATE_INDEX_ON_TIMESTAMP.format(table=self._full_tbl, index=index_name))
|
||||
|
||||
def put(self, item: RateItem) -> Union[bool, Awaitable[bool]]:
|
||||
"""Put an item (typically the current time) in the bucket
|
||||
return true if successful, otherwise false
|
||||
"""
|
||||
if item.weight == 0:
|
||||
return True
|
||||
|
||||
with self._get_conn() as conn:
|
||||
for rate in self.rates:
|
||||
bound = f"SELECT TO_TIMESTAMP({item.timestamp / 1000}) - INTERVAL '{rate.interval} milliseconds'"
|
||||
query = f'SELECT COUNT(*) FROM {self._full_tbl} WHERE item_timestamp >= ({bound})'
|
||||
cur = conn.execute(query)
|
||||
count = int(cur.fetchone()[0])
|
||||
cur.close()
|
||||
|
||||
if rate.limit - count < item.weight:
|
||||
self.failing_rate = rate
|
||||
return False
|
||||
|
||||
self.failing_rate = None
|
||||
|
||||
query = Queries.PUT.format(table=self._full_tbl)
|
||||
|
||||
# https://www.psycopg.org/docs/extras.html#fast-exec
|
||||
|
||||
for _ in range(item.weight):
|
||||
conn.execute(query, (item.name, item.weight, item.timestamp / 1000))
|
||||
|
||||
return True
|
||||
|
||||
def leak(
|
||||
self,
|
||||
current_timestamp: Optional[int] = None,
|
||||
) -> Union[int, Awaitable[int]]:
|
||||
"""leaking bucket - removing items that are outdated"""
|
||||
assert current_timestamp is not None, "current-time must be passed on for leak"
|
||||
lower_bound = current_timestamp - self.rates[-1].interval
|
||||
|
||||
if lower_bound <= 0:
|
||||
return 0
|
||||
|
||||
count = 0
|
||||
|
||||
with self._get_conn() as conn:
|
||||
conn = conn.execute(Queries.LEAK_COUNT.format(table=self._full_tbl, timestamp=lower_bound / 1000))
|
||||
result = conn.fetchone()
|
||||
|
||||
if result:
|
||||
conn.execute(Queries.LEAK.format(table=self._full_tbl, timestamp=lower_bound / 1000))
|
||||
count = int(result[0])
|
||||
|
||||
return count
|
||||
|
||||
def flush(self) -> Union[None, Awaitable[None]]:
|
||||
"""Flush the whole bucket
|
||||
- Must remove `failing-rate` after flushing
|
||||
"""
|
||||
with self._get_conn() as conn:
|
||||
conn.execute(Queries.FLUSH.format(table=self._full_tbl))
|
||||
self.failing_rate = None
|
||||
|
||||
return None
|
||||
|
||||
def count(self) -> Union[int, Awaitable[int]]:
|
||||
"""Count number of items in the bucket"""
|
||||
count = 0
|
||||
with self._get_conn() as conn:
|
||||
conn = conn.execute(Queries.COUNT.format(table=self._full_tbl))
|
||||
result = conn.fetchone()
|
||||
assert result
|
||||
count = int(result[0])
|
||||
|
||||
return count
|
||||
|
||||
def peek(self, index: int) -> Union[Optional[RateItem], Awaitable[Optional[RateItem]]]:
|
||||
"""Peek at the rate-item at a specific index in latest-to-earliest order
|
||||
NOTE: The reason we cannot peek from the start of the queue(earliest-to-latest) is
|
||||
we can't really tell how many outdated items are still in the queue
|
||||
"""
|
||||
item = None
|
||||
|
||||
with self._get_conn() as conn:
|
||||
conn = conn.execute(Queries.PEEK.format(table=self._full_tbl, offset=index))
|
||||
result = conn.fetchone()
|
||||
if result:
|
||||
name, weight, timestamp = result[0], int(result[1]), int(result[2])
|
||||
item = RateItem(name=name, weight=weight, timestamp=timestamp)
|
||||
|
||||
return item
|
||||
@@ -0,0 +1,187 @@
|
||||
"""Bucket implementation using Redis
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from inspect import isawaitable
|
||||
from typing import Awaitable
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from ..abstracts import AbstractBucket
|
||||
from ..abstracts import Rate
|
||||
from ..abstracts import RateItem
|
||||
from ..utils import id_generator
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis import Redis
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
|
||||
|
||||
class LuaScript:
|
||||
"""Scripts that deal with bucket operations"""
|
||||
|
||||
PUT_ITEM = """
|
||||
local bucket = KEYS[1]
|
||||
local now = ARGV[1]
|
||||
local space_required = tonumber(ARGV[2])
|
||||
local item_name = ARGV[3]
|
||||
local rates_count = tonumber(ARGV[4])
|
||||
|
||||
for i=1,rates_count do
|
||||
local offset = (i - 1) * 2
|
||||
local interval = tonumber(ARGV[5 + offset])
|
||||
local limit = tonumber(ARGV[5 + offset + 1])
|
||||
local count = redis.call('ZCOUNT', bucket, now - interval, now)
|
||||
local space_available = limit - tonumber(count)
|
||||
if space_available < space_required then
|
||||
return i - 1
|
||||
end
|
||||
end
|
||||
|
||||
for i=1,space_required do
|
||||
redis.call('ZADD', bucket, now, item_name..i)
|
||||
end
|
||||
return -1
|
||||
"""
|
||||
|
||||
|
||||
class RedisBucket(AbstractBucket):
|
||||
"""A bucket using redis for storing data
|
||||
- We are not using redis' built-in TIME since it is non-deterministic
|
||||
- In distributed context, use local server time or a remote time server
|
||||
- Each bucket instance use a dedicated connection to avoid race-condition
|
||||
- can be either sync or async
|
||||
"""
|
||||
|
||||
rates: List[Rate]
|
||||
failing_rate: Optional[Rate]
|
||||
bucket_key: str
|
||||
script_hash: str
|
||||
redis: Union[Redis, AsyncRedis]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rates: List[Rate],
|
||||
redis: Union[Redis, AsyncRedis],
|
||||
bucket_key: str,
|
||||
script_hash: str,
|
||||
):
|
||||
self.rates = rates
|
||||
self.redis = redis
|
||||
self.bucket_key = bucket_key
|
||||
self.script_hash = script_hash
|
||||
self.failing_rate = None
|
||||
|
||||
@classmethod
|
||||
def init(
|
||||
cls,
|
||||
rates: List[Rate],
|
||||
redis: Union[Redis, AsyncRedis],
|
||||
bucket_key: str,
|
||||
):
|
||||
script_hash = redis.script_load(LuaScript.PUT_ITEM)
|
||||
|
||||
if isawaitable(script_hash):
|
||||
|
||||
async def _async_init():
|
||||
nonlocal script_hash
|
||||
script_hash = await script_hash
|
||||
return cls(rates, redis, bucket_key, script_hash)
|
||||
|
||||
return _async_init()
|
||||
|
||||
return cls(rates, redis, bucket_key, script_hash)
|
||||
|
||||
def _check_and_insert(self, item: RateItem) -> Union[Rate, None, Awaitable[Optional[Rate]]]:
|
||||
keys = [self.bucket_key]
|
||||
|
||||
args = [
|
||||
item.timestamp,
|
||||
item.weight,
|
||||
# NOTE: this is to avoid key collision since we are using ZSET
|
||||
f"{item.name}:{id_generator()}:", # noqa: E231
|
||||
len(self.rates),
|
||||
*[value for rate in self.rates for value in (rate.interval, rate.limit)],
|
||||
]
|
||||
|
||||
idx = self.redis.evalsha(self.script_hash, len(keys), *keys, *args)
|
||||
|
||||
def _handle_sync(returned_idx: int):
|
||||
assert isinstance(returned_idx, int), "Not int"
|
||||
if returned_idx < 0:
|
||||
return None
|
||||
|
||||
return self.rates[returned_idx]
|
||||
|
||||
async def _handle_async(returned_idx: Awaitable[int]):
|
||||
assert isawaitable(returned_idx), "Not corotine"
|
||||
awaited_idx = await returned_idx
|
||||
return _handle_sync(awaited_idx)
|
||||
|
||||
return _handle_async(idx) if isawaitable(idx) else _handle_sync(idx)
|
||||
|
||||
def put(self, item: RateItem) -> Union[bool, Awaitable[bool]]:
|
||||
"""Add item to key"""
|
||||
failing_rate = self._check_and_insert(item)
|
||||
if isawaitable(failing_rate):
|
||||
|
||||
async def _handle_async():
|
||||
self.failing_rate = await failing_rate
|
||||
return not bool(self.failing_rate)
|
||||
|
||||
return _handle_async()
|
||||
|
||||
assert isinstance(failing_rate, Rate) or failing_rate is None
|
||||
self.failing_rate = failing_rate
|
||||
return not bool(self.failing_rate)
|
||||
|
||||
def leak(self, current_timestamp: Optional[int] = None) -> Union[int, Awaitable[int]]:
|
||||
assert current_timestamp is not None
|
||||
return self.redis.zremrangebyscore(
|
||||
self.bucket_key,
|
||||
0,
|
||||
current_timestamp - self.rates[-1].interval,
|
||||
)
|
||||
|
||||
def flush(self):
|
||||
self.failing_rate = None
|
||||
return self.redis.delete(self.bucket_key)
|
||||
|
||||
def count(self):
|
||||
return self.redis.zcard(self.bucket_key)
|
||||
|
||||
def peek(self, index: int) -> Union[RateItem, None, Awaitable[Optional[RateItem]]]:
|
||||
items = self.redis.zrange(
|
||||
self.bucket_key,
|
||||
-1 - index,
|
||||
-1 - index,
|
||||
withscores=True,
|
||||
score_cast_func=int,
|
||||
)
|
||||
|
||||
if not items:
|
||||
return None
|
||||
|
||||
def _handle_items(received_items: List[Tuple[str, int]]):
|
||||
if not received_items:
|
||||
return None
|
||||
|
||||
item = received_items[0]
|
||||
rate_item = RateItem(name=str(item[0]), timestamp=item[1])
|
||||
return rate_item
|
||||
|
||||
if isawaitable(items):
|
||||
|
||||
async def _awaiting():
|
||||
nonlocal items
|
||||
items = await items
|
||||
return _handle_items(items)
|
||||
|
||||
return _awaiting()
|
||||
|
||||
assert isinstance(items, list)
|
||||
return _handle_items(items)
|
||||
@@ -0,0 +1,250 @@
|
||||
"""Bucket implementation using SQLite"""
|
||||
import logging
|
||||
import sqlite3
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from tempfile import gettempdir
|
||||
from threading import RLock
|
||||
from time import time
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
from ..abstracts import AbstractBucket
|
||||
from ..abstracts import Rate
|
||||
from ..abstracts import RateItem
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Queries:
|
||||
CREATE_BUCKET_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS '{table}' (
|
||||
name VARCHAR,
|
||||
item_timestamp INTEGER
|
||||
)
|
||||
"""
|
||||
CREATE_INDEX_ON_TIMESTAMP = """
|
||||
CREATE INDEX IF NOT EXISTS '{index_name}' ON '{table_name}' (item_timestamp)
|
||||
"""
|
||||
COUNT_BEFORE_INSERT = """
|
||||
SELECT :interval{index} as interval, COUNT(*) FROM '{table}'
|
||||
WHERE item_timestamp >= :current_timestamp - :interval{index}
|
||||
"""
|
||||
PUT_ITEM = """
|
||||
INSERT INTO '{table}' (name, item_timestamp) VALUES %s
|
||||
"""
|
||||
LEAK = """
|
||||
DELETE FROM "{table}" WHERE rowid IN (
|
||||
SELECT rowid FROM "{table}" ORDER BY item_timestamp ASC LIMIT {count});
|
||||
""".strip()
|
||||
COUNT_BEFORE_LEAK = """SELECT COUNT(*) FROM '{table}' WHERE item_timestamp < {current_timestamp} - {interval}"""
|
||||
FLUSH = """DELETE FROM '{table}'"""
|
||||
# The below sqls are for testing only
|
||||
DROP_TABLE = "DROP TABLE IF EXISTS '{table}'"
|
||||
DROP_INDEX = "DROP INDEX IF EXISTS '{index}'"
|
||||
COUNT_ALL = "SELECT COUNT(*) FROM '{table}'"
|
||||
GET_ALL_ITEM = "SELECT * FROM '{table}' ORDER BY item_timestamp ASC"
|
||||
GET_FIRST_ITEM = (
|
||||
"SELECT name, item_timestamp FROM '{table}' ORDER BY item_timestamp ASC"
|
||||
)
|
||||
GET_LAG = """
|
||||
SELECT (strftime ('%s', 'now') || substr(strftime ('%f', 'now'), 4)) - (
|
||||
SELECT item_timestamp
|
||||
FROM '{table}'
|
||||
ORDER BY item_timestamp
|
||||
ASC
|
||||
LIMIT 1
|
||||
)
|
||||
"""
|
||||
PEEK = 'SELECT * FROM "{table}" ORDER BY item_timestamp DESC LIMIT 1 OFFSET {count}'
|
||||
|
||||
|
||||
class SQLiteBucket(AbstractBucket):
|
||||
"""For sqlite bucket, we are using the sql time function as the clock
|
||||
item's timestamp wont matter here
|
||||
"""
|
||||
|
||||
rates: List[Rate]
|
||||
failing_rate: Optional[Rate]
|
||||
conn: sqlite3.Connection
|
||||
table: str
|
||||
full_count_query: str
|
||||
lock: RLock
|
||||
use_limiter_lock: bool
|
||||
|
||||
def __init__(
|
||||
self, rates: List[Rate], conn: sqlite3.Connection, table: str, lock=None
|
||||
):
|
||||
self.conn = conn
|
||||
self.table = table
|
||||
self.rates = rates
|
||||
|
||||
if not lock:
|
||||
self.use_limiter_lock = False
|
||||
self.lock = RLock()
|
||||
else:
|
||||
self.use_limiter_lock = True
|
||||
self.lock = lock
|
||||
|
||||
def limiter_lock(self):
|
||||
if self.use_limiter_lock:
|
||||
return self.lock
|
||||
else:
|
||||
return None
|
||||
|
||||
def _build_full_count_query(self, current_timestamp: int) -> Tuple[str, dict]:
|
||||
full_query: List[str] = []
|
||||
|
||||
parameters = {"current_timestamp": current_timestamp}
|
||||
|
||||
for index, rate in enumerate(self.rates):
|
||||
parameters[f"interval{index}"] = rate.interval
|
||||
query = Queries.COUNT_BEFORE_INSERT.format(table=self.table, index=index)
|
||||
full_query.append(query)
|
||||
|
||||
join_full_query = (
|
||||
" union ".join(full_query) if len(full_query) > 1 else full_query[0]
|
||||
)
|
||||
return join_full_query, parameters
|
||||
|
||||
def put(self, item: RateItem) -> bool:
|
||||
with self.lock:
|
||||
query, parameters = self._build_full_count_query(item.timestamp)
|
||||
cur = self.conn.execute(query, parameters)
|
||||
rate_limit_counts = cur.fetchall()
|
||||
cur.close()
|
||||
|
||||
for idx, result in enumerate(rate_limit_counts):
|
||||
interval, count = result
|
||||
rate = self.rates[idx]
|
||||
assert interval == rate.interval
|
||||
space_available = rate.limit - count
|
||||
|
||||
if space_available < item.weight:
|
||||
self.failing_rate = rate
|
||||
return False
|
||||
|
||||
items = ", ".join(
|
||||
[f"('{name}', {item.timestamp})" for name in [item.name] * item.weight]
|
||||
)
|
||||
query = (Queries.PUT_ITEM.format(table=self.table)) % items
|
||||
self.conn.execute(query).close()
|
||||
self.conn.commit()
|
||||
return True
|
||||
|
||||
def leak(self, current_timestamp: Optional[int] = None) -> int:
|
||||
"""Leaking/clean up bucket"""
|
||||
with self.lock:
|
||||
assert current_timestamp is not None
|
||||
query = Queries.COUNT_BEFORE_LEAK.format(
|
||||
table=self.table,
|
||||
interval=self.rates[-1].interval,
|
||||
current_timestamp=current_timestamp,
|
||||
)
|
||||
cur = self.conn.execute(query)
|
||||
count = cur.fetchone()[0]
|
||||
query = Queries.LEAK.format(table=self.table, count=count)
|
||||
cur.execute(query)
|
||||
cur.close()
|
||||
self.conn.commit()
|
||||
return count
|
||||
|
||||
def flush(self) -> None:
|
||||
with self.lock:
|
||||
self.conn.execute(Queries.FLUSH.format(table=self.table)).close()
|
||||
self.conn.commit()
|
||||
self.failing_rate = None
|
||||
|
||||
def count(self) -> int:
|
||||
with self.lock:
|
||||
cur = self.conn.execute(
|
||||
Queries.COUNT_ALL.format(table=self.table)
|
||||
)
|
||||
ret = cur.fetchone()[0]
|
||||
cur.close()
|
||||
return ret
|
||||
|
||||
def peek(self, index: int) -> Optional[RateItem]:
|
||||
with self.lock:
|
||||
query = Queries.PEEK.format(table=self.table, count=index)
|
||||
cur = self.conn.execute(query)
|
||||
item = cur.fetchone()
|
||||
cur.close()
|
||||
|
||||
if not item:
|
||||
return None
|
||||
|
||||
return RateItem(item[0], item[1])
|
||||
|
||||
@classmethod
|
||||
def init_from_file(
|
||||
cls,
|
||||
rates: List[Rate],
|
||||
table: str = "rate_bucket",
|
||||
db_path: Optional[str] = None,
|
||||
create_new_table: bool = True,
|
||||
use_file_lock: bool = False
|
||||
) -> "SQLiteBucket":
|
||||
|
||||
if db_path is None and use_file_lock:
|
||||
raise ValueError("db_path must be specified when using use_file_lock")
|
||||
|
||||
if db_path is None:
|
||||
temp_dir = Path(gettempdir())
|
||||
db_path = str(temp_dir / f"pyrate_limiter_{time()}.sqlite")
|
||||
|
||||
# TBD: FileLock switched to a thread-local FileLock in 3.11.0.
|
||||
# Should we set FileLock's thread_local to False, for cases where user is both multiprocessing & threading?
|
||||
# As is, the file lock should be Multi Process - Single Thread and non-filelock is Single Process - Multi Thread
|
||||
# A hybrid lock may be needed to gracefully handle both cases
|
||||
file_lock = None
|
||||
file_lock_ctx = nullcontext()
|
||||
|
||||
if use_file_lock:
|
||||
try:
|
||||
from filelock import FileLock # type: ignore[import-untyped]
|
||||
file_lock = FileLock(db_path + ".lock") # type: ignore[no-redef]
|
||||
file_lock_ctx: Union[nullcontext, FileLock] = file_lock # type: ignore[no-redef]
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"filelock is required for file locking. "
|
||||
"Please install it as optional dependency"
|
||||
)
|
||||
|
||||
with file_lock_ctx:
|
||||
assert db_path is not None
|
||||
assert db_path.endswith(".sqlite"), (
|
||||
"Please provide a valid sqlite file path"
|
||||
)
|
||||
|
||||
sqlite_connection = sqlite3.connect(
|
||||
db_path,
|
||||
isolation_level="DEFERRED",
|
||||
check_same_thread=False,
|
||||
)
|
||||
|
||||
cur = sqlite_connection.cursor()
|
||||
if use_file_lock:
|
||||
# https://www.sqlite.org/wal.html
|
||||
cur.execute("PRAGMA journal_mode=WAL;")
|
||||
|
||||
# https://www.sqlite.org/pragma.html#pragma_synchronous
|
||||
cur.execute("PRAGMA synchronous=NORMAL;")
|
||||
|
||||
if create_new_table:
|
||||
cur.execute(
|
||||
Queries.CREATE_BUCKET_TABLE.format(table=table)
|
||||
)
|
||||
|
||||
create_idx_query = Queries.CREATE_INDEX_ON_TIMESTAMP.format(
|
||||
index_name=f"idx_{table}_rate_item_timestamp",
|
||||
table_name=table,
|
||||
)
|
||||
|
||||
cur.execute(create_idx_query)
|
||||
cur.close()
|
||||
sqlite_connection.commit()
|
||||
|
||||
return cls(rates, sqlite_connection, table=table, lock=file_lock)
|
||||
90
venv/lib/python3.10/site-packages/pyrate_limiter/clocks.py
Normal file
90
venv/lib/python3.10/site-packages/pyrate_limiter/clocks.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Clock implementation using different backend"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from contextlib import nullcontext
|
||||
from time import monotonic
|
||||
from time import time
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from .abstracts import AbstractClock
|
||||
from .buckets import SQLiteBucket
|
||||
from .utils import dedicated_sqlite_clock_connection
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from psycopg_pool import ConnectionPool
|
||||
from threading import RLock
|
||||
|
||||
|
||||
class MonotonicClock(AbstractClock):
|
||||
def __init__(self):
|
||||
monotonic()
|
||||
|
||||
def now(self):
|
||||
return int(1000 * monotonic())
|
||||
|
||||
|
||||
class TimeClock(AbstractClock):
|
||||
def now(self):
|
||||
return int(1000 * time())
|
||||
|
||||
|
||||
class TimeAsyncClock(AbstractClock):
|
||||
"""Time Async Clock, meant for testing only"""
|
||||
|
||||
async def now(self) -> int:
|
||||
return int(1000 * time())
|
||||
|
||||
|
||||
class SQLiteClock(AbstractClock):
|
||||
"""Get timestamp using SQLite as remote clock backend"""
|
||||
|
||||
time_query = (
|
||||
"SELECT CAST(ROUND((julianday('now') - 2440587.5)*86400000) As INTEGER)"
|
||||
)
|
||||
|
||||
def __init__(self, conn: Union[sqlite3.Connection, SQLiteBucket]):
|
||||
"""
|
||||
In multiprocessing cases, use the bucket, so that a shared lock is used.
|
||||
"""
|
||||
|
||||
self.lock: Optional[RLock] = None
|
||||
|
||||
if isinstance(conn, SQLiteBucket):
|
||||
self.conn = conn.conn
|
||||
self.lock = conn.lock
|
||||
else:
|
||||
self.conn = conn
|
||||
|
||||
@classmethod
|
||||
def default(cls):
|
||||
conn = dedicated_sqlite_clock_connection()
|
||||
return cls(conn)
|
||||
|
||||
def now(self) -> int:
|
||||
with self.lock if self.lock else nullcontext():
|
||||
cur = self.conn.execute(self.time_query)
|
||||
now = cur.fetchone()[0]
|
||||
cur.close()
|
||||
return int(now)
|
||||
|
||||
|
||||
class PostgresClock(AbstractClock):
|
||||
"""Get timestamp using Postgres as remote clock backend"""
|
||||
|
||||
def __init__(self, pool: "ConnectionPool"):
|
||||
self.pool = pool
|
||||
|
||||
def now(self) -> int:
|
||||
value = 0
|
||||
|
||||
with self.pool.connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("SELECT EXTRACT(epoch FROM current_timestamp) * 1000")
|
||||
result = cur.fetchone()
|
||||
assert result, "unable to get current-timestamp from postgres"
|
||||
value = int(result[0])
|
||||
|
||||
return value
|
||||
@@ -0,0 +1,47 @@
|
||||
# pylint: disable=C0114,C0115
|
||||
from typing import Dict
|
||||
from typing import Union
|
||||
|
||||
from .abstracts.rate import Rate
|
||||
from .abstracts.rate import RateItem
|
||||
|
||||
|
||||
class BucketFullException(Exception):
|
||||
def __init__(self, item: RateItem, rate: Rate):
|
||||
error = f"Bucket for item={item.name} with Rate {rate} is already full"
|
||||
self.item = item
|
||||
self.rate = rate
|
||||
self.meta_info: Dict[str, Union[str, float]] = {
|
||||
"error": error,
|
||||
"name": item.name,
|
||||
"weight": item.weight,
|
||||
"rate": str(rate),
|
||||
}
|
||||
super().__init__(error)
|
||||
|
||||
def __reduce__(self):
|
||||
return (self.__class__, (self.item, self.rate))
|
||||
|
||||
|
||||
class LimiterDelayException(Exception):
|
||||
def __init__(self, item: RateItem, rate: Rate, actual_delay: int, max_delay: int):
|
||||
self.item = item
|
||||
self.rate = rate
|
||||
self.actual_delay = actual_delay
|
||||
self.max_delay = max_delay
|
||||
error = f"""
|
||||
Actual delay exceeded allowance: actual={actual_delay}, allowed={max_delay}
|
||||
Bucket for {item.name} with Rate {rate} is already full
|
||||
"""
|
||||
self.meta_info: Dict[str, Union[str, float]] = {
|
||||
"error": error,
|
||||
"name": item.name,
|
||||
"weight": item.weight,
|
||||
"rate": str(rate),
|
||||
"max_delay": max_delay,
|
||||
"actual_delay": actual_delay,
|
||||
}
|
||||
super().__init__(error)
|
||||
|
||||
def __reduce__(self):
|
||||
return (self.__class__, (self.item, self.rate, self.actual_delay, self.max_delay))
|
||||
482
venv/lib/python3.10/site-packages/pyrate_limiter/limiter.py
Normal file
482
venv/lib/python3.10/site-packages/pyrate_limiter/limiter.py
Normal file
@@ -0,0 +1,482 @@
|
||||
"""Limiter class implementation
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from inspect import isawaitable
|
||||
from threading import local
|
||||
from threading import RLock
|
||||
from time import sleep
|
||||
from typing import Any
|
||||
from typing import Awaitable
|
||||
from typing import Callable
|
||||
from typing import Iterable
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
from .abstracts import AbstractBucket
|
||||
from .abstracts import AbstractClock
|
||||
from .abstracts import BucketFactory
|
||||
from .abstracts import Duration
|
||||
from .abstracts import Rate
|
||||
from .abstracts import RateItem
|
||||
from .buckets import InMemoryBucket
|
||||
from .clocks import TimeClock
|
||||
from .exceptions import BucketFullException
|
||||
from .exceptions import LimiterDelayException
|
||||
|
||||
logger = logging.getLogger("pyrate_limiter")
|
||||
|
||||
ItemMapping = Callable[[Any], Tuple[str, int]]
|
||||
DecoratorWrapper = Callable[[Callable[[Any], Any]], Callable[[Any], Any]]
|
||||
|
||||
|
||||
class SingleBucketFactory(BucketFactory):
|
||||
"""Single-bucket factory for quick use with Limiter"""
|
||||
|
||||
bucket: AbstractBucket
|
||||
clock: AbstractClock
|
||||
|
||||
def __init__(self, bucket: AbstractBucket, clock: AbstractClock):
|
||||
self.clock = clock
|
||||
self.bucket = bucket
|
||||
self.schedule_leak(bucket, clock)
|
||||
|
||||
def wrap_item(self, name: str, weight: int = 1):
|
||||
now = self.clock.now()
|
||||
|
||||
async def wrap_async():
|
||||
return RateItem(name, await now, weight=weight)
|
||||
|
||||
def wrap_sync():
|
||||
return RateItem(name, now, weight=weight)
|
||||
|
||||
return wrap_async() if isawaitable(now) else wrap_sync()
|
||||
|
||||
def get(self, _: RateItem) -> AbstractBucket:
|
||||
return self.bucket
|
||||
|
||||
|
||||
@contextmanager
|
||||
def combined_lock(locks: Iterable, timeout_sec: Optional[float] = None):
|
||||
"""Acquires and releases multiple locks. Intended to be used in multiprocessing for a cross-process
|
||||
lock combined with in process thread RLocks"""
|
||||
acquired = []
|
||||
try:
|
||||
for lock in locks:
|
||||
if timeout_sec is not None:
|
||||
if not lock.acquire(timeout=timeout_sec):
|
||||
raise TimeoutError("Timeout while acquiring combined lock.")
|
||||
else:
|
||||
lock.acquire()
|
||||
acquired.append(lock)
|
||||
yield
|
||||
finally:
|
||||
for lock in reversed(acquired):
|
||||
lock.release()
|
||||
|
||||
|
||||
class Limiter:
|
||||
"""This class responsibility is to sum up all underlying logic
|
||||
and make working with async/sync functions easily
|
||||
"""
|
||||
|
||||
bucket_factory: BucketFactory
|
||||
raise_when_fail: bool
|
||||
retry_until_max_delay: bool
|
||||
max_delay: Optional[int] = None
|
||||
lock: Union[RLock, Iterable]
|
||||
buffer_ms: int
|
||||
|
||||
# async_lock is thread local, created on first use
|
||||
_thread_local: local
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
argument: Union[BucketFactory, AbstractBucket, Rate, List[Rate]],
|
||||
clock: AbstractClock = TimeClock(),
|
||||
raise_when_fail: bool = True,
|
||||
max_delay: Optional[Union[int, Duration]] = None,
|
||||
retry_until_max_delay: bool = False,
|
||||
buffer_ms: int = 50
|
||||
):
|
||||
"""Init Limiter using either a single bucket / multiple-bucket factory
|
||||
/ single rate / rate list.
|
||||
|
||||
Parameters:
|
||||
argument (Union[BucketFactory, AbstractBucket, Rate, List[Rate]]): The bucket or rate configuration.
|
||||
clock (AbstractClock, optional): The clock instance to use for rate limiting. Defaults to TimeClock().
|
||||
raise_when_fail (bool, optional): Whether to raise an exception when rate limiting fails. Defaults to True.
|
||||
max_delay (Optional[Union[int, Duration]], optional): The maximum delay allowed for rate limiting.
|
||||
Defaults to None.
|
||||
retry_until_max_delay (bool, optional): If True, retry operations until the maximum delay is reached.
|
||||
Useful for ensuring operations eventually succeed within the allowed delay window. Defaults to False.
|
||||
"""
|
||||
self.bucket_factory = self._init_bucket_factory(argument, clock=clock)
|
||||
self.raise_when_fail = raise_when_fail
|
||||
self.retry_until_max_delay = retry_until_max_delay
|
||||
self.buffer_ms = buffer_ms
|
||||
if max_delay is not None:
|
||||
if isinstance(max_delay, Duration):
|
||||
max_delay = int(max_delay)
|
||||
|
||||
assert max_delay >= 0, "Max-delay must not be negative"
|
||||
|
||||
self.max_delay = max_delay
|
||||
self.lock = RLock()
|
||||
self._thread_local = local()
|
||||
|
||||
if isinstance(argument, AbstractBucket):
|
||||
limiter_lock = argument.limiter_lock()
|
||||
if limiter_lock is not None:
|
||||
self.lock = (limiter_lock, self.lock)
|
||||
|
||||
def buckets(self) -> List[AbstractBucket]:
|
||||
"""Get list of active buckets
|
||||
"""
|
||||
return self.bucket_factory.get_buckets()
|
||||
|
||||
def dispose(self, bucket: Union[int, AbstractBucket]) -> bool:
|
||||
"""Dispose/Remove a specific bucket,
|
||||
using bucket-id or bucket object as param
|
||||
"""
|
||||
return self.bucket_factory.dispose(bucket)
|
||||
|
||||
def _init_bucket_factory(
|
||||
self,
|
||||
argument: Union[BucketFactory, AbstractBucket, Rate, List[Rate]],
|
||||
clock: AbstractClock,
|
||||
) -> BucketFactory:
|
||||
if isinstance(argument, Rate):
|
||||
argument = [argument]
|
||||
|
||||
if isinstance(argument, list):
|
||||
assert len(argument) > 0, "Rates must not be empty"
|
||||
assert isinstance(argument[0], Rate), "Not valid rates list"
|
||||
rates = argument
|
||||
logger.info("Initializing default bucket(InMemoryBucket) with rates: %s", rates)
|
||||
argument = InMemoryBucket(rates)
|
||||
|
||||
if isinstance(argument, AbstractBucket):
|
||||
argument = SingleBucketFactory(argument, clock)
|
||||
|
||||
assert isinstance(argument, BucketFactory), "Not a valid bucket/bucket-factory"
|
||||
|
||||
return argument
|
||||
|
||||
def _raise_bucket_full_if_necessary(
|
||||
self,
|
||||
bucket: AbstractBucket,
|
||||
item: RateItem,
|
||||
):
|
||||
if self.raise_when_fail:
|
||||
assert bucket.failing_rate is not None # NOTE: silence mypy
|
||||
raise BucketFullException(item, bucket.failing_rate)
|
||||
|
||||
def _raise_delay_exception_if_necessary(
|
||||
self,
|
||||
bucket: AbstractBucket,
|
||||
item: RateItem,
|
||||
delay: int,
|
||||
):
|
||||
if self.raise_when_fail:
|
||||
assert bucket.failing_rate is not None # NOTE: silence mypy
|
||||
assert isinstance(self.max_delay, int)
|
||||
raise LimiterDelayException(
|
||||
item,
|
||||
bucket.failing_rate,
|
||||
delay,
|
||||
self.max_delay,
|
||||
)
|
||||
|
||||
def delay_or_raise(
|
||||
self,
|
||||
bucket: AbstractBucket,
|
||||
item: RateItem,
|
||||
) -> Union[bool, Awaitable[bool]]:
|
||||
"""On `try_acquire` failed, handle delay or raise error immediately"""
|
||||
assert bucket.failing_rate is not None
|
||||
|
||||
if self.max_delay is None:
|
||||
self._raise_bucket_full_if_necessary(bucket, item)
|
||||
return False
|
||||
|
||||
delay = bucket.waiting(item)
|
||||
|
||||
def _handle_reacquire(re_acquire: bool) -> bool:
|
||||
if not re_acquire:
|
||||
logger.error("""Failed to re-acquire after the expected delay. If it failed,
|
||||
either clock or bucket is unstable.
|
||||
If asyncio, use try_acquire_async(). If multiprocessing,
|
||||
use retry_until_max_delay=True.""")
|
||||
self._raise_bucket_full_if_necessary(bucket, item)
|
||||
|
||||
return re_acquire
|
||||
|
||||
if isawaitable(delay):
|
||||
|
||||
async def _handle_async():
|
||||
nonlocal delay
|
||||
delay = await delay
|
||||
assert isinstance(delay, int), "Delay not integer"
|
||||
|
||||
total_delay = 0
|
||||
delay += self.buffer_ms
|
||||
|
||||
while True:
|
||||
total_delay += delay
|
||||
|
||||
if self.retry_until_max_delay:
|
||||
if self.max_delay is not None and total_delay > self.max_delay:
|
||||
logger.error("Total delay exceeded max_delay: total_delay=%s, max_delay=%s",
|
||||
total_delay, self.max_delay)
|
||||
self._raise_delay_exception_if_necessary(bucket, item, total_delay)
|
||||
return False
|
||||
else:
|
||||
if self.max_delay is not None and delay > self.max_delay:
|
||||
logger.error(
|
||||
"Required delay too large: actual=%s, expected=%s",
|
||||
delay,
|
||||
self.max_delay,
|
||||
)
|
||||
self._raise_delay_exception_if_necessary(bucket, item, delay)
|
||||
return False
|
||||
|
||||
await asyncio.sleep(delay / 1000)
|
||||
item.timestamp += delay
|
||||
re_acquire = bucket.put(item)
|
||||
|
||||
if isawaitable(re_acquire):
|
||||
re_acquire = await re_acquire
|
||||
|
||||
if not self.retry_until_max_delay:
|
||||
return _handle_reacquire(re_acquire)
|
||||
elif re_acquire:
|
||||
return True
|
||||
|
||||
return _handle_async()
|
||||
|
||||
assert isinstance(delay, int)
|
||||
|
||||
if delay < 0:
|
||||
logger.error(
|
||||
"Cannot fit item into bucket: item=%s, rate=%s, bucket=%s",
|
||||
item,
|
||||
bucket.failing_rate,
|
||||
bucket,
|
||||
)
|
||||
self._raise_bucket_full_if_necessary(bucket, item)
|
||||
return False
|
||||
|
||||
total_delay = 0
|
||||
|
||||
while True:
|
||||
logger.debug("delay=%d, total_delay=%s", delay, total_delay)
|
||||
delay = bucket.waiting(item)
|
||||
assert isinstance(delay, int)
|
||||
|
||||
delay += self.buffer_ms
|
||||
total_delay += delay
|
||||
|
||||
if self.max_delay is not None and total_delay > self.max_delay:
|
||||
logger.error(
|
||||
"Required delay too large: actual=%s, expected=%s",
|
||||
delay,
|
||||
self.max_delay,
|
||||
)
|
||||
|
||||
if self.retry_until_max_delay:
|
||||
self._raise_delay_exception_if_necessary(bucket, item, total_delay)
|
||||
else:
|
||||
self._raise_delay_exception_if_necessary(bucket, item, delay)
|
||||
|
||||
return False
|
||||
|
||||
sleep(delay / 1000)
|
||||
item.timestamp += delay
|
||||
re_acquire = bucket.put(item)
|
||||
# NOTE: if delay is not Awaitable, then `bucket.put` is not Awaitable
|
||||
assert isinstance(re_acquire, bool)
|
||||
|
||||
if not self.retry_until_max_delay:
|
||||
return _handle_reacquire(re_acquire)
|
||||
elif re_acquire:
|
||||
return True
|
||||
|
||||
def handle_bucket_put(
|
||||
self,
|
||||
bucket: AbstractBucket,
|
||||
item: RateItem,
|
||||
) -> Union[bool, Awaitable[bool]]:
|
||||
"""Putting item into bucket"""
|
||||
|
||||
def _handle_result(is_success: bool):
|
||||
if not is_success:
|
||||
return self.delay_or_raise(bucket, item)
|
||||
|
||||
return True
|
||||
|
||||
acquire = bucket.put(item)
|
||||
|
||||
if isawaitable(acquire):
|
||||
|
||||
async def _put_async():
|
||||
nonlocal acquire
|
||||
acquire = await acquire
|
||||
result = _handle_result(acquire)
|
||||
|
||||
while isawaitable(result):
|
||||
result = await result
|
||||
|
||||
return result
|
||||
|
||||
return _put_async()
|
||||
|
||||
return _handle_result(acquire) # type: ignore
|
||||
|
||||
def _get_async_lock(self):
|
||||
"""Must be called before first try_acquire_async for each thread"""
|
||||
try:
|
||||
return self._thread_local.async_lock
|
||||
except AttributeError:
|
||||
lock = asyncio.Lock()
|
||||
self._thread_local.async_lock = lock
|
||||
return lock
|
||||
|
||||
async def try_acquire_async(self, name: str, weight: int = 1) -> bool:
|
||||
"""
|
||||
async version of try_acquire.
|
||||
|
||||
This uses a top level, thread-local async lock to ensure that the async loop doesn't block
|
||||
|
||||
This does not make the entire code async: use an async bucket for that.
|
||||
"""
|
||||
async with self._get_async_lock():
|
||||
acquired = self.try_acquire(name=name, weight=weight)
|
||||
|
||||
if isawaitable(acquired):
|
||||
return await acquired
|
||||
else:
|
||||
logger.warning("async call made without an async bucket.")
|
||||
return acquired
|
||||
|
||||
def try_acquire(self, name: str, weight: int = 1) -> Union[bool, Awaitable[bool]]:
|
||||
"""Try acquiring an item with name & weight
|
||||
Return true on success, false on failure
|
||||
"""
|
||||
with self.lock if not isinstance(self.lock, Iterable) else combined_lock(self.lock):
|
||||
assert weight >= 0, "item's weight must be >= 0"
|
||||
|
||||
if weight == 0:
|
||||
# NOTE: if item is weightless, just let it go through
|
||||
# NOTE: this might change in the future
|
||||
return True
|
||||
|
||||
item = self.bucket_factory.wrap_item(name, weight)
|
||||
|
||||
if isawaitable(item):
|
||||
|
||||
async def _handle_async():
|
||||
nonlocal item
|
||||
item = await item
|
||||
bucket = self.bucket_factory.get(item)
|
||||
if isawaitable(bucket):
|
||||
bucket = await bucket
|
||||
assert isinstance(bucket, AbstractBucket), f"Invalid bucket: item: {name}"
|
||||
result = self.handle_bucket_put(bucket, item)
|
||||
|
||||
while isawaitable(result):
|
||||
result = await result
|
||||
|
||||
return result
|
||||
|
||||
return _handle_async()
|
||||
|
||||
assert isinstance(item, RateItem) # NOTE: this is to silence mypy warning
|
||||
bucket = self.bucket_factory.get(item)
|
||||
if isawaitable(bucket):
|
||||
|
||||
async def _handle_async_bucket():
|
||||
nonlocal bucket
|
||||
bucket = await bucket
|
||||
assert isinstance(bucket, AbstractBucket), f"Invalid bucket: item: {name}"
|
||||
result = self.handle_bucket_put(bucket, item)
|
||||
|
||||
while isawaitable(result):
|
||||
result = await result
|
||||
|
||||
return result
|
||||
|
||||
return _handle_async_bucket()
|
||||
|
||||
assert isinstance(bucket, AbstractBucket), f"Invalid bucket: item: {name}"
|
||||
result = self.handle_bucket_put(bucket, item)
|
||||
|
||||
if isawaitable(result):
|
||||
|
||||
async def _handle_async_result():
|
||||
nonlocal result
|
||||
while isawaitable(result):
|
||||
result = await result
|
||||
|
||||
return result
|
||||
|
||||
return _handle_async_result()
|
||||
|
||||
return result
|
||||
|
||||
def as_decorator(self) -> Callable[[ItemMapping], DecoratorWrapper]:
|
||||
"""Use limiter decorator
|
||||
Use with both sync & async function
|
||||
"""
|
||||
def with_mapping_func(mapping: ItemMapping) -> DecoratorWrapper:
|
||||
def decorator_wrapper(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
|
||||
"""Actual function wrapper"""
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper_async(*args, **kwargs):
|
||||
(name, weight) = mapping(*args, **kwargs)
|
||||
assert isinstance(name, str), "Mapping name is expected but not found"
|
||||
assert isinstance(weight, int), "Mapping weight is expected but not found"
|
||||
accquire_ok = self.try_acquire_async(name, weight)
|
||||
|
||||
if isawaitable(accquire_ok):
|
||||
await accquire_ok
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper_async
|
||||
else:
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
(name, weight) = mapping(*args, **kwargs)
|
||||
assert isinstance(name, str), "Mapping name is expected but not found"
|
||||
assert isinstance(weight, int), "Mapping weight is expected but not found"
|
||||
accquire_ok = self.try_acquire(name, weight)
|
||||
|
||||
if not isawaitable(accquire_ok):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
async def _handle_accquire_async():
|
||||
nonlocal accquire_ok
|
||||
accquire_ok = await accquire_ok
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
if isawaitable(result):
|
||||
return await result
|
||||
|
||||
return result
|
||||
|
||||
return _handle_accquire_async()
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator_wrapper
|
||||
|
||||
return with_mapping_func
|
||||
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
A collection of common use cases and patterns for pyrate_limiter
|
||||
"""
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from pyrate_limiter import AbstractBucket
|
||||
from pyrate_limiter import BucketAsyncWrapper
|
||||
from pyrate_limiter import Duration
|
||||
from pyrate_limiter import InMemoryBucket
|
||||
from pyrate_limiter import Limiter
|
||||
from pyrate_limiter import Rate
|
||||
from pyrate_limiter import SQLiteBucket
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Global for convenience in multiprocessing, populated by init_mp_limiter.
|
||||
# Intended to be called by a ProcessPoolExecutor's initializer
|
||||
LIMITER: Optional[Limiter] = None
|
||||
|
||||
|
||||
def create_sqlite_bucket(
|
||||
rates: List[Rate],
|
||||
db_path: Optional[str],
|
||||
table_name: str = "pyrate_limiter",
|
||||
use_file_lock: bool = False,
|
||||
):
|
||||
"""
|
||||
Create and initialize a SQLite bucket for rate limiting.
|
||||
|
||||
Args:
|
||||
rates: List of rate limit configurations.
|
||||
db_path: Path to the SQLite database file (or in-memory if None).
|
||||
table_name: Name of the table to store rate bucket data.
|
||||
use_file_lock: Enable file locking for multi-process synchronization.
|
||||
|
||||
Returns:
|
||||
SQLiteBucket: Initialized SQLite-backed bucket.
|
||||
"""
|
||||
logger.info(f"{table_name=}")
|
||||
bucket = SQLiteBucket.init_from_file(
|
||||
rates,
|
||||
db_path=str(db_path),
|
||||
table=table_name,
|
||||
create_new_table=True,
|
||||
use_file_lock=use_file_lock,
|
||||
)
|
||||
|
||||
return bucket
|
||||
|
||||
|
||||
def create_sqlite_limiter(
|
||||
rate_per_duration: int = 3,
|
||||
duration: Union[int, Duration] = Duration.SECOND,
|
||||
db_path: Optional[str] = None,
|
||||
table_name: str = "rate_bucket",
|
||||
max_delay: Union[int, Duration] = Duration.DAY,
|
||||
buffer_ms: int = 50,
|
||||
use_file_lock: bool = False,
|
||||
async_wrapper: bool = False,
|
||||
) -> Limiter:
|
||||
"""
|
||||
Create a SQLite-backed rate limiter with configurable rate, persistence, and optional async support.
|
||||
|
||||
Args:
|
||||
rate_per_duration: Number of allowed requests per duration.
|
||||
duration: Time window for the rate limit.
|
||||
db_path: Path to the SQLite database file (or in-memory if None).
|
||||
table_name: Name of the table used for rate buckets.
|
||||
max_delay: Maximum delay before failing requests.
|
||||
buffer_ms: Extra wait time in milliseconds to account for clock drift.
|
||||
use_file_lock: Enable file locking for multi-process synchronization.
|
||||
async_wrapper: Whether to wrap the bucket for async usage.
|
||||
|
||||
Returns:
|
||||
Limiter: Configured SQLite-backed limiter instance.
|
||||
"""
|
||||
rate = Rate(rate_per_duration, duration)
|
||||
rate_limits = [rate]
|
||||
|
||||
bucket: AbstractBucket = SQLiteBucket.init_from_file(
|
||||
rate_limits,
|
||||
db_path=str(db_path),
|
||||
table=table_name,
|
||||
create_new_table=True,
|
||||
use_file_lock=use_file_lock,
|
||||
)
|
||||
|
||||
if async_wrapper:
|
||||
bucket = BucketAsyncWrapper(bucket)
|
||||
|
||||
limiter = Limiter(
|
||||
bucket, raise_when_fail=False, max_delay=max_delay, retry_until_max_delay=True, buffer_ms=buffer_ms
|
||||
)
|
||||
|
||||
return limiter
|
||||
|
||||
|
||||
def create_inmemory_limiter(
|
||||
rate_per_duration: int = 3,
|
||||
duration: Union[int, Duration] = Duration.SECOND,
|
||||
max_delay: Union[int, Duration] = Duration.DAY,
|
||||
buffer_ms: int = 50,
|
||||
async_wrapper: bool = False,
|
||||
) -> Limiter:
|
||||
"""
|
||||
Create an in-memory rate limiter with configurable rate, duration, delay, and optional async support.
|
||||
|
||||
Args:
|
||||
rate_per_duration: Number of allowed requests per duration.
|
||||
duration: Time window for the rate limit.
|
||||
max_delay: Maximum delay before failing requests.
|
||||
buffer_ms: Extra wait time in milliseconds to account for clock drift.
|
||||
async_wrapper: Whether to wrap the bucket for async usage.
|
||||
|
||||
Returns:
|
||||
Limiter: Configured in-memory limiter instance.
|
||||
"""
|
||||
rate = Rate(rate_per_duration, duration)
|
||||
rate_limits = [rate]
|
||||
bucket: AbstractBucket = InMemoryBucket(rate_limits)
|
||||
|
||||
if async_wrapper:
|
||||
bucket = BucketAsyncWrapper(InMemoryBucket(rate_limits))
|
||||
|
||||
limiter = Limiter(
|
||||
bucket, raise_when_fail=False, max_delay=max_delay, retry_until_max_delay=True, buffer_ms=buffer_ms
|
||||
)
|
||||
|
||||
return limiter
|
||||
|
||||
|
||||
def init_global_limiter(bucket: AbstractBucket,
|
||||
max_delay: Union[int, Duration] = Duration.HOUR,
|
||||
raise_when_fail: bool = False,
|
||||
retry_until_max_delay: bool = True,
|
||||
buffer_ms: int = 50):
|
||||
"""
|
||||
Initialize a global Limiter instance using the provided bucket.
|
||||
|
||||
Intended for use as an initializer for ProcessPoolExecutor.
|
||||
|
||||
Args:
|
||||
bucket: The rate-limiting bucket to be used.
|
||||
max_delay: Maximum delay before failing requests.
|
||||
raise_when_fail: Whether to raise an exception when a request fails.
|
||||
retry_until_max_delay: Retry until the maximum delay is reached.
|
||||
buffer_ms: Additional buffer time in milliseconds for retries.
|
||||
"""
|
||||
|
||||
global LIMITER
|
||||
LIMITER = Limiter(bucket, raise_when_fail=raise_when_fail,
|
||||
max_delay=max_delay, retry_until_max_delay=retry_until_max_delay, buffer_ms=buffer_ms)
|
||||
83
venv/lib/python3.10/site-packages/pyrate_limiter/utils.py
Normal file
83
venv/lib/python3.10/site-packages/pyrate_limiter/utils.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import random
|
||||
import sqlite3
|
||||
import string
|
||||
from pathlib import Path
|
||||
from tempfile import gettempdir
|
||||
from typing import List
|
||||
|
||||
from .abstracts import Rate
|
||||
from .abstracts import RateItem
|
||||
|
||||
|
||||
def binary_search(items: List[RateItem], value: int) -> int:
|
||||
"""Find the index of item in list where left.timestamp < value <= right.timestamp
|
||||
this is to determine the current size of some window that
|
||||
stretches from now back to lower-boundary = value and
|
||||
"""
|
||||
if not items:
|
||||
return 0
|
||||
|
||||
if value > items[-1].timestamp:
|
||||
return -1
|
||||
|
||||
if value <= items[0].timestamp:
|
||||
return 0
|
||||
|
||||
if len(items) == 2:
|
||||
return 1
|
||||
|
||||
left_pointer, right_pointer, mid = 0, len(items) - 1, -2
|
||||
|
||||
while left_pointer <= right_pointer:
|
||||
mid = (left_pointer + right_pointer) // 2
|
||||
left, right = items[mid - 1].timestamp, items[mid].timestamp
|
||||
|
||||
if left < value <= right:
|
||||
break
|
||||
|
||||
if left >= value:
|
||||
right_pointer = mid
|
||||
|
||||
if right < value:
|
||||
left_pointer = mid + 1
|
||||
|
||||
return mid
|
||||
|
||||
|
||||
def validate_rate_list(rates: List[Rate]) -> bool:
|
||||
"""Raise false if rates are incorrectly ordered."""
|
||||
if not rates:
|
||||
return False
|
||||
|
||||
for idx, current_rate in enumerate(rates[1:]):
|
||||
prev_rate = rates[idx]
|
||||
|
||||
if current_rate.interval <= prev_rate.interval:
|
||||
return False
|
||||
|
||||
if current_rate.limit <= prev_rate.limit:
|
||||
return False
|
||||
|
||||
if (current_rate.limit / current_rate.interval) > (prev_rate.limit / prev_rate.interval):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def id_generator(
|
||||
size=6,
|
||||
chars=string.ascii_uppercase + string.digits + string.ascii_lowercase,
|
||||
) -> str:
|
||||
return "".join(random.choice(chars) for _ in range(size))
|
||||
|
||||
|
||||
def dedicated_sqlite_clock_connection():
|
||||
temp_dir = Path(gettempdir())
|
||||
default_db_path = temp_dir / "pyrate_limiter_clock_only.sqlite"
|
||||
|
||||
conn = sqlite3.connect(
|
||||
default_db_path,
|
||||
isolation_level="EXCLUSIVE",
|
||||
check_same_thread=False,
|
||||
)
|
||||
return conn
|
||||
Reference in New Issue
Block a user