Initial commit
This commit is contained in:
122
venv/lib/python3.10/site-packages/hishel/_s3.py
Normal file
122
venv/lib/python3.10/site-packages/hishel/_s3.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import time
|
||||
import typing as tp
|
||||
|
||||
from anyio import to_thread
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
|
||||
def get_timestamp_in_ms() -> float:
|
||||
return time.time() * 1000
|
||||
|
||||
|
||||
class S3Manager:
|
||||
def __init__(
|
||||
self,
|
||||
client: tp.Any,
|
||||
bucket_name: str,
|
||||
check_ttl_every: tp.Union[int, float],
|
||||
is_binary: bool = False,
|
||||
path_prefix: str = "hishel-",
|
||||
):
|
||||
self._client = client
|
||||
self._bucket_name = bucket_name
|
||||
self._is_binary = is_binary
|
||||
self._last_cleaned = time.monotonic()
|
||||
self._check_ttl_every = check_ttl_every
|
||||
self._path_prefix = path_prefix
|
||||
|
||||
def write_to(self, path: str, data: tp.Union[bytes, str], only_metadata: bool = False) -> None:
|
||||
path = self._path_prefix + path
|
||||
if isinstance(data, str):
|
||||
data = data.encode("utf-8")
|
||||
|
||||
created_at = None
|
||||
if only_metadata:
|
||||
try:
|
||||
response = self._client.get_object(
|
||||
Bucket=self._bucket_name,
|
||||
Key=path,
|
||||
)
|
||||
created_at = response["Metadata"]["created_at"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._client.put_object(
|
||||
Bucket=self._bucket_name,
|
||||
Key=path,
|
||||
Body=data,
|
||||
Metadata={"created_at": created_at or str(get_timestamp_in_ms())},
|
||||
)
|
||||
|
||||
def read_from(self, path: str) -> tp.Union[bytes, str]:
|
||||
path = self._path_prefix + path
|
||||
response = self._client.get_object(
|
||||
Bucket=self._bucket_name,
|
||||
Key=path,
|
||||
)
|
||||
|
||||
content = response["Body"].read()
|
||||
|
||||
if self._is_binary: # pragma: no cover
|
||||
return tp.cast(bytes, content)
|
||||
|
||||
return tp.cast(str, content.decode("utf-8"))
|
||||
|
||||
def remove_expired(self, ttl: int, key: str) -> None:
|
||||
path = self._path_prefix + key
|
||||
|
||||
if time.monotonic() - self._last_cleaned < self._check_ttl_every:
|
||||
try:
|
||||
response = self._client.get_object(Bucket=self._bucket_name, Key=path)
|
||||
if get_timestamp_in_ms() - float(response["Metadata"]["created_at"]) > ttl:
|
||||
self._client.delete_object(Bucket=self._bucket_name, Key=path)
|
||||
return
|
||||
except ClientError as e:
|
||||
if e.response["Error"]["Code"] == "NoSuchKey":
|
||||
return
|
||||
raise e
|
||||
|
||||
self._last_cleaned = time.monotonic()
|
||||
for obj in self._client.list_objects(Bucket=self._bucket_name).get("Contents", []):
|
||||
if not obj["Key"].startswith(self._path_prefix): # pragma: no cover
|
||||
continue
|
||||
|
||||
try:
|
||||
metadata_obj = self._client.head_object(Bucket=self._bucket_name, Key=obj["Key"]).get("Metadata", {})
|
||||
except ClientError as e:
|
||||
if e.response["Error"]["Code"] == "404":
|
||||
continue
|
||||
|
||||
if not metadata_obj or "created_at" not in metadata_obj:
|
||||
continue
|
||||
|
||||
if get_timestamp_in_ms() - float(metadata_obj["created_at"]) > ttl:
|
||||
self._client.delete_object(Bucket=self._bucket_name, Key=obj["Key"])
|
||||
|
||||
def remove_entry(self, key: str) -> None:
|
||||
path = self._path_prefix + key
|
||||
self._client.delete_object(Bucket=self._bucket_name, Key=path)
|
||||
|
||||
|
||||
class AsyncS3Manager: # pragma: no cover
|
||||
def __init__(
|
||||
self,
|
||||
client: tp.Any,
|
||||
bucket_name: str,
|
||||
check_ttl_every: tp.Union[int, float],
|
||||
is_binary: bool = False,
|
||||
path_prefix: str = "hishel-",
|
||||
):
|
||||
self._sync_manager = S3Manager(client, bucket_name, check_ttl_every, is_binary, path_prefix)
|
||||
|
||||
async def write_to(self, path: str, data: tp.Union[bytes, str], only_metadata: bool = False) -> None:
|
||||
return await to_thread.run_sync(self._sync_manager.write_to, path, data, only_metadata)
|
||||
|
||||
async def read_from(self, path: str) -> tp.Union[bytes, str]:
|
||||
return await to_thread.run_sync(self._sync_manager.read_from, path)
|
||||
|
||||
async def remove_expired(self, ttl: int, key: str) -> None:
|
||||
return await to_thread.run_sync(self._sync_manager.remove_expired, ttl, key)
|
||||
|
||||
async def remove_entry(self, key: str) -> None:
|
||||
return await to_thread.run_sync(self._sync_manager.remove_entry, key)
|
||||
Reference in New Issue
Block a user