117 lines
3.5 KiB
Python
117 lines
3.5 KiB
Python
"""
|
|
Search functionality for SEC entities.
|
|
This module provides functions and classes for searching for SEC entities.
|
|
"""
|
|
from functools import lru_cache
|
|
from typing import Any, Dict, List
|
|
|
|
import pandas as pd
|
|
from rich import box
|
|
from rich.table import Column, Table
|
|
|
|
from edgar.entity import Company
|
|
from edgar.entity.tickers import get_company_tickers
|
|
from edgar.richtools import repr_rich
|
|
from edgar.search.datasearch import FastSearch, company_ticker_preprocess, company_ticker_score
|
|
|
|
__all__ = [
|
|
'find_company',
|
|
'CompanySearchResults',
|
|
'CompanySearchIndex'
|
|
]
|
|
|
|
|
|
class CompanySearchResults:
|
|
"""
|
|
Results from a company search.
|
|
"""
|
|
def __init__(self, query: str,
|
|
search_results: List[Dict[str, Any]]):
|
|
self.query: str = query
|
|
self.results: pd.DataFrame = pd.DataFrame(search_results, columns=['cik', 'ticker', 'company', 'score'])
|
|
|
|
@property
|
|
def tickers(self):
|
|
return self.results.ticker.tolist()
|
|
|
|
@property
|
|
def ciks(self):
|
|
return self.results.cik.tolist()
|
|
|
|
@property
|
|
def empty(self):
|
|
return self.results.empty
|
|
|
|
def __len__(self):
|
|
return len(self.results)
|
|
|
|
def __getitem__(self, item):
|
|
if 0 <= item < len(self):
|
|
row = self.results.iloc[item]
|
|
cik: int = int(row.cik)
|
|
return Company(cik)
|
|
|
|
def __rich__(self):
|
|
table = Table(Column(""),
|
|
Column("Ticker", justify="left"),
|
|
Column("Name", justify="left"),
|
|
Column("Score", justify="left"),
|
|
title=f"Search results for '{self.query}'",
|
|
box=box.SIMPLE)
|
|
for index, row in enumerate(self.results.itertuples()):
|
|
table.add_row(str(index), row.ticker.rjust(6), row.company, f"{int(row.score)}%")
|
|
return table
|
|
|
|
def __repr__(self):
|
|
return repr_rich(self.__rich__())
|
|
|
|
|
|
class CompanySearchIndex(FastSearch):
|
|
"""
|
|
Search index for companies.
|
|
"""
|
|
def __init__(self):
|
|
data = get_company_tickers(as_dataframe=False)
|
|
super().__init__(data, ['company', 'ticker'],
|
|
preprocess_func=company_ticker_preprocess,
|
|
score_func=company_ticker_score)
|
|
|
|
def search(self, query: str, top_n: int = 10, threshold: float = 60) -> CompanySearchResults:
|
|
results = super().search(query, top_n, threshold)
|
|
return CompanySearchResults(query=query, search_results=results)
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __hash__(self):
|
|
# Combine column names and last 10 values in the 'company' column to create a hash
|
|
column_names = tuple(self.data[0].keys())
|
|
last_10_companies = tuple(entry['company'] for entry in self.data[-10:])
|
|
return hash((column_names, last_10_companies))
|
|
|
|
def __eq__(self, other):
|
|
if not isinstance(other, CompanySearchIndex):
|
|
return False
|
|
return (self.data[-10:], tuple(self.data[0].keys())) == (other.data[-10:], tuple(other.data[0].keys()))
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def _get_company_search_index():
|
|
"""Get the company search index."""
|
|
return CompanySearchIndex()
|
|
|
|
|
|
@lru_cache(maxsize=16)
|
|
def find_company(company: str, top_n: int = 10):
|
|
"""
|
|
Find a company by name.
|
|
|
|
Args:
|
|
company: The company name or ticker to search for
|
|
top_n: The maximum number of results to return
|
|
|
|
Returns:
|
|
CompanySearchResults: The search results
|
|
"""
|
|
return _get_company_search_index().search(company, top_n=top_n)
|