Compare commits

...

2 Commits

Author SHA1 Message Date
agentson
42c06929ea test: add session-risk reload edge-case coverage (#327)
Some checks are pending
Gitea CI / test (push) Waiting to run
Gitea CI / test (pull_request) Waiting to run
2026-02-28 22:20:59 +09:00
agentson
5facd22ef9 feat: reload session risk profile on session transitions (#327)
Some checks failed
Gitea CI / test (pull_request) Waiting to run
Gitea CI / test (push) Has been cancelled
2026-02-28 21:04:06 +09:00
3 changed files with 332 additions and 15 deletions

View File

@@ -68,6 +68,8 @@ class Settings(BaseSettings):
KR_ATR_STOP_MIN_PCT: float = Field(default=-2.0, le=0.0)
KR_ATR_STOP_MAX_PCT: float = Field(default=-7.0, le=0.0)
OVERNIGHT_EXCEPTION_ENABLED: bool = True
SESSION_RISK_RELOAD_ENABLED: bool = True
SESSION_RISK_PROFILES_JSON: str = "{}"
# Trading frequency mode (daily = batch API calls, realtime = per-stock calls)
TRADE_MODE: str = Field(default="daily", pattern="^(daily|realtime)$")

View File

@@ -72,6 +72,10 @@ _RUNTIME_EXIT_STATES: dict[str, PositionState] = {}
_RUNTIME_EXIT_PEAKS: dict[str, float] = {}
_STOPLOSS_REENTRY_COOLDOWN_UNTIL: dict[str, float] = {}
_VOLATILITY_ANALYZER = VolatilityAnalyzer()
_SESSION_RISK_PROFILES_RAW = "{}"
_SESSION_RISK_PROFILES_MAP: dict[str, dict[str, Any]] = {}
_SESSION_RISK_LAST_BY_MARKET: dict[str, str] = {}
_SESSION_RISK_OVERRIDES_BY_MARKET: dict[str, dict[str, Any]] = {}
def safe_float(value: str | float | None, default: float = 0.0) -> float:
@@ -122,6 +126,7 @@ def _resolve_sell_qty_for_pnl(*, sell_qty: int | None, buy_qty: int | None) -> i
def _compute_kr_dynamic_stop_loss_pct(
*,
market: MarketInfo | None = None,
entry_price: float,
atr_value: float,
fallback_stop_loss_pct: float,
@@ -131,9 +136,24 @@ def _compute_kr_dynamic_stop_loss_pct(
if entry_price <= 0 or atr_value <= 0:
return fallback_stop_loss_pct
k = float(getattr(settings, "KR_ATR_STOP_MULTIPLIER_K", 2.0) if settings else 2.0)
min_pct = float(getattr(settings, "KR_ATR_STOP_MIN_PCT", -2.0) if settings else -2.0)
max_pct = float(getattr(settings, "KR_ATR_STOP_MAX_PCT", -7.0) if settings else -7.0)
k = _resolve_market_setting(
market=market,
settings=settings,
key="KR_ATR_STOP_MULTIPLIER_K",
default=2.0,
)
min_pct = _resolve_market_setting(
market=market,
settings=settings,
key="KR_ATR_STOP_MIN_PCT",
default=-2.0,
)
max_pct = _resolve_market_setting(
market=market,
settings=settings,
key="KR_ATR_STOP_MAX_PCT",
default=-7.0,
)
if max_pct > min_pct:
min_pct, max_pct = max_pct, min_pct
@@ -145,10 +165,123 @@ def _stoploss_cooldown_key(*, market: MarketInfo, stock_code: str) -> str:
return f"{market.code}:{stock_code}"
def _stoploss_cooldown_minutes(settings: Settings | None) -> int:
def _parse_session_risk_profiles(settings: Settings | None) -> dict[str, dict[str, Any]]:
if settings is None:
return 120
return max(1, int(getattr(settings, "STOPLOSS_REENTRY_COOLDOWN_MINUTES", 120)))
return {}
global _SESSION_RISK_PROFILES_RAW, _SESSION_RISK_PROFILES_MAP
raw = str(getattr(settings, "SESSION_RISK_PROFILES_JSON", "{}") or "{}")
if raw == _SESSION_RISK_PROFILES_RAW:
return _SESSION_RISK_PROFILES_MAP
parsed_map: dict[str, dict[str, Any]] = {}
try:
decoded = json.loads(raw)
if isinstance(decoded, dict):
for session_id, session_values in decoded.items():
if isinstance(session_id, str) and isinstance(session_values, dict):
parsed_map[session_id] = session_values
except (ValueError, TypeError) as exc:
logger.warning("Invalid SESSION_RISK_PROFILES_JSON; using defaults: %s", exc)
parsed_map = {}
_SESSION_RISK_PROFILES_RAW = raw
_SESSION_RISK_PROFILES_MAP = parsed_map
return _SESSION_RISK_PROFILES_MAP
def _coerce_setting_value(*, value: Any, default: Any) -> Any:
if isinstance(default, bool):
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.strip().lower() in {"1", "true", "yes", "on"}
if isinstance(value, (int, float)):
return value != 0
return default
if isinstance(default, int) and not isinstance(default, bool):
try:
return int(value)
except (ValueError, TypeError):
return default
if isinstance(default, float):
return safe_float(value, float(default))
if isinstance(default, str):
return str(value)
return value
def _session_risk_overrides(
*,
market: MarketInfo | None,
settings: Settings | None,
) -> dict[str, Any]:
if market is None or settings is None:
return {}
if not bool(getattr(settings, "SESSION_RISK_RELOAD_ENABLED", True)):
return {}
session_id = get_session_info(market).session_id
previous_session = _SESSION_RISK_LAST_BY_MARKET.get(market.code)
if previous_session == session_id:
return _SESSION_RISK_OVERRIDES_BY_MARKET.get(market.code, {})
profile_map = _parse_session_risk_profiles(settings)
merged: dict[str, Any] = {}
default_profile = profile_map.get("default")
if isinstance(default_profile, dict):
merged.update(default_profile)
session_profile = profile_map.get(session_id)
if isinstance(session_profile, dict):
merged.update(session_profile)
_SESSION_RISK_LAST_BY_MARKET[market.code] = session_id
_SESSION_RISK_OVERRIDES_BY_MARKET[market.code] = merged
if previous_session is None:
logger.info(
"Session risk profile initialized for %s: %s (overrides=%s)",
market.code,
session_id,
",".join(sorted(merged.keys())) if merged else "none",
)
else:
logger.info(
"Session risk profile reloaded for %s: %s -> %s (overrides=%s)",
market.code,
previous_session,
session_id,
",".join(sorted(merged.keys())) if merged else "none",
)
return merged
def _resolve_market_setting(
*,
market: MarketInfo | None,
settings: Settings | None,
key: str,
default: Any,
) -> Any:
if settings is None:
return default
fallback = getattr(settings, key, default)
overrides = _session_risk_overrides(market=market, settings=settings)
if key not in overrides:
return fallback
return _coerce_setting_value(value=overrides[key], default=fallback)
def _stoploss_cooldown_minutes(
settings: Settings | None,
market: MarketInfo | None = None,
) -> int:
minutes = _resolve_market_setting(
market=market,
settings=settings,
key="STOPLOSS_REENTRY_COOLDOWN_MINUTES",
default=120,
)
return max(1, int(minutes))
def _estimate_pred_down_prob_from_rsi(rsi: float | str | None) -> float:
@@ -578,7 +711,14 @@ def _should_block_overseas_buy_for_fx_buffer(
):
return False, total_cash - order_amount, 0.0
remaining = total_cash - order_amount
required = settings.USD_BUFFER_MIN
required = float(
_resolve_market_setting(
market=market,
settings=settings,
key="USD_BUFFER_MIN",
default=1000.0,
)
)
return remaining < required, remaining, required
@@ -594,7 +734,13 @@ def _should_force_exit_for_overnight(
return True
if settings is None:
return False
return not settings.OVERNIGHT_EXCEPTION_ENABLED
overnight_enabled = _resolve_market_setting(
market=market,
settings=settings,
key="OVERNIGHT_EXCEPTION_ENABLED",
default=True,
)
return not bool(overnight_enabled)
def _build_runtime_position_key(
@@ -643,6 +789,7 @@ def _apply_staged_exit_override_for_hold(
atr_value = safe_float(market_data.get("atr_value"), 0.0)
if market.code == "KR":
stop_loss_threshold = _compute_kr_dynamic_stop_loss_pct(
market=market,
entry_price=entry_price,
atr_value=atr_value,
fallback_stop_loss_pct=stop_loss_threshold,
@@ -652,10 +799,27 @@ def _apply_staged_exit_override_for_hold(
be_arm_pct = max(0.5, take_profit_threshold * 0.4)
arm_pct = take_profit_threshold
else:
be_arm_pct = max(0.1, float(getattr(settings, "STAGED_EXIT_BE_ARM_PCT", 1.2)))
be_arm_pct = max(
0.1,
float(
_resolve_market_setting(
market=market,
settings=settings,
key="STAGED_EXIT_BE_ARM_PCT",
default=1.2,
)
),
)
arm_pct = max(
be_arm_pct,
float(getattr(settings, "STAGED_EXIT_ARM_PCT", 3.0)),
float(
_resolve_market_setting(
market=market,
settings=settings,
key="STAGED_EXIT_ARM_PCT",
default=3.0,
)
),
)
runtime_key = _build_runtime_position_key(
@@ -1148,6 +1312,7 @@ async def trading_cycle(
) -> None:
"""Execute one trading cycle for a single stock."""
cycle_start_time = asyncio.get_event_loop().time()
_session_risk_overrides(market=market, settings=settings)
# 1. Fetch market data
price_output: dict[str, Any] = {} # Populated for overseas markets; used for fallback metrics
@@ -1397,7 +1562,14 @@ async def trading_cycle(
# 2.1. Apply market_outlook-based BUY confidence threshold
if decision.action == "BUY":
base_threshold = (settings.CONFIDENCE_THRESHOLD if settings else 80)
base_threshold = int(
_resolve_market_setting(
market=market,
settings=settings,
key="CONFIDENCE_THRESHOLD",
default=80,
)
)
outlook = playbook.market_outlook
if outlook == MarketOutlook.BEARISH:
min_confidence = 90
@@ -1450,7 +1622,14 @@ async def trading_cycle(
market.name,
)
elif market.code.startswith("US"):
min_price = float(getattr(settings, "US_MIN_PRICE", 5.0) if settings else 5.0)
min_price = float(
_resolve_market_setting(
market=market,
settings=settings,
key="US_MIN_PRICE",
default=5.0,
)
)
if current_price <= min_price:
decision = TradeDecision(
action="HOLD",
@@ -1877,7 +2056,7 @@ async def trading_cycle(
)
if trade_pnl < 0:
cooldown_key = _stoploss_cooldown_key(market=market, stock_code=stock_code)
cooldown_minutes = _stoploss_cooldown_minutes(settings)
cooldown_minutes = _stoploss_cooldown_minutes(settings, market=market)
_STOPLOSS_REENTRY_COOLDOWN_UNTIL[cooldown_key] = (
datetime.now(UTC).timestamp() + cooldown_minutes * 60
)
@@ -2329,6 +2508,7 @@ async def run_daily_session(
# Process each open market
for market in open_markets:
_session_risk_overrides(market=market, settings=settings)
await process_blackout_recovery_orders(
broker=broker,
overseas_broker=overseas_broker,
@@ -2666,7 +2846,14 @@ async def run_daily_session(
market.name,
)
elif market.code.startswith("US"):
min_price = float(getattr(settings, "US_MIN_PRICE", 5.0))
min_price = float(
_resolve_market_setting(
market=market,
settings=settings,
key="US_MIN_PRICE",
default=5.0,
)
)
if stock_data["current_price"] <= min_price:
decision = TradeDecision(
action="HOLD",
@@ -3041,7 +3228,10 @@ async def run_daily_session(
)
if trade_pnl < 0:
cooldown_key = _stoploss_cooldown_key(market=market, stock_code=stock_code)
cooldown_minutes = _stoploss_cooldown_minutes(settings)
cooldown_minutes = _stoploss_cooldown_minutes(
settings,
market=market,
)
_STOPLOSS_REENTRY_COOLDOWN_UNTIL[cooldown_key] = (
datetime.now(UTC).timestamp() + cooldown_minutes * 60
)
@@ -3849,6 +4039,7 @@ async def run(settings: Settings) -> None:
break
session_info = get_session_info(market)
_session_risk_overrides(market=market, settings=settings)
logger.info(
"Market session active: %s (%s) session=%s",
market.code,

View File

@@ -4,6 +4,7 @@ from datetime import UTC, date, datetime
from unittest.mock import ANY, AsyncMock, MagicMock, patch
import pytest
import src.main as main_module
from src.config import Settings
from src.context.layer import ContextLayer
@@ -15,6 +16,9 @@ from src.evolution.scorecard import DailyScorecard
from src.logging.decision_logger import DecisionLogger
from src.main import (
KILL_SWITCH,
_SESSION_RISK_LAST_BY_MARKET,
_SESSION_RISK_OVERRIDES_BY_MARKET,
_SESSION_RISK_PROFILES_MAP,
_STOPLOSS_REENTRY_COOLDOWN_UNTIL,
_apply_staged_exit_override_for_hold,
_compute_kr_atr_value,
@@ -32,10 +36,12 @@ from src.main import (
_extract_held_qty_from_balance,
_handle_market_close,
_retry_connection,
_resolve_market_setting,
_resolve_sell_qty_for_pnl,
_run_context_scheduler,
_run_evolution_loop,
_start_dashboard_server,
_stoploss_cooldown_minutes,
_compute_kr_dynamic_stop_loss_pct,
handle_domestic_pending_orders,
handle_overseas_pending_orders,
@@ -99,11 +105,19 @@ def _reset_kill_switch_state() -> None:
KILL_SWITCH.clear_block()
_RUNTIME_EXIT_STATES.clear()
_RUNTIME_EXIT_PEAKS.clear()
_SESSION_RISK_LAST_BY_MARKET.clear()
_SESSION_RISK_OVERRIDES_BY_MARKET.clear()
_SESSION_RISK_PROFILES_MAP.clear()
main_module._SESSION_RISK_PROFILES_RAW = "__reset__"
_STOPLOSS_REENTRY_COOLDOWN_UNTIL.clear()
yield
KILL_SWITCH.clear_block()
_RUNTIME_EXIT_STATES.clear()
_RUNTIME_EXIT_PEAKS.clear()
_SESSION_RISK_LAST_BY_MARKET.clear()
_SESSION_RISK_OVERRIDES_BY_MARKET.clear()
_SESSION_RISK_PROFILES_MAP.clear()
main_module._SESSION_RISK_PROFILES_RAW = "__reset__"
_STOPLOSS_REENTRY_COOLDOWN_UNTIL.clear()
@@ -186,6 +200,116 @@ def test_compute_kr_dynamic_stop_loss_pct_uses_settings_values() -> None:
assert out == -3.0
def test_resolve_market_setting_uses_session_profile_override() -> None:
settings = Settings(
KIS_APP_KEY="k",
KIS_APP_SECRET="s",
KIS_ACCOUNT_NO="12345678-01",
GEMINI_API_KEY="g",
SESSION_RISK_PROFILES_JSON='{"US_PRE": {"US_MIN_PRICE": 7.5}}',
)
market = MagicMock()
market.code = "US_NASDAQ"
with patch("src.main.get_session_info", return_value=MagicMock(session_id="US_PRE")):
value = _resolve_market_setting(
market=market,
settings=settings,
key="US_MIN_PRICE",
default=5.0,
)
assert value == pytest.approx(7.5)
def test_stoploss_cooldown_minutes_uses_session_override() -> None:
settings = Settings(
KIS_APP_KEY="k",
KIS_APP_SECRET="s",
KIS_ACCOUNT_NO="12345678-01",
GEMINI_API_KEY="g",
STOPLOSS_REENTRY_COOLDOWN_MINUTES=120,
SESSION_RISK_PROFILES_JSON='{"NXT_AFTER": {"STOPLOSS_REENTRY_COOLDOWN_MINUTES": 45}}',
)
market = MagicMock()
market.code = "KR"
with patch("src.main.get_session_info", return_value=MagicMock(session_id="NXT_AFTER")):
value = _stoploss_cooldown_minutes(settings, market=market)
assert value == 45
def test_resolve_market_setting_ignores_profile_when_reload_disabled() -> None:
settings = Settings(
KIS_APP_KEY="k",
KIS_APP_SECRET="s",
KIS_ACCOUNT_NO="12345678-01",
GEMINI_API_KEY="g",
US_MIN_PRICE=5.0,
SESSION_RISK_RELOAD_ENABLED=False,
SESSION_RISK_PROFILES_JSON='{"US_PRE": {"US_MIN_PRICE": 9.5}}',
)
market = MagicMock()
market.code = "US_NASDAQ"
with patch("src.main.get_session_info", return_value=MagicMock(session_id="US_PRE")):
value = _resolve_market_setting(
market=market,
settings=settings,
key="US_MIN_PRICE",
default=5.0,
)
assert value == pytest.approx(5.0)
def test_resolve_market_setting_falls_back_on_invalid_profile_json() -> None:
settings = Settings(
KIS_APP_KEY="k",
KIS_APP_SECRET="s",
KIS_ACCOUNT_NO="12345678-01",
GEMINI_API_KEY="g",
US_MIN_PRICE=5.0,
SESSION_RISK_PROFILES_JSON="{invalid-json",
)
market = MagicMock()
market.code = "US_NASDAQ"
with patch("src.main.get_session_info", return_value=MagicMock(session_id="US_PRE")):
value = _resolve_market_setting(
market=market,
settings=settings,
key="US_MIN_PRICE",
default=5.0,
)
assert value == pytest.approx(5.0)
def test_resolve_market_setting_coerces_bool_string_override() -> None:
settings = Settings(
KIS_APP_KEY="k",
KIS_APP_SECRET="s",
KIS_ACCOUNT_NO="12345678-01",
GEMINI_API_KEY="g",
OVERNIGHT_EXCEPTION_ENABLED=True,
SESSION_RISK_PROFILES_JSON='{"US_AFTER": {"OVERNIGHT_EXCEPTION_ENABLED": "false"}}',
)
market = MagicMock()
market.code = "US_NASDAQ"
with patch("src.main.get_session_info", return_value=MagicMock(session_id="US_AFTER")):
value = _resolve_market_setting(
market=market,
settings=settings,
key="OVERNIGHT_EXCEPTION_ENABLED",
default=True,
)
assert value is False
def test_estimate_pred_down_prob_from_rsi_uses_linear_mapping() -> None:
assert _estimate_pred_down_prob_from_rsi(None) == 0.5
assert _estimate_pred_down_prob_from_rsi(0.0) == 0.0