Compare commits

...

6 Commits

Author SHA1 Message Date
agentson
033d5fcadd Merge main into feature/issue-22-data-driven
Some checks failed
CI / test (pull_request) Has been cancelled
2026-02-04 18:41:44 +09:00
128324427f Merge pull request 'feat: implement Token Efficiency - Context optimization (issue #24)' (#28) from feature/issue-24-token-efficiency into main
Some checks failed
CI / test (push) Has been cancelled
Reviewed-on: #28
2026-02-04 18:39:20 +09:00
agentson
61f5aaf4a3 fix: resolve linting issues in token efficiency implementation
Some checks failed
CI / test (pull_request) Has been cancelled
- Fix ambiguous variable names (l → layer)
- Remove unused imports and variables
- Organize import statements

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-02-04 18:35:55 +09:00
agentson
4f61d5af8e feat: implement token efficiency optimization for issue #24
Implement comprehensive token efficiency system to reduce LLM costs:

- Add prompt_optimizer.py: Token counting, compression, abbreviations
- Add context_selector.py: Smart L1-L7 context layer selection
- Add summarizer.py: Historical data aggregation and summarization
- Add cache.py: TTL-based response caching with hit rate tracking
- Enhance gemini_client.py: Integrate optimization, caching, metrics

Key features:
- Compressed prompts with abbreviations (40-50% reduction)
- Smart context selection (L7 for normal, L6-L5 for strategic)
- Response caching for HOLD decisions and high-confidence calls
- Token usage tracking and metrics (avg tokens, cache hit rate)
- Comprehensive test coverage (34 tests, 84-93% coverage)

Metrics tracked:
- Total tokens used
- Avg tokens per decision
- Cache hit rate
- Cost per decision

All tests passing (191 total, 76% overall coverage).

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-02-04 18:09:51 +09:00
agentson
62fd4ff5e1 feat: implement data-driven external data integration (issue #22)
Add objective external data sources to enhance trading decisions beyond
market prices and user input.

## New Modules

### src/data/news_api.py
- News sentiment analysis with Alpha Vantage and NewsAPI support
- Sentiment scoring (-1.0 to +1.0) per article and aggregated
- 5-minute caching to minimize API quota usage
- Graceful degradation when APIs unavailable

### src/data/economic_calendar.py
- Track major economic events (FOMC, GDP, CPI)
- Earnings calendar per stock
- Event proximity checking for high-volatility periods
- Hardcoded major events for 2026 (no API required)

### src/data/market_data.py
- Market sentiment indicators (Fear & Greed equivalent)
- Market breadth (advance/decline ratios)
- Sector performance tracking
- Fear/Greed score calculation

## Integration

Enhanced GeminiClient to seamlessly integrate external data:
- Optional news_api, economic_calendar, and market_data parameters
- Async build_prompt() includes external context when available
- Backward-compatible build_prompt_sync() for existing code
- Graceful fallback when external data unavailable

External data automatically added to AI prompts:
- News sentiment with top articles
- Upcoming high-impact economic events
- Market sentiment and breadth indicators

## Configuration

Added optional settings to config.py:
- NEWS_API_KEY: API key for news provider
- NEWS_API_PROVIDER: "alphavantage" or "newsapi"
- MARKET_DATA_API_KEY: API key for market data

## Testing

Comprehensive test suite with 38 tests:
- NewsAPI caching, sentiment parsing, API integration
- EconomicCalendar event filtering, earnings lookup
- MarketData sentiment and breadth calculations
- GeminiClient integration with external data sources
- All tests use mocks (no real API keys required)
- 81% coverage for src/data module (exceeds 80% requirement)

## Circular Import Fix

Fixed circular dependency between gemini_client.py and cache.py:
- Use TYPE_CHECKING for imports in cache.py
- String annotations for TradeDecision type hints

All 195 existing tests pass. No breaking changes to existing functionality.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-02-04 18:06:34 +09:00
f40f19e735 Merge pull request 'feat: implement Latency Control with criticality-based prioritization (Pillar 1)' (#27) from feature/issue-21-latency-control into main
Some checks failed
CI / test (push) Has been cancelled
Reviewed-on: #27
2026-02-04 17:02:40 +09:00
19 changed files with 3842 additions and 19 deletions

View File

@@ -21,3 +21,8 @@ RATE_LIMIT_RPS=10.0
# Trading Mode (paper / live)
MODE=paper
# External Data APIs (optional — for enhanced decision-making)
# NEWS_API_KEY=your_news_api_key_here
# NEWS_API_PROVIDER=alphavantage
# MARKET_DATA_API_KEY=your_market_data_key_here

3
.gitignore vendored
View File

@@ -174,4 +174,7 @@ cython_debug/
# PyPI configuration file
.pypirc
# Data files (trade logs, databases)
# But NOT src/data/ which contains source code
data/
!src/data/

293
src/brain/cache.py Normal file
View File

@@ -0,0 +1,293 @@
"""Response caching system for reducing redundant LLM calls.
This module provides caching for common trading scenarios:
- TTL-based cache invalidation
- Cache key based on market conditions
- Cache hit rate monitoring
- Special handling for HOLD decisions in quiet markets
"""
from __future__ import annotations
import hashlib
import json
import logging
import time
from dataclasses import dataclass, field
from typing import Any, TYPE_CHECKING
if TYPE_CHECKING:
from src.brain.gemini_client import TradeDecision
logger = logging.getLogger(__name__)
@dataclass
class CacheEntry:
"""Cached decision with metadata."""
decision: "TradeDecision"
cached_at: float # Unix timestamp
hit_count: int = 0
market_data_hash: str = ""
@dataclass
class CacheMetrics:
"""Metrics for cache performance monitoring."""
total_requests: int = 0
cache_hits: int = 0
cache_misses: int = 0
evictions: int = 0
total_entries: int = 0
@property
def hit_rate(self) -> float:
"""Calculate cache hit rate."""
if self.total_requests == 0:
return 0.0
return self.cache_hits / self.total_requests
def to_dict(self) -> dict[str, Any]:
"""Convert metrics to dictionary."""
return {
"total_requests": self.total_requests,
"cache_hits": self.cache_hits,
"cache_misses": self.cache_misses,
"hit_rate": self.hit_rate,
"evictions": self.evictions,
"total_entries": self.total_entries,
}
class DecisionCache:
"""TTL-based cache for trade decisions."""
def __init__(self, ttl_seconds: int = 300, max_size: int = 1000) -> None:
"""Initialize the decision cache.
Args:
ttl_seconds: Time-to-live for cache entries in seconds (default: 5 minutes)
max_size: Maximum number of cache entries
"""
self.ttl_seconds = ttl_seconds
self.max_size = max_size
self._cache: dict[str, CacheEntry] = {}
self._metrics = CacheMetrics()
def _generate_cache_key(self, market_data: dict[str, Any]) -> str:
"""Generate cache key from market data.
Key is based on:
- Stock code
- Current price (rounded to reduce sensitivity)
- Market conditions (orderbook snapshot)
Args:
market_data: Market data dictionary
Returns:
Cache key string
"""
# Extract key components
stock_code = market_data.get("stock_code", "UNKNOWN")
current_price = market_data.get("current_price", 0)
# Round price to reduce sensitivity (cache hits for similar prices)
# For prices > 1000, round to nearest 10
# For prices < 1000, round to nearest 1
if current_price > 1000:
price_rounded = round(current_price / 10) * 10
else:
price_rounded = round(current_price)
# Include orderbook snapshot (if available)
orderbook_key = ""
if "orderbook" in market_data and market_data["orderbook"]:
ob = market_data["orderbook"]
# Just use bid/ask spread as indicator
if "bid" in ob and "ask" in ob and ob["bid"] and ob["ask"]:
bid_price = ob["bid"][0].get("price", 0) if ob["bid"] else 0
ask_price = ob["ask"][0].get("price", 0) if ob["ask"] else 0
spread = ask_price - bid_price
orderbook_key = f"_spread{spread}"
# Generate cache key
key_str = f"{stock_code}_{price_rounded}{orderbook_key}"
return key_str
def _generate_market_hash(self, market_data: dict[str, Any]) -> str:
"""Generate hash of full market data for invalidation checks.
Args:
market_data: Market data dictionary
Returns:
Hash string
"""
# Create stable JSON representation
stable_json = json.dumps(market_data, sort_keys=True, ensure_ascii=False)
return hashlib.md5(stable_json.encode()).hexdigest()
def get(self, market_data: dict[str, Any]) -> TradeDecision | None:
"""Retrieve cached decision if valid.
Args:
market_data: Market data dictionary
Returns:
Cached TradeDecision if valid, None otherwise
"""
self._metrics.total_requests += 1
cache_key = self._generate_cache_key(market_data)
if cache_key not in self._cache:
self._metrics.cache_misses += 1
return None
entry = self._cache[cache_key]
current_time = time.time()
# Check TTL
if current_time - entry.cached_at > self.ttl_seconds:
# Expired
del self._cache[cache_key]
self._metrics.cache_misses += 1
self._metrics.evictions += 1
logger.debug("Cache expired for key: %s", cache_key)
return None
# Cache hit
entry.hit_count += 1
self._metrics.cache_hits += 1
logger.debug("Cache hit for key: %s (hits: %d)", cache_key, entry.hit_count)
return entry.decision
def set(
self,
market_data: dict[str, Any],
decision: TradeDecision,
) -> None:
"""Store decision in cache.
Args:
market_data: Market data dictionary
decision: TradeDecision to cache
"""
cache_key = self._generate_cache_key(market_data)
market_hash = self._generate_market_hash(market_data)
# Enforce max size (evict oldest if full)
if len(self._cache) >= self.max_size:
# Find oldest entry
oldest_key = min(self._cache.keys(), key=lambda k: self._cache[k].cached_at)
del self._cache[oldest_key]
self._metrics.evictions += 1
logger.debug("Cache full, evicted key: %s", oldest_key)
# Store entry
entry = CacheEntry(
decision=decision,
cached_at=time.time(),
market_data_hash=market_hash,
)
self._cache[cache_key] = entry
self._metrics.total_entries = len(self._cache)
logger.debug("Cached decision for key: %s", cache_key)
def invalidate(self, stock_code: str | None = None) -> int:
"""Invalidate cache entries.
Args:
stock_code: Specific stock code to invalidate, or None for all
Returns:
Number of entries invalidated
"""
if stock_code is None:
# Clear all
count = len(self._cache)
self._cache.clear()
self._metrics.evictions += count
self._metrics.total_entries = 0
logger.info("Invalidated all cache entries (%d)", count)
return count
# Invalidate specific stock
keys_to_remove = [k for k in self._cache.keys() if k.startswith(f"{stock_code}_")]
count = len(keys_to_remove)
for key in keys_to_remove:
del self._cache[key]
self._metrics.evictions += count
self._metrics.total_entries = len(self._cache)
logger.info("Invalidated %d cache entries for stock: %s", count, stock_code)
return count
def cleanup_expired(self) -> int:
"""Remove expired entries from cache.
Returns:
Number of entries removed
"""
current_time = time.time()
expired_keys = [
k
for k, v in self._cache.items()
if current_time - v.cached_at > self.ttl_seconds
]
count = len(expired_keys)
for key in expired_keys:
del self._cache[key]
self._metrics.evictions += count
self._metrics.total_entries = len(self._cache)
if count > 0:
logger.debug("Cleaned up %d expired cache entries", count)
return count
def get_metrics(self) -> CacheMetrics:
"""Get current cache metrics.
Returns:
CacheMetrics object with current statistics
"""
return self._metrics
def reset_metrics(self) -> None:
"""Reset cache metrics."""
self._metrics = CacheMetrics(total_entries=len(self._cache))
logger.info("Cache metrics reset")
def should_cache_decision(self, decision: TradeDecision) -> bool:
"""Determine if a decision should be cached.
HOLD decisions with low confidence are good candidates for caching,
as they're likely to recur in quiet markets.
Args:
decision: TradeDecision to evaluate
Returns:
True if decision should be cached
"""
# Cache HOLD decisions (common in quiet markets)
if decision.action == "HOLD":
return True
# Cache high-confidence decisions (stable signals)
if decision.confidence >= 90:
return True
# Don't cache low-confidence BUY/SELL (volatile signals)
return False

View File

@@ -0,0 +1,296 @@
"""Smart context selection for optimizing token usage.
This module implements intelligent selection of context layers (L1-L7) based on
decision type and market conditions:
- L7 (real-time) for normal trading decisions
- L6-L5 (daily/weekly) for strategic decisions
- L4-L1 (monthly/legacy) only for major events or policy changes
"""
from __future__ import annotations
from dataclasses import dataclass
from datetime import UTC, datetime
from enum import Enum
from typing import Any
from src.context.layer import ContextLayer
from src.context.store import ContextStore
class DecisionType(str, Enum):
"""Type of trading decision being made."""
NORMAL = "normal" # Regular trade decision
STRATEGIC = "strategic" # Strategy adjustment
MAJOR_EVENT = "major_event" # Portfolio rebalancing, policy change
@dataclass(frozen=True)
class ContextSelection:
"""Selected context layers and their relevance scores."""
layers: list[ContextLayer]
relevance_scores: dict[ContextLayer, float]
total_score: float
class ContextSelector:
"""Selects optimal context layers to minimize token usage."""
def __init__(self, store: ContextStore) -> None:
"""Initialize the context selector.
Args:
store: ContextStore instance for retrieving context data
"""
self.store = store
def select_layers(
self,
decision_type: DecisionType = DecisionType.NORMAL,
include_realtime: bool = True,
) -> list[ContextLayer]:
"""Select context layers based on decision type.
Strategy:
- NORMAL: L7 (real-time) only
- STRATEGIC: L7 + L6 + L5 (real-time + daily + weekly)
- MAJOR_EVENT: All layers L1-L7
Args:
decision_type: Type of decision being made
include_realtime: Whether to include L7 real-time data
Returns:
List of context layers to use (ordered by priority)
"""
if decision_type == DecisionType.NORMAL:
# Normal trading: only real-time data
return [ContextLayer.L7_REALTIME] if include_realtime else []
elif decision_type == DecisionType.STRATEGIC:
# Strategic decisions: real-time + recent history
layers = []
if include_realtime:
layers.append(ContextLayer.L7_REALTIME)
layers.extend([ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY])
return layers
else: # MAJOR_EVENT
# Major events: all layers for comprehensive context
layers = []
if include_realtime:
layers.append(ContextLayer.L7_REALTIME)
layers.extend(
[
ContextLayer.L6_DAILY,
ContextLayer.L5_WEEKLY,
ContextLayer.L4_MONTHLY,
ContextLayer.L3_QUARTERLY,
ContextLayer.L2_ANNUAL,
ContextLayer.L1_LEGACY,
]
)
return layers
def score_layer_relevance(
self,
layer: ContextLayer,
decision_type: DecisionType,
current_time: datetime | None = None,
) -> float:
"""Calculate relevance score for a context layer.
Relevance is based on:
1. Decision type (normal, strategic, major event)
2. Layer recency (L7 > L6 > ... > L1)
3. Data availability
Args:
layer: Context layer to score
decision_type: Type of decision being made
current_time: Current time (defaults to now)
Returns:
Relevance score (0.0 to 1.0)
"""
if current_time is None:
current_time = datetime.now(UTC)
# Base scores by decision type
base_scores = {
DecisionType.NORMAL: {
ContextLayer.L7_REALTIME: 1.0,
ContextLayer.L6_DAILY: 0.1,
ContextLayer.L5_WEEKLY: 0.05,
ContextLayer.L4_MONTHLY: 0.01,
ContextLayer.L3_QUARTERLY: 0.0,
ContextLayer.L2_ANNUAL: 0.0,
ContextLayer.L1_LEGACY: 0.0,
},
DecisionType.STRATEGIC: {
ContextLayer.L7_REALTIME: 0.9,
ContextLayer.L6_DAILY: 0.8,
ContextLayer.L5_WEEKLY: 0.7,
ContextLayer.L4_MONTHLY: 0.3,
ContextLayer.L3_QUARTERLY: 0.2,
ContextLayer.L2_ANNUAL: 0.1,
ContextLayer.L1_LEGACY: 0.05,
},
DecisionType.MAJOR_EVENT: {
ContextLayer.L7_REALTIME: 0.7,
ContextLayer.L6_DAILY: 0.7,
ContextLayer.L5_WEEKLY: 0.7,
ContextLayer.L4_MONTHLY: 0.8,
ContextLayer.L3_QUARTERLY: 0.8,
ContextLayer.L2_ANNUAL: 0.9,
ContextLayer.L1_LEGACY: 1.0,
},
}
score = base_scores[decision_type].get(layer, 0.0)
# Check data availability
latest_timeframe = self.store.get_latest_timeframe(layer)
if latest_timeframe is None:
# No data available - reduce score significantly
score *= 0.1
return score
def select_with_scoring(
self,
decision_type: DecisionType = DecisionType.NORMAL,
min_score: float = 0.5,
) -> ContextSelection:
"""Select context layers with relevance scoring.
Args:
decision_type: Type of decision being made
min_score: Minimum relevance score to include a layer
Returns:
ContextSelection with selected layers and scores
"""
all_layers = [
ContextLayer.L7_REALTIME,
ContextLayer.L6_DAILY,
ContextLayer.L5_WEEKLY,
ContextLayer.L4_MONTHLY,
ContextLayer.L3_QUARTERLY,
ContextLayer.L2_ANNUAL,
ContextLayer.L1_LEGACY,
]
scores = {
layer: self.score_layer_relevance(layer, decision_type) for layer in all_layers
}
# Filter by minimum score
selected_layers = [layer for layer, score in scores.items() if score >= min_score]
# Sort by score (descending)
selected_layers.sort(key=lambda layer: scores[layer], reverse=True)
total_score = sum(scores[layer] for layer in selected_layers)
return ContextSelection(
layers=selected_layers,
relevance_scores=scores,
total_score=total_score,
)
def get_context_data(
self,
layers: list[ContextLayer],
max_items_per_layer: int = 10,
) -> dict[str, Any]:
"""Retrieve context data for selected layers.
Args:
layers: List of context layers to retrieve
max_items_per_layer: Maximum number of items per layer
Returns:
Dictionary with context data organized by layer
"""
result: dict[str, Any] = {}
for layer in layers:
# Get latest timeframe for this layer
latest_timeframe = self.store.get_latest_timeframe(layer)
if latest_timeframe:
# Get all contexts for latest timeframe
contexts = self.store.get_all_contexts(layer, latest_timeframe)
# Limit number of items
if len(contexts) > max_items_per_layer:
# Keep only first N items
contexts = dict(list(contexts.items())[:max_items_per_layer])
result[layer.value] = contexts
return result
def estimate_context_tokens(self, context_data: dict[str, Any]) -> int:
"""Estimate total tokens for context data.
Args:
context_data: Context data dictionary
Returns:
Estimated token count
"""
import json
from src.brain.prompt_optimizer import PromptOptimizer
# Serialize to JSON and estimate tokens
json_str = json.dumps(context_data, ensure_ascii=False)
return PromptOptimizer.estimate_tokens(json_str)
def optimize_context_for_budget(
self,
decision_type: DecisionType,
max_tokens: int,
) -> dict[str, Any]:
"""Select and retrieve context data within a token budget.
Args:
decision_type: Type of decision being made
max_tokens: Maximum token budget for context
Returns:
Optimized context data within budget
"""
# Start with minimal selection
selection = self.select_with_scoring(decision_type, min_score=0.5)
# Retrieve data
context_data = self.get_context_data(selection.layers)
# Check if within budget
estimated_tokens = self.estimate_context_tokens(context_data)
if estimated_tokens <= max_tokens:
return context_data
# If over budget, progressively reduce
# 1. Reduce items per layer
for max_items in [5, 3, 1]:
context_data = self.get_context_data(selection.layers, max_items)
estimated_tokens = self.estimate_context_tokens(context_data)
if estimated_tokens <= max_tokens:
return context_data
# 2. Remove lower-priority layers
for min_score in [0.6, 0.7, 0.8, 0.9]:
selection = self.select_with_scoring(decision_type, min_score=min_score)
context_data = self.get_context_data(selection.layers, max_items_per_layer=1)
estimated_tokens = self.estimate_context_tokens(context_data)
if estimated_tokens <= max_tokens:
return context_data
# Last resort: return only L7 with minimal data
return self.get_context_data([ContextLayer.L7_REALTIME], max_items_per_layer=1)

View File

@@ -2,6 +2,17 @@
Constructs prompts from market data, calls Gemini, and parses structured
JSON responses into validated TradeDecision objects.
Includes token efficiency optimizations:
- Prompt compression and abbreviation
- Response caching for common scenarios
- Smart context selection
- Token usage tracking and metrics
Includes external data integration:
- News sentiment analysis
- Economic calendar events
- Market indicators
"""
from __future__ import annotations
@@ -15,6 +26,11 @@ from typing import Any
from google import genai
from src.config import Settings
from src.data.news_api import NewsAPI, NewsSentiment
from src.data.economic_calendar import EconomicCalendar
from src.data.market_data import MarketData
from src.brain.cache import DecisionCache
from src.brain.prompt_optimizer import PromptOptimizer
logger = logging.getLogger(__name__)
@@ -28,23 +44,176 @@ class TradeDecision:
action: str # "BUY" | "SELL" | "HOLD"
confidence: int # 0-100
rationale: str
token_count: int = 0 # Estimated tokens used
cached: bool = False # Whether decision came from cache
class GeminiClient:
"""Wraps the Gemini API for trade decision-making."""
def __init__(self, settings: Settings) -> None:
def __init__(
self,
settings: Settings,
news_api: NewsAPI | None = None,
economic_calendar: EconomicCalendar | None = None,
market_data: MarketData | None = None,
enable_cache: bool = True,
enable_optimization: bool = True,
) -> None:
self._settings = settings
self._confidence_threshold = settings.CONFIDENCE_THRESHOLD
self._client = genai.Client(api_key=settings.GEMINI_API_KEY)
self._model_name = settings.GEMINI_MODEL
# External data sources (optional)
self._news_api = news_api
self._economic_calendar = economic_calendar
self._market_data = market_data
# Token efficiency features
self._enable_cache = enable_cache
self._enable_optimization = enable_optimization
self._cache = DecisionCache(ttl_seconds=300) if enable_cache else None
self._optimizer = PromptOptimizer()
# Token usage metrics
self._total_tokens_used = 0
self._total_decisions = 0
self._total_cached_decisions = 0
# ------------------------------------------------------------------
# External Data Integration
# ------------------------------------------------------------------
async def _build_external_context(
self, stock_code: str, news_sentiment: NewsSentiment | None = None
) -> str:
"""Build external data context for the prompt.
Args:
stock_code: Stock ticker symbol
news_sentiment: Optional pre-fetched news sentiment
Returns:
Formatted string with external data context
"""
context_parts: list[str] = []
# News sentiment
if news_sentiment is not None:
sentiment_str = self._format_news_sentiment(news_sentiment)
if sentiment_str:
context_parts.append(sentiment_str)
elif self._news_api is not None:
# Fetch news sentiment if not provided
try:
sentiment = await self._news_api.get_news_sentiment(stock_code)
if sentiment is not None:
sentiment_str = self._format_news_sentiment(sentiment)
if sentiment_str:
context_parts.append(sentiment_str)
except Exception as exc:
logger.warning("Failed to fetch news sentiment: %s", exc)
# Economic events
if self._economic_calendar is not None:
events_str = self._format_economic_events(stock_code)
if events_str:
context_parts.append(events_str)
# Market indicators
if self._market_data is not None:
indicators_str = self._format_market_indicators()
if indicators_str:
context_parts.append(indicators_str)
if not context_parts:
return ""
return "EXTERNAL DATA:\n" + "\n\n".join(context_parts)
def _format_news_sentiment(self, sentiment: NewsSentiment) -> str:
"""Format news sentiment for prompt."""
if sentiment.article_count == 0:
return ""
# Select top 3 most relevant articles
top_articles = sentiment.articles[:3]
lines = [
f"News Sentiment: {sentiment.avg_sentiment:.2f} "
f"(from {sentiment.article_count} articles)",
]
for i, article in enumerate(top_articles, 1):
lines.append(
f" {i}. [{article.source}] {article.title} "
f"(sentiment: {article.sentiment_score:.2f})"
)
return "\n".join(lines)
def _format_economic_events(self, stock_code: str) -> str:
"""Format upcoming economic events for prompt."""
if self._economic_calendar is None:
return ""
# Check for upcoming high-impact events
upcoming = self._economic_calendar.get_upcoming_events(
days_ahead=7, min_impact="HIGH"
)
if upcoming.high_impact_count == 0:
return ""
lines = [
f"Upcoming High-Impact Events: {upcoming.high_impact_count} in next 7 days"
]
if upcoming.next_major_event is not None:
event = upcoming.next_major_event
lines.append(
f" Next: {event.name} ({event.event_type}) "
f"on {event.datetime.strftime('%Y-%m-%d')}"
)
# Check for earnings
earnings_date = self._economic_calendar.get_earnings_date(stock_code)
if earnings_date is not None:
lines.append(
f" Earnings: {stock_code} on {earnings_date.strftime('%Y-%m-%d')}"
)
return "\n".join(lines)
def _format_market_indicators(self) -> str:
"""Format market indicators for prompt."""
if self._market_data is None:
return ""
try:
indicators = self._market_data.get_market_indicators()
lines = [f"Market Sentiment: {indicators.sentiment.name}"]
# Add breadth if meaningful
if indicators.breadth.advance_decline_ratio != 1.0:
lines.append(
f"Advance/Decline Ratio: {indicators.breadth.advance_decline_ratio:.2f}"
)
return "\n".join(lines)
except Exception as exc:
logger.warning("Failed to get market indicators: %s", exc)
return ""
# ------------------------------------------------------------------
# Prompt Construction
# ------------------------------------------------------------------
def build_prompt(self, market_data: dict[str, Any]) -> str:
"""Build a structured prompt from market data.
async def build_prompt(
self, market_data: dict[str, Any], news_sentiment: NewsSentiment | None = None
) -> str:
"""Build a structured prompt from market data and external sources.
The prompt instructs Gemini to return valid JSON with action,
confidence, and rationale fields.
@@ -72,6 +241,60 @@ class GeminiClient:
market_info = "\n".join(market_info_lines)
# Add external data context if available
external_context = await self._build_external_context(
market_data["stock_code"], news_sentiment
)
if external_context:
market_info += f"\n\n{external_context}"
json_format = (
'{"action": "BUY"|"SELL"|"HOLD", '
'"confidence": <int 0-100>, "rationale": "<string>"}'
)
return (
f"You are a professional {market_name} trading analyst.\n"
"Analyze the following market data and decide whether to "
"BUY, SELL, or HOLD.\n\n"
f"{market_info}\n\n"
"You MUST respond with ONLY valid JSON in the following format:\n"
f"{json_format}\n\n"
"Rules:\n"
"- action must be exactly one of: BUY, SELL, HOLD\n"
"- confidence must be an integer from 0 to 100\n"
"- rationale must explain your reasoning concisely\n"
"- Do NOT wrap the JSON in markdown code blocks\n"
)
def build_prompt_sync(self, market_data: dict[str, Any]) -> str:
"""Synchronous version of build_prompt (for backward compatibility).
This version does NOT include external data integration.
Use async build_prompt() for full functionality.
"""
market_name = market_data.get("market_name", "Korean stock market")
# Build market data section dynamically based on available fields
market_info_lines = [
f"Market: {market_name}",
f"Stock Code: {market_data['stock_code']}",
f"Current Price: {market_data['current_price']}",
]
# Add orderbook if available (domestic markets)
if "orderbook" in market_data:
market_info_lines.append(
f"Orderbook: {json.dumps(market_data['orderbook'], ensure_ascii=False)}"
)
# Add foreigner net if non-zero
if market_data.get("foreigner_net", 0) != 0:
market_info_lines.append(
f"Foreigner Net Buy/Sell: {market_data['foreigner_net']}"
)
market_info = "\n".join(market_info_lines)
json_format = (
'{"action": "BUY"|"SELL"|"HOLD", '
'"confidence": <int 0-100>, "rationale": "<string>"}'
@@ -152,28 +375,153 @@ class GeminiClient:
# API Call
# ------------------------------------------------------------------
async def decide(self, market_data: dict[str, Any]) -> TradeDecision:
"""Build prompt, call Gemini, and return a parsed decision."""
prompt = self.build_prompt(market_data)
logger.info("Requesting trade decision from Gemini")
async def decide(
self, market_data: dict[str, Any], news_sentiment: NewsSentiment | None = None
) -> TradeDecision:
"""Build prompt, call Gemini, and return a parsed decision.
Args:
market_data: Market data dictionary with price, orderbook, etc.
news_sentiment: Optional pre-fetched news sentiment
Returns:
Parsed TradeDecision
"""
# Check cache first
if self._cache:
cached_decision = self._cache.get(market_data)
if cached_decision:
self._total_cached_decisions += 1
self._total_decisions += 1
logger.info(
"Cache hit for decision",
extra={
"action": cached_decision.action,
"confidence": cached_decision.confidence,
"cache_hit_rate": self.get_cache_hit_rate(),
},
)
# Return cached decision with cached flag
return TradeDecision(
action=cached_decision.action,
confidence=cached_decision.confidence,
rationale=cached_decision.rationale,
token_count=0,
cached=True,
)
# Build optimized prompt
if self._enable_optimization:
prompt = self._optimizer.build_compressed_prompt(market_data)
else:
prompt = await self.build_prompt(market_data, news_sentiment)
# Estimate tokens
token_count = self._optimizer.estimate_tokens(prompt)
self._total_tokens_used += token_count
logger.info(
"Requesting trade decision from Gemini",
extra={"estimated_tokens": token_count, "optimized": self._enable_optimization},
)
try:
response = await self._client.aio.models.generate_content(
model=self._model_name, contents=prompt,
model=self._model_name,
contents=prompt,
)
raw = response.text
except Exception as exc:
logger.error("Gemini API error: %s", exc)
return TradeDecision(
action="HOLD", confidence=0, rationale=f"API error: {exc}"
action="HOLD", confidence=0, rationale=f"API error: {exc}", token_count=token_count
)
decision = self.parse_response(raw)
self._total_decisions += 1
# Add token count to decision
decision_with_tokens = TradeDecision(
action=decision.action,
confidence=decision.confidence,
rationale=decision.rationale,
token_count=token_count,
cached=False,
)
# Cache if appropriate
if self._cache and self._cache.should_cache_decision(decision):
self._cache.set(market_data, decision)
logger.info(
"Gemini decision",
extra={
"action": decision.action,
"confidence": decision.confidence,
"tokens": token_count,
"avg_tokens": self.get_avg_tokens_per_decision(),
},
)
return decision
return decision_with_tokens
# ------------------------------------------------------------------
# Token Efficiency Metrics
# ------------------------------------------------------------------
def get_token_metrics(self) -> dict[str, Any]:
"""Get token usage metrics.
Returns:
Dictionary with token usage statistics
"""
metrics = {
"total_tokens_used": self._total_tokens_used,
"total_decisions": self._total_decisions,
"total_cached_decisions": self._total_cached_decisions,
"avg_tokens_per_decision": self.get_avg_tokens_per_decision(),
"cache_hit_rate": self.get_cache_hit_rate(),
}
if self._cache:
cache_metrics = self._cache.get_metrics()
metrics["cache_metrics"] = cache_metrics.to_dict()
return metrics
def get_avg_tokens_per_decision(self) -> float:
"""Calculate average tokens per decision.
Returns:
Average tokens per decision
"""
if self._total_decisions == 0:
return 0.0
return self._total_tokens_used / self._total_decisions
def get_cache_hit_rate(self) -> float:
"""Calculate cache hit rate.
Returns:
Cache hit rate (0.0 to 1.0)
"""
if self._total_decisions == 0:
return 0.0
return self._total_cached_decisions / self._total_decisions
def reset_metrics(self) -> None:
"""Reset token usage metrics."""
self._total_tokens_used = 0
self._total_decisions = 0
self._total_cached_decisions = 0
if self._cache:
self._cache.reset_metrics()
logger.info("Token metrics reset")
def get_cache(self) -> DecisionCache | None:
"""Get the decision cache instance.
Returns:
DecisionCache instance or None if caching disabled
"""
return self._cache

View File

@@ -0,0 +1,267 @@
"""Prompt optimization utilities for reducing token usage.
This module provides tools to compress prompts while maintaining decision quality:
- Token counting
- Text compression and abbreviation
- Template-based prompts with variable slots
- Priority-based context truncation
"""
from __future__ import annotations
import json
import re
from dataclasses import dataclass
from typing import Any
# Abbreviation mapping for common terms
ABBREVIATIONS = {
"price": "P",
"volume": "V",
"current": "cur",
"previous": "prev",
"change": "chg",
"percentage": "pct",
"market": "mkt",
"orderbook": "ob",
"foreigner": "fgn",
"buy": "B",
"sell": "S",
"hold": "H",
"confidence": "conf",
"rationale": "reason",
"action": "act",
"net": "net",
}
# Reverse mapping for decompression
REVERSE_ABBREVIATIONS = {v: k for k, v in ABBREVIATIONS.items()}
@dataclass(frozen=True)
class TokenMetrics:
"""Metrics about token usage in a prompt."""
char_count: int
word_count: int
estimated_tokens: int # Rough estimate: ~4 chars per token
compression_ratio: float = 1.0 # Original / Compressed
class PromptOptimizer:
"""Optimizes prompts to reduce token usage while maintaining quality."""
@staticmethod
def estimate_tokens(text: str) -> int:
"""Estimate token count for text.
Uses a simple heuristic: ~4 characters per token for English.
This is approximate but sufficient for optimization purposes.
Args:
text: Input text to estimate tokens for
Returns:
Estimated token count
"""
if not text:
return 0
# Simple estimate: 1 token ≈ 4 characters
return max(1, len(text) // 4)
@staticmethod
def count_tokens(text: str) -> TokenMetrics:
"""Count various metrics for a text.
Args:
text: Input text to analyze
Returns:
TokenMetrics with character, word, and estimated token counts
"""
char_count = len(text)
word_count = len(text.split())
estimated_tokens = PromptOptimizer.estimate_tokens(text)
return TokenMetrics(
char_count=char_count,
word_count=word_count,
estimated_tokens=estimated_tokens,
)
@staticmethod
def compress_json(data: dict[str, Any]) -> str:
"""Compress JSON by removing whitespace.
Args:
data: Dictionary to serialize
Returns:
Compact JSON string without whitespace
"""
return json.dumps(data, separators=(",", ":"), ensure_ascii=False)
@staticmethod
def abbreviate_text(text: str, aggressive: bool = False) -> str:
"""Apply abbreviations to reduce text length.
Args:
text: Input text to abbreviate
aggressive: If True, apply more aggressive compression
Returns:
Abbreviated text
"""
result = text
# Apply word-level abbreviations (case-insensitive)
for full, abbr in ABBREVIATIONS.items():
# Word boundaries to avoid partial replacements
pattern = r"\b" + re.escape(full) + r"\b"
result = re.sub(pattern, abbr, result, flags=re.IGNORECASE)
if aggressive:
# Remove articles and filler words
result = re.sub(r"\b(a|an|the)\b", "", result, flags=re.IGNORECASE)
result = re.sub(r"\b(is|are|was|were)\b", "", result, flags=re.IGNORECASE)
# Collapse multiple spaces
result = re.sub(r"\s+", " ", result)
return result.strip()
@staticmethod
def build_compressed_prompt(
market_data: dict[str, Any],
include_instructions: bool = True,
max_length: int | None = None,
) -> str:
"""Build a compressed prompt from market data.
Args:
market_data: Market data dictionary with stock info
include_instructions: Whether to include full instructions
max_length: Maximum character length (truncates if needed)
Returns:
Compressed prompt string
"""
# Abbreviated market name
market_name = market_data.get("market_name", "KR")
if "Korea" in market_name:
market_name = "KR"
elif "United States" in market_name or "US" in market_name:
market_name = "US"
# Core data - always included
core_info = {
"mkt": market_name,
"code": market_data["stock_code"],
"P": market_data["current_price"],
}
# Optional fields
if "orderbook" in market_data and market_data["orderbook"]:
ob = market_data["orderbook"]
# Compress orderbook: keep only top 3 levels
compressed_ob = {
"bid": ob.get("bid", [])[:3],
"ask": ob.get("ask", [])[:3],
}
core_info["ob"] = compressed_ob
if market_data.get("foreigner_net", 0) != 0:
core_info["fgn_net"] = market_data["foreigner_net"]
# Compress to JSON
data_str = PromptOptimizer.compress_json(core_info)
if include_instructions:
# Minimal instructions
prompt = (
f"{market_name} trader. Analyze:\n{data_str}\n\n"
'Return JSON: {"act":"BUY"|"SELL"|"HOLD","conf":<0-100>,"reason":"<text>"}\n'
"Rules: act=BUY/SELL/HOLD, conf=0-100, reason=concise. No markdown."
)
else:
# Data only (for cached contexts where instructions are known)
prompt = data_str
# Truncate if needed
if max_length and len(prompt) > max_length:
prompt = prompt[:max_length] + "..."
return prompt
@staticmethod
def truncate_context(
context: dict[str, Any],
max_tokens: int,
priority_keys: list[str] | None = None,
) -> dict[str, Any]:
"""Truncate context data to fit within token budget.
Keeps high-priority keys first, then truncates less important data.
Args:
context: Context dictionary to truncate
max_tokens: Maximum token budget
priority_keys: List of keys to keep (in order of priority)
Returns:
Truncated context dictionary
"""
if not context:
return {}
if priority_keys is None:
priority_keys = []
result: dict[str, Any] = {}
current_tokens = 0
# Add priority keys first
for key in priority_keys:
if key in context:
value_str = json.dumps(context[key])
tokens = PromptOptimizer.estimate_tokens(value_str)
if current_tokens + tokens <= max_tokens:
result[key] = context[key]
current_tokens += tokens
else:
break
# Add remaining keys if space available
for key, value in context.items():
if key in result:
continue
value_str = json.dumps(value)
tokens = PromptOptimizer.estimate_tokens(value_str)
if current_tokens + tokens <= max_tokens:
result[key] = value
current_tokens += tokens
else:
break
return result
@staticmethod
def calculate_compression_ratio(original: str, compressed: str) -> float:
"""Calculate compression ratio between original and compressed text.
Args:
original: Original text
compressed: Compressed text
Returns:
Compression ratio (original_tokens / compressed_tokens)
"""
original_tokens = PromptOptimizer.estimate_tokens(original)
compressed_tokens = PromptOptimizer.estimate_tokens(compressed)
if compressed_tokens == 0:
return 1.0
return original_tokens / compressed_tokens

View File

@@ -19,6 +19,11 @@ class Settings(BaseSettings):
GEMINI_API_KEY: str
GEMINI_MODEL: str = "gemini-pro"
# External Data APIs (optional — for data-driven decisions)
NEWS_API_KEY: str | None = None
NEWS_API_PROVIDER: str = "alphavantage" # "alphavantage" or "newsapi"
MARKET_DATA_API_KEY: str | None = None
# Risk Management
CIRCUIT_BREAKER_PCT: float = Field(default=-3.0, le=0.0)
FAT_FINGER_PCT: float = Field(default=30.0, gt=0.0, le=100.0)

328
src/context/summarizer.py Normal file
View File

@@ -0,0 +1,328 @@
"""Context summarization for efficient historical data representation.
This module summarizes old context data instead of including raw details:
- Key metrics only (averages, trends, not details)
- Rolling window (keep last N days detailed, summarize older)
- Aggregate historical data efficiently
"""
from __future__ import annotations
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from typing import Any
from src.context.layer import ContextLayer
from src.context.store import ContextStore
@dataclass(frozen=True)
class SummaryStats:
"""Statistical summary of historical data."""
count: int
mean: float | None = None
min: float | None = None
max: float | None = None
std: float | None = None
trend: str | None = None # "up", "down", "flat"
class ContextSummarizer:
"""Summarizes historical context data to reduce token usage."""
def __init__(self, store: ContextStore) -> None:
"""Initialize the context summarizer.
Args:
store: ContextStore instance for retrieving context data
"""
self.store = store
def summarize_numeric_values(self, values: list[float]) -> SummaryStats:
"""Summarize a list of numeric values.
Args:
values: List of numeric values to summarize
Returns:
SummaryStats with mean, min, max, std, and trend
"""
if not values:
return SummaryStats(count=0)
count = len(values)
mean = sum(values) / count
min_val = min(values)
max_val = max(values)
# Calculate standard deviation
if count > 1:
variance = sum((x - mean) ** 2 for x in values) / (count - 1)
std = variance**0.5
else:
std = 0.0
# Determine trend
trend = "flat"
if count >= 3:
# Simple trend: compare first third vs last third
first_third = values[: count // 3]
last_third = values[-(count // 3) :]
first_avg = sum(first_third) / len(first_third)
last_avg = sum(last_third) / len(last_third)
# Trend threshold: 5% change
threshold = 0.05 * abs(first_avg) if first_avg != 0 else 0.01
if last_avg > first_avg + threshold:
trend = "up"
elif last_avg < first_avg - threshold:
trend = "down"
return SummaryStats(
count=count,
mean=round(mean, 4),
min=round(min_val, 4),
max=round(max_val, 4),
std=round(std, 4),
trend=trend,
)
def summarize_layer(
self,
layer: ContextLayer,
start_date: datetime | None = None,
end_date: datetime | None = None,
) -> dict[str, Any]:
"""Summarize all context data for a layer within a date range.
Args:
layer: Context layer to summarize
start_date: Start date (inclusive), None for all
end_date: End date (inclusive), None for now
Returns:
Dictionary with summarized metrics
"""
if end_date is None:
end_date = datetime.now(UTC)
# Get all contexts for this layer
all_contexts = self.store.get_all_contexts(layer)
if not all_contexts:
return {"summary": "No data available", "count": 0}
# Group numeric values by key
numeric_data: dict[str, list[float]] = {}
text_data: dict[str, list[str]] = {}
for key, value in all_contexts.items():
# Try to extract numeric values
if isinstance(value, (int, float)):
if key not in numeric_data:
numeric_data[key] = []
numeric_data[key].append(float(value))
elif isinstance(value, dict):
# Extract numeric fields from dict
for subkey, subvalue in value.items():
if isinstance(subvalue, (int, float)):
full_key = f"{key}.{subkey}"
if full_key not in numeric_data:
numeric_data[full_key] = []
numeric_data[full_key].append(float(subvalue))
elif isinstance(value, str):
if key not in text_data:
text_data[key] = []
text_data[key].append(value)
# Summarize numeric data
summary: dict[str, Any] = {}
for key, values in numeric_data.items():
stats = self.summarize_numeric_values(values)
summary[key] = {
"count": stats.count,
"avg": stats.mean,
"range": [stats.min, stats.max],
"trend": stats.trend,
}
# Summarize text data (just counts)
for key, values in text_data.items():
summary[f"{key}_count"] = len(values)
summary["total_entries"] = len(all_contexts)
return summary
def rolling_window_summary(
self,
layer: ContextLayer,
window_days: int = 30,
summarize_older: bool = True,
) -> dict[str, Any]:
"""Create a rolling window summary.
Recent data (within window) is kept detailed.
Older data is summarized to key metrics.
Args:
layer: Context layer to summarize
window_days: Number of days to keep detailed
summarize_older: Whether to summarize data older than window
Returns:
Dictionary with recent (detailed) and historical (summary) data
"""
result: dict[str, Any] = {
"window_days": window_days,
"recent_data": {},
"historical_summary": {},
}
# Get all contexts
all_contexts = self.store.get_all_contexts(layer)
recent_values: dict[str, list[float]] = {}
historical_values: dict[str, list[float]] = {}
for key, value in all_contexts.items():
# For simplicity, treat all numeric values
if isinstance(value, (int, float)):
# Note: We don't have timestamps in context keys
# This is a simplified implementation
# In practice, would need to check timeframe field
# For now, put recent data in window
if key not in recent_values:
recent_values[key] = []
recent_values[key].append(float(value))
# Detailed recent data
result["recent_data"] = {key: values[-10:] for key, values in recent_values.items()}
# Summarized historical data
if summarize_older:
for key, values in historical_values.items():
stats = self.summarize_numeric_values(values)
result["historical_summary"][key] = {
"count": stats.count,
"avg": stats.mean,
"trend": stats.trend,
}
return result
def aggregate_to_higher_layer(
self,
source_layer: ContextLayer,
target_layer: ContextLayer,
metric_key: str,
aggregation_func: str = "mean",
) -> float | None:
"""Aggregate data from source layer to target layer.
Args:
source_layer: Source context layer (more granular)
target_layer: Target context layer (less granular)
metric_key: Key of metric to aggregate
aggregation_func: Aggregation function ("mean", "sum", "max", "min")
Returns:
Aggregated value, or None if no data available
"""
# Get all contexts from source layer
source_contexts = self.store.get_all_contexts(source_layer)
# Extract values for metric_key
values = []
for key, value in source_contexts.items():
if key == metric_key and isinstance(value, (int, float)):
values.append(float(value))
elif isinstance(value, dict) and metric_key in value:
subvalue = value[metric_key]
if isinstance(subvalue, (int, float)):
values.append(float(subvalue))
if not values:
return None
# Apply aggregation function
if aggregation_func == "mean":
return sum(values) / len(values)
elif aggregation_func == "sum":
return sum(values)
elif aggregation_func == "max":
return max(values)
elif aggregation_func == "min":
return min(values)
else:
return sum(values) / len(values) # Default to mean
def create_compact_summary(
self,
layers: list[ContextLayer],
top_n_metrics: int = 5,
) -> dict[str, Any]:
"""Create a compact summary across multiple layers.
Args:
layers: List of context layers to summarize
top_n_metrics: Number of top metrics to include per layer
Returns:
Compact summary dictionary
"""
summary: dict[str, Any] = {}
for layer in layers:
layer_summary = self.summarize_layer(layer)
# Keep only top N metrics (by count/relevance)
metrics = []
for key, value in layer_summary.items():
if isinstance(value, dict) and "count" in value:
metrics.append((key, value, value["count"]))
# Sort by count (descending)
metrics.sort(key=lambda x: x[2], reverse=True)
# Keep top N
top_metrics = {m[0]: m[1] for m in metrics[:top_n_metrics]}
summary[layer.value] = top_metrics
return summary
def format_summary_for_prompt(self, summary: dict[str, Any]) -> str:
"""Format summary for inclusion in a prompt.
Args:
summary: Summary dictionary
Returns:
Formatted string for prompt
"""
lines = []
for layer, metrics in summary.items():
if not metrics:
continue
lines.append(f"{layer}:")
for key, value in metrics.items():
if isinstance(value, dict):
# Format as: key: avg=X, trend=Y
parts = []
if "avg" in value and value["avg"] is not None:
parts.append(f"avg={value['avg']:.2f}")
if "trend" in value and value["trend"]:
parts.append(f"trend={value['trend']}")
if parts:
lines.append(f" {key}: {', '.join(parts)}")
else:
lines.append(f" {key}: {value}")
return "\n".join(lines)

205
src/data/README.md Normal file
View File

@@ -0,0 +1,205 @@
# External Data Integration
This module provides objective external data sources to enhance trading decisions beyond just market prices and user input.
## Modules
### `news_api.py` - News Sentiment Analysis
Fetches real-time news for stocks with sentiment scoring.
**Features:**
- Alpha Vantage and NewsAPI.org support
- Sentiment scoring (-1.0 to +1.0)
- 5-minute caching to minimize API quota usage
- Graceful fallback when API unavailable
**Usage:**
```python
from src.data.news_api import NewsAPI
# Initialize with API key
news_api = NewsAPI(api_key="your_key", provider="alphavantage")
# Fetch news sentiment
sentiment = await news_api.get_news_sentiment("AAPL")
if sentiment:
print(f"Average sentiment: {sentiment.avg_sentiment}")
for article in sentiment.articles[:3]:
print(f"{article.title} ({article.sentiment_score})")
```
### `economic_calendar.py` - Major Economic Events
Tracks FOMC meetings, GDP releases, CPI, earnings calendars, and other market-moving events.
**Features:**
- High-impact event tracking (FOMC, GDP, CPI)
- Earnings calendar per stock
- Event proximity checking
- Hardcoded major events for 2026 (no API required)
**Usage:**
```python
from src.data.economic_calendar import EconomicCalendar
calendar = EconomicCalendar()
calendar.load_hardcoded_events()
# Get upcoming high-impact events
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="HIGH")
print(f"High-impact events: {upcoming.high_impact_count}")
# Check if near earnings
earnings_date = calendar.get_earnings_date("AAPL")
if earnings_date:
print(f"Next earnings: {earnings_date}")
# Check for high volatility period
if calendar.is_high_volatility_period(hours_ahead=24):
print("High-impact event imminent!")
```
### `market_data.py` - Market Indicators
Provides market breadth, sector performance, and sentiment indicators.
**Features:**
- Market sentiment levels (Fear & Greed equivalent)
- Market breadth (advancing/declining stocks)
- Sector performance tracking
- Fear/Greed score calculation
**Usage:**
```python
from src.data.market_data import MarketData
market_data = MarketData(api_key="your_key")
# Get market sentiment
sentiment = market_data.get_market_sentiment()
print(f"Market sentiment: {sentiment.name}")
# Get full indicators
indicators = market_data.get_market_indicators("US")
print(f"Sentiment: {indicators.sentiment.name}")
print(f"A/D Ratio: {indicators.breadth.advance_decline_ratio}")
```
## Integration with GeminiClient
The external data sources are seamlessly integrated into the AI decision engine:
```python
from src.brain.gemini_client import GeminiClient
from src.data.news_api import NewsAPI
from src.data.economic_calendar import EconomicCalendar
from src.data.market_data import MarketData
from src.config import Settings
settings = Settings()
# Initialize data sources
news_api = NewsAPI(api_key=settings.NEWS_API_KEY, provider=settings.NEWS_API_PROVIDER)
calendar = EconomicCalendar()
calendar.load_hardcoded_events()
market_data = MarketData(api_key=settings.MARKET_DATA_API_KEY)
# Create enhanced client
client = GeminiClient(
settings,
news_api=news_api,
economic_calendar=calendar,
market_data=market_data
)
# Make decision with external context
market_data_dict = {
"stock_code": "AAPL",
"current_price": 180.0,
"market_name": "US stock market"
}
decision = await client.decide(market_data_dict)
```
The external data is automatically included in the prompt sent to Gemini:
```
Market: US stock market
Stock Code: AAPL
Current Price: 180.0
EXTERNAL DATA:
News Sentiment: 0.85 (from 10 articles)
1. [Reuters] Apple hits record high (sentiment: 0.92)
2. [Bloomberg] Strong iPhone sales (sentiment: 0.78)
3. [CNBC] Tech sector rallying (sentiment: 0.85)
Upcoming High-Impact Events: 2 in next 7 days
Next: FOMC Meeting (FOMC) on 2026-03-18
Earnings: AAPL on 2026-02-10
Market Sentiment: GREED
Advance/Decline Ratio: 2.35
```
## Configuration
Add these to your `.env` file:
```bash
# External Data APIs (optional)
NEWS_API_KEY=your_alpha_vantage_key
NEWS_API_PROVIDER=alphavantage # or "newsapi"
MARKET_DATA_API_KEY=your_market_data_key
```
## API Recommendations
### Alpha Vantage (News)
- **Free tier:** 5 calls/min, 500 calls/day
- **Pros:** Provides sentiment scores, no credit card required
- **URL:** https://www.alphavantage.co/
### NewsAPI.org
- **Free tier:** 100 requests/day
- **Pros:** Large news coverage, easy to use
- **Cons:** No sentiment scores (we use keyword heuristics)
- **URL:** https://newsapi.org/
## Caching Strategy
To minimize API quota usage:
1. **News:** 5-minute TTL cache per stock
2. **Economic Calendar:** Loaded once at startup (hardcoded events)
3. **Market Data:** Fetched per decision (lightweight)
## Graceful Degradation
The system works gracefully without external data:
- If no API keys provided → decisions work with just market prices
- If API fails → decision continues without external context
- If cache expired → attempts refetch, falls back to no data
- Errors are logged but never block trading decisions
## Testing
All modules have comprehensive test coverage (81%+):
```bash
pytest tests/test_data_integration.py -v --cov=src/data
```
Tests use mocks to avoid requiring real API keys.
## Future Enhancements
- Twitter/X sentiment analysis
- Reddit WallStreetBets sentiment
- Options flow data
- Insider trading activity
- Analyst upgrades/downgrades
- Real-time economic data APIs

5
src/data/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
"""External data integration for objective decision-making."""
from __future__ import annotations
__all__ = ["NewsAPI", "EconomicCalendar", "MarketData"]

View File

@@ -0,0 +1,219 @@
"""Economic calendar integration for major market events.
Tracks FOMC meetings, GDP releases, CPI, earnings calendars, and other
market-moving events.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any
logger = logging.getLogger(__name__)
@dataclass
class EconomicEvent:
"""Single economic event."""
name: str
event_type: str # "FOMC", "GDP", "CPI", "EARNINGS", etc.
datetime: datetime
impact: str # "HIGH", "MEDIUM", "LOW"
country: str
description: str
@dataclass
class UpcomingEvents:
"""Collection of upcoming economic events."""
events: list[EconomicEvent]
high_impact_count: int
next_major_event: EconomicEvent | None
class EconomicCalendar:
"""Economic calendar with event tracking and impact scoring."""
def __init__(self, api_key: str | None = None) -> None:
"""Initialize economic calendar.
Args:
api_key: API key for calendar provider (None for testing/hardcoded)
"""
self._api_key = api_key
# For now, use hardcoded major events (can be extended with API)
self._events: list[EconomicEvent] = []
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def get_upcoming_events(
self, days_ahead: int = 7, min_impact: str = "MEDIUM"
) -> UpcomingEvents:
"""Get upcoming economic events within specified timeframe.
Args:
days_ahead: Number of days to look ahead
min_impact: Minimum impact level ("LOW", "MEDIUM", "HIGH")
Returns:
UpcomingEvents with filtered events
"""
now = datetime.now()
end_date = now + timedelta(days=days_ahead)
# Filter events by timeframe and impact
upcoming = [
event
for event in self._events
if now <= event.datetime <= end_date
and self._impact_level(event.impact) >= self._impact_level(min_impact)
]
# Sort by datetime
upcoming.sort(key=lambda e: e.datetime)
# Count high-impact events
high_impact_count = sum(1 for e in upcoming if e.impact == "HIGH")
# Get next major event
next_major = None
for event in upcoming:
if event.impact == "HIGH":
next_major = event
break
return UpcomingEvents(
events=upcoming,
high_impact_count=high_impact_count,
next_major_event=next_major,
)
def add_event(self, event: EconomicEvent) -> None:
"""Add an economic event to the calendar."""
self._events.append(event)
def clear_events(self) -> None:
"""Clear all events (useful for testing)."""
self._events.clear()
def get_earnings_date(self, stock_code: str) -> datetime | None:
"""Get next earnings date for a stock.
Args:
stock_code: Stock ticker symbol
Returns:
Next earnings datetime or None if not found
"""
now = datetime.now()
earnings_events = [
event
for event in self._events
if event.event_type == "EARNINGS"
and stock_code.upper() in event.name.upper()
and event.datetime > now
]
if not earnings_events:
return None
# Return earliest upcoming earnings
earnings_events.sort(key=lambda e: e.datetime)
return earnings_events[0].datetime
def load_hardcoded_events(self) -> None:
"""Load hardcoded major economic events for 2026.
This is a fallback when no API is available.
"""
# Major FOMC meetings in 2026 (estimated)
fomc_dates = [
datetime(2026, 3, 18),
datetime(2026, 5, 6),
datetime(2026, 6, 17),
datetime(2026, 7, 29),
datetime(2026, 9, 16),
datetime(2026, 11, 4),
datetime(2026, 12, 16),
]
for date in fomc_dates:
self.add_event(
EconomicEvent(
name="FOMC Meeting",
event_type="FOMC",
datetime=date,
impact="HIGH",
country="US",
description="Federal Reserve interest rate decision",
)
)
# Quarterly GDP releases (estimated)
gdp_dates = [
datetime(2026, 4, 28),
datetime(2026, 7, 30),
datetime(2026, 10, 29),
]
for date in gdp_dates:
self.add_event(
EconomicEvent(
name="US GDP Release",
event_type="GDP",
datetime=date,
impact="HIGH",
country="US",
description="Quarterly GDP growth rate",
)
)
# Monthly CPI releases (12th of each month, estimated)
for month in range(1, 13):
try:
cpi_date = datetime(2026, month, 12)
self.add_event(
EconomicEvent(
name="US CPI Release",
event_type="CPI",
datetime=cpi_date,
impact="HIGH",
country="US",
description="Consumer Price Index inflation data",
)
)
except ValueError:
continue
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _impact_level(self, impact: str) -> int:
"""Convert impact string to numeric level."""
levels = {"LOW": 1, "MEDIUM": 2, "HIGH": 3}
return levels.get(impact.upper(), 0)
def is_high_volatility_period(self, hours_ahead: int = 24) -> bool:
"""Check if we're near a high-impact event.
Args:
hours_ahead: Number of hours to look ahead
Returns:
True if high-impact event is imminent
"""
now = datetime.now()
threshold = now + timedelta(hours=hours_ahead)
for event in self._events:
if event.impact == "HIGH" and now <= event.datetime <= threshold:
return True
return False

198
src/data/market_data.py Normal file
View File

@@ -0,0 +1,198 @@
"""Additional market data indicators beyond basic price data.
Provides market breadth, sector performance, and market sentiment indicators.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from enum import Enum
logger = logging.getLogger(__name__)
class MarketSentiment(Enum):
"""Overall market sentiment levels."""
EXTREME_FEAR = 1
FEAR = 2
NEUTRAL = 3
GREED = 4
EXTREME_GREED = 5
@dataclass
class SectorPerformance:
"""Performance metrics for a market sector."""
sector_name: str
daily_change_pct: float
weekly_change_pct: float
leader_stock: str # Best performing stock in sector
laggard_stock: str # Worst performing stock in sector
@dataclass
class MarketBreadth:
"""Market breadth indicators."""
advancing_stocks: int
declining_stocks: int
unchanged_stocks: int
new_highs: int
new_lows: int
advance_decline_ratio: float
@dataclass
class MarketIndicators:
"""Aggregated market indicators."""
sentiment: MarketSentiment
breadth: MarketBreadth
sector_performance: list[SectorPerformance]
vix_level: float | None # Volatility index if available
class MarketData:
"""Market data provider for additional indicators."""
def __init__(self, api_key: str | None = None) -> None:
"""Initialize market data provider.
Args:
api_key: API key for data provider (None for testing)
"""
self._api_key = api_key
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def get_market_sentiment(self) -> MarketSentiment:
"""Get current market sentiment level.
This is a simplified version. In production, this would integrate
with Fear & Greed Index or similar sentiment indicators.
Returns:
MarketSentiment enum value
"""
# Default to neutral when API not available
if self._api_key is None:
logger.debug("No market data API key — returning NEUTRAL sentiment")
return MarketSentiment.NEUTRAL
# TODO: Integrate with actual sentiment API
return MarketSentiment.NEUTRAL
def get_market_breadth(self, market: str = "US") -> MarketBreadth | None:
"""Get market breadth indicators.
Args:
market: Market code ("US", "KR", etc.)
Returns:
MarketBreadth object or None if unavailable
"""
if self._api_key is None:
logger.debug("No market data API key — returning None for breadth")
return None
# TODO: Integrate with actual market breadth API
return None
def get_sector_performance(
self, market: str = "US"
) -> list[SectorPerformance]:
"""Get sector performance rankings.
Args:
market: Market code ("US", "KR", etc.)
Returns:
List of SectorPerformance objects, sorted by daily change
"""
if self._api_key is None:
logger.debug("No market data API key — returning empty sector list")
return []
# TODO: Integrate with actual sector performance API
return []
def get_market_indicators(self, market: str = "US") -> MarketIndicators:
"""Get aggregated market indicators.
Args:
market: Market code ("US", "KR", etc.)
Returns:
MarketIndicators with all available data
"""
sentiment = self.get_market_sentiment()
breadth = self.get_market_breadth(market)
sectors = self.get_sector_performance(market)
# Default breadth if unavailable
if breadth is None:
breadth = MarketBreadth(
advancing_stocks=0,
declining_stocks=0,
unchanged_stocks=0,
new_highs=0,
new_lows=0,
advance_decline_ratio=1.0,
)
return MarketIndicators(
sentiment=sentiment,
breadth=breadth,
sector_performance=sectors,
vix_level=None, # TODO: Add VIX integration
)
# ------------------------------------------------------------------
# Helper Methods
# ------------------------------------------------------------------
def calculate_fear_greed_score(
self, breadth: MarketBreadth, vix: float | None = None
) -> int:
"""Calculate a simple fear/greed score (0-100).
Args:
breadth: Market breadth data
vix: VIX level (optional)
Returns:
Score from 0 (extreme fear) to 100 (extreme greed)
"""
# Start at neutral
score = 50
# Adjust based on advance/decline ratio
if breadth.advance_decline_ratio > 1.5:
score += 20
elif breadth.advance_decline_ratio > 1.0:
score += 10
elif breadth.advance_decline_ratio < 0.5:
score -= 20
elif breadth.advance_decline_ratio < 1.0:
score -= 10
# Adjust based on new highs/lows
if breadth.new_highs > breadth.new_lows * 2:
score += 15
elif breadth.new_lows > breadth.new_highs * 2:
score -= 15
# Adjust based on VIX if available
if vix is not None:
if vix > 30: # High volatility = fear
score -= 15
elif vix < 15: # Low volatility = complacency/greed
score += 10
# Clamp to 0-100
return max(0, min(100, score))

316
src/data/news_api.py Normal file
View File

@@ -0,0 +1,316 @@
"""News API integration with sentiment analysis and caching.
Fetches real-time news for stocks using free-tier APIs (Alpha Vantage or NewsAPI).
Includes 5-minute caching to minimize API quota usage.
"""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass
from typing import Any
import aiohttp
logger = logging.getLogger(__name__)
# Cache entries expire after 5 minutes
CACHE_TTL_SECONDS = 300
@dataclass
class NewsArticle:
"""Single news article with sentiment."""
title: str
summary: str
source: str
published_at: str
sentiment_score: float # -1.0 (negative) to +1.0 (positive)
url: str
@dataclass
class NewsSentiment:
"""Aggregated news sentiment for a stock."""
stock_code: str
articles: list[NewsArticle]
avg_sentiment: float # Average sentiment across all articles
article_count: int
fetched_at: float # Unix timestamp
class NewsAPI:
"""News API client with sentiment analysis and caching."""
def __init__(
self,
api_key: str | None = None,
provider: str = "alphavantage",
cache_ttl: int = CACHE_TTL_SECONDS,
) -> None:
"""Initialize NewsAPI client.
Args:
api_key: API key for the news provider (None for testing)
provider: News provider ("alphavantage" or "newsapi")
cache_ttl: Cache time-to-live in seconds
"""
self._api_key = api_key
self._provider = provider
self._cache_ttl = cache_ttl
self._cache: dict[str, NewsSentiment] = {}
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
async def get_news_sentiment(self, stock_code: str) -> NewsSentiment | None:
"""Fetch news sentiment for a stock with caching.
Args:
stock_code: Stock ticker symbol (e.g., "AAPL", "005930")
Returns:
NewsSentiment object or None if fetch fails or API unavailable
"""
# Check cache first
cached = self._get_from_cache(stock_code)
if cached is not None:
logger.debug("News cache hit for %s", stock_code)
return cached
# API key required for real requests
if self._api_key is None:
logger.warning("No news API key provided — returning None")
return None
# Fetch from API
try:
sentiment = await self._fetch_news(stock_code)
if sentiment is not None:
self._cache[stock_code] = sentiment
return sentiment
except Exception as exc:
logger.error("Failed to fetch news for %s: %s", stock_code, exc)
return None
def clear_cache(self) -> None:
"""Clear the news cache (useful for testing)."""
self._cache.clear()
# ------------------------------------------------------------------
# Cache Management
# ------------------------------------------------------------------
def _get_from_cache(self, stock_code: str) -> NewsSentiment | None:
"""Retrieve cached sentiment if not expired."""
if stock_code not in self._cache:
return None
cached = self._cache[stock_code]
age = time.time() - cached.fetched_at
if age > self._cache_ttl:
logger.debug("News cache expired for %s (age: %.1fs)", stock_code, age)
del self._cache[stock_code]
return None
return cached
# ------------------------------------------------------------------
# API Fetching
# ------------------------------------------------------------------
async def _fetch_news(self, stock_code: str) -> NewsSentiment | None:
"""Fetch news from the provider API."""
if self._provider == "alphavantage":
return await self._fetch_alphavantage(stock_code)
elif self._provider == "newsapi":
return await self._fetch_newsapi(stock_code)
else:
logger.error("Unknown news provider: %s", self._provider)
return None
async def _fetch_alphavantage(self, stock_code: str) -> NewsSentiment | None:
"""Fetch news from Alpha Vantage News Sentiment API."""
url = "https://www.alphavantage.co/query"
params = {
"function": "NEWS_SENTIMENT",
"tickers": stock_code,
"apikey": self._api_key,
"limit": 10, # Fetch top 10 articles
}
try:
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params, timeout=10) as resp:
if resp.status != 200:
logger.error(
"Alpha Vantage API error: HTTP %d", resp.status
)
return None
data = await resp.json()
return self._parse_alphavantage_response(stock_code, data)
except Exception as exc:
logger.error("Alpha Vantage request failed: %s", exc)
return None
async def _fetch_newsapi(self, stock_code: str) -> NewsSentiment | None:
"""Fetch news from NewsAPI.org."""
url = "https://newsapi.org/v2/everything"
params = {
"q": stock_code,
"apiKey": self._api_key,
"pageSize": 10,
"sortBy": "publishedAt",
"language": "en",
}
try:
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params, timeout=10) as resp:
if resp.status != 200:
logger.error("NewsAPI error: HTTP %d", resp.status)
return None
data = await resp.json()
return self._parse_newsapi_response(stock_code, data)
except Exception as exc:
logger.error("NewsAPI request failed: %s", exc)
return None
# ------------------------------------------------------------------
# Response Parsing
# ------------------------------------------------------------------
def _parse_alphavantage_response(
self, stock_code: str, data: dict[str, Any]
) -> NewsSentiment | None:
"""Parse Alpha Vantage API response."""
if "feed" not in data:
logger.warning("No 'feed' key in Alpha Vantage response")
return None
articles: list[NewsArticle] = []
for item in data["feed"]:
# Extract sentiment for this specific ticker
ticker_sentiment = self._extract_ticker_sentiment(item, stock_code)
article = NewsArticle(
title=item.get("title", ""),
summary=item.get("summary", "")[:200], # Truncate long summaries
source=item.get("source", "Unknown"),
published_at=item.get("time_published", ""),
sentiment_score=ticker_sentiment,
url=item.get("url", ""),
)
articles.append(article)
if not articles:
return None
avg_sentiment = sum(a.sentiment_score for a in articles) / len(articles)
return NewsSentiment(
stock_code=stock_code,
articles=articles,
avg_sentiment=avg_sentiment,
article_count=len(articles),
fetched_at=time.time(),
)
def _extract_ticker_sentiment(
self, item: dict[str, Any], stock_code: str
) -> float:
"""Extract sentiment score for specific ticker from article."""
ticker_sentiments = item.get("ticker_sentiment", [])
for ts in ticker_sentiments:
if ts.get("ticker", "").upper() == stock_code.upper():
# Alpha Vantage provides sentiment_score as string
score_str = ts.get("ticker_sentiment_score", "0")
try:
return float(score_str)
except ValueError:
return 0.0
# Fallback to overall sentiment if ticker-specific not found
overall_sentiment = item.get("overall_sentiment_score", "0")
try:
return float(overall_sentiment)
except ValueError:
return 0.0
def _parse_newsapi_response(
self, stock_code: str, data: dict[str, Any]
) -> NewsSentiment | None:
"""Parse NewsAPI.org response.
Note: NewsAPI doesn't provide sentiment scores, so we use a
simple heuristic based on title keywords.
"""
if data.get("status") != "ok" or "articles" not in data:
logger.warning("Invalid NewsAPI response")
return None
articles: list[NewsArticle] = []
for item in data["articles"]:
# Simple sentiment heuristic based on keywords
sentiment = self._estimate_sentiment_from_text(
item.get("title", "") + " " + item.get("description", "")
)
article = NewsArticle(
title=item.get("title", ""),
summary=item.get("description", "")[:200],
source=item.get("source", {}).get("name", "Unknown"),
published_at=item.get("publishedAt", ""),
sentiment_score=sentiment,
url=item.get("url", ""),
)
articles.append(article)
if not articles:
return None
avg_sentiment = sum(a.sentiment_score for a in articles) / len(articles)
return NewsSentiment(
stock_code=stock_code,
articles=articles,
avg_sentiment=avg_sentiment,
article_count=len(articles),
fetched_at=time.time(),
)
def _estimate_sentiment_from_text(self, text: str) -> float:
"""Simple keyword-based sentiment estimation.
This is a fallback for APIs that don't provide sentiment scores.
Returns a score between -1.0 and +1.0.
"""
text_lower = text.lower()
positive_keywords = [
"surge", "jump", "gain", "rise", "soar", "rally", "profit",
"growth", "upgrade", "beat", "strong", "bullish", "breakthrough",
]
negative_keywords = [
"plunge", "fall", "drop", "decline", "crash", "loss", "weak",
"downgrade", "miss", "bearish", "concern", "risk", "warning",
]
positive_count = sum(1 for kw in positive_keywords if kw in text_lower)
negative_count = sum(1 for kw in negative_keywords if kw in text_lower)
total = positive_count + negative_count
if total == 0:
return 0.0
# Normalize to -1.0 to +1.0 range
return (positive_count - negative_count) / total

View File

@@ -23,7 +23,7 @@ from google import genai
from src.config import Settings
from src.db import init_db
from src.logging.decision_logger import DecisionLog, DecisionLogger
from src.logging.decision_logger import DecisionLogger
logger = logging.getLogger(__name__)

View File

@@ -21,7 +21,7 @@ from src.broker.overseas import OverseasBroker
from src.config import Settings
from src.context.layer import ContextLayer
from src.context.store import ContextStore
from src.core.criticality import CriticalityAssessor, CriticalityLevel
from src.core.criticality import CriticalityAssessor
from src.core.priority_queue import PriorityTaskQueue
from src.core.risk_manager import CircuitBreakerTripped, RiskManager
from src.db import init_db, log_trade

View File

@@ -126,7 +126,7 @@ class TestPromptConstruction:
"orderbook": {"asks": [], "bids": []},
"foreigner_net": -50000,
}
prompt = client.build_prompt(market_data)
prompt = client.build_prompt_sync(market_data)
assert "005930" in prompt
def test_prompt_contains_price(self, settings):
@@ -137,7 +137,7 @@ class TestPromptConstruction:
"orderbook": {"asks": [], "bids": []},
"foreigner_net": -50000,
}
prompt = client.build_prompt(market_data)
prompt = client.build_prompt_sync(market_data)
assert "72000" in prompt
def test_prompt_enforces_json_output_format(self, settings):
@@ -148,7 +148,7 @@ class TestPromptConstruction:
"orderbook": {"asks": [], "bids": []},
"foreigner_net": 0,
}
prompt = client.build_prompt(market_data)
prompt = client.build_prompt_sync(market_data)
assert "JSON" in prompt
assert "action" in prompt
assert "confidence" in prompt

View File

@@ -0,0 +1,673 @@
"""Tests for external data integration (news, economic calendar, market data)."""
from __future__ import annotations
import time
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from src.brain.gemini_client import GeminiClient
from src.data.economic_calendar import EconomicCalendar, EconomicEvent
from src.data.market_data import MarketBreadth, MarketData, MarketSentiment
from src.data.news_api import NewsAPI, NewsArticle, NewsSentiment
# ---------------------------------------------------------------------------
# NewsAPI Tests
# ---------------------------------------------------------------------------
class TestNewsAPI:
"""Test news API integration with caching."""
def test_news_api_init_without_key(self):
"""NewsAPI should initialize without API key for testing."""
api = NewsAPI(api_key=None)
assert api._api_key is None
assert api._provider == "alphavantage"
assert api._cache_ttl == 300
def test_news_api_init_with_custom_settings(self):
"""NewsAPI should accept custom provider and cache TTL."""
api = NewsAPI(api_key="test_key", provider="newsapi", cache_ttl=600)
assert api._api_key == "test_key"
assert api._provider == "newsapi"
assert api._cache_ttl == 600
@pytest.mark.asyncio
async def test_get_news_sentiment_without_api_key_returns_none(self):
"""Without API key, get_news_sentiment should return None."""
api = NewsAPI(api_key=None)
result = await api.get_news_sentiment("AAPL")
assert result is None
@pytest.mark.asyncio
async def test_cache_hit_returns_cached_sentiment(self):
"""Cache hit should return cached sentiment without API call."""
api = NewsAPI(api_key="test_key")
# Manually populate cache
cached_sentiment = NewsSentiment(
stock_code="AAPL",
articles=[],
avg_sentiment=0.5,
article_count=0,
fetched_at=time.time(),
)
api._cache["AAPL"] = cached_sentiment
result = await api.get_news_sentiment("AAPL")
assert result is cached_sentiment
assert result.stock_code == "AAPL"
@pytest.mark.asyncio
async def test_cache_expiry_triggers_refetch(self):
"""Expired cache entry should trigger refetch."""
api = NewsAPI(api_key="test_key", cache_ttl=1)
# Add expired cache entry
expired_sentiment = NewsSentiment(
stock_code="AAPL",
articles=[],
avg_sentiment=0.5,
article_count=0,
fetched_at=time.time() - 10, # 10 seconds ago
)
api._cache["AAPL"] = expired_sentiment
# Mock the fetch to avoid real API call
with patch.object(api, "_fetch_news", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = None
result = await api.get_news_sentiment("AAPL")
# Should have attempted refetch since cache expired
mock_fetch.assert_called_once_with("AAPL")
def test_clear_cache(self):
"""clear_cache should empty the cache."""
api = NewsAPI(api_key="test_key")
api._cache["AAPL"] = NewsSentiment(
stock_code="AAPL",
articles=[],
avg_sentiment=0.0,
article_count=0,
fetched_at=time.time(),
)
assert len(api._cache) == 1
api.clear_cache()
assert len(api._cache) == 0
def test_parse_alphavantage_response_with_valid_data(self):
"""Should parse Alpha Vantage response correctly."""
api = NewsAPI(api_key="test_key", provider="alphavantage")
mock_response = {
"feed": [
{
"title": "Apple hits new high",
"summary": "Apple stock surges to record levels",
"source": "Reuters",
"time_published": "2026-02-04T10:00:00",
"url": "https://example.com/1",
"ticker_sentiment": [
{"ticker": "AAPL", "ticker_sentiment_score": "0.85"}
],
"overall_sentiment_score": "0.75",
},
{
"title": "Market volatility rises",
"summary": "Tech stocks face headwinds",
"source": "Bloomberg",
"time_published": "2026-02-04T09:00:00",
"url": "https://example.com/2",
"ticker_sentiment": [
{"ticker": "AAPL", "ticker_sentiment_score": "-0.3"}
],
"overall_sentiment_score": "-0.2",
},
]
}
result = api._parse_alphavantage_response("AAPL", mock_response)
assert result is not None
assert result.stock_code == "AAPL"
assert result.article_count == 2
assert len(result.articles) == 2
assert result.articles[0].title == "Apple hits new high"
assert result.articles[0].sentiment_score == 0.85
assert result.articles[1].sentiment_score == -0.3
# Average: (0.85 - 0.3) / 2 = 0.275
assert abs(result.avg_sentiment - 0.275) < 0.01
def test_parse_alphavantage_response_without_feed_returns_none(self):
"""Should return None if 'feed' key is missing."""
api = NewsAPI(api_key="test_key", provider="alphavantage")
result = api._parse_alphavantage_response("AAPL", {})
assert result is None
def test_parse_newsapi_response_with_valid_data(self):
"""Should parse NewsAPI.org response correctly."""
api = NewsAPI(api_key="test_key", provider="newsapi")
mock_response = {
"status": "ok",
"articles": [
{
"title": "Apple stock surges",
"description": "Strong earnings beat expectations",
"source": {"name": "TechCrunch"},
"publishedAt": "2026-02-04T10:00:00Z",
"url": "https://example.com/1",
},
{
"title": "Tech sector faces risks",
"description": "Concerns over market downturn",
"source": {"name": "CNBC"},
"publishedAt": "2026-02-04T09:00:00Z",
"url": "https://example.com/2",
},
],
}
result = api._parse_newsapi_response("AAPL", mock_response)
assert result is not None
assert result.stock_code == "AAPL"
assert result.article_count == 2
assert len(result.articles) == 2
assert result.articles[0].title == "Apple stock surges"
assert result.articles[0].source == "TechCrunch"
def test_estimate_sentiment_from_text_positive(self):
"""Should detect positive sentiment from keywords."""
api = NewsAPI()
text = "Stock price surges with strong profit growth and upgrade"
sentiment = api._estimate_sentiment_from_text(text)
assert sentiment > 0.5
def test_estimate_sentiment_from_text_negative(self):
"""Should detect negative sentiment from keywords."""
api = NewsAPI()
text = "Stock plunges on weak earnings, downgrade warning"
sentiment = api._estimate_sentiment_from_text(text)
assert sentiment < -0.5
def test_estimate_sentiment_from_text_neutral(self):
"""Should return neutral sentiment without keywords."""
api = NewsAPI()
text = "Company announces quarterly report"
sentiment = api._estimate_sentiment_from_text(text)
assert abs(sentiment) < 0.1
# ---------------------------------------------------------------------------
# EconomicCalendar Tests
# ---------------------------------------------------------------------------
class TestEconomicCalendar:
"""Test economic calendar functionality."""
def test_economic_calendar_init(self):
"""EconomicCalendar should initialize correctly."""
calendar = EconomicCalendar(api_key="test_key")
assert calendar._api_key == "test_key"
assert len(calendar._events) == 0
def test_add_event(self):
"""Should be able to add events to calendar."""
calendar = EconomicCalendar()
event = EconomicEvent(
name="FOMC Meeting",
event_type="FOMC",
datetime=datetime(2026, 3, 18),
impact="HIGH",
country="US",
description="Interest rate decision",
)
calendar.add_event(event)
assert len(calendar._events) == 1
assert calendar._events[0].name == "FOMC Meeting"
def test_get_upcoming_events_filters_by_timeframe(self):
"""Should only return events within specified timeframe."""
calendar = EconomicCalendar()
# Add events at different times
now = datetime.now()
calendar.add_event(
EconomicEvent(
name="Event Tomorrow",
event_type="GDP",
datetime=now + timedelta(days=1),
impact="HIGH",
country="US",
description="Test event",
)
)
calendar.add_event(
EconomicEvent(
name="Event Next Month",
event_type="CPI",
datetime=now + timedelta(days=30),
impact="HIGH",
country="US",
description="Test event",
)
)
# Get events for next 7 days
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="HIGH")
assert upcoming.high_impact_count == 1
assert upcoming.events[0].name == "Event Tomorrow"
def test_get_upcoming_events_filters_by_impact(self):
"""Should filter events by minimum impact level."""
calendar = EconomicCalendar()
now = datetime.now()
calendar.add_event(
EconomicEvent(
name="High Impact Event",
event_type="FOMC",
datetime=now + timedelta(days=1),
impact="HIGH",
country="US",
description="Test",
)
)
calendar.add_event(
EconomicEvent(
name="Low Impact Event",
event_type="OTHER",
datetime=now + timedelta(days=1),
impact="LOW",
country="US",
description="Test",
)
)
# Filter for HIGH impact only
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="HIGH")
assert upcoming.high_impact_count == 1
assert upcoming.events[0].name == "High Impact Event"
# Filter for MEDIUM and above (should still get HIGH)
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="MEDIUM")
assert len(upcoming.events) == 1
# Filter for LOW and above (should get both)
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="LOW")
assert len(upcoming.events) == 2
def test_get_earnings_date_returns_next_earnings(self):
"""Should return next earnings date for a stock."""
calendar = EconomicCalendar()
now = datetime.now()
earnings_date = now + timedelta(days=5)
calendar.add_event(
EconomicEvent(
name="AAPL Earnings",
event_type="EARNINGS",
datetime=earnings_date,
impact="HIGH",
country="US",
description="Apple quarterly earnings",
)
)
result = calendar.get_earnings_date("AAPL")
assert result == earnings_date
def test_get_earnings_date_returns_none_if_not_found(self):
"""Should return None if no earnings found for stock."""
calendar = EconomicCalendar()
result = calendar.get_earnings_date("UNKNOWN")
assert result is None
def test_load_hardcoded_events(self):
"""Should load hardcoded major economic events."""
calendar = EconomicCalendar()
calendar.load_hardcoded_events()
# Should have multiple events (FOMC, GDP, CPI)
assert len(calendar._events) > 10
# Check for FOMC events
fomc_events = [e for e in calendar._events if e.event_type == "FOMC"]
assert len(fomc_events) > 0
# Check for GDP events
gdp_events = [e for e in calendar._events if e.event_type == "GDP"]
assert len(gdp_events) > 0
# Check for CPI events
cpi_events = [e for e in calendar._events if e.event_type == "CPI"]
assert len(cpi_events) == 12 # Monthly CPI releases
def test_is_high_volatility_period_returns_true_near_high_impact(self):
"""Should return True if high-impact event is within threshold."""
calendar = EconomicCalendar()
now = datetime.now()
calendar.add_event(
EconomicEvent(
name="FOMC Meeting",
event_type="FOMC",
datetime=now + timedelta(hours=12),
impact="HIGH",
country="US",
description="Test",
)
)
assert calendar.is_high_volatility_period(hours_ahead=24) is True
def test_is_high_volatility_period_returns_false_when_no_events(self):
"""Should return False if no high-impact events nearby."""
calendar = EconomicCalendar()
assert calendar.is_high_volatility_period(hours_ahead=24) is False
def test_clear_events(self):
"""Should clear all events."""
calendar = EconomicCalendar()
calendar.add_event(
EconomicEvent(
name="Test",
event_type="TEST",
datetime=datetime.now(),
impact="LOW",
country="US",
description="Test",
)
)
assert len(calendar._events) == 1
calendar.clear_events()
assert len(calendar._events) == 0
# ---------------------------------------------------------------------------
# MarketData Tests
# ---------------------------------------------------------------------------
class TestMarketData:
"""Test market data indicators."""
def test_market_data_init(self):
"""MarketData should initialize correctly."""
data = MarketData(api_key="test_key")
assert data._api_key == "test_key"
def test_get_market_sentiment_without_api_key_returns_neutral(self):
"""Without API key, should return NEUTRAL sentiment."""
data = MarketData(api_key=None)
sentiment = data.get_market_sentiment()
assert sentiment == MarketSentiment.NEUTRAL
def test_get_market_breadth_without_api_key_returns_none(self):
"""Without API key, should return None for breadth."""
data = MarketData(api_key=None)
breadth = data.get_market_breadth()
assert breadth is None
def test_get_sector_performance_without_api_key_returns_empty(self):
"""Without API key, should return empty list."""
data = MarketData(api_key=None)
sectors = data.get_sector_performance()
assert sectors == []
def test_get_market_indicators_returns_defaults_without_api(self):
"""Should return default indicators without API key."""
data = MarketData(api_key=None)
indicators = data.get_market_indicators()
assert indicators.sentiment == MarketSentiment.NEUTRAL
assert indicators.breadth.advance_decline_ratio == 1.0
assert indicators.sector_performance == []
assert indicators.vix_level is None
def test_calculate_fear_greed_score_neutral_baseline(self):
"""Should return neutral score (50) for balanced market."""
data = MarketData()
breadth = MarketBreadth(
advancing_stocks=500,
declining_stocks=500,
unchanged_stocks=100,
new_highs=50,
new_lows=50,
advance_decline_ratio=1.0,
)
score = data.calculate_fear_greed_score(breadth)
assert score == 50
def test_calculate_fear_greed_score_greedy_market(self):
"""Should return high score for greedy market conditions."""
data = MarketData()
breadth = MarketBreadth(
advancing_stocks=800,
declining_stocks=200,
unchanged_stocks=100,
new_highs=100,
new_lows=10,
advance_decline_ratio=4.0,
)
score = data.calculate_fear_greed_score(breadth, vix=12.0)
assert score > 70
def test_calculate_fear_greed_score_fearful_market(self):
"""Should return low score for fearful market conditions."""
data = MarketData()
breadth = MarketBreadth(
advancing_stocks=200,
declining_stocks=800,
unchanged_stocks=100,
new_highs=10,
new_lows=100,
advance_decline_ratio=0.25,
)
score = data.calculate_fear_greed_score(breadth, vix=35.0)
assert score < 30
# ---------------------------------------------------------------------------
# GeminiClient Integration Tests
# ---------------------------------------------------------------------------
class TestGeminiClientWithExternalData:
"""Test GeminiClient integration with external data sources."""
def test_gemini_client_accepts_optional_data_sources(self, settings):
"""GeminiClient should accept optional external data sources."""
news_api = NewsAPI(api_key="test_key")
calendar = EconomicCalendar()
market_data = MarketData()
client = GeminiClient(
settings,
news_api=news_api,
economic_calendar=calendar,
market_data=market_data,
)
assert client._news_api is news_api
assert client._economic_calendar is calendar
assert client._market_data is market_data
def test_gemini_client_works_without_external_data(self, settings):
"""GeminiClient should work without external data sources."""
client = GeminiClient(settings)
assert client._news_api is None
assert client._economic_calendar is None
assert client._market_data is None
@pytest.mark.asyncio
async def test_build_prompt_includes_news_sentiment(self, settings):
"""build_prompt should include news sentiment when available."""
client = GeminiClient(settings)
market_data = {
"stock_code": "AAPL",
"current_price": 180.0,
"market_name": "US stock market",
}
sentiment = NewsSentiment(
stock_code="AAPL",
articles=[
NewsArticle(
title="Apple hits record high",
summary="Strong earnings",
source="Reuters",
published_at="2026-02-04",
sentiment_score=0.85,
url="https://example.com",
)
],
avg_sentiment=0.85,
article_count=1,
fetched_at=time.time(),
)
prompt = await client.build_prompt(market_data, news_sentiment=sentiment)
assert "AAPL" in prompt
assert "180.0" in prompt
assert "EXTERNAL DATA" in prompt
assert "News Sentiment" in prompt
assert "0.85" in prompt
assert "Apple hits record high" in prompt
@pytest.mark.asyncio
async def test_build_prompt_with_economic_events(self, settings):
"""build_prompt should include upcoming economic events."""
calendar = EconomicCalendar()
now = datetime.now()
calendar.add_event(
EconomicEvent(
name="FOMC Meeting",
event_type="FOMC",
datetime=now + timedelta(days=2),
impact="HIGH",
country="US",
description="Interest rate decision",
)
)
client = GeminiClient(settings, economic_calendar=calendar)
market_data = {
"stock_code": "AAPL",
"current_price": 180.0,
"market_name": "US stock market",
}
prompt = await client.build_prompt(market_data)
assert "EXTERNAL DATA" in prompt
assert "High-Impact Events" in prompt
assert "FOMC Meeting" in prompt
@pytest.mark.asyncio
async def test_build_prompt_with_market_indicators(self, settings):
"""build_prompt should include market sentiment indicators."""
market_data_provider = MarketData(api_key="test_key")
# Mock the get_market_indicators to return test data
with patch.object(market_data_provider, "get_market_indicators") as mock:
mock.return_value = MagicMock(
sentiment=MarketSentiment.EXTREME_GREED,
breadth=MagicMock(advance_decline_ratio=2.5),
)
client = GeminiClient(settings, market_data=market_data_provider)
market_data = {
"stock_code": "AAPL",
"current_price": 180.0,
"market_name": "US stock market",
}
prompt = await client.build_prompt(market_data)
assert "EXTERNAL DATA" in prompt
assert "Market Sentiment" in prompt
assert "EXTREME_GREED" in prompt
@pytest.mark.asyncio
async def test_build_prompt_graceful_when_no_external_data(self, settings):
"""build_prompt should work gracefully without external data."""
client = GeminiClient(settings)
market_data = {
"stock_code": "AAPL",
"current_price": 180.0,
"market_name": "US stock market",
}
prompt = await client.build_prompt(market_data)
assert "AAPL" in prompt
assert "180.0" in prompt
# Should NOT have external data section
assert "EXTERNAL DATA" not in prompt
def test_build_prompt_sync_backward_compatibility(self, settings):
"""build_prompt_sync should maintain backward compatibility."""
client = GeminiClient(settings)
market_data = {
"stock_code": "005930",
"current_price": 72000,
"orderbook": {"asks": [], "bids": []},
"foreigner_net": -50000,
}
prompt = client.build_prompt_sync(market_data)
assert "005930" in prompt
assert "72000" in prompt
assert "JSON" in prompt
# Sync version should NOT have external data
assert "EXTERNAL DATA" not in prompt
@pytest.mark.asyncio
async def test_decide_with_news_sentiment_parameter(self, settings):
"""decide should accept optional news_sentiment parameter."""
client = GeminiClient(settings)
market_data = {
"stock_code": "AAPL",
"current_price": 180.0,
"market_name": "US stock market",
}
sentiment = NewsSentiment(
stock_code="AAPL",
articles=[],
avg_sentiment=0.5,
article_count=1,
fetched_at=time.time(),
)
# Mock the Gemini API call
with patch.object(client._client.aio.models, "generate_content", new_callable=AsyncMock) as mock_gen:
mock_response = MagicMock()
mock_response.text = '{"action": "BUY", "confidence": 85, "rationale": "Good news"}'
mock_gen.return_value = mock_response
decision = await client.decide(market_data, news_sentiment=sentiment)
assert decision.action == "BUY"
assert decision.confidence == 85
mock_gen.assert_called_once()

View File

@@ -11,15 +11,15 @@ from __future__ import annotations
import json
import sqlite3
import tempfile
from datetime import UTC, datetime, timedelta
from datetime import UTC, datetime
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, Mock, patch
from unittest.mock import AsyncMock, Mock, patch
import pytest
from src.config import Settings
from src.db import init_db, log_trade
from src.evolution.ab_test import ABTester, ABTestResult, StrategyPerformance
from src.evolution.ab_test import ABTester
from src.evolution.optimizer import EvolutionOptimizer
from src.evolution.performance_tracker import (
PerformanceDashboard,
@@ -28,7 +28,6 @@ from src.evolution.performance_tracker import (
)
from src.logging.decision_logger import DecisionLogger
# ------------------------------------------------------------------
# Fixtures
# ------------------------------------------------------------------

View File

@@ -0,0 +1,663 @@
"""Tests for token efficiency optimization components.
Tests cover:
- Prompt compression and optimization
- Context selection logic
- Summarization
- Caching
- Token reduction metrics
"""
from __future__ import annotations
import sqlite3
import time
import pytest
from src.brain.cache import DecisionCache
from src.brain.context_selector import ContextSelector, DecisionType
from src.brain.gemini_client import TradeDecision
from src.brain.prompt_optimizer import PromptOptimizer, TokenMetrics
from src.context.layer import ContextLayer
from src.context.store import ContextStore
from src.context.summarizer import ContextSummarizer, SummaryStats
# ============================================================================
# Prompt Optimizer Tests
# ============================================================================
class TestPromptOptimizer:
"""Tests for PromptOptimizer."""
def test_estimate_tokens(self):
"""Test token estimation."""
optimizer = PromptOptimizer()
# Empty text
assert optimizer.estimate_tokens("") == 0
# Short text (4 chars = 1 token estimate)
assert optimizer.estimate_tokens("test") == 1
# Longer text
text = "This is a longer piece of text for testing token estimation."
tokens = optimizer.estimate_tokens(text)
assert tokens > 0
assert tokens == len(text) // 4
def test_count_tokens(self):
"""Test token counting metrics."""
optimizer = PromptOptimizer()
text = "Hello world, this is a test."
metrics = optimizer.count_tokens(text)
assert isinstance(metrics, TokenMetrics)
assert metrics.char_count == len(text)
assert metrics.word_count == 6
assert metrics.estimated_tokens > 0
def test_compress_json(self):
"""Test JSON compression."""
optimizer = PromptOptimizer()
data = {
"action": "BUY",
"confidence": 85,
"rationale": "Strong uptrend",
}
compressed = optimizer.compress_json(data)
# Should have no newlines and minimal whitespace
assert "\n" not in compressed
# Note: JSON values may contain spaces (e.g., "Strong uptrend")
# but there should be no spaces around separators
assert ": " not in compressed
assert ", " not in compressed
# Should be valid JSON
import json
parsed = json.loads(compressed)
assert parsed == data
def test_abbreviate_text(self):
"""Test text abbreviation."""
optimizer = PromptOptimizer()
text = "The current price is high and volume is increasing."
abbreviated = optimizer.abbreviate_text(text)
# Should contain abbreviations
assert "cur" in abbreviated or "P" in abbreviated
assert len(abbreviated) <= len(text)
def test_abbreviate_text_aggressive(self):
"""Test aggressive text abbreviation."""
optimizer = PromptOptimizer()
text = "The price is increasing and the volume is high."
abbreviated = optimizer.abbreviate_text(text, aggressive=True)
# Should be shorter
assert len(abbreviated) < len(text)
# Should have removed articles
assert "the" not in abbreviated.lower()
def test_build_compressed_prompt(self):
"""Test compressed prompt building."""
optimizer = PromptOptimizer()
market_data = {
"stock_code": "005930",
"current_price": 75000,
"market_name": "Korean stock market",
}
prompt = optimizer.build_compressed_prompt(market_data)
# Should be much shorter than original
assert len(prompt) < 300
assert "005930" in prompt
assert "75000" in prompt
def test_build_compressed_prompt_no_instructions(self):
"""Test compressed prompt without instructions."""
optimizer = PromptOptimizer()
market_data = {
"stock_code": "AAPL",
"current_price": 150.5,
"market_name": "United States",
}
prompt = optimizer.build_compressed_prompt(market_data, include_instructions=False)
# Should be very short (data only)
assert len(prompt) < 100
assert "AAPL" in prompt
def test_truncate_context(self):
"""Test context truncation."""
optimizer = PromptOptimizer()
context = {
"price": 100.5,
"volume": 1000000,
"sentiment": 0.8,
"extra_data": "Some long text that should be truncated",
}
# Truncate to small budget
truncated = optimizer.truncate_context(context, max_tokens=10)
# Should have fewer keys
assert len(truncated) <= len(context)
def test_truncate_context_with_priority(self):
"""Test context truncation with priority keys."""
optimizer = PromptOptimizer()
context = {
"price": 100.5,
"volume": 1000000,
"sentiment": 0.8,
"extra_data": "Some data",
}
priority_keys = ["price", "sentiment"]
truncated = optimizer.truncate_context(context, max_tokens=20, priority_keys=priority_keys)
# Priority keys should be included
assert "price" in truncated
assert "sentiment" in truncated
def test_calculate_compression_ratio(self):
"""Test compression ratio calculation."""
optimizer = PromptOptimizer()
original = "This is a very long piece of text that should be compressed significantly."
compressed = "Short text"
ratio = optimizer.calculate_compression_ratio(original, compressed)
# Ratio should be > 1 (original is longer)
assert ratio > 1.0
# ============================================================================
# Context Selector Tests
# ============================================================================
class TestContextSelector:
"""Tests for ContextSelector."""
@pytest.fixture
def store(self):
"""Create in-memory ContextStore."""
conn = sqlite3.connect(":memory:")
# Create tables
conn.execute(
"""
CREATE TABLE context_metadata (
layer TEXT PRIMARY KEY,
description TEXT,
retention_days INTEGER,
aggregation_source TEXT
)
"""
)
conn.execute(
"""
CREATE TABLE contexts (
layer TEXT,
timeframe TEXT,
key TEXT,
value TEXT,
created_at TEXT,
updated_at TEXT,
PRIMARY KEY (layer, timeframe, key)
)
"""
)
conn.commit()
return ContextStore(conn)
def test_select_layers_normal(self, store):
"""Test layer selection for normal decisions."""
selector = ContextSelector(store)
layers = selector.select_layers(DecisionType.NORMAL)
# Should only select L7 (real-time)
assert layers == [ContextLayer.L7_REALTIME]
def test_select_layers_strategic(self, store):
"""Test layer selection for strategic decisions."""
selector = ContextSelector(store)
layers = selector.select_layers(DecisionType.STRATEGIC)
# Should select L7 + L6 + L5
assert ContextLayer.L7_REALTIME in layers
assert ContextLayer.L6_DAILY in layers
assert ContextLayer.L5_WEEKLY in layers
assert len(layers) == 3
def test_select_layers_major_event(self, store):
"""Test layer selection for major events."""
selector = ContextSelector(store)
layers = selector.select_layers(DecisionType.MAJOR_EVENT)
# Should select all layers
assert len(layers) == 7
assert ContextLayer.L1_LEGACY in layers
assert ContextLayer.L7_REALTIME in layers
def test_score_layer_relevance(self, store):
"""Test layer relevance scoring."""
selector = ContextSelector(store)
# Add some data first so scores aren't penalized
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
store.set_context(ContextLayer.L1_LEGACY, "legacy", "lesson", "test")
# L7 should have high score for normal decisions
score = selector.score_layer_relevance(ContextLayer.L7_REALTIME, DecisionType.NORMAL)
assert score == 1.0
# L1 should have low score for normal decisions
score = selector.score_layer_relevance(ContextLayer.L1_LEGACY, DecisionType.NORMAL)
assert score == 0.0
# L1 should have high score for major events
score = selector.score_layer_relevance(ContextLayer.L1_LEGACY, DecisionType.MAJOR_EVENT)
assert score == 1.0
def test_select_with_scoring(self, store):
"""Test selection with relevance scoring."""
selector = ContextSelector(store)
# Add data so layers aren't penalized
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
selection = selector.select_with_scoring(DecisionType.NORMAL, min_score=0.5)
# Should only select high-relevance layers
assert len(selection.layers) >= 1
assert ContextLayer.L7_REALTIME in selection.layers
assert all(selection.relevance_scores[layer] >= 0.5 for layer in selection.layers)
def test_get_context_data(self, store):
"""Test context data retrieval."""
selector = ContextSelector(store)
# Add some test data
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "volume", 1000000)
context_data = selector.get_context_data([ContextLayer.L7_REALTIME])
# Should retrieve data
assert "L7_REALTIME" in context_data
assert "price" in context_data["L7_REALTIME"]
assert context_data["L7_REALTIME"]["price"] == 100.5
def test_estimate_context_tokens(self, store):
"""Test context token estimation."""
selector = ContextSelector(store)
context_data = {
"L7_REALTIME": {"price": 100.5, "volume": 1000000},
"L6_DAILY": {"avg_price": 99.8, "avg_volume": 950000},
}
tokens = selector.estimate_context_tokens(context_data)
# Should estimate tokens
assert tokens > 0
def test_optimize_context_for_budget(self, store):
"""Test context optimization for token budget."""
selector = ContextSelector(store)
# Add test data
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
# Get optimized context within budget
context = selector.optimize_context_for_budget(DecisionType.NORMAL, max_tokens=50)
# Should return data within budget
tokens = selector.estimate_context_tokens(context)
assert tokens <= 50
# ============================================================================
# Context Summarizer Tests
# ============================================================================
class TestContextSummarizer:
"""Tests for ContextSummarizer."""
@pytest.fixture
def store(self):
"""Create in-memory ContextStore."""
conn = sqlite3.connect(":memory:")
conn.execute(
"""
CREATE TABLE context_metadata (
layer TEXT PRIMARY KEY,
description TEXT,
retention_days INTEGER,
aggregation_source TEXT
)
"""
)
conn.execute(
"""
CREATE TABLE contexts (
layer TEXT,
timeframe TEXT,
key TEXT,
value TEXT,
created_at TEXT,
updated_at TEXT,
PRIMARY KEY (layer, timeframe, key)
)
"""
)
conn.commit()
return ContextStore(conn)
def test_summarize_numeric_values(self, store):
"""Test numeric value summarization."""
summarizer = ContextSummarizer(store)
values = [10.0, 20.0, 30.0, 40.0, 50.0]
stats = summarizer.summarize_numeric_values(values)
assert isinstance(stats, SummaryStats)
assert stats.count == 5
assert stats.mean == 30.0
assert stats.min == 10.0
assert stats.max == 50.0
assert stats.std is not None
def test_summarize_numeric_values_trend(self, store):
"""Test trend detection in numeric values."""
summarizer = ContextSummarizer(store)
# Uptrend
values_up = [10.0, 15.0, 20.0, 25.0, 30.0, 35.0]
stats_up = summarizer.summarize_numeric_values(values_up)
assert stats_up.trend == "up"
# Downtrend
values_down = [35.0, 30.0, 25.0, 20.0, 15.0, 10.0]
stats_down = summarizer.summarize_numeric_values(values_down)
assert stats_down.trend == "down"
# Flat
values_flat = [20.0, 20.1, 19.9, 20.0, 20.1, 19.9]
stats_flat = summarizer.summarize_numeric_values(values_flat)
assert stats_flat.trend == "flat"
def test_summarize_layer(self, store):
"""Test layer summarization."""
summarizer = ContextSummarizer(store)
# Add test data
store.set_context(ContextLayer.L6_DAILY, "2026-02-04", "price", 100.5)
store.set_context(ContextLayer.L6_DAILY, "2026-02-04", "volume", 1000000)
summary = summarizer.summarize_layer(ContextLayer.L6_DAILY)
# Should have summary
assert "total_entries" in summary
assert summary["total_entries"] > 0
def test_create_compact_summary(self, store):
"""Test compact summary creation."""
summarizer = ContextSummarizer(store)
# Add test data
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
layers = [ContextLayer.L7_REALTIME, ContextLayer.L6_DAILY]
summary = summarizer.create_compact_summary(layers, top_n_metrics=3)
# Should have summaries for layers
assert "L7_REALTIME" in summary
def test_format_summary_for_prompt(self, store):
"""Test summary formatting for prompt."""
summarizer = ContextSummarizer(store)
summary = {
"L7_REALTIME": {
"price": {"avg": 100.5, "trend": "up"},
"volume": {"avg": 1000000, "trend": "flat"},
}
}
formatted = summarizer.format_summary_for_prompt(summary)
# Should be formatted string
assert isinstance(formatted, str)
assert "L7_REALTIME" in formatted
assert "100.5" in formatted or "100.50" in formatted
# ============================================================================
# Decision Cache Tests
# ============================================================================
class TestDecisionCache:
"""Tests for DecisionCache."""
def test_cache_init(self):
"""Test cache initialization."""
cache = DecisionCache(ttl_seconds=60, max_size=100)
assert cache.ttl_seconds == 60
assert cache.max_size == 100
def test_cache_miss(self):
"""Test cache miss."""
cache = DecisionCache()
market_data = {"stock_code": "005930", "current_price": 75000}
decision = cache.get(market_data)
# Should be None (cache miss)
assert decision is None
metrics = cache.get_metrics()
assert metrics.cache_misses == 1
assert metrics.cache_hits == 0
def test_cache_hit(self):
"""Test cache hit."""
cache = DecisionCache()
market_data = {"stock_code": "005930", "current_price": 75000}
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
# Set cache
cache.set(market_data, decision)
# Get from cache
cached = cache.get(market_data)
assert cached is not None
assert cached.action == "HOLD"
assert cached.confidence == 50
metrics = cache.get_metrics()
assert metrics.cache_hits == 1
def test_cache_ttl_expiration(self):
"""Test cache TTL expiration."""
cache = DecisionCache(ttl_seconds=1) # 1 second TTL
market_data = {"stock_code": "005930", "current_price": 75000}
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
# Set cache
cache.set(market_data, decision)
# Should hit immediately
cached = cache.get(market_data)
assert cached is not None
# Wait for expiration
time.sleep(1.1)
# Should miss after expiration
cached = cache.get(market_data)
assert cached is None
def test_cache_max_size(self):
"""Test cache max size eviction."""
cache = DecisionCache(max_size=2)
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
# Add 3 entries (exceeds max_size)
for i in range(3):
market_data = {"stock_code": f"00{i}", "current_price": 1000 * i}
cache.set(market_data, decision)
metrics = cache.get_metrics()
# Should have evicted 1 entry
assert metrics.total_entries == 2
assert metrics.evictions == 1
def test_invalidate_all(self):
"""Test invalidate all cache entries."""
cache = DecisionCache()
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
# Add entries
for i in range(3):
market_data = {"stock_code": f"00{i}", "current_price": 1000}
cache.set(market_data, decision)
# Invalidate all
count = cache.invalidate()
assert count == 3
metrics = cache.get_metrics()
assert metrics.total_entries == 0
def test_invalidate_by_stock(self):
"""Test invalidate cache by stock code."""
cache = DecisionCache()
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
# Add entries for different stocks
cache.set({"stock_code": "005930", "current_price": 75000}, decision)
cache.set({"stock_code": "000660", "current_price": 50000}, decision)
# Invalidate specific stock
count = cache.invalidate("005930")
assert count >= 1
# Other stock should still be cached
cached = cache.get({"stock_code": "000660", "current_price": 50000})
assert cached is not None
def test_cleanup_expired(self):
"""Test cleanup of expired entries."""
cache = DecisionCache(ttl_seconds=1)
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
# Add entry
cache.set({"stock_code": "005930", "current_price": 75000}, decision)
# Wait for expiration
time.sleep(1.1)
# Cleanup
count = cache.cleanup_expired()
assert count == 1
metrics = cache.get_metrics()
assert metrics.total_entries == 0
def test_should_cache_decision(self):
"""Test decision caching criteria."""
cache = DecisionCache()
# HOLD decisions should be cached
hold_decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
assert cache.should_cache_decision(hold_decision) is True
# High confidence BUY should be cached
buy_decision = TradeDecision(action="BUY", confidence=95, rationale="Test")
assert cache.should_cache_decision(buy_decision) is True
# Low confidence BUY should not be cached
low_conf_buy = TradeDecision(action="BUY", confidence=60, rationale="Test")
assert cache.should_cache_decision(low_conf_buy) is False
def test_cache_hit_rate(self):
"""Test cache hit rate calculation."""
cache = DecisionCache()
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
market_data = {"stock_code": "005930", "current_price": 75000}
# First request (miss)
cache.get(market_data)
# Set cache
cache.set(market_data, decision)
# Second request (hit)
cache.get(market_data)
# Third request (hit)
cache.get(market_data)
metrics = cache.get_metrics()
# 1 miss, 2 hits out of 3 requests
assert metrics.total_requests == 3
assert metrics.cache_hits == 2
assert metrics.cache_misses == 1
assert metrics.hit_rate == pytest.approx(2 / 3)
def test_reset_metrics(self):
"""Test metrics reset."""
cache = DecisionCache()
market_data = {"stock_code": "005930", "current_price": 75000}
# Generate some activity
cache.get(market_data)
cache.get(market_data)
# Reset
cache.reset_metrics()
metrics = cache.get_metrics()
assert metrics.total_requests == 0
assert metrics.cache_hits == 0
assert metrics.cache_misses == 0