diff --git a/src/strategy/scenario_engine.py b/src/strategy/scenario_engine.py index 59aadf1..85164b1 100644 --- a/src/strategy/scenario_engine.py +++ b/src/strategy/scenario_engine.py @@ -43,6 +43,9 @@ class ScenarioEngine: Callers must normalize data source keys to match this contract. """ + def __init__(self) -> None: + self._warned_keys: set[str] = set() + @staticmethod def _safe_float(value: Any) -> float | None: """Safely cast a value to float. Returns None on failure.""" @@ -53,6 +56,12 @@ class ScenarioEngine: except (ValueError, TypeError): return None + def _warn_missing_key(self, key: str) -> None: + """Log a missing-key warning once per key per engine instance.""" + if key not in self._warned_keys: + self._warned_keys.add(key) + logger.warning("Condition requires '%s' but key missing from market_data", key) + def evaluate( self, playbook: DayPlaybook, @@ -164,7 +173,7 @@ class ScenarioEngine: rsi = self._safe_float(market_data.get("rsi")) if condition.rsi_below is not None or condition.rsi_above is not None: if "rsi" not in market_data: - logger.warning("Condition requires 'rsi' but key missing from market_data") + self._warn_missing_key("rsi") if condition.rsi_below is not None: checks.append(rsi is not None and rsi < condition.rsi_below) if condition.rsi_above is not None: @@ -173,7 +182,7 @@ class ScenarioEngine: volume_ratio = self._safe_float(market_data.get("volume_ratio")) if condition.volume_ratio_above is not None or condition.volume_ratio_below is not None: if "volume_ratio" not in market_data: - logger.warning("Condition requires 'volume_ratio' but key missing from market_data") + self._warn_missing_key("volume_ratio") if condition.volume_ratio_above is not None: checks.append(volume_ratio is not None and volume_ratio > condition.volume_ratio_above) if condition.volume_ratio_below is not None: @@ -182,7 +191,7 @@ class ScenarioEngine: price = self._safe_float(market_data.get("current_price")) if condition.price_above is not None or condition.price_below is not None: if "current_price" not in market_data: - logger.warning("Condition requires 'current_price' but key missing from market_data") + self._warn_missing_key("current_price") if condition.price_above is not None: checks.append(price is not None and price > condition.price_above) if condition.price_below is not None: @@ -191,7 +200,7 @@ class ScenarioEngine: price_change_pct = self._safe_float(market_data.get("price_change_pct")) if condition.price_change_pct_above is not None or condition.price_change_pct_below is not None: if "price_change_pct" not in market_data: - logger.warning("Condition requires 'price_change_pct' but key missing from market_data") + self._warn_missing_key("price_change_pct") if condition.price_change_pct_above is not None: checks.append(price_change_pct is not None and price_change_pct > condition.price_change_pct_above) if condition.price_change_pct_below is not None: @@ -246,16 +255,16 @@ class ScenarioEngine: condition: StockCondition, market_data: dict[str, Any], ) -> dict[str, Any]: - """Build a summary of which conditions matched and their values.""" + """Build a summary of which conditions matched and their normalized values.""" details: dict[str, Any] = {} if condition.rsi_below is not None or condition.rsi_above is not None: - details["rsi"] = market_data.get("rsi") + details["rsi"] = self._safe_float(market_data.get("rsi")) if condition.volume_ratio_above is not None or condition.volume_ratio_below is not None: - details["volume_ratio"] = market_data.get("volume_ratio") + details["volume_ratio"] = self._safe_float(market_data.get("volume_ratio")) if condition.price_above is not None or condition.price_below is not None: - details["current_price"] = market_data.get("current_price") + details["current_price"] = self._safe_float(market_data.get("current_price")) if condition.price_change_pct_above is not None or condition.price_change_pct_below is not None: - details["price_change_pct"] = market_data.get("price_change_pct") + details["price_change_pct"] = self._safe_float(market_data.get("price_change_pct")) return details diff --git a/tests/test_scenario_engine.py b/tests/test_scenario_engine.py index 59163eb..4d8acfe 100644 --- a/tests/test_scenario_engine.py +++ b/tests/test_scenario_engine.py @@ -174,14 +174,18 @@ class TestEvaluateCondition: # Should return False (invalid types → None → False), never raise assert not engine.evaluate_condition(cond, data) - def test_missing_key_logs_warning(self, engine: ScenarioEngine, caplog) -> None: - """Missing market_data key should log a warning.""" + def test_missing_key_logs_warning_once(self, caplog) -> None: + """Missing key warning should fire only once per key per engine instance.""" import logging + eng = ScenarioEngine() cond = StockCondition(rsi_below=30.0) with caplog.at_level(logging.WARNING): - engine.evaluate_condition(cond, {}) - assert "key missing from market_data" in caplog.text + eng.evaluate_condition(cond, {}) + eng.evaluate_condition(cond, {}) + eng.evaluate_condition(cond, {}) + # Warning should appear exactly once despite 3 calls + assert caplog.text.count("'rsi' but key missing") == 1 # --------------------------------------------------------------------------- @@ -427,3 +431,12 @@ class TestEvaluate: pb = _playbook() result = engine.evaluate(pb, "005930", {"rsi": 25.0}, {}) assert result.stock_code == "005930" + + def test_match_details_normalized(self, engine: ScenarioEngine) -> None: + """match_details should contain _safe_float normalized values, not raw.""" + pb = _playbook(scenarios=[_scenario(rsi_below=30.0)]) + # Pass string value — should be normalized to float in match_details + result = engine.evaluate(pb, "005930", {"rsi": "25.0"}, {}) + assert result.action == ScenarioAction.BUY + assert result.match_details["rsi"] == 25.0 + assert isinstance(result.match_details["rsi"], float)