Initial commit

This commit is contained in:
kdusek
2025-12-09 12:13:01 +01:00
commit 8e654ed209
13332 changed files with 2695056 additions and 0 deletions

View File

@@ -0,0 +1,7 @@
# flake8: noqa
from .abstracts import *
from .buckets import *
from .clocks import *
from .exceptions import *
from .limiter import *
from .utils import *

View File

@@ -0,0 +1,4 @@
from .bucket import * # noqa
from .clock import * # noqa
from .rate import * # noqa
from .wrappers import * # noqa

View File

@@ -0,0 +1,304 @@
""" Implement this class to create
a workable bucket for Limiter to use
"""
import asyncio
import logging
from abc import ABC
from abc import abstractmethod
from collections import defaultdict
from inspect import isawaitable
from inspect import iscoroutine
from threading import Thread
from typing import Awaitable
from typing import Dict
from typing import List
from typing import Optional
from typing import Type
from typing import Union
from .clock import AbstractClock
from .rate import Rate
from .rate import RateItem
logger = logging.getLogger("pyrate_limiter")
class AbstractBucket(ABC):
"""Base bucket interface
Assumption: len(rates) always > 0
TODO: allow empty rates
"""
rates: List[Rate]
failing_rate: Optional[Rate] = None
@abstractmethod
def put(self, item: RateItem) -> Union[bool, Awaitable[bool]]:
"""Put an item (typically the current time) in the bucket
return true if successful, otherwise false
"""
@abstractmethod
def leak(
self,
current_timestamp: Optional[int] = None,
) -> Union[int, Awaitable[int]]:
"""leaking bucket - removing items that are outdated"""
@abstractmethod
def flush(self) -> Union[None, Awaitable[None]]:
"""Flush the whole bucket
- Must remove `failing-rate` after flushing
"""
@abstractmethod
def count(self) -> Union[int, Awaitable[int]]:
"""Count number of items in the bucket"""
@abstractmethod
def peek(self, index: int) -> Union[Optional[RateItem], Awaitable[Optional[RateItem]]]:
"""Peek at the rate-item at a specific index in latest-to-earliest order
NOTE: The reason we cannot peek from the start of the queue(earliest-to-latest) is
we can't really tell how many outdated items are still in the queue
"""
def waiting(self, item: RateItem) -> Union[int, Awaitable[int]]:
"""Calculate time until bucket become availabe to consume an item again"""
if self.failing_rate is None:
return 0
assert item.weight > 0, "Item's weight must > 0"
if item.weight > self.failing_rate.limit:
return -1
bound_item = self.peek(self.failing_rate.limit - item.weight)
if bound_item is None:
# NOTE: No waiting, bucket is immediately ready
return 0
def _calc_waiting(inner_bound_item: RateItem) -> int:
assert self.failing_rate is not None # NOTE: silence mypy
lower_time_bound = item.timestamp - self.failing_rate.interval
upper_time_bound = inner_bound_item.timestamp
return upper_time_bound - lower_time_bound
async def _calc_waiting_async() -> int:
nonlocal bound_item
while isawaitable(bound_item):
bound_item = await bound_item
if bound_item is None:
# NOTE: No waiting, bucket is immediately ready
return 0
assert isinstance(bound_item, RateItem)
return _calc_waiting(bound_item)
if isawaitable(bound_item):
return _calc_waiting_async()
assert isinstance(bound_item, RateItem)
return _calc_waiting(bound_item)
def limiter_lock(self) -> Optional[object]: # type: ignore
"""An additional lock to be used by Limiter in-front of the thread lock.
Intended for multiprocessing environments where a thread lock is insufficient.
"""
return None
class Leaker(Thread):
"""Responsible for scheduling buckets' leaking at the background either
through a daemon task(for sync buckets) or a task using asyncio.Task
"""
daemon = True
name = "PyrateLimiter's Leaker"
sync_buckets: Optional[Dict[int, AbstractBucket]] = None
async_buckets: Optional[Dict[int, AbstractBucket]] = None
clocks: Optional[Dict[int, AbstractClock]] = None
leak_interval: int = 10_000
aio_leak_task: Optional[asyncio.Task] = None
def __init__(self, leak_interval: int):
self.sync_buckets = defaultdict()
self.async_buckets = defaultdict()
self.clocks = defaultdict()
self.leak_interval = leak_interval
super().__init__()
def register(self, bucket: AbstractBucket, clock: AbstractClock):
"""Register a new bucket with its associated clock"""
assert self.sync_buckets is not None
assert self.clocks is not None
assert self.async_buckets is not None
try_leak = bucket.leak(0)
bucket_id = id(bucket)
if iscoroutine(try_leak):
try_leak.close()
self.async_buckets[bucket_id] = bucket
else:
self.sync_buckets[bucket_id] = bucket
self.clocks[bucket_id] = clock
def deregister(self, bucket_id: int) -> bool:
"""Deregister a bucket"""
if self.sync_buckets and bucket_id in self.sync_buckets:
del self.sync_buckets[bucket_id]
assert self.clocks
del self.clocks[bucket_id]
return True
if self.async_buckets and bucket_id in self.async_buckets:
del self.async_buckets[bucket_id]
assert self.clocks
del self.clocks[bucket_id]
if not self.async_buckets and self.aio_leak_task:
self.aio_leak_task.cancel()
self.aio_leak_task = None
return True
return False
async def _leak(self, buckets: Dict[int, AbstractBucket]) -> None:
assert self.clocks
while buckets:
try:
for bucket_id, bucket in list(buckets.items()):
clock = self.clocks[bucket_id]
now = clock.now()
while isawaitable(now):
now = await now
assert isinstance(now, int)
leak = bucket.leak(now)
while isawaitable(leak):
leak = await leak
assert isinstance(leak, int)
await asyncio.sleep(self.leak_interval / 1000)
except RuntimeError as e:
logger.info("Leak task stopped due to event loop shutdown. %s", e)
return
def leak_async(self):
if self.async_buckets and not self.aio_leak_task:
self.aio_leak_task = asyncio.create_task(self._leak(self.async_buckets))
def run(self) -> None:
""" Override the original method of Thread
Not meant to be called directly
"""
assert self.sync_buckets
asyncio.run(self._leak(self.sync_buckets))
def start(self) -> None:
""" Override the original method of Thread
Call to run leaking sync buckets
"""
if self.sync_buckets and not self.is_alive():
super().start()
class BucketFactory(ABC):
"""Asbtract BucketFactory class.
It is reserved for user to implement/override this class with
his own bucket-routing/creating logic
"""
_leaker: Optional[Leaker] = None
_leak_interval: int = 10_000
@property
def leak_interval(self) -> int:
"""Retrieve leak-interval from inner Leaker task"""
if not self._leaker:
return self._leak_interval
return self._leaker.leak_interval
@leak_interval.setter
def leak_interval(self, value: int):
"""Set leak-interval for inner Leaker task"""
if self._leaker:
self._leaker.leak_interval = value
self._leak_interval = value
@abstractmethod
def wrap_item(
self,
name: str,
weight: int = 1,
) -> Union[RateItem, Awaitable[RateItem]]:
"""Add the current timestamp to the receiving item using any clock backend
- Turn it into a RateItem
- Can return either a coroutine or a RateItem instance
"""
@abstractmethod
def get(self, item: RateItem) -> Union[AbstractBucket, Awaitable[AbstractBucket]]:
"""Get the corresponding bucket to this item"""
def create(
self,
clock: AbstractClock,
bucket_class: Type[AbstractBucket],
*args,
**kwargs,
) -> AbstractBucket:
"""Creating a bucket dynamically"""
bucket = bucket_class(*args, **kwargs)
self.schedule_leak(bucket, clock)
return bucket
def schedule_leak(self, new_bucket: AbstractBucket, associated_clock: AbstractClock) -> None:
"""Schedule all the buckets' leak, reset bucket's failing rate"""
assert new_bucket.rates, "Bucket rates are not set"
if not self._leaker:
self._leaker = Leaker(self.leak_interval)
self._leaker.register(new_bucket, associated_clock)
self._leaker.start()
self._leaker.leak_async()
def get_buckets(self) -> List[AbstractBucket]:
"""Iterator over all buckets in the factory
"""
if not self._leaker:
return []
buckets = []
if self._leaker.sync_buckets:
for _, bucket in self._leaker.sync_buckets.items():
buckets.append(bucket)
if self._leaker.async_buckets:
for _, bucket in self._leaker.async_buckets.items():
buckets.append(bucket)
return buckets
def dispose(self, bucket: Union[int, AbstractBucket]) -> bool:
"""Delete a bucket from the factory"""
if isinstance(bucket, AbstractBucket):
bucket = id(bucket)
assert isinstance(bucket, int), "not valid bucket id"
if not self._leaker:
return False
return self._leaker.deregister(bucket)

