Initial commit
This commit is contained in:
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
Advanced ranking functionality for edgar.documents.
|
||||
|
||||
This package provides BM25-based ranking with semantic structure awareness
|
||||
and intelligent index caching for performance optimization.
|
||||
"""
|
||||
|
||||
from edgar.documents.ranking.ranking import (
|
||||
RankingAlgorithm,
|
||||
RankingEngine,
|
||||
BM25Engine,
|
||||
HybridEngine,
|
||||
SemanticEngine,
|
||||
RankedResult,
|
||||
)
|
||||
from edgar.documents.ranking.cache import (
|
||||
SearchIndexCache,
|
||||
CacheEntry,
|
||||
get_search_cache,
|
||||
set_search_cache,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'RankingAlgorithm',
|
||||
'RankingEngine',
|
||||
'BM25Engine',
|
||||
'HybridEngine',
|
||||
'SemanticEngine',
|
||||
'RankedResult',
|
||||
'SearchIndexCache',
|
||||
'CacheEntry',
|
||||
'get_search_cache',
|
||||
'set_search_cache',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,311 @@
|
||||
"""
|
||||
Search index caching for performance optimization.
|
||||
|
||||
Provides memory and disk caching with LRU eviction and TTL expiration.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List
|
||||
import hashlib
|
||||
import pickle
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""
|
||||
Cached search index entry.
|
||||
|
||||
Stores pre-built search indices for a document along with metadata
|
||||
for cache management (access tracking, TTL).
|
||||
"""
|
||||
document_hash: str
|
||||
index_data: Dict[str, Any] # Serialized BM25 index data
|
||||
created_at: datetime
|
||||
access_count: int = 0
|
||||
last_accessed: Optional[datetime] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class SearchIndexCache:
|
||||
"""
|
||||
Manages search index caching with memory + disk storage.
|
||||
|
||||
Features:
|
||||
- In-memory LRU cache for fast access
|
||||
- Optional disk persistence for reuse across sessions
|
||||
- TTL-based expiration
|
||||
- Access statistics tracking
|
||||
|
||||
Parameters:
|
||||
memory_cache_size: Maximum entries in memory (default: 10)
|
||||
disk_cache_enabled: Enable disk persistence (default: True)
|
||||
cache_dir: Directory for disk cache (default: ~/.edgar_cache/search)
|
||||
ttl_hours: Time-to-live for cached entries (default: 24)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
memory_cache_size: int = 10,
|
||||
disk_cache_enabled: bool = True,
|
||||
cache_dir: Optional[Path] = None,
|
||||
ttl_hours: int = 24):
|
||||
"""Initialize cache."""
|
||||
self.memory_cache_size = memory_cache_size
|
||||
self.disk_cache_enabled = disk_cache_enabled
|
||||
self.cache_dir = cache_dir or Path.home() / ".edgar_cache" / "search"
|
||||
self.ttl = timedelta(hours=ttl_hours)
|
||||
|
||||
# In-memory cache (LRU)
|
||||
self._memory_cache: Dict[str, CacheEntry] = {}
|
||||
self._access_order: List[str] = []
|
||||
|
||||
# Statistics
|
||||
self._hits = 0
|
||||
self._misses = 0
|
||||
|
||||
# Create cache directory
|
||||
if disk_cache_enabled:
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def compute_document_hash(self, document_id: str, content_sample: str) -> str:
|
||||
"""
|
||||
Compute cache key from document identifiers.
|
||||
|
||||
Uses document ID (e.g., accession number) and a content sample
|
||||
to create a unique, stable hash.
|
||||
|
||||
Args:
|
||||
document_id: Unique document identifier
|
||||
content_sample: Sample of document content for verification
|
||||
|
||||
Returns:
|
||||
16-character hex hash
|
||||
"""
|
||||
content = f"{document_id}:{content_sample}"
|
||||
return hashlib.sha256(content.encode()).hexdigest()[:16]
|
||||
|
||||
def get(self, document_hash: str) -> Optional[CacheEntry]:
|
||||
"""
|
||||
Get cached entry.
|
||||
|
||||
Tries memory cache first, then disk cache. Updates LRU order
|
||||
and access statistics.
|
||||
|
||||
Args:
|
||||
document_hash: Cache key
|
||||
|
||||
Returns:
|
||||
CacheEntry if found and valid, None otherwise
|
||||
"""
|
||||
# Try memory cache first
|
||||
if document_hash in self._memory_cache:
|
||||
entry = self._memory_cache[document_hash]
|
||||
|
||||
# Check TTL
|
||||
if datetime.now() - entry.created_at > self.ttl:
|
||||
# Expired - remove from cache
|
||||
self._evict_memory(document_hash)
|
||||
self._misses += 1
|
||||
return None
|
||||
|
||||
# Update access tracking
|
||||
entry.access_count += 1
|
||||
entry.last_accessed = datetime.now()
|
||||
|
||||
# Update LRU order
|
||||
if document_hash in self._access_order:
|
||||
self._access_order.remove(document_hash)
|
||||
self._access_order.append(document_hash)
|
||||
|
||||
self._hits += 1
|
||||
logger.debug(f"Cache hit (memory): {document_hash}")
|
||||
return entry
|
||||
|
||||
# Try disk cache
|
||||
if self.disk_cache_enabled:
|
||||
entry = self._load_from_disk(document_hash)
|
||||
if entry:
|
||||
# Check TTL
|
||||
if datetime.now() - entry.created_at > self.ttl:
|
||||
# Expired - delete file
|
||||
self._delete_from_disk(document_hash)
|
||||
self._misses += 1
|
||||
return None
|
||||
|
||||
# Load into memory cache
|
||||
self._put_memory(document_hash, entry)
|
||||
self._hits += 1
|
||||
logger.debug(f"Cache hit (disk): {document_hash}")
|
||||
return entry
|
||||
|
||||
self._misses += 1
|
||||
logger.debug(f"Cache miss: {document_hash}")
|
||||
return None
|
||||
|
||||
def put(self, document_hash: str, entry: CacheEntry) -> None:
|
||||
"""
|
||||
Cache entry in memory and optionally on disk.
|
||||
|
||||
Args:
|
||||
document_hash: Cache key
|
||||
entry: Entry to cache
|
||||
"""
|
||||
# Put in memory cache
|
||||
self._put_memory(document_hash, entry)
|
||||
|
||||
# Put in disk cache
|
||||
if self.disk_cache_enabled:
|
||||
self._save_to_disk(document_hash, entry)
|
||||
|
||||
logger.debug(f"Cached entry: {document_hash}")
|
||||
|
||||
def _put_memory(self, document_hash: str, entry: CacheEntry) -> None:
|
||||
"""Put entry in memory cache with LRU eviction."""
|
||||
# Evict if cache full
|
||||
while len(self._memory_cache) >= self.memory_cache_size:
|
||||
if self._access_order:
|
||||
oldest = self._access_order.pop(0)
|
||||
self._evict_memory(oldest)
|
||||
else:
|
||||
break
|
||||
|
||||
self._memory_cache[document_hash] = entry
|
||||
self._access_order.append(document_hash)
|
||||
|
||||
def _evict_memory(self, document_hash: str) -> None:
|
||||
"""Evict entry from memory cache."""
|
||||
if document_hash in self._memory_cache:
|
||||
del self._memory_cache[document_hash]
|
||||
logger.debug(f"Evicted from memory: {document_hash}")
|
||||
|
||||
def _load_from_disk(self, document_hash: str) -> Optional[CacheEntry]:
|
||||
"""Load entry from disk cache."""
|
||||
cache_file = self.cache_dir / f"{document_hash}.pkl"
|
||||
if not cache_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(cache_file, 'rb') as f:
|
||||
entry = pickle.load(f)
|
||||
return entry
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load cache from disk: {e}")
|
||||
# Delete corrupted file
|
||||
try:
|
||||
cache_file.unlink()
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
def _save_to_disk(self, document_hash: str, entry: CacheEntry) -> None:
|
||||
"""Save entry to disk cache."""
|
||||
cache_file = self.cache_dir / f"{document_hash}.pkl"
|
||||
try:
|
||||
with open(cache_file, 'wb') as f:
|
||||
pickle.dump(entry, f)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save cache to disk: {e}")
|
||||
|
||||
def _delete_from_disk(self, document_hash: str) -> None:
|
||||
"""Delete entry from disk cache."""
|
||||
cache_file = self.cache_dir / f"{document_hash}.pkl"
|
||||
try:
|
||||
if cache_file.exists():
|
||||
cache_file.unlink()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete cache file: {e}")
|
||||
|
||||
def clear(self, memory_only: bool = False) -> None:
|
||||
"""
|
||||
Clear cache.
|
||||
|
||||
Args:
|
||||
memory_only: If True, only clear memory cache (keep disk)
|
||||
"""
|
||||
self._memory_cache.clear()
|
||||
self._access_order.clear()
|
||||
logger.info("Cleared memory cache")
|
||||
|
||||
if not memory_only and self.disk_cache_enabled:
|
||||
try:
|
||||
for cache_file in self.cache_dir.glob("*.pkl"):
|
||||
cache_file.unlink()
|
||||
logger.info("Cleared disk cache")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clear disk cache: {e}")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache statistics
|
||||
"""
|
||||
disk_entries = 0
|
||||
if self.disk_cache_enabled:
|
||||
try:
|
||||
disk_entries = len(list(self.cache_dir.glob("*.pkl")))
|
||||
except:
|
||||
pass
|
||||
|
||||
total_requests = self._hits + self._misses
|
||||
hit_rate = self._hits / total_requests if total_requests > 0 else 0.0
|
||||
|
||||
return {
|
||||
"memory_entries": len(self._memory_cache),
|
||||
"disk_entries": disk_entries,
|
||||
"total_accesses": sum(e.access_count for e in self._memory_cache.values()),
|
||||
"cache_hits": self._hits,
|
||||
"cache_misses": self._misses,
|
||||
"hit_rate": hit_rate,
|
||||
"memory_size_mb": self._estimate_cache_size()
|
||||
}
|
||||
|
||||
def _estimate_cache_size(self) -> float:
|
||||
"""Estimate memory cache size in MB."""
|
||||
try:
|
||||
import sys
|
||||
total_bytes = sum(
|
||||
sys.getsizeof(entry.index_data)
|
||||
for entry in self._memory_cache.values()
|
||||
)
|
||||
return total_bytes / (1024 * 1024)
|
||||
except:
|
||||
# Rough estimate if sys.getsizeof fails
|
||||
return len(self._memory_cache) * 5.0 # Assume ~5MB per entry
|
||||
|
||||
|
||||
# Global cache instance
|
||||
_global_cache: Optional[SearchIndexCache] = None
|
||||
|
||||
|
||||
def get_search_cache() -> SearchIndexCache:
|
||||
"""
|
||||
Get global search cache instance.
|
||||
|
||||
Creates a singleton cache instance on first call.
|
||||
|
||||
Returns:
|
||||
Global SearchIndexCache instance
|
||||
"""
|
||||
global _global_cache
|
||||
if _global_cache is None:
|
||||
_global_cache = SearchIndexCache()
|
||||
return _global_cache
|
||||
|
||||
|
||||
def set_search_cache(cache: Optional[SearchIndexCache]) -> None:
|
||||
"""
|
||||
Set global search cache instance.
|
||||
|
||||
Useful for testing or custom cache configuration.
|
||||
|
||||
Args:
|
||||
cache: Cache instance to use globally (None to disable)
|
||||
"""
|
||||
global _global_cache
|
||||
_global_cache = cache
|
||||
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
Text preprocessing for search.
|
||||
|
||||
Provides tokenization and text normalization for BM25 and semantic analysis.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import List, Set
|
||||
|
||||
|
||||
# Common English stopwords (minimal set for financial documents)
|
||||
# We keep many financial terms that might be stopwords in other contexts
|
||||
STOPWORDS: Set[str] = {
|
||||
'a', 'an', 'and', 'are', 'as', 'at', 'be', 'by', 'for',
|
||||
'from', 'has', 'he', 'in', 'is', 'it', 'its', 'of', 'on',
|
||||
'that', 'the', 'to', 'was', 'will', 'with'
|
||||
}
|
||||
|
||||
|
||||
def preprocess_text(text: str,
|
||||
lowercase: bool = True,
|
||||
remove_punctuation: bool = False) -> str:
|
||||
"""
|
||||
Preprocess text for search.
|
||||
|
||||
Args:
|
||||
text: Raw text
|
||||
lowercase: Convert to lowercase
|
||||
remove_punctuation: Remove punctuation (keep for financial data)
|
||||
|
||||
Returns:
|
||||
Preprocessed text
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
# Normalize whitespace
|
||||
text = ' '.join(text.split())
|
||||
|
||||
# Lowercase (important for BM25 matching)
|
||||
if lowercase:
|
||||
text = text.lower()
|
||||
|
||||
# Optionally remove punctuation (usually keep for "$5B", "Item 1A", etc.)
|
||||
if remove_punctuation:
|
||||
text = re.sub(r'[^\w\s]', ' ', text)
|
||||
text = ' '.join(text.split()) # Clean up extra spaces
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def tokenize(text: str,
|
||||
remove_stopwords: bool = False,
|
||||
min_token_length: int = 2) -> List[str]:
|
||||
"""
|
||||
Tokenize text for BM25 indexing.
|
||||
|
||||
Args:
|
||||
text: Text to tokenize
|
||||
remove_stopwords: Remove common stopwords
|
||||
min_token_length: Minimum token length to keep
|
||||
|
||||
Returns:
|
||||
List of tokens
|
||||
"""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
# Split on whitespace and punctuation boundaries
|
||||
# Keep alphanumeric + some special chars for financial terms
|
||||
tokens = re.findall(r'\b[\w$%]+\b', text.lower())
|
||||
|
||||
# Filter by length
|
||||
tokens = [t for t in tokens if len(t) >= min_token_length]
|
||||
|
||||
# Optionally remove stopwords
|
||||
if remove_stopwords:
|
||||
tokens = [t for t in tokens if t not in STOPWORDS]
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
def extract_query_terms(query: str) -> List[str]:
|
||||
"""
|
||||
Extract important terms from query for boosting.
|
||||
|
||||
Identifies key financial terms, numbers, and important phrases.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
List of important query terms
|
||||
"""
|
||||
# Tokenize
|
||||
tokens = tokenize(query, remove_stopwords=True)
|
||||
|
||||
# Extract important patterns
|
||||
important = []
|
||||
|
||||
# Financial amounts: $5B, $1.2M, etc.
|
||||
amounts = re.findall(r'\$[\d,.]+[BMK]?', query, re.IGNORECASE)
|
||||
important.extend(amounts)
|
||||
|
||||
# Percentages: 15%, 3.5%
|
||||
percentages = re.findall(r'\d+\.?\d*%', query)
|
||||
important.extend(percentages)
|
||||
|
||||
# Years: 2023, 2024
|
||||
years = re.findall(r'\b(19|20)\d{2}\b', query)
|
||||
important.extend(years)
|
||||
|
||||
# Item references: Item 1A, Item 7
|
||||
items = re.findall(r'item\s+\d+[a-z]?', query, re.IGNORECASE)
|
||||
important.extend(items)
|
||||
|
||||
# Add all tokens
|
||||
important.extend(tokens)
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
seen = set()
|
||||
result = []
|
||||
for term in important:
|
||||
term_lower = term.lower()
|
||||
if term_lower not in seen:
|
||||
seen.add(term_lower)
|
||||
result.append(term)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def normalize_financial_term(term: str) -> str:
|
||||
"""
|
||||
Normalize financial terms for consistent matching.
|
||||
|
||||
Examples:
|
||||
"$5 billion" -> "$5b"
|
||||
"5,000,000" -> "5000000"
|
||||
"Item 1A" -> "item1a"
|
||||
|
||||
Args:
|
||||
term: Financial term
|
||||
|
||||
Returns:
|
||||
Normalized term
|
||||
"""
|
||||
term = term.lower().strip()
|
||||
|
||||
# Remove commas from numbers
|
||||
term = term.replace(',', '')
|
||||
|
||||
# Normalize billion/million/thousand
|
||||
term = re.sub(r'\s*billion\b', 'b', term)
|
||||
term = re.sub(r'\s*million\b', 'm', term)
|
||||
term = re.sub(r'\s*thousand\b', 'k', term)
|
||||
|
||||
# Remove spaces in compound terms
|
||||
term = re.sub(r'(item|section|part)\s+(\d+[a-z]?)', r'\1\2', term)
|
||||
|
||||
# Remove extra whitespace
|
||||
term = ' '.join(term.split())
|
||||
|
||||
return term
|
||||
|
||||
|
||||
def get_ngrams(tokens: List[str], n: int = 2) -> List[str]:
|
||||
"""
|
||||
Generate n-grams from tokens.
|
||||
|
||||
Useful for phrase matching in BM25.
|
||||
|
||||
Args:
|
||||
tokens: List of tokens
|
||||
n: N-gram size
|
||||
|
||||
Returns:
|
||||
List of n-grams as strings
|
||||
"""
|
||||
if len(tokens) < n:
|
||||
return []
|
||||
|
||||
ngrams = []
|
||||
for i in range(len(tokens) - n + 1):
|
||||
ngram = ' '.join(tokens[i:i + n])
|
||||
ngrams.append(ngram)
|
||||
|
||||
return ngrams
|
||||
@@ -0,0 +1,401 @@
|
||||
"""
|
||||
Ranking engines for document search.
|
||||
|
||||
Provides BM25-based ranking with optional semantic structure boosting.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum, auto
|
||||
from typing import List, Optional, Dict, Any, TYPE_CHECKING
|
||||
|
||||
from rank_bm25 import BM25Okapi
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from edgar.documents.nodes import Node
|
||||
|
||||
|
||||
class RankingAlgorithm(Enum):
|
||||
"""Supported ranking algorithms."""
|
||||
BM25 = auto() # Classic BM25 (Okapi variant)
|
||||
HYBRID = auto() # BM25 + Semantic structure boosting
|
||||
SEMANTIC = auto() # Pure structure-aware scoring
|
||||
|
||||
|
||||
@dataclass
|
||||
class RankedResult:
|
||||
"""
|
||||
A search result with ranking score.
|
||||
|
||||
Attributes:
|
||||
node: Document node containing the match
|
||||
score: Relevance score (higher is better)
|
||||
rank: Position in results (1-indexed)
|
||||
text: Matched text content
|
||||
bm25_score: Raw BM25 score (if applicable)
|
||||
semantic_score: Semantic boost score (if applicable)
|
||||
metadata: Additional result metadata
|
||||
"""
|
||||
node: 'Node'
|
||||
score: float
|
||||
rank: int
|
||||
text: str
|
||||
bm25_score: Optional[float] = None
|
||||
semantic_score: Optional[float] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def snippet(self) -> str:
|
||||
"""Get text snippet (first 200 chars)."""
|
||||
if len(self.text) <= 200:
|
||||
return self.text
|
||||
return self.text[:197] + "..."
|
||||
|
||||
|
||||
class RankingEngine(ABC):
|
||||
"""Abstract base class for ranking engines."""
|
||||
|
||||
@abstractmethod
|
||||
def rank(self, query: str, nodes: List['Node']) -> List[RankedResult]:
|
||||
"""
|
||||
Rank nodes by relevance to query.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
nodes: Nodes to rank
|
||||
|
||||
Returns:
|
||||
List of ranked results sorted by relevance (best first)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_algorithm_name(self) -> str:
|
||||
"""Get name of ranking algorithm."""
|
||||
pass
|
||||
|
||||
|
||||
class BM25Engine(RankingEngine):
|
||||
"""
|
||||
BM25 ranking engine using Okapi variant.
|
||||
|
||||
BM25 is a probabilistic retrieval function that ranks documents based on
|
||||
query term frequency and inverse document frequency. Well-suited for
|
||||
financial documents where exact term matching is important.
|
||||
|
||||
Parameters:
|
||||
k1: Term frequency saturation parameter (default: 1.5)
|
||||
Controls how quickly term frequency impact plateaus.
|
||||
b: Length normalization parameter (default: 0.75)
|
||||
0 = no normalization, 1 = full normalization.
|
||||
"""
|
||||
|
||||
def __init__(self, k1: float = 1.5, b: float = 0.75):
|
||||
"""
|
||||
Initialize BM25 engine.
|
||||
|
||||
Args:
|
||||
k1: Term frequency saturation (1.2-2.0 typical)
|
||||
b: Length normalization (0.75 is standard)
|
||||
"""
|
||||
self.k1 = k1
|
||||
self.b = b
|
||||
self._bm25: Optional[BM25Okapi] = None
|
||||
self._corpus_nodes: Optional[List['Node']] = None
|
||||
self._tokenized_corpus: Optional[List[List[str]]] = None
|
||||
|
||||
def rank(self, query: str, nodes: List['Node']) -> List[RankedResult]:
|
||||
"""
|
||||
Rank nodes using BM25 algorithm.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
nodes: Nodes to rank
|
||||
|
||||
Returns:
|
||||
Ranked results sorted by BM25 score
|
||||
"""
|
||||
if not nodes:
|
||||
return []
|
||||
|
||||
# Import preprocessing here to avoid circular dependency
|
||||
from edgar.documents.ranking.preprocessing import preprocess_text, tokenize
|
||||
|
||||
# Build index if needed or if nodes changed
|
||||
if self._corpus_nodes != nodes:
|
||||
self._build_index(nodes)
|
||||
|
||||
# Tokenize and preprocess query
|
||||
query_tokens = tokenize(preprocess_text(query))
|
||||
|
||||
if not query_tokens:
|
||||
return []
|
||||
|
||||
# Get BM25 scores
|
||||
scores = self._bm25.get_scores(query_tokens)
|
||||
|
||||
# Create ranked results
|
||||
results = []
|
||||
for idx, (node, score) in enumerate(zip(nodes, scores)):
|
||||
if score > 0: # Only include nodes with positive scores
|
||||
text = node.text() if hasattr(node, 'text') else str(node)
|
||||
results.append(RankedResult(
|
||||
node=node,
|
||||
score=float(score),
|
||||
rank=0, # Will be set after sorting
|
||||
text=text,
|
||||
bm25_score=float(score),
|
||||
metadata={'algorithm': 'BM25'}
|
||||
))
|
||||
|
||||
# Sort by score (highest first) and assign ranks
|
||||
results.sort(key=lambda r: r.score, reverse=True)
|
||||
for rank, result in enumerate(results, start=1):
|
||||
result.rank = rank
|
||||
|
||||
return results
|
||||
|
||||
def _build_index(self, nodes: List['Node']):
|
||||
"""Build BM25 index from nodes."""
|
||||
from edgar.documents.ranking.preprocessing import preprocess_text, tokenize
|
||||
|
||||
# Store corpus
|
||||
self._corpus_nodes = nodes
|
||||
|
||||
# Tokenize all nodes
|
||||
self._tokenized_corpus = []
|
||||
for node in nodes:
|
||||
text = node.text() if hasattr(node, 'text') else str(node)
|
||||
processed = preprocess_text(text)
|
||||
tokens = tokenize(processed)
|
||||
self._tokenized_corpus.append(tokens)
|
||||
|
||||
# Build BM25 index with custom parameters
|
||||
self._bm25 = BM25Okapi(
|
||||
self._tokenized_corpus,
|
||||
k1=self.k1,
|
||||
b=self.b
|
||||
)
|
||||
|
||||
def get_index_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize index data for caching.
|
||||
|
||||
Returns:
|
||||
Dictionary with serializable index data
|
||||
"""
|
||||
return {
|
||||
'tokenized_corpus': self._tokenized_corpus,
|
||||
'k1': self.k1,
|
||||
'b': self.b,
|
||||
'algorithm': 'BM25'
|
||||
}
|
||||
|
||||
def load_index_data(self, index_data: Dict[str, Any], nodes: List['Node']) -> None:
|
||||
"""
|
||||
Load index from cached data.
|
||||
|
||||
Args:
|
||||
index_data: Serialized index data
|
||||
nodes: Nodes corresponding to the index
|
||||
"""
|
||||
self._corpus_nodes = nodes
|
||||
self._tokenized_corpus = index_data['tokenized_corpus']
|
||||
self.k1 = index_data['k1']
|
||||
self.b = index_data['b']
|
||||
|
||||
# Rebuild BM25 index from tokenized corpus
|
||||
self._bm25 = BM25Okapi(
|
||||
self._tokenized_corpus,
|
||||
k1=self.k1,
|
||||
b=self.b
|
||||
)
|
||||
|
||||
def get_algorithm_name(self) -> str:
|
||||
"""Get algorithm name."""
|
||||
return "BM25"
|
||||
|
||||
|
||||
class HybridEngine(RankingEngine):
|
||||
"""
|
||||
Hybrid ranking engine: BM25 + Semantic structure boosting.
|
||||
|
||||
Combines classic BM25 text matching with semantic structure awareness:
|
||||
- BM25 provides strong exact-match ranking for financial terms
|
||||
- Semantic scoring boosts results based on document structure:
|
||||
* Headings and section markers
|
||||
* Cross-references ("See Item X")
|
||||
* Gateway content (summaries, overviews)
|
||||
* Table and XBRL importance
|
||||
|
||||
This approach is agent-friendly: it surfaces starting points for
|
||||
investigation rather than fragmented chunks.
|
||||
|
||||
Parameters:
|
||||
bm25_weight: Weight for BM25 score (default: 0.8)
|
||||
semantic_weight: Weight for semantic score (default: 0.2)
|
||||
k1: BM25 term frequency saturation
|
||||
b: BM25 length normalization
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
bm25_weight: float = 0.8,
|
||||
semantic_weight: float = 0.2,
|
||||
k1: float = 1.5,
|
||||
b: float = 0.75,
|
||||
boost_sections: Optional[List[str]] = None):
|
||||
"""
|
||||
Initialize hybrid engine.
|
||||
|
||||
Args:
|
||||
bm25_weight: Weight for BM25 component (0-1)
|
||||
semantic_weight: Weight for semantic component (0-1)
|
||||
k1: BM25 k1 parameter
|
||||
b: BM25 b parameter
|
||||
boost_sections: Section names to boost (e.g., ["Risk Factors"])
|
||||
"""
|
||||
self.bm25_engine = BM25Engine(k1=k1, b=b)
|
||||
self.bm25_weight = bm25_weight
|
||||
self.semantic_weight = semantic_weight
|
||||
self.boost_sections = boost_sections or []
|
||||
|
||||
# Validate weights
|
||||
total_weight = bm25_weight + semantic_weight
|
||||
if not (0.99 <= total_weight <= 1.01): # Allow small floating point error
|
||||
raise ValueError(f"Weights must sum to 1.0, got {total_weight}")
|
||||
|
||||
def rank(self, query: str, nodes: List['Node']) -> List[RankedResult]:
|
||||
"""
|
||||
Rank nodes using hybrid approach.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
nodes: Nodes to rank
|
||||
|
||||
Returns:
|
||||
Ranked results with combined BM25 + semantic scores
|
||||
"""
|
||||
if not nodes:
|
||||
return []
|
||||
|
||||
# Get BM25 results
|
||||
bm25_results = self.bm25_engine.rank(query, nodes)
|
||||
|
||||
if not bm25_results:
|
||||
return []
|
||||
|
||||
# Import semantic scoring
|
||||
from edgar.documents.ranking.semantic import compute_semantic_scores
|
||||
|
||||
# Get semantic scores for all nodes
|
||||
semantic_scores_dict = compute_semantic_scores(
|
||||
nodes=nodes,
|
||||
query=query,
|
||||
boost_sections=self.boost_sections
|
||||
)
|
||||
|
||||
# Normalize BM25 scores to 0-1 range
|
||||
max_bm25 = max(r.bm25_score for r in bm25_results)
|
||||
if max_bm25 > 0:
|
||||
for result in bm25_results:
|
||||
result.bm25_score = result.bm25_score / max_bm25
|
||||
|
||||
# Combine scores
|
||||
for result in bm25_results:
|
||||
semantic_score = semantic_scores_dict.get(id(result.node), 0.0)
|
||||
result.semantic_score = semantic_score
|
||||
|
||||
# Weighted combination
|
||||
result.score = (
|
||||
self.bm25_weight * result.bm25_score +
|
||||
self.semantic_weight * semantic_score
|
||||
)
|
||||
|
||||
result.metadata['algorithm'] = 'Hybrid'
|
||||
result.metadata['bm25_weight'] = self.bm25_weight
|
||||
result.metadata['semantic_weight'] = self.semantic_weight
|
||||
|
||||
# Re-sort by combined score
|
||||
bm25_results.sort(key=lambda r: r.score, reverse=True)
|
||||
|
||||
# Update ranks
|
||||
for rank, result in enumerate(bm25_results, start=1):
|
||||
result.rank = rank
|
||||
|
||||
return bm25_results
|
||||
|
||||
def get_algorithm_name(self) -> str:
|
||||
"""Get algorithm name."""
|
||||
return "Hybrid"
|
||||
|
||||
|
||||
class SemanticEngine(RankingEngine):
|
||||
"""
|
||||
Pure semantic/structure-based ranking (no text matching).
|
||||
|
||||
Ranks nodes purely by structural importance:
|
||||
- Section headings
|
||||
- Cross-references
|
||||
- Gateway content
|
||||
- Document structure position
|
||||
|
||||
Useful for understanding document organization without specific queries.
|
||||
"""
|
||||
|
||||
def __init__(self, boost_sections: Optional[List[str]] = None):
|
||||
"""
|
||||
Initialize semantic engine.
|
||||
|
||||
Args:
|
||||
boost_sections: Section names to boost
|
||||
"""
|
||||
self.boost_sections = boost_sections or []
|
||||
|
||||
def rank(self, query: str, nodes: List['Node']) -> List[RankedResult]:
|
||||
"""
|
||||
Rank nodes by semantic importance.
|
||||
|
||||
Args:
|
||||
query: Search query (used for context)
|
||||
nodes: Nodes to rank
|
||||
|
||||
Returns:
|
||||
Ranked results by structural importance
|
||||
"""
|
||||
if not nodes:
|
||||
return []
|
||||
|
||||
from edgar.documents.ranking.semantic import compute_semantic_scores
|
||||
|
||||
# Get semantic scores
|
||||
semantic_scores = compute_semantic_scores(
|
||||
nodes=nodes,
|
||||
query=query,
|
||||
boost_sections=self.boost_sections
|
||||
)
|
||||
|
||||
# Create results
|
||||
results = []
|
||||
for node in nodes:
|
||||
score = semantic_scores.get(id(node), 0.0)
|
||||
if score > 0:
|
||||
text = node.text() if hasattr(node, 'text') else str(node)
|
||||
results.append(RankedResult(
|
||||
node=node,
|
||||
score=score,
|
||||
rank=0,
|
||||
text=text,
|
||||
semantic_score=score,
|
||||
metadata={'algorithm': 'Semantic'}
|
||||
))
|
||||
|
||||
# Sort and rank
|
||||
results.sort(key=lambda r: r.score, reverse=True)
|
||||
for rank, result in enumerate(results, start=1):
|
||||
result.rank = rank
|
||||
|
||||
return results
|
||||
|
||||
def get_algorithm_name(self) -> str:
|
||||
"""Get algorithm name."""
|
||||
return "Semantic"
|
||||
@@ -0,0 +1,333 @@
|
||||
"""
|
||||
Semantic scoring for document structure awareness.
|
||||
|
||||
Provides structure-based boosting without ML/embeddings:
|
||||
- Node type importance (headings, tables, XBRL)
|
||||
- Cross-reference detection (gateway content)
|
||||
- Section importance
|
||||
- Text quality signals
|
||||
|
||||
This is NOT embedding-based semantic search. It's structure-aware ranking
|
||||
that helps agents find investigation starting points.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import List, Dict, Optional, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from edgar.documents.nodes import Node
|
||||
|
||||
from edgar.documents.types import NodeType, SemanticType
|
||||
|
||||
|
||||
# Gateway terms that indicate summary/overview content
|
||||
GATEWAY_TERMS = [
|
||||
'summary', 'overview', 'introduction', 'highlights',
|
||||
'key points', 'executive summary', 'in summary',
|
||||
'table of contents', 'index'
|
||||
]
|
||||
|
||||
# Cross-reference patterns
|
||||
CROSS_REFERENCE_PATTERNS = [
|
||||
r'\bsee\s+item\s+\d+[a-z]?\b', # "See Item 1A"
|
||||
r'\bsee\s+(?:part|section)\s+\d+\b', # "See Part II"
|
||||
r'\brefer\s+to\s+item\s+\d+[a-z]?\b', # "Refer to Item 7"
|
||||
r'\bas\s+discussed\s+in\s+item\s+\d+\b', # "As discussed in Item 1"
|
||||
r'\bfor\s+(?:more|additional)\s+information\b', # "For more information"
|
||||
]
|
||||
|
||||
# Section importance weights
|
||||
SECTION_IMPORTANCE = {
|
||||
'risk factors': 1.5,
|
||||
'management discussion': 1.4,
|
||||
'md&a': 1.4,
|
||||
'business': 1.3,
|
||||
'financial statements': 1.2,
|
||||
'controls and procedures': 1.2,
|
||||
}
|
||||
|
||||
|
||||
def compute_semantic_scores(nodes: List['Node'],
|
||||
query: str,
|
||||
boost_sections: Optional[List[str]] = None) -> Dict[int, float]:
|
||||
"""
|
||||
Compute semantic/structure scores for nodes.
|
||||
|
||||
This provides structure-aware boosting based on:
|
||||
1. Node type (headings > tables > paragraphs)
|
||||
2. Cross-references (gateway content)
|
||||
3. Section importance
|
||||
4. Gateway terms (summaries, overviews)
|
||||
5. XBRL presence
|
||||
6. Text quality
|
||||
|
||||
Args:
|
||||
nodes: Nodes to score
|
||||
query: Search query (for context-aware boosting)
|
||||
boost_sections: Additional sections to boost
|
||||
|
||||
Returns:
|
||||
Dictionary mapping node id to semantic score (0-1 range)
|
||||
"""
|
||||
scores = {}
|
||||
boost_sections = boost_sections or []
|
||||
|
||||
# Get query context
|
||||
query_lower = query.lower()
|
||||
is_item_query = bool(re.search(r'item\s+\d+[a-z]?', query_lower))
|
||||
|
||||
for node in nodes:
|
||||
score = 0.0
|
||||
|
||||
# 1. Node Type Boosting
|
||||
score += _get_node_type_boost(node)
|
||||
|
||||
# 2. Semantic Type Boosting
|
||||
score += _get_semantic_type_boost(node)
|
||||
|
||||
# 3. Cross-Reference Detection (gateway content)
|
||||
score += _detect_cross_references(node)
|
||||
|
||||
# 4. Gateway Content Detection
|
||||
score += _detect_gateway_content(node, query_lower)
|
||||
|
||||
# 5. Section Importance Boosting
|
||||
score += _get_section_boost(node, boost_sections)
|
||||
|
||||
# 6. XBRL Fact Boosting (for financial queries)
|
||||
score += _get_xbrl_boost(node)
|
||||
|
||||
# 7. Text Quality Signals
|
||||
score += _get_quality_boost(node)
|
||||
|
||||
# 8. Query-Specific Boosting
|
||||
if is_item_query:
|
||||
score += _get_item_header_boost(node)
|
||||
|
||||
# Normalize to 0-1 range (max possible score is ~7.0)
|
||||
normalized_score = min(score / 7.0, 1.0)
|
||||
|
||||
scores[id(node)] = normalized_score
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def _get_node_type_boost(node: 'Node') -> float:
|
||||
"""
|
||||
Boost based on node type.
|
||||
|
||||
Headings and structural elements are more important for navigation.
|
||||
"""
|
||||
type_boosts = {
|
||||
NodeType.HEADING: 2.0, # Headings are key navigation points
|
||||
NodeType.SECTION: 1.5, # Section markers
|
||||
NodeType.TABLE: 1.0, # Tables contain structured data
|
||||
NodeType.XBRL_FACT: 0.8, # Financial facts
|
||||
NodeType.LIST: 0.5, # Lists
|
||||
NodeType.PARAGRAPH: 0.3, # Regular text
|
||||
NodeType.TEXT: 0.1, # Plain text nodes
|
||||
}
|
||||
|
||||
return type_boosts.get(node.type, 0.0)
|
||||
|
||||
|
||||
def _get_semantic_type_boost(node: 'Node') -> float:
|
||||
"""
|
||||
Boost based on semantic type.
|
||||
|
||||
Section headers and items are important for SEC filings.
|
||||
"""
|
||||
if not hasattr(node, 'semantic_type') or node.semantic_type is None:
|
||||
return 0.0
|
||||
|
||||
semantic_boosts = {
|
||||
SemanticType.ITEM_HEADER: 2.0, # Item headers are critical
|
||||
SemanticType.SECTION_HEADER: 1.5, # Section headers
|
||||
SemanticType.FINANCIAL_STATEMENT: 1.2, # Financial statements
|
||||
SemanticType.TABLE_OF_CONTENTS: 1.0, # TOC is a gateway
|
||||
SemanticType.TITLE: 0.8,
|
||||
SemanticType.HEADER: 0.6,
|
||||
}
|
||||
|
||||
return semantic_boosts.get(node.semantic_type, 0.0)
|
||||
|
||||
|
||||
def _detect_cross_references(node: 'Node') -> float:
|
||||
"""
|
||||
Detect cross-references that indicate gateway content.
|
||||
|
||||
Content that points to other sections is useful for navigation.
|
||||
"""
|
||||
text = node.text() if hasattr(node, 'text') else ''
|
||||
if not text:
|
||||
return 0.0
|
||||
|
||||
text_lower = text.lower()
|
||||
|
||||
# Check each pattern
|
||||
matches = 0
|
||||
for pattern in CROSS_REFERENCE_PATTERNS:
|
||||
if re.search(pattern, text_lower):
|
||||
matches += 1
|
||||
|
||||
# Boost increases with number of cross-references
|
||||
return min(matches * 0.5, 1.5) # Cap at 1.5
|
||||
|
||||
|
||||
def _detect_gateway_content(node: 'Node', query_lower: str) -> float:
|
||||
"""
|
||||
Detect gateway content (summaries, overviews, introductions).
|
||||
|
||||
These are excellent starting points for investigation.
|
||||
"""
|
||||
text = node.text() if hasattr(node, 'text') else ''
|
||||
if not text:
|
||||
return 0.0
|
||||
|
||||
text_lower = text.lower()
|
||||
|
||||
# Check for gateway terms in text
|
||||
for term in GATEWAY_TERMS:
|
||||
if term in text_lower:
|
||||
return 1.0
|
||||
|
||||
# Check if this is an introductory paragraph (first ~200 chars)
|
||||
if len(text) < 200 and len(text) > 20:
|
||||
# Short intro paragraphs are often summaries
|
||||
if any(word in text_lower for word in ['provides', 'describes', 'includes', 'contains']):
|
||||
return 0.5
|
||||
|
||||
return 0.0
|
||||
|
||||
|
||||
def _get_section_boost(node: 'Node', boost_sections: List[str]) -> float:
|
||||
"""
|
||||
Boost nodes in important sections.
|
||||
|
||||
Some SEC sections are more relevant for certain queries.
|
||||
"""
|
||||
# Try to determine section from node or ancestors
|
||||
section_name = _get_node_section(node)
|
||||
if not section_name:
|
||||
return 0.0
|
||||
|
||||
section_lower = section_name.lower()
|
||||
|
||||
# Check built-in importance
|
||||
for key, boost in SECTION_IMPORTANCE.items():
|
||||
if key in section_lower:
|
||||
return boost
|
||||
|
||||
# Check user-specified sections
|
||||
for boost_section in boost_sections:
|
||||
if boost_section.lower() in section_lower:
|
||||
return 1.5
|
||||
|
||||
return 0.0
|
||||
|
||||
|
||||
def _get_xbrl_boost(node: 'Node') -> float:
|
||||
"""
|
||||
Boost XBRL facts and tables with XBRL data.
|
||||
|
||||
Financial data is important for financial queries.
|
||||
"""
|
||||
if node.type == NodeType.XBRL_FACT:
|
||||
return 0.8
|
||||
|
||||
# Check if table contains XBRL facts
|
||||
if node.type == NodeType.TABLE:
|
||||
# Check metadata for XBRL indicator
|
||||
if hasattr(node, 'metadata') and node.metadata.get('has_xbrl'):
|
||||
return 0.6
|
||||
|
||||
return 0.0
|
||||
|
||||
|
||||
def _get_quality_boost(node: 'Node') -> float:
|
||||
"""
|
||||
Boost based on text quality signals.
|
||||
|
||||
Higher quality content tends to be more useful:
|
||||
- Appropriate length (not too short, not too long)
|
||||
- Good structure (sentences, punctuation)
|
||||
- Substantive content (not just formatting)
|
||||
"""
|
||||
text = node.text() if hasattr(node, 'text') else ''
|
||||
if not text:
|
||||
return 0.0
|
||||
|
||||
score = 0.0
|
||||
|
||||
# Length signal
|
||||
text_len = len(text)
|
||||
if 50 <= text_len <= 1000:
|
||||
score += 0.3 # Good length
|
||||
elif text_len > 1000:
|
||||
score += 0.1 # Long but might be comprehensive
|
||||
else:
|
||||
score += 0.0 # Too short, likely not substantive
|
||||
|
||||
# Sentence structure
|
||||
sentence_count = text.count('.') + text.count('?') + text.count('!')
|
||||
if sentence_count >= 2:
|
||||
score += 0.2 # Multiple sentences indicate substantive content
|
||||
|
||||
# Avoid pure formatting/navigation
|
||||
if text.strip() in ['...', '—', '-', 'Table of Contents', 'Page', '']:
|
||||
return 0.0 # Skip pure formatting
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def _get_item_header_boost(node: 'Node') -> float:
|
||||
"""
|
||||
Boost Item headers when query is about items.
|
||||
|
||||
"Item 1A" queries should prioritize Item 1A headers.
|
||||
"""
|
||||
if node.type != NodeType.HEADING:
|
||||
return 0.0
|
||||
|
||||
text = node.text() if hasattr(node, 'text') else ''
|
||||
if not text:
|
||||
return 0.0
|
||||
|
||||
# Check if this is an Item header
|
||||
if re.match(r'^\s*item\s+\d+[a-z]?[:\.\s]', text, re.IGNORECASE):
|
||||
return 1.5
|
||||
|
||||
return 0.0
|
||||
|
||||
|
||||
def _get_node_section(node: 'Node') -> Optional[str]:
|
||||
"""
|
||||
Get section name for a node by walking up the tree.
|
||||
|
||||
Returns:
|
||||
Section name if found, None otherwise
|
||||
"""
|
||||
# Check if node has section in metadata
|
||||
if hasattr(node, 'metadata') and 'section' in node.metadata:
|
||||
return node.metadata['section']
|
||||
|
||||
# Walk up tree looking for section marker
|
||||
current = node
|
||||
while current:
|
||||
if hasattr(current, 'semantic_type'):
|
||||
if current.semantic_type in (SemanticType.SECTION_HEADER, SemanticType.ITEM_HEADER):
|
||||
return current.text() if hasattr(current, 'text') else None
|
||||
|
||||
current = current.parent if hasattr(current, 'parent') else None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_section_importance_names() -> List[str]:
|
||||
"""
|
||||
Get list of important section names for reference.
|
||||
|
||||
Returns:
|
||||
List of section names with built-in importance boosts
|
||||
"""
|
||||
return list(SECTION_IMPORTANCE.keys())
|
||||
Reference in New Issue
Block a user