Initial commit

This commit is contained in:
kdusek
2025-12-09 12:13:01 +01:00
commit 8e654ed209
13332 changed files with 2695056 additions and 0 deletions

View File

@@ -0,0 +1,10 @@
# flake8: noqa
"""Conrete bucket implementations
"""
from .in_memory_bucket import InMemoryBucket
from .mp_bucket import MultiprocessBucket
from .postgres import PostgresBucket
from .postgres import Queries as PgQueries
from .redis_bucket import RedisBucket
from .sqlite_bucket import Queries as SQLiteQueries
from .sqlite_bucket import SQLiteBucket

View File

@@ -0,0 +1,92 @@
"""Naive bucket implementation using built-in list
"""
from typing import List
from typing import Optional
from ..abstracts import AbstractBucket
from ..abstracts import Rate
from ..abstracts import RateItem
from ..utils import binary_search
class InMemoryBucket(AbstractBucket):
"""Simple In-memory Bucket using native list
Clock can be either `time.time` or `time.monotonic`
When leak, clock is required
Pros: fast, safe, and precise
Cons: since it resides in local memory, the data is not persistent, nor scalable
Usecase: small applications, simple logic
"""
items: List[RateItem]
failing_rate: Optional[Rate]
def __init__(self, rates: List[Rate]):
self.rates = sorted(rates, key=lambda r: r.interval)
self.items = []
def put(self, item: RateItem) -> bool:
if item.weight == 0:
return True
current_length = len(self.items)
after_length = item.weight + current_length
for rate in self.rates:
if after_length < rate.limit:
break
lower_bound_value = item.timestamp - rate.interval
lower_bound_idx = binary_search(self.items, lower_bound_value)
if lower_bound_idx >= 0:
count_existing_items = len(self.items) - lower_bound_idx
space_available = rate.limit - count_existing_items
else:
space_available = rate.limit
if space_available < item.weight:
self.failing_rate = rate
return False
self.failing_rate = None
if item.weight > 1:
self.items.extend([item for _ in range(item.weight)])
else:
self.items.append(item)
return True
def leak(self, current_timestamp: Optional[int] = None) -> int:
assert current_timestamp is not None
if self.items:
max_interval = self.rates[-1].interval
lower_bound = current_timestamp - max_interval
if lower_bound > self.items[-1].timestamp:
remove_count = len(self.items)
del self.items[:]
return remove_count
if lower_bound < self.items[0].timestamp:
return 0
idx = binary_search(self.items, lower_bound)
del self.items[:idx]
return idx
return 0
def flush(self) -> None:
self.failing_rate = None
del self.items[:]
def count(self) -> int:
return len(self.items)
def peek(self, index: int) -> Optional[RateItem]:
if not self.items:
return None
return self.items[-1 - index] if abs(index) < self.count() else None

View File

@@ -0,0 +1,53 @@
"""multiprocessing In-memory Bucket using a multiprocessing.Manager.ListProxy
and a multiprocessing.Lock.
"""
from multiprocessing import Manager
from multiprocessing import RLock
from multiprocessing.managers import ListProxy
from multiprocessing.synchronize import RLock as LockType
from typing import List
from typing import Optional
from ..abstracts import Rate
from ..abstracts import RateItem
from pyrate_limiter.buckets import InMemoryBucket
class MultiprocessBucket(InMemoryBucket):
items: List[RateItem] # ListProxy
mp_lock: LockType
def __init__(self, rates: List[Rate], items: List[RateItem], mp_lock: LockType):
if not isinstance(items, ListProxy):
raise ValueError("items must be a ListProxy")
self.rates = sorted(rates, key=lambda r: r.interval)
self.items = items
self.mp_lock = mp_lock
def put(self, item: RateItem) -> bool:
with self.mp_lock:
return super().put(item)
def leak(self, current_timestamp: Optional[int] = None) -> int:
with self.mp_lock:
return super().leak(current_timestamp)
def limiter_lock(self):
return self.mp_lock
@classmethod
def init(
cls,
rates: List[Rate],
):
"""
Creates a single ListProxy so that this bucket can be shared across multiple processes.
"""
shared_items: List[RateItem] = Manager().list() # type: ignore[assignment]
mp_lock: LockType = RLock()
return cls(rates=rates, items=shared_items, mp_lock=mp_lock)

View File