View File

@@ -0,0 +1,12 @@
from abc import ABC
from abc import abstractmethod
from typing import Awaitable
from typing import Union
class AbstractClock(ABC):
"""Clock that return timestamp for `now`"""
@abstractmethod
def now(self) -> Union[int, Awaitable[int]]:
"""Get time as of now, in miliseconds"""

View File

@@ -0,0 +1,96 @@
"""Unit classes that deals with rate, item & duration
"""
from enum import Enum
from typing import Union
class Duration(Enum):
"""Interval helper class"""
SECOND = 1000
MINUTE = 1000 * 60
HOUR = 1000 * 60 * 60
DAY = 1000 * 60 * 60 * 24
WEEK = 1000 * 60 * 60 * 24 * 7
def __mul__(self, mutiplier: float) -> int:
return int(self.value * mutiplier)
def __rmul__(self, multiplier: float) -> int:
return self.__mul__(multiplier)
def __add__(self, another_duration: Union["Duration", int]) -> int:
return self.value + int(another_duration)
def __radd__(self, another_duration: Union["Duration", int]) -> int:
return self.__add__(another_duration)
def __int__(self) -> int:
return self.value
def __eq__(self, duration: object) -> bool:
if not isinstance(duration, (Duration, int)):
return NotImplemented
return self.value == int(duration)
@staticmethod
def readable(value: int) -> str:
notes = [
(Duration.WEEK, "w"),
(Duration.DAY, "d"),
(Duration.HOUR, "h"),
(Duration.MINUTE, "m"),
(Duration.SECOND, "s"),
]
for note, shorten in notes:
if value >= note.value:
decimal_value = value / note.value
return f"{decimal_value:0.1f}{shorten}" # noqa: E231
return f"{value}ms"
class RateItem:
"""RateItem is a wrapper for bucket to work with"""
name: str
weight: int
timestamp: int
def __init__(self, name: str, timestamp: int, weight: int = 1):
self.name = name
self.timestamp = timestamp
self.weight = weight
def __str__(self) -> str:
return f"RateItem(name={self.name}, weight={self.weight}, timestamp={self.timestamp})"
class Rate:
"""Rate definition.
Args:
limit: Number of requests allowed within ``interval``
interval: Time interval, in miliseconds
"""
limit: int
interval: int
def __init__(
self,
limit: int,
interval: Union[int, Duration],
):
self.limit = limit
self.interval = int(interval)
assert self.interval
assert self.limit
def __str__(self) -> str:
return f"limit={self.limit}/{Duration.readable(self.interval)}"
def __repr__(self) -> str:
return f"limit={self.limit}/{self.interval}"

View File

@@ -0,0 +1,77 @@
""" Wrappers over different abstract types
"""
from inspect import isawaitable
from typing import Optional
from .bucket import AbstractBucket
from .rate import RateItem
class BucketAsyncWrapper(AbstractBucket):
"""BucketAsyncWrapper is a wrapping over any bucket
that turns a async/synchronous bucket into an async one
"""
def __init__(self, bucket: AbstractBucket):
assert isinstance(bucket, AbstractBucket)
self.bucket = bucket
async def put(self, item: RateItem):
result = self.bucket.put(item)
while isawaitable(result):
result = await result
return result
async def count(self):
result = self.bucket.count()
while isawaitable(result):
result = await result
return result
async def leak(self, current_timestamp: Optional[int] = None) -> int:
result = self.bucket.leak(current_timestamp)
while isawaitable(result):
result = await result
assert isinstance(result, int)
return result
async def flush(self) -> None:
result = self.bucket.flush()
while isawaitable(result):
# TODO: AbstractBucket.flush() may not have correct type annotation?
result = await result # type: ignore
return None
async def peek(self, index: int) -> Optional[RateItem]:
item = self.bucket.peek(index)
while isawaitable(item):
item = await item
assert item is None or isinstance(item, RateItem)
return item
async def waiting(self, item: RateItem) -> int:
wait = super().waiting(item)
if isawaitable(wait):
wait = await wait
assert isinstance(wait, int)
return wait
@property
def failing_rate(self):
return self.bucket.failing_rate
@property
def rates(self):
return self.bucket.rates

View File

@@ -0,0 +1,10 @@
# flake8: noqa
"""Conrete bucket implementations
"""
from .in_memory_bucket import InMemoryBucket
from .mp_bucket import MultiprocessBucket
from .postgres import PostgresBucket
from .postgres import Queries as PgQueries
from .redis_bucket import RedisBucket
from .sqlite_bucket import Queries as SQLiteQueries
from .sqlite_bucket import SQLiteBucket

View File

