Merge main into feature/issue-22-data-driven
Some checks failed
CI / test (pull_request) Has been cancelled
Some checks failed
CI / test (pull_request) Has been cancelled
This commit is contained in:
296
src/brain/context_selector.py
Normal file
296
src/brain/context_selector.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""Smart context selection for optimizing token usage.
|
||||
|
||||
This module implements intelligent selection of context layers (L1-L7) based on
|
||||
decision type and market conditions:
|
||||
- L7 (real-time) for normal trading decisions
|
||||
- L6-L5 (daily/weekly) for strategic decisions
|
||||
- L4-L1 (monthly/legacy) only for major events or policy changes
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from src.context.layer import ContextLayer
|
||||
from src.context.store import ContextStore
|
||||
|
||||
|
||||
class DecisionType(str, Enum):
|
||||
"""Type of trading decision being made."""
|
||||
|
||||
NORMAL = "normal" # Regular trade decision
|
||||
STRATEGIC = "strategic" # Strategy adjustment
|
||||
MAJOR_EVENT = "major_event" # Portfolio rebalancing, policy change
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ContextSelection:
|
||||
"""Selected context layers and their relevance scores."""
|
||||
|
||||
layers: list[ContextLayer]
|
||||
relevance_scores: dict[ContextLayer, float]
|
||||
total_score: float
|
||||
|
||||
|
||||
class ContextSelector:
|
||||
"""Selects optimal context layers to minimize token usage."""
|
||||
|
||||
def __init__(self, store: ContextStore) -> None:
|
||||
"""Initialize the context selector.
|
||||
|
||||
Args:
|
||||
store: ContextStore instance for retrieving context data
|
||||
"""
|
||||
self.store = store
|
||||
|
||||
def select_layers(
|
||||
self,
|
||||
decision_type: DecisionType = DecisionType.NORMAL,
|
||||
include_realtime: bool = True,
|
||||
) -> list[ContextLayer]:
|
||||
"""Select context layers based on decision type.
|
||||
|
||||
Strategy:
|
||||
- NORMAL: L7 (real-time) only
|
||||
- STRATEGIC: L7 + L6 + L5 (real-time + daily + weekly)
|
||||
- MAJOR_EVENT: All layers L1-L7
|
||||
|
||||
Args:
|
||||
decision_type: Type of decision being made
|
||||
include_realtime: Whether to include L7 real-time data
|
||||
|
||||
Returns:
|
||||
List of context layers to use (ordered by priority)
|
||||
"""
|
||||
if decision_type == DecisionType.NORMAL:
|
||||
# Normal trading: only real-time data
|
||||
return [ContextLayer.L7_REALTIME] if include_realtime else []
|
||||
|
||||
elif decision_type == DecisionType.STRATEGIC:
|
||||
# Strategic decisions: real-time + recent history
|
||||
layers = []
|
||||
if include_realtime:
|
||||
layers.append(ContextLayer.L7_REALTIME)
|
||||
layers.extend([ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY])
|
||||
return layers
|
||||
|
||||
else: # MAJOR_EVENT
|
||||
# Major events: all layers for comprehensive context
|
||||
layers = []
|
||||
if include_realtime:
|
||||
layers.append(ContextLayer.L7_REALTIME)
|
||||
layers.extend(
|
||||
[
|
||||
ContextLayer.L6_DAILY,
|
||||
ContextLayer.L5_WEEKLY,
|
||||
ContextLayer.L4_MONTHLY,
|
||||
ContextLayer.L3_QUARTERLY,
|
||||
ContextLayer.L2_ANNUAL,
|
||||
ContextLayer.L1_LEGACY,
|
||||
]
|
||||
)
|
||||
return layers
|
||||
|
||||
def score_layer_relevance(
|
||||
self,
|
||||
layer: ContextLayer,
|
||||
decision_type: DecisionType,
|
||||
current_time: datetime | None = None,
|
||||
) -> float:
|
||||
"""Calculate relevance score for a context layer.
|
||||
|
||||
Relevance is based on:
|
||||
1. Decision type (normal, strategic, major event)
|
||||
2. Layer recency (L7 > L6 > ... > L1)
|
||||
3. Data availability
|
||||
|
||||
Args:
|
||||
layer: Context layer to score
|
||||
decision_type: Type of decision being made
|
||||
current_time: Current time (defaults to now)
|
||||
|
||||
Returns:
|
||||
Relevance score (0.0 to 1.0)
|
||||
"""
|
||||
if current_time is None:
|
||||
current_time = datetime.now(UTC)
|
||||
|
||||
# Base scores by decision type
|
||||
base_scores = {
|
||||
DecisionType.NORMAL: {
|
||||
ContextLayer.L7_REALTIME: 1.0,
|
||||
ContextLayer.L6_DAILY: 0.1,
|
||||
ContextLayer.L5_WEEKLY: 0.05,
|
||||
ContextLayer.L4_MONTHLY: 0.01,
|
||||
ContextLayer.L3_QUARTERLY: 0.0,
|
||||
ContextLayer.L2_ANNUAL: 0.0,
|
||||
ContextLayer.L1_LEGACY: 0.0,
|
||||
},
|
||||
DecisionType.STRATEGIC: {
|
||||
ContextLayer.L7_REALTIME: 0.9,
|
||||
ContextLayer.L6_DAILY: 0.8,
|
||||
ContextLayer.L5_WEEKLY: 0.7,
|
||||
ContextLayer.L4_MONTHLY: 0.3,
|
||||
ContextLayer.L3_QUARTERLY: 0.2,
|
||||
ContextLayer.L2_ANNUAL: 0.1,
|
||||
ContextLayer.L1_LEGACY: 0.05,
|
||||
},
|
||||
DecisionType.MAJOR_EVENT: {
|
||||
ContextLayer.L7_REALTIME: 0.7,
|
||||
ContextLayer.L6_DAILY: 0.7,
|
||||
ContextLayer.L5_WEEKLY: 0.7,
|
||||
ContextLayer.L4_MONTHLY: 0.8,
|
||||
ContextLayer.L3_QUARTERLY: 0.8,
|
||||
ContextLayer.L2_ANNUAL: 0.9,
|
||||
ContextLayer.L1_LEGACY: 1.0,
|
||||
},
|
||||
}
|
||||
|
||||
score = base_scores[decision_type].get(layer, 0.0)
|
||||
|
||||
# Check data availability
|
||||
latest_timeframe = self.store.get_latest_timeframe(layer)
|
||||
if latest_timeframe is None:
|
||||
# No data available - reduce score significantly
|
||||
score *= 0.1
|
||||
|
||||
return score
|
||||
|
||||
def select_with_scoring(
|
||||
self,
|
||||
decision_type: DecisionType = DecisionType.NORMAL,
|
||||
min_score: float = 0.5,
|
||||
) -> ContextSelection:
|
||||
"""Select context layers with relevance scoring.
|
||||
|
||||
Args:
|
||||
decision_type: Type of decision being made
|
||||
min_score: Minimum relevance score to include a layer
|
||||
|
||||
Returns:
|
||||
ContextSelection with selected layers and scores
|
||||
"""
|
||||
all_layers = [
|
||||
ContextLayer.L7_REALTIME,
|
||||
ContextLayer.L6_DAILY,
|
||||
ContextLayer.L5_WEEKLY,
|
||||
ContextLayer.L4_MONTHLY,
|
||||
ContextLayer.L3_QUARTERLY,
|
||||
ContextLayer.L2_ANNUAL,
|
||||
ContextLayer.L1_LEGACY,
|
||||
]
|
||||
|
||||
scores = {
|
||||
layer: self.score_layer_relevance(layer, decision_type) for layer in all_layers
|
||||
}
|
||||
|
||||
# Filter by minimum score
|
||||
selected_layers = [layer for layer, score in scores.items() if score >= min_score]
|
||||
|
||||
# Sort by score (descending)
|
||||
selected_layers.sort(key=lambda layer: scores[layer], reverse=True)
|
||||
|
||||
total_score = sum(scores[layer] for layer in selected_layers)
|
||||
|
||||
return ContextSelection(
|
||||
layers=selected_layers,
|
||||
relevance_scores=scores,
|
||||
total_score=total_score,
|
||||
)
|
||||
|
||||
def get_context_data(
|
||||
self,
|
||||
layers: list[ContextLayer],
|
||||
max_items_per_layer: int = 10,
|
||||
) -> dict[str, Any]:
|
||||
"""Retrieve context data for selected layers.
|
||||
|
||||
Args:
|
||||
layers: List of context layers to retrieve
|
||||
max_items_per_layer: Maximum number of items per layer
|
||||
|
||||
Returns:
|
||||
Dictionary with context data organized by layer
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
|
||||
for layer in layers:
|
||||
# Get latest timeframe for this layer
|
||||
latest_timeframe = self.store.get_latest_timeframe(layer)
|
||||
if latest_timeframe:
|
||||
# Get all contexts for latest timeframe
|
||||
contexts = self.store.get_all_contexts(layer, latest_timeframe)
|
||||
|
||||
# Limit number of items
|
||||
if len(contexts) > max_items_per_layer:
|
||||
# Keep only first N items
|
||||
contexts = dict(list(contexts.items())[:max_items_per_layer])
|
||||
|
||||
result[layer.value] = contexts
|
||||
|
||||
return result
|
||||
|
||||
def estimate_context_tokens(self, context_data: dict[str, Any]) -> int:
|
||||
"""Estimate total tokens for context data.
|
||||
|
||||
Args:
|
||||
context_data: Context data dictionary
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
import json
|
||||
|
||||
from src.brain.prompt_optimizer import PromptOptimizer
|
||||
|
||||
# Serialize to JSON and estimate tokens
|
||||
json_str = json.dumps(context_data, ensure_ascii=False)
|
||||
return PromptOptimizer.estimate_tokens(json_str)
|
||||
|
||||
def optimize_context_for_budget(
|
||||
self,
|
||||
decision_type: DecisionType,
|
||||
max_tokens: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Select and retrieve context data within a token budget.
|
||||
|
||||
Args:
|
||||
decision_type: Type of decision being made
|
||||
max_tokens: Maximum token budget for context
|
||||
|
||||
Returns:
|
||||
Optimized context data within budget
|
||||
"""
|
||||
# Start with minimal selection
|
||||
selection = self.select_with_scoring(decision_type, min_score=0.5)
|
||||
|
||||
# Retrieve data
|
||||
context_data = self.get_context_data(selection.layers)
|
||||
|
||||
# Check if within budget
|
||||
estimated_tokens = self.estimate_context_tokens(context_data)
|
||||
|
||||
if estimated_tokens <= max_tokens:
|
||||
return context_data
|
||||
|
||||
# If over budget, progressively reduce
|
||||
# 1. Reduce items per layer
|
||||
for max_items in [5, 3, 1]:
|
||||
context_data = self.get_context_data(selection.layers, max_items)
|
||||
estimated_tokens = self.estimate_context_tokens(context_data)
|
||||
if estimated_tokens <= max_tokens:
|
||||
return context_data
|
||||
|
||||
# 2. Remove lower-priority layers
|
||||
for min_score in [0.6, 0.7, 0.8, 0.9]:
|
||||
selection = self.select_with_scoring(decision_type, min_score=min_score)
|
||||
context_data = self.get_context_data(selection.layers, max_items_per_layer=1)
|
||||
estimated_tokens = self.estimate_context_tokens(context_data)
|
||||
if estimated_tokens <= max_tokens:
|
||||
return context_data
|
||||
|
||||
# Last resort: return only L7 with minimal data
|
||||
return self.get_context_data([ContextLayer.L7_REALTIME], max_items_per_layer=1)
|
||||
@@ -7,7 +7,12 @@ Includes token efficiency optimizations:
|
||||
- Prompt compression and abbreviation
|
||||
- Response caching for common scenarios
|
||||
- Smart context selection
|
||||
- Token usage tracking
|
||||
- Token usage tracking and metrics
|
||||
|
||||
Includes external data integration:
|
||||
- News sentiment analysis
|
||||
- Economic calendar events
|
||||
- Market indicators
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -15,7 +20,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from google import genai
|
||||
|
||||
267
src/brain/prompt_optimizer.py
Normal file
267
src/brain/prompt_optimizer.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""Prompt optimization utilities for reducing token usage.
|
||||
|
||||
This module provides tools to compress prompts while maintaining decision quality:
|
||||
- Token counting
|
||||
- Text compression and abbreviation
|
||||
- Template-based prompts with variable slots
|
||||
- Priority-based context truncation
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
# Abbreviation mapping for common terms
|
||||
ABBREVIATIONS = {
|
||||
"price": "P",
|
||||
"volume": "V",
|
||||
"current": "cur",
|
||||
"previous": "prev",
|
||||
"change": "chg",
|
||||
"percentage": "pct",
|
||||
"market": "mkt",
|
||||
"orderbook": "ob",
|
||||
"foreigner": "fgn",
|
||||
"buy": "B",
|
||||
"sell": "S",
|
||||
"hold": "H",
|
||||
"confidence": "conf",
|
||||
"rationale": "reason",
|
||||
"action": "act",
|
||||
"net": "net",
|
||||
}
|
||||
|
||||
# Reverse mapping for decompression
|
||||
REVERSE_ABBREVIATIONS = {v: k for k, v in ABBREVIATIONS.items()}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TokenMetrics:
|
||||
"""Metrics about token usage in a prompt."""
|
||||
|
||||
char_count: int
|
||||
word_count: int
|
||||
estimated_tokens: int # Rough estimate: ~4 chars per token
|
||||
compression_ratio: float = 1.0 # Original / Compressed
|
||||
|
||||
|
||||
class PromptOptimizer:
|
||||
"""Optimizes prompts to reduce token usage while maintaining quality."""
|
||||
|
||||
@staticmethod
|
||||
def estimate_tokens(text: str) -> int:
|
||||
"""Estimate token count for text.
|
||||
|
||||
Uses a simple heuristic: ~4 characters per token for English.
|
||||
This is approximate but sufficient for optimization purposes.
|
||||
|
||||
Args:
|
||||
text: Input text to estimate tokens for
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
# Simple estimate: 1 token ≈ 4 characters
|
||||
return max(1, len(text) // 4)
|
||||
|
||||
@staticmethod
|
||||
def count_tokens(text: str) -> TokenMetrics:
|
||||
"""Count various metrics for a text.
|
||||
|
||||
Args:
|
||||
text: Input text to analyze
|
||||
|
||||
Returns:
|
||||
TokenMetrics with character, word, and estimated token counts
|
||||
"""
|
||||
char_count = len(text)
|
||||
word_count = len(text.split())
|
||||
estimated_tokens = PromptOptimizer.estimate_tokens(text)
|
||||
|
||||
return TokenMetrics(
|
||||
char_count=char_count,
|
||||
word_count=word_count,
|
||||
estimated_tokens=estimated_tokens,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def compress_json(data: dict[str, Any]) -> str:
|
||||
"""Compress JSON by removing whitespace.
|
||||
|
||||
Args:
|
||||
data: Dictionary to serialize
|
||||
|
||||
Returns:
|
||||
Compact JSON string without whitespace
|
||||
"""
|
||||
return json.dumps(data, separators=(",", ":"), ensure_ascii=False)
|
||||
|
||||
@staticmethod
|
||||
def abbreviate_text(text: str, aggressive: bool = False) -> str:
|
||||
"""Apply abbreviations to reduce text length.
|
||||
|
||||
Args:
|
||||
text: Input text to abbreviate
|
||||
aggressive: If True, apply more aggressive compression
|
||||
|
||||
Returns:
|
||||
Abbreviated text
|
||||
"""
|
||||
result = text
|
||||
|
||||
# Apply word-level abbreviations (case-insensitive)
|
||||
for full, abbr in ABBREVIATIONS.items():
|
||||
# Word boundaries to avoid partial replacements
|
||||
pattern = r"\b" + re.escape(full) + r"\b"
|
||||
result = re.sub(pattern, abbr, result, flags=re.IGNORECASE)
|
||||
|
||||
if aggressive:
|
||||
# Remove articles and filler words
|
||||
result = re.sub(r"\b(a|an|the)\b", "", result, flags=re.IGNORECASE)
|
||||
result = re.sub(r"\b(is|are|was|were)\b", "", result, flags=re.IGNORECASE)
|
||||
# Collapse multiple spaces
|
||||
result = re.sub(r"\s+", " ", result)
|
||||
|
||||
return result.strip()
|
||||
|
||||
@staticmethod
|
||||
def build_compressed_prompt(
|
||||
market_data: dict[str, Any],
|
||||
include_instructions: bool = True,
|
||||
max_length: int | None = None,
|
||||
) -> str:
|
||||
"""Build a compressed prompt from market data.
|
||||
|
||||
Args:
|
||||
market_data: Market data dictionary with stock info
|
||||
include_instructions: Whether to include full instructions
|
||||
max_length: Maximum character length (truncates if needed)
|
||||
|
||||
Returns:
|
||||
Compressed prompt string
|
||||
"""
|
||||
# Abbreviated market name
|
||||
market_name = market_data.get("market_name", "KR")
|
||||
if "Korea" in market_name:
|
||||
market_name = "KR"
|
||||
elif "United States" in market_name or "US" in market_name:
|
||||
market_name = "US"
|
||||
|
||||
# Core data - always included
|
||||
core_info = {
|
||||
"mkt": market_name,
|
||||
"code": market_data["stock_code"],
|
||||
"P": market_data["current_price"],
|
||||
}
|
||||
|
||||
# Optional fields
|
||||
if "orderbook" in market_data and market_data["orderbook"]:
|
||||
ob = market_data["orderbook"]
|
||||
# Compress orderbook: keep only top 3 levels
|
||||
compressed_ob = {
|
||||
"bid": ob.get("bid", [])[:3],
|
||||
"ask": ob.get("ask", [])[:3],
|
||||
}
|
||||
core_info["ob"] = compressed_ob
|
||||
|
||||
if market_data.get("foreigner_net", 0) != 0:
|
||||
core_info["fgn_net"] = market_data["foreigner_net"]
|
||||
|
||||
# Compress to JSON
|
||||
data_str = PromptOptimizer.compress_json(core_info)
|
||||
|
||||
if include_instructions:
|
||||
# Minimal instructions
|
||||
prompt = (
|
||||
f"{market_name} trader. Analyze:\n{data_str}\n\n"
|
||||
'Return JSON: {"act":"BUY"|"SELL"|"HOLD","conf":<0-100>,"reason":"<text>"}\n'
|
||||
"Rules: act=BUY/SELL/HOLD, conf=0-100, reason=concise. No markdown."
|
||||
)
|
||||
else:
|
||||
# Data only (for cached contexts where instructions are known)
|
||||
prompt = data_str
|
||||
|
||||
# Truncate if needed
|
||||
if max_length and len(prompt) > max_length:
|
||||
prompt = prompt[:max_length] + "..."
|
||||
|
||||
return prompt
|
||||
|
||||
@staticmethod
|
||||
def truncate_context(
|
||||
context: dict[str, Any],
|
||||
max_tokens: int,
|
||||
priority_keys: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Truncate context data to fit within token budget.
|
||||
|
||||
Keeps high-priority keys first, then truncates less important data.
|
||||
|
||||
Args:
|
||||
context: Context dictionary to truncate
|
||||
max_tokens: Maximum token budget
|
||||
priority_keys: List of keys to keep (in order of priority)
|
||||
|
||||
Returns:
|
||||
Truncated context dictionary
|
||||
"""
|
||||
if not context:
|
||||
return {}
|
||||
|
||||
if priority_keys is None:
|
||||
priority_keys = []
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
current_tokens = 0
|
||||
|
||||
# Add priority keys first
|
||||
for key in priority_keys:
|
||||
if key in context:
|
||||
value_str = json.dumps(context[key])
|
||||
tokens = PromptOptimizer.estimate_tokens(value_str)
|
||||
|
||||
if current_tokens + tokens <= max_tokens:
|
||||
result[key] = context[key]
|
||||
current_tokens += tokens
|
||||
else:
|
||||
break
|
||||
|
||||
# Add remaining keys if space available
|
||||
for key, value in context.items():
|
||||
if key in result:
|
||||
continue
|
||||
|
||||
value_str = json.dumps(value)
|
||||
tokens = PromptOptimizer.estimate_tokens(value_str)
|
||||
|
||||
if current_tokens + tokens <= max_tokens:
|
||||
result[key] = value
|
||||
current_tokens += tokens
|
||||
else:
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def calculate_compression_ratio(original: str, compressed: str) -> float:
|
||||
"""Calculate compression ratio between original and compressed text.
|
||||
|
||||
Args:
|
||||
original: Original text
|
||||
compressed: Compressed text
|
||||
|
||||
Returns:
|
||||
Compression ratio (original_tokens / compressed_tokens)
|
||||
"""
|
||||
original_tokens = PromptOptimizer.estimate_tokens(original)
|
||||
compressed_tokens = PromptOptimizer.estimate_tokens(compressed)
|
||||
|
||||
if compressed_tokens == 0:
|
||||
return 1.0
|
||||
|
||||
return original_tokens / compressed_tokens
|
||||
Reference in New Issue
Block a user