@@ -0,0 +1,166 @@
"""A bucket using PostgreSQL as backend
"""
from __future__ import annotations
from contextlib import contextmanager
from typing import Awaitable
from typing import List
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from ..abstracts import AbstractBucket
from ..abstracts import Rate
from ..abstracts import RateItem
if TYPE_CHECKING:
from psycopg_pool import ConnectionPool # type: ignore[import-untyped]
class Queries:
CREATE_BUCKET_TABLE = """
CREATE TABLE IF NOT EXISTS {table} (
name VARCHAR,
weight SMALLINT,
item_timestamp TIMESTAMP
)
"""
CREATE_INDEX_ON_TIMESTAMP = """
CREATE INDEX IF NOT EXISTS {index} ON {table} (item_timestamp)
"""
COUNT = """
SELECT COUNT(*) FROM {table}
"""
PUT = """
INSERT INTO {table} (name, weight, item_timestamp) VALUES (%s, %s, TO_TIMESTAMP(%s))
"""
FLUSH = """
DELETE FROM {table}
"""
PEEK = """
SELECT name, weight, (extract(EPOCH FROM item_timestamp) * 1000) as item_timestamp
FROM {table}
ORDER BY item_timestamp DESC
LIMIT 1
OFFSET {offset}
"""
LEAK = """
DELETE FROM {table} WHERE item_timestamp < TO_TIMESTAMP({timestamp})
"""
LEAK_COUNT = """
SELECT COUNT(*) FROM {table} WHERE item_timestamp < TO_TIMESTAMP({timestamp})
"""
class PostgresBucket(AbstractBucket):
table: str
pool: ConnectionPool
def __init__(self, pool: ConnectionPool, table: str, rates: List[Rate]):
self.table = table.lower()
self.pool = pool
assert rates
self.rates = rates
self._full_tbl = f'ratelimit___{self.table}'
self._create_table()
@contextmanager
def _get_conn(self):
with self.pool.connection() as conn:
yield conn
def _create_table(self):
with self._get_conn() as conn:
conn.execute(Queries.CREATE_BUCKET_TABLE.format(table=self._full_tbl))
index_name = f'timestampIndex_{self.table}'
conn.execute(Queries.CREATE_INDEX_ON_TIMESTAMP.format(table=self._full_tbl, index=index_name))
def put(self, item: RateItem) -> Union[bool, Awaitable[bool]]:
"""Put an item (typically the current time) in the bucket
return true if successful, otherwise false
"""
if item.weight == 0:
return True
with self._get_conn() as conn:
for rate in self.rates:
bound = f"SELECT TO_TIMESTAMP({item.timestamp / 1000}) - INTERVAL '{rate.interval} milliseconds'"
query = f'SELECT COUNT(*) FROM {self._full_tbl} WHERE item_timestamp >= ({bound})'
cur = conn.execute(query)
count = int(cur.fetchone()[0])
cur.close()
if rate.limit - count < item.weight:
self.failing_rate = rate
return False
self.failing_rate = None
query = Queries.PUT.format(table=self._full_tbl)
# https://www.psycopg.org/docs/extras.html#fast-exec
for _ in range(item.weight):
conn.execute(query, (item.name, item.weight, item.timestamp / 1000))
return True
def leak(
self,
current_timestamp: Optional[int] = None,
) -> Union[int, Awaitable[int]]:
"""leaking bucket - removing items that are outdated"""
assert current_timestamp is not None, "current-time must be passed on for leak"
lower_bound = current_timestamp - self.rates[-1].interval
if lower_bound <= 0:
return 0
count = 0
with self._get_conn() as conn:
conn = conn.execute(Queries.LEAK_COUNT.format(table=self._full_tbl, timestamp=lower_bound / 1000))
result = conn.fetchone()
if result:
conn.execute(Queries.LEAK.format(table=self._full_tbl, timestamp=lower_bound / 1000))
count = int(result[0])
return count
def flush(self) -> Union[None, Awaitable[None]]:
"""Flush the whole bucket
- Must remove `failing-rate` after flushing
"""
with self._get_conn() as conn:
conn.execute(Queries.FLUSH.format(table=self._full_tbl))
self.failing_rate = None
return None
def count(self) -> Union[int, Awaitable[int]]:
"""Count number of items in the bucket"""
count = 0
with self._get_conn() as conn:
conn = conn.execute(Queries.COUNT.format(table=self._full_tbl))
result = conn.fetchone()
assert result
count = int(result[0])
return count
def peek(self, index: int) -> Union[Optional[RateItem], Awaitable[Optional[RateItem]]]:
"""Peek at the rate-item at a specific index in latest-to-earliest order
NOTE: The reason we cannot peek from the start of the queue(earliest-to-latest) is
we can't really tell how many outdated items are still in the queue
"""
item = None
with self._get_conn() as conn:
conn = conn.execute(Queries.PEEK.format(table=self._full_tbl, offset=index))
result = conn.fetchone()
if result:
name, weight, timestamp = result[0], int(result[1]), int(result[2])
item = RateItem(name=name, weight=weight, timestamp=timestamp)
return item