@@ -0,0 +1,92 @@
"""Naive bucket implementation using built-in list
"""
from typing import List
from typing import Optional
from ..abstracts import AbstractBucket
from ..abstracts import Rate
from ..abstracts import RateItem
from ..utils import binary_search
class InMemoryBucket(AbstractBucket):
"""Simple In-memory Bucket using native list
Clock can be either `time.time` or `time.monotonic`
When leak, clock is required
Pros: fast, safe, and precise
Cons: since it resides in local memory, the data is not persistent, nor scalable
Usecase: small applications, simple logic
"""
items: List[RateItem]
failing_rate: Optional[Rate]
def __init__(self, rates: List[Rate]):
self.rates = sorted(rates, key=lambda r: r.interval)
self.items = []
def put(self, item: RateItem) -> bool:
if item.weight == 0:
return True
current_length = len(self.items)
after_length = item.weight + current_length
for rate in self.rates:
if after_length < rate.limit:
break
lower_bound_value = item.timestamp - rate.interval
lower_bound_idx = binary_search(self.items, lower_bound_value)
if lower_bound_idx >= 0:
count_existing_items = len(self.items) - lower_bound_idx
space_available = rate.limit - count_existing_items
else:
space_available = rate.limit
if space_available < item.weight:
self.failing_rate = rate
return False
self.failing_rate = None
if item.weight > 1:
self.items.extend([item for _ in range(item.weight)])
else:
self.items.append(item)
return True
def leak(self, current_timestamp: Optional[int] = None) -> int:
assert current_timestamp is not None
if self.items:
max_interval = self.rates[-1].interval
lower_bound = current_timestamp - max_interval
if lower_bound > self.items[-1].timestamp:
remove_count = len(self.items)
del self.items[:]
return remove_count
if lower_bound < self.items[0].timestamp:
return 0
idx = binary_search(self.items, lower_bound)
del self.items[:idx]
return idx
return 0
def flush(self) -> None:
self.failing_rate = None
del self.items[:]
def count(self) -> int:
return len(self.items)
def peek(self, index: int) -> Optional[RateItem]:
if not self.items:
return None
return self.items[-1 - index] if abs(index) < self.count() else None

View File

@@ -0,0 +1,53 @@
"""multiprocessing In-memory Bucket using a multiprocessing.Manager.ListProxy
and a multiprocessing.Lock.
"""
from multiprocessing import Manager
from multiprocessing import RLock
from multiprocessing.managers import ListProxy
from multiprocessing.synchronize import RLock as LockType
from typing import List
from typing import Optional
from ..abstracts import Rate
from ..abstracts import RateItem
from pyrate_limiter.buckets import InMemoryBucket
class MultiprocessBucket(InMemoryBucket):
items: List[RateItem] # ListProxy
mp_lock: LockType
def __init__(self, rates: List[Rate], items: List[RateItem], mp_lock: LockType):
if not isinstance(items, ListProxy):
raise ValueError("items must be a ListProxy")
self.rates = sorted(rates, key=lambda r: r.interval)
self.items = items
self.mp_lock = mp_lock
def put(self, item: RateItem) -> bool:
with self.mp_lock:
return super().put(item)
def leak(self, current_timestamp: Optional[int] = None) -> int:
with self.mp_lock:
return super().leak(current_timestamp)
def limiter_lock(self):
return self.mp_lock
@classmethod
def init(
cls,
rates: List[Rate],
):
"""
Creates a single ListProxy so that this bucket can be shared across multiple processes.
"""
shared_items: List[RateItem] = Manager().list() # type: ignore[assignment]
mp_lock: LockType = RLock()
return cls(rates=rates, items=shared_items, mp_lock=mp_lock)

View File

@@ -0,0 +1,166 @@
"""A bucket using PostgreSQL as backend
"""
from __future__ import annotations
from contextlib import contextmanager
from typing import Awaitable
from typing import List
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from ..abstracts import AbstractBucket
from ..abstracts import Rate
from ..abstracts import RateItem
if TYPE_CHECKING:
from psycopg_pool import ConnectionPool # type: ignore[import-untyped]
class Queries:
CREATE_BUCKET_TABLE = """
CREATE TABLE IF NOT EXISTS {table} (
name VARCHAR,
weight SMALLINT,
item_timestamp TIMESTAMP
)
"""
CREATE_INDEX_ON_TIMESTAMP = """
CREATE INDEX IF NOT EXISTS {index} ON {table} (item_timestamp)
"""
COUNT = """
SELECT COUNT(*) FROM {table}
"""
PUT = """
INSERT INTO {table} (name, weight, item_timestamp) VALUES (%s, %s, TO_TIMESTAMP(%s))
"""
FLUSH = """
DELETE FROM {table}
"""
PEEK = """
SELECT name, weight, (extract(EPOCH FROM item_timestamp) * 1000) as item_timestamp
FROM {table}
ORDER BY item_timestamp DESC
LIMIT 1
OFFSET {offset}
"""
LEAK = """
DELETE FROM {table} WHERE item_timestamp < TO_TIMESTAMP({timestamp})
"""
LEAK_COUNT = """
SELECT COUNT(*) FROM {table} WHERE item_timestamp < TO_TIMESTAMP({timestamp})
"""
class PostgresBucket(AbstractBucket):
table: str
pool: ConnectionPool
def __init__(self, pool: ConnectionPool, table: str, rates: List[Rate]):
self.table = table.lower()
self.pool = pool
assert rates
self.rates = rates
self._full_tbl = f'ratelimit___{self.table}'
self._create_table()
@contextmanager
def _get_conn(self):
with self.pool.connection() as conn:
yield conn
def _create_table(self):
with self._get_conn() as conn:
conn.execute(Queries.CREATE_BUCKET_TABLE.format(table=self._full_tbl))
index_name = f'timestampIndex_{self.table}'
conn.execute(Queries.CREATE_INDEX_ON_TIMESTAMP.format(table=self._full_tbl, index=index_name))
def put(self, item: RateItem) -> Union[bool, Awaitable[bool]]:
"""Put an item (typically the current time) in the bucket
return true if successful, otherwise false
"""
if item.weight == 0:
return True
with self._get_conn() as conn:
for rate in self.rates:
bound = f"SELECT TO_TIMESTAMP({item.timestamp / 1000}) - INTERVAL '{rate.interval} milliseconds'"
query = f'SELECT COUNT(*) FROM {self._full_tbl} WHERE item_timestamp >= ({bound})'
cur = conn.execute(query)
count = int(cur.fetchone()[0])
cur.close()
if rate.limit - count < item.weight:
self.failing_rate = rate
return False
self.failing_rate = None
query = Queries.PUT.format(table=self._full_tbl)
# https://www.psycopg.org/docs/extras.html#fast-exec
for _ in range(item.weight):
conn.execute(query, (item.name, item.weight, item.timestamp / 1000))
return True
def leak(
self,
current_timestamp: Optional[int] = None,
) -> Union[int, Awaitable[int]]:
"""leaking bucket - removing items that are outdated"""
assert current_timestamp is not None, "current-time must be passed on for leak"
lower_bound = current_timestamp - self.rates[-1].interval
if lower_bound <= 0:
return 0
count = 0
with self._get_conn() as conn:
conn = conn.execute(Queries.LEAK_COUNT.format(table=self._full_tbl, timestamp=lower_bound / 1000))
result = conn.fetchone()
if result:
conn.execute(Queries.LEAK.format(table=self._full_tbl, timestamp=lower_bound / 1000))
count = int(result[0])
return count
def flush(self) -> Union[None, Awaitable[None]]:
"""Flush the whole bucket
- Must remove `failing-rate` after flushing
"""
with self._get_conn() as conn:
conn.execute(Queries.FLUSH.format(table=self._full_tbl))
self.failing_rate = None
return None
def count(self) -> Union[int, Awaitable[int]]:
"""Count number of items in the bucket"""
count = 0
with self._get_conn() as conn:
conn = conn.execute(Queries.COUNT.format(table=self._full_tbl))
result = conn.fetchone()
assert result
count = int(result[0])
return count
def peek(self, index: int) -> Union[Optional[RateItem], Awaitable[Optional[RateItem]]]:
"""Peek at the rate-item at a specific index in latest-to-earliest order
NOTE: The reason we cannot peek from the start of the queue(earliest-to-latest) is
we can't really tell how many outdated items are still in the queue
"""
item = None
with self._get_conn() as conn:
conn = conn.execute(Queries.PEEK.format(table=self._full_tbl, offset=index))
result = conn.fetchone()
if result:
name, weight, timestamp = result[0], int(result[1]), int(result[2])
item = RateItem(name=name, weight=weight, timestamp=timestamp)
return item

