From ac4fb00644617e7ba92b1ee704d0a3fe30af6f27 Mon Sep 17 00:00:00 2001 From: agentson Date: Mon, 23 Feb 2026 12:51:15 +0900 Subject: [PATCH] =?UTF-8?q?feat:=20Daily=20=EB=AA=A8=EB=93=9C=20Connection?= =?UTF-8?q?Error=20=EC=9E=AC=EC=8B=9C=EB=8F=84=20=EB=A1=9C=EC=A7=81=20?= =?UTF-8?q?=EC=B6=94=EA=B0=80=20(issue=20#209)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - _retry_connection() 헬퍼 추가: MAX_CONNECTION_RETRIES(3회) 지수 백오프 (2^attempt 초) 재시도, 읽기 전용 API 호출에만 적용 (주문 제외) - run_daily_session(): get_current_price / get_overseas_price 호출에 적용 - run_daily_session(): get_balance / get_overseas_balance 호출에 적용 - 잔고 조회 전체 실패 시 해당 마켓을 skip하고 다른 마켓은 계속 처리 - 테스트 5개 추가: TestRetryConnection 클래스 Co-Authored-By: Claude Sonnet 4.6 --- src/main.py | 77 ++++++++++++++++++++++++++++++++++++---- tests/test_main.py | 88 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 159 insertions(+), 6 deletions(-) diff --git a/src/main.py b/src/main.py index 35acfa0..e76bca1 100644 --- a/src/main.py +++ b/src/main.py @@ -88,6 +88,47 @@ DAILY_TRADE_SESSIONS = 4 # Number of trading sessions per day TRADE_SESSION_INTERVAL_HOURS = 6 # Hours between sessions +async def _retry_connection(coro_factory: Any, *args: Any, label: str = "", **kwargs: Any) -> Any: + """Call an async function retrying on ConnectionError with exponential backoff. + + Retries up to MAX_CONNECTION_RETRIES times (exclusive of the first attempt), + sleeping 2^attempt seconds between attempts. Use only for idempotent read + operations — never for order submission. + + Args: + coro_factory: Async callable (method or function) to invoke. + *args: Positional arguments forwarded to coro_factory. + label: Human-readable label for log messages. + **kwargs: Keyword arguments forwarded to coro_factory. + + Raises: + ConnectionError: If all retries are exhausted. + """ + for attempt in range(1, MAX_CONNECTION_RETRIES + 1): + try: + return await coro_factory(*args, **kwargs) + except ConnectionError as exc: + if attempt < MAX_CONNECTION_RETRIES: + wait_secs = 2 ** attempt + logger.warning( + "Connection error %s (attempt %d/%d), retrying in %ds: %s", + label, + attempt, + MAX_CONNECTION_RETRIES, + wait_secs, + exc, + ) + await asyncio.sleep(wait_secs) + else: + logger.error( + "Connection error %s — all %d retries exhausted: %s", + label, + MAX_CONNECTION_RETRIES, + exc, + ) + raise + + def _extract_symbol_from_holding(item: dict[str, Any]) -> str: """Extract symbol from overseas holding payload variants.""" for key in ( @@ -964,11 +1005,18 @@ async def run_daily_session( try: if market.is_domestic: current_price, price_change_pct, foreigner_net = ( - await broker.get_current_price(stock_code) + await _retry_connection( + broker.get_current_price, + stock_code, + label=stock_code, + ) ) else: - price_data = await overseas_broker.get_overseas_price( - market.exchange_code, stock_code + price_data = await _retry_connection( + overseas_broker.get_overseas_price, + market.exchange_code, + stock_code, + label=f"{stock_code}@{market.exchange_code}", ) current_price = safe_float( price_data.get("output", {}).get("last", "0") @@ -1019,9 +1067,27 @@ async def run_daily_session( logger.warning("No valid stock data for market %s", market.code) continue - # Get balance data once for the market + # Get balance data once for the market (read-only — safe to retry) + try: + if market.is_domestic: + balance_data = await _retry_connection( + broker.get_balance, label=f"balance:{market.code}" + ) + else: + balance_data = await _retry_connection( + overseas_broker.get_overseas_balance, + market.exchange_code, + label=f"overseas_balance:{market.exchange_code}", + ) + except ConnectionError as exc: + logger.error( + "Balance fetch failed for market %s after all retries — skipping market: %s", + market.code, + exc, + ) + continue + if market.is_domestic: - balance_data = await broker.get_balance() output2 = balance_data.get("output2", [{}]) total_eval = safe_float( output2[0].get("tot_evlu_amt", "0") @@ -1033,7 +1099,6 @@ async def run_daily_session( output2[0].get("pchs_amt_smtl_amt", "0") ) if output2 else 0 else: - balance_data = await overseas_broker.get_overseas_balance(market.exchange_code) output2 = balance_data.get("output2", [{}]) if isinstance(output2, list) and output2: balance_info = output2[0] diff --git a/tests/test_main.py b/tests/test_main.py index f00467f..76d7e13 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -18,6 +18,7 @@ from src.main import ( _extract_held_codes_from_balance, _extract_held_qty_from_balance, _handle_market_close, + _retry_connection, _run_context_scheduler, _run_evolution_loop, _start_dashboard_server, @@ -3183,3 +3184,90 @@ class TestOverseasBrokerIntegration: # DB도 브로커도 보유 없음 → BUY 주문이 실행되어야 함 (회귀 테스트) overseas_broker.send_overseas_order.assert_called_once() + + +# --------------------------------------------------------------------------- +# _retry_connection — unit tests (issue #209) +# --------------------------------------------------------------------------- + + +class TestRetryConnection: + """Unit tests for the _retry_connection helper (issue #209).""" + + @pytest.mark.asyncio + async def test_success_on_first_attempt(self) -> None: + """Returns the result immediately when the first call succeeds.""" + async def ok() -> str: + return "data" + + result = await _retry_connection(ok, label="test") + assert result == "data" + + @pytest.mark.asyncio + async def test_succeeds_after_one_connection_error(self) -> None: + """Retries once on ConnectionError and returns result on 2nd attempt.""" + call_count = 0 + + async def flaky() -> str: + nonlocal call_count + call_count += 1 + if call_count < 2: + raise ConnectionError("timeout") + return "ok" + + with patch("src.main.asyncio.sleep") as mock_sleep: + mock_sleep.return_value = None + result = await _retry_connection(flaky, label="flaky") + + assert result == "ok" + assert call_count == 2 + mock_sleep.assert_called_once() + + @pytest.mark.asyncio + async def test_raises_after_all_retries_exhausted(self) -> None: + """Raises ConnectionError after MAX_CONNECTION_RETRIES attempts.""" + from src.main import MAX_CONNECTION_RETRIES + + call_count = 0 + + async def always_fail() -> None: + nonlocal call_count + call_count += 1 + raise ConnectionError("unreachable") + + with patch("src.main.asyncio.sleep") as mock_sleep: + mock_sleep.return_value = None + with pytest.raises(ConnectionError, match="unreachable"): + await _retry_connection(always_fail, label="always_fail") + + assert call_count == MAX_CONNECTION_RETRIES + + @pytest.mark.asyncio + async def test_passes_args_and_kwargs_to_factory(self) -> None: + """Forwards positional and keyword arguments to the callable.""" + received: dict = {} + + async def capture(a: int, b: int, *, key: str) -> str: + received["a"] = a + received["b"] = b + received["key"] = key + return "captured" + + result = await _retry_connection(capture, 1, 2, key="val", label="test") + assert result == "captured" + assert received == {"a": 1, "b": 2, "key": "val"} + + @pytest.mark.asyncio + async def test_non_connection_error_not_retried(self) -> None: + """Non-ConnectionError exceptions propagate immediately without retry.""" + call_count = 0 + + async def bad_input() -> None: + nonlocal call_count + call_count += 1 + raise ValueError("bad data") + + with pytest.raises(ValueError, match="bad data"): + await _retry_connection(bad_input, label="bad") + + assert call_count == 1 # No retry for non-ConnectionError -- 2.49.1