Compare commits
10 Commits
feature/is
...
feature/is
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
db0d966a6a | ||
| 33b5ff5e54 | |||
| 3923d03650 | |||
|
|
c57ccc4bca | ||
|
|
cb2e3fae57 | ||
| 5e4c68c9d8 | |||
|
|
95f540e5df | ||
| 0087a6b20a | |||
|
|
3dfd7c0935 | ||
| 4b2bb25d03 |
@@ -16,8 +16,9 @@ CONFIDENCE_THRESHOLD=80
|
||||
# Database
|
||||
DB_PATH=data/trade_logs.db
|
||||
|
||||
# Rate Limiting
|
||||
RATE_LIMIT_RPS=10.0
|
||||
# Rate Limiting (requests per second for KIS API)
|
||||
# Reduced to 5.0 to avoid "초당 거래건수 초과" errors (EGW00201)
|
||||
RATE_LIMIT_RPS=5.0
|
||||
|
||||
# Trading Mode (paper / live)
|
||||
MODE=paper
|
||||
|
||||
@@ -55,6 +55,7 @@ class KISBroker:
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._access_token: str | None = None
|
||||
self._token_expires_at: float = 0.0
|
||||
self._token_lock = asyncio.Lock()
|
||||
self._rate_limiter = LeakyBucket(settings.RATE_LIMIT_RPS)
|
||||
|
||||
def _get_session(self) -> aiohttp.ClientSession:
|
||||
@@ -80,30 +81,42 @@ class KISBroker:
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _ensure_token(self) -> str:
|
||||
"""Return a valid access token, refreshing if expired."""
|
||||
"""Return a valid access token, refreshing if expired.
|
||||
|
||||
Uses a lock to prevent concurrent token refresh attempts that would
|
||||
hit the API's 1-per-minute rate limit (EGW00133).
|
||||
"""
|
||||
# Fast path: check without lock
|
||||
now = asyncio.get_event_loop().time()
|
||||
if self._access_token and now < self._token_expires_at:
|
||||
return self._access_token
|
||||
|
||||
logger.info("Refreshing KIS access token")
|
||||
session = self._get_session()
|
||||
url = f"{self._base_url}/oauth2/tokenP"
|
||||
body = {
|
||||
"grant_type": "client_credentials",
|
||||
"appkey": self._app_key,
|
||||
"appsecret": self._app_secret,
|
||||
}
|
||||
# Slow path: acquire lock and refresh
|
||||
async with self._token_lock:
|
||||
# Re-check after acquiring lock (another coroutine may have refreshed)
|
||||
now = asyncio.get_event_loop().time()
|
||||
if self._access_token and now < self._token_expires_at:
|
||||
return self._access_token
|
||||
|
||||
async with session.post(url, json=body) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ConnectionError(f"Token refresh failed ({resp.status}): {text}")
|
||||
data = await resp.json()
|
||||
logger.info("Refreshing KIS access token")
|
||||
session = self._get_session()
|
||||
url = f"{self._base_url}/oauth2/tokenP"
|
||||
body = {
|
||||
"grant_type": "client_credentials",
|
||||
"appkey": self._app_key,
|
||||
"appsecret": self._app_secret,
|
||||
}
|
||||
|
||||
self._access_token = data["access_token"]
|
||||
self._token_expires_at = now + data.get("expires_in", 86400) - 60 # 1-min buffer
|
||||
logger.info("Token refreshed successfully")
|
||||
return self._access_token
|
||||
async with session.post(url, json=body) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ConnectionError(f"Token refresh failed ({resp.status}): {text}")
|
||||
data = await resp.json()
|
||||
|
||||
self._access_token = data["access_token"]
|
||||
self._token_expires_at = now + data.get("expires_in", 86400) - 60 # 1-min buffer
|
||||
logger.info("Token refreshed successfully")
|
||||
return self._access_token
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Hash Key (required for POST bodies)
|
||||
|
||||
@@ -37,7 +37,8 @@ class Settings(BaseSettings):
|
||||
DB_PATH: str = "data/trade_logs.db"
|
||||
|
||||
# Rate Limiting (requests per second for KIS API)
|
||||
RATE_LIMIT_RPS: float = 10.0
|
||||
# Reduced to 5.0 to avoid EGW00201 "초당 거래건수 초과" errors
|
||||
RATE_LIMIT_RPS: float = 5.0
|
||||
|
||||
# Trading mode
|
||||
MODE: str = Field(default="paper", pattern="^(paper|live)$")
|
||||
|
||||
57
src/main.py
57
src/main.py
@@ -33,6 +33,35 @@ from src.notifications.telegram_client import TelegramClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def safe_float(value: str | float | None, default: float = 0.0) -> float:
|
||||
"""Convert to float, handling empty strings and None.
|
||||
|
||||
Args:
|
||||
value: Value to convert (string, float, or None)
|
||||
default: Default value if conversion fails
|
||||
|
||||
Returns:
|
||||
Converted float or default value
|
||||
|
||||
Examples:
|
||||
>>> safe_float("123.45")
|
||||
123.45
|
||||
>>> safe_float("")
|
||||
0.0
|
||||
>>> safe_float(None)
|
||||
0.0
|
||||
>>> safe_float("invalid", 99.0)
|
||||
99.0
|
||||
"""
|
||||
if value is None or value == "":
|
||||
return default
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
return default
|
||||
|
||||
|
||||
# Target stock codes to monitor per market
|
||||
WATCHLISTS = {
|
||||
"KR": ["005930", "000660", "035420"], # Samsung, SK Hynix, NAVER
|
||||
@@ -77,16 +106,16 @@ async def trading_cycle(
|
||||
balance_data = await broker.get_balance()
|
||||
|
||||
output2 = balance_data.get("output2", [{}])
|
||||
total_eval = float(output2[0].get("tot_evlu_amt", "0")) if output2 else 0
|
||||
total_cash = float(
|
||||
total_eval = safe_float(output2[0].get("tot_evlu_amt", "0")) if output2 else 0
|
||||
total_cash = safe_float(
|
||||
balance_data.get("output2", [{}])[0].get("dnca_tot_amt", "0")
|
||||
if output2
|
||||
else "0"
|
||||
)
|
||||
purchase_total = float(output2[0].get("pchs_amt_smtl_amt", "0")) if output2 else 0
|
||||
purchase_total = safe_float(output2[0].get("pchs_amt_smtl_amt", "0")) if output2 else 0
|
||||
|
||||
current_price = float(orderbook.get("output1", {}).get("stck_prpr", "0"))
|
||||
foreigner_net = float(orderbook.get("output1", {}).get("frgn_ntby_qty", "0"))
|
||||
current_price = safe_float(orderbook.get("output1", {}).get("stck_prpr", "0"))
|
||||
foreigner_net = safe_float(orderbook.get("output1", {}).get("frgn_ntby_qty", "0"))
|
||||
else:
|
||||
# Overseas market
|
||||
price_data = await overseas_broker.get_overseas_price(
|
||||
@@ -95,11 +124,19 @@ async def trading_cycle(
|
||||
balance_data = await overseas_broker.get_overseas_balance(market.exchange_code)
|
||||
|
||||
output2 = balance_data.get("output2", [{}])
|
||||
total_eval = float(output2[0].get("frcr_evlu_tota", "0")) if output2 else 0
|
||||
total_cash = float(output2[0].get("frcr_dncl_amt_2", "0")) if output2 else 0
|
||||
purchase_total = float(output2[0].get("frcr_buy_amt_smtl", "0")) if output2 else 0
|
||||
# Handle both list and dict response formats
|
||||
if isinstance(output2, list) and output2:
|
||||
balance_info = output2[0]
|
||||
elif isinstance(output2, dict):
|
||||
balance_info = output2
|
||||
else:
|
||||
balance_info = {}
|
||||
|
||||
current_price = float(price_data.get("output", {}).get("last", "0"))
|
||||
total_eval = safe_float(balance_info.get("frcr_evlu_tota", "0") or "0")
|
||||
total_cash = safe_float(balance_info.get("frcr_dncl_amt_2", "0") or "0")
|
||||
purchase_total = safe_float(balance_info.get("frcr_buy_amt_smtl", "0") or "0")
|
||||
|
||||
current_price = safe_float(price_data.get("output", {}).get("last", "0"))
|
||||
foreigner_net = 0.0 # Not available for overseas
|
||||
|
||||
# Calculate daily P&L %
|
||||
@@ -512,7 +549,9 @@ async def run(settings: Settings) -> None:
|
||||
except TimeoutError:
|
||||
pass # Normal — timeout means it's time for next cycle
|
||||
finally:
|
||||
# Clean up resources
|
||||
await broker.close()
|
||||
await telegram.close()
|
||||
db_conn.close()
|
||||
logger.info("The Ouroboros rests.")
|
||||
|
||||
|
||||
@@ -49,6 +49,46 @@ class TestTokenManagement:
|
||||
|
||||
await broker.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_token_refresh_calls_api_once(self, settings):
|
||||
"""Multiple concurrent token requests should only call API once."""
|
||||
broker = KISBroker(settings)
|
||||
|
||||
# Track how many times the mock API is called
|
||||
call_count = [0]
|
||||
|
||||
def create_mock_resp():
|
||||
call_count[0] += 1
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(
|
||||
return_value={
|
||||
"access_token": "tok_concurrent",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 86400,
|
||||
}
|
||||
)
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
return mock_resp
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=create_mock_resp()):
|
||||
# Launch 5 concurrent token requests
|
||||
tokens = await asyncio.gather(
|
||||
broker._ensure_token(),
|
||||
broker._ensure_token(),
|
||||
broker._ensure_token(),
|
||||
broker._ensure_token(),
|
||||
broker._ensure_token(),
|
||||
)
|
||||
|
||||
# All should get the same token
|
||||
assert all(t == "tok_concurrent" for t in tokens)
|
||||
# API should be called only once (due to lock)
|
||||
assert call_count[0] == 1
|
||||
|
||||
await broker.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Network Error Handling
|
||||
|
||||
@@ -6,7 +6,43 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from src.core.risk_manager import CircuitBreakerTripped, FatFingerRejected
|
||||
from src.main import trading_cycle
|
||||
from src.main import safe_float, trading_cycle
|
||||
|
||||
|
||||
class TestSafeFloat:
|
||||
"""Test safe_float() helper function."""
|
||||
|
||||
def test_converts_valid_string(self):
|
||||
"""Test conversion of valid numeric string."""
|
||||
assert safe_float("123.45") == 123.45
|
||||
assert safe_float("0") == 0.0
|
||||
assert safe_float("-99.9") == -99.9
|
||||
|
||||
def test_handles_empty_string(self):
|
||||
"""Test empty string returns default."""
|
||||
assert safe_float("") == 0.0
|
||||
assert safe_float("", 99.0) == 99.0
|
||||
|
||||
def test_handles_none(self):
|
||||
"""Test None returns default."""
|
||||
assert safe_float(None) == 0.0
|
||||
assert safe_float(None, 42.0) == 42.0
|
||||
|
||||
def test_handles_invalid_string(self):
|
||||
"""Test invalid string returns default."""
|
||||
assert safe_float("invalid") == 0.0
|
||||
assert safe_float("not_a_number", 100.0) == 100.0
|
||||
assert safe_float("12.34.56") == 0.0
|
||||
|
||||
def test_handles_float_input(self):
|
||||
"""Test float input passes through."""
|
||||
assert safe_float(123.45) == 123.45
|
||||
assert safe_float(0.0) == 0.0
|
||||
|
||||
def test_custom_default(self):
|
||||
"""Test custom default value."""
|
||||
assert safe_float("", -1.0) == -1.0
|
||||
assert safe_float(None, 999.0) == 999.0
|
||||
|
||||
|
||||
class TestTradingCycleTelegramIntegration:
|
||||
@@ -341,3 +377,221 @@ class TestRunFunctionTelegramIntegration:
|
||||
pnl_pct=-3.5,
|
||||
threshold=-3.0,
|
||||
)
|
||||
|
||||
|
||||
class TestOverseasBalanceParsing:
|
||||
"""Test overseas balance output2 parsing handles different formats."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_overseas_broker_with_list(self) -> MagicMock:
|
||||
"""Create mock overseas broker returning list format."""
|
||||
broker = MagicMock()
|
||||
broker.get_overseas_price = AsyncMock(
|
||||
return_value={"output": {"last": "150.50"}}
|
||||
)
|
||||
broker.get_overseas_balance = AsyncMock(
|
||||
return_value={
|
||||
"output2": [
|
||||
{
|
||||
"frcr_evlu_tota": "10000.00",
|
||||
"frcr_dncl_amt_2": "5000.00",
|
||||
"frcr_buy_amt_smtl": "4500.00",
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
return broker
|
||||
|
||||
@pytest.fixture
|
||||
def mock_overseas_broker_with_dict(self) -> MagicMock:
|
||||
"""Create mock overseas broker returning dict format."""
|
||||
broker = MagicMock()
|
||||
broker.get_overseas_price = AsyncMock(
|
||||
return_value={"output": {"last": "150.50"}}
|
||||
)
|
||||
broker.get_overseas_balance = AsyncMock(
|
||||
return_value={
|
||||
"output2": {
|
||||
"frcr_evlu_tota": "10000.00",
|
||||
"frcr_dncl_amt_2": "5000.00",
|
||||
"frcr_buy_amt_smtl": "4500.00",
|
||||
}
|
||||
}
|
||||
)
|
||||
return broker
|
||||
|
||||
@pytest.fixture
|
||||
def mock_overseas_broker_with_empty(self) -> MagicMock:
|
||||
"""Create mock overseas broker returning empty output2."""
|
||||
broker = MagicMock()
|
||||
broker.get_overseas_price = AsyncMock(
|
||||
return_value={"output": {"last": "150.50"}}
|
||||
)
|
||||
broker.get_overseas_balance = AsyncMock(return_value={"output2": []})
|
||||
return broker
|
||||
|
||||
@pytest.fixture
|
||||
def mock_domestic_broker(self) -> MagicMock:
|
||||
"""Create minimal mock domestic broker."""
|
||||
broker = MagicMock()
|
||||
return broker
|
||||
|
||||
@pytest.fixture
|
||||
def mock_overseas_market(self) -> MagicMock:
|
||||
"""Create mock overseas market info."""
|
||||
market = MagicMock()
|
||||
market.name = "NASDAQ"
|
||||
market.code = "US_NASDAQ"
|
||||
market.exchange_code = "NASD"
|
||||
market.is_domestic = False
|
||||
return market
|
||||
|
||||
@pytest.fixture
|
||||
def mock_brain_hold(self) -> MagicMock:
|
||||
"""Create mock brain that always holds."""
|
||||
brain = MagicMock()
|
||||
decision = MagicMock()
|
||||
decision.action = "HOLD"
|
||||
decision.confidence = 50
|
||||
decision.rationale = "Testing balance parsing"
|
||||
brain.decide = AsyncMock(return_value=decision)
|
||||
return brain
|
||||
|
||||
@pytest.fixture
|
||||
def mock_risk(self) -> MagicMock:
|
||||
"""Create mock risk manager."""
|
||||
return MagicMock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db(self) -> MagicMock:
|
||||
"""Create mock database."""
|
||||
return MagicMock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_decision_logger(self) -> MagicMock:
|
||||
"""Create mock decision logger."""
|
||||
return MagicMock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context_store(self) -> MagicMock:
|
||||
"""Create mock context store."""
|
||||
store = MagicMock()
|
||||
store.get_latest_timeframe = MagicMock(return_value=None)
|
||||
return store
|
||||
|
||||
@pytest.fixture
|
||||
def mock_criticality_assessor(self) -> MagicMock:
|
||||
"""Create mock criticality assessor."""
|
||||
assessor = MagicMock()
|
||||
assessor.assess_market_conditions = MagicMock(
|
||||
return_value=MagicMock(value="NORMAL")
|
||||
)
|
||||
assessor.get_timeout = MagicMock(return_value=5.0)
|
||||
return assessor
|
||||
|
||||
@pytest.fixture
|
||||
def mock_telegram(self) -> MagicMock:
|
||||
"""Create mock telegram client."""
|
||||
return MagicMock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_overseas_balance_list_format(
|
||||
self,
|
||||
mock_domestic_broker: MagicMock,
|
||||
mock_overseas_broker_with_list: MagicMock,
|
||||
mock_brain_hold: MagicMock,
|
||||
mock_risk: MagicMock,
|
||||
mock_db: MagicMock,
|
||||
mock_decision_logger: MagicMock,
|
||||
mock_context_store: MagicMock,
|
||||
mock_criticality_assessor: MagicMock,
|
||||
mock_telegram: MagicMock,
|
||||
mock_overseas_market: MagicMock,
|
||||
) -> None:
|
||||
"""Test overseas balance parsing with list format (output2=[{...}])."""
|
||||
with patch("src.main.log_trade"):
|
||||
# Should not raise KeyError
|
||||
await trading_cycle(
|
||||
broker=mock_domestic_broker,
|
||||
overseas_broker=mock_overseas_broker_with_list,
|
||||
brain=mock_brain_hold,
|
||||
risk=mock_risk,
|
||||
db_conn=mock_db,
|
||||
decision_logger=mock_decision_logger,
|
||||
context_store=mock_context_store,
|
||||
criticality_assessor=mock_criticality_assessor,
|
||||
telegram=mock_telegram,
|
||||
market=mock_overseas_market,
|
||||
stock_code="AAPL",
|
||||
)
|
||||
|
||||
# Verify balance API was called
|
||||
mock_overseas_broker_with_list.get_overseas_balance.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_overseas_balance_dict_format(
|
||||
self,
|
||||
mock_domestic_broker: MagicMock,
|
||||
mock_overseas_broker_with_dict: MagicMock,
|
||||
mock_brain_hold: MagicMock,
|
||||
mock_risk: MagicMock,
|
||||
mock_db: MagicMock,
|
||||
mock_decision_logger: MagicMock,
|
||||
mock_context_store: MagicMock,
|
||||
mock_criticality_assessor: MagicMock,
|
||||
mock_telegram: MagicMock,
|
||||
mock_overseas_market: MagicMock,
|
||||
) -> None:
|
||||
"""Test overseas balance parsing with dict format (output2={...})."""
|
||||
with patch("src.main.log_trade"):
|
||||
# Should not raise KeyError
|
||||
await trading_cycle(
|
||||
broker=mock_domestic_broker,
|
||||
overseas_broker=mock_overseas_broker_with_dict,
|
||||
brain=mock_brain_hold,
|
||||
risk=mock_risk,
|
||||
db_conn=mock_db,
|
||||
decision_logger=mock_decision_logger,
|
||||
context_store=mock_context_store,
|
||||
criticality_assessor=mock_criticality_assessor,
|
||||
telegram=mock_telegram,
|
||||
market=mock_overseas_market,
|
||||
stock_code="AAPL",
|
||||
)
|
||||
|
||||
# Verify balance API was called
|
||||
mock_overseas_broker_with_dict.get_overseas_balance.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_overseas_balance_empty_format(
|
||||
self,
|
||||
mock_domestic_broker: MagicMock,
|
||||
mock_overseas_broker_with_empty: MagicMock,
|
||||
mock_brain_hold: MagicMock,
|
||||
mock_risk: MagicMock,
|
||||
mock_db: MagicMock,
|
||||
mock_decision_logger: MagicMock,
|
||||
mock_context_store: MagicMock,
|
||||
mock_criticality_assessor: MagicMock,
|
||||
mock_telegram: MagicMock,
|
||||
mock_overseas_market: MagicMock,
|
||||
) -> None:
|
||||
"""Test overseas balance parsing with empty output2."""
|
||||
with patch("src.main.log_trade"):
|
||||
# Should not raise KeyError, should default to 0
|
||||
await trading_cycle(
|
||||
broker=mock_domestic_broker,
|
||||
overseas_broker=mock_overseas_broker_with_empty,
|
||||
brain=mock_brain_hold,
|
||||
risk=mock_risk,
|
||||
db_conn=mock_db,
|
||||
decision_logger=mock_decision_logger,
|
||||
context_store=mock_context_store,
|
||||
criticality_assessor=mock_criticality_assessor,
|
||||
telegram=mock_telegram,
|
||||
market=mock_overseas_market,
|
||||
stock_code="AAPL",
|
||||
)
|
||||
|
||||
# Verify balance API was called
|
||||
mock_overseas_broker_with_empty.get_overseas_balance.assert_called_once()
|
||||
|
||||
Reference in New Issue
Block a user