View File

@@ -0,0 +1,187 @@
"""Bucket implementation using Redis
"""
from __future__ import annotations
from inspect import isawaitable
from typing import Awaitable
from typing import List
from typing import Optional
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
from ..abstracts import AbstractBucket
from ..abstracts import Rate
from ..abstracts import RateItem
from ..utils import id_generator
if TYPE_CHECKING:
from redis import Redis
from redis.asyncio import Redis as AsyncRedis
class LuaScript:
"""Scripts that deal with bucket operations"""
PUT_ITEM = """
local bucket = KEYS[1]
local now = ARGV[1]
local space_required = tonumber(ARGV[2])
local item_name = ARGV[3]
local rates_count = tonumber(ARGV[4])
for i=1,rates_count do
local offset = (i - 1) * 2
local interval = tonumber(ARGV[5 + offset])
local limit = tonumber(ARGV[5 + offset + 1])
local count = redis.call('ZCOUNT', bucket, now - interval, now)
local space_available = limit - tonumber(count)
if space_available < space_required then
return i - 1
end
end
for i=1,space_required do
redis.call('ZADD', bucket, now, item_name..i)
end
return -1
"""
class RedisBucket(AbstractBucket):
"""A bucket using redis for storing data
- We are not using redis' built-in TIME since it is non-deterministic
- In distributed context, use local server time or a remote time server
- Each bucket instance use a dedicated connection to avoid race-condition
- can be either sync or async
"""
rates: List[Rate]
failing_rate: Optional[Rate]
bucket_key: str
script_hash: str
redis: Union[Redis, AsyncRedis]
def __init__(
self,
rates: List[Rate],
redis: Union[Redis, AsyncRedis],
bucket_key: str,
script_hash: str,
):
self.rates = rates
self.redis = redis
self.bucket_key = bucket_key
self.script_hash = script_hash
self.failing_rate = None
@classmethod
def init(
cls,
rates: List[Rate],
redis: Union[Redis, AsyncRedis],
bucket_key: str,
):
script_hash = redis.script_load(LuaScript.PUT_ITEM)
if isawaitable(script_hash):
async def _async_init():
nonlocal script_hash
script_hash = await script_hash
return cls(rates, redis, bucket_key, script_hash)
return _async_init()
return cls(rates, redis, bucket_key, script_hash)
def _check_and_insert(self, item: RateItem) -> Union[Rate, None, Awaitable[Optional[Rate]]]:
keys = [self.bucket_key]
args = [
item.timestamp,
item.weight,
# NOTE: this is to avoid key collision since we are using ZSET
f"{item.name}:{id_generator()}:", # noqa: E231
len(self.rates),
*[value for rate in self.rates for value in (rate.interval, rate.limit)],
]
idx = self.redis.evalsha(self.script_hash, len(keys), *keys, *args)
def _handle_sync(returned_idx: int):
assert isinstance(returned_idx, int), "Not int"
if returned_idx < 0:
return None
return self.rates[returned_idx]
async def _handle_async(returned_idx: Awaitable[int]):
assert isawaitable(returned_idx), "Not corotine"
awaited_idx = await returned_idx
return _handle_sync(awaited_idx)
return _handle_async(idx) if isawaitable(idx) else _handle_sync(idx)
def put(self, item: RateItem) -> Union[bool, Awaitable[bool]]:
"""Add item to key"""
failing_rate = self._check_and_insert(item)
if isawaitable(failing_rate):
async def _handle_async():
self.failing_rate = await failing_rate
return not bool(self.failing_rate)
return _handle_async()
assert isinstance(failing_rate, Rate) or failing_rate is None
self.failing_rate = failing_rate
return not bool(self.failing_rate)
def leak(self, current_timestamp: Optional[int] = None) -> Union[int, Awaitable[int]]:
assert current_timestamp is not None
return self.redis.zremrangebyscore(
self.bucket_key,
0,
current_timestamp - self.rates[-1].interval,
)
def flush(self):
self.failing_rate = None
return self.redis.delete(self.bucket_key)
def count(self):
return self.redis.zcard(self.bucket_key)
def peek(self, index: int) -> Union[RateItem, None, Awaitable[Optional[RateItem]]]:
items = self.redis.zrange(
self.bucket_key,
-1 - index,
-1 - index,
withscores=True,
score_cast_func=int,
)
if not items:
return None
def _handle_items(received_items: List[Tuple[str, int]]):
if not received_items:
return None
item = received_items[0]
rate_item = RateItem(name=str(item[0]), timestamp=item[1])
return rate_item
if isawaitable(items):
async def _awaiting():
nonlocal items
items = await items
return _handle_items(items)
return _awaiting()
assert isinstance(items, list)
return _handle_items(items)

View File

