Initial commit
This commit is contained in:
@@ -0,0 +1,2 @@
|
||||
from edgar.search.datasearch import FastSearch, company_ticker_preprocess, company_ticker_score, create_search_index, search
|
||||
from edgar.search.textsearch import BM25Search, RegexSearch, SearchResults, SimilaritySearchIndex, preprocess
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
134
venv/lib/python3.10/site-packages/edgar/search/datasearch.py
Normal file
134
venv/lib/python3.10/site-packages/edgar/search/datasearch.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import hashlib
|
||||
import re
|
||||
from functools import lru_cache
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
import pyarrow as pa
|
||||
from rapidfuzz import fuzz
|
||||
from unidecode import unidecode
|
||||
|
||||
|
||||
class FastSearch:
|
||||
def __init__(self, data: pa.Table, columns: List[str], preprocess_func: Callable[[str], str] = None,
|
||||
score_func: Callable[[str, str, str], float] = None):
|
||||
self.data = data
|
||||
self.columns = columns
|
||||
self.preprocess = preprocess_func or self._default_preprocess
|
||||
self.calculate_score = score_func or self._default_calculate_score
|
||||
self.indices = {column: self._build_index(column) for column in columns}
|
||||
|
||||
# Calculate and store the hash of the data structure
|
||||
self._data_hash = self._compute_data_hash()
|
||||
|
||||
@staticmethod
|
||||
def _default_preprocess(text: str) -> str:
|
||||
text = unidecode(text.lower())
|
||||
text = re.sub(r'[^\w\s]', '', text)
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
return text
|
||||
|
||||
def _build_index(self, column: str) -> Dict[str, List[int]]:
|
||||
index = {}
|
||||
for i, value in enumerate(self.data[column].to_pylist()):
|
||||
processed_value = self.preprocess(str(value))
|
||||
for word in processed_value.split():
|
||||
if word not in index:
|
||||
index[word] = []
|
||||
index[word].append(i)
|
||||
return index
|
||||
|
||||
@staticmethod
|
||||
def _default_calculate_score(query: str, value: str) -> float:
|
||||
return fuzz.ratio(query, value)
|
||||
|
||||
def search(self, query: str, top_n: int = 10, threshold: float = 60) -> List[Dict[str, Any]]:
|
||||
processed_query = self.preprocess(query)
|
||||
query_words = processed_query.split()
|
||||
|
||||
candidate_indices = set()
|
||||
for column in self.columns:
|
||||
for word in query_words:
|
||||
candidate_indices.update(self.indices[column].get(word, []))
|
||||
|
||||
if len(query) <= 5: # Assume it's a ticker query
|
||||
for indexed_word in self.indices[column]:
|
||||
if indexed_word.startswith(query.lower()):
|
||||
candidate_indices.update(self.indices[column][indexed_word])
|
||||
|
||||
scores = []
|
||||
for idx in candidate_indices:
|
||||
record = {column: self.data[column][idx].as_py() for column in self.data.schema.names}
|
||||
best_score = max(
|
||||
self.calculate_score(processed_query, self.preprocess(str(record[column])), column) for column in
|
||||
self.columns)
|
||||
if best_score >= threshold:
|
||||
record['score'] = best_score
|
||||
scores.append(record)
|
||||
|
||||
return sorted(scores, key=lambda x: x['score'], reverse=True)[:top_n]
|
||||
|
||||
def _compute_data_hash(self) -> int:
|
||||
# Create a string representation of the data structure
|
||||
data_repr = f"Shape: {self.data.shape}, "
|
||||
data_repr += f"Columns: {','.join(self.data.column_names)}, "
|
||||
data_repr += f"Types: {','.join(str(field.type) for field in self.data.schema)}, "
|
||||
data_repr += f"Index Columns: {','.join(self.columns)}"
|
||||
|
||||
# Use SHA256 to create a hash of the data representation
|
||||
return int(hashlib.sha256(data_repr.encode()).hexdigest(), 16)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self._data_hash
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, FastSearch):
|
||||
return NotImplemented
|
||||
return self._data_hash == other._data_hash
|
||||
|
||||
|
||||
def create_search_index(data: pa.Table, columns: List[str], preprocess_func: Callable[[str], str] = None,
|
||||
score_func: Callable[[str, str, str], float] = None) -> FastSearch:
|
||||
return FastSearch(data, columns, preprocess_func, score_func)
|
||||
|
||||
|
||||
def search(index: FastSearch, query: str, top_n: int = 10) -> List[Dict[str, str]]:
|
||||
return index.search(query, top_n)
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def cached_search(index: FastSearch, query: str, top_n: int = 10):
|
||||
return index.search(query, top_n)
|
||||
|
||||
|
||||
# Example usage for company and ticker search
|
||||
def company_ticker_preprocess(text: str) -> str:
|
||||
text = FastSearch._default_preprocess(text)
|
||||
common_terms = ['llc', 'inc', 'corp', 'ltd', 'limited', 'company']
|
||||
return ' '.join(word for word in text.split() if word not in common_terms)
|
||||
|
||||
|
||||
def company_ticker_score(query: str, value: str, column: str) -> float:
|
||||
query = query.upper()
|
||||
value = value.upper()
|
||||
|
||||
# Check if it's likely a ticker (5 characters or less)
|
||||
if len(query) <= 5 and column == 'ticker':
|
||||
if query == value:
|
||||
return 100 # Exact match
|
||||
elif value.startswith(query):
|
||||
return 90 + (10 * len(query) / len(value)) # Partial match, score based on completeness
|
||||
else:
|
||||
return 0 # No match for tickers
|
||||
else:
|
||||
# For company names, use the default scoring method
|
||||
return FastSearch._default_calculate_score(query, value)
|
||||
|
||||
|
||||
def preprocess_company_name(company_name: str) -> str:
|
||||
company_name = unidecode(company_name.lower())
|
||||
company_name = re.sub(r'[^\w\s]', '', company_name)
|
||||
company_name = re.sub(r'\s+', ' ', company_name).strip()
|
||||
return company_name
|
||||
|
||||
|
||||
|
||||
266
venv/lib/python3.10/site-packages/edgar/search/textsearch.py
Normal file
266
venv/lib/python3.10/site-packages/edgar/search/textsearch.py
Normal file
@@ -0,0 +1,266 @@
|
||||
import re
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
import pandas as pd
|
||||
from rich import box
|
||||
from rich.console import Group
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
|
||||
from edgar._markdown import convert_table
|
||||
from edgar.richtools import repr_rich
|
||||
|
||||
PUNCTUATION = re.compile('[%s]' % re.escape(r"""!"#&'()*+,-/:;<=>?@[\]^`{|}~"""))
|
||||
|
||||
__all__ = [
|
||||
'SimilaritySearchIndex',
|
||||
'SearchResults',
|
||||
'BM25Search',
|
||||
'RegexSearch',
|
||||
'preprocess'
|
||||
]
|
||||
|
||||
|
||||
class SimilaritySearchIndex:
|
||||
|
||||
def __init__(self,
|
||||
data: pd.DataFrame,
|
||||
search_column: str):
|
||||
self.data: pd.DataFrame = data
|
||||
self._search_column = search_column
|
||||
|
||||
def similar(self,
|
||||
query: str,
|
||||
threshold=0.6,
|
||||
topn=20):
|
||||
import textdistance
|
||||
query = query.lower()
|
||||
df = (self
|
||||
.data.assign(match=self.data[self._search_column].apply(textdistance.jaro, s2=query).round(2))
|
||||
)
|
||||
df = df[df.match > threshold]
|
||||
df['matches_start'] = df[self._search_column].str.startswith(query[0])
|
||||
df = (df.sort_values(['match'], ascending=[False]).head(topn)
|
||||
.sort_values(['matches_start', 'match'], ascending=[False, False]))
|
||||
cols = [col for col in df if col not in [self._search_column, 'matches_start']]
|
||||
return df[cols]
|
||||
|
||||
def __repr__(self):
|
||||
return f"SimilaritySearchIndex(search_column='{self._search_column}')"
|
||||
|
||||
|
||||
Corpus = List[List[str]]
|
||||
|
||||
|
||||
def tokenize(text):
|
||||
return text.split()
|
||||
|
||||
|
||||
def lowercase_filter(tokens):
|
||||
return [token.lower() for token in tokens]
|
||||
|
||||
|
||||
def punctuation_filter(tokens):
|
||||
return [PUNCTUATION.sub('', token) for token in tokens]
|
||||
|
||||
|
||||
STOPWORDS = {'the', 'be', 'to', 'of', 'and', 'a', 'in', 'that', 'have', 'i', 'it', 'for', 'not', 'on', 'with', 'he',
|
||||
'as', 'you', 'do', 'at', 'this', 'but', 'his', 'by', 'from'}
|
||||
|
||||
|
||||
def stopword_filter(tokens):
|
||||
return [token for token in tokens if token not in STOPWORDS]
|
||||
|
||||
|
||||
def convert_items_to_tokens(text: str):
|
||||
"""Change 'Item 4.' to item_4. This keeps this in the filnal text"""
|
||||
return re.sub(r"item\s+(\d+\.\d+|\d+)\.?", r"item_\1", text, flags=re.IGNORECASE)
|
||||
|
||||
|
||||
def numeric_shape(tokens: List[str]) -> List[str]:
|
||||
"""Replace numbers with xx.x """
|
||||
toks = []
|
||||
for token in tokens:
|
||||
if re.fullmatch(r"(\d+[\d,.]*)%?|([,.]\d+)%?", token):
|
||||
toks.append(re.sub(r'\d', 'x', token))
|
||||
else:
|
||||
toks.append(token)
|
||||
return toks
|
||||
|
||||
|
||||
def return_spaces_to_items(tokens: List[str]) -> List[str]:
|
||||
toks = []
|
||||
pattern = r"item_(\d+(\.\d+)?)"
|
||||
for token in tokens:
|
||||
if re.fullmatch(pattern, token):
|
||||
toks += re.sub(pattern, r"item \1", token).split(" ")
|
||||
else:
|
||||
toks.append(token)
|
||||
return toks
|
||||
|
||||
|
||||
def preprocess(text: str):
|
||||
text = text.lower()
|
||||
text = convert_items_to_tokens(text)
|
||||
tokens = tokenize(text)
|
||||
tokens = punctuation_filter(tokens)
|
||||
tokens = stopword_filter(tokens)
|
||||
tokens = numeric_shape(tokens)
|
||||
tokens = return_spaces_to_items(tokens)
|
||||
return tokens
|
||||
|
||||
|
||||
def preprocess_documents(documents: List[str]) -> Corpus:
|
||||
return [preprocess(document) for document in documents]
|
||||
|
||||
|
||||
LocAndDoc = Tuple[int, str]
|
||||
|
||||
|
||||
class DocSection:
|
||||
|
||||
def __init__(self,
|
||||
loc: int,
|
||||
doc: str,
|
||||
score: float = 0.0
|
||||
):
|
||||
self.loc: int = loc
|
||||
self.doc: str = doc
|
||||
self.score: float = score
|
||||
|
||||
# Make this class sortable by loc
|
||||
def __lt__(self, other):
|
||||
return self.loc < other.loc
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.doc)
|
||||
|
||||
def json(self):
|
||||
return {
|
||||
'loc': self.loc,
|
||||
'doc': self.doc,
|
||||
'score': self.score
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.loc}\n{self.doc}"
|
||||
|
||||
|
||||
class SearchResults:
|
||||
|
||||
def __init__(self,
|
||||
query: str,
|
||||
sections: List[DocSection],
|
||||
tables: bool = True
|
||||
):
|
||||
self.query: str = query
|
||||
self.sections: List[DocSection] = sections
|
||||
self._show_tables = tables
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sections)
|
||||
|
||||
@property
|
||||
def empty(self):
|
||||
return len(self) == 0
|
||||
|
||||
def __getitem__(self, item):
|
||||
# return none instead of error
|
||||
if 0 > item >= len(self.sections):
|
||||
return None
|
||||
return self.sections[item]
|
||||
|
||||
def json(self):
|
||||
return {
|
||||
'query': self.query,
|
||||
'sections': [section.json() for section in self.sections],
|
||||
'tables': self._show_tables
|
||||
}
|
||||
|
||||
def __rich__(self):
|
||||
_md = ""
|
||||
renderables = []
|
||||
title = f"Searching for '{self.query}'"
|
||||
subtitle = f"{len(self)} result(s)" if not self.empty else "No results"
|
||||
sorted_sections = sorted(self.sections, key=lambda s: s.score, reverse=True)
|
||||
for i, doc_section in enumerate(sorted_sections):
|
||||
if doc_section.doc.startswith("| |") and self._show_tables:
|
||||
table = convert_table(doc_section.doc)
|
||||
section = table
|
||||
else:
|
||||
section = Markdown(doc_section.doc + "\n\n---")
|
||||
renderables.append(Panel(section, box=box.ROUNDED, title=f"{i}"))
|
||||
return Panel(
|
||||
Group(*renderables), title=title, subtitle=subtitle, box=box.SIMPLE
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return repr_rich(self.__rich__())
|
||||
|
||||
|
||||
class BM25Search:
|
||||
|
||||
def __init__(self,
|
||||
document_objs: List[str],
|
||||
text_fn: Callable = None):
|
||||
from rank_bm25 import BM25Okapi
|
||||
if text_fn:
|
||||
self.corpus: Corpus = [BM25Search.preprocess(text_fn(doc)) for doc in document_objs]
|
||||
else:
|
||||
self.corpus: Corpus = [BM25Search.preprocess(doc) for doc in document_objs]
|
||||
self.document_objs = document_objs
|
||||
self.bm25: BM25Okapi = BM25Okapi(self.corpus)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.document_objs)
|
||||
|
||||
@staticmethod
|
||||
def preprocess(text: str):
|
||||
text = text.lower()
|
||||
text = convert_items_to_tokens(text)
|
||||
tokens = tokenize(text)
|
||||
tokens = punctuation_filter(tokens)
|
||||
tokens = stopword_filter(tokens)
|
||||
tokens = numeric_shape(tokens)
|
||||
tokens = return_spaces_to_items(tokens)
|
||||
return tokens
|
||||
|
||||
def search(self,
|
||||
query: str,
|
||||
tables: bool = True):
|
||||
preprocessed_query = preprocess(query)
|
||||
scores = self.bm25.get_scores(preprocessed_query)
|
||||
doc_scores = zip(self.document_objs, scores, strict=False)
|
||||
# doc_scores_sorted = sorted([doc for doc in doc_scores if doc[1] > 0], key=lambda t: t[1])[::-1]
|
||||
# Return the list of location and document
|
||||
return SearchResults(query=query,
|
||||
sections=[DocSection(loc=loc, doc=doc_and_score[0], score=doc_and_score[1])
|
||||
for loc, doc_and_score in enumerate(doc_scores)
|
||||
if doc_and_score[1] > 0],
|
||||
tables=tables)
|
||||
|
||||
|
||||
class RegexSearch:
|
||||
|
||||
def __init__(self,
|
||||
documents: List[str]):
|
||||
self.document_objs = [RegexSearch.preprocess(document) for document in documents]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.document_objs)
|
||||
|
||||
@staticmethod
|
||||
def preprocess(text: str):
|
||||
text = text.replace(" ", " ")
|
||||
return text
|
||||
|
||||
def search(self,
|
||||
query: str,
|
||||
tables: bool = True):
|
||||
return SearchResults(
|
||||
query=query,
|
||||
sections=[DocSection(loc=loc, doc=doc)
|
||||
for loc, doc in enumerate(self.document_objs)
|
||||
if re.search(query, doc, flags=re.IGNORECASE)],
|
||||
tables=tables
|
||||
)
|
||||
Reference in New Issue
Block a user