diff --git a/src/strategy/scenario_engine.py b/src/strategy/scenario_engine.py index bf84740..59aadf1 100644 --- a/src/strategy/scenario_engine.py +++ b/src/strategy/scenario_engine.py @@ -38,8 +38,21 @@ class ScenarioEngine: """Evaluates playbook scenarios against real-time market data. No API calls — pure Python condition matching. + + Expected market_data keys: "rsi", "volume_ratio", "current_price", "price_change_pct". + Callers must normalize data source keys to match this contract. """ + @staticmethod + def _safe_float(value: Any) -> float | None: + """Safely cast a value to float. Returns None on failure.""" + if value is None: + return None + try: + return float(value) + except (ValueError, TypeError): + return None + def evaluate( self, playbook: DayPlaybook, @@ -148,25 +161,37 @@ class ScenarioEngine: checks: list[bool] = [] - rsi = market_data.get("rsi") + rsi = self._safe_float(market_data.get("rsi")) + if condition.rsi_below is not None or condition.rsi_above is not None: + if "rsi" not in market_data: + logger.warning("Condition requires 'rsi' but key missing from market_data") if condition.rsi_below is not None: checks.append(rsi is not None and rsi < condition.rsi_below) if condition.rsi_above is not None: checks.append(rsi is not None and rsi > condition.rsi_above) - volume_ratio = market_data.get("volume_ratio") + volume_ratio = self._safe_float(market_data.get("volume_ratio")) + if condition.volume_ratio_above is not None or condition.volume_ratio_below is not None: + if "volume_ratio" not in market_data: + logger.warning("Condition requires 'volume_ratio' but key missing from market_data") if condition.volume_ratio_above is not None: checks.append(volume_ratio is not None and volume_ratio > condition.volume_ratio_above) if condition.volume_ratio_below is not None: checks.append(volume_ratio is not None and volume_ratio < condition.volume_ratio_below) - price = market_data.get("current_price") + price = self._safe_float(market_data.get("current_price")) + if condition.price_above is not None or condition.price_below is not None: + if "current_price" not in market_data: + logger.warning("Condition requires 'current_price' but key missing from market_data") if condition.price_above is not None: checks.append(price is not None and price > condition.price_above) if condition.price_below is not None: checks.append(price is not None and price < condition.price_below) - price_change_pct = market_data.get("price_change_pct") + price_change_pct = self._safe_float(market_data.get("price_change_pct")) + if condition.price_change_pct_above is not None or condition.price_change_pct_below is not None: + if "price_change_pct" not in market_data: + logger.warning("Condition requires 'price_change_pct' but key missing from market_data") if condition.price_change_pct_above is not None: checks.append(price_change_pct is not None and price_change_pct > condition.price_change_pct_above) if condition.price_change_pct_below is not None: diff --git a/tests/test_scenario_engine.py b/tests/test_scenario_engine.py index e440fa0..59163eb 100644 --- a/tests/test_scenario_engine.py +++ b/tests/test_scenario_engine.py @@ -139,6 +139,50 @@ class TestEvaluateCondition: cond = StockCondition(rsi_above=70.0) assert not engine.evaluate_condition(cond, {"rsi": 70.0}) + def test_string_value_no_exception(self, engine: ScenarioEngine) -> None: + """String numeric value should not raise TypeError.""" + cond = StockCondition(rsi_below=30.0) + # "25" can be cast to float → should match + assert engine.evaluate_condition(cond, {"rsi": "25"}) + # "35" → should not match + assert not engine.evaluate_condition(cond, {"rsi": "35"}) + + def test_percent_string_returns_false(self, engine: ScenarioEngine) -> None: + """Percent string like '30%' cannot be cast to float → False, no exception.""" + cond = StockCondition(rsi_below=30.0) + assert not engine.evaluate_condition(cond, {"rsi": "30%"}) + + def test_decimal_value_no_exception(self, engine: ScenarioEngine) -> None: + """Decimal values should be safely handled.""" + from decimal import Decimal + + cond = StockCondition(rsi_below=30.0) + assert engine.evaluate_condition(cond, {"rsi": Decimal("25.0")}) + + def test_mixed_invalid_types_no_exception(self, engine: ScenarioEngine) -> None: + """Various invalid types should not raise exceptions.""" + cond = StockCondition( + rsi_below=30.0, volume_ratio_above=2.0, + price_above=100, price_change_pct_below=-1.0, + ) + data = { + "rsi": [25], # list + "volume_ratio": "bad", # non-numeric string + "current_price": {}, # dict + "price_change_pct": object(), # arbitrary object + } + # Should return False (invalid types → None → False), never raise + assert not engine.evaluate_condition(cond, data) + + def test_missing_key_logs_warning(self, engine: ScenarioEngine, caplog) -> None: + """Missing market_data key should log a warning.""" + import logging + + cond = StockCondition(rsi_below=30.0) + with caplog.at_level(logging.WARNING): + engine.evaluate_condition(cond, {}) + assert "key missing from market_data" in caplog.text + # --------------------------------------------------------------------------- # check_global_rules