@@ -0,0 +1,250 @@
"""Bucket implementation using SQLite"""
import logging
import sqlite3
from contextlib import nullcontext
from pathlib import Path
from tempfile import gettempdir
from threading import RLock
from time import time
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
from ..abstracts import AbstractBucket
from ..abstracts import Rate
from ..abstracts import RateItem
logger = logging.getLogger(__name__)
class Queries:
CREATE_BUCKET_TABLE = """
CREATE TABLE IF NOT EXISTS '{table}' (
name VARCHAR,
item_timestamp INTEGER
)
"""
CREATE_INDEX_ON_TIMESTAMP = """
CREATE INDEX IF NOT EXISTS '{index_name}' ON '{table_name}' (item_timestamp)
"""
COUNT_BEFORE_INSERT = """
SELECT :interval{index} as interval, COUNT(*) FROM '{table}'
WHERE item_timestamp >= :current_timestamp - :interval{index}
"""
PUT_ITEM = """
INSERT INTO '{table}' (name, item_timestamp) VALUES %s
"""
LEAK = """
DELETE FROM "{table}" WHERE rowid IN (
SELECT rowid FROM "{table}" ORDER BY item_timestamp ASC LIMIT {count});
""".strip()
COUNT_BEFORE_LEAK = """SELECT COUNT(*) FROM '{table}' WHERE item_timestamp < {current_timestamp} - {interval}"""
FLUSH = """DELETE FROM '{table}'"""
# The below sqls are for testing only
DROP_TABLE = "DROP TABLE IF EXISTS '{table}'"
DROP_INDEX = "DROP INDEX IF EXISTS '{index}'"
COUNT_ALL = "SELECT COUNT(*) FROM '{table}'"
GET_ALL_ITEM = "SELECT * FROM '{table}' ORDER BY item_timestamp ASC"
GET_FIRST_ITEM = (
"SELECT name, item_timestamp FROM '{table}' ORDER BY item_timestamp ASC"
)
GET_LAG = """
SELECT (strftime ('%s', 'now') || substr(strftime ('%f', 'now'), 4)) - (
SELECT item_timestamp
FROM '{table}'
ORDER BY item_timestamp
ASC
LIMIT 1
)
"""
PEEK = 'SELECT * FROM "{table}" ORDER BY item_timestamp DESC LIMIT 1 OFFSET {count}'
class SQLiteBucket(AbstractBucket):
"""For sqlite bucket, we are using the sql time function as the clock
item's timestamp wont matter here
"""
rates: List[Rate]
failing_rate: Optional[Rate]
conn: sqlite3.Connection
table: str
full_count_query: str
lock: RLock
use_limiter_lock: bool
def __init__(
self, rates: List[Rate], conn: sqlite3.Connection, table: str, lock=None
):
self.conn = conn
self.table = table
self.rates = rates
if not lock:
self.use_limiter_lock = False
self.lock = RLock()
else:
self.use_limiter_lock = True
self.lock = lock
def limiter_lock(self):
if self.use_limiter_lock:
return self.lock
else:
return None
def _build_full_count_query(self, current_timestamp: int) -> Tuple[str, dict]:
full_query: List[str] = []
parameters = {"current_timestamp": current_timestamp}
for index, rate in enumerate(self.rates):
parameters[f"interval{index}"] = rate.interval
query = Queries.COUNT_BEFORE_INSERT.format(table=self.table, index=index)
full_query.append(query)
join_full_query = (
" union ".join(full_query) if len(full_query) > 1 else full_query[0]
)
return join_full_query, parameters
def put(self, item: RateItem) -> bool:
with self.lock:
query, parameters = self._build_full_count_query(item.timestamp)
cur = self.conn.execute(query, parameters)
rate_limit_counts = cur.fetchall()
cur.close()
for idx, result in enumerate(rate_limit_counts):
interval, count = result
rate = self.rates[idx]
assert interval == rate.interval
space_available = rate.limit - count
if space_available < item.weight:
self.failing_rate = rate
return False
items = ", ".join(
[f"('{name}', {item.timestamp})" for name in [item.name] * item.weight]
)
query = (Queries.PUT_ITEM.format(table=self.table)) % items
self.conn.execute(query).close()
self.conn.commit()
return True
def leak(self, current_timestamp: Optional[int] = None) -> int:
"""Leaking/clean up bucket"""
with self.lock:
assert current_timestamp is not None
query = Queries.COUNT_BEFORE_LEAK.format(
table=self.table,
interval=self.rates[-1].interval,
current_timestamp=current_timestamp,
)
cur = self.conn.execute(query)
count = cur.fetchone()[0]
query = Queries.LEAK.format(table=self.table, count=count)
cur.execute(query)
cur.close()
self.conn.commit()
return count
def flush(self) -> None:
with self.lock:
self.conn.execute(Queries.FLUSH.format(table=self.table)).close()
self.conn.commit()
self.failing_rate = None
def count(self) -> int:
with self.lock:
cur = self.conn.execute(
Queries.COUNT_ALL.format(table=self.table)
)
ret = cur.fetchone()[0]
cur.close()
return ret
def peek(self, index: int) -> Optional[RateItem]:
with self.lock:
query = Queries.PEEK.format(table=self.table, count=index)
cur = self.conn.execute(query)
item = cur.fetchone()
cur.close()
if not item:
return None
return RateItem(item[0], item[1])
@classmethod
def init_from_file(
cls,
rates: List[Rate],
table: str = "rate_bucket",
db_path: Optional[str] = None,
create_new_table: bool = True,
use_file_lock: bool = False
) -> "SQLiteBucket":
if db_path is None and use_file_lock:
raise ValueError("db_path must be specified when using use_file_lock")
if db_path is None:
temp_dir = Path(gettempdir())
db_path = str(temp_dir / f"pyrate_limiter_{time()}.sqlite")
# TBD: FileLock switched to a thread-local FileLock in 3.11.0.
# Should we set FileLock's thread_local to False, for cases where user is both multiprocessing & threading?
# As is, the file lock should be Multi Process - Single Thread and non-filelock is Single Process - Multi Thread
# A hybrid lock may be needed to gracefully handle both cases
file_lock = None
file_lock_ctx = nullcontext()
if use_file_lock:
try:
from filelock import FileLock # type: ignore[import-untyped]
file_lock = FileLock(db_path + ".lock") # type: ignore[no-redef]
file_lock_ctx: Union[nullcontext, FileLock] = file_lock # type: ignore[no-redef]
except ImportError:
raise ImportError(
"filelock is required for file locking. "
"Please install it as optional dependency"
)
with file_lock_ctx:
assert db_path is not None
assert db_path.endswith(".sqlite"), (
"Please provide a valid sqlite file path"
)
sqlite_connection = sqlite3.connect(
db_path,
isolation_level="DEFERRED",
check_same_thread=False,
)
cur = sqlite_connection.cursor()
if use_file_lock:
# https://www.sqlite.org/wal.html
cur.execute("PRAGMA journal_mode=WAL;")
# https://www.sqlite.org/pragma.html#pragma_synchronous
cur.execute("PRAGMA synchronous=NORMAL;")
if create_new_table:
cur.execute(
Queries.CREATE_BUCKET_TABLE.format(table=table)
)
create_idx_query = Queries.CREATE_INDEX_ON_TIMESTAMP.format(
index_name=f"idx_{table}_rate_item_timestamp",
table_name=table,
)
cur.execute(create_idx_query)
cur.close()
sqlite_connection.commit()
return cls(rates, sqlite_connection, table=table, lock=file_lock)

View File

@@ -0,0 +1,90 @@
"""Clock implementation using different backend"""
from __future__ import annotations
import sqlite3
from contextlib import nullcontext
from time import monotonic
from time import time
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from .abstracts import AbstractClock
from .buckets import SQLiteBucket
from .utils import dedicated_sqlite_clock_connection
if TYPE_CHECKING:
from psycopg_pool import ConnectionPool
from threading import RLock
class MonotonicClock(AbstractClock):
def __init__(self):
monotonic()
def now(self):
return int(1000 * monotonic())
class TimeClock(AbstractClock):
def now(self):
return int(1000 * time())
class TimeAsyncClock(AbstractClock):
"""Time Async Clock, meant for testing only"""
async def now(self) -> int:
return int(1000 * time())
class SQLiteClock(AbstractClock):
"""Get timestamp using SQLite as remote clock backend"""
time_query = (
"SELECT CAST(ROUND((julianday('now') - 2440587.5)*86400000) As INTEGER)"
)
def __init__(self, conn: Union[sqlite3.Connection, SQLiteBucket]):
"""
In multiprocessing cases, use the bucket, so that a shared lock is used.
"""
self.lock: Optional[RLock] = None
if isinstance(conn, SQLiteBucket):
self.conn = conn.conn
self.lock = conn.lock
else:
self.conn = conn
@classmethod
def default(cls):
conn = dedicated_sqlite_clock_connection()
return cls(conn)
def now(self) -> int:
with self.lock if self.lock else nullcontext():
cur = self.conn.execute(self.time_query)
now = cur.fetchone()[0]
cur.close()
return int(now)
class PostgresClock(AbstractClock):
"""Get timestamp using Postgres as remote clock backend"""
def __init__(self, pool: "ConnectionPool"):
self.pool = pool
def now(self) -> int:
value = 0
with self.pool.connection() as conn:
with conn.cursor() as cur:
cur.execute("SELECT EXTRACT(epoch FROM current_timestamp) * 1000")
result = cur.fetchone()
assert result, "unable to get current-timestamp from postgres"
value = int(result[0])
return value

