Files
edgartools/venv/lib/python3.10/site-packages/pyrate_limiter/buckets/redis_bucket.py
2025-12-09 12:13:01 +01:00

188 lines
5.4 KiB
Python

"""Bucket implementation using Redis
"""
from __future__ import annotations
from inspect import isawaitable
from typing import Awaitable
from typing import List
from typing import Optional
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
from ..abstracts import AbstractBucket
from ..abstracts import Rate
from ..abstracts import RateItem
from ..utils import id_generator
if TYPE_CHECKING:
from redis import Redis
from redis.asyncio import Redis as AsyncRedis
class LuaScript:
"""Scripts that deal with bucket operations"""
PUT_ITEM = """
local bucket = KEYS[1]
local now = ARGV[1]
local space_required = tonumber(ARGV[2])
local item_name = ARGV[3]
local rates_count = tonumber(ARGV[4])
for i=1,rates_count do
local offset = (i - 1) * 2
local interval = tonumber(ARGV[5 + offset])
local limit = tonumber(ARGV[5 + offset + 1])
local count = redis.call('ZCOUNT', bucket, now - interval, now)
local space_available = limit - tonumber(count)
if space_available < space_required then
return i - 1
end
end
for i=1,space_required do
redis.call('ZADD', bucket, now, item_name..i)
end
return -1
"""
class RedisBucket(AbstractBucket):
"""A bucket using redis for storing data
- We are not using redis' built-in TIME since it is non-deterministic
- In distributed context, use local server time or a remote time server
- Each bucket instance use a dedicated connection to avoid race-condition
- can be either sync or async
"""
rates: List[Rate]
failing_rate: Optional[Rate]
bucket_key: str
script_hash: str
redis: Union[Redis, AsyncRedis]
def __init__(
self,
rates: List[Rate],
redis: Union[Redis, AsyncRedis],
bucket_key: str,
script_hash: str,
):
self.rates = rates
self.redis = redis
self.bucket_key = bucket_key
self.script_hash = script_hash
self.failing_rate = None
@classmethod
def init(
cls,
rates: List[Rate],
redis: Union[Redis, AsyncRedis],
bucket_key: str,
):
script_hash = redis.script_load(LuaScript.PUT_ITEM)
if isawaitable(script_hash):
async def _async_init():
nonlocal script_hash
script_hash = await script_hash
return cls(rates, redis, bucket_key, script_hash)
return _async_init()
return cls(rates, redis, bucket_key, script_hash)
def _check_and_insert(self, item: RateItem) -> Union[Rate, None, Awaitable[Optional[Rate]]]:
keys = [self.bucket_key]
args = [
item.timestamp,
item.weight,
# NOTE: this is to avoid key collision since we are using ZSET
f"{item.name}:{id_generator()}:", # noqa: E231
len(self.rates),
*[value for rate in self.rates for value in (rate.interval, rate.limit)],
]
idx = self.redis.evalsha(self.script_hash, len(keys), *keys, *args)
def _handle_sync(returned_idx: int):
assert isinstance(returned_idx, int), "Not int"
if returned_idx < 0:
return None
return self.rates[returned_idx]
async def _handle_async(returned_idx: Awaitable[int]):
assert isawaitable(returned_idx), "Not corotine"
awaited_idx = await returned_idx
return _handle_sync(awaited_idx)
return _handle_async(idx) if isawaitable(idx) else _handle_sync(idx)
def put(self, item: RateItem) -> Union[bool, Awaitable[bool]]:
"""Add item to key"""
failing_rate = self._check_and_insert(item)
if isawaitable(failing_rate):
async def _handle_async():
self.failing_rate = await failing_rate
return not bool(self.failing_rate)
return _handle_async()
assert isinstance(failing_rate, Rate) or failing_rate is None
self.failing_rate = failing_rate
return not bool(self.failing_rate)
def leak(self, current_timestamp: Optional[int] = None) -> Union[int, Awaitable[int]]:
assert current_timestamp is not None
return self.redis.zremrangebyscore(
self.bucket_key,
0,
current_timestamp - self.rates[-1].interval,
)
def flush(self):
self.failing_rate = None
return self.redis.delete(self.bucket_key)
def count(self):
return self.redis.zcard(self.bucket_key)
def peek(self, index: int) -> Union[RateItem, None, Awaitable[Optional[RateItem]]]:
items = self.redis.zrange(
self.bucket_key,
-1 - index,
-1 - index,
withscores=True,
score_cast_func=int,
)
if not items:
return None
def _handle_items(received_items: List[Tuple[str, int]]):
if not received_items:
return None
item = received_items[0]
rate_item = RateItem(name=str(item[0]), timestamp=item[1])
return rate_item
if isawaitable(items):
async def _awaiting():
nonlocal items
items = await items
return _handle_items(items)
return _awaiting()
assert isinstance(items, list)
return _handle_items(items)