"""Bucket implementation using Redis """ from __future__ import annotations from inspect import isawaitable from typing import Awaitable from typing import List from typing import Optional from typing import Tuple from typing import TYPE_CHECKING from typing import Union from ..abstracts import AbstractBucket from ..abstracts import Rate from ..abstracts import RateItem from ..utils import id_generator if TYPE_CHECKING: from redis import Redis from redis.asyncio import Redis as AsyncRedis class LuaScript: """Scripts that deal with bucket operations""" PUT_ITEM = """ local bucket = KEYS[1] local now = ARGV[1] local space_required = tonumber(ARGV[2]) local item_name = ARGV[3] local rates_count = tonumber(ARGV[4]) for i=1,rates_count do local offset = (i - 1) * 2 local interval = tonumber(ARGV[5 + offset]) local limit = tonumber(ARGV[5 + offset + 1]) local count = redis.call('ZCOUNT', bucket, now - interval, now) local space_available = limit - tonumber(count) if space_available < space_required then return i - 1 end end for i=1,space_required do redis.call('ZADD', bucket, now, item_name..i) end return -1 """ class RedisBucket(AbstractBucket): """A bucket using redis for storing data - We are not using redis' built-in TIME since it is non-deterministic - In distributed context, use local server time or a remote time server - Each bucket instance use a dedicated connection to avoid race-condition - can be either sync or async """ rates: List[Rate] failing_rate: Optional[Rate] bucket_key: str script_hash: str redis: Union[Redis, AsyncRedis] def __init__( self, rates: List[Rate], redis: Union[Redis, AsyncRedis], bucket_key: str, script_hash: str, ): self.rates = rates self.redis = redis self.bucket_key = bucket_key self.script_hash = script_hash self.failing_rate = None @classmethod def init( cls, rates: List[Rate], redis: Union[Redis, AsyncRedis], bucket_key: str, ): script_hash = redis.script_load(LuaScript.PUT_ITEM) if isawaitable(script_hash): async def _async_init(): nonlocal script_hash script_hash = await script_hash return cls(rates, redis, bucket_key, script_hash) return _async_init() return cls(rates, redis, bucket_key, script_hash) def _check_and_insert(self, item: RateItem) -> Union[Rate, None, Awaitable[Optional[Rate]]]: keys = [self.bucket_key] args = [ item.timestamp, item.weight, # NOTE: this is to avoid key collision since we are using ZSET f"{item.name}:{id_generator()}:", # noqa: E231 len(self.rates), *[value for rate in self.rates for value in (rate.interval, rate.limit)], ] idx = self.redis.evalsha(self.script_hash, len(keys), *keys, *args) def _handle_sync(returned_idx: int): assert isinstance(returned_idx, int), "Not int" if returned_idx < 0: return None return self.rates[returned_idx] async def _handle_async(returned_idx: Awaitable[int]): assert isawaitable(returned_idx), "Not corotine" awaited_idx = await returned_idx return _handle_sync(awaited_idx) return _handle_async(idx) if isawaitable(idx) else _handle_sync(idx) def put(self, item: RateItem) -> Union[bool, Awaitable[bool]]: """Add item to key""" failing_rate = self._check_and_insert(item) if isawaitable(failing_rate): async def _handle_async(): self.failing_rate = await failing_rate return not bool(self.failing_rate) return _handle_async() assert isinstance(failing_rate, Rate) or failing_rate is None self.failing_rate = failing_rate return not bool(self.failing_rate) def leak(self, current_timestamp: Optional[int] = None) -> Union[int, Awaitable[int]]: assert current_timestamp is not None return self.redis.zremrangebyscore( self.bucket_key, 0, current_timestamp - self.rates[-1].interval, ) def flush(self): self.failing_rate = None return self.redis.delete(self.bucket_key) def count(self): return self.redis.zcard(self.bucket_key) def peek(self, index: int) -> Union[RateItem, None, Awaitable[Optional[RateItem]]]: items = self.redis.zrange( self.bucket_key, -1 - index, -1 - index, withscores=True, score_cast_func=int, ) if not items: return None def _handle_items(received_items: List[Tuple[str, int]]): if not received_items: return None item = received_items[0] rate_item = RateItem(name=str(item[0]), timestamp=item[1]) return rate_item if isawaitable(items): async def _awaiting(): nonlocal items items = await items return _handle_items(items) return _awaiting() assert isinstance(items, list) return _handle_items(items)