View File

@@ -0,0 +1,47 @@
# pylint: disable=C0114,C0115
from typing import Dict
from typing import Union
from .abstracts.rate import Rate
from .abstracts.rate import RateItem
class BucketFullException(Exception):
def __init__(self, item: RateItem, rate: Rate):
error = f"Bucket for item={item.name} with Rate {rate} is already full"
self.item = item
self.rate = rate
self.meta_info: Dict[str, Union[str, float]] = {
"error": error,
"name": item.name,
"weight": item.weight,
"rate": str(rate),
}
super().__init__(error)
def __reduce__(self):
return (self.__class__, (self.item, self.rate))
class LimiterDelayException(Exception):
def __init__(self, item: RateItem, rate: Rate, actual_delay: int, max_delay: int):
self.item = item
self.rate = rate
self.actual_delay = actual_delay
self.max_delay = max_delay
error = f"""
Actual delay exceeded allowance: actual={actual_delay}, allowed={max_delay}
Bucket for {item.name} with Rate {rate} is already full
"""
self.meta_info: Dict[str, Union[str, float]] = {
"error": error,
"name": item.name,
"weight": item.weight,
"rate": str(rate),
"max_delay": max_delay,
"actual_delay": actual_delay,
}
super().__init__(error)
def __reduce__(self):
return (self.__class__, (self.item, self.rate, self.actual_delay, self.max_delay))

View File

