Files
2025-12-09 12:13:01 +01:00

887 lines
31 KiB
Python

import gzip
import logging
import os
import shutil
import tarfile
import tempfile
import uuid
import zipfile
from functools import wraps
from io import BytesIO
from pathlib import Path
from typing import Optional, Union
import httpcore
import orjson as json
from httpx import AsyncClient, ConnectError, HTTPError, ReadTimeout, RequestError, Response, Timeout, TimeoutException
from stamina import retry
from tqdm import tqdm
from edgar.core import get_edgar_data_directory, text_extensions
from edgar.httpclient import async_http_client, http_client
"""
This module provides functions to handle HTTP requests with retry logic, throttling, and identity management.
"""
__all__ = [
"get_with_retry",
"get_with_retry_async",
"stream_with_retry",
"post_with_retry",
"post_with_retry_async",
"download_file",
"download_file_async",
"download_json",
"download_json_async",
"stream_file",
"download_text",
"download_text_between_tags",
"download_bulk_data",
"download_datafile",
"SSLVerificationError",
]
attempts = 6
max_requests_per_second = 8
throttle_disabled = False
TIMEOUT = Timeout(30.0, connect=10.0)
RETRY_WAIT_INITIAL = 1 # Initial retry delay (seconds)
RETRY_WAIT_MAX = 60 # Max retry delay (seconds)
# Quick API requests - fail fast for responsive UX
QUICK_RETRY_ATTEMPTS = 5
QUICK_WAIT_MAX = 16 # max 16s delay
# Bulk downloads - very persistent for large files
BULK_RETRY_ATTEMPTS = 8
BULK_RETRY_TIMEOUT = None # unlimited
BULK_WAIT_MAX = 120 # max 2min delay
# SSL errors - retry once in case of transient issues, then fail with helpful message
SSL_RETRY_ATTEMPTS = 2 # 1 retry = 2 total attempts
SSL_WAIT_MAX = 5 # Short delay for SSL retries
# Exception types to retry on - includes both httpx and httpcore exceptions
RETRYABLE_EXCEPTIONS = (
# HTTPX exceptions
RequestError, HTTPError, TimeoutException, ConnectError, ReadTimeout,
# HTTPCORE exceptions that can slip through
httpcore.ReadTimeout, httpcore.WriteTimeout, httpcore.ConnectTimeout,
httpcore.PoolTimeout, httpcore.ConnectError, httpcore.NetworkError,
httpcore.TimeoutException
)
def is_ssl_error(exc: Exception) -> bool:
"""
Detect if exception is SSL certificate verification related.
Checks both the exception chain for SSL errors and error messages
for SSL-related keywords.
"""
import ssl
# Check if any httpx/httpcore exception wraps an SSL error
if isinstance(exc, (ConnectError, httpcore.ConnectError,
httpcore.NetworkError, httpcore.ProxyError)):
cause = exc.__cause__
while cause:
if isinstance(cause, ssl.SSLError):
return True
cause = cause.__cause__
# Check error message for SSL indicators
error_msg = str(exc).lower()
ssl_indicators = ['ssl', 'certificate', 'verify failed', 'certificate_verify_failed']
return any(indicator in error_msg for indicator in ssl_indicators)
def should_retry(exc: Exception) -> bool:
"""
Determine if an exception should be retried.
SSL errors are not retried because they are deterministic failures
that won't succeed on retry. All other retryable exceptions are retried.
Args:
exc: The exception to check
Returns:
True if the exception should be retried, False otherwise
"""
# Don't retry SSL errors - fail fast with helpful message
if isinstance(exc, (ConnectError, httpcore.ConnectError)):
if is_ssl_error(exc):
return False
# Retry all other exceptions in the retryable list
return isinstance(exc, RETRYABLE_EXCEPTIONS)
class TooManyRequestsError(Exception):
def __init__(self, url, message="Too Many Requests"):
self.url = url
self.message = message
super().__init__(self.message)
class IdentityNotSetException(Exception):
pass
class SSLVerificationError(Exception):
"""Raised when SSL certificate verification fails"""
def __init__(self, original_error, url):
self.original_error = original_error
self.url = url
message = f"""
SSL Certificate Verification Failed
====================================
URL: {self.url}
Error: {str(self.original_error)}
Common Causes:
• Corporate network with SSL inspection proxy
• Self-signed certificates in development environments
• Custom certificate authorities
Solution:
---------
If you trust this network, disable SSL verification:
export EDGAR_VERIFY_SSL="false"
Or in Python:
import os
os.environ['EDGAR_VERIFY_SSL'] = 'false'
from edgar import Company # Import after setting
⚠️ WARNING: Only disable in trusted environments.
This makes connections vulnerable to attacks.
Alternative Solutions:
---------------------
• Install your organization's root CA certificate
• Contact IT for proper certificate configuration
• Use a network without SSL inspection
For details: https://github.com/dgunning/edgartools/blob/main/docs/guides/ssl_verification.md
"""
super().__init__(message)
def is_redirect(response):
return response.status_code in [301, 302]
def with_identity(func):
@wraps(func)
def wrapper(url, identity=None, identity_callable=None, *args, **kwargs):
if identity is None:
if identity_callable is not None:
identity = identity_callable()
else:
identity = os.environ.get("EDGAR_IDENTITY")
if identity is None:
raise IdentityNotSetException("User-Agent identity is not set")
headers = kwargs.get("headers", {})
headers["User-Agent"] = identity
kwargs["headers"] = headers
return func(url, identity=identity, identity_callable=identity_callable, *args, **kwargs)
return wrapper
def async_with_identity(func):
@wraps(func)
def wrapper(client, url, identity=None, identity_callable=None, *args, **kwargs):
if identity is None:
if identity_callable is not None:
identity = identity_callable()
else:
identity = os.environ.get("EDGAR_IDENTITY")
if identity is None:
raise IdentityNotSetException("User-Agent identity is not set")
headers = kwargs.get("headers", {})
headers["User-Agent"] = identity
kwargs["headers"] = headers
return func(client, url, identity=identity, identity_callable=identity_callable, *args, **kwargs)
return wrapper
@retry(
on=should_retry,
attempts=QUICK_RETRY_ATTEMPTS,
wait_initial=RETRY_WAIT_INITIAL,
wait_max=QUICK_WAIT_MAX,
wait_jitter=0.5,
wait_exp_base=2
)
@with_identity
def get_with_retry(url, identity=None, identity_callable=None, **kwargs):
"""
Sends a GET request with retry functionality and identity handling.
Args:
url (str): The URL to send the GET request to.
identity (str, optional): The identity to use for the request. Defaults to None.
identity_callable (callable, optional): A callable that returns the identity. Defaults to None.
**kwargs: Additional keyword arguments to pass to the underlying httpx.Client.get() method.
Returns:
httpx.Response: The response object returned by the GET request.
Raises:
TooManyRequestsError: If the response status code is 429 (Too Many Requests).
SSLVerificationError: If SSL certificate verification fails.
"""
try:
with http_client() as client:
response = client.get(url, **kwargs)
if response.status_code == 429:
raise TooManyRequestsError(url)
elif is_redirect(response):
return get_with_retry(url=response.headers["Location"], identity=identity, identity_callable=identity_callable, **kwargs)
return response
except ConnectError as e:
if is_ssl_error(e):
raise SSLVerificationError(e, url) from e
raise
@retry(
on=should_retry,
attempts=QUICK_RETRY_ATTEMPTS,
wait_initial=RETRY_WAIT_INITIAL,
wait_max=QUICK_WAIT_MAX,
wait_jitter=0.5,
wait_exp_base=2
)
@async_with_identity
async def get_with_retry_async(client: AsyncClient, url, identity=None, identity_callable=None, **kwargs):
"""
Sends an asynchronous GET request with retry functionality and identity handling.
Args:
url (str): The URL to send the GET request to.
identity (str, optional): The identity to use for the request. Defaults to None.
identity_callable (callable, optional): A callable that returns the identity. Defaults to None.
**kwargs: Additional keyword arguments to pass to the underlying httpx.AsyncClient.get() method.
Returns:
httpx.Response: The response object returned by the GET request.
Raises:
TooManyRequestsError: If the response status code is 429 (Too Many Requests).
SSLVerificationError: If SSL certificate verification fails.
"""
try:
response = await client.get(url, **kwargs)
if response.status_code == 429:
raise TooManyRequestsError(url)
elif is_redirect(response):
return await get_with_retry_async(
client=client, url=response.headers["Location"], identity=identity, identity_callable=identity_callable, **kwargs
)
return response
except ConnectError as e:
if is_ssl_error(e):
raise SSLVerificationError(e, url) from e
raise
@retry(
on=should_retry,
attempts=BULK_RETRY_ATTEMPTS,
timeout=BULK_RETRY_TIMEOUT,
wait_initial=RETRY_WAIT_INITIAL,
wait_max=BULK_WAIT_MAX,
wait_jitter=0.5,
wait_exp_base=2
)
@with_identity
def stream_with_retry(url, identity=None, identity_callable=None, **kwargs):
"""
Sends a streaming GET request with retry functionality and identity handling.
Args:
url (str): The URL to send the streaming GET request to.
identity (str, optional): The identity to use for the request. Defaults to None.
identity_callable (callable, optional): A callable that returns the identity. Defaults to None.
**kwargs: Additional keyword arguments to pass to the underlying httpx.Client.stream() method.
Yields:
bytes: The bytes of the response content.
Raises:
TooManyRequestsError: If the response status code is 429 (Too Many Requests).
SSLVerificationError: If SSL certificate verification fails.
"""
try:
with http_client() as client:
with client.stream("GET", url, **kwargs) as response:
if response.status_code == 429:
raise TooManyRequestsError(url)
elif is_redirect(response):
response = stream_with_retry(response.headers["Location"], identity=identity, identity_callable=identity_callable, **kwargs)
yield from response
else:
yield response
except ConnectError as e:
if is_ssl_error(e):
raise SSLVerificationError(e, url) from e
raise
@retry(
on=should_retry,
attempts=QUICK_RETRY_ATTEMPTS,
wait_initial=RETRY_WAIT_INITIAL,
wait_max=QUICK_WAIT_MAX,
wait_jitter=0.5,
wait_exp_base=2
)
@with_identity
def post_with_retry(url, data=None, json=None, identity=None, identity_callable=None, **kwargs):
"""
Sends a POST request with retry functionality and identity handling.
Args:
url (str): The URL to send the POST request to.
data (dict, optional): The data to include in the request body. Defaults to None.
json (dict, optional): The JSON data to include in the request body. Defaults to None.
identity (str, optional): The identity to use for the request. Defaults to None.
identity_callable (callable, optional): A callable that returns the identity. Defaults to None.
**kwargs: Additional keyword arguments to pass to the underlying httpx.Client.post() method.
Returns:
httpx.Response: The response object returned by the POST request.
Raises:
TooManyRequestsError: If the response status code is 429 (Too Many Requests).
SSLVerificationError: If SSL certificate verification fails.
"""
try:
with http_client() as client:
response = client.post(url, data=data, json=json, **kwargs)
if response.status_code == 429:
raise TooManyRequestsError(url)
elif is_redirect(response):
return post_with_retry(
response.headers["Location"], data=data, json=json, identity=identity, identity_callable=identity_callable, **kwargs
)
return response
except ConnectError as e:
if is_ssl_error(e):
raise SSLVerificationError(e, url) from e
raise
@retry(
on=should_retry,
attempts=QUICK_RETRY_ATTEMPTS,
wait_initial=RETRY_WAIT_INITIAL,
wait_max=QUICK_WAIT_MAX,
wait_jitter=0.5,
wait_exp_base=2
)
@async_with_identity
async def post_with_retry_async(client: AsyncClient, url, data=None, json=None, identity=None, identity_callable=None, **kwargs):
"""
Sends an asynchronous POST request with retry functionality and identity handling.
Args:
url (str): The URL to send the POST request to.
data (dict, optional): The data to include in the request body. Defaults to None.
json (dict, optional): The JSON data to include in the request body. Defaults to None.
identity (str, optional): The identity to use for the request. Defaults to None.
identity_callable (callable, optional): A callable that returns the identity. Defaults to None.
**kwargs: Additional keyword arguments to pass to the underlying httpx.AsyncClient.post() method.
Returns:
httpx.Response: The response object returned by the POST request.
Raises:
TooManyRequestsError: If the response status code is 429 (Too Many Requests).
SSLVerificationError: If SSL certificate verification fails.
"""
try:
response = await client.post(url, data=data, json=json, **kwargs)
if response.status_code == 429:
raise TooManyRequestsError(url)
elif is_redirect(response):
return await post_with_retry_async(
client, response.headers["Location"], data=data, json=json, identity=identity, identity_callable=identity_callable, **kwargs
)
return response
except ConnectError as e:
if is_ssl_error(e):
raise SSLVerificationError(e, url) from e
raise
def inspect_response(response: Response):
"""
Check if the response is successful and raise an exception if not.
"""
if response.status_code != 200:
response.raise_for_status()
def decode_content(content: bytes) -> str:
"""
Decode the content of a file.
"""
try:
return content.decode("utf-8")
except UnicodeDecodeError:
return content.decode("latin-1")
def save_or_return_content(content: Union[str, bytes], path: Optional[Union[str, Path]]) -> Union[str, bytes, None]:
"""
Save the content to a specified path or return the content.
Args:
content (str or bytes): The content to save or return.
path (str or Path, optional): The path where the content should be saved. If None, return the content.
Returns:
str or bytes or None: The content if not saved, or None if saved.
"""
if path is not None:
path = Path(path)
# Determine if the path is a directory or a file
if path.is_dir():
file_name = "downloaded_file" # Replace with logic to extract file name from URL if available
file_path = path / file_name
else:
file_path = path
# Save the file
if isinstance(content, bytes):
file_path.write_bytes(content)
else:
file_path.write_text(content)
return None
return content
def download_file(url: str, as_text: bool = None, path: Optional[Union[str, Path]] = None) -> Union[str, bytes, None]:
"""
Download a file from a URL.
Args:
url (str): The URL of the file to download.
as_text (bool, optional): Whether to download the file as text or binary.
path (str or Path, optional): The path where the file should be saved.
If None, the default is determined based on the file extension. Defaults to None.
Returns:
str or bytes: The content of the downloaded file, either as text or binary data.
"""
if as_text is None:
# Set the default based on the file extension
as_text = url.endswith(text_extensions)
response = get_with_retry(url=url)
inspect_response(response)
if not as_text:
# Set the default to true if the url ends with a text extension
as_text = any([url.endswith(ext) for ext in text_extensions])
# Check if the content is gzip-compressed
if url.endswith("gz"):
binary_file = BytesIO(response.content)
with gzip.open(binary_file, "rb") as f:
file_content = f.read()
if as_text:
file_content = decode_content(file_content)
else:
# If we explicitly asked for text or there is an encoding, try to return text
if as_text:
file_content = response.text
# Should get here for jpg and PDFs
else:
file_content = response.content
path = Path(path) if path else None
if path and path.is_dir():
path = path / os.path.basename(url)
return save_or_return_content(file_content, path)
async def download_file_async(
client: AsyncClient, url: str, as_text: bool = None, path: Optional[Union[str, Path]] = None
) -> Union[str, bytes, None]:
"""
Download a file from a URL asynchronously.
Args:
url (str): The URL of the file to download.
as_text (bool, optional): Whether to download the file as text or binary.
If None, the default is determined based on the file extension. Defaults to None.
path (str or Path, optional): The path where the file should be saved.
Returns:
str or bytes: The content of the downloaded file, either as text or binary data.
"""
if as_text is None:
# Set the default based on the file extension
as_text = url.endswith(text_extensions)
response = await get_with_retry_async(client, url)
inspect_response(response)
if as_text:
# Download as text
return response.text
else:
# Download as binary
content = response.content
# Check if the content is gzip-compressed
if response.headers.get("Content-Encoding") == "gzip":
content = gzip.decompress(content)
if path and path.is_dir():
path = path / os.path.basename(url)
return save_or_return_content(content, path)
CHUNK_SIZE = 4 * 1024 * 1024 # 4MB
CHUNK_SIZE_LARGE = 8 * 1024 * 1024 # 8MB for files > 500MB
CHUNK_SIZE_MEDIUM = 4 * 1024 * 1024 # 4MB for files > 100MB
CHUNK_SIZE_SMALL = 2 * 1024 * 1024 # 2MB for files <= 100MB
CHUNK_SIZE_DEFAULT = CHUNK_SIZE
@retry(
on=should_retry,
attempts=BULK_RETRY_ATTEMPTS,
timeout=BULK_RETRY_TIMEOUT,
wait_initial=RETRY_WAIT_INITIAL,
wait_max=BULK_WAIT_MAX,
wait_jitter=0.5, # Add jitter to avoid synchronized retries
wait_exp_base=2 # Exponential backoff (doubles delay each retry)
)
@with_identity
async def stream_file(
url: str, as_text: bool = None, path: Optional[Union[str, Path]] = None, client: Optional[AsyncClient] = None, **kwargs
) -> Union[str, bytes, None]:
"""
Download a file from a URL asynchronously with progress bar using httpx.
Args:
url (str): The URL of the file to download.
as_text (bool, optional): Whether to download the file as text or binary.
If None, the default is determined based on the file extension. Defaults to None.
path (str or Path, optional): The path where the file should be saved.
client: The httpx.AsyncClient instance
Returns:
str or bytes: The content of the downloaded file, either as text or binary data.
"""
if as_text is None:
# Set the default based on the file extension
as_text = url.endswith(text_extensions)
# Create temporary directory for atomic downloads
temp_dir = tempfile.mkdtemp(prefix="edgar_")
temp_file = Path(temp_dir) / f"download_{uuid.uuid1()}"
try:
async with async_http_client(client, timeout=TIMEOUT, bypass_cache=True) as async_client:
async with async_client.stream("GET", url) as response:
inspect_response(response)
total_size = int(response.headers.get("Content-Length", 0))
if as_text:
# Download as text
content = await response.text()
return content
else:
# Download as binary - select optimal chunk size first
if total_size > 0:
if total_size > 500 * 1024 * 1024: # > 500MB
chunk_size = CHUNK_SIZE_LARGE
elif total_size > 100 * 1024 * 1024: # > 100MB
chunk_size = CHUNK_SIZE_MEDIUM
else: # <= 100MB
chunk_size = CHUNK_SIZE_SMALL
else:
# Unknown size, use default
chunk_size = CHUNK_SIZE_DEFAULT
progress_bar = tqdm(
total=total_size / (1024 * 1024),
unit="MB",
unit_scale=True,
unit_divisor=1024,
leave=False, # Force horizontal display
position=0, # Lock the position
dynamic_ncols=True, # Adapt to terminal width
bar_format="{l_bar}{bar}| {n:.2f}/{total:.2f}MB [{elapsed}<{remaining}, {rate_fmt}]",
desc=f"Downloading {os.path.basename(url)}",
ascii=False,
)
# Always stream to temporary file
try:
with open(temp_file, "wb") as f:
# For large files, update progress less frequently to reduce overhead
update_threshold = 1.0 if total_size > 500 * 1024 * 1024 else 0.1 # MB
accumulated_mb = 0.0
async for chunk in response.aiter_bytes(chunk_size=chunk_size):
f.write(chunk)
chunk_mb = len(chunk) / (1024 * 1024)
accumulated_mb += chunk_mb
# Update progress bar only when threshold is reached
if accumulated_mb >= update_threshold:
progress_bar.update(accumulated_mb)
accumulated_mb = 0.0
# Update any remaining progress
if accumulated_mb > 0:
progress_bar.update(accumulated_mb)
finally:
progress_bar.close()
# Handle the result based on whether path was provided
if path is not None:
# Atomic move to final destination
final_path = Path(path)
if final_path.is_dir():
final_path = final_path / os.path.basename(url)
# Ensure parent directory exists
final_path.parent.mkdir(parents=True, exist_ok=True)
# Atomic move from temp to final location
shutil.move(str(temp_file), str(final_path))
return None
else:
with open(temp_file, 'rb') as f:
content = f.read()
return content
finally:
# Clean up temporary directory
try:
shutil.rmtree(temp_dir, ignore_errors=True)
except Exception as e:
logger.warning("Failed to clean up temporary directory %s: %s", temp_dir, e)
def download_json(data_url: str) -> dict:
"""
Download JSON data from a URL.
Args:
data_url (str): The URL of the JSON data to download.
Returns:
dict: The parsed JSON data.
"""
content = download_file(data_url, as_text=True)
return json.loads(content)
def download_text(url: str) -> Optional[str]:
return download_file(url, as_text=True)
async def download_json_async(client: AsyncClient, data_url: str) -> dict:
"""
Download JSON data from a URL asynchronously.
Args:
data_url (str): The URL of the JSON data to download.
Returns:
dict: The parsed JSON data.
"""
content = await download_file_async(client=client, url=data_url, as_text=True)
return json.loads(content)
def download_text_between_tags(url: str, tag: str):
"""
Download the content of a URL and extract the text between the tags
This is mainly for reading the header of a filing
:param url: The URL to download
:param tag: The tag to extract the content from
"""
tag_start = f"<{tag}>"
tag_end = f"</{tag}>"
is_header = False
content = ""
for response in stream_with_retry(url):
for line in response.iter_lines():
if line:
# If line matches header_start, start capturing
if line.startswith(tag_start):
is_header = True
continue # Skip the current line as it's the opening tag
# If line matches header_end, stop capturing
elif line.startswith(tag_end):
break
# If within header lines, add to header_content
elif is_header:
content += line + "\n" # Add a newline to preserve original line breaks
return content
logger = logging.getLogger(__name__)
@retry(
on=should_retry,
attempts=BULK_RETRY_ATTEMPTS,
timeout=BULK_RETRY_TIMEOUT,
wait_initial=RETRY_WAIT_INITIAL,
wait_max=BULK_WAIT_MAX,
wait_jitter=0.5,
wait_exp_base=2
)
async def download_bulk_data(
url: str,
data_directory: Path = get_edgar_data_directory(),
client: Optional[AsyncClient] = None,
) -> Path:
"""
Download and extract bulk data from zip or tar.gz archives
Args:
client: The httpx.AsyncClient instance
url: URL to download from (e.g. "https://www.sec.gov/Archives/edgar/daily-index/xbrl/companyfacts.zip")
data_directory: Base directory for downloads
Returns:
Path to the directory containing the extracted data
Raises:
ValueError: If the URL or filename is invalid
IOError: If there are file system operation failures
zipfile.BadZipFile: If the downloaded zip file is corrupted
tarfile.TarError: If the downloaded tar.gz file is corrupted
"""
if not url:
raise ValueError("Data URL cannot be empty")
filename = os.path.basename(url)
if not filename:
raise ValueError("Invalid URL - cannot extract filename")
local_dir = filename.split(".")[0]
download_path = data_directory / local_dir
download_filename = download_path / filename
try:
# Create the directory with parents=True and exist_ok=True to avoid race conditions
download_path.mkdir(parents=True, exist_ok=True)
# Download the file
try:
await stream_file(url, client=client, path=download_path)
except Exception as e:
raise IOError(f"Failed to download file: {e}") from e
# Extract based on file extension
try:
if filename.endswith(".zip"):
with zipfile.ZipFile(download_filename, "r") as z:
# Calculate total size for progress bar
total_size = sum(info.file_size for info in z.filelist)
extracted_size = 0
with tqdm(total=total_size, unit="B", unit_scale=True, desc="Extracting") as pbar:
for info in z.filelist:
z.extract(info, download_path)
extracted_size += info.file_size
pbar.update(info.file_size)
elif any(filename.endswith(ext) for ext in (".tar.gz", ".tgz")):
with tarfile.open(download_filename, "r:gz") as tar:
# Security check for tar files to prevent path traversal
def is_within_directory(directory: Path, target: Path) -> bool:
try:
return os.path.commonpath([directory, target]) == str(directory)
except ValueError:
return False
members = tar.getmembers()
total_size = sum(member.size for member in members)
with tqdm(total=total_size, unit="B", unit_scale=True, desc="Extracting") as pbar:
for member in members:
# Check for path traversal
member_path = os.path.join(str(download_path), member.name)
if not is_within_directory(Path(str(download_path)), Path(member_path)):
raise ValueError(f"Attempted path traversal in tar file: {member.name}")
# Extract file and update progress
try:
tar.extract(member, str(download_path), filter="tar")
except TypeError:
tar.extract(member, str(download_path))
pbar.update(member.size)
else:
raise ValueError(f"Unsupported file format: {filename}")
except (zipfile.BadZipFile, tarfile.TarError) as e:
raise type(e)(f"Failed to extract archive {filename}: {e}") from e
finally:
# Always try to clean up the archive file, but don't fail if we can't
try:
if download_filename.exists():
download_filename.unlink()
except Exception as e:
logger.warning("Failed to delete archive file %s: %s", download_filename, e)
return download_path
except Exception:
# Clean up the download directory in case of any errors
try:
if download_path.exists():
shutil.rmtree(download_path)
except Exception as cleanup_error:
logger.error("Failed to clean up after error: %s", cleanup_error)
raise
def download_datafile(data_url: str, local_directory: Path = None) -> Path:
"""Download a file to the local storage directory"""
filename = os.path.basename(data_url)
# Create the directory if it doesn't exist
local_directory = local_directory or get_edgar_data_directory()
if not local_directory.exists():
local_directory.mkdir()
download_filename = local_directory / filename
download_file(data_url, path=download_filename)
return download_filename