Files
2025-12-09 12:13:01 +01:00

167 lines
5.3 KiB
Python

"""A bucket using PostgreSQL as backend
"""
from __future__ import annotations
from contextlib import contextmanager
from typing import Awaitable
from typing import List
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from ..abstracts import AbstractBucket
from ..abstracts import Rate
from ..abstracts import RateItem
if TYPE_CHECKING:
from psycopg_pool import ConnectionPool # type: ignore[import-untyped]
class Queries:
CREATE_BUCKET_TABLE = """
CREATE TABLE IF NOT EXISTS {table} (
name VARCHAR,
weight SMALLINT,
item_timestamp TIMESTAMP
)
"""
CREATE_INDEX_ON_TIMESTAMP = """
CREATE INDEX IF NOT EXISTS {index} ON {table} (item_timestamp)
"""
COUNT = """
SELECT COUNT(*) FROM {table}
"""
PUT = """
INSERT INTO {table} (name, weight, item_timestamp) VALUES (%s, %s, TO_TIMESTAMP(%s))
"""
FLUSH = """
DELETE FROM {table}
"""
PEEK = """
SELECT name, weight, (extract(EPOCH FROM item_timestamp) * 1000) as item_timestamp
FROM {table}
ORDER BY item_timestamp DESC
LIMIT 1
OFFSET {offset}
"""
LEAK = """
DELETE FROM {table} WHERE item_timestamp < TO_TIMESTAMP({timestamp})
"""
LEAK_COUNT = """
SELECT COUNT(*) FROM {table} WHERE item_timestamp < TO_TIMESTAMP({timestamp})
"""
class PostgresBucket(AbstractBucket):
table: str
pool: ConnectionPool
def __init__(self, pool: ConnectionPool, table: str, rates: List[Rate]):
self.table = table.lower()
self.pool = pool
assert rates
self.rates = rates
self._full_tbl = f'ratelimit___{self.table}'
self._create_table()
@contextmanager
def _get_conn(self):
with self.pool.connection() as conn:
yield conn
def _create_table(self):
with self._get_conn() as conn:
conn.execute(Queries.CREATE_BUCKET_TABLE.format(table=self._full_tbl))
index_name = f'timestampIndex_{self.table}'
conn.execute(Queries.CREATE_INDEX_ON_TIMESTAMP.format(table=self._full_tbl, index=index_name))
def put(self, item: RateItem) -> Union[bool, Awaitable[bool]]:
"""Put an item (typically the current time) in the bucket
return true if successful, otherwise false
"""
if item.weight == 0:
return True
with self._get_conn() as conn:
for rate in self.rates:
bound = f"SELECT TO_TIMESTAMP({item.timestamp / 1000}) - INTERVAL '{rate.interval} milliseconds'"
query = f'SELECT COUNT(*) FROM {self._full_tbl} WHERE item_timestamp >= ({bound})'
cur = conn.execute(query)
count = int(cur.fetchone()[0])
cur.close()
if rate.limit - count < item.weight:
self.failing_rate = rate
return False
self.failing_rate = None
query = Queries.PUT.format(table=self._full_tbl)
# https://www.psycopg.org/docs/extras.html#fast-exec
for _ in range(item.weight):
conn.execute(query, (item.name, item.weight, item.timestamp / 1000))
return True
def leak(
self,
current_timestamp: Optional[int] = None,
) -> Union[int, Awaitable[int]]:
"""leaking bucket - removing items that are outdated"""
assert current_timestamp is not None, "current-time must be passed on for leak"
lower_bound = current_timestamp - self.rates[-1].interval
if lower_bound <= 0:
return 0
count = 0
with self._get_conn() as conn:
conn = conn.execute(Queries.LEAK_COUNT.format(table=self._full_tbl, timestamp=lower_bound / 1000))
result = conn.fetchone()
if result:
conn.execute(Queries.LEAK.format(table=self._full_tbl, timestamp=lower_bound / 1000))
count = int(result[0])
return count
def flush(self) -> Union[None, Awaitable[None]]:
"""Flush the whole bucket
- Must remove `failing-rate` after flushing
"""
with self._get_conn() as conn:
conn.execute(Queries.FLUSH.format(table=self._full_tbl))
self.failing_rate = None
return None
def count(self) -> Union[int, Awaitable[int]]:
"""Count number of items in the bucket"""
count = 0
with self._get_conn() as conn:
conn = conn.execute(Queries.COUNT.format(table=self._full_tbl))
result = conn.fetchone()
assert result
count = int(result[0])
return count
def peek(self, index: int) -> Union[Optional[RateItem], Awaitable[Optional[RateItem]]]:
"""Peek at the rate-item at a specific index in latest-to-earliest order
NOTE: The reason we cannot peek from the start of the queue(earliest-to-latest) is
we can't really tell how many outdated items are still in the queue
"""
item = None
with self._get_conn() as conn:
conn = conn.execute(Queries.PEEK.format(table=self._full_tbl, offset=index))
result = conn.fetchone()
if result:
name, weight, timestamp = result[0], int(result[1]), int(result[2])
item = RateItem(name=name, weight=weight, timestamp=timestamp)
return item