267 lines
7.7 KiB
Python
267 lines
7.7 KiB
Python
import re
|
|
from typing import Callable, List, Tuple
|
|
|
|
import pandas as pd
|
|
from rich import box
|
|
from rich.console import Group
|
|
from rich.markdown import Markdown
|
|
from rich.panel import Panel
|
|
|
|
from edgar._markdown import convert_table
|
|
from edgar.richtools import repr_rich
|
|
|
|
PUNCTUATION = re.compile('[%s]' % re.escape(r"""!"#&'()*+,-/:;<=>?@[\]^`{|}~"""))
|
|
|
|
__all__ = [
|
|
'SimilaritySearchIndex',
|
|
'SearchResults',
|
|
'BM25Search',
|
|
'RegexSearch',
|
|
'preprocess'
|
|
]
|
|
|
|
|
|
class SimilaritySearchIndex:
|
|
|
|
def __init__(self,
|
|
data: pd.DataFrame,
|
|
search_column: str):
|
|
self.data: pd.DataFrame = data
|
|
self._search_column = search_column
|
|
|
|
def similar(self,
|
|
query: str,
|
|
threshold=0.6,
|
|
topn=20):
|
|
import textdistance
|
|
query = query.lower()
|
|
df = (self
|
|
.data.assign(match=self.data[self._search_column].apply(textdistance.jaro, s2=query).round(2))
|
|
)
|
|
df = df[df.match > threshold]
|
|
df['matches_start'] = df[self._search_column].str.startswith(query[0])
|
|
df = (df.sort_values(['match'], ascending=[False]).head(topn)
|
|
.sort_values(['matches_start', 'match'], ascending=[False, False]))
|
|
cols = [col for col in df if col not in [self._search_column, 'matches_start']]
|
|
return df[cols]
|
|
|
|
def __repr__(self):
|
|
return f"SimilaritySearchIndex(search_column='{self._search_column}')"
|
|
|
|
|
|
Corpus = List[List[str]]
|
|
|
|
|
|
def tokenize(text):
|
|
return text.split()
|
|
|
|
|
|
def lowercase_filter(tokens):
|
|
return [token.lower() for token in tokens]
|
|
|
|
|
|
def punctuation_filter(tokens):
|
|
return [PUNCTUATION.sub('', token) for token in tokens]
|
|
|
|
|
|
STOPWORDS = {'the', 'be', 'to', 'of', 'and', 'a', 'in', 'that', 'have', 'i', 'it', 'for', 'not', 'on', 'with', 'he',
|
|
'as', 'you', 'do', 'at', 'this', 'but', 'his', 'by', 'from'}
|
|
|
|
|
|
def stopword_filter(tokens):
|
|
return [token for token in tokens if token not in STOPWORDS]
|
|
|
|
|
|
def convert_items_to_tokens(text: str):
|
|
"""Change 'Item 4.' to item_4. This keeps this in the filnal text"""
|
|
return re.sub(r"item\s+(\d+\.\d+|\d+)\.?", r"item_\1", text, flags=re.IGNORECASE)
|
|
|
|
|
|
def numeric_shape(tokens: List[str]) -> List[str]:
|
|
"""Replace numbers with xx.x """
|
|
toks = []
|
|
for token in tokens:
|
|
if re.fullmatch(r"(\d+[\d,.]*)%?|([,.]\d+)%?", token):
|
|
toks.append(re.sub(r'\d', 'x', token))
|
|
else:
|
|
toks.append(token)
|
|
return toks
|
|
|
|
|
|
def return_spaces_to_items(tokens: List[str]) -> List[str]:
|
|
toks = []
|
|
pattern = r"item_(\d+(\.\d+)?)"
|
|
for token in tokens:
|
|
if re.fullmatch(pattern, token):
|
|
toks += re.sub(pattern, r"item \1", token).split(" ")
|
|
else:
|
|
toks.append(token)
|
|
return toks
|
|
|
|
|
|
def preprocess(text: str):
|
|
text = text.lower()
|
|
text = convert_items_to_tokens(text)
|
|
tokens = tokenize(text)
|
|
tokens = punctuation_filter(tokens)
|
|
tokens = stopword_filter(tokens)
|
|
tokens = numeric_shape(tokens)
|
|
tokens = return_spaces_to_items(tokens)
|
|
return tokens
|
|
|
|
|
|
def preprocess_documents(documents: List[str]) -> Corpus:
|
|
return [preprocess(document) for document in documents]
|
|
|
|
|
|
LocAndDoc = Tuple[int, str]
|
|
|
|
|
|
class DocSection:
|
|
|
|
def __init__(self,
|
|
loc: int,
|
|
doc: str,
|
|
score: float = 0.0
|
|
):
|
|
self.loc: int = loc
|
|
self.doc: str = doc
|
|
self.score: float = score
|
|
|
|
# Make this class sortable by loc
|
|
def __lt__(self, other):
|
|
return self.loc < other.loc
|
|
|
|
def __hash__(self):
|
|
return hash(self.doc)
|
|
|
|
def json(self):
|
|
return {
|
|
'loc': self.loc,
|
|
'doc': self.doc,
|
|
'score': self.score
|
|
}
|
|
|
|
def __repr__(self):
|
|
return f"{self.loc}\n{self.doc}"
|
|
|
|
|
|
class SearchResults:
|
|
|
|
def __init__(self,
|
|
query: str,
|
|
sections: List[DocSection],
|
|
tables: bool = True
|
|
):
|
|
self.query: str = query
|
|
self.sections: List[DocSection] = sections
|
|
self._show_tables = tables
|
|
|
|
def __len__(self):
|
|
return len(self.sections)
|
|
|
|
@property
|
|
def empty(self):
|
|
return len(self) == 0
|
|
|
|
def __getitem__(self, item):
|
|
# return none instead of error
|
|
if 0 > item >= len(self.sections):
|
|
return None
|
|
return self.sections[item]
|
|
|
|
def json(self):
|
|
return {
|
|
'query': self.query,
|
|
'sections': [section.json() for section in self.sections],
|
|
'tables': self._show_tables
|
|
}
|
|
|
|
def __rich__(self):
|
|
_md = ""
|
|
renderables = []
|
|
title = f"Searching for '{self.query}'"
|
|
subtitle = f"{len(self)} result(s)" if not self.empty else "No results"
|
|
sorted_sections = sorted(self.sections, key=lambda s: s.score, reverse=True)
|
|
for i, doc_section in enumerate(sorted_sections):
|
|
if doc_section.doc.startswith("| |") and self._show_tables:
|
|
table = convert_table(doc_section.doc)
|
|
section = table
|
|
else:
|
|
section = Markdown(doc_section.doc + "\n\n---")
|
|
renderables.append(Panel(section, box=box.ROUNDED, title=f"{i}"))
|
|
return Panel(
|
|
Group(*renderables), title=title, subtitle=subtitle, box=box.SIMPLE
|
|
)
|
|
|
|
def __repr__(self):
|
|
return repr_rich(self.__rich__())
|
|
|
|
|
|
class BM25Search:
|
|
|
|
def __init__(self,
|
|
document_objs: List[str],
|
|
text_fn: Callable = None):
|
|
from rank_bm25 import BM25Okapi
|
|
if text_fn:
|
|
self.corpus: Corpus = [BM25Search.preprocess(text_fn(doc)) for doc in document_objs]
|
|
else:
|
|
self.corpus: Corpus = [BM25Search.preprocess(doc) for doc in document_objs]
|
|
self.document_objs = document_objs
|
|
self.bm25: BM25Okapi = BM25Okapi(self.corpus)
|
|
|
|
def __len__(self):
|
|
return len(self.document_objs)
|
|
|
|
@staticmethod
|
|
def preprocess(text: str):
|
|
text = text.lower()
|
|
text = convert_items_to_tokens(text)
|
|
tokens = tokenize(text)
|
|
tokens = punctuation_filter(tokens)
|
|
tokens = stopword_filter(tokens)
|
|
tokens = numeric_shape(tokens)
|
|
tokens = return_spaces_to_items(tokens)
|
|
return tokens
|
|
|
|
def search(self,
|
|
query: str,
|
|
tables: bool = True):
|
|
preprocessed_query = preprocess(query)
|
|
scores = self.bm25.get_scores(preprocessed_query)
|
|
doc_scores = zip(self.document_objs, scores, strict=False)
|
|
# doc_scores_sorted = sorted([doc for doc in doc_scores if doc[1] > 0], key=lambda t: t[1])[::-1]
|
|
# Return the list of location and document
|
|
return SearchResults(query=query,
|
|
sections=[DocSection(loc=loc, doc=doc_and_score[0], score=doc_and_score[1])
|
|
for loc, doc_and_score in enumerate(doc_scores)
|
|
if doc_and_score[1] > 0],
|
|
tables=tables)
|
|
|
|
|
|
class RegexSearch:
|
|
|
|
def __init__(self,
|
|
documents: List[str]):
|
|
self.document_objs = [RegexSearch.preprocess(document) for document in documents]
|
|
|
|
def __len__(self):
|
|
return len(self.document_objs)
|
|
|
|
@staticmethod
|
|
def preprocess(text: str):
|
|
text = text.replace(" ", " ")
|
|
return text
|
|
|
|
def search(self,
|
|
query: str,
|
|
tables: bool = True):
|
|
return SearchResults(
|
|
query=query,
|
|
sections=[DocSection(loc=loc, doc=doc)
|
|
for loc, doc in enumerate(self.document_objs)
|
|
if re.search(query, doc, flags=re.IGNORECASE)],
|
|
tables=tables
|
|
)
|