Files
2025-12-09 12:13:01 +01:00

117 lines
3.5 KiB
Python

"""
Search functionality for SEC entities.
This module provides functions and classes for searching for SEC entities.
"""
from functools import lru_cache
from typing import Any, Dict, List
import pandas as pd
from rich import box
from rich.table import Column, Table
from edgar.entity import Company
from edgar.entity.tickers import get_company_tickers
from edgar.richtools import repr_rich
from edgar.search.datasearch import FastSearch, company_ticker_preprocess, company_ticker_score
__all__ = [
'find_company',
'CompanySearchResults',
'CompanySearchIndex'
]
class CompanySearchResults:
"""
Results from a company search.
"""
def __init__(self, query: str,
search_results: List[Dict[str, Any]]):
self.query: str = query
self.results: pd.DataFrame = pd.DataFrame(search_results, columns=['cik', 'ticker', 'company', 'score'])
@property
def tickers(self):
return self.results.ticker.tolist()
@property
def ciks(self):
return self.results.cik.tolist()
@property
def empty(self):
return self.results.empty
def __len__(self):
return len(self.results)
def __getitem__(self, item):
if 0 <= item < len(self):
row = self.results.iloc[item]
cik: int = int(row.cik)
return Company(cik)
def __rich__(self):
table = Table(Column(""),
Column("Ticker", justify="left"),
Column("Name", justify="left"),
Column("Score", justify="left"),
title=f"Search results for '{self.query}'",
box=box.SIMPLE)
for index, row in enumerate(self.results.itertuples()):
table.add_row(str(index), row.ticker.rjust(6), row.company, f"{int(row.score)}%")
return table
def __repr__(self):
return repr_rich(self.__rich__())
class CompanySearchIndex(FastSearch):
"""
Search index for companies.
"""
def __init__(self):
data = get_company_tickers(as_dataframe=False)
super().__init__(data, ['company', 'ticker'],
preprocess_func=company_ticker_preprocess,
score_func=company_ticker_score)
def search(self, query: str, top_n: int = 10, threshold: float = 60) -> CompanySearchResults:
results = super().search(query, top_n, threshold)
return CompanySearchResults(query=query, search_results=results)
def __len__(self):
return len(self.data)
def __hash__(self):
# Combine column names and last 10 values in the 'company' column to create a hash
column_names = tuple(self.data[0].keys())
last_10_companies = tuple(entry['company'] for entry in self.data[-10:])
return hash((column_names, last_10_companies))
def __eq__(self, other):
if not isinstance(other, CompanySearchIndex):
return False
return (self.data[-10:], tuple(self.data[0].keys())) == (other.data[-10:], tuple(other.data[0].keys()))
@lru_cache(maxsize=1)
def _get_company_search_index():
"""Get the company search index."""
return CompanySearchIndex()
@lru_cache(maxsize=16)
def find_company(company: str, top_n: int = 10):
"""
Find a company by name.
Args:
company: The company name or ticker to search for
top_n: The maximum number of results to return
Returns:
CompanySearchResults: The search results
"""
return _get_company_search_index().search(company, top_n=top_n)