Files
The-Ouroboros/src/strategy/models.py
agentson 60a22d6cd4
Some checks failed
CI / test (pull_request) Has been cancelled
feat: add position-aware conditions to StockCondition (#171)
- Add unrealized_pnl_pct_above/below and holding_days_above/below fields
  to StockCondition so AI can generate rules like 'P&L > 3% → SELL'
- Evaluate new fields in ScenarioEngine.evaluate_condition() with same
  AND-combining logic as existing technical indicator fields
- Include position fields in _build_match_details() for audit logging
- Parse new fields from AI JSON response in PreMarketPlanner._parse_scenario()
- Update prompt schema example to show new position-aware condition fields
- Add 13 tests covering all new condition combinations and edge cases

Closes #171

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-02-20 08:27:44 +09:00

185 lines
6.1 KiB
Python

"""Pydantic models for pre-market scenario planning.
Defines the data contracts for the proactive strategy system:
- AI generates DayPlaybook before market open (structured JSON scenarios)
- Local ScenarioEngine matches conditions during market hours (no API calls)
"""
from __future__ import annotations
from datetime import UTC, date, datetime
from enum import Enum
from pydantic import BaseModel, Field, field_validator
class ScenarioAction(str, Enum):
"""Actions that can be taken by scenarios."""
BUY = "BUY"
SELL = "SELL"
HOLD = "HOLD"
REDUCE_ALL = "REDUCE_ALL"
class MarketOutlook(str, Enum):
"""AI's assessment of market direction."""
BULLISH = "bullish"
NEUTRAL_TO_BULLISH = "neutral_to_bullish"
NEUTRAL = "neutral"
NEUTRAL_TO_BEARISH = "neutral_to_bearish"
BEARISH = "bearish"
class PlaybookStatus(str, Enum):
"""Lifecycle status of a playbook."""
PENDING = "pending"
READY = "ready"
FAILED = "failed"
EXPIRED = "expired"
class StockCondition(BaseModel):
"""Condition fields for scenario matching (all optional, AND-combined).
The ScenarioEngine evaluates all non-None fields as AND conditions.
A condition matches only if ALL specified fields are satisfied.
Technical indicator fields:
rsi_below / rsi_above — RSI threshold
volume_ratio_above / volume_ratio_below — volume vs previous day
price_above / price_below — absolute price level
price_change_pct_above / price_change_pct_below — intraday % change
Position-aware fields (require market_data enrichment from open position):
unrealized_pnl_pct_above — matches if unrealized P&L > threshold (e.g. 3.0 → +3%)
unrealized_pnl_pct_below — matches if unrealized P&L < threshold (e.g. -2.0 → -2%)
holding_days_above — matches if position held for more than N days
holding_days_below — matches if position held for fewer than N days
"""
rsi_below: float | None = None
rsi_above: float | None = None
volume_ratio_above: float | None = None
volume_ratio_below: float | None = None
price_above: float | None = None
price_below: float | None = None
price_change_pct_above: float | None = None
price_change_pct_below: float | None = None
unrealized_pnl_pct_above: float | None = None
unrealized_pnl_pct_below: float | None = None
holding_days_above: int | None = None
holding_days_below: int | None = None
def has_any_condition(self) -> bool:
"""Check if at least one condition field is set."""
return any(
v is not None
for v in (
self.rsi_below,
self.rsi_above,
self.volume_ratio_above,
self.volume_ratio_below,
self.price_above,
self.price_below,
self.price_change_pct_above,
self.price_change_pct_below,
self.unrealized_pnl_pct_above,
self.unrealized_pnl_pct_below,
self.holding_days_above,
self.holding_days_below,
)
)
class StockScenario(BaseModel):
"""A single condition-action rule for one stock."""
condition: StockCondition
action: ScenarioAction
confidence: int = Field(ge=0, le=100)
allocation_pct: float = Field(ge=0, le=100, default=10.0)
stop_loss_pct: float = Field(le=0, default=-2.0)
take_profit_pct: float = Field(ge=0, default=3.0)
rationale: str = ""
class StockPlaybook(BaseModel):
"""All scenarios for a single stock (ordered by priority)."""
stock_code: str
stock_name: str = ""
scenarios: list[StockScenario] = Field(min_length=1)
class GlobalRule(BaseModel):
"""Portfolio-level rule (checked before stock-level scenarios)."""
condition: str # e.g. "portfolio_pnl_pct < -2.0"
action: ScenarioAction
rationale: str = ""
class CrossMarketContext(BaseModel):
"""Summary of another market's state for cross-market awareness."""
market: str # e.g. "US" or "KR"
date: str
total_pnl: float = 0.0
win_rate: float = 0.0
index_change_pct: float = 0.0 # e.g. KOSPI or S&P500 change
key_events: list[str] = Field(default_factory=list)
lessons: list[str] = Field(default_factory=list)
class DayPlaybook(BaseModel):
"""Complete playbook for a single trading day in a single market.
Generated by PreMarketPlanner (1 Gemini call per market per day).
Consumed by ScenarioEngine during market hours (0 API calls).
"""
date: date
market: str # "KR" or "US"
market_outlook: MarketOutlook = MarketOutlook.NEUTRAL
generated_at: str = "" # ISO timestamp
gemini_model: str = ""
token_count: int = 0
global_rules: list[GlobalRule] = Field(default_factory=list)
stock_playbooks: list[StockPlaybook] = Field(default_factory=list)
default_action: ScenarioAction = ScenarioAction.HOLD
context_summary: dict = Field(default_factory=dict)
cross_market: CrossMarketContext | None = None
@field_validator("stock_playbooks")
@classmethod
def validate_unique_stocks(cls, v: list[StockPlaybook]) -> list[StockPlaybook]:
codes = [pb.stock_code for pb in v]
if len(codes) != len(set(codes)):
raise ValueError("Duplicate stock codes in playbook")
return v
def get_stock_playbook(self, stock_code: str) -> StockPlaybook | None:
"""Find the playbook for a specific stock."""
for pb in self.stock_playbooks:
if pb.stock_code == stock_code:
return pb
return None
@property
def scenario_count(self) -> int:
"""Total number of scenarios across all stocks."""
return sum(len(pb.scenarios) for pb in self.stock_playbooks)
@property
def stock_count(self) -> int:
"""Number of stocks with scenarios."""
return len(self.stock_playbooks)
def model_post_init(self, __context: object) -> None:
"""Set generated_at if not provided."""
if not self.generated_at:
self.generated_at = datetime.now(UTC).isoformat()