feat: feed staged-exit with ATR/RSI runtime features (#325)
Some checks failed
Gitea CI / test (push) Failing after 3s
Gitea CI / test (pull_request) Failing after 3s

This commit is contained in:
agentson
2026-02-28 18:35:32 +09:00
parent dd8549b912
commit 62cd8a81a4
3 changed files with 210 additions and 3 deletions

View File

@@ -71,6 +71,7 @@ _SESSION_CLOSE_WINDOWS = {"NXT_AFTER", "US_AFTER"}
_RUNTIME_EXIT_STATES: dict[str, PositionState] = {}
_RUNTIME_EXIT_PEAKS: dict[str, float] = {}
_STOPLOSS_REENTRY_COOLDOWN_UNTIL: dict[str, float] = {}
_VOLATILITY_ANALYZER = VolatilityAnalyzer()
def safe_float(value: str | float | None, default: float = 0.0) -> float:
@@ -150,6 +151,90 @@ def _stoploss_cooldown_minutes(settings: Settings | None) -> int:
return max(1, int(getattr(settings, "STOPLOSS_REENTRY_COOLDOWN_MINUTES", 120)))
def _estimate_pred_down_prob_from_rsi(rsi: float | str | None) -> float:
"""Estimate downside probability from RSI using a simple linear mapping."""
if rsi is None:
return 0.5
rsi_value = max(0.0, min(100.0, safe_float(rsi, 50.0)))
return rsi_value / 100.0
async def _compute_kr_atr_value(
*,
broker: KISBroker,
stock_code: str,
period: int = 14,
) -> float:
"""Compute ATR(period) for KR stocks using daily OHLC."""
days = max(period + 1, 30)
try:
daily_prices = await _retry_connection(
broker.get_daily_prices,
stock_code,
days=days,
label=f"daily_prices:{stock_code}",
)
except ConnectionError as exc:
logger.warning("ATR source unavailable for %s: %s", stock_code, exc)
return 0.0
except Exception as exc:
logger.warning("Unexpected ATR fetch failure for %s: %s", stock_code, exc)
return 0.0
if not isinstance(daily_prices, list):
return 0.0
highs: list[float] = []
lows: list[float] = []
closes: list[float] = []
for row in daily_prices:
if not isinstance(row, dict):
continue
high = safe_float(row.get("high"), 0.0)
low = safe_float(row.get("low"), 0.0)
close = safe_float(row.get("close"), 0.0)
if high <= 0 or low <= 0 or close <= 0:
continue
highs.append(high)
lows.append(low)
closes.append(close)
if len(highs) < period + 1 or len(lows) < period + 1 or len(closes) < period + 1:
return 0.0
return max(0.0, _VOLATILITY_ANALYZER.calculate_atr(highs, lows, closes, period=period))
async def _inject_staged_exit_features(
*,
market: MarketInfo,
stock_code: str,
open_position: dict[str, Any] | None,
market_data: dict[str, Any],
broker: KISBroker | None,
) -> None:
"""Inject ATR/pred_down_prob used by staged exit evaluation."""
if not open_position:
return
if "pred_down_prob" not in market_data:
market_data["pred_down_prob"] = _estimate_pred_down_prob_from_rsi(
market_data.get("rsi")
)
existing_atr = safe_float(market_data.get("atr_value"), 0.0)
if existing_atr > 0:
return
if market.is_domestic and broker is not None:
market_data["atr_value"] = await _compute_kr_atr_value(
broker=broker,
stock_code=stock_code,
)
return
market_data["atr_value"] = 0.0
async def _retry_connection(coro_factory: Any, *args: Any, label: str = "", **kwargs: Any) -> Any:
"""Call an async function retrying on ConnectionError with exponential backoff.
@@ -563,6 +648,15 @@ def _apply_staged_exit_override_for_hold(
fallback_stop_loss_pct=stop_loss_threshold,
settings=settings,
)
if settings is None:
be_arm_pct = max(0.5, take_profit_threshold * 0.4)
arm_pct = take_profit_threshold
else:
be_arm_pct = max(0.1, float(getattr(settings, "STAGED_EXIT_BE_ARM_PCT", 1.2)))
arm_pct = max(
be_arm_pct,
float(getattr(settings, "STAGED_EXIT_ARM_PCT", 3.0)),
)
runtime_key = _build_runtime_position_key(
market_code=market.code,
@@ -581,8 +675,8 @@ def _apply_staged_exit_override_for_hold(
current_state=current_state,
config=ExitRuleConfig(
hard_stop_pct=stop_loss_threshold,
be_arm_pct=max(0.5, take_profit_threshold * 0.4),
arm_pct=take_profit_threshold,
be_arm_pct=be_arm_pct,
arm_pct=arm_pct,
),
inp=ExitRuleInput(
current_price=current_price,
@@ -608,7 +702,7 @@ def _apply_staged_exit_override_for_hold(
elif exit_eval.reason == "arm_take_profit":
rationale = (
f"Take-profit triggered ({pnl_pct:.2f}% >= "
f"{take_profit_threshold:.2f}%)"
f"{arm_pct:.2f}%)"
)
elif exit_eval.reason == "atr_trailing_stop":
rationale = "ATR trailing-stop triggered"
@@ -1398,6 +1492,13 @@ async def trading_cycle(
market_code=market.code,
stock_code=stock_code,
)
await _inject_staged_exit_features(
market=market,
stock_code=stock_code,
open_position=open_position,
market_data=market_data,
broker=broker,
)
decision = _apply_staged_exit_override_for_hold(
decision=decision,
market=market,
@@ -2606,6 +2707,13 @@ async def run_daily_session(
market_code=market.code,
stock_code=stock_code,
)
await _inject_staged_exit_features(
market=market,
stock_code=stock_code,
open_position=daily_open,
market_data=stock_data,
broker=broker,
)
decision = _apply_staged_exit_override_for_hold(
decision=decision,
market=market,