Compare commits
3 Commits
3923d03650
...
feature/is
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a56adcd342 | ||
| 33b5ff5e54 | |||
|
|
c57ccc4bca |
@@ -56,6 +56,8 @@ class KISBroker:
|
||||
self._access_token: str | None = None
|
||||
self._token_expires_at: float = 0.0
|
||||
self._token_lock = asyncio.Lock()
|
||||
self._last_refresh_attempt: float = 0.0
|
||||
self._refresh_cooldown: float = 60.0 # Seconds (matches KIS 1/minute limit)
|
||||
self._rate_limiter = LeakyBucket(settings.RATE_LIMIT_RPS)
|
||||
|
||||
def _get_session(self) -> aiohttp.ClientSession:
|
||||
@@ -98,7 +100,19 @@ class KISBroker:
|
||||
if self._access_token and now < self._token_expires_at:
|
||||
return self._access_token
|
||||
|
||||
# Check cooldown period (prevents hitting EGW00133: 1/minute limit)
|
||||
time_since_last_attempt = now - self._last_refresh_attempt
|
||||
if time_since_last_attempt < self._refresh_cooldown:
|
||||
remaining = self._refresh_cooldown - time_since_last_attempt
|
||||
error_msg = (
|
||||
f"Token refresh on cooldown. "
|
||||
f"Retry in {remaining:.1f}s (KIS allows 1/minute)"
|
||||
)
|
||||
logger.warning(error_msg)
|
||||
raise ConnectionError(error_msg)
|
||||
|
||||
logger.info("Refreshing KIS access token")
|
||||
self._last_refresh_attempt = now
|
||||
session = self._get_session()
|
||||
url = f"{self._base_url}/oauth2/tokenP"
|
||||
body = {
|
||||
|
||||
47
src/main.py
47
src/main.py
@@ -33,6 +33,35 @@ from src.notifications.telegram_client import TelegramClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def safe_float(value: str | float | None, default: float = 0.0) -> float:
|
||||
"""Convert to float, handling empty strings and None.
|
||||
|
||||
Args:
|
||||
value: Value to convert (string, float, or None)
|
||||
default: Default value if conversion fails
|
||||
|
||||
Returns:
|
||||
Converted float or default value
|
||||
|
||||
Examples:
|
||||
>>> safe_float("123.45")
|
||||
123.45
|
||||
>>> safe_float("")
|
||||
0.0
|
||||
>>> safe_float(None)
|
||||
0.0
|
||||
>>> safe_float("invalid", 99.0)
|
||||
99.0
|
||||
"""
|
||||
if value is None or value == "":
|
||||
return default
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
return default
|
||||
|
||||
|
||||
# Target stock codes to monitor per market
|
||||
WATCHLISTS = {
|
||||
"KR": ["005930", "000660", "035420"], # Samsung, SK Hynix, NAVER
|
||||
@@ -77,16 +106,16 @@ async def trading_cycle(
|
||||
balance_data = await broker.get_balance()
|
||||
|
||||
output2 = balance_data.get("output2", [{}])
|
||||
total_eval = float(output2[0].get("tot_evlu_amt", "0")) if output2 else 0
|
||||
total_cash = float(
|
||||
total_eval = safe_float(output2[0].get("tot_evlu_amt", "0")) if output2 else 0
|
||||
total_cash = safe_float(
|
||||
balance_data.get("output2", [{}])[0].get("dnca_tot_amt", "0")
|
||||
if output2
|
||||
else "0"
|
||||
)
|
||||
purchase_total = float(output2[0].get("pchs_amt_smtl_amt", "0")) if output2 else 0
|
||||
purchase_total = safe_float(output2[0].get("pchs_amt_smtl_amt", "0")) if output2 else 0
|
||||
|
||||
current_price = float(orderbook.get("output1", {}).get("stck_prpr", "0"))
|
||||
foreigner_net = float(orderbook.get("output1", {}).get("frgn_ntby_qty", "0"))
|
||||
current_price = safe_float(orderbook.get("output1", {}).get("stck_prpr", "0"))
|
||||
foreigner_net = safe_float(orderbook.get("output1", {}).get("frgn_ntby_qty", "0"))
|
||||
else:
|
||||
# Overseas market
|
||||
price_data = await overseas_broker.get_overseas_price(
|
||||
@@ -103,11 +132,11 @@ async def trading_cycle(
|
||||
else:
|
||||
balance_info = {}
|
||||
|
||||
total_eval = float(balance_info.get("frcr_evlu_tota", "0") or "0")
|
||||
total_cash = float(balance_info.get("frcr_dncl_amt_2", "0") or "0")
|
||||
purchase_total = float(balance_info.get("frcr_buy_amt_smtl", "0") or "0")
|
||||
total_eval = safe_float(balance_info.get("frcr_evlu_tota", "0") or "0")
|
||||
total_cash = safe_float(balance_info.get("frcr_dncl_amt_2", "0") or "0")
|
||||
purchase_total = safe_float(balance_info.get("frcr_buy_amt_smtl", "0") or "0")
|
||||
|
||||
current_price = float(price_data.get("output", {}).get("last", "0"))
|
||||
current_price = safe_float(price_data.get("output", {}).get("last", "0"))
|
||||
foreigner_net = 0.0 # Not available for overseas
|
||||
|
||||
# Calculate daily P&L %
|
||||
|
||||
@@ -89,6 +89,70 @@ class TestTokenManagement:
|
||||
|
||||
await broker.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_refresh_cooldown_prevents_rapid_retries(self, settings):
|
||||
"""Token refresh should enforce cooldown after failure (issue #54)."""
|
||||
broker = KISBroker(settings)
|
||||
broker._refresh_cooldown = 2.0 # Short cooldown for testing
|
||||
|
||||
# First refresh attempt fails with 403 (EGW00133)
|
||||
mock_resp_403 = AsyncMock()
|
||||
mock_resp_403.status = 403
|
||||
mock_resp_403.text = AsyncMock(
|
||||
return_value='{"error_code":"EGW00133","error_description":"접근토큰 발급 잠시 후 다시 시도하세요(1분당 1회)"}'
|
||||
)
|
||||
mock_resp_403.__aenter__ = AsyncMock(return_value=mock_resp_403)
|
||||
mock_resp_403.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp_403):
|
||||
# First attempt should fail with 403
|
||||
with pytest.raises(ConnectionError, match="Token refresh failed"):
|
||||
await broker._ensure_token()
|
||||
|
||||
# Second attempt within cooldown should fail with cooldown error
|
||||
with pytest.raises(ConnectionError, match="Token refresh on cooldown"):
|
||||
await broker._ensure_token()
|
||||
|
||||
await broker.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_refresh_allowed_after_cooldown(self, settings):
|
||||
"""Token refresh should be allowed after cooldown period expires."""
|
||||
broker = KISBroker(settings)
|
||||
broker._refresh_cooldown = 0.1 # Very short cooldown for testing
|
||||
|
||||
# First attempt fails
|
||||
mock_resp_403 = AsyncMock()
|
||||
mock_resp_403.status = 403
|
||||
mock_resp_403.text = AsyncMock(return_value='{"error_code":"EGW00133"}')
|
||||
mock_resp_403.__aenter__ = AsyncMock(return_value=mock_resp_403)
|
||||
mock_resp_403.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
# Second attempt succeeds
|
||||
mock_resp_200 = AsyncMock()
|
||||
mock_resp_200.status = 200
|
||||
mock_resp_200.json = AsyncMock(
|
||||
return_value={
|
||||
"access_token": "tok_after_cooldown",
|
||||
"expires_in": 86400,
|
||||
}
|
||||
)
|
||||
mock_resp_200.__aenter__ = AsyncMock(return_value=mock_resp_200)
|
||||
mock_resp_200.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp_403):
|
||||
with pytest.raises(ConnectionError, match="Token refresh failed"):
|
||||
await broker._ensure_token()
|
||||
|
||||
# Wait for cooldown to expire
|
||||
await asyncio.sleep(0.15)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp_200):
|
||||
token = await broker._ensure_token()
|
||||
assert token == "tok_after_cooldown"
|
||||
|
||||
await broker.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Network Error Handling
|
||||
|
||||
@@ -6,7 +6,43 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from src.core.risk_manager import CircuitBreakerTripped, FatFingerRejected
|
||||
from src.main import trading_cycle
|
||||
from src.main import safe_float, trading_cycle
|
||||
|
||||
|
||||
class TestSafeFloat:
|
||||
"""Test safe_float() helper function."""
|
||||
|
||||
def test_converts_valid_string(self):
|
||||
"""Test conversion of valid numeric string."""
|
||||
assert safe_float("123.45") == 123.45
|
||||
assert safe_float("0") == 0.0
|
||||
assert safe_float("-99.9") == -99.9
|
||||
|
||||
def test_handles_empty_string(self):
|
||||
"""Test empty string returns default."""
|
||||
assert safe_float("") == 0.0
|
||||
assert safe_float("", 99.0) == 99.0
|
||||
|
||||
def test_handles_none(self):
|
||||
"""Test None returns default."""
|
||||
assert safe_float(None) == 0.0
|
||||
assert safe_float(None, 42.0) == 42.0
|
||||
|
||||
def test_handles_invalid_string(self):
|
||||
"""Test invalid string returns default."""
|
||||
assert safe_float("invalid") == 0.0
|
||||
assert safe_float("not_a_number", 100.0) == 100.0
|
||||
assert safe_float("12.34.56") == 0.0
|
||||
|
||||
def test_handles_float_input(self):
|
||||
"""Test float input passes through."""
|
||||
assert safe_float(123.45) == 123.45
|
||||
assert safe_float(0.0) == 0.0
|
||||
|
||||
def test_custom_default(self):
|
||||
"""Test custom default value."""
|
||||
assert safe_float("", -1.0) == -1.0
|
||||
assert safe_float(None, 999.0) == 999.0
|
||||
|
||||
|
||||
class TestTradingCycleTelegramIntegration:
|
||||
|
||||
Reference in New Issue
Block a user