Files
2025-12-09 12:13:01 +01:00

641 lines
23 KiB
Python

"""
XBRL Statement Stitching - Query Functionality
This module provides query functionality for stitched XBRL facts, allowing
users to query standardized, multi-period financial data.
"""
import re
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import pandas as pd
from rich import box
from rich.console import Group
from rich.markdown import Markdown
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
from edgar.richtools import repr_rich
from edgar.xbrl.facts import FactQuery
if TYPE_CHECKING:
from edgar.xbrl.stitching.xbrls import XBRLS
class StitchedFactsView:
"""
A view over stitched facts from multiple XBRL filings.
This class extracts facts from stitched statements rather than raw XBRL facts,
ensuring that queries operate on standardized, post-processed data.
"""
def __init__(self, xbrls: 'XBRLS'):
self.xbrls = xbrls
self._facts_cache = None
self._last_cache_key = None
def __len__(self):
return len(self.get_facts())
@property
def entity_name(self):
"""Get entity name from the most recent XBRL filing."""
if self.xbrls.xbrl_list:
return getattr(self.xbrls.xbrl_list[0], 'entity_name', 'Unknown Entity')
return 'Unknown Entity'
@property
def document_type(self):
"""Get document type from entity info."""
return self.xbrls.entity_info.get('document_type', 'Multi-Period Stitched')
def get_facts(self,
max_periods: int = 8,
standard: bool = True,
statement_types: Optional[List[str]] = None) -> List[Dict[str, Any]]:
"""
Extract facts from stitched statements.
Args:
max_periods: Maximum periods to include
standard: Whether to use standardized labels
statement_types: List of statement types to include
Returns:
List of fact dictionaries with stitched/standardized data
"""
# Create cache key
cache_key = (max_periods, standard, tuple(statement_types or []))
if self._facts_cache and self._last_cache_key == cache_key:
return self._facts_cache
statement_types = statement_types or [
'IncomeStatement', 'BalanceSheet', 'CashFlowStatement',
'StatementOfEquity', 'ComprehensiveIncome'
]
all_facts = []
for statement_type in statement_types:
try:
# Get stitched statement data (this applies standardization)
stitched_data = self.xbrls.get_statement(
statement_type=statement_type,
max_periods=max_periods,
standard=standard
)
# Extract facts from stitched data
facts = self._extract_facts_from_stitched_data(
stitched_data, statement_type
)
all_facts.extend(facts)
except Exception:
# Skip statements that can't be stitched
continue
# Cache results
self._facts_cache = all_facts
self._last_cache_key = cache_key
return all_facts
def _extract_facts_from_stitched_data(self,
stitched_data: Dict[str, Any],
statement_type: str) -> List[Dict[str, Any]]:
"""
Convert stitched statement data back to fact-like records for querying.
Args:
stitched_data: Output from StatementStitcher
statement_type: Type of statement
Returns:
List of fact dictionaries
"""
facts = []
periods = stitched_data.get('periods', [])
statement_data = stitched_data.get('statement_data', [])
for item in statement_data:
# Skip abstract items without values
if item.get('is_abstract', False) and not item.get('has_values', False):
continue
concept = item.get('concept', '')
label = item.get('label', '')
original_label = item.get('original_label', label)
# Create a fact record for each period with data
for period_id, value in item.get('values', {}).items():
if value is None:
continue
# Find period metadata
period_info = self._get_period_info(period_id, periods)
fact = {
# Core identification
'concept': concept,
'label': label, # Standardized label
'original_label': original_label, # Original company label
'statement_type': statement_type,
# Value information
'value': value,
'numeric_value': self._convert_to_numeric(value),
'decimals': item.get('decimals', {}).get(period_id, 0),
# Period information
'period_key': period_id,
'period_type': period_info.get('period_type', 'duration'),
'period_start': period_info.get('period_start'),
'period_end': period_info.get('period_end'),
'period_instant': period_info.get('period_instant'),
'period_label': period_info.get('period_label', ''),
# Statement context
'level': item.get('level', 0),
'is_abstract': item.get('is_abstract', False),
'is_total': item.get('is_total', False),
# Multi-filing context
'filing_count': len(self.xbrls.xbrl_list),
'standardized': True, # Mark as coming from standardized data
# Source attribution (which XBRL filing this came from)
'source_filing_index': self._determine_source_filing(period_id),
}
# Add fiscal period info if available
fiscal_info = self._extract_fiscal_info(period_id)
fact.update(fiscal_info)
facts.append(fact)
return facts
def _get_period_info(self, period_id: str, periods: List[tuple]) -> Dict[str, Any]:
"""Extract period metadata from period_id and periods list."""
period_info = {}
# Find matching period
for pid, label in periods:
if pid == period_id:
period_info['period_label'] = label
break
# Parse period_id to extract dates and type
if period_id.startswith('instant_'):
period_info['period_type'] = 'instant'
date_str = period_id.replace('instant_', '')
period_info['period_instant'] = date_str
period_info['period_end'] = date_str
elif period_id.startswith('duration_'):
period_info['period_type'] = 'duration'
parts = period_id.replace('duration_', '').split('_')
if len(parts) >= 2:
period_info['period_start'] = parts[0]
period_info['period_end'] = parts[1]
return period_info
def _convert_to_numeric(self, value: Any) -> Optional[float]:
"""Convert value to numeric if possible."""
if value is None:
return None
try:
if isinstance(value, (int, float)):
return float(value)
if isinstance(value, str):
# Remove commas and try to convert
cleaned = value.replace(',', '').replace('$', '').strip()
return float(cleaned)
except (ValueError, TypeError):
pass
return None
def _determine_source_filing(self, period_id: str) -> Optional[int]:
"""Determine which filing this period came from."""
# This would require enhanced tracking in the stitching process
# For now, return None but this could be enhanced
return None
def _extract_fiscal_info(self, period_id: str) -> Dict[str, Any]:
"""Extract fiscal period and year information."""
fiscal_info = {}
# Try to extract fiscal info from entity_info of the relevant XBRL
# This is a simplified approach - could be enhanced with better tracking
if self.xbrls.xbrl_list:
entity_info = self.xbrls.xbrl_list[0].entity_info
if entity_info:
fiscal_info['fiscal_period'] = entity_info.get('fiscal_period')
fiscal_info['fiscal_year'] = entity_info.get('fiscal_year')
return fiscal_info
def query(self, **kwargs) -> 'StitchedFactQuery':
"""Create a new query for stitched facts."""
return StitchedFactQuery(self, **kwargs)
class StitchedFactQuery(FactQuery):
"""
Enhanced fact query for stitched/standardized multi-filing data.
Extends the base FactQuery with capabilities specific to multi-period,
standardized financial data.
"""
def __init__(self, stitched_facts_view: StitchedFactsView, **kwargs):
# Initialize with stitched facts view instead of regular facts view
self._stitched_facts_view = stitched_facts_view
# Initialize base FactQuery attributes manually since we're not calling super().__init__
self._facts_view = stitched_facts_view # For compatibility with base class
self._filters = []
self._transformations = []
self._aggregations = []
self._include_dimensions = True
self._include_contexts = True
self._include_element_info = True
self._sort_by = None
self._sort_ascending = True
self._limit = None
self._statement_type = None
# Multi-filing specific options
self._cross_period_only = False
self._trend_analysis = False
self._require_all_periods = False
# Store query-specific parameters for get_facts
self._max_periods = kwargs.get('max_periods', 8)
self._standard = kwargs.get('standard', True)
self._statement_types = kwargs.get('statement_types', None)
def __str__(self):
return f"StitchedFactQuery(filters={len(self._filters)})"
# Enhanced filtering methods for multi-filing scenarios
def by_standardized_concept(self, concept_name: str) -> 'StitchedFactQuery':
"""
Filter by standardized concept name (e.g., 'Revenue', 'Net Income').
Args:
concept_name: Standardized concept name
Returns:
Self for method chaining
"""
# Query both the standardized label and original concept
self._filters.append(
lambda f: (f.get('label') == concept_name or
concept_name.lower() in f.get('label', '').lower() or
concept_name.lower() in f.get('concept', '').lower())
)
return self
def by_original_label(self, pattern: str, exact: bool = False) -> 'StitchedFactQuery':
"""
Filter by original company-specific labels before standardization.
Args:
pattern: Pattern to match against original labels
exact: Whether to require exact match
Returns:
Self for method chaining
"""
if exact:
self._filters.append(lambda f: f.get('original_label') == pattern)
else:
regex = re.compile(pattern, re.IGNORECASE)
self._filters.append(
lambda f: f.get('original_label') and regex.search(f['original_label'])
)
return self
def across_periods(self, min_periods: int = 2) -> 'StitchedFactQuery':
"""
Filter to concepts that appear across multiple periods.
Args:
min_periods: Minimum number of periods the concept must appear in
Returns:
Self for method chaining
"""
self._cross_period_only = True
self._min_periods = min_periods
return self
def by_fiscal_period(self, fiscal_period: str) -> 'StitchedFactQuery':
"""
Filter by fiscal period (FY, Q1, Q2, Q3, Q4).
Args:
fiscal_period: Fiscal period identifier
Returns:
Self for method chaining
"""
self._filters.append(
lambda f: f.get('fiscal_period') == fiscal_period
)
return self
def by_filing_index(self, filing_index: int) -> 'StitchedFactQuery':
"""
Filter facts by which filing they originated from.
Args:
filing_index: Index of the filing (0 = most recent)
Returns:
Self for method chaining
"""
self._filters.append(
lambda f: f.get('source_filing_index') == filing_index
)
return self
def trend_analysis(self, concept: str) -> 'StitchedFactQuery':
"""
Set up for trend analysis of a specific concept across periods.
Args:
concept: Concept to analyze trends for
Returns:
Self for method chaining
"""
self._trend_analysis = True
self.by_standardized_concept(concept)
return self
def complete_periods_only(self) -> 'StitchedFactQuery':
"""
Only return concepts that have values in all available periods.
Returns:
Self for method chaining
"""
self._require_all_periods = True
return self
def execute(self) -> List[Dict[str, Any]]:
"""
Execute the query with enhanced multi-period processing.
Returns:
List of fact dictionaries
"""
# Get base results from stitched facts with query parameters
results = self._stitched_facts_view.get_facts(
max_periods=self._max_periods,
standard=self._standard,
statement_types=self._statement_types
)
# Apply standard filters
for filter_func in self._filters:
results = [f for f in results if filter_func(f)]
# Apply transformations
for transform_fn in self._transformations:
for fact in results:
if 'value' in fact and fact['value'] is not None:
fact['value'] = transform_fn(fact['value'])
# Apply aggregations
if self._aggregations:
aggregated_results = {}
for agg in self._aggregations:
dimension = agg['dimension']
func = agg['function']
# Group facts by dimension
groups = {}
for fact in results:
dim_value = fact.get(f'dim_{dimension}')
if dim_value and 'value' in fact and fact['value'] is not None:
if dim_value not in groups:
groups[dim_value] = []
groups[dim_value].append(fact['value'])
# Apply aggregation function
for dim_value, values in groups.items():
agg_value = 0.0 # Initialize with default value
if func == 'sum':
agg_value = sum(values)
elif func == 'average':
agg_value = sum(values) / len(values)
key = (dimension, dim_value)
if key not in aggregated_results:
aggregated_results[key] = {'dimension': dimension, 'value': dim_value, 'values': {}}
aggregated_results[key]['values'][func] = agg_value
results = list(aggregated_results.values())
# Apply cross-period filtering if requested
if self._cross_period_only:
results = self._filter_cross_period_concepts(results)
# Apply complete periods filtering if requested
if self._require_all_periods:
results = self._filter_complete_periods(results)
# Apply trend analysis if requested
if self._trend_analysis:
results = self._prepare_trend_data(results)
# Apply sorting if specified
if results and self._sort_by and self._sort_by in results[0]:
results.sort(key=lambda f: f.get(self._sort_by, ''),
reverse=not self._sort_ascending)
# Apply limit if specified
if self._limit is not None:
results = results[:self._limit]
return results
def _filter_cross_period_concepts(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Filter to concepts that appear in multiple periods."""
concept_periods = defaultdict(set)
for fact in results:
concept_key = (fact.get('concept', ''), fact.get('label', ''))
concept_periods[concept_key].add(fact.get('period_key', ''))
# Filter to concepts with minimum period count
valid_concepts = {
concept for concept, periods in concept_periods.items()
if len(periods) >= getattr(self, '_min_periods', 2)
}
return [
fact for fact in results
if (fact.get('concept', ''), fact.get('label', '')) in valid_concepts
]
def _filter_complete_periods(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Filter to concepts that have values in all periods."""
# Get all available periods
all_periods = set(fact.get('period_key', '') for fact in results)
concept_periods = defaultdict(set)
for fact in results:
concept_key = (fact.get('concept', ''), fact.get('label', ''))
concept_periods[concept_key].add(fact.get('period_key', ''))
# Filter to concepts with complete period coverage
complete_concepts = {
concept for concept, periods in concept_periods.items()
if periods == all_periods
}
return [
fact for fact in results
if (fact.get('concept', ''), fact.get('label', '')) in complete_concepts
]
def _prepare_trend_data(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Prepare data for trend analysis by sorting periods."""
# Sort by period end date for trend analysis
return sorted(results, key=lambda f: f.get('period_end', ''))
def to_trend_dataframe(self) -> pd.DataFrame:
"""
Create a DataFrame optimized for trend analysis.
Returns:
DataFrame with concepts as rows and periods as columns
"""
results = self.execute()
if not results:
return pd.DataFrame()
# Pivot data for trend analysis
df = pd.DataFrame(results)
# Create pivot table with concepts as rows and periods as columns
if 'concept' in df.columns and 'period_end' in df.columns and 'numeric_value' in df.columns:
trend_df = df.pivot_table(
index=['label', 'concept'],
columns='period_end',
values='numeric_value',
aggfunc='first'
)
return trend_df
return df
def to_dataframe(self, *columns) -> pd.DataFrame:
"""
Execute the query and return results as a DataFrame.
Args:
columns: List of columns to include in the DataFrame
Returns:
pandas DataFrame with query results
"""
results = self.execute()
if not results:
return pd.DataFrame()
df = pd.DataFrame(results)
df['value'] = df['value'].astype(str) # Ensure value is string for display
# Filter columns based on inclusion flags
if not self._include_dimensions:
df = df.loc[:, [col for col in df.columns if not col.startswith('dim_')]]
if not self._include_contexts:
context_cols = ['context_ref', 'entity_identifier', 'entity_scheme',
'period_type']
df = df.loc[:, [col for col in df.columns if col not in context_cols]]
if not self._include_element_info:
element_cols = ['element_id', 'element_name', 'element_type', 'element_period_type',
'element_balance', 'element_label']
df = df.loc[:, [col for col in df.columns if col not in element_cols]]
# Drop empty columns
df = df.dropna(axis=1, how='all')
# Filter columns if specified
if columns:
df = df[list(columns)]
# Skip these columns
skip_columns = ['fact_key', 'period_key']
# Order columns
first_columns = [col for col in
['concept', 'label', 'original_label', 'value', 'numeric_value',
'period_start', 'period_end', 'decimals', 'statement_type', 'fiscal_period']
if col in df.columns]
columns = first_columns + [col for col in df.columns
if col not in first_columns
and col not in skip_columns]
return df[columns]
def __rich__(self):
title = Text.assemble(("Stitched Facts Query"),
)
subtitle = Text.assemble((self._stitched_facts_view.entity_name, "bold deep_sky_blue1"),
" - ",
(self._stitched_facts_view.document_type)
)
df = self.to_dataframe().fillna('')
columns = df.columns.tolist()
description = Markdown(
f"""
Use *to_dataframe(columns)* to get a DataFrame of the results.
e.g. `query.to_dataframe('concept', 'value', 'period_end')`
Available columns:
'{', '.join(columns)}'
**Enhanced Multi-Period Methods:**
- `across_periods(min_periods=2)` - Filter to concepts across multiple periods
- `by_standardized_concept('Revenue')` - Filter by standardized labels
- `by_original_label('Net sales')` - Filter by original company labels
- `trend_analysis('Revenue')` - Set up trend analysis
- `to_trend_dataframe()` - Get trend-optimized DataFrame
"""
)
display_columns = [col for col in ['label', 'concept', 'value', 'period_start', 'period_end', 'statement_type']
if col in columns]
if not df.empty:
df_display = df[display_columns].head(10) # Show first 10 rows
table = Table(*display_columns, show_header=True, header_style="bold", box=box.SIMPLE)
for t in df_display.itertuples(index=False):
row = []
for i in t:
row.append(str(i)[:50]) # Truncate long values
table.add_row(*row)
else:
table = Table("No results found", box=box.SIMPLE)
panel = Panel(Group(description, table), title=title, subtitle=subtitle, box=box.ROUNDED)
return panel
def __repr__(self):
return repr_rich(self.__rich__())