Initial commit

This commit is contained in:
kdusek
2025-12-09 12:13:01 +01:00
commit 8e654ed209
13332 changed files with 2695056 additions and 0 deletions

View File

@@ -0,0 +1,134 @@
import hashlib
import re
from functools import lru_cache
from typing import Any, Callable, Dict, List
import pyarrow as pa
from rapidfuzz import fuzz
from unidecode import unidecode
class FastSearch:
def __init__(self, data: pa.Table, columns: List[str], preprocess_func: Callable[[str], str] = None,
score_func: Callable[[str, str, str], float] = None):
self.data = data
self.columns = columns
self.preprocess = preprocess_func or self._default_preprocess
self.calculate_score = score_func or self._default_calculate_score
self.indices = {column: self._build_index(column) for column in columns}
# Calculate and store the hash of the data structure
self._data_hash = self._compute_data_hash()
@staticmethod
def _default_preprocess(text: str) -> str:
text = unidecode(text.lower())
text = re.sub(r'[^\w\s]', '', text)
text = re.sub(r'\s+', ' ', text).strip()
return text
def _build_index(self, column: str) -> Dict[str, List[int]]:
index = {}
for i, value in enumerate(self.data[column].to_pylist()):
processed_value = self.preprocess(str(value))
for word in processed_value.split():
if word not in index:
index[word] = []
index[word].append(i)
return index
@staticmethod
def _default_calculate_score(query: str, value: str) -> float:
return fuzz.ratio(query, value)
def search(self, query: str, top_n: int = 10, threshold: float = 60) -> List[Dict[str, Any]]:
processed_query = self.preprocess(query)
query_words = processed_query.split()
candidate_indices = set()
for column in self.columns:
for word in query_words:
candidate_indices.update(self.indices[column].get(word, []))
if len(query) <= 5: # Assume it's a ticker query
for indexed_word in self.indices[column]:
if indexed_word.startswith(query.lower()):
candidate_indices.update(self.indices[column][indexed_word])
scores = []
for idx in candidate_indices:
record = {column: self.data[column][idx].as_py() for column in self.data.schema.names}
best_score = max(
self.calculate_score(processed_query, self.preprocess(str(record[column])), column) for column in
self.columns)
if best_score >= threshold:
record['score'] = best_score
scores.append(record)
return sorted(scores, key=lambda x: x['score'], reverse=True)[:top_n]
def _compute_data_hash(self) -> int:
# Create a string representation of the data structure
data_repr = f"Shape: {self.data.shape}, "
data_repr += f"Columns: {','.join(self.data.column_names)}, "
data_repr += f"Types: {','.join(str(field.type) for field in self.data.schema)}, "
data_repr += f"Index Columns: {','.join(self.columns)}"
# Use SHA256 to create a hash of the data representation
return int(hashlib.sha256(data_repr.encode()).hexdigest(), 16)
def __hash__(self) -> int:
return self._data_hash
def __eq__(self, other: object) -> bool:
if not isinstance(other, FastSearch):
return NotImplemented
return self._data_hash == other._data_hash
def create_search_index(data: pa.Table, columns: List[str], preprocess_func: Callable[[str], str] = None,
score_func: Callable[[str, str, str], float] = None) -> FastSearch:
return FastSearch(data, columns, preprocess_func, score_func)
def search(index: FastSearch, query: str, top_n: int = 10) -> List[Dict[str, str]]:
return index.search(query, top_n)
@lru_cache(maxsize=128)
def cached_search(index: FastSearch, query: str, top_n: int = 10):
return index.search(query, top_n)
# Example usage for company and ticker search
def company_ticker_preprocess(text: str) -> str:
text = FastSearch._default_preprocess(text)
common_terms = ['llc', 'inc', 'corp', 'ltd', 'limited', 'company']
return ' '.join(word for word in text.split() if word not in common_terms)
def company_ticker_score(query: str, value: str, column: str) -> float:
query = query.upper()
value = value.upper()
# Check if it's likely a ticker (5 characters or less)
if len(query) <= 5 and column == 'ticker':
if query == value:
return 100 # Exact match
elif value.startswith(query):
return 90 + (10 * len(query) / len(value)) # Partial match, score based on completeness
else:
return 0 # No match for tickers
else:
# For company names, use the default scoring method
return FastSearch._default_calculate_score(query, value)
def preprocess_company_name(company_name: str) -> str:
company_name = unidecode(company_name.lower())
company_name = re.sub(r'[^\w\s]', '', company_name)
company_name = re.sub(r'\s+', ' ', company_name).strip()
return company_name