Compare commits
9 Commits
feature/is
...
feature/is
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
73e1d0a54e | ||
|
|
87556b145e | ||
| 645c761238 | |||
|
|
033d5fcadd | ||
| 128324427f | |||
|
|
61f5aaf4a3 | ||
|
|
4f61d5af8e | ||
|
|
62fd4ff5e1 | ||
| f40f19e735 |
@@ -21,3 +21,8 @@ RATE_LIMIT_RPS=10.0
|
|||||||
|
|
||||||
# Trading Mode (paper / live)
|
# Trading Mode (paper / live)
|
||||||
MODE=paper
|
MODE=paper
|
||||||
|
|
||||||
|
# External Data APIs (optional — for enhanced decision-making)
|
||||||
|
# NEWS_API_KEY=your_news_api_key_here
|
||||||
|
# NEWS_API_PROVIDER=alphavantage
|
||||||
|
# MARKET_DATA_API_KEY=your_market_data_key_here
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -174,4 +174,7 @@ cython_debug/
|
|||||||
# PyPI configuration file
|
# PyPI configuration file
|
||||||
.pypirc
|
.pypirc
|
||||||
|
|
||||||
|
# Data files (trade logs, databases)
|
||||||
|
# But NOT src/data/ which contains source code
|
||||||
data/
|
data/
|
||||||
|
!src/data/
|
||||||
|
|||||||
293
src/brain/cache.py
Normal file
293
src/brain/cache.py
Normal file
@@ -0,0 +1,293 @@
|
|||||||
|
"""Response caching system for reducing redundant LLM calls.
|
||||||
|
|
||||||
|
This module provides caching for common trading scenarios:
|
||||||
|
- TTL-based cache invalidation
|
||||||
|
- Cache key based on market conditions
|
||||||
|
- Cache hit rate monitoring
|
||||||
|
- Special handling for HOLD decisions in quiet markets
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.brain.gemini_client import TradeDecision
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CacheEntry:
|
||||||
|
"""Cached decision with metadata."""
|
||||||
|
|
||||||
|
decision: "TradeDecision"
|
||||||
|
cached_at: float # Unix timestamp
|
||||||
|
hit_count: int = 0
|
||||||
|
market_data_hash: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CacheMetrics:
|
||||||
|
"""Metrics for cache performance monitoring."""
|
||||||
|
|
||||||
|
total_requests: int = 0
|
||||||
|
cache_hits: int = 0
|
||||||
|
cache_misses: int = 0
|
||||||
|
evictions: int = 0
|
||||||
|
total_entries: int = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hit_rate(self) -> float:
|
||||||
|
"""Calculate cache hit rate."""
|
||||||
|
if self.total_requests == 0:
|
||||||
|
return 0.0
|
||||||
|
return self.cache_hits / self.total_requests
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert metrics to dictionary."""
|
||||||
|
return {
|
||||||
|
"total_requests": self.total_requests,
|
||||||
|
"cache_hits": self.cache_hits,
|
||||||
|
"cache_misses": self.cache_misses,
|
||||||
|
"hit_rate": self.hit_rate,
|
||||||
|
"evictions": self.evictions,
|
||||||
|
"total_entries": self.total_entries,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DecisionCache:
|
||||||
|
"""TTL-based cache for trade decisions."""
|
||||||
|
|
||||||
|
def __init__(self, ttl_seconds: int = 300, max_size: int = 1000) -> None:
|
||||||
|
"""Initialize the decision cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ttl_seconds: Time-to-live for cache entries in seconds (default: 5 minutes)
|
||||||
|
max_size: Maximum number of cache entries
|
||||||
|
"""
|
||||||
|
self.ttl_seconds = ttl_seconds
|
||||||
|
self.max_size = max_size
|
||||||
|
self._cache: dict[str, CacheEntry] = {}
|
||||||
|
self._metrics = CacheMetrics()
|
||||||
|
|
||||||
|
def _generate_cache_key(self, market_data: dict[str, Any]) -> str:
|
||||||
|
"""Generate cache key from market data.
|
||||||
|
|
||||||
|
Key is based on:
|
||||||
|
- Stock code
|
||||||
|
- Current price (rounded to reduce sensitivity)
|
||||||
|
- Market conditions (orderbook snapshot)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
market_data: Market data dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cache key string
|
||||||
|
"""
|
||||||
|
# Extract key components
|
||||||
|
stock_code = market_data.get("stock_code", "UNKNOWN")
|
||||||
|
current_price = market_data.get("current_price", 0)
|
||||||
|
|
||||||
|
# Round price to reduce sensitivity (cache hits for similar prices)
|
||||||
|
# For prices > 1000, round to nearest 10
|
||||||
|
# For prices < 1000, round to nearest 1
|
||||||
|
if current_price > 1000:
|
||||||
|
price_rounded = round(current_price / 10) * 10
|
||||||
|
else:
|
||||||
|
price_rounded = round(current_price)
|
||||||
|
|
||||||
|
# Include orderbook snapshot (if available)
|
||||||
|
orderbook_key = ""
|
||||||
|
if "orderbook" in market_data and market_data["orderbook"]:
|
||||||
|
ob = market_data["orderbook"]
|
||||||
|
# Just use bid/ask spread as indicator
|
||||||
|
if "bid" in ob and "ask" in ob and ob["bid"] and ob["ask"]:
|
||||||
|
bid_price = ob["bid"][0].get("price", 0) if ob["bid"] else 0
|
||||||
|
ask_price = ob["ask"][0].get("price", 0) if ob["ask"] else 0
|
||||||
|
spread = ask_price - bid_price
|
||||||
|
orderbook_key = f"_spread{spread}"
|
||||||
|
|
||||||
|
# Generate cache key
|
||||||
|
key_str = f"{stock_code}_{price_rounded}{orderbook_key}"
|
||||||
|
|
||||||
|
return key_str
|
||||||
|
|
||||||
|
def _generate_market_hash(self, market_data: dict[str, Any]) -> str:
|
||||||
|
"""Generate hash of full market data for invalidation checks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
market_data: Market data dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hash string
|
||||||
|
"""
|
||||||
|
# Create stable JSON representation
|
||||||
|
stable_json = json.dumps(market_data, sort_keys=True, ensure_ascii=False)
|
||||||
|
return hashlib.md5(stable_json.encode()).hexdigest()
|
||||||
|
|
||||||
|
def get(self, market_data: dict[str, Any]) -> TradeDecision | None:
|
||||||
|
"""Retrieve cached decision if valid.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
market_data: Market data dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached TradeDecision if valid, None otherwise
|
||||||
|
"""
|
||||||
|
self._metrics.total_requests += 1
|
||||||
|
|
||||||
|
cache_key = self._generate_cache_key(market_data)
|
||||||
|
|
||||||
|
if cache_key not in self._cache:
|
||||||
|
self._metrics.cache_misses += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
entry = self._cache[cache_key]
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
# Check TTL
|
||||||
|
if current_time - entry.cached_at > self.ttl_seconds:
|
||||||
|
# Expired
|
||||||
|
del self._cache[cache_key]
|
||||||
|
self._metrics.cache_misses += 1
|
||||||
|
self._metrics.evictions += 1
|
||||||
|
logger.debug("Cache expired for key: %s", cache_key)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Cache hit
|
||||||
|
entry.hit_count += 1
|
||||||
|
self._metrics.cache_hits += 1
|
||||||
|
logger.debug("Cache hit for key: %s (hits: %d)", cache_key, entry.hit_count)
|
||||||
|
|
||||||
|
return entry.decision
|
||||||
|
|
||||||
|
def set(
|
||||||
|
self,
|
||||||
|
market_data: dict[str, Any],
|
||||||
|
decision: TradeDecision,
|
||||||
|
) -> None:
|
||||||
|
"""Store decision in cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
market_data: Market data dictionary
|
||||||
|
decision: TradeDecision to cache
|
||||||
|
"""
|
||||||
|
cache_key = self._generate_cache_key(market_data)
|
||||||
|
market_hash = self._generate_market_hash(market_data)
|
||||||
|
|
||||||
|
# Enforce max size (evict oldest if full)
|
||||||
|
if len(self._cache) >= self.max_size:
|
||||||
|
# Find oldest entry
|
||||||
|
oldest_key = min(self._cache.keys(), key=lambda k: self._cache[k].cached_at)
|
||||||
|
del self._cache[oldest_key]
|
||||||
|
self._metrics.evictions += 1
|
||||||
|
logger.debug("Cache full, evicted key: %s", oldest_key)
|
||||||
|
|
||||||
|
# Store entry
|
||||||
|
entry = CacheEntry(
|
||||||
|
decision=decision,
|
||||||
|
cached_at=time.time(),
|
||||||
|
market_data_hash=market_hash,
|
||||||
|
)
|
||||||
|
self._cache[cache_key] = entry
|
||||||
|
self._metrics.total_entries = len(self._cache)
|
||||||
|
|
||||||
|
logger.debug("Cached decision for key: %s", cache_key)
|
||||||
|
|
||||||
|
def invalidate(self, stock_code: str | None = None) -> int:
|
||||||
|
"""Invalidate cache entries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stock_code: Specific stock code to invalidate, or None for all
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of entries invalidated
|
||||||
|
"""
|
||||||
|
if stock_code is None:
|
||||||
|
# Clear all
|
||||||
|
count = len(self._cache)
|
||||||
|
self._cache.clear()
|
||||||
|
self._metrics.evictions += count
|
||||||
|
self._metrics.total_entries = 0
|
||||||
|
logger.info("Invalidated all cache entries (%d)", count)
|
||||||
|
return count
|
||||||
|
|
||||||
|
# Invalidate specific stock
|
||||||
|
keys_to_remove = [k for k in self._cache.keys() if k.startswith(f"{stock_code}_")]
|
||||||
|
count = len(keys_to_remove)
|
||||||
|
|
||||||
|
for key in keys_to_remove:
|
||||||
|
del self._cache[key]
|
||||||
|
|
||||||
|
self._metrics.evictions += count
|
||||||
|
self._metrics.total_entries = len(self._cache)
|
||||||
|
logger.info("Invalidated %d cache entries for stock: %s", count, stock_code)
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
def cleanup_expired(self) -> int:
|
||||||
|
"""Remove expired entries from cache.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of entries removed
|
||||||
|
"""
|
||||||
|
current_time = time.time()
|
||||||
|
expired_keys = [
|
||||||
|
k
|
||||||
|
for k, v in self._cache.items()
|
||||||
|
if current_time - v.cached_at > self.ttl_seconds
|
||||||
|
]
|
||||||
|
|
||||||
|
count = len(expired_keys)
|
||||||
|
for key in expired_keys:
|
||||||
|
del self._cache[key]
|
||||||
|
|
||||||
|
self._metrics.evictions += count
|
||||||
|
self._metrics.total_entries = len(self._cache)
|
||||||
|
|
||||||
|
if count > 0:
|
||||||
|
logger.debug("Cleaned up %d expired cache entries", count)
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
def get_metrics(self) -> CacheMetrics:
|
||||||
|
"""Get current cache metrics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CacheMetrics object with current statistics
|
||||||
|
"""
|
||||||
|
return self._metrics
|
||||||
|
|
||||||
|
def reset_metrics(self) -> None:
|
||||||
|
"""Reset cache metrics."""
|
||||||
|
self._metrics = CacheMetrics(total_entries=len(self._cache))
|
||||||
|
logger.info("Cache metrics reset")
|
||||||
|
|
||||||
|
def should_cache_decision(self, decision: TradeDecision) -> bool:
|
||||||
|
"""Determine if a decision should be cached.
|
||||||
|
|
||||||
|
HOLD decisions with low confidence are good candidates for caching,
|
||||||
|
as they're likely to recur in quiet markets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decision: TradeDecision to evaluate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if decision should be cached
|
||||||
|
"""
|
||||||
|
# Cache HOLD decisions (common in quiet markets)
|
||||||
|
if decision.action == "HOLD":
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Cache high-confidence decisions (stable signals)
|
||||||
|
if decision.confidence >= 90:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Don't cache low-confidence BUY/SELL (volatile signals)
|
||||||
|
return False
|
||||||
296
src/brain/context_selector.py
Normal file
296
src/brain/context_selector.py
Normal file
@@ -0,0 +1,296 @@
|
|||||||
|
"""Smart context selection for optimizing token usage.
|
||||||
|
|
||||||
|
This module implements intelligent selection of context layers (L1-L7) based on
|
||||||
|
decision type and market conditions:
|
||||||
|
- L7 (real-time) for normal trading decisions
|
||||||
|
- L6-L5 (daily/weekly) for strategic decisions
|
||||||
|
- L4-L1 (monthly/legacy) only for major events or policy changes
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from src.context.layer import ContextLayer
|
||||||
|
from src.context.store import ContextStore
|
||||||
|
|
||||||
|
|
||||||
|
class DecisionType(str, Enum):
|
||||||
|
"""Type of trading decision being made."""
|
||||||
|
|
||||||
|
NORMAL = "normal" # Regular trade decision
|
||||||
|
STRATEGIC = "strategic" # Strategy adjustment
|
||||||
|
MAJOR_EVENT = "major_event" # Portfolio rebalancing, policy change
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ContextSelection:
|
||||||
|
"""Selected context layers and their relevance scores."""
|
||||||
|
|
||||||
|
layers: list[ContextLayer]
|
||||||
|
relevance_scores: dict[ContextLayer, float]
|
||||||
|
total_score: float
|
||||||
|
|
||||||
|
|
||||||
|
class ContextSelector:
|
||||||
|
"""Selects optimal context layers to minimize token usage."""
|
||||||
|
|
||||||
|
def __init__(self, store: ContextStore) -> None:
|
||||||
|
"""Initialize the context selector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
store: ContextStore instance for retrieving context data
|
||||||
|
"""
|
||||||
|
self.store = store
|
||||||
|
|
||||||
|
def select_layers(
|
||||||
|
self,
|
||||||
|
decision_type: DecisionType = DecisionType.NORMAL,
|
||||||
|
include_realtime: bool = True,
|
||||||
|
) -> list[ContextLayer]:
|
||||||
|
"""Select context layers based on decision type.
|
||||||
|
|
||||||
|
Strategy:
|
||||||
|
- NORMAL: L7 (real-time) only
|
||||||
|
- STRATEGIC: L7 + L6 + L5 (real-time + daily + weekly)
|
||||||
|
- MAJOR_EVENT: All layers L1-L7
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decision_type: Type of decision being made
|
||||||
|
include_realtime: Whether to include L7 real-time data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of context layers to use (ordered by priority)
|
||||||
|
"""
|
||||||
|
if decision_type == DecisionType.NORMAL:
|
||||||
|
# Normal trading: only real-time data
|
||||||
|
return [ContextLayer.L7_REALTIME] if include_realtime else []
|
||||||
|
|
||||||
|
elif decision_type == DecisionType.STRATEGIC:
|
||||||
|
# Strategic decisions: real-time + recent history
|
||||||
|
layers = []
|
||||||
|
if include_realtime:
|
||||||
|
layers.append(ContextLayer.L7_REALTIME)
|
||||||
|
layers.extend([ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY])
|
||||||
|
return layers
|
||||||
|
|
||||||
|
else: # MAJOR_EVENT
|
||||||
|
# Major events: all layers for comprehensive context
|
||||||
|
layers = []
|
||||||
|
if include_realtime:
|
||||||
|
layers.append(ContextLayer.L7_REALTIME)
|
||||||
|
layers.extend(
|
||||||
|
[
|
||||||
|
ContextLayer.L6_DAILY,
|
||||||
|
ContextLayer.L5_WEEKLY,
|
||||||
|
ContextLayer.L4_MONTHLY,
|
||||||
|
ContextLayer.L3_QUARTERLY,
|
||||||
|
ContextLayer.L2_ANNUAL,
|
||||||
|
ContextLayer.L1_LEGACY,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return layers
|
||||||
|
|
||||||
|
def score_layer_relevance(
|
||||||
|
self,
|
||||||
|
layer: ContextLayer,
|
||||||
|
decision_type: DecisionType,
|
||||||
|
current_time: datetime | None = None,
|
||||||
|
) -> float:
|
||||||
|
"""Calculate relevance score for a context layer.
|
||||||
|
|
||||||
|
Relevance is based on:
|
||||||
|
1. Decision type (normal, strategic, major event)
|
||||||
|
2. Layer recency (L7 > L6 > ... > L1)
|
||||||
|
3. Data availability
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer: Context layer to score
|
||||||
|
decision_type: Type of decision being made
|
||||||
|
current_time: Current time (defaults to now)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Relevance score (0.0 to 1.0)
|
||||||
|
"""
|
||||||
|
if current_time is None:
|
||||||
|
current_time = datetime.now(UTC)
|
||||||
|
|
||||||
|
# Base scores by decision type
|
||||||
|
base_scores = {
|
||||||
|
DecisionType.NORMAL: {
|
||||||
|
ContextLayer.L7_REALTIME: 1.0,
|
||||||
|
ContextLayer.L6_DAILY: 0.1,
|
||||||
|
ContextLayer.L5_WEEKLY: 0.05,
|
||||||
|
ContextLayer.L4_MONTHLY: 0.01,
|
||||||
|
ContextLayer.L3_QUARTERLY: 0.0,
|
||||||
|
ContextLayer.L2_ANNUAL: 0.0,
|
||||||
|
ContextLayer.L1_LEGACY: 0.0,
|
||||||
|
},
|
||||||
|
DecisionType.STRATEGIC: {
|
||||||
|
ContextLayer.L7_REALTIME: 0.9,
|
||||||
|
ContextLayer.L6_DAILY: 0.8,
|
||||||
|
ContextLayer.L5_WEEKLY: 0.7,
|
||||||
|
ContextLayer.L4_MONTHLY: 0.3,
|
||||||
|
ContextLayer.L3_QUARTERLY: 0.2,
|
||||||
|
ContextLayer.L2_ANNUAL: 0.1,
|
||||||
|
ContextLayer.L1_LEGACY: 0.05,
|
||||||
|
},
|
||||||
|
DecisionType.MAJOR_EVENT: {
|
||||||
|
ContextLayer.L7_REALTIME: 0.7,
|
||||||
|
ContextLayer.L6_DAILY: 0.7,
|
||||||
|
ContextLayer.L5_WEEKLY: 0.7,
|
||||||
|
ContextLayer.L4_MONTHLY: 0.8,
|
||||||
|
ContextLayer.L3_QUARTERLY: 0.8,
|
||||||
|
ContextLayer.L2_ANNUAL: 0.9,
|
||||||
|
ContextLayer.L1_LEGACY: 1.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
score = base_scores[decision_type].get(layer, 0.0)
|
||||||
|
|
||||||
|
# Check data availability
|
||||||
|
latest_timeframe = self.store.get_latest_timeframe(layer)
|
||||||
|
if latest_timeframe is None:
|
||||||
|
# No data available - reduce score significantly
|
||||||
|
score *= 0.1
|
||||||
|
|
||||||
|
return score
|
||||||
|
|
||||||
|
def select_with_scoring(
|
||||||
|
self,
|
||||||
|
decision_type: DecisionType = DecisionType.NORMAL,
|
||||||
|
min_score: float = 0.5,
|
||||||
|
) -> ContextSelection:
|
||||||
|
"""Select context layers with relevance scoring.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decision_type: Type of decision being made
|
||||||
|
min_score: Minimum relevance score to include a layer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ContextSelection with selected layers and scores
|
||||||
|
"""
|
||||||
|
all_layers = [
|
||||||
|
ContextLayer.L7_REALTIME,
|
||||||
|
ContextLayer.L6_DAILY,
|
||||||
|
ContextLayer.L5_WEEKLY,
|
||||||
|
ContextLayer.L4_MONTHLY,
|
||||||
|
ContextLayer.L3_QUARTERLY,
|
||||||
|
ContextLayer.L2_ANNUAL,
|
||||||
|
ContextLayer.L1_LEGACY,
|
||||||
|
]
|
||||||
|
|
||||||
|
scores = {
|
||||||
|
layer: self.score_layer_relevance(layer, decision_type) for layer in all_layers
|
||||||
|
}
|
||||||
|
|
||||||
|
# Filter by minimum score
|
||||||
|
selected_layers = [layer for layer, score in scores.items() if score >= min_score]
|
||||||
|
|
||||||
|
# Sort by score (descending)
|
||||||
|
selected_layers.sort(key=lambda layer: scores[layer], reverse=True)
|
||||||
|
|
||||||
|
total_score = sum(scores[layer] for layer in selected_layers)
|
||||||
|
|
||||||
|
return ContextSelection(
|
||||||
|
layers=selected_layers,
|
||||||
|
relevance_scores=scores,
|
||||||
|
total_score=total_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_context_data(
|
||||||
|
self,
|
||||||
|
layers: list[ContextLayer],
|
||||||
|
max_items_per_layer: int = 10,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Retrieve context data for selected layers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layers: List of context layers to retrieve
|
||||||
|
max_items_per_layer: Maximum number of items per layer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with context data organized by layer
|
||||||
|
"""
|
||||||
|
result: dict[str, Any] = {}
|
||||||
|
|
||||||
|
for layer in layers:
|
||||||
|
# Get latest timeframe for this layer
|
||||||
|
latest_timeframe = self.store.get_latest_timeframe(layer)
|
||||||
|
if latest_timeframe:
|
||||||
|
# Get all contexts for latest timeframe
|
||||||
|
contexts = self.store.get_all_contexts(layer, latest_timeframe)
|
||||||
|
|
||||||
|
# Limit number of items
|
||||||
|
if len(contexts) > max_items_per_layer:
|
||||||
|
# Keep only first N items
|
||||||
|
contexts = dict(list(contexts.items())[:max_items_per_layer])
|
||||||
|
|
||||||
|
result[layer.value] = contexts
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def estimate_context_tokens(self, context_data: dict[str, Any]) -> int:
|
||||||
|
"""Estimate total tokens for context data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_data: Context data dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
from src.brain.prompt_optimizer import PromptOptimizer
|
||||||
|
|
||||||
|
# Serialize to JSON and estimate tokens
|
||||||
|
json_str = json.dumps(context_data, ensure_ascii=False)
|
||||||
|
return PromptOptimizer.estimate_tokens(json_str)
|
||||||
|
|
||||||
|
def optimize_context_for_budget(
|
||||||
|
self,
|
||||||
|
decision_type: DecisionType,
|
||||||
|
max_tokens: int,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Select and retrieve context data within a token budget.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decision_type: Type of decision being made
|
||||||
|
max_tokens: Maximum token budget for context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optimized context data within budget
|
||||||
|
"""
|
||||||
|
# Start with minimal selection
|
||||||
|
selection = self.select_with_scoring(decision_type, min_score=0.5)
|
||||||
|
|
||||||
|
# Retrieve data
|
||||||
|
context_data = self.get_context_data(selection.layers)
|
||||||
|
|
||||||
|
# Check if within budget
|
||||||
|
estimated_tokens = self.estimate_context_tokens(context_data)
|
||||||
|
|
||||||
|
if estimated_tokens <= max_tokens:
|
||||||
|
return context_data
|
||||||
|
|
||||||
|
# If over budget, progressively reduce
|
||||||
|
# 1. Reduce items per layer
|
||||||
|
for max_items in [5, 3, 1]:
|
||||||
|
context_data = self.get_context_data(selection.layers, max_items)
|
||||||
|
estimated_tokens = self.estimate_context_tokens(context_data)
|
||||||
|
if estimated_tokens <= max_tokens:
|
||||||
|
return context_data
|
||||||
|
|
||||||
|
# 2. Remove lower-priority layers
|
||||||
|
for min_score in [0.6, 0.7, 0.8, 0.9]:
|
||||||
|
selection = self.select_with_scoring(decision_type, min_score=min_score)
|
||||||
|
context_data = self.get_context_data(selection.layers, max_items_per_layer=1)
|
||||||
|
estimated_tokens = self.estimate_context_tokens(context_data)
|
||||||
|
if estimated_tokens <= max_tokens:
|
||||||
|
return context_data
|
||||||
|
|
||||||
|
# Last resort: return only L7 with minimal data
|
||||||
|
return self.get_context_data([ContextLayer.L7_REALTIME], max_items_per_layer=1)
|
||||||
@@ -2,6 +2,17 @@
|
|||||||
|
|
||||||
Constructs prompts from market data, calls Gemini, and parses structured
|
Constructs prompts from market data, calls Gemini, and parses structured
|
||||||
JSON responses into validated TradeDecision objects.
|
JSON responses into validated TradeDecision objects.
|
||||||
|
|
||||||
|
Includes token efficiency optimizations:
|
||||||
|
- Prompt compression and abbreviation
|
||||||
|
- Response caching for common scenarios
|
||||||
|
- Smart context selection
|
||||||
|
- Token usage tracking and metrics
|
||||||
|
|
||||||
|
Includes external data integration:
|
||||||
|
- News sentiment analysis
|
||||||
|
- Economic calendar events
|
||||||
|
- Market indicators
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -15,6 +26,11 @@ from typing import Any
|
|||||||
from google import genai
|
from google import genai
|
||||||
|
|
||||||
from src.config import Settings
|
from src.config import Settings
|
||||||
|
from src.data.news_api import NewsAPI, NewsSentiment
|
||||||
|
from src.data.economic_calendar import EconomicCalendar
|
||||||
|
from src.data.market_data import MarketData
|
||||||
|
from src.brain.cache import DecisionCache
|
||||||
|
from src.brain.prompt_optimizer import PromptOptimizer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -28,23 +44,176 @@ class TradeDecision:
|
|||||||
action: str # "BUY" | "SELL" | "HOLD"
|
action: str # "BUY" | "SELL" | "HOLD"
|
||||||
confidence: int # 0-100
|
confidence: int # 0-100
|
||||||
rationale: str
|
rationale: str
|
||||||
|
token_count: int = 0 # Estimated tokens used
|
||||||
|
cached: bool = False # Whether decision came from cache
|
||||||
|
|
||||||
|
|
||||||
class GeminiClient:
|
class GeminiClient:
|
||||||
"""Wraps the Gemini API for trade decision-making."""
|
"""Wraps the Gemini API for trade decision-making."""
|
||||||
|
|
||||||
def __init__(self, settings: Settings) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
settings: Settings,
|
||||||
|
news_api: NewsAPI | None = None,
|
||||||
|
economic_calendar: EconomicCalendar | None = None,
|
||||||
|
market_data: MarketData | None = None,
|
||||||
|
enable_cache: bool = True,
|
||||||
|
enable_optimization: bool = True,
|
||||||
|
) -> None:
|
||||||
self._settings = settings
|
self._settings = settings
|
||||||
self._confidence_threshold = settings.CONFIDENCE_THRESHOLD
|
self._confidence_threshold = settings.CONFIDENCE_THRESHOLD
|
||||||
self._client = genai.Client(api_key=settings.GEMINI_API_KEY)
|
self._client = genai.Client(api_key=settings.GEMINI_API_KEY)
|
||||||
self._model_name = settings.GEMINI_MODEL
|
self._model_name = settings.GEMINI_MODEL
|
||||||
|
|
||||||
|
# External data sources (optional)
|
||||||
|
self._news_api = news_api
|
||||||
|
self._economic_calendar = economic_calendar
|
||||||
|
self._market_data = market_data
|
||||||
|
|
||||||
|
# Token efficiency features
|
||||||
|
self._enable_cache = enable_cache
|
||||||
|
self._enable_optimization = enable_optimization
|
||||||
|
self._cache = DecisionCache(ttl_seconds=300) if enable_cache else None
|
||||||
|
self._optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
# Token usage metrics
|
||||||
|
self._total_tokens_used = 0
|
||||||
|
self._total_decisions = 0
|
||||||
|
self._total_cached_decisions = 0
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# External Data Integration
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _build_external_context(
|
||||||
|
self, stock_code: str, news_sentiment: NewsSentiment | None = None
|
||||||
|
) -> str:
|
||||||
|
"""Build external data context for the prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stock_code: Stock ticker symbol
|
||||||
|
news_sentiment: Optional pre-fetched news sentiment
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted string with external data context
|
||||||
|
"""
|
||||||
|
context_parts: list[str] = []
|
||||||
|
|
||||||
|
# News sentiment
|
||||||
|
if news_sentiment is not None:
|
||||||
|
sentiment_str = self._format_news_sentiment(news_sentiment)
|
||||||
|
if sentiment_str:
|
||||||
|
context_parts.append(sentiment_str)
|
||||||
|
elif self._news_api is not None:
|
||||||
|
# Fetch news sentiment if not provided
|
||||||
|
try:
|
||||||
|
sentiment = await self._news_api.get_news_sentiment(stock_code)
|
||||||
|
if sentiment is not None:
|
||||||
|
sentiment_str = self._format_news_sentiment(sentiment)
|
||||||
|
if sentiment_str:
|
||||||
|
context_parts.append(sentiment_str)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to fetch news sentiment: %s", exc)
|
||||||
|
|
||||||
|
# Economic events
|
||||||
|
if self._economic_calendar is not None:
|
||||||
|
events_str = self._format_economic_events(stock_code)
|
||||||
|
if events_str:
|
||||||
|
context_parts.append(events_str)
|
||||||
|
|
||||||
|
# Market indicators
|
||||||
|
if self._market_data is not None:
|
||||||
|
indicators_str = self._format_market_indicators()
|
||||||
|
if indicators_str:
|
||||||
|
context_parts.append(indicators_str)
|
||||||
|
|
||||||
|
if not context_parts:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
return "EXTERNAL DATA:\n" + "\n\n".join(context_parts)
|
||||||
|
|
||||||
|
def _format_news_sentiment(self, sentiment: NewsSentiment) -> str:
|
||||||
|
"""Format news sentiment for prompt."""
|
||||||
|
if sentiment.article_count == 0:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Select top 3 most relevant articles
|
||||||
|
top_articles = sentiment.articles[:3]
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
f"News Sentiment: {sentiment.avg_sentiment:.2f} "
|
||||||
|
f"(from {sentiment.article_count} articles)",
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, article in enumerate(top_articles, 1):
|
||||||
|
lines.append(
|
||||||
|
f" {i}. [{article.source}] {article.title} "
|
||||||
|
f"(sentiment: {article.sentiment_score:.2f})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def _format_economic_events(self, stock_code: str) -> str:
|
||||||
|
"""Format upcoming economic events for prompt."""
|
||||||
|
if self._economic_calendar is None:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Check for upcoming high-impact events
|
||||||
|
upcoming = self._economic_calendar.get_upcoming_events(
|
||||||
|
days_ahead=7, min_impact="HIGH"
|
||||||
|
)
|
||||||
|
|
||||||
|
if upcoming.high_impact_count == 0:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
f"Upcoming High-Impact Events: {upcoming.high_impact_count} in next 7 days"
|
||||||
|
]
|
||||||
|
|
||||||
|
if upcoming.next_major_event is not None:
|
||||||
|
event = upcoming.next_major_event
|
||||||
|
lines.append(
|
||||||
|
f" Next: {event.name} ({event.event_type}) "
|
||||||
|
f"on {event.datetime.strftime('%Y-%m-%d')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for earnings
|
||||||
|
earnings_date = self._economic_calendar.get_earnings_date(stock_code)
|
||||||
|
if earnings_date is not None:
|
||||||
|
lines.append(
|
||||||
|
f" Earnings: {stock_code} on {earnings_date.strftime('%Y-%m-%d')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def _format_market_indicators(self) -> str:
|
||||||
|
"""Format market indicators for prompt."""
|
||||||
|
if self._market_data is None:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
indicators = self._market_data.get_market_indicators()
|
||||||
|
lines = [f"Market Sentiment: {indicators.sentiment.name}"]
|
||||||
|
|
||||||
|
# Add breadth if meaningful
|
||||||
|
if indicators.breadth.advance_decline_ratio != 1.0:
|
||||||
|
lines.append(
|
||||||
|
f"Advance/Decline Ratio: {indicators.breadth.advance_decline_ratio:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to get market indicators: %s", exc)
|
||||||
|
return ""
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Prompt Construction
|
# Prompt Construction
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def build_prompt(self, market_data: dict[str, Any]) -> str:
|
async def build_prompt(
|
||||||
"""Build a structured prompt from market data.
|
self, market_data: dict[str, Any], news_sentiment: NewsSentiment | None = None
|
||||||
|
) -> str:
|
||||||
|
"""Build a structured prompt from market data and external sources.
|
||||||
|
|
||||||
The prompt instructs Gemini to return valid JSON with action,
|
The prompt instructs Gemini to return valid JSON with action,
|
||||||
confidence, and rationale fields.
|
confidence, and rationale fields.
|
||||||
@@ -72,6 +241,60 @@ class GeminiClient:
|
|||||||
|
|
||||||
market_info = "\n".join(market_info_lines)
|
market_info = "\n".join(market_info_lines)
|
||||||
|
|
||||||
|
# Add external data context if available
|
||||||
|
external_context = await self._build_external_context(
|
||||||
|
market_data["stock_code"], news_sentiment
|
||||||
|
)
|
||||||
|
if external_context:
|
||||||
|
market_info += f"\n\n{external_context}"
|
||||||
|
|
||||||
|
json_format = (
|
||||||
|
'{"action": "BUY"|"SELL"|"HOLD", '
|
||||||
|
'"confidence": <int 0-100>, "rationale": "<string>"}'
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
f"You are a professional {market_name} trading analyst.\n"
|
||||||
|
"Analyze the following market data and decide whether to "
|
||||||
|
"BUY, SELL, or HOLD.\n\n"
|
||||||
|
f"{market_info}\n\n"
|
||||||
|
"You MUST respond with ONLY valid JSON in the following format:\n"
|
||||||
|
f"{json_format}\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
"- action must be exactly one of: BUY, SELL, HOLD\n"
|
||||||
|
"- confidence must be an integer from 0 to 100\n"
|
||||||
|
"- rationale must explain your reasoning concisely\n"
|
||||||
|
"- Do NOT wrap the JSON in markdown code blocks\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_prompt_sync(self, market_data: dict[str, Any]) -> str:
|
||||||
|
"""Synchronous version of build_prompt (for backward compatibility).
|
||||||
|
|
||||||
|
This version does NOT include external data integration.
|
||||||
|
Use async build_prompt() for full functionality.
|
||||||
|
"""
|
||||||
|
market_name = market_data.get("market_name", "Korean stock market")
|
||||||
|
|
||||||
|
# Build market data section dynamically based on available fields
|
||||||
|
market_info_lines = [
|
||||||
|
f"Market: {market_name}",
|
||||||
|
f"Stock Code: {market_data['stock_code']}",
|
||||||
|
f"Current Price: {market_data['current_price']}",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add orderbook if available (domestic markets)
|
||||||
|
if "orderbook" in market_data:
|
||||||
|
market_info_lines.append(
|
||||||
|
f"Orderbook: {json.dumps(market_data['orderbook'], ensure_ascii=False)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add foreigner net if non-zero
|
||||||
|
if market_data.get("foreigner_net", 0) != 0:
|
||||||
|
market_info_lines.append(
|
||||||
|
f"Foreigner Net Buy/Sell: {market_data['foreigner_net']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
market_info = "\n".join(market_info_lines)
|
||||||
|
|
||||||
json_format = (
|
json_format = (
|
||||||
'{"action": "BUY"|"SELL"|"HOLD", '
|
'{"action": "BUY"|"SELL"|"HOLD", '
|
||||||
'"confidence": <int 0-100>, "rationale": "<string>"}'
|
'"confidence": <int 0-100>, "rationale": "<string>"}'
|
||||||
@@ -152,28 +375,153 @@ class GeminiClient:
|
|||||||
# API Call
|
# API Call
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
async def decide(self, market_data: dict[str, Any]) -> TradeDecision:
|
async def decide(
|
||||||
"""Build prompt, call Gemini, and return a parsed decision."""
|
self, market_data: dict[str, Any], news_sentiment: NewsSentiment | None = None
|
||||||
prompt = self.build_prompt(market_data)
|
) -> TradeDecision:
|
||||||
logger.info("Requesting trade decision from Gemini")
|
"""Build prompt, call Gemini, and return a parsed decision.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
market_data: Market data dictionary with price, orderbook, etc.
|
||||||
|
news_sentiment: Optional pre-fetched news sentiment
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed TradeDecision
|
||||||
|
"""
|
||||||
|
# Check cache first
|
||||||
|
if self._cache:
|
||||||
|
cached_decision = self._cache.get(market_data)
|
||||||
|
if cached_decision:
|
||||||
|
self._total_cached_decisions += 1
|
||||||
|
self._total_decisions += 1
|
||||||
|
logger.info(
|
||||||
|
"Cache hit for decision",
|
||||||
|
extra={
|
||||||
|
"action": cached_decision.action,
|
||||||
|
"confidence": cached_decision.confidence,
|
||||||
|
"cache_hit_rate": self.get_cache_hit_rate(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Return cached decision with cached flag
|
||||||
|
return TradeDecision(
|
||||||
|
action=cached_decision.action,
|
||||||
|
confidence=cached_decision.confidence,
|
||||||
|
rationale=cached_decision.rationale,
|
||||||
|
token_count=0,
|
||||||
|
cached=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build optimized prompt
|
||||||
|
if self._enable_optimization:
|
||||||
|
prompt = self._optimizer.build_compressed_prompt(market_data)
|
||||||
|
else:
|
||||||
|
prompt = await self.build_prompt(market_data, news_sentiment)
|
||||||
|
|
||||||
|
# Estimate tokens
|
||||||
|
token_count = self._optimizer.estimate_tokens(prompt)
|
||||||
|
self._total_tokens_used += token_count
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Requesting trade decision from Gemini",
|
||||||
|
extra={"estimated_tokens": token_count, "optimized": self._enable_optimization},
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self._client.aio.models.generate_content(
|
response = await self._client.aio.models.generate_content(
|
||||||
model=self._model_name, contents=prompt,
|
model=self._model_name,
|
||||||
|
contents=prompt,
|
||||||
)
|
)
|
||||||
raw = response.text
|
raw = response.text
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("Gemini API error: %s", exc)
|
logger.error("Gemini API error: %s", exc)
|
||||||
return TradeDecision(
|
return TradeDecision(
|
||||||
action="HOLD", confidence=0, rationale=f"API error: {exc}"
|
action="HOLD", confidence=0, rationale=f"API error: {exc}", token_count=token_count
|
||||||
)
|
)
|
||||||
|
|
||||||
decision = self.parse_response(raw)
|
decision = self.parse_response(raw)
|
||||||
|
self._total_decisions += 1
|
||||||
|
|
||||||
|
# Add token count to decision
|
||||||
|
decision_with_tokens = TradeDecision(
|
||||||
|
action=decision.action,
|
||||||
|
confidence=decision.confidence,
|
||||||
|
rationale=decision.rationale,
|
||||||
|
token_count=token_count,
|
||||||
|
cached=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache if appropriate
|
||||||
|
if self._cache and self._cache.should_cache_decision(decision):
|
||||||
|
self._cache.set(market_data, decision)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Gemini decision",
|
"Gemini decision",
|
||||||
extra={
|
extra={
|
||||||
"action": decision.action,
|
"action": decision.action,
|
||||||
"confidence": decision.confidence,
|
"confidence": decision.confidence,
|
||||||
|
"tokens": token_count,
|
||||||
|
"avg_tokens": self.get_avg_tokens_per_decision(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return decision
|
|
||||||
|
return decision_with_tokens
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Token Efficiency Metrics
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_token_metrics(self) -> dict[str, Any]:
|
||||||
|
"""Get token usage metrics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with token usage statistics
|
||||||
|
"""
|
||||||
|
metrics = {
|
||||||
|
"total_tokens_used": self._total_tokens_used,
|
||||||
|
"total_decisions": self._total_decisions,
|
||||||
|
"total_cached_decisions": self._total_cached_decisions,
|
||||||
|
"avg_tokens_per_decision": self.get_avg_tokens_per_decision(),
|
||||||
|
"cache_hit_rate": self.get_cache_hit_rate(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if self._cache:
|
||||||
|
cache_metrics = self._cache.get_metrics()
|
||||||
|
metrics["cache_metrics"] = cache_metrics.to_dict()
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def get_avg_tokens_per_decision(self) -> float:
|
||||||
|
"""Calculate average tokens per decision.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Average tokens per decision
|
||||||
|
"""
|
||||||
|
if self._total_decisions == 0:
|
||||||
|
return 0.0
|
||||||
|
return self._total_tokens_used / self._total_decisions
|
||||||
|
|
||||||
|
def get_cache_hit_rate(self) -> float:
|
||||||
|
"""Calculate cache hit rate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cache hit rate (0.0 to 1.0)
|
||||||
|
"""
|
||||||
|
if self._total_decisions == 0:
|
||||||
|
return 0.0
|
||||||
|
return self._total_cached_decisions / self._total_decisions
|
||||||
|
|
||||||
|
def reset_metrics(self) -> None:
|
||||||
|
"""Reset token usage metrics."""
|
||||||
|
self._total_tokens_used = 0
|
||||||
|
self._total_decisions = 0
|
||||||
|
self._total_cached_decisions = 0
|
||||||
|
if self._cache:
|
||||||
|
self._cache.reset_metrics()
|
||||||
|
logger.info("Token metrics reset")
|
||||||
|
|
||||||
|
def get_cache(self) -> DecisionCache | None:
|
||||||
|
"""Get the decision cache instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DecisionCache instance or None if caching disabled
|
||||||
|
"""
|
||||||
|
return self._cache
|
||||||
|
|||||||
267
src/brain/prompt_optimizer.py
Normal file
267
src/brain/prompt_optimizer.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
"""Prompt optimization utilities for reducing token usage.
|
||||||
|
|
||||||
|
This module provides tools to compress prompts while maintaining decision quality:
|
||||||
|
- Token counting
|
||||||
|
- Text compression and abbreviation
|
||||||
|
- Template-based prompts with variable slots
|
||||||
|
- Priority-based context truncation
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
# Abbreviation mapping for common terms
|
||||||
|
ABBREVIATIONS = {
|
||||||
|
"price": "P",
|
||||||
|
"volume": "V",
|
||||||
|
"current": "cur",
|
||||||
|
"previous": "prev",
|
||||||
|
"change": "chg",
|
||||||
|
"percentage": "pct",
|
||||||
|
"market": "mkt",
|
||||||
|
"orderbook": "ob",
|
||||||
|
"foreigner": "fgn",
|
||||||
|
"buy": "B",
|
||||||
|
"sell": "S",
|
||||||
|
"hold": "H",
|
||||||
|
"confidence": "conf",
|
||||||
|
"rationale": "reason",
|
||||||
|
"action": "act",
|
||||||
|
"net": "net",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Reverse mapping for decompression
|
||||||
|
REVERSE_ABBREVIATIONS = {v: k for k, v in ABBREVIATIONS.items()}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TokenMetrics:
|
||||||
|
"""Metrics about token usage in a prompt."""
|
||||||
|
|
||||||
|
char_count: int
|
||||||
|
word_count: int
|
||||||
|
estimated_tokens: int # Rough estimate: ~4 chars per token
|
||||||
|
compression_ratio: float = 1.0 # Original / Compressed
|
||||||
|
|
||||||
|
|
||||||
|
class PromptOptimizer:
|
||||||
|
"""Optimizes prompts to reduce token usage while maintaining quality."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def estimate_tokens(text: str) -> int:
|
||||||
|
"""Estimate token count for text.
|
||||||
|
|
||||||
|
Uses a simple heuristic: ~4 characters per token for English.
|
||||||
|
This is approximate but sufficient for optimization purposes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Input text to estimate tokens for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return 0
|
||||||
|
# Simple estimate: 1 token ≈ 4 characters
|
||||||
|
return max(1, len(text) // 4)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def count_tokens(text: str) -> TokenMetrics:
|
||||||
|
"""Count various metrics for a text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Input text to analyze
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TokenMetrics with character, word, and estimated token counts
|
||||||
|
"""
|
||||||
|
char_count = len(text)
|
||||||
|
word_count = len(text.split())
|
||||||
|
estimated_tokens = PromptOptimizer.estimate_tokens(text)
|
||||||
|
|
||||||
|
return TokenMetrics(
|
||||||
|
char_count=char_count,
|
||||||
|
word_count=word_count,
|
||||||
|
estimated_tokens=estimated_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compress_json(data: dict[str, Any]) -> str:
|
||||||
|
"""Compress JSON by removing whitespace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary to serialize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compact JSON string without whitespace
|
||||||
|
"""
|
||||||
|
return json.dumps(data, separators=(",", ":"), ensure_ascii=False)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def abbreviate_text(text: str, aggressive: bool = False) -> str:
|
||||||
|
"""Apply abbreviations to reduce text length.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Input text to abbreviate
|
||||||
|
aggressive: If True, apply more aggressive compression
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Abbreviated text
|
||||||
|
"""
|
||||||
|
result = text
|
||||||
|
|
||||||
|
# Apply word-level abbreviations (case-insensitive)
|
||||||
|
for full, abbr in ABBREVIATIONS.items():
|
||||||
|
# Word boundaries to avoid partial replacements
|
||||||
|
pattern = r"\b" + re.escape(full) + r"\b"
|
||||||
|
result = re.sub(pattern, abbr, result, flags=re.IGNORECASE)
|
||||||
|
|
||||||
|
if aggressive:
|
||||||
|
# Remove articles and filler words
|
||||||
|
result = re.sub(r"\b(a|an|the)\b", "", result, flags=re.IGNORECASE)
|
||||||
|
result = re.sub(r"\b(is|are|was|were)\b", "", result, flags=re.IGNORECASE)
|
||||||
|
# Collapse multiple spaces
|
||||||
|
result = re.sub(r"\s+", " ", result)
|
||||||
|
|
||||||
|
return result.strip()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_compressed_prompt(
|
||||||
|
market_data: dict[str, Any],
|
||||||
|
include_instructions: bool = True,
|
||||||
|
max_length: int | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Build a compressed prompt from market data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
market_data: Market data dictionary with stock info
|
||||||
|
include_instructions: Whether to include full instructions
|
||||||
|
max_length: Maximum character length (truncates if needed)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compressed prompt string
|
||||||
|
"""
|
||||||
|
# Abbreviated market name
|
||||||
|
market_name = market_data.get("market_name", "KR")
|
||||||
|
if "Korea" in market_name:
|
||||||
|
market_name = "KR"
|
||||||
|
elif "United States" in market_name or "US" in market_name:
|
||||||
|
market_name = "US"
|
||||||
|
|
||||||
|
# Core data - always included
|
||||||
|
core_info = {
|
||||||
|
"mkt": market_name,
|
||||||
|
"code": market_data["stock_code"],
|
||||||
|
"P": market_data["current_price"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Optional fields
|
||||||
|
if "orderbook" in market_data and market_data["orderbook"]:
|
||||||
|
ob = market_data["orderbook"]
|
||||||
|
# Compress orderbook: keep only top 3 levels
|
||||||
|
compressed_ob = {
|
||||||
|
"bid": ob.get("bid", [])[:3],
|
||||||
|
"ask": ob.get("ask", [])[:3],
|
||||||
|
}
|
||||||
|
core_info["ob"] = compressed_ob
|
||||||
|
|
||||||
|
if market_data.get("foreigner_net", 0) != 0:
|
||||||
|
core_info["fgn_net"] = market_data["foreigner_net"]
|
||||||
|
|
||||||
|
# Compress to JSON
|
||||||
|
data_str = PromptOptimizer.compress_json(core_info)
|
||||||
|
|
||||||
|
if include_instructions:
|
||||||
|
# Minimal instructions
|
||||||
|
prompt = (
|
||||||
|
f"{market_name} trader. Analyze:\n{data_str}\n\n"
|
||||||
|
'Return JSON: {"act":"BUY"|"SELL"|"HOLD","conf":<0-100>,"reason":"<text>"}\n'
|
||||||
|
"Rules: act=BUY/SELL/HOLD, conf=0-100, reason=concise. No markdown."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Data only (for cached contexts where instructions are known)
|
||||||
|
prompt = data_str
|
||||||
|
|
||||||
|
# Truncate if needed
|
||||||
|
if max_length and len(prompt) > max_length:
|
||||||
|
prompt = prompt[:max_length] + "..."
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def truncate_context(
|
||||||
|
context: dict[str, Any],
|
||||||
|
max_tokens: int,
|
||||||
|
priority_keys: list[str] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Truncate context data to fit within token budget.
|
||||||
|
|
||||||
|
Keeps high-priority keys first, then truncates less important data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context dictionary to truncate
|
||||||
|
max_tokens: Maximum token budget
|
||||||
|
priority_keys: List of keys to keep (in order of priority)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Truncated context dictionary
|
||||||
|
"""
|
||||||
|
if not context:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if priority_keys is None:
|
||||||
|
priority_keys = []
|
||||||
|
|
||||||
|
result: dict[str, Any] = {}
|
||||||
|
current_tokens = 0
|
||||||
|
|
||||||
|
# Add priority keys first
|
||||||
|
for key in priority_keys:
|
||||||
|
if key in context:
|
||||||
|
value_str = json.dumps(context[key])
|
||||||
|
tokens = PromptOptimizer.estimate_tokens(value_str)
|
||||||
|
|
||||||
|
if current_tokens + tokens <= max_tokens:
|
||||||
|
result[key] = context[key]
|
||||||
|
current_tokens += tokens
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Add remaining keys if space available
|
||||||
|
for key, value in context.items():
|
||||||
|
if key in result:
|
||||||
|
continue
|
||||||
|
|
||||||
|
value_str = json.dumps(value)
|
||||||
|
tokens = PromptOptimizer.estimate_tokens(value_str)
|
||||||
|
|
||||||
|
if current_tokens + tokens <= max_tokens:
|
||||||
|
result[key] = value
|
||||||
|
current_tokens += tokens
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calculate_compression_ratio(original: str, compressed: str) -> float:
|
||||||
|
"""Calculate compression ratio between original and compressed text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
original: Original text
|
||||||
|
compressed: Compressed text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compression ratio (original_tokens / compressed_tokens)
|
||||||
|
"""
|
||||||
|
original_tokens = PromptOptimizer.estimate_tokens(original)
|
||||||
|
compressed_tokens = PromptOptimizer.estimate_tokens(compressed)
|
||||||
|
|
||||||
|
if compressed_tokens == 0:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
return original_tokens / compressed_tokens
|
||||||
@@ -19,6 +19,15 @@ class Settings(BaseSettings):
|
|||||||
GEMINI_API_KEY: str
|
GEMINI_API_KEY: str
|
||||||
GEMINI_MODEL: str = "gemini-pro"
|
GEMINI_MODEL: str = "gemini-pro"
|
||||||
|
|
||||||
|
# External Data APIs (optional — for data-driven decisions)
|
||||||
|
NEWS_API_KEY: str | None = None
|
||||||
|
NEWS_API_PROVIDER: str = "alphavantage" # "alphavantage" or "newsapi"
|
||||||
|
MARKET_DATA_API_KEY: str | None = None
|
||||||
|
|
||||||
|
# Legacy field names (for backward compatibility)
|
||||||
|
ALPHA_VANTAGE_API_KEY: str | None = None
|
||||||
|
NEWSAPI_KEY: str | None = None
|
||||||
|
|
||||||
# Risk Management
|
# Risk Management
|
||||||
CIRCUIT_BREAKER_PCT: float = Field(default=-3.0, le=0.0)
|
CIRCUIT_BREAKER_PCT: float = Field(default=-3.0, le=0.0)
|
||||||
FAT_FINGER_PCT: float = Field(default=30.0, gt=0.0, le=100.0)
|
FAT_FINGER_PCT: float = Field(default=30.0, gt=0.0, le=100.0)
|
||||||
|
|||||||
328
src/context/summarizer.py
Normal file
328
src/context/summarizer.py
Normal file
@@ -0,0 +1,328 @@
|
|||||||
|
"""Context summarization for efficient historical data representation.
|
||||||
|
|
||||||
|
This module summarizes old context data instead of including raw details:
|
||||||
|
- Key metrics only (averages, trends, not details)
|
||||||
|
- Rolling window (keep last N days detailed, summarize older)
|
||||||
|
- Aggregate historical data efficiently
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from src.context.layer import ContextLayer
|
||||||
|
from src.context.store import ContextStore
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SummaryStats:
|
||||||
|
"""Statistical summary of historical data."""
|
||||||
|
|
||||||
|
count: int
|
||||||
|
mean: float | None = None
|
||||||
|
min: float | None = None
|
||||||
|
max: float | None = None
|
||||||
|
std: float | None = None
|
||||||
|
trend: str | None = None # "up", "down", "flat"
|
||||||
|
|
||||||
|
|
||||||
|
class ContextSummarizer:
|
||||||
|
"""Summarizes historical context data to reduce token usage."""
|
||||||
|
|
||||||
|
def __init__(self, store: ContextStore) -> None:
|
||||||
|
"""Initialize the context summarizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
store: ContextStore instance for retrieving context data
|
||||||
|
"""
|
||||||
|
self.store = store
|
||||||
|
|
||||||
|
def summarize_numeric_values(self, values: list[float]) -> SummaryStats:
|
||||||
|
"""Summarize a list of numeric values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
values: List of numeric values to summarize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SummaryStats with mean, min, max, std, and trend
|
||||||
|
"""
|
||||||
|
if not values:
|
||||||
|
return SummaryStats(count=0)
|
||||||
|
|
||||||
|
count = len(values)
|
||||||
|
mean = sum(values) / count
|
||||||
|
min_val = min(values)
|
||||||
|
max_val = max(values)
|
||||||
|
|
||||||
|
# Calculate standard deviation
|
||||||
|
if count > 1:
|
||||||
|
variance = sum((x - mean) ** 2 for x in values) / (count - 1)
|
||||||
|
std = variance**0.5
|
||||||
|
else:
|
||||||
|
std = 0.0
|
||||||
|
|
||||||
|
# Determine trend
|
||||||
|
trend = "flat"
|
||||||
|
if count >= 3:
|
||||||
|
# Simple trend: compare first third vs last third
|
||||||
|
first_third = values[: count // 3]
|
||||||
|
last_third = values[-(count // 3) :]
|
||||||
|
first_avg = sum(first_third) / len(first_third)
|
||||||
|
last_avg = sum(last_third) / len(last_third)
|
||||||
|
|
||||||
|
# Trend threshold: 5% change
|
||||||
|
threshold = 0.05 * abs(first_avg) if first_avg != 0 else 0.01
|
||||||
|
|
||||||
|
if last_avg > first_avg + threshold:
|
||||||
|
trend = "up"
|
||||||
|
elif last_avg < first_avg - threshold:
|
||||||
|
trend = "down"
|
||||||
|
|
||||||
|
return SummaryStats(
|
||||||
|
count=count,
|
||||||
|
mean=round(mean, 4),
|
||||||
|
min=round(min_val, 4),
|
||||||
|
max=round(max_val, 4),
|
||||||
|
std=round(std, 4),
|
||||||
|
trend=trend,
|
||||||
|
)
|
||||||
|
|
||||||
|
def summarize_layer(
|
||||||
|
self,
|
||||||
|
layer: ContextLayer,
|
||||||
|
start_date: datetime | None = None,
|
||||||
|
end_date: datetime | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Summarize all context data for a layer within a date range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer: Context layer to summarize
|
||||||
|
start_date: Start date (inclusive), None for all
|
||||||
|
end_date: End date (inclusive), None for now
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with summarized metrics
|
||||||
|
"""
|
||||||
|
if end_date is None:
|
||||||
|
end_date = datetime.now(UTC)
|
||||||
|
|
||||||
|
# Get all contexts for this layer
|
||||||
|
all_contexts = self.store.get_all_contexts(layer)
|
||||||
|
|
||||||
|
if not all_contexts:
|
||||||
|
return {"summary": "No data available", "count": 0}
|
||||||
|
|
||||||
|
# Group numeric values by key
|
||||||
|
numeric_data: dict[str, list[float]] = {}
|
||||||
|
text_data: dict[str, list[str]] = {}
|
||||||
|
|
||||||
|
for key, value in all_contexts.items():
|
||||||
|
# Try to extract numeric values
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
if key not in numeric_data:
|
||||||
|
numeric_data[key] = []
|
||||||
|
numeric_data[key].append(float(value))
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
# Extract numeric fields from dict
|
||||||
|
for subkey, subvalue in value.items():
|
||||||
|
if isinstance(subvalue, (int, float)):
|
||||||
|
full_key = f"{key}.{subkey}"
|
||||||
|
if full_key not in numeric_data:
|
||||||
|
numeric_data[full_key] = []
|
||||||
|
numeric_data[full_key].append(float(subvalue))
|
||||||
|
elif isinstance(value, str):
|
||||||
|
if key not in text_data:
|
||||||
|
text_data[key] = []
|
||||||
|
text_data[key].append(value)
|
||||||
|
|
||||||
|
# Summarize numeric data
|
||||||
|
summary: dict[str, Any] = {}
|
||||||
|
|
||||||
|
for key, values in numeric_data.items():
|
||||||
|
stats = self.summarize_numeric_values(values)
|
||||||
|
summary[key] = {
|
||||||
|
"count": stats.count,
|
||||||
|
"avg": stats.mean,
|
||||||
|
"range": [stats.min, stats.max],
|
||||||
|
"trend": stats.trend,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Summarize text data (just counts)
|
||||||
|
for key, values in text_data.items():
|
||||||
|
summary[f"{key}_count"] = len(values)
|
||||||
|
|
||||||
|
summary["total_entries"] = len(all_contexts)
|
||||||
|
|
||||||
|
return summary
|
||||||
|
|
||||||
|
def rolling_window_summary(
|
||||||
|
self,
|
||||||
|
layer: ContextLayer,
|
||||||
|
window_days: int = 30,
|
||||||
|
summarize_older: bool = True,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create a rolling window summary.
|
||||||
|
|
||||||
|
Recent data (within window) is kept detailed.
|
||||||
|
Older data is summarized to key metrics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer: Context layer to summarize
|
||||||
|
window_days: Number of days to keep detailed
|
||||||
|
summarize_older: Whether to summarize data older than window
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with recent (detailed) and historical (summary) data
|
||||||
|
"""
|
||||||
|
result: dict[str, Any] = {
|
||||||
|
"window_days": window_days,
|
||||||
|
"recent_data": {},
|
||||||
|
"historical_summary": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get all contexts
|
||||||
|
all_contexts = self.store.get_all_contexts(layer)
|
||||||
|
|
||||||
|
recent_values: dict[str, list[float]] = {}
|
||||||
|
historical_values: dict[str, list[float]] = {}
|
||||||
|
|
||||||
|
for key, value in all_contexts.items():
|
||||||
|
# For simplicity, treat all numeric values
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
# Note: We don't have timestamps in context keys
|
||||||
|
# This is a simplified implementation
|
||||||
|
# In practice, would need to check timeframe field
|
||||||
|
|
||||||
|
# For now, put recent data in window
|
||||||
|
if key not in recent_values:
|
||||||
|
recent_values[key] = []
|
||||||
|
recent_values[key].append(float(value))
|
||||||
|
|
||||||
|
# Detailed recent data
|
||||||
|
result["recent_data"] = {key: values[-10:] for key, values in recent_values.items()}
|
||||||
|
|
||||||
|
# Summarized historical data
|
||||||
|
if summarize_older:
|
||||||
|
for key, values in historical_values.items():
|
||||||
|
stats = self.summarize_numeric_values(values)
|
||||||
|
result["historical_summary"][key] = {
|
||||||
|
"count": stats.count,
|
||||||
|
"avg": stats.mean,
|
||||||
|
"trend": stats.trend,
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def aggregate_to_higher_layer(
|
||||||
|
self,
|
||||||
|
source_layer: ContextLayer,
|
||||||
|
target_layer: ContextLayer,
|
||||||
|
metric_key: str,
|
||||||
|
aggregation_func: str = "mean",
|
||||||
|
) -> float | None:
|
||||||
|
"""Aggregate data from source layer to target layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_layer: Source context layer (more granular)
|
||||||
|
target_layer: Target context layer (less granular)
|
||||||
|
metric_key: Key of metric to aggregate
|
||||||
|
aggregation_func: Aggregation function ("mean", "sum", "max", "min")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Aggregated value, or None if no data available
|
||||||
|
"""
|
||||||
|
# Get all contexts from source layer
|
||||||
|
source_contexts = self.store.get_all_contexts(source_layer)
|
||||||
|
|
||||||
|
# Extract values for metric_key
|
||||||
|
values = []
|
||||||
|
for key, value in source_contexts.items():
|
||||||
|
if key == metric_key and isinstance(value, (int, float)):
|
||||||
|
values.append(float(value))
|
||||||
|
elif isinstance(value, dict) and metric_key in value:
|
||||||
|
subvalue = value[metric_key]
|
||||||
|
if isinstance(subvalue, (int, float)):
|
||||||
|
values.append(float(subvalue))
|
||||||
|
|
||||||
|
if not values:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Apply aggregation function
|
||||||
|
if aggregation_func == "mean":
|
||||||
|
return sum(values) / len(values)
|
||||||
|
elif aggregation_func == "sum":
|
||||||
|
return sum(values)
|
||||||
|
elif aggregation_func == "max":
|
||||||
|
return max(values)
|
||||||
|
elif aggregation_func == "min":
|
||||||
|
return min(values)
|
||||||
|
else:
|
||||||
|
return sum(values) / len(values) # Default to mean
|
||||||
|
|
||||||
|
def create_compact_summary(
|
||||||
|
self,
|
||||||
|
layers: list[ContextLayer],
|
||||||
|
top_n_metrics: int = 5,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create a compact summary across multiple layers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layers: List of context layers to summarize
|
||||||
|
top_n_metrics: Number of top metrics to include per layer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compact summary dictionary
|
||||||
|
"""
|
||||||
|
summary: dict[str, Any] = {}
|
||||||
|
|
||||||
|
for layer in layers:
|
||||||
|
layer_summary = self.summarize_layer(layer)
|
||||||
|
|
||||||
|
# Keep only top N metrics (by count/relevance)
|
||||||
|
metrics = []
|
||||||
|
for key, value in layer_summary.items():
|
||||||
|
if isinstance(value, dict) and "count" in value:
|
||||||
|
metrics.append((key, value, value["count"]))
|
||||||
|
|
||||||
|
# Sort by count (descending)
|
||||||
|
metrics.sort(key=lambda x: x[2], reverse=True)
|
||||||
|
|
||||||
|
# Keep top N
|
||||||
|
top_metrics = {m[0]: m[1] for m in metrics[:top_n_metrics]}
|
||||||
|
|
||||||
|
summary[layer.value] = top_metrics
|
||||||
|
|
||||||
|
return summary
|
||||||
|
|
||||||
|
def format_summary_for_prompt(self, summary: dict[str, Any]) -> str:
|
||||||
|
"""Format summary for inclusion in a prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
summary: Summary dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted string for prompt
|
||||||
|
"""
|
||||||
|
lines = []
|
||||||
|
|
||||||
|
for layer, metrics in summary.items():
|
||||||
|
if not metrics:
|
||||||
|
continue
|
||||||
|
|
||||||
|
lines.append(f"{layer}:")
|
||||||
|
for key, value in metrics.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
# Format as: key: avg=X, trend=Y
|
||||||
|
parts = []
|
||||||
|
if "avg" in value and value["avg"] is not None:
|
||||||
|
parts.append(f"avg={value['avg']:.2f}")
|
||||||
|
if "trend" in value and value["trend"]:
|
||||||
|
parts.append(f"trend={value['trend']}")
|
||||||
|
if parts:
|
||||||
|
lines.append(f" {key}: {', '.join(parts)}")
|
||||||
|
else:
|
||||||
|
lines.append(f" {key}: {value}")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
205
src/data/README.md
Normal file
205
src/data/README.md
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
# External Data Integration
|
||||||
|
|
||||||
|
This module provides objective external data sources to enhance trading decisions beyond just market prices and user input.
|
||||||
|
|
||||||
|
## Modules
|
||||||
|
|
||||||
|
### `news_api.py` - News Sentiment Analysis
|
||||||
|
|
||||||
|
Fetches real-time news for stocks with sentiment scoring.
|
||||||
|
|
||||||
|
**Features:**
|
||||||
|
- Alpha Vantage and NewsAPI.org support
|
||||||
|
- Sentiment scoring (-1.0 to +1.0)
|
||||||
|
- 5-minute caching to minimize API quota usage
|
||||||
|
- Graceful fallback when API unavailable
|
||||||
|
|
||||||
|
**Usage:**
|
||||||
|
```python
|
||||||
|
from src.data.news_api import NewsAPI
|
||||||
|
|
||||||
|
# Initialize with API key
|
||||||
|
news_api = NewsAPI(api_key="your_key", provider="alphavantage")
|
||||||
|
|
||||||
|
# Fetch news sentiment
|
||||||
|
sentiment = await news_api.get_news_sentiment("AAPL")
|
||||||
|
if sentiment:
|
||||||
|
print(f"Average sentiment: {sentiment.avg_sentiment}")
|
||||||
|
for article in sentiment.articles[:3]:
|
||||||
|
print(f"{article.title} ({article.sentiment_score})")
|
||||||
|
```
|
||||||
|
|
||||||
|
### `economic_calendar.py` - Major Economic Events
|
||||||
|
|
||||||
|
Tracks FOMC meetings, GDP releases, CPI, earnings calendars, and other market-moving events.
|
||||||
|
|
||||||
|
**Features:**
|
||||||
|
- High-impact event tracking (FOMC, GDP, CPI)
|
||||||
|
- Earnings calendar per stock
|
||||||
|
- Event proximity checking
|
||||||
|
- Hardcoded major events for 2026 (no API required)
|
||||||
|
|
||||||
|
**Usage:**
|
||||||
|
```python
|
||||||
|
from src.data.economic_calendar import EconomicCalendar
|
||||||
|
|
||||||
|
calendar = EconomicCalendar()
|
||||||
|
calendar.load_hardcoded_events()
|
||||||
|
|
||||||
|
# Get upcoming high-impact events
|
||||||
|
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="HIGH")
|
||||||
|
print(f"High-impact events: {upcoming.high_impact_count}")
|
||||||
|
|
||||||
|
# Check if near earnings
|
||||||
|
earnings_date = calendar.get_earnings_date("AAPL")
|
||||||
|
if earnings_date:
|
||||||
|
print(f"Next earnings: {earnings_date}")
|
||||||
|
|
||||||
|
# Check for high volatility period
|
||||||
|
if calendar.is_high_volatility_period(hours_ahead=24):
|
||||||
|
print("High-impact event imminent!")
|
||||||
|
```
|
||||||
|
|
||||||
|
### `market_data.py` - Market Indicators
|
||||||
|
|
||||||
|
Provides market breadth, sector performance, and sentiment indicators.
|
||||||
|
|
||||||
|
**Features:**
|
||||||
|
- Market sentiment levels (Fear & Greed equivalent)
|
||||||
|
- Market breadth (advancing/declining stocks)
|
||||||
|
- Sector performance tracking
|
||||||
|
- Fear/Greed score calculation
|
||||||
|
|
||||||
|
**Usage:**
|
||||||
|
```python
|
||||||
|
from src.data.market_data import MarketData
|
||||||
|
|
||||||
|
market_data = MarketData(api_key="your_key")
|
||||||
|
|
||||||
|
# Get market sentiment
|
||||||
|
sentiment = market_data.get_market_sentiment()
|
||||||
|
print(f"Market sentiment: {sentiment.name}")
|
||||||
|
|
||||||
|
# Get full indicators
|
||||||
|
indicators = market_data.get_market_indicators("US")
|
||||||
|
print(f"Sentiment: {indicators.sentiment.name}")
|
||||||
|
print(f"A/D Ratio: {indicators.breadth.advance_decline_ratio}")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Integration with GeminiClient
|
||||||
|
|
||||||
|
The external data sources are seamlessly integrated into the AI decision engine:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.brain.gemini_client import GeminiClient
|
||||||
|
from src.data.news_api import NewsAPI
|
||||||
|
from src.data.economic_calendar import EconomicCalendar
|
||||||
|
from src.data.market_data import MarketData
|
||||||
|
from src.config import Settings
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
|
||||||
|
# Initialize data sources
|
||||||
|
news_api = NewsAPI(api_key=settings.NEWS_API_KEY, provider=settings.NEWS_API_PROVIDER)
|
||||||
|
calendar = EconomicCalendar()
|
||||||
|
calendar.load_hardcoded_events()
|
||||||
|
market_data = MarketData(api_key=settings.MARKET_DATA_API_KEY)
|
||||||
|
|
||||||
|
# Create enhanced client
|
||||||
|
client = GeminiClient(
|
||||||
|
settings,
|
||||||
|
news_api=news_api,
|
||||||
|
economic_calendar=calendar,
|
||||||
|
market_data=market_data
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make decision with external context
|
||||||
|
market_data_dict = {
|
||||||
|
"stock_code": "AAPL",
|
||||||
|
"current_price": 180.0,
|
||||||
|
"market_name": "US stock market"
|
||||||
|
}
|
||||||
|
|
||||||
|
decision = await client.decide(market_data_dict)
|
||||||
|
```
|
||||||
|
|
||||||
|
The external data is automatically included in the prompt sent to Gemini:
|
||||||
|
|
||||||
|
```
|
||||||
|
Market: US stock market
|
||||||
|
Stock Code: AAPL
|
||||||
|
Current Price: 180.0
|
||||||
|
|
||||||
|
EXTERNAL DATA:
|
||||||
|
News Sentiment: 0.85 (from 10 articles)
|
||||||
|
1. [Reuters] Apple hits record high (sentiment: 0.92)
|
||||||
|
2. [Bloomberg] Strong iPhone sales (sentiment: 0.78)
|
||||||
|
3. [CNBC] Tech sector rallying (sentiment: 0.85)
|
||||||
|
|
||||||
|
Upcoming High-Impact Events: 2 in next 7 days
|
||||||
|
Next: FOMC Meeting (FOMC) on 2026-03-18
|
||||||
|
Earnings: AAPL on 2026-02-10
|
||||||
|
|
||||||
|
Market Sentiment: GREED
|
||||||
|
Advance/Decline Ratio: 2.35
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Add these to your `.env` file:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# External Data APIs (optional)
|
||||||
|
NEWS_API_KEY=your_alpha_vantage_key
|
||||||
|
NEWS_API_PROVIDER=alphavantage # or "newsapi"
|
||||||
|
MARKET_DATA_API_KEY=your_market_data_key
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Recommendations
|
||||||
|
|
||||||
|
### Alpha Vantage (News)
|
||||||
|
- **Free tier:** 5 calls/min, 500 calls/day
|
||||||
|
- **Pros:** Provides sentiment scores, no credit card required
|
||||||
|
- **URL:** https://www.alphavantage.co/
|
||||||
|
|
||||||
|
### NewsAPI.org
|
||||||
|
- **Free tier:** 100 requests/day
|
||||||
|
- **Pros:** Large news coverage, easy to use
|
||||||
|
- **Cons:** No sentiment scores (we use keyword heuristics)
|
||||||
|
- **URL:** https://newsapi.org/
|
||||||
|
|
||||||
|
## Caching Strategy
|
||||||
|
|
||||||
|
To minimize API quota usage:
|
||||||
|
|
||||||
|
1. **News:** 5-minute TTL cache per stock
|
||||||
|
2. **Economic Calendar:** Loaded once at startup (hardcoded events)
|
||||||
|
3. **Market Data:** Fetched per decision (lightweight)
|
||||||
|
|
||||||
|
## Graceful Degradation
|
||||||
|
|
||||||
|
The system works gracefully without external data:
|
||||||
|
|
||||||
|
- If no API keys provided → decisions work with just market prices
|
||||||
|
- If API fails → decision continues without external context
|
||||||
|
- If cache expired → attempts refetch, falls back to no data
|
||||||
|
- Errors are logged but never block trading decisions
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
All modules have comprehensive test coverage (81%+):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pytest tests/test_data_integration.py -v --cov=src/data
|
||||||
|
```
|
||||||
|
|
||||||
|
Tests use mocks to avoid requiring real API keys.
|
||||||
|
|
||||||
|
## Future Enhancements
|
||||||
|
|
||||||
|
- Twitter/X sentiment analysis
|
||||||
|
- Reddit WallStreetBets sentiment
|
||||||
|
- Options flow data
|
||||||
|
- Insider trading activity
|
||||||
|
- Analyst upgrades/downgrades
|
||||||
|
- Real-time economic data APIs
|
||||||
5
src/data/__init__.py
Normal file
5
src/data/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""External data integration for objective decision-making."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
__all__ = ["NewsAPI", "EconomicCalendar", "MarketData"]
|
||||||
219
src/data/economic_calendar.py
Normal file
219
src/data/economic_calendar.py
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
"""Economic calendar integration for major market events.
|
||||||
|
|
||||||
|
Tracks FOMC meetings, GDP releases, CPI, earnings calendars, and other
|
||||||
|
market-moving events.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EconomicEvent:
|
||||||
|
"""Single economic event."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
event_type: str # "FOMC", "GDP", "CPI", "EARNINGS", etc.
|
||||||
|
datetime: datetime
|
||||||
|
impact: str # "HIGH", "MEDIUM", "LOW"
|
||||||
|
country: str
|
||||||
|
description: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UpcomingEvents:
|
||||||
|
"""Collection of upcoming economic events."""
|
||||||
|
|
||||||
|
events: list[EconomicEvent]
|
||||||
|
high_impact_count: int
|
||||||
|
next_major_event: EconomicEvent | None
|
||||||
|
|
||||||
|
|
||||||
|
class EconomicCalendar:
|
||||||
|
"""Economic calendar with event tracking and impact scoring."""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str | None = None) -> None:
|
||||||
|
"""Initialize economic calendar.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API key for calendar provider (None for testing/hardcoded)
|
||||||
|
"""
|
||||||
|
self._api_key = api_key
|
||||||
|
# For now, use hardcoded major events (can be extended with API)
|
||||||
|
self._events: list[EconomicEvent] = []
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_upcoming_events(
|
||||||
|
self, days_ahead: int = 7, min_impact: str = "MEDIUM"
|
||||||
|
) -> UpcomingEvents:
|
||||||
|
"""Get upcoming economic events within specified timeframe.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
days_ahead: Number of days to look ahead
|
||||||
|
min_impact: Minimum impact level ("LOW", "MEDIUM", "HIGH")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UpcomingEvents with filtered events
|
||||||
|
"""
|
||||||
|
now = datetime.now()
|
||||||
|
end_date = now + timedelta(days=days_ahead)
|
||||||
|
|
||||||
|
# Filter events by timeframe and impact
|
||||||
|
upcoming = [
|
||||||
|
event
|
||||||
|
for event in self._events
|
||||||
|
if now <= event.datetime <= end_date
|
||||||
|
and self._impact_level(event.impact) >= self._impact_level(min_impact)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Sort by datetime
|
||||||
|
upcoming.sort(key=lambda e: e.datetime)
|
||||||
|
|
||||||
|
# Count high-impact events
|
||||||
|
high_impact_count = sum(1 for e in upcoming if e.impact == "HIGH")
|
||||||
|
|
||||||
|
# Get next major event
|
||||||
|
next_major = None
|
||||||
|
for event in upcoming:
|
||||||
|
if event.impact == "HIGH":
|
||||||
|
next_major = event
|
||||||
|
break
|
||||||
|
|
||||||
|
return UpcomingEvents(
|
||||||
|
events=upcoming,
|
||||||
|
high_impact_count=high_impact_count,
|
||||||
|
next_major_event=next_major,
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_event(self, event: EconomicEvent) -> None:
|
||||||
|
"""Add an economic event to the calendar."""
|
||||||
|
self._events.append(event)
|
||||||
|
|
||||||
|
def clear_events(self) -> None:
|
||||||
|
"""Clear all events (useful for testing)."""
|
||||||
|
self._events.clear()
|
||||||
|
|
||||||
|
def get_earnings_date(self, stock_code: str) -> datetime | None:
|
||||||
|
"""Get next earnings date for a stock.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stock_code: Stock ticker symbol
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Next earnings datetime or None if not found
|
||||||
|
"""
|
||||||
|
now = datetime.now()
|
||||||
|
earnings_events = [
|
||||||
|
event
|
||||||
|
for event in self._events
|
||||||
|
if event.event_type == "EARNINGS"
|
||||||
|
and stock_code.upper() in event.name.upper()
|
||||||
|
and event.datetime > now
|
||||||
|
]
|
||||||
|
|
||||||
|
if not earnings_events:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Return earliest upcoming earnings
|
||||||
|
earnings_events.sort(key=lambda e: e.datetime)
|
||||||
|
return earnings_events[0].datetime
|
||||||
|
|
||||||
|
def load_hardcoded_events(self) -> None:
|
||||||
|
"""Load hardcoded major economic events for 2026.
|
||||||
|
|
||||||
|
This is a fallback when no API is available.
|
||||||
|
"""
|
||||||
|
# Major FOMC meetings in 2026 (estimated)
|
||||||
|
fomc_dates = [
|
||||||
|
datetime(2026, 3, 18),
|
||||||
|
datetime(2026, 5, 6),
|
||||||
|
datetime(2026, 6, 17),
|
||||||
|
datetime(2026, 7, 29),
|
||||||
|
datetime(2026, 9, 16),
|
||||||
|
datetime(2026, 11, 4),
|
||||||
|
datetime(2026, 12, 16),
|
||||||
|
]
|
||||||
|
|
||||||
|
for date in fomc_dates:
|
||||||
|
self.add_event(
|
||||||
|
EconomicEvent(
|
||||||
|
name="FOMC Meeting",
|
||||||
|
event_type="FOMC",
|
||||||
|
datetime=date,
|
||||||
|
impact="HIGH",
|
||||||
|
country="US",
|
||||||
|
description="Federal Reserve interest rate decision",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Quarterly GDP releases (estimated)
|
||||||
|
gdp_dates = [
|
||||||
|
datetime(2026, 4, 28),
|
||||||
|
datetime(2026, 7, 30),
|
||||||
|
datetime(2026, 10, 29),
|
||||||
|
]
|
||||||
|
|
||||||
|
for date in gdp_dates:
|
||||||
|
self.add_event(
|
||||||
|
EconomicEvent(
|
||||||
|
name="US GDP Release",
|
||||||
|
event_type="GDP",
|
||||||
|
datetime=date,
|
||||||
|
impact="HIGH",
|
||||||
|
country="US",
|
||||||
|
description="Quarterly GDP growth rate",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Monthly CPI releases (12th of each month, estimated)
|
||||||
|
for month in range(1, 13):
|
||||||
|
try:
|
||||||
|
cpi_date = datetime(2026, month, 12)
|
||||||
|
self.add_event(
|
||||||
|
EconomicEvent(
|
||||||
|
name="US CPI Release",
|
||||||
|
event_type="CPI",
|
||||||
|
datetime=cpi_date,
|
||||||
|
impact="HIGH",
|
||||||
|
country="US",
|
||||||
|
description="Consumer Price Index inflation data",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _impact_level(self, impact: str) -> int:
|
||||||
|
"""Convert impact string to numeric level."""
|
||||||
|
levels = {"LOW": 1, "MEDIUM": 2, "HIGH": 3}
|
||||||
|
return levels.get(impact.upper(), 0)
|
||||||
|
|
||||||
|
def is_high_volatility_period(self, hours_ahead: int = 24) -> bool:
|
||||||
|
"""Check if we're near a high-impact event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hours_ahead: Number of hours to look ahead
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if high-impact event is imminent
|
||||||
|
"""
|
||||||
|
now = datetime.now()
|
||||||
|
threshold = now + timedelta(hours=hours_ahead)
|
||||||
|
|
||||||
|
for event in self._events:
|
||||||
|
if event.impact == "HIGH" and now <= event.datetime <= threshold:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
198
src/data/market_data.py
Normal file
198
src/data/market_data.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
"""Additional market data indicators beyond basic price data.
|
||||||
|
|
||||||
|
Provides market breadth, sector performance, and market sentiment indicators.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MarketSentiment(Enum):
|
||||||
|
"""Overall market sentiment levels."""
|
||||||
|
|
||||||
|
EXTREME_FEAR = 1
|
||||||
|
FEAR = 2
|
||||||
|
NEUTRAL = 3
|
||||||
|
GREED = 4
|
||||||
|
EXTREME_GREED = 5
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SectorPerformance:
|
||||||
|
"""Performance metrics for a market sector."""
|
||||||
|
|
||||||
|
sector_name: str
|
||||||
|
daily_change_pct: float
|
||||||
|
weekly_change_pct: float
|
||||||
|
leader_stock: str # Best performing stock in sector
|
||||||
|
laggard_stock: str # Worst performing stock in sector
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MarketBreadth:
|
||||||
|
"""Market breadth indicators."""
|
||||||
|
|
||||||
|
advancing_stocks: int
|
||||||
|
declining_stocks: int
|
||||||
|
unchanged_stocks: int
|
||||||
|
new_highs: int
|
||||||
|
new_lows: int
|
||||||
|
advance_decline_ratio: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MarketIndicators:
|
||||||
|
"""Aggregated market indicators."""
|
||||||
|
|
||||||
|
sentiment: MarketSentiment
|
||||||
|
breadth: MarketBreadth
|
||||||
|
sector_performance: list[SectorPerformance]
|
||||||
|
vix_level: float | None # Volatility index if available
|
||||||
|
|
||||||
|
|
||||||
|
class MarketData:
|
||||||
|
"""Market data provider for additional indicators."""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str | None = None) -> None:
|
||||||
|
"""Initialize market data provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API key for data provider (None for testing)
|
||||||
|
"""
|
||||||
|
self._api_key = api_key
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_market_sentiment(self) -> MarketSentiment:
|
||||||
|
"""Get current market sentiment level.
|
||||||
|
|
||||||
|
This is a simplified version. In production, this would integrate
|
||||||
|
with Fear & Greed Index or similar sentiment indicators.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MarketSentiment enum value
|
||||||
|
"""
|
||||||
|
# Default to neutral when API not available
|
||||||
|
if self._api_key is None:
|
||||||
|
logger.debug("No market data API key — returning NEUTRAL sentiment")
|
||||||
|
return MarketSentiment.NEUTRAL
|
||||||
|
|
||||||
|
# TODO: Integrate with actual sentiment API
|
||||||
|
return MarketSentiment.NEUTRAL
|
||||||
|
|
||||||
|
def get_market_breadth(self, market: str = "US") -> MarketBreadth | None:
|
||||||
|
"""Get market breadth indicators.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
market: Market code ("US", "KR", etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MarketBreadth object or None if unavailable
|
||||||
|
"""
|
||||||
|
if self._api_key is None:
|
||||||
|
logger.debug("No market data API key — returning None for breadth")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# TODO: Integrate with actual market breadth API
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_sector_performance(
|
||||||
|
self, market: str = "US"
|
||||||
|
) -> list[SectorPerformance]:
|
||||||
|
"""Get sector performance rankings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
market: Market code ("US", "KR", etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SectorPerformance objects, sorted by daily change
|
||||||
|
"""
|
||||||
|
if self._api_key is None:
|
||||||
|
logger.debug("No market data API key — returning empty sector list")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# TODO: Integrate with actual sector performance API
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_market_indicators(self, market: str = "US") -> MarketIndicators:
|
||||||
|
"""Get aggregated market indicators.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
market: Market code ("US", "KR", etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MarketIndicators with all available data
|
||||||
|
"""
|
||||||
|
sentiment = self.get_market_sentiment()
|
||||||
|
breadth = self.get_market_breadth(market)
|
||||||
|
sectors = self.get_sector_performance(market)
|
||||||
|
|
||||||
|
# Default breadth if unavailable
|
||||||
|
if breadth is None:
|
||||||
|
breadth = MarketBreadth(
|
||||||
|
advancing_stocks=0,
|
||||||
|
declining_stocks=0,
|
||||||
|
unchanged_stocks=0,
|
||||||
|
new_highs=0,
|
||||||
|
new_lows=0,
|
||||||
|
advance_decline_ratio=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
return MarketIndicators(
|
||||||
|
sentiment=sentiment,
|
||||||
|
breadth=breadth,
|
||||||
|
sector_performance=sectors,
|
||||||
|
vix_level=None, # TODO: Add VIX integration
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Helper Methods
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def calculate_fear_greed_score(
|
||||||
|
self, breadth: MarketBreadth, vix: float | None = None
|
||||||
|
) -> int:
|
||||||
|
"""Calculate a simple fear/greed score (0-100).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
breadth: Market breadth data
|
||||||
|
vix: VIX level (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Score from 0 (extreme fear) to 100 (extreme greed)
|
||||||
|
"""
|
||||||
|
# Start at neutral
|
||||||
|
score = 50
|
||||||
|
|
||||||
|
# Adjust based on advance/decline ratio
|
||||||
|
if breadth.advance_decline_ratio > 1.5:
|
||||||
|
score += 20
|
||||||
|
elif breadth.advance_decline_ratio > 1.0:
|
||||||
|
score += 10
|
||||||
|
elif breadth.advance_decline_ratio < 0.5:
|
||||||
|
score -= 20
|
||||||
|
elif breadth.advance_decline_ratio < 1.0:
|
||||||
|
score -= 10
|
||||||
|
|
||||||
|
# Adjust based on new highs/lows
|
||||||
|
if breadth.new_highs > breadth.new_lows * 2:
|
||||||
|
score += 15
|
||||||
|
elif breadth.new_lows > breadth.new_highs * 2:
|
||||||
|
score -= 15
|
||||||
|
|
||||||
|
# Adjust based on VIX if available
|
||||||
|
if vix is not None:
|
||||||
|
if vix > 30: # High volatility = fear
|
||||||
|
score -= 15
|
||||||
|
elif vix < 15: # Low volatility = complacency/greed
|
||||||
|
score += 10
|
||||||
|
|
||||||
|
# Clamp to 0-100
|
||||||
|
return max(0, min(100, score))
|
||||||
316
src/data/news_api.py
Normal file
316
src/data/news_api.py
Normal file
@@ -0,0 +1,316 @@
|
|||||||
|
"""News API integration with sentiment analysis and caching.
|
||||||
|
|
||||||
|
Fetches real-time news for stocks using free-tier APIs (Alpha Vantage or NewsAPI).
|
||||||
|
Includes 5-minute caching to minimize API quota usage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Cache entries expire after 5 minutes
|
||||||
|
CACHE_TTL_SECONDS = 300
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NewsArticle:
|
||||||
|
"""Single news article with sentiment."""
|
||||||
|
|
||||||
|
title: str
|
||||||
|
summary: str
|
||||||
|
source: str
|
||||||
|
published_at: str
|
||||||
|
sentiment_score: float # -1.0 (negative) to +1.0 (positive)
|
||||||
|
url: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NewsSentiment:
|
||||||
|
"""Aggregated news sentiment for a stock."""
|
||||||
|
|
||||||
|
stock_code: str
|
||||||
|
articles: list[NewsArticle]
|
||||||
|
avg_sentiment: float # Average sentiment across all articles
|
||||||
|
article_count: int
|
||||||
|
fetched_at: float # Unix timestamp
|
||||||
|
|
||||||
|
|
||||||
|
class NewsAPI:
|
||||||
|
"""News API client with sentiment analysis and caching."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str | None = None,
|
||||||
|
provider: str = "alphavantage",
|
||||||
|
cache_ttl: int = CACHE_TTL_SECONDS,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize NewsAPI client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API key for the news provider (None for testing)
|
||||||
|
provider: News provider ("alphavantage" or "newsapi")
|
||||||
|
cache_ttl: Cache time-to-live in seconds
|
||||||
|
"""
|
||||||
|
self._api_key = api_key
|
||||||
|
self._provider = provider
|
||||||
|
self._cache_ttl = cache_ttl
|
||||||
|
self._cache: dict[str, NewsSentiment] = {}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def get_news_sentiment(self, stock_code: str) -> NewsSentiment | None:
|
||||||
|
"""Fetch news sentiment for a stock with caching.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stock_code: Stock ticker symbol (e.g., "AAPL", "005930")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
NewsSentiment object or None if fetch fails or API unavailable
|
||||||
|
"""
|
||||||
|
# Check cache first
|
||||||
|
cached = self._get_from_cache(stock_code)
|
||||||
|
if cached is not None:
|
||||||
|
logger.debug("News cache hit for %s", stock_code)
|
||||||
|
return cached
|
||||||
|
|
||||||
|
# API key required for real requests
|
||||||
|
if self._api_key is None:
|
||||||
|
logger.warning("No news API key provided — returning None")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Fetch from API
|
||||||
|
try:
|
||||||
|
sentiment = await self._fetch_news(stock_code)
|
||||||
|
if sentiment is not None:
|
||||||
|
self._cache[stock_code] = sentiment
|
||||||
|
return sentiment
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to fetch news for %s: %s", stock_code, exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def clear_cache(self) -> None:
|
||||||
|
"""Clear the news cache (useful for testing)."""
|
||||||
|
self._cache.clear()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Cache Management
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _get_from_cache(self, stock_code: str) -> NewsSentiment | None:
|
||||||
|
"""Retrieve cached sentiment if not expired."""
|
||||||
|
if stock_code not in self._cache:
|
||||||
|
return None
|
||||||
|
|
||||||
|
cached = self._cache[stock_code]
|
||||||
|
age = time.time() - cached.fetched_at
|
||||||
|
|
||||||
|
if age > self._cache_ttl:
|
||||||
|
logger.debug("News cache expired for %s (age: %.1fs)", stock_code, age)
|
||||||
|
del self._cache[stock_code]
|
||||||
|
return None
|
||||||
|
|
||||||
|
return cached
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# API Fetching
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _fetch_news(self, stock_code: str) -> NewsSentiment | None:
|
||||||
|
"""Fetch news from the provider API."""
|
||||||
|
if self._provider == "alphavantage":
|
||||||
|
return await self._fetch_alphavantage(stock_code)
|
||||||
|
elif self._provider == "newsapi":
|
||||||
|
return await self._fetch_newsapi(stock_code)
|
||||||
|
else:
|
||||||
|
logger.error("Unknown news provider: %s", self._provider)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _fetch_alphavantage(self, stock_code: str) -> NewsSentiment | None:
|
||||||
|
"""Fetch news from Alpha Vantage News Sentiment API."""
|
||||||
|
url = "https://www.alphavantage.co/query"
|
||||||
|
params = {
|
||||||
|
"function": "NEWS_SENTIMENT",
|
||||||
|
"tickers": stock_code,
|
||||||
|
"apikey": self._api_key,
|
||||||
|
"limit": 10, # Fetch top 10 articles
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(url, params=params, timeout=10) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
logger.error(
|
||||||
|
"Alpha Vantage API error: HTTP %d", resp.status
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
data = await resp.json()
|
||||||
|
return self._parse_alphavantage_response(stock_code, data)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Alpha Vantage request failed: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _fetch_newsapi(self, stock_code: str) -> NewsSentiment | None:
|
||||||
|
"""Fetch news from NewsAPI.org."""
|
||||||
|
url = "https://newsapi.org/v2/everything"
|
||||||
|
params = {
|
||||||
|
"q": stock_code,
|
||||||
|
"apiKey": self._api_key,
|
||||||
|
"pageSize": 10,
|
||||||
|
"sortBy": "publishedAt",
|
||||||
|
"language": "en",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(url, params=params, timeout=10) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
logger.error("NewsAPI error: HTTP %d", resp.status)
|
||||||
|
return None
|
||||||
|
|
||||||
|
data = await resp.json()
|
||||||
|
return self._parse_newsapi_response(stock_code, data)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("NewsAPI request failed: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Response Parsing
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _parse_alphavantage_response(
|
||||||
|
self, stock_code: str, data: dict[str, Any]
|
||||||
|
) -> NewsSentiment | None:
|
||||||
|
"""Parse Alpha Vantage API response."""
|
||||||
|
if "feed" not in data:
|
||||||
|
logger.warning("No 'feed' key in Alpha Vantage response")
|
||||||
|
return None
|
||||||
|
|
||||||
|
articles: list[NewsArticle] = []
|
||||||
|
for item in data["feed"]:
|
||||||
|
# Extract sentiment for this specific ticker
|
||||||
|
ticker_sentiment = self._extract_ticker_sentiment(item, stock_code)
|
||||||
|
|
||||||
|
article = NewsArticle(
|
||||||
|
title=item.get("title", ""),
|
||||||
|
summary=item.get("summary", "")[:200], # Truncate long summaries
|
||||||
|
source=item.get("source", "Unknown"),
|
||||||
|
published_at=item.get("time_published", ""),
|
||||||
|
sentiment_score=ticker_sentiment,
|
||||||
|
url=item.get("url", ""),
|
||||||
|
)
|
||||||
|
articles.append(article)
|
||||||
|
|
||||||
|
if not articles:
|
||||||
|
return None
|
||||||
|
|
||||||
|
avg_sentiment = sum(a.sentiment_score for a in articles) / len(articles)
|
||||||
|
|
||||||
|
return NewsSentiment(
|
||||||
|
stock_code=stock_code,
|
||||||
|
articles=articles,
|
||||||
|
avg_sentiment=avg_sentiment,
|
||||||
|
article_count=len(articles),
|
||||||
|
fetched_at=time.time(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _extract_ticker_sentiment(
|
||||||
|
self, item: dict[str, Any], stock_code: str
|
||||||
|
) -> float:
|
||||||
|
"""Extract sentiment score for specific ticker from article."""
|
||||||
|
ticker_sentiments = item.get("ticker_sentiment", [])
|
||||||
|
for ts in ticker_sentiments:
|
||||||
|
if ts.get("ticker", "").upper() == stock_code.upper():
|
||||||
|
# Alpha Vantage provides sentiment_score as string
|
||||||
|
score_str = ts.get("ticker_sentiment_score", "0")
|
||||||
|
try:
|
||||||
|
return float(score_str)
|
||||||
|
except ValueError:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Fallback to overall sentiment if ticker-specific not found
|
||||||
|
overall_sentiment = item.get("overall_sentiment_score", "0")
|
||||||
|
try:
|
||||||
|
return float(overall_sentiment)
|
||||||
|
except ValueError:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def _parse_newsapi_response(
|
||||||
|
self, stock_code: str, data: dict[str, Any]
|
||||||
|
) -> NewsSentiment | None:
|
||||||
|
"""Parse NewsAPI.org response.
|
||||||
|
|
||||||
|
Note: NewsAPI doesn't provide sentiment scores, so we use a
|
||||||
|
simple heuristic based on title keywords.
|
||||||
|
"""
|
||||||
|
if data.get("status") != "ok" or "articles" not in data:
|
||||||
|
logger.warning("Invalid NewsAPI response")
|
||||||
|
return None
|
||||||
|
|
||||||
|
articles: list[NewsArticle] = []
|
||||||
|
for item in data["articles"]:
|
||||||
|
# Simple sentiment heuristic based on keywords
|
||||||
|
sentiment = self._estimate_sentiment_from_text(
|
||||||
|
item.get("title", "") + " " + item.get("description", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
article = NewsArticle(
|
||||||
|
title=item.get("title", ""),
|
||||||
|
summary=item.get("description", "")[:200],
|
||||||
|
source=item.get("source", {}).get("name", "Unknown"),
|
||||||
|
published_at=item.get("publishedAt", ""),
|
||||||
|
sentiment_score=sentiment,
|
||||||
|
url=item.get("url", ""),
|
||||||
|
)
|
||||||
|
articles.append(article)
|
||||||
|
|
||||||
|
if not articles:
|
||||||
|
return None
|
||||||
|
|
||||||
|
avg_sentiment = sum(a.sentiment_score for a in articles) / len(articles)
|
||||||
|
|
||||||
|
return NewsSentiment(
|
||||||
|
stock_code=stock_code,
|
||||||
|
articles=articles,
|
||||||
|
avg_sentiment=avg_sentiment,
|
||||||
|
article_count=len(articles),
|
||||||
|
fetched_at=time.time(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _estimate_sentiment_from_text(self, text: str) -> float:
|
||||||
|
"""Simple keyword-based sentiment estimation.
|
||||||
|
|
||||||
|
This is a fallback for APIs that don't provide sentiment scores.
|
||||||
|
Returns a score between -1.0 and +1.0.
|
||||||
|
"""
|
||||||
|
text_lower = text.lower()
|
||||||
|
|
||||||
|
positive_keywords = [
|
||||||
|
"surge", "jump", "gain", "rise", "soar", "rally", "profit",
|
||||||
|
"growth", "upgrade", "beat", "strong", "bullish", "breakthrough",
|
||||||
|
]
|
||||||
|
negative_keywords = [
|
||||||
|
"plunge", "fall", "drop", "decline", "crash", "loss", "weak",
|
||||||
|
"downgrade", "miss", "bearish", "concern", "risk", "warning",
|
||||||
|
]
|
||||||
|
|
||||||
|
positive_count = sum(1 for kw in positive_keywords if kw in text_lower)
|
||||||
|
negative_count = sum(1 for kw in negative_keywords if kw in text_lower)
|
||||||
|
|
||||||
|
total = positive_count + negative_count
|
||||||
|
if total == 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Normalize to -1.0 to +1.0 range
|
||||||
|
return (positive_count - negative_count) / total
|
||||||
@@ -23,7 +23,7 @@ from google import genai
|
|||||||
|
|
||||||
from src.config import Settings
|
from src.config import Settings
|
||||||
from src.db import init_db
|
from src.db import init_db
|
||||||
from src.logging.decision_logger import DecisionLog, DecisionLogger
|
from src.logging.decision_logger import DecisionLogger
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from src.broker.overseas import OverseasBroker
|
|||||||
from src.config import Settings
|
from src.config import Settings
|
||||||
from src.context.layer import ContextLayer
|
from src.context.layer import ContextLayer
|
||||||
from src.context.store import ContextStore
|
from src.context.store import ContextStore
|
||||||
from src.core.criticality import CriticalityAssessor, CriticalityLevel
|
from src.core.criticality import CriticalityAssessor
|
||||||
from src.core.priority_queue import PriorityTaskQueue
|
from src.core.priority_queue import PriorityTaskQueue
|
||||||
from src.core.risk_manager import CircuitBreakerTripped, RiskManager
|
from src.core.risk_manager import CircuitBreakerTripped, RiskManager
|
||||||
from src.db import init_db, log_trade
|
from src.db import init_db, log_trade
|
||||||
|
|||||||
213
src/notifications/README.md
Normal file
213
src/notifications/README.md
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
# Telegram Notifications
|
||||||
|
|
||||||
|
Real-time trading event notifications via Telegram Bot API.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
### 1. Create a Telegram Bot
|
||||||
|
|
||||||
|
1. Open Telegram and message [@BotFather](https://t.me/BotFather)
|
||||||
|
2. Send `/newbot` command
|
||||||
|
3. Follow prompts to name your bot
|
||||||
|
4. Save the **bot token** (looks like `1234567890:ABCdefGHIjklMNOpqrsTUVwxyz`)
|
||||||
|
|
||||||
|
### 2. Get Your Chat ID
|
||||||
|
|
||||||
|
**Option A: Using @userinfobot**
|
||||||
|
1. Message [@userinfobot](https://t.me/userinfobot) on Telegram
|
||||||
|
2. Send `/start`
|
||||||
|
3. Save your numeric **chat ID** (e.g., `123456789`)
|
||||||
|
|
||||||
|
**Option B: Using @RawDataBot**
|
||||||
|
1. Message [@RawDataBot](https://t.me/rawdatabot) on Telegram
|
||||||
|
2. Look for `"id":` in the JSON response
|
||||||
|
3. Save your numeric **chat ID**
|
||||||
|
|
||||||
|
### 3. Configure Environment
|
||||||
|
|
||||||
|
Add to your `.env` file:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
TELEGRAM_BOT_TOKEN=1234567890:ABCdefGHIjklMNOpqrsTUVwxyz
|
||||||
|
TELEGRAM_CHAT_ID=123456789
|
||||||
|
TELEGRAM_ENABLED=true
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Test the Bot
|
||||||
|
|
||||||
|
Start a conversation with your bot on Telegram first (send `/start`), then run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m src.main --mode=paper
|
||||||
|
```
|
||||||
|
|
||||||
|
You should receive a startup notification.
|
||||||
|
|
||||||
|
## Message Examples
|
||||||
|
|
||||||
|
### Trade Execution
|
||||||
|
```
|
||||||
|
🟢 BUY
|
||||||
|
Symbol: AAPL (United States)
|
||||||
|
Quantity: 10 shares
|
||||||
|
Price: 150.25
|
||||||
|
Confidence: 85%
|
||||||
|
```
|
||||||
|
|
||||||
|
### Circuit Breaker
|
||||||
|
```
|
||||||
|
🚨 CIRCUIT BREAKER TRIPPED
|
||||||
|
P&L: -3.15% (threshold: -3.0%)
|
||||||
|
Trading halted for safety
|
||||||
|
```
|
||||||
|
|
||||||
|
### Fat-Finger Protection
|
||||||
|
```
|
||||||
|
⚠️ Fat-Finger Protection
|
||||||
|
Order rejected: TSLA
|
||||||
|
Attempted: 45.0% of cash
|
||||||
|
Max allowed: 30%
|
||||||
|
Amount: 45,000 / 100,000
|
||||||
|
```
|
||||||
|
|
||||||
|
### Market Open/Close
|
||||||
|
```
|
||||||
|
ℹ️ Market Open
|
||||||
|
Korea trading session started
|
||||||
|
|
||||||
|
ℹ️ Market Close
|
||||||
|
Korea trading session ended
|
||||||
|
📈 P&L: +1.25%
|
||||||
|
```
|
||||||
|
|
||||||
|
### System Status
|
||||||
|
```
|
||||||
|
📝 System Started
|
||||||
|
Mode: PAPER
|
||||||
|
Markets: KRX, NASDAQ
|
||||||
|
|
||||||
|
System Shutdown
|
||||||
|
Normal shutdown
|
||||||
|
```
|
||||||
|
|
||||||
|
## Notification Priorities
|
||||||
|
|
||||||
|
| Priority | Emoji | Use Case |
|
||||||
|
|----------|-------|----------|
|
||||||
|
| LOW | ℹ️ | Market open/close |
|
||||||
|
| MEDIUM | 📊 | Trade execution, system start/stop |
|
||||||
|
| HIGH | ⚠️ | Fat-finger protection, errors |
|
||||||
|
| CRITICAL | 🚨 | Circuit breaker trips |
|
||||||
|
|
||||||
|
## Rate Limiting
|
||||||
|
|
||||||
|
- Default: 1 message per second
|
||||||
|
- Prevents hitting Telegram's global rate limits
|
||||||
|
- Configurable via `rate_limit` parameter
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### No notifications received
|
||||||
|
|
||||||
|
1. **Check bot configuration**
|
||||||
|
```bash
|
||||||
|
# Verify env variables are set
|
||||||
|
grep TELEGRAM .env
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Start conversation with bot**
|
||||||
|
- Open bot in Telegram
|
||||||
|
- Send `/start` command
|
||||||
|
- Bot cannot message users who haven't started a conversation
|
||||||
|
|
||||||
|
3. **Check logs**
|
||||||
|
```bash
|
||||||
|
# Look for Telegram-related errors
|
||||||
|
python -m src.main --mode=paper 2>&1 | grep -i telegram
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Verify bot token**
|
||||||
|
```bash
|
||||||
|
curl https://api.telegram.org/bot<YOUR_TOKEN>/getMe
|
||||||
|
# Should return bot info (not 401 error)
|
||||||
|
```
|
||||||
|
|
||||||
|
5. **Verify chat ID**
|
||||||
|
```bash
|
||||||
|
curl -X POST https://api.telegram.org/bot<YOUR_TOKEN>/sendMessage \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{"chat_id": "<YOUR_CHAT_ID>", "text": "Test"}'
|
||||||
|
# Should send a test message
|
||||||
|
```
|
||||||
|
|
||||||
|
### Notifications delayed
|
||||||
|
|
||||||
|
- Check rate limiter settings
|
||||||
|
- Verify network connection
|
||||||
|
- Look for timeout errors in logs
|
||||||
|
|
||||||
|
### "Chat not found" error
|
||||||
|
|
||||||
|
- Incorrect chat ID
|
||||||
|
- Bot blocked by user
|
||||||
|
- Need to send `/start` to bot first
|
||||||
|
|
||||||
|
### "Unauthorized" error
|
||||||
|
|
||||||
|
- Invalid bot token
|
||||||
|
- Token revoked (regenerate with @BotFather)
|
||||||
|
|
||||||
|
## Graceful Degradation
|
||||||
|
|
||||||
|
The system works without Telegram notifications:
|
||||||
|
|
||||||
|
- Missing credentials → notifications disabled automatically
|
||||||
|
- API errors → logged but trading continues
|
||||||
|
- Network timeouts → trading loop unaffected
|
||||||
|
- Rate limiting → messages queued, trading proceeds
|
||||||
|
|
||||||
|
**Notifications never crash the trading system.**
|
||||||
|
|
||||||
|
## Security Notes
|
||||||
|
|
||||||
|
- Never commit `.env` file with credentials
|
||||||
|
- Bot token grants full bot control
|
||||||
|
- Chat ID is not sensitive (just a number)
|
||||||
|
- Messages are sent over HTTPS
|
||||||
|
- No trading credentials in notifications
|
||||||
|
|
||||||
|
## Advanced Usage
|
||||||
|
|
||||||
|
### Group Notifications
|
||||||
|
|
||||||
|
1. Add bot to Telegram group
|
||||||
|
2. Get group chat ID (negative number like `-123456789`)
|
||||||
|
3. Use group chat ID in `TELEGRAM_CHAT_ID`
|
||||||
|
|
||||||
|
### Multiple Recipients
|
||||||
|
|
||||||
|
Create multiple bots or use a broadcast group with multiple members.
|
||||||
|
|
||||||
|
### Custom Rate Limits
|
||||||
|
|
||||||
|
Not currently exposed in config, but can be modified in code:
|
||||||
|
|
||||||
|
```python
|
||||||
|
telegram = TelegramClient(
|
||||||
|
bot_token=settings.TELEGRAM_BOT_TOKEN,
|
||||||
|
chat_id=settings.TELEGRAM_CHAT_ID,
|
||||||
|
rate_limit=2.0, # 2 messages per second
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Reference
|
||||||
|
|
||||||
|
See `telegram_client.py` for full API documentation.
|
||||||
|
|
||||||
|
Key methods:
|
||||||
|
- `notify_trade_execution()` - Trade alerts
|
||||||
|
- `notify_circuit_breaker()` - Emergency stops
|
||||||
|
- `notify_fat_finger()` - Order rejections
|
||||||
|
- `notify_market_open/close()` - Session tracking
|
||||||
|
- `notify_system_start/shutdown()` - Lifecycle events
|
||||||
|
- `notify_error()` - Error alerts
|
||||||
5
src/notifications/__init__.py
Normal file
5
src/notifications/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Real-time notifications for trading events."""
|
||||||
|
|
||||||
|
from src.notifications.telegram_client import TelegramClient
|
||||||
|
|
||||||
|
__all__ = ["TelegramClient"]
|
||||||
325
src/notifications/telegram_client.py
Normal file
325
src/notifications/telegram_client.py
Normal file
@@ -0,0 +1,325 @@
|
|||||||
|
"""Telegram notification client for real-time trading alerts."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class NotificationPriority(Enum):
|
||||||
|
"""Priority levels for notifications with emoji indicators."""
|
||||||
|
|
||||||
|
LOW = ("ℹ️", "info")
|
||||||
|
MEDIUM = ("📊", "medium")
|
||||||
|
HIGH = ("⚠️", "warning")
|
||||||
|
CRITICAL = ("🚨", "critical")
|
||||||
|
|
||||||
|
def __init__(self, emoji: str, label: str) -> None:
|
||||||
|
self.emoji = emoji
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
|
||||||
|
class LeakyBucket:
|
||||||
|
"""Rate limiter using leaky bucket algorithm."""
|
||||||
|
|
||||||
|
def __init__(self, rate: float, capacity: int = 1) -> None:
|
||||||
|
"""
|
||||||
|
Initialize rate limiter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rate: Maximum requests per second
|
||||||
|
capacity: Bucket capacity (burst size)
|
||||||
|
"""
|
||||||
|
self._rate = rate
|
||||||
|
self._capacity = capacity
|
||||||
|
self._tokens = float(capacity)
|
||||||
|
self._last_update = time.monotonic()
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def acquire(self) -> None:
|
||||||
|
"""Wait until a token is available, then consume it."""
|
||||||
|
async with self._lock:
|
||||||
|
now = time.monotonic()
|
||||||
|
elapsed = now - self._last_update
|
||||||
|
self._tokens = min(self._capacity, self._tokens + elapsed * self._rate)
|
||||||
|
self._last_update = now
|
||||||
|
|
||||||
|
if self._tokens < 1.0:
|
||||||
|
wait_time = (1.0 - self._tokens) / self._rate
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
self._tokens = 0.0
|
||||||
|
else:
|
||||||
|
self._tokens -= 1.0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NotificationMessage:
|
||||||
|
"""Internal notification message structure."""
|
||||||
|
|
||||||
|
priority: NotificationPriority
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class TelegramClient:
|
||||||
|
"""Telegram Bot API client for sending trading notifications."""
|
||||||
|
|
||||||
|
API_BASE = "https://api.telegram.org/bot{token}"
|
||||||
|
DEFAULT_TIMEOUT = 5.0 # seconds
|
||||||
|
DEFAULT_RATE = 1.0 # messages per second
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
bot_token: str | None = None,
|
||||||
|
chat_id: str | None = None,
|
||||||
|
enabled: bool = True,
|
||||||
|
rate_limit: float = DEFAULT_RATE,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize Telegram client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bot_token: Telegram bot token from @BotFather
|
||||||
|
chat_id: Target chat ID (user or group)
|
||||||
|
enabled: Enable/disable notifications globally
|
||||||
|
rate_limit: Maximum messages per second
|
||||||
|
"""
|
||||||
|
self._bot_token = bot_token
|
||||||
|
self._chat_id = chat_id
|
||||||
|
self._enabled = enabled
|
||||||
|
self._rate_limiter = LeakyBucket(rate=rate_limit)
|
||||||
|
self._session: aiohttp.ClientSession | None = None
|
||||||
|
|
||||||
|
if not enabled:
|
||||||
|
logger.info("Telegram notifications disabled via configuration")
|
||||||
|
elif bot_token is None or chat_id is None:
|
||||||
|
logger.warning(
|
||||||
|
"Telegram notifications disabled (missing bot_token or chat_id)"
|
||||||
|
)
|
||||||
|
self._enabled = False
|
||||||
|
else:
|
||||||
|
logger.info("Telegram notifications enabled for chat_id=%s", chat_id)
|
||||||
|
|
||||||
|
def _get_session(self) -> aiohttp.ClientSession:
|
||||||
|
"""Get or create aiohttp session."""
|
||||||
|
if self._session is None or self._session.closed:
|
||||||
|
self._session = aiohttp.ClientSession(
|
||||||
|
timeout=aiohttp.ClientTimeout(total=self.DEFAULT_TIMEOUT)
|
||||||
|
)
|
||||||
|
return self._session
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close HTTP session."""
|
||||||
|
if self._session is not None and not self._session.closed:
|
||||||
|
await self._session.close()
|
||||||
|
|
||||||
|
async def _send_notification(self, msg: NotificationMessage) -> None:
|
||||||
|
"""
|
||||||
|
Send notification to Telegram with graceful degradation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msg: Notification message to send
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._rate_limiter.acquire()
|
||||||
|
|
||||||
|
formatted_message = f"{msg.priority.emoji} {msg.message}"
|
||||||
|
url = f"{self.API_BASE.format(token=self._bot_token)}/sendMessage"
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"chat_id": self._chat_id,
|
||||||
|
"text": formatted_message,
|
||||||
|
"parse_mode": "HTML",
|
||||||
|
}
|
||||||
|
|
||||||
|
session = self._get_session()
|
||||||
|
async with session.post(url, json=payload) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
error_text = await resp.text()
|
||||||
|
logger.error(
|
||||||
|
"Telegram API error (status=%d): %s", resp.status, error_text
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug("Telegram notification sent: %s", msg.message[:50])
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error("Telegram notification timeout")
|
||||||
|
except aiohttp.ClientError as exc:
|
||||||
|
logger.error("Telegram notification failed: %s", exc)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Unexpected error sending notification: %s", exc)
|
||||||
|
|
||||||
|
async def notify_trade_execution(
|
||||||
|
self,
|
||||||
|
stock_code: str,
|
||||||
|
market: str,
|
||||||
|
action: str,
|
||||||
|
quantity: int,
|
||||||
|
price: float,
|
||||||
|
confidence: float,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Notify trade execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stock_code: Stock ticker symbol
|
||||||
|
market: Market name (e.g., "Korea", "United States")
|
||||||
|
action: "BUY" or "SELL"
|
||||||
|
quantity: Number of shares
|
||||||
|
price: Execution price
|
||||||
|
confidence: AI confidence level (0-100)
|
||||||
|
"""
|
||||||
|
emoji = "🟢" if action == "BUY" else "🔴"
|
||||||
|
message = (
|
||||||
|
f"<b>{emoji} {action}</b>\n"
|
||||||
|
f"Symbol: <code>{stock_code}</code> ({market})\n"
|
||||||
|
f"Quantity: {quantity:,} shares\n"
|
||||||
|
f"Price: {price:,.2f}\n"
|
||||||
|
f"Confidence: {confidence:.0f}%"
|
||||||
|
)
|
||||||
|
await self._send_notification(
|
||||||
|
NotificationMessage(priority=NotificationPriority.MEDIUM, message=message)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def notify_market_open(self, market_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Notify market opening.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
market_name: Name of the market (e.g., "Korea", "United States")
|
||||||
|
"""
|
||||||
|
message = f"<b>Market Open</b>\n{market_name} trading session started"
|
||||||
|
await self._send_notification(
|
||||||
|
NotificationMessage(priority=NotificationPriority.LOW, message=message)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def notify_market_close(self, market_name: str, pnl_pct: float) -> None:
|
||||||
|
"""
|
||||||
|
Notify market closing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
market_name: Name of the market
|
||||||
|
pnl_pct: Final P&L percentage for the session
|
||||||
|
"""
|
||||||
|
pnl_sign = "+" if pnl_pct >= 0 else ""
|
||||||
|
pnl_emoji = "📈" if pnl_pct >= 0 else "📉"
|
||||||
|
message = (
|
||||||
|
f"<b>Market Close</b>\n"
|
||||||
|
f"{market_name} trading session ended\n"
|
||||||
|
f"{pnl_emoji} P&L: {pnl_sign}{pnl_pct:.2f}%"
|
||||||
|
)
|
||||||
|
await self._send_notification(
|
||||||
|
NotificationMessage(priority=NotificationPriority.LOW, message=message)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def notify_circuit_breaker(
|
||||||
|
self, pnl_pct: float, threshold: float
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Notify circuit breaker activation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pnl_pct: Current P&L percentage
|
||||||
|
threshold: Circuit breaker threshold
|
||||||
|
"""
|
||||||
|
message = (
|
||||||
|
f"<b>CIRCUIT BREAKER TRIPPED</b>\n"
|
||||||
|
f"P&L: {pnl_pct:.2f}% (threshold: {threshold:.1f}%)\n"
|
||||||
|
f"Trading halted for safety"
|
||||||
|
)
|
||||||
|
await self._send_notification(
|
||||||
|
NotificationMessage(priority=NotificationPriority.CRITICAL, message=message)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def notify_fat_finger(
|
||||||
|
self,
|
||||||
|
stock_code: str,
|
||||||
|
order_amount: float,
|
||||||
|
total_cash: float,
|
||||||
|
max_pct: float,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Notify fat-finger protection rejection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stock_code: Stock ticker symbol
|
||||||
|
order_amount: Attempted order amount
|
||||||
|
total_cash: Total available cash
|
||||||
|
max_pct: Maximum allowed percentage
|
||||||
|
"""
|
||||||
|
attempted_pct = (order_amount / total_cash) * 100 if total_cash > 0 else 0
|
||||||
|
message = (
|
||||||
|
f"<b>Fat-Finger Protection</b>\n"
|
||||||
|
f"Order rejected: <code>{stock_code}</code>\n"
|
||||||
|
f"Attempted: {attempted_pct:.1f}% of cash\n"
|
||||||
|
f"Max allowed: {max_pct:.0f}%\n"
|
||||||
|
f"Amount: {order_amount:,.0f} / {total_cash:,.0f}"
|
||||||
|
)
|
||||||
|
await self._send_notification(
|
||||||
|
NotificationMessage(priority=NotificationPriority.HIGH, message=message)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def notify_system_start(
|
||||||
|
self, mode: str, enabled_markets: list[str]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Notify system startup.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mode: Trading mode ("paper" or "live")
|
||||||
|
enabled_markets: List of enabled market codes
|
||||||
|
"""
|
||||||
|
mode_emoji = "📝" if mode == "paper" else "💰"
|
||||||
|
markets_str = ", ".join(enabled_markets)
|
||||||
|
message = (
|
||||||
|
f"<b>{mode_emoji} System Started</b>\n"
|
||||||
|
f"Mode: {mode.upper()}\n"
|
||||||
|
f"Markets: {markets_str}"
|
||||||
|
)
|
||||||
|
await self._send_notification(
|
||||||
|
NotificationMessage(priority=NotificationPriority.MEDIUM, message=message)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def notify_system_shutdown(self, reason: str) -> None:
|
||||||
|
"""
|
||||||
|
Notify system shutdown.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reason: Reason for shutdown (e.g., "Normal shutdown", "Circuit breaker")
|
||||||
|
"""
|
||||||
|
message = f"<b>System Shutdown</b>\n{reason}"
|
||||||
|
priority = (
|
||||||
|
NotificationPriority.CRITICAL
|
||||||
|
if "circuit breaker" in reason.lower()
|
||||||
|
else NotificationPriority.MEDIUM
|
||||||
|
)
|
||||||
|
await self._send_notification(
|
||||||
|
NotificationMessage(priority=priority, message=message)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def notify_error(
|
||||||
|
self, error_type: str, error_msg: str, context: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Notify system error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_type: Type of error (e.g., "Connection Error")
|
||||||
|
error_msg: Error message
|
||||||
|
context: Error context (e.g., stock code, market)
|
||||||
|
"""
|
||||||
|
message = (
|
||||||
|
f"<b>Error: {error_type}</b>\n"
|
||||||
|
f"Context: {context}\n"
|
||||||
|
f"Message: {error_msg[:200]}" # Truncate long errors
|
||||||
|
)
|
||||||
|
await self._send_notification(
|
||||||
|
NotificationMessage(priority=NotificationPriority.HIGH, message=message)
|
||||||
|
)
|
||||||
@@ -126,7 +126,7 @@ class TestPromptConstruction:
|
|||||||
"orderbook": {"asks": [], "bids": []},
|
"orderbook": {"asks": [], "bids": []},
|
||||||
"foreigner_net": -50000,
|
"foreigner_net": -50000,
|
||||||
}
|
}
|
||||||
prompt = client.build_prompt(market_data)
|
prompt = client.build_prompt_sync(market_data)
|
||||||
assert "005930" in prompt
|
assert "005930" in prompt
|
||||||
|
|
||||||
def test_prompt_contains_price(self, settings):
|
def test_prompt_contains_price(self, settings):
|
||||||
@@ -137,7 +137,7 @@ class TestPromptConstruction:
|
|||||||
"orderbook": {"asks": [], "bids": []},
|
"orderbook": {"asks": [], "bids": []},
|
||||||
"foreigner_net": -50000,
|
"foreigner_net": -50000,
|
||||||
}
|
}
|
||||||
prompt = client.build_prompt(market_data)
|
prompt = client.build_prompt_sync(market_data)
|
||||||
assert "72000" in prompt
|
assert "72000" in prompt
|
||||||
|
|
||||||
def test_prompt_enforces_json_output_format(self, settings):
|
def test_prompt_enforces_json_output_format(self, settings):
|
||||||
@@ -148,7 +148,7 @@ class TestPromptConstruction:
|
|||||||
"orderbook": {"asks": [], "bids": []},
|
"orderbook": {"asks": [], "bids": []},
|
||||||
"foreigner_net": 0,
|
"foreigner_net": 0,
|
||||||
}
|
}
|
||||||
prompt = client.build_prompt(market_data)
|
prompt = client.build_prompt_sync(market_data)
|
||||||
assert "JSON" in prompt
|
assert "JSON" in prompt
|
||||||
assert "action" in prompt
|
assert "action" in prompt
|
||||||
assert "confidence" in prompt
|
assert "confidence" in prompt
|
||||||
|
|||||||
673
tests/test_data_integration.py
Normal file
673
tests/test_data_integration.py
Normal file
@@ -0,0 +1,673 @@
|
|||||||
|
"""Tests for external data integration (news, economic calendar, market data)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.brain.gemini_client import GeminiClient
|
||||||
|
from src.data.economic_calendar import EconomicCalendar, EconomicEvent
|
||||||
|
from src.data.market_data import MarketBreadth, MarketData, MarketSentiment
|
||||||
|
from src.data.news_api import NewsAPI, NewsArticle, NewsSentiment
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# NewsAPI Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestNewsAPI:
|
||||||
|
"""Test news API integration with caching."""
|
||||||
|
|
||||||
|
def test_news_api_init_without_key(self):
|
||||||
|
"""NewsAPI should initialize without API key for testing."""
|
||||||
|
api = NewsAPI(api_key=None)
|
||||||
|
assert api._api_key is None
|
||||||
|
assert api._provider == "alphavantage"
|
||||||
|
assert api._cache_ttl == 300
|
||||||
|
|
||||||
|
def test_news_api_init_with_custom_settings(self):
|
||||||
|
"""NewsAPI should accept custom provider and cache TTL."""
|
||||||
|
api = NewsAPI(api_key="test_key", provider="newsapi", cache_ttl=600)
|
||||||
|
assert api._api_key == "test_key"
|
||||||
|
assert api._provider == "newsapi"
|
||||||
|
assert api._cache_ttl == 600
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_news_sentiment_without_api_key_returns_none(self):
|
||||||
|
"""Without API key, get_news_sentiment should return None."""
|
||||||
|
api = NewsAPI(api_key=None)
|
||||||
|
result = await api.get_news_sentiment("AAPL")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_hit_returns_cached_sentiment(self):
|
||||||
|
"""Cache hit should return cached sentiment without API call."""
|
||||||
|
api = NewsAPI(api_key="test_key")
|
||||||
|
|
||||||
|
# Manually populate cache
|
||||||
|
cached_sentiment = NewsSentiment(
|
||||||
|
stock_code="AAPL",
|
||||||
|
articles=[],
|
||||||
|
avg_sentiment=0.5,
|
||||||
|
article_count=0,
|
||||||
|
fetched_at=time.time(),
|
||||||
|
)
|
||||||
|
api._cache["AAPL"] = cached_sentiment
|
||||||
|
|
||||||
|
result = await api.get_news_sentiment("AAPL")
|
||||||
|
assert result is cached_sentiment
|
||||||
|
assert result.stock_code == "AAPL"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_expiry_triggers_refetch(self):
|
||||||
|
"""Expired cache entry should trigger refetch."""
|
||||||
|
api = NewsAPI(api_key="test_key", cache_ttl=1)
|
||||||
|
|
||||||
|
# Add expired cache entry
|
||||||
|
expired_sentiment = NewsSentiment(
|
||||||
|
stock_code="AAPL",
|
||||||
|
articles=[],
|
||||||
|
avg_sentiment=0.5,
|
||||||
|
article_count=0,
|
||||||
|
fetched_at=time.time() - 10, # 10 seconds ago
|
||||||
|
)
|
||||||
|
api._cache["AAPL"] = expired_sentiment
|
||||||
|
|
||||||
|
# Mock the fetch to avoid real API call
|
||||||
|
with patch.object(api, "_fetch_news", new_callable=AsyncMock) as mock_fetch:
|
||||||
|
mock_fetch.return_value = None
|
||||||
|
result = await api.get_news_sentiment("AAPL")
|
||||||
|
|
||||||
|
# Should have attempted refetch since cache expired
|
||||||
|
mock_fetch.assert_called_once_with("AAPL")
|
||||||
|
|
||||||
|
def test_clear_cache(self):
|
||||||
|
"""clear_cache should empty the cache."""
|
||||||
|
api = NewsAPI(api_key="test_key")
|
||||||
|
api._cache["AAPL"] = NewsSentiment(
|
||||||
|
stock_code="AAPL",
|
||||||
|
articles=[],
|
||||||
|
avg_sentiment=0.0,
|
||||||
|
article_count=0,
|
||||||
|
fetched_at=time.time(),
|
||||||
|
)
|
||||||
|
assert len(api._cache) == 1
|
||||||
|
|
||||||
|
api.clear_cache()
|
||||||
|
assert len(api._cache) == 0
|
||||||
|
|
||||||
|
def test_parse_alphavantage_response_with_valid_data(self):
|
||||||
|
"""Should parse Alpha Vantage response correctly."""
|
||||||
|
api = NewsAPI(api_key="test_key", provider="alphavantage")
|
||||||
|
|
||||||
|
mock_response = {
|
||||||
|
"feed": [
|
||||||
|
{
|
||||||
|
"title": "Apple hits new high",
|
||||||
|
"summary": "Apple stock surges to record levels",
|
||||||
|
"source": "Reuters",
|
||||||
|
"time_published": "2026-02-04T10:00:00",
|
||||||
|
"url": "https://example.com/1",
|
||||||
|
"ticker_sentiment": [
|
||||||
|
{"ticker": "AAPL", "ticker_sentiment_score": "0.85"}
|
||||||
|
],
|
||||||
|
"overall_sentiment_score": "0.75",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "Market volatility rises",
|
||||||
|
"summary": "Tech stocks face headwinds",
|
||||||
|
"source": "Bloomberg",
|
||||||
|
"time_published": "2026-02-04T09:00:00",
|
||||||
|
"url": "https://example.com/2",
|
||||||
|
"ticker_sentiment": [
|
||||||
|
{"ticker": "AAPL", "ticker_sentiment_score": "-0.3"}
|
||||||
|
],
|
||||||
|
"overall_sentiment_score": "-0.2",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
result = api._parse_alphavantage_response("AAPL", mock_response)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.stock_code == "AAPL"
|
||||||
|
assert result.article_count == 2
|
||||||
|
assert len(result.articles) == 2
|
||||||
|
assert result.articles[0].title == "Apple hits new high"
|
||||||
|
assert result.articles[0].sentiment_score == 0.85
|
||||||
|
assert result.articles[1].sentiment_score == -0.3
|
||||||
|
# Average: (0.85 - 0.3) / 2 = 0.275
|
||||||
|
assert abs(result.avg_sentiment - 0.275) < 0.01
|
||||||
|
|
||||||
|
def test_parse_alphavantage_response_without_feed_returns_none(self):
|
||||||
|
"""Should return None if 'feed' key is missing."""
|
||||||
|
api = NewsAPI(api_key="test_key", provider="alphavantage")
|
||||||
|
result = api._parse_alphavantage_response("AAPL", {})
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_parse_newsapi_response_with_valid_data(self):
|
||||||
|
"""Should parse NewsAPI.org response correctly."""
|
||||||
|
api = NewsAPI(api_key="test_key", provider="newsapi")
|
||||||
|
|
||||||
|
mock_response = {
|
||||||
|
"status": "ok",
|
||||||
|
"articles": [
|
||||||
|
{
|
||||||
|
"title": "Apple stock surges",
|
||||||
|
"description": "Strong earnings beat expectations",
|
||||||
|
"source": {"name": "TechCrunch"},
|
||||||
|
"publishedAt": "2026-02-04T10:00:00Z",
|
||||||
|
"url": "https://example.com/1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "Tech sector faces risks",
|
||||||
|
"description": "Concerns over market downturn",
|
||||||
|
"source": {"name": "CNBC"},
|
||||||
|
"publishedAt": "2026-02-04T09:00:00Z",
|
||||||
|
"url": "https://example.com/2",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = api._parse_newsapi_response("AAPL", mock_response)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.stock_code == "AAPL"
|
||||||
|
assert result.article_count == 2
|
||||||
|
assert len(result.articles) == 2
|
||||||
|
assert result.articles[0].title == "Apple stock surges"
|
||||||
|
assert result.articles[0].source == "TechCrunch"
|
||||||
|
|
||||||
|
def test_estimate_sentiment_from_text_positive(self):
|
||||||
|
"""Should detect positive sentiment from keywords."""
|
||||||
|
api = NewsAPI()
|
||||||
|
text = "Stock price surges with strong profit growth and upgrade"
|
||||||
|
sentiment = api._estimate_sentiment_from_text(text)
|
||||||
|
assert sentiment > 0.5
|
||||||
|
|
||||||
|
def test_estimate_sentiment_from_text_negative(self):
|
||||||
|
"""Should detect negative sentiment from keywords."""
|
||||||
|
api = NewsAPI()
|
||||||
|
text = "Stock plunges on weak earnings, downgrade warning"
|
||||||
|
sentiment = api._estimate_sentiment_from_text(text)
|
||||||
|
assert sentiment < -0.5
|
||||||
|
|
||||||
|
def test_estimate_sentiment_from_text_neutral(self):
|
||||||
|
"""Should return neutral sentiment without keywords."""
|
||||||
|
api = NewsAPI()
|
||||||
|
text = "Company announces quarterly report"
|
||||||
|
sentiment = api._estimate_sentiment_from_text(text)
|
||||||
|
assert abs(sentiment) < 0.1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# EconomicCalendar Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEconomicCalendar:
|
||||||
|
"""Test economic calendar functionality."""
|
||||||
|
|
||||||
|
def test_economic_calendar_init(self):
|
||||||
|
"""EconomicCalendar should initialize correctly."""
|
||||||
|
calendar = EconomicCalendar(api_key="test_key")
|
||||||
|
assert calendar._api_key == "test_key"
|
||||||
|
assert len(calendar._events) == 0
|
||||||
|
|
||||||
|
def test_add_event(self):
|
||||||
|
"""Should be able to add events to calendar."""
|
||||||
|
calendar = EconomicCalendar()
|
||||||
|
event = EconomicEvent(
|
||||||
|
name="FOMC Meeting",
|
||||||
|
event_type="FOMC",
|
||||||
|
datetime=datetime(2026, 3, 18),
|
||||||
|
impact="HIGH",
|
||||||
|
country="US",
|
||||||
|
description="Interest rate decision",
|
||||||
|
)
|
||||||
|
calendar.add_event(event)
|
||||||
|
assert len(calendar._events) == 1
|
||||||
|
assert calendar._events[0].name == "FOMC Meeting"
|
||||||
|
|
||||||
|
def test_get_upcoming_events_filters_by_timeframe(self):
|
||||||
|
"""Should only return events within specified timeframe."""
|
||||||
|
calendar = EconomicCalendar()
|
||||||
|
|
||||||
|
# Add events at different times
|
||||||
|
now = datetime.now()
|
||||||
|
calendar.add_event(
|
||||||
|
EconomicEvent(
|
||||||
|
name="Event Tomorrow",
|
||||||
|
event_type="GDP",
|
||||||
|
datetime=now + timedelta(days=1),
|
||||||
|
impact="HIGH",
|
||||||
|
country="US",
|
||||||
|
description="Test event",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
calendar.add_event(
|
||||||
|
EconomicEvent(
|
||||||
|
name="Event Next Month",
|
||||||
|
event_type="CPI",
|
||||||
|
datetime=now + timedelta(days=30),
|
||||||
|
impact="HIGH",
|
||||||
|
country="US",
|
||||||
|
description="Test event",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get events for next 7 days
|
||||||
|
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="HIGH")
|
||||||
|
assert upcoming.high_impact_count == 1
|
||||||
|
assert upcoming.events[0].name == "Event Tomorrow"
|
||||||
|
|
||||||
|
def test_get_upcoming_events_filters_by_impact(self):
|
||||||
|
"""Should filter events by minimum impact level."""
|
||||||
|
calendar = EconomicCalendar()
|
||||||
|
|
||||||
|
now = datetime.now()
|
||||||
|
calendar.add_event(
|
||||||
|
EconomicEvent(
|
||||||
|
name="High Impact Event",
|
||||||
|
event_type="FOMC",
|
||||||
|
datetime=now + timedelta(days=1),
|
||||||
|
impact="HIGH",
|
||||||
|
country="US",
|
||||||
|
description="Test",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
calendar.add_event(
|
||||||
|
EconomicEvent(
|
||||||
|
name="Low Impact Event",
|
||||||
|
event_type="OTHER",
|
||||||
|
datetime=now + timedelta(days=1),
|
||||||
|
impact="LOW",
|
||||||
|
country="US",
|
||||||
|
description="Test",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter for HIGH impact only
|
||||||
|
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="HIGH")
|
||||||
|
assert upcoming.high_impact_count == 1
|
||||||
|
assert upcoming.events[0].name == "High Impact Event"
|
||||||
|
|
||||||
|
# Filter for MEDIUM and above (should still get HIGH)
|
||||||
|
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="MEDIUM")
|
||||||
|
assert len(upcoming.events) == 1
|
||||||
|
|
||||||
|
# Filter for LOW and above (should get both)
|
||||||
|
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="LOW")
|
||||||
|
assert len(upcoming.events) == 2
|
||||||
|
|
||||||
|
def test_get_earnings_date_returns_next_earnings(self):
|
||||||
|
"""Should return next earnings date for a stock."""
|
||||||
|
calendar = EconomicCalendar()
|
||||||
|
|
||||||
|
now = datetime.now()
|
||||||
|
earnings_date = now + timedelta(days=5)
|
||||||
|
|
||||||
|
calendar.add_event(
|
||||||
|
EconomicEvent(
|
||||||
|
name="AAPL Earnings",
|
||||||
|
event_type="EARNINGS",
|
||||||
|
datetime=earnings_date,
|
||||||
|
impact="HIGH",
|
||||||
|
country="US",
|
||||||
|
description="Apple quarterly earnings",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = calendar.get_earnings_date("AAPL")
|
||||||
|
assert result == earnings_date
|
||||||
|
|
||||||
|
def test_get_earnings_date_returns_none_if_not_found(self):
|
||||||
|
"""Should return None if no earnings found for stock."""
|
||||||
|
calendar = EconomicCalendar()
|
||||||
|
result = calendar.get_earnings_date("UNKNOWN")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_load_hardcoded_events(self):
|
||||||
|
"""Should load hardcoded major economic events."""
|
||||||
|
calendar = EconomicCalendar()
|
||||||
|
calendar.load_hardcoded_events()
|
||||||
|
|
||||||
|
# Should have multiple events (FOMC, GDP, CPI)
|
||||||
|
assert len(calendar._events) > 10
|
||||||
|
|
||||||
|
# Check for FOMC events
|
||||||
|
fomc_events = [e for e in calendar._events if e.event_type == "FOMC"]
|
||||||
|
assert len(fomc_events) > 0
|
||||||
|
|
||||||
|
# Check for GDP events
|
||||||
|
gdp_events = [e for e in calendar._events if e.event_type == "GDP"]
|
||||||
|
assert len(gdp_events) > 0
|
||||||
|
|
||||||
|
# Check for CPI events
|
||||||
|
cpi_events = [e for e in calendar._events if e.event_type == "CPI"]
|
||||||
|
assert len(cpi_events) == 12 # Monthly CPI releases
|
||||||
|
|
||||||
|
def test_is_high_volatility_period_returns_true_near_high_impact(self):
|
||||||
|
"""Should return True if high-impact event is within threshold."""
|
||||||
|
calendar = EconomicCalendar()
|
||||||
|
|
||||||
|
now = datetime.now()
|
||||||
|
calendar.add_event(
|
||||||
|
EconomicEvent(
|
||||||
|
name="FOMC Meeting",
|
||||||
|
event_type="FOMC",
|
||||||
|
datetime=now + timedelta(hours=12),
|
||||||
|
impact="HIGH",
|
||||||
|
country="US",
|
||||||
|
description="Test",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert calendar.is_high_volatility_period(hours_ahead=24) is True
|
||||||
|
|
||||||
|
def test_is_high_volatility_period_returns_false_when_no_events(self):
|
||||||
|
"""Should return False if no high-impact events nearby."""
|
||||||
|
calendar = EconomicCalendar()
|
||||||
|
assert calendar.is_high_volatility_period(hours_ahead=24) is False
|
||||||
|
|
||||||
|
def test_clear_events(self):
|
||||||
|
"""Should clear all events."""
|
||||||
|
calendar = EconomicCalendar()
|
||||||
|
calendar.add_event(
|
||||||
|
EconomicEvent(
|
||||||
|
name="Test",
|
||||||
|
event_type="TEST",
|
||||||
|
datetime=datetime.now(),
|
||||||
|
impact="LOW",
|
||||||
|
country="US",
|
||||||
|
description="Test",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert len(calendar._events) == 1
|
||||||
|
|
||||||
|
calendar.clear_events()
|
||||||
|
assert len(calendar._events) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# MarketData Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMarketData:
|
||||||
|
"""Test market data indicators."""
|
||||||
|
|
||||||
|
def test_market_data_init(self):
|
||||||
|
"""MarketData should initialize correctly."""
|
||||||
|
data = MarketData(api_key="test_key")
|
||||||
|
assert data._api_key == "test_key"
|
||||||
|
|
||||||
|
def test_get_market_sentiment_without_api_key_returns_neutral(self):
|
||||||
|
"""Without API key, should return NEUTRAL sentiment."""
|
||||||
|
data = MarketData(api_key=None)
|
||||||
|
sentiment = data.get_market_sentiment()
|
||||||
|
assert sentiment == MarketSentiment.NEUTRAL
|
||||||
|
|
||||||
|
def test_get_market_breadth_without_api_key_returns_none(self):
|
||||||
|
"""Without API key, should return None for breadth."""
|
||||||
|
data = MarketData(api_key=None)
|
||||||
|
breadth = data.get_market_breadth()
|
||||||
|
assert breadth is None
|
||||||
|
|
||||||
|
def test_get_sector_performance_without_api_key_returns_empty(self):
|
||||||
|
"""Without API key, should return empty list."""
|
||||||
|
data = MarketData(api_key=None)
|
||||||
|
sectors = data.get_sector_performance()
|
||||||
|
assert sectors == []
|
||||||
|
|
||||||
|
def test_get_market_indicators_returns_defaults_without_api(self):
|
||||||
|
"""Should return default indicators without API key."""
|
||||||
|
data = MarketData(api_key=None)
|
||||||
|
indicators = data.get_market_indicators()
|
||||||
|
|
||||||
|
assert indicators.sentiment == MarketSentiment.NEUTRAL
|
||||||
|
assert indicators.breadth.advance_decline_ratio == 1.0
|
||||||
|
assert indicators.sector_performance == []
|
||||||
|
assert indicators.vix_level is None
|
||||||
|
|
||||||
|
def test_calculate_fear_greed_score_neutral_baseline(self):
|
||||||
|
"""Should return neutral score (50) for balanced market."""
|
||||||
|
data = MarketData()
|
||||||
|
breadth = MarketBreadth(
|
||||||
|
advancing_stocks=500,
|
||||||
|
declining_stocks=500,
|
||||||
|
unchanged_stocks=100,
|
||||||
|
new_highs=50,
|
||||||
|
new_lows=50,
|
||||||
|
advance_decline_ratio=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
score = data.calculate_fear_greed_score(breadth)
|
||||||
|
assert score == 50
|
||||||
|
|
||||||
|
def test_calculate_fear_greed_score_greedy_market(self):
|
||||||
|
"""Should return high score for greedy market conditions."""
|
||||||
|
data = MarketData()
|
||||||
|
breadth = MarketBreadth(
|
||||||
|
advancing_stocks=800,
|
||||||
|
declining_stocks=200,
|
||||||
|
unchanged_stocks=100,
|
||||||
|
new_highs=100,
|
||||||
|
new_lows=10,
|
||||||
|
advance_decline_ratio=4.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
score = data.calculate_fear_greed_score(breadth, vix=12.0)
|
||||||
|
assert score > 70
|
||||||
|
|
||||||
|
def test_calculate_fear_greed_score_fearful_market(self):
|
||||||
|
"""Should return low score for fearful market conditions."""
|
||||||
|
data = MarketData()
|
||||||
|
breadth = MarketBreadth(
|
||||||
|
advancing_stocks=200,
|
||||||
|
declining_stocks=800,
|
||||||
|
unchanged_stocks=100,
|
||||||
|
new_highs=10,
|
||||||
|
new_lows=100,
|
||||||
|
advance_decline_ratio=0.25,
|
||||||
|
)
|
||||||
|
|
||||||
|
score = data.calculate_fear_greed_score(breadth, vix=35.0)
|
||||||
|
assert score < 30
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GeminiClient Integration Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestGeminiClientWithExternalData:
|
||||||
|
"""Test GeminiClient integration with external data sources."""
|
||||||
|
|
||||||
|
def test_gemini_client_accepts_optional_data_sources(self, settings):
|
||||||
|
"""GeminiClient should accept optional external data sources."""
|
||||||
|
news_api = NewsAPI(api_key="test_key")
|
||||||
|
calendar = EconomicCalendar()
|
||||||
|
market_data = MarketData()
|
||||||
|
|
||||||
|
client = GeminiClient(
|
||||||
|
settings,
|
||||||
|
news_api=news_api,
|
||||||
|
economic_calendar=calendar,
|
||||||
|
market_data=market_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert client._news_api is news_api
|
||||||
|
assert client._economic_calendar is calendar
|
||||||
|
assert client._market_data is market_data
|
||||||
|
|
||||||
|
def test_gemini_client_works_without_external_data(self, settings):
|
||||||
|
"""GeminiClient should work without external data sources."""
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
assert client._news_api is None
|
||||||
|
assert client._economic_calendar is None
|
||||||
|
assert client._market_data is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_build_prompt_includes_news_sentiment(self, settings):
|
||||||
|
"""build_prompt should include news sentiment when available."""
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
|
||||||
|
market_data = {
|
||||||
|
"stock_code": "AAPL",
|
||||||
|
"current_price": 180.0,
|
||||||
|
"market_name": "US stock market",
|
||||||
|
}
|
||||||
|
|
||||||
|
sentiment = NewsSentiment(
|
||||||
|
stock_code="AAPL",
|
||||||
|
articles=[
|
||||||
|
NewsArticle(
|
||||||
|
title="Apple hits record high",
|
||||||
|
summary="Strong earnings",
|
||||||
|
source="Reuters",
|
||||||
|
published_at="2026-02-04",
|
||||||
|
sentiment_score=0.85,
|
||||||
|
url="https://example.com",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
avg_sentiment=0.85,
|
||||||
|
article_count=1,
|
||||||
|
fetched_at=time.time(),
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = await client.build_prompt(market_data, news_sentiment=sentiment)
|
||||||
|
|
||||||
|
assert "AAPL" in prompt
|
||||||
|
assert "180.0" in prompt
|
||||||
|
assert "EXTERNAL DATA" in prompt
|
||||||
|
assert "News Sentiment" in prompt
|
||||||
|
assert "0.85" in prompt
|
||||||
|
assert "Apple hits record high" in prompt
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_build_prompt_with_economic_events(self, settings):
|
||||||
|
"""build_prompt should include upcoming economic events."""
|
||||||
|
calendar = EconomicCalendar()
|
||||||
|
now = datetime.now()
|
||||||
|
calendar.add_event(
|
||||||
|
EconomicEvent(
|
||||||
|
name="FOMC Meeting",
|
||||||
|
event_type="FOMC",
|
||||||
|
datetime=now + timedelta(days=2),
|
||||||
|
impact="HIGH",
|
||||||
|
country="US",
|
||||||
|
description="Interest rate decision",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
client = GeminiClient(settings, economic_calendar=calendar)
|
||||||
|
|
||||||
|
market_data = {
|
||||||
|
"stock_code": "AAPL",
|
||||||
|
"current_price": 180.0,
|
||||||
|
"market_name": "US stock market",
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt = await client.build_prompt(market_data)
|
||||||
|
|
||||||
|
assert "EXTERNAL DATA" in prompt
|
||||||
|
assert "High-Impact Events" in prompt
|
||||||
|
assert "FOMC Meeting" in prompt
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_build_prompt_with_market_indicators(self, settings):
|
||||||
|
"""build_prompt should include market sentiment indicators."""
|
||||||
|
market_data_provider = MarketData(api_key="test_key")
|
||||||
|
|
||||||
|
# Mock the get_market_indicators to return test data
|
||||||
|
with patch.object(market_data_provider, "get_market_indicators") as mock:
|
||||||
|
mock.return_value = MagicMock(
|
||||||
|
sentiment=MarketSentiment.EXTREME_GREED,
|
||||||
|
breadth=MagicMock(advance_decline_ratio=2.5),
|
||||||
|
)
|
||||||
|
|
||||||
|
client = GeminiClient(settings, market_data=market_data_provider)
|
||||||
|
|
||||||
|
market_data = {
|
||||||
|
"stock_code": "AAPL",
|
||||||
|
"current_price": 180.0,
|
||||||
|
"market_name": "US stock market",
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt = await client.build_prompt(market_data)
|
||||||
|
|
||||||
|
assert "EXTERNAL DATA" in prompt
|
||||||
|
assert "Market Sentiment" in prompt
|
||||||
|
assert "EXTREME_GREED" in prompt
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_build_prompt_graceful_when_no_external_data(self, settings):
|
||||||
|
"""build_prompt should work gracefully without external data."""
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
|
||||||
|
market_data = {
|
||||||
|
"stock_code": "AAPL",
|
||||||
|
"current_price": 180.0,
|
||||||
|
"market_name": "US stock market",
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt = await client.build_prompt(market_data)
|
||||||
|
|
||||||
|
assert "AAPL" in prompt
|
||||||
|
assert "180.0" in prompt
|
||||||
|
# Should NOT have external data section
|
||||||
|
assert "EXTERNAL DATA" not in prompt
|
||||||
|
|
||||||
|
def test_build_prompt_sync_backward_compatibility(self, settings):
|
||||||
|
"""build_prompt_sync should maintain backward compatibility."""
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
|
||||||
|
market_data = {
|
||||||
|
"stock_code": "005930",
|
||||||
|
"current_price": 72000,
|
||||||
|
"orderbook": {"asks": [], "bids": []},
|
||||||
|
"foreigner_net": -50000,
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt = client.build_prompt_sync(market_data)
|
||||||
|
|
||||||
|
assert "005930" in prompt
|
||||||
|
assert "72000" in prompt
|
||||||
|
assert "JSON" in prompt
|
||||||
|
# Sync version should NOT have external data
|
||||||
|
assert "EXTERNAL DATA" not in prompt
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decide_with_news_sentiment_parameter(self, settings):
|
||||||
|
"""decide should accept optional news_sentiment parameter."""
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
|
||||||
|
market_data = {
|
||||||
|
"stock_code": "AAPL",
|
||||||
|
"current_price": 180.0,
|
||||||
|
"market_name": "US stock market",
|
||||||
|
}
|
||||||
|
|
||||||
|
sentiment = NewsSentiment(
|
||||||
|
stock_code="AAPL",
|
||||||
|
articles=[],
|
||||||
|
avg_sentiment=0.5,
|
||||||
|
article_count=1,
|
||||||
|
fetched_at=time.time(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the Gemini API call
|
||||||
|
with patch.object(client._client.aio.models, "generate_content", new_callable=AsyncMock) as mock_gen:
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.text = '{"action": "BUY", "confidence": 85, "rationale": "Good news"}'
|
||||||
|
mock_gen.return_value = mock_response
|
||||||
|
|
||||||
|
decision = await client.decide(market_data, news_sentiment=sentiment)
|
||||||
|
|
||||||
|
assert decision.action == "BUY"
|
||||||
|
assert decision.confidence == 85
|
||||||
|
mock_gen.assert_called_once()
|
||||||
@@ -11,15 +11,15 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import tempfile
|
import tempfile
|
||||||
from datetime import UTC, datetime, timedelta
|
from datetime import UTC, datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from src.config import Settings
|
from src.config import Settings
|
||||||
from src.db import init_db, log_trade
|
from src.db import init_db, log_trade
|
||||||
from src.evolution.ab_test import ABTester, ABTestResult, StrategyPerformance
|
from src.evolution.ab_test import ABTester
|
||||||
from src.evolution.optimizer import EvolutionOptimizer
|
from src.evolution.optimizer import EvolutionOptimizer
|
||||||
from src.evolution.performance_tracker import (
|
from src.evolution.performance_tracker import (
|
||||||
PerformanceDashboard,
|
PerformanceDashboard,
|
||||||
@@ -28,7 +28,6 @@ from src.evolution.performance_tracker import (
|
|||||||
)
|
)
|
||||||
from src.logging.decision_logger import DecisionLogger
|
from src.logging.decision_logger import DecisionLogger
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Fixtures
|
# Fixtures
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|||||||
663
tests/test_token_efficiency.py
Normal file
663
tests/test_token_efficiency.py
Normal file
@@ -0,0 +1,663 @@
|
|||||||
|
"""Tests for token efficiency optimization components.
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- Prompt compression and optimization
|
||||||
|
- Context selection logic
|
||||||
|
- Summarization
|
||||||
|
- Caching
|
||||||
|
- Token reduction metrics
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.brain.cache import DecisionCache
|
||||||
|
from src.brain.context_selector import ContextSelector, DecisionType
|
||||||
|
from src.brain.gemini_client import TradeDecision
|
||||||
|
from src.brain.prompt_optimizer import PromptOptimizer, TokenMetrics
|
||||||
|
from src.context.layer import ContextLayer
|
||||||
|
from src.context.store import ContextStore
|
||||||
|
from src.context.summarizer import ContextSummarizer, SummaryStats
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Prompt Optimizer Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptOptimizer:
|
||||||
|
"""Tests for PromptOptimizer."""
|
||||||
|
|
||||||
|
def test_estimate_tokens(self):
|
||||||
|
"""Test token estimation."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
# Empty text
|
||||||
|
assert optimizer.estimate_tokens("") == 0
|
||||||
|
|
||||||
|
# Short text (4 chars = 1 token estimate)
|
||||||
|
assert optimizer.estimate_tokens("test") == 1
|
||||||
|
|
||||||
|
# Longer text
|
||||||
|
text = "This is a longer piece of text for testing token estimation."
|
||||||
|
tokens = optimizer.estimate_tokens(text)
|
||||||
|
assert tokens > 0
|
||||||
|
assert tokens == len(text) // 4
|
||||||
|
|
||||||
|
def test_count_tokens(self):
|
||||||
|
"""Test token counting metrics."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
text = "Hello world, this is a test."
|
||||||
|
metrics = optimizer.count_tokens(text)
|
||||||
|
|
||||||
|
assert isinstance(metrics, TokenMetrics)
|
||||||
|
assert metrics.char_count == len(text)
|
||||||
|
assert metrics.word_count == 6
|
||||||
|
assert metrics.estimated_tokens > 0
|
||||||
|
|
||||||
|
def test_compress_json(self):
|
||||||
|
"""Test JSON compression."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"action": "BUY",
|
||||||
|
"confidence": 85,
|
||||||
|
"rationale": "Strong uptrend",
|
||||||
|
}
|
||||||
|
|
||||||
|
compressed = optimizer.compress_json(data)
|
||||||
|
|
||||||
|
# Should have no newlines and minimal whitespace
|
||||||
|
assert "\n" not in compressed
|
||||||
|
# Note: JSON values may contain spaces (e.g., "Strong uptrend")
|
||||||
|
# but there should be no spaces around separators
|
||||||
|
assert ": " not in compressed
|
||||||
|
assert ", " not in compressed
|
||||||
|
|
||||||
|
# Should be valid JSON
|
||||||
|
import json
|
||||||
|
|
||||||
|
parsed = json.loads(compressed)
|
||||||
|
assert parsed == data
|
||||||
|
|
||||||
|
def test_abbreviate_text(self):
|
||||||
|
"""Test text abbreviation."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
text = "The current price is high and volume is increasing."
|
||||||
|
abbreviated = optimizer.abbreviate_text(text)
|
||||||
|
|
||||||
|
# Should contain abbreviations
|
||||||
|
assert "cur" in abbreviated or "P" in abbreviated
|
||||||
|
assert len(abbreviated) <= len(text)
|
||||||
|
|
||||||
|
def test_abbreviate_text_aggressive(self):
|
||||||
|
"""Test aggressive text abbreviation."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
text = "The price is increasing and the volume is high."
|
||||||
|
abbreviated = optimizer.abbreviate_text(text, aggressive=True)
|
||||||
|
|
||||||
|
# Should be shorter
|
||||||
|
assert len(abbreviated) < len(text)
|
||||||
|
|
||||||
|
# Should have removed articles
|
||||||
|
assert "the" not in abbreviated.lower()
|
||||||
|
|
||||||
|
def test_build_compressed_prompt(self):
|
||||||
|
"""Test compressed prompt building."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
market_data = {
|
||||||
|
"stock_code": "005930",
|
||||||
|
"current_price": 75000,
|
||||||
|
"market_name": "Korean stock market",
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt = optimizer.build_compressed_prompt(market_data)
|
||||||
|
|
||||||
|
# Should be much shorter than original
|
||||||
|
assert len(prompt) < 300
|
||||||
|
assert "005930" in prompt
|
||||||
|
assert "75000" in prompt
|
||||||
|
|
||||||
|
def test_build_compressed_prompt_no_instructions(self):
|
||||||
|
"""Test compressed prompt without instructions."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
market_data = {
|
||||||
|
"stock_code": "AAPL",
|
||||||
|
"current_price": 150.5,
|
||||||
|
"market_name": "United States",
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt = optimizer.build_compressed_prompt(market_data, include_instructions=False)
|
||||||
|
|
||||||
|
# Should be very short (data only)
|
||||||
|
assert len(prompt) < 100
|
||||||
|
assert "AAPL" in prompt
|
||||||
|
|
||||||
|
def test_truncate_context(self):
|
||||||
|
"""Test context truncation."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
context = {
|
||||||
|
"price": 100.5,
|
||||||
|
"volume": 1000000,
|
||||||
|
"sentiment": 0.8,
|
||||||
|
"extra_data": "Some long text that should be truncated",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Truncate to small budget
|
||||||
|
truncated = optimizer.truncate_context(context, max_tokens=10)
|
||||||
|
|
||||||
|
# Should have fewer keys
|
||||||
|
assert len(truncated) <= len(context)
|
||||||
|
|
||||||
|
def test_truncate_context_with_priority(self):
|
||||||
|
"""Test context truncation with priority keys."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
context = {
|
||||||
|
"price": 100.5,
|
||||||
|
"volume": 1000000,
|
||||||
|
"sentiment": 0.8,
|
||||||
|
"extra_data": "Some data",
|
||||||
|
}
|
||||||
|
|
||||||
|
priority_keys = ["price", "sentiment"]
|
||||||
|
truncated = optimizer.truncate_context(context, max_tokens=20, priority_keys=priority_keys)
|
||||||
|
|
||||||
|
# Priority keys should be included
|
||||||
|
assert "price" in truncated
|
||||||
|
assert "sentiment" in truncated
|
||||||
|
|
||||||
|
def test_calculate_compression_ratio(self):
|
||||||
|
"""Test compression ratio calculation."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
original = "This is a very long piece of text that should be compressed significantly."
|
||||||
|
compressed = "Short text"
|
||||||
|
|
||||||
|
ratio = optimizer.calculate_compression_ratio(original, compressed)
|
||||||
|
|
||||||
|
# Ratio should be > 1 (original is longer)
|
||||||
|
assert ratio > 1.0
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Context Selector Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextSelector:
|
||||||
|
"""Tests for ContextSelector."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def store(self):
|
||||||
|
"""Create in-memory ContextStore."""
|
||||||
|
conn = sqlite3.connect(":memory:")
|
||||||
|
# Create tables
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE context_metadata (
|
||||||
|
layer TEXT PRIMARY KEY,
|
||||||
|
description TEXT,
|
||||||
|
retention_days INTEGER,
|
||||||
|
aggregation_source TEXT
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE contexts (
|
||||||
|
layer TEXT,
|
||||||
|
timeframe TEXT,
|
||||||
|
key TEXT,
|
||||||
|
value TEXT,
|
||||||
|
created_at TEXT,
|
||||||
|
updated_at TEXT,
|
||||||
|
PRIMARY KEY (layer, timeframe, key)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
return ContextStore(conn)
|
||||||
|
|
||||||
|
def test_select_layers_normal(self, store):
|
||||||
|
"""Test layer selection for normal decisions."""
|
||||||
|
selector = ContextSelector(store)
|
||||||
|
|
||||||
|
layers = selector.select_layers(DecisionType.NORMAL)
|
||||||
|
|
||||||
|
# Should only select L7 (real-time)
|
||||||
|
assert layers == [ContextLayer.L7_REALTIME]
|
||||||
|
|
||||||
|
def test_select_layers_strategic(self, store):
|
||||||
|
"""Test layer selection for strategic decisions."""
|
||||||
|
selector = ContextSelector(store)
|
||||||
|
|
||||||
|
layers = selector.select_layers(DecisionType.STRATEGIC)
|
||||||
|
|
||||||
|
# Should select L7 + L6 + L5
|
||||||
|
assert ContextLayer.L7_REALTIME in layers
|
||||||
|
assert ContextLayer.L6_DAILY in layers
|
||||||
|
assert ContextLayer.L5_WEEKLY in layers
|
||||||
|
assert len(layers) == 3
|
||||||
|
|
||||||
|
def test_select_layers_major_event(self, store):
|
||||||
|
"""Test layer selection for major events."""
|
||||||
|
selector = ContextSelector(store)
|
||||||
|
|
||||||
|
layers = selector.select_layers(DecisionType.MAJOR_EVENT)
|
||||||
|
|
||||||
|
# Should select all layers
|
||||||
|
assert len(layers) == 7
|
||||||
|
assert ContextLayer.L1_LEGACY in layers
|
||||||
|
assert ContextLayer.L7_REALTIME in layers
|
||||||
|
|
||||||
|
def test_score_layer_relevance(self, store):
|
||||||
|
"""Test layer relevance scoring."""
|
||||||
|
selector = ContextSelector(store)
|
||||||
|
|
||||||
|
# Add some data first so scores aren't penalized
|
||||||
|
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
|
||||||
|
store.set_context(ContextLayer.L1_LEGACY, "legacy", "lesson", "test")
|
||||||
|
|
||||||
|
# L7 should have high score for normal decisions
|
||||||
|
score = selector.score_layer_relevance(ContextLayer.L7_REALTIME, DecisionType.NORMAL)
|
||||||
|
assert score == 1.0
|
||||||
|
|
||||||
|
# L1 should have low score for normal decisions
|
||||||
|
score = selector.score_layer_relevance(ContextLayer.L1_LEGACY, DecisionType.NORMAL)
|
||||||
|
assert score == 0.0
|
||||||
|
|
||||||
|
# L1 should have high score for major events
|
||||||
|
score = selector.score_layer_relevance(ContextLayer.L1_LEGACY, DecisionType.MAJOR_EVENT)
|
||||||
|
assert score == 1.0
|
||||||
|
|
||||||
|
def test_select_with_scoring(self, store):
|
||||||
|
"""Test selection with relevance scoring."""
|
||||||
|
selector = ContextSelector(store)
|
||||||
|
|
||||||
|
# Add data so layers aren't penalized
|
||||||
|
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
|
||||||
|
|
||||||
|
selection = selector.select_with_scoring(DecisionType.NORMAL, min_score=0.5)
|
||||||
|
|
||||||
|
# Should only select high-relevance layers
|
||||||
|
assert len(selection.layers) >= 1
|
||||||
|
assert ContextLayer.L7_REALTIME in selection.layers
|
||||||
|
assert all(selection.relevance_scores[layer] >= 0.5 for layer in selection.layers)
|
||||||
|
|
||||||
|
def test_get_context_data(self, store):
|
||||||
|
"""Test context data retrieval."""
|
||||||
|
selector = ContextSelector(store)
|
||||||
|
|
||||||
|
# Add some test data
|
||||||
|
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
|
||||||
|
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "volume", 1000000)
|
||||||
|
|
||||||
|
context_data = selector.get_context_data([ContextLayer.L7_REALTIME])
|
||||||
|
|
||||||
|
# Should retrieve data
|
||||||
|
assert "L7_REALTIME" in context_data
|
||||||
|
assert "price" in context_data["L7_REALTIME"]
|
||||||
|
assert context_data["L7_REALTIME"]["price"] == 100.5
|
||||||
|
|
||||||
|
def test_estimate_context_tokens(self, store):
|
||||||
|
"""Test context token estimation."""
|
||||||
|
selector = ContextSelector(store)
|
||||||
|
|
||||||
|
context_data = {
|
||||||
|
"L7_REALTIME": {"price": 100.5, "volume": 1000000},
|
||||||
|
"L6_DAILY": {"avg_price": 99.8, "avg_volume": 950000},
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens = selector.estimate_context_tokens(context_data)
|
||||||
|
|
||||||
|
# Should estimate tokens
|
||||||
|
assert tokens > 0
|
||||||
|
|
||||||
|
def test_optimize_context_for_budget(self, store):
|
||||||
|
"""Test context optimization for token budget."""
|
||||||
|
selector = ContextSelector(store)
|
||||||
|
|
||||||
|
# Add test data
|
||||||
|
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
|
||||||
|
|
||||||
|
# Get optimized context within budget
|
||||||
|
context = selector.optimize_context_for_budget(DecisionType.NORMAL, max_tokens=50)
|
||||||
|
|
||||||
|
# Should return data within budget
|
||||||
|
tokens = selector.estimate_context_tokens(context)
|
||||||
|
assert tokens <= 50
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Context Summarizer Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextSummarizer:
|
||||||
|
"""Tests for ContextSummarizer."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def store(self):
|
||||||
|
"""Create in-memory ContextStore."""
|
||||||
|
conn = sqlite3.connect(":memory:")
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE context_metadata (
|
||||||
|
layer TEXT PRIMARY KEY,
|
||||||
|
description TEXT,
|
||||||
|
retention_days INTEGER,
|
||||||
|
aggregation_source TEXT
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE contexts (
|
||||||
|
layer TEXT,
|
||||||
|
timeframe TEXT,
|
||||||
|
key TEXT,
|
||||||
|
value TEXT,
|
||||||
|
created_at TEXT,
|
||||||
|
updated_at TEXT,
|
||||||
|
PRIMARY KEY (layer, timeframe, key)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
return ContextStore(conn)
|
||||||
|
|
||||||
|
def test_summarize_numeric_values(self, store):
|
||||||
|
"""Test numeric value summarization."""
|
||||||
|
summarizer = ContextSummarizer(store)
|
||||||
|
|
||||||
|
values = [10.0, 20.0, 30.0, 40.0, 50.0]
|
||||||
|
stats = summarizer.summarize_numeric_values(values)
|
||||||
|
|
||||||
|
assert isinstance(stats, SummaryStats)
|
||||||
|
assert stats.count == 5
|
||||||
|
assert stats.mean == 30.0
|
||||||
|
assert stats.min == 10.0
|
||||||
|
assert stats.max == 50.0
|
||||||
|
assert stats.std is not None
|
||||||
|
|
||||||
|
def test_summarize_numeric_values_trend(self, store):
|
||||||
|
"""Test trend detection in numeric values."""
|
||||||
|
summarizer = ContextSummarizer(store)
|
||||||
|
|
||||||
|
# Uptrend
|
||||||
|
values_up = [10.0, 15.0, 20.0, 25.0, 30.0, 35.0]
|
||||||
|
stats_up = summarizer.summarize_numeric_values(values_up)
|
||||||
|
assert stats_up.trend == "up"
|
||||||
|
|
||||||
|
# Downtrend
|
||||||
|
values_down = [35.0, 30.0, 25.0, 20.0, 15.0, 10.0]
|
||||||
|
stats_down = summarizer.summarize_numeric_values(values_down)
|
||||||
|
assert stats_down.trend == "down"
|
||||||
|
|
||||||
|
# Flat
|
||||||
|
values_flat = [20.0, 20.1, 19.9, 20.0, 20.1, 19.9]
|
||||||
|
stats_flat = summarizer.summarize_numeric_values(values_flat)
|
||||||
|
assert stats_flat.trend == "flat"
|
||||||
|
|
||||||
|
def test_summarize_layer(self, store):
|
||||||
|
"""Test layer summarization."""
|
||||||
|
summarizer = ContextSummarizer(store)
|
||||||
|
|
||||||
|
# Add test data
|
||||||
|
store.set_context(ContextLayer.L6_DAILY, "2026-02-04", "price", 100.5)
|
||||||
|
store.set_context(ContextLayer.L6_DAILY, "2026-02-04", "volume", 1000000)
|
||||||
|
|
||||||
|
summary = summarizer.summarize_layer(ContextLayer.L6_DAILY)
|
||||||
|
|
||||||
|
# Should have summary
|
||||||
|
assert "total_entries" in summary
|
||||||
|
assert summary["total_entries"] > 0
|
||||||
|
|
||||||
|
def test_create_compact_summary(self, store):
|
||||||
|
"""Test compact summary creation."""
|
||||||
|
summarizer = ContextSummarizer(store)
|
||||||
|
|
||||||
|
# Add test data
|
||||||
|
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
|
||||||
|
|
||||||
|
layers = [ContextLayer.L7_REALTIME, ContextLayer.L6_DAILY]
|
||||||
|
summary = summarizer.create_compact_summary(layers, top_n_metrics=3)
|
||||||
|
|
||||||
|
# Should have summaries for layers
|
||||||
|
assert "L7_REALTIME" in summary
|
||||||
|
|
||||||
|
def test_format_summary_for_prompt(self, store):
|
||||||
|
"""Test summary formatting for prompt."""
|
||||||
|
summarizer = ContextSummarizer(store)
|
||||||
|
|
||||||
|
summary = {
|
||||||
|
"L7_REALTIME": {
|
||||||
|
"price": {"avg": 100.5, "trend": "up"},
|
||||||
|
"volume": {"avg": 1000000, "trend": "flat"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
formatted = summarizer.format_summary_for_prompt(summary)
|
||||||
|
|
||||||
|
# Should be formatted string
|
||||||
|
assert isinstance(formatted, str)
|
||||||
|
assert "L7_REALTIME" in formatted
|
||||||
|
assert "100.5" in formatted or "100.50" in formatted
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Decision Cache Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestDecisionCache:
|
||||||
|
"""Tests for DecisionCache."""
|
||||||
|
|
||||||
|
def test_cache_init(self):
|
||||||
|
"""Test cache initialization."""
|
||||||
|
cache = DecisionCache(ttl_seconds=60, max_size=100)
|
||||||
|
|
||||||
|
assert cache.ttl_seconds == 60
|
||||||
|
assert cache.max_size == 100
|
||||||
|
|
||||||
|
def test_cache_miss(self):
|
||||||
|
"""Test cache miss."""
|
||||||
|
cache = DecisionCache()
|
||||||
|
|
||||||
|
market_data = {"stock_code": "005930", "current_price": 75000}
|
||||||
|
|
||||||
|
decision = cache.get(market_data)
|
||||||
|
|
||||||
|
# Should be None (cache miss)
|
||||||
|
assert decision is None
|
||||||
|
|
||||||
|
metrics = cache.get_metrics()
|
||||||
|
assert metrics.cache_misses == 1
|
||||||
|
assert metrics.cache_hits == 0
|
||||||
|
|
||||||
|
def test_cache_hit(self):
|
||||||
|
"""Test cache hit."""
|
||||||
|
cache = DecisionCache()
|
||||||
|
|
||||||
|
market_data = {"stock_code": "005930", "current_price": 75000}
|
||||||
|
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||||
|
|
||||||
|
# Set cache
|
||||||
|
cache.set(market_data, decision)
|
||||||
|
|
||||||
|
# Get from cache
|
||||||
|
cached = cache.get(market_data)
|
||||||
|
|
||||||
|
assert cached is not None
|
||||||
|
assert cached.action == "HOLD"
|
||||||
|
assert cached.confidence == 50
|
||||||
|
|
||||||
|
metrics = cache.get_metrics()
|
||||||
|
assert metrics.cache_hits == 1
|
||||||
|
|
||||||
|
def test_cache_ttl_expiration(self):
|
||||||
|
"""Test cache TTL expiration."""
|
||||||
|
cache = DecisionCache(ttl_seconds=1) # 1 second TTL
|
||||||
|
|
||||||
|
market_data = {"stock_code": "005930", "current_price": 75000}
|
||||||
|
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||||
|
|
||||||
|
# Set cache
|
||||||
|
cache.set(market_data, decision)
|
||||||
|
|
||||||
|
# Should hit immediately
|
||||||
|
cached = cache.get(market_data)
|
||||||
|
assert cached is not None
|
||||||
|
|
||||||
|
# Wait for expiration
|
||||||
|
time.sleep(1.1)
|
||||||
|
|
||||||
|
# Should miss after expiration
|
||||||
|
cached = cache.get(market_data)
|
||||||
|
assert cached is None
|
||||||
|
|
||||||
|
def test_cache_max_size(self):
|
||||||
|
"""Test cache max size eviction."""
|
||||||
|
cache = DecisionCache(max_size=2)
|
||||||
|
|
||||||
|
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||||
|
|
||||||
|
# Add 3 entries (exceeds max_size)
|
||||||
|
for i in range(3):
|
||||||
|
market_data = {"stock_code": f"00{i}", "current_price": 1000 * i}
|
||||||
|
cache.set(market_data, decision)
|
||||||
|
|
||||||
|
metrics = cache.get_metrics()
|
||||||
|
|
||||||
|
# Should have evicted 1 entry
|
||||||
|
assert metrics.total_entries == 2
|
||||||
|
assert metrics.evictions == 1
|
||||||
|
|
||||||
|
def test_invalidate_all(self):
|
||||||
|
"""Test invalidate all cache entries."""
|
||||||
|
cache = DecisionCache()
|
||||||
|
|
||||||
|
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||||
|
|
||||||
|
# Add entries
|
||||||
|
for i in range(3):
|
||||||
|
market_data = {"stock_code": f"00{i}", "current_price": 1000}
|
||||||
|
cache.set(market_data, decision)
|
||||||
|
|
||||||
|
# Invalidate all
|
||||||
|
count = cache.invalidate()
|
||||||
|
|
||||||
|
assert count == 3
|
||||||
|
|
||||||
|
metrics = cache.get_metrics()
|
||||||
|
assert metrics.total_entries == 0
|
||||||
|
|
||||||
|
def test_invalidate_by_stock(self):
|
||||||
|
"""Test invalidate cache by stock code."""
|
||||||
|
cache = DecisionCache()
|
||||||
|
|
||||||
|
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||||
|
|
||||||
|
# Add entries for different stocks
|
||||||
|
cache.set({"stock_code": "005930", "current_price": 75000}, decision)
|
||||||
|
cache.set({"stock_code": "000660", "current_price": 50000}, decision)
|
||||||
|
|
||||||
|
# Invalidate specific stock
|
||||||
|
count = cache.invalidate("005930")
|
||||||
|
|
||||||
|
assert count >= 1
|
||||||
|
|
||||||
|
# Other stock should still be cached
|
||||||
|
cached = cache.get({"stock_code": "000660", "current_price": 50000})
|
||||||
|
assert cached is not None
|
||||||
|
|
||||||
|
def test_cleanup_expired(self):
|
||||||
|
"""Test cleanup of expired entries."""
|
||||||
|
cache = DecisionCache(ttl_seconds=1)
|
||||||
|
|
||||||
|
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||||
|
|
||||||
|
# Add entry
|
||||||
|
cache.set({"stock_code": "005930", "current_price": 75000}, decision)
|
||||||
|
|
||||||
|
# Wait for expiration
|
||||||
|
time.sleep(1.1)
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
count = cache.cleanup_expired()
|
||||||
|
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
metrics = cache.get_metrics()
|
||||||
|
assert metrics.total_entries == 0
|
||||||
|
|
||||||
|
def test_should_cache_decision(self):
|
||||||
|
"""Test decision caching criteria."""
|
||||||
|
cache = DecisionCache()
|
||||||
|
|
||||||
|
# HOLD decisions should be cached
|
||||||
|
hold_decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||||
|
assert cache.should_cache_decision(hold_decision) is True
|
||||||
|
|
||||||
|
# High confidence BUY should be cached
|
||||||
|
buy_decision = TradeDecision(action="BUY", confidence=95, rationale="Test")
|
||||||
|
assert cache.should_cache_decision(buy_decision) is True
|
||||||
|
|
||||||
|
# Low confidence BUY should not be cached
|
||||||
|
low_conf_buy = TradeDecision(action="BUY", confidence=60, rationale="Test")
|
||||||
|
assert cache.should_cache_decision(low_conf_buy) is False
|
||||||
|
|
||||||
|
def test_cache_hit_rate(self):
|
||||||
|
"""Test cache hit rate calculation."""
|
||||||
|
cache = DecisionCache()
|
||||||
|
|
||||||
|
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||||
|
market_data = {"stock_code": "005930", "current_price": 75000}
|
||||||
|
|
||||||
|
# First request (miss)
|
||||||
|
cache.get(market_data)
|
||||||
|
|
||||||
|
# Set cache
|
||||||
|
cache.set(market_data, decision)
|
||||||
|
|
||||||
|
# Second request (hit)
|
||||||
|
cache.get(market_data)
|
||||||
|
|
||||||
|
# Third request (hit)
|
||||||
|
cache.get(market_data)
|
||||||
|
|
||||||
|
metrics = cache.get_metrics()
|
||||||
|
|
||||||
|
# 1 miss, 2 hits out of 3 requests
|
||||||
|
assert metrics.total_requests == 3
|
||||||
|
assert metrics.cache_hits == 2
|
||||||
|
assert metrics.cache_misses == 1
|
||||||
|
assert metrics.hit_rate == pytest.approx(2 / 3)
|
||||||
|
|
||||||
|
def test_reset_metrics(self):
|
||||||
|
"""Test metrics reset."""
|
||||||
|
cache = DecisionCache()
|
||||||
|
|
||||||
|
market_data = {"stock_code": "005930", "current_price": 75000}
|
||||||
|
|
||||||
|
# Generate some activity
|
||||||
|
cache.get(market_data)
|
||||||
|
cache.get(market_data)
|
||||||
|
|
||||||
|
# Reset
|
||||||
|
cache.reset_metrics()
|
||||||
|
|
||||||
|
metrics = cache.get_metrics()
|
||||||
|
assert metrics.total_requests == 0
|
||||||
|
assert metrics.cache_hits == 0
|
||||||
|
assert metrics.cache_misses == 0
|
||||||
Reference in New Issue
Block a user