Some checks failed
CI / test (pull_request) Has been cancelled
decide() ignored market_data["prompt_override"], always building a generic trade-decision prompt. This caused pre_market_planner playbook generation to fail with JSONDecodeError on every market, falling back to defensive playbooks. Now prompt_override takes priority over both optimization and standard prompt building. Closes #143 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
760 lines
27 KiB
Python
760 lines
27 KiB
Python
"""Decision engine powered by Google Gemini.
|
|
|
|
Constructs prompts from market data, calls Gemini, and parses structured
|
|
JSON responses into validated TradeDecision objects.
|
|
|
|
Includes token efficiency optimizations:
|
|
- Prompt compression and abbreviation
|
|
- Response caching for common scenarios
|
|
- Smart context selection
|
|
- Token usage tracking and metrics
|
|
|
|
Includes external data integration:
|
|
- News sentiment analysis
|
|
- Economic calendar events
|
|
- Market indicators
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
from google import genai
|
|
|
|
from src.config import Settings
|
|
from src.data.news_api import NewsAPI, NewsSentiment
|
|
from src.data.economic_calendar import EconomicCalendar
|
|
from src.data.market_data import MarketData
|
|
from src.brain.cache import DecisionCache
|
|
from src.brain.prompt_optimizer import PromptOptimizer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
VALID_ACTIONS = {"BUY", "SELL", "HOLD"}
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class TradeDecision:
|
|
"""Validated decision from the AI brain."""
|
|
|
|
action: str # "BUY" | "SELL" | "HOLD"
|
|
confidence: int # 0-100
|
|
rationale: str
|
|
token_count: int = 0 # Estimated tokens used
|
|
cached: bool = False # Whether decision came from cache
|
|
|
|
|
|
class GeminiClient:
|
|
"""Wraps the Gemini API for trade decision-making."""
|
|
|
|
def __init__(
|
|
self,
|
|
settings: Settings,
|
|
news_api: NewsAPI | None = None,
|
|
economic_calendar: EconomicCalendar | None = None,
|
|
market_data: MarketData | None = None,
|
|
enable_cache: bool = True,
|
|
enable_optimization: bool = True,
|
|
) -> None:
|
|
self._settings = settings
|
|
self._confidence_threshold = settings.CONFIDENCE_THRESHOLD
|
|
self._client = genai.Client(api_key=settings.GEMINI_API_KEY)
|
|
self._model_name = settings.GEMINI_MODEL
|
|
|
|
# External data sources (optional)
|
|
self._news_api = news_api
|
|
self._economic_calendar = economic_calendar
|
|
self._market_data = market_data
|
|
|
|
# Token efficiency features
|
|
self._enable_cache = enable_cache
|
|
self._enable_optimization = enable_optimization
|
|
self._cache = DecisionCache(ttl_seconds=300) if enable_cache else None
|
|
self._optimizer = PromptOptimizer()
|
|
|
|
# Token usage metrics
|
|
self._total_tokens_used = 0
|
|
self._total_decisions = 0
|
|
self._total_cached_decisions = 0
|
|
|
|
# ------------------------------------------------------------------
|
|
# External Data Integration
|
|
# ------------------------------------------------------------------
|
|
|
|
async def _build_external_context(
|
|
self, stock_code: str, news_sentiment: NewsSentiment | None = None
|
|
) -> str:
|
|
"""Build external data context for the prompt.
|
|
|
|
Args:
|
|
stock_code: Stock ticker symbol
|
|
news_sentiment: Optional pre-fetched news sentiment
|
|
|
|
Returns:
|
|
Formatted string with external data context
|
|
"""
|
|
context_parts: list[str] = []
|
|
|
|
# News sentiment
|
|
if news_sentiment is not None:
|
|
sentiment_str = self._format_news_sentiment(news_sentiment)
|
|
if sentiment_str:
|
|
context_parts.append(sentiment_str)
|
|
elif self._news_api is not None:
|
|
# Fetch news sentiment if not provided
|
|
try:
|
|
sentiment = await self._news_api.get_news_sentiment(stock_code)
|
|
if sentiment is not None:
|
|
sentiment_str = self._format_news_sentiment(sentiment)
|
|
if sentiment_str:
|
|
context_parts.append(sentiment_str)
|
|
except Exception as exc:
|
|
logger.warning("Failed to fetch news sentiment: %s", exc)
|
|
|
|
# Economic events
|
|
if self._economic_calendar is not None:
|
|
events_str = self._format_economic_events(stock_code)
|
|
if events_str:
|
|
context_parts.append(events_str)
|
|
|
|
# Market indicators
|
|
if self._market_data is not None:
|
|
indicators_str = self._format_market_indicators()
|
|
if indicators_str:
|
|
context_parts.append(indicators_str)
|
|
|
|
if not context_parts:
|
|
return ""
|
|
|
|
return "EXTERNAL DATA:\n" + "\n\n".join(context_parts)
|
|
|
|
def _format_news_sentiment(self, sentiment: NewsSentiment) -> str:
|
|
"""Format news sentiment for prompt."""
|
|
if sentiment.article_count == 0:
|
|
return ""
|
|
|
|
# Select top 3 most relevant articles
|
|
top_articles = sentiment.articles[:3]
|
|
|
|
lines = [
|
|
f"News Sentiment: {sentiment.avg_sentiment:.2f} "
|
|
f"(from {sentiment.article_count} articles)",
|
|
]
|
|
|
|
for i, article in enumerate(top_articles, 1):
|
|
lines.append(
|
|
f" {i}. [{article.source}] {article.title} "
|
|
f"(sentiment: {article.sentiment_score:.2f})"
|
|
)
|
|
|
|
return "\n".join(lines)
|
|
|
|
def _format_economic_events(self, stock_code: str) -> str:
|
|
"""Format upcoming economic events for prompt."""
|
|
if self._economic_calendar is None:
|
|
return ""
|
|
|
|
# Check for upcoming high-impact events
|
|
upcoming = self._economic_calendar.get_upcoming_events(
|
|
days_ahead=7, min_impact="HIGH"
|
|
)
|
|
|
|
if upcoming.high_impact_count == 0:
|
|
return ""
|
|
|
|
lines = [
|
|
f"Upcoming High-Impact Events: {upcoming.high_impact_count} in next 7 days"
|
|
]
|
|
|
|
if upcoming.next_major_event is not None:
|
|
event = upcoming.next_major_event
|
|
lines.append(
|
|
f" Next: {event.name} ({event.event_type}) "
|
|
f"on {event.datetime.strftime('%Y-%m-%d')}"
|
|
)
|
|
|
|
# Check for earnings
|
|
earnings_date = self._economic_calendar.get_earnings_date(stock_code)
|
|
if earnings_date is not None:
|
|
lines.append(
|
|
f" Earnings: {stock_code} on {earnings_date.strftime('%Y-%m-%d')}"
|
|
)
|
|
|
|
return "\n".join(lines)
|
|
|
|
def _format_market_indicators(self) -> str:
|
|
"""Format market indicators for prompt."""
|
|
if self._market_data is None:
|
|
return ""
|
|
|
|
try:
|
|
indicators = self._market_data.get_market_indicators()
|
|
lines = [f"Market Sentiment: {indicators.sentiment.name}"]
|
|
|
|
# Add breadth if meaningful
|
|
if indicators.breadth.advance_decline_ratio != 1.0:
|
|
lines.append(
|
|
f"Advance/Decline Ratio: {indicators.breadth.advance_decline_ratio:.2f}"
|
|
)
|
|
|
|
return "\n".join(lines)
|
|
except Exception as exc:
|
|
logger.warning("Failed to get market indicators: %s", exc)
|
|
return ""
|
|
|
|
# ------------------------------------------------------------------
|
|
# Prompt Construction
|
|
# ------------------------------------------------------------------
|
|
|
|
async def build_prompt(
|
|
self, market_data: dict[str, Any], news_sentiment: NewsSentiment | None = None
|
|
) -> str:
|
|
"""Build a structured prompt from market data and external sources.
|
|
|
|
The prompt instructs Gemini to return valid JSON with action,
|
|
confidence, and rationale fields.
|
|
"""
|
|
market_name = market_data.get("market_name", "Korean stock market")
|
|
|
|
# Build market data section dynamically based on available fields
|
|
market_info_lines = [
|
|
f"Market: {market_name}",
|
|
f"Stock Code: {market_data['stock_code']}",
|
|
f"Current Price: {market_data['current_price']}",
|
|
]
|
|
|
|
# Add orderbook if available (domestic markets)
|
|
if "orderbook" in market_data:
|
|
market_info_lines.append(
|
|
f"Orderbook: {json.dumps(market_data['orderbook'], ensure_ascii=False)}"
|
|
)
|
|
|
|
# Add foreigner net if non-zero
|
|
if market_data.get("foreigner_net", 0) != 0:
|
|
market_info_lines.append(
|
|
f"Foreigner Net Buy/Sell: {market_data['foreigner_net']}"
|
|
)
|
|
|
|
market_info = "\n".join(market_info_lines)
|
|
|
|
# Add external data context if available
|
|
external_context = await self._build_external_context(
|
|
market_data["stock_code"], news_sentiment
|
|
)
|
|
if external_context:
|
|
market_info += f"\n\n{external_context}"
|
|
|
|
json_format = (
|
|
'{"action": "BUY"|"SELL"|"HOLD", '
|
|
'"confidence": <int 0-100>, "rationale": "<string>"}'
|
|
)
|
|
return (
|
|
f"You are a professional {market_name} trading analyst.\n"
|
|
"Analyze the following market data and decide whether to "
|
|
"BUY, SELL, or HOLD.\n\n"
|
|
f"{market_info}\n\n"
|
|
"You MUST respond with ONLY valid JSON in the following format:\n"
|
|
f"{json_format}\n\n"
|
|
"Rules:\n"
|
|
"- action must be exactly one of: BUY, SELL, HOLD\n"
|
|
"- confidence must be an integer from 0 to 100\n"
|
|
"- rationale must explain your reasoning concisely\n"
|
|
"- Do NOT wrap the JSON in markdown code blocks\n"
|
|
)
|
|
|
|
def build_prompt_sync(self, market_data: dict[str, Any]) -> str:
|
|
"""Synchronous version of build_prompt (for backward compatibility).
|
|
|
|
This version does NOT include external data integration.
|
|
Use async build_prompt() for full functionality.
|
|
"""
|
|
market_name = market_data.get("market_name", "Korean stock market")
|
|
|
|
# Build market data section dynamically based on available fields
|
|
market_info_lines = [
|
|
f"Market: {market_name}",
|
|
f"Stock Code: {market_data['stock_code']}",
|
|
f"Current Price: {market_data['current_price']}",
|
|
]
|
|
|
|
# Add orderbook if available (domestic markets)
|
|
if "orderbook" in market_data:
|
|
market_info_lines.append(
|
|
f"Orderbook: {json.dumps(market_data['orderbook'], ensure_ascii=False)}"
|
|
)
|
|
|
|
# Add foreigner net if non-zero
|
|
if market_data.get("foreigner_net", 0) != 0:
|
|
market_info_lines.append(
|
|
f"Foreigner Net Buy/Sell: {market_data['foreigner_net']}"
|
|
)
|
|
|
|
market_info = "\n".join(market_info_lines)
|
|
|
|
json_format = (
|
|
'{"action": "BUY"|"SELL"|"HOLD", '
|
|
'"confidence": <int 0-100>, "rationale": "<string>"}'
|
|
)
|
|
return (
|
|
f"You are a professional {market_name} trading analyst.\n"
|
|
"Analyze the following market data and decide whether to "
|
|
"BUY, SELL, or HOLD.\n\n"
|
|
f"{market_info}\n\n"
|
|
"You MUST respond with ONLY valid JSON in the following format:\n"
|
|
f"{json_format}\n\n"
|
|
"Rules:\n"
|
|
"- action must be exactly one of: BUY, SELL, HOLD\n"
|
|
"- confidence must be an integer from 0 to 100\n"
|
|
"- rationale must explain your reasoning concisely\n"
|
|
"- Do NOT wrap the JSON in markdown code blocks\n"
|
|
)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Response Parsing
|
|
# ------------------------------------------------------------------
|
|
|
|
def parse_response(self, raw: str) -> TradeDecision:
|
|
"""Parse a raw Gemini response into a TradeDecision.
|
|
|
|
Handles: valid JSON, JSON wrapped in markdown code blocks,
|
|
malformed JSON, missing fields, and invalid action values.
|
|
|
|
On any failure, returns a safe HOLD with confidence 0.
|
|
"""
|
|
if not raw or not raw.strip():
|
|
logger.warning("Empty response from Gemini — defaulting to HOLD")
|
|
return TradeDecision(action="HOLD", confidence=0, rationale="Empty response")
|
|
|
|
# Strip markdown code fences if present
|
|
cleaned = raw.strip()
|
|
match = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", cleaned, re.DOTALL)
|
|
if match:
|
|
cleaned = match.group(1).strip()
|
|
|
|
try:
|
|
data = json.loads(cleaned)
|
|
except json.JSONDecodeError:
|
|
logger.warning("Malformed JSON from Gemini — defaulting to HOLD")
|
|
return TradeDecision(
|
|
action="HOLD", confidence=0, rationale="Malformed JSON response"
|
|
)
|
|
|
|
# Validate required fields
|
|
if not all(k in data for k in ("action", "confidence", "rationale")):
|
|
logger.warning("Missing fields in Gemini response — defaulting to HOLD")
|
|
return TradeDecision(
|
|
action="HOLD", confidence=0, rationale="Missing required fields"
|
|
)
|
|
|
|
action = str(data["action"]).upper()
|
|
if action not in VALID_ACTIONS:
|
|
logger.warning("Invalid action '%s' from Gemini — defaulting to HOLD", action)
|
|
return TradeDecision(
|
|
action="HOLD", confidence=0, rationale=f"Invalid action: {action}"
|
|
)
|
|
|
|
confidence = int(data["confidence"])
|
|
rationale = str(data["rationale"])
|
|
|
|
# Enforce confidence threshold
|
|
if confidence < self._confidence_threshold:
|
|
logger.info(
|
|
"Confidence %d < threshold %d — forcing HOLD",
|
|
confidence,
|
|
self._confidence_threshold,
|
|
)
|
|
action = "HOLD"
|
|
|
|
return TradeDecision(action=action, confidence=confidence, rationale=rationale)
|
|
|
|
# ------------------------------------------------------------------
|
|
# API Call
|
|
# ------------------------------------------------------------------
|
|
|
|
async def decide(
|
|
self, market_data: dict[str, Any], news_sentiment: NewsSentiment | None = None
|
|
) -> TradeDecision:
|
|
"""Build prompt, call Gemini, and return a parsed decision.
|
|
|
|
Args:
|
|
market_data: Market data dictionary with price, orderbook, etc.
|
|
news_sentiment: Optional pre-fetched news sentiment
|
|
|
|
Returns:
|
|
Parsed TradeDecision
|
|
"""
|
|
# Check cache first
|
|
if self._cache:
|
|
cached_decision = self._cache.get(market_data)
|
|
if cached_decision:
|
|
self._total_cached_decisions += 1
|
|
self._total_decisions += 1
|
|
logger.info(
|
|
"Cache hit for decision",
|
|
extra={
|
|
"action": cached_decision.action,
|
|
"confidence": cached_decision.confidence,
|
|
"cache_hit_rate": self.get_cache_hit_rate(),
|
|
},
|
|
)
|
|
# Return cached decision with cached flag
|
|
return TradeDecision(
|
|
action=cached_decision.action,
|
|
confidence=cached_decision.confidence,
|
|
rationale=cached_decision.rationale,
|
|
token_count=0,
|
|
cached=True,
|
|
)
|
|
|
|
# Build prompt (prompt_override takes priority for callers like pre_market_planner)
|
|
if "prompt_override" in market_data:
|
|
prompt = market_data["prompt_override"]
|
|
elif self._enable_optimization:
|
|
prompt = self._optimizer.build_compressed_prompt(market_data)
|
|
else:
|
|
prompt = await self.build_prompt(market_data, news_sentiment)
|
|
|
|
# Estimate tokens
|
|
token_count = self._optimizer.estimate_tokens(prompt)
|
|
self._total_tokens_used += token_count
|
|
|
|
logger.info(
|
|
"Requesting trade decision from Gemini",
|
|
extra={"estimated_tokens": token_count, "optimized": self._enable_optimization},
|
|
)
|
|
|
|
try:
|
|
response = await self._client.aio.models.generate_content(
|
|
model=self._model_name,
|
|
contents=prompt,
|
|
)
|
|
raw = response.text
|
|
except Exception as exc:
|
|
logger.error("Gemini API error: %s", exc)
|
|
return TradeDecision(
|
|
action="HOLD", confidence=0, rationale=f"API error: {exc}", token_count=token_count
|
|
)
|
|
|
|
decision = self.parse_response(raw)
|
|
self._total_decisions += 1
|
|
|
|
# Add token count to decision
|
|
decision_with_tokens = TradeDecision(
|
|
action=decision.action,
|
|
confidence=decision.confidence,
|
|
rationale=decision.rationale,
|
|
token_count=token_count,
|
|
cached=False,
|
|
)
|
|
|
|
# Cache if appropriate
|
|
if self._cache and self._cache.should_cache_decision(decision):
|
|
self._cache.set(market_data, decision)
|
|
|
|
logger.info(
|
|
"Gemini decision",
|
|
extra={
|
|
"action": decision.action,
|
|
"confidence": decision.confidence,
|
|
"tokens": token_count,
|
|
"avg_tokens": self.get_avg_tokens_per_decision(),
|
|
},
|
|
)
|
|
|
|
return decision_with_tokens
|
|
|
|
# ------------------------------------------------------------------
|
|
# Token Efficiency Metrics
|
|
# ------------------------------------------------------------------
|
|
|
|
def get_token_metrics(self) -> dict[str, Any]:
|
|
"""Get token usage metrics.
|
|
|
|
Returns:
|
|
Dictionary with token usage statistics
|
|
"""
|
|
metrics = {
|
|
"total_tokens_used": self._total_tokens_used,
|
|
"total_decisions": self._total_decisions,
|
|
"total_cached_decisions": self._total_cached_decisions,
|
|
"avg_tokens_per_decision": self.get_avg_tokens_per_decision(),
|
|
"cache_hit_rate": self.get_cache_hit_rate(),
|
|
}
|
|
|
|
if self._cache:
|
|
cache_metrics = self._cache.get_metrics()
|
|
metrics["cache_metrics"] = cache_metrics.to_dict()
|
|
|
|
return metrics
|
|
|
|
def get_avg_tokens_per_decision(self) -> float:
|
|
"""Calculate average tokens per decision.
|
|
|
|
Returns:
|
|
Average tokens per decision
|
|
"""
|
|
if self._total_decisions == 0:
|
|
return 0.0
|
|
return self._total_tokens_used / self._total_decisions
|
|
|
|
def get_cache_hit_rate(self) -> float:
|
|
"""Calculate cache hit rate.
|
|
|
|
Returns:
|
|
Cache hit rate (0.0 to 1.0)
|
|
"""
|
|
if self._total_decisions == 0:
|
|
return 0.0
|
|
return self._total_cached_decisions / self._total_decisions
|
|
|
|
def reset_metrics(self) -> None:
|
|
"""Reset token usage metrics."""
|
|
self._total_tokens_used = 0
|
|
self._total_decisions = 0
|
|
self._total_cached_decisions = 0
|
|
if self._cache:
|
|
self._cache.reset_metrics()
|
|
logger.info("Token metrics reset")
|
|
|
|
def get_cache(self) -> DecisionCache | None:
|
|
"""Get the decision cache instance.
|
|
|
|
Returns:
|
|
DecisionCache instance or None if caching disabled
|
|
"""
|
|
return self._cache
|
|
|
|
# ------------------------------------------------------------------
|
|
# Batch Decision Making (for daily trading mode)
|
|
# ------------------------------------------------------------------
|
|
|
|
async def decide_batch(
|
|
self, stocks_data: list[dict[str, Any]]
|
|
) -> dict[str, TradeDecision]:
|
|
"""Make decisions for multiple stocks in a single API call.
|
|
|
|
This is designed for daily trading mode to minimize API usage
|
|
when working with Gemini Free tier (20 calls/day limit).
|
|
|
|
Args:
|
|
stocks_data: List of market data dictionaries, each with:
|
|
- stock_code: Stock ticker
|
|
- current_price: Current price
|
|
- market_name: Market name (optional)
|
|
- foreigner_net: Foreigner net buy/sell (optional)
|
|
|
|
Returns:
|
|
Dictionary mapping stock_code to TradeDecision
|
|
|
|
Example:
|
|
>>> stocks_data = [
|
|
... {"stock_code": "AAPL", "current_price": 185.5},
|
|
... {"stock_code": "MSFT", "current_price": 420.0},
|
|
... ]
|
|
>>> decisions = await client.decide_batch(stocks_data)
|
|
>>> decisions["AAPL"].action
|
|
'BUY'
|
|
"""
|
|
if not stocks_data:
|
|
return {}
|
|
|
|
# Build compressed batch prompt
|
|
market_name = stocks_data[0].get("market_name", "stock market")
|
|
|
|
# Format stock data as compact JSON array
|
|
compact_stocks = []
|
|
for stock in stocks_data:
|
|
compact = {
|
|
"code": stock["stock_code"],
|
|
"price": stock["current_price"],
|
|
}
|
|
if stock.get("foreigner_net", 0) != 0:
|
|
compact["frgn"] = stock["foreigner_net"]
|
|
compact_stocks.append(compact)
|
|
|
|
data_str = json.dumps(compact_stocks, ensure_ascii=False)
|
|
|
|
prompt = (
|
|
f"You are a professional {market_name} trading analyst.\n"
|
|
"Analyze the following stocks and decide whether to BUY, SELL, or HOLD each one.\n\n"
|
|
f"Stock Data: {data_str}\n\n"
|
|
"You MUST respond with ONLY a valid JSON array in this format:\n"
|
|
'[{"code": "AAPL", "action": "BUY", "confidence": 85, "rationale": "..."},\n'
|
|
' {"code": "MSFT", "action": "HOLD", "confidence": 50, "rationale": "..."}, ...]\n\n'
|
|
"Rules:\n"
|
|
"- Return one decision object per stock\n"
|
|
"- action must be exactly: BUY, SELL, or HOLD\n"
|
|
"- confidence must be 0-100\n"
|
|
"- rationale should be concise (1-2 sentences)\n"
|
|
"- Do NOT wrap JSON in markdown code blocks\n"
|
|
)
|
|
|
|
# Estimate tokens
|
|
token_count = self._optimizer.estimate_tokens(prompt)
|
|
self._total_tokens_used += token_count
|
|
|
|
logger.info(
|
|
"Requesting batch decision for %d stocks from Gemini",
|
|
len(stocks_data),
|
|
extra={"estimated_tokens": token_count},
|
|
)
|
|
|
|
try:
|
|
response = await self._client.aio.models.generate_content(
|
|
model=self._model_name,
|
|
contents=prompt,
|
|
)
|
|
raw = response.text
|
|
except Exception as exc:
|
|
logger.error("Gemini API error in batch decision: %s", exc)
|
|
# Return HOLD for all stocks on API error
|
|
return {
|
|
stock["stock_code"]: TradeDecision(
|
|
action="HOLD",
|
|
confidence=0,
|
|
rationale=f"API error: {exc}",
|
|
token_count=token_count,
|
|
cached=False,
|
|
)
|
|
for stock in stocks_data
|
|
}
|
|
|
|
# Parse batch response
|
|
return self._parse_batch_response(raw, stocks_data, token_count)
|
|
|
|
def _parse_batch_response(
|
|
self, raw: str, stocks_data: list[dict[str, Any]], token_count: int
|
|
) -> dict[str, TradeDecision]:
|
|
"""Parse batch response into a dictionary of decisions.
|
|
|
|
Args:
|
|
raw: Raw response from Gemini
|
|
stocks_data: Original stock data list
|
|
token_count: Token count for the request
|
|
|
|
Returns:
|
|
Dictionary mapping stock_code to TradeDecision
|
|
"""
|
|
if not raw or not raw.strip():
|
|
logger.warning("Empty batch response from Gemini — defaulting all to HOLD")
|
|
return {
|
|
stock["stock_code"]: TradeDecision(
|
|
action="HOLD",
|
|
confidence=0,
|
|
rationale="Empty response",
|
|
token_count=0,
|
|
cached=False,
|
|
)
|
|
for stock in stocks_data
|
|
}
|
|
|
|
# Strip markdown code fences if present
|
|
cleaned = raw.strip()
|
|
match = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", cleaned, re.DOTALL)
|
|
if match:
|
|
cleaned = match.group(1).strip()
|
|
|
|
try:
|
|
data = json.loads(cleaned)
|
|
except json.JSONDecodeError:
|
|
logger.warning("Malformed JSON in batch response — defaulting all to HOLD")
|
|
return {
|
|
stock["stock_code"]: TradeDecision(
|
|
action="HOLD",
|
|
confidence=0,
|
|
rationale="Malformed JSON response",
|
|
token_count=0,
|
|
cached=False,
|
|
)
|
|
for stock in stocks_data
|
|
}
|
|
|
|
if not isinstance(data, list):
|
|
logger.warning("Batch response is not a JSON array — defaulting all to HOLD")
|
|
return {
|
|
stock["stock_code"]: TradeDecision(
|
|
action="HOLD",
|
|
confidence=0,
|
|
rationale="Invalid response format",
|
|
token_count=0,
|
|
cached=False,
|
|
)
|
|
for stock in stocks_data
|
|
}
|
|
|
|
# Build decision map
|
|
decisions: dict[str, TradeDecision] = {}
|
|
stock_codes = {stock["stock_code"] for stock in stocks_data}
|
|
|
|
for item in data:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
|
|
code = item.get("code")
|
|
if not code or code not in stock_codes:
|
|
continue
|
|
|
|
# Validate required fields
|
|
if not all(k in item for k in ("action", "confidence", "rationale")):
|
|
logger.warning("Missing fields for %s — using HOLD", code)
|
|
decisions[code] = TradeDecision(
|
|
action="HOLD",
|
|
confidence=0,
|
|
rationale="Missing required fields",
|
|
token_count=0,
|
|
cached=False,
|
|
)
|
|
continue
|
|
|
|
action = str(item["action"]).upper()
|
|
if action not in VALID_ACTIONS:
|
|
logger.warning("Invalid action '%s' for %s — forcing HOLD", action, code)
|
|
action = "HOLD"
|
|
|
|
confidence = int(item["confidence"])
|
|
rationale = str(item["rationale"])
|
|
|
|
# Enforce confidence threshold
|
|
if confidence < self._confidence_threshold:
|
|
logger.info(
|
|
"Confidence %d < threshold %d for %s — forcing HOLD",
|
|
confidence,
|
|
self._confidence_threshold,
|
|
code,
|
|
)
|
|
action = "HOLD"
|
|
|
|
decisions[code] = TradeDecision(
|
|
action=action,
|
|
confidence=confidence,
|
|
rationale=rationale,
|
|
token_count=token_count // len(stocks_data), # Split token cost
|
|
cached=False,
|
|
)
|
|
self._total_decisions += 1
|
|
|
|
# Fill in missing stocks with HOLD
|
|
for stock in stocks_data:
|
|
code = stock["stock_code"]
|
|
if code not in decisions:
|
|
logger.warning("No decision for %s in batch response — using HOLD", code)
|
|
decisions[code] = TradeDecision(
|
|
action="HOLD",
|
|
confidence=0,
|
|
rationale="Not found in batch response",
|
|
token_count=0,
|
|
cached=False,
|
|
)
|
|
|
|
logger.info(
|
|
"Batch decision completed for %d stocks",
|
|
len(decisions),
|
|
extra={"tokens": token_count},
|
|
)
|
|
|
|
return decisions
|