Initial commit
This commit is contained in:
@@ -0,0 +1,363 @@
|
||||
"""An alternative cache using:
|
||||
- Flat files
|
||||
|
||||
"""
|
||||
|
||||
import calendar
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Optional, Union
|
||||
from urllib.parse import quote, unquote
|
||||
|
||||
import aiofiles
|
||||
import httpx
|
||||
from filelock import AsyncFileLock, FileLock
|
||||
from httpx._types import SyncByteStream # protocol type
|
||||
|
||||
from ..controller import get_rule_for_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AlreadyLockedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DualFileStream(httpx._types.SyncByteStream, httpx.AsyncByteStream):
|
||||
def __init__(
|
||||
self,
|
||||
path: Path,
|
||||
chunk_size: int = 1024 * 1024,
|
||||
on_close: Optional[callable] = None,
|
||||
async_on_close: Optional[callable] = None,
|
||||
):
|
||||
self.path, self.chunk_size = Path(path), chunk_size
|
||||
self.on_close, self.async_on_close = on_close, async_on_close
|
||||
|
||||
def __iter__(self):
|
||||
with open(self.path, "rb") as f:
|
||||
while True:
|
||||
b = f.read(self.chunk_size)
|
||||
if not b:
|
||||
break
|
||||
yield b
|
||||
|
||||
def close(self) -> None:
|
||||
if self.on_close: # pragma: no cover
|
||||
self.on_close()
|
||||
|
||||
async def __aiter__(self):
|
||||
async with aiofiles.open(self.path, "rb") as f:
|
||||
while True:
|
||||
b = await f.read(self.chunk_size)
|
||||
if not b:
|
||||
break
|
||||
yield b
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self.async_on_close: # pragma: no cover
|
||||
await self.async_on_close()
|
||||
|
||||
|
||||
class FileCache:
|
||||
def __init__(self, cache_dir: Union[str, Path], locking: bool = True):
|
||||
self.cache_dir = Path(cache_dir)
|
||||
logger.info("cache_dir=%s", self.cache_dir)
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.locking = locking
|
||||
|
||||
def _meta_path(self, p: Path) -> Path:
|
||||
return p.with_suffix(p.suffix + ".meta")
|
||||
|
||||
def _load_meta(self, p: Path) -> dict:
|
||||
try:
|
||||
return json.loads(self._meta_path(p).read_text())
|
||||
except FileNotFoundError: # pragma: no cover
|
||||
return {}
|
||||
|
||||
def to_path(self, host: str, path: str, query: str) -> Path:
|
||||
site = host.lower().rstrip(".")
|
||||
(self.cache_dir / site).mkdir(parents=True, exist_ok=True)
|
||||
name = unquote(path).strip("/").replace("/", "-") or "index"
|
||||
if query:
|
||||
name += "-" + unquote(query).replace("&", "-").replace("=", "-")
|
||||
return self.cache_dir / site / quote(name, safe="._-~")
|
||||
|
||||
def get_if_fresh(
|
||||
self, host: str, path: str, query: str, cache_rules: dict[str, dict[str, Union[bool, int]]]
|
||||
) -> tuple[bool, Optional[Path]]:
|
||||
cached = get_rule_for_request(request_host=host, target=path, cache_rules=cache_rules)
|
||||
|
||||
if not cached:
|
||||
logger.info("No cache policy for %s://%s, not retrieving from cache", host, path)
|
||||
return False, None
|
||||
|
||||
p = self.to_path(host=host, path=path, query=query)
|
||||
if not p.exists():
|
||||
logger.info("Cache file doesn't exist: %s for %s", path, p)
|
||||
return False, None
|
||||
|
||||
meta = self._load_meta(p)
|
||||
fetched = meta.get("fetched")
|
||||
if not fetched:
|
||||
return False, p # pragma: no cover
|
||||
|
||||
if cached is True:
|
||||
logger.info("Cache policy allows unlimited cache, returning %s", p)
|
||||
return True, p
|
||||
|
||||
age = round(time.time() - fetched)
|
||||
if age < 0: # pragma: no cover
|
||||
raise ValueError(f"Age is less than 0, impossible {age=}, file {path=}")
|
||||
logger.info("file is %s seconds old, policy allows caching for up to %s", age, cached)
|
||||
return (age <= cached, p)
|
||||
|
||||
|
||||
class _TeeCore:
|
||||
def __init__(self, resp: httpx.Response, path: Path, locking: bool, last_modified: str, access_date: str):
|
||||
assert path is not None
|
||||
|
||||
self.resp = resp
|
||||
self.path = path
|
||||
self.tmp = path.with_name(path.name + ".tmp")
|
||||
self.lock = FileLock(str(path) + ".lock") if locking else None
|
||||
self.fh = None
|
||||
if last_modified:
|
||||
self.mtime = calendar.timegm(time.strptime(last_modified, "%a, %d %b %Y %H:%M:%S GMT"))
|
||||
else:
|
||||
self.mtime = None
|
||||
|
||||
if access_date:
|
||||
self.atime = calendar.timegm(time.strptime(access_date, "%a, %d %b %Y %H:%M:%S GMT"))
|
||||
else:
|
||||
self.atime = None # pragma: no cover
|
||||
|
||||
def acquire(self):
|
||||
self.lock and self.lock.acquire() # pyright: ignore[reportUnusedExpression]
|
||||
|
||||
def open_tmp(self):
|
||||
self.fh = open(self.tmp, "wb")
|
||||
|
||||
def write(self, chunk: bytes):
|
||||
self.fh.write(chunk) # pyright: ignore[reportOptionalMemberAccess]
|
||||
|
||||
def finalize(self):
|
||||
try:
|
||||
if self.fh and not self.fh.closed:
|
||||
self.fh.flush()
|
||||
os.fsync(self.fh.fileno())
|
||||
self.fh.close()
|
||||
os.replace(self.tmp, self.path)
|
||||
try:
|
||||
meta_path = self.path.with_suffix(self.path.suffix + ".meta")
|
||||
headers = {
|
||||
"content-type": self.resp.headers.get("content-type"),
|
||||
"content-encoding": self.resp.headers.get("content-encoding"),
|
||||
}
|
||||
|
||||
meta_path.write_text(json.dumps({"fetched": self.atime, "origin_lm": self.mtime, "headers": headers}))
|
||||
except FileNotFoundError: # pragma: no cover
|
||||
pass
|
||||
finally:
|
||||
if self.lock and getattr(self.lock, "is_locked", False):
|
||||
self.lock.release()
|
||||
|
||||
|
||||
class _TeeToDisk(SyncByteStream):
|
||||
def __init__(self, resp: httpx.Response, path: Path, locking: bool, last_modified: str, access_date: str) -> None:
|
||||
self.core = _TeeCore(resp, path, locking, last_modified, access_date)
|
||||
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
self.core.acquire()
|
||||
try:
|
||||
self.core.open_tmp()
|
||||
for chunk in self.core.resp.iter_raw():
|
||||
self.core.write(chunk)
|
||||
yield chunk
|
||||
finally:
|
||||
self.core.finalize()
|
||||
|
||||
def close(self) -> None:
|
||||
try:
|
||||
self.core.resp.close()
|
||||
finally:
|
||||
self.core.finalize()
|
||||
|
||||
|
||||
class _AsyncTeeToDisk(httpx.AsyncByteStream):
|
||||
def __init__(self, resp, path, locking, last_modified, access_date):
|
||||
self.resp = resp
|
||||
self.path = path
|
||||
self.tmp = path.with_name(path.name + ".tmp")
|
||||
self.lock = AsyncFileLock(str(path) + ".lock") if locking else None
|
||||
if last_modified:
|
||||
self.mtime = calendar.timegm(time.strptime(last_modified, "%a, %d %b %Y %H:%M:%S GMT"))
|
||||
else:
|
||||
self.mtime = None
|
||||
|
||||
if access_date:
|
||||
self.atime = calendar.timegm(time.strptime(access_date, "%a, %d %b %Y %H:%M:%S GMT"))
|
||||
else:
|
||||
self.atime = None # pragma: no cover
|
||||
|
||||
async def __aiter__(self):
|
||||
if self.lock:
|
||||
await self.lock.acquire()
|
||||
try:
|
||||
async with aiofiles.open(self.tmp, "wb") as f:
|
||||
async for chunk in self.resp.aiter_raw():
|
||||
await f.write(chunk)
|
||||
yield chunk
|
||||
os.replace(self.tmp, self.path)
|
||||
async with aiofiles.open(self.path.with_suffix(self.path.suffix + ".meta"), "w") as m:
|
||||
headers = {
|
||||
"content-type": self.resp.headers.get("content-type"),
|
||||
"content-encoding": self.resp.headers.get("content-encoding"),
|
||||
}
|
||||
await m.write(json.dumps({"fetched": self.atime, "origin_lm": self.mtime, "headers": headers}))
|
||||
finally:
|
||||
if self.lock:
|
||||
await self.lock.release()
|
||||
|
||||
async def aclose(self):
|
||||
try:
|
||||
await self.resp.aclose()
|
||||
finally:
|
||||
if self.lock:
|
||||
await self.lock.release()
|
||||
|
||||
|
||||
class CachingTransport(httpx.BaseTransport, httpx.AsyncBaseTransport):
|
||||
cache_rules: dict[str, dict[str, Union[bool, int]]]
|
||||
streaming_cutoff: int = 8 * 1024 * 1024
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_dir: Union[str, Path],
|
||||
cache_rules: dict[str, dict[str, Union[bool, int]]],
|
||||
transport: Optional[httpx.BaseTransport] = None,
|
||||
):
|
||||
self._cache = FileCache(cache_dir=cache_dir, locking=True)
|
||||
self.transport = transport or httpx.HTTPTransport()
|
||||
self.cache_rules = cache_rules
|
||||
|
||||
def _cache_hit_response(self, req, path: Path, status_code: int = 200):
|
||||
"""
|
||||
TODO: More carefully consider async here. read_text, read_bytes both are blocking.
|
||||
|
||||
Large files are streamed async, so the only blocking events here are for reading small(ish) files
|
||||
"""
|
||||
meta = json.loads(path.with_suffix(path.suffix + ".meta").read_text())
|
||||
date = time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime(meta["fetched"]))
|
||||
last_modified = time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime(meta["origin_lm"]))
|
||||
|
||||
ct = meta.get("headers", {}).get("content-type", "application/octet-stream")
|
||||
ce = meta.get("headers", {}).get("content-encoding")
|
||||
size = os.path.getsize(path)
|
||||
|
||||
headers = [
|
||||
("x-cache", "HIT"),
|
||||
("content-length", str(size)),
|
||||
("Date", date),
|
||||
("Last-Modified", last_modified),
|
||||
]
|
||||
if ce:
|
||||
headers.append(("content-encoding", ce))
|
||||
if ct:
|
||||
headers.append(("content-type", ct))
|
||||
|
||||
if size < self.streaming_cutoff:
|
||||
# If the file is small, just read it and return it
|
||||
return httpx.Response(
|
||||
status_code=status_code,
|
||||
headers=headers,
|
||||
content=path.read_bytes(),
|
||||
request=req,
|
||||
)
|
||||
else:
|
||||
# If the file is large, stream it
|
||||
return httpx.Response(
|
||||
status_code=status_code,
|
||||
headers=headers,
|
||||
stream=DualFileStream(path),
|
||||
request=req,
|
||||
)
|
||||
|
||||
def _cache_miss_response(self, req, net, path, tee_factory):
|
||||
if net.status_code != 200:
|
||||
return net
|
||||
|
||||
miss_headers = [
|
||||
(k, v)
|
||||
for k, v in net.headers.items()
|
||||
if k.lower() not in ("transfer-encoding",) # "content-encoding", "content-length", "transfer-encoding")
|
||||
]
|
||||
miss_headers.append(("x-cache", "MISS"))
|
||||
return httpx.Response(
|
||||
status_code=net.status_code,
|
||||
headers=miss_headers,
|
||||
stream=tee_factory(
|
||||
net, path, self._cache.locking, net.headers.get("Last-Modified"), net.headers.get("Date")
|
||||
),
|
||||
request=req,
|
||||
extensions={**net.extensions, "decode_content": False},
|
||||
)
|
||||
|
||||
def return_if_fresh(self, request):
|
||||
host = request.url.host
|
||||
path = request.url.path
|
||||
query = request.url.query.decode() if request.url.query else ""
|
||||
|
||||
fresh, path = self._cache.get_if_fresh(host, path, query, self.cache_rules)
|
||||
|
||||
if path:
|
||||
if fresh:
|
||||
return self._cache_hit_response(request, path), path
|
||||
else:
|
||||
lm = json.loads(path.with_suffix(path.suffix + ".meta").read_text()).get("origin_lm")
|
||||
if lm:
|
||||
request.headers["If-Modified-Since"] = time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime(lm))
|
||||
return None, path
|
||||
else:
|
||||
return None, None
|
||||
|
||||
def handle_request(self, request: httpx.Request) -> httpx.Response:
|
||||
if request.method != "GET":
|
||||
return self.transport.handle_request(request)
|
||||
|
||||
response, path = self.return_if_fresh(request)
|
||||
if response:
|
||||
return response
|
||||
|
||||
net = self.transport.handle_request(request)
|
||||
if net.status_code == 304:
|
||||
logger.info("304 for %s", request)
|
||||
assert path is not None # must be true
|
||||
return self._cache_hit_response(request, path, status_code=304)
|
||||
|
||||
host = request.url.host
|
||||
path = request.url.path
|
||||
query = request.url.query.decode() if request.url.query else ""
|
||||
|
||||
path = self._cache.to_path(host, path, query)
|
||||
return self._cache_miss_response(request, net, path, _TeeToDisk)
|
||||
|
||||
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
|
||||
if request.method != "GET":
|
||||
return await self.transport.handle_async_request(request) # type: ignore[attr-defined]
|
||||
|
||||
response, path = self.return_if_fresh(request)
|
||||
if response:
|
||||
return response
|
||||
|
||||
net = await self.transport.handle_async_request(request)
|
||||
if net.status_code == 304:
|
||||
assert path is not None # must be true
|
||||
logger.info("304 for %s", request)
|
||||
return self._cache_hit_response(request, path, status_code=304)
|
||||
|
||||
path = self._cache.to_path(request.url.host, request.url.path, request.url.query.decode())
|
||||
return self._cache_miss_response(request, net, path, _AsyncTeeToDisk)
|
||||
Reference in New Issue
Block a user