167 lines
5.3 KiB
Python
167 lines
5.3 KiB
Python
"""A bucket using PostgreSQL as backend
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from contextlib import contextmanager
|
|
from typing import Awaitable
|
|
from typing import List
|
|
from typing import Optional
|
|
from typing import TYPE_CHECKING
|
|
from typing import Union
|
|
|
|
from ..abstracts import AbstractBucket
|
|
from ..abstracts import Rate
|
|
from ..abstracts import RateItem
|
|
|
|
if TYPE_CHECKING:
|
|
from psycopg_pool import ConnectionPool # type: ignore[import-untyped]
|
|
|
|
|
|
class Queries:
|
|
CREATE_BUCKET_TABLE = """
|
|
CREATE TABLE IF NOT EXISTS {table} (
|
|
name VARCHAR,
|
|
weight SMALLINT,
|
|
item_timestamp TIMESTAMP
|
|
)
|
|
"""
|
|
CREATE_INDEX_ON_TIMESTAMP = """
|
|
CREATE INDEX IF NOT EXISTS {index} ON {table} (item_timestamp)
|
|
"""
|
|
COUNT = """
|
|
SELECT COUNT(*) FROM {table}
|
|
"""
|
|
PUT = """
|
|
INSERT INTO {table} (name, weight, item_timestamp) VALUES (%s, %s, TO_TIMESTAMP(%s))
|
|
"""
|
|
FLUSH = """
|
|
DELETE FROM {table}
|
|
"""
|
|
PEEK = """
|
|
SELECT name, weight, (extract(EPOCH FROM item_timestamp) * 1000) as item_timestamp
|
|
FROM {table}
|
|
ORDER BY item_timestamp DESC
|
|
LIMIT 1
|
|
OFFSET {offset}
|
|
"""
|
|
LEAK = """
|
|
DELETE FROM {table} WHERE item_timestamp < TO_TIMESTAMP({timestamp})
|
|
"""
|
|
LEAK_COUNT = """
|
|
SELECT COUNT(*) FROM {table} WHERE item_timestamp < TO_TIMESTAMP({timestamp})
|
|
"""
|
|
|
|
|
|
class PostgresBucket(AbstractBucket):
|
|
table: str
|
|
pool: ConnectionPool
|
|
|
|
def __init__(self, pool: ConnectionPool, table: str, rates: List[Rate]):
|
|
self.table = table.lower()
|
|
self.pool = pool
|
|
assert rates
|
|
self.rates = rates
|
|
self._full_tbl = f'ratelimit___{self.table}'
|
|
self._create_table()
|
|
|
|
@contextmanager
|
|
def _get_conn(self):
|
|
with self.pool.connection() as conn:
|
|
yield conn
|
|
|
|
def _create_table(self):
|
|
with self._get_conn() as conn:
|
|
conn.execute(Queries.CREATE_BUCKET_TABLE.format(table=self._full_tbl))
|
|
index_name = f'timestampIndex_{self.table}'
|
|
conn.execute(Queries.CREATE_INDEX_ON_TIMESTAMP.format(table=self._full_tbl, index=index_name))
|
|
|
|
def put(self, item: RateItem) -> Union[bool, Awaitable[bool]]:
|
|
"""Put an item (typically the current time) in the bucket
|
|
return true if successful, otherwise false
|
|
"""
|
|
if item.weight == 0:
|
|
return True
|
|
|
|
with self._get_conn() as conn:
|
|
for rate in self.rates:
|
|
bound = f"SELECT TO_TIMESTAMP({item.timestamp / 1000}) - INTERVAL '{rate.interval} milliseconds'"
|
|
query = f'SELECT COUNT(*) FROM {self._full_tbl} WHERE item_timestamp >= ({bound})'
|
|
cur = conn.execute(query)
|
|
count = int(cur.fetchone()[0])
|
|
cur.close()
|
|
|
|
if rate.limit - count < item.weight:
|
|
self.failing_rate = rate
|
|
return False
|
|
|
|
self.failing_rate = None
|
|
|
|
query = Queries.PUT.format(table=self._full_tbl)
|
|
|
|
# https://www.psycopg.org/docs/extras.html#fast-exec
|
|
|
|
for _ in range(item.weight):
|
|
conn.execute(query, (item.name, item.weight, item.timestamp / 1000))
|
|
|
|
return True
|
|
|
|
def leak(
|
|
self,
|
|
current_timestamp: Optional[int] = None,
|
|
) -> Union[int, Awaitable[int]]:
|
|
"""leaking bucket - removing items that are outdated"""
|
|
assert current_timestamp is not None, "current-time must be passed on for leak"
|
|
lower_bound = current_timestamp - self.rates[-1].interval
|
|
|
|
if lower_bound <= 0:
|
|
return 0
|
|
|
|
count = 0
|
|
|
|
with self._get_conn() as conn:
|
|
conn = conn.execute(Queries.LEAK_COUNT.format(table=self._full_tbl, timestamp=lower_bound / 1000))
|
|
result = conn.fetchone()
|
|
|
|
if result:
|
|
conn.execute(Queries.LEAK.format(table=self._full_tbl, timestamp=lower_bound / 1000))
|
|
count = int(result[0])
|
|
|
|
return count
|
|
|
|
def flush(self) -> Union[None, Awaitable[None]]:
|
|
"""Flush the whole bucket
|
|
- Must remove `failing-rate` after flushing
|
|
"""
|
|
with self._get_conn() as conn:
|
|
conn.execute(Queries.FLUSH.format(table=self._full_tbl))
|
|
self.failing_rate = None
|
|
|
|
return None
|
|
|
|
def count(self) -> Union[int, Awaitable[int]]:
|
|
"""Count number of items in the bucket"""
|
|
count = 0
|
|
with self._get_conn() as conn:
|
|
conn = conn.execute(Queries.COUNT.format(table=self._full_tbl))
|
|
result = conn.fetchone()
|
|
assert result
|
|
count = int(result[0])
|
|
|
|
return count
|
|
|
|
def peek(self, index: int) -> Union[Optional[RateItem], Awaitable[Optional[RateItem]]]:
|
|
"""Peek at the rate-item at a specific index in latest-to-earliest order
|
|
NOTE: The reason we cannot peek from the start of the queue(earliest-to-latest) is
|
|
we can't really tell how many outdated items are still in the queue
|
|
"""
|
|
item = None
|
|
|
|
with self._get_conn() as conn:
|
|
conn = conn.execute(Queries.PEEK.format(table=self._full_tbl, offset=index))
|
|
result = conn.fetchone()
|
|
if result:
|
|
name, weight, timestamp = result[0], int(result[1]), int(result[2])
|
|
item = RateItem(name=name, weight=weight, timestamp=timestamp)
|
|
|
|
return item
|