Some checks failed
CI / test (pull_request) Has been cancelled
- Add unrealized_pnl_pct_above/below and holding_days_above/below fields to StockCondition so AI can generate rules like 'P&L > 3% → SELL' - Evaluate new fields in ScenarioEngine.evaluate_condition() with same AND-combining logic as existing technical indicator fields - Include position fields in _build_match_details() for audit logging - Parse new fields from AI JSON response in PreMarketPlanner._parse_scenario() - Update prompt schema example to show new position-aware condition fields - Add 13 tests covering all new condition combinations and edge cases Closes #171 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
185 lines
6.1 KiB
Python
185 lines
6.1 KiB
Python
"""Pydantic models for pre-market scenario planning.
|
|
|
|
Defines the data contracts for the proactive strategy system:
|
|
- AI generates DayPlaybook before market open (structured JSON scenarios)
|
|
- Local ScenarioEngine matches conditions during market hours (no API calls)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from datetime import UTC, date, datetime
|
|
from enum import Enum
|
|
|
|
from pydantic import BaseModel, Field, field_validator
|
|
|
|
|
|
class ScenarioAction(str, Enum):
|
|
"""Actions that can be taken by scenarios."""
|
|
|
|
BUY = "BUY"
|
|
SELL = "SELL"
|
|
HOLD = "HOLD"
|
|
REDUCE_ALL = "REDUCE_ALL"
|
|
|
|
|
|
class MarketOutlook(str, Enum):
|
|
"""AI's assessment of market direction."""
|
|
|
|
BULLISH = "bullish"
|
|
NEUTRAL_TO_BULLISH = "neutral_to_bullish"
|
|
NEUTRAL = "neutral"
|
|
NEUTRAL_TO_BEARISH = "neutral_to_bearish"
|
|
BEARISH = "bearish"
|
|
|
|
|
|
class PlaybookStatus(str, Enum):
|
|
"""Lifecycle status of a playbook."""
|
|
|
|
PENDING = "pending"
|
|
READY = "ready"
|
|
FAILED = "failed"
|
|
EXPIRED = "expired"
|
|
|
|
|
|
class StockCondition(BaseModel):
|
|
"""Condition fields for scenario matching (all optional, AND-combined).
|
|
|
|
The ScenarioEngine evaluates all non-None fields as AND conditions.
|
|
A condition matches only if ALL specified fields are satisfied.
|
|
|
|
Technical indicator fields:
|
|
rsi_below / rsi_above — RSI threshold
|
|
volume_ratio_above / volume_ratio_below — volume vs previous day
|
|
price_above / price_below — absolute price level
|
|
price_change_pct_above / price_change_pct_below — intraday % change
|
|
|
|
Position-aware fields (require market_data enrichment from open position):
|
|
unrealized_pnl_pct_above — matches if unrealized P&L > threshold (e.g. 3.0 → +3%)
|
|
unrealized_pnl_pct_below — matches if unrealized P&L < threshold (e.g. -2.0 → -2%)
|
|
holding_days_above — matches if position held for more than N days
|
|
holding_days_below — matches if position held for fewer than N days
|
|
"""
|
|
|
|
rsi_below: float | None = None
|
|
rsi_above: float | None = None
|
|
volume_ratio_above: float | None = None
|
|
volume_ratio_below: float | None = None
|
|
price_above: float | None = None
|
|
price_below: float | None = None
|
|
price_change_pct_above: float | None = None
|
|
price_change_pct_below: float | None = None
|
|
unrealized_pnl_pct_above: float | None = None
|
|
unrealized_pnl_pct_below: float | None = None
|
|
holding_days_above: int | None = None
|
|
holding_days_below: int | None = None
|
|
|
|
def has_any_condition(self) -> bool:
|
|
"""Check if at least one condition field is set."""
|
|
return any(
|
|
v is not None
|
|
for v in (
|
|
self.rsi_below,
|
|
self.rsi_above,
|
|
self.volume_ratio_above,
|
|
self.volume_ratio_below,
|
|
self.price_above,
|
|
self.price_below,
|
|
self.price_change_pct_above,
|
|
self.price_change_pct_below,
|
|
self.unrealized_pnl_pct_above,
|
|
self.unrealized_pnl_pct_below,
|
|
self.holding_days_above,
|
|
self.holding_days_below,
|
|
)
|
|
)
|
|
|
|
|
|
class StockScenario(BaseModel):
|
|
"""A single condition-action rule for one stock."""
|
|
|
|
condition: StockCondition
|
|
action: ScenarioAction
|
|
confidence: int = Field(ge=0, le=100)
|
|
allocation_pct: float = Field(ge=0, le=100, default=10.0)
|
|
stop_loss_pct: float = Field(le=0, default=-2.0)
|
|
take_profit_pct: float = Field(ge=0, default=3.0)
|
|
rationale: str = ""
|
|
|
|
|
|
class StockPlaybook(BaseModel):
|
|
"""All scenarios for a single stock (ordered by priority)."""
|
|
|
|
stock_code: str
|
|
stock_name: str = ""
|
|
scenarios: list[StockScenario] = Field(min_length=1)
|
|
|
|
|
|
class GlobalRule(BaseModel):
|
|
"""Portfolio-level rule (checked before stock-level scenarios)."""
|
|
|
|
condition: str # e.g. "portfolio_pnl_pct < -2.0"
|
|
action: ScenarioAction
|
|
rationale: str = ""
|
|
|
|
|
|
class CrossMarketContext(BaseModel):
|
|
"""Summary of another market's state for cross-market awareness."""
|
|
|
|
market: str # e.g. "US" or "KR"
|
|
date: str
|
|
total_pnl: float = 0.0
|
|
win_rate: float = 0.0
|
|
index_change_pct: float = 0.0 # e.g. KOSPI or S&P500 change
|
|
key_events: list[str] = Field(default_factory=list)
|
|
lessons: list[str] = Field(default_factory=list)
|
|
|
|
|
|
class DayPlaybook(BaseModel):
|
|
"""Complete playbook for a single trading day in a single market.
|
|
|
|
Generated by PreMarketPlanner (1 Gemini call per market per day).
|
|
Consumed by ScenarioEngine during market hours (0 API calls).
|
|
"""
|
|
|
|
date: date
|
|
market: str # "KR" or "US"
|
|
market_outlook: MarketOutlook = MarketOutlook.NEUTRAL
|
|
generated_at: str = "" # ISO timestamp
|
|
gemini_model: str = ""
|
|
token_count: int = 0
|
|
global_rules: list[GlobalRule] = Field(default_factory=list)
|
|
stock_playbooks: list[StockPlaybook] = Field(default_factory=list)
|
|
default_action: ScenarioAction = ScenarioAction.HOLD
|
|
context_summary: dict = Field(default_factory=dict)
|
|
cross_market: CrossMarketContext | None = None
|
|
|
|
@field_validator("stock_playbooks")
|
|
@classmethod
|
|
def validate_unique_stocks(cls, v: list[StockPlaybook]) -> list[StockPlaybook]:
|
|
codes = [pb.stock_code for pb in v]
|
|
if len(codes) != len(set(codes)):
|
|
raise ValueError("Duplicate stock codes in playbook")
|
|
return v
|
|
|
|
def get_stock_playbook(self, stock_code: str) -> StockPlaybook | None:
|
|
"""Find the playbook for a specific stock."""
|
|
for pb in self.stock_playbooks:
|
|
if pb.stock_code == stock_code:
|
|
return pb
|
|
return None
|
|
|
|
@property
|
|
def scenario_count(self) -> int:
|
|
"""Total number of scenarios across all stocks."""
|
|
return sum(len(pb.scenarios) for pb in self.stock_playbooks)
|
|
|
|
@property
|
|
def stock_count(self) -> int:
|
|
"""Number of stocks with scenarios."""
|
|
return len(self.stock_playbooks)
|
|
|
|
def model_post_init(self, __context: object) -> None:
|
|
"""Set generated_at if not provided."""
|
|
if not self.generated_at:
|
|
self.generated_at = datetime.now(UTC).isoformat()
|