From 62cd8a81a4faf8930a5f09e21869eaa32a1dcb7c Mon Sep 17 00:00:00 2001 From: agentson Date: Sat, 28 Feb 2026 18:35:32 +0900 Subject: [PATCH] feat: feed staged-exit with ATR/RSI runtime features (#325) --- src/config.py | 2 + src/main.py | 114 +++++++++++++++++++++++++++++++++++++++++++-- tests/test_main.py | 97 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 210 insertions(+), 3 deletions(-) diff --git a/src/config.py b/src/config.py index 7f0a367..eeb4f1f 100644 --- a/src/config.py +++ b/src/config.py @@ -61,6 +61,8 @@ class Settings(BaseSettings): PAPER_OVERSEAS_CASH: float = Field(default=50000.0, ge=0.0) USD_BUFFER_MIN: float = Field(default=1000.0, ge=0.0) US_MIN_PRICE: float = Field(default=5.0, ge=0.0) + STAGED_EXIT_BE_ARM_PCT: float = Field(default=1.2, gt=0.0, le=30.0) + STAGED_EXIT_ARM_PCT: float = Field(default=3.0, gt=0.0, le=100.0) STOPLOSS_REENTRY_COOLDOWN_MINUTES: int = Field(default=120, ge=1, le=1440) KR_ATR_STOP_MULTIPLIER_K: float = Field(default=2.0, ge=0.1, le=10.0) KR_ATR_STOP_MIN_PCT: float = Field(default=-2.0, le=0.0) diff --git a/src/main.py b/src/main.py index bc9a926..1349e5c 100644 --- a/src/main.py +++ b/src/main.py @@ -71,6 +71,7 @@ _SESSION_CLOSE_WINDOWS = {"NXT_AFTER", "US_AFTER"} _RUNTIME_EXIT_STATES: dict[str, PositionState] = {} _RUNTIME_EXIT_PEAKS: dict[str, float] = {} _STOPLOSS_REENTRY_COOLDOWN_UNTIL: dict[str, float] = {} +_VOLATILITY_ANALYZER = VolatilityAnalyzer() def safe_float(value: str | float | None, default: float = 0.0) -> float: @@ -150,6 +151,90 @@ def _stoploss_cooldown_minutes(settings: Settings | None) -> int: return max(1, int(getattr(settings, "STOPLOSS_REENTRY_COOLDOWN_MINUTES", 120))) +def _estimate_pred_down_prob_from_rsi(rsi: float | str | None) -> float: + """Estimate downside probability from RSI using a simple linear mapping.""" + if rsi is None: + return 0.5 + rsi_value = max(0.0, min(100.0, safe_float(rsi, 50.0))) + return rsi_value / 100.0 + + +async def _compute_kr_atr_value( + *, + broker: KISBroker, + stock_code: str, + period: int = 14, +) -> float: + """Compute ATR(period) for KR stocks using daily OHLC.""" + days = max(period + 1, 30) + try: + daily_prices = await _retry_connection( + broker.get_daily_prices, + stock_code, + days=days, + label=f"daily_prices:{stock_code}", + ) + except ConnectionError as exc: + logger.warning("ATR source unavailable for %s: %s", stock_code, exc) + return 0.0 + except Exception as exc: + logger.warning("Unexpected ATR fetch failure for %s: %s", stock_code, exc) + return 0.0 + + if not isinstance(daily_prices, list): + return 0.0 + + highs: list[float] = [] + lows: list[float] = [] + closes: list[float] = [] + for row in daily_prices: + if not isinstance(row, dict): + continue + high = safe_float(row.get("high"), 0.0) + low = safe_float(row.get("low"), 0.0) + close = safe_float(row.get("close"), 0.0) + if high <= 0 or low <= 0 or close <= 0: + continue + highs.append(high) + lows.append(low) + closes.append(close) + + if len(highs) < period + 1 or len(lows) < period + 1 or len(closes) < period + 1: + return 0.0 + return max(0.0, _VOLATILITY_ANALYZER.calculate_atr(highs, lows, closes, period=period)) + + +async def _inject_staged_exit_features( + *, + market: MarketInfo, + stock_code: str, + open_position: dict[str, Any] | None, + market_data: dict[str, Any], + broker: KISBroker | None, +) -> None: + """Inject ATR/pred_down_prob used by staged exit evaluation.""" + if not open_position: + return + + if "pred_down_prob" not in market_data: + market_data["pred_down_prob"] = _estimate_pred_down_prob_from_rsi( + market_data.get("rsi") + ) + + existing_atr = safe_float(market_data.get("atr_value"), 0.0) + if existing_atr > 0: + return + + if market.is_domestic and broker is not None: + market_data["atr_value"] = await _compute_kr_atr_value( + broker=broker, + stock_code=stock_code, + ) + return + + market_data["atr_value"] = 0.0 + + async def _retry_connection(coro_factory: Any, *args: Any, label: str = "", **kwargs: Any) -> Any: """Call an async function retrying on ConnectionError with exponential backoff. @@ -563,6 +648,15 @@ def _apply_staged_exit_override_for_hold( fallback_stop_loss_pct=stop_loss_threshold, settings=settings, ) + if settings is None: + be_arm_pct = max(0.5, take_profit_threshold * 0.4) + arm_pct = take_profit_threshold + else: + be_arm_pct = max(0.1, float(getattr(settings, "STAGED_EXIT_BE_ARM_PCT", 1.2))) + arm_pct = max( + be_arm_pct, + float(getattr(settings, "STAGED_EXIT_ARM_PCT", 3.0)), + ) runtime_key = _build_runtime_position_key( market_code=market.code, @@ -581,8 +675,8 @@ def _apply_staged_exit_override_for_hold( current_state=current_state, config=ExitRuleConfig( hard_stop_pct=stop_loss_threshold, - be_arm_pct=max(0.5, take_profit_threshold * 0.4), - arm_pct=take_profit_threshold, + be_arm_pct=be_arm_pct, + arm_pct=arm_pct, ), inp=ExitRuleInput( current_price=current_price, @@ -608,7 +702,7 @@ def _apply_staged_exit_override_for_hold( elif exit_eval.reason == "arm_take_profit": rationale = ( f"Take-profit triggered ({pnl_pct:.2f}% >= " - f"{take_profit_threshold:.2f}%)" + f"{arm_pct:.2f}%)" ) elif exit_eval.reason == "atr_trailing_stop": rationale = "ATR trailing-stop triggered" @@ -1398,6 +1492,13 @@ async def trading_cycle( market_code=market.code, stock_code=stock_code, ) + await _inject_staged_exit_features( + market=market, + stock_code=stock_code, + open_position=open_position, + market_data=market_data, + broker=broker, + ) decision = _apply_staged_exit_override_for_hold( decision=decision, market=market, @@ -2606,6 +2707,13 @@ async def run_daily_session( market_code=market.code, stock_code=stock_code, ) + await _inject_staged_exit_features( + market=market, + stock_code=stock_code, + open_position=daily_open, + market_data=stock_data, + broker=broker, + ) decision = _apply_staged_exit_override_for_hold( decision=decision, market=market, diff --git a/tests/test_main.py b/tests/test_main.py index e98a659..bd2ea2b 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -16,6 +16,10 @@ from src.logging.decision_logger import DecisionLogger from src.main import ( KILL_SWITCH, _STOPLOSS_REENTRY_COOLDOWN_UNTIL, + _apply_staged_exit_override_for_hold, + _compute_kr_atr_value, + _estimate_pred_down_prob_from_rsi, + _inject_staged_exit_features, _RUNTIME_EXIT_PEAKS, _RUNTIME_EXIT_STATES, _should_force_exit_for_overnight, @@ -181,6 +185,99 @@ def test_compute_kr_dynamic_stop_loss_pct_uses_settings_values() -> None: ) assert out == -3.0 + +def test_estimate_pred_down_prob_from_rsi_uses_linear_mapping() -> None: + assert _estimate_pred_down_prob_from_rsi(None) == 0.5 + assert _estimate_pred_down_prob_from_rsi(0.0) == 0.0 + assert _estimate_pred_down_prob_from_rsi(50.0) == 0.5 + assert _estimate_pred_down_prob_from_rsi(100.0) == 1.0 + + +@pytest.mark.asyncio +async def test_compute_kr_atr_value_returns_zero_on_short_series() -> None: + broker = MagicMock() + broker.get_daily_prices = AsyncMock( + return_value=[{"high": 101.0, "low": 99.0, "close": 100.0}] * 10 + ) + + atr = await _compute_kr_atr_value(broker=broker, stock_code="005930") + assert atr == 0.0 + + +@pytest.mark.asyncio +async def test_inject_staged_exit_features_sets_pred_down_prob_and_atr_for_kr() -> None: + market = MagicMock() + market.is_domestic = True + stock_data: dict[str, float] = {"rsi": 65.0} + + broker = MagicMock() + broker.get_daily_prices = AsyncMock( + return_value=[ + {"high": 102.0 + i, "low": 98.0 + i, "close": 100.0 + i} + for i in range(40) + ] + ) + + await _inject_staged_exit_features( + market=market, + stock_code="005930", + open_position={"price": 100.0, "quantity": 1}, + market_data=stock_data, + broker=broker, + ) + + assert stock_data["pred_down_prob"] == pytest.approx(0.65) + assert stock_data["atr_value"] > 0.0 + + +def test_apply_staged_exit_uses_independent_arm_threshold_settings() -> None: + market = MagicMock() + market.code = "KR" + market.name = "Korea" + + decision = MagicMock() + decision.action = "HOLD" + decision.confidence = 70 + decision.rationale = "hold" + + settings = Settings( + KIS_APP_KEY="k", + KIS_APP_SECRET="s", + KIS_ACCOUNT_NO="12345678-01", + GEMINI_API_KEY="g", + STAGED_EXIT_BE_ARM_PCT=2.2, + STAGED_EXIT_ARM_PCT=5.4, + ) + + captured: dict[str, float] = {} + + def _fake_eval(**kwargs): # type: ignore[no-untyped-def] + cfg = kwargs["config"] + captured["be_arm_pct"] = cfg.be_arm_pct + captured["arm_pct"] = cfg.arm_pct + + class _Out: + should_exit = False + reason = "none" + state = PositionState.HOLDING + + return _Out() + + with patch("src.main.evaluate_exit", side_effect=_fake_eval): + out = _apply_staged_exit_override_for_hold( + decision=decision, + market=market, + stock_code="005930", + open_position={"price": 100.0, "quantity": 1, "decision_id": "d1", "timestamp": "t1"}, + market_data={"current_price": 101.0, "rsi": 60.0, "pred_down_prob": 0.6}, + stock_playbook=None, + settings=settings, + ) + + assert out is decision + assert captured["be_arm_pct"] == pytest.approx(2.2) + assert captured["arm_pct"] == pytest.approx(5.4) + def test_returns_zero_when_field_empty_string(self) -> None: """Returns 0.0 when pchs_avg_pric is an empty string.""" balance = {"output1": [{"pdno": "005930", "pchs_avg_pric": ""}]}