Files
edgartools/venv/lib/python3.10/site-packages/hishel/_s3.py
2025-12-09 12:13:01 +01:00

123 lines
4.2 KiB
Python

import time
import typing as tp
from anyio import to_thread
from botocore.exceptions import ClientError
def get_timestamp_in_ms() -> float:
return time.time() * 1000
class S3Manager:
def __init__(
self,
client: tp.Any,
bucket_name: str,
check_ttl_every: tp.Union[int, float],
is_binary: bool = False,
path_prefix: str = "hishel-",
):
self._client = client
self._bucket_name = bucket_name
self._is_binary = is_binary
self._last_cleaned = time.monotonic()
self._check_ttl_every = check_ttl_every
self._path_prefix = path_prefix
def write_to(self, path: str, data: tp.Union[bytes, str], only_metadata: bool = False) -> None:
path = self._path_prefix + path
if isinstance(data, str):
data = data.encode("utf-8")
created_at = None
if only_metadata:
try:
response = self._client.get_object(
Bucket=self._bucket_name,
Key=path,
)
created_at = response["Metadata"]["created_at"]
except Exception:
pass
self._client.put_object(
Bucket=self._bucket_name,
Key=path,
Body=data,
Metadata={"created_at": created_at or str(get_timestamp_in_ms())},
)
def read_from(self, path: str) -> tp.Union[bytes, str]:
path = self._path_prefix + path
response = self._client.get_object(
Bucket=self._bucket_name,
Key=path,
)
content = response["Body"].read()
if self._is_binary: # pragma: no cover
return tp.cast(bytes, content)
return tp.cast(str, content.decode("utf-8"))
def remove_expired(self, ttl: int, key: str) -> None:
path = self._path_prefix + key
if time.monotonic() - self._last_cleaned < self._check_ttl_every:
try:
response = self._client.get_object(Bucket=self._bucket_name, Key=path)
if get_timestamp_in_ms() - float(response["Metadata"]["created_at"]) > ttl:
self._client.delete_object(Bucket=self._bucket_name, Key=path)
return
except ClientError as e:
if e.response["Error"]["Code"] == "NoSuchKey":
return
raise e
self._last_cleaned = time.monotonic()
for obj in self._client.list_objects(Bucket=self._bucket_name).get("Contents", []):
if not obj["Key"].startswith(self._path_prefix): # pragma: no cover
continue
try:
metadata_obj = self._client.head_object(Bucket=self._bucket_name, Key=obj["Key"]).get("Metadata", {})
except ClientError as e:
if e.response["Error"]["Code"] == "404":
continue
if not metadata_obj or "created_at" not in metadata_obj:
continue
if get_timestamp_in_ms() - float(metadata_obj["created_at"]) > ttl:
self._client.delete_object(Bucket=self._bucket_name, Key=obj["Key"])
def remove_entry(self, key: str) -> None:
path = self._path_prefix + key
self._client.delete_object(Bucket=self._bucket_name, Key=path)
class AsyncS3Manager: # pragma: no cover
def __init__(
self,
client: tp.Any,
bucket_name: str,
check_ttl_every: tp.Union[int, float],
is_binary: bool = False,
path_prefix: str = "hishel-",
):
self._sync_manager = S3Manager(client, bucket_name, check_ttl_every, is_binary, path_prefix)
async def write_to(self, path: str, data: tp.Union[bytes, str], only_metadata: bool = False) -> None:
return await to_thread.run_sync(self._sync_manager.write_to, path, data, only_metadata)
async def read_from(self, path: str) -> tp.Union[bytes, str]:
return await to_thread.run_sync(self._sync_manager.read_from, path)
async def remove_expired(self, ttl: int, key: str) -> None:
return await to_thread.run_sync(self._sync_manager.remove_expired, ttl, key)
async def remove_entry(self, key: str) -> None:
return await to_thread.run_sync(self._sync_manager.remove_entry, key)