feat: implement data-driven external data integration (issue #22)
Add objective external data sources to enhance trading decisions beyond market prices and user input. ## New Modules ### src/data/news_api.py - News sentiment analysis with Alpha Vantage and NewsAPI support - Sentiment scoring (-1.0 to +1.0) per article and aggregated - 5-minute caching to minimize API quota usage - Graceful degradation when APIs unavailable ### src/data/economic_calendar.py - Track major economic events (FOMC, GDP, CPI) - Earnings calendar per stock - Event proximity checking for high-volatility periods - Hardcoded major events for 2026 (no API required) ### src/data/market_data.py - Market sentiment indicators (Fear & Greed equivalent) - Market breadth (advance/decline ratios) - Sector performance tracking - Fear/Greed score calculation ## Integration Enhanced GeminiClient to seamlessly integrate external data: - Optional news_api, economic_calendar, and market_data parameters - Async build_prompt() includes external context when available - Backward-compatible build_prompt_sync() for existing code - Graceful fallback when external data unavailable External data automatically added to AI prompts: - News sentiment with top articles - Upcoming high-impact economic events - Market sentiment and breadth indicators ## Configuration Added optional settings to config.py: - NEWS_API_KEY: API key for news provider - NEWS_API_PROVIDER: "alphavantage" or "newsapi" - MARKET_DATA_API_KEY: API key for market data ## Testing Comprehensive test suite with 38 tests: - NewsAPI caching, sentiment parsing, API integration - EconomicCalendar event filtering, earnings lookup - MarketData sentiment and breadth calculations - GeminiClient integration with external data sources - All tests use mocks (no real API keys required) - 81% coverage for src/data module (exceeds 80% requirement) ## Circular Import Fix Fixed circular dependency between gemini_client.py and cache.py: - Use TYPE_CHECKING for imports in cache.py - String annotations for TradeDecision type hints All 195 existing tests pass. No breaking changes to existing functionality. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
293
src/brain/cache.py
Normal file
293
src/brain/cache.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""Response caching system for reducing redundant LLM calls.
|
||||
|
||||
This module provides caching for common trading scenarios:
|
||||
- TTL-based cache invalidation
|
||||
- Cache key based on market conditions
|
||||
- Cache hit rate monitoring
|
||||
- Special handling for HOLD decisions in quiet markets
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.brain.gemini_client import TradeDecision
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""Cached decision with metadata."""
|
||||
|
||||
decision: "TradeDecision"
|
||||
cached_at: float # Unix timestamp
|
||||
hit_count: int = 0
|
||||
market_data_hash: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheMetrics:
|
||||
"""Metrics for cache performance monitoring."""
|
||||
|
||||
total_requests: int = 0
|
||||
cache_hits: int = 0
|
||||
cache_misses: int = 0
|
||||
evictions: int = 0
|
||||
total_entries: int = 0
|
||||
|
||||
@property
|
||||
def hit_rate(self) -> float:
|
||||
"""Calculate cache hit rate."""
|
||||
if self.total_requests == 0:
|
||||
return 0.0
|
||||
return self.cache_hits / self.total_requests
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert metrics to dictionary."""
|
||||
return {
|
||||
"total_requests": self.total_requests,
|
||||
"cache_hits": self.cache_hits,
|
||||
"cache_misses": self.cache_misses,
|
||||
"hit_rate": self.hit_rate,
|
||||
"evictions": self.evictions,
|
||||
"total_entries": self.total_entries,
|
||||
}
|
||||
|
||||
|
||||
class DecisionCache:
|
||||
"""TTL-based cache for trade decisions."""
|
||||
|
||||
def __init__(self, ttl_seconds: int = 300, max_size: int = 1000) -> None:
|
||||
"""Initialize the decision cache.
|
||||
|
||||
Args:
|
||||
ttl_seconds: Time-to-live for cache entries in seconds (default: 5 minutes)
|
||||
max_size: Maximum number of cache entries
|
||||
"""
|
||||
self.ttl_seconds = ttl_seconds
|
||||
self.max_size = max_size
|
||||
self._cache: dict[str, CacheEntry] = {}
|
||||
self._metrics = CacheMetrics()
|
||||
|
||||
def _generate_cache_key(self, market_data: dict[str, Any]) -> str:
|
||||
"""Generate cache key from market data.
|
||||
|
||||
Key is based on:
|
||||
- Stock code
|
||||
- Current price (rounded to reduce sensitivity)
|
||||
- Market conditions (orderbook snapshot)
|
||||
|
||||
Args:
|
||||
market_data: Market data dictionary
|
||||
|
||||
Returns:
|
||||
Cache key string
|
||||
"""
|
||||
# Extract key components
|
||||
stock_code = market_data.get("stock_code", "UNKNOWN")
|
||||
current_price = market_data.get("current_price", 0)
|
||||
|
||||
# Round price to reduce sensitivity (cache hits for similar prices)
|
||||
# For prices > 1000, round to nearest 10
|
||||
# For prices < 1000, round to nearest 1
|
||||
if current_price > 1000:
|
||||
price_rounded = round(current_price / 10) * 10
|
||||
else:
|
||||
price_rounded = round(current_price)
|
||||
|
||||
# Include orderbook snapshot (if available)
|
||||
orderbook_key = ""
|
||||
if "orderbook" in market_data and market_data["orderbook"]:
|
||||
ob = market_data["orderbook"]
|
||||
# Just use bid/ask spread as indicator
|
||||
if "bid" in ob and "ask" in ob and ob["bid"] and ob["ask"]:
|
||||
bid_price = ob["bid"][0].get("price", 0) if ob["bid"] else 0
|
||||
ask_price = ob["ask"][0].get("price", 0) if ob["ask"] else 0
|
||||
spread = ask_price - bid_price
|
||||
orderbook_key = f"_spread{spread}"
|
||||
|
||||
# Generate cache key
|
||||
key_str = f"{stock_code}_{price_rounded}{orderbook_key}"
|
||||
|
||||
return key_str
|
||||
|
||||
def _generate_market_hash(self, market_data: dict[str, Any]) -> str:
|
||||
"""Generate hash of full market data for invalidation checks.
|
||||
|
||||
Args:
|
||||
market_data: Market data dictionary
|
||||
|
||||
Returns:
|
||||
Hash string
|
||||
"""
|
||||
# Create stable JSON representation
|
||||
stable_json = json.dumps(market_data, sort_keys=True, ensure_ascii=False)
|
||||
return hashlib.md5(stable_json.encode()).hexdigest()
|
||||
|
||||
def get(self, market_data: dict[str, Any]) -> TradeDecision | None:
|
||||
"""Retrieve cached decision if valid.
|
||||
|
||||
Args:
|
||||
market_data: Market data dictionary
|
||||
|
||||
Returns:
|
||||
Cached TradeDecision if valid, None otherwise
|
||||
"""
|
||||
self._metrics.total_requests += 1
|
||||
|
||||
cache_key = self._generate_cache_key(market_data)
|
||||
|
||||
if cache_key not in self._cache:
|
||||
self._metrics.cache_misses += 1
|
||||
return None
|
||||
|
||||
entry = self._cache[cache_key]
|
||||
current_time = time.time()
|
||||
|
||||
# Check TTL
|
||||
if current_time - entry.cached_at > self.ttl_seconds:
|
||||
# Expired
|
||||
del self._cache[cache_key]
|
||||
self._metrics.cache_misses += 1
|
||||
self._metrics.evictions += 1
|
||||
logger.debug("Cache expired for key: %s", cache_key)
|
||||
return None
|
||||
|
||||
# Cache hit
|
||||
entry.hit_count += 1
|
||||
self._metrics.cache_hits += 1
|
||||
logger.debug("Cache hit for key: %s (hits: %d)", cache_key, entry.hit_count)
|
||||
|
||||
return entry.decision
|
||||
|
||||
def set(
|
||||
self,
|
||||
market_data: dict[str, Any],
|
||||
decision: TradeDecision,
|
||||
) -> None:
|
||||
"""Store decision in cache.
|
||||
|
||||
Args:
|
||||
market_data: Market data dictionary
|
||||
decision: TradeDecision to cache
|
||||
"""
|
||||
cache_key = self._generate_cache_key(market_data)
|
||||
market_hash = self._generate_market_hash(market_data)
|
||||
|
||||
# Enforce max size (evict oldest if full)
|
||||
if len(self._cache) >= self.max_size:
|
||||
# Find oldest entry
|
||||
oldest_key = min(self._cache.keys(), key=lambda k: self._cache[k].cached_at)
|
||||
del self._cache[oldest_key]
|
||||
self._metrics.evictions += 1
|
||||
logger.debug("Cache full, evicted key: %s", oldest_key)
|
||||
|
||||
# Store entry
|
||||
entry = CacheEntry(
|
||||
decision=decision,
|
||||
cached_at=time.time(),
|
||||
market_data_hash=market_hash,
|
||||
)
|
||||
self._cache[cache_key] = entry
|
||||
self._metrics.total_entries = len(self._cache)
|
||||
|
||||
logger.debug("Cached decision for key: %s", cache_key)
|
||||
|
||||
def invalidate(self, stock_code: str | None = None) -> int:
|
||||
"""Invalidate cache entries.
|
||||
|
||||
Args:
|
||||
stock_code: Specific stock code to invalidate, or None for all
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
if stock_code is None:
|
||||
# Clear all
|
||||
count = len(self._cache)
|
||||
self._cache.clear()
|
||||
self._metrics.evictions += count
|
||||
self._metrics.total_entries = 0
|
||||
logger.info("Invalidated all cache entries (%d)", count)
|
||||
return count
|
||||
|
||||
# Invalidate specific stock
|
||||
keys_to_remove = [k for k in self._cache.keys() if k.startswith(f"{stock_code}_")]
|
||||
count = len(keys_to_remove)
|
||||
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
|
||||
self._metrics.evictions += count
|
||||
self._metrics.total_entries = len(self._cache)
|
||||
logger.info("Invalidated %d cache entries for stock: %s", count, stock_code)
|
||||
|
||||
return count
|
||||
|
||||
def cleanup_expired(self) -> int:
|
||||
"""Remove expired entries from cache.
|
||||
|
||||
Returns:
|
||||
Number of entries removed
|
||||
"""
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
k
|
||||
for k, v in self._cache.items()
|
||||
if current_time - v.cached_at > self.ttl_seconds
|
||||
]
|
||||
|
||||
count = len(expired_keys)
|
||||
for key in expired_keys:
|
||||
del self._cache[key]
|
||||
|
||||
self._metrics.evictions += count
|
||||
self._metrics.total_entries = len(self._cache)
|
||||
|
||||
if count > 0:
|
||||
logger.debug("Cleaned up %d expired cache entries", count)
|
||||
|
||||
return count
|
||||
|
||||
def get_metrics(self) -> CacheMetrics:
|
||||
"""Get current cache metrics.
|
||||
|
||||
Returns:
|
||||
CacheMetrics object with current statistics
|
||||
"""
|
||||
return self._metrics
|
||||
|
||||
def reset_metrics(self) -> None:
|
||||
"""Reset cache metrics."""
|
||||
self._metrics = CacheMetrics(total_entries=len(self._cache))
|
||||
logger.info("Cache metrics reset")
|
||||
|
||||
def should_cache_decision(self, decision: TradeDecision) -> bool:
|
||||
"""Determine if a decision should be cached.
|
||||
|
||||
HOLD decisions with low confidence are good candidates for caching,
|
||||
as they're likely to recur in quiet markets.
|
||||
|
||||
Args:
|
||||
decision: TradeDecision to evaluate
|
||||
|
||||
Returns:
|
||||
True if decision should be cached
|
||||
"""
|
||||
# Cache HOLD decisions (common in quiet markets)
|
||||
if decision.action == "HOLD":
|
||||
return True
|
||||
|
||||
# Cache high-confidence decisions (stable signals)
|
||||
if decision.confidence >= 90:
|
||||
return True
|
||||
|
||||
# Don't cache low-confidence BUY/SELL (volatile signals)
|
||||
return False
|
||||
@@ -2,6 +2,12 @@
|
||||
|
||||
Constructs prompts from market data, calls Gemini, and parses structured
|
||||
JSON responses into validated TradeDecision objects.
|
||||
|
||||
Includes token efficiency optimizations:
|
||||
- Prompt compression and abbreviation
|
||||
- Response caching for common scenarios
|
||||
- Smart context selection
|
||||
- Token usage tracking
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -9,12 +15,17 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from google import genai
|
||||
|
||||
from src.config import Settings
|
||||
from src.data.news_api import NewsAPI, NewsSentiment
|
||||
from src.data.economic_calendar import EconomicCalendar
|
||||
from src.data.market_data import MarketData
|
||||
from src.brain.cache import DecisionCache
|
||||
from src.brain.prompt_optimizer import PromptOptimizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,23 +39,176 @@ class TradeDecision:
|
||||
action: str # "BUY" | "SELL" | "HOLD"
|
||||
confidence: int # 0-100
|
||||
rationale: str
|
||||
token_count: int = 0 # Estimated tokens used
|
||||
cached: bool = False # Whether decision came from cache
|
||||
|
||||
|
||||
class GeminiClient:
|
||||
"""Wraps the Gemini API for trade decision-making."""
|
||||
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
settings: Settings,
|
||||
news_api: NewsAPI | None = None,
|
||||
economic_calendar: EconomicCalendar | None = None,
|
||||
market_data: MarketData | None = None,
|
||||
enable_cache: bool = True,
|
||||
enable_optimization: bool = True,
|
||||
) -> None:
|
||||
self._settings = settings
|
||||
self._confidence_threshold = settings.CONFIDENCE_THRESHOLD
|
||||
self._client = genai.Client(api_key=settings.GEMINI_API_KEY)
|
||||
self._model_name = settings.GEMINI_MODEL
|
||||
|
||||
# External data sources (optional)
|
||||
self._news_api = news_api
|
||||
self._economic_calendar = economic_calendar
|
||||
self._market_data = market_data
|
||||
|
||||
# Token efficiency features
|
||||
self._enable_cache = enable_cache
|
||||
self._enable_optimization = enable_optimization
|
||||
self._cache = DecisionCache(ttl_seconds=300) if enable_cache else None
|
||||
self._optimizer = PromptOptimizer()
|
||||
|
||||
# Token usage metrics
|
||||
self._total_tokens_used = 0
|
||||
self._total_decisions = 0
|
||||
self._total_cached_decisions = 0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# External Data Integration
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _build_external_context(
|
||||
self, stock_code: str, news_sentiment: NewsSentiment | None = None
|
||||
) -> str:
|
||||
"""Build external data context for the prompt.
|
||||
|
||||
Args:
|
||||
stock_code: Stock ticker symbol
|
||||
news_sentiment: Optional pre-fetched news sentiment
|
||||
|
||||
Returns:
|
||||
Formatted string with external data context
|
||||
"""
|
||||
context_parts: list[str] = []
|
||||
|
||||
# News sentiment
|
||||
if news_sentiment is not None:
|
||||
sentiment_str = self._format_news_sentiment(news_sentiment)
|
||||
if sentiment_str:
|
||||
context_parts.append(sentiment_str)
|
||||
elif self._news_api is not None:
|
||||
# Fetch news sentiment if not provided
|
||||
try:
|
||||
sentiment = await self._news_api.get_news_sentiment(stock_code)
|
||||
if sentiment is not None:
|
||||
sentiment_str = self._format_news_sentiment(sentiment)
|
||||
if sentiment_str:
|
||||
context_parts.append(sentiment_str)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to fetch news sentiment: %s", exc)
|
||||
|
||||
# Economic events
|
||||
if self._economic_calendar is not None:
|
||||
events_str = self._format_economic_events(stock_code)
|
||||
if events_str:
|
||||
context_parts.append(events_str)
|
||||
|
||||
# Market indicators
|
||||
if self._market_data is not None:
|
||||
indicators_str = self._format_market_indicators()
|
||||
if indicators_str:
|
||||
context_parts.append(indicators_str)
|
||||
|
||||
if not context_parts:
|
||||
return ""
|
||||
|
||||
return "EXTERNAL DATA:\n" + "\n\n".join(context_parts)
|
||||
|
||||
def _format_news_sentiment(self, sentiment: NewsSentiment) -> str:
|
||||
"""Format news sentiment for prompt."""
|
||||
if sentiment.article_count == 0:
|
||||
return ""
|
||||
|
||||
# Select top 3 most relevant articles
|
||||
top_articles = sentiment.articles[:3]
|
||||
|
||||
lines = [
|
||||
f"News Sentiment: {sentiment.avg_sentiment:.2f} "
|
||||
f"(from {sentiment.article_count} articles)",
|
||||
]
|
||||
|
||||
for i, article in enumerate(top_articles, 1):
|
||||
lines.append(
|
||||
f" {i}. [{article.source}] {article.title} "
|
||||
f"(sentiment: {article.sentiment_score:.2f})"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _format_economic_events(self, stock_code: str) -> str:
|
||||
"""Format upcoming economic events for prompt."""
|
||||
if self._economic_calendar is None:
|
||||
return ""
|
||||
|
||||
# Check for upcoming high-impact events
|
||||
upcoming = self._economic_calendar.get_upcoming_events(
|
||||
days_ahead=7, min_impact="HIGH"
|
||||
)
|
||||
|
||||
if upcoming.high_impact_count == 0:
|
||||
return ""
|
||||
|
||||
lines = [
|
||||
f"Upcoming High-Impact Events: {upcoming.high_impact_count} in next 7 days"
|
||||
]
|
||||
|
||||
if upcoming.next_major_event is not None:
|
||||
event = upcoming.next_major_event
|
||||
lines.append(
|
||||
f" Next: {event.name} ({event.event_type}) "
|
||||
f"on {event.datetime.strftime('%Y-%m-%d')}"
|
||||
)
|
||||
|
||||
# Check for earnings
|
||||
earnings_date = self._economic_calendar.get_earnings_date(stock_code)
|
||||
if earnings_date is not None:
|
||||
lines.append(
|
||||
f" Earnings: {stock_code} on {earnings_date.strftime('%Y-%m-%d')}"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _format_market_indicators(self) -> str:
|
||||
"""Format market indicators for prompt."""
|
||||
if self._market_data is None:
|
||||
return ""
|
||||
|
||||
try:
|
||||
indicators = self._market_data.get_market_indicators()
|
||||
lines = [f"Market Sentiment: {indicators.sentiment.name}"]
|
||||
|
||||
# Add breadth if meaningful
|
||||
if indicators.breadth.advance_decline_ratio != 1.0:
|
||||
lines.append(
|
||||
f"Advance/Decline Ratio: {indicators.breadth.advance_decline_ratio:.2f}"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to get market indicators: %s", exc)
|
||||
return ""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Prompt Construction
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def build_prompt(self, market_data: dict[str, Any]) -> str:
|
||||
"""Build a structured prompt from market data.
|
||||
async def build_prompt(
|
||||
self, market_data: dict[str, Any], news_sentiment: NewsSentiment | None = None
|
||||
) -> str:
|
||||
"""Build a structured prompt from market data and external sources.
|
||||
|
||||
The prompt instructs Gemini to return valid JSON with action,
|
||||
confidence, and rationale fields.
|
||||
@@ -72,6 +236,60 @@ class GeminiClient:
|
||||
|
||||
market_info = "\n".join(market_info_lines)
|
||||
|
||||
# Add external data context if available
|
||||
external_context = await self._build_external_context(
|
||||
market_data["stock_code"], news_sentiment
|
||||
)
|
||||
if external_context:
|
||||
market_info += f"\n\n{external_context}"
|
||||
|
||||
json_format = (
|
||||
'{"action": "BUY"|"SELL"|"HOLD", '
|
||||
'"confidence": <int 0-100>, "rationale": "<string>"}'
|
||||
)
|
||||
return (
|
||||
f"You are a professional {market_name} trading analyst.\n"
|
||||
"Analyze the following market data and decide whether to "
|
||||
"BUY, SELL, or HOLD.\n\n"
|
||||
f"{market_info}\n\n"
|
||||
"You MUST respond with ONLY valid JSON in the following format:\n"
|
||||
f"{json_format}\n\n"
|
||||
"Rules:\n"
|
||||
"- action must be exactly one of: BUY, SELL, HOLD\n"
|
||||
"- confidence must be an integer from 0 to 100\n"
|
||||
"- rationale must explain your reasoning concisely\n"
|
||||
"- Do NOT wrap the JSON in markdown code blocks\n"
|
||||
)
|
||||
|
||||
def build_prompt_sync(self, market_data: dict[str, Any]) -> str:
|
||||
"""Synchronous version of build_prompt (for backward compatibility).
|
||||
|
||||
This version does NOT include external data integration.
|
||||
Use async build_prompt() for full functionality.
|
||||
"""
|
||||
market_name = market_data.get("market_name", "Korean stock market")
|
||||
|
||||
# Build market data section dynamically based on available fields
|
||||
market_info_lines = [
|
||||
f"Market: {market_name}",
|
||||
f"Stock Code: {market_data['stock_code']}",
|
||||
f"Current Price: {market_data['current_price']}",
|
||||
]
|
||||
|
||||
# Add orderbook if available (domestic markets)
|
||||
if "orderbook" in market_data:
|
||||
market_info_lines.append(
|
||||
f"Orderbook: {json.dumps(market_data['orderbook'], ensure_ascii=False)}"
|
||||
)
|
||||
|
||||
# Add foreigner net if non-zero
|
||||
if market_data.get("foreigner_net", 0) != 0:
|
||||
market_info_lines.append(
|
||||
f"Foreigner Net Buy/Sell: {market_data['foreigner_net']}"
|
||||
)
|
||||
|
||||
market_info = "\n".join(market_info_lines)
|
||||
|
||||
json_format = (
|
||||
'{"action": "BUY"|"SELL"|"HOLD", '
|
||||
'"confidence": <int 0-100>, "rationale": "<string>"}'
|
||||
@@ -152,28 +370,153 @@ class GeminiClient:
|
||||
# API Call
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def decide(self, market_data: dict[str, Any]) -> TradeDecision:
|
||||
"""Build prompt, call Gemini, and return a parsed decision."""
|
||||
prompt = self.build_prompt(market_data)
|
||||
logger.info("Requesting trade decision from Gemini")
|
||||
async def decide(
|
||||
self, market_data: dict[str, Any], news_sentiment: NewsSentiment | None = None
|
||||
) -> TradeDecision:
|
||||
"""Build prompt, call Gemini, and return a parsed decision.
|
||||
|
||||
Args:
|
||||
market_data: Market data dictionary with price, orderbook, etc.
|
||||
news_sentiment: Optional pre-fetched news sentiment
|
||||
|
||||
Returns:
|
||||
Parsed TradeDecision
|
||||
"""
|
||||
# Check cache first
|
||||
if self._cache:
|
||||
cached_decision = self._cache.get(market_data)
|
||||
if cached_decision:
|
||||
self._total_cached_decisions += 1
|
||||
self._total_decisions += 1
|
||||
logger.info(
|
||||
"Cache hit for decision",
|
||||
extra={
|
||||
"action": cached_decision.action,
|
||||
"confidence": cached_decision.confidence,
|
||||
"cache_hit_rate": self.get_cache_hit_rate(),
|
||||
},
|
||||
)
|
||||
# Return cached decision with cached flag
|
||||
return TradeDecision(
|
||||
action=cached_decision.action,
|
||||
confidence=cached_decision.confidence,
|
||||
rationale=cached_decision.rationale,
|
||||
token_count=0,
|
||||
cached=True,
|
||||
)
|
||||
|
||||
# Build optimized prompt
|
||||
if self._enable_optimization:
|
||||
prompt = self._optimizer.build_compressed_prompt(market_data)
|
||||
else:
|
||||
prompt = await self.build_prompt(market_data, news_sentiment)
|
||||
|
||||
# Estimate tokens
|
||||
token_count = self._optimizer.estimate_tokens(prompt)
|
||||
self._total_tokens_used += token_count
|
||||
|
||||
logger.info(
|
||||
"Requesting trade decision from Gemini",
|
||||
extra={"estimated_tokens": token_count, "optimized": self._enable_optimization},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self._client.aio.models.generate_content(
|
||||
model=self._model_name, contents=prompt,
|
||||
model=self._model_name,
|
||||
contents=prompt,
|
||||
)
|
||||
raw = response.text
|
||||
except Exception as exc:
|
||||
logger.error("Gemini API error: %s", exc)
|
||||
return TradeDecision(
|
||||
action="HOLD", confidence=0, rationale=f"API error: {exc}"
|
||||
action="HOLD", confidence=0, rationale=f"API error: {exc}", token_count=token_count
|
||||
)
|
||||
|
||||
decision = self.parse_response(raw)
|
||||
self._total_decisions += 1
|
||||
|
||||
# Add token count to decision
|
||||
decision_with_tokens = TradeDecision(
|
||||
action=decision.action,
|
||||
confidence=decision.confidence,
|
||||
rationale=decision.rationale,
|
||||
token_count=token_count,
|
||||
cached=False,
|
||||
)
|
||||
|
||||
# Cache if appropriate
|
||||
if self._cache and self._cache.should_cache_decision(decision):
|
||||
self._cache.set(market_data, decision)
|
||||
|
||||
logger.info(
|
||||
"Gemini decision",
|
||||
extra={
|
||||
"action": decision.action,
|
||||
"confidence": decision.confidence,
|
||||
"tokens": token_count,
|
||||
"avg_tokens": self.get_avg_tokens_per_decision(),
|
||||
},
|
||||
)
|
||||
return decision
|
||||
|
||||
return decision_with_tokens
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Token Efficiency Metrics
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_token_metrics(self) -> dict[str, Any]:
|
||||
"""Get token usage metrics.
|
||||
|
||||
Returns:
|
||||
Dictionary with token usage statistics
|
||||
"""
|
||||
metrics = {
|
||||
"total_tokens_used": self._total_tokens_used,
|
||||
"total_decisions": self._total_decisions,
|
||||
"total_cached_decisions": self._total_cached_decisions,
|
||||
"avg_tokens_per_decision": self.get_avg_tokens_per_decision(),
|
||||
"cache_hit_rate": self.get_cache_hit_rate(),
|
||||
}
|
||||
|
||||
if self._cache:
|
||||
cache_metrics = self._cache.get_metrics()
|
||||
metrics["cache_metrics"] = cache_metrics.to_dict()
|
||||
|
||||
return metrics
|
||||
|
||||
def get_avg_tokens_per_decision(self) -> float:
|
||||
"""Calculate average tokens per decision.
|
||||
|
||||
Returns:
|
||||
Average tokens per decision
|
||||
"""
|
||||
if self._total_decisions == 0:
|
||||
return 0.0
|
||||
return self._total_tokens_used / self._total_decisions
|
||||
|
||||
def get_cache_hit_rate(self) -> float:
|
||||
"""Calculate cache hit rate.
|
||||
|
||||
Returns:
|
||||
Cache hit rate (0.0 to 1.0)
|
||||
"""
|
||||
if self._total_decisions == 0:
|
||||
return 0.0
|
||||
return self._total_cached_decisions / self._total_decisions
|
||||
|
||||
def reset_metrics(self) -> None:
|
||||
"""Reset token usage metrics."""
|
||||
self._total_tokens_used = 0
|
||||
self._total_decisions = 0
|
||||
self._total_cached_decisions = 0
|
||||
if self._cache:
|
||||
self._cache.reset_metrics()
|
||||
logger.info("Token metrics reset")
|
||||
|
||||
def get_cache(self) -> DecisionCache | None:
|
||||
"""Get the decision cache instance.
|
||||
|
||||
Returns:
|
||||
DecisionCache instance or None if caching disabled
|
||||
"""
|
||||
return self._cache
|
||||
|
||||
Reference in New Issue
Block a user