Compare commits
3 Commits
feature/is
...
5e4c68c9d8
| Author | SHA1 | Date | |
|---|---|---|---|
| 5e4c68c9d8 | |||
|
|
95f540e5df | ||
| 0087a6b20a |
@@ -55,6 +55,7 @@ class KISBroker:
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._access_token: str | None = None
|
||||
self._token_expires_at: float = 0.0
|
||||
self._token_lock = asyncio.Lock()
|
||||
self._rate_limiter = LeakyBucket(settings.RATE_LIMIT_RPS)
|
||||
|
||||
def _get_session(self) -> aiohttp.ClientSession:
|
||||
@@ -80,7 +81,19 @@ class KISBroker:
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _ensure_token(self) -> str:
|
||||
"""Return a valid access token, refreshing if expired."""
|
||||
"""Return a valid access token, refreshing if expired.
|
||||
|
||||
Uses a lock to prevent concurrent token refresh attempts that would
|
||||
hit the API's 1-per-minute rate limit (EGW00133).
|
||||
"""
|
||||
# Fast path: check without lock
|
||||
now = asyncio.get_event_loop().time()
|
||||
if self._access_token and now < self._token_expires_at:
|
||||
return self._access_token
|
||||
|
||||
# Slow path: acquire lock and refresh
|
||||
async with self._token_lock:
|
||||
# Re-check after acquiring lock (another coroutine may have refreshed)
|
||||
now = asyncio.get_event_loop().time()
|
||||
if self._access_token and now < self._token_expires_at:
|
||||
return self._access_token
|
||||
|
||||
@@ -49,6 +49,46 @@ class TestTokenManagement:
|
||||
|
||||
await broker.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_token_refresh_calls_api_once(self, settings):
|
||||
"""Multiple concurrent token requests should only call API once."""
|
||||
broker = KISBroker(settings)
|
||||
|
||||
# Track how many times the mock API is called
|
||||
call_count = [0]
|
||||
|
||||
def create_mock_resp():
|
||||
call_count[0] += 1
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(
|
||||
return_value={
|
||||
"access_token": "tok_concurrent",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 86400,
|
||||
}
|
||||
)
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
return mock_resp
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=create_mock_resp()):
|
||||
# Launch 5 concurrent token requests
|
||||
tokens = await asyncio.gather(
|
||||
broker._ensure_token(),
|
||||
broker._ensure_token(),
|
||||
broker._ensure_token(),
|
||||
broker._ensure_token(),
|
||||
broker._ensure_token(),
|
||||
)
|
||||
|
||||
# All should get the same token
|
||||
assert all(t == "tok_concurrent" for t in tokens)
|
||||
# API should be called only once (due to lock)
|
||||
assert call_count[0] == 1
|
||||
|
||||
await broker.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Network Error Handling
|
||||
|
||||
Reference in New Issue
Block a user