Compare commits
2 Commits
feature/is
...
feature/is
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
95f540e5df | ||
| 0087a6b20a |
@@ -55,6 +55,7 @@ class KISBroker:
|
|||||||
self._session: aiohttp.ClientSession | None = None
|
self._session: aiohttp.ClientSession | None = None
|
||||||
self._access_token: str | None = None
|
self._access_token: str | None = None
|
||||||
self._token_expires_at: float = 0.0
|
self._token_expires_at: float = 0.0
|
||||||
|
self._token_lock = asyncio.Lock()
|
||||||
self._rate_limiter = LeakyBucket(settings.RATE_LIMIT_RPS)
|
self._rate_limiter = LeakyBucket(settings.RATE_LIMIT_RPS)
|
||||||
|
|
||||||
def _get_session(self) -> aiohttp.ClientSession:
|
def _get_session(self) -> aiohttp.ClientSession:
|
||||||
@@ -80,30 +81,42 @@ class KISBroker:
|
|||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
async def _ensure_token(self) -> str:
|
async def _ensure_token(self) -> str:
|
||||||
"""Return a valid access token, refreshing if expired."""
|
"""Return a valid access token, refreshing if expired.
|
||||||
|
|
||||||
|
Uses a lock to prevent concurrent token refresh attempts that would
|
||||||
|
hit the API's 1-per-minute rate limit (EGW00133).
|
||||||
|
"""
|
||||||
|
# Fast path: check without lock
|
||||||
now = asyncio.get_event_loop().time()
|
now = asyncio.get_event_loop().time()
|
||||||
if self._access_token and now < self._token_expires_at:
|
if self._access_token and now < self._token_expires_at:
|
||||||
return self._access_token
|
return self._access_token
|
||||||
|
|
||||||
logger.info("Refreshing KIS access token")
|
# Slow path: acquire lock and refresh
|
||||||
session = self._get_session()
|
async with self._token_lock:
|
||||||
url = f"{self._base_url}/oauth2/tokenP"
|
# Re-check after acquiring lock (another coroutine may have refreshed)
|
||||||
body = {
|
now = asyncio.get_event_loop().time()
|
||||||
"grant_type": "client_credentials",
|
if self._access_token and now < self._token_expires_at:
|
||||||
"appkey": self._app_key,
|
return self._access_token
|
||||||
"appsecret": self._app_secret,
|
|
||||||
}
|
|
||||||
|
|
||||||
async with session.post(url, json=body) as resp:
|
logger.info("Refreshing KIS access token")
|
||||||
if resp.status != 200:
|
session = self._get_session()
|
||||||
text = await resp.text()
|
url = f"{self._base_url}/oauth2/tokenP"
|
||||||
raise ConnectionError(f"Token refresh failed ({resp.status}): {text}")
|
body = {
|
||||||
data = await resp.json()
|
"grant_type": "client_credentials",
|
||||||
|
"appkey": self._app_key,
|
||||||
|
"appsecret": self._app_secret,
|
||||||
|
}
|
||||||
|
|
||||||
self._access_token = data["access_token"]
|
async with session.post(url, json=body) as resp:
|
||||||
self._token_expires_at = now + data.get("expires_in", 86400) - 60 # 1-min buffer
|
if resp.status != 200:
|
||||||
logger.info("Token refreshed successfully")
|
text = await resp.text()
|
||||||
return self._access_token
|
raise ConnectionError(f"Token refresh failed ({resp.status}): {text}")
|
||||||
|
data = await resp.json()
|
||||||
|
|
||||||
|
self._access_token = data["access_token"]
|
||||||
|
self._token_expires_at = now + data.get("expires_in", 86400) - 60 # 1-min buffer
|
||||||
|
logger.info("Token refreshed successfully")
|
||||||
|
return self._access_token
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Hash Key (required for POST bodies)
|
# Hash Key (required for POST bodies)
|
||||||
|
|||||||
@@ -49,6 +49,46 @@ class TestTokenManagement:
|
|||||||
|
|
||||||
await broker.close()
|
await broker.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_token_refresh_calls_api_once(self, settings):
|
||||||
|
"""Multiple concurrent token requests should only call API once."""
|
||||||
|
broker = KISBroker(settings)
|
||||||
|
|
||||||
|
# Track how many times the mock API is called
|
||||||
|
call_count = [0]
|
||||||
|
|
||||||
|
def create_mock_resp():
|
||||||
|
call_count[0] += 1
|
||||||
|
mock_resp = AsyncMock()
|
||||||
|
mock_resp.status = 200
|
||||||
|
mock_resp.json = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"access_token": "tok_concurrent",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": 86400,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
return mock_resp
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession.post", return_value=create_mock_resp()):
|
||||||
|
# Launch 5 concurrent token requests
|
||||||
|
tokens = await asyncio.gather(
|
||||||
|
broker._ensure_token(),
|
||||||
|
broker._ensure_token(),
|
||||||
|
broker._ensure_token(),
|
||||||
|
broker._ensure_token(),
|
||||||
|
broker._ensure_token(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# All should get the same token
|
||||||
|
assert all(t == "tok_concurrent" for t in tokens)
|
||||||
|
# API should be called only once (due to lock)
|
||||||
|
assert call_count[0] == 1
|
||||||
|
|
||||||
|
await broker.close()
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Network Error Handling
|
# Network Error Handling
|
||||||
|
|||||||
Reference in New Issue
Block a user