feat: feed staged-exit with ATR/RSI runtime features (#325)
This commit is contained in:
@@ -61,6 +61,8 @@ class Settings(BaseSettings):
|
|||||||
PAPER_OVERSEAS_CASH: float = Field(default=50000.0, ge=0.0)
|
PAPER_OVERSEAS_CASH: float = Field(default=50000.0, ge=0.0)
|
||||||
USD_BUFFER_MIN: float = Field(default=1000.0, ge=0.0)
|
USD_BUFFER_MIN: float = Field(default=1000.0, ge=0.0)
|
||||||
US_MIN_PRICE: float = Field(default=5.0, ge=0.0)
|
US_MIN_PRICE: float = Field(default=5.0, ge=0.0)
|
||||||
|
STAGED_EXIT_BE_ARM_PCT: float = Field(default=1.2, gt=0.0, le=30.0)
|
||||||
|
STAGED_EXIT_ARM_PCT: float = Field(default=3.0, gt=0.0, le=100.0)
|
||||||
STOPLOSS_REENTRY_COOLDOWN_MINUTES: int = Field(default=120, ge=1, le=1440)
|
STOPLOSS_REENTRY_COOLDOWN_MINUTES: int = Field(default=120, ge=1, le=1440)
|
||||||
OVERNIGHT_EXCEPTION_ENABLED: bool = True
|
OVERNIGHT_EXCEPTION_ENABLED: bool = True
|
||||||
|
|
||||||
|
|||||||
117
src/main.py
117
src/main.py
@@ -71,6 +71,7 @@ _SESSION_CLOSE_WINDOWS = {"NXT_AFTER", "US_AFTER"}
|
|||||||
_RUNTIME_EXIT_STATES: dict[str, PositionState] = {}
|
_RUNTIME_EXIT_STATES: dict[str, PositionState] = {}
|
||||||
_RUNTIME_EXIT_PEAKS: dict[str, float] = {}
|
_RUNTIME_EXIT_PEAKS: dict[str, float] = {}
|
||||||
_STOPLOSS_REENTRY_COOLDOWN_UNTIL: dict[str, float] = {}
|
_STOPLOSS_REENTRY_COOLDOWN_UNTIL: dict[str, float] = {}
|
||||||
|
_VOLATILITY_ANALYZER = VolatilityAnalyzer()
|
||||||
|
|
||||||
|
|
||||||
def safe_float(value: str | float | None, default: float = 0.0) -> float:
|
def safe_float(value: str | float | None, default: float = 0.0) -> float:
|
||||||
@@ -129,6 +130,90 @@ def _stoploss_cooldown_minutes(settings: Settings | None) -> int:
|
|||||||
return max(1, int(getattr(settings, "STOPLOSS_REENTRY_COOLDOWN_MINUTES", 120)))
|
return max(1, int(getattr(settings, "STOPLOSS_REENTRY_COOLDOWN_MINUTES", 120)))
|
||||||
|
|
||||||
|
|
||||||
|
def _estimate_pred_down_prob_from_rsi(rsi: float | str | None) -> float:
|
||||||
|
"""Estimate downside probability from RSI using a simple linear mapping."""
|
||||||
|
if rsi is None:
|
||||||
|
return 0.5
|
||||||
|
rsi_value = max(0.0, min(100.0, safe_float(rsi, 50.0)))
|
||||||
|
return rsi_value / 100.0
|
||||||
|
|
||||||
|
|
||||||
|
async def _compute_kr_atr_value(
|
||||||
|
*,
|
||||||
|
broker: KISBroker,
|
||||||
|
stock_code: str,
|
||||||
|
period: int = 14,
|
||||||
|
) -> float:
|
||||||
|
"""Compute ATR(period) for KR stocks using daily OHLC."""
|
||||||
|
days = max(period + 1, 30)
|
||||||
|
try:
|
||||||
|
daily_prices = await _retry_connection(
|
||||||
|
broker.get_daily_prices,
|
||||||
|
stock_code,
|
||||||
|
days=days,
|
||||||
|
label=f"daily_prices:{stock_code}",
|
||||||
|
)
|
||||||
|
except ConnectionError as exc:
|
||||||
|
logger.warning("ATR source unavailable for %s: %s", stock_code, exc)
|
||||||
|
return 0.0
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Unexpected ATR fetch failure for %s: %s", stock_code, exc)
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
if not isinstance(daily_prices, list):
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
highs: list[float] = []
|
||||||
|
lows: list[float] = []
|
||||||
|
closes: list[float] = []
|
||||||
|
for row in daily_prices:
|
||||||
|
if not isinstance(row, dict):
|
||||||
|
continue
|
||||||
|
high = safe_float(row.get("high"), 0.0)
|
||||||
|
low = safe_float(row.get("low"), 0.0)
|
||||||
|
close = safe_float(row.get("close"), 0.0)
|
||||||
|
if high <= 0 or low <= 0 or close <= 0:
|
||||||
|
continue
|
||||||
|
highs.append(high)
|
||||||
|
lows.append(low)
|
||||||
|
closes.append(close)
|
||||||
|
|
||||||
|
if len(highs) < period + 1 or len(lows) < period + 1 or len(closes) < period + 1:
|
||||||
|
return 0.0
|
||||||
|
return max(0.0, _VOLATILITY_ANALYZER.calculate_atr(highs, lows, closes, period=period))
|
||||||
|
|
||||||
|
|
||||||
|
async def _inject_staged_exit_features(
|
||||||
|
*,
|
||||||
|
market: MarketInfo,
|
||||||
|
stock_code: str,
|
||||||
|
open_position: dict[str, Any] | None,
|
||||||
|
market_data: dict[str, Any],
|
||||||
|
broker: KISBroker | None,
|
||||||
|
) -> None:
|
||||||
|
"""Inject ATR/pred_down_prob used by staged exit evaluation."""
|
||||||
|
if not open_position:
|
||||||
|
return
|
||||||
|
|
||||||
|
if "pred_down_prob" not in market_data:
|
||||||
|
market_data["pred_down_prob"] = _estimate_pred_down_prob_from_rsi(
|
||||||
|
market_data.get("rsi")
|
||||||
|
)
|
||||||
|
|
||||||
|
existing_atr = safe_float(market_data.get("atr_value"), 0.0)
|
||||||
|
if existing_atr > 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
if market.is_domestic and broker is not None:
|
||||||
|
market_data["atr_value"] = await _compute_kr_atr_value(
|
||||||
|
broker=broker,
|
||||||
|
stock_code=stock_code,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
market_data["atr_value"] = 0.0
|
||||||
|
|
||||||
|
|
||||||
async def _retry_connection(coro_factory: Any, *args: Any, label: str = "", **kwargs: Any) -> Any:
|
async def _retry_connection(coro_factory: Any, *args: Any, label: str = "", **kwargs: Any) -> Any:
|
||||||
"""Call an async function retrying on ConnectionError with exponential backoff.
|
"""Call an async function retrying on ConnectionError with exponential backoff.
|
||||||
|
|
||||||
@@ -518,6 +603,7 @@ def _apply_staged_exit_override_for_hold(
|
|||||||
open_position: dict[str, Any] | None,
|
open_position: dict[str, Any] | None,
|
||||||
market_data: dict[str, Any],
|
market_data: dict[str, Any],
|
||||||
stock_playbook: Any | None,
|
stock_playbook: Any | None,
|
||||||
|
settings: Settings | None = None,
|
||||||
) -> TradeDecision:
|
) -> TradeDecision:
|
||||||
"""Apply v2 staged exit semantics for HOLD positions using runtime state."""
|
"""Apply v2 staged exit semantics for HOLD positions using runtime state."""
|
||||||
if decision.action != "HOLD" or not open_position:
|
if decision.action != "HOLD" or not open_position:
|
||||||
@@ -533,6 +619,15 @@ def _apply_staged_exit_override_for_hold(
|
|||||||
if stock_playbook and stock_playbook.scenarios:
|
if stock_playbook and stock_playbook.scenarios:
|
||||||
stop_loss_threshold = stock_playbook.scenarios[0].stop_loss_pct
|
stop_loss_threshold = stock_playbook.scenarios[0].stop_loss_pct
|
||||||
take_profit_threshold = stock_playbook.scenarios[0].take_profit_pct
|
take_profit_threshold = stock_playbook.scenarios[0].take_profit_pct
|
||||||
|
if settings is None:
|
||||||
|
be_arm_pct = max(0.5, take_profit_threshold * 0.4)
|
||||||
|
arm_pct = take_profit_threshold
|
||||||
|
else:
|
||||||
|
be_arm_pct = max(0.1, float(getattr(settings, "STAGED_EXIT_BE_ARM_PCT", 1.2)))
|
||||||
|
arm_pct = max(
|
||||||
|
be_arm_pct,
|
||||||
|
float(getattr(settings, "STAGED_EXIT_ARM_PCT", 3.0)),
|
||||||
|
)
|
||||||
|
|
||||||
runtime_key = _build_runtime_position_key(
|
runtime_key = _build_runtime_position_key(
|
||||||
market_code=market.code,
|
market_code=market.code,
|
||||||
@@ -551,8 +646,8 @@ def _apply_staged_exit_override_for_hold(
|
|||||||
current_state=current_state,
|
current_state=current_state,
|
||||||
config=ExitRuleConfig(
|
config=ExitRuleConfig(
|
||||||
hard_stop_pct=stop_loss_threshold,
|
hard_stop_pct=stop_loss_threshold,
|
||||||
be_arm_pct=max(0.5, take_profit_threshold * 0.4),
|
be_arm_pct=be_arm_pct,
|
||||||
arm_pct=take_profit_threshold,
|
arm_pct=arm_pct,
|
||||||
),
|
),
|
||||||
inp=ExitRuleInput(
|
inp=ExitRuleInput(
|
||||||
current_price=current_price,
|
current_price=current_price,
|
||||||
@@ -578,7 +673,7 @@ def _apply_staged_exit_override_for_hold(
|
|||||||
elif exit_eval.reason == "arm_take_profit":
|
elif exit_eval.reason == "arm_take_profit":
|
||||||
rationale = (
|
rationale = (
|
||||||
f"Take-profit triggered ({pnl_pct:.2f}% >= "
|
f"Take-profit triggered ({pnl_pct:.2f}% >= "
|
||||||
f"{take_profit_threshold:.2f}%)"
|
f"{arm_pct:.2f}%)"
|
||||||
)
|
)
|
||||||
elif exit_eval.reason == "atr_trailing_stop":
|
elif exit_eval.reason == "atr_trailing_stop":
|
||||||
rationale = "ATR trailing-stop triggered"
|
rationale = "ATR trailing-stop triggered"
|
||||||
@@ -1368,6 +1463,13 @@ async def trading_cycle(
|
|||||||
market_code=market.code,
|
market_code=market.code,
|
||||||
stock_code=stock_code,
|
stock_code=stock_code,
|
||||||
)
|
)
|
||||||
|
await _inject_staged_exit_features(
|
||||||
|
market=market,
|
||||||
|
stock_code=stock_code,
|
||||||
|
open_position=open_position,
|
||||||
|
market_data=market_data,
|
||||||
|
broker=broker,
|
||||||
|
)
|
||||||
decision = _apply_staged_exit_override_for_hold(
|
decision = _apply_staged_exit_override_for_hold(
|
||||||
decision=decision,
|
decision=decision,
|
||||||
market=market,
|
market=market,
|
||||||
@@ -1375,6 +1477,7 @@ async def trading_cycle(
|
|||||||
open_position=open_position,
|
open_position=open_position,
|
||||||
market_data=market_data,
|
market_data=market_data,
|
||||||
stock_playbook=stock_playbook,
|
stock_playbook=stock_playbook,
|
||||||
|
settings=settings,
|
||||||
)
|
)
|
||||||
if open_position and decision.action == "HOLD" and _should_force_exit_for_overnight(
|
if open_position and decision.action == "HOLD" and _should_force_exit_for_overnight(
|
||||||
market=market,
|
market=market,
|
||||||
@@ -2575,6 +2678,13 @@ async def run_daily_session(
|
|||||||
market_code=market.code,
|
market_code=market.code,
|
||||||
stock_code=stock_code,
|
stock_code=stock_code,
|
||||||
)
|
)
|
||||||
|
await _inject_staged_exit_features(
|
||||||
|
market=market,
|
||||||
|
stock_code=stock_code,
|
||||||
|
open_position=daily_open,
|
||||||
|
market_data=stock_data,
|
||||||
|
broker=broker,
|
||||||
|
)
|
||||||
decision = _apply_staged_exit_override_for_hold(
|
decision = _apply_staged_exit_override_for_hold(
|
||||||
decision=decision,
|
decision=decision,
|
||||||
market=market,
|
market=market,
|
||||||
@@ -2582,6 +2692,7 @@ async def run_daily_session(
|
|||||||
open_position=daily_open,
|
open_position=daily_open,
|
||||||
market_data=stock_data,
|
market_data=stock_data,
|
||||||
stock_playbook=stock_playbook,
|
stock_playbook=stock_playbook,
|
||||||
|
settings=settings,
|
||||||
)
|
)
|
||||||
if daily_open and decision.action == "HOLD" and _should_force_exit_for_overnight(
|
if daily_open and decision.action == "HOLD" and _should_force_exit_for_overnight(
|
||||||
market=market,
|
market=market,
|
||||||
|
|||||||
@@ -16,6 +16,10 @@ from src.logging.decision_logger import DecisionLogger
|
|||||||
from src.main import (
|
from src.main import (
|
||||||
KILL_SWITCH,
|
KILL_SWITCH,
|
||||||
_STOPLOSS_REENTRY_COOLDOWN_UNTIL,
|
_STOPLOSS_REENTRY_COOLDOWN_UNTIL,
|
||||||
|
_apply_staged_exit_override_for_hold,
|
||||||
|
_compute_kr_atr_value,
|
||||||
|
_estimate_pred_down_prob_from_rsi,
|
||||||
|
_inject_staged_exit_features,
|
||||||
_RUNTIME_EXIT_PEAKS,
|
_RUNTIME_EXIT_PEAKS,
|
||||||
_RUNTIME_EXIT_STATES,
|
_RUNTIME_EXIT_STATES,
|
||||||
_should_force_exit_for_overnight,
|
_should_force_exit_for_overnight,
|
||||||
@@ -135,6 +139,99 @@ def test_resolve_sell_qty_for_pnl_uses_buy_qty_fallback_when_sell_qty_missing()
|
|||||||
def test_resolve_sell_qty_for_pnl_returns_zero_when_both_missing() -> None:
|
def test_resolve_sell_qty_for_pnl_returns_zero_when_both_missing() -> None:
|
||||||
assert _resolve_sell_qty_for_pnl(sell_qty=None, buy_qty=None) == 0
|
assert _resolve_sell_qty_for_pnl(sell_qty=None, buy_qty=None) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_estimate_pred_down_prob_from_rsi_uses_linear_mapping() -> None:
|
||||||
|
assert _estimate_pred_down_prob_from_rsi(None) == 0.5
|
||||||
|
assert _estimate_pred_down_prob_from_rsi(0.0) == 0.0
|
||||||
|
assert _estimate_pred_down_prob_from_rsi(50.0) == 0.5
|
||||||
|
assert _estimate_pred_down_prob_from_rsi(100.0) == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_compute_kr_atr_value_returns_zero_on_short_series() -> None:
|
||||||
|
broker = MagicMock()
|
||||||
|
broker.get_daily_prices = AsyncMock(
|
||||||
|
return_value=[{"high": 101.0, "low": 99.0, "close": 100.0}] * 10
|
||||||
|
)
|
||||||
|
|
||||||
|
atr = await _compute_kr_atr_value(broker=broker, stock_code="005930")
|
||||||
|
assert atr == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inject_staged_exit_features_sets_pred_down_prob_and_atr_for_kr() -> None:
|
||||||
|
market = MagicMock()
|
||||||
|
market.is_domestic = True
|
||||||
|
stock_data: dict[str, float] = {"rsi": 65.0}
|
||||||
|
|
||||||
|
broker = MagicMock()
|
||||||
|
broker.get_daily_prices = AsyncMock(
|
||||||
|
return_value=[
|
||||||
|
{"high": 102.0 + i, "low": 98.0 + i, "close": 100.0 + i}
|
||||||
|
for i in range(40)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
await _inject_staged_exit_features(
|
||||||
|
market=market,
|
||||||
|
stock_code="005930",
|
||||||
|
open_position={"price": 100.0, "quantity": 1},
|
||||||
|
market_data=stock_data,
|
||||||
|
broker=broker,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert stock_data["pred_down_prob"] == pytest.approx(0.65)
|
||||||
|
assert stock_data["atr_value"] > 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_staged_exit_uses_independent_arm_threshold_settings() -> None:
|
||||||
|
market = MagicMock()
|
||||||
|
market.code = "KR"
|
||||||
|
market.name = "Korea"
|
||||||
|
|
||||||
|
decision = MagicMock()
|
||||||
|
decision.action = "HOLD"
|
||||||
|
decision.confidence = 70
|
||||||
|
decision.rationale = "hold"
|
||||||
|
|
||||||
|
settings = Settings(
|
||||||
|
KIS_APP_KEY="k",
|
||||||
|
KIS_APP_SECRET="s",
|
||||||
|
KIS_ACCOUNT_NO="12345678-01",
|
||||||
|
GEMINI_API_KEY="g",
|
||||||
|
STAGED_EXIT_BE_ARM_PCT=2.2,
|
||||||
|
STAGED_EXIT_ARM_PCT=5.4,
|
||||||
|
)
|
||||||
|
|
||||||
|
captured: dict[str, float] = {}
|
||||||
|
|
||||||
|
def _fake_eval(**kwargs): # type: ignore[no-untyped-def]
|
||||||
|
cfg = kwargs["config"]
|
||||||
|
captured["be_arm_pct"] = cfg.be_arm_pct
|
||||||
|
captured["arm_pct"] = cfg.arm_pct
|
||||||
|
|
||||||
|
class _Out:
|
||||||
|
should_exit = False
|
||||||
|
reason = "none"
|
||||||
|
state = PositionState.HOLDING
|
||||||
|
|
||||||
|
return _Out()
|
||||||
|
|
||||||
|
with patch("src.main.evaluate_exit", side_effect=_fake_eval):
|
||||||
|
out = _apply_staged_exit_override_for_hold(
|
||||||
|
decision=decision,
|
||||||
|
market=market,
|
||||||
|
stock_code="005930",
|
||||||
|
open_position={"price": 100.0, "quantity": 1, "decision_id": "d1", "timestamp": "t1"},
|
||||||
|
market_data={"current_price": 101.0, "rsi": 60.0, "pred_down_prob": 0.6},
|
||||||
|
stock_playbook=None,
|
||||||
|
settings=settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert out is decision
|
||||||
|
assert captured["be_arm_pct"] == pytest.approx(2.2)
|
||||||
|
assert captured["arm_pct"] == pytest.approx(5.4)
|
||||||
|
|
||||||
def test_returns_zero_when_field_empty_string(self) -> None:
|
def test_returns_zero_when_field_empty_string(self) -> None:
|
||||||
"""Returns 0.0 when pchs_avg_pric is an empty string."""
|
"""Returns 0.0 when pchs_avg_pric is an empty string."""
|
||||||
balance = {"output1": [{"pdno": "005930", "pchs_avg_pric": ""}]}
|
balance = {"output1": [{"pdno": "005930", "pchs_avg_pric": ""}]}
|
||||||
|
|||||||
Reference in New Issue
Block a user