@@ -0,0 +1,482 @@
"""Limiter class implementation
"""
import asyncio
import logging
from contextlib import contextmanager
from functools import wraps
from inspect import isawaitable
from threading import local
from threading import RLock
from time import sleep
from typing import Any
from typing import Awaitable
from typing import Callable
from typing import Iterable
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
from .abstracts import AbstractBucket
from .abstracts import AbstractClock
from .abstracts import BucketFactory
from .abstracts import Duration
from .abstracts import Rate
from .abstracts import RateItem
from .buckets import InMemoryBucket
from .clocks import TimeClock
from .exceptions import BucketFullException
from .exceptions import LimiterDelayException
logger = logging.getLogger("pyrate_limiter")
ItemMapping = Callable[[Any], Tuple[str, int]]
DecoratorWrapper = Callable[[Callable[[Any], Any]], Callable[[Any], Any]]
class SingleBucketFactory(BucketFactory):
"""Single-bucket factory for quick use with Limiter"""
bucket: AbstractBucket
clock: AbstractClock
def __init__(self, bucket: AbstractBucket, clock: AbstractClock):
self.clock = clock
self.bucket = bucket
self.schedule_leak(bucket, clock)
def wrap_item(self, name: str, weight: int = 1):
now = self.clock.now()
async def wrap_async():
return RateItem(name, await now, weight=weight)
def wrap_sync():
return RateItem(name, now, weight=weight)
return wrap_async() if isawaitable(now) else wrap_sync()
def get(self, _: RateItem) -> AbstractBucket:
return self.bucket
@contextmanager
def combined_lock(locks: Iterable, timeout_sec: Optional[float] = None):
"""Acquires and releases multiple locks. Intended to be used in multiprocessing for a cross-process
lock combined with in process thread RLocks"""
acquired = []
try:
for lock in locks:
if timeout_sec is not None:
if not lock.acquire(timeout=timeout_sec):
raise TimeoutError("Timeout while acquiring combined lock.")
else:
lock.acquire()
acquired.append(lock)
yield
finally:
for lock in reversed(acquired):
lock.release()
class Limiter:
"""This class responsibility is to sum up all underlying logic
and make working with async/sync functions easily
"""
bucket_factory: BucketFactory
raise_when_fail: bool
retry_until_max_delay: bool
max_delay: Optional[int] = None
lock: Union[RLock, Iterable]
buffer_ms: int
# async_lock is thread local, created on first use
_thread_local: local
def __init__(
self,
argument: Union[BucketFactory, AbstractBucket, Rate, List[Rate]],
clock: AbstractClock = TimeClock(),
raise_when_fail: bool = True,
max_delay: Optional[Union[int, Duration]] = None,
retry_until_max_delay: bool = False,
buffer_ms: int = 50
):
"""Init Limiter using either a single bucket / multiple-bucket factory
/ single rate / rate list.
Parameters:
argument (Union[BucketFactory, AbstractBucket, Rate, List[Rate]]): The bucket or rate configuration.
clock (AbstractClock, optional): The clock instance to use for rate limiting. Defaults to TimeClock().
raise_when_fail (bool, optional): Whether to raise an exception when rate limiting fails. Defaults to True.
max_delay (Optional[Union[int, Duration]], optional): The maximum delay allowed for rate limiting.
Defaults to None.
retry_until_max_delay (bool, optional): If True, retry operations until the maximum delay is reached.
Useful for ensuring operations eventually succeed within the allowed delay window. Defaults to False.
"""
self.bucket_factory = self._init_bucket_factory(argument, clock=clock)
self.raise_when_fail = raise_when_fail
self.retry_until_max_delay = retry_until_max_delay
self.buffer_ms = buffer_ms
if max_delay is not None:
if isinstance(max_delay, Duration):
max_delay = int(max_delay)
assert max_delay >= 0, "Max-delay must not be negative"
self.max_delay = max_delay
self.lock = RLock()
self._thread_local = local()
if isinstance(argument, AbstractBucket):
limiter_lock = argument.limiter_lock()
if limiter_lock is not None:
self.lock = (limiter_lock, self.lock)
def buckets(self) -> List[AbstractBucket]:
"""Get list of active buckets
"""
return self.bucket_factory.get_buckets()
def dispose(self, bucket: Union[int, AbstractBucket]) -> bool:
"""Dispose/Remove a specific bucket,
using bucket-id or bucket object as param
"""
return self.bucket_factory.dispose(bucket)
def _init_bucket_factory(
self,
argument: Union[BucketFactory, AbstractBucket, Rate, List[Rate]],
clock: AbstractClock,
) -> BucketFactory:
if isinstance(argument, Rate):
argument = [argument]
if isinstance(argument, list):
assert len(argument) > 0, "Rates must not be empty"
assert isinstance(argument[0], Rate), "Not valid rates list"
rates = argument
logger.info("Initializing default bucket(InMemoryBucket) with rates: %s", rates)
argument = InMemoryBucket(rates)
if isinstance(argument, AbstractBucket):
argument = SingleBucketFactory(argument, clock)
assert isinstance(argument, BucketFactory), "Not a valid bucket/bucket-factory"
return argument
def _raise_bucket_full_if_necessary(
self,
bucket: AbstractBucket,
item: RateItem,
):
if self.raise_when_fail:
assert bucket.failing_rate is not None # NOTE: silence mypy
raise BucketFullException(item, bucket.failing_rate)
def _raise_delay_exception_if_necessary(
self,
bucket: AbstractBucket,
item: RateItem,
delay: int,
):
if self.raise_when_fail:
assert bucket.failing_rate is not None # NOTE: silence mypy
assert isinstance(self.max_delay, int)
raise LimiterDelayException(
item,
bucket.failing_rate,
delay,
self.max_delay,
)
def delay_or_raise(
self,
bucket: AbstractBucket,
item: RateItem,
) -> Union[bool, Awaitable[bool]]:
"""On `try_acquire` failed, handle delay or raise error immediately"""
assert bucket.failing_rate is not None
if self.max_delay is None:
self._raise_bucket_full_if_necessary(bucket, item)
return False
delay = bucket.waiting(item)
def _handle_reacquire(re_acquire: bool) -> bool:
if not re_acquire:
logger.error("""Failed to re-acquire after the expected delay. If it failed,
either clock or bucket is unstable.
If asyncio, use try_acquire_async(). If multiprocessing,
use retry_until_max_delay=True.""")
self._raise_bucket_full_if_necessary(bucket, item)
return re_acquire
if isawaitable(delay):
async def _handle_async():
nonlocal delay
delay = await delay
assert isinstance(delay, int), "Delay not integer"
total_delay = 0
delay += self.buffer_ms
while True:
total_delay += delay
if self.retry_until_max_delay:
if self.max_delay is not None and total_delay > self.max_delay:
logger.error("Total delay exceeded max_delay: total_delay=%s, max_delay=%s",
total_delay, self.max_delay)
self._raise_delay_exception_if_necessary(bucket, item, total_delay)
return False
else:
if self.max_delay is not None and delay > self.max_delay:
logger.error(
"Required delay too large: actual=%s, expected=%s",
delay,
self.max_delay,
)
self._raise_delay_exception_if_necessary(bucket, item, delay)
return False
await asyncio.sleep(delay / 1000)
item.timestamp += delay
re_acquire = bucket.put(item)
if isawaitable(re_acquire):
re_acquire = await re_acquire
if not self.retry_until_max_delay:
return _handle_reacquire(re_acquire)
elif re_acquire:
return True
return _handle_async()
assert isinstance(delay, int)
if delay < 0:
logger.error(
"Cannot fit item into bucket: item=%s, rate=%s, bucket=%s",
item,
bucket.failing_rate,
bucket,
)
self._raise_bucket_full_if_necessary(bucket, item)
return False
total_delay = 0
while True:
logger.debug("delay=%d, total_delay=%s", delay, total_delay)
delay = bucket.waiting(item)
assert isinstance(delay, int)
delay += self.buffer_ms
total_delay += delay
if self.max_delay is not None and total_delay > self.max_delay:
logger.error(
"Required delay too large: actual=%s, expected=%s",
delay,
self.max_delay,
)
if self.retry_until_max_delay:
self._raise_delay_exception_if_necessary(bucket, item, total_delay)
else:
self._raise_delay_exception_if_necessary(bucket, item, delay)
return False
sleep(delay / 1000)
item.timestamp += delay
re_acquire = bucket.put(item)
# NOTE: if delay is not Awaitable, then `bucket.put` is not Awaitable
assert isinstance(re_acquire, bool)
if not self.retry_until_max_delay:
return _handle_reacquire(re_acquire)
elif re_acquire:
return True
def handle_bucket_put(
self,
bucket: AbstractBucket,
item: RateItem,
) -> Union[bool, Awaitable[bool]]:
"""Putting item into bucket"""
def _handle_result(is_success: bool):
if not is_success:
return self.delay_or_raise(bucket, item)
return True
acquire = bucket.put(item)
if isawaitable(acquire):
async def _put_async():
nonlocal acquire
acquire = await acquire
result = _handle_result(acquire)
while isawaitable(result):
result = await result
return result
return _put_async()
return _handle_result(acquire) # type: ignore
def _get_async_lock(self):
"""Must be called before first try_acquire_async for each thread"""
try:
return self._thread_local.async_lock
except AttributeError:
lock = asyncio.Lock()
self._thread_local.async_lock = lock
return lock
async def try_acquire_async(self, name: str, weight: int = 1) -> bool:
"""
async version of try_acquire.
This uses a top level, thread-local async lock to ensure that the async loop doesn't block
This does not make the entire code async: use an async bucket for that.
"""
async with self._get_async_lock():
acquired = self.try_acquire(name=name, weight=weight)
if isawaitable(acquired):
return await acquired
else:
logger.warning("async call made without an async bucket.")
return acquired
def try_acquire(self, name: str, weight: int = 1) -> Union[bool, Awaitable[bool]]:
"""Try acquiring an item with name & weight
Return true on success, false on failure
"""
with self.lock if not isinstance(self.lock, Iterable) else combined_lock(self.lock):
assert weight >= 0, "item's weight must be >= 0"
if weight == 0:
# NOTE: if item is weightless, just let it go through
# NOTE: this might change in the future
return True
item = self.bucket_factory.wrap_item(name, weight)
if isawaitable(item):
async def _handle_async():
nonlocal item
item = await item
bucket = self.bucket_factory.get(item)
if isawaitable(bucket):
bucket = await bucket
assert isinstance(bucket, AbstractBucket), f"Invalid bucket: item: {name}"
result = self.handle_bucket_put(bucket, item)
while isawaitable(result):
result = await result
return result
return _handle_async()
assert isinstance(item, RateItem) # NOTE: this is to silence mypy warning
bucket = self.bucket_factory.get(item)
if isawaitable(bucket):
async def _handle_async_bucket():
nonlocal bucket
bucket = await bucket
assert isinstance(bucket, AbstractBucket), f"Invalid bucket: item: {name}"
result = self.handle_bucket_put(bucket, item)
while isawaitable(result):
result = await result
return result
return _handle_async_bucket()
assert isinstance(bucket, AbstractBucket), f"Invalid bucket: item: {name}"
result = self.handle_bucket_put(bucket, item)
if isawaitable(result):
async def _handle_async_result():
nonlocal result
while isawaitable(result):
result = await result
return result
return _handle_async_result()
return result
def as_decorator(self) -> Callable[[ItemMapping], DecoratorWrapper]:
"""Use limiter decorator
Use with both sync & async function
"""
def with_mapping_func(mapping: ItemMapping) -> DecoratorWrapper:
def decorator_wrapper(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
"""Actual function wrapper"""
if asyncio.iscoroutinefunction(func):
@wraps(func)
async def wrapper_async(*args, **kwargs):
(name, weight) = mapping(*args, **kwargs)
assert isinstance(name, str), "Mapping name is expected but not found"
assert isinstance(weight, int), "Mapping weight is expected but not found"
accquire_ok = self.try_acquire_async(name, weight)
if isawaitable(accquire_ok):
await accquire_ok
return await func(*args, **kwargs)
return wrapper_async
else:
@wraps(func)
def wrapper(*args, **kwargs):
(name, weight) = mapping(*args, **kwargs)
assert isinstance(name, str), "Mapping name is expected but not found"
assert isinstance(weight, int), "Mapping weight is expected but not found"
accquire_ok = self.try_acquire(name, weight)
if not isawaitable(accquire_ok):
return func(*args, **kwargs)
async def _handle_accquire_async():
nonlocal accquire_ok
accquire_ok = await accquire_ok
result = func(*args, **kwargs)
if isawaitable(result):
return await result
return result
return _handle_accquire_async()
return wrapper
return decorator_wrapper
return with_mapping_func

View File

@@ -0,0 +1,156 @@
"""
A collection of common use cases and patterns for pyrate_limiter
"""
import logging
from typing import List
from typing import Optional
from typing import Union
from pyrate_limiter import AbstractBucket
from pyrate_limiter import BucketAsyncWrapper
from pyrate_limiter import Duration
from pyrate_limiter import InMemoryBucket
from pyrate_limiter import Limiter
from pyrate_limiter import Rate
from pyrate_limiter import SQLiteBucket
logger = logging.getLogger(__name__)
# Global for convenience in multiprocessing, populated by init_mp_limiter.
# Intended to be called by a ProcessPoolExecutor's initializer
LIMITER: Optional[Limiter] = None
def create_sqlite_bucket(
rates: List[Rate],
db_path: Optional[str],
table_name: str = "pyrate_limiter",
use_file_lock: bool = False,
):
"""
Create and initialize a SQLite bucket for rate limiting.
Args:
rates: List of rate limit configurations.
db_path: Path to the SQLite database file (or in-memory if None).
table_name: Name of the table to store rate bucket data.
use_file_lock: Enable file locking for multi-process synchronization.
Returns:
SQLiteBucket: Initialized SQLite-backed bucket.
"""
logger.info(f"{table_name=}")
bucket = SQLiteBucket.init_from_file(
rates,
db_path=str(db_path),
table=table_name,
create_new_table=True,
use_file_lock=use_file_lock,
)
return bucket
def create_sqlite_limiter(
rate_per_duration: int = 3,
duration: Union[int, Duration] = Duration.SECOND,
db_path: Optional[str] = None,
table_name: str = "rate_bucket",
max_delay: Union[int, Duration] = Duration.DAY,
buffer_ms: int = 50,
use_file_lock: bool = False,
async_wrapper: bool = False,
) -> Limiter:
"""
Create a SQLite-backed rate limiter with configurable rate, persistence, and optional async support.
Args:
rate_per_duration: Number of allowed requests per duration.
duration: Time window for the rate limit.
db_path: Path to the SQLite database file (or in-memory if None).
table_name: Name of the table used for rate buckets.
max_delay: Maximum delay before failing requests.
buffer_ms: Extra wait time in milliseconds to account for clock drift.
use_file_lock: Enable file locking for multi-process synchronization.
async_wrapper: Whether to wrap the bucket for async usage.
Returns:
Limiter: Configured SQLite-backed limiter instance.
"""
rate = Rate(rate_per_duration, duration)
rate_limits = [rate]
bucket: AbstractBucket = SQLiteBucket.init_from_file(
rate_limits,
db_path=str(db_path),
table=table_name,
create_new_table=True,
use_file_lock=use_file_lock,
)
if async_wrapper:
bucket = BucketAsyncWrapper(bucket)
limiter = Limiter(
bucket, raise_when_fail=False, max_delay=max_delay, retry_until_max_delay=True, buffer_ms=buffer_ms
)
return limiter
def create_inmemory_limiter(
rate_per_duration: int = 3,
duration: Union[int, Duration] = Duration.SECOND,
max_delay: Union[int, Duration] = Duration.DAY,
buffer_ms: int = 50,
async_wrapper: bool = False,
) -> Limiter:
"""
Create an in-memory rate limiter with configurable rate, duration, delay, and optional async support.
Args:
rate_per_duration: Number of allowed requests per duration.
duration: Time window for the rate limit.
max_delay: Maximum delay before failing requests.
buffer_ms: Extra wait time in milliseconds to account for clock drift.
async_wrapper: Whether to wrap the bucket for async usage.
Returns:
Limiter: Configured in-memory limiter instance.
"""
rate = Rate(rate_per_duration, duration)
rate_limits = [rate]
bucket: AbstractBucket = InMemoryBucket(rate_limits)
if async_wrapper:
bucket = BucketAsyncWrapper(InMemoryBucket(rate_limits))
limiter = Limiter(
bucket, raise_when_fail=False, max_delay=max_delay, retry_until_max_delay=True, buffer_ms=buffer_ms
)
return limiter
def init_global_limiter(bucket: AbstractBucket,
max_delay: Union[int, Duration] = Duration.HOUR,
raise_when_fail: bool = False,
retry_until_max_delay: bool = True,
buffer_ms: int = 50):
"""
Initialize a global Limiter instance using the provided bucket.
Intended for use as an initializer for ProcessPoolExecutor.
Args:
bucket: The rate-limiting bucket to be used.
max_delay: Maximum delay before failing requests.
raise_when_fail: Whether to raise an exception when a request fails.
retry_until_max_delay: Retry until the maximum delay is reached.
buffer_ms: Additional buffer time in milliseconds for retries.
"""
global LIMITER
LIMITER = Limiter(bucket, raise_when_fail=raise_when_fail,
max_delay=max_delay, retry_until_max_delay=retry_until_max_delay, buffer_ms=buffer_ms)

View File

@@ -0,0 +1,83 @@
import random
import sqlite3
import string
from pathlib import Path
from tempfile import gettempdir
from typing import List
from .abstracts import Rate
from .abstracts import RateItem
def binary_search(items: List[RateItem], value: int) -> int:
"""Find the index of item in list where left.timestamp < value <= right.timestamp
this is to determine the current size of some window that
stretches from now back to lower-boundary = value and
"""
if not items:
return 0
if value > items[-1].timestamp:
return -1
if value <= items[0].timestamp:
return 0
if len(items) == 2:
return 1
left_pointer, right_pointer, mid = 0, len(items) - 1, -2
while left_pointer <= right_pointer:
mid = (left_pointer + right_pointer) // 2
left, right = items[mid - 1].timestamp, items[mid].timestamp
if left < value <= right:
break
if left >= value:
right_pointer = mid
if right < value:
left_pointer = mid + 1
return mid
def validate_rate_list(rates: List[Rate]) -> bool:
"""Raise false if rates are incorrectly ordered."""
if not rates:
return False
for idx, current_rate in enumerate(rates[1:]):
prev_rate = rates[idx]
if current_rate.interval <= prev_rate.interval:
return False
if current_rate.limit <= prev_rate.limit:
return False
if (current_rate.limit / current_rate.interval) > (prev_rate.limit / prev_rate.interval):
return False
return True
def id_generator(
size=6,
chars=string.ascii_uppercase + string.digits + string.ascii_lowercase,
) -> str:
return "".join(random.choice(chars) for _ in range(size))
def dedicated_sqlite_clock_connection():
temp_dir = Path(gettempdir())
default_db_path = temp_dir / "pyrate_limiter_clock_only.sqlite"
conn = sqlite3.connect(
default_db_path,
isolation_level="EXCLUSIVE",
check_same_thread=False,
)
return conn