251 lines
8.2 KiB
Python
251 lines
8.2 KiB
Python
"""Bucket implementation using SQLite"""
|
|
import logging
|
|
import sqlite3
|
|
from contextlib import nullcontext
|
|
from pathlib import Path
|
|
from tempfile import gettempdir
|
|
from threading import RLock
|
|
from time import time
|
|
from typing import List
|
|
from typing import Optional
|
|
from typing import Tuple
|
|
from typing import Union
|
|
|
|
from ..abstracts import AbstractBucket
|
|
from ..abstracts import Rate
|
|
from ..abstracts import RateItem
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Queries:
|
|
CREATE_BUCKET_TABLE = """
|
|
CREATE TABLE IF NOT EXISTS '{table}' (
|
|
name VARCHAR,
|
|
item_timestamp INTEGER
|
|
)
|
|
"""
|
|
CREATE_INDEX_ON_TIMESTAMP = """
|
|
CREATE INDEX IF NOT EXISTS '{index_name}' ON '{table_name}' (item_timestamp)
|
|
"""
|
|
COUNT_BEFORE_INSERT = """
|
|
SELECT :interval{index} as interval, COUNT(*) FROM '{table}'
|
|
WHERE item_timestamp >= :current_timestamp - :interval{index}
|
|
"""
|
|
PUT_ITEM = """
|
|
INSERT INTO '{table}' (name, item_timestamp) VALUES %s
|
|
"""
|
|
LEAK = """
|
|
DELETE FROM "{table}" WHERE rowid IN (
|
|
SELECT rowid FROM "{table}" ORDER BY item_timestamp ASC LIMIT {count});
|
|
""".strip()
|
|
COUNT_BEFORE_LEAK = """SELECT COUNT(*) FROM '{table}' WHERE item_timestamp < {current_timestamp} - {interval}"""
|
|
FLUSH = """DELETE FROM '{table}'"""
|
|
# The below sqls are for testing only
|
|
DROP_TABLE = "DROP TABLE IF EXISTS '{table}'"
|
|
DROP_INDEX = "DROP INDEX IF EXISTS '{index}'"
|
|
COUNT_ALL = "SELECT COUNT(*) FROM '{table}'"
|
|
GET_ALL_ITEM = "SELECT * FROM '{table}' ORDER BY item_timestamp ASC"
|
|
GET_FIRST_ITEM = (
|
|
"SELECT name, item_timestamp FROM '{table}' ORDER BY item_timestamp ASC"
|
|
)
|
|
GET_LAG = """
|
|
SELECT (strftime ('%s', 'now') || substr(strftime ('%f', 'now'), 4)) - (
|
|
SELECT item_timestamp
|
|
FROM '{table}'
|
|
ORDER BY item_timestamp
|
|
ASC
|
|
LIMIT 1
|
|
)
|
|
"""
|
|
PEEK = 'SELECT * FROM "{table}" ORDER BY item_timestamp DESC LIMIT 1 OFFSET {count}'
|
|
|
|
|
|
class SQLiteBucket(AbstractBucket):
|
|
"""For sqlite bucket, we are using the sql time function as the clock
|
|
item's timestamp wont matter here
|
|
"""
|
|
|
|
rates: List[Rate]
|
|
failing_rate: Optional[Rate]
|
|
conn: sqlite3.Connection
|
|
table: str
|
|
full_count_query: str
|
|
lock: RLock
|
|
use_limiter_lock: bool
|
|
|
|
def __init__(
|
|
self, rates: List[Rate], conn: sqlite3.Connection, table: str, lock=None
|
|
):
|
|
self.conn = conn
|
|
self.table = table
|
|
self.rates = rates
|
|
|
|
if not lock:
|
|
self.use_limiter_lock = False
|
|
self.lock = RLock()
|
|
else:
|
|
self.use_limiter_lock = True
|
|
self.lock = lock
|
|
|
|
def limiter_lock(self):
|
|
if self.use_limiter_lock:
|
|
return self.lock
|
|
else:
|
|
return None
|
|
|
|
def _build_full_count_query(self, current_timestamp: int) -> Tuple[str, dict]:
|
|
full_query: List[str] = []
|
|
|
|
parameters = {"current_timestamp": current_timestamp}
|
|
|
|
for index, rate in enumerate(self.rates):
|
|
parameters[f"interval{index}"] = rate.interval
|
|
query = Queries.COUNT_BEFORE_INSERT.format(table=self.table, index=index)
|
|
full_query.append(query)
|
|
|
|
join_full_query = (
|
|
" union ".join(full_query) if len(full_query) > 1 else full_query[0]
|
|
)
|
|
return join_full_query, parameters
|
|
|
|
def put(self, item: RateItem) -> bool:
|
|
with self.lock:
|
|
query, parameters = self._build_full_count_query(item.timestamp)
|
|
cur = self.conn.execute(query, parameters)
|
|
rate_limit_counts = cur.fetchall()
|
|
cur.close()
|
|
|
|
for idx, result in enumerate(rate_limit_counts):
|
|
interval, count = result
|
|
rate = self.rates[idx]
|
|
assert interval == rate.interval
|
|
space_available = rate.limit - count
|
|
|
|
if space_available < item.weight:
|
|
self.failing_rate = rate
|
|
return False
|
|
|
|
items = ", ".join(
|
|
[f"('{name}', {item.timestamp})" for name in [item.name] * item.weight]
|
|
)
|
|
query = (Queries.PUT_ITEM.format(table=self.table)) % items
|
|
self.conn.execute(query).close()
|
|
self.conn.commit()
|
|
return True
|
|
|
|
def leak(self, current_timestamp: Optional[int] = None) -> int:
|
|
"""Leaking/clean up bucket"""
|
|
with self.lock:
|
|
assert current_timestamp is not None
|
|
query = Queries.COUNT_BEFORE_LEAK.format(
|
|
table=self.table,
|
|
interval=self.rates[-1].interval,
|
|
current_timestamp=current_timestamp,
|
|
)
|
|
cur = self.conn.execute(query)
|
|
count = cur.fetchone()[0]
|
|
query = Queries.LEAK.format(table=self.table, count=count)
|
|
cur.execute(query)
|
|
cur.close()
|
|
self.conn.commit()
|
|
return count
|
|
|
|
def flush(self) -> None:
|
|
with self.lock:
|
|
self.conn.execute(Queries.FLUSH.format(table=self.table)).close()
|
|
self.conn.commit()
|
|
self.failing_rate = None
|
|
|
|
def count(self) -> int:
|
|
with self.lock:
|
|
cur = self.conn.execute(
|
|
Queries.COUNT_ALL.format(table=self.table)
|
|
)
|
|
ret = cur.fetchone()[0]
|
|
cur.close()
|
|
return ret
|
|
|
|
def peek(self, index: int) -> Optional[RateItem]:
|
|
with self.lock:
|
|
query = Queries.PEEK.format(table=self.table, count=index)
|
|
cur = self.conn.execute(query)
|
|
item = cur.fetchone()
|
|
cur.close()
|
|
|
|
if not item:
|
|
return None
|
|
|
|
return RateItem(item[0], item[1])
|
|
|
|
@classmethod
|
|
def init_from_file(
|
|
cls,
|
|
rates: List[Rate],
|
|
table: str = "rate_bucket",
|
|
db_path: Optional[str] = None,
|
|
create_new_table: bool = True,
|
|
use_file_lock: bool = False
|
|
) -> "SQLiteBucket":
|
|
|
|
if db_path is None and use_file_lock:
|
|
raise ValueError("db_path must be specified when using use_file_lock")
|
|
|
|
if db_path is None:
|
|
temp_dir = Path(gettempdir())
|
|
db_path = str(temp_dir / f"pyrate_limiter_{time()}.sqlite")
|
|
|
|
# TBD: FileLock switched to a thread-local FileLock in 3.11.0.
|
|
# Should we set FileLock's thread_local to False, for cases where user is both multiprocessing & threading?
|
|
# As is, the file lock should be Multi Process - Single Thread and non-filelock is Single Process - Multi Thread
|
|
# A hybrid lock may be needed to gracefully handle both cases
|
|
file_lock = None
|
|
file_lock_ctx = nullcontext()
|
|
|
|
if use_file_lock:
|
|
try:
|
|
from filelock import FileLock # type: ignore[import-untyped]
|
|
file_lock = FileLock(db_path + ".lock") # type: ignore[no-redef]
|
|
file_lock_ctx: Union[nullcontext, FileLock] = file_lock # type: ignore[no-redef]
|
|
except ImportError:
|
|
raise ImportError(
|
|
"filelock is required for file locking. "
|
|
"Please install it as optional dependency"
|
|
)
|
|
|
|
with file_lock_ctx:
|
|
assert db_path is not None
|
|
assert db_path.endswith(".sqlite"), (
|
|
"Please provide a valid sqlite file path"
|
|
)
|
|
|
|
sqlite_connection = sqlite3.connect(
|
|
db_path,
|
|
isolation_level="DEFERRED",
|
|
check_same_thread=False,
|
|
)
|
|
|
|
cur = sqlite_connection.cursor()
|
|
if use_file_lock:
|
|
# https://www.sqlite.org/wal.html
|
|
cur.execute("PRAGMA journal_mode=WAL;")
|
|
|
|
# https://www.sqlite.org/pragma.html#pragma_synchronous
|
|
cur.execute("PRAGMA synchronous=NORMAL;")
|
|
|
|
if create_new_table:
|
|
cur.execute(
|
|
Queries.CREATE_BUCKET_TABLE.format(table=table)
|
|
)
|
|
|
|
create_idx_query = Queries.CREATE_INDEX_ON_TIMESTAMP.format(
|
|
index_name=f"idx_{table}_rate_item_timestamp",
|
|
table_name=table,
|
|
)
|
|
|
|
cur.execute(create_idx_query)
|
|
cur.close()
|
|
sqlite_connection.commit()
|
|
|
|
return cls(rates, sqlite_connection, table=table, lock=file_lock)
|