Compare commits
7 Commits
feature/is
...
feature/is
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
61f5aaf4a3 | ||
|
|
4f61d5af8e | ||
| f40f19e735 | |||
|
|
ce952d97b2 | ||
| 53d3637b3e | |||
|
|
ae7195c829 | ||
| ad1f17bb56 |
@@ -8,6 +8,7 @@ dependencies = [
|
|||||||
"pydantic>=2.5,<3",
|
"pydantic>=2.5,<3",
|
||||||
"pydantic-settings>=2.1,<3",
|
"pydantic-settings>=2.1,<3",
|
||||||
"google-genai>=1.0,<2",
|
"google-genai>=1.0,<2",
|
||||||
|
"scipy>=1.11,<2",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|||||||
293
src/brain/cache.py
Normal file
293
src/brain/cache.py
Normal file
@@ -0,0 +1,293 @@
|
|||||||
|
"""Response caching system for reducing redundant LLM calls.
|
||||||
|
|
||||||
|
This module provides caching for common trading scenarios:
|
||||||
|
- TTL-based cache invalidation
|
||||||
|
- Cache key based on market conditions
|
||||||
|
- Cache hit rate monitoring
|
||||||
|
- Special handling for HOLD decisions in quiet markets
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.brain.gemini_client import TradeDecision
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CacheEntry:
|
||||||
|
"""Cached decision with metadata."""
|
||||||
|
|
||||||
|
decision: TradeDecision
|
||||||
|
cached_at: float # Unix timestamp
|
||||||
|
hit_count: int = 0
|
||||||
|
market_data_hash: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CacheMetrics:
|
||||||
|
"""Metrics for cache performance monitoring."""
|
||||||
|
|
||||||
|
total_requests: int = 0
|
||||||
|
cache_hits: int = 0
|
||||||
|
cache_misses: int = 0
|
||||||
|
evictions: int = 0
|
||||||
|
total_entries: int = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hit_rate(self) -> float:
|
||||||
|
"""Calculate cache hit rate."""
|
||||||
|
if self.total_requests == 0:
|
||||||
|
return 0.0
|
||||||
|
return self.cache_hits / self.total_requests
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert metrics to dictionary."""
|
||||||
|
return {
|
||||||
|
"total_requests": self.total_requests,
|
||||||
|
"cache_hits": self.cache_hits,
|
||||||
|
"cache_misses": self.cache_misses,
|
||||||
|
"hit_rate": self.hit_rate,
|
||||||
|
"evictions": self.evictions,
|
||||||
|
"total_entries": self.total_entries,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DecisionCache:
|
||||||
|
"""TTL-based cache for trade decisions."""
|
||||||
|
|
||||||
|
def __init__(self, ttl_seconds: int = 300, max_size: int = 1000) -> None:
|
||||||
|
"""Initialize the decision cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ttl_seconds: Time-to-live for cache entries in seconds (default: 5 minutes)
|
||||||
|
max_size: Maximum number of cache entries
|
||||||
|
"""
|
||||||
|
self.ttl_seconds = ttl_seconds
|
||||||
|
self.max_size = max_size
|
||||||
|
self._cache: dict[str, CacheEntry] = {}
|
||||||
|
self._metrics = CacheMetrics()
|
||||||
|
|
||||||
|
def _generate_cache_key(self, market_data: dict[str, Any]) -> str:
|
||||||
|
"""Generate cache key from market data.
|
||||||
|
|
||||||
|
Key is based on:
|
||||||
|
- Stock code
|
||||||
|
- Current price (rounded to reduce sensitivity)
|
||||||
|
- Market conditions (orderbook snapshot)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
market_data: Market data dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cache key string
|
||||||
|
"""
|
||||||
|
# Extract key components
|
||||||
|
stock_code = market_data.get("stock_code", "UNKNOWN")
|
||||||
|
current_price = market_data.get("current_price", 0)
|
||||||
|
|
||||||
|
# Round price to reduce sensitivity (cache hits for similar prices)
|
||||||
|
# For prices > 1000, round to nearest 10
|
||||||
|
# For prices < 1000, round to nearest 1
|
||||||
|
if current_price > 1000:
|
||||||
|
price_rounded = round(current_price / 10) * 10
|
||||||
|
else:
|
||||||
|
price_rounded = round(current_price)
|
||||||
|
|
||||||
|
# Include orderbook snapshot (if available)
|
||||||
|
orderbook_key = ""
|
||||||
|
if "orderbook" in market_data and market_data["orderbook"]:
|
||||||
|
ob = market_data["orderbook"]
|
||||||
|
# Just use bid/ask spread as indicator
|
||||||
|
if "bid" in ob and "ask" in ob and ob["bid"] and ob["ask"]:
|
||||||
|
bid_price = ob["bid"][0].get("price", 0) if ob["bid"] else 0
|
||||||
|
ask_price = ob["ask"][0].get("price", 0) if ob["ask"] else 0
|
||||||
|
spread = ask_price - bid_price
|
||||||
|
orderbook_key = f"_spread{spread}"
|
||||||
|
|
||||||
|
# Generate cache key
|
||||||
|
key_str = f"{stock_code}_{price_rounded}{orderbook_key}"
|
||||||
|
|
||||||
|
return key_str
|
||||||
|
|
||||||
|
def _generate_market_hash(self, market_data: dict[str, Any]) -> str:
|
||||||
|
"""Generate hash of full market data for invalidation checks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
market_data: Market data dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hash string
|
||||||
|
"""
|
||||||
|
# Create stable JSON representation
|
||||||
|
stable_json = json.dumps(market_data, sort_keys=True, ensure_ascii=False)
|
||||||
|
return hashlib.md5(stable_json.encode()).hexdigest()
|
||||||
|
|
||||||
|
def get(self, market_data: dict[str, Any]) -> TradeDecision | None:
|
||||||
|
"""Retrieve cached decision if valid.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
market_data: Market data dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached TradeDecision if valid, None otherwise
|
||||||
|
"""
|
||||||
|
self._metrics.total_requests += 1
|
||||||
|
|
||||||
|
cache_key = self._generate_cache_key(market_data)
|
||||||
|
|
||||||
|
if cache_key not in self._cache:
|
||||||
|
self._metrics.cache_misses += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
entry = self._cache[cache_key]
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
# Check TTL
|
||||||
|
if current_time - entry.cached_at > self.ttl_seconds:
|
||||||
|
# Expired
|
||||||
|
del self._cache[cache_key]
|
||||||
|
self._metrics.cache_misses += 1
|
||||||
|
self._metrics.evictions += 1
|
||||||
|
logger.debug("Cache expired for key: %s", cache_key)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Cache hit
|
||||||
|
entry.hit_count += 1
|
||||||
|
self._metrics.cache_hits += 1
|
||||||
|
logger.debug("Cache hit for key: %s (hits: %d)", cache_key, entry.hit_count)
|
||||||
|
|
||||||
|
return entry.decision
|
||||||
|
|
||||||
|
def set(
|
||||||
|
self,
|
||||||
|
market_data: dict[str, Any],
|
||||||
|
decision: TradeDecision,
|
||||||
|
) -> None:
|
||||||
|
"""Store decision in cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
market_data: Market data dictionary
|
||||||
|
decision: TradeDecision to cache
|
||||||
|
"""
|
||||||
|
cache_key = self._generate_cache_key(market_data)
|
||||||
|
market_hash = self._generate_market_hash(market_data)
|
||||||
|
|
||||||
|
# Enforce max size (evict oldest if full)
|
||||||
|
if len(self._cache) >= self.max_size:
|
||||||
|
# Find oldest entry
|
||||||
|
oldest_key = min(self._cache.keys(), key=lambda k: self._cache[k].cached_at)
|
||||||
|
del self._cache[oldest_key]
|
||||||
|
self._metrics.evictions += 1
|
||||||
|
logger.debug("Cache full, evicted key: %s", oldest_key)
|
||||||
|
|
||||||
|
# Store entry
|
||||||
|
entry = CacheEntry(
|
||||||
|
decision=decision,
|
||||||
|
cached_at=time.time(),
|
||||||
|
market_data_hash=market_hash,
|
||||||
|
)
|
||||||
|
self._cache[cache_key] = entry
|
||||||
|
self._metrics.total_entries = len(self._cache)
|
||||||
|
|
||||||
|
logger.debug("Cached decision for key: %s", cache_key)
|
||||||
|
|
||||||
|
def invalidate(self, stock_code: str | None = None) -> int:
|
||||||
|
"""Invalidate cache entries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stock_code: Specific stock code to invalidate, or None for all
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of entries invalidated
|
||||||
|
"""
|
||||||
|
if stock_code is None:
|
||||||
|
# Clear all
|
||||||
|
count = len(self._cache)
|
||||||
|
self._cache.clear()
|
||||||
|
self._metrics.evictions += count
|
||||||
|
self._metrics.total_entries = 0
|
||||||
|
logger.info("Invalidated all cache entries (%d)", count)
|
||||||
|
return count
|
||||||
|
|
||||||
|
# Invalidate specific stock
|
||||||
|
keys_to_remove = [k for k in self._cache.keys() if k.startswith(f"{stock_code}_")]
|
||||||
|
count = len(keys_to_remove)
|
||||||
|
|
||||||
|
for key in keys_to_remove:
|
||||||
|
del self._cache[key]
|
||||||
|
|
||||||
|
self._metrics.evictions += count
|
||||||
|
self._metrics.total_entries = len(self._cache)
|
||||||
|
logger.info("Invalidated %d cache entries for stock: %s", count, stock_code)
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
def cleanup_expired(self) -> int:
|
||||||
|
"""Remove expired entries from cache.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of entries removed
|
||||||
|
"""
|
||||||
|
current_time = time.time()
|
||||||
|
expired_keys = [
|
||||||
|
k
|
||||||
|
for k, v in self._cache.items()
|
||||||
|
if current_time - v.cached_at > self.ttl_seconds
|
||||||
|
]
|
||||||
|
|
||||||
|
count = len(expired_keys)
|
||||||
|
for key in expired_keys:
|
||||||
|
del self._cache[key]
|
||||||
|
|
||||||
|
self._metrics.evictions += count
|
||||||
|
self._metrics.total_entries = len(self._cache)
|
||||||
|
|
||||||
|
if count > 0:
|
||||||
|
logger.debug("Cleaned up %d expired cache entries", count)
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
def get_metrics(self) -> CacheMetrics:
|
||||||
|
"""Get current cache metrics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CacheMetrics object with current statistics
|
||||||
|
"""
|
||||||
|
return self._metrics
|
||||||
|
|
||||||
|
def reset_metrics(self) -> None:
|
||||||
|
"""Reset cache metrics."""
|
||||||
|
self._metrics = CacheMetrics(total_entries=len(self._cache))
|
||||||
|
logger.info("Cache metrics reset")
|
||||||
|
|
||||||
|
def should_cache_decision(self, decision: TradeDecision) -> bool:
|
||||||
|
"""Determine if a decision should be cached.
|
||||||
|
|
||||||
|
HOLD decisions with low confidence are good candidates for caching,
|
||||||
|
as they're likely to recur in quiet markets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decision: TradeDecision to evaluate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if decision should be cached
|
||||||
|
"""
|
||||||
|
# Cache HOLD decisions (common in quiet markets)
|
||||||
|
if decision.action == "HOLD":
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Cache high-confidence decisions (stable signals)
|
||||||
|
if decision.confidence >= 90:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Don't cache low-confidence BUY/SELL (volatile signals)
|
||||||
|
return False
|
||||||
296
src/brain/context_selector.py
Normal file
296
src/brain/context_selector.py
Normal file
@@ -0,0 +1,296 @@
|
|||||||
|
"""Smart context selection for optimizing token usage.
|
||||||
|
|
||||||
|
This module implements intelligent selection of context layers (L1-L7) based on
|
||||||
|
decision type and market conditions:
|
||||||
|
- L7 (real-time) for normal trading decisions
|
||||||
|
- L6-L5 (daily/weekly) for strategic decisions
|
||||||
|
- L4-L1 (monthly/legacy) only for major events or policy changes
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from src.context.layer import ContextLayer
|
||||||
|
from src.context.store import ContextStore
|
||||||
|
|
||||||
|
|
||||||
|
class DecisionType(str, Enum):
|
||||||
|
"""Type of trading decision being made."""
|
||||||
|
|
||||||
|
NORMAL = "normal" # Regular trade decision
|
||||||
|
STRATEGIC = "strategic" # Strategy adjustment
|
||||||
|
MAJOR_EVENT = "major_event" # Portfolio rebalancing, policy change
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ContextSelection:
|
||||||
|
"""Selected context layers and their relevance scores."""
|
||||||
|
|
||||||
|
layers: list[ContextLayer]
|
||||||
|
relevance_scores: dict[ContextLayer, float]
|
||||||
|
total_score: float
|
||||||
|
|
||||||
|
|
||||||
|
class ContextSelector:
|
||||||
|
"""Selects optimal context layers to minimize token usage."""
|
||||||
|
|
||||||
|
def __init__(self, store: ContextStore) -> None:
|
||||||
|
"""Initialize the context selector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
store: ContextStore instance for retrieving context data
|
||||||
|
"""
|
||||||
|
self.store = store
|
||||||
|
|
||||||
|
def select_layers(
|
||||||
|
self,
|
||||||
|
decision_type: DecisionType = DecisionType.NORMAL,
|
||||||
|
include_realtime: bool = True,
|
||||||
|
) -> list[ContextLayer]:
|
||||||
|
"""Select context layers based on decision type.
|
||||||
|
|
||||||
|
Strategy:
|
||||||
|
- NORMAL: L7 (real-time) only
|
||||||
|
- STRATEGIC: L7 + L6 + L5 (real-time + daily + weekly)
|
||||||
|
- MAJOR_EVENT: All layers L1-L7
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decision_type: Type of decision being made
|
||||||
|
include_realtime: Whether to include L7 real-time data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of context layers to use (ordered by priority)
|
||||||
|
"""
|
||||||
|
if decision_type == DecisionType.NORMAL:
|
||||||
|
# Normal trading: only real-time data
|
||||||
|
return [ContextLayer.L7_REALTIME] if include_realtime else []
|
||||||
|
|
||||||
|
elif decision_type == DecisionType.STRATEGIC:
|
||||||
|
# Strategic decisions: real-time + recent history
|
||||||
|
layers = []
|
||||||
|
if include_realtime:
|
||||||
|
layers.append(ContextLayer.L7_REALTIME)
|
||||||
|
layers.extend([ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY])
|
||||||
|
return layers
|
||||||
|
|
||||||
|
else: # MAJOR_EVENT
|
||||||
|
# Major events: all layers for comprehensive context
|
||||||
|
layers = []
|
||||||
|
if include_realtime:
|
||||||
|
layers.append(ContextLayer.L7_REALTIME)
|
||||||
|
layers.extend(
|
||||||
|
[
|
||||||
|
ContextLayer.L6_DAILY,
|
||||||
|
ContextLayer.L5_WEEKLY,
|
||||||
|
ContextLayer.L4_MONTHLY,
|
||||||
|
ContextLayer.L3_QUARTERLY,
|
||||||
|
ContextLayer.L2_ANNUAL,
|
||||||
|
ContextLayer.L1_LEGACY,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return layers
|
||||||
|
|
||||||
|
def score_layer_relevance(
|
||||||
|
self,
|
||||||
|
layer: ContextLayer,
|
||||||
|
decision_type: DecisionType,
|
||||||
|
current_time: datetime | None = None,
|
||||||
|
) -> float:
|
||||||
|
"""Calculate relevance score for a context layer.
|
||||||
|
|
||||||
|
Relevance is based on:
|
||||||
|
1. Decision type (normal, strategic, major event)
|
||||||
|
2. Layer recency (L7 > L6 > ... > L1)
|
||||||
|
3. Data availability
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer: Context layer to score
|
||||||
|
decision_type: Type of decision being made
|
||||||
|
current_time: Current time (defaults to now)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Relevance score (0.0 to 1.0)
|
||||||
|
"""
|
||||||
|
if current_time is None:
|
||||||
|
current_time = datetime.now(UTC)
|
||||||
|
|
||||||
|
# Base scores by decision type
|
||||||
|
base_scores = {
|
||||||
|
DecisionType.NORMAL: {
|
||||||
|
ContextLayer.L7_REALTIME: 1.0,
|
||||||
|
ContextLayer.L6_DAILY: 0.1,
|
||||||
|
ContextLayer.L5_WEEKLY: 0.05,
|
||||||
|
ContextLayer.L4_MONTHLY: 0.01,
|
||||||
|
ContextLayer.L3_QUARTERLY: 0.0,
|
||||||
|
ContextLayer.L2_ANNUAL: 0.0,
|
||||||
|
ContextLayer.L1_LEGACY: 0.0,
|
||||||
|
},
|
||||||
|
DecisionType.STRATEGIC: {
|
||||||
|
ContextLayer.L7_REALTIME: 0.9,
|
||||||
|
ContextLayer.L6_DAILY: 0.8,
|
||||||
|
ContextLayer.L5_WEEKLY: 0.7,
|
||||||
|
ContextLayer.L4_MONTHLY: 0.3,
|
||||||
|
ContextLayer.L3_QUARTERLY: 0.2,
|
||||||
|
ContextLayer.L2_ANNUAL: 0.1,
|
||||||
|
ContextLayer.L1_LEGACY: 0.05,
|
||||||
|
},
|
||||||
|
DecisionType.MAJOR_EVENT: {
|
||||||
|
ContextLayer.L7_REALTIME: 0.7,
|
||||||
|
ContextLayer.L6_DAILY: 0.7,
|
||||||
|
ContextLayer.L5_WEEKLY: 0.7,
|
||||||
|
ContextLayer.L4_MONTHLY: 0.8,
|
||||||
|
ContextLayer.L3_QUARTERLY: 0.8,
|
||||||
|
ContextLayer.L2_ANNUAL: 0.9,
|
||||||
|
ContextLayer.L1_LEGACY: 1.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
score = base_scores[decision_type].get(layer, 0.0)
|
||||||
|
|
||||||
|
# Check data availability
|
||||||
|
latest_timeframe = self.store.get_latest_timeframe(layer)
|
||||||
|
if latest_timeframe is None:
|
||||||
|
# No data available - reduce score significantly
|
||||||
|
score *= 0.1
|
||||||
|
|
||||||
|
return score
|
||||||
|
|
||||||
|
def select_with_scoring(
|
||||||
|
self,
|
||||||
|
decision_type: DecisionType = DecisionType.NORMAL,
|
||||||
|
min_score: float = 0.5,
|
||||||
|
) -> ContextSelection:
|
||||||
|
"""Select context layers with relevance scoring.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decision_type: Type of decision being made
|
||||||
|
min_score: Minimum relevance score to include a layer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ContextSelection with selected layers and scores
|
||||||
|
"""
|
||||||
|
all_layers = [
|
||||||
|
ContextLayer.L7_REALTIME,
|
||||||
|
ContextLayer.L6_DAILY,
|
||||||
|
ContextLayer.L5_WEEKLY,
|
||||||
|
ContextLayer.L4_MONTHLY,
|
||||||
|
ContextLayer.L3_QUARTERLY,
|
||||||
|
ContextLayer.L2_ANNUAL,
|
||||||
|
ContextLayer.L1_LEGACY,
|
||||||
|
]
|
||||||
|
|
||||||
|
scores = {
|
||||||
|
layer: self.score_layer_relevance(layer, decision_type) for layer in all_layers
|
||||||
|
}
|
||||||
|
|
||||||
|
# Filter by minimum score
|
||||||
|
selected_layers = [layer for layer, score in scores.items() if score >= min_score]
|
||||||
|
|
||||||
|
# Sort by score (descending)
|
||||||
|
selected_layers.sort(key=lambda layer: scores[layer], reverse=True)
|
||||||
|
|
||||||
|
total_score = sum(scores[layer] for layer in selected_layers)
|
||||||
|
|
||||||
|
return ContextSelection(
|
||||||
|
layers=selected_layers,
|
||||||
|
relevance_scores=scores,
|
||||||
|
total_score=total_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_context_data(
|
||||||
|
self,
|
||||||
|
layers: list[ContextLayer],
|
||||||
|
max_items_per_layer: int = 10,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Retrieve context data for selected layers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layers: List of context layers to retrieve
|
||||||
|
max_items_per_layer: Maximum number of items per layer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with context data organized by layer
|
||||||
|
"""
|
||||||
|
result: dict[str, Any] = {}
|
||||||
|
|
||||||
|
for layer in layers:
|
||||||
|
# Get latest timeframe for this layer
|
||||||
|
latest_timeframe = self.store.get_latest_timeframe(layer)
|
||||||
|
if latest_timeframe:
|
||||||
|
# Get all contexts for latest timeframe
|
||||||
|
contexts = self.store.get_all_contexts(layer, latest_timeframe)
|
||||||
|
|
||||||
|
# Limit number of items
|
||||||
|
if len(contexts) > max_items_per_layer:
|
||||||
|
# Keep only first N items
|
||||||
|
contexts = dict(list(contexts.items())[:max_items_per_layer])
|
||||||
|
|
||||||
|
result[layer.value] = contexts
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def estimate_context_tokens(self, context_data: dict[str, Any]) -> int:
|
||||||
|
"""Estimate total tokens for context data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_data: Context data dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
from src.brain.prompt_optimizer import PromptOptimizer
|
||||||
|
|
||||||
|
# Serialize to JSON and estimate tokens
|
||||||
|
json_str = json.dumps(context_data, ensure_ascii=False)
|
||||||
|
return PromptOptimizer.estimate_tokens(json_str)
|
||||||
|
|
||||||
|
def optimize_context_for_budget(
|
||||||
|
self,
|
||||||
|
decision_type: DecisionType,
|
||||||
|
max_tokens: int,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Select and retrieve context data within a token budget.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decision_type: Type of decision being made
|
||||||
|
max_tokens: Maximum token budget for context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optimized context data within budget
|
||||||
|
"""
|
||||||
|
# Start with minimal selection
|
||||||
|
selection = self.select_with_scoring(decision_type, min_score=0.5)
|
||||||
|
|
||||||
|
# Retrieve data
|
||||||
|
context_data = self.get_context_data(selection.layers)
|
||||||
|
|
||||||
|
# Check if within budget
|
||||||
|
estimated_tokens = self.estimate_context_tokens(context_data)
|
||||||
|
|
||||||
|
if estimated_tokens <= max_tokens:
|
||||||
|
return context_data
|
||||||
|
|
||||||
|
# If over budget, progressively reduce
|
||||||
|
# 1. Reduce items per layer
|
||||||
|
for max_items in [5, 3, 1]:
|
||||||
|
context_data = self.get_context_data(selection.layers, max_items)
|
||||||
|
estimated_tokens = self.estimate_context_tokens(context_data)
|
||||||
|
if estimated_tokens <= max_tokens:
|
||||||
|
return context_data
|
||||||
|
|
||||||
|
# 2. Remove lower-priority layers
|
||||||
|
for min_score in [0.6, 0.7, 0.8, 0.9]:
|
||||||
|
selection = self.select_with_scoring(decision_type, min_score=min_score)
|
||||||
|
context_data = self.get_context_data(selection.layers, max_items_per_layer=1)
|
||||||
|
estimated_tokens = self.estimate_context_tokens(context_data)
|
||||||
|
if estimated_tokens <= max_tokens:
|
||||||
|
return context_data
|
||||||
|
|
||||||
|
# Last resort: return only L7 with minimal data
|
||||||
|
return self.get_context_data([ContextLayer.L7_REALTIME], max_items_per_layer=1)
|
||||||
@@ -2,6 +2,11 @@
|
|||||||
|
|
||||||
Constructs prompts from market data, calls Gemini, and parses structured
|
Constructs prompts from market data, calls Gemini, and parses structured
|
||||||
JSON responses into validated TradeDecision objects.
|
JSON responses into validated TradeDecision objects.
|
||||||
|
|
||||||
|
Includes token efficiency optimizations:
|
||||||
|
- Prompt compression and abbreviation
|
||||||
|
- Response caching for common scenarios
|
||||||
|
- Token usage tracking and metrics
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -14,6 +19,8 @@ from typing import Any
|
|||||||
|
|
||||||
from google import genai
|
from google import genai
|
||||||
|
|
||||||
|
from src.brain.cache import DecisionCache
|
||||||
|
from src.brain.prompt_optimizer import PromptOptimizer
|
||||||
from src.config import Settings
|
from src.config import Settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -28,17 +35,35 @@ class TradeDecision:
|
|||||||
action: str # "BUY" | "SELL" | "HOLD"
|
action: str # "BUY" | "SELL" | "HOLD"
|
||||||
confidence: int # 0-100
|
confidence: int # 0-100
|
||||||
rationale: str
|
rationale: str
|
||||||
|
token_count: int = 0 # Estimated tokens used
|
||||||
|
cached: bool = False # Whether decision came from cache
|
||||||
|
|
||||||
|
|
||||||
class GeminiClient:
|
class GeminiClient:
|
||||||
"""Wraps the Gemini API for trade decision-making."""
|
"""Wraps the Gemini API for trade decision-making."""
|
||||||
|
|
||||||
def __init__(self, settings: Settings) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
settings: Settings,
|
||||||
|
enable_cache: bool = True,
|
||||||
|
enable_optimization: bool = True,
|
||||||
|
) -> None:
|
||||||
self._settings = settings
|
self._settings = settings
|
||||||
self._confidence_threshold = settings.CONFIDENCE_THRESHOLD
|
self._confidence_threshold = settings.CONFIDENCE_THRESHOLD
|
||||||
self._client = genai.Client(api_key=settings.GEMINI_API_KEY)
|
self._client = genai.Client(api_key=settings.GEMINI_API_KEY)
|
||||||
self._model_name = settings.GEMINI_MODEL
|
self._model_name = settings.GEMINI_MODEL
|
||||||
|
|
||||||
|
# Token efficiency features
|
||||||
|
self._enable_cache = enable_cache
|
||||||
|
self._enable_optimization = enable_optimization
|
||||||
|
self._cache = DecisionCache(ttl_seconds=300) if enable_cache else None
|
||||||
|
self._optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
# Token usage metrics
|
||||||
|
self._total_tokens_used = 0
|
||||||
|
self._total_decisions = 0
|
||||||
|
self._total_cached_decisions = 0
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Prompt Construction
|
# Prompt Construction
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@@ -154,26 +179,141 @@ class GeminiClient:
|
|||||||
|
|
||||||
async def decide(self, market_data: dict[str, Any]) -> TradeDecision:
|
async def decide(self, market_data: dict[str, Any]) -> TradeDecision:
|
||||||
"""Build prompt, call Gemini, and return a parsed decision."""
|
"""Build prompt, call Gemini, and return a parsed decision."""
|
||||||
|
# Check cache first
|
||||||
|
if self._cache:
|
||||||
|
cached_decision = self._cache.get(market_data)
|
||||||
|
if cached_decision:
|
||||||
|
self._total_cached_decisions += 1
|
||||||
|
self._total_decisions += 1
|
||||||
|
logger.info(
|
||||||
|
"Cache hit for decision",
|
||||||
|
extra={
|
||||||
|
"action": cached_decision.action,
|
||||||
|
"confidence": cached_decision.confidence,
|
||||||
|
"cache_hit_rate": self.get_cache_hit_rate(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Return cached decision with cached flag
|
||||||
|
return TradeDecision(
|
||||||
|
action=cached_decision.action,
|
||||||
|
confidence=cached_decision.confidence,
|
||||||
|
rationale=cached_decision.rationale,
|
||||||
|
token_count=0,
|
||||||
|
cached=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build optimized prompt
|
||||||
|
if self._enable_optimization:
|
||||||
|
prompt = self._optimizer.build_compressed_prompt(market_data)
|
||||||
|
else:
|
||||||
prompt = self.build_prompt(market_data)
|
prompt = self.build_prompt(market_data)
|
||||||
logger.info("Requesting trade decision from Gemini")
|
|
||||||
|
# Estimate tokens
|
||||||
|
token_count = self._optimizer.estimate_tokens(prompt)
|
||||||
|
self._total_tokens_used += token_count
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Requesting trade decision from Gemini",
|
||||||
|
extra={"estimated_tokens": token_count, "optimized": self._enable_optimization},
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self._client.aio.models.generate_content(
|
response = await self._client.aio.models.generate_content(
|
||||||
model=self._model_name, contents=prompt,
|
model=self._model_name,
|
||||||
|
contents=prompt,
|
||||||
)
|
)
|
||||||
raw = response.text
|
raw = response.text
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("Gemini API error: %s", exc)
|
logger.error("Gemini API error: %s", exc)
|
||||||
return TradeDecision(
|
return TradeDecision(
|
||||||
action="HOLD", confidence=0, rationale=f"API error: {exc}"
|
action="HOLD", confidence=0, rationale=f"API error: {exc}", token_count=token_count
|
||||||
)
|
)
|
||||||
|
|
||||||
decision = self.parse_response(raw)
|
decision = self.parse_response(raw)
|
||||||
|
self._total_decisions += 1
|
||||||
|
|
||||||
|
# Add token count to decision
|
||||||
|
decision_with_tokens = TradeDecision(
|
||||||
|
action=decision.action,
|
||||||
|
confidence=decision.confidence,
|
||||||
|
rationale=decision.rationale,
|
||||||
|
token_count=token_count,
|
||||||
|
cached=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache if appropriate
|
||||||
|
if self._cache and self._cache.should_cache_decision(decision):
|
||||||
|
self._cache.set(market_data, decision)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Gemini decision",
|
"Gemini decision",
|
||||||
extra={
|
extra={
|
||||||
"action": decision.action,
|
"action": decision.action,
|
||||||
"confidence": decision.confidence,
|
"confidence": decision.confidence,
|
||||||
|
"tokens": token_count,
|
||||||
|
"avg_tokens": self.get_avg_tokens_per_decision(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return decision
|
|
||||||
|
return decision_with_tokens
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Token Efficiency Metrics
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_token_metrics(self) -> dict[str, Any]:
|
||||||
|
"""Get token usage metrics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with token usage statistics
|
||||||
|
"""
|
||||||
|
metrics = {
|
||||||
|
"total_tokens_used": self._total_tokens_used,
|
||||||
|
"total_decisions": self._total_decisions,
|
||||||
|
"total_cached_decisions": self._total_cached_decisions,
|
||||||
|
"avg_tokens_per_decision": self.get_avg_tokens_per_decision(),
|
||||||
|
"cache_hit_rate": self.get_cache_hit_rate(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if self._cache:
|
||||||
|
cache_metrics = self._cache.get_metrics()
|
||||||
|
metrics["cache_metrics"] = cache_metrics.to_dict()
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def get_avg_tokens_per_decision(self) -> float:
|
||||||
|
"""Calculate average tokens per decision.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Average tokens per decision
|
||||||
|
"""
|
||||||
|
if self._total_decisions == 0:
|
||||||
|
return 0.0
|
||||||
|
return self._total_tokens_used / self._total_decisions
|
||||||
|
|
||||||
|
def get_cache_hit_rate(self) -> float:
|
||||||
|
"""Calculate cache hit rate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cache hit rate (0.0 to 1.0)
|
||||||
|
"""
|
||||||
|
if self._total_decisions == 0:
|
||||||
|
return 0.0
|
||||||
|
return self._total_cached_decisions / self._total_decisions
|
||||||
|
|
||||||
|
def reset_metrics(self) -> None:
|
||||||
|
"""Reset token usage metrics."""
|
||||||
|
self._total_tokens_used = 0
|
||||||
|
self._total_decisions = 0
|
||||||
|
self._total_cached_decisions = 0
|
||||||
|
if self._cache:
|
||||||
|
self._cache.reset_metrics()
|
||||||
|
logger.info("Token metrics reset")
|
||||||
|
|
||||||
|
def get_cache(self) -> DecisionCache | None:
|
||||||
|
"""Get the decision cache instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DecisionCache instance or None if caching disabled
|
||||||
|
"""
|
||||||
|
return self._cache
|
||||||
|
|||||||
267
src/brain/prompt_optimizer.py
Normal file
267
src/brain/prompt_optimizer.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
"""Prompt optimization utilities for reducing token usage.
|
||||||
|
|
||||||
|
This module provides tools to compress prompts while maintaining decision quality:
|
||||||
|
- Token counting
|
||||||
|
- Text compression and abbreviation
|
||||||
|
- Template-based prompts with variable slots
|
||||||
|
- Priority-based context truncation
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
# Abbreviation mapping for common terms
|
||||||
|
ABBREVIATIONS = {
|
||||||
|
"price": "P",
|
||||||
|
"volume": "V",
|
||||||
|
"current": "cur",
|
||||||
|
"previous": "prev",
|
||||||
|
"change": "chg",
|
||||||
|
"percentage": "pct",
|
||||||
|
"market": "mkt",
|
||||||
|
"orderbook": "ob",
|
||||||
|
"foreigner": "fgn",
|
||||||
|
"buy": "B",
|
||||||
|
"sell": "S",
|
||||||
|
"hold": "H",
|
||||||
|
"confidence": "conf",
|
||||||
|
"rationale": "reason",
|
||||||
|
"action": "act",
|
||||||
|
"net": "net",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Reverse mapping for decompression
|
||||||
|
REVERSE_ABBREVIATIONS = {v: k for k, v in ABBREVIATIONS.items()}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TokenMetrics:
|
||||||
|
"""Metrics about token usage in a prompt."""
|
||||||
|
|
||||||
|
char_count: int
|
||||||
|
word_count: int
|
||||||
|
estimated_tokens: int # Rough estimate: ~4 chars per token
|
||||||
|
compression_ratio: float = 1.0 # Original / Compressed
|
||||||
|
|
||||||
|
|
||||||
|
class PromptOptimizer:
|
||||||
|
"""Optimizes prompts to reduce token usage while maintaining quality."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def estimate_tokens(text: str) -> int:
|
||||||
|
"""Estimate token count for text.
|
||||||
|
|
||||||
|
Uses a simple heuristic: ~4 characters per token for English.
|
||||||
|
This is approximate but sufficient for optimization purposes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Input text to estimate tokens for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return 0
|
||||||
|
# Simple estimate: 1 token ≈ 4 characters
|
||||||
|
return max(1, len(text) // 4)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def count_tokens(text: str) -> TokenMetrics:
|
||||||
|
"""Count various metrics for a text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Input text to analyze
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TokenMetrics with character, word, and estimated token counts
|
||||||
|
"""
|
||||||
|
char_count = len(text)
|
||||||
|
word_count = len(text.split())
|
||||||
|
estimated_tokens = PromptOptimizer.estimate_tokens(text)
|
||||||
|
|
||||||
|
return TokenMetrics(
|
||||||
|
char_count=char_count,
|
||||||
|
word_count=word_count,
|
||||||
|
estimated_tokens=estimated_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compress_json(data: dict[str, Any]) -> str:
|
||||||
|
"""Compress JSON by removing whitespace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary to serialize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compact JSON string without whitespace
|
||||||
|
"""
|
||||||
|
return json.dumps(data, separators=(",", ":"), ensure_ascii=False)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def abbreviate_text(text: str, aggressive: bool = False) -> str:
|
||||||
|
"""Apply abbreviations to reduce text length.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Input text to abbreviate
|
||||||
|
aggressive: If True, apply more aggressive compression
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Abbreviated text
|
||||||
|
"""
|
||||||
|
result = text
|
||||||
|
|
||||||
|
# Apply word-level abbreviations (case-insensitive)
|
||||||
|
for full, abbr in ABBREVIATIONS.items():
|
||||||
|
# Word boundaries to avoid partial replacements
|
||||||
|
pattern = r"\b" + re.escape(full) + r"\b"
|
||||||
|
result = re.sub(pattern, abbr, result, flags=re.IGNORECASE)
|
||||||
|
|
||||||
|
if aggressive:
|
||||||
|
# Remove articles and filler words
|
||||||
|
result = re.sub(r"\b(a|an|the)\b", "", result, flags=re.IGNORECASE)
|
||||||
|
result = re.sub(r"\b(is|are|was|were)\b", "", result, flags=re.IGNORECASE)
|
||||||
|
# Collapse multiple spaces
|
||||||
|
result = re.sub(r"\s+", " ", result)
|
||||||
|
|
||||||
|
return result.strip()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_compressed_prompt(
|
||||||
|
market_data: dict[str, Any],
|
||||||
|
include_instructions: bool = True,
|
||||||
|
max_length: int | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Build a compressed prompt from market data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
market_data: Market data dictionary with stock info
|
||||||
|
include_instructions: Whether to include full instructions
|
||||||
|
max_length: Maximum character length (truncates if needed)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compressed prompt string
|
||||||
|
"""
|
||||||
|
# Abbreviated market name
|
||||||
|
market_name = market_data.get("market_name", "KR")
|
||||||
|
if "Korea" in market_name:
|
||||||
|
market_name = "KR"
|
||||||
|
elif "United States" in market_name or "US" in market_name:
|
||||||
|
market_name = "US"
|
||||||
|
|
||||||
|
# Core data - always included
|
||||||
|
core_info = {
|
||||||
|
"mkt": market_name,
|
||||||
|
"code": market_data["stock_code"],
|
||||||
|
"P": market_data["current_price"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Optional fields
|
||||||
|
if "orderbook" in market_data and market_data["orderbook"]:
|
||||||
|
ob = market_data["orderbook"]
|
||||||
|
# Compress orderbook: keep only top 3 levels
|
||||||
|
compressed_ob = {
|
||||||
|
"bid": ob.get("bid", [])[:3],
|
||||||
|
"ask": ob.get("ask", [])[:3],
|
||||||
|
}
|
||||||
|
core_info["ob"] = compressed_ob
|
||||||
|
|
||||||
|
if market_data.get("foreigner_net", 0) != 0:
|
||||||
|
core_info["fgn_net"] = market_data["foreigner_net"]
|
||||||
|
|
||||||
|
# Compress to JSON
|
||||||
|
data_str = PromptOptimizer.compress_json(core_info)
|
||||||
|
|
||||||
|
if include_instructions:
|
||||||
|
# Minimal instructions
|
||||||
|
prompt = (
|
||||||
|
f"{market_name} trader. Analyze:\n{data_str}\n\n"
|
||||||
|
'Return JSON: {"act":"BUY"|"SELL"|"HOLD","conf":<0-100>,"reason":"<text>"}\n'
|
||||||
|
"Rules: act=BUY/SELL/HOLD, conf=0-100, reason=concise. No markdown."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Data only (for cached contexts where instructions are known)
|
||||||
|
prompt = data_str
|
||||||
|
|
||||||
|
# Truncate if needed
|
||||||
|
if max_length and len(prompt) > max_length:
|
||||||
|
prompt = prompt[:max_length] + "..."
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def truncate_context(
|
||||||
|
context: dict[str, Any],
|
||||||
|
max_tokens: int,
|
||||||
|
priority_keys: list[str] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Truncate context data to fit within token budget.
|
||||||
|
|
||||||
|
Keeps high-priority keys first, then truncates less important data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context dictionary to truncate
|
||||||
|
max_tokens: Maximum token budget
|
||||||
|
priority_keys: List of keys to keep (in order of priority)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Truncated context dictionary
|
||||||
|
"""
|
||||||
|
if not context:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if priority_keys is None:
|
||||||
|
priority_keys = []
|
||||||
|
|
||||||
|
result: dict[str, Any] = {}
|
||||||
|
current_tokens = 0
|
||||||
|
|
||||||
|
# Add priority keys first
|
||||||
|
for key in priority_keys:
|
||||||
|
if key in context:
|
||||||
|
value_str = json.dumps(context[key])
|
||||||
|
tokens = PromptOptimizer.estimate_tokens(value_str)
|
||||||
|
|
||||||
|
if current_tokens + tokens <= max_tokens:
|
||||||
|
result[key] = context[key]
|
||||||
|
current_tokens += tokens
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Add remaining keys if space available
|
||||||
|
for key, value in context.items():
|
||||||
|
if key in result:
|
||||||
|
continue
|
||||||
|
|
||||||
|
value_str = json.dumps(value)
|
||||||
|
tokens = PromptOptimizer.estimate_tokens(value_str)
|
||||||
|
|
||||||
|
if current_tokens + tokens <= max_tokens:
|
||||||
|
result[key] = value
|
||||||
|
current_tokens += tokens
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calculate_compression_ratio(original: str, compressed: str) -> float:
|
||||||
|
"""Calculate compression ratio between original and compressed text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
original: Original text
|
||||||
|
compressed: Compressed text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compression ratio (original_tokens / compressed_tokens)
|
||||||
|
"""
|
||||||
|
original_tokens = PromptOptimizer.estimate_tokens(original)
|
||||||
|
compressed_tokens = PromptOptimizer.estimate_tokens(compressed)
|
||||||
|
|
||||||
|
if compressed_tokens == 0:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
return original_tokens / compressed_tokens
|
||||||
328
src/context/summarizer.py
Normal file
328
src/context/summarizer.py
Normal file
@@ -0,0 +1,328 @@
|
|||||||
|
"""Context summarization for efficient historical data representation.
|
||||||
|
|
||||||
|
This module summarizes old context data instead of including raw details:
|
||||||
|
- Key metrics only (averages, trends, not details)
|
||||||
|
- Rolling window (keep last N days detailed, summarize older)
|
||||||
|
- Aggregate historical data efficiently
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from src.context.layer import ContextLayer
|
||||||
|
from src.context.store import ContextStore
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SummaryStats:
|
||||||
|
"""Statistical summary of historical data."""
|
||||||
|
|
||||||
|
count: int
|
||||||
|
mean: float | None = None
|
||||||
|
min: float | None = None
|
||||||
|
max: float | None = None
|
||||||
|
std: float | None = None
|
||||||
|
trend: str | None = None # "up", "down", "flat"
|
||||||
|
|
||||||
|
|
||||||
|
class ContextSummarizer:
|
||||||
|
"""Summarizes historical context data to reduce token usage."""
|
||||||
|
|
||||||
|
def __init__(self, store: ContextStore) -> None:
|
||||||
|
"""Initialize the context summarizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
store: ContextStore instance for retrieving context data
|
||||||
|
"""
|
||||||
|
self.store = store
|
||||||
|
|
||||||
|
def summarize_numeric_values(self, values: list[float]) -> SummaryStats:
|
||||||
|
"""Summarize a list of numeric values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
values: List of numeric values to summarize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SummaryStats with mean, min, max, std, and trend
|
||||||
|
"""
|
||||||
|
if not values:
|
||||||
|
return SummaryStats(count=0)
|
||||||
|
|
||||||
|
count = len(values)
|
||||||
|
mean = sum(values) / count
|
||||||
|
min_val = min(values)
|
||||||
|
max_val = max(values)
|
||||||
|
|
||||||
|
# Calculate standard deviation
|
||||||
|
if count > 1:
|
||||||
|
variance = sum((x - mean) ** 2 for x in values) / (count - 1)
|
||||||
|
std = variance**0.5
|
||||||
|
else:
|
||||||
|
std = 0.0
|
||||||
|
|
||||||
|
# Determine trend
|
||||||
|
trend = "flat"
|
||||||
|
if count >= 3:
|
||||||
|
# Simple trend: compare first third vs last third
|
||||||
|
first_third = values[: count // 3]
|
||||||
|
last_third = values[-(count // 3) :]
|
||||||
|
first_avg = sum(first_third) / len(first_third)
|
||||||
|
last_avg = sum(last_third) / len(last_third)
|
||||||
|
|
||||||
|
# Trend threshold: 5% change
|
||||||
|
threshold = 0.05 * abs(first_avg) if first_avg != 0 else 0.01
|
||||||
|
|
||||||
|
if last_avg > first_avg + threshold:
|
||||||
|
trend = "up"
|
||||||
|
elif last_avg < first_avg - threshold:
|
||||||
|
trend = "down"
|
||||||
|
|
||||||
|
return SummaryStats(
|
||||||
|
count=count,
|
||||||
|
mean=round(mean, 4),
|
||||||
|
min=round(min_val, 4),
|
||||||
|
max=round(max_val, 4),
|
||||||
|
std=round(std, 4),
|
||||||
|
trend=trend,
|
||||||
|
)
|
||||||
|
|
||||||
|
def summarize_layer(
|
||||||
|
self,
|
||||||
|
layer: ContextLayer,
|
||||||
|
start_date: datetime | None = None,
|
||||||
|
end_date: datetime | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Summarize all context data for a layer within a date range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer: Context layer to summarize
|
||||||
|
start_date: Start date (inclusive), None for all
|
||||||
|
end_date: End date (inclusive), None for now
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with summarized metrics
|
||||||
|
"""
|
||||||
|
if end_date is None:
|
||||||
|
end_date = datetime.now(UTC)
|
||||||
|
|
||||||
|
# Get all contexts for this layer
|
||||||
|
all_contexts = self.store.get_all_contexts(layer)
|
||||||
|
|
||||||
|
if not all_contexts:
|
||||||
|
return {"summary": "No data available", "count": 0}
|
||||||
|
|
||||||
|
# Group numeric values by key
|
||||||
|
numeric_data: dict[str, list[float]] = {}
|
||||||
|
text_data: dict[str, list[str]] = {}
|
||||||
|
|
||||||
|
for key, value in all_contexts.items():
|
||||||
|
# Try to extract numeric values
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
if key not in numeric_data:
|
||||||
|
numeric_data[key] = []
|
||||||
|
numeric_data[key].append(float(value))
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
# Extract numeric fields from dict
|
||||||
|
for subkey, subvalue in value.items():
|
||||||
|
if isinstance(subvalue, (int, float)):
|
||||||
|
full_key = f"{key}.{subkey}"
|
||||||
|
if full_key not in numeric_data:
|
||||||
|
numeric_data[full_key] = []
|
||||||
|
numeric_data[full_key].append(float(subvalue))
|
||||||
|
elif isinstance(value, str):
|
||||||
|
if key not in text_data:
|
||||||
|
text_data[key] = []
|
||||||
|
text_data[key].append(value)
|
||||||
|
|
||||||
|
# Summarize numeric data
|
||||||
|
summary: dict[str, Any] = {}
|
||||||
|
|
||||||
|
for key, values in numeric_data.items():
|
||||||
|
stats = self.summarize_numeric_values(values)
|
||||||
|
summary[key] = {
|
||||||
|
"count": stats.count,
|
||||||
|
"avg": stats.mean,
|
||||||
|
"range": [stats.min, stats.max],
|
||||||
|
"trend": stats.trend,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Summarize text data (just counts)
|
||||||
|
for key, values in text_data.items():
|
||||||
|
summary[f"{key}_count"] = len(values)
|
||||||
|
|
||||||
|
summary["total_entries"] = len(all_contexts)
|
||||||
|
|
||||||
|
return summary
|
||||||
|
|
||||||
|
def rolling_window_summary(
|
||||||
|
self,
|
||||||
|
layer: ContextLayer,
|
||||||
|
window_days: int = 30,
|
||||||
|
summarize_older: bool = True,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create a rolling window summary.
|
||||||
|
|
||||||
|
Recent data (within window) is kept detailed.
|
||||||
|
Older data is summarized to key metrics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer: Context layer to summarize
|
||||||
|
window_days: Number of days to keep detailed
|
||||||
|
summarize_older: Whether to summarize data older than window
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with recent (detailed) and historical (summary) data
|
||||||
|
"""
|
||||||
|
result: dict[str, Any] = {
|
||||||
|
"window_days": window_days,
|
||||||
|
"recent_data": {},
|
||||||
|
"historical_summary": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get all contexts
|
||||||
|
all_contexts = self.store.get_all_contexts(layer)
|
||||||
|
|
||||||
|
recent_values: dict[str, list[float]] = {}
|
||||||
|
historical_values: dict[str, list[float]] = {}
|
||||||
|
|
||||||
|
for key, value in all_contexts.items():
|
||||||
|
# For simplicity, treat all numeric values
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
# Note: We don't have timestamps in context keys
|
||||||
|
# This is a simplified implementation
|
||||||
|
# In practice, would need to check timeframe field
|
||||||
|
|
||||||
|
# For now, put recent data in window
|
||||||
|
if key not in recent_values:
|
||||||
|
recent_values[key] = []
|
||||||
|
recent_values[key].append(float(value))
|
||||||
|
|
||||||
|
# Detailed recent data
|
||||||
|
result["recent_data"] = {key: values[-10:] for key, values in recent_values.items()}
|
||||||
|
|
||||||
|
# Summarized historical data
|
||||||
|
if summarize_older:
|
||||||
|
for key, values in historical_values.items():
|
||||||
|
stats = self.summarize_numeric_values(values)
|
||||||
|
result["historical_summary"][key] = {
|
||||||
|
"count": stats.count,
|
||||||
|
"avg": stats.mean,
|
||||||
|
"trend": stats.trend,
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def aggregate_to_higher_layer(
|
||||||
|
self,
|
||||||
|
source_layer: ContextLayer,
|
||||||
|
target_layer: ContextLayer,
|
||||||
|
metric_key: str,
|
||||||
|
aggregation_func: str = "mean",
|
||||||
|
) -> float | None:
|
||||||
|
"""Aggregate data from source layer to target layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_layer: Source context layer (more granular)
|
||||||
|
target_layer: Target context layer (less granular)
|
||||||
|
metric_key: Key of metric to aggregate
|
||||||
|
aggregation_func: Aggregation function ("mean", "sum", "max", "min")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Aggregated value, or None if no data available
|
||||||
|
"""
|
||||||
|
# Get all contexts from source layer
|
||||||
|
source_contexts = self.store.get_all_contexts(source_layer)
|
||||||
|
|
||||||
|
# Extract values for metric_key
|
||||||
|
values = []
|
||||||
|
for key, value in source_contexts.items():
|
||||||
|
if key == metric_key and isinstance(value, (int, float)):
|
||||||
|
values.append(float(value))
|
||||||
|
elif isinstance(value, dict) and metric_key in value:
|
||||||
|
subvalue = value[metric_key]
|
||||||
|
if isinstance(subvalue, (int, float)):
|
||||||
|
values.append(float(subvalue))
|
||||||
|
|
||||||
|
if not values:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Apply aggregation function
|
||||||
|
if aggregation_func == "mean":
|
||||||
|
return sum(values) / len(values)
|
||||||
|
elif aggregation_func == "sum":
|
||||||
|
return sum(values)
|
||||||
|
elif aggregation_func == "max":
|
||||||
|
return max(values)
|
||||||
|
elif aggregation_func == "min":
|
||||||
|
return min(values)
|
||||||
|
else:
|
||||||
|
return sum(values) / len(values) # Default to mean
|
||||||
|
|
||||||
|
def create_compact_summary(
|
||||||
|
self,
|
||||||
|
layers: list[ContextLayer],
|
||||||
|
top_n_metrics: int = 5,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create a compact summary across multiple layers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layers: List of context layers to summarize
|
||||||
|
top_n_metrics: Number of top metrics to include per layer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compact summary dictionary
|
||||||
|
"""
|
||||||
|
summary: dict[str, Any] = {}
|
||||||
|
|
||||||
|
for layer in layers:
|
||||||
|
layer_summary = self.summarize_layer(layer)
|
||||||
|
|
||||||
|
# Keep only top N metrics (by count/relevance)
|
||||||
|
metrics = []
|
||||||
|
for key, value in layer_summary.items():
|
||||||
|
if isinstance(value, dict) and "count" in value:
|
||||||
|
metrics.append((key, value, value["count"]))
|
||||||
|
|
||||||
|
# Sort by count (descending)
|
||||||
|
metrics.sort(key=lambda x: x[2], reverse=True)
|
||||||
|
|
||||||
|
# Keep top N
|
||||||
|
top_metrics = {m[0]: m[1] for m in metrics[:top_n_metrics]}
|
||||||
|
|
||||||
|
summary[layer.value] = top_metrics
|
||||||
|
|
||||||
|
return summary
|
||||||
|
|
||||||
|
def format_summary_for_prompt(self, summary: dict[str, Any]) -> str:
|
||||||
|
"""Format summary for inclusion in a prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
summary: Summary dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted string for prompt
|
||||||
|
"""
|
||||||
|
lines = []
|
||||||
|
|
||||||
|
for layer, metrics in summary.items():
|
||||||
|
if not metrics:
|
||||||
|
continue
|
||||||
|
|
||||||
|
lines.append(f"{layer}:")
|
||||||
|
for key, value in metrics.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
# Format as: key: avg=X, trend=Y
|
||||||
|
parts = []
|
||||||
|
if "avg" in value and value["avg"] is not None:
|
||||||
|
parts.append(f"avg={value['avg']:.2f}")
|
||||||
|
if "trend" in value and value["trend"]:
|
||||||
|
parts.append(f"trend={value['trend']}")
|
||||||
|
if parts:
|
||||||
|
lines.append(f" {key}: {', '.join(parts)}")
|
||||||
|
else:
|
||||||
|
lines.append(f" {key}: {value}")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
110
src/core/criticality.py
Normal file
110
src/core/criticality.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
"""Criticality assessment for urgency-based response system.
|
||||||
|
|
||||||
|
Evaluates market conditions to determine response urgency and enable
|
||||||
|
faster reactions in critical situations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from enum import StrEnum
|
||||||
|
|
||||||
|
|
||||||
|
class CriticalityLevel(StrEnum):
|
||||||
|
"""Urgency levels for market conditions and trading decisions."""
|
||||||
|
|
||||||
|
CRITICAL = "CRITICAL" # <5s timeout - Emergency response required
|
||||||
|
HIGH = "HIGH" # <30s timeout - Elevated priority
|
||||||
|
NORMAL = "NORMAL" # <60s timeout - Standard processing
|
||||||
|
LOW = "LOW" # No timeout - Batch processing
|
||||||
|
|
||||||
|
|
||||||
|
class CriticalityAssessor:
|
||||||
|
"""Assesses market conditions to determine response criticality level."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
critical_pnl_threshold: float = -2.5,
|
||||||
|
critical_price_change_threshold: float = 5.0,
|
||||||
|
critical_volume_surge_threshold: float = 10.0,
|
||||||
|
high_volatility_threshold: float = 70.0,
|
||||||
|
low_volatility_threshold: float = 30.0,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the criticality assessor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
critical_pnl_threshold: P&L % that triggers CRITICAL (default -2.5%)
|
||||||
|
critical_price_change_threshold: Price change % that triggers CRITICAL
|
||||||
|
(default 5.0% in 1 minute)
|
||||||
|
critical_volume_surge_threshold: Volume surge ratio that triggers CRITICAL
|
||||||
|
(default 10x average)
|
||||||
|
high_volatility_threshold: Volatility score that triggers HIGH
|
||||||
|
(default 70.0)
|
||||||
|
low_volatility_threshold: Volatility score below which is LOW
|
||||||
|
(default 30.0)
|
||||||
|
"""
|
||||||
|
self.critical_pnl_threshold = critical_pnl_threshold
|
||||||
|
self.critical_price_change_threshold = critical_price_change_threshold
|
||||||
|
self.critical_volume_surge_threshold = critical_volume_surge_threshold
|
||||||
|
self.high_volatility_threshold = high_volatility_threshold
|
||||||
|
self.low_volatility_threshold = low_volatility_threshold
|
||||||
|
|
||||||
|
def assess_market_conditions(
|
||||||
|
self,
|
||||||
|
pnl_pct: float,
|
||||||
|
volatility_score: float,
|
||||||
|
volume_surge: float,
|
||||||
|
price_change_1m: float = 0.0,
|
||||||
|
is_market_open: bool = True,
|
||||||
|
) -> CriticalityLevel:
|
||||||
|
"""Assess criticality level based on market conditions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pnl_pct: Current P&L percentage
|
||||||
|
volatility_score: Momentum score from VolatilityAnalyzer (0-100)
|
||||||
|
volume_surge: Volume surge ratio (current / average)
|
||||||
|
price_change_1m: 1-minute price change percentage
|
||||||
|
is_market_open: Whether the market is currently open
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CriticalityLevel indicating required response urgency
|
||||||
|
"""
|
||||||
|
# Market closed or very quiet → LOW priority (batch processing)
|
||||||
|
if not is_market_open or volatility_score < self.low_volatility_threshold:
|
||||||
|
return CriticalityLevel.LOW
|
||||||
|
|
||||||
|
# CRITICAL conditions: immediate action required
|
||||||
|
# 1. P&L near circuit breaker (-2.5% is close to -3.0% breaker)
|
||||||
|
if pnl_pct <= self.critical_pnl_threshold:
|
||||||
|
return CriticalityLevel.CRITICAL
|
||||||
|
|
||||||
|
# 2. Large sudden price movement (>5% in 1 minute)
|
||||||
|
if abs(price_change_1m) >= self.critical_price_change_threshold:
|
||||||
|
return CriticalityLevel.CRITICAL
|
||||||
|
|
||||||
|
# 3. Extreme volume surge (>10x average) indicates major event
|
||||||
|
if volume_surge >= self.critical_volume_surge_threshold:
|
||||||
|
return CriticalityLevel.CRITICAL
|
||||||
|
|
||||||
|
# HIGH priority: elevated volatility requires faster response
|
||||||
|
if volatility_score >= self.high_volatility_threshold:
|
||||||
|
return CriticalityLevel.HIGH
|
||||||
|
|
||||||
|
# NORMAL: standard trading conditions
|
||||||
|
return CriticalityLevel.NORMAL
|
||||||
|
|
||||||
|
def get_timeout(self, level: CriticalityLevel) -> float | None:
|
||||||
|
"""Get timeout in seconds for a given criticality level.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: Criticality level
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Timeout in seconds, or None for no timeout (LOW priority)
|
||||||
|
"""
|
||||||
|
timeout_map = {
|
||||||
|
CriticalityLevel.CRITICAL: 5.0,
|
||||||
|
CriticalityLevel.HIGH: 30.0,
|
||||||
|
CriticalityLevel.NORMAL: 60.0,
|
||||||
|
CriticalityLevel.LOW: None,
|
||||||
|
}
|
||||||
|
return timeout_map[level]
|
||||||
291
src/core/priority_queue.py
Normal file
291
src/core/priority_queue.py
Normal file
@@ -0,0 +1,291 @@
|
|||||||
|
"""Priority-based task queue for latency control.
|
||||||
|
|
||||||
|
Implements a thread-safe priority queue with timeout enforcement and metrics tracking.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import heapq
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from collections.abc import Callable, Coroutine
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from src.core.criticality import CriticalityLevel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(order=True)
|
||||||
|
class PriorityTask:
|
||||||
|
"""Task with priority and timestamp for queue ordering."""
|
||||||
|
|
||||||
|
# Lower priority value = higher urgency (CRITICAL=0, HIGH=1, NORMAL=2, LOW=3)
|
||||||
|
priority: int
|
||||||
|
timestamp: float
|
||||||
|
# Task data not used in comparison
|
||||||
|
task_id: str = field(compare=False)
|
||||||
|
task_data: dict[str, Any] = field(compare=False, default_factory=dict)
|
||||||
|
callback: Callable[[], Coroutine[Any, Any, Any]] | None = field(
|
||||||
|
compare=False, default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QueueMetrics:
|
||||||
|
"""Metrics for priority queue performance monitoring."""
|
||||||
|
|
||||||
|
total_enqueued: int = 0
|
||||||
|
total_dequeued: int = 0
|
||||||
|
total_timeouts: int = 0
|
||||||
|
total_errors: int = 0
|
||||||
|
current_size: int = 0
|
||||||
|
# Average wait time per criticality level (in seconds)
|
||||||
|
avg_wait_time: dict[CriticalityLevel, float] = field(default_factory=dict)
|
||||||
|
# P95 wait time per criticality level
|
||||||
|
p95_wait_time: dict[CriticalityLevel, float] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class PriorityTaskQueue:
|
||||||
|
"""Thread-safe priority queue with timeout enforcement."""
|
||||||
|
|
||||||
|
# Priority mapping for criticality levels
|
||||||
|
PRIORITY_MAP = {
|
||||||
|
CriticalityLevel.CRITICAL: 0,
|
||||||
|
CriticalityLevel.HIGH: 1,
|
||||||
|
CriticalityLevel.NORMAL: 2,
|
||||||
|
CriticalityLevel.LOW: 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, max_size: int = 1000) -> None:
|
||||||
|
"""Initialize the priority task queue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_size: Maximum queue size (default 1000)
|
||||||
|
"""
|
||||||
|
self._queue: list[PriorityTask] = []
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
self._max_size = max_size
|
||||||
|
self._metrics = QueueMetrics()
|
||||||
|
# Track wait times for metrics
|
||||||
|
self._wait_times: dict[CriticalityLevel, list[float]] = {
|
||||||
|
level: [] for level in CriticalityLevel
|
||||||
|
}
|
||||||
|
|
||||||
|
async def enqueue(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
criticality: CriticalityLevel,
|
||||||
|
task_data: dict[str, Any],
|
||||||
|
callback: Callable[[], Coroutine[Any, Any, Any]] | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Add a task to the priority queue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Unique identifier for the task
|
||||||
|
criticality: Criticality level determining priority
|
||||||
|
task_data: Data associated with the task
|
||||||
|
callback: Optional async callback to execute
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if enqueued successfully, False if queue is full
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
if len(self._queue) >= self._max_size:
|
||||||
|
logger.warning(
|
||||||
|
"Priority queue full (size=%d), rejecting task %s",
|
||||||
|
len(self._queue),
|
||||||
|
task_id,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
priority = self.PRIORITY_MAP[criticality]
|
||||||
|
timestamp = time.time()
|
||||||
|
|
||||||
|
task = PriorityTask(
|
||||||
|
priority=priority,
|
||||||
|
timestamp=timestamp,
|
||||||
|
task_id=task_id,
|
||||||
|
task_data=task_data,
|
||||||
|
callback=callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
heapq.heappush(self._queue, task)
|
||||||
|
self._metrics.total_enqueued += 1
|
||||||
|
self._metrics.current_size = len(self._queue)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Enqueued task %s with criticality %s (priority=%d, queue_size=%d)",
|
||||||
|
task_id,
|
||||||
|
criticality.value,
|
||||||
|
priority,
|
||||||
|
len(self._queue),
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def dequeue(self, timeout: float | None = None) -> PriorityTask | None:
|
||||||
|
"""Remove and return the highest priority task from the queue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: Maximum time to wait for a task (seconds)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PriorityTask if available, None if queue is empty or timeout
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
deadline = start_time + timeout if timeout else None
|
||||||
|
|
||||||
|
while True:
|
||||||
|
async with self._lock:
|
||||||
|
if self._queue:
|
||||||
|
task = heapq.heappop(self._queue)
|
||||||
|
self._metrics.total_dequeued += 1
|
||||||
|
self._metrics.current_size = len(self._queue)
|
||||||
|
|
||||||
|
# Calculate wait time
|
||||||
|
wait_time = time.time() - task.timestamp
|
||||||
|
criticality = self._get_criticality_from_priority(task.priority)
|
||||||
|
self._wait_times[criticality].append(wait_time)
|
||||||
|
self._update_wait_time_metrics()
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Dequeued task %s (priority=%d, wait_time=%.2fs, queue_size=%d)",
|
||||||
|
task.task_id,
|
||||||
|
task.priority,
|
||||||
|
wait_time,
|
||||||
|
len(self._queue),
|
||||||
|
)
|
||||||
|
|
||||||
|
return task
|
||||||
|
|
||||||
|
# Queue is empty
|
||||||
|
if deadline and time.time() >= deadline:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Wait a bit before checking again
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
async def execute_with_timeout(
|
||||||
|
self,
|
||||||
|
task: PriorityTask,
|
||||||
|
timeout: float | None,
|
||||||
|
) -> Any:
|
||||||
|
"""Execute a task with timeout enforcement.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: Task to execute
|
||||||
|
timeout: Timeout in seconds (None = no timeout)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Result from task callback
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
asyncio.TimeoutError: If task exceeds timeout
|
||||||
|
Exception: Any exception raised by the task callback
|
||||||
|
"""
|
||||||
|
if not task.callback:
|
||||||
|
logger.warning("Task %s has no callback, skipping execution", task.task_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
criticality = self._get_criticality_from_priority(task.priority)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if timeout:
|
||||||
|
result = await asyncio.wait_for(task.callback(), timeout=timeout)
|
||||||
|
else:
|
||||||
|
result = await task.callback()
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Task %s completed successfully (criticality=%s)",
|
||||||
|
task.task_id,
|
||||||
|
criticality.value,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
except TimeoutError:
|
||||||
|
self._metrics.total_timeouts += 1
|
||||||
|
logger.error(
|
||||||
|
"Task %s timed out after %.2fs (criticality=%s)",
|
||||||
|
task.task_id,
|
||||||
|
timeout or 0.0,
|
||||||
|
criticality.value,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
self._metrics.total_errors += 1
|
||||||
|
logger.exception(
|
||||||
|
"Task %s failed with error (criticality=%s): %s",
|
||||||
|
task.task_id,
|
||||||
|
criticality.value,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _get_criticality_from_priority(self, priority: int) -> CriticalityLevel:
|
||||||
|
"""Convert priority back to criticality level."""
|
||||||
|
for level, prio in self.PRIORITY_MAP.items():
|
||||||
|
if prio == priority:
|
||||||
|
return level
|
||||||
|
return CriticalityLevel.NORMAL
|
||||||
|
|
||||||
|
def _update_wait_time_metrics(self) -> None:
|
||||||
|
"""Update average and p95 wait time metrics."""
|
||||||
|
for level, times in self._wait_times.items():
|
||||||
|
if not times:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Keep only last 1000 measurements to avoid memory bloat
|
||||||
|
if len(times) > 1000:
|
||||||
|
self._wait_times[level] = times[-1000:]
|
||||||
|
times = self._wait_times[level]
|
||||||
|
|
||||||
|
# Calculate average
|
||||||
|
self._metrics.avg_wait_time[level] = sum(times) / len(times)
|
||||||
|
|
||||||
|
# Calculate P95
|
||||||
|
sorted_times = sorted(times)
|
||||||
|
p95_idx = int(len(sorted_times) * 0.95)
|
||||||
|
self._metrics.p95_wait_time[level] = sorted_times[p95_idx]
|
||||||
|
|
||||||
|
async def get_metrics(self) -> QueueMetrics:
|
||||||
|
"""Get current queue metrics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
QueueMetrics with current statistics
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
return QueueMetrics(
|
||||||
|
total_enqueued=self._metrics.total_enqueued,
|
||||||
|
total_dequeued=self._metrics.total_dequeued,
|
||||||
|
total_timeouts=self._metrics.total_timeouts,
|
||||||
|
total_errors=self._metrics.total_errors,
|
||||||
|
current_size=self._metrics.current_size,
|
||||||
|
avg_wait_time=dict(self._metrics.avg_wait_time),
|
||||||
|
p95_wait_time=dict(self._metrics.p95_wait_time),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def size(self) -> int:
|
||||||
|
"""Get current queue size.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of tasks in queue
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
return len(self._queue)
|
||||||
|
|
||||||
|
async def clear(self) -> int:
|
||||||
|
"""Clear all tasks from the queue.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of tasks cleared
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
count = len(self._queue)
|
||||||
|
self._queue.clear()
|
||||||
|
self._metrics.current_size = 0
|
||||||
|
logger.info("Cleared %d tasks from priority queue", count)
|
||||||
|
return count
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
"""Evolution engine for self-improving trading strategies."""
|
||||||
|
|
||||||
|
from src.evolution.ab_test import ABTester, ABTestResult, StrategyPerformance
|
||||||
|
from src.evolution.optimizer import EvolutionOptimizer
|
||||||
|
from src.evolution.performance_tracker import (
|
||||||
|
PerformanceDashboard,
|
||||||
|
PerformanceTracker,
|
||||||
|
StrategyMetrics,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"EvolutionOptimizer",
|
||||||
|
"ABTester",
|
||||||
|
"ABTestResult",
|
||||||
|
"StrategyPerformance",
|
||||||
|
"PerformanceTracker",
|
||||||
|
"PerformanceDashboard",
|
||||||
|
"StrategyMetrics",
|
||||||
|
]
|
||||||
|
|||||||
220
src/evolution/ab_test.py
Normal file
220
src/evolution/ab_test.py
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
"""A/B Testing framework for strategy comparison.
|
||||||
|
|
||||||
|
Runs multiple strategies in parallel, tracks their performance,
|
||||||
|
and uses statistical significance testing to determine winners.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import scipy.stats as stats
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StrategyPerformance:
|
||||||
|
"""Performance metrics for a single strategy."""
|
||||||
|
|
||||||
|
strategy_name: str
|
||||||
|
total_trades: int
|
||||||
|
wins: int
|
||||||
|
losses: int
|
||||||
|
total_pnl: float
|
||||||
|
avg_pnl: float
|
||||||
|
win_rate: float
|
||||||
|
sharpe_ratio: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ABTestResult:
|
||||||
|
"""Result of an A/B test between two strategies."""
|
||||||
|
|
||||||
|
strategy_a: str
|
||||||
|
strategy_b: str
|
||||||
|
winner: str | None
|
||||||
|
p_value: float
|
||||||
|
confidence_level: float
|
||||||
|
is_significant: bool
|
||||||
|
performance_a: StrategyPerformance
|
||||||
|
performance_b: StrategyPerformance
|
||||||
|
|
||||||
|
|
||||||
|
class ABTester:
|
||||||
|
"""A/B testing framework for comparing trading strategies."""
|
||||||
|
|
||||||
|
def __init__(self, significance_level: float = 0.05) -> None:
|
||||||
|
"""Initialize A/B tester.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
significance_level: P-value threshold for statistical significance (default 0.05)
|
||||||
|
"""
|
||||||
|
self._significance_level = significance_level
|
||||||
|
|
||||||
|
def calculate_performance(
|
||||||
|
self, trades: list[dict[str, Any]], strategy_name: str
|
||||||
|
) -> StrategyPerformance:
|
||||||
|
"""Calculate performance metrics for a strategy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trades: List of trade records with pnl values
|
||||||
|
strategy_name: Name of the strategy
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StrategyPerformance object with calculated metrics
|
||||||
|
"""
|
||||||
|
if not trades:
|
||||||
|
return StrategyPerformance(
|
||||||
|
strategy_name=strategy_name,
|
||||||
|
total_trades=0,
|
||||||
|
wins=0,
|
||||||
|
losses=0,
|
||||||
|
total_pnl=0.0,
|
||||||
|
avg_pnl=0.0,
|
||||||
|
win_rate=0.0,
|
||||||
|
sharpe_ratio=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
total_trades = len(trades)
|
||||||
|
wins = sum(1 for t in trades if t.get("pnl", 0) > 0)
|
||||||
|
losses = sum(1 for t in trades if t.get("pnl", 0) < 0)
|
||||||
|
pnls = [t.get("pnl", 0.0) for t in trades]
|
||||||
|
total_pnl = sum(pnls)
|
||||||
|
avg_pnl = total_pnl / total_trades if total_trades > 0 else 0.0
|
||||||
|
win_rate = (wins / total_trades * 100) if total_trades > 0 else 0.0
|
||||||
|
|
||||||
|
# Calculate Sharpe ratio (risk-adjusted return)
|
||||||
|
sharpe_ratio = None
|
||||||
|
if len(pnls) > 1:
|
||||||
|
mean_return = avg_pnl
|
||||||
|
std_return = (
|
||||||
|
sum((p - mean_return) ** 2 for p in pnls) / (len(pnls) - 1)
|
||||||
|
) ** 0.5
|
||||||
|
if std_return > 0:
|
||||||
|
sharpe_ratio = mean_return / std_return
|
||||||
|
|
||||||
|
return StrategyPerformance(
|
||||||
|
strategy_name=strategy_name,
|
||||||
|
total_trades=total_trades,
|
||||||
|
wins=wins,
|
||||||
|
losses=losses,
|
||||||
|
total_pnl=round(total_pnl, 2),
|
||||||
|
avg_pnl=round(avg_pnl, 2),
|
||||||
|
win_rate=round(win_rate, 2),
|
||||||
|
sharpe_ratio=round(sharpe_ratio, 4) if sharpe_ratio else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def compare_strategies(
|
||||||
|
self,
|
||||||
|
trades_a: list[dict[str, Any]],
|
||||||
|
trades_b: list[dict[str, Any]],
|
||||||
|
strategy_a_name: str = "Strategy A",
|
||||||
|
strategy_b_name: str = "Strategy B",
|
||||||
|
) -> ABTestResult:
|
||||||
|
"""Compare two strategies using statistical testing.
|
||||||
|
|
||||||
|
Uses a two-sample t-test to determine if performance difference is significant.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trades_a: List of trades from strategy A
|
||||||
|
trades_b: List of trades from strategy B
|
||||||
|
strategy_a_name: Name of strategy A
|
||||||
|
strategy_b_name: Name of strategy B
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ABTestResult with comparison details
|
||||||
|
"""
|
||||||
|
perf_a = self.calculate_performance(trades_a, strategy_a_name)
|
||||||
|
perf_b = self.calculate_performance(trades_b, strategy_b_name)
|
||||||
|
|
||||||
|
# Extract PnL arrays for statistical testing
|
||||||
|
pnls_a = [t.get("pnl", 0.0) for t in trades_a]
|
||||||
|
pnls_b = [t.get("pnl", 0.0) for t in trades_b]
|
||||||
|
|
||||||
|
# Perform two-sample t-test
|
||||||
|
if len(pnls_a) > 1 and len(pnls_b) > 1:
|
||||||
|
t_stat, p_value = stats.ttest_ind(pnls_a, pnls_b, equal_var=False)
|
||||||
|
is_significant = p_value < self._significance_level
|
||||||
|
confidence_level = (1 - p_value) * 100
|
||||||
|
else:
|
||||||
|
# Not enough data for statistical test
|
||||||
|
p_value = 1.0
|
||||||
|
is_significant = False
|
||||||
|
confidence_level = 0.0
|
||||||
|
|
||||||
|
# Determine winner based on average PnL
|
||||||
|
winner = None
|
||||||
|
if is_significant:
|
||||||
|
if perf_a.avg_pnl > perf_b.avg_pnl:
|
||||||
|
winner = strategy_a_name
|
||||||
|
elif perf_b.avg_pnl > perf_a.avg_pnl:
|
||||||
|
winner = strategy_b_name
|
||||||
|
|
||||||
|
return ABTestResult(
|
||||||
|
strategy_a=strategy_a_name,
|
||||||
|
strategy_b=strategy_b_name,
|
||||||
|
winner=winner,
|
||||||
|
p_value=round(p_value, 4),
|
||||||
|
confidence_level=round(confidence_level, 2),
|
||||||
|
is_significant=is_significant,
|
||||||
|
performance_a=perf_a,
|
||||||
|
performance_b=perf_b,
|
||||||
|
)
|
||||||
|
|
||||||
|
def should_deploy(
|
||||||
|
self,
|
||||||
|
result: ABTestResult,
|
||||||
|
min_win_rate: float = 60.0,
|
||||||
|
min_trades: int = 20,
|
||||||
|
) -> bool:
|
||||||
|
"""Determine if a winning strategy should be deployed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: A/B test result
|
||||||
|
min_win_rate: Minimum win rate percentage for deployment (default 60%)
|
||||||
|
min_trades: Minimum number of trades required (default 20)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the winning strategy meets deployment criteria
|
||||||
|
"""
|
||||||
|
if not result.is_significant or result.winner is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Get performance of winning strategy
|
||||||
|
if result.winner == result.strategy_a:
|
||||||
|
winning_perf = result.performance_a
|
||||||
|
else:
|
||||||
|
winning_perf = result.performance_b
|
||||||
|
|
||||||
|
# Check deployment criteria
|
||||||
|
has_enough_trades = winning_perf.total_trades >= min_trades
|
||||||
|
has_good_win_rate = winning_perf.win_rate >= min_win_rate
|
||||||
|
is_profitable = winning_perf.avg_pnl > 0
|
||||||
|
|
||||||
|
meets_criteria = has_enough_trades and has_good_win_rate and is_profitable
|
||||||
|
|
||||||
|
if meets_criteria:
|
||||||
|
logger.info(
|
||||||
|
"Strategy '%s' meets deployment criteria: "
|
||||||
|
"win_rate=%.2f%%, trades=%d, avg_pnl=%.2f",
|
||||||
|
result.winner,
|
||||||
|
winning_perf.win_rate,
|
||||||
|
winning_perf.total_trades,
|
||||||
|
winning_perf.avg_pnl,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Strategy '%s' does NOT meet deployment criteria: "
|
||||||
|
"win_rate=%.2f%% (min %.2f%%), trades=%d (min %d), avg_pnl=%.2f",
|
||||||
|
result.winner if result.winner else "unknown",
|
||||||
|
winning_perf.win_rate if result.winner else 0.0,
|
||||||
|
min_win_rate,
|
||||||
|
winning_perf.total_trades if result.winner else 0,
|
||||||
|
min_trades,
|
||||||
|
winning_perf.avg_pnl if result.winner else 0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
return meets_criteria
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
"""Evolution Engine — analyzes trade logs and generates new strategies.
|
"""Evolution Engine — analyzes trade logs and generates new strategies.
|
||||||
|
|
||||||
This module:
|
This module:
|
||||||
1. Reads trade_logs.db to identify failing patterns
|
1. Uses DecisionLogger.get_losing_decisions() to identify failing patterns
|
||||||
2. Asks Gemini to generate a new strategy class
|
2. Analyzes failure patterns by time, market conditions, stock characteristics
|
||||||
3. Runs pytest on the generated file
|
3. Asks Gemini to generate improved strategy recommendations
|
||||||
4. Creates a simulated PR if tests pass
|
4. Generates new strategy classes with enhanced decision-making logic
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -14,6 +14,7 @@ import logging
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
import subprocess
|
import subprocess
|
||||||
import textwrap
|
import textwrap
|
||||||
|
from collections import Counter
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -21,6 +22,8 @@ from typing import Any
|
|||||||
from google import genai
|
from google import genai
|
||||||
|
|
||||||
from src.config import Settings
|
from src.config import Settings
|
||||||
|
from src.db import init_db
|
||||||
|
from src.logging.decision_logger import DecisionLogger
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -53,29 +56,105 @@ class EvolutionOptimizer:
|
|||||||
self._db_path = settings.DB_PATH
|
self._db_path = settings.DB_PATH
|
||||||
self._client = genai.Client(api_key=settings.GEMINI_API_KEY)
|
self._client = genai.Client(api_key=settings.GEMINI_API_KEY)
|
||||||
self._model_name = settings.GEMINI_MODEL
|
self._model_name = settings.GEMINI_MODEL
|
||||||
|
self._conn = init_db(self._db_path)
|
||||||
|
self._decision_logger = DecisionLogger(self._conn)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Analysis
|
# Analysis
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def analyze_failures(self, limit: int = 50) -> list[dict[str, Any]]:
|
def analyze_failures(self, limit: int = 50) -> list[dict[str, Any]]:
|
||||||
"""Find trades where high confidence led to losses."""
|
"""Find high-confidence decisions that resulted in losses.
|
||||||
conn = sqlite3.connect(self._db_path)
|
|
||||||
conn.row_factory = sqlite3.Row
|
Uses DecisionLogger.get_losing_decisions() to retrieve failures.
|
||||||
try:
|
|
||||||
rows = conn.execute(
|
|
||||||
"""
|
"""
|
||||||
SELECT stock_code, action, confidence, pnl, rationale, timestamp
|
losing_decisions = self._decision_logger.get_losing_decisions(
|
||||||
FROM trades
|
min_confidence=80, min_loss=-100.0
|
||||||
WHERE confidence >= 80 AND pnl < 0
|
)
|
||||||
ORDER BY pnl ASC
|
|
||||||
LIMIT ?
|
# Limit results
|
||||||
""",
|
if len(losing_decisions) > limit:
|
||||||
(limit,),
|
losing_decisions = losing_decisions[:limit]
|
||||||
).fetchall()
|
|
||||||
return [dict(r) for r in rows]
|
# Convert to dict format for analysis
|
||||||
finally:
|
failures = []
|
||||||
conn.close()
|
for decision in losing_decisions:
|
||||||
|
failures.append({
|
||||||
|
"decision_id": decision.decision_id,
|
||||||
|
"timestamp": decision.timestamp,
|
||||||
|
"stock_code": decision.stock_code,
|
||||||
|
"market": decision.market,
|
||||||
|
"exchange_code": decision.exchange_code,
|
||||||
|
"action": decision.action,
|
||||||
|
"confidence": decision.confidence,
|
||||||
|
"rationale": decision.rationale,
|
||||||
|
"outcome_pnl": decision.outcome_pnl,
|
||||||
|
"outcome_accuracy": decision.outcome_accuracy,
|
||||||
|
"context_snapshot": decision.context_snapshot,
|
||||||
|
"input_data": decision.input_data,
|
||||||
|
})
|
||||||
|
|
||||||
|
return failures
|
||||||
|
|
||||||
|
def identify_failure_patterns(
|
||||||
|
self, failures: list[dict[str, Any]]
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Identify patterns in losing decisions.
|
||||||
|
|
||||||
|
Analyzes:
|
||||||
|
- Time patterns (hour of day, day of week)
|
||||||
|
- Market conditions (volatility, volume)
|
||||||
|
- Stock characteristics (price range, market)
|
||||||
|
- Common failure modes in rationale
|
||||||
|
"""
|
||||||
|
if not failures:
|
||||||
|
return {"pattern_count": 0, "patterns": {}}
|
||||||
|
|
||||||
|
patterns = {
|
||||||
|
"markets": Counter(),
|
||||||
|
"actions": Counter(),
|
||||||
|
"hours": Counter(),
|
||||||
|
"avg_confidence": 0.0,
|
||||||
|
"avg_loss": 0.0,
|
||||||
|
"total_failures": len(failures),
|
||||||
|
}
|
||||||
|
|
||||||
|
total_confidence = 0
|
||||||
|
total_loss = 0.0
|
||||||
|
|
||||||
|
for failure in failures:
|
||||||
|
# Market distribution
|
||||||
|
patterns["markets"][failure.get("market", "UNKNOWN")] += 1
|
||||||
|
|
||||||
|
# Action distribution
|
||||||
|
patterns["actions"][failure.get("action", "UNKNOWN")] += 1
|
||||||
|
|
||||||
|
# Time pattern (extract hour from ISO timestamp)
|
||||||
|
timestamp = failure.get("timestamp", "")
|
||||||
|
if timestamp:
|
||||||
|
try:
|
||||||
|
dt = datetime.fromisoformat(timestamp)
|
||||||
|
patterns["hours"][dt.hour] += 1
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Aggregate metrics
|
||||||
|
total_confidence += failure.get("confidence", 0)
|
||||||
|
total_loss += failure.get("outcome_pnl", 0.0)
|
||||||
|
|
||||||
|
patterns["avg_confidence"] = (
|
||||||
|
round(total_confidence / len(failures), 2) if failures else 0.0
|
||||||
|
)
|
||||||
|
patterns["avg_loss"] = (
|
||||||
|
round(total_loss / len(failures), 2) if failures else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert Counters to regular dicts for JSON serialization
|
||||||
|
patterns["markets"] = dict(patterns["markets"])
|
||||||
|
patterns["actions"] = dict(patterns["actions"])
|
||||||
|
patterns["hours"] = dict(patterns["hours"])
|
||||||
|
|
||||||
|
return patterns
|
||||||
|
|
||||||
def get_performance_summary(self) -> dict[str, Any]:
|
def get_performance_summary(self) -> dict[str, Any]:
|
||||||
"""Return aggregate performance metrics from trade logs."""
|
"""Return aggregate performance metrics from trade logs."""
|
||||||
@@ -109,14 +188,25 @@ class EvolutionOptimizer:
|
|||||||
async def generate_strategy(self, failures: list[dict[str, Any]]) -> Path | None:
|
async def generate_strategy(self, failures: list[dict[str, Any]]) -> Path | None:
|
||||||
"""Ask Gemini to generate a new strategy based on failure analysis.
|
"""Ask Gemini to generate a new strategy based on failure analysis.
|
||||||
|
|
||||||
|
Integrates failure patterns and market conditions to create improved strategies.
|
||||||
Returns the path to the generated strategy file, or None on failure.
|
Returns the path to the generated strategy file, or None on failure.
|
||||||
"""
|
"""
|
||||||
|
# Identify failure patterns first
|
||||||
|
patterns = self.identify_failure_patterns(failures)
|
||||||
|
|
||||||
prompt = (
|
prompt = (
|
||||||
"You are a quantitative trading strategy developer.\n"
|
"You are a quantitative trading strategy developer.\n"
|
||||||
"Analyze these failed trades and generate an improved strategy.\n\n"
|
"Analyze these failed trades and their patterns, then generate an improved strategy.\n\n"
|
||||||
f"Failed trades:\n{json.dumps(failures, indent=2, default=str)}\n\n"
|
f"Failure Patterns:\n{json.dumps(patterns, indent=2)}\n\n"
|
||||||
"Generate a Python class that inherits from BaseStrategy.\n"
|
f"Sample Failed Trades (first 5):\n"
|
||||||
"The class must have an `evaluate(self, market_data: dict) -> dict` method.\n"
|
f"{json.dumps(failures[:5], indent=2, default=str)}\n\n"
|
||||||
|
"Based on these patterns, generate an improved trading strategy.\n"
|
||||||
|
"The strategy should:\n"
|
||||||
|
"1. Avoid the identified failure patterns\n"
|
||||||
|
"2. Consider market-specific conditions\n"
|
||||||
|
"3. Adjust confidence based on historical performance\n\n"
|
||||||
|
"Generate a Python method body that inherits from BaseStrategy.\n"
|
||||||
|
"The method signature is: evaluate(self, market_data: dict) -> dict\n"
|
||||||
"The method must return a dict with keys: action, confidence, rationale.\n"
|
"The method must return a dict with keys: action, confidence, rationale.\n"
|
||||||
"Respond with ONLY the method body (Python code), no class definition.\n"
|
"Respond with ONLY the method body (Python code), no class definition.\n"
|
||||||
)
|
)
|
||||||
@@ -147,10 +237,15 @@ class EvolutionOptimizer:
|
|||||||
# Indent the body for the class method
|
# Indent the body for the class method
|
||||||
indented_body = textwrap.indent(body, " ")
|
indented_body = textwrap.indent(body, " ")
|
||||||
|
|
||||||
|
# Generate rationale from patterns
|
||||||
|
rationale = f"Auto-evolved from {len(failures)} failures. "
|
||||||
|
rationale += f"Primary failure markets: {list(patterns.get('markets', {}).keys())}. "
|
||||||
|
rationale += f"Average loss: {patterns.get('avg_loss', 0.0)}"
|
||||||
|
|
||||||
content = STRATEGY_TEMPLATE.format(
|
content = STRATEGY_TEMPLATE.format(
|
||||||
name=version,
|
name=version,
|
||||||
timestamp=datetime.now(UTC).isoformat(),
|
timestamp=datetime.now(UTC).isoformat(),
|
||||||
rationale="Auto-evolved from failure analysis",
|
rationale=rationale,
|
||||||
class_name=class_name,
|
class_name=class_name,
|
||||||
body=indented_body.strip(),
|
body=indented_body.strip(),
|
||||||
)
|
)
|
||||||
|
|||||||
303
src/evolution/performance_tracker.py
Normal file
303
src/evolution/performance_tracker.py
Normal file
@@ -0,0 +1,303 @@
|
|||||||
|
"""Performance tracking system for strategy monitoring.
|
||||||
|
|
||||||
|
Tracks win rates, monitors improvement over time,
|
||||||
|
and provides performance metrics dashboard.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sqlite3
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StrategyMetrics:
|
||||||
|
"""Performance metrics for a strategy over a time period."""
|
||||||
|
|
||||||
|
strategy_name: str
|
||||||
|
period_start: str
|
||||||
|
period_end: str
|
||||||
|
total_trades: int
|
||||||
|
wins: int
|
||||||
|
losses: int
|
||||||
|
holds: int
|
||||||
|
win_rate: float
|
||||||
|
avg_pnl: float
|
||||||
|
total_pnl: float
|
||||||
|
best_trade: float
|
||||||
|
worst_trade: float
|
||||||
|
avg_confidence: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PerformanceDashboard:
|
||||||
|
"""Comprehensive performance dashboard."""
|
||||||
|
|
||||||
|
generated_at: str
|
||||||
|
overall_metrics: StrategyMetrics
|
||||||
|
daily_metrics: list[StrategyMetrics]
|
||||||
|
weekly_metrics: list[StrategyMetrics]
|
||||||
|
improvement_trend: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class PerformanceTracker:
|
||||||
|
"""Tracks and monitors strategy performance over time."""
|
||||||
|
|
||||||
|
def __init__(self, db_path: str) -> None:
|
||||||
|
"""Initialize performance tracker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path to the trade logs database
|
||||||
|
"""
|
||||||
|
self._db_path = db_path
|
||||||
|
|
||||||
|
def get_strategy_metrics(
|
||||||
|
self,
|
||||||
|
strategy_name: str | None = None,
|
||||||
|
start_date: str | None = None,
|
||||||
|
end_date: str | None = None,
|
||||||
|
) -> StrategyMetrics:
|
||||||
|
"""Get performance metrics for a strategy over a time period.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
strategy_name: Name of the strategy (None = all strategies)
|
||||||
|
start_date: Start date in ISO format (None = beginning of time)
|
||||||
|
end_date: End date in ISO format (None = now)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StrategyMetrics object with performance data
|
||||||
|
"""
|
||||||
|
conn = sqlite3.connect(self._db_path)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Build query with optional filters
|
||||||
|
query = """
|
||||||
|
SELECT
|
||||||
|
COUNT(*) as total_trades,
|
||||||
|
SUM(CASE WHEN pnl > 0 THEN 1 ELSE 0 END) as wins,
|
||||||
|
SUM(CASE WHEN pnl < 0 THEN 1 ELSE 0 END) as losses,
|
||||||
|
SUM(CASE WHEN action = 'HOLD' THEN 1 ELSE 0 END) as holds,
|
||||||
|
COALESCE(AVG(CASE WHEN pnl IS NOT NULL THEN pnl END), 0) as avg_pnl,
|
||||||
|
COALESCE(SUM(CASE WHEN pnl IS NOT NULL THEN pnl ELSE 0 END), 0) as total_pnl,
|
||||||
|
COALESCE(MAX(pnl), 0) as best_trade,
|
||||||
|
COALESCE(MIN(pnl), 0) as worst_trade,
|
||||||
|
COALESCE(AVG(confidence), 0) as avg_confidence,
|
||||||
|
MIN(timestamp) as period_start,
|
||||||
|
MAX(timestamp) as period_end
|
||||||
|
FROM trades
|
||||||
|
WHERE 1=1
|
||||||
|
"""
|
||||||
|
params: list[Any] = []
|
||||||
|
|
||||||
|
if start_date:
|
||||||
|
query += " AND timestamp >= ?"
|
||||||
|
params.append(start_date)
|
||||||
|
|
||||||
|
if end_date:
|
||||||
|
query += " AND timestamp <= ?"
|
||||||
|
params.append(end_date)
|
||||||
|
|
||||||
|
# Note: Currently trades table doesn't have strategy_name column
|
||||||
|
# This is a placeholder for future extension
|
||||||
|
|
||||||
|
row = conn.execute(query, params).fetchone()
|
||||||
|
|
||||||
|
total_trades = row["total_trades"] or 0
|
||||||
|
wins = row["wins"] or 0
|
||||||
|
win_rate = (wins / total_trades * 100) if total_trades > 0 else 0.0
|
||||||
|
|
||||||
|
return StrategyMetrics(
|
||||||
|
strategy_name=strategy_name or "default",
|
||||||
|
period_start=row["period_start"] or "",
|
||||||
|
period_end=row["period_end"] or "",
|
||||||
|
total_trades=total_trades,
|
||||||
|
wins=wins,
|
||||||
|
losses=row["losses"] or 0,
|
||||||
|
holds=row["holds"] or 0,
|
||||||
|
win_rate=round(win_rate, 2),
|
||||||
|
avg_pnl=round(row["avg_pnl"], 2),
|
||||||
|
total_pnl=round(row["total_pnl"], 2),
|
||||||
|
best_trade=round(row["best_trade"], 2),
|
||||||
|
worst_trade=round(row["worst_trade"], 2),
|
||||||
|
avg_confidence=round(row["avg_confidence"], 2),
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def get_daily_metrics(
|
||||||
|
self, days: int = 7, strategy_name: str | None = None
|
||||||
|
) -> list[StrategyMetrics]:
|
||||||
|
"""Get daily performance metrics for the last N days.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
days: Number of days to retrieve (default 7)
|
||||||
|
strategy_name: Name of the strategy (None = all strategies)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of StrategyMetrics, one per day
|
||||||
|
"""
|
||||||
|
metrics = []
|
||||||
|
end_date = datetime.now(UTC)
|
||||||
|
|
||||||
|
for i in range(days):
|
||||||
|
day_end = end_date - timedelta(days=i)
|
||||||
|
day_start = day_end - timedelta(days=1)
|
||||||
|
|
||||||
|
day_metrics = self.get_strategy_metrics(
|
||||||
|
strategy_name=strategy_name,
|
||||||
|
start_date=day_start.isoformat(),
|
||||||
|
end_date=day_end.isoformat(),
|
||||||
|
)
|
||||||
|
metrics.append(day_metrics)
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def get_weekly_metrics(
|
||||||
|
self, weeks: int = 4, strategy_name: str | None = None
|
||||||
|
) -> list[StrategyMetrics]:
|
||||||
|
"""Get weekly performance metrics for the last N weeks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weeks: Number of weeks to retrieve (default 4)
|
||||||
|
strategy_name: Name of the strategy (None = all strategies)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of StrategyMetrics, one per week
|
||||||
|
"""
|
||||||
|
metrics = []
|
||||||
|
end_date = datetime.now(UTC)
|
||||||
|
|
||||||
|
for i in range(weeks):
|
||||||
|
week_end = end_date - timedelta(weeks=i)
|
||||||
|
week_start = week_end - timedelta(weeks=1)
|
||||||
|
|
||||||
|
week_metrics = self.get_strategy_metrics(
|
||||||
|
strategy_name=strategy_name,
|
||||||
|
start_date=week_start.isoformat(),
|
||||||
|
end_date=week_end.isoformat(),
|
||||||
|
)
|
||||||
|
metrics.append(week_metrics)
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def calculate_improvement_trend(
|
||||||
|
self, metrics_history: list[StrategyMetrics]
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Calculate improvement trend from historical metrics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metrics_history: List of StrategyMetrics ordered from oldest to newest
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with trend analysis
|
||||||
|
"""
|
||||||
|
if len(metrics_history) < 2:
|
||||||
|
return {
|
||||||
|
"trend": "insufficient_data",
|
||||||
|
"win_rate_change": 0.0,
|
||||||
|
"pnl_change": 0.0,
|
||||||
|
"confidence_change": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
oldest = metrics_history[0]
|
||||||
|
newest = metrics_history[-1]
|
||||||
|
|
||||||
|
win_rate_change = newest.win_rate - oldest.win_rate
|
||||||
|
pnl_change = newest.avg_pnl - oldest.avg_pnl
|
||||||
|
confidence_change = newest.avg_confidence - oldest.avg_confidence
|
||||||
|
|
||||||
|
# Determine overall trend
|
||||||
|
if win_rate_change > 5.0 and pnl_change > 0:
|
||||||
|
trend = "improving"
|
||||||
|
elif win_rate_change < -5.0 or pnl_change < 0:
|
||||||
|
trend = "declining"
|
||||||
|
else:
|
||||||
|
trend = "stable"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"trend": trend,
|
||||||
|
"win_rate_change": round(win_rate_change, 2),
|
||||||
|
"pnl_change": round(pnl_change, 2),
|
||||||
|
"confidence_change": round(confidence_change, 2),
|
||||||
|
"period_count": len(metrics_history),
|
||||||
|
}
|
||||||
|
|
||||||
|
def generate_dashboard(
|
||||||
|
self, strategy_name: str | None = None
|
||||||
|
) -> PerformanceDashboard:
|
||||||
|
"""Generate a comprehensive performance dashboard.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
strategy_name: Name of the strategy (None = all strategies)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PerformanceDashboard with all metrics
|
||||||
|
"""
|
||||||
|
# Get overall metrics
|
||||||
|
overall_metrics = self.get_strategy_metrics(strategy_name=strategy_name)
|
||||||
|
|
||||||
|
# Get daily metrics (last 7 days)
|
||||||
|
daily_metrics = self.get_daily_metrics(days=7, strategy_name=strategy_name)
|
||||||
|
|
||||||
|
# Get weekly metrics (last 4 weeks)
|
||||||
|
weekly_metrics = self.get_weekly_metrics(weeks=4, strategy_name=strategy_name)
|
||||||
|
|
||||||
|
# Calculate improvement trend
|
||||||
|
improvement_trend = self.calculate_improvement_trend(weekly_metrics[::-1])
|
||||||
|
|
||||||
|
return PerformanceDashboard(
|
||||||
|
generated_at=datetime.now(UTC).isoformat(),
|
||||||
|
overall_metrics=overall_metrics,
|
||||||
|
daily_metrics=daily_metrics,
|
||||||
|
weekly_metrics=weekly_metrics,
|
||||||
|
improvement_trend=improvement_trend,
|
||||||
|
)
|
||||||
|
|
||||||
|
def export_dashboard_json(
|
||||||
|
self, dashboard: PerformanceDashboard
|
||||||
|
) -> str:
|
||||||
|
"""Export dashboard as JSON string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dashboard: PerformanceDashboard object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON string representation
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"generated_at": dashboard.generated_at,
|
||||||
|
"overall_metrics": asdict(dashboard.overall_metrics),
|
||||||
|
"daily_metrics": [asdict(m) for m in dashboard.daily_metrics],
|
||||||
|
"weekly_metrics": [asdict(m) for m in dashboard.weekly_metrics],
|
||||||
|
"improvement_trend": dashboard.improvement_trend,
|
||||||
|
}
|
||||||
|
return json.dumps(data, indent=2)
|
||||||
|
|
||||||
|
def log_dashboard(self, dashboard: PerformanceDashboard) -> None:
|
||||||
|
"""Log dashboard summary to logger.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dashboard: PerformanceDashboard object
|
||||||
|
"""
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("PERFORMANCE DASHBOARD")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("Generated: %s", dashboard.generated_at)
|
||||||
|
logger.info("")
|
||||||
|
logger.info("Overall Performance:")
|
||||||
|
logger.info(" Total Trades: %d", dashboard.overall_metrics.total_trades)
|
||||||
|
logger.info(" Win Rate: %.2f%%", dashboard.overall_metrics.win_rate)
|
||||||
|
logger.info(" Average P&L: %.2f", dashboard.overall_metrics.avg_pnl)
|
||||||
|
logger.info(" Total P&L: %.2f", dashboard.overall_metrics.total_pnl)
|
||||||
|
logger.info("")
|
||||||
|
logger.info("Improvement Trend (%s):", dashboard.improvement_trend["trend"])
|
||||||
|
logger.info(" Win Rate Change: %+.2f%%", dashboard.improvement_trend["win_rate_change"])
|
||||||
|
logger.info(" P&L Change: %+.2f", dashboard.improvement_trend["pnl_change"])
|
||||||
|
logger.info("=" * 60)
|
||||||
88
src/main.py
88
src/main.py
@@ -19,7 +19,10 @@ from src.brain.gemini_client import GeminiClient
|
|||||||
from src.broker.kis_api import KISBroker
|
from src.broker.kis_api import KISBroker
|
||||||
from src.broker.overseas import OverseasBroker
|
from src.broker.overseas import OverseasBroker
|
||||||
from src.config import Settings
|
from src.config import Settings
|
||||||
|
from src.context.layer import ContextLayer
|
||||||
from src.context.store import ContextStore
|
from src.context.store import ContextStore
|
||||||
|
from src.core.criticality import CriticalityAssessor
|
||||||
|
from src.core.priority_queue import PriorityTaskQueue
|
||||||
from src.core.risk_manager import CircuitBreakerTripped, RiskManager
|
from src.core.risk_manager import CircuitBreakerTripped, RiskManager
|
||||||
from src.db import init_db, log_trade
|
from src.db import init_db, log_trade
|
||||||
from src.logging.decision_logger import DecisionLogger
|
from src.logging.decision_logger import DecisionLogger
|
||||||
@@ -57,10 +60,14 @@ async def trading_cycle(
|
|||||||
risk: RiskManager,
|
risk: RiskManager,
|
||||||
db_conn: Any,
|
db_conn: Any,
|
||||||
decision_logger: DecisionLogger,
|
decision_logger: DecisionLogger,
|
||||||
|
context_store: ContextStore,
|
||||||
|
criticality_assessor: CriticalityAssessor,
|
||||||
market: MarketInfo,
|
market: MarketInfo,
|
||||||
stock_code: str,
|
stock_code: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute one trading cycle for a single stock."""
|
"""Execute one trading cycle for a single stock."""
|
||||||
|
cycle_start_time = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
# 1. Fetch market data
|
# 1. Fetch market data
|
||||||
if market.is_domestic:
|
if market.is_domestic:
|
||||||
orderbook = await broker.get_orderbook(stock_code)
|
orderbook = await broker.get_orderbook(stock_code)
|
||||||
@@ -106,6 +113,42 @@ async def trading_cycle(
|
|||||||
"foreigner_net": foreigner_net,
|
"foreigner_net": foreigner_net,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 1.5. Get volatility metrics from context store (L7_REALTIME)
|
||||||
|
latest_timeframe = context_store.get_latest_timeframe(ContextLayer.L7_REALTIME)
|
||||||
|
volatility_score = 50.0 # Default normal volatility
|
||||||
|
volume_surge = 1.0
|
||||||
|
price_change_1m = 0.0
|
||||||
|
|
||||||
|
if latest_timeframe:
|
||||||
|
volatility_data = context_store.get_context(
|
||||||
|
ContextLayer.L7_REALTIME,
|
||||||
|
latest_timeframe,
|
||||||
|
f"volatility_{stock_code}",
|
||||||
|
)
|
||||||
|
if volatility_data:
|
||||||
|
volatility_score = volatility_data.get("momentum_score", 50.0)
|
||||||
|
volume_surge = volatility_data.get("volume_surge", 1.0)
|
||||||
|
price_change_1m = volatility_data.get("price_change_1m", 0.0)
|
||||||
|
|
||||||
|
# 1.6. Assess criticality based on market conditions
|
||||||
|
criticality = criticality_assessor.assess_market_conditions(
|
||||||
|
pnl_pct=pnl_pct,
|
||||||
|
volatility_score=volatility_score,
|
||||||
|
volume_surge=volume_surge,
|
||||||
|
price_change_1m=price_change_1m,
|
||||||
|
is_market_open=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Criticality for %s (%s): %s (pnl=%.2f%%, volatility=%.1f, volume_surge=%.1fx)",
|
||||||
|
stock_code,
|
||||||
|
market.name,
|
||||||
|
criticality.value,
|
||||||
|
pnl_pct,
|
||||||
|
volatility_score,
|
||||||
|
volume_surge,
|
||||||
|
)
|
||||||
|
|
||||||
# 2. Ask the brain for a decision
|
# 2. Ask the brain for a decision
|
||||||
decision = await brain.decide(market_data)
|
decision = await brain.decide(market_data)
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -191,6 +234,27 @@ async def trading_cycle(
|
|||||||
exchange_code=market.exchange_code,
|
exchange_code=market.exchange_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 7. Latency monitoring
|
||||||
|
cycle_end_time = asyncio.get_event_loop().time()
|
||||||
|
cycle_latency = cycle_end_time - cycle_start_time
|
||||||
|
timeout = criticality_assessor.get_timeout(criticality)
|
||||||
|
|
||||||
|
if timeout and cycle_latency > timeout:
|
||||||
|
logger.warning(
|
||||||
|
"Trading cycle exceeded timeout for %s (criticality=%s, latency=%.2fs, timeout=%.2fs)",
|
||||||
|
stock_code,
|
||||||
|
criticality.value,
|
||||||
|
cycle_latency,
|
||||||
|
timeout,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"Trading cycle completed within timeout for %s (criticality=%s, latency=%.2fs)",
|
||||||
|
stock_code,
|
||||||
|
criticality.value,
|
||||||
|
cycle_latency,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def run(settings: Settings) -> None:
|
async def run(settings: Settings) -> None:
|
||||||
"""Main async loop — iterate over open markets on a timer."""
|
"""Main async loop — iterate over open markets on a timer."""
|
||||||
@@ -212,6 +276,16 @@ async def run(settings: Settings) -> None:
|
|||||||
top_n=5,
|
top_n=5,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Initialize latency control system
|
||||||
|
criticality_assessor = CriticalityAssessor(
|
||||||
|
critical_pnl_threshold=-2.5, # Near circuit breaker at -3.0%
|
||||||
|
critical_price_change_threshold=5.0, # 5% in 1 minute
|
||||||
|
critical_volume_surge_threshold=10.0, # 10x average
|
||||||
|
high_volatility_threshold=70.0,
|
||||||
|
low_volatility_threshold=30.0,
|
||||||
|
)
|
||||||
|
priority_queue = PriorityTaskQueue(max_size=1000)
|
||||||
|
|
||||||
# Track last scan time for each market
|
# Track last scan time for each market
|
||||||
last_scan_time: dict[str, float] = {}
|
last_scan_time: dict[str, float] = {}
|
||||||
|
|
||||||
@@ -315,6 +389,8 @@ async def run(settings: Settings) -> None:
|
|||||||
risk,
|
risk,
|
||||||
db_conn,
|
db_conn,
|
||||||
decision_logger,
|
decision_logger,
|
||||||
|
context_store,
|
||||||
|
criticality_assessor,
|
||||||
market,
|
market,
|
||||||
stock_code,
|
stock_code,
|
||||||
)
|
)
|
||||||
@@ -343,6 +419,18 @@ async def run(settings: Settings) -> None:
|
|||||||
logger.exception("Unexpected error for %s: %s", stock_code, exc)
|
logger.exception("Unexpected error for %s: %s", stock_code, exc)
|
||||||
break # Don't retry on unexpected errors
|
break # Don't retry on unexpected errors
|
||||||
|
|
||||||
|
# Log priority queue metrics periodically
|
||||||
|
metrics = await priority_queue.get_metrics()
|
||||||
|
if metrics.total_enqueued > 0:
|
||||||
|
logger.info(
|
||||||
|
"Priority queue metrics: enqueued=%d, dequeued=%d, size=%d, timeouts=%d, errors=%d",
|
||||||
|
metrics.total_enqueued,
|
||||||
|
metrics.total_dequeued,
|
||||||
|
metrics.current_size,
|
||||||
|
metrics.total_timeouts,
|
||||||
|
metrics.total_errors,
|
||||||
|
)
|
||||||
|
|
||||||
# Wait for next cycle or shutdown
|
# Wait for next cycle or shutdown
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(shutdown.wait(), timeout=TRADE_INTERVAL_SECONDS)
|
await asyncio.wait_for(shutdown.wait(), timeout=TRADE_INTERVAL_SECONDS)
|
||||||
|
|||||||
685
tests/test_evolution.py
Normal file
685
tests/test_evolution.py
Normal file
@@ -0,0 +1,685 @@
|
|||||||
|
"""Tests for the Evolution Engine components.
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- EvolutionOptimizer: failure analysis and strategy generation
|
||||||
|
- ABTester: A/B testing and statistical comparison
|
||||||
|
- PerformanceTracker: metrics tracking and dashboard
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sqlite3
|
||||||
|
import tempfile
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.config import Settings
|
||||||
|
from src.db import init_db, log_trade
|
||||||
|
from src.evolution.ab_test import ABTester
|
||||||
|
from src.evolution.optimizer import EvolutionOptimizer
|
||||||
|
from src.evolution.performance_tracker import (
|
||||||
|
PerformanceDashboard,
|
||||||
|
PerformanceTracker,
|
||||||
|
StrategyMetrics,
|
||||||
|
)
|
||||||
|
from src.logging.decision_logger import DecisionLogger
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_conn() -> sqlite3.Connection:
|
||||||
|
"""Provide an in-memory database with initialized schema."""
|
||||||
|
return init_db(":memory:")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def settings() -> Settings:
|
||||||
|
"""Provide test settings."""
|
||||||
|
return Settings(
|
||||||
|
KIS_APP_KEY="test_key",
|
||||||
|
KIS_APP_SECRET="test_secret",
|
||||||
|
KIS_ACCOUNT_NO="12345678-01",
|
||||||
|
GEMINI_API_KEY="test_gemini_key",
|
||||||
|
GEMINI_MODEL="gemini-pro",
|
||||||
|
DB_PATH=":memory:",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def optimizer(settings: Settings) -> EvolutionOptimizer:
|
||||||
|
"""Provide an EvolutionOptimizer instance."""
|
||||||
|
return EvolutionOptimizer(settings)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def decision_logger(db_conn: sqlite3.Connection) -> DecisionLogger:
|
||||||
|
"""Provide a DecisionLogger instance."""
|
||||||
|
return DecisionLogger(db_conn)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def ab_tester() -> ABTester:
|
||||||
|
"""Provide an ABTester instance."""
|
||||||
|
return ABTester(significance_level=0.05)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def performance_tracker(settings: Settings) -> PerformanceTracker:
|
||||||
|
"""Provide a PerformanceTracker instance."""
|
||||||
|
return PerformanceTracker(db_path=":memory:")
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# EvolutionOptimizer Tests
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_analyze_failures_uses_decision_logger(optimizer: EvolutionOptimizer) -> None:
|
||||||
|
"""Test that analyze_failures uses DecisionLogger.get_losing_decisions()."""
|
||||||
|
# Add some losing decisions to the database
|
||||||
|
logger = optimizer._decision_logger
|
||||||
|
|
||||||
|
# High-confidence loss
|
||||||
|
id1 = logger.log_decision(
|
||||||
|
stock_code="005930",
|
||||||
|
market="KR",
|
||||||
|
exchange_code="KRX",
|
||||||
|
action="BUY",
|
||||||
|
confidence=85,
|
||||||
|
rationale="Expected growth",
|
||||||
|
context_snapshot={"L1": {"price": 70000}},
|
||||||
|
input_data={"price": 70000, "volume": 1000},
|
||||||
|
)
|
||||||
|
logger.update_outcome(id1, pnl=-2000.0, accuracy=0)
|
||||||
|
|
||||||
|
# Another high-confidence loss
|
||||||
|
id2 = logger.log_decision(
|
||||||
|
stock_code="000660",
|
||||||
|
market="KR",
|
||||||
|
exchange_code="KRX",
|
||||||
|
action="SELL",
|
||||||
|
confidence=90,
|
||||||
|
rationale="Expected drop",
|
||||||
|
context_snapshot={"L1": {"price": 100000}},
|
||||||
|
input_data={"price": 100000, "volume": 500},
|
||||||
|
)
|
||||||
|
logger.update_outcome(id2, pnl=-1500.0, accuracy=0)
|
||||||
|
|
||||||
|
# Low-confidence loss (should be ignored)
|
||||||
|
id3 = logger.log_decision(
|
||||||
|
stock_code="035420",
|
||||||
|
market="KR",
|
||||||
|
exchange_code="KRX",
|
||||||
|
action="HOLD",
|
||||||
|
confidence=70,
|
||||||
|
rationale="Uncertain",
|
||||||
|
context_snapshot={},
|
||||||
|
input_data={},
|
||||||
|
)
|
||||||
|
logger.update_outcome(id3, pnl=-500.0, accuracy=0)
|
||||||
|
|
||||||
|
# Analyze failures
|
||||||
|
failures = optimizer.analyze_failures(limit=10)
|
||||||
|
|
||||||
|
# Should get 2 failures (confidence >= 80)
|
||||||
|
assert len(failures) == 2
|
||||||
|
assert all(f["confidence"] >= 80 for f in failures)
|
||||||
|
assert all(f["outcome_pnl"] <= -100.0 for f in failures)
|
||||||
|
|
||||||
|
|
||||||
|
def test_analyze_failures_empty_database(optimizer: EvolutionOptimizer) -> None:
|
||||||
|
"""Test analyze_failures with no losing decisions."""
|
||||||
|
failures = optimizer.analyze_failures()
|
||||||
|
assert failures == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_identify_failure_patterns(optimizer: EvolutionOptimizer) -> None:
|
||||||
|
"""Test identification of failure patterns."""
|
||||||
|
failures = [
|
||||||
|
{
|
||||||
|
"decision_id": "1",
|
||||||
|
"timestamp": "2024-01-15T09:30:00+00:00",
|
||||||
|
"stock_code": "005930",
|
||||||
|
"market": "KR",
|
||||||
|
"exchange_code": "KRX",
|
||||||
|
"action": "BUY",
|
||||||
|
"confidence": 85,
|
||||||
|
"rationale": "Test",
|
||||||
|
"outcome_pnl": -1000.0,
|
||||||
|
"outcome_accuracy": 0,
|
||||||
|
"context_snapshot": {},
|
||||||
|
"input_data": {},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"decision_id": "2",
|
||||||
|
"timestamp": "2024-01-15T14:30:00+00:00",
|
||||||
|
"stock_code": "000660",
|
||||||
|
"market": "KR",
|
||||||
|
"exchange_code": "KRX",
|
||||||
|
"action": "SELL",
|
||||||
|
"confidence": 90,
|
||||||
|
"rationale": "Test",
|
||||||
|
"outcome_pnl": -2000.0,
|
||||||
|
"outcome_accuracy": 0,
|
||||||
|
"context_snapshot": {},
|
||||||
|
"input_data": {},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"decision_id": "3",
|
||||||
|
"timestamp": "2024-01-15T09:45:00+00:00",
|
||||||
|
"stock_code": "035420",
|
||||||
|
"market": "US_NASDAQ",
|
||||||
|
"exchange_code": "NASDAQ",
|
||||||
|
"action": "BUY",
|
||||||
|
"confidence": 80,
|
||||||
|
"rationale": "Test",
|
||||||
|
"outcome_pnl": -500.0,
|
||||||
|
"outcome_accuracy": 0,
|
||||||
|
"context_snapshot": {},
|
||||||
|
"input_data": {},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
patterns = optimizer.identify_failure_patterns(failures)
|
||||||
|
|
||||||
|
assert patterns["total_failures"] == 3
|
||||||
|
assert patterns["markets"]["KR"] == 2
|
||||||
|
assert patterns["markets"]["US_NASDAQ"] == 1
|
||||||
|
assert patterns["actions"]["BUY"] == 2
|
||||||
|
assert patterns["actions"]["SELL"] == 1
|
||||||
|
assert 9 in patterns["hours"] # 09:30 and 09:45
|
||||||
|
assert 14 in patterns["hours"] # 14:30
|
||||||
|
assert patterns["avg_confidence"] == 85.0
|
||||||
|
assert patterns["avg_loss"] == -1166.67
|
||||||
|
|
||||||
|
|
||||||
|
def test_identify_failure_patterns_empty(optimizer: EvolutionOptimizer) -> None:
|
||||||
|
"""Test pattern identification with no failures."""
|
||||||
|
patterns = optimizer.identify_failure_patterns([])
|
||||||
|
assert patterns["pattern_count"] == 0
|
||||||
|
assert patterns["patterns"] == {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_strategy_creates_file(optimizer: EvolutionOptimizer, tmp_path: Path) -> None:
|
||||||
|
"""Test that generate_strategy creates a strategy file."""
|
||||||
|
failures = [
|
||||||
|
{
|
||||||
|
"decision_id": "1",
|
||||||
|
"timestamp": "2024-01-15T09:30:00+00:00",
|
||||||
|
"stock_code": "005930",
|
||||||
|
"market": "KR",
|
||||||
|
"action": "BUY",
|
||||||
|
"confidence": 85,
|
||||||
|
"outcome_pnl": -1000.0,
|
||||||
|
"context_snapshot": {},
|
||||||
|
"input_data": {},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock Gemini response
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.text = """
|
||||||
|
# Simple strategy
|
||||||
|
price = market_data.get("current_price", 0)
|
||||||
|
if price > 50000:
|
||||||
|
return {"action": "BUY", "confidence": 70, "rationale": "Price above threshold"}
|
||||||
|
return {"action": "HOLD", "confidence": 50, "rationale": "Waiting"}
|
||||||
|
"""
|
||||||
|
|
||||||
|
with patch.object(optimizer._client.aio.models, "generate_content", new=AsyncMock(return_value=mock_response)):
|
||||||
|
with patch("src.evolution.optimizer.STRATEGIES_DIR", tmp_path):
|
||||||
|
strategy_path = await optimizer.generate_strategy(failures)
|
||||||
|
|
||||||
|
assert strategy_path is not None
|
||||||
|
assert strategy_path.exists()
|
||||||
|
assert strategy_path.suffix == ".py"
|
||||||
|
assert "class Strategy_" in strategy_path.read_text()
|
||||||
|
assert "def evaluate" in strategy_path.read_text()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_strategy_handles_api_error(optimizer: EvolutionOptimizer) -> None:
|
||||||
|
"""Test that generate_strategy handles Gemini API errors gracefully."""
|
||||||
|
failures = [{"decision_id": "1", "timestamp": "2024-01-15T09:30:00+00:00"}]
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
optimizer._client.aio.models,
|
||||||
|
"generate_content",
|
||||||
|
side_effect=Exception("API Error"),
|
||||||
|
):
|
||||||
|
strategy_path = await optimizer.generate_strategy(failures)
|
||||||
|
|
||||||
|
assert strategy_path is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_performance_summary() -> None:
|
||||||
|
"""Test getting performance summary from trades table."""
|
||||||
|
# Create a temporary database with trades
|
||||||
|
import tempfile
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
|
||||||
|
tmp_path = tmp.name
|
||||||
|
|
||||||
|
conn = init_db(tmp_path)
|
||||||
|
log_trade(conn, "005930", "BUY", 85, "Test win", quantity=10, price=70000, pnl=1000.0)
|
||||||
|
log_trade(conn, "000660", "SELL", 90, "Test loss", quantity=5, price=100000, pnl=-500.0)
|
||||||
|
log_trade(conn, "035420", "BUY", 80, "Test win", quantity=8, price=50000, pnl=800.0)
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
# Create settings with temp database path
|
||||||
|
settings = Settings(
|
||||||
|
KIS_APP_KEY="test_key",
|
||||||
|
KIS_APP_SECRET="test_secret",
|
||||||
|
KIS_ACCOUNT_NO="12345678-01",
|
||||||
|
GEMINI_API_KEY="test_gemini_key",
|
||||||
|
GEMINI_MODEL="gemini-pro",
|
||||||
|
DB_PATH=tmp_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer = EvolutionOptimizer(settings)
|
||||||
|
summary = optimizer.get_performance_summary()
|
||||||
|
|
||||||
|
assert summary["total_trades"] == 3
|
||||||
|
assert summary["wins"] == 2
|
||||||
|
assert summary["losses"] == 1
|
||||||
|
assert summary["total_pnl"] == 1300.0
|
||||||
|
assert summary["avg_pnl"] == 433.33
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
Path(tmp_path).unlink()
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_strategy_success(optimizer: EvolutionOptimizer, tmp_path: Path) -> None:
|
||||||
|
"""Test strategy validation when tests pass."""
|
||||||
|
strategy_file = tmp_path / "test_strategy.py"
|
||||||
|
strategy_file.write_text("# Valid strategy file")
|
||||||
|
|
||||||
|
with patch("subprocess.run") as mock_run:
|
||||||
|
mock_run.return_value = Mock(returncode=0, stdout="", stderr="")
|
||||||
|
result = optimizer.validate_strategy(strategy_file)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert strategy_file.exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_strategy_failure(optimizer: EvolutionOptimizer, tmp_path: Path) -> None:
|
||||||
|
"""Test strategy validation when tests fail."""
|
||||||
|
strategy_file = tmp_path / "test_strategy.py"
|
||||||
|
strategy_file.write_text("# Invalid strategy file")
|
||||||
|
|
||||||
|
with patch("subprocess.run") as mock_run:
|
||||||
|
mock_run.return_value = Mock(returncode=1, stdout="FAILED", stderr="")
|
||||||
|
result = optimizer.validate_strategy(strategy_file)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
# File should be deleted on failure
|
||||||
|
assert not strategy_file.exists()
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# ABTester Tests
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_calculate_performance_basic(ab_tester: ABTester) -> None:
|
||||||
|
"""Test basic performance calculation."""
|
||||||
|
trades = [
|
||||||
|
{"pnl": 1000.0},
|
||||||
|
{"pnl": -500.0},
|
||||||
|
{"pnl": 800.0},
|
||||||
|
{"pnl": 200.0},
|
||||||
|
]
|
||||||
|
|
||||||
|
perf = ab_tester.calculate_performance(trades, "TestStrategy")
|
||||||
|
|
||||||
|
assert perf.strategy_name == "TestStrategy"
|
||||||
|
assert perf.total_trades == 4
|
||||||
|
assert perf.wins == 3
|
||||||
|
assert perf.losses == 1
|
||||||
|
assert perf.total_pnl == 1500.0
|
||||||
|
assert perf.avg_pnl == 375.0
|
||||||
|
assert perf.win_rate == 75.0
|
||||||
|
assert perf.sharpe_ratio is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_calculate_performance_empty(ab_tester: ABTester) -> None:
|
||||||
|
"""Test performance calculation with no trades."""
|
||||||
|
perf = ab_tester.calculate_performance([], "EmptyStrategy")
|
||||||
|
|
||||||
|
assert perf.total_trades == 0
|
||||||
|
assert perf.wins == 0
|
||||||
|
assert perf.losses == 0
|
||||||
|
assert perf.total_pnl == 0.0
|
||||||
|
assert perf.avg_pnl == 0.0
|
||||||
|
assert perf.win_rate == 0.0
|
||||||
|
assert perf.sharpe_ratio is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_compare_strategies_significant_difference(ab_tester: ABTester) -> None:
|
||||||
|
"""Test strategy comparison with significant performance difference."""
|
||||||
|
# Strategy A: consistently profitable
|
||||||
|
trades_a = [{"pnl": 1000.0} for _ in range(30)]
|
||||||
|
|
||||||
|
# Strategy B: consistently losing
|
||||||
|
trades_b = [{"pnl": -500.0} for _ in range(30)]
|
||||||
|
|
||||||
|
result = ab_tester.compare_strategies(trades_a, trades_b, "Strategy A", "Strategy B")
|
||||||
|
|
||||||
|
# scipy returns np.True_ instead of Python bool
|
||||||
|
assert bool(result.is_significant) is True
|
||||||
|
assert result.winner == "Strategy A"
|
||||||
|
assert result.p_value < 0.05
|
||||||
|
assert result.performance_a.avg_pnl > result.performance_b.avg_pnl
|
||||||
|
|
||||||
|
|
||||||
|
def test_compare_strategies_no_difference(ab_tester: ABTester) -> None:
|
||||||
|
"""Test strategy comparison with no significant difference."""
|
||||||
|
# Both strategies have similar performance
|
||||||
|
trades_a = [{"pnl": 100.0}, {"pnl": -50.0}, {"pnl": 80.0}]
|
||||||
|
trades_b = [{"pnl": 90.0}, {"pnl": -60.0}, {"pnl": 85.0}]
|
||||||
|
|
||||||
|
result = ab_tester.compare_strategies(trades_a, trades_b, "Strategy A", "Strategy B")
|
||||||
|
|
||||||
|
# With small samples and similar performance, likely not significant
|
||||||
|
assert result.winner is None or not result.is_significant
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_deploy_meets_criteria(ab_tester: ABTester) -> None:
|
||||||
|
"""Test deployment decision when criteria are met."""
|
||||||
|
# Create a winning result that meets criteria
|
||||||
|
trades_a = [{"pnl": 1000.0} for _ in range(25)] # 100% win rate
|
||||||
|
trades_b = [{"pnl": -500.0} for _ in range(25)]
|
||||||
|
|
||||||
|
result = ab_tester.compare_strategies(trades_a, trades_b, "Winner", "Loser")
|
||||||
|
|
||||||
|
should_deploy = ab_tester.should_deploy(result, min_win_rate=60.0, min_trades=20)
|
||||||
|
|
||||||
|
assert should_deploy is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_deploy_insufficient_trades(ab_tester: ABTester) -> None:
|
||||||
|
"""Test deployment decision with insufficient trades."""
|
||||||
|
trades_a = [{"pnl": 1000.0} for _ in range(10)] # Only 10 trades
|
||||||
|
trades_b = [{"pnl": -500.0} for _ in range(10)]
|
||||||
|
|
||||||
|
result = ab_tester.compare_strategies(trades_a, trades_b, "Winner", "Loser")
|
||||||
|
|
||||||
|
should_deploy = ab_tester.should_deploy(result, min_win_rate=60.0, min_trades=20)
|
||||||
|
|
||||||
|
assert should_deploy is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_deploy_low_win_rate(ab_tester: ABTester) -> None:
|
||||||
|
"""Test deployment decision with low win rate."""
|
||||||
|
# Mix of wins and losses, below 60% win rate
|
||||||
|
trades_a = [{"pnl": 100.0}] * 10 + [{"pnl": -100.0}] * 15 # 40% win rate
|
||||||
|
trades_b = [{"pnl": -500.0} for _ in range(25)]
|
||||||
|
|
||||||
|
result = ab_tester.compare_strategies(trades_a, trades_b, "LowWinner", "Loser")
|
||||||
|
|
||||||
|
should_deploy = ab_tester.should_deploy(result, min_win_rate=60.0, min_trades=20)
|
||||||
|
|
||||||
|
assert should_deploy is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_deploy_not_significant(ab_tester: ABTester) -> None:
|
||||||
|
"""Test deployment decision when difference is not significant."""
|
||||||
|
# Use more varied data to ensure statistical insignificance
|
||||||
|
trades_a = [{"pnl": 100.0}, {"pnl": -50.0}] * 12 + [{"pnl": 100.0}]
|
||||||
|
trades_b = [{"pnl": 95.0}, {"pnl": -45.0}] * 12 + [{"pnl": 95.0}]
|
||||||
|
|
||||||
|
result = ab_tester.compare_strategies(trades_a, trades_b, "A", "B")
|
||||||
|
|
||||||
|
should_deploy = ab_tester.should_deploy(result, min_win_rate=60.0, min_trades=20)
|
||||||
|
|
||||||
|
# Not significant or not profitable enough
|
||||||
|
# Even if significant, win rate is 50% which is below 60% threshold
|
||||||
|
assert should_deploy is False
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# PerformanceTracker Tests
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_strategy_metrics(db_conn: sqlite3.Connection) -> None:
|
||||||
|
"""Test getting strategy metrics."""
|
||||||
|
# Add some trades
|
||||||
|
log_trade(db_conn, "005930", "BUY", 85, "Win 1", quantity=10, price=70000, pnl=1000.0)
|
||||||
|
log_trade(db_conn, "000660", "SELL", 90, "Loss 1", quantity=5, price=100000, pnl=-500.0)
|
||||||
|
log_trade(db_conn, "035420", "BUY", 80, "Win 2", quantity=8, price=50000, pnl=800.0)
|
||||||
|
log_trade(db_conn, "005930", "HOLD", 75, "Hold", quantity=0, price=70000, pnl=0.0)
|
||||||
|
|
||||||
|
tracker = PerformanceTracker(db_path=":memory:")
|
||||||
|
# Manually set connection for testing
|
||||||
|
tracker._db_path = db_conn
|
||||||
|
|
||||||
|
# Need to use the same connection
|
||||||
|
with patch("sqlite3.connect", return_value=db_conn):
|
||||||
|
metrics = tracker.get_strategy_metrics()
|
||||||
|
|
||||||
|
assert metrics.total_trades == 4
|
||||||
|
assert metrics.wins == 2
|
||||||
|
assert metrics.losses == 1
|
||||||
|
assert metrics.holds == 1
|
||||||
|
assert metrics.win_rate == 50.0
|
||||||
|
assert metrics.total_pnl == 1300.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_calculate_improvement_trend_improving(performance_tracker: PerformanceTracker) -> None:
|
||||||
|
"""Test improvement trend calculation for improving strategy."""
|
||||||
|
metrics = [
|
||||||
|
StrategyMetrics(
|
||||||
|
strategy_name="test",
|
||||||
|
period_start="2024-01-01",
|
||||||
|
period_end="2024-01-07",
|
||||||
|
total_trades=10,
|
||||||
|
wins=5,
|
||||||
|
losses=5,
|
||||||
|
holds=0,
|
||||||
|
win_rate=50.0,
|
||||||
|
avg_pnl=100.0,
|
||||||
|
total_pnl=1000.0,
|
||||||
|
best_trade=500.0,
|
||||||
|
worst_trade=-300.0,
|
||||||
|
avg_confidence=75.0,
|
||||||
|
),
|
||||||
|
StrategyMetrics(
|
||||||
|
strategy_name="test",
|
||||||
|
period_start="2024-01-08",
|
||||||
|
period_end="2024-01-14",
|
||||||
|
total_trades=10,
|
||||||
|
wins=7,
|
||||||
|
losses=3,
|
||||||
|
holds=0,
|
||||||
|
win_rate=70.0,
|
||||||
|
avg_pnl=200.0,
|
||||||
|
total_pnl=2000.0,
|
||||||
|
best_trade=600.0,
|
||||||
|
worst_trade=-200.0,
|
||||||
|
avg_confidence=80.0,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
trend = performance_tracker.calculate_improvement_trend(metrics)
|
||||||
|
|
||||||
|
assert trend["trend"] == "improving"
|
||||||
|
assert trend["win_rate_change"] == 20.0
|
||||||
|
assert trend["pnl_change"] == 100.0
|
||||||
|
assert trend["confidence_change"] == 5.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_calculate_improvement_trend_declining(performance_tracker: PerformanceTracker) -> None:
|
||||||
|
"""Test improvement trend calculation for declining strategy."""
|
||||||
|
metrics = [
|
||||||
|
StrategyMetrics(
|
||||||
|
strategy_name="test",
|
||||||
|
period_start="2024-01-01",
|
||||||
|
period_end="2024-01-07",
|
||||||
|
total_trades=10,
|
||||||
|
wins=7,
|
||||||
|
losses=3,
|
||||||
|
holds=0,
|
||||||
|
win_rate=70.0,
|
||||||
|
avg_pnl=200.0,
|
||||||
|
total_pnl=2000.0,
|
||||||
|
best_trade=600.0,
|
||||||
|
worst_trade=-200.0,
|
||||||
|
avg_confidence=80.0,
|
||||||
|
),
|
||||||
|
StrategyMetrics(
|
||||||
|
strategy_name="test",
|
||||||
|
period_start="2024-01-08",
|
||||||
|
period_end="2024-01-14",
|
||||||
|
total_trades=10,
|
||||||
|
wins=4,
|
||||||
|
losses=6,
|
||||||
|
holds=0,
|
||||||
|
win_rate=40.0,
|
||||||
|
avg_pnl=-50.0,
|
||||||
|
total_pnl=-500.0,
|
||||||
|
best_trade=300.0,
|
||||||
|
worst_trade=-400.0,
|
||||||
|
avg_confidence=70.0,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
trend = performance_tracker.calculate_improvement_trend(metrics)
|
||||||
|
|
||||||
|
assert trend["trend"] == "declining"
|
||||||
|
assert trend["win_rate_change"] == -30.0
|
||||||
|
assert trend["pnl_change"] == -250.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_calculate_improvement_trend_insufficient_data(performance_tracker: PerformanceTracker) -> None:
|
||||||
|
"""Test improvement trend with insufficient data."""
|
||||||
|
metrics = [
|
||||||
|
StrategyMetrics(
|
||||||
|
strategy_name="test",
|
||||||
|
period_start="2024-01-01",
|
||||||
|
period_end="2024-01-07",
|
||||||
|
total_trades=10,
|
||||||
|
wins=5,
|
||||||
|
losses=5,
|
||||||
|
holds=0,
|
||||||
|
win_rate=50.0,
|
||||||
|
avg_pnl=100.0,
|
||||||
|
total_pnl=1000.0,
|
||||||
|
best_trade=500.0,
|
||||||
|
worst_trade=-300.0,
|
||||||
|
avg_confidence=75.0,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
trend = performance_tracker.calculate_improvement_trend(metrics)
|
||||||
|
|
||||||
|
assert trend["trend"] == "insufficient_data"
|
||||||
|
assert trend["win_rate_change"] == 0.0
|
||||||
|
assert trend["pnl_change"] == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_export_dashboard_json(performance_tracker: PerformanceTracker) -> None:
|
||||||
|
"""Test exporting dashboard as JSON."""
|
||||||
|
overall_metrics = StrategyMetrics(
|
||||||
|
strategy_name="test",
|
||||||
|
period_start="2024-01-01",
|
||||||
|
period_end="2024-01-31",
|
||||||
|
total_trades=100,
|
||||||
|
wins=60,
|
||||||
|
losses=40,
|
||||||
|
holds=10,
|
||||||
|
win_rate=60.0,
|
||||||
|
avg_pnl=150.0,
|
||||||
|
total_pnl=15000.0,
|
||||||
|
best_trade=1000.0,
|
||||||
|
worst_trade=-500.0,
|
||||||
|
avg_confidence=80.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
dashboard = PerformanceDashboard(
|
||||||
|
generated_at=datetime.now(UTC).isoformat(),
|
||||||
|
overall_metrics=overall_metrics,
|
||||||
|
daily_metrics=[],
|
||||||
|
weekly_metrics=[],
|
||||||
|
improvement_trend={"trend": "improving", "win_rate_change": 10.0},
|
||||||
|
)
|
||||||
|
|
||||||
|
json_output = performance_tracker.export_dashboard_json(dashboard)
|
||||||
|
|
||||||
|
# Verify it's valid JSON
|
||||||
|
data = json.loads(json_output)
|
||||||
|
assert "generated_at" in data
|
||||||
|
assert "overall_metrics" in data
|
||||||
|
assert data["overall_metrics"]["total_trades"] == 100
|
||||||
|
assert data["overall_metrics"]["win_rate"] == 60.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_dashboard() -> None:
|
||||||
|
"""Test generating a complete dashboard."""
|
||||||
|
# Create tracker with temp database
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
|
||||||
|
tmp_path = tmp.name
|
||||||
|
|
||||||
|
# Initialize with data
|
||||||
|
conn = init_db(tmp_path)
|
||||||
|
log_trade(conn, "005930", "BUY", 85, "Win", quantity=10, price=70000, pnl=1000.0)
|
||||||
|
log_trade(conn, "000660", "SELL", 90, "Loss", quantity=5, price=100000, pnl=-500.0)
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
tracker = PerformanceTracker(db_path=tmp_path)
|
||||||
|
dashboard = tracker.generate_dashboard()
|
||||||
|
|
||||||
|
assert isinstance(dashboard, PerformanceDashboard)
|
||||||
|
assert dashboard.overall_metrics.total_trades == 2
|
||||||
|
assert len(dashboard.daily_metrics) == 7
|
||||||
|
assert len(dashboard.weekly_metrics) == 4
|
||||||
|
assert "trend" in dashboard.improvement_trend
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
Path(tmp_path).unlink()
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Integration Tests
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_evolution_pipeline(optimizer: EvolutionOptimizer, tmp_path: Path) -> None:
|
||||||
|
"""Test the complete evolution pipeline."""
|
||||||
|
# Add losing decisions
|
||||||
|
logger = optimizer._decision_logger
|
||||||
|
id1 = logger.log_decision(
|
||||||
|
stock_code="005930",
|
||||||
|
market="KR",
|
||||||
|
exchange_code="KRX",
|
||||||
|
action="BUY",
|
||||||
|
confidence=85,
|
||||||
|
rationale="Expected growth",
|
||||||
|
context_snapshot={},
|
||||||
|
input_data={},
|
||||||
|
)
|
||||||
|
logger.update_outcome(id1, pnl=-2000.0, accuracy=0)
|
||||||
|
|
||||||
|
# Mock Gemini and subprocess
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.text = 'return {"action": "HOLD", "confidence": 50, "rationale": "Test"}'
|
||||||
|
|
||||||
|
with patch.object(optimizer._client.aio.models, "generate_content", new=AsyncMock(return_value=mock_response)):
|
||||||
|
with patch("src.evolution.optimizer.STRATEGIES_DIR", tmp_path):
|
||||||
|
with patch("subprocess.run") as mock_run:
|
||||||
|
mock_run.return_value = Mock(returncode=0, stdout="", stderr="")
|
||||||
|
|
||||||
|
result = await optimizer.evolve()
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert "title" in result
|
||||||
|
assert "branch" in result
|
||||||
|
assert "status" in result
|
||||||
558
tests/test_latency_control.py
Normal file
558
tests/test_latency_control.py
Normal file
@@ -0,0 +1,558 @@
|
|||||||
|
"""Tests for latency control system (criticality assessment and priority queue)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.core.criticality import CriticalityAssessor, CriticalityLevel
|
||||||
|
from src.core.priority_queue import PriorityTask, PriorityTaskQueue
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# CriticalityAssessor Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCriticalityAssessor:
|
||||||
|
"""Test suite for criticality assessment logic."""
|
||||||
|
|
||||||
|
def test_market_closed_returns_low(self) -> None:
|
||||||
|
"""Market closed should return LOW priority."""
|
||||||
|
assessor = CriticalityAssessor()
|
||||||
|
level = assessor.assess_market_conditions(
|
||||||
|
pnl_pct=0.0,
|
||||||
|
volatility_score=50.0,
|
||||||
|
volume_surge=1.0,
|
||||||
|
is_market_open=False,
|
||||||
|
)
|
||||||
|
assert level == CriticalityLevel.LOW
|
||||||
|
|
||||||
|
def test_very_low_volatility_returns_low(self) -> None:
|
||||||
|
"""Very low volatility should return LOW priority."""
|
||||||
|
assessor = CriticalityAssessor()
|
||||||
|
level = assessor.assess_market_conditions(
|
||||||
|
pnl_pct=0.0,
|
||||||
|
volatility_score=20.0, # Below 30.0 threshold
|
||||||
|
volume_surge=1.0,
|
||||||
|
is_market_open=True,
|
||||||
|
)
|
||||||
|
assert level == CriticalityLevel.LOW
|
||||||
|
|
||||||
|
def test_critical_pnl_threshold_triggered(self) -> None:
|
||||||
|
"""P&L below -2.5% should trigger CRITICAL."""
|
||||||
|
assessor = CriticalityAssessor()
|
||||||
|
level = assessor.assess_market_conditions(
|
||||||
|
pnl_pct=-2.6, # Below -2.5% threshold
|
||||||
|
volatility_score=50.0,
|
||||||
|
volume_surge=1.0,
|
||||||
|
is_market_open=True,
|
||||||
|
)
|
||||||
|
assert level == CriticalityLevel.CRITICAL
|
||||||
|
|
||||||
|
def test_critical_pnl_at_circuit_breaker_proximity(self) -> None:
|
||||||
|
"""P&L at exactly -2.5% (near -3.0% breaker) should be CRITICAL."""
|
||||||
|
assessor = CriticalityAssessor()
|
||||||
|
level = assessor.assess_market_conditions(
|
||||||
|
pnl_pct=-2.5,
|
||||||
|
volatility_score=50.0,
|
||||||
|
volume_surge=1.0,
|
||||||
|
is_market_open=True,
|
||||||
|
)
|
||||||
|
assert level == CriticalityLevel.CRITICAL
|
||||||
|
|
||||||
|
def test_critical_price_change_positive(self) -> None:
|
||||||
|
"""Large positive price change (>5%) should trigger CRITICAL."""
|
||||||
|
assessor = CriticalityAssessor()
|
||||||
|
level = assessor.assess_market_conditions(
|
||||||
|
pnl_pct=0.0,
|
||||||
|
volatility_score=50.0,
|
||||||
|
volume_surge=1.0,
|
||||||
|
price_change_1m=5.5, # Above 5.0% threshold
|
||||||
|
is_market_open=True,
|
||||||
|
)
|
||||||
|
assert level == CriticalityLevel.CRITICAL
|
||||||
|
|
||||||
|
def test_critical_price_change_negative(self) -> None:
|
||||||
|
"""Large negative price change (<-5%) should trigger CRITICAL."""
|
||||||
|
assessor = CriticalityAssessor()
|
||||||
|
level = assessor.assess_market_conditions(
|
||||||
|
pnl_pct=0.0,
|
||||||
|
volatility_score=50.0,
|
||||||
|
volume_surge=1.0,
|
||||||
|
price_change_1m=-6.0, # Below -5.0% threshold
|
||||||
|
is_market_open=True,
|
||||||
|
)
|
||||||
|
assert level == CriticalityLevel.CRITICAL
|
||||||
|
|
||||||
|
def test_critical_volume_surge(self) -> None:
|
||||||
|
"""Extreme volume surge (>10x) should trigger CRITICAL."""
|
||||||
|
assessor = CriticalityAssessor()
|
||||||
|
level = assessor.assess_market_conditions(
|
||||||
|
pnl_pct=0.0,
|
||||||
|
volatility_score=50.0,
|
||||||
|
volume_surge=12.0, # Above 10.0x threshold
|
||||||
|
is_market_open=True,
|
||||||
|
)
|
||||||
|
assert level == CriticalityLevel.CRITICAL
|
||||||
|
|
||||||
|
def test_high_volatility_returns_high(self) -> None:
|
||||||
|
"""High volatility score should return HIGH priority."""
|
||||||
|
assessor = CriticalityAssessor()
|
||||||
|
level = assessor.assess_market_conditions(
|
||||||
|
pnl_pct=0.0,
|
||||||
|
volatility_score=75.0, # Above 70.0 threshold
|
||||||
|
volume_surge=1.0,
|
||||||
|
is_market_open=True,
|
||||||
|
)
|
||||||
|
assert level == CriticalityLevel.HIGH
|
||||||
|
|
||||||
|
def test_normal_conditions_return_normal(self) -> None:
|
||||||
|
"""Normal market conditions should return NORMAL priority."""
|
||||||
|
assessor = CriticalityAssessor()
|
||||||
|
level = assessor.assess_market_conditions(
|
||||||
|
pnl_pct=0.5,
|
||||||
|
volatility_score=50.0, # Between 30-70
|
||||||
|
volume_surge=1.5,
|
||||||
|
price_change_1m=1.0,
|
||||||
|
is_market_open=True,
|
||||||
|
)
|
||||||
|
assert level == CriticalityLevel.NORMAL
|
||||||
|
|
||||||
|
def test_custom_thresholds(self) -> None:
|
||||||
|
"""Custom thresholds should be respected."""
|
||||||
|
assessor = CriticalityAssessor(
|
||||||
|
critical_pnl_threshold=-1.0,
|
||||||
|
critical_price_change_threshold=3.0,
|
||||||
|
critical_volume_surge_threshold=5.0,
|
||||||
|
high_volatility_threshold=60.0,
|
||||||
|
low_volatility_threshold=20.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test custom P&L threshold
|
||||||
|
level = assessor.assess_market_conditions(
|
||||||
|
pnl_pct=-1.1,
|
||||||
|
volatility_score=50.0,
|
||||||
|
volume_surge=1.0,
|
||||||
|
is_market_open=True,
|
||||||
|
)
|
||||||
|
assert level == CriticalityLevel.CRITICAL
|
||||||
|
|
||||||
|
# Test custom price change threshold
|
||||||
|
level = assessor.assess_market_conditions(
|
||||||
|
pnl_pct=0.0,
|
||||||
|
volatility_score=50.0,
|
||||||
|
volume_surge=1.0,
|
||||||
|
price_change_1m=3.5,
|
||||||
|
is_market_open=True,
|
||||||
|
)
|
||||||
|
assert level == CriticalityLevel.CRITICAL
|
||||||
|
|
||||||
|
def test_get_timeout_returns_correct_values(self) -> None:
|
||||||
|
"""Timeout values should match specification."""
|
||||||
|
assessor = CriticalityAssessor()
|
||||||
|
|
||||||
|
assert assessor.get_timeout(CriticalityLevel.CRITICAL) == 5.0
|
||||||
|
assert assessor.get_timeout(CriticalityLevel.HIGH) == 30.0
|
||||||
|
assert assessor.get_timeout(CriticalityLevel.NORMAL) == 60.0
|
||||||
|
assert assessor.get_timeout(CriticalityLevel.LOW) is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# PriorityTaskQueue Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestPriorityTaskQueue:
|
||||||
|
"""Test suite for priority queue implementation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enqueue_task(self) -> None:
|
||||||
|
"""Tasks should be enqueued successfully."""
|
||||||
|
queue = PriorityTaskQueue()
|
||||||
|
|
||||||
|
success = await queue.enqueue(
|
||||||
|
task_id="test-1",
|
||||||
|
criticality=CriticalityLevel.NORMAL,
|
||||||
|
task_data={"action": "test"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert success is True
|
||||||
|
assert await queue.size() == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enqueue_rejects_when_full(self) -> None:
|
||||||
|
"""Queue should reject tasks when full."""
|
||||||
|
queue = PriorityTaskQueue(max_size=2)
|
||||||
|
|
||||||
|
# Fill the queue
|
||||||
|
await queue.enqueue("task-1", CriticalityLevel.NORMAL, {})
|
||||||
|
await queue.enqueue("task-2", CriticalityLevel.NORMAL, {})
|
||||||
|
|
||||||
|
# Third task should be rejected
|
||||||
|
success = await queue.enqueue("task-3", CriticalityLevel.NORMAL, {})
|
||||||
|
assert success is False
|
||||||
|
assert await queue.size() == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dequeue_returns_highest_priority(self) -> None:
|
||||||
|
"""Dequeue should return highest priority task first."""
|
||||||
|
queue = PriorityTaskQueue()
|
||||||
|
|
||||||
|
# Enqueue tasks in reverse priority order
|
||||||
|
await queue.enqueue("low", CriticalityLevel.LOW, {"priority": 3})
|
||||||
|
await queue.enqueue("normal", CriticalityLevel.NORMAL, {"priority": 2})
|
||||||
|
await queue.enqueue("high", CriticalityLevel.HIGH, {"priority": 1})
|
||||||
|
await queue.enqueue("critical", CriticalityLevel.CRITICAL, {"priority": 0})
|
||||||
|
|
||||||
|
# Dequeue should return CRITICAL first
|
||||||
|
task = await queue.dequeue(timeout=1.0)
|
||||||
|
assert task is not None
|
||||||
|
assert task.task_id == "critical"
|
||||||
|
assert task.priority == 0
|
||||||
|
|
||||||
|
# Then HIGH
|
||||||
|
task = await queue.dequeue(timeout=1.0)
|
||||||
|
assert task is not None
|
||||||
|
assert task.task_id == "high"
|
||||||
|
assert task.priority == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dequeue_fifo_within_same_priority(self) -> None:
|
||||||
|
"""Tasks with same priority should be FIFO."""
|
||||||
|
queue = PriorityTaskQueue()
|
||||||
|
|
||||||
|
# Enqueue multiple tasks with same priority
|
||||||
|
await queue.enqueue("task-1", CriticalityLevel.NORMAL, {})
|
||||||
|
await asyncio.sleep(0.01) # Small delay to ensure different timestamps
|
||||||
|
await queue.enqueue("task-2", CriticalityLevel.NORMAL, {})
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
await queue.enqueue("task-3", CriticalityLevel.NORMAL, {})
|
||||||
|
|
||||||
|
# Should dequeue in FIFO order
|
||||||
|
task1 = await queue.dequeue(timeout=1.0)
|
||||||
|
task2 = await queue.dequeue(timeout=1.0)
|
||||||
|
task3 = await queue.dequeue(timeout=1.0)
|
||||||
|
|
||||||
|
assert task1 is not None and task1.task_id == "task-1"
|
||||||
|
assert task2 is not None and task2.task_id == "task-2"
|
||||||
|
assert task3 is not None and task3.task_id == "task-3"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dequeue_returns_none_when_empty(self) -> None:
|
||||||
|
"""Dequeue should return None when queue is empty after timeout."""
|
||||||
|
queue = PriorityTaskQueue()
|
||||||
|
|
||||||
|
task = await queue.dequeue(timeout=0.1)
|
||||||
|
assert task is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_with_timeout_success(self) -> None:
|
||||||
|
"""Task execution should succeed within timeout."""
|
||||||
|
queue = PriorityTaskQueue()
|
||||||
|
|
||||||
|
# Create a simple async callback
|
||||||
|
async def test_callback() -> str:
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
return "success"
|
||||||
|
|
||||||
|
task = PriorityTask(
|
||||||
|
priority=0,
|
||||||
|
timestamp=0.0,
|
||||||
|
task_id="test",
|
||||||
|
task_data={},
|
||||||
|
callback=test_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await queue.execute_with_timeout(task, timeout=1.0)
|
||||||
|
assert result == "success"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_with_timeout_raises_timeout_error(self) -> None:
|
||||||
|
"""Task execution should raise TimeoutError if exceeds timeout."""
|
||||||
|
queue = PriorityTaskQueue()
|
||||||
|
|
||||||
|
# Create a slow async callback
|
||||||
|
async def slow_callback() -> str:
|
||||||
|
await asyncio.sleep(1.0)
|
||||||
|
return "too slow"
|
||||||
|
|
||||||
|
task = PriorityTask(
|
||||||
|
priority=0,
|
||||||
|
timestamp=0.0,
|
||||||
|
task_id="test",
|
||||||
|
task_data={},
|
||||||
|
callback=slow_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(asyncio.TimeoutError):
|
||||||
|
await queue.execute_with_timeout(task, timeout=0.1)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_with_timeout_propagates_exceptions(self) -> None:
|
||||||
|
"""Task execution should propagate exceptions from callback."""
|
||||||
|
queue = PriorityTaskQueue()
|
||||||
|
|
||||||
|
# Create a failing async callback
|
||||||
|
async def failing_callback() -> None:
|
||||||
|
raise ValueError("Test error")
|
||||||
|
|
||||||
|
task = PriorityTask(
|
||||||
|
priority=0,
|
||||||
|
timestamp=0.0,
|
||||||
|
task_id="test",
|
||||||
|
task_data={},
|
||||||
|
callback=failing_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Test error"):
|
||||||
|
await queue.execute_with_timeout(task, timeout=1.0)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_without_timeout(self) -> None:
|
||||||
|
"""Task execution should work without timeout (LOW priority)."""
|
||||||
|
queue = PriorityTaskQueue()
|
||||||
|
|
||||||
|
async def test_callback() -> str:
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
return "success"
|
||||||
|
|
||||||
|
task = PriorityTask(
|
||||||
|
priority=3,
|
||||||
|
timestamp=0.0,
|
||||||
|
task_id="test",
|
||||||
|
task_data={},
|
||||||
|
callback=test_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await queue.execute_with_timeout(task, timeout=None)
|
||||||
|
assert result == "success"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_metrics(self) -> None:
|
||||||
|
"""Queue should track metrics correctly."""
|
||||||
|
queue = PriorityTaskQueue()
|
||||||
|
|
||||||
|
# Enqueue and dequeue some tasks
|
||||||
|
await queue.enqueue("task-1", CriticalityLevel.CRITICAL, {})
|
||||||
|
await queue.enqueue("task-2", CriticalityLevel.HIGH, {})
|
||||||
|
await queue.enqueue("task-3", CriticalityLevel.NORMAL, {})
|
||||||
|
|
||||||
|
await queue.dequeue(timeout=1.0)
|
||||||
|
await queue.dequeue(timeout=1.0)
|
||||||
|
|
||||||
|
metrics = await queue.get_metrics()
|
||||||
|
|
||||||
|
assert metrics.total_enqueued == 3
|
||||||
|
assert metrics.total_dequeued == 2
|
||||||
|
assert metrics.current_size == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_wait_time_metrics(self) -> None:
|
||||||
|
"""Queue should track wait times per criticality level."""
|
||||||
|
queue = PriorityTaskQueue()
|
||||||
|
|
||||||
|
# Enqueue tasks with different criticality
|
||||||
|
await queue.enqueue("critical-1", CriticalityLevel.CRITICAL, {})
|
||||||
|
await asyncio.sleep(0.05) # Add some wait time
|
||||||
|
|
||||||
|
await queue.dequeue(timeout=1.0)
|
||||||
|
|
||||||
|
metrics = await queue.get_metrics()
|
||||||
|
|
||||||
|
# Should have wait time metrics for CRITICAL
|
||||||
|
assert CriticalityLevel.CRITICAL in metrics.avg_wait_time
|
||||||
|
assert metrics.avg_wait_time[CriticalityLevel.CRITICAL] > 0.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_clear_queue(self) -> None:
|
||||||
|
"""Clear should remove all tasks from queue."""
|
||||||
|
queue = PriorityTaskQueue()
|
||||||
|
|
||||||
|
await queue.enqueue("task-1", CriticalityLevel.NORMAL, {})
|
||||||
|
await queue.enqueue("task-2", CriticalityLevel.NORMAL, {})
|
||||||
|
await queue.enqueue("task-3", CriticalityLevel.NORMAL, {})
|
||||||
|
|
||||||
|
cleared = await queue.clear()
|
||||||
|
|
||||||
|
assert cleared == 3
|
||||||
|
assert await queue.size() == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_enqueue_dequeue(self) -> None:
|
||||||
|
"""Queue should handle concurrent operations safely."""
|
||||||
|
queue = PriorityTaskQueue()
|
||||||
|
|
||||||
|
# Concurrent enqueue operations
|
||||||
|
async def enqueue_tasks() -> None:
|
||||||
|
for i in range(10):
|
||||||
|
await queue.enqueue(
|
||||||
|
f"task-{i}",
|
||||||
|
CriticalityLevel.NORMAL,
|
||||||
|
{"index": i},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Concurrent dequeue operations
|
||||||
|
async def dequeue_tasks() -> list[str]:
|
||||||
|
tasks = []
|
||||||
|
for _ in range(10):
|
||||||
|
task = await queue.dequeue(timeout=1.0)
|
||||||
|
if task:
|
||||||
|
tasks.append(task.task_id)
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
return tasks
|
||||||
|
|
||||||
|
# Run both concurrently
|
||||||
|
enqueue_task = asyncio.create_task(enqueue_tasks())
|
||||||
|
dequeue_task = asyncio.create_task(dequeue_tasks())
|
||||||
|
|
||||||
|
await enqueue_task
|
||||||
|
dequeued_ids = await dequeue_task
|
||||||
|
|
||||||
|
# All tasks should be processed
|
||||||
|
assert len(dequeued_ids) == 10
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_timeout_metric_tracking(self) -> None:
|
||||||
|
"""Queue should track timeout occurrences."""
|
||||||
|
queue = PriorityTaskQueue()
|
||||||
|
|
||||||
|
async def slow_callback() -> str:
|
||||||
|
await asyncio.sleep(1.0)
|
||||||
|
return "too slow"
|
||||||
|
|
||||||
|
task = PriorityTask(
|
||||||
|
priority=0,
|
||||||
|
timestamp=0.0,
|
||||||
|
task_id="test",
|
||||||
|
task_data={},
|
||||||
|
callback=slow_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await queue.execute_with_timeout(task, timeout=0.1)
|
||||||
|
except TimeoutError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
metrics = await queue.get_metrics()
|
||||||
|
assert metrics.total_timeouts == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_error_metric_tracking(self) -> None:
|
||||||
|
"""Queue should track execution errors."""
|
||||||
|
queue = PriorityTaskQueue()
|
||||||
|
|
||||||
|
async def failing_callback() -> None:
|
||||||
|
raise ValueError("Test error")
|
||||||
|
|
||||||
|
task = PriorityTask(
|
||||||
|
priority=0,
|
||||||
|
timestamp=0.0,
|
||||||
|
task_id="test",
|
||||||
|
task_data={},
|
||||||
|
callback=failing_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await queue.execute_with_timeout(task, timeout=1.0)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
metrics = await queue.get_metrics()
|
||||||
|
assert metrics.total_errors == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Integration Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestLatencyControlIntegration:
|
||||||
|
"""Integration tests for criticality assessment and priority queue."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_critical_task_bypass_queue(self) -> None:
|
||||||
|
"""CRITICAL tasks should bypass lower priority tasks."""
|
||||||
|
queue = PriorityTaskQueue()
|
||||||
|
|
||||||
|
# Add normal priority tasks
|
||||||
|
await queue.enqueue("normal-1", CriticalityLevel.NORMAL, {})
|
||||||
|
await queue.enqueue("normal-2", CriticalityLevel.NORMAL, {})
|
||||||
|
|
||||||
|
# Add critical task (should jump to front)
|
||||||
|
await queue.enqueue("critical", CriticalityLevel.CRITICAL, {})
|
||||||
|
|
||||||
|
# Dequeue should return critical first
|
||||||
|
task = await queue.dequeue(timeout=1.0)
|
||||||
|
assert task is not None
|
||||||
|
assert task.task_id == "critical"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_timeout_enforcement_by_criticality(self) -> None:
|
||||||
|
"""Timeout enforcement should match criticality level."""
|
||||||
|
assessor = CriticalityAssessor()
|
||||||
|
|
||||||
|
# CRITICAL should have 5s timeout
|
||||||
|
critical_timeout = assessor.get_timeout(CriticalityLevel.CRITICAL)
|
||||||
|
assert critical_timeout == 5.0
|
||||||
|
|
||||||
|
# HIGH should have 30s timeout
|
||||||
|
high_timeout = assessor.get_timeout(CriticalityLevel.HIGH)
|
||||||
|
assert high_timeout == 30.0
|
||||||
|
|
||||||
|
# NORMAL should have 60s timeout
|
||||||
|
normal_timeout = assessor.get_timeout(CriticalityLevel.NORMAL)
|
||||||
|
assert normal_timeout == 60.0
|
||||||
|
|
||||||
|
# LOW should have no timeout
|
||||||
|
low_timeout = assessor.get_timeout(CriticalityLevel.LOW)
|
||||||
|
assert low_timeout is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fast_path_execution_for_critical(self) -> None:
|
||||||
|
"""CRITICAL tasks should complete quickly."""
|
||||||
|
queue = PriorityTaskQueue()
|
||||||
|
|
||||||
|
# Create a fast callback simulating fast-path execution
|
||||||
|
async def fast_path_callback() -> str:
|
||||||
|
# Simulate simplified decision flow
|
||||||
|
await asyncio.sleep(0.01) # Very fast execution
|
||||||
|
return "fast_path_complete"
|
||||||
|
|
||||||
|
task = PriorityTask(
|
||||||
|
priority=0, # CRITICAL
|
||||||
|
timestamp=0.0,
|
||||||
|
task_id="critical-fast",
|
||||||
|
task_data={},
|
||||||
|
callback=fast_path_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
result = await queue.execute_with_timeout(task, timeout=5.0)
|
||||||
|
elapsed = time.time() - start
|
||||||
|
|
||||||
|
assert result == "fast_path_complete"
|
||||||
|
assert elapsed < 5.0 # Should complete well under CRITICAL timeout
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graceful_degradation_when_queue_full(self) -> None:
|
||||||
|
"""System should gracefully handle full queue."""
|
||||||
|
queue = PriorityTaskQueue(max_size=2)
|
||||||
|
|
||||||
|
# Fill the queue
|
||||||
|
await queue.enqueue("task-1", CriticalityLevel.NORMAL, {})
|
||||||
|
await queue.enqueue("task-2", CriticalityLevel.NORMAL, {})
|
||||||
|
|
||||||
|
# Try to add more tasks
|
||||||
|
success = await queue.enqueue("task-3", CriticalityLevel.NORMAL, {})
|
||||||
|
assert success is False
|
||||||
|
|
||||||
|
# Queue should still function
|
||||||
|
task = await queue.dequeue(timeout=1.0)
|
||||||
|
assert task is not None
|
||||||
|
|
||||||
|
# Now we can add another task
|
||||||
|
success = await queue.enqueue("task-4", CriticalityLevel.NORMAL, {})
|
||||||
|
assert success is True
|
||||||
663
tests/test_token_efficiency.py
Normal file
663
tests/test_token_efficiency.py
Normal file
@@ -0,0 +1,663 @@
|
|||||||
|
"""Tests for token efficiency optimization components.
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- Prompt compression and optimization
|
||||||
|
- Context selection logic
|
||||||
|
- Summarization
|
||||||
|
- Caching
|
||||||
|
- Token reduction metrics
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.brain.cache import DecisionCache
|
||||||
|
from src.brain.context_selector import ContextSelector, DecisionType
|
||||||
|
from src.brain.gemini_client import TradeDecision
|
||||||
|
from src.brain.prompt_optimizer import PromptOptimizer, TokenMetrics
|
||||||
|
from src.context.layer import ContextLayer
|
||||||
|
from src.context.store import ContextStore
|
||||||
|
from src.context.summarizer import ContextSummarizer, SummaryStats
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Prompt Optimizer Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptOptimizer:
|
||||||
|
"""Tests for PromptOptimizer."""
|
||||||
|
|
||||||
|
def test_estimate_tokens(self):
|
||||||
|
"""Test token estimation."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
# Empty text
|
||||||
|
assert optimizer.estimate_tokens("") == 0
|
||||||
|
|
||||||
|
# Short text (4 chars = 1 token estimate)
|
||||||
|
assert optimizer.estimate_tokens("test") == 1
|
||||||
|
|
||||||
|
# Longer text
|
||||||
|
text = "This is a longer piece of text for testing token estimation."
|
||||||
|
tokens = optimizer.estimate_tokens(text)
|
||||||
|
assert tokens > 0
|
||||||
|
assert tokens == len(text) // 4
|
||||||
|
|
||||||
|
def test_count_tokens(self):
|
||||||
|
"""Test token counting metrics."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
text = "Hello world, this is a test."
|
||||||
|
metrics = optimizer.count_tokens(text)
|
||||||
|
|
||||||
|
assert isinstance(metrics, TokenMetrics)
|
||||||
|
assert metrics.char_count == len(text)
|
||||||
|
assert metrics.word_count == 6
|
||||||
|
assert metrics.estimated_tokens > 0
|
||||||
|
|
||||||
|
def test_compress_json(self):
|
||||||
|
"""Test JSON compression."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"action": "BUY",
|
||||||
|
"confidence": 85,
|
||||||
|
"rationale": "Strong uptrend",
|
||||||
|
}
|
||||||
|
|
||||||
|
compressed = optimizer.compress_json(data)
|
||||||
|
|
||||||
|
# Should have no newlines and minimal whitespace
|
||||||
|
assert "\n" not in compressed
|
||||||
|
# Note: JSON values may contain spaces (e.g., "Strong uptrend")
|
||||||
|
# but there should be no spaces around separators
|
||||||
|
assert ": " not in compressed
|
||||||
|
assert ", " not in compressed
|
||||||
|
|
||||||
|
# Should be valid JSON
|
||||||
|
import json
|
||||||
|
|
||||||
|
parsed = json.loads(compressed)
|
||||||
|
assert parsed == data
|
||||||
|
|
||||||
|
def test_abbreviate_text(self):
|
||||||
|
"""Test text abbreviation."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
text = "The current price is high and volume is increasing."
|
||||||
|
abbreviated = optimizer.abbreviate_text(text)
|
||||||
|
|
||||||
|
# Should contain abbreviations
|
||||||
|
assert "cur" in abbreviated or "P" in abbreviated
|
||||||
|
assert len(abbreviated) <= len(text)
|
||||||
|
|
||||||
|
def test_abbreviate_text_aggressive(self):
|
||||||
|
"""Test aggressive text abbreviation."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
text = "The price is increasing and the volume is high."
|
||||||
|
abbreviated = optimizer.abbreviate_text(text, aggressive=True)
|
||||||
|
|
||||||
|
# Should be shorter
|
||||||
|
assert len(abbreviated) < len(text)
|
||||||
|
|
||||||
|
# Should have removed articles
|
||||||
|
assert "the" not in abbreviated.lower()
|
||||||
|
|
||||||
|
def test_build_compressed_prompt(self):
|
||||||
|
"""Test compressed prompt building."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
market_data = {
|
||||||
|
"stock_code": "005930",
|
||||||
|
"current_price": 75000,
|
||||||
|
"market_name": "Korean stock market",
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt = optimizer.build_compressed_prompt(market_data)
|
||||||
|
|
||||||
|
# Should be much shorter than original
|
||||||
|
assert len(prompt) < 300
|
||||||
|
assert "005930" in prompt
|
||||||
|
assert "75000" in prompt
|
||||||
|
|
||||||
|
def test_build_compressed_prompt_no_instructions(self):
|
||||||
|
"""Test compressed prompt without instructions."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
market_data = {
|
||||||
|
"stock_code": "AAPL",
|
||||||
|
"current_price": 150.5,
|
||||||
|
"market_name": "United States",
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt = optimizer.build_compressed_prompt(market_data, include_instructions=False)
|
||||||
|
|
||||||
|
# Should be very short (data only)
|
||||||
|
assert len(prompt) < 100
|
||||||
|
assert "AAPL" in prompt
|
||||||
|
|
||||||
|
def test_truncate_context(self):
|
||||||
|
"""Test context truncation."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
context = {
|
||||||
|
"price": 100.5,
|
||||||
|
"volume": 1000000,
|
||||||
|
"sentiment": 0.8,
|
||||||
|
"extra_data": "Some long text that should be truncated",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Truncate to small budget
|
||||||
|
truncated = optimizer.truncate_context(context, max_tokens=10)
|
||||||
|
|
||||||
|
# Should have fewer keys
|
||||||
|
assert len(truncated) <= len(context)
|
||||||
|
|
||||||
|
def test_truncate_context_with_priority(self):
|
||||||
|
"""Test context truncation with priority keys."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
context = {
|
||||||
|
"price": 100.5,
|
||||||
|
"volume": 1000000,
|
||||||
|
"sentiment": 0.8,
|
||||||
|
"extra_data": "Some data",
|
||||||
|
}
|
||||||
|
|
||||||
|
priority_keys = ["price", "sentiment"]
|
||||||
|
truncated = optimizer.truncate_context(context, max_tokens=20, priority_keys=priority_keys)
|
||||||
|
|
||||||
|
# Priority keys should be included
|
||||||
|
assert "price" in truncated
|
||||||
|
assert "sentiment" in truncated
|
||||||
|
|
||||||
|
def test_calculate_compression_ratio(self):
|
||||||
|
"""Test compression ratio calculation."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
original = "This is a very long piece of text that should be compressed significantly."
|
||||||
|
compressed = "Short text"
|
||||||
|
|
||||||
|
ratio = optimizer.calculate_compression_ratio(original, compressed)
|
||||||
|
|
||||||
|
# Ratio should be > 1 (original is longer)
|
||||||
|
assert ratio > 1.0
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Context Selector Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextSelector:
|
||||||
|
"""Tests for ContextSelector."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def store(self):
|
||||||
|
"""Create in-memory ContextStore."""
|
||||||
|
conn = sqlite3.connect(":memory:")
|
||||||
|
# Create tables
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE context_metadata (
|
||||||
|
layer TEXT PRIMARY KEY,
|
||||||
|
description TEXT,
|
||||||
|
retention_days INTEGER,
|
||||||
|
aggregation_source TEXT
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE contexts (
|
||||||
|
layer TEXT,
|
||||||
|
timeframe TEXT,
|
||||||
|
key TEXT,
|
||||||
|
value TEXT,
|
||||||
|
created_at TEXT,
|
||||||
|
updated_at TEXT,
|
||||||
|
PRIMARY KEY (layer, timeframe, key)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
return ContextStore(conn)
|
||||||
|
|
||||||
|
def test_select_layers_normal(self, store):
|
||||||
|
"""Test layer selection for normal decisions."""
|
||||||
|
selector = ContextSelector(store)
|
||||||
|
|
||||||
|
layers = selector.select_layers(DecisionType.NORMAL)
|
||||||
|
|
||||||
|
# Should only select L7 (real-time)
|
||||||
|
assert layers == [ContextLayer.L7_REALTIME]
|
||||||
|
|
||||||
|
def test_select_layers_strategic(self, store):
|
||||||
|
"""Test layer selection for strategic decisions."""
|
||||||
|
selector = ContextSelector(store)
|
||||||
|
|
||||||
|
layers = selector.select_layers(DecisionType.STRATEGIC)
|
||||||
|
|
||||||
|
# Should select L7 + L6 + L5
|
||||||
|
assert ContextLayer.L7_REALTIME in layers
|
||||||
|
assert ContextLayer.L6_DAILY in layers
|
||||||
|
assert ContextLayer.L5_WEEKLY in layers
|
||||||
|
assert len(layers) == 3
|
||||||
|
|
||||||
|
def test_select_layers_major_event(self, store):
|
||||||
|
"""Test layer selection for major events."""
|
||||||
|
selector = ContextSelector(store)
|
||||||
|
|
||||||
|
layers = selector.select_layers(DecisionType.MAJOR_EVENT)
|
||||||
|
|
||||||
|
# Should select all layers
|
||||||
|
assert len(layers) == 7
|
||||||
|
assert ContextLayer.L1_LEGACY in layers
|
||||||
|
assert ContextLayer.L7_REALTIME in layers
|
||||||
|
|
||||||
|
def test_score_layer_relevance(self, store):
|
||||||
|
"""Test layer relevance scoring."""
|
||||||
|
selector = ContextSelector(store)
|
||||||
|
|
||||||
|
# Add some data first so scores aren't penalized
|
||||||
|
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
|
||||||
|
store.set_context(ContextLayer.L1_LEGACY, "legacy", "lesson", "test")
|
||||||
|
|
||||||
|
# L7 should have high score for normal decisions
|
||||||
|
score = selector.score_layer_relevance(ContextLayer.L7_REALTIME, DecisionType.NORMAL)
|
||||||
|
assert score == 1.0
|
||||||
|
|
||||||
|
# L1 should have low score for normal decisions
|
||||||
|
score = selector.score_layer_relevance(ContextLayer.L1_LEGACY, DecisionType.NORMAL)
|
||||||
|
assert score == 0.0
|
||||||
|
|
||||||
|
# L1 should have high score for major events
|
||||||
|
score = selector.score_layer_relevance(ContextLayer.L1_LEGACY, DecisionType.MAJOR_EVENT)
|
||||||
|
assert score == 1.0
|
||||||
|
|
||||||
|
def test_select_with_scoring(self, store):
|
||||||
|
"""Test selection with relevance scoring."""
|
||||||
|
selector = ContextSelector(store)
|
||||||
|
|
||||||
|
# Add data so layers aren't penalized
|
||||||
|
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
|
||||||
|
|
||||||
|
selection = selector.select_with_scoring(DecisionType.NORMAL, min_score=0.5)
|
||||||
|
|
||||||
|
# Should only select high-relevance layers
|
||||||
|
assert len(selection.layers) >= 1
|
||||||
|
assert ContextLayer.L7_REALTIME in selection.layers
|
||||||
|
assert all(selection.relevance_scores[layer] >= 0.5 for layer in selection.layers)
|
||||||
|
|
||||||
|
def test_get_context_data(self, store):
|
||||||
|
"""Test context data retrieval."""
|
||||||
|
selector = ContextSelector(store)
|
||||||
|
|
||||||
|
# Add some test data
|
||||||
|
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
|
||||||
|
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "volume", 1000000)
|
||||||
|
|
||||||
|
context_data = selector.get_context_data([ContextLayer.L7_REALTIME])
|
||||||
|
|
||||||
|
# Should retrieve data
|
||||||
|
assert "L7_REALTIME" in context_data
|
||||||
|
assert "price" in context_data["L7_REALTIME"]
|
||||||
|
assert context_data["L7_REALTIME"]["price"] == 100.5
|
||||||
|
|
||||||
|
def test_estimate_context_tokens(self, store):
|
||||||
|
"""Test context token estimation."""
|
||||||
|
selector = ContextSelector(store)
|
||||||
|
|
||||||
|
context_data = {
|
||||||
|
"L7_REALTIME": {"price": 100.5, "volume": 1000000},
|
||||||
|
"L6_DAILY": {"avg_price": 99.8, "avg_volume": 950000},
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens = selector.estimate_context_tokens(context_data)
|
||||||
|
|
||||||
|
# Should estimate tokens
|
||||||
|
assert tokens > 0
|
||||||
|
|
||||||
|
def test_optimize_context_for_budget(self, store):
|
||||||
|
"""Test context optimization for token budget."""
|
||||||
|
selector = ContextSelector(store)
|
||||||
|
|
||||||
|
# Add test data
|
||||||
|
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
|
||||||
|
|
||||||
|
# Get optimized context within budget
|
||||||
|
context = selector.optimize_context_for_budget(DecisionType.NORMAL, max_tokens=50)
|
||||||
|
|
||||||
|
# Should return data within budget
|
||||||
|
tokens = selector.estimate_context_tokens(context)
|
||||||
|
assert tokens <= 50
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Context Summarizer Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextSummarizer:
|
||||||
|
"""Tests for ContextSummarizer."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def store(self):
|
||||||
|
"""Create in-memory ContextStore."""
|
||||||
|
conn = sqlite3.connect(":memory:")
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE context_metadata (
|
||||||
|
layer TEXT PRIMARY KEY,
|
||||||
|
description TEXT,
|
||||||
|
retention_days INTEGER,
|
||||||
|
aggregation_source TEXT
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE contexts (
|
||||||
|
layer TEXT,
|
||||||
|
timeframe TEXT,
|
||||||
|
key TEXT,
|
||||||
|
value TEXT,
|
||||||
|
created_at TEXT,
|
||||||
|
updated_at TEXT,
|
||||||
|
PRIMARY KEY (layer, timeframe, key)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
return ContextStore(conn)
|
||||||
|
|
||||||
|
def test_summarize_numeric_values(self, store):
|
||||||
|
"""Test numeric value summarization."""
|
||||||
|
summarizer = ContextSummarizer(store)
|
||||||
|
|
||||||
|
values = [10.0, 20.0, 30.0, 40.0, 50.0]
|
||||||
|
stats = summarizer.summarize_numeric_values(values)
|
||||||
|
|
||||||
|
assert isinstance(stats, SummaryStats)
|
||||||
|
assert stats.count == 5
|
||||||
|
assert stats.mean == 30.0
|
||||||
|
assert stats.min == 10.0
|
||||||
|
assert stats.max == 50.0
|
||||||
|
assert stats.std is not None
|
||||||
|
|
||||||
|
def test_summarize_numeric_values_trend(self, store):
|
||||||
|
"""Test trend detection in numeric values."""
|
||||||
|
summarizer = ContextSummarizer(store)
|
||||||
|
|
||||||
|
# Uptrend
|
||||||
|
values_up = [10.0, 15.0, 20.0, 25.0, 30.0, 35.0]
|
||||||
|
stats_up = summarizer.summarize_numeric_values(values_up)
|
||||||
|
assert stats_up.trend == "up"
|
||||||
|
|
||||||
|
# Downtrend
|
||||||
|
values_down = [35.0, 30.0, 25.0, 20.0, 15.0, 10.0]
|
||||||
|
stats_down = summarizer.summarize_numeric_values(values_down)
|
||||||
|
assert stats_down.trend == "down"
|
||||||
|
|
||||||
|
# Flat
|
||||||
|
values_flat = [20.0, 20.1, 19.9, 20.0, 20.1, 19.9]
|
||||||
|
stats_flat = summarizer.summarize_numeric_values(values_flat)
|
||||||
|
assert stats_flat.trend == "flat"
|
||||||
|
|
||||||
|
def test_summarize_layer(self, store):
|
||||||
|
"""Test layer summarization."""
|
||||||
|
summarizer = ContextSummarizer(store)
|
||||||
|
|
||||||
|
# Add test data
|
||||||
|
store.set_context(ContextLayer.L6_DAILY, "2026-02-04", "price", 100.5)
|
||||||
|
store.set_context(ContextLayer.L6_DAILY, "2026-02-04", "volume", 1000000)
|
||||||
|
|
||||||
|
summary = summarizer.summarize_layer(ContextLayer.L6_DAILY)
|
||||||
|
|
||||||
|
# Should have summary
|
||||||
|
assert "total_entries" in summary
|
||||||
|
assert summary["total_entries"] > 0
|
||||||
|
|
||||||
|
def test_create_compact_summary(self, store):
|
||||||
|
"""Test compact summary creation."""
|
||||||
|
summarizer = ContextSummarizer(store)
|
||||||
|
|
||||||
|
# Add test data
|
||||||
|
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
|
||||||
|
|
||||||
|
layers = [ContextLayer.L7_REALTIME, ContextLayer.L6_DAILY]
|
||||||
|
summary = summarizer.create_compact_summary(layers, top_n_metrics=3)
|
||||||
|
|
||||||
|
# Should have summaries for layers
|
||||||
|
assert "L7_REALTIME" in summary
|
||||||
|
|
||||||
|
def test_format_summary_for_prompt(self, store):
|
||||||
|
"""Test summary formatting for prompt."""
|
||||||
|
summarizer = ContextSummarizer(store)
|
||||||
|
|
||||||
|
summary = {
|
||||||
|
"L7_REALTIME": {
|
||||||
|
"price": {"avg": 100.5, "trend": "up"},
|
||||||
|
"volume": {"avg": 1000000, "trend": "flat"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
formatted = summarizer.format_summary_for_prompt(summary)
|
||||||
|
|
||||||
|
# Should be formatted string
|
||||||
|
assert isinstance(formatted, str)
|
||||||
|
assert "L7_REALTIME" in formatted
|
||||||
|
assert "100.5" in formatted or "100.50" in formatted
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Decision Cache Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestDecisionCache:
|
||||||
|
"""Tests for DecisionCache."""
|
||||||
|
|
||||||
|
def test_cache_init(self):
|
||||||
|
"""Test cache initialization."""
|
||||||
|
cache = DecisionCache(ttl_seconds=60, max_size=100)
|
||||||
|
|
||||||
|
assert cache.ttl_seconds == 60
|
||||||
|
assert cache.max_size == 100
|
||||||
|
|
||||||
|
def test_cache_miss(self):
|
||||||
|
"""Test cache miss."""
|
||||||
|
cache = DecisionCache()
|
||||||
|
|
||||||
|
market_data = {"stock_code": "005930", "current_price": 75000}
|
||||||
|
|
||||||
|
decision = cache.get(market_data)
|
||||||
|
|
||||||
|
# Should be None (cache miss)
|
||||||
|
assert decision is None
|
||||||
|
|
||||||
|
metrics = cache.get_metrics()
|
||||||
|
assert metrics.cache_misses == 1
|
||||||
|
assert metrics.cache_hits == 0
|
||||||
|
|
||||||
|
def test_cache_hit(self):
|
||||||
|
"""Test cache hit."""
|
||||||
|
cache = DecisionCache()
|
||||||
|
|
||||||
|
market_data = {"stock_code": "005930", "current_price": 75000}
|
||||||
|
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||||
|
|
||||||
|
# Set cache
|
||||||
|
cache.set(market_data, decision)
|
||||||
|
|
||||||
|
# Get from cache
|
||||||
|
cached = cache.get(market_data)
|
||||||
|
|
||||||
|
assert cached is not None
|
||||||
|
assert cached.action == "HOLD"
|
||||||
|
assert cached.confidence == 50
|
||||||
|
|
||||||
|
metrics = cache.get_metrics()
|
||||||
|
assert metrics.cache_hits == 1
|
||||||
|
|
||||||
|
def test_cache_ttl_expiration(self):
|
||||||
|
"""Test cache TTL expiration."""
|
||||||
|
cache = DecisionCache(ttl_seconds=1) # 1 second TTL
|
||||||
|
|
||||||
|
market_data = {"stock_code": "005930", "current_price": 75000}
|
||||||
|
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||||
|
|
||||||
|
# Set cache
|
||||||
|
cache.set(market_data, decision)
|
||||||
|
|
||||||
|
# Should hit immediately
|
||||||
|
cached = cache.get(market_data)
|
||||||
|
assert cached is not None
|
||||||
|
|
||||||
|
# Wait for expiration
|
||||||
|
time.sleep(1.1)
|
||||||
|
|
||||||
|
# Should miss after expiration
|
||||||
|
cached = cache.get(market_data)
|
||||||
|
assert cached is None
|
||||||
|
|
||||||
|
def test_cache_max_size(self):
|
||||||
|
"""Test cache max size eviction."""
|
||||||
|
cache = DecisionCache(max_size=2)
|
||||||
|
|
||||||
|
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||||
|
|
||||||
|
# Add 3 entries (exceeds max_size)
|
||||||
|
for i in range(3):
|
||||||
|
market_data = {"stock_code": f"00{i}", "current_price": 1000 * i}
|
||||||
|
cache.set(market_data, decision)
|
||||||
|
|
||||||
|
metrics = cache.get_metrics()
|
||||||
|
|
||||||
|
# Should have evicted 1 entry
|
||||||
|
assert metrics.total_entries == 2
|
||||||
|
assert metrics.evictions == 1
|
||||||
|
|
||||||
|
def test_invalidate_all(self):
|
||||||
|
"""Test invalidate all cache entries."""
|
||||||
|
cache = DecisionCache()
|
||||||
|
|
||||||
|
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||||
|
|
||||||
|
# Add entries
|
||||||
|
for i in range(3):
|
||||||
|
market_data = {"stock_code": f"00{i}", "current_price": 1000}
|
||||||
|
cache.set(market_data, decision)
|
||||||
|
|
||||||
|
# Invalidate all
|
||||||
|
count = cache.invalidate()
|
||||||
|
|
||||||
|
assert count == 3
|
||||||
|
|
||||||
|
metrics = cache.get_metrics()
|
||||||
|
assert metrics.total_entries == 0
|
||||||
|
|
||||||
|
def test_invalidate_by_stock(self):
|
||||||
|
"""Test invalidate cache by stock code."""
|
||||||
|
cache = DecisionCache()
|
||||||
|
|
||||||
|
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||||
|
|
||||||
|
# Add entries for different stocks
|
||||||
|
cache.set({"stock_code": "005930", "current_price": 75000}, decision)
|
||||||
|
cache.set({"stock_code": "000660", "current_price": 50000}, decision)
|
||||||
|
|
||||||
|
# Invalidate specific stock
|
||||||
|
count = cache.invalidate("005930")
|
||||||
|
|
||||||
|
assert count >= 1
|
||||||
|
|
||||||
|
# Other stock should still be cached
|
||||||
|
cached = cache.get({"stock_code": "000660", "current_price": 50000})
|
||||||
|
assert cached is not None
|
||||||
|
|
||||||
|
def test_cleanup_expired(self):
|
||||||
|
"""Test cleanup of expired entries."""
|
||||||
|
cache = DecisionCache(ttl_seconds=1)
|
||||||
|
|
||||||
|
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||||
|
|
||||||
|
# Add entry
|
||||||
|
cache.set({"stock_code": "005930", "current_price": 75000}, decision)
|
||||||
|
|
||||||
|
# Wait for expiration
|
||||||
|
time.sleep(1.1)
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
count = cache.cleanup_expired()
|
||||||
|
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
metrics = cache.get_metrics()
|
||||||
|
assert metrics.total_entries == 0
|
||||||
|
|
||||||
|
def test_should_cache_decision(self):
|
||||||
|
"""Test decision caching criteria."""
|
||||||
|
cache = DecisionCache()
|
||||||
|
|
||||||
|
# HOLD decisions should be cached
|
||||||
|
hold_decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||||
|
assert cache.should_cache_decision(hold_decision) is True
|
||||||
|
|
||||||
|
# High confidence BUY should be cached
|
||||||
|
buy_decision = TradeDecision(action="BUY", confidence=95, rationale="Test")
|
||||||
|
assert cache.should_cache_decision(buy_decision) is True
|
||||||
|
|
||||||
|
# Low confidence BUY should not be cached
|
||||||
|
low_conf_buy = TradeDecision(action="BUY", confidence=60, rationale="Test")
|
||||||
|
assert cache.should_cache_decision(low_conf_buy) is False
|
||||||
|
|
||||||
|
def test_cache_hit_rate(self):
|
||||||
|
"""Test cache hit rate calculation."""
|
||||||
|
cache = DecisionCache()
|
||||||
|
|
||||||
|
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||||
|
market_data = {"stock_code": "005930", "current_price": 75000}
|
||||||
|
|
||||||
|
# First request (miss)
|
||||||
|
cache.get(market_data)
|
||||||
|
|
||||||
|
# Set cache
|
||||||
|
cache.set(market_data, decision)
|
||||||
|
|
||||||
|
# Second request (hit)
|
||||||
|
cache.get(market_data)
|
||||||
|
|
||||||
|
# Third request (hit)
|
||||||
|
cache.get(market_data)
|
||||||
|
|
||||||
|
metrics = cache.get_metrics()
|
||||||
|
|
||||||
|
# 1 miss, 2 hits out of 3 requests
|
||||||
|
assert metrics.total_requests == 3
|
||||||
|
assert metrics.cache_hits == 2
|
||||||
|
assert metrics.cache_misses == 1
|
||||||
|
assert metrics.hit_rate == pytest.approx(2 / 3)
|
||||||
|
|
||||||
|
def test_reset_metrics(self):
|
||||||
|
"""Test metrics reset."""
|
||||||
|
cache = DecisionCache()
|
||||||
|
|
||||||
|
market_data = {"stock_code": "005930", "current_price": 75000}
|
||||||
|
|
||||||
|
# Generate some activity
|
||||||
|
cache.get(market_data)
|
||||||
|
cache.get(market_data)
|
||||||
|
|
||||||
|
# Reset
|
||||||
|
cache.reset_metrics()
|
||||||
|
|
||||||
|
metrics = cache.get_metrics()
|
||||||
|
assert metrics.total_requests == 0
|
||||||
|
assert metrics.cache_hits == 0
|
||||||
|
assert metrics.cache_misses == 0
|
||||||
Reference in New Issue
Block a user