diff --git a/src/strategy/scenario_engine.py b/src/strategy/scenario_engine.py new file mode 100644 index 0000000..bf84740 --- /dev/null +++ b/src/strategy/scenario_engine.py @@ -0,0 +1,236 @@ +"""Local scenario engine for playbook execution. + +Matches real-time market conditions against pre-defined scenarios +without any API calls. Designed for sub-100ms execution. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from typing import Any + +from src.strategy.models import ( + DayPlaybook, + GlobalRule, + ScenarioAction, + StockCondition, + StockScenario, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class ScenarioMatch: + """Result of matching market conditions against scenarios.""" + + stock_code: str + matched_scenario: StockScenario | None + action: ScenarioAction + confidence: int + rationale: str + global_rule_triggered: GlobalRule | None = None + match_details: dict[str, Any] = field(default_factory=dict) + + +class ScenarioEngine: + """Evaluates playbook scenarios against real-time market data. + + No API calls — pure Python condition matching. + """ + + def evaluate( + self, + playbook: DayPlaybook, + stock_code: str, + market_data: dict[str, Any], + portfolio_data: dict[str, Any], + ) -> ScenarioMatch: + """Match market conditions to scenarios and return a decision. + + Algorithm: + 1. Check global rules first (portfolio-level circuit breakers) + 2. Find the StockPlaybook for the given stock_code + 3. Iterate scenarios in order (first match wins) + 4. If no match, return playbook.default_action (HOLD) + + Args: + playbook: Today's DayPlaybook for this market + stock_code: Stock ticker to evaluate + market_data: Real-time market data (price, rsi, volume_ratio, etc.) + portfolio_data: Portfolio state (pnl_pct, total_cash, etc.) + + Returns: + ScenarioMatch with the decision + """ + # 1. Check global rules + triggered_rule = self.check_global_rules(playbook, portfolio_data) + if triggered_rule is not None: + logger.info( + "Global rule triggered for %s: %s -> %s", + stock_code, + triggered_rule.condition, + triggered_rule.action.value, + ) + return ScenarioMatch( + stock_code=stock_code, + matched_scenario=None, + action=triggered_rule.action, + confidence=100, + rationale=f"Global rule: {triggered_rule.rationale or triggered_rule.condition}", + global_rule_triggered=triggered_rule, + ) + + # 2. Find stock playbook + stock_pb = playbook.get_stock_playbook(stock_code) + if stock_pb is None: + logger.debug("No playbook for %s — defaulting to %s", stock_code, playbook.default_action) + return ScenarioMatch( + stock_code=stock_code, + matched_scenario=None, + action=playbook.default_action, + confidence=0, + rationale=f"No scenarios defined for {stock_code}", + ) + + # 3. Iterate scenarios (first match wins) + for scenario in stock_pb.scenarios: + if self.evaluate_condition(scenario.condition, market_data): + logger.info( + "Scenario matched for %s: %s (confidence=%d)", + stock_code, + scenario.action.value, + scenario.confidence, + ) + return ScenarioMatch( + stock_code=stock_code, + matched_scenario=scenario, + action=scenario.action, + confidence=scenario.confidence, + rationale=scenario.rationale, + match_details=self._build_match_details(scenario.condition, market_data), + ) + + # 4. No match — default action + logger.debug("No scenario matched for %s — defaulting to %s", stock_code, playbook.default_action) + return ScenarioMatch( + stock_code=stock_code, + matched_scenario=None, + action=playbook.default_action, + confidence=0, + rationale="No scenario conditions met — holding position", + ) + + def check_global_rules( + self, + playbook: DayPlaybook, + portfolio_data: dict[str, Any], + ) -> GlobalRule | None: + """Check portfolio-level rules. Returns first triggered rule or None.""" + for rule in playbook.global_rules: + if self._evaluate_global_condition(rule.condition, portfolio_data): + return rule + return None + + def evaluate_condition( + self, + condition: StockCondition, + market_data: dict[str, Any], + ) -> bool: + """Evaluate all non-None fields in condition as AND. + + Returns True only if ALL specified conditions are met. + Empty condition (no fields set) returns False for safety. + """ + if not condition.has_any_condition(): + return False + + checks: list[bool] = [] + + rsi = market_data.get("rsi") + if condition.rsi_below is not None: + checks.append(rsi is not None and rsi < condition.rsi_below) + if condition.rsi_above is not None: + checks.append(rsi is not None and rsi > condition.rsi_above) + + volume_ratio = market_data.get("volume_ratio") + if condition.volume_ratio_above is not None: + checks.append(volume_ratio is not None and volume_ratio > condition.volume_ratio_above) + if condition.volume_ratio_below is not None: + checks.append(volume_ratio is not None and volume_ratio < condition.volume_ratio_below) + + price = market_data.get("current_price") + if condition.price_above is not None: + checks.append(price is not None and price > condition.price_above) + if condition.price_below is not None: + checks.append(price is not None and price < condition.price_below) + + price_change_pct = market_data.get("price_change_pct") + if condition.price_change_pct_above is not None: + checks.append(price_change_pct is not None and price_change_pct > condition.price_change_pct_above) + if condition.price_change_pct_below is not None: + checks.append(price_change_pct is not None and price_change_pct < condition.price_change_pct_below) + + return len(checks) > 0 and all(checks) + + def _evaluate_global_condition( + self, + condition_str: str, + portfolio_data: dict[str, Any], + ) -> bool: + """Evaluate a simple global condition string against portfolio data. + + Supports: "field < value", "field > value", "field <= value", "field >= value" + """ + parts = condition_str.strip().split() + if len(parts) != 3: + logger.warning("Invalid global condition format: %s", condition_str) + return False + + field_name, operator, value_str = parts + try: + threshold = float(value_str) + except ValueError: + logger.warning("Invalid threshold in condition: %s", condition_str) + return False + + actual = portfolio_data.get(field_name) + if actual is None: + return False + + try: + actual_val = float(actual) + except (ValueError, TypeError): + return False + + if operator == "<": + return actual_val < threshold + elif operator == ">": + return actual_val > threshold + elif operator == "<=": + return actual_val <= threshold + elif operator == ">=": + return actual_val >= threshold + else: + logger.warning("Unknown operator in condition: %s", operator) + return False + + def _build_match_details( + self, + condition: StockCondition, + market_data: dict[str, Any], + ) -> dict[str, Any]: + """Build a summary of which conditions matched and their values.""" + details: dict[str, Any] = {} + + if condition.rsi_below is not None or condition.rsi_above is not None: + details["rsi"] = market_data.get("rsi") + if condition.volume_ratio_above is not None or condition.volume_ratio_below is not None: + details["volume_ratio"] = market_data.get("volume_ratio") + if condition.price_above is not None or condition.price_below is not None: + details["current_price"] = market_data.get("current_price") + if condition.price_change_pct_above is not None or condition.price_change_pct_below is not None: + details["price_change_pct"] = market_data.get("price_change_pct") + + return details diff --git a/tests/test_scenario_engine.py b/tests/test_scenario_engine.py new file mode 100644 index 0000000..e440fa0 --- /dev/null +++ b/tests/test_scenario_engine.py @@ -0,0 +1,385 @@ +"""Tests for the local scenario engine.""" + +from __future__ import annotations + +from datetime import date + +import pytest + +from src.strategy.models import ( + DayPlaybook, + GlobalRule, + ScenarioAction, + StockCondition, + StockPlaybook, + StockScenario, +) +from src.strategy.scenario_engine import ScenarioEngine, ScenarioMatch + + +@pytest.fixture +def engine() -> ScenarioEngine: + return ScenarioEngine() + + +def _scenario( + rsi_below: float | None = None, + rsi_above: float | None = None, + volume_ratio_above: float | None = None, + action: ScenarioAction = ScenarioAction.BUY, + confidence: int = 85, + **kwargs, +) -> StockScenario: + return StockScenario( + condition=StockCondition( + rsi_below=rsi_below, + rsi_above=rsi_above, + volume_ratio_above=volume_ratio_above, + **kwargs, + ), + action=action, + confidence=confidence, + rationale=f"Test scenario: {action.value}", + ) + + +def _playbook( + stock_code: str = "005930", + scenarios: list[StockScenario] | None = None, + global_rules: list[GlobalRule] | None = None, + default_action: ScenarioAction = ScenarioAction.HOLD, +) -> DayPlaybook: + if scenarios is None: + scenarios = [_scenario(rsi_below=30.0)] + return DayPlaybook( + date=date(2026, 2, 7), + market="KR", + stock_playbooks=[StockPlaybook(stock_code=stock_code, scenarios=scenarios)], + global_rules=global_rules or [], + default_action=default_action, + ) + + +# --------------------------------------------------------------------------- +# evaluate_condition +# --------------------------------------------------------------------------- + + +class TestEvaluateCondition: + def test_rsi_below_match(self, engine: ScenarioEngine) -> None: + cond = StockCondition(rsi_below=30.0) + assert engine.evaluate_condition(cond, {"rsi": 25.0}) + + def test_rsi_below_no_match(self, engine: ScenarioEngine) -> None: + cond = StockCondition(rsi_below=30.0) + assert not engine.evaluate_condition(cond, {"rsi": 35.0}) + + def test_rsi_above_match(self, engine: ScenarioEngine) -> None: + cond = StockCondition(rsi_above=70.0) + assert engine.evaluate_condition(cond, {"rsi": 75.0}) + + def test_rsi_above_no_match(self, engine: ScenarioEngine) -> None: + cond = StockCondition(rsi_above=70.0) + assert not engine.evaluate_condition(cond, {"rsi": 65.0}) + + def test_volume_ratio_above_match(self, engine: ScenarioEngine) -> None: + cond = StockCondition(volume_ratio_above=3.0) + assert engine.evaluate_condition(cond, {"volume_ratio": 4.5}) + + def test_volume_ratio_below_match(self, engine: ScenarioEngine) -> None: + cond = StockCondition(volume_ratio_below=1.0) + assert engine.evaluate_condition(cond, {"volume_ratio": 0.5}) + + def test_price_above_match(self, engine: ScenarioEngine) -> None: + cond = StockCondition(price_above=50000) + assert engine.evaluate_condition(cond, {"current_price": 55000}) + + def test_price_below_match(self, engine: ScenarioEngine) -> None: + cond = StockCondition(price_below=50000) + assert engine.evaluate_condition(cond, {"current_price": 45000}) + + def test_price_change_pct_above_match(self, engine: ScenarioEngine) -> None: + cond = StockCondition(price_change_pct_above=2.0) + assert engine.evaluate_condition(cond, {"price_change_pct": 3.5}) + + def test_price_change_pct_below_match(self, engine: ScenarioEngine) -> None: + cond = StockCondition(price_change_pct_below=-3.0) + assert engine.evaluate_condition(cond, {"price_change_pct": -4.0}) + + def test_multiple_conditions_and_logic(self, engine: ScenarioEngine) -> None: + cond = StockCondition(rsi_below=30.0, volume_ratio_above=3.0) + # Both met + assert engine.evaluate_condition(cond, {"rsi": 25.0, "volume_ratio": 4.0}) + # Only RSI met + assert not engine.evaluate_condition(cond, {"rsi": 25.0, "volume_ratio": 2.0}) + # Only volume met + assert not engine.evaluate_condition(cond, {"rsi": 35.0, "volume_ratio": 4.0}) + # Neither met + assert not engine.evaluate_condition(cond, {"rsi": 35.0, "volume_ratio": 2.0}) + + def test_empty_condition_returns_false(self, engine: ScenarioEngine) -> None: + cond = StockCondition() + assert not engine.evaluate_condition(cond, {"rsi": 25.0}) + + def test_missing_data_returns_false(self, engine: ScenarioEngine) -> None: + cond = StockCondition(rsi_below=30.0) + assert not engine.evaluate_condition(cond, {}) + + def test_none_data_returns_false(self, engine: ScenarioEngine) -> None: + cond = StockCondition(rsi_below=30.0) + assert not engine.evaluate_condition(cond, {"rsi": None}) + + def test_boundary_value_not_matched(self, engine: ScenarioEngine) -> None: + """rsi_below=30 should NOT match rsi=30 (strict less than).""" + cond = StockCondition(rsi_below=30.0) + assert not engine.evaluate_condition(cond, {"rsi": 30.0}) + + def test_boundary_value_above_not_matched(self, engine: ScenarioEngine) -> None: + """rsi_above=70 should NOT match rsi=70 (strict greater than).""" + cond = StockCondition(rsi_above=70.0) + assert not engine.evaluate_condition(cond, {"rsi": 70.0}) + + +# --------------------------------------------------------------------------- +# check_global_rules +# --------------------------------------------------------------------------- + + +class TestCheckGlobalRules: + def test_no_rules(self, engine: ScenarioEngine) -> None: + pb = _playbook(global_rules=[]) + result = engine.check_global_rules(pb, {"portfolio_pnl_pct": -1.0}) + assert result is None + + def test_rule_triggered(self, engine: ScenarioEngine) -> None: + pb = _playbook( + global_rules=[ + GlobalRule( + condition="portfolio_pnl_pct < -2.0", + action=ScenarioAction.REDUCE_ALL, + rationale="Near circuit breaker", + ), + ] + ) + result = engine.check_global_rules(pb, {"portfolio_pnl_pct": -2.5}) + assert result is not None + assert result.action == ScenarioAction.REDUCE_ALL + + def test_rule_not_triggered(self, engine: ScenarioEngine) -> None: + pb = _playbook( + global_rules=[ + GlobalRule( + condition="portfolio_pnl_pct < -2.0", + action=ScenarioAction.REDUCE_ALL, + ), + ] + ) + result = engine.check_global_rules(pb, {"portfolio_pnl_pct": -1.0}) + assert result is None + + def test_first_rule_wins(self, engine: ScenarioEngine) -> None: + pb = _playbook( + global_rules=[ + GlobalRule(condition="portfolio_pnl_pct < -2.0", action=ScenarioAction.REDUCE_ALL), + GlobalRule(condition="portfolio_pnl_pct < -1.0", action=ScenarioAction.HOLD), + ] + ) + result = engine.check_global_rules(pb, {"portfolio_pnl_pct": -2.5}) + assert result is not None + assert result.action == ScenarioAction.REDUCE_ALL + + def test_greater_than_operator(self, engine: ScenarioEngine) -> None: + pb = _playbook( + global_rules=[ + GlobalRule(condition="volatility_index > 30", action=ScenarioAction.HOLD), + ] + ) + result = engine.check_global_rules(pb, {"volatility_index": 35}) + assert result is not None + + def test_missing_field_not_triggered(self, engine: ScenarioEngine) -> None: + pb = _playbook( + global_rules=[ + GlobalRule(condition="unknown_field < -2.0", action=ScenarioAction.REDUCE_ALL), + ] + ) + result = engine.check_global_rules(pb, {"portfolio_pnl_pct": -5.0}) + assert result is None + + def test_invalid_condition_format(self, engine: ScenarioEngine) -> None: + pb = _playbook( + global_rules=[ + GlobalRule(condition="bad format", action=ScenarioAction.HOLD), + ] + ) + result = engine.check_global_rules(pb, {}) + assert result is None + + def test_le_operator(self, engine: ScenarioEngine) -> None: + pb = _playbook( + global_rules=[ + GlobalRule(condition="portfolio_pnl_pct <= -2.0", action=ScenarioAction.REDUCE_ALL), + ] + ) + assert engine.check_global_rules(pb, {"portfolio_pnl_pct": -2.0}) is not None + assert engine.check_global_rules(pb, {"portfolio_pnl_pct": -1.9}) is None + + def test_ge_operator(self, engine: ScenarioEngine) -> None: + pb = _playbook( + global_rules=[ + GlobalRule(condition="volatility >= 80.0", action=ScenarioAction.HOLD), + ] + ) + assert engine.check_global_rules(pb, {"volatility": 80.0}) is not None + assert engine.check_global_rules(pb, {"volatility": 79.9}) is None + + +# --------------------------------------------------------------------------- +# evaluate (full pipeline) +# --------------------------------------------------------------------------- + + +class TestEvaluate: + def test_scenario_match(self, engine: ScenarioEngine) -> None: + pb = _playbook(scenarios=[_scenario(rsi_below=30.0)]) + result = engine.evaluate(pb, "005930", {"rsi": 25.0}, {}) + assert result.action == ScenarioAction.BUY + assert result.confidence == 85 + assert result.matched_scenario is not None + + def test_no_scenario_match_returns_default(self, engine: ScenarioEngine) -> None: + pb = _playbook(scenarios=[_scenario(rsi_below=30.0)]) + result = engine.evaluate(pb, "005930", {"rsi": 50.0}, {}) + assert result.action == ScenarioAction.HOLD + assert result.confidence == 0 + assert result.matched_scenario is None + + def test_stock_not_in_playbook(self, engine: ScenarioEngine) -> None: + pb = _playbook(stock_code="005930") + result = engine.evaluate(pb, "AAPL", {"rsi": 25.0}, {}) + assert result.action == ScenarioAction.HOLD + assert result.confidence == 0 + + def test_global_rule_takes_priority(self, engine: ScenarioEngine) -> None: + pb = _playbook( + scenarios=[_scenario(rsi_below=30.0)], + global_rules=[ + GlobalRule( + condition="portfolio_pnl_pct < -2.0", + action=ScenarioAction.REDUCE_ALL, + rationale="Loss limit", + ), + ], + ) + result = engine.evaluate( + pb, + "005930", + {"rsi": 25.0}, # Would match scenario + {"portfolio_pnl_pct": -2.5}, # But global rule triggers first + ) + assert result.action == ScenarioAction.REDUCE_ALL + assert result.global_rule_triggered is not None + assert result.matched_scenario is None + + def test_first_scenario_wins(self, engine: ScenarioEngine) -> None: + pb = _playbook( + scenarios=[ + _scenario(rsi_below=30.0, action=ScenarioAction.BUY, confidence=90), + _scenario(rsi_below=25.0, action=ScenarioAction.BUY, confidence=95), + ] + ) + result = engine.evaluate(pb, "005930", {"rsi": 20.0}, {}) + # Both match, but first wins + assert result.confidence == 90 + + def test_sell_scenario(self, engine: ScenarioEngine) -> None: + pb = _playbook( + scenarios=[ + _scenario(rsi_above=75.0, action=ScenarioAction.SELL, confidence=80), + ] + ) + result = engine.evaluate(pb, "005930", {"rsi": 80.0}, {}) + assert result.action == ScenarioAction.SELL + + def test_empty_playbook(self, engine: ScenarioEngine) -> None: + pb = DayPlaybook(date=date(2026, 2, 7), market="KR", stock_playbooks=[]) + result = engine.evaluate(pb, "005930", {"rsi": 25.0}, {}) + assert result.action == ScenarioAction.HOLD + + def test_match_details_populated(self, engine: ScenarioEngine) -> None: + pb = _playbook(scenarios=[_scenario(rsi_below=30.0, volume_ratio_above=2.0)]) + result = engine.evaluate( + pb, "005930", {"rsi": 25.0, "volume_ratio": 3.0}, {} + ) + assert result.match_details.get("rsi") == 25.0 + assert result.match_details.get("volume_ratio") == 3.0 + + def test_custom_default_action(self, engine: ScenarioEngine) -> None: + pb = _playbook( + scenarios=[_scenario(rsi_below=10.0)], # Very unlikely to match + default_action=ScenarioAction.SELL, + ) + result = engine.evaluate(pb, "005930", {"rsi": 50.0}, {}) + assert result.action == ScenarioAction.SELL + + def test_multiple_stocks_in_playbook(self, engine: ScenarioEngine) -> None: + pb = DayPlaybook( + date=date(2026, 2, 7), + market="US", + stock_playbooks=[ + StockPlaybook( + stock_code="AAPL", + scenarios=[_scenario(rsi_below=25.0, confidence=90)], + ), + StockPlaybook( + stock_code="MSFT", + scenarios=[_scenario(rsi_above=75.0, action=ScenarioAction.SELL, confidence=80)], + ), + ], + ) + aapl = engine.evaluate(pb, "AAPL", {"rsi": 20.0}, {}) + assert aapl.action == ScenarioAction.BUY + assert aapl.confidence == 90 + + msft = engine.evaluate(pb, "MSFT", {"rsi": 80.0}, {}) + assert msft.action == ScenarioAction.SELL + + def test_complex_multi_condition(self, engine: ScenarioEngine) -> None: + pb = _playbook( + scenarios=[ + _scenario( + rsi_below=30.0, + volume_ratio_above=3.0, + price_change_pct_below=-2.0, + confidence=95, + ), + ] + ) + # All conditions met + result = engine.evaluate( + pb, + "005930", + {"rsi": 22.0, "volume_ratio": 4.0, "price_change_pct": -3.0}, + {}, + ) + assert result.action == ScenarioAction.BUY + assert result.confidence == 95 + + # One condition not met + result2 = engine.evaluate( + pb, + "005930", + {"rsi": 22.0, "volume_ratio": 4.0, "price_change_pct": -1.0}, + {}, + ) + assert result2.action == ScenarioAction.HOLD + + def test_scenario_match_returns_rationale(self, engine: ScenarioEngine) -> None: + pb = _playbook(scenarios=[_scenario(rsi_below=30.0)]) + result = engine.evaluate(pb, "005930", {"rsi": 25.0}, {}) + assert result.rationale != "" + + def test_result_stock_code(self, engine: ScenarioEngine) -> None: + pb = _playbook() + result = engine.evaluate(pb, "005930", {"rsi": 25.0}, {}) + assert result.stock_code == "005930"