fix: add safe type casting and missing-key warnings in ScenarioEngine
Some checks failed
CI / test (pull_request) Has been cancelled
Some checks failed
CI / test (pull_request) Has been cancelled
Addresses PR #102 review findings: - _safe_float() prevents TypeError from str/Decimal/invalid market_data values - Warning logs when condition references a key missing from market_data - 5 new tests: string, percent string, Decimal, mixed invalid types, log check Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -38,8 +38,21 @@ class ScenarioEngine:
|
|||||||
"""Evaluates playbook scenarios against real-time market data.
|
"""Evaluates playbook scenarios against real-time market data.
|
||||||
|
|
||||||
No API calls — pure Python condition matching.
|
No API calls — pure Python condition matching.
|
||||||
|
|
||||||
|
Expected market_data keys: "rsi", "volume_ratio", "current_price", "price_change_pct".
|
||||||
|
Callers must normalize data source keys to match this contract.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _safe_float(value: Any) -> float | None:
|
||||||
|
"""Safely cast a value to float. Returns None on failure."""
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return float(value)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return None
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
playbook: DayPlaybook,
|
playbook: DayPlaybook,
|
||||||
@@ -148,25 +161,37 @@ class ScenarioEngine:
|
|||||||
|
|
||||||
checks: list[bool] = []
|
checks: list[bool] = []
|
||||||
|
|
||||||
rsi = market_data.get("rsi")
|
rsi = self._safe_float(market_data.get("rsi"))
|
||||||
|
if condition.rsi_below is not None or condition.rsi_above is not None:
|
||||||
|
if "rsi" not in market_data:
|
||||||
|
logger.warning("Condition requires 'rsi' but key missing from market_data")
|
||||||
if condition.rsi_below is not None:
|
if condition.rsi_below is not None:
|
||||||
checks.append(rsi is not None and rsi < condition.rsi_below)
|
checks.append(rsi is not None and rsi < condition.rsi_below)
|
||||||
if condition.rsi_above is not None:
|
if condition.rsi_above is not None:
|
||||||
checks.append(rsi is not None and rsi > condition.rsi_above)
|
checks.append(rsi is not None and rsi > condition.rsi_above)
|
||||||
|
|
||||||
volume_ratio = market_data.get("volume_ratio")
|
volume_ratio = self._safe_float(market_data.get("volume_ratio"))
|
||||||
|
if condition.volume_ratio_above is not None or condition.volume_ratio_below is not None:
|
||||||
|
if "volume_ratio" not in market_data:
|
||||||
|
logger.warning("Condition requires 'volume_ratio' but key missing from market_data")
|
||||||
if condition.volume_ratio_above is not None:
|
if condition.volume_ratio_above is not None:
|
||||||
checks.append(volume_ratio is not None and volume_ratio > condition.volume_ratio_above)
|
checks.append(volume_ratio is not None and volume_ratio > condition.volume_ratio_above)
|
||||||
if condition.volume_ratio_below is not None:
|
if condition.volume_ratio_below is not None:
|
||||||
checks.append(volume_ratio is not None and volume_ratio < condition.volume_ratio_below)
|
checks.append(volume_ratio is not None and volume_ratio < condition.volume_ratio_below)
|
||||||
|
|
||||||
price = market_data.get("current_price")
|
price = self._safe_float(market_data.get("current_price"))
|
||||||
|
if condition.price_above is not None or condition.price_below is not None:
|
||||||
|
if "current_price" not in market_data:
|
||||||
|
logger.warning("Condition requires 'current_price' but key missing from market_data")
|
||||||
if condition.price_above is not None:
|
if condition.price_above is not None:
|
||||||
checks.append(price is not None and price > condition.price_above)
|
checks.append(price is not None and price > condition.price_above)
|
||||||
if condition.price_below is not None:
|
if condition.price_below is not None:
|
||||||
checks.append(price is not None and price < condition.price_below)
|
checks.append(price is not None and price < condition.price_below)
|
||||||
|
|
||||||
price_change_pct = market_data.get("price_change_pct")
|
price_change_pct = self._safe_float(market_data.get("price_change_pct"))
|
||||||
|
if condition.price_change_pct_above is not None or condition.price_change_pct_below is not None:
|
||||||
|
if "price_change_pct" not in market_data:
|
||||||
|
logger.warning("Condition requires 'price_change_pct' but key missing from market_data")
|
||||||
if condition.price_change_pct_above is not None:
|
if condition.price_change_pct_above is not None:
|
||||||
checks.append(price_change_pct is not None and price_change_pct > condition.price_change_pct_above)
|
checks.append(price_change_pct is not None and price_change_pct > condition.price_change_pct_above)
|
||||||
if condition.price_change_pct_below is not None:
|
if condition.price_change_pct_below is not None:
|
||||||
|
|||||||
@@ -139,6 +139,50 @@ class TestEvaluateCondition:
|
|||||||
cond = StockCondition(rsi_above=70.0)
|
cond = StockCondition(rsi_above=70.0)
|
||||||
assert not engine.evaluate_condition(cond, {"rsi": 70.0})
|
assert not engine.evaluate_condition(cond, {"rsi": 70.0})
|
||||||
|
|
||||||
|
def test_string_value_no_exception(self, engine: ScenarioEngine) -> None:
|
||||||
|
"""String numeric value should not raise TypeError."""
|
||||||
|
cond = StockCondition(rsi_below=30.0)
|
||||||
|
# "25" can be cast to float → should match
|
||||||
|
assert engine.evaluate_condition(cond, {"rsi": "25"})
|
||||||
|
# "35" → should not match
|
||||||
|
assert not engine.evaluate_condition(cond, {"rsi": "35"})
|
||||||
|
|
||||||
|
def test_percent_string_returns_false(self, engine: ScenarioEngine) -> None:
|
||||||
|
"""Percent string like '30%' cannot be cast to float → False, no exception."""
|
||||||
|
cond = StockCondition(rsi_below=30.0)
|
||||||
|
assert not engine.evaluate_condition(cond, {"rsi": "30%"})
|
||||||
|
|
||||||
|
def test_decimal_value_no_exception(self, engine: ScenarioEngine) -> None:
|
||||||
|
"""Decimal values should be safely handled."""
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
cond = StockCondition(rsi_below=30.0)
|
||||||
|
assert engine.evaluate_condition(cond, {"rsi": Decimal("25.0")})
|
||||||
|
|
||||||
|
def test_mixed_invalid_types_no_exception(self, engine: ScenarioEngine) -> None:
|
||||||
|
"""Various invalid types should not raise exceptions."""
|
||||||
|
cond = StockCondition(
|
||||||
|
rsi_below=30.0, volume_ratio_above=2.0,
|
||||||
|
price_above=100, price_change_pct_below=-1.0,
|
||||||
|
)
|
||||||
|
data = {
|
||||||
|
"rsi": [25], # list
|
||||||
|
"volume_ratio": "bad", # non-numeric string
|
||||||
|
"current_price": {}, # dict
|
||||||
|
"price_change_pct": object(), # arbitrary object
|
||||||
|
}
|
||||||
|
# Should return False (invalid types → None → False), never raise
|
||||||
|
assert not engine.evaluate_condition(cond, data)
|
||||||
|
|
||||||
|
def test_missing_key_logs_warning(self, engine: ScenarioEngine, caplog) -> None:
|
||||||
|
"""Missing market_data key should log a warning."""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
cond = StockCondition(rsi_below=30.0)
|
||||||
|
with caplog.at_level(logging.WARNING):
|
||||||
|
engine.evaluate_condition(cond, {})
|
||||||
|
assert "key missing from market_data" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# check_global_rules
|
# check_global_rules
|
||||||
|
|||||||
Reference in New Issue
Block a user