View File

@@ -0,0 +1,187 @@
"""Bucket implementation using Redis
"""
from __future__ import annotations
from inspect import isawaitable
from typing import Awaitable
from typing import List
from typing import Optional
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
from ..abstracts import AbstractBucket
from ..abstracts import Rate
from ..abstracts import RateItem
from ..utils import id_generator
if TYPE_CHECKING:
from redis import Redis
from redis.asyncio import Redis as AsyncRedis
class LuaScript:
"""Scripts that deal with bucket operations"""
PUT_ITEM = """
local bucket = KEYS[1]
local now = ARGV[1]
local space_required = tonumber(ARGV[2])
local item_name = ARGV[3]
local rates_count = tonumber(ARGV[4])
for i=1,rates_count do
local offset = (i - 1) * 2
local interval = tonumber(ARGV[5 + offset])
local limit = tonumber(ARGV[5 + offset + 1])
local count = redis.call('ZCOUNT', bucket, now - interval, now)
local space_available = limit - tonumber(count)
if space_available < space_required then
return i - 1
end
end
for i=1,space_required do
redis.call('ZADD', bucket, now, item_name..i)
end
return -1
"""
class RedisBucket(AbstractBucket):
"""A bucket using redis for storing data
- We are not using redis' built-in TIME since it is non-deterministic
- In distributed context, use local server time or a remote time server
- Each bucket instance use a dedicated connection to avoid race-condition
- can be either sync or async
"""
rates: List[Rate]
failing_rate: Optional[Rate]
bucket_key: str
script_hash: str
redis: Union[Redis, AsyncRedis]
def __init__(
self,
rates: List[Rate],
redis: Union[Redis, AsyncRedis],
bucket_key: str,
script_hash: str,
):
self.rates = rates
self.redis = redis
self.bucket_key = bucket_key
self.script_hash = script_hash
self.failing_rate = None
@classmethod
def init(
cls,
rates: List[Rate],
redis: Union[Redis, AsyncRedis],
bucket_key: str,
):
script_hash = redis.script_load(LuaScript.PUT_ITEM)
if isawaitable(script_hash):
async def _async_init():
nonlocal script_hash
script_hash = await script_hash
return cls(rates, redis, bucket_key, script_hash)
return _async_init()
return cls(rates, redis, bucket_key, script_hash)
def _check_and_insert(self, item: RateItem) -> Union[Rate, None, Awaitable[Optional[Rate]]]:
keys = [self.bucket_key]
args = [
item.timestamp,
item.weight,
# NOTE: this is to avoid key collision since we are using ZSET
f"{item.name}:{id_generator()}:", # noqa: E231
len(self.rates),
*[value for rate in self.rates for value in (rate.interval, rate.limit)],
]
idx = self.redis.evalsha(self.script_hash, len(keys), *keys, *args)
def _handle_sync(returned_idx: int):
assert isinstance(returned_idx, int), "Not int"
if returned_idx < 0:
return None
return self.rates[returned_idx]
async def _handle_async(returned_idx: Awaitable[int]):
assert isawaitable(returned_idx), "Not corotine"
awaited_idx = await returned_idx
return _handle_sync(awaited_idx)
return _handle_async(idx) if isawaitable(idx) else _handle_sync(idx)
def put(self, item: RateItem) -> Union[bool, Awaitable[bool]]:
"""Add item to key"""
failing_rate = self._check_and_insert(item)
if isawaitable(failing_rate):
async def _handle_async():
self.failing_rate = await failing_rate
return not bool(self.failing_rate)
return _handle_async()
assert isinstance(failing_rate, Rate) or failing_rate is None
self.failing_rate = failing_rate
return not bool(self.failing_rate)
def leak(self, current_timestamp: Optional[int] = None) -> Union[int, Awaitable[int]]:
assert current_timestamp is not None
return self.redis.zremrangebyscore(
self.bucket_key,
0,
current_timestamp - self.rates[-1].interval,
)
def flush(self):
self.failing_rate = None
return self.redis.delete(self.bucket_key)
def count(self):
return self.redis.zcard(self.bucket_key)
def peek(self, index: int) -> Union[RateItem, None, Awaitable[Optional[RateItem]]]:
items = self.redis.zrange(
self.bucket_key,
-1 - index,
-1 - index,
withscores=True,
score_cast_func=int,
)
if not items:
return None
def _handle_items(received_items: List[Tuple[str, int]]):
if not received_items:
return None
item = received_items[0]
rate_item = RateItem(name=str(item[0]), timestamp=item[1])
return rate_item
if isawaitable(items):
async def _awaiting():
nonlocal items
items = await items
return _handle_items(items)
return _awaiting()
assert isinstance(items, list)
return _handle_items(items)

