Initial commit
This commit is contained in:
@@ -0,0 +1,10 @@
|
||||
# flake8: noqa
|
||||
"""Conrete bucket implementations
|
||||
"""
|
||||
from .in_memory_bucket import InMemoryBucket
|
||||
from .mp_bucket import MultiprocessBucket
|
||||
from .postgres import PostgresBucket
|
||||
from .postgres import Queries as PgQueries
|
||||
from .redis_bucket import RedisBucket
|
||||
from .sqlite_bucket import Queries as SQLiteQueries
|
||||
from .sqlite_bucket import SQLiteBucket
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,92 @@
|
||||
"""Naive bucket implementation using built-in list
|
||||
"""
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from ..abstracts import AbstractBucket
|
||||
from ..abstracts import Rate
|
||||
from ..abstracts import RateItem
|
||||
from ..utils import binary_search
|
||||
|
||||
|
||||
class InMemoryBucket(AbstractBucket):
|
||||
"""Simple In-memory Bucket using native list
|
||||
Clock can be either `time.time` or `time.monotonic`
|
||||
When leak, clock is required
|
||||
Pros: fast, safe, and precise
|
||||
Cons: since it resides in local memory, the data is not persistent, nor scalable
|
||||
Usecase: small applications, simple logic
|
||||
"""
|
||||
|
||||
items: List[RateItem]
|
||||
failing_rate: Optional[Rate]
|
||||
|
||||
def __init__(self, rates: List[Rate]):
|
||||
self.rates = sorted(rates, key=lambda r: r.interval)
|
||||
self.items = []
|
||||
|
||||
def put(self, item: RateItem) -> bool:
|
||||
if item.weight == 0:
|
||||
return True
|
||||
|
||||
current_length = len(self.items)
|
||||
after_length = item.weight + current_length
|
||||
|
||||
for rate in self.rates:
|
||||
if after_length < rate.limit:
|
||||
break
|
||||
|
||||
lower_bound_value = item.timestamp - rate.interval
|
||||
lower_bound_idx = binary_search(self.items, lower_bound_value)
|
||||
|
||||
if lower_bound_idx >= 0:
|
||||
count_existing_items = len(self.items) - lower_bound_idx
|
||||
space_available = rate.limit - count_existing_items
|
||||
else:
|
||||
space_available = rate.limit
|
||||
|
||||
if space_available < item.weight:
|
||||
self.failing_rate = rate
|
||||
return False
|
||||
|
||||
self.failing_rate = None
|
||||
|
||||
if item.weight > 1:
|
||||
self.items.extend([item for _ in range(item.weight)])
|
||||
else:
|
||||
self.items.append(item)
|
||||
|
||||
return True
|
||||
|
||||
def leak(self, current_timestamp: Optional[int] = None) -> int:
|
||||
assert current_timestamp is not None
|
||||
if self.items:
|
||||
max_interval = self.rates[-1].interval
|
||||
lower_bound = current_timestamp - max_interval
|
||||
|
||||
if lower_bound > self.items[-1].timestamp:
|
||||
remove_count = len(self.items)
|
||||
del self.items[:]
|
||||
return remove_count
|
||||
|
||||
if lower_bound < self.items[0].timestamp:
|
||||
return 0
|
||||
|
||||
idx = binary_search(self.items, lower_bound)
|
||||
del self.items[:idx]
|
||||
return idx
|
||||
|
||||
return 0
|
||||
|
||||
def flush(self) -> None:
|
||||
self.failing_rate = None
|
||||
del self.items[:]
|
||||
|
||||
def count(self) -> int:
|
||||
return len(self.items)
|
||||
|
||||
def peek(self, index: int) -> Optional[RateItem]:
|
||||
if not self.items:
|
||||
return None
|
||||
|
||||
return self.items[-1 - index] if abs(index) < self.count() else None
|
||||
@@ -0,0 +1,53 @@
|
||||
"""multiprocessing In-memory Bucket using a multiprocessing.Manager.ListProxy
|
||||
and a multiprocessing.Lock.
|
||||
"""
|
||||
from multiprocessing import Manager
|
||||
from multiprocessing import RLock
|
||||
from multiprocessing.managers import ListProxy
|
||||
from multiprocessing.synchronize import RLock as LockType
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from ..abstracts import Rate
|
||||
from ..abstracts import RateItem
|
||||
from pyrate_limiter.buckets import InMemoryBucket
|
||||
|
||||
|
||||
class MultiprocessBucket(InMemoryBucket):
|
||||
|
||||
items: List[RateItem] # ListProxy
|
||||
mp_lock: LockType
|
||||
|
||||
def __init__(self, rates: List[Rate], items: List[RateItem], mp_lock: LockType):
|
||||
|
||||
if not isinstance(items, ListProxy):
|
||||
raise ValueError("items must be a ListProxy")
|
||||
|
||||
self.rates = sorted(rates, key=lambda r: r.interval)
|
||||
self.items = items
|
||||
self.mp_lock = mp_lock
|
||||
|
||||
def put(self, item: RateItem) -> bool:
|
||||
with self.mp_lock:
|
||||
return super().put(item)
|
||||
|
||||
def leak(self, current_timestamp: Optional[int] = None) -> int:
|
||||
with self.mp_lock:
|
||||
return super().leak(current_timestamp)
|
||||
|
||||
def limiter_lock(self):
|
||||
return self.mp_lock
|
||||
|
||||
@classmethod
|
||||
def init(
|
||||
cls,
|
||||
rates: List[Rate],
|
||||
):
|
||||
"""
|
||||
Creates a single ListProxy so that this bucket can be shared across multiple processes.
|
||||
"""
|
||||
shared_items: List[RateItem] = Manager().list() # type: ignore[assignment]
|
||||
|
||||
mp_lock: LockType = RLock()
|
||||
|
||||
return cls(rates=rates, items=shared_items, mp_lock=mp_lock)
|
||||
@@ -0,0 +1,166 @@
|
||||
"""A bucket using PostgreSQL as backend
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Awaitable
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from ..abstracts import AbstractBucket
|
||||
from ..abstracts import Rate
|
||||
from ..abstracts import RateItem
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from psycopg_pool import ConnectionPool # type: ignore[import-untyped]
|
||||
|
||||
|
||||
class Queries:
|
||||
CREATE_BUCKET_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS {table} (
|
||||
name VARCHAR,
|
||||
weight SMALLINT,
|
||||
item_timestamp TIMESTAMP
|
||||
)
|
||||
"""
|
||||
CREATE_INDEX_ON_TIMESTAMP = """
|
||||
CREATE INDEX IF NOT EXISTS {index} ON {table} (item_timestamp)
|
||||
"""
|
||||
COUNT = """
|
||||
SELECT COUNT(*) FROM {table}
|
||||
"""
|
||||
PUT = """
|
||||
INSERT INTO {table} (name, weight, item_timestamp) VALUES (%s, %s, TO_TIMESTAMP(%s))
|
||||
"""
|
||||
FLUSH = """
|
||||
DELETE FROM {table}
|
||||
"""
|
||||
PEEK = """
|
||||
SELECT name, weight, (extract(EPOCH FROM item_timestamp) * 1000) as item_timestamp
|
||||
FROM {table}
|
||||
ORDER BY item_timestamp DESC
|
||||
LIMIT 1
|
||||
OFFSET {offset}
|
||||
"""
|
||||
LEAK = """
|
||||
DELETE FROM {table} WHERE item_timestamp < TO_TIMESTAMP({timestamp})
|
||||
"""
|
||||
LEAK_COUNT = """
|
||||
SELECT COUNT(*) FROM {table} WHERE item_timestamp < TO_TIMESTAMP({timestamp})
|
||||
"""
|
||||
|
||||
|
||||
class PostgresBucket(AbstractBucket):
|
||||
table: str
|
||||
pool: ConnectionPool
|
||||
|
||||
def __init__(self, pool: ConnectionPool, table: str, rates: List[Rate]):
|
||||
self.table = table.lower()
|
||||
self.pool = pool
|
||||
assert rates
|
||||
self.rates = rates
|
||||
self._full_tbl = f'ratelimit___{self.table}'
|
||||
self._create_table()
|
||||
|
||||
@contextmanager
|
||||
def _get_conn(self):
|
||||
with self.pool.connection() as conn:
|
||||
yield conn
|
||||
|
||||
def _create_table(self):
|
||||
with self._get_conn() as conn:
|
||||
conn.execute(Queries.CREATE_BUCKET_TABLE.format(table=self._full_tbl))
|
||||
index_name = f'timestampIndex_{self.table}'
|
||||
conn.execute(Queries.CREATE_INDEX_ON_TIMESTAMP.format(table=self._full_tbl, index=index_name))
|
||||
|
||||
def put(self, item: RateItem) -> Union[bool, Awaitable[bool]]:
|
||||
"""Put an item (typically the current time) in the bucket
|
||||
return true if successful, otherwise false
|
||||
"""
|
||||
if item.weight == 0:
|
||||
return True
|
||||
|
||||
with self._get_conn() as conn:
|
||||
for rate in self.rates:
|
||||
bound = f"SELECT TO_TIMESTAMP({item.timestamp / 1000}) - INTERVAL '{rate.interval} milliseconds'"
|
||||
query = f'SELECT COUNT(*) FROM {self._full_tbl} WHERE item_timestamp >= ({bound})'
|
||||
cur = conn.execute(query)
|
||||
count = int(cur.fetchone()[0])
|
||||
cur.close()
|
||||
|
||||
if rate.limit - count < item.weight:
|
||||
self.failing_rate = rate
|
||||
return False
|
||||
|
||||
self.failing_rate = None
|
||||
|
||||
query = Queries.PUT.format(table=self._full_tbl)
|
||||
|
||||
# https://www.psycopg.org/docs/extras.html#fast-exec
|
||||
|
||||
for _ in range(item.weight):
|
||||
conn.execute(query, (item.name, item.weight, item.timestamp / 1000))
|
||||
|
||||
return True
|
||||
|
||||
def leak(
|
||||
self,
|
||||
current_timestamp: Optional[int] = None,
|
||||
) -> Union[int, Awaitable[int]]:
|
||||
"""leaking bucket - removing items that are outdated"""
|
||||
assert current_timestamp is not None, "current-time must be passed on for leak"
|
||||
lower_bound = current_timestamp - self.rates[-1].interval
|
||||
|
||||
if lower_bound <= 0:
|
||||
return 0
|
||||
|
||||
count = 0
|
||||
|
||||
with self._get_conn() as conn:
|
||||
conn = conn.execute(Queries.LEAK_COUNT.format(table=self._full_tbl, timestamp=lower_bound / 1000))
|
||||
result = conn.fetchone()
|
||||
|
||||
if result:
|
||||
conn.execute(Queries.LEAK.format(table=self._full_tbl, timestamp=lower_bound / 1000))
|
||||
count = int(result[0])
|
||||
|
||||
return count
|
||||
|
||||
def flush(self) -> Union[None, Awaitable[None]]:
|
||||
"""Flush the whole bucket
|
||||
- Must remove `failing-rate` after flushing
|
||||
"""
|
||||
with self._get_conn() as conn:
|
||||
conn.execute(Queries.FLUSH.format(table=self._full_tbl))
|
||||
self.failing_rate = None
|
||||
|
||||
return None
|
||||
|
||||
def count(self) -> Union[int, Awaitable[int]]:
|
||||
"""Count number of items in the bucket"""
|
||||
count = 0
|
||||
with self._get_conn() as conn:
|
||||
conn = conn.execute(Queries.COUNT.format(table=self._full_tbl))
|
||||
result = conn.fetchone()
|
||||
assert result
|
||||
count = int(result[0])
|
||||
|
||||
return count
|
||||
|
||||
def peek(self, index: int) -> Union[Optional[RateItem], Awaitable[Optional[RateItem]]]:
|
||||
"""Peek at the rate-item at a specific index in latest-to-earliest order
|
||||
NOTE: The reason we cannot peek from the start of the queue(earliest-to-latest) is
|
||||
we can't really tell how many outdated items are still in the queue
|
||||
"""
|
||||
item = None
|
||||
|
||||
with self._get_conn() as conn:
|
||||
conn = conn.execute(Queries.PEEK.format(table=self._full_tbl, offset=index))
|
||||
result = conn.fetchone()
|
||||
if result:
|
||||
name, weight, timestamp = result[0], int(result[1]), int(result[2])
|
||||
item = RateItem(name=name, weight=weight, timestamp=timestamp)
|
||||
|
||||
return item
|
||||
@@ -0,0 +1,187 @@
|
||||
"""Bucket implementation using Redis
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from inspect import isawaitable
|
||||
from typing import Awaitable
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from ..abstracts import AbstractBucket
|
||||
from ..abstracts import Rate
|
||||
from ..abstracts import RateItem
|
||||
from ..utils import id_generator
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis import Redis
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
|
||||
|
||||
class LuaScript:
|
||||
"""Scripts that deal with bucket operations"""
|
||||
|
||||
PUT_ITEM = """
|
||||
local bucket = KEYS[1]
|
||||
local now = ARGV[1]
|
||||
local space_required = tonumber(ARGV[2])
|
||||
local item_name = ARGV[3]
|
||||
local rates_count = tonumber(ARGV[4])
|
||||
|
||||
for i=1,rates_count do
|
||||
local offset = (i - 1) * 2
|
||||
local interval = tonumber(ARGV[5 + offset])
|
||||
local limit = tonumber(ARGV[5 + offset + 1])
|
||||
local count = redis.call('ZCOUNT', bucket, now - interval, now)
|
||||
local space_available = limit - tonumber(count)
|
||||
if space_available < space_required then
|
||||
return i - 1
|
||||
end
|
||||
end
|
||||
|
||||
for i=1,space_required do
|
||||
redis.call('ZADD', bucket, now, item_name..i)
|
||||
end
|
||||
return -1
|
||||
"""
|
||||
|
||||
|
||||
class RedisBucket(AbstractBucket):
|
||||
"""A bucket using redis for storing data
|
||||
- We are not using redis' built-in TIME since it is non-deterministic
|
||||
- In distributed context, use local server time or a remote time server
|
||||
- Each bucket instance use a dedicated connection to avoid race-condition
|
||||
- can be either sync or async
|
||||
"""
|
||||
|
||||
rates: List[Rate]
|
||||
failing_rate: Optional[Rate]
|
||||
bucket_key: str
|
||||
script_hash: str
|
||||
redis: Union[Redis, AsyncRedis]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rates: List[Rate],
|
||||
redis: Union[Redis, AsyncRedis],
|
||||
bucket_key: str,
|
||||
script_hash: str,
|
||||
):
|
||||
self.rates = rates
|
||||
self.redis = redis
|
||||
self.bucket_key = bucket_key
|
||||
self.script_hash = script_hash
|
||||
self.failing_rate = None
|
||||
|
||||
@classmethod
|
||||
def init(
|
||||
cls,
|
||||
rates: List[Rate],
|
||||
redis: Union[Redis, AsyncRedis],
|
||||
bucket_key: str,
|
||||
):
|
||||
script_hash = redis.script_load(LuaScript.PUT_ITEM)
|
||||
|
||||
if isawaitable(script_hash):
|
||||
|
||||
async def _async_init():
|
||||
nonlocal script_hash
|
||||
script_hash = await script_hash
|
||||
return cls(rates, redis, bucket_key, script_hash)
|
||||
|
||||
return _async_init()
|
||||
|
||||
return cls(rates, redis, bucket_key, script_hash)
|
||||
|
||||
def _check_and_insert(self, item: RateItem) -> Union[Rate, None, Awaitable[Optional[Rate]]]:
|
||||
keys = [self.bucket_key]
|
||||
|
||||
args = [
|
||||
item.timestamp,
|
||||
item.weight,
|
||||
# NOTE: this is to avoid key collision since we are using ZSET
|
||||
f"{item.name}:{id_generator()}:", # noqa: E231
|
||||
len(self.rates),
|
||||
*[value for rate in self.rates for value in (rate.interval, rate.limit)],
|
||||
]
|
||||
|
||||
idx = self.redis.evalsha(self.script_hash, len(keys), *keys, *args)
|
||||
|
||||
def _handle_sync(returned_idx: int):
|
||||
assert isinstance(returned_idx, int), "Not int"
|
||||
if returned_idx < 0:
|
||||
return None
|
||||
|
||||
return self.rates[returned_idx]
|
||||
|
||||
async def _handle_async(returned_idx: Awaitable[int]):
|
||||
assert isawaitable(returned_idx), "Not corotine"
|
||||
awaited_idx = await returned_idx
|
||||
return _handle_sync(awaited_idx)
|
||||
|
||||
return _handle_async(idx) if isawaitable(idx) else _handle_sync(idx)
|
||||
|
||||
def put(self, item: RateItem) -> Union[bool, Awaitable[bool]]:
|
||||
"""Add item to key"""
|
||||
failing_rate = self._check_and_insert(item)
|
||||
if isawaitable(failing_rate):
|
||||
|
||||
async def _handle_async():
|
||||
self.failing_rate = await failing_rate
|
||||
return not bool(self.failing_rate)
|
||||
|
||||
return _handle_async()
|
||||
|
||||
assert isinstance(failing_rate, Rate) or failing_rate is None
|
||||
self.failing_rate = failing_rate
|
||||
return not bool(self.failing_rate)
|
||||
|
||||
def leak(self, current_timestamp: Optional[int] = None) -> Union[int, Awaitable[int]]:
|
||||
assert current_timestamp is not None
|
||||
return self.redis.zremrangebyscore(
|
||||
self.bucket_key,
|
||||
0,
|
||||
current_timestamp - self.rates[-1].interval,
|
||||
)
|
||||
|
||||
def flush(self):
|
||||
self.failing_rate = None
|
||||
return self.redis.delete(self.bucket_key)
|
||||
|
||||
def count(self):
|
||||
return self.redis.zcard(self.bucket_key)
|
||||
|
||||
def peek(self, index: int) -> Union[RateItem, None, Awaitable[Optional[RateItem]]]:
|
||||
items = self.redis.zrange(
|
||||
self.bucket_key,
|
||||
-1 - index,
|
||||
-1 - index,
|
||||
withscores=True,
|
||||
score_cast_func=int,
|
||||
)
|
||||
|
||||
if not items:
|
||||
return None
|
||||
|
||||
def _handle_items(received_items: List[Tuple[str, int]]):
|
||||
if not received_items:
|
||||
return None
|
||||
|
||||
item = received_items[0]
|
||||
rate_item = RateItem(name=str(item[0]), timestamp=item[1])
|
||||
return rate_item
|
||||
|
||||
if isawaitable(items):
|
||||
|
||||
async def _awaiting():
|
||||
nonlocal items
|
||||
items = await items
|
||||
return _handle_items(items)
|
||||
|
||||
return _awaiting()
|
||||
|
||||
assert isinstance(items, list)
|
||||
return _handle_items(items)
|
||||
@@ -0,0 +1,250 @@
|
||||
"""Bucket implementation using SQLite"""
|
||||
import logging
|
||||
import sqlite3
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from tempfile import gettempdir
|
||||
from threading import RLock
|
||||
from time import time
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
from ..abstracts import AbstractBucket
|
||||
from ..abstracts import Rate
|
||||
from ..abstracts import RateItem
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Queries:
|
||||
CREATE_BUCKET_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS '{table}' (
|
||||
name VARCHAR,
|
||||
item_timestamp INTEGER
|
||||
)
|
||||
"""
|
||||
CREATE_INDEX_ON_TIMESTAMP = """
|
||||
CREATE INDEX IF NOT EXISTS '{index_name}' ON '{table_name}' (item_timestamp)
|
||||
"""
|
||||
COUNT_BEFORE_INSERT = """
|
||||
SELECT :interval{index} as interval, COUNT(*) FROM '{table}'
|
||||
WHERE item_timestamp >= :current_timestamp - :interval{index}
|
||||
"""
|
||||
PUT_ITEM = """
|
||||
INSERT INTO '{table}' (name, item_timestamp) VALUES %s
|
||||
"""
|
||||
LEAK = """
|
||||
DELETE FROM "{table}" WHERE rowid IN (
|
||||
SELECT rowid FROM "{table}" ORDER BY item_timestamp ASC LIMIT {count});
|
||||
""".strip()
|
||||
COUNT_BEFORE_LEAK = """SELECT COUNT(*) FROM '{table}' WHERE item_timestamp < {current_timestamp} - {interval}"""
|
||||
FLUSH = """DELETE FROM '{table}'"""
|
||||
# The below sqls are for testing only
|
||||
DROP_TABLE = "DROP TABLE IF EXISTS '{table}'"
|
||||
DROP_INDEX = "DROP INDEX IF EXISTS '{index}'"
|
||||
COUNT_ALL = "SELECT COUNT(*) FROM '{table}'"
|
||||
GET_ALL_ITEM = "SELECT * FROM '{table}' ORDER BY item_timestamp ASC"
|
||||
GET_FIRST_ITEM = (
|
||||
"SELECT name, item_timestamp FROM '{table}' ORDER BY item_timestamp ASC"
|
||||
)
|
||||
GET_LAG = """
|
||||
SELECT (strftime ('%s', 'now') || substr(strftime ('%f', 'now'), 4)) - (
|
||||
SELECT item_timestamp
|
||||
FROM '{table}'
|
||||
ORDER BY item_timestamp
|
||||
ASC
|
||||
LIMIT 1
|
||||
)
|
||||
"""
|
||||
PEEK = 'SELECT * FROM "{table}" ORDER BY item_timestamp DESC LIMIT 1 OFFSET {count}'
|
||||
|
||||
|
||||
class SQLiteBucket(AbstractBucket):
|
||||
"""For sqlite bucket, we are using the sql time function as the clock
|
||||
item's timestamp wont matter here
|
||||
"""
|
||||
|
||||
rates: List[Rate]
|
||||
failing_rate: Optional[Rate]
|
||||
conn: sqlite3.Connection
|
||||
table: str
|
||||
full_count_query: str
|
||||
lock: RLock
|
||||
use_limiter_lock: bool
|
||||
|
||||
def __init__(
|
||||
self, rates: List[Rate], conn: sqlite3.Connection, table: str, lock=None
|
||||
):
|
||||
self.conn = conn
|
||||
self.table = table
|
||||
self.rates = rates
|
||||
|
||||
if not lock:
|
||||
self.use_limiter_lock = False
|
||||
self.lock = RLock()
|
||||
else:
|
||||
self.use_limiter_lock = True
|
||||
self.lock = lock
|
||||
|
||||
def limiter_lock(self):
|
||||
if self.use_limiter_lock:
|
||||
return self.lock
|
||||
else:
|
||||
return None
|
||||
|
||||
def _build_full_count_query(self, current_timestamp: int) -> Tuple[str, dict]:
|
||||
full_query: List[str] = []
|
||||
|
||||
parameters = {"current_timestamp": current_timestamp}
|
||||
|
||||
for index, rate in enumerate(self.rates):
|
||||
parameters[f"interval{index}"] = rate.interval
|
||||
query = Queries.COUNT_BEFORE_INSERT.format(table=self.table, index=index)
|
||||
full_query.append(query)
|
||||
|
||||
join_full_query = (
|
||||
" union ".join(full_query) if len(full_query) > 1 else full_query[0]
|
||||
)
|
||||
return join_full_query, parameters
|
||||
|
||||
def put(self, item: RateItem) -> bool:
|
||||
with self.lock:
|
||||
query, parameters = self._build_full_count_query(item.timestamp)
|
||||
cur = self.conn.execute(query, parameters)
|
||||
rate_limit_counts = cur.fetchall()
|
||||
cur.close()
|
||||
|
||||
for idx, result in enumerate(rate_limit_counts):
|
||||
interval, count = result
|
||||
rate = self.rates[idx]
|
||||
assert interval == rate.interval
|
||||
space_available = rate.limit - count
|
||||
|
||||
if space_available < item.weight:
|
||||
self.failing_rate = rate
|
||||
return False
|
||||
|
||||
items = ", ".join(
|
||||
[f"('{name}', {item.timestamp})" for name in [item.name] * item.weight]
|
||||
)
|
||||
query = (Queries.PUT_ITEM.format(table=self.table)) % items
|
||||
self.conn.execute(query).close()
|
||||
self.conn.commit()
|
||||
return True
|
||||
|
||||
def leak(self, current_timestamp: Optional[int] = None) -> int:
|
||||
"""Leaking/clean up bucket"""
|
||||
with self.lock:
|
||||
assert current_timestamp is not None
|
||||
query = Queries.COUNT_BEFORE_LEAK.format(
|
||||
table=self.table,
|
||||
interval=self.rates[-1].interval,
|
||||
current_timestamp=current_timestamp,
|
||||
)
|
||||
cur = self.conn.execute(query)
|
||||
count = cur.fetchone()[0]
|
||||
query = Queries.LEAK.format(table=self.table, count=count)
|
||||
cur.execute(query)
|
||||
cur.close()
|
||||
self.conn.commit()
|
||||
return count
|
||||
|
||||
def flush(self) -> None:
|
||||
with self.lock:
|
||||
self.conn.execute(Queries.FLUSH.format(table=self.table)).close()
|
||||
self.conn.commit()
|
||||
self.failing_rate = None
|
||||
|
||||
def count(self) -> int:
|
||||
with self.lock:
|
||||
cur = self.conn.execute(
|
||||
Queries.COUNT_ALL.format(table=self.table)
|
||||
)
|
||||
ret = cur.fetchone()[0]
|
||||
cur.close()
|
||||
return ret
|
||||
|
||||
def peek(self, index: int) -> Optional[RateItem]:
|
||||
with self.lock:
|
||||
query = Queries.PEEK.format(table=self.table, count=index)
|
||||
cur = self.conn.execute(query)
|
||||
item = cur.fetchone()
|
||||
cur.close()
|
||||
|
||||
if not item:
|
||||
return None
|
||||
|
||||
return RateItem(item[0], item[1])
|
||||
|
||||
@classmethod
|
||||
def init_from_file(
|
||||
cls,
|
||||
rates: List[Rate],
|
||||
table: str = "rate_bucket",
|
||||
db_path: Optional[str] = None,
|
||||
create_new_table: bool = True,
|
||||
use_file_lock: bool = False
|
||||
) -> "SQLiteBucket":
|
||||
|
||||
if db_path is None and use_file_lock:
|
||||
raise ValueError("db_path must be specified when using use_file_lock")
|
||||
|
||||
if db_path is None:
|
||||
temp_dir = Path(gettempdir())
|
||||
db_path = str(temp_dir / f"pyrate_limiter_{time()}.sqlite")
|
||||
|
||||
# TBD: FileLock switched to a thread-local FileLock in 3.11.0.
|
||||
# Should we set FileLock's thread_local to False, for cases where user is both multiprocessing & threading?
|
||||
# As is, the file lock should be Multi Process - Single Thread and non-filelock is Single Process - Multi Thread
|
||||
# A hybrid lock may be needed to gracefully handle both cases
|
||||
file_lock = None
|
||||
file_lock_ctx = nullcontext()
|
||||
|
||||
if use_file_lock:
|
||||
try:
|
||||
from filelock import FileLock # type: ignore[import-untyped]
|
||||
file_lock = FileLock(db_path + ".lock") # type: ignore[no-redef]
|
||||
file_lock_ctx: Union[nullcontext, FileLock] = file_lock # type: ignore[no-redef]
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"filelock is required for file locking. "
|
||||
"Please install it as optional dependency"
|
||||
)
|
||||
|
||||
with file_lock_ctx:
|
||||
assert db_path is not None
|
||||
assert db_path.endswith(".sqlite"), (
|
||||
"Please provide a valid sqlite file path"
|
||||
)
|
||||
|
||||
sqlite_connection = sqlite3.connect(
|
||||
db_path,
|
||||
isolation_level="DEFERRED",
|
||||
check_same_thread=False,
|
||||
)
|
||||
|
||||
cur = sqlite_connection.cursor()
|
||||
if use_file_lock:
|
||||
# https://www.sqlite.org/wal.html
|
||||
cur.execute("PRAGMA journal_mode=WAL;")
|
||||
|
||||
# https://www.sqlite.org/pragma.html#pragma_synchronous
|
||||
cur.execute("PRAGMA synchronous=NORMAL;")
|
||||
|
||||
if create_new_table:
|
||||
cur.execute(
|
||||
Queries.CREATE_BUCKET_TABLE.format(table=table)
|
||||
)
|
||||
|
||||
create_idx_query = Queries.CREATE_INDEX_ON_TIMESTAMP.format(
|
||||
index_name=f"idx_{table}_rate_item_timestamp",
|
||||
table_name=table,
|
||||
)
|
||||
|
||||
cur.execute(create_idx_query)
|
||||
cur.close()
|
||||
sqlite_connection.commit()
|
||||
|
||||
return cls(rates, sqlite_connection, table=table, lock=file_lock)
|
||||
Reference in New Issue
Block a user