Compare commits
1 Commits
c737d5009a
...
feature/is
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ac4fb00644 |
77
src/main.py
77
src/main.py
@@ -88,6 +88,47 @@ DAILY_TRADE_SESSIONS = 4 # Number of trading sessions per day
|
|||||||
TRADE_SESSION_INTERVAL_HOURS = 6 # Hours between sessions
|
TRADE_SESSION_INTERVAL_HOURS = 6 # Hours between sessions
|
||||||
|
|
||||||
|
|
||||||
|
async def _retry_connection(coro_factory: Any, *args: Any, label: str = "", **kwargs: Any) -> Any:
|
||||||
|
"""Call an async function retrying on ConnectionError with exponential backoff.
|
||||||
|
|
||||||
|
Retries up to MAX_CONNECTION_RETRIES times (exclusive of the first attempt),
|
||||||
|
sleeping 2^attempt seconds between attempts. Use only for idempotent read
|
||||||
|
operations — never for order submission.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
coro_factory: Async callable (method or function) to invoke.
|
||||||
|
*args: Positional arguments forwarded to coro_factory.
|
||||||
|
label: Human-readable label for log messages.
|
||||||
|
**kwargs: Keyword arguments forwarded to coro_factory.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ConnectionError: If all retries are exhausted.
|
||||||
|
"""
|
||||||
|
for attempt in range(1, MAX_CONNECTION_RETRIES + 1):
|
||||||
|
try:
|
||||||
|
return await coro_factory(*args, **kwargs)
|
||||||
|
except ConnectionError as exc:
|
||||||
|
if attempt < MAX_CONNECTION_RETRIES:
|
||||||
|
wait_secs = 2 ** attempt
|
||||||
|
logger.warning(
|
||||||
|
"Connection error %s (attempt %d/%d), retrying in %ds: %s",
|
||||||
|
label,
|
||||||
|
attempt,
|
||||||
|
MAX_CONNECTION_RETRIES,
|
||||||
|
wait_secs,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(wait_secs)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
"Connection error %s — all %d retries exhausted: %s",
|
||||||
|
label,
|
||||||
|
MAX_CONNECTION_RETRIES,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def _extract_symbol_from_holding(item: dict[str, Any]) -> str:
|
def _extract_symbol_from_holding(item: dict[str, Any]) -> str:
|
||||||
"""Extract symbol from overseas holding payload variants."""
|
"""Extract symbol from overseas holding payload variants."""
|
||||||
for key in (
|
for key in (
|
||||||
@@ -964,11 +1005,18 @@ async def run_daily_session(
|
|||||||
try:
|
try:
|
||||||
if market.is_domestic:
|
if market.is_domestic:
|
||||||
current_price, price_change_pct, foreigner_net = (
|
current_price, price_change_pct, foreigner_net = (
|
||||||
await broker.get_current_price(stock_code)
|
await _retry_connection(
|
||||||
|
broker.get_current_price,
|
||||||
|
stock_code,
|
||||||
|
label=stock_code,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
price_data = await overseas_broker.get_overseas_price(
|
price_data = await _retry_connection(
|
||||||
market.exchange_code, stock_code
|
overseas_broker.get_overseas_price,
|
||||||
|
market.exchange_code,
|
||||||
|
stock_code,
|
||||||
|
label=f"{stock_code}@{market.exchange_code}",
|
||||||
)
|
)
|
||||||
current_price = safe_float(
|
current_price = safe_float(
|
||||||
price_data.get("output", {}).get("last", "0")
|
price_data.get("output", {}).get("last", "0")
|
||||||
@@ -1019,9 +1067,27 @@ async def run_daily_session(
|
|||||||
logger.warning("No valid stock data for market %s", market.code)
|
logger.warning("No valid stock data for market %s", market.code)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Get balance data once for the market
|
# Get balance data once for the market (read-only — safe to retry)
|
||||||
|
try:
|
||||||
|
if market.is_domestic:
|
||||||
|
balance_data = await _retry_connection(
|
||||||
|
broker.get_balance, label=f"balance:{market.code}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
balance_data = await _retry_connection(
|
||||||
|
overseas_broker.get_overseas_balance,
|
||||||
|
market.exchange_code,
|
||||||
|
label=f"overseas_balance:{market.exchange_code}",
|
||||||
|
)
|
||||||
|
except ConnectionError as exc:
|
||||||
|
logger.error(
|
||||||
|
"Balance fetch failed for market %s after all retries — skipping market: %s",
|
||||||
|
market.code,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
if market.is_domestic:
|
if market.is_domestic:
|
||||||
balance_data = await broker.get_balance()
|
|
||||||
output2 = balance_data.get("output2", [{}])
|
output2 = balance_data.get("output2", [{}])
|
||||||
total_eval = safe_float(
|
total_eval = safe_float(
|
||||||
output2[0].get("tot_evlu_amt", "0")
|
output2[0].get("tot_evlu_amt", "0")
|
||||||
@@ -1033,7 +1099,6 @@ async def run_daily_session(
|
|||||||
output2[0].get("pchs_amt_smtl_amt", "0")
|
output2[0].get("pchs_amt_smtl_amt", "0")
|
||||||
) if output2 else 0
|
) if output2 else 0
|
||||||
else:
|
else:
|
||||||
balance_data = await overseas_broker.get_overseas_balance(market.exchange_code)
|
|
||||||
output2 = balance_data.get("output2", [{}])
|
output2 = balance_data.get("output2", [{}])
|
||||||
if isinstance(output2, list) and output2:
|
if isinstance(output2, list) and output2:
|
||||||
balance_info = output2[0]
|
balance_info = output2[0]
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from src.main import (
|
|||||||
_extract_held_codes_from_balance,
|
_extract_held_codes_from_balance,
|
||||||
_extract_held_qty_from_balance,
|
_extract_held_qty_from_balance,
|
||||||
_handle_market_close,
|
_handle_market_close,
|
||||||
|
_retry_connection,
|
||||||
_run_context_scheduler,
|
_run_context_scheduler,
|
||||||
_run_evolution_loop,
|
_run_evolution_loop,
|
||||||
_start_dashboard_server,
|
_start_dashboard_server,
|
||||||
@@ -3183,3 +3184,90 @@ class TestOverseasBrokerIntegration:
|
|||||||
|
|
||||||
# DB도 브로커도 보유 없음 → BUY 주문이 실행되어야 함 (회귀 테스트)
|
# DB도 브로커도 보유 없음 → BUY 주문이 실행되어야 함 (회귀 테스트)
|
||||||
overseas_broker.send_overseas_order.assert_called_once()
|
overseas_broker.send_overseas_order.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _retry_connection — unit tests (issue #209)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetryConnection:
|
||||||
|
"""Unit tests for the _retry_connection helper (issue #209)."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_success_on_first_attempt(self) -> None:
|
||||||
|
"""Returns the result immediately when the first call succeeds."""
|
||||||
|
async def ok() -> str:
|
||||||
|
return "data"
|
||||||
|
|
||||||
|
result = await _retry_connection(ok, label="test")
|
||||||
|
assert result == "data"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_succeeds_after_one_connection_error(self) -> None:
|
||||||
|
"""Retries once on ConnectionError and returns result on 2nd attempt."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def flaky() -> str:
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count < 2:
|
||||||
|
raise ConnectionError("timeout")
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
with patch("src.main.asyncio.sleep") as mock_sleep:
|
||||||
|
mock_sleep.return_value = None
|
||||||
|
result = await _retry_connection(flaky, label="flaky")
|
||||||
|
|
||||||
|
assert result == "ok"
|
||||||
|
assert call_count == 2
|
||||||
|
mock_sleep.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_raises_after_all_retries_exhausted(self) -> None:
|
||||||
|
"""Raises ConnectionError after MAX_CONNECTION_RETRIES attempts."""
|
||||||
|
from src.main import MAX_CONNECTION_RETRIES
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def always_fail() -> None:
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
raise ConnectionError("unreachable")
|
||||||
|
|
||||||
|
with patch("src.main.asyncio.sleep") as mock_sleep:
|
||||||
|
mock_sleep.return_value = None
|
||||||
|
with pytest.raises(ConnectionError, match="unreachable"):
|
||||||
|
await _retry_connection(always_fail, label="always_fail")
|
||||||
|
|
||||||
|
assert call_count == MAX_CONNECTION_RETRIES
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_passes_args_and_kwargs_to_factory(self) -> None:
|
||||||
|
"""Forwards positional and keyword arguments to the callable."""
|
||||||
|
received: dict = {}
|
||||||
|
|
||||||
|
async def capture(a: int, b: int, *, key: str) -> str:
|
||||||
|
received["a"] = a
|
||||||
|
received["b"] = b
|
||||||
|
received["key"] = key
|
||||||
|
return "captured"
|
||||||
|
|
||||||
|
result = await _retry_connection(capture, 1, 2, key="val", label="test")
|
||||||
|
assert result == "captured"
|
||||||
|
assert received == {"a": 1, "b": 2, "key": "val"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_non_connection_error_not_retried(self) -> None:
|
||||||
|
"""Non-ConnectionError exceptions propagate immediately without retry."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def bad_input() -> None:
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
raise ValueError("bad data")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="bad data"):
|
||||||
|
await _retry_connection(bad_input, label="bad")
|
||||||
|
|
||||||
|
assert call_count == 1 # No retry for non-ConnectionError
|
||||||
|
|||||||
Reference in New Issue
Block a user