123 lines
4.2 KiB
Python
123 lines
4.2 KiB
Python
import time
|
|
import typing as tp
|
|
|
|
from anyio import to_thread
|
|
from botocore.exceptions import ClientError
|
|
|
|
|
|
def get_timestamp_in_ms() -> float:
|
|
return time.time() * 1000
|
|
|
|
|
|
class S3Manager:
|
|
def __init__(
|
|
self,
|
|
client: tp.Any,
|
|
bucket_name: str,
|
|
check_ttl_every: tp.Union[int, float],
|
|
is_binary: bool = False,
|
|
path_prefix: str = "hishel-",
|
|
):
|
|
self._client = client
|
|
self._bucket_name = bucket_name
|
|
self._is_binary = is_binary
|
|
self._last_cleaned = time.monotonic()
|
|
self._check_ttl_every = check_ttl_every
|
|
self._path_prefix = path_prefix
|
|
|
|
def write_to(self, path: str, data: tp.Union[bytes, str], only_metadata: bool = False) -> None:
|
|
path = self._path_prefix + path
|
|
if isinstance(data, str):
|
|
data = data.encode("utf-8")
|
|
|
|
created_at = None
|
|
if only_metadata:
|
|
try:
|
|
response = self._client.get_object(
|
|
Bucket=self._bucket_name,
|
|
Key=path,
|
|
)
|
|
created_at = response["Metadata"]["created_at"]
|
|
except Exception:
|
|
pass
|
|
|
|
self._client.put_object(
|
|
Bucket=self._bucket_name,
|
|
Key=path,
|
|
Body=data,
|
|
Metadata={"created_at": created_at or str(get_timestamp_in_ms())},
|
|
)
|
|
|
|
def read_from(self, path: str) -> tp.Union[bytes, str]:
|
|
path = self._path_prefix + path
|
|
response = self._client.get_object(
|
|
Bucket=self._bucket_name,
|
|
Key=path,
|
|
)
|
|
|
|
content = response["Body"].read()
|
|
|
|
if self._is_binary: # pragma: no cover
|
|
return tp.cast(bytes, content)
|
|
|
|
return tp.cast(str, content.decode("utf-8"))
|
|
|
|
def remove_expired(self, ttl: int, key: str) -> None:
|
|
path = self._path_prefix + key
|
|
|
|
if time.monotonic() - self._last_cleaned < self._check_ttl_every:
|
|
try:
|
|
response = self._client.get_object(Bucket=self._bucket_name, Key=path)
|
|
if get_timestamp_in_ms() - float(response["Metadata"]["created_at"]) > ttl:
|
|
self._client.delete_object(Bucket=self._bucket_name, Key=path)
|
|
return
|
|
except ClientError as e:
|
|
if e.response["Error"]["Code"] == "NoSuchKey":
|
|
return
|
|
raise e
|
|
|
|
self._last_cleaned = time.monotonic()
|
|
for obj in self._client.list_objects(Bucket=self._bucket_name).get("Contents", []):
|
|
if not obj["Key"].startswith(self._path_prefix): # pragma: no cover
|
|
continue
|
|
|
|
try:
|
|
metadata_obj = self._client.head_object(Bucket=self._bucket_name, Key=obj["Key"]).get("Metadata", {})
|
|
except ClientError as e:
|
|
if e.response["Error"]["Code"] == "404":
|
|
continue
|
|
|
|
if not metadata_obj or "created_at" not in metadata_obj:
|
|
continue
|
|
|
|
if get_timestamp_in_ms() - float(metadata_obj["created_at"]) > ttl:
|
|
self._client.delete_object(Bucket=self._bucket_name, Key=obj["Key"])
|
|
|
|
def remove_entry(self, key: str) -> None:
|
|
path = self._path_prefix + key
|
|
self._client.delete_object(Bucket=self._bucket_name, Key=path)
|
|
|
|
|
|
class AsyncS3Manager: # pragma: no cover
|
|
def __init__(
|
|
self,
|
|
client: tp.Any,
|
|
bucket_name: str,
|
|
check_ttl_every: tp.Union[int, float],
|
|
is_binary: bool = False,
|
|
path_prefix: str = "hishel-",
|
|
):
|
|
self._sync_manager = S3Manager(client, bucket_name, check_ttl_every, is_binary, path_prefix)
|
|
|
|
async def write_to(self, path: str, data: tp.Union[bytes, str], only_metadata: bool = False) -> None:
|
|
return await to_thread.run_sync(self._sync_manager.write_to, path, data, only_metadata)
|
|
|
|
async def read_from(self, path: str) -> tp.Union[bytes, str]:
|
|
return await to_thread.run_sync(self._sync_manager.read_from, path)
|
|
|
|
async def remove_expired(self, ttl: int, key: str) -> None:
|
|
return await to_thread.run_sync(self._sync_manager.remove_expired, ttl, key)
|
|
|
|
async def remove_entry(self, key: str) -> None:
|
|
return await to_thread.run_sync(self._sync_manager.remove_entry, key)
|