feat: implement Token Efficiency - Context optimization (issue #24) #28

Merged
jihoson merged 2 commits from feature/issue-24-token-efficiency into main 2026-02-04 18:39:20 +09:00
9 changed files with 1998 additions and 12 deletions

293
src/brain/cache.py Normal file
View File

@@ -0,0 +1,293 @@
"""Response caching system for reducing redundant LLM calls.
This module provides caching for common trading scenarios:
- TTL-based cache invalidation
- Cache key based on market conditions
- Cache hit rate monitoring
- Special handling for HOLD decisions in quiet markets
"""
from __future__ import annotations
import hashlib
import json
import logging
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from src.brain.gemini_client import TradeDecision
logger = logging.getLogger(__name__)
@dataclass
class CacheEntry:
"""Cached decision with metadata."""
decision: TradeDecision
cached_at: float # Unix timestamp
hit_count: int = 0
market_data_hash: str = ""
@dataclass
class CacheMetrics:
"""Metrics for cache performance monitoring."""
total_requests: int = 0
cache_hits: int = 0
cache_misses: int = 0
evictions: int = 0
total_entries: int = 0
@property
def hit_rate(self) -> float:
"""Calculate cache hit rate."""
if self.total_requests == 0:
return 0.0
return self.cache_hits / self.total_requests
def to_dict(self) -> dict[str, Any]:
"""Convert metrics to dictionary."""
return {
"total_requests": self.total_requests,
"cache_hits": self.cache_hits,
"cache_misses": self.cache_misses,
"hit_rate": self.hit_rate,
"evictions": self.evictions,
"total_entries": self.total_entries,
}
class DecisionCache:
"""TTL-based cache for trade decisions."""
def __init__(self, ttl_seconds: int = 300, max_size: int = 1000) -> None:
"""Initialize the decision cache.
Args:
ttl_seconds: Time-to-live for cache entries in seconds (default: 5 minutes)
max_size: Maximum number of cache entries
"""
self.ttl_seconds = ttl_seconds
self.max_size = max_size
self._cache: dict[str, CacheEntry] = {}
self._metrics = CacheMetrics()
def _generate_cache_key(self, market_data: dict[str, Any]) -> str:
"""Generate cache key from market data.
Key is based on:
- Stock code
- Current price (rounded to reduce sensitivity)
- Market conditions (orderbook snapshot)
Args:
market_data: Market data dictionary
Returns:
Cache key string
"""
# Extract key components
stock_code = market_data.get("stock_code", "UNKNOWN")
current_price = market_data.get("current_price", 0)
# Round price to reduce sensitivity (cache hits for similar prices)
# For prices > 1000, round to nearest 10
# For prices < 1000, round to nearest 1
if current_price > 1000:
price_rounded = round(current_price / 10) * 10
else:
price_rounded = round(current_price)
# Include orderbook snapshot (if available)
orderbook_key = ""
if "orderbook" in market_data and market_data["orderbook"]:
ob = market_data["orderbook"]
# Just use bid/ask spread as indicator
if "bid" in ob and "ask" in ob and ob["bid"] and ob["ask"]:
bid_price = ob["bid"][0].get("price", 0) if ob["bid"] else 0
ask_price = ob["ask"][0].get("price", 0) if ob["ask"] else 0
spread = ask_price - bid_price
orderbook_key = f"_spread{spread}"
# Generate cache key
key_str = f"{stock_code}_{price_rounded}{orderbook_key}"
return key_str
def _generate_market_hash(self, market_data: dict[str, Any]) -> str:
"""Generate hash of full market data for invalidation checks.
Args:
market_data: Market data dictionary
Returns:
Hash string
"""
# Create stable JSON representation
stable_json = json.dumps(market_data, sort_keys=True, ensure_ascii=False)
return hashlib.md5(stable_json.encode()).hexdigest()
def get(self, market_data: dict[str, Any]) -> TradeDecision | None:
"""Retrieve cached decision if valid.
Args:
market_data: Market data dictionary
Returns:
Cached TradeDecision if valid, None otherwise
"""
self._metrics.total_requests += 1
cache_key = self._generate_cache_key(market_data)
if cache_key not in self._cache:
self._metrics.cache_misses += 1
return None
entry = self._cache[cache_key]
current_time = time.time()
# Check TTL
if current_time - entry.cached_at > self.ttl_seconds:
# Expired
del self._cache[cache_key]
self._metrics.cache_misses += 1
self._metrics.evictions += 1
logger.debug("Cache expired for key: %s", cache_key)
return None
# Cache hit
entry.hit_count += 1
self._metrics.cache_hits += 1
logger.debug("Cache hit for key: %s (hits: %d)", cache_key, entry.hit_count)
return entry.decision
def set(
self,
market_data: dict[str, Any],
decision: TradeDecision,
) -> None:
"""Store decision in cache.
Args:
market_data: Market data dictionary
decision: TradeDecision to cache
"""
cache_key = self._generate_cache_key(market_data)
market_hash = self._generate_market_hash(market_data)
# Enforce max size (evict oldest if full)
if len(self._cache) >= self.max_size:
# Find oldest entry
oldest_key = min(self._cache.keys(), key=lambda k: self._cache[k].cached_at)
del self._cache[oldest_key]
self._metrics.evictions += 1
logger.debug("Cache full, evicted key: %s", oldest_key)
# Store entry
entry = CacheEntry(
decision=decision,
cached_at=time.time(),
market_data_hash=market_hash,
)
self._cache[cache_key] = entry
self._metrics.total_entries = len(self._cache)
logger.debug("Cached decision for key: %s", cache_key)
def invalidate(self, stock_code: str | None = None) -> int:
"""Invalidate cache entries.
Args:
stock_code: Specific stock code to invalidate, or None for all
Returns:
Number of entries invalidated
"""
if stock_code is None:
# Clear all
count = len(self._cache)
self._cache.clear()
self._metrics.evictions += count
self._metrics.total_entries = 0
logger.info("Invalidated all cache entries (%d)", count)
return count
# Invalidate specific stock
keys_to_remove = [k for k in self._cache.keys() if k.startswith(f"{stock_code}_")]
count = len(keys_to_remove)
for key in keys_to_remove:
del self._cache[key]
self._metrics.evictions += count
self._metrics.total_entries = len(self._cache)
logger.info("Invalidated %d cache entries for stock: %s", count, stock_code)
return count
def cleanup_expired(self) -> int:
"""Remove expired entries from cache.
Returns:
Number of entries removed
"""
current_time = time.time()
expired_keys = [
k
for k, v in self._cache.items()
if current_time - v.cached_at > self.ttl_seconds
]
count = len(expired_keys)
for key in expired_keys:
del self._cache[key]
self._metrics.evictions += count
self._metrics.total_entries = len(self._cache)
if count > 0:
logger.debug("Cleaned up %d expired cache entries", count)
return count
def get_metrics(self) -> CacheMetrics:
"""Get current cache metrics.
Returns:
CacheMetrics object with current statistics
"""
return self._metrics
def reset_metrics(self) -> None:
"""Reset cache metrics."""
self._metrics = CacheMetrics(total_entries=len(self._cache))
logger.info("Cache metrics reset")
def should_cache_decision(self, decision: TradeDecision) -> bool:
"""Determine if a decision should be cached.
HOLD decisions with low confidence are good candidates for caching,
as they're likely to recur in quiet markets.
Args:
decision: TradeDecision to evaluate
Returns:
True if decision should be cached
"""
# Cache HOLD decisions (common in quiet markets)
if decision.action == "HOLD":
return True
# Cache high-confidence decisions (stable signals)
if decision.confidence >= 90:
return True
# Don't cache low-confidence BUY/SELL (volatile signals)
return False

View File

@@ -0,0 +1,296 @@
"""Smart context selection for optimizing token usage.
This module implements intelligent selection of context layers (L1-L7) based on
decision type and market conditions:
- L7 (real-time) for normal trading decisions
- L6-L5 (daily/weekly) for strategic decisions
- L4-L1 (monthly/legacy) only for major events or policy changes
"""
from __future__ import annotations
from dataclasses import dataclass
from datetime import UTC, datetime
from enum import Enum
from typing import Any
from src.context.layer import ContextLayer
from src.context.store import ContextStore
class DecisionType(str, Enum):
"""Type of trading decision being made."""
NORMAL = "normal" # Regular trade decision
STRATEGIC = "strategic" # Strategy adjustment
MAJOR_EVENT = "major_event" # Portfolio rebalancing, policy change
@dataclass(frozen=True)
class ContextSelection:
"""Selected context layers and their relevance scores."""
layers: list[ContextLayer]
relevance_scores: dict[ContextLayer, float]
total_score: float
class ContextSelector:
"""Selects optimal context layers to minimize token usage."""
def __init__(self, store: ContextStore) -> None:
"""Initialize the context selector.
Args:
store: ContextStore instance for retrieving context data
"""
self.store = store
def select_layers(
self,
decision_type: DecisionType = DecisionType.NORMAL,
include_realtime: bool = True,
) -> list[ContextLayer]:
"""Select context layers based on decision type.
Strategy:
- NORMAL: L7 (real-time) only
- STRATEGIC: L7 + L6 + L5 (real-time + daily + weekly)
- MAJOR_EVENT: All layers L1-L7
Args:
decision_type: Type of decision being made
include_realtime: Whether to include L7 real-time data
Returns:
List of context layers to use (ordered by priority)
"""
if decision_type == DecisionType.NORMAL:
# Normal trading: only real-time data
return [ContextLayer.L7_REALTIME] if include_realtime else []
elif decision_type == DecisionType.STRATEGIC:
# Strategic decisions: real-time + recent history
layers = []
if include_realtime:
layers.append(ContextLayer.L7_REALTIME)
layers.extend([ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY])
return layers
else: # MAJOR_EVENT
# Major events: all layers for comprehensive context
layers = []
if include_realtime:
layers.append(ContextLayer.L7_REALTIME)
layers.extend(
[
ContextLayer.L6_DAILY,
ContextLayer.L5_WEEKLY,
ContextLayer.L4_MONTHLY,
ContextLayer.L3_QUARTERLY,
ContextLayer.L2_ANNUAL,
ContextLayer.L1_LEGACY,
]
)
return layers
def score_layer_relevance(
self,
layer: ContextLayer,
decision_type: DecisionType,
current_time: datetime | None = None,
) -> float:
"""Calculate relevance score for a context layer.
Relevance is based on:
1. Decision type (normal, strategic, major event)
2. Layer recency (L7 > L6 > ... > L1)
3. Data availability
Args:
layer: Context layer to score
decision_type: Type of decision being made
current_time: Current time (defaults to now)
Returns:
Relevance score (0.0 to 1.0)
"""
if current_time is None:
current_time = datetime.now(UTC)
# Base scores by decision type
base_scores = {
DecisionType.NORMAL: {
ContextLayer.L7_REALTIME: 1.0,
ContextLayer.L6_DAILY: 0.1,
ContextLayer.L5_WEEKLY: 0.05,
ContextLayer.L4_MONTHLY: 0.01,
ContextLayer.L3_QUARTERLY: 0.0,
ContextLayer.L2_ANNUAL: 0.0,
ContextLayer.L1_LEGACY: 0.0,
},
DecisionType.STRATEGIC: {
ContextLayer.L7_REALTIME: 0.9,
ContextLayer.L6_DAILY: 0.8,
ContextLayer.L5_WEEKLY: 0.7,
ContextLayer.L4_MONTHLY: 0.3,
ContextLayer.L3_QUARTERLY: 0.2,
ContextLayer.L2_ANNUAL: 0.1,
ContextLayer.L1_LEGACY: 0.05,
},
DecisionType.MAJOR_EVENT: {
ContextLayer.L7_REALTIME: 0.7,
ContextLayer.L6_DAILY: 0.7,
ContextLayer.L5_WEEKLY: 0.7,
ContextLayer.L4_MONTHLY: 0.8,
ContextLayer.L3_QUARTERLY: 0.8,
ContextLayer.L2_ANNUAL: 0.9,
ContextLayer.L1_LEGACY: 1.0,
},
}
score = base_scores[decision_type].get(layer, 0.0)
# Check data availability
latest_timeframe = self.store.get_latest_timeframe(layer)
if latest_timeframe is None:
# No data available - reduce score significantly
score *= 0.1
return score
def select_with_scoring(
self,
decision_type: DecisionType = DecisionType.NORMAL,
min_score: float = 0.5,
) -> ContextSelection:
"""Select context layers with relevance scoring.
Args:
decision_type: Type of decision being made
min_score: Minimum relevance score to include a layer
Returns:
ContextSelection with selected layers and scores
"""
all_layers = [
ContextLayer.L7_REALTIME,
ContextLayer.L6_DAILY,
ContextLayer.L5_WEEKLY,
ContextLayer.L4_MONTHLY,
ContextLayer.L3_QUARTERLY,
ContextLayer.L2_ANNUAL,
ContextLayer.L1_LEGACY,
]
scores = {
layer: self.score_layer_relevance(layer, decision_type) for layer in all_layers
}
# Filter by minimum score
selected_layers = [layer for layer, score in scores.items() if score >= min_score]
# Sort by score (descending)
selected_layers.sort(key=lambda layer: scores[layer], reverse=True)
total_score = sum(scores[layer] for layer in selected_layers)
return ContextSelection(
layers=selected_layers,
relevance_scores=scores,
total_score=total_score,
)
def get_context_data(
self,
layers: list[ContextLayer],
max_items_per_layer: int = 10,
) -> dict[str, Any]:
"""Retrieve context data for selected layers.
Args:
layers: List of context layers to retrieve
max_items_per_layer: Maximum number of items per layer
Returns:
Dictionary with context data organized by layer
"""
result: dict[str, Any] = {}
for layer in layers:
# Get latest timeframe for this layer
latest_timeframe = self.store.get_latest_timeframe(layer)
if latest_timeframe:
# Get all contexts for latest timeframe
contexts = self.store.get_all_contexts(layer, latest_timeframe)
# Limit number of items
if len(contexts) > max_items_per_layer:
# Keep only first N items
contexts = dict(list(contexts.items())[:max_items_per_layer])
result[layer.value] = contexts
return result
def estimate_context_tokens(self, context_data: dict[str, Any]) -> int:
"""Estimate total tokens for context data.
Args:
context_data: Context data dictionary
Returns:
Estimated token count
"""
import json
from src.brain.prompt_optimizer import PromptOptimizer
# Serialize to JSON and estimate tokens
json_str = json.dumps(context_data, ensure_ascii=False)
return PromptOptimizer.estimate_tokens(json_str)
def optimize_context_for_budget(
self,
decision_type: DecisionType,
max_tokens: int,
) -> dict[str, Any]:
"""Select and retrieve context data within a token budget.
Args:
decision_type: Type of decision being made
max_tokens: Maximum token budget for context
Returns:
Optimized context data within budget
"""
# Start with minimal selection
selection = self.select_with_scoring(decision_type, min_score=0.5)
# Retrieve data
context_data = self.get_context_data(selection.layers)
# Check if within budget
estimated_tokens = self.estimate_context_tokens(context_data)
if estimated_tokens <= max_tokens:
return context_data
# If over budget, progressively reduce
# 1. Reduce items per layer
for max_items in [5, 3, 1]:
context_data = self.get_context_data(selection.layers, max_items)
estimated_tokens = self.estimate_context_tokens(context_data)
if estimated_tokens <= max_tokens:
return context_data
# 2. Remove lower-priority layers
for min_score in [0.6, 0.7, 0.8, 0.9]:
selection = self.select_with_scoring(decision_type, min_score=min_score)
context_data = self.get_context_data(selection.layers, max_items_per_layer=1)
estimated_tokens = self.estimate_context_tokens(context_data)
if estimated_tokens <= max_tokens:
return context_data
# Last resort: return only L7 with minimal data
return self.get_context_data([ContextLayer.L7_REALTIME], max_items_per_layer=1)

View File

@@ -2,6 +2,11 @@
Constructs prompts from market data, calls Gemini, and parses structured
JSON responses into validated TradeDecision objects.
Includes token efficiency optimizations:
- Prompt compression and abbreviation
- Response caching for common scenarios
- Token usage tracking and metrics
"""
from __future__ import annotations
@@ -14,6 +19,8 @@ from typing import Any
from google import genai
from src.brain.cache import DecisionCache
from src.brain.prompt_optimizer import PromptOptimizer
from src.config import Settings
logger = logging.getLogger(__name__)
@@ -28,17 +35,35 @@ class TradeDecision:
action: str # "BUY" | "SELL" | "HOLD"
confidence: int # 0-100
rationale: str
token_count: int = 0 # Estimated tokens used
cached: bool = False # Whether decision came from cache
class GeminiClient:
"""Wraps the Gemini API for trade decision-making."""
def __init__(self, settings: Settings) -> None:
def __init__(
self,
settings: Settings,
enable_cache: bool = True,
enable_optimization: bool = True,
) -> None:
self._settings = settings
self._confidence_threshold = settings.CONFIDENCE_THRESHOLD
self._client = genai.Client(api_key=settings.GEMINI_API_KEY)
self._model_name = settings.GEMINI_MODEL
# Token efficiency features
self._enable_cache = enable_cache
self._enable_optimization = enable_optimization
self._cache = DecisionCache(ttl_seconds=300) if enable_cache else None
self._optimizer = PromptOptimizer()
# Token usage metrics
self._total_tokens_used = 0
self._total_decisions = 0
self._total_cached_decisions = 0
# ------------------------------------------------------------------
# Prompt Construction
# ------------------------------------------------------------------
@@ -154,26 +179,141 @@ class GeminiClient:
async def decide(self, market_data: dict[str, Any]) -> TradeDecision:
"""Build prompt, call Gemini, and return a parsed decision."""
prompt = self.build_prompt(market_data)
logger.info("Requesting trade decision from Gemini")
# Check cache first
if self._cache:
cached_decision = self._cache.get(market_data)
if cached_decision:
self._total_cached_decisions += 1
self._total_decisions += 1
logger.info(
"Cache hit for decision",
extra={
"action": cached_decision.action,
"confidence": cached_decision.confidence,
"cache_hit_rate": self.get_cache_hit_rate(),
},
)
# Return cached decision with cached flag
return TradeDecision(
action=cached_decision.action,
confidence=cached_decision.confidence,
rationale=cached_decision.rationale,
token_count=0,
cached=True,
)
# Build optimized prompt
if self._enable_optimization:
prompt = self._optimizer.build_compressed_prompt(market_data)
else:
prompt = self.build_prompt(market_data)
# Estimate tokens
token_count = self._optimizer.estimate_tokens(prompt)
self._total_tokens_used += token_count
logger.info(
"Requesting trade decision from Gemini",
extra={"estimated_tokens": token_count, "optimized": self._enable_optimization},
)
try:
response = await self._client.aio.models.generate_content(
model=self._model_name, contents=prompt,
model=self._model_name,
contents=prompt,
)
raw = response.text
except Exception as exc:
logger.error("Gemini API error: %s", exc)
return TradeDecision(
action="HOLD", confidence=0, rationale=f"API error: {exc}"
action="HOLD", confidence=0, rationale=f"API error: {exc}", token_count=token_count
)
decision = self.parse_response(raw)
self._total_decisions += 1
# Add token count to decision
decision_with_tokens = TradeDecision(
action=decision.action,
confidence=decision.confidence,
rationale=decision.rationale,
token_count=token_count,
cached=False,
)
# Cache if appropriate
if self._cache and self._cache.should_cache_decision(decision):
self._cache.set(market_data, decision)
logger.info(
"Gemini decision",
extra={
"action": decision.action,
"confidence": decision.confidence,
"tokens": token_count,
"avg_tokens": self.get_avg_tokens_per_decision(),
},
)
return decision
return decision_with_tokens
# ------------------------------------------------------------------
# Token Efficiency Metrics
# ------------------------------------------------------------------
def get_token_metrics(self) -> dict[str, Any]:
"""Get token usage metrics.
Returns:
Dictionary with token usage statistics
"""
metrics = {
"total_tokens_used": self._total_tokens_used,
"total_decisions": self._total_decisions,
"total_cached_decisions": self._total_cached_decisions,
"avg_tokens_per_decision": self.get_avg_tokens_per_decision(),
"cache_hit_rate": self.get_cache_hit_rate(),
}
if self._cache:
cache_metrics = self._cache.get_metrics()
metrics["cache_metrics"] = cache_metrics.to_dict()
return metrics
def get_avg_tokens_per_decision(self) -> float:
"""Calculate average tokens per decision.
Returns:
Average tokens per decision
"""
if self._total_decisions == 0:
return 0.0
return self._total_tokens_used / self._total_decisions
def get_cache_hit_rate(self) -> float:
"""Calculate cache hit rate.
Returns:
Cache hit rate (0.0 to 1.0)
"""
if self._total_decisions == 0:
return 0.0
return self._total_cached_decisions / self._total_decisions
def reset_metrics(self) -> None:
"""Reset token usage metrics."""
self._total_tokens_used = 0
self._total_decisions = 0
self._total_cached_decisions = 0
if self._cache:
self._cache.reset_metrics()
logger.info("Token metrics reset")
def get_cache(self) -> DecisionCache | None:
"""Get the decision cache instance.
Returns:
DecisionCache instance or None if caching disabled
"""
return self._cache

View File

@@ -0,0 +1,267 @@
"""Prompt optimization utilities for reducing token usage.
This module provides tools to compress prompts while maintaining decision quality:
- Token counting
- Text compression and abbreviation
- Template-based prompts with variable slots
- Priority-based context truncation
"""
from __future__ import annotations
import json
import re
from dataclasses import dataclass
from typing import Any
# Abbreviation mapping for common terms
ABBREVIATIONS = {
"price": "P",
"volume": "V",
"current": "cur",
"previous": "prev",
"change": "chg",
"percentage": "pct",
"market": "mkt",
"orderbook": "ob",
"foreigner": "fgn",
"buy": "B",
"sell": "S",
"hold": "H",
"confidence": "conf",
"rationale": "reason",
"action": "act",
"net": "net",
}
# Reverse mapping for decompression
REVERSE_ABBREVIATIONS = {v: k for k, v in ABBREVIATIONS.items()}
@dataclass(frozen=True)
class TokenMetrics:
"""Metrics about token usage in a prompt."""
char_count: int
word_count: int
estimated_tokens: int # Rough estimate: ~4 chars per token
compression_ratio: float = 1.0 # Original / Compressed
class PromptOptimizer:
"""Optimizes prompts to reduce token usage while maintaining quality."""
@staticmethod
def estimate_tokens(text: str) -> int:
"""Estimate token count for text.
Uses a simple heuristic: ~4 characters per token for English.
This is approximate but sufficient for optimization purposes.
Args:
text: Input text to estimate tokens for
Returns:
Estimated token count
"""
if not text:
return 0
# Simple estimate: 1 token ≈ 4 characters
return max(1, len(text) // 4)
@staticmethod
def count_tokens(text: str) -> TokenMetrics:
"""Count various metrics for a text.
Args:
text: Input text to analyze
Returns:
TokenMetrics with character, word, and estimated token counts
"""
char_count = len(text)
word_count = len(text.split())
estimated_tokens = PromptOptimizer.estimate_tokens(text)
return TokenMetrics(
char_count=char_count,
word_count=word_count,
estimated_tokens=estimated_tokens,
)
@staticmethod
def compress_json(data: dict[str, Any]) -> str:
"""Compress JSON by removing whitespace.
Args:
data: Dictionary to serialize
Returns:
Compact JSON string without whitespace
"""
return json.dumps(data, separators=(",", ":"), ensure_ascii=False)
@staticmethod
def abbreviate_text(text: str, aggressive: bool = False) -> str:
"""Apply abbreviations to reduce text length.
Args:
text: Input text to abbreviate
aggressive: If True, apply more aggressive compression
Returns:
Abbreviated text
"""
result = text
# Apply word-level abbreviations (case-insensitive)
for full, abbr in ABBREVIATIONS.items():
# Word boundaries to avoid partial replacements
pattern = r"\b" + re.escape(full) + r"\b"
result = re.sub(pattern, abbr, result, flags=re.IGNORECASE)
if aggressive:
# Remove articles and filler words
result = re.sub(r"\b(a|an|the)\b", "", result, flags=re.IGNORECASE)
result = re.sub(r"\b(is|are|was|were)\b", "", result, flags=re.IGNORECASE)
# Collapse multiple spaces
result = re.sub(r"\s+", " ", result)
return result.strip()
@staticmethod
def build_compressed_prompt(
market_data: dict[str, Any],
include_instructions: bool = True,
max_length: int | None = None,
) -> str:
"""Build a compressed prompt from market data.
Args:
market_data: Market data dictionary with stock info
include_instructions: Whether to include full instructions
max_length: Maximum character length (truncates if needed)
Returns:
Compressed prompt string
"""
# Abbreviated market name
market_name = market_data.get("market_name", "KR")
if "Korea" in market_name:
market_name = "KR"
elif "United States" in market_name or "US" in market_name:
market_name = "US"
# Core data - always included
core_info = {
"mkt": market_name,
"code": market_data["stock_code"],
"P": market_data["current_price"],
}
# Optional fields
if "orderbook" in market_data and market_data["orderbook"]:
ob = market_data["orderbook"]
# Compress orderbook: keep only top 3 levels
compressed_ob = {
"bid": ob.get("bid", [])[:3],
"ask": ob.get("ask", [])[:3],
}
core_info["ob"] = compressed_ob
if market_data.get("foreigner_net", 0) != 0:
core_info["fgn_net"] = market_data["foreigner_net"]
# Compress to JSON
data_str = PromptOptimizer.compress_json(core_info)
if include_instructions:
# Minimal instructions
prompt = (
f"{market_name} trader. Analyze:\n{data_str}\n\n"
'Return JSON: {"act":"BUY"|"SELL"|"HOLD","conf":<0-100>,"reason":"<text>"}\n'
"Rules: act=BUY/SELL/HOLD, conf=0-100, reason=concise. No markdown."
)
else:
# Data only (for cached contexts where instructions are known)
prompt = data_str
# Truncate if needed
if max_length and len(prompt) > max_length:
prompt = prompt[:max_length] + "..."
return prompt
@staticmethod
def truncate_context(
context: dict[str, Any],
max_tokens: int,
priority_keys: list[str] | None = None,
) -> dict[str, Any]:
"""Truncate context data to fit within token budget.
Keeps high-priority keys first, then truncates less important data.
Args:
context: Context dictionary to truncate
max_tokens: Maximum token budget
priority_keys: List of keys to keep (in order of priority)
Returns:
Truncated context dictionary
"""
if not context:
return {}
if priority_keys is None:
priority_keys = []
result: dict[str, Any] = {}
current_tokens = 0
# Add priority keys first
for key in priority_keys:
if key in context:
value_str = json.dumps(context[key])
tokens = PromptOptimizer.estimate_tokens(value_str)
if current_tokens + tokens <= max_tokens:
result[key] = context[key]
current_tokens += tokens
else:
break
# Add remaining keys if space available
for key, value in context.items():
if key in result:
continue
value_str = json.dumps(value)
tokens = PromptOptimizer.estimate_tokens(value_str)
if current_tokens + tokens <= max_tokens:
result[key] = value
current_tokens += tokens
else:
break
return result
@staticmethod
def calculate_compression_ratio(original: str, compressed: str) -> float:
"""Calculate compression ratio between original and compressed text.
Args:
original: Original text
compressed: Compressed text
Returns:
Compression ratio (original_tokens / compressed_tokens)
"""
original_tokens = PromptOptimizer.estimate_tokens(original)
compressed_tokens = PromptOptimizer.estimate_tokens(compressed)
if compressed_tokens == 0:
return 1.0
return original_tokens / compressed_tokens

328
src/context/summarizer.py Normal file
View File

@@ -0,0 +1,328 @@
"""Context summarization for efficient historical data representation.
This module summarizes old context data instead of including raw details:
- Key metrics only (averages, trends, not details)
- Rolling window (keep last N days detailed, summarize older)
- Aggregate historical data efficiently
"""
from __future__ import annotations
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from typing import Any
from src.context.layer import ContextLayer
from src.context.store import ContextStore
@dataclass(frozen=True)
class SummaryStats:
"""Statistical summary of historical data."""
count: int
mean: float | None = None
min: float | None = None
max: float | None = None
std: float | None = None
trend: str | None = None # "up", "down", "flat"
class ContextSummarizer:
"""Summarizes historical context data to reduce token usage."""
def __init__(self, store: ContextStore) -> None:
"""Initialize the context summarizer.
Args:
store: ContextStore instance for retrieving context data
"""
self.store = store
def summarize_numeric_values(self, values: list[float]) -> SummaryStats:
"""Summarize a list of numeric values.
Args:
values: List of numeric values to summarize
Returns:
SummaryStats with mean, min, max, std, and trend
"""
if not values:
return SummaryStats(count=0)
count = len(values)
mean = sum(values) / count
min_val = min(values)
max_val = max(values)
# Calculate standard deviation
if count > 1:
variance = sum((x - mean) ** 2 for x in values) / (count - 1)
std = variance**0.5
else:
std = 0.0
# Determine trend
trend = "flat"
if count >= 3:
# Simple trend: compare first third vs last third
first_third = values[: count // 3]
last_third = values[-(count // 3) :]
first_avg = sum(first_third) / len(first_third)
last_avg = sum(last_third) / len(last_third)
# Trend threshold: 5% change
threshold = 0.05 * abs(first_avg) if first_avg != 0 else 0.01
if last_avg > first_avg + threshold:
trend = "up"
elif last_avg < first_avg - threshold:
trend = "down"
return SummaryStats(
count=count,
mean=round(mean, 4),
min=round(min_val, 4),
max=round(max_val, 4),
std=round(std, 4),
trend=trend,
)
def summarize_layer(
self,
layer: ContextLayer,
start_date: datetime | None = None,
end_date: datetime | None = None,
) -> dict[str, Any]:
"""Summarize all context data for a layer within a date range.
Args:
layer: Context layer to summarize
start_date: Start date (inclusive), None for all
end_date: End date (inclusive), None for now
Returns:
Dictionary with summarized metrics
"""
if end_date is None:
end_date = datetime.now(UTC)
# Get all contexts for this layer
all_contexts = self.store.get_all_contexts(layer)
if not all_contexts:
return {"summary": "No data available", "count": 0}
# Group numeric values by key
numeric_data: dict[str, list[float]] = {}
text_data: dict[str, list[str]] = {}
for key, value in all_contexts.items():
# Try to extract numeric values
if isinstance(value, (int, float)):
if key not in numeric_data:
numeric_data[key] = []
numeric_data[key].append(float(value))
elif isinstance(value, dict):
# Extract numeric fields from dict
for subkey, subvalue in value.items():
if isinstance(subvalue, (int, float)):
full_key = f"{key}.{subkey}"
if full_key not in numeric_data:
numeric_data[full_key] = []
numeric_data[full_key].append(float(subvalue))
elif isinstance(value, str):
if key not in text_data:
text_data[key] = []
text_data[key].append(value)
# Summarize numeric data
summary: dict[str, Any] = {}
for key, values in numeric_data.items():
stats = self.summarize_numeric_values(values)
summary[key] = {
"count": stats.count,
"avg": stats.mean,
"range": [stats.min, stats.max],
"trend": stats.trend,
}
# Summarize text data (just counts)
for key, values in text_data.items():
summary[f"{key}_count"] = len(values)
summary["total_entries"] = len(all_contexts)
return summary
def rolling_window_summary(
self,
layer: ContextLayer,
window_days: int = 30,
summarize_older: bool = True,
) -> dict[str, Any]:
"""Create a rolling window summary.
Recent data (within window) is kept detailed.
Older data is summarized to key metrics.
Args:
layer: Context layer to summarize
window_days: Number of days to keep detailed
summarize_older: Whether to summarize data older than window
Returns:
Dictionary with recent (detailed) and historical (summary) data
"""
result: dict[str, Any] = {
"window_days": window_days,
"recent_data": {},
"historical_summary": {},
}
# Get all contexts
all_contexts = self.store.get_all_contexts(layer)
recent_values: dict[str, list[float]] = {}
historical_values: dict[str, list[float]] = {}
for key, value in all_contexts.items():
# For simplicity, treat all numeric values
if isinstance(value, (int, float)):
# Note: We don't have timestamps in context keys
# This is a simplified implementation
# In practice, would need to check timeframe field
# For now, put recent data in window
if key not in recent_values:
recent_values[key] = []
recent_values[key].append(float(value))
# Detailed recent data
result["recent_data"] = {key: values[-10:] for key, values in recent_values.items()}
# Summarized historical data
if summarize_older:
for key, values in historical_values.items():
stats = self.summarize_numeric_values(values)
result["historical_summary"][key] = {
"count": stats.count,
"avg": stats.mean,
"trend": stats.trend,
}
return result
def aggregate_to_higher_layer(
self,
source_layer: ContextLayer,
target_layer: ContextLayer,
metric_key: str,
aggregation_func: str = "mean",
) -> float | None:
"""Aggregate data from source layer to target layer.
Args:
source_layer: Source context layer (more granular)
target_layer: Target context layer (less granular)
metric_key: Key of metric to aggregate
aggregation_func: Aggregation function ("mean", "sum", "max", "min")
Returns:
Aggregated value, or None if no data available
"""
# Get all contexts from source layer
source_contexts = self.store.get_all_contexts(source_layer)
# Extract values for metric_key
values = []
for key, value in source_contexts.items():
if key == metric_key and isinstance(value, (int, float)):
values.append(float(value))
elif isinstance(value, dict) and metric_key in value:
subvalue = value[metric_key]
if isinstance(subvalue, (int, float)):
values.append(float(subvalue))
if not values:
return None
# Apply aggregation function
if aggregation_func == "mean":
return sum(values) / len(values)
elif aggregation_func == "sum":
return sum(values)
elif aggregation_func == "max":
return max(values)
elif aggregation_func == "min":
return min(values)
else:
return sum(values) / len(values) # Default to mean
def create_compact_summary(
self,
layers: list[ContextLayer],
top_n_metrics: int = 5,
) -> dict[str, Any]:
"""Create a compact summary across multiple layers.
Args:
layers: List of context layers to summarize
top_n_metrics: Number of top metrics to include per layer
Returns:
Compact summary dictionary
"""
summary: dict[str, Any] = {}
for layer in layers:
layer_summary = self.summarize_layer(layer)
# Keep only top N metrics (by count/relevance)
metrics = []
for key, value in layer_summary.items():
if isinstance(value, dict) and "count" in value:
metrics.append((key, value, value["count"]))
# Sort by count (descending)
metrics.sort(key=lambda x: x[2], reverse=True)
# Keep top N
top_metrics = {m[0]: m[1] for m in metrics[:top_n_metrics]}
summary[layer.value] = top_metrics
return summary
def format_summary_for_prompt(self, summary: dict[str, Any]) -> str:
"""Format summary for inclusion in a prompt.
Args:
summary: Summary dictionary
Returns:
Formatted string for prompt
"""
lines = []
for layer, metrics in summary.items():
if not metrics:
continue
lines.append(f"{layer}:")
for key, value in metrics.items():
if isinstance(value, dict):
# Format as: key: avg=X, trend=Y
parts = []
if "avg" in value and value["avg"] is not None:
parts.append(f"avg={value['avg']:.2f}")
if "trend" in value and value["trend"]:
parts.append(f"trend={value['trend']}")
if parts:
lines.append(f" {key}: {', '.join(parts)}")
else:
lines.append(f" {key}: {value}")
return "\n".join(lines)

View File

@@ -23,7 +23,7 @@ from google import genai
from src.config import Settings
from src.db import init_db
from src.logging.decision_logger import DecisionLog, DecisionLogger
from src.logging.decision_logger import DecisionLogger
logger = logging.getLogger(__name__)

View File

@@ -21,7 +21,7 @@ from src.broker.overseas import OverseasBroker
from src.config import Settings
from src.context.layer import ContextLayer
from src.context.store import ContextStore
from src.core.criticality import CriticalityAssessor, CriticalityLevel
from src.core.criticality import CriticalityAssessor
from src.core.priority_queue import PriorityTaskQueue
from src.core.risk_manager import CircuitBreakerTripped, RiskManager
from src.db import init_db, log_trade

View File

@@ -11,15 +11,15 @@ from __future__ import annotations
import json
import sqlite3
import tempfile
from datetime import UTC, datetime, timedelta
from datetime import UTC, datetime
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, Mock, patch
from unittest.mock import AsyncMock, Mock, patch
import pytest
from src.config import Settings
from src.db import init_db, log_trade
from src.evolution.ab_test import ABTester, ABTestResult, StrategyPerformance
from src.evolution.ab_test import ABTester
from src.evolution.optimizer import EvolutionOptimizer
from src.evolution.performance_tracker import (
PerformanceDashboard,
@@ -28,7 +28,6 @@ from src.evolution.performance_tracker import (
)
from src.logging.decision_logger import DecisionLogger
# ------------------------------------------------------------------
# Fixtures
# ------------------------------------------------------------------

View File

@@ -0,0 +1,663 @@
"""Tests for token efficiency optimization components.
Tests cover:
- Prompt compression and optimization
- Context selection logic
- Summarization
- Caching
- Token reduction metrics
"""
from __future__ import annotations
import sqlite3
import time
import pytest
from src.brain.cache import DecisionCache
from src.brain.context_selector import ContextSelector, DecisionType
from src.brain.gemini_client import TradeDecision
from src.brain.prompt_optimizer import PromptOptimizer, TokenMetrics
from src.context.layer import ContextLayer
from src.context.store import ContextStore
from src.context.summarizer import ContextSummarizer, SummaryStats
# ============================================================================
# Prompt Optimizer Tests
# ============================================================================
class TestPromptOptimizer:
"""Tests for PromptOptimizer."""
def test_estimate_tokens(self):
"""Test token estimation."""
optimizer = PromptOptimizer()
# Empty text
assert optimizer.estimate_tokens("") == 0
# Short text (4 chars = 1 token estimate)
assert optimizer.estimate_tokens("test") == 1
# Longer text
text = "This is a longer piece of text for testing token estimation."
tokens = optimizer.estimate_tokens(text)
assert tokens > 0
assert tokens == len(text) // 4
def test_count_tokens(self):
"""Test token counting metrics."""
optimizer = PromptOptimizer()
text = "Hello world, this is a test."
metrics = optimizer.count_tokens(text)
assert isinstance(metrics, TokenMetrics)
assert metrics.char_count == len(text)
assert metrics.word_count == 6
assert metrics.estimated_tokens > 0
def test_compress_json(self):
"""Test JSON compression."""
optimizer = PromptOptimizer()
data = {
"action": "BUY",
"confidence": 85,
"rationale": "Strong uptrend",
}
compressed = optimizer.compress_json(data)
# Should have no newlines and minimal whitespace
assert "\n" not in compressed
# Note: JSON values may contain spaces (e.g., "Strong uptrend")
# but there should be no spaces around separators
assert ": " not in compressed
assert ", " not in compressed
# Should be valid JSON
import json
parsed = json.loads(compressed)
assert parsed == data
def test_abbreviate_text(self):
"""Test text abbreviation."""
optimizer = PromptOptimizer()
text = "The current price is high and volume is increasing."
abbreviated = optimizer.abbreviate_text(text)
# Should contain abbreviations
assert "cur" in abbreviated or "P" in abbreviated
assert len(abbreviated) <= len(text)
def test_abbreviate_text_aggressive(self):
"""Test aggressive text abbreviation."""
optimizer = PromptOptimizer()
text = "The price is increasing and the volume is high."
abbreviated = optimizer.abbreviate_text(text, aggressive=True)
# Should be shorter
assert len(abbreviated) < len(text)
# Should have removed articles
assert "the" not in abbreviated.lower()
def test_build_compressed_prompt(self):
"""Test compressed prompt building."""
optimizer = PromptOptimizer()
market_data = {
"stock_code": "005930",
"current_price": 75000,
"market_name": "Korean stock market",
}
prompt = optimizer.build_compressed_prompt(market_data)
# Should be much shorter than original
assert len(prompt) < 300
assert "005930" in prompt
assert "75000" in prompt
def test_build_compressed_prompt_no_instructions(self):
"""Test compressed prompt without instructions."""
optimizer = PromptOptimizer()
market_data = {
"stock_code": "AAPL",
"current_price": 150.5,
"market_name": "United States",
}
prompt = optimizer.build_compressed_prompt(market_data, include_instructions=False)
# Should be very short (data only)
assert len(prompt) < 100
assert "AAPL" in prompt
def test_truncate_context(self):
"""Test context truncation."""
optimizer = PromptOptimizer()
context = {
"price": 100.5,
"volume": 1000000,
"sentiment": 0.8,
"extra_data": "Some long text that should be truncated",
}
# Truncate to small budget
truncated = optimizer.truncate_context(context, max_tokens=10)
# Should have fewer keys
assert len(truncated) <= len(context)
def test_truncate_context_with_priority(self):
"""Test context truncation with priority keys."""
optimizer = PromptOptimizer()
context = {
"price": 100.5,
"volume": 1000000,
"sentiment": 0.8,
"extra_data": "Some data",
}
priority_keys = ["price", "sentiment"]
truncated = optimizer.truncate_context(context, max_tokens=20, priority_keys=priority_keys)
# Priority keys should be included
assert "price" in truncated
assert "sentiment" in truncated
def test_calculate_compression_ratio(self):
"""Test compression ratio calculation."""
optimizer = PromptOptimizer()
original = "This is a very long piece of text that should be compressed significantly."
compressed = "Short text"
ratio = optimizer.calculate_compression_ratio(original, compressed)
# Ratio should be > 1 (original is longer)
assert ratio > 1.0
# ============================================================================
# Context Selector Tests
# ============================================================================
class TestContextSelector:
"""Tests for ContextSelector."""
@pytest.fixture
def store(self):
"""Create in-memory ContextStore."""
conn = sqlite3.connect(":memory:")
# Create tables
conn.execute(
"""
CREATE TABLE context_metadata (
layer TEXT PRIMARY KEY,
description TEXT,
retention_days INTEGER,
aggregation_source TEXT
)
"""
)
conn.execute(
"""
CREATE TABLE contexts (
layer TEXT,
timeframe TEXT,
key TEXT,
value TEXT,
created_at TEXT,
updated_at TEXT,
PRIMARY KEY (layer, timeframe, key)
)
"""
)
conn.commit()
return ContextStore(conn)
def test_select_layers_normal(self, store):
"""Test layer selection for normal decisions."""
selector = ContextSelector(store)
layers = selector.select_layers(DecisionType.NORMAL)
# Should only select L7 (real-time)
assert layers == [ContextLayer.L7_REALTIME]
def test_select_layers_strategic(self, store):
"""Test layer selection for strategic decisions."""
selector = ContextSelector(store)
layers = selector.select_layers(DecisionType.STRATEGIC)
# Should select L7 + L6 + L5
assert ContextLayer.L7_REALTIME in layers
assert ContextLayer.L6_DAILY in layers
assert ContextLayer.L5_WEEKLY in layers
assert len(layers) == 3
def test_select_layers_major_event(self, store):
"""Test layer selection for major events."""
selector = ContextSelector(store)
layers = selector.select_layers(DecisionType.MAJOR_EVENT)
# Should select all layers
assert len(layers) == 7
assert ContextLayer.L1_LEGACY in layers
assert ContextLayer.L7_REALTIME in layers
def test_score_layer_relevance(self, store):
"""Test layer relevance scoring."""
selector = ContextSelector(store)
# Add some data first so scores aren't penalized
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
store.set_context(ContextLayer.L1_LEGACY, "legacy", "lesson", "test")
# L7 should have high score for normal decisions
score = selector.score_layer_relevance(ContextLayer.L7_REALTIME, DecisionType.NORMAL)
assert score == 1.0
# L1 should have low score for normal decisions
score = selector.score_layer_relevance(ContextLayer.L1_LEGACY, DecisionType.NORMAL)
assert score == 0.0
# L1 should have high score for major events
score = selector.score_layer_relevance(ContextLayer.L1_LEGACY, DecisionType.MAJOR_EVENT)
assert score == 1.0
def test_select_with_scoring(self, store):
"""Test selection with relevance scoring."""
selector = ContextSelector(store)
# Add data so layers aren't penalized
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
selection = selector.select_with_scoring(DecisionType.NORMAL, min_score=0.5)
# Should only select high-relevance layers
assert len(selection.layers) >= 1
assert ContextLayer.L7_REALTIME in selection.layers
assert all(selection.relevance_scores[layer] >= 0.5 for layer in selection.layers)
def test_get_context_data(self, store):
"""Test context data retrieval."""
selector = ContextSelector(store)
# Add some test data
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "volume", 1000000)
context_data = selector.get_context_data([ContextLayer.L7_REALTIME])
# Should retrieve data
assert "L7_REALTIME" in context_data
assert "price" in context_data["L7_REALTIME"]
assert context_data["L7_REALTIME"]["price"] == 100.5
def test_estimate_context_tokens(self, store):
"""Test context token estimation."""
selector = ContextSelector(store)
context_data = {
"L7_REALTIME": {"price": 100.5, "volume": 1000000},
"L6_DAILY": {"avg_price": 99.8, "avg_volume": 950000},
}
tokens = selector.estimate_context_tokens(context_data)
# Should estimate tokens
assert tokens > 0
def test_optimize_context_for_budget(self, store):
"""Test context optimization for token budget."""
selector = ContextSelector(store)
# Add test data
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
# Get optimized context within budget
context = selector.optimize_context_for_budget(DecisionType.NORMAL, max_tokens=50)
# Should return data within budget
tokens = selector.estimate_context_tokens(context)
assert tokens <= 50
# ============================================================================
# Context Summarizer Tests
# ============================================================================
class TestContextSummarizer:
"""Tests for ContextSummarizer."""
@pytest.fixture
def store(self):
"""Create in-memory ContextStore."""
conn = sqlite3.connect(":memory:")
conn.execute(
"""
CREATE TABLE context_metadata (
layer TEXT PRIMARY KEY,
description TEXT,
retention_days INTEGER,
aggregation_source TEXT
)
"""
)
conn.execute(
"""
CREATE TABLE contexts (
layer TEXT,
timeframe TEXT,
key TEXT,
value TEXT,
created_at TEXT,
updated_at TEXT,
PRIMARY KEY (layer, timeframe, key)
)
"""
)
conn.commit()
return ContextStore(conn)
def test_summarize_numeric_values(self, store):
"""Test numeric value summarization."""
summarizer = ContextSummarizer(store)
values = [10.0, 20.0, 30.0, 40.0, 50.0]
stats = summarizer.summarize_numeric_values(values)
assert isinstance(stats, SummaryStats)
assert stats.count == 5
assert stats.mean == 30.0
assert stats.min == 10.0
assert stats.max == 50.0
assert stats.std is not None
def test_summarize_numeric_values_trend(self, store):
"""Test trend detection in numeric values."""
summarizer = ContextSummarizer(store)
# Uptrend
values_up = [10.0, 15.0, 20.0, 25.0, 30.0, 35.0]
stats_up = summarizer.summarize_numeric_values(values_up)
assert stats_up.trend == "up"
# Downtrend
values_down = [35.0, 30.0, 25.0, 20.0, 15.0, 10.0]
stats_down = summarizer.summarize_numeric_values(values_down)
assert stats_down.trend == "down"
# Flat
values_flat = [20.0, 20.1, 19.9, 20.0, 20.1, 19.9]
stats_flat = summarizer.summarize_numeric_values(values_flat)
assert stats_flat.trend == "flat"
def test_summarize_layer(self, store):
"""Test layer summarization."""
summarizer = ContextSummarizer(store)
# Add test data
store.set_context(ContextLayer.L6_DAILY, "2026-02-04", "price", 100.5)
store.set_context(ContextLayer.L6_DAILY, "2026-02-04", "volume", 1000000)
summary = summarizer.summarize_layer(ContextLayer.L6_DAILY)
# Should have summary
assert "total_entries" in summary
assert summary["total_entries"] > 0
def test_create_compact_summary(self, store):
"""Test compact summary creation."""
summarizer = ContextSummarizer(store)
# Add test data
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
layers = [ContextLayer.L7_REALTIME, ContextLayer.L6_DAILY]
summary = summarizer.create_compact_summary(layers, top_n_metrics=3)
# Should have summaries for layers
assert "L7_REALTIME" in summary
def test_format_summary_for_prompt(self, store):
"""Test summary formatting for prompt."""
summarizer = ContextSummarizer(store)
summary = {
"L7_REALTIME": {
"price": {"avg": 100.5, "trend": "up"},
"volume": {"avg": 1000000, "trend": "flat"},
}
}
formatted = summarizer.format_summary_for_prompt(summary)
# Should be formatted string
assert isinstance(formatted, str)
assert "L7_REALTIME" in formatted
assert "100.5" in formatted or "100.50" in formatted
# ============================================================================
# Decision Cache Tests
# ============================================================================
class TestDecisionCache:
"""Tests for DecisionCache."""
def test_cache_init(self):
"""Test cache initialization."""
cache = DecisionCache(ttl_seconds=60, max_size=100)
assert cache.ttl_seconds == 60
assert cache.max_size == 100
def test_cache_miss(self):
"""Test cache miss."""
cache = DecisionCache()
market_data = {"stock_code": "005930", "current_price": 75000}
decision = cache.get(market_data)
# Should be None (cache miss)
assert decision is None
metrics = cache.get_metrics()
assert metrics.cache_misses == 1
assert metrics.cache_hits == 0
def test_cache_hit(self):
"""Test cache hit."""
cache = DecisionCache()
market_data = {"stock_code": "005930", "current_price": 75000}
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
# Set cache
cache.set(market_data, decision)
# Get from cache
cached = cache.get(market_data)
assert cached is not None
assert cached.action == "HOLD"
assert cached.confidence == 50
metrics = cache.get_metrics()
assert metrics.cache_hits == 1
def test_cache_ttl_expiration(self):
"""Test cache TTL expiration."""
cache = DecisionCache(ttl_seconds=1) # 1 second TTL
market_data = {"stock_code": "005930", "current_price": 75000}
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
# Set cache
cache.set(market_data, decision)
# Should hit immediately
cached = cache.get(market_data)
assert cached is not None
# Wait for expiration
time.sleep(1.1)
# Should miss after expiration
cached = cache.get(market_data)
assert cached is None
def test_cache_max_size(self):
"""Test cache max size eviction."""
cache = DecisionCache(max_size=2)
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
# Add 3 entries (exceeds max_size)
for i in range(3):
market_data = {"stock_code": f"00{i}", "current_price": 1000 * i}
cache.set(market_data, decision)
metrics = cache.get_metrics()
# Should have evicted 1 entry
assert metrics.total_entries == 2
assert metrics.evictions == 1
def test_invalidate_all(self):
"""Test invalidate all cache entries."""
cache = DecisionCache()
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
# Add entries
for i in range(3):
market_data = {"stock_code": f"00{i}", "current_price": 1000}
cache.set(market_data, decision)
# Invalidate all
count = cache.invalidate()
assert count == 3
metrics = cache.get_metrics()
assert metrics.total_entries == 0
def test_invalidate_by_stock(self):
"""Test invalidate cache by stock code."""
cache = DecisionCache()
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
# Add entries for different stocks
cache.set({"stock_code": "005930", "current_price": 75000}, decision)
cache.set({"stock_code": "000660", "current_price": 50000}, decision)
# Invalidate specific stock
count = cache.invalidate("005930")
assert count >= 1
# Other stock should still be cached
cached = cache.get({"stock_code": "000660", "current_price": 50000})
assert cached is not None
def test_cleanup_expired(self):
"""Test cleanup of expired entries."""
cache = DecisionCache(ttl_seconds=1)
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
# Add entry
cache.set({"stock_code": "005930", "current_price": 75000}, decision)
# Wait for expiration
time.sleep(1.1)
# Cleanup
count = cache.cleanup_expired()
assert count == 1
metrics = cache.get_metrics()
assert metrics.total_entries == 0
def test_should_cache_decision(self):
"""Test decision caching criteria."""
cache = DecisionCache()
# HOLD decisions should be cached
hold_decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
assert cache.should_cache_decision(hold_decision) is True
# High confidence BUY should be cached
buy_decision = TradeDecision(action="BUY", confidence=95, rationale="Test")
assert cache.should_cache_decision(buy_decision) is True
# Low confidence BUY should not be cached
low_conf_buy = TradeDecision(action="BUY", confidence=60, rationale="Test")
assert cache.should_cache_decision(low_conf_buy) is False
def test_cache_hit_rate(self):
"""Test cache hit rate calculation."""
cache = DecisionCache()
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
market_data = {"stock_code": "005930", "current_price": 75000}
# First request (miss)
cache.get(market_data)
# Set cache
cache.set(market_data, decision)
# Second request (hit)
cache.get(market_data)
# Third request (hit)
cache.get(market_data)
metrics = cache.get_metrics()
# 1 miss, 2 hits out of 3 requests
assert metrics.total_requests == 3
assert metrics.cache_hits == 2
assert metrics.cache_misses == 1
assert metrics.hit_rate == pytest.approx(2 / 3)
def test_reset_metrics(self):
"""Test metrics reset."""
cache = DecisionCache()
market_data = {"stock_code": "005930", "current_price": 75000}
# Generate some activity
cache.get(market_data)
cache.get(market_data)
# Reset
cache.reset_metrics()
metrics = cache.get_metrics()
assert metrics.total_requests == 0
assert metrics.cache_hits == 0
assert metrics.cache_misses == 0