Merge main into feature/issue-22-data-driven
Some checks failed
CI / test (pull_request) Has been cancelled

This commit is contained in:
agentson
2026-02-04 18:41:44 +09:00
8 changed files with 1566 additions and 8 deletions

View File

@@ -11,15 +11,15 @@ from __future__ import annotations
import json
import sqlite3
import tempfile
from datetime import UTC, datetime, timedelta
from datetime import UTC, datetime
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, Mock, patch
from unittest.mock import AsyncMock, Mock, patch
import pytest
from src.config import Settings
from src.db import init_db, log_trade
from src.evolution.ab_test import ABTester, ABTestResult, StrategyPerformance
from src.evolution.ab_test import ABTester
from src.evolution.optimizer import EvolutionOptimizer
from src.evolution.performance_tracker import (
PerformanceDashboard,
@@ -28,7 +28,6 @@ from src.evolution.performance_tracker import (
)
from src.logging.decision_logger import DecisionLogger
# ------------------------------------------------------------------
# Fixtures
# ------------------------------------------------------------------

View File

@@ -0,0 +1,663 @@
"""Tests for token efficiency optimization components.
Tests cover:
- Prompt compression and optimization
- Context selection logic
- Summarization
- Caching
- Token reduction metrics
"""
from __future__ import annotations
import sqlite3
import time
import pytest
from src.brain.cache import DecisionCache
from src.brain.context_selector import ContextSelector, DecisionType
from src.brain.gemini_client import TradeDecision
from src.brain.prompt_optimizer import PromptOptimizer, TokenMetrics
from src.context.layer import ContextLayer
from src.context.store import ContextStore
from src.context.summarizer import ContextSummarizer, SummaryStats
# ============================================================================
# Prompt Optimizer Tests
# ============================================================================
class TestPromptOptimizer:
"""Tests for PromptOptimizer."""
def test_estimate_tokens(self):
"""Test token estimation."""
optimizer = PromptOptimizer()
# Empty text
assert optimizer.estimate_tokens("") == 0
# Short text (4 chars = 1 token estimate)
assert optimizer.estimate_tokens("test") == 1
# Longer text
text = "This is a longer piece of text for testing token estimation."
tokens = optimizer.estimate_tokens(text)
assert tokens > 0
assert tokens == len(text) // 4
def test_count_tokens(self):
"""Test token counting metrics."""
optimizer = PromptOptimizer()
text = "Hello world, this is a test."
metrics = optimizer.count_tokens(text)
assert isinstance(metrics, TokenMetrics)
assert metrics.char_count == len(text)
assert metrics.word_count == 6
assert metrics.estimated_tokens > 0
def test_compress_json(self):
"""Test JSON compression."""
optimizer = PromptOptimizer()
data = {
"action": "BUY",
"confidence": 85,
"rationale": "Strong uptrend",
}
compressed = optimizer.compress_json(data)
# Should have no newlines and minimal whitespace
assert "\n" not in compressed
# Note: JSON values may contain spaces (e.g., "Strong uptrend")
# but there should be no spaces around separators
assert ": " not in compressed
assert ", " not in compressed
# Should be valid JSON
import json
parsed = json.loads(compressed)
assert parsed == data
def test_abbreviate_text(self):
"""Test text abbreviation."""
optimizer = PromptOptimizer()
text = "The current price is high and volume is increasing."
abbreviated = optimizer.abbreviate_text(text)
# Should contain abbreviations
assert "cur" in abbreviated or "P" in abbreviated
assert len(abbreviated) <= len(text)
def test_abbreviate_text_aggressive(self):
"""Test aggressive text abbreviation."""
optimizer = PromptOptimizer()
text = "The price is increasing and the volume is high."
abbreviated = optimizer.abbreviate_text(text, aggressive=True)
# Should be shorter
assert len(abbreviated) < len(text)
# Should have removed articles
assert "the" not in abbreviated.lower()
def test_build_compressed_prompt(self):
"""Test compressed prompt building."""
optimizer = PromptOptimizer()
market_data = {
"stock_code": "005930",
"current_price": 75000,
"market_name": "Korean stock market",
}
prompt = optimizer.build_compressed_prompt(market_data)
# Should be much shorter than original
assert len(prompt) < 300
assert "005930" in prompt
assert "75000" in prompt
def test_build_compressed_prompt_no_instructions(self):
"""Test compressed prompt without instructions."""
optimizer = PromptOptimizer()
market_data = {
"stock_code": "AAPL",
"current_price": 150.5,
"market_name": "United States",
}
prompt = optimizer.build_compressed_prompt(market_data, include_instructions=False)
# Should be very short (data only)
assert len(prompt) < 100
assert "AAPL" in prompt
def test_truncate_context(self):
"""Test context truncation."""
optimizer = PromptOptimizer()
context = {
"price": 100.5,
"volume": 1000000,
"sentiment": 0.8,
"extra_data": "Some long text that should be truncated",
}
# Truncate to small budget
truncated = optimizer.truncate_context(context, max_tokens=10)
# Should have fewer keys
assert len(truncated) <= len(context)
def test_truncate_context_with_priority(self):
"""Test context truncation with priority keys."""
optimizer = PromptOptimizer()
context = {
"price": 100.5,
"volume": 1000000,
"sentiment": 0.8,
"extra_data": "Some data",
}
priority_keys = ["price", "sentiment"]
truncated = optimizer.truncate_context(context, max_tokens=20, priority_keys=priority_keys)
# Priority keys should be included
assert "price" in truncated
assert "sentiment" in truncated
def test_calculate_compression_ratio(self):
"""Test compression ratio calculation."""
optimizer = PromptOptimizer()
original = "This is a very long piece of text that should be compressed significantly."
compressed = "Short text"
ratio = optimizer.calculate_compression_ratio(original, compressed)
# Ratio should be > 1 (original is longer)
assert ratio > 1.0
# ============================================================================
# Context Selector Tests
# ============================================================================
class TestContextSelector:
"""Tests for ContextSelector."""
@pytest.fixture
def store(self):
"""Create in-memory ContextStore."""
conn = sqlite3.connect(":memory:")
# Create tables
conn.execute(
"""
CREATE TABLE context_metadata (
layer TEXT PRIMARY KEY,
description TEXT,
retention_days INTEGER,
aggregation_source TEXT
)
"""
)
conn.execute(
"""
CREATE TABLE contexts (
layer TEXT,
timeframe TEXT,
key TEXT,
value TEXT,
created_at TEXT,
updated_at TEXT,
PRIMARY KEY (layer, timeframe, key)
)
"""
)
conn.commit()
return ContextStore(conn)
def test_select_layers_normal(self, store):
"""Test layer selection for normal decisions."""
selector = ContextSelector(store)
layers = selector.select_layers(DecisionType.NORMAL)
# Should only select L7 (real-time)
assert layers == [ContextLayer.L7_REALTIME]
def test_select_layers_strategic(self, store):
"""Test layer selection for strategic decisions."""
selector = ContextSelector(store)
layers = selector.select_layers(DecisionType.STRATEGIC)
# Should select L7 + L6 + L5
assert ContextLayer.L7_REALTIME in layers
assert ContextLayer.L6_DAILY in layers
assert ContextLayer.L5_WEEKLY in layers
assert len(layers) == 3
def test_select_layers_major_event(self, store):
"""Test layer selection for major events."""
selector = ContextSelector(store)
layers = selector.select_layers(DecisionType.MAJOR_EVENT)
# Should select all layers
assert len(layers) == 7
assert ContextLayer.L1_LEGACY in layers
assert ContextLayer.L7_REALTIME in layers
def test_score_layer_relevance(self, store):
"""Test layer relevance scoring."""
selector = ContextSelector(store)
# Add some data first so scores aren't penalized
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
store.set_context(ContextLayer.L1_LEGACY, "legacy", "lesson", "test")
# L7 should have high score for normal decisions
score = selector.score_layer_relevance(ContextLayer.L7_REALTIME, DecisionType.NORMAL)
assert score == 1.0
# L1 should have low score for normal decisions
score = selector.score_layer_relevance(ContextLayer.L1_LEGACY, DecisionType.NORMAL)
assert score == 0.0
# L1 should have high score for major events
score = selector.score_layer_relevance(ContextLayer.L1_LEGACY, DecisionType.MAJOR_EVENT)
assert score == 1.0
def test_select_with_scoring(self, store):
"""Test selection with relevance scoring."""
selector = ContextSelector(store)
# Add data so layers aren't penalized
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
selection = selector.select_with_scoring(DecisionType.NORMAL, min_score=0.5)
# Should only select high-relevance layers
assert len(selection.layers) >= 1
assert ContextLayer.L7_REALTIME in selection.layers
assert all(selection.relevance_scores[layer] >= 0.5 for layer in selection.layers)
def test_get_context_data(self, store):
"""Test context data retrieval."""
selector = ContextSelector(store)
# Add some test data
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "volume", 1000000)
context_data = selector.get_context_data([ContextLayer.L7_REALTIME])
# Should retrieve data
assert "L7_REALTIME" in context_data
assert "price" in context_data["L7_REALTIME"]
assert context_data["L7_REALTIME"]["price"] == 100.5
def test_estimate_context_tokens(self, store):
"""Test context token estimation."""
selector = ContextSelector(store)
context_data = {
"L7_REALTIME": {"price": 100.5, "volume": 1000000},
"L6_DAILY": {"avg_price": 99.8, "avg_volume": 950000},
}
tokens = selector.estimate_context_tokens(context_data)
# Should estimate tokens
assert tokens > 0
def test_optimize_context_for_budget(self, store):
"""Test context optimization for token budget."""
selector = ContextSelector(store)
# Add test data
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
# Get optimized context within budget
context = selector.optimize_context_for_budget(DecisionType.NORMAL, max_tokens=50)
# Should return data within budget
tokens = selector.estimate_context_tokens(context)
assert tokens <= 50
# ============================================================================
# Context Summarizer Tests
# ============================================================================
class TestContextSummarizer:
"""Tests for ContextSummarizer."""
@pytest.fixture
def store(self):
"""Create in-memory ContextStore."""
conn = sqlite3.connect(":memory:")
conn.execute(
"""
CREATE TABLE context_metadata (
layer TEXT PRIMARY KEY,
description TEXT,
retention_days INTEGER,
aggregation_source TEXT
)
"""
)
conn.execute(
"""
CREATE TABLE contexts (
layer TEXT,
timeframe TEXT,
key TEXT,
value TEXT,
created_at TEXT,
updated_at TEXT,
PRIMARY KEY (layer, timeframe, key)
)
"""
)
conn.commit()
return ContextStore(conn)
def test_summarize_numeric_values(self, store):
"""Test numeric value summarization."""
summarizer = ContextSummarizer(store)
values = [10.0, 20.0, 30.0, 40.0, 50.0]
stats = summarizer.summarize_numeric_values(values)
assert isinstance(stats, SummaryStats)
assert stats.count == 5
assert stats.mean == 30.0
assert stats.min == 10.0
assert stats.max == 50.0
assert stats.std is not None
def test_summarize_numeric_values_trend(self, store):
"""Test trend detection in numeric values."""
summarizer = ContextSummarizer(store)
# Uptrend
values_up = [10.0, 15.0, 20.0, 25.0, 30.0, 35.0]
stats_up = summarizer.summarize_numeric_values(values_up)
assert stats_up.trend == "up"
# Downtrend
values_down = [35.0, 30.0, 25.0, 20.0, 15.0, 10.0]
stats_down = summarizer.summarize_numeric_values(values_down)
assert stats_down.trend == "down"
# Flat
values_flat = [20.0, 20.1, 19.9, 20.0, 20.1, 19.9]
stats_flat = summarizer.summarize_numeric_values(values_flat)
assert stats_flat.trend == "flat"
def test_summarize_layer(self, store):
"""Test layer summarization."""
summarizer = ContextSummarizer(store)
# Add test data
store.set_context(ContextLayer.L6_DAILY, "2026-02-04", "price", 100.5)
store.set_context(ContextLayer.L6_DAILY, "2026-02-04", "volume", 1000000)
summary = summarizer.summarize_layer(ContextLayer.L6_DAILY)
# Should have summary
assert "total_entries" in summary
assert summary["total_entries"] > 0
def test_create_compact_summary(self, store):
"""Test compact summary creation."""
summarizer = ContextSummarizer(store)
# Add test data
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
layers = [ContextLayer.L7_REALTIME, ContextLayer.L6_DAILY]
summary = summarizer.create_compact_summary(layers, top_n_metrics=3)
# Should have summaries for layers
assert "L7_REALTIME" in summary
def test_format_summary_for_prompt(self, store):
"""Test summary formatting for prompt."""
summarizer = ContextSummarizer(store)
summary = {
"L7_REALTIME": {
"price": {"avg": 100.5, "trend": "up"},
"volume": {"avg": 1000000, "trend": "flat"},
}
}
formatted = summarizer.format_summary_for_prompt(summary)
# Should be formatted string
assert isinstance(formatted, str)
assert "L7_REALTIME" in formatted
assert "100.5" in formatted or "100.50" in formatted
# ============================================================================
# Decision Cache Tests
# ============================================================================
class TestDecisionCache:
"""Tests for DecisionCache."""
def test_cache_init(self):
"""Test cache initialization."""
cache = DecisionCache(ttl_seconds=60, max_size=100)
assert cache.ttl_seconds == 60
assert cache.max_size == 100
def test_cache_miss(self):
"""Test cache miss."""
cache = DecisionCache()
market_data = {"stock_code": "005930", "current_price": 75000}
decision = cache.get(market_data)
# Should be None (cache miss)
assert decision is None
metrics = cache.get_metrics()
assert metrics.cache_misses == 1
assert metrics.cache_hits == 0
def test_cache_hit(self):
"""Test cache hit."""
cache = DecisionCache()
market_data = {"stock_code": "005930", "current_price": 75000}
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
# Set cache
cache.set(market_data, decision)
# Get from cache
cached = cache.get(market_data)
assert cached is not None
assert cached.action == "HOLD"
assert cached.confidence == 50
metrics = cache.get_metrics()
assert metrics.cache_hits == 1
def test_cache_ttl_expiration(self):
"""Test cache TTL expiration."""
cache = DecisionCache(ttl_seconds=1) # 1 second TTL
market_data = {"stock_code": "005930", "current_price": 75000}
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
# Set cache
cache.set(market_data, decision)
# Should hit immediately
cached = cache.get(market_data)
assert cached is not None
# Wait for expiration
time.sleep(1.1)
# Should miss after expiration
cached = cache.get(market_data)
assert cached is None
def test_cache_max_size(self):
"""Test cache max size eviction."""
cache = DecisionCache(max_size=2)
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
# Add 3 entries (exceeds max_size)
for i in range(3):
market_data = {"stock_code": f"00{i}", "current_price": 1000 * i}
cache.set(market_data, decision)
metrics = cache.get_metrics()
# Should have evicted 1 entry
assert metrics.total_entries == 2
assert metrics.evictions == 1
def test_invalidate_all(self):
"""Test invalidate all cache entries."""
cache = DecisionCache()
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
# Add entries
for i in range(3):
market_data = {"stock_code": f"00{i}", "current_price": 1000}
cache.set(market_data, decision)
# Invalidate all
count = cache.invalidate()
assert count == 3
metrics = cache.get_metrics()
assert metrics.total_entries == 0
def test_invalidate_by_stock(self):
"""Test invalidate cache by stock code."""
cache = DecisionCache()
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
# Add entries for different stocks
cache.set({"stock_code": "005930", "current_price": 75000}, decision)
cache.set({"stock_code": "000660", "current_price": 50000}, decision)
# Invalidate specific stock
count = cache.invalidate("005930")
assert count >= 1
# Other stock should still be cached
cached = cache.get({"stock_code": "000660", "current_price": 50000})
assert cached is not None
def test_cleanup_expired(self):
"""Test cleanup of expired entries."""
cache = DecisionCache(ttl_seconds=1)
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
# Add entry
cache.set({"stock_code": "005930", "current_price": 75000}, decision)
# Wait for expiration
time.sleep(1.1)
# Cleanup
count = cache.cleanup_expired()
assert count == 1
metrics = cache.get_metrics()
assert metrics.total_entries == 0
def test_should_cache_decision(self):
"""Test decision caching criteria."""
cache = DecisionCache()
# HOLD decisions should be cached
hold_decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
assert cache.should_cache_decision(hold_decision) is True
# High confidence BUY should be cached
buy_decision = TradeDecision(action="BUY", confidence=95, rationale="Test")
assert cache.should_cache_decision(buy_decision) is True
# Low confidence BUY should not be cached
low_conf_buy = TradeDecision(action="BUY", confidence=60, rationale="Test")
assert cache.should_cache_decision(low_conf_buy) is False
def test_cache_hit_rate(self):
"""Test cache hit rate calculation."""
cache = DecisionCache()
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
market_data = {"stock_code": "005930", "current_price": 75000}
# First request (miss)
cache.get(market_data)
# Set cache
cache.set(market_data, decision)
# Second request (hit)
cache.get(market_data)
# Third request (hit)
cache.get(market_data)
metrics = cache.get_metrics()
# 1 miss, 2 hits out of 3 requests
assert metrics.total_requests == 3
assert metrics.cache_hits == 2
assert metrics.cache_misses == 1
assert metrics.hit_rate == pytest.approx(2 / 3)
def test_reset_metrics(self):
"""Test metrics reset."""
cache = DecisionCache()
market_data = {"stock_code": "005930", "current_price": 75000}
# Generate some activity
cache.get(market_data)
cache.get(market_data)
# Reset
cache.reset_metrics()
metrics = cache.get_metrics()
assert metrics.total_requests == 0
assert metrics.cache_hits == 0
assert metrics.cache_misses == 0