feat: feed staged-exit with ATR/RSI runtime features (#325)
This commit is contained in:
114
src/main.py
114
src/main.py
@@ -71,6 +71,7 @@ _SESSION_CLOSE_WINDOWS = {"NXT_AFTER", "US_AFTER"}
|
||||
_RUNTIME_EXIT_STATES: dict[str, PositionState] = {}
|
||||
_RUNTIME_EXIT_PEAKS: dict[str, float] = {}
|
||||
_STOPLOSS_REENTRY_COOLDOWN_UNTIL: dict[str, float] = {}
|
||||
_VOLATILITY_ANALYZER = VolatilityAnalyzer()
|
||||
|
||||
|
||||
def safe_float(value: str | float | None, default: float = 0.0) -> float:
|
||||
@@ -150,6 +151,90 @@ def _stoploss_cooldown_minutes(settings: Settings | None) -> int:
|
||||
return max(1, int(getattr(settings, "STOPLOSS_REENTRY_COOLDOWN_MINUTES", 120)))
|
||||
|
||||
|
||||
def _estimate_pred_down_prob_from_rsi(rsi: float | str | None) -> float:
|
||||
"""Estimate downside probability from RSI using a simple linear mapping."""
|
||||
if rsi is None:
|
||||
return 0.5
|
||||
rsi_value = max(0.0, min(100.0, safe_float(rsi, 50.0)))
|
||||
return rsi_value / 100.0
|
||||
|
||||
|
||||
async def _compute_kr_atr_value(
|
||||
*,
|
||||
broker: KISBroker,
|
||||
stock_code: str,
|
||||
period: int = 14,
|
||||
) -> float:
|
||||
"""Compute ATR(period) for KR stocks using daily OHLC."""
|
||||
days = max(period + 1, 30)
|
||||
try:
|
||||
daily_prices = await _retry_connection(
|
||||
broker.get_daily_prices,
|
||||
stock_code,
|
||||
days=days,
|
||||
label=f"daily_prices:{stock_code}",
|
||||
)
|
||||
except ConnectionError as exc:
|
||||
logger.warning("ATR source unavailable for %s: %s", stock_code, exc)
|
||||
return 0.0
|
||||
except Exception as exc:
|
||||
logger.warning("Unexpected ATR fetch failure for %s: %s", stock_code, exc)
|
||||
return 0.0
|
||||
|
||||
if not isinstance(daily_prices, list):
|
||||
return 0.0
|
||||
|
||||
highs: list[float] = []
|
||||
lows: list[float] = []
|
||||
closes: list[float] = []
|
||||
for row in daily_prices:
|
||||
if not isinstance(row, dict):
|
||||
continue
|
||||
high = safe_float(row.get("high"), 0.0)
|
||||
low = safe_float(row.get("low"), 0.0)
|
||||
close = safe_float(row.get("close"), 0.0)
|
||||
if high <= 0 or low <= 0 or close <= 0:
|
||||
continue
|
||||
highs.append(high)
|
||||
lows.append(low)
|
||||
closes.append(close)
|
||||
|
||||
if len(highs) < period + 1 or len(lows) < period + 1 or len(closes) < period + 1:
|
||||
return 0.0
|
||||
return max(0.0, _VOLATILITY_ANALYZER.calculate_atr(highs, lows, closes, period=period))
|
||||
|
||||
|
||||
async def _inject_staged_exit_features(
|
||||
*,
|
||||
market: MarketInfo,
|
||||
stock_code: str,
|
||||
open_position: dict[str, Any] | None,
|
||||
market_data: dict[str, Any],
|
||||
broker: KISBroker | None,
|
||||
) -> None:
|
||||
"""Inject ATR/pred_down_prob used by staged exit evaluation."""
|
||||
if not open_position:
|
||||
return
|
||||
|
||||
if "pred_down_prob" not in market_data:
|
||||
market_data["pred_down_prob"] = _estimate_pred_down_prob_from_rsi(
|
||||
market_data.get("rsi")
|
||||
)
|
||||
|
||||
existing_atr = safe_float(market_data.get("atr_value"), 0.0)
|
||||
if existing_atr > 0:
|
||||
return
|
||||
|
||||
if market.is_domestic and broker is not None:
|
||||
market_data["atr_value"] = await _compute_kr_atr_value(
|
||||
broker=broker,
|
||||
stock_code=stock_code,
|
||||
)
|
||||
return
|
||||
|
||||
market_data["atr_value"] = 0.0
|
||||
|
||||
|
||||
async def _retry_connection(coro_factory: Any, *args: Any, label: str = "", **kwargs: Any) -> Any:
|
||||
"""Call an async function retrying on ConnectionError with exponential backoff.
|
||||
|
||||
@@ -563,6 +648,15 @@ def _apply_staged_exit_override_for_hold(
|
||||
fallback_stop_loss_pct=stop_loss_threshold,
|
||||
settings=settings,
|
||||
)
|
||||
if settings is None:
|
||||
be_arm_pct = max(0.5, take_profit_threshold * 0.4)
|
||||
arm_pct = take_profit_threshold
|
||||
else:
|
||||
be_arm_pct = max(0.1, float(getattr(settings, "STAGED_EXIT_BE_ARM_PCT", 1.2)))
|
||||
arm_pct = max(
|
||||
be_arm_pct,
|
||||
float(getattr(settings, "STAGED_EXIT_ARM_PCT", 3.0)),
|
||||
)
|
||||
|
||||
runtime_key = _build_runtime_position_key(
|
||||
market_code=market.code,
|
||||
@@ -581,8 +675,8 @@ def _apply_staged_exit_override_for_hold(
|
||||
current_state=current_state,
|
||||
config=ExitRuleConfig(
|
||||
hard_stop_pct=stop_loss_threshold,
|
||||
be_arm_pct=max(0.5, take_profit_threshold * 0.4),
|
||||
arm_pct=take_profit_threshold,
|
||||
be_arm_pct=be_arm_pct,
|
||||
arm_pct=arm_pct,
|
||||
),
|
||||
inp=ExitRuleInput(
|
||||
current_price=current_price,
|
||||
@@ -608,7 +702,7 @@ def _apply_staged_exit_override_for_hold(
|
||||
elif exit_eval.reason == "arm_take_profit":
|
||||
rationale = (
|
||||
f"Take-profit triggered ({pnl_pct:.2f}% >= "
|
||||
f"{take_profit_threshold:.2f}%)"
|
||||
f"{arm_pct:.2f}%)"
|
||||
)
|
||||
elif exit_eval.reason == "atr_trailing_stop":
|
||||
rationale = "ATR trailing-stop triggered"
|
||||
@@ -1398,6 +1492,13 @@ async def trading_cycle(
|
||||
market_code=market.code,
|
||||
stock_code=stock_code,
|
||||
)
|
||||
await _inject_staged_exit_features(
|
||||
market=market,
|
||||
stock_code=stock_code,
|
||||
open_position=open_position,
|
||||
market_data=market_data,
|
||||
broker=broker,
|
||||
)
|
||||
decision = _apply_staged_exit_override_for_hold(
|
||||
decision=decision,
|
||||
market=market,
|
||||
@@ -2606,6 +2707,13 @@ async def run_daily_session(
|
||||
market_code=market.code,
|
||||
stock_code=stock_code,
|
||||
)
|
||||
await _inject_staged_exit_features(
|
||||
market=market,
|
||||
stock_code=stock_code,
|
||||
open_position=daily_open,
|
||||
market_data=stock_data,
|
||||
broker=broker,
|
||||
)
|
||||
decision = _apply_staged_exit_override_for_hold(
|
||||
decision=decision,
|
||||
market=market,
|
||||
|
||||
Reference in New Issue
Block a user