fix: KR session-aware exchange routing (#409) #411

Merged
jihoson merged 8 commits from feature/issue-409-kr-session-exchange-routing into main 2026-03-04 23:06:09 +09:00
6 changed files with 179 additions and 8 deletions
Showing only changes of commit c80f3daad7 - Show all commits

View File

@@ -68,6 +68,7 @@ class SmartVolatilityScanner:
self, self,
market: MarketInfo | None = None, market: MarketInfo | None = None,
fallback_stocks: list[str] | None = None, fallback_stocks: list[str] | None = None,
domestic_session_id: str | None = None,
) -> list[ScanCandidate]: ) -> list[ScanCandidate]:
"""Execute smart scan and return qualified candidates. """Execute smart scan and return qualified candidates.
@@ -81,11 +82,12 @@ class SmartVolatilityScanner:
if market and not market.is_domestic: if market and not market.is_domestic:
return await self._scan_overseas(market, fallback_stocks) return await self._scan_overseas(market, fallback_stocks)
return await self._scan_domestic(fallback_stocks) return await self._scan_domestic(fallback_stocks, session_id=domestic_session_id)
async def _scan_domestic( async def _scan_domestic(
self, self,
fallback_stocks: list[str] | None = None, fallback_stocks: list[str] | None = None,
session_id: str | None = None,
) -> list[ScanCandidate]: ) -> list[ScanCandidate]:
"""Scan domestic market using volatility-first ranking + liquidity bonus.""" """Scan domestic market using volatility-first ranking + liquidity bonus."""
# 1) Primary universe from fluctuation ranking. # 1) Primary universe from fluctuation ranking.
@@ -93,6 +95,7 @@ class SmartVolatilityScanner:
fluct_rows = await self.broker.fetch_market_rankings( fluct_rows = await self.broker.fetch_market_rankings(
ranking_type="fluctuation", ranking_type="fluctuation",
limit=50, limit=50,
session_id=session_id,
) )
except ConnectionError as exc: except ConnectionError as exc:
logger.warning("Domestic fluctuation ranking failed: %s", exc) logger.warning("Domestic fluctuation ranking failed: %s", exc)
@@ -103,6 +106,7 @@ class SmartVolatilityScanner:
volume_rows = await self.broker.fetch_market_rankings( volume_rows = await self.broker.fetch_market_rankings(
ranking_type="volume", ranking_type="volume",
limit=50, limit=50,
session_id=session_id,
) )
except ConnectionError as exc: except ConnectionError as exc:
logger.warning("Domestic volume ranking failed: %s", exc) logger.warning("Domestic volume ranking failed: %s", exc)

View File

