feat: implement Token Efficiency - Context optimization (issue #24) #28
293
src/brain/cache.py
Normal file
293
src/brain/cache.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""Response caching system for reducing redundant LLM calls.
|
||||
|
||||
This module provides caching for common trading scenarios:
|
||||
- TTL-based cache invalidation
|
||||
- Cache key based on market conditions
|
||||
- Cache hit rate monitoring
|
||||
- Special handling for HOLD decisions in quiet markets
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.brain.gemini_client import TradeDecision
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""Cached decision with metadata."""
|
||||
|
||||
decision: "TradeDecision"
|
||||
cached_at: float # Unix timestamp
|
||||
hit_count: int = 0
|
||||
market_data_hash: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheMetrics:
|
||||
"""Metrics for cache performance monitoring."""
|
||||
|
||||
total_requests: int = 0
|
||||
cache_hits: int = 0
|
||||
cache_misses: int = 0
|
||||
evictions: int = 0
|
||||
total_entries: int = 0
|
||||
|
||||
@property
|
||||
def hit_rate(self) -> float:
|
||||
"""Calculate cache hit rate."""
|
||||
if self.total_requests == 0:
|
||||
return 0.0
|
||||
return self.cache_hits / self.total_requests
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert metrics to dictionary."""
|
||||
return {
|
||||
"total_requests": self.total_requests,
|
||||
"cache_hits": self.cache_hits,
|
||||
"cache_misses": self.cache_misses,
|
||||
"hit_rate": self.hit_rate,
|
||||
"evictions": self.evictions,
|
||||
"total_entries": self.total_entries,
|
||||
}
|
||||
|
||||
|
||||
class DecisionCache:
|
||||
"""TTL-based cache for trade decisions."""
|
||||
|
||||
def __init__(self, ttl_seconds: int = 300, max_size: int = 1000) -> None:
|
||||
"""Initialize the decision cache.
|
||||
|
||||
Args:
|
||||
ttl_seconds: Time-to-live for cache entries in seconds (default: 5 minutes)
|
||||
max_size: Maximum number of cache entries
|
||||
"""
|
||||
self.ttl_seconds = ttl_seconds
|
||||
self.max_size = max_size
|
||||
self._cache: dict[str, CacheEntry] = {}
|
||||
self._metrics = CacheMetrics()
|
||||
|
||||
def _generate_cache_key(self, market_data: dict[str, Any]) -> str:
|
||||
"""Generate cache key from market data.
|
||||
|
||||
Key is based on:
|
||||
- Stock code
|
||||
- Current price (rounded to reduce sensitivity)
|
||||
- Market conditions (orderbook snapshot)
|
||||
|
||||
Args:
|
||||
market_data: Market data dictionary
|
||||
|
||||
Returns:
|
||||
Cache key string
|
||||
"""
|
||||
# Extract key components
|
||||
stock_code = market_data.get("stock_code", "UNKNOWN")
|
||||
current_price = market_data.get("current_price", 0)
|
||||
|
||||
# Round price to reduce sensitivity (cache hits for similar prices)
|
||||
# For prices > 1000, round to nearest 10
|
||||
# For prices < 1000, round to nearest 1
|
||||
if current_price > 1000:
|
||||
price_rounded = round(current_price / 10) * 10
|
||||
else:
|
||||
price_rounded = round(current_price)
|
||||
|
||||
# Include orderbook snapshot (if available)
|
||||
orderbook_key = ""
|
||||
if "orderbook" in market_data and market_data["orderbook"]:
|
||||
ob = market_data["orderbook"]
|
||||
# Just use bid/ask spread as indicator
|
||||
if "bid" in ob and "ask" in ob and ob["bid"] and ob["ask"]:
|
||||
bid_price = ob["bid"][0].get("price", 0) if ob["bid"] else 0
|
||||
ask_price = ob["ask"][0].get("price", 0) if ob["ask"] else 0
|
||||
spread = ask_price - bid_price
|
||||
orderbook_key = f"_spread{spread}"
|
||||
|
||||
# Generate cache key
|
||||
key_str = f"{stock_code}_{price_rounded}{orderbook_key}"
|
||||
|
||||
return key_str
|
||||
|
||||
def _generate_market_hash(self, market_data: dict[str, Any]) -> str:
|
||||
"""Generate hash of full market data for invalidation checks.
|
||||
|
||||
Args:
|
||||
market_data: Market data dictionary
|
||||
|
||||
Returns:
|
||||
Hash string
|
||||
"""
|
||||
# Create stable JSON representation
|
||||
stable_json = json.dumps(market_data, sort_keys=True, ensure_ascii=False)
|
||||
return hashlib.md5(stable_json.encode()).hexdigest()
|
||||
|
||||
def get(self, market_data: dict[str, Any]) -> TradeDecision | None:
|
||||
"""Retrieve cached decision if valid.
|
||||
|
||||
Args:
|
||||
market_data: Market data dictionary
|
||||
|
||||
Returns:
|
||||
Cached TradeDecision if valid, None otherwise
|
||||
"""
|
||||
self._metrics.total_requests += 1
|
||||
|
||||
cache_key = self._generate_cache_key(market_data)
|
||||
|
||||
if cache_key not in self._cache:
|
||||
self._metrics.cache_misses += 1
|
||||
return None
|
||||
|
||||
entry = self._cache[cache_key]
|
||||
current_time = time.time()
|
||||
|
||||
# Check TTL
|
||||
if current_time - entry.cached_at > self.ttl_seconds:
|
||||
# Expired
|
||||
del self._cache[cache_key]
|
||||
self._metrics.cache_misses += 1
|
||||
self._metrics.evictions += 1
|
||||
logger.debug("Cache expired for key: %s", cache_key)
|
||||
return None
|
||||
|
||||
# Cache hit
|
||||
entry.hit_count += 1
|
||||
self._metrics.cache_hits += 1
|
||||
logger.debug("Cache hit for key: %s (hits: %d)", cache_key, entry.hit_count)
|
||||
|
||||
return entry.decision
|
||||
|
||||
def set(
|
||||
self,
|
||||
market_data: dict[str, Any],
|
||||
decision: TradeDecision,
|
||||
) -> None:
|
||||
"""Store decision in cache.
|
||||
|
||||
Args:
|
||||
market_data: Market data dictionary
|
||||
decision: TradeDecision to cache
|
||||
"""
|
||||
cache_key = self._generate_cache_key(market_data)
|
||||
market_hash = self._generate_market_hash(market_data)
|
||||
|
||||
# Enforce max size (evict oldest if full)
|
||||
if len(self._cache) >= self.max_size:
|
||||
# Find oldest entry
|
||||
oldest_key = min(self._cache.keys(), key=lambda k: self._cache[k].cached_at)
|
||||
del self._cache[oldest_key]
|
||||
self._metrics.evictions += 1
|
||||
logger.debug("Cache full, evicted key: %s", oldest_key)
|
||||
|
||||
# Store entry
|
||||
entry = CacheEntry(
|
||||
decision=decision,
|
||||
cached_at=time.time(),
|
||||
market_data_hash=market_hash,
|
||||
)
|
||||
self._cache[cache_key] = entry
|
||||
self._metrics.total_entries = len(self._cache)
|
||||
|
||||
logger.debug("Cached decision for key: %s", cache_key)
|
||||
|
||||
def invalidate(self, stock_code: str | None = None) -> int:
|
||||
"""Invalidate cache entries.
|
||||
|
||||
Args:
|
||||
stock_code: Specific stock code to invalidate, or None for all
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
if stock_code is None:
|
||||
# Clear all
|
||||
count = len(self._cache)
|
||||
self._cache.clear()
|
||||
self._metrics.evictions += count
|
||||
self._metrics.total_entries = 0
|
||||
logger.info("Invalidated all cache entries (%d)", count)
|
||||
return count
|
||||
|
||||
# Invalidate specific stock
|
||||
keys_to_remove = [k for k in self._cache.keys() if k.startswith(f"{stock_code}_")]
|
||||
count = len(keys_to_remove)
|
||||
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
|
||||
self._metrics.evictions += count
|
||||
self._metrics.total_entries = len(self._cache)
|
||||
logger.info("Invalidated %d cache entries for stock: %s", count, stock_code)
|
||||
|
||||
return count
|
||||
|
||||
def cleanup_expired(self) -> int:
|
||||
"""Remove expired entries from cache.
|
||||
|
||||
Returns:
|
||||
Number of entries removed
|
||||
"""
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
k
|
||||
for k, v in self._cache.items()
|
||||
if current_time - v.cached_at > self.ttl_seconds
|
||||
]
|
||||
|
||||
count = len(expired_keys)
|
||||
for key in expired_keys:
|
||||
del self._cache[key]
|
||||
|
||||
self._metrics.evictions += count
|
||||
self._metrics.total_entries = len(self._cache)
|
||||
|
||||
if count > 0:
|
||||
logger.debug("Cleaned up %d expired cache entries", count)
|
||||
|
||||
return count
|
||||
|
||||
def get_metrics(self) -> CacheMetrics:
|
||||
"""Get current cache metrics.
|
||||
|
||||
Returns:
|
||||
CacheMetrics object with current statistics
|
||||
"""
|
||||
return self._metrics
|
||||
|
||||
def reset_metrics(self) -> None:
|
||||
"""Reset cache metrics."""
|
||||
self._metrics = CacheMetrics(total_entries=len(self._cache))
|
||||
logger.info("Cache metrics reset")
|
||||
|
||||
def should_cache_decision(self, decision: TradeDecision) -> bool:
|
||||
"""Determine if a decision should be cached.
|
||||
|
||||
HOLD decisions with low confidence are good candidates for caching,
|
||||
as they're likely to recur in quiet markets.
|
||||
|
||||
Args:
|
||||
decision: TradeDecision to evaluate
|
||||
|
||||
Returns:
|
||||
True if decision should be cached
|
||||
"""
|
||||
# Cache HOLD decisions (common in quiet markets)
|
||||
if decision.action == "HOLD":
|
||||
return True
|
||||
|
||||
# Cache high-confidence decisions (stable signals)
|
||||
if decision.confidence >= 90:
|
||||
return True
|
||||
|
||||
# Don't cache low-confidence BUY/SELL (volatile signals)
|
||||
return False
|
||||
296
src/brain/context_selector.py
Normal file
296
src/brain/context_selector.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""Smart context selection for optimizing token usage.
|
||||
|
||||
This module implements intelligent selection of context layers (L1-L7) based on
|
||||
decision type and market conditions:
|
||||
- L7 (real-time) for normal trading decisions
|
||||
- L6-L5 (daily/weekly) for strategic decisions
|
||||
- L4-L1 (monthly/legacy) only for major events or policy changes
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from src.context.layer import ContextLayer
|
||||
from src.context.store import ContextStore
|
||||
|
||||
|
||||
class DecisionType(str, Enum):
|
||||
"""Type of trading decision being made."""
|
||||
|
||||
NORMAL = "normal" # Regular trade decision
|
||||
STRATEGIC = "strategic" # Strategy adjustment
|
||||
MAJOR_EVENT = "major_event" # Portfolio rebalancing, policy change
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ContextSelection:
|
||||
"""Selected context layers and their relevance scores."""
|
||||
|
||||
layers: list[ContextLayer]
|
||||
relevance_scores: dict[ContextLayer, float]
|
||||
total_score: float
|
||||
|
||||
|
||||
class ContextSelector:
|
||||
"""Selects optimal context layers to minimize token usage."""
|
||||
|
||||
def __init__(self, store: ContextStore) -> None:
|
||||
"""Initialize the context selector.
|
||||
|
||||
Args:
|
||||
store: ContextStore instance for retrieving context data
|
||||
"""
|
||||
self.store = store
|
||||
|
||||
def select_layers(
|
||||
self,
|
||||
decision_type: DecisionType = DecisionType.NORMAL,
|
||||
include_realtime: bool = True,
|
||||
) -> list[ContextLayer]:
|
||||
"""Select context layers based on decision type.
|
||||
|
||||
Strategy:
|
||||
- NORMAL: L7 (real-time) only
|
||||
- STRATEGIC: L7 + L6 + L5 (real-time + daily + weekly)
|
||||
- MAJOR_EVENT: All layers L1-L7
|
||||
|
||||
Args:
|
||||
decision_type: Type of decision being made
|
||||
include_realtime: Whether to include L7 real-time data
|
||||
|
||||
Returns:
|
||||
List of context layers to use (ordered by priority)
|
||||
"""
|
||||
if decision_type == DecisionType.NORMAL:
|
||||
# Normal trading: only real-time data
|
||||
return [ContextLayer.L7_REALTIME] if include_realtime else []
|
||||
|
||||
elif decision_type == DecisionType.STRATEGIC:
|
||||
# Strategic decisions: real-time + recent history
|
||||
layers = []
|
||||
if include_realtime:
|
||||
layers.append(ContextLayer.L7_REALTIME)
|
||||
layers.extend([ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY])
|
||||
return layers
|
||||
|
||||
else: # MAJOR_EVENT
|
||||
# Major events: all layers for comprehensive context
|
||||
layers = []
|
||||
if include_realtime:
|
||||
layers.append(ContextLayer.L7_REALTIME)
|
||||
layers.extend(
|
||||
[
|
||||
ContextLayer.L6_DAILY,
|
||||
ContextLayer.L5_WEEKLY,
|
||||
ContextLayer.L4_MONTHLY,
|
||||
ContextLayer.L3_QUARTERLY,
|
||||
ContextLayer.L2_ANNUAL,
|
||||
ContextLayer.L1_LEGACY,
|
||||
]
|
||||
)
|
||||
return layers
|
||||
|
||||
def score_layer_relevance(
|
||||
self,
|
||||
layer: ContextLayer,
|
||||
decision_type: DecisionType,
|
||||
current_time: datetime | None = None,
|
||||
) -> float:
|
||||
"""Calculate relevance score for a context layer.
|
||||
|
||||
Relevance is based on:
|
||||
1. Decision type (normal, strategic, major event)
|
||||
2. Layer recency (L7 > L6 > ... > L1)
|
||||
3. Data availability
|
||||
|
||||
Args:
|
||||
layer: Context layer to score
|
||||
decision_type: Type of decision being made
|
||||
current_time: Current time (defaults to now)
|
||||
|
||||
Returns:
|
||||
Relevance score (0.0 to 1.0)
|
||||
"""
|
||||
if current_time is None:
|
||||
current_time = datetime.now(UTC)
|
||||
|
||||
# Base scores by decision type
|
||||
base_scores = {
|
||||
DecisionType.NORMAL: {
|
||||
ContextLayer.L7_REALTIME: 1.0,
|
||||
ContextLayer.L6_DAILY: 0.1,
|
||||
ContextLayer.L5_WEEKLY: 0.05,
|
||||
ContextLayer.L4_MONTHLY: 0.01,
|
||||
ContextLayer.L3_QUARTERLY: 0.0,
|
||||
ContextLayer.L2_ANNUAL: 0.0,
|
||||
ContextLayer.L1_LEGACY: 0.0,
|
||||
},
|
||||
DecisionType.STRATEGIC: {
|
||||
ContextLayer.L7_REALTIME: 0.9,
|
||||
ContextLayer.L6_DAILY: 0.8,
|
||||
ContextLayer.L5_WEEKLY: 0.7,
|
||||
ContextLayer.L4_MONTHLY: 0.3,
|
||||
ContextLayer.L3_QUARTERLY: 0.2,
|
||||
ContextLayer.L2_ANNUAL: 0.1,
|
||||
ContextLayer.L1_LEGACY: 0.05,
|
||||
},
|
||||
DecisionType.MAJOR_EVENT: {
|
||||
ContextLayer.L7_REALTIME: 0.7,
|
||||
ContextLayer.L6_DAILY: 0.7,
|
||||
ContextLayer.L5_WEEKLY: 0.7,
|
||||
ContextLayer.L4_MONTHLY: 0.8,
|
||||
ContextLayer.L3_QUARTERLY: 0.8,
|
||||
ContextLayer.L2_ANNUAL: 0.9,
|
||||
ContextLayer.L1_LEGACY: 1.0,
|
||||
},
|
||||
}
|
||||
|
||||
score = base_scores[decision_type].get(layer, 0.0)
|
||||
|
||||
# Check data availability
|
||||
latest_timeframe = self.store.get_latest_timeframe(layer)
|
||||
if latest_timeframe is None:
|
||||
# No data available - reduce score significantly
|
||||
score *= 0.1
|
||||
|
||||
return score
|
||||
|
||||
def select_with_scoring(
|
||||
self,
|
||||
decision_type: DecisionType = DecisionType.NORMAL,
|
||||
min_score: float = 0.5,
|
||||
) -> ContextSelection:
|
||||
"""Select context layers with relevance scoring.
|
||||
|
||||
Args:
|
||||
decision_type: Type of decision being made
|
||||
min_score: Minimum relevance score to include a layer
|
||||
|
||||
Returns:
|
||||
ContextSelection with selected layers and scores
|
||||
"""
|
||||
all_layers = [
|
||||
ContextLayer.L7_REALTIME,
|
||||
ContextLayer.L6_DAILY,
|
||||
ContextLayer.L5_WEEKLY,
|
||||
ContextLayer.L4_MONTHLY,
|
||||
ContextLayer.L3_QUARTERLY,
|
||||
ContextLayer.L2_ANNUAL,
|
||||
ContextLayer.L1_LEGACY,
|
||||
]
|
||||
|
||||
scores = {
|
||||
layer: self.score_layer_relevance(layer, decision_type) for layer in all_layers
|
||||
}
|
||||
|
||||
# Filter by minimum score
|
||||
selected_layers = [layer for layer, score in scores.items() if score >= min_score]
|
||||
|
||||
# Sort by score (descending)
|
||||
selected_layers.sort(key=lambda l: scores[l], reverse=True)
|
||||
|
||||
total_score = sum(scores[layer] for layer in selected_layers)
|
||||
|
||||
return ContextSelection(
|
||||
layers=selected_layers,
|
||||
relevance_scores=scores,
|
||||
total_score=total_score,
|
||||
)
|
||||
|
||||
def get_context_data(
|
||||
self,
|
||||
layers: list[ContextLayer],
|
||||
max_items_per_layer: int = 10,
|
||||
) -> dict[str, Any]:
|
||||
"""Retrieve context data for selected layers.
|
||||
|
||||
Args:
|
||||
layers: List of context layers to retrieve
|
||||
max_items_per_layer: Maximum number of items per layer
|
||||
|
||||
Returns:
|
||||
Dictionary with context data organized by layer
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
|
||||
for layer in layers:
|
||||
# Get latest timeframe for this layer
|
||||
latest_timeframe = self.store.get_latest_timeframe(layer)
|
||||
if latest_timeframe:
|
||||
# Get all contexts for latest timeframe
|
||||
contexts = self.store.get_all_contexts(layer, latest_timeframe)
|
||||
|
||||
# Limit number of items
|
||||
if len(contexts) > max_items_per_layer:
|
||||
# Keep only first N items
|
||||
contexts = dict(list(contexts.items())[:max_items_per_layer])
|
||||
|
||||
result[layer.value] = contexts
|
||||
|
||||
return result
|
||||
|
||||
def estimate_context_tokens(self, context_data: dict[str, Any]) -> int:
|
||||
"""Estimate total tokens for context data.
|
||||
|
||||
Args:
|
||||
context_data: Context data dictionary
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
import json
|
||||
|
||||
from src.brain.prompt_optimizer import PromptOptimizer
|
||||
|
||||
# Serialize to JSON and estimate tokens
|
||||
json_str = json.dumps(context_data, ensure_ascii=False)
|
||||
return PromptOptimizer.estimate_tokens(json_str)
|
||||
|
||||
def optimize_context_for_budget(
|
||||
self,
|
||||
decision_type: DecisionType,
|
||||
max_tokens: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Select and retrieve context data within a token budget.
|
||||
|
||||
Args:
|
||||
decision_type: Type of decision being made
|
||||
max_tokens: Maximum token budget for context
|
||||
|
||||
Returns:
|
||||
Optimized context data within budget
|
||||
"""
|
||||
# Start with minimal selection
|
||||
selection = self.select_with_scoring(decision_type, min_score=0.5)
|
||||
|
||||
# Retrieve data
|
||||
context_data = self.get_context_data(selection.layers)
|
||||
|
||||
# Check if within budget
|
||||
estimated_tokens = self.estimate_context_tokens(context_data)
|
||||
|
||||
if estimated_tokens <= max_tokens:
|
||||
return context_data
|
||||
|
||||
# If over budget, progressively reduce
|
||||
# 1. Reduce items per layer
|
||||
for max_items in [5, 3, 1]:
|
||||
context_data = self.get_context_data(selection.layers, max_items)
|
||||
estimated_tokens = self.estimate_context_tokens(context_data)
|
||||
if estimated_tokens <= max_tokens:
|
||||
return context_data
|
||||
|
||||
# 2. Remove lower-priority layers
|
||||
for min_score in [0.6, 0.7, 0.8, 0.9]:
|
||||
selection = self.select_with_scoring(decision_type, min_score=min_score)
|
||||
context_data = self.get_context_data(selection.layers, max_items_per_layer=1)
|
||||
estimated_tokens = self.estimate_context_tokens(context_data)
|
||||
if estimated_tokens <= max_tokens:
|
||||
return context_data
|
||||
|
||||
# Last resort: return only L7 with minimal data
|
||||
return self.get_context_data([ContextLayer.L7_REALTIME], max_items_per_layer=1)
|
||||
@@ -2,6 +2,11 @@
|
||||
|
||||
Constructs prompts from market data, calls Gemini, and parses structured
|
||||
JSON responses into validated TradeDecision objects.
|
||||
|
||||
Includes token efficiency optimizations:
|
||||
- Prompt compression and abbreviation
|
||||
- Response caching for common scenarios
|
||||
- Token usage tracking and metrics
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -15,6 +20,8 @@ from typing import Any
|
||||
from google import genai
|
||||
|
||||
from src.config import Settings
|
||||
from src.brain.cache import DecisionCache
|
||||
from src.brain.prompt_optimizer import PromptOptimizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,17 +35,35 @@ class TradeDecision:
|
||||
action: str # "BUY" | "SELL" | "HOLD"
|
||||
confidence: int # 0-100
|
||||
rationale: str
|
||||
token_count: int = 0 # Estimated tokens used
|
||||
cached: bool = False # Whether decision came from cache
|
||||
|
||||
|
||||
class GeminiClient:
|
||||
"""Wraps the Gemini API for trade decision-making."""
|
||||
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
settings: Settings,
|
||||
enable_cache: bool = True,
|
||||
enable_optimization: bool = True,
|
||||
) -> None:
|
||||
self._settings = settings
|
||||
self._confidence_threshold = settings.CONFIDENCE_THRESHOLD
|
||||
self._client = genai.Client(api_key=settings.GEMINI_API_KEY)
|
||||
self._model_name = settings.GEMINI_MODEL
|
||||
|
||||
# Token efficiency features
|
||||
self._enable_cache = enable_cache
|
||||
self._enable_optimization = enable_optimization
|
||||
self._cache = DecisionCache(ttl_seconds=300) if enable_cache else None
|
||||
self._optimizer = PromptOptimizer()
|
||||
|
||||
# Token usage metrics
|
||||
self._total_tokens_used = 0
|
||||
self._total_decisions = 0
|
||||
self._total_cached_decisions = 0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Prompt Construction
|
||||
# ------------------------------------------------------------------
|
||||
@@ -154,26 +179,141 @@ class GeminiClient:
|
||||
|
||||
async def decide(self, market_data: dict[str, Any]) -> TradeDecision:
|
||||
"""Build prompt, call Gemini, and return a parsed decision."""
|
||||
# Check cache first
|
||||
if self._cache:
|
||||
cached_decision = self._cache.get(market_data)
|
||||
if cached_decision:
|
||||
self._total_cached_decisions += 1
|
||||
self._total_decisions += 1
|
||||
logger.info(
|
||||
"Cache hit for decision",
|
||||
extra={
|
||||
"action": cached_decision.action,
|
||||
"confidence": cached_decision.confidence,
|
||||
"cache_hit_rate": self.get_cache_hit_rate(),
|
||||
},
|
||||
)
|
||||
# Return cached decision with cached flag
|
||||
return TradeDecision(
|
||||
action=cached_decision.action,
|
||||
confidence=cached_decision.confidence,
|
||||
rationale=cached_decision.rationale,
|
||||
token_count=0,
|
||||
cached=True,
|
||||
)
|
||||
|
||||
# Build optimized prompt
|
||||
if self._enable_optimization:
|
||||
prompt = self._optimizer.build_compressed_prompt(market_data)
|
||||
else:
|
||||
prompt = self.build_prompt(market_data)
|
||||
logger.info("Requesting trade decision from Gemini")
|
||||
|
||||
# Estimate tokens
|
||||
token_count = self._optimizer.estimate_tokens(prompt)
|
||||
self._total_tokens_used += token_count
|
||||
|
||||
logger.info(
|
||||
"Requesting trade decision from Gemini",
|
||||
extra={"estimated_tokens": token_count, "optimized": self._enable_optimization},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self._client.aio.models.generate_content(
|
||||
model=self._model_name, contents=prompt,
|
||||
model=self._model_name,
|
||||
contents=prompt,
|
||||
)
|
||||
raw = response.text
|
||||
except Exception as exc:
|
||||
logger.error("Gemini API error: %s", exc)
|
||||
return TradeDecision(
|
||||
action="HOLD", confidence=0, rationale=f"API error: {exc}"
|
||||
action="HOLD", confidence=0, rationale=f"API error: {exc}", token_count=token_count
|
||||
)
|
||||
|
||||
decision = self.parse_response(raw)
|
||||
self._total_decisions += 1
|
||||
|
||||
# Add token count to decision
|
||||
decision_with_tokens = TradeDecision(
|
||||
action=decision.action,
|
||||
confidence=decision.confidence,
|
||||
rationale=decision.rationale,
|
||||
token_count=token_count,
|
||||
cached=False,
|
||||
)
|
||||
|
||||
# Cache if appropriate
|
||||
if self._cache and self._cache.should_cache_decision(decision):
|
||||
self._cache.set(market_data, decision)
|
||||
|
||||
logger.info(
|
||||
"Gemini decision",
|
||||
extra={
|
||||
"action": decision.action,
|
||||
"confidence": decision.confidence,
|
||||
"tokens": token_count,
|
||||
"avg_tokens": self.get_avg_tokens_per_decision(),
|
||||
},
|
||||
)
|
||||
return decision
|
||||
|
||||
return decision_with_tokens
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Token Efficiency Metrics
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_token_metrics(self) -> dict[str, Any]:
|
||||
"""Get token usage metrics.
|
||||
|
||||
Returns:
|
||||
Dictionary with token usage statistics
|
||||
"""
|
||||
metrics = {
|
||||
"total_tokens_used": self._total_tokens_used,
|
||||
"total_decisions": self._total_decisions,
|
||||
"total_cached_decisions": self._total_cached_decisions,
|
||||
"avg_tokens_per_decision": self.get_avg_tokens_per_decision(),
|
||||
"cache_hit_rate": self.get_cache_hit_rate(),
|
||||
}
|
||||
|
||||
if self._cache:
|
||||
cache_metrics = self._cache.get_metrics()
|
||||
metrics["cache_metrics"] = cache_metrics.to_dict()
|
||||
|
||||
return metrics
|
||||
|
||||
def get_avg_tokens_per_decision(self) -> float:
|
||||
"""Calculate average tokens per decision.
|
||||
|
||||
Returns:
|
||||
Average tokens per decision
|
||||
"""
|
||||
if self._total_decisions == 0:
|
||||
return 0.0
|
||||
return self._total_tokens_used / self._total_decisions
|
||||
|
||||
def get_cache_hit_rate(self) -> float:
|
||||
"""Calculate cache hit rate.
|
||||
|
||||
Returns:
|
||||
Cache hit rate (0.0 to 1.0)
|
||||
"""
|
||||
if self._total_decisions == 0:
|
||||
return 0.0
|
||||
return self._total_cached_decisions / self._total_decisions
|
||||
|
||||
def reset_metrics(self) -> None:
|
||||
"""Reset token usage metrics."""
|
||||
self._total_tokens_used = 0
|
||||
self._total_decisions = 0
|
||||
self._total_cached_decisions = 0
|
||||
if self._cache:
|
||||
self._cache.reset_metrics()
|
||||
logger.info("Token metrics reset")
|
||||
|
||||
def get_cache(self) -> DecisionCache | None:
|
||||
"""Get the decision cache instance.
|
||||
|
||||
Returns:
|
||||
DecisionCache instance or None if caching disabled
|
||||
"""
|
||||
return self._cache
|
||||
|
||||
268
src/brain/prompt_optimizer.py
Normal file
268
src/brain/prompt_optimizer.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""Prompt optimization utilities for reducing token usage.
|
||||
|
||||
This module provides tools to compress prompts while maintaining decision quality:
|
||||
- Token counting
|
||||
- Text compression and abbreviation
|
||||
- Template-based prompts with variable slots
|
||||
- Priority-based context truncation
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
# Abbreviation mapping for common terms
|
||||
ABBREVIATIONS = {
|
||||
"price": "P",
|
||||
"volume": "V",
|
||||
"current": "cur",
|
||||
"previous": "prev",
|
||||
"change": "chg",
|
||||
"percentage": "pct",
|
||||
"market": "mkt",
|
||||
"orderbook": "ob",
|
||||
"foreigner": "fgn",
|
||||
"buy": "B",
|
||||
"sell": "S",
|
||||
"hold": "H",
|
||||
"confidence": "conf",
|
||||
"rationale": "reason",
|
||||
"action": "act",
|
||||
"net": "net",
|
||||
}
|
||||
|
||||
# Reverse mapping for decompression
|
||||
REVERSE_ABBREVIATIONS = {v: k for k, v in ABBREVIATIONS.items()}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TokenMetrics:
|
||||
"""Metrics about token usage in a prompt."""
|
||||
|
||||
char_count: int
|
||||
word_count: int
|
||||
estimated_tokens: int # Rough estimate: ~4 chars per token
|
||||
compression_ratio: float = 1.0 # Original / Compressed
|
||||
|
||||
|
||||
class PromptOptimizer:
|
||||
"""Optimizes prompts to reduce token usage while maintaining quality."""
|
||||
|
||||
@staticmethod
|
||||
def estimate_tokens(text: str) -> int:
|
||||
"""Estimate token count for text.
|
||||
|
||||
Uses a simple heuristic: ~4 characters per token for English.
|
||||
This is approximate but sufficient for optimization purposes.
|
||||
|
||||
Args:
|
||||
text: Input text to estimate tokens for
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
# Simple estimate: 1 token ≈ 4 characters
|
||||
return max(1, len(text) // 4)
|
||||
|
||||
@staticmethod
|
||||
def count_tokens(text: str) -> TokenMetrics:
|
||||
"""Count various metrics for a text.
|
||||
|
||||
Args:
|
||||
text: Input text to analyze
|
||||
|
||||
Returns:
|
||||
TokenMetrics with character, word, and estimated token counts
|
||||
"""
|
||||
char_count = len(text)
|
||||
word_count = len(text.split())
|
||||
estimated_tokens = PromptOptimizer.estimate_tokens(text)
|
||||
|
||||
return TokenMetrics(
|
||||
char_count=char_count,
|
||||
word_count=word_count,
|
||||
estimated_tokens=estimated_tokens,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def compress_json(data: dict[str, Any]) -> str:
|
||||
"""Compress JSON by removing whitespace.
|
||||
|
||||
Args:
|
||||
data: Dictionary to serialize
|
||||
|
||||
Returns:
|
||||
Compact JSON string without whitespace
|
||||
"""
|
||||
return json.dumps(data, separators=(",", ":"), ensure_ascii=False)
|
||||
|
||||
@staticmethod
|
||||
def abbreviate_text(text: str, aggressive: bool = False) -> str:
|
||||
"""Apply abbreviations to reduce text length.
|
||||
|
||||
Args:
|
||||
text: Input text to abbreviate
|
||||
aggressive: If True, apply more aggressive compression
|
||||
|
||||
Returns:
|
||||
Abbreviated text
|
||||
"""
|
||||
result = text
|
||||
|
||||
# Apply word-level abbreviations (case-insensitive)
|
||||
for full, abbr in ABBREVIATIONS.items():
|
||||
# Word boundaries to avoid partial replacements
|
||||
pattern = r"\b" + re.escape(full) + r"\b"
|
||||
result = re.sub(pattern, abbr, result, flags=re.IGNORECASE)
|
||||
|
||||
if aggressive:
|
||||
# Remove articles and filler words
|
||||
result = re.sub(r"\b(a|an|the)\b", "", result, flags=re.IGNORECASE)
|
||||
result = re.sub(r"\b(is|are|was|were)\b", "", result, flags=re.IGNORECASE)
|
||||
# Collapse multiple spaces
|
||||
result = re.sub(r"\s+", " ", result)
|
||||
|
||||
return result.strip()
|
||||
|
||||
@staticmethod
|
||||
def build_compressed_prompt(
|
||||
market_data: dict[str, Any],
|
||||
include_instructions: bool = True,
|
||||
max_length: int | None = None,
|
||||
) -> str:
|
||||
"""Build a compressed prompt from market data.
|
||||
|
||||
Args:
|
||||
market_data: Market data dictionary with stock info
|
||||
include_instructions: Whether to include full instructions
|
||||
max_length: Maximum character length (truncates if needed)
|
||||
|
||||
Returns:
|
||||
Compressed prompt string
|
||||
"""
|
||||
# Abbreviated market name
|
||||
market_name = market_data.get("market_name", "KR")
|
||||
if "Korea" in market_name:
|
||||
market_name = "KR"
|
||||
elif "United States" in market_name or "US" in market_name:
|
||||
market_name = "US"
|
||||
|
||||
# Core data - always included
|
||||
core_info = {
|
||||
"mkt": market_name,
|
||||
"code": market_data["stock_code"],
|
||||
"P": market_data["current_price"],
|
||||
}
|
||||
|
||||
# Optional fields
|
||||
if "orderbook" in market_data and market_data["orderbook"]:
|
||||
ob = market_data["orderbook"]
|
||||
# Compress orderbook: keep only top 3 levels
|
||||
compressed_ob = {
|
||||
"bid": ob.get("bid", [])[:3],
|
||||
"ask": ob.get("ask", [])[:3],
|
||||
}
|
||||
core_info["ob"] = compressed_ob
|
||||
|
||||
if market_data.get("foreigner_net", 0) != 0:
|
||||
core_info["fgn_net"] = market_data["foreigner_net"]
|
||||
|
||||
# Compress to JSON
|
||||
data_str = PromptOptimizer.compress_json(core_info)
|
||||
|
||||
if include_instructions:
|
||||
# Minimal instructions
|
||||
prompt = (
|
||||
f"{market_name} trader. Analyze:\n{data_str}\n\n"
|
||||
'Return JSON: {"act":"BUY"|"SELL"|"HOLD","conf":<0-100>,"reason":"<text>"}\n'
|
||||
"Rules: act=BUY/SELL/HOLD, conf=0-100, reason=concise. No markdown."
|
||||
)
|
||||
else:
|
||||
# Data only (for cached contexts where instructions are known)
|
||||
prompt = data_str
|
||||
|
||||
# Truncate if needed
|
||||
if max_length and len(prompt) > max_length:
|
||||
prompt = prompt[:max_length] + "..."
|
||||
|
||||
return prompt
|
||||
|
||||
@staticmethod
|
||||
def truncate_context(
|
||||
context: dict[str, Any],
|
||||
max_tokens: int,
|
||||
priority_keys: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Truncate context data to fit within token budget.
|
||||
|
||||
Keeps high-priority keys first, then truncates less important data.
|
||||
|
||||
Args:
|
||||
context: Context dictionary to truncate
|
||||
max_tokens: Maximum token budget
|
||||
priority_keys: List of keys to keep (in order of priority)
|
||||
|
||||
Returns:
|
||||
Truncated context dictionary
|
||||
"""
|
||||
if not context:
|
||||
return {}
|
||||
|
||||
if priority_keys is None:
|
||||
priority_keys = []
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
current_tokens = 0
|
||||
|
||||
# Add priority keys first
|
||||
for key in priority_keys:
|
||||
if key in context:
|
||||
value_str = json.dumps(context[key])
|
||||
tokens = PromptOptimizer.estimate_tokens(value_str)
|
||||
|
||||
if current_tokens + tokens <= max_tokens:
|
||||
result[key] = context[key]
|
||||
current_tokens += tokens
|
||||
else:
|
||||
break
|
||||
|
||||
# Add remaining keys if space available
|
||||
for key, value in context.items():
|
||||
if key in result:
|
||||
continue
|
||||
|
||||
value_str = json.dumps(value)
|
||||
tokens = PromptOptimizer.estimate_tokens(value_str)
|
||||
|
||||
if current_tokens + tokens <= max_tokens:
|
||||
result[key] = value
|
||||
current_tokens += tokens
|
||||
else:
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def calculate_compression_ratio(original: str, compressed: str) -> float:
|
||||
"""Calculate compression ratio between original and compressed text.
|
||||
|
||||
Args:
|
||||
original: Original text
|
||||
compressed: Compressed text
|
||||
|
||||
Returns:
|
||||
Compression ratio (original_tokens / compressed_tokens)
|
||||
"""
|
||||
original_tokens = PromptOptimizer.estimate_tokens(original)
|
||||
compressed_tokens = PromptOptimizer.estimate_tokens(compressed)
|
||||
|
||||
if compressed_tokens == 0:
|
||||
return 1.0
|
||||
|
||||
return original_tokens / compressed_tokens
|
||||
331
src/context/summarizer.py
Normal file
331
src/context/summarizer.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""Context summarization for efficient historical data representation.
|
||||
|
||||
This module summarizes old context data instead of including raw details:
|
||||
- Key metrics only (averages, trends, not details)
|
||||
- Rolling window (keep last N days detailed, summarize older)
|
||||
- Aggregate historical data efficiently
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from src.context.layer import ContextLayer
|
||||
from src.context.store import ContextStore
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SummaryStats:
|
||||
"""Statistical summary of historical data."""
|
||||
|
||||
count: int
|
||||
mean: float | None = None
|
||||
min: float | None = None
|
||||
max: float | None = None
|
||||
std: float | None = None
|
||||
trend: str | None = None # "up", "down", "flat"
|
||||
|
||||
|
||||
class ContextSummarizer:
|
||||
"""Summarizes historical context data to reduce token usage."""
|
||||
|
||||
def __init__(self, store: ContextStore) -> None:
|
||||
"""Initialize the context summarizer.
|
||||
|
||||
Args:
|
||||
store: ContextStore instance for retrieving context data
|
||||
"""
|
||||
self.store = store
|
||||
|
||||
def summarize_numeric_values(self, values: list[float]) -> SummaryStats:
|
||||
"""Summarize a list of numeric values.
|
||||
|
||||
Args:
|
||||
values: List of numeric values to summarize
|
||||
|
||||
Returns:
|
||||
SummaryStats with mean, min, max, std, and trend
|
||||
"""
|
||||
if not values:
|
||||
return SummaryStats(count=0)
|
||||
|
||||
count = len(values)
|
||||
mean = sum(values) / count
|
||||
min_val = min(values)
|
||||
max_val = max(values)
|
||||
|
||||
# Calculate standard deviation
|
||||
if count > 1:
|
||||
variance = sum((x - mean) ** 2 for x in values) / (count - 1)
|
||||
std = variance**0.5
|
||||
else:
|
||||
std = 0.0
|
||||
|
||||
# Determine trend
|
||||
trend = "flat"
|
||||
if count >= 3:
|
||||
# Simple trend: compare first third vs last third
|
||||
first_third = values[: count // 3]
|
||||
last_third = values[-(count // 3) :]
|
||||
first_avg = sum(first_third) / len(first_third)
|
||||
last_avg = sum(last_third) / len(last_third)
|
||||
|
||||
# Trend threshold: 5% change
|
||||
threshold = 0.05 * abs(first_avg) if first_avg != 0 else 0.01
|
||||
|
||||
if last_avg > first_avg + threshold:
|
||||
trend = "up"
|
||||
elif last_avg < first_avg - threshold:
|
||||
trend = "down"
|
||||
|
||||
return SummaryStats(
|
||||
count=count,
|
||||
mean=round(mean, 4),
|
||||
min=round(min_val, 4),
|
||||
max=round(max_val, 4),
|
||||
std=round(std, 4),
|
||||
trend=trend,
|
||||
)
|
||||
|
||||
def summarize_layer(
|
||||
self,
|
||||
layer: ContextLayer,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Summarize all context data for a layer within a date range.
|
||||
|
||||
Args:
|
||||
layer: Context layer to summarize
|
||||
start_date: Start date (inclusive), None for all
|
||||
end_date: End date (inclusive), None for now
|
||||
|
||||
Returns:
|
||||
Dictionary with summarized metrics
|
||||
"""
|
||||
if end_date is None:
|
||||
end_date = datetime.now(UTC)
|
||||
|
||||
# Get all contexts for this layer
|
||||
all_contexts = self.store.get_all_contexts(layer)
|
||||
|
||||
if not all_contexts:
|
||||
return {"summary": "No data available", "count": 0}
|
||||
|
||||
# Group numeric values by key
|
||||
numeric_data: dict[str, list[float]] = {}
|
||||
text_data: dict[str, list[str]] = {}
|
||||
|
||||
for key, value in all_contexts.items():
|
||||
# Try to extract numeric values
|
||||
if isinstance(value, (int, float)):
|
||||
if key not in numeric_data:
|
||||
numeric_data[key] = []
|
||||
numeric_data[key].append(float(value))
|
||||
elif isinstance(value, dict):
|
||||
# Extract numeric fields from dict
|
||||
for subkey, subvalue in value.items():
|
||||
if isinstance(subvalue, (int, float)):
|
||||
full_key = f"{key}.{subkey}"
|
||||
if full_key not in numeric_data:
|
||||
numeric_data[full_key] = []
|
||||
numeric_data[full_key].append(float(subvalue))
|
||||
elif isinstance(value, str):
|
||||
if key not in text_data:
|
||||
text_data[key] = []
|
||||
text_data[key].append(value)
|
||||
|
||||
# Summarize numeric data
|
||||
summary: dict[str, Any] = {}
|
||||
|
||||
for key, values in numeric_data.items():
|
||||
stats = self.summarize_numeric_values(values)
|
||||
summary[key] = {
|
||||
"count": stats.count,
|
||||
"avg": stats.mean,
|
||||
"range": [stats.min, stats.max],
|
||||
"trend": stats.trend,
|
||||
}
|
||||
|
||||
# Summarize text data (just counts)
|
||||
for key, values in text_data.items():
|
||||
summary[f"{key}_count"] = len(values)
|
||||
|
||||
summary["total_entries"] = len(all_contexts)
|
||||
|
||||
return summary
|
||||
|
||||
def rolling_window_summary(
|
||||
self,
|
||||
layer: ContextLayer,
|
||||
window_days: int = 30,
|
||||
summarize_older: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a rolling window summary.
|
||||
|
||||
Recent data (within window) is kept detailed.
|
||||
Older data is summarized to key metrics.
|
||||
|
||||
Args:
|
||||
layer: Context layer to summarize
|
||||
window_days: Number of days to keep detailed
|
||||
summarize_older: Whether to summarize data older than window
|
||||
|
||||
Returns:
|
||||
Dictionary with recent (detailed) and historical (summary) data
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
cutoff = now - timedelta(days=window_days)
|
||||
|
||||
result: dict[str, Any] = {
|
||||
"window_days": window_days,
|
||||
"recent_data": {},
|
||||
"historical_summary": {},
|
||||
}
|
||||
|
||||
# Get all contexts
|
||||
all_contexts = self.store.get_all_contexts(layer)
|
||||
|
||||
recent_values: dict[str, list[float]] = {}
|
||||
historical_values: dict[str, list[float]] = {}
|
||||
|
||||
for key, value in all_contexts.items():
|
||||
# For simplicity, treat all numeric values
|
||||
if isinstance(value, (int, float)):
|
||||
# Note: We don't have timestamps in context keys
|
||||
# This is a simplified implementation
|
||||
# In practice, would need to check timeframe field
|
||||
|
||||
# For now, put recent data in window
|
||||
if key not in recent_values:
|
||||
recent_values[key] = []
|
||||
recent_values[key].append(float(value))
|
||||
|
||||
# Detailed recent data
|
||||
result["recent_data"] = {key: values[-10:] for key, values in recent_values.items()}
|
||||
|
||||
# Summarized historical data
|
||||
if summarize_older:
|
||||
for key, values in historical_values.items():
|
||||
stats = self.summarize_numeric_values(values)
|
||||
result["historical_summary"][key] = {
|
||||
"count": stats.count,
|
||||
"avg": stats.mean,
|
||||
"trend": stats.trend,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def aggregate_to_higher_layer(
|
||||
self,
|
||||
source_layer: ContextLayer,
|
||||
target_layer: ContextLayer,
|
||||
metric_key: str,
|
||||
aggregation_func: str = "mean",
|
||||
) -> float | None:
|
||||
"""Aggregate data from source layer to target layer.
|
||||
|
||||
Args:
|
||||
source_layer: Source context layer (more granular)
|
||||
target_layer: Target context layer (less granular)
|
||||
metric_key: Key of metric to aggregate
|
||||
aggregation_func: Aggregation function ("mean", "sum", "max", "min")
|
||||
|
||||
Returns:
|
||||
Aggregated value, or None if no data available
|
||||
"""
|
||||
# Get all contexts from source layer
|
||||
source_contexts = self.store.get_all_contexts(source_layer)
|
||||
|
||||
# Extract values for metric_key
|
||||
values = []
|
||||
for key, value in source_contexts.items():
|
||||
if key == metric_key and isinstance(value, (int, float)):
|
||||
values.append(float(value))
|
||||
elif isinstance(value, dict) and metric_key in value:
|
||||
subvalue = value[metric_key]
|
||||
if isinstance(subvalue, (int, float)):
|
||||
values.append(float(subvalue))
|
||||
|
||||
if not values:
|
||||
return None
|
||||
|
||||
# Apply aggregation function
|
||||
if aggregation_func == "mean":
|
||||
return sum(values) / len(values)
|
||||
elif aggregation_func == "sum":
|
||||
return sum(values)
|
||||
elif aggregation_func == "max":
|
||||
return max(values)
|
||||
elif aggregation_func == "min":
|
||||
return min(values)
|
||||
else:
|
||||
return sum(values) / len(values) # Default to mean
|
||||
|
||||
def create_compact_summary(
|
||||
self,
|
||||
layers: list[ContextLayer],
|
||||
top_n_metrics: int = 5,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a compact summary across multiple layers.
|
||||
|
||||
Args:
|
||||
layers: List of context layers to summarize
|
||||
top_n_metrics: Number of top metrics to include per layer
|
||||
|
||||
Returns:
|
||||
Compact summary dictionary
|
||||
"""
|
||||
summary: dict[str, Any] = {}
|
||||
|
||||
for layer in layers:
|
||||
layer_summary = self.summarize_layer(layer)
|
||||
|
||||
# Keep only top N metrics (by count/relevance)
|
||||
metrics = []
|
||||
for key, value in layer_summary.items():
|
||||
if isinstance(value, dict) and "count" in value:
|
||||
metrics.append((key, value, value["count"]))
|
||||
|
||||
# Sort by count (descending)
|
||||
metrics.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
# Keep top N
|
||||
top_metrics = {m[0]: m[1] for m in metrics[:top_n_metrics]}
|
||||
|
||||
summary[layer.value] = top_metrics
|
||||
|
||||
return summary
|
||||
|
||||
def format_summary_for_prompt(self, summary: dict[str, Any]) -> str:
|
||||
"""Format summary for inclusion in a prompt.
|
||||
|
||||
Args:
|
||||
summary: Summary dictionary
|
||||
|
||||
Returns:
|
||||
Formatted string for prompt
|
||||
"""
|
||||
lines = []
|
||||
|
||||
for layer, metrics in summary.items():
|
||||
if not metrics:
|
||||
continue
|
||||
|
||||
lines.append(f"{layer}:")
|
||||
for key, value in metrics.items():
|
||||
if isinstance(value, dict):
|
||||
# Format as: key: avg=X, trend=Y
|
||||
parts = []
|
||||
if "avg" in value and value["avg"] is not None:
|
||||
parts.append(f"avg={value['avg']:.2f}")
|
||||
if "trend" in value and value["trend"]:
|
||||
parts.append(f"trend={value['trend']}")
|
||||
if parts:
|
||||
lines.append(f" {key}: {', '.join(parts)}")
|
||||
else:
|
||||
lines.append(f" {key}: {value}")
|
||||
|
||||
return "\n".join(lines)
|
||||
665
tests/test_token_efficiency.py
Normal file
665
tests/test_token_efficiency.py
Normal file
@@ -0,0 +1,665 @@
|
||||
"""Tests for token efficiency optimization components.
|
||||
|
||||
Tests cover:
|
||||
- Prompt compression and optimization
|
||||
- Context selection logic
|
||||
- Summarization
|
||||
- Caching
|
||||
- Token reduction metrics
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import time
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from src.brain.cache import CacheMetrics, DecisionCache
|
||||
from src.brain.context_selector import ContextSelector, DecisionType
|
||||
from src.brain.gemini_client import TradeDecision
|
||||
from src.brain.prompt_optimizer import PromptOptimizer, TokenMetrics
|
||||
from src.context.layer import ContextLayer
|
||||
from src.context.store import ContextStore
|
||||
from src.context.summarizer import ContextSummarizer, SummaryStats
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Prompt Optimizer Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestPromptOptimizer:
|
||||
"""Tests for PromptOptimizer."""
|
||||
|
||||
def test_estimate_tokens(self):
|
||||
"""Test token estimation."""
|
||||
optimizer = PromptOptimizer()
|
||||
|
||||
# Empty text
|
||||
assert optimizer.estimate_tokens("") == 0
|
||||
|
||||
# Short text (4 chars = 1 token estimate)
|
||||
assert optimizer.estimate_tokens("test") == 1
|
||||
|
||||
# Longer text
|
||||
text = "This is a longer piece of text for testing token estimation."
|
||||
tokens = optimizer.estimate_tokens(text)
|
||||
assert tokens > 0
|
||||
assert tokens == len(text) // 4
|
||||
|
||||
def test_count_tokens(self):
|
||||
"""Test token counting metrics."""
|
||||
optimizer = PromptOptimizer()
|
||||
|
||||
text = "Hello world, this is a test."
|
||||
metrics = optimizer.count_tokens(text)
|
||||
|
||||
assert isinstance(metrics, TokenMetrics)
|
||||
assert metrics.char_count == len(text)
|
||||
assert metrics.word_count == 6
|
||||
assert metrics.estimated_tokens > 0
|
||||
|
||||
def test_compress_json(self):
|
||||
"""Test JSON compression."""
|
||||
optimizer = PromptOptimizer()
|
||||
|
||||
data = {
|
||||
"action": "BUY",
|
||||
"confidence": 85,
|
||||
"rationale": "Strong uptrend",
|
||||
}
|
||||
|
||||
compressed = optimizer.compress_json(data)
|
||||
|
||||
# Should have no newlines and minimal whitespace
|
||||
assert "\n" not in compressed
|
||||
# Note: JSON values may contain spaces (e.g., "Strong uptrend")
|
||||
# but there should be no spaces around separators
|
||||
assert ": " not in compressed
|
||||
assert ", " not in compressed
|
||||
|
||||
# Should be valid JSON
|
||||
import json
|
||||
|
||||
parsed = json.loads(compressed)
|
||||
assert parsed == data
|
||||
|
||||
def test_abbreviate_text(self):
|
||||
"""Test text abbreviation."""
|
||||
optimizer = PromptOptimizer()
|
||||
|
||||
text = "The current price is high and volume is increasing."
|
||||
abbreviated = optimizer.abbreviate_text(text)
|
||||
|
||||
# Should contain abbreviations
|
||||
assert "cur" in abbreviated or "P" in abbreviated
|
||||
assert len(abbreviated) <= len(text)
|
||||
|
||||
def test_abbreviate_text_aggressive(self):
|
||||
"""Test aggressive text abbreviation."""
|
||||
optimizer = PromptOptimizer()
|
||||
|
||||
text = "The price is increasing and the volume is high."
|
||||
abbreviated = optimizer.abbreviate_text(text, aggressive=True)
|
||||
|
||||
# Should be shorter
|
||||
assert len(abbreviated) < len(text)
|
||||
|
||||
# Should have removed articles
|
||||
assert "the" not in abbreviated.lower()
|
||||
|
||||
def test_build_compressed_prompt(self):
|
||||
"""Test compressed prompt building."""
|
||||
optimizer = PromptOptimizer()
|
||||
|
||||
market_data = {
|
||||
"stock_code": "005930",
|
||||
"current_price": 75000,
|
||||
"market_name": "Korean stock market",
|
||||
}
|
||||
|
||||
prompt = optimizer.build_compressed_prompt(market_data)
|
||||
|
||||
# Should be much shorter than original
|
||||
assert len(prompt) < 300
|
||||
assert "005930" in prompt
|
||||
assert "75000" in prompt
|
||||
|
||||
def test_build_compressed_prompt_no_instructions(self):
|
||||
"""Test compressed prompt without instructions."""
|
||||
optimizer = PromptOptimizer()
|
||||
|
||||
market_data = {
|
||||
"stock_code": "AAPL",
|
||||
"current_price": 150.5,
|
||||
"market_name": "United States",
|
||||
}
|
||||
|
||||
prompt = optimizer.build_compressed_prompt(market_data, include_instructions=False)
|
||||
|
||||
# Should be very short (data only)
|
||||
assert len(prompt) < 100
|
||||
assert "AAPL" in prompt
|
||||
|
||||
def test_truncate_context(self):
|
||||
"""Test context truncation."""
|
||||
optimizer = PromptOptimizer()
|
||||
|
||||
context = {
|
||||
"price": 100.5,
|
||||
"volume": 1000000,
|
||||
"sentiment": 0.8,
|
||||
"extra_data": "Some long text that should be truncated",
|
||||
}
|
||||
|
||||
# Truncate to small budget
|
||||
truncated = optimizer.truncate_context(context, max_tokens=10)
|
||||
|
||||
# Should have fewer keys
|
||||
assert len(truncated) <= len(context)
|
||||
|
||||
def test_truncate_context_with_priority(self):
|
||||
"""Test context truncation with priority keys."""
|
||||
optimizer = PromptOptimizer()
|
||||
|
||||
context = {
|
||||
"price": 100.5,
|
||||
"volume": 1000000,
|
||||
"sentiment": 0.8,
|
||||
"extra_data": "Some data",
|
||||
}
|
||||
|
||||
priority_keys = ["price", "sentiment"]
|
||||
truncated = optimizer.truncate_context(context, max_tokens=20, priority_keys=priority_keys)
|
||||
|
||||
# Priority keys should be included
|
||||
assert "price" in truncated
|
||||
assert "sentiment" in truncated
|
||||
|
||||
def test_calculate_compression_ratio(self):
|
||||
"""Test compression ratio calculation."""
|
||||
optimizer = PromptOptimizer()
|
||||
|
||||
original = "This is a very long piece of text that should be compressed significantly."
|
||||
compressed = "Short text"
|
||||
|
||||
ratio = optimizer.calculate_compression_ratio(original, compressed)
|
||||
|
||||
# Ratio should be > 1 (original is longer)
|
||||
assert ratio > 1.0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Context Selector Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestContextSelector:
|
||||
"""Tests for ContextSelector."""
|
||||
|
||||
@pytest.fixture
|
||||
def store(self):
|
||||
"""Create in-memory ContextStore."""
|
||||
conn = sqlite3.connect(":memory:")
|
||||
# Create tables
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE context_metadata (
|
||||
layer TEXT PRIMARY KEY,
|
||||
description TEXT,
|
||||
retention_days INTEGER,
|
||||
aggregation_source TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE contexts (
|
||||
layer TEXT,
|
||||
timeframe TEXT,
|
||||
key TEXT,
|
||||
value TEXT,
|
||||
created_at TEXT,
|
||||
updated_at TEXT,
|
||||
PRIMARY KEY (layer, timeframe, key)
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
return ContextStore(conn)
|
||||
|
||||
def test_select_layers_normal(self, store):
|
||||
"""Test layer selection for normal decisions."""
|
||||
selector = ContextSelector(store)
|
||||
|
||||
layers = selector.select_layers(DecisionType.NORMAL)
|
||||
|
||||
# Should only select L7 (real-time)
|
||||
assert layers == [ContextLayer.L7_REALTIME]
|
||||
|
||||
def test_select_layers_strategic(self, store):
|
||||
"""Test layer selection for strategic decisions."""
|
||||
selector = ContextSelector(store)
|
||||
|
||||
layers = selector.select_layers(DecisionType.STRATEGIC)
|
||||
|
||||
# Should select L7 + L6 + L5
|
||||
assert ContextLayer.L7_REALTIME in layers
|
||||
assert ContextLayer.L6_DAILY in layers
|
||||
assert ContextLayer.L5_WEEKLY in layers
|
||||
assert len(layers) == 3
|
||||
|
||||
def test_select_layers_major_event(self, store):
|
||||
"""Test layer selection for major events."""
|
||||
selector = ContextSelector(store)
|
||||
|
||||
layers = selector.select_layers(DecisionType.MAJOR_EVENT)
|
||||
|
||||
# Should select all layers
|
||||
assert len(layers) == 7
|
||||
assert ContextLayer.L1_LEGACY in layers
|
||||
assert ContextLayer.L7_REALTIME in layers
|
||||
|
||||
def test_score_layer_relevance(self, store):
|
||||
"""Test layer relevance scoring."""
|
||||
selector = ContextSelector(store)
|
||||
|
||||
# Add some data first so scores aren't penalized
|
||||
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
|
||||
store.set_context(ContextLayer.L1_LEGACY, "legacy", "lesson", "test")
|
||||
|
||||
# L7 should have high score for normal decisions
|
||||
score = selector.score_layer_relevance(ContextLayer.L7_REALTIME, DecisionType.NORMAL)
|
||||
assert score == 1.0
|
||||
|
||||
# L1 should have low score for normal decisions
|
||||
score = selector.score_layer_relevance(ContextLayer.L1_LEGACY, DecisionType.NORMAL)
|
||||
assert score == 0.0
|
||||
|
||||
# L1 should have high score for major events
|
||||
score = selector.score_layer_relevance(ContextLayer.L1_LEGACY, DecisionType.MAJOR_EVENT)
|
||||
assert score == 1.0
|
||||
|
||||
def test_select_with_scoring(self, store):
|
||||
"""Test selection with relevance scoring."""
|
||||
selector = ContextSelector(store)
|
||||
|
||||
# Add data so layers aren't penalized
|
||||
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
|
||||
|
||||
selection = selector.select_with_scoring(DecisionType.NORMAL, min_score=0.5)
|
||||
|
||||
# Should only select high-relevance layers
|
||||
assert len(selection.layers) >= 1
|
||||
assert ContextLayer.L7_REALTIME in selection.layers
|
||||
assert all(selection.relevance_scores[l] >= 0.5 for l in selection.layers)
|
||||
|
||||
def test_get_context_data(self, store):
|
||||
"""Test context data retrieval."""
|
||||
selector = ContextSelector(store)
|
||||
|
||||
# Add some test data
|
||||
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
|
||||
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "volume", 1000000)
|
||||
|
||||
context_data = selector.get_context_data([ContextLayer.L7_REALTIME])
|
||||
|
||||
# Should retrieve data
|
||||
assert "L7_REALTIME" in context_data
|
||||
assert "price" in context_data["L7_REALTIME"]
|
||||
assert context_data["L7_REALTIME"]["price"] == 100.5
|
||||
|
||||
def test_estimate_context_tokens(self, store):
|
||||
"""Test context token estimation."""
|
||||
selector = ContextSelector(store)
|
||||
|
||||
context_data = {
|
||||
"L7_REALTIME": {"price": 100.5, "volume": 1000000},
|
||||
"L6_DAILY": {"avg_price": 99.8, "avg_volume": 950000},
|
||||
}
|
||||
|
||||
tokens = selector.estimate_context_tokens(context_data)
|
||||
|
||||
# Should estimate tokens
|
||||
assert tokens > 0
|
||||
|
||||
def test_optimize_context_for_budget(self, store):
|
||||
"""Test context optimization for token budget."""
|
||||
selector = ContextSelector(store)
|
||||
|
||||
# Add test data
|
||||
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
|
||||
|
||||
# Get optimized context within budget
|
||||
context = selector.optimize_context_for_budget(DecisionType.NORMAL, max_tokens=50)
|
||||
|
||||
# Should return data within budget
|
||||
tokens = selector.estimate_context_tokens(context)
|
||||
assert tokens <= 50
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Context Summarizer Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestContextSummarizer:
|
||||
"""Tests for ContextSummarizer."""
|
||||
|
||||
@pytest.fixture
|
||||
def store(self):
|
||||
"""Create in-memory ContextStore."""
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE context_metadata (
|
||||
layer TEXT PRIMARY KEY,
|
||||
description TEXT,
|
||||
retention_days INTEGER,
|
||||
aggregation_source TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE contexts (
|
||||
layer TEXT,
|
||||
timeframe TEXT,
|
||||
key TEXT,
|
||||
value TEXT,
|
||||
created_at TEXT,
|
||||
updated_at TEXT,
|
||||
PRIMARY KEY (layer, timeframe, key)
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
return ContextStore(conn)
|
||||
|
||||
def test_summarize_numeric_values(self, store):
|
||||
"""Test numeric value summarization."""
|
||||
summarizer = ContextSummarizer(store)
|
||||
|
||||
values = [10.0, 20.0, 30.0, 40.0, 50.0]
|
||||
stats = summarizer.summarize_numeric_values(values)
|
||||
|
||||
assert isinstance(stats, SummaryStats)
|
||||
assert stats.count == 5
|
||||
assert stats.mean == 30.0
|
||||
assert stats.min == 10.0
|
||||
assert stats.max == 50.0
|
||||
assert stats.std is not None
|
||||
|
||||
def test_summarize_numeric_values_trend(self, store):
|
||||
"""Test trend detection in numeric values."""
|
||||
summarizer = ContextSummarizer(store)
|
||||
|
||||
# Uptrend
|
||||
values_up = [10.0, 15.0, 20.0, 25.0, 30.0, 35.0]
|
||||
stats_up = summarizer.summarize_numeric_values(values_up)
|
||||
assert stats_up.trend == "up"
|
||||
|
||||
# Downtrend
|
||||
values_down = [35.0, 30.0, 25.0, 20.0, 15.0, 10.0]
|
||||
stats_down = summarizer.summarize_numeric_values(values_down)
|
||||
assert stats_down.trend == "down"
|
||||
|
||||
# Flat
|
||||
values_flat = [20.0, 20.1, 19.9, 20.0, 20.1, 19.9]
|
||||
stats_flat = summarizer.summarize_numeric_values(values_flat)
|
||||
assert stats_flat.trend == "flat"
|
||||
|
||||
def test_summarize_layer(self, store):
|
||||
"""Test layer summarization."""
|
||||
summarizer = ContextSummarizer(store)
|
||||
|
||||
# Add test data
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-04", "price", 100.5)
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-04", "volume", 1000000)
|
||||
|
||||
summary = summarizer.summarize_layer(ContextLayer.L6_DAILY)
|
||||
|
||||
# Should have summary
|
||||
assert "total_entries" in summary
|
||||
assert summary["total_entries"] > 0
|
||||
|
||||
def test_create_compact_summary(self, store):
|
||||
"""Test compact summary creation."""
|
||||
summarizer = ContextSummarizer(store)
|
||||
|
||||
# Add test data
|
||||
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
|
||||
|
||||
layers = [ContextLayer.L7_REALTIME, ContextLayer.L6_DAILY]
|
||||
summary = summarizer.create_compact_summary(layers, top_n_metrics=3)
|
||||
|
||||
# Should have summaries for layers
|
||||
assert "L7_REALTIME" in summary
|
||||
|
||||
def test_format_summary_for_prompt(self, store):
|
||||
"""Test summary formatting for prompt."""
|
||||
summarizer = ContextSummarizer(store)
|
||||
|
||||
summary = {
|
||||
"L7_REALTIME": {
|
||||
"price": {"avg": 100.5, "trend": "up"},
|
||||
"volume": {"avg": 1000000, "trend": "flat"},
|
||||
}
|
||||
}
|
||||
|
||||
formatted = summarizer.format_summary_for_prompt(summary)
|
||||
|
||||
# Should be formatted string
|
||||
assert isinstance(formatted, str)
|
||||
assert "L7_REALTIME" in formatted
|
||||
assert "100.5" in formatted or "100.50" in formatted
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Decision Cache Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestDecisionCache:
|
||||
"""Tests for DecisionCache."""
|
||||
|
||||
def test_cache_init(self):
|
||||
"""Test cache initialization."""
|
||||
cache = DecisionCache(ttl_seconds=60, max_size=100)
|
||||
|
||||
assert cache.ttl_seconds == 60
|
||||
assert cache.max_size == 100
|
||||
|
||||
def test_cache_miss(self):
|
||||
"""Test cache miss."""
|
||||
cache = DecisionCache()
|
||||
|
||||
market_data = {"stock_code": "005930", "current_price": 75000}
|
||||
|
||||
decision = cache.get(market_data)
|
||||
|
||||
# Should be None (cache miss)
|
||||
assert decision is None
|
||||
|
||||
metrics = cache.get_metrics()
|
||||
assert metrics.cache_misses == 1
|
||||
assert metrics.cache_hits == 0
|
||||
|
||||
def test_cache_hit(self):
|
||||
"""Test cache hit."""
|
||||
cache = DecisionCache()
|
||||
|
||||
market_data = {"stock_code": "005930", "current_price": 75000}
|
||||
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||
|
||||
# Set cache
|
||||
cache.set(market_data, decision)
|
||||
|
||||
# Get from cache
|
||||
cached = cache.get(market_data)
|
||||
|
||||
assert cached is not None
|
||||
assert cached.action == "HOLD"
|
||||
assert cached.confidence == 50
|
||||
|
||||
metrics = cache.get_metrics()
|
||||
assert metrics.cache_hits == 1
|
||||
|
||||
def test_cache_ttl_expiration(self):
|
||||
"""Test cache TTL expiration."""
|
||||
cache = DecisionCache(ttl_seconds=1) # 1 second TTL
|
||||
|
||||
market_data = {"stock_code": "005930", "current_price": 75000}
|
||||
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||
|
||||
# Set cache
|
||||
cache.set(market_data, decision)
|
||||
|
||||
# Should hit immediately
|
||||
cached = cache.get(market_data)
|
||||
assert cached is not None
|
||||
|
||||
# Wait for expiration
|
||||
time.sleep(1.1)
|
||||
|
||||
# Should miss after expiration
|
||||
cached = cache.get(market_data)
|
||||
assert cached is None
|
||||
|
||||
def test_cache_max_size(self):
|
||||
"""Test cache max size eviction."""
|
||||
cache = DecisionCache(max_size=2)
|
||||
|
||||
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||
|
||||
# Add 3 entries (exceeds max_size)
|
||||
for i in range(3):
|
||||
market_data = {"stock_code": f"00{i}", "current_price": 1000 * i}
|
||||
cache.set(market_data, decision)
|
||||
|
||||
metrics = cache.get_metrics()
|
||||
|
||||
# Should have evicted 1 entry
|
||||
assert metrics.total_entries == 2
|
||||
assert metrics.evictions == 1
|
||||
|
||||
def test_invalidate_all(self):
|
||||
"""Test invalidate all cache entries."""
|
||||
cache = DecisionCache()
|
||||
|
||||
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||
|
||||
# Add entries
|
||||
for i in range(3):
|
||||
market_data = {"stock_code": f"00{i}", "current_price": 1000}
|
||||
cache.set(market_data, decision)
|
||||
|
||||
# Invalidate all
|
||||
count = cache.invalidate()
|
||||
|
||||
assert count == 3
|
||||
|
||||
metrics = cache.get_metrics()
|
||||
assert metrics.total_entries == 0
|
||||
|
||||
def test_invalidate_by_stock(self):
|
||||
"""Test invalidate cache by stock code."""
|
||||
cache = DecisionCache()
|
||||
|
||||
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||
|
||||
# Add entries for different stocks
|
||||
cache.set({"stock_code": "005930", "current_price": 75000}, decision)
|
||||
cache.set({"stock_code": "000660", "current_price": 50000}, decision)
|
||||
|
||||
# Invalidate specific stock
|
||||
count = cache.invalidate("005930")
|
||||
|
||||
assert count >= 1
|
||||
|
||||
# Other stock should still be cached
|
||||
cached = cache.get({"stock_code": "000660", "current_price": 50000})
|
||||
assert cached is not None
|
||||
|
||||
def test_cleanup_expired(self):
|
||||
"""Test cleanup of expired entries."""
|
||||
cache = DecisionCache(ttl_seconds=1)
|
||||
|
||||
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||
|
||||
# Add entry
|
||||
cache.set({"stock_code": "005930", "current_price": 75000}, decision)
|
||||
|
||||
# Wait for expiration
|
||||
time.sleep(1.1)
|
||||
|
||||
# Cleanup
|
||||
count = cache.cleanup_expired()
|
||||
|
||||
assert count == 1
|
||||
|
||||
metrics = cache.get_metrics()
|
||||
assert metrics.total_entries == 0
|
||||
|
||||
def test_should_cache_decision(self):
|
||||
"""Test decision caching criteria."""
|
||||
cache = DecisionCache()
|
||||
|
||||
# HOLD decisions should be cached
|
||||
hold_decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||
assert cache.should_cache_decision(hold_decision) is True
|
||||
|
||||
# High confidence BUY should be cached
|
||||
buy_decision = TradeDecision(action="BUY", confidence=95, rationale="Test")
|
||||
assert cache.should_cache_decision(buy_decision) is True
|
||||
|
||||
# Low confidence BUY should not be cached
|
||||
low_conf_buy = TradeDecision(action="BUY", confidence=60, rationale="Test")
|
||||
assert cache.should_cache_decision(low_conf_buy) is False
|
||||
|
||||
def test_cache_hit_rate(self):
|
||||
"""Test cache hit rate calculation."""
|
||||
cache = DecisionCache()
|
||||
|
||||
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||
market_data = {"stock_code": "005930", "current_price": 75000}
|
||||
|
||||
# First request (miss)
|
||||
cache.get(market_data)
|
||||
|
||||
# Set cache
|
||||
cache.set(market_data, decision)
|
||||
|
||||
# Second request (hit)
|
||||
cache.get(market_data)
|
||||
|
||||
# Third request (hit)
|
||||
cache.get(market_data)
|
||||
|
||||
metrics = cache.get_metrics()
|
||||
|
||||
# 1 miss, 2 hits out of 3 requests
|
||||
assert metrics.total_requests == 3
|
||||
assert metrics.cache_hits == 2
|
||||
assert metrics.cache_misses == 1
|
||||
assert metrics.hit_rate == pytest.approx(2 / 3)
|
||||
|
||||
def test_reset_metrics(self):
|
||||
"""Test metrics reset."""
|
||||
cache = DecisionCache()
|
||||
|
||||
market_data = {"stock_code": "005930", "current_price": 75000}
|
||||
|
||||
# Generate some activity
|
||||
cache.get(market_data)
|
||||
cache.get(market_data)
|
||||
|
||||
# Reset
|
||||
cache.reset_metrics()
|
||||
|
||||
metrics = cache.get_metrics()
|
||||
assert metrics.total_requests == 0
|
||||
assert metrics.cache_hits == 0
|
||||
assert metrics.cache_misses == 0
|
||||
Reference in New Issue
Block a user