Compare commits
4 Commits
feature/is
...
feature/is
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c57ccc4bca | ||
| 5e4c68c9d8 | |||
|
|
95f540e5df | ||
| 0087a6b20a |
@@ -55,6 +55,7 @@ class KISBroker:
|
|||||||
self._session: aiohttp.ClientSession | None = None
|
self._session: aiohttp.ClientSession | None = None
|
||||||
self._access_token: str | None = None
|
self._access_token: str | None = None
|
||||||
self._token_expires_at: float = 0.0
|
self._token_expires_at: float = 0.0
|
||||||
|
self._token_lock = asyncio.Lock()
|
||||||
self._rate_limiter = LeakyBucket(settings.RATE_LIMIT_RPS)
|
self._rate_limiter = LeakyBucket(settings.RATE_LIMIT_RPS)
|
||||||
|
|
||||||
def _get_session(self) -> aiohttp.ClientSession:
|
def _get_session(self) -> aiohttp.ClientSession:
|
||||||
@@ -80,7 +81,19 @@ class KISBroker:
|
|||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
async def _ensure_token(self) -> str:
|
async def _ensure_token(self) -> str:
|
||||||
"""Return a valid access token, refreshing if expired."""
|
"""Return a valid access token, refreshing if expired.
|
||||||
|
|
||||||
|
Uses a lock to prevent concurrent token refresh attempts that would
|
||||||
|
hit the API's 1-per-minute rate limit (EGW00133).
|
||||||
|
"""
|
||||||
|
# Fast path: check without lock
|
||||||
|
now = asyncio.get_event_loop().time()
|
||||||
|
if self._access_token and now < self._token_expires_at:
|
||||||
|
return self._access_token
|
||||||
|
|
||||||
|
# Slow path: acquire lock and refresh
|
||||||
|
async with self._token_lock:
|
||||||
|
# Re-check after acquiring lock (another coroutine may have refreshed)
|
||||||
now = asyncio.get_event_loop().time()
|
now = asyncio.get_event_loop().time()
|
||||||
if self._access_token and now < self._token_expires_at:
|
if self._access_token and now < self._token_expires_at:
|
||||||
return self._access_token
|
return self._access_token
|
||||||
|
|||||||
47
src/main.py
47
src/main.py
@@ -33,6 +33,35 @@ from src.notifications.telegram_client import TelegramClient
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def safe_float(value: str | float | None, default: float = 0.0) -> float:
|
||||||
|
"""Convert to float, handling empty strings and None.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Value to convert (string, float, or None)
|
||||||
|
default: Default value if conversion fails
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Converted float or default value
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> safe_float("123.45")
|
||||||
|
123.45
|
||||||
|
>>> safe_float("")
|
||||||
|
0.0
|
||||||
|
>>> safe_float(None)
|
||||||
|
0.0
|
||||||
|
>>> safe_float("invalid", 99.0)
|
||||||
|
99.0
|
||||||
|
"""
|
||||||
|
if value is None or value == "":
|
||||||
|
return default
|
||||||
|
try:
|
||||||
|
return float(value)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
# Target stock codes to monitor per market
|
# Target stock codes to monitor per market
|
||||||
WATCHLISTS = {
|
WATCHLISTS = {
|
||||||
"KR": ["005930", "000660", "035420"], # Samsung, SK Hynix, NAVER
|
"KR": ["005930", "000660", "035420"], # Samsung, SK Hynix, NAVER
|
||||||
@@ -77,16 +106,16 @@ async def trading_cycle(
|
|||||||
balance_data = await broker.get_balance()
|
balance_data = await broker.get_balance()
|
||||||
|
|
||||||
output2 = balance_data.get("output2", [{}])
|
output2 = balance_data.get("output2", [{}])
|
||||||
total_eval = float(output2[0].get("tot_evlu_amt", "0")) if output2 else 0
|
total_eval = safe_float(output2[0].get("tot_evlu_amt", "0")) if output2 else 0
|
||||||
total_cash = float(
|
total_cash = safe_float(
|
||||||
balance_data.get("output2", [{}])[0].get("dnca_tot_amt", "0")
|
balance_data.get("output2", [{}])[0].get("dnca_tot_amt", "0")
|
||||||
if output2
|
if output2
|
||||||
else "0"
|
else "0"
|
||||||
)
|
)
|
||||||
purchase_total = float(output2[0].get("pchs_amt_smtl_amt", "0")) if output2 else 0
|
purchase_total = safe_float(output2[0].get("pchs_amt_smtl_amt", "0")) if output2 else 0
|
||||||
|
|
||||||
current_price = float(orderbook.get("output1", {}).get("stck_prpr", "0"))
|
current_price = safe_float(orderbook.get("output1", {}).get("stck_prpr", "0"))
|
||||||
foreigner_net = float(orderbook.get("output1", {}).get("frgn_ntby_qty", "0"))
|
foreigner_net = safe_float(orderbook.get("output1", {}).get("frgn_ntby_qty", "0"))
|
||||||
else:
|
else:
|
||||||
# Overseas market
|
# Overseas market
|
||||||
price_data = await overseas_broker.get_overseas_price(
|
price_data = await overseas_broker.get_overseas_price(
|
||||||
@@ -103,11 +132,11 @@ async def trading_cycle(
|
|||||||
else:
|
else:
|
||||||
balance_info = {}
|
balance_info = {}
|
||||||
|
|
||||||
total_eval = float(balance_info.get("frcr_evlu_tota", "0") or "0")
|
total_eval = safe_float(balance_info.get("frcr_evlu_tota", "0") or "0")
|
||||||
total_cash = float(balance_info.get("frcr_dncl_amt_2", "0") or "0")
|
total_cash = safe_float(balance_info.get("frcr_dncl_amt_2", "0") or "0")
|
||||||
purchase_total = float(balance_info.get("frcr_buy_amt_smtl", "0") or "0")
|
purchase_total = safe_float(balance_info.get("frcr_buy_amt_smtl", "0") or "0")
|
||||||
|
|
||||||
current_price = float(price_data.get("output", {}).get("last", "0"))
|
current_price = safe_float(price_data.get("output", {}).get("last", "0"))
|
||||||
foreigner_net = 0.0 # Not available for overseas
|
foreigner_net = 0.0 # Not available for overseas
|
||||||
|
|
||||||
# Calculate daily P&L %
|
# Calculate daily P&L %
|
||||||
|
|||||||
@@ -49,6 +49,46 @@ class TestTokenManagement:
|
|||||||
|
|
||||||
await broker.close()
|
await broker.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_token_refresh_calls_api_once(self, settings):
|
||||||
|
"""Multiple concurrent token requests should only call API once."""
|
||||||
|
broker = KISBroker(settings)
|
||||||
|
|
||||||
|
# Track how many times the mock API is called
|
||||||
|
call_count = [0]
|
||||||
|
|
||||||
|
def create_mock_resp():
|
||||||
|
call_count[0] += 1
|
||||||
|
mock_resp = AsyncMock()
|
||||||
|
mock_resp.status = 200
|
||||||
|
mock_resp.json = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"access_token": "tok_concurrent",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": 86400,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
return mock_resp
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession.post", return_value=create_mock_resp()):
|
||||||
|
# Launch 5 concurrent token requests
|
||||||
|
tokens = await asyncio.gather(
|
||||||
|
broker._ensure_token(),
|
||||||
|
broker._ensure_token(),
|
||||||
|
broker._ensure_token(),
|
||||||
|
broker._ensure_token(),
|
||||||
|
broker._ensure_token(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# All should get the same token
|
||||||
|
assert all(t == "tok_concurrent" for t in tokens)
|
||||||
|
# API should be called only once (due to lock)
|
||||||
|
assert call_count[0] == 1
|
||||||
|
|
||||||
|
await broker.close()
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Network Error Handling
|
# Network Error Handling
|
||||||
|
|||||||
@@ -6,7 +6,43 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from src.core.risk_manager import CircuitBreakerTripped, FatFingerRejected
|
from src.core.risk_manager import CircuitBreakerTripped, FatFingerRejected
|
||||||
from src.main import trading_cycle
|
from src.main import safe_float, trading_cycle
|
||||||
|
|
||||||
|
|
||||||
|
class TestSafeFloat:
|
||||||
|
"""Test safe_float() helper function."""
|
||||||
|
|
||||||
|
def test_converts_valid_string(self):
|
||||||
|
"""Test conversion of valid numeric string."""
|
||||||
|
assert safe_float("123.45") == 123.45
|
||||||
|
assert safe_float("0") == 0.0
|
||||||
|
assert safe_float("-99.9") == -99.9
|
||||||
|
|
||||||
|
def test_handles_empty_string(self):
|
||||||
|
"""Test empty string returns default."""
|
||||||
|
assert safe_float("") == 0.0
|
||||||
|
assert safe_float("", 99.0) == 99.0
|
||||||
|
|
||||||
|
def test_handles_none(self):
|
||||||
|
"""Test None returns default."""
|
||||||
|
assert safe_float(None) == 0.0
|
||||||
|
assert safe_float(None, 42.0) == 42.0
|
||||||
|
|
||||||
|
def test_handles_invalid_string(self):
|
||||||
|
"""Test invalid string returns default."""
|
||||||
|
assert safe_float("invalid") == 0.0
|
||||||
|
assert safe_float("not_a_number", 100.0) == 100.0
|
||||||
|
assert safe_float("12.34.56") == 0.0
|
||||||
|
|
||||||
|
def test_handles_float_input(self):
|
||||||
|
"""Test float input passes through."""
|
||||||
|
assert safe_float(123.45) == 123.45
|
||||||
|
assert safe_float(0.0) == 0.0
|
||||||
|
|
||||||
|
def test_custom_default(self):
|
||||||
|
"""Test custom default value."""
|
||||||
|
assert safe_float("", -1.0) == -1.0
|
||||||
|
assert safe_float(None, 999.0) == 999.0
|
||||||
|
|
||||||
|
|
||||||
class TestTradingCycleTelegramIntegration:
|
class TestTradingCycleTelegramIntegration:
|
||||||
|
|||||||
Reference in New Issue
Block a user