View File

@@ -0,0 +1,250 @@
"""Bucket implementation using SQLite"""
import logging
import sqlite3
from contextlib import nullcontext
from pathlib import Path
from tempfile import gettempdir
from threading import RLock
from time import time
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
from ..abstracts import AbstractBucket
from ..abstracts import Rate
from ..abstracts import RateItem
logger = logging.getLogger(__name__)
class Queries:
CREATE_BUCKET_TABLE = """
CREATE TABLE IF NOT EXISTS '{table}' (
name VARCHAR,
item_timestamp INTEGER
)
"""
CREATE_INDEX_ON_TIMESTAMP = """
CREATE INDEX IF NOT EXISTS '{index_name}' ON '{table_name}' (item_timestamp)
"""
COUNT_BEFORE_INSERT = """
SELECT :interval{index} as interval, COUNT(*) FROM '{table}'
WHERE item_timestamp >= :current_timestamp - :interval{index}
"""
PUT_ITEM = """
INSERT INTO '{table}' (name, item_timestamp) VALUES %s
"""
LEAK = """
DELETE FROM "{table}" WHERE rowid IN (
SELECT rowid FROM "{table}" ORDER BY item_timestamp ASC LIMIT {count});
""".strip()
COUNT_BEFORE_LEAK = """SELECT COUNT(*) FROM '{table}' WHERE item_timestamp < {current_timestamp} - {interval}"""
FLUSH = """DELETE FROM '{table}'"""
# The below sqls are for testing only
DROP_TABLE = "DROP TABLE IF EXISTS '{table}'"
DROP_INDEX = "DROP INDEX IF EXISTS '{index}'"
COUNT_ALL = "SELECT COUNT(*) FROM '{table}'"
GET_ALL_ITEM = "SELECT * FROM '{table}' ORDER BY item_timestamp ASC"
GET_FIRST_ITEM = (
"SELECT name, item_timestamp FROM '{table}' ORDER BY item_timestamp ASC"
)
GET_LAG = """
SELECT (strftime ('%s', 'now') || substr(strftime ('%f', 'now'), 4)) - (
SELECT item_timestamp
FROM '{table}'
ORDER BY item_timestamp
ASC
LIMIT 1
)
"""
PEEK = 'SELECT * FROM "{table}" ORDER BY item_timestamp DESC LIMIT 1 OFFSET {count}'
class SQLiteBucket(AbstractBucket):
"""For sqlite bucket, we are using the sql time function as the clock
item's timestamp wont matter here
"""
rates: List[Rate]
failing_rate: Optional[Rate]
conn: sqlite3.Connection
table: str
full_count_query: str
lock: RLock
use_limiter_lock: bool
def __init__(
self, rates: List[Rate], conn: sqlite3.Connection, table: str, lock=None
):
self.conn = conn
self.table = table
self.rates = rates
if not lock:
self.use_limiter_lock = False
self.lock = RLock()
else:
self.use_limiter_lock = True
self.lock = lock
def limiter_lock(self):
if self.use_limiter_lock:
return self.lock
else:
return None
def _build_full_count_query(self, current_timestamp: int) -> Tuple[str, dict]:
full_query: List[str] = []
parameters = {"current_timestamp": current_timestamp}
for index, rate in enumerate(self.rates):
parameters[f"interval{index}"] = rate.interval
query = Queries.COUNT_BEFORE_INSERT.format(table=self.table, index=index)
full_query.append(query)
join_full_query = (
" union ".join(full_query) if len(full_query) > 1 else full_query[0]
)
return join_full_query, parameters
def put(self, item: RateItem) -> bool:
with self.lock:
query, parameters = self._build_full_count_query(item.timestamp)
cur = self.conn.execute(query, parameters)
rate_limit_counts = cur.fetchall()
cur.close()
for idx, result in enumerate(rate_limit_counts):
interval, count = result
rate = self.rates[idx]
assert interval == rate.interval
space_available = rate.limit - count
if space_available < item.weight:
self.failing_rate = rate
return False
items = ", ".join(
[f"('{name}', {item.timestamp})" for name in [item.name] * item.weight]
)
query = (Queries.PUT_ITEM.format(table=self.table)) % items
self.conn.execute(query).close()
self.conn.commit()
return True
def leak(self, current_timestamp: Optional[int] = None) -> int:
"""Leaking/clean up bucket"""
with self.lock:
assert current_timestamp is not None
query = Queries.COUNT_BEFORE_LEAK.format(
table=self.table,
interval=self.rates[-1].interval,
current_timestamp=current_timestamp,
)
cur = self.conn.execute(query)
count = cur.fetchone()[0]
query = Queries.LEAK.format(table=self.table, count=count)
cur.execute(query)
cur.close()
self.conn.commit()
return count
def flush(self) -> None:
with self.lock:
self.conn.execute(Queries.FLUSH.format(table=self.table)).close()
self.conn.commit()
self.failing_rate = None
def count(self) -> int:
with self.lock:
cur = self.conn.execute(
Queries.COUNT_ALL.format(table=self.table)
)
ret = cur.fetchone()[0]
cur.close()
return ret
def peek(self, index: int) -> Optional[RateItem]:
with self.lock:
query = Queries.PEEK.format(table=self.table, count=index)
cur = self.conn.execute(query)
item = cur.fetchone()
cur.close()
if not item:
return None
return RateItem(item[0], item[1])
@classmethod
def init_from_file(
cls,
rates: List[Rate],
table: str = "rate_bucket",
db_path: Optional[str] = None,
create_new_table: bool = True,
use_file_lock: bool = False
) -> "SQLiteBucket":
if db_path is None and use_file_lock:
raise ValueError("db_path must be specified when using use_file_lock")
if db_path is None:
temp_dir = Path(gettempdir())
db_path = str(temp_dir / f"pyrate_limiter_{time()}.sqlite")
# TBD: FileLock switched to a thread-local FileLock in 3.11.0.
# Should we set FileLock's thread_local to False, for cases where user is both multiprocessing & threading?
# As is, the file lock should be Multi Process - Single Thread and non-filelock is Single Process - Multi Thread
# A hybrid lock may be needed to gracefully handle both cases
file_lock = None
file_lock_ctx = nullcontext()
if use_file_lock:
try:
from filelock import FileLock # type: ignore[import-untyped]
file_lock = FileLock(db_path + ".lock") # type: ignore[no-redef]
file_lock_ctx: Union[nullcontext, FileLock] = file_lock # type: ignore[no-redef]
except ImportError:
raise ImportError(
"filelock is required for file locking. "
"Please install it as optional dependency"
)
with file_lock_ctx:
assert db_path is not None
assert db_path.endswith(".sqlite"), (
"Please provide a valid sqlite file path"
)
sqlite_connection = sqlite3.connect(
db_path,
isolation_level="DEFERRED",
check_same_thread=False,
)
cur = sqlite_connection.cursor()
if use_file_lock:
# https://www.sqlite.org/wal.html
cur.execute("PRAGMA journal_mode=WAL;")
# https://www.sqlite.org/pragma.html#pragma_synchronous
cur.execute("PRAGMA synchronous=NORMAL;")
if create_new_table:
cur.execute(
Queries.CREATE_BUCKET_TABLE.format(table=table)
)
create_idx_query = Queries.CREATE_INDEX_ON_TIMESTAMP.format(
index_name=f"idx_{table}_rate_item_timestamp",
table_name=table,
)
cur.execute(create_idx_query)
cur.close()
sqlite_connection.commit()
return cls(rates, sqlite_connection, table=table, lock=file_lock)