@@ -12,7 +12,10 @@ from typing import Any, cast
import aiohttp import aiohttp
from src.broker.kr_exchange_router import KRExchangeRouter
from src.config import Settings from src.config import Settings
from src.core.order_policy import classify_session_id
from src.markets.schedule import MARKETS
# KIS virtual trading server has a known SSL certificate hostname mismatch. # KIS virtual trading server has a known SSL certificate hostname mismatch.
_KIS_VTS_HOST = "openapivts.koreainvestment.com" _KIS_VTS_HOST = "openapivts.koreainvestment.com"
@@ -92,6 +95,7 @@ class KISBroker:
self._last_refresh_attempt: float = 0.0 self._last_refresh_attempt: float = 0.0
self._refresh_cooldown: float = 60.0 # Seconds (matches KIS 1/minute limit) self._refresh_cooldown: float = 60.0 # Seconds (matches KIS 1/minute limit)
self._rate_limiter = LeakyBucket(settings.RATE_LIMIT_RPS) self._rate_limiter = LeakyBucket(settings.RATE_LIMIT_RPS)
self._kr_router = KRExchangeRouter()
def _get_session(self) -> aiohttp.ClientSession: def _get_session(self) -> aiohttp.ClientSession:
if self._session is None or self._session.closed: if self._session is None or self._session.closed:
@@ -187,9 +191,12 @@ class KISBroker:
if resp.status != 200: if resp.status != 200:
text = await resp.text() text = await resp.text()
raise ConnectionError(f"Hash key request failed ({resp.status}): {text}") raise ConnectionError(f"Hash key request failed ({resp.status}): {text}")
data = await resp.json() data = cast(dict[str, Any], await resp.json())
return data["HASH"] hash_value = data.get("HASH")
if not isinstance(hash_value, str):
raise ConnectionError("Hash key response missing HASH")
return hash_value
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Common Headers # Common Headers
@@ -226,7 +233,7 @@ class KISBroker:
if resp.status != 200: if resp.status != 200:
text = await resp.text() text = await resp.text()
raise ConnectionError(f"get_orderbook failed ({resp.status}): {text}") raise ConnectionError(f"get_orderbook failed ({resp.status}): {text}")
return await resp.json() return cast(dict[str, Any], await resp.json())
except (TimeoutError, aiohttp.ClientError) as exc: except (TimeoutError, aiohttp.ClientError) as exc:
raise ConnectionError(f"Network error fetching orderbook: {exc}") from exc raise ConnectionError(f"Network error fetching orderbook: {exc}") from exc
@@ -302,7 +309,7 @@ class KISBroker:
if resp.status != 200: if resp.status != 200:
text = await resp.text() text = await resp.text()
raise ConnectionError(f"get_balance failed ({resp.status}): {text}") raise ConnectionError(f"get_balance failed ({resp.status}): {text}")
return await resp.json() return cast(dict[str, Any], await resp.json())
except (TimeoutError, aiohttp.ClientError) as exc: except (TimeoutError, aiohttp.ClientError) as exc:
raise ConnectionError(f"Network error fetching balance: {exc}") from exc raise ConnectionError(f"Network error fetching balance: {exc}") from exc
@@ -312,6 +319,7 @@ class KISBroker:
order_type: str, # "BUY" or "SELL" order_type: str, # "BUY" or "SELL"
quantity: int, quantity: int,
price: int = 0, price: int = 0,
session_id: str | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Submit a buy or sell order. """Submit a buy or sell order.
@@ -341,10 +349,17 @@ class KISBroker:
ord_dvsn = "01" # 시장가 ord_dvsn = "01" # 시장가
ord_price = 0 ord_price = 0
resolved_session = session_id or classify_session_id(MARKETS["KR"])
resolution = self._kr_router.resolve_for_order(
stock_code=stock_code,
session_id=resolved_session,
)
body = { body = {
"CANO": self._account_no, "CANO": self._account_no,
"ACNT_PRDT_CD": self._product_cd, "ACNT_PRDT_CD": self._product_cd,
"PDNO": stock_code, "PDNO": stock_code,
"EXCG_ID_DVSN_CD": resolution.exchange_code,
"ORD_DVSN": ord_dvsn, "ORD_DVSN": ord_dvsn,
"ORD_QTY": str(quantity), "ORD_QTY": str(quantity),
"ORD_UNPR": str(ord_price), "ORD_UNPR": str(ord_price),
@@ -361,12 +376,15 @@ class KISBroker:
if resp.status != 200: if resp.status != 200:
text = await resp.text() text = await resp.text()
raise ConnectionError(f"send_order failed ({resp.status}): {text}") raise ConnectionError(f"send_order failed ({resp.status}): {text}")
data = await resp.json() data = cast(dict[str, Any], await resp.json())
logger.info( logger.info(
"Order submitted", "Order submitted",
extra={ extra={
"stock_code": stock_code, "stock_code": stock_code,
"action": order_type, "action": order_type,
"session_id": resolved_session,
"exchange": resolution.exchange_code,
"routing_reason": resolution.reason,
}, },
) )
return data return data
@@ -377,6 +395,7 @@ class KISBroker:
self, self,
ranking_type: str = "volume", ranking_type: str = "volume",
limit: int = 30, limit: int = 30,
session_id: str | None = None,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Fetch market rankings from KIS API. """Fetch market rankings from KIS API.
@@ -394,12 +413,15 @@ class KISBroker:
await self._rate_limiter.acquire() await self._rate_limiter.acquire()
session = self._get_session() session = self._get_session()
resolved_session = session_id or classify_session_id(MARKETS["KR"])
ranking_market_code = self._kr_router.resolve_for_ranking(resolved_session)
if ranking_type == "volume": if ranking_type == "volume":
# 거래량순위: FHPST01710000 / /quotations/volume-rank # 거래량순위: FHPST01710000 / /quotations/volume-rank
tr_id = "FHPST01710000" tr_id = "FHPST01710000"
url = f"{self._base_url}/uapi/domestic-stock/v1/quotations/volume-rank" url = f"{self._base_url}/uapi/domestic-stock/v1/quotations/volume-rank"
params: dict[str, str] = { params: dict[str, str] = {
"FID_COND_MRKT_DIV_CODE": "J", "FID_COND_MRKT_DIV_CODE": ranking_market_code,
"FID_COND_SCR_DIV_CODE": "20171", "FID_COND_SCR_DIV_CODE": "20171",
"FID_INPUT_ISCD": "0000", "FID_INPUT_ISCD": "0000",
"FID_DIV_CLS_CODE": "0", "FID_DIV_CLS_CODE": "0",
@@ -416,7 +438,7 @@ class KISBroker:
tr_id = "FHPST01700000" tr_id = "FHPST01700000"
url = f"{self._base_url}/uapi/domestic-stock/v1/ranking/fluctuation" url = f"{self._base_url}/uapi/domestic-stock/v1/ranking/fluctuation"
params = { params = {
"fid_cond_mrkt_div_code": "J", "fid_cond_mrkt_div_code": ranking_market_code,
"fid_cond_scr_div_code": "20170", "fid_cond_scr_div_code": "20170",
"fid_input_iscd": "0000", "fid_input_iscd": "0000",
"fid_rank_sort_cls_code": "0", "fid_rank_sort_cls_code": "0",

View File

@@ -0,0 +1,48 @@
from __future__ import annotations
from dataclasses import dataclass
@dataclass(frozen=True)
class ExchangeResolution:
exchange_code: str
reason: str
class KRExchangeRouter:
"""Resolve domestic exchange routing for KR sessions."""
def resolve_for_ranking(self, session_id: str) -> str:
if session_id in {"NXT_PRE", "NXT_AFTER"}:
return "NX"
return "J"
def resolve_for_order(
self,
*,
stock_code: str,
session_id: str,
is_dual_listed: bool = False,
spread_krx: float | None = None,
spread_nxt: float | None = None,
liquidity_krx: float | None = None,
liquidity_nxt: float | None = None,
) -> ExchangeResolution:
del stock_code
default_exchange = "NXT" if session_id in {"NXT_PRE", "NXT_AFTER"} else "KRX"
default_reason = "session_default"
if not is_dual_listed:
return ExchangeResolution(default_exchange, default_reason)
if spread_krx is not None and spread_nxt is not None:
if spread_nxt < spread_krx:
return ExchangeResolution("NXT", "dual_listing_spread")
return ExchangeResolution("KRX", "dual_listing_spread")
if liquidity_krx is not None and liquidity_nxt is not None:
if liquidity_nxt > liquidity_krx:
return ExchangeResolution("NXT", "dual_listing_liquidity")
return ExchangeResolution("KRX", "dual_listing_liquidity")
return ExchangeResolution(default_exchange, "fallback_data_unavailable")

View File

@@ -400,6 +400,15 @@ class TestFetchMarketRankings:
assert result[0]["stock_code"] == "015260" assert result[0]["stock_code"] == "015260"
assert result[0]["change_rate"] == 29.74 assert result[0]["change_rate"] == 29.74
@pytest.mark.asyncio
async def test_volume_uses_nx_market_code_in_nxt_session(self, broker: KISBroker) -> None:
mock_resp = _make_ranking_mock([])
with patch("aiohttp.ClientSession.get", return_value=mock_resp) as mock_get:
await broker.fetch_market_rankings(ranking_type="volume", session_id="NXT_PRE")
params = mock_get.call_args[1].get("params", {})
assert params.get("FID_COND_MRKT_DIV_CODE") == "NX"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# KRX tick unit / round-down helpers (issue #157) # KRX tick unit / round-down helpers (issue #157)
@@ -591,6 +600,27 @@ class TestSendOrderTickRounding:
body = order_call[1].get("json", {}) body = order_call[1].get("json", {})
assert body["ORD_DVSN"] == "01" assert body["ORD_DVSN"] == "01"
@pytest.mark.asyncio
async def test_send_order_sets_exchange_field_from_session(self, broker: KISBroker) -> None:
mock_hash = AsyncMock()
mock_hash.status = 200
mock_hash.json = AsyncMock(return_value={"HASH": "h"})
mock_hash.__aenter__ = AsyncMock(return_value=mock_hash)
mock_hash.__aexit__ = AsyncMock(return_value=False)
mock_order = AsyncMock()
mock_order.status = 200
mock_order.json = AsyncMock(return_value={"rt_cd": "0"})
mock_order.__aenter__ = AsyncMock(return_value=mock_order)
mock_order.__aexit__ = AsyncMock(return_value=False)
with patch("aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]) as mock_post:
await broker.send_order("005930", "BUY", 1, price=50000, session_id="NXT_PRE")
order_call = mock_post.call_args_list[1]
body = order_call[1].get("json", {})
assert body["EXCG_ID_DVSN_CD"] == "NXT"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# TR_ID live/paper branching (issues #201, #202, #203) # TR_ID live/paper branching (issues #201, #202, #203)

View File

@@ -0,0 +1,40 @@
from __future__ import annotations
from src.broker.kr_exchange_router import KRExchangeRouter
def test_ranking_market_code_by_session() -> None:
router = KRExchangeRouter()
assert router.resolve_for_ranking("KRX_REG") == "J"
assert router.resolve_for_ranking("NXT_PRE") == "NX"
assert router.resolve_for_ranking("NXT_AFTER") == "NX"
def test_order_exchange_falls_back_to_session_default_on_missing_data() -> None:
router = KRExchangeRouter()
resolved = router.resolve_for_order(
stock_code="0001A0",
session_id="NXT_PRE",
is_dual_listed=True,
spread_krx=None,
spread_nxt=None,
liquidity_krx=None,
liquidity_nxt=None,
)
assert resolved.exchange_code == "NXT"
assert resolved.reason == "fallback_data_unavailable"
def test_order_exchange_uses_spread_preference_for_dual_listing() -> None:
router = KRExchangeRouter()
resolved = router.resolve_for_order(
stock_code="0001A0",
session_id="KRX_REG",
is_dual_listed=True,
spread_krx=0.005,
spread_nxt=0.003,
liquidity_krx=100000.0,
liquidity_nxt=90000.0,
)
assert resolved.exchange_code == "NXT"
assert resolved.reason == "dual_listing_spread"

View File

@@ -103,6 +103,33 @@ class TestSmartVolatilityScanner:
assert candidates[0].stock_code == "005930" assert candidates[0].stock_code == "005930"
assert candidates[0].signal == "oversold" assert candidates[0].signal == "oversold"
@pytest.mark.asyncio
async def test_scan_domestic_passes_session_id_to_rankings(
self, scanner: SmartVolatilityScanner, mock_broker: MagicMock
) -> None:
fluctuation_rows = [
{
"stock_code": "005930",
"name": "Samsung",
"price": 70000,
"volume": 5000000,
"change_rate": 1.0,
"volume_increase_rate": 120,
},
]
mock_broker.fetch_market_rankings.side_effect = [fluctuation_rows, fluctuation_rows]
mock_broker.get_daily_prices.return_value = [
{"open": 1, "high": 71000, "low": 69000, "close": 70000, "volume": 1000000},
{"open": 1, "high": 70000, "low": 68000, "close": 69000, "volume": 900000},
]
await scanner.scan(domestic_session_id="NXT_PRE")
first_call = mock_broker.fetch_market_rankings.call_args_list[0]
second_call = mock_broker.fetch_market_rankings.call_args_list[1]
assert first_call.kwargs["session_id"] == "NXT_PRE"
assert second_call.kwargs["session_id"] == "NXT_PRE"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_scan_domestic_finds_momentum_candidate( async def test_scan_domestic_finds_momentum_candidate(
self, scanner: SmartVolatilityScanner, mock_broker: MagicMock self, scanner: SmartVolatilityScanner, mock_broker: MagicMock