Compare commits
18 Commits
feature/is
...
feature/is
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0057de4d12 | ||
|
|
71ac59794e | ||
| be04820b00 | |||
| 10b6e34d44 | |||
| 58f1106dbd | |||
| cf5072cced | |||
|
|
702653e52e | ||
|
|
db0d966a6a | ||
|
|
a56adcd342 | ||
|
|
eaf509a895 | ||
|
|
854931bed2 | ||
| 33b5ff5e54 | |||
| 3923d03650 | |||
|
|
c57ccc4bca | ||
|
|
cb2e3fae57 | ||
| 5e4c68c9d8 | |||
|
|
95f540e5df | ||
| 0087a6b20a |
@@ -16,8 +16,9 @@ CONFIDENCE_THRESHOLD=80
|
|||||||
# Database
|
# Database
|
||||||
DB_PATH=data/trade_logs.db
|
DB_PATH=data/trade_logs.db
|
||||||
|
|
||||||
# Rate Limiting
|
# Rate Limiting (requests per second for KIS API)
|
||||||
RATE_LIMIT_RPS=10.0
|
# Reduced to 5.0 to avoid "초당 거래건수 초과" errors (EGW00201)
|
||||||
|
RATE_LIMIT_RPS=5.0
|
||||||
|
|
||||||
# Trading Mode (paper / live)
|
# Trading Mode (paper / live)
|
||||||
MODE=paper
|
MODE=paper
|
||||||
|
|||||||
10
CLAUDE.md
10
CLAUDE.md
@@ -53,6 +53,7 @@ Get real-time alerts for trades, circuit breakers, and system events via Telegra
|
|||||||
- **[Context Tree](docs/context-tree.md)** — L1-L7 hierarchical memory system
|
- **[Context Tree](docs/context-tree.md)** — L1-L7 hierarchical memory system
|
||||||
- **[Testing](docs/testing.md)** — Test structure, coverage requirements, writing tests
|
- **[Testing](docs/testing.md)** — Test structure, coverage requirements, writing tests
|
||||||
- **[Agent Policies](docs/agents.md)** — Prime directives, constraints, prohibited actions
|
- **[Agent Policies](docs/agents.md)** — Prime directives, constraints, prohibited actions
|
||||||
|
- **[Requirements Log](docs/requirements-log.md)** — User requirements and feedback tracking
|
||||||
|
|
||||||
## Core Principles
|
## Core Principles
|
||||||
|
|
||||||
@@ -61,6 +62,15 @@ Get real-time alerts for trades, circuit breakers, and system events via Telegra
|
|||||||
3. **Issue-Driven Development** — All work goes through Gitea issues → feature branches → PRs
|
3. **Issue-Driven Development** — All work goes through Gitea issues → feature branches → PRs
|
||||||
4. **Agent Specialization** — Use dedicated agents for design, coding, testing, docs, review
|
4. **Agent Specialization** — Use dedicated agents for design, coding, testing, docs, review
|
||||||
|
|
||||||
|
## Requirements Management
|
||||||
|
|
||||||
|
User requirements and feedback are tracked in [docs/requirements-log.md](docs/requirements-log.md):
|
||||||
|
|
||||||
|
- New requirements are added chronologically with dates
|
||||||
|
- Code changes should reference related requirements
|
||||||
|
- Helps maintain project evolution aligned with user needs
|
||||||
|
- Preserves context across conversations and development cycles
|
||||||
|
|
||||||
## Project Structure
|
## Project Structure
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -2,7 +2,42 @@
|
|||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
Self-evolving AI trading agent for global stock markets via KIS (Korea Investment & Securities) API. The main loop in `src/main.py` orchestrates four components in a 60-second cycle per stock across multiple markets.
|
Self-evolving AI trading agent for global stock markets via KIS (Korea Investment & Securities) API. The main loop in `src/main.py` orchestrates four components across multiple markets with two trading modes: daily (batch API calls) or realtime (per-stock decisions).
|
||||||
|
|
||||||
|
## Trading Modes
|
||||||
|
|
||||||
|
The system supports two trading frequency modes controlled by the `TRADE_MODE` environment variable:
|
||||||
|
|
||||||
|
### Daily Mode (default)
|
||||||
|
|
||||||
|
Optimized for Gemini Free tier API limits (20 calls/day):
|
||||||
|
|
||||||
|
- **Batch decisions**: 1 API call per market per session
|
||||||
|
- **Fixed schedule**: 4 sessions per day at 6-hour intervals (configurable)
|
||||||
|
- **API efficiency**: Processes all stocks in a market simultaneously
|
||||||
|
- **Use case**: Free tier users, cost-conscious deployments
|
||||||
|
- **Configuration**:
|
||||||
|
```bash
|
||||||
|
TRADE_MODE=daily
|
||||||
|
DAILY_SESSIONS=4 # Sessions per day (1-10)
|
||||||
|
SESSION_INTERVAL_HOURS=6 # Hours between sessions (1-24)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example**: With 2 markets (US, KR) and 4 sessions/day = 8 API calls/day (within 20 call limit)
|
||||||
|
|
||||||
|
### Realtime Mode
|
||||||
|
|
||||||
|
High-frequency trading with individual stock analysis:
|
||||||
|
|
||||||
|
- **Per-stock decisions**: 1 API call per stock per cycle
|
||||||
|
- **60-second interval**: Continuous monitoring
|
||||||
|
- **Use case**: Production deployments with Gemini paid tier
|
||||||
|
- **Configuration**:
|
||||||
|
```bash
|
||||||
|
TRADE_MODE=realtime
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note**: Realtime mode requires Gemini API subscription due to high call volume.
|
||||||
|
|
||||||
## Core Components
|
## Core Components
|
||||||
|
|
||||||
@@ -192,6 +227,11 @@ MAX_LOSS_PCT=3.0
|
|||||||
MAX_ORDER_PCT=30.0
|
MAX_ORDER_PCT=30.0
|
||||||
ENABLED_MARKETS=KR,US_NASDAQ # Comma-separated market codes
|
ENABLED_MARKETS=KR,US_NASDAQ # Comma-separated market codes
|
||||||
|
|
||||||
|
# Trading Mode (API efficiency)
|
||||||
|
TRADE_MODE=daily # daily | realtime
|
||||||
|
DAILY_SESSIONS=4 # Sessions per day (daily mode only)
|
||||||
|
SESSION_INTERVAL_HOURS=6 # Hours between sessions (daily mode only)
|
||||||
|
|
||||||
# Telegram Notifications (optional)
|
# Telegram Notifications (optional)
|
||||||
TELEGRAM_BOT_TOKEN=1234567890:ABCdefGHIjklMNOpqrsTUVwxyz
|
TELEGRAM_BOT_TOKEN=1234567890:ABCdefGHIjklMNOpqrsTUVwxyz
|
||||||
TELEGRAM_CHAT_ID=123456789
|
TELEGRAM_CHAT_ID=123456789
|
||||||
|
|||||||
28
docs/requirements-log.md
Normal file
28
docs/requirements-log.md
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# Requirements Log
|
||||||
|
|
||||||
|
프로젝트 진화를 위한 사용자 요구사항 기록.
|
||||||
|
|
||||||
|
이 문서는 시간순으로 사용자와의 대화에서 나온 요구사항과 피드백을 기록합니다.
|
||||||
|
새로운 요구사항이 있으면 날짜와 함께 추가하세요.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2026-02-05
|
||||||
|
|
||||||
|
### API 효율화
|
||||||
|
- Gemini API는 귀중한 자원. 종목별 개별 호출 대신 배치 호출 필요
|
||||||
|
- Free tier 한도(20 calls/day) 고려하여 일일 몇 차례 거래 모드로 전환
|
||||||
|
- 배치 API 호출로 여러 종목을 한 번에 분석
|
||||||
|
|
||||||
|
### 거래 모드
|
||||||
|
- **Daily Mode**: 하루 4회 거래 세션 (6시간 간격) - Free tier 호환
|
||||||
|
- **Realtime Mode**: 60초 간격 실시간 거래 - 유료 구독 필요
|
||||||
|
- `TRADE_MODE` 환경변수로 모드 선택
|
||||||
|
|
||||||
|
### 진화 시스템
|
||||||
|
- 사용자 대화 내용을 문서로 기록하여 향후에도 의도 반영
|
||||||
|
- 프롬프트 품질 검증은 별도 이슈로 다룰 예정
|
||||||
|
|
||||||
|
### 문서화
|
||||||
|
- 시스템 구조, 기능별 설명 등 코드 문서화 항상 신경쓸 것
|
||||||
|
- 새로운 기능 추가 시 관련 문서 업데이트 필수
|
||||||
@@ -42,6 +42,7 @@ class MarketScanner:
|
|||||||
volatility_analyzer: VolatilityAnalyzer,
|
volatility_analyzer: VolatilityAnalyzer,
|
||||||
context_store: ContextStore,
|
context_store: ContextStore,
|
||||||
top_n: int = 5,
|
top_n: int = 5,
|
||||||
|
max_concurrent_scans: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the market scanner.
|
"""Initialize the market scanner.
|
||||||
|
|
||||||
@@ -51,12 +52,14 @@ class MarketScanner:
|
|||||||
volatility_analyzer: Volatility analyzer instance
|
volatility_analyzer: Volatility analyzer instance
|
||||||
context_store: Context store for L7 real-time data
|
context_store: Context store for L7 real-time data
|
||||||
top_n: Number of top movers to return per market (default 5)
|
top_n: Number of top movers to return per market (default 5)
|
||||||
|
max_concurrent_scans: Max concurrent stock scans (default 1, fully serialized)
|
||||||
"""
|
"""
|
||||||
self.broker = broker
|
self.broker = broker
|
||||||
self.overseas_broker = overseas_broker
|
self.overseas_broker = overseas_broker
|
||||||
self.analyzer = volatility_analyzer
|
self.analyzer = volatility_analyzer
|
||||||
self.context_store = context_store
|
self.context_store = context_store
|
||||||
self.top_n = top_n
|
self.top_n = top_n
|
||||||
|
self._scan_semaphore = asyncio.Semaphore(max_concurrent_scans)
|
||||||
|
|
||||||
async def scan_stock(
|
async def scan_stock(
|
||||||
self,
|
self,
|
||||||
@@ -83,8 +86,8 @@ class MarketScanner:
|
|||||||
# Convert to orderbook-like structure
|
# Convert to orderbook-like structure
|
||||||
orderbook = {
|
orderbook = {
|
||||||
"output1": {
|
"output1": {
|
||||||
"stck_prpr": price_data.get("output", {}).get("last", "0"),
|
"stck_prpr": price_data.get("output", {}).get("last", "0") or "0",
|
||||||
"acml_vol": price_data.get("output", {}).get("tvol", "0"),
|
"acml_vol": price_data.get("output", {}).get("tvol", "0") or "0",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -139,8 +142,12 @@ class MarketScanner:
|
|||||||
|
|
||||||
logger.info("Scanning %s market (%d stocks)", market.name, len(stock_codes))
|
logger.info("Scanning %s market (%d stocks)", market.name, len(stock_codes))
|
||||||
|
|
||||||
# Scan all stocks concurrently (with rate limiting handled by broker)
|
# Scan stocks with bounded concurrency to prevent API rate limit burst
|
||||||
tasks = [self.scan_stock(code, market) for code in stock_codes]
|
async def _bounded_scan(code: str) -> VolatilityMetrics | None:
|
||||||
|
async with self._scan_semaphore:
|
||||||
|
return await self.scan_stock(code, market)
|
||||||
|
|
||||||
|
tasks = [_bounded_scan(code) for code in stock_codes]
|
||||||
results = await asyncio.gather(*tasks)
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
# Filter out failures and sort by momentum score
|
# Filter out failures and sort by momentum score
|
||||||
|
|||||||
@@ -525,3 +525,233 @@ class GeminiClient:
|
|||||||
DecisionCache instance or None if caching disabled
|
DecisionCache instance or None if caching disabled
|
||||||
"""
|
"""
|
||||||
return self._cache
|
return self._cache
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Batch Decision Making (for daily trading mode)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def decide_batch(
|
||||||
|
self, stocks_data: list[dict[str, Any]]
|
||||||
|
) -> dict[str, TradeDecision]:
|
||||||
|
"""Make decisions for multiple stocks in a single API call.
|
||||||
|
|
||||||
|
This is designed for daily trading mode to minimize API usage
|
||||||
|
when working with Gemini Free tier (20 calls/day limit).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stocks_data: List of market data dictionaries, each with:
|
||||||
|
- stock_code: Stock ticker
|
||||||
|
- current_price: Current price
|
||||||
|
- market_name: Market name (optional)
|
||||||
|
- foreigner_net: Foreigner net buy/sell (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping stock_code to TradeDecision
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> stocks_data = [
|
||||||
|
... {"stock_code": "AAPL", "current_price": 185.5},
|
||||||
|
... {"stock_code": "MSFT", "current_price": 420.0},
|
||||||
|
... ]
|
||||||
|
>>> decisions = await client.decide_batch(stocks_data)
|
||||||
|
>>> decisions["AAPL"].action
|
||||||
|
'BUY'
|
||||||
|
"""
|
||||||
|
if not stocks_data:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Build compressed batch prompt
|
||||||
|
market_name = stocks_data[0].get("market_name", "stock market")
|
||||||
|
|
||||||
|
# Format stock data as compact JSON array
|
||||||
|
compact_stocks = []
|
||||||
|
for stock in stocks_data:
|
||||||
|
compact = {
|
||||||
|
"code": stock["stock_code"],
|
||||||
|
"price": stock["current_price"],
|
||||||
|
}
|
||||||
|
if stock.get("foreigner_net", 0) != 0:
|
||||||
|
compact["frgn"] = stock["foreigner_net"]
|
||||||
|
compact_stocks.append(compact)
|
||||||
|
|
||||||
|
data_str = json.dumps(compact_stocks, ensure_ascii=False)
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
f"You are a professional {market_name} trading analyst.\n"
|
||||||
|
"Analyze the following stocks and decide whether to BUY, SELL, or HOLD each one.\n\n"
|
||||||
|
f"Stock Data: {data_str}\n\n"
|
||||||
|
"You MUST respond with ONLY a valid JSON array in this format:\n"
|
||||||
|
'[{"code": "AAPL", "action": "BUY", "confidence": 85, "rationale": "..."},\n'
|
||||||
|
' {"code": "MSFT", "action": "HOLD", "confidence": 50, "rationale": "..."}, ...]\n\n'
|
||||||
|
"Rules:\n"
|
||||||
|
"- Return one decision object per stock\n"
|
||||||
|
"- action must be exactly: BUY, SELL, or HOLD\n"
|
||||||
|
"- confidence must be 0-100\n"
|
||||||
|
"- rationale should be concise (1-2 sentences)\n"
|
||||||
|
"- Do NOT wrap JSON in markdown code blocks\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Estimate tokens
|
||||||
|
token_count = self._optimizer.estimate_tokens(prompt)
|
||||||
|
self._total_tokens_used += token_count
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Requesting batch decision for %d stocks from Gemini",
|
||||||
|
len(stocks_data),
|
||||||
|
extra={"estimated_tokens": token_count},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self._client.aio.models.generate_content(
|
||||||
|
model=self._model_name,
|
||||||
|
contents=prompt,
|
||||||
|
)
|
||||||
|
raw = response.text
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Gemini API error in batch decision: %s", exc)
|
||||||
|
# Return HOLD for all stocks on API error
|
||||||
|
return {
|
||||||
|
stock["stock_code"]: TradeDecision(
|
||||||
|
action="HOLD",
|
||||||
|
confidence=0,
|
||||||
|
rationale=f"API error: {exc}",
|
||||||
|
token_count=token_count,
|
||||||
|
cached=False,
|
||||||
|
)
|
||||||
|
for stock in stocks_data
|
||||||
|
}
|
||||||
|
|
||||||
|
# Parse batch response
|
||||||
|
return self._parse_batch_response(raw, stocks_data, token_count)
|
||||||
|
|
||||||
|
def _parse_batch_response(
|
||||||
|
self, raw: str, stocks_data: list[dict[str, Any]], token_count: int
|
||||||
|
) -> dict[str, TradeDecision]:
|
||||||
|
"""Parse batch response into a dictionary of decisions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw: Raw response from Gemini
|
||||||
|
stocks_data: Original stock data list
|
||||||
|
token_count: Token count for the request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping stock_code to TradeDecision
|
||||||
|
"""
|
||||||
|
if not raw or not raw.strip():
|
||||||
|
logger.warning("Empty batch response from Gemini — defaulting all to HOLD")
|
||||||
|
return {
|
||||||
|
stock["stock_code"]: TradeDecision(
|
||||||
|
action="HOLD",
|
||||||
|
confidence=0,
|
||||||
|
rationale="Empty response",
|
||||||
|
token_count=0,
|
||||||
|
cached=False,
|
||||||
|
)
|
||||||
|
for stock in stocks_data
|
||||||
|
}
|
||||||
|
|
||||||
|
# Strip markdown code fences if present
|
||||||
|
cleaned = raw.strip()
|
||||||
|
match = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", cleaned, re.DOTALL)
|
||||||
|
if match:
|
||||||
|
cleaned = match.group(1).strip()
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(cleaned)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("Malformed JSON in batch response — defaulting all to HOLD")
|
||||||
|
return {
|
||||||
|
stock["stock_code"]: TradeDecision(
|
||||||
|
action="HOLD",
|
||||||
|
confidence=0,
|
||||||
|
rationale="Malformed JSON response",
|
||||||
|
token_count=0,
|
||||||
|
cached=False,
|
||||||
|
)
|
||||||
|
for stock in stocks_data
|
||||||
|
}
|
||||||
|
|
||||||
|
if not isinstance(data, list):
|
||||||
|
logger.warning("Batch response is not a JSON array — defaulting all to HOLD")
|
||||||
|
return {
|
||||||
|
stock["stock_code"]: TradeDecision(
|
||||||
|
action="HOLD",
|
||||||
|
confidence=0,
|
||||||
|
rationale="Invalid response format",
|
||||||
|
token_count=0,
|
||||||
|
cached=False,
|
||||||
|
)
|
||||||
|
for stock in stocks_data
|
||||||
|
}
|
||||||
|
|
||||||
|
# Build decision map
|
||||||
|
decisions: dict[str, TradeDecision] = {}
|
||||||
|
stock_codes = {stock["stock_code"] for stock in stocks_data}
|
||||||
|
|
||||||
|
for item in data:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
code = item.get("code")
|
||||||
|
if not code or code not in stock_codes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Validate required fields
|
||||||
|
if not all(k in item for k in ("action", "confidence", "rationale")):
|
||||||
|
logger.warning("Missing fields for %s — using HOLD", code)
|
||||||
|
decisions[code] = TradeDecision(
|
||||||
|
action="HOLD",
|
||||||
|
confidence=0,
|
||||||
|
rationale="Missing required fields",
|
||||||
|
token_count=0,
|
||||||
|
cached=False,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
action = str(item["action"]).upper()
|
||||||
|
if action not in VALID_ACTIONS:
|
||||||
|
logger.warning("Invalid action '%s' for %s — forcing HOLD", action, code)
|
||||||
|
action = "HOLD"
|
||||||
|
|
||||||
|
confidence = int(item["confidence"])
|
||||||
|
rationale = str(item["rationale"])
|
||||||
|
|
||||||
|
# Enforce confidence threshold
|
||||||
|
if confidence < self._confidence_threshold:
|
||||||
|
logger.info(
|
||||||
|
"Confidence %d < threshold %d for %s — forcing HOLD",
|
||||||
|
confidence,
|
||||||
|
self._confidence_threshold,
|
||||||
|
code,
|
||||||
|
)
|
||||||
|
action = "HOLD"
|
||||||
|
|
||||||
|
decisions[code] = TradeDecision(
|
||||||
|
action=action,
|
||||||
|
confidence=confidence,
|
||||||
|
rationale=rationale,
|
||||||
|
token_count=token_count // len(stocks_data), # Split token cost
|
||||||
|
cached=False,
|
||||||
|
)
|
||||||
|
self._total_decisions += 1
|
||||||
|
|
||||||
|
# Fill in missing stocks with HOLD
|
||||||
|
for stock in stocks_data:
|
||||||
|
code = stock["stock_code"]
|
||||||
|
if code not in decisions:
|
||||||
|
logger.warning("No decision for %s in batch response — using HOLD", code)
|
||||||
|
decisions[code] = TradeDecision(
|
||||||
|
action="HOLD",
|
||||||
|
confidence=0,
|
||||||
|
rationale="Not found in batch response",
|
||||||
|
token_count=0,
|
||||||
|
cached=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Batch decision completed for %d stocks",
|
||||||
|
len(decisions),
|
||||||
|
extra={"tokens": token_count},
|
||||||
|
)
|
||||||
|
|
||||||
|
return decisions
|
||||||
|
|||||||
@@ -55,6 +55,9 @@ class KISBroker:
|
|||||||
self._session: aiohttp.ClientSession | None = None
|
self._session: aiohttp.ClientSession | None = None
|
||||||
self._access_token: str | None = None
|
self._access_token: str | None = None
|
||||||
self._token_expires_at: float = 0.0
|
self._token_expires_at: float = 0.0
|
||||||
|
self._token_lock = asyncio.Lock()
|
||||||
|
self._last_refresh_attempt: float = 0.0
|
||||||
|
self._refresh_cooldown: float = 60.0 # Seconds (matches KIS 1/minute limit)
|
||||||
self._rate_limiter = LeakyBucket(settings.RATE_LIMIT_RPS)
|
self._rate_limiter = LeakyBucket(settings.RATE_LIMIT_RPS)
|
||||||
|
|
||||||
def _get_session(self) -> aiohttp.ClientSession:
|
def _get_session(self) -> aiohttp.ClientSession:
|
||||||
@@ -80,30 +83,54 @@ class KISBroker:
|
|||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
async def _ensure_token(self) -> str:
|
async def _ensure_token(self) -> str:
|
||||||
"""Return a valid access token, refreshing if expired."""
|
"""Return a valid access token, refreshing if expired.
|
||||||
|
|
||||||
|
Uses a lock to prevent concurrent token refresh attempts that would
|
||||||
|
hit the API's 1-per-minute rate limit (EGW00133).
|
||||||
|
"""
|
||||||
|
# Fast path: check without lock
|
||||||
now = asyncio.get_event_loop().time()
|
now = asyncio.get_event_loop().time()
|
||||||
if self._access_token and now < self._token_expires_at:
|
if self._access_token and now < self._token_expires_at:
|
||||||
return self._access_token
|
return self._access_token
|
||||||
|
|
||||||
logger.info("Refreshing KIS access token")
|
# Slow path: acquire lock and refresh
|
||||||
session = self._get_session()
|
async with self._token_lock:
|
||||||
url = f"{self._base_url}/oauth2/tokenP"
|
# Re-check after acquiring lock (another coroutine may have refreshed)
|
||||||
body = {
|
now = asyncio.get_event_loop().time()
|
||||||
"grant_type": "client_credentials",
|
if self._access_token and now < self._token_expires_at:
|
||||||
"appkey": self._app_key,
|
return self._access_token
|
||||||
"appsecret": self._app_secret,
|
|
||||||
}
|
|
||||||
|
|
||||||
async with session.post(url, json=body) as resp:
|
# Check cooldown period (prevents hitting EGW00133: 1/minute limit)
|
||||||
if resp.status != 200:
|
time_since_last_attempt = now - self._last_refresh_attempt
|
||||||
text = await resp.text()
|
if time_since_last_attempt < self._refresh_cooldown:
|
||||||
raise ConnectionError(f"Token refresh failed ({resp.status}): {text}")
|
remaining = self._refresh_cooldown - time_since_last_attempt
|
||||||
data = await resp.json()
|
error_msg = (
|
||||||
|
f"Token refresh on cooldown. "
|
||||||
|
f"Retry in {remaining:.1f}s (KIS allows 1/minute)"
|
||||||
|
)
|
||||||
|
logger.warning(error_msg)
|
||||||
|
raise ConnectionError(error_msg)
|
||||||
|
|
||||||
self._access_token = data["access_token"]
|
logger.info("Refreshing KIS access token")
|
||||||
self._token_expires_at = now + data.get("expires_in", 86400) - 60 # 1-min buffer
|
self._last_refresh_attempt = now
|
||||||
logger.info("Token refreshed successfully")
|
session = self._get_session()
|
||||||
return self._access_token
|
url = f"{self._base_url}/oauth2/tokenP"
|
||||||
|
body = {
|
||||||
|
"grant_type": "client_credentials",
|
||||||
|
"appkey": self._app_key,
|
||||||
|
"appsecret": self._app_secret,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with session.post(url, json=body) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
text = await resp.text()
|
||||||
|
raise ConnectionError(f"Token refresh failed ({resp.status}): {text}")
|
||||||
|
data = await resp.json()
|
||||||
|
|
||||||
|
self._access_token = data["access_token"]
|
||||||
|
self._token_expires_at = now + data.get("expires_in", 86400) - 60 # 1-min buffer
|
||||||
|
logger.info("Token refreshed successfully")
|
||||||
|
return self._access_token
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Hash Key (required for POST bodies)
|
# Hash Key (required for POST bodies)
|
||||||
@@ -111,6 +138,7 @@ class KISBroker:
|
|||||||
|
|
||||||
async def _get_hash_key(self, body: dict[str, Any]) -> str:
|
async def _get_hash_key(self, body: dict[str, Any]) -> str:
|
||||||
"""Request a hash key from KIS for POST request body signing."""
|
"""Request a hash key from KIS for POST request body signing."""
|
||||||
|
await self._rate_limiter.acquire()
|
||||||
session = self._get_session()
|
session = self._get_session()
|
||||||
url = f"{self._base_url}/uapi/hashkey"
|
url = f"{self._base_url}/uapi/hashkey"
|
||||||
headers = {
|
headers = {
|
||||||
|
|||||||
@@ -37,11 +37,18 @@ class Settings(BaseSettings):
|
|||||||
DB_PATH: str = "data/trade_logs.db"
|
DB_PATH: str = "data/trade_logs.db"
|
||||||
|
|
||||||
# Rate Limiting (requests per second for KIS API)
|
# Rate Limiting (requests per second for KIS API)
|
||||||
RATE_LIMIT_RPS: float = 10.0
|
# Conservative limit to avoid EGW00201 "초당 거래건수 초과" errors.
|
||||||
|
# KIS API real limit is ~2 RPS; 2.0 provides maximum safety.
|
||||||
|
RATE_LIMIT_RPS: float = 2.0
|
||||||
|
|
||||||
# Trading mode
|
# Trading mode
|
||||||
MODE: str = Field(default="paper", pattern="^(paper|live)$")
|
MODE: str = Field(default="paper", pattern="^(paper|live)$")
|
||||||
|
|
||||||
|
# Trading frequency mode (daily = batch API calls, realtime = per-stock calls)
|
||||||
|
TRADE_MODE: str = Field(default="daily", pattern="^(daily|realtime)$")
|
||||||
|
DAILY_SESSIONS: int = Field(default=4, ge=1, le=10)
|
||||||
|
SESSION_INTERVAL_HOURS: int = Field(default=6, ge=1, le=24)
|
||||||
|
|
||||||
# Market selection (comma-separated market codes)
|
# Market selection (comma-separated market codes)
|
||||||
ENABLED_MARKETS: str = "KR"
|
ENABLED_MARKETS: str = "KR"
|
||||||
|
|
||||||
|
|||||||
628
src/main.py
628
src/main.py
@@ -33,6 +33,35 @@ from src.notifications.telegram_client import TelegramClient
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def safe_float(value: str | float | None, default: float = 0.0) -> float:
|
||||||
|
"""Convert to float, handling empty strings and None.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Value to convert (string, float, or None)
|
||||||
|
default: Default value if conversion fails
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Converted float or default value
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> safe_float("123.45")
|
||||||
|
123.45
|
||||||
|
>>> safe_float("")
|
||||||
|
0.0
|
||||||
|
>>> safe_float(None)
|
||||||
|
0.0
|
||||||
|
>>> safe_float("invalid", 99.0)
|
||||||
|
99.0
|
||||||
|
"""
|
||||||
|
if value is None or value == "":
|
||||||
|
return default
|
||||||
|
try:
|
||||||
|
return float(value)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
# Target stock codes to monitor per market
|
# Target stock codes to monitor per market
|
||||||
WATCHLISTS = {
|
WATCHLISTS = {
|
||||||
"KR": ["005930", "000660", "035420"], # Samsung, SK Hynix, NAVER
|
"KR": ["005930", "000660", "035420"], # Samsung, SK Hynix, NAVER
|
||||||
@@ -45,6 +74,10 @@ TRADE_INTERVAL_SECONDS = 60
|
|||||||
SCAN_INTERVAL_SECONDS = 60 # Scan markets every 60 seconds
|
SCAN_INTERVAL_SECONDS = 60 # Scan markets every 60 seconds
|
||||||
MAX_CONNECTION_RETRIES = 3
|
MAX_CONNECTION_RETRIES = 3
|
||||||
|
|
||||||
|
# Daily trading mode constants (for Free tier API efficiency)
|
||||||
|
DAILY_TRADE_SESSIONS = 4 # Number of trading sessions per day
|
||||||
|
TRADE_SESSION_INTERVAL_HOURS = 6 # Hours between sessions
|
||||||
|
|
||||||
# Full stock universe per market (for scanning)
|
# Full stock universe per market (for scanning)
|
||||||
# In production, this would be loaded from a database or API
|
# In production, this would be loaded from a database or API
|
||||||
STOCK_UNIVERSE = {
|
STOCK_UNIVERSE = {
|
||||||
@@ -77,16 +110,16 @@ async def trading_cycle(
|
|||||||
balance_data = await broker.get_balance()
|
balance_data = await broker.get_balance()
|
||||||
|
|
||||||
output2 = balance_data.get("output2", [{}])
|
output2 = balance_data.get("output2", [{}])
|
||||||
total_eval = float(output2[0].get("tot_evlu_amt", "0")) if output2 else 0
|
total_eval = safe_float(output2[0].get("tot_evlu_amt", "0")) if output2 else 0
|
||||||
total_cash = float(
|
total_cash = safe_float(
|
||||||
balance_data.get("output2", [{}])[0].get("dnca_tot_amt", "0")
|
balance_data.get("output2", [{}])[0].get("dnca_tot_amt", "0")
|
||||||
if output2
|
if output2
|
||||||
else "0"
|
else "0"
|
||||||
)
|
)
|
||||||
purchase_total = float(output2[0].get("pchs_amt_smtl_amt", "0")) if output2 else 0
|
purchase_total = safe_float(output2[0].get("pchs_amt_smtl_amt", "0")) if output2 else 0
|
||||||
|
|
||||||
current_price = float(orderbook.get("output1", {}).get("stck_prpr", "0"))
|
current_price = safe_float(orderbook.get("output1", {}).get("stck_prpr", "0"))
|
||||||
foreigner_net = float(orderbook.get("output1", {}).get("frgn_ntby_qty", "0"))
|
foreigner_net = safe_float(orderbook.get("output1", {}).get("frgn_ntby_qty", "0"))
|
||||||
else:
|
else:
|
||||||
# Overseas market
|
# Overseas market
|
||||||
price_data = await overseas_broker.get_overseas_price(
|
price_data = await overseas_broker.get_overseas_price(
|
||||||
@@ -103,11 +136,11 @@ async def trading_cycle(
|
|||||||
else:
|
else:
|
||||||
balance_info = {}
|
balance_info = {}
|
||||||
|
|
||||||
total_eval = float(balance_info.get("frcr_evlu_tota", "0") or "0")
|
total_eval = safe_float(balance_info.get("frcr_evlu_tota", "0") or "0")
|
||||||
total_cash = float(balance_info.get("frcr_dncl_amt_2", "0") or "0")
|
total_cash = safe_float(balance_info.get("frcr_dncl_amt_2", "0") or "0")
|
||||||
purchase_total = float(balance_info.get("frcr_buy_amt_smtl", "0") or "0")
|
purchase_total = safe_float(balance_info.get("frcr_buy_amt_smtl", "0") or "0")
|
||||||
|
|
||||||
current_price = float(price_data.get("output", {}).get("last", "0"))
|
current_price = safe_float(price_data.get("output", {}).get("last", "0"))
|
||||||
foreigner_net = 0.0 # Not available for overseas
|
foreigner_net = 0.0 # Not available for overseas
|
||||||
|
|
||||||
# Calculate daily P&L %
|
# Calculate daily P&L %
|
||||||
@@ -292,6 +325,239 @@ async def trading_cycle(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_daily_session(
|
||||||
|
broker: KISBroker,
|
||||||
|
overseas_broker: OverseasBroker,
|
||||||
|
brain: GeminiClient,
|
||||||
|
risk: RiskManager,
|
||||||
|
db_conn: Any,
|
||||||
|
decision_logger: DecisionLogger,
|
||||||
|
context_store: ContextStore,
|
||||||
|
criticality_assessor: CriticalityAssessor,
|
||||||
|
telegram: TelegramClient,
|
||||||
|
settings: Settings,
|
||||||
|
) -> None:
|
||||||
|
"""Execute one daily trading session.
|
||||||
|
|
||||||
|
Designed for API efficiency with Gemini Free tier:
|
||||||
|
- Batch decision making (1 API call per market)
|
||||||
|
- Runs N times per day at fixed intervals
|
||||||
|
- Minimizes API usage while maintaining trading capability
|
||||||
|
"""
|
||||||
|
# Get currently open markets
|
||||||
|
open_markets = get_open_markets(settings.enabled_market_list)
|
||||||
|
|
||||||
|
if not open_markets:
|
||||||
|
logger.info("No markets open for this session")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Starting daily trading session for %d markets", len(open_markets))
|
||||||
|
|
||||||
|
# Process each open market
|
||||||
|
for market in open_markets:
|
||||||
|
# Get watchlist for this market
|
||||||
|
watchlist = WATCHLISTS.get(market.code, [])
|
||||||
|
if not watchlist:
|
||||||
|
logger.debug("No watchlist for market %s", market.code)
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info("Processing market: %s (%d stocks)", market.name, len(watchlist))
|
||||||
|
|
||||||
|
# Collect market data for all stocks in the watchlist
|
||||||
|
stocks_data = []
|
||||||
|
for stock_code in watchlist:
|
||||||
|
try:
|
||||||
|
if market.is_domestic:
|
||||||
|
orderbook = await broker.get_orderbook(stock_code)
|
||||||
|
current_price = safe_float(orderbook.get("output1", {}).get("stck_prpr", "0"))
|
||||||
|
foreigner_net = safe_float(
|
||||||
|
orderbook.get("output1", {}).get("frgn_ntby_qty", "0")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
price_data = await overseas_broker.get_overseas_price(
|
||||||
|
market.exchange_code, stock_code
|
||||||
|
)
|
||||||
|
current_price = safe_float(price_data.get("output", {}).get("last", "0"))
|
||||||
|
foreigner_net = 0.0
|
||||||
|
|
||||||
|
stocks_data.append(
|
||||||
|
{
|
||||||
|
"stock_code": stock_code,
|
||||||
|
"market_name": market.name,
|
||||||
|
"current_price": current_price,
|
||||||
|
"foreigner_net": foreigner_net,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to fetch data for %s: %s", stock_code, exc)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not stocks_data:
|
||||||
|
logger.warning("No valid stock data for market %s", market.code)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get batch decisions (1 API call for all stocks in this market)
|
||||||
|
logger.info("Requesting batch decision for %d stocks in %s", len(stocks_data), market.name)
|
||||||
|
decisions = await brain.decide_batch(stocks_data)
|
||||||
|
|
||||||
|
# Get balance data once for the market
|
||||||
|
if market.is_domestic:
|
||||||
|
balance_data = await broker.get_balance()
|
||||||
|
output2 = balance_data.get("output2", [{}])
|
||||||
|
total_eval = safe_float(output2[0].get("tot_evlu_amt", "0")) if output2 else 0
|
||||||
|
total_cash = safe_float(output2[0].get("dnca_tot_amt", "0")) if output2 else 0
|
||||||
|
purchase_total = safe_float(output2[0].get("pchs_amt_smtl_amt", "0")) if output2 else 0
|
||||||
|
else:
|
||||||
|
balance_data = await overseas_broker.get_overseas_balance(market.exchange_code)
|
||||||
|
output2 = balance_data.get("output2", [{}])
|
||||||
|
if isinstance(output2, list) and output2:
|
||||||
|
balance_info = output2[0]
|
||||||
|
elif isinstance(output2, dict):
|
||||||
|
balance_info = output2
|
||||||
|
else:
|
||||||
|
balance_info = {}
|
||||||
|
|
||||||
|
total_eval = safe_float(balance_info.get("frcr_evlu_tota", "0") or "0")
|
||||||
|
total_cash = safe_float(balance_info.get("frcr_dncl_amt_2", "0") or "0")
|
||||||
|
purchase_total = safe_float(balance_info.get("frcr_buy_amt_smtl", "0") or "0")
|
||||||
|
|
||||||
|
# Calculate daily P&L %
|
||||||
|
pnl_pct = (
|
||||||
|
((total_eval - purchase_total) / purchase_total * 100) if purchase_total > 0 else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute decisions for each stock
|
||||||
|
for stock_data in stocks_data:
|
||||||
|
stock_code = stock_data["stock_code"]
|
||||||
|
decision = decisions.get(stock_code)
|
||||||
|
|
||||||
|
if not decision:
|
||||||
|
logger.warning("No decision for %s — skipping", stock_code)
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Decision for %s (%s): %s (confidence=%d)",
|
||||||
|
stock_code,
|
||||||
|
market.name,
|
||||||
|
decision.action,
|
||||||
|
decision.confidence,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log decision
|
||||||
|
context_snapshot = {
|
||||||
|
"L1": {
|
||||||
|
"current_price": stock_data["current_price"],
|
||||||
|
"foreigner_net": stock_data["foreigner_net"],
|
||||||
|
},
|
||||||
|
"L2": {
|
||||||
|
"total_eval": total_eval,
|
||||||
|
"total_cash": total_cash,
|
||||||
|
"purchase_total": purchase_total,
|
||||||
|
"pnl_pct": pnl_pct,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
input_data = {
|
||||||
|
"current_price": stock_data["current_price"],
|
||||||
|
"foreigner_net": stock_data["foreigner_net"],
|
||||||
|
"total_eval": total_eval,
|
||||||
|
"total_cash": total_cash,
|
||||||
|
"pnl_pct": pnl_pct,
|
||||||
|
}
|
||||||
|
|
||||||
|
decision_logger.log_decision(
|
||||||
|
stock_code=stock_code,
|
||||||
|
market=market.code,
|
||||||
|
exchange_code=market.exchange_code,
|
||||||
|
action=decision.action,
|
||||||
|
confidence=decision.confidence,
|
||||||
|
rationale=decision.rationale,
|
||||||
|
context_snapshot=context_snapshot,
|
||||||
|
input_data=input_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute if actionable
|
||||||
|
if decision.action in ("BUY", "SELL"):
|
||||||
|
quantity = 1
|
||||||
|
order_amount = stock_data["current_price"] * quantity
|
||||||
|
|
||||||
|
# Risk check
|
||||||
|
try:
|
||||||
|
risk.validate_order(
|
||||||
|
current_pnl_pct=pnl_pct,
|
||||||
|
order_amount=order_amount,
|
||||||
|
total_cash=total_cash,
|
||||||
|
)
|
||||||
|
except FatFingerRejected as exc:
|
||||||
|
try:
|
||||||
|
await telegram.notify_fat_finger(
|
||||||
|
stock_code=stock_code,
|
||||||
|
order_amount=exc.order_amount,
|
||||||
|
total_cash=exc.total_cash,
|
||||||
|
max_pct=exc.max_pct,
|
||||||
|
)
|
||||||
|
except Exception as notify_exc:
|
||||||
|
logger.warning("Fat finger notification failed: %s", notify_exc)
|
||||||
|
continue # Skip this order
|
||||||
|
except CircuitBreakerTripped as exc:
|
||||||
|
logger.critical("Circuit breaker tripped — stopping session")
|
||||||
|
try:
|
||||||
|
await telegram.notify_circuit_breaker(
|
||||||
|
pnl_pct=exc.pnl_pct,
|
||||||
|
threshold=exc.threshold,
|
||||||
|
)
|
||||||
|
except Exception as notify_exc:
|
||||||
|
logger.warning("Circuit breaker notification failed: %s", notify_exc)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Send order
|
||||||
|
try:
|
||||||
|
if market.is_domestic:
|
||||||
|
result = await broker.send_order(
|
||||||
|
stock_code=stock_code,
|
||||||
|
order_type=decision.action,
|
||||||
|
quantity=quantity,
|
||||||
|
price=0, # market order
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = await overseas_broker.send_overseas_order(
|
||||||
|
exchange_code=market.exchange_code,
|
||||||
|
stock_code=stock_code,
|
||||||
|
order_type=decision.action,
|
||||||
|
quantity=quantity,
|
||||||
|
price=0.0, # market order
|
||||||
|
)
|
||||||
|
logger.info("Order result: %s", result.get("msg1", "OK"))
|
||||||
|
|
||||||
|
# Notify trade execution
|
||||||
|
try:
|
||||||
|
await telegram.notify_trade_execution(
|
||||||
|
stock_code=stock_code,
|
||||||
|
market=market.name,
|
||||||
|
action=decision.action,
|
||||||
|
quantity=quantity,
|
||||||
|
price=stock_data["current_price"],
|
||||||
|
confidence=decision.confidence,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Telegram notification failed: %s", exc)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Order execution failed for %s: %s", stock_code, exc)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Log trade
|
||||||
|
log_trade(
|
||||||
|
conn=db_conn,
|
||||||
|
stock_code=stock_code,
|
||||||
|
action=decision.action,
|
||||||
|
confidence=decision.confidence,
|
||||||
|
rationale=decision.rationale,
|
||||||
|
market=market.code,
|
||||||
|
exchange_code=market.exchange_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Daily trading session completed")
|
||||||
|
|
||||||
|
|
||||||
async def run(settings: Settings) -> None:
|
async def run(settings: Settings) -> None:
|
||||||
"""Main async loop — iterate over open markets on a timer."""
|
"""Main async loop — iterate over open markets on a timer."""
|
||||||
broker = KISBroker(settings)
|
broker = KISBroker(settings)
|
||||||
@@ -317,6 +583,7 @@ async def run(settings: Settings) -> None:
|
|||||||
volatility_analyzer=volatility_analyzer,
|
volatility_analyzer=volatility_analyzer,
|
||||||
context_store=context_store,
|
context_store=context_store,
|
||||||
top_n=5,
|
top_n=5,
|
||||||
|
max_concurrent_scans=1, # Fully serialized to avoid EGW00201
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize latency control system
|
# Initialize latency control system
|
||||||
@@ -345,7 +612,7 @@ async def run(settings: Settings) -> None:
|
|||||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||||
loop.add_signal_handler(sig, _signal_handler)
|
loop.add_signal_handler(sig, _signal_handler)
|
||||||
|
|
||||||
logger.info("The Ouroboros is alive. Mode: %s", settings.MODE)
|
logger.info("The Ouroboros is alive. Mode: %s, Trading: %s", settings.MODE, settings.TRADE_MODE)
|
||||||
logger.info("Enabled markets: %s", settings.enabled_market_list)
|
logger.info("Enabled markets: %s", settings.enabled_market_list)
|
||||||
|
|
||||||
# Notify system startup
|
# Notify system startup
|
||||||
@@ -355,172 +622,217 @@ async def run(settings: Settings) -> None:
|
|||||||
logger.warning("System startup notification failed: %s", exc)
|
logger.warning("System startup notification failed: %s", exc)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while not shutdown.is_set():
|
# Branch based on trading mode
|
||||||
# Get currently open markets
|
if settings.TRADE_MODE == "daily":
|
||||||
open_markets = get_open_markets(settings.enabled_market_list)
|
# Daily trading mode: batch decisions at fixed intervals
|
||||||
|
logger.info(
|
||||||
|
"Daily trading mode: %d sessions every %d hours",
|
||||||
|
settings.DAILY_SESSIONS,
|
||||||
|
settings.SESSION_INTERVAL_HOURS,
|
||||||
|
)
|
||||||
|
|
||||||
if not open_markets:
|
session_interval = settings.SESSION_INTERVAL_HOURS * 3600 # Convert to seconds
|
||||||
# Notify market close for any markets that were open
|
|
||||||
for market_code, is_open in list(_market_states.items()):
|
|
||||||
if is_open:
|
|
||||||
try:
|
|
||||||
from src.markets.schedule import MARKETS
|
|
||||||
|
|
||||||
market_info = MARKETS.get(market_code)
|
while not shutdown.is_set():
|
||||||
if market_info:
|
|
||||||
await telegram.notify_market_close(market_info.name, 0.0)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("Market close notification failed: %s", exc)
|
|
||||||
_market_states[market_code] = False
|
|
||||||
|
|
||||||
# No markets open — wait until next market opens
|
|
||||||
try:
|
try:
|
||||||
next_market, next_open_time = get_next_market_open(
|
await run_daily_session(
|
||||||
settings.enabled_market_list
|
broker,
|
||||||
|
overseas_broker,
|
||||||
|
brain,
|
||||||
|
risk,
|
||||||
|
db_conn,
|
||||||
|
decision_logger,
|
||||||
|
context_store,
|
||||||
|
criticality_assessor,
|
||||||
|
telegram,
|
||||||
|
settings,
|
||||||
)
|
)
|
||||||
now = datetime.now(UTC)
|
except CircuitBreakerTripped:
|
||||||
wait_seconds = (next_open_time - now).total_seconds()
|
logger.critical("Circuit breaker tripped — shutting down")
|
||||||
logger.info(
|
shutdown.set()
|
||||||
"No markets open. Next market: %s, opens in %.1f hours",
|
|
||||||
next_market.name,
|
|
||||||
wait_seconds / 3600,
|
|
||||||
)
|
|
||||||
await asyncio.wait_for(shutdown.wait(), timeout=wait_seconds)
|
|
||||||
except TimeoutError:
|
|
||||||
continue # Market should be open now
|
|
||||||
except ValueError as exc:
|
|
||||||
logger.error("Failed to find next market open: %s", exc)
|
|
||||||
await asyncio.sleep(TRADE_INTERVAL_SECONDS)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Process each open market
|
|
||||||
for market in open_markets:
|
|
||||||
if shutdown.is_set():
|
|
||||||
break
|
break
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Daily session error: %s", exc)
|
||||||
|
|
||||||
# Notify market open if it just opened
|
# Wait for next session or shutdown
|
||||||
if not _market_states.get(market.code, False):
|
logger.info("Next session in %.1f hours", session_interval / 3600)
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(shutdown.wait(), timeout=session_interval)
|
||||||
|
except TimeoutError:
|
||||||
|
pass # Normal — time for next session
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Realtime trading mode: original per-stock loop
|
||||||
|
logger.info("Realtime trading mode: 60s interval per stock")
|
||||||
|
|
||||||
|
while not shutdown.is_set():
|
||||||
|
# Get currently open markets
|
||||||
|
open_markets = get_open_markets(settings.enabled_market_list)
|
||||||
|
|
||||||
|
if not open_markets:
|
||||||
|
# Notify market close for any markets that were open
|
||||||
|
for market_code, is_open in list(_market_states.items()):
|
||||||
|
if is_open:
|
||||||
|
try:
|
||||||
|
from src.markets.schedule import MARKETS
|
||||||
|
|
||||||
|
market_info = MARKETS.get(market_code)
|
||||||
|
if market_info:
|
||||||
|
await telegram.notify_market_close(market_info.name, 0.0)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Market close notification failed: %s", exc)
|
||||||
|
_market_states[market_code] = False
|
||||||
|
|
||||||
|
# No markets open — wait until next market opens
|
||||||
try:
|
try:
|
||||||
await telegram.notify_market_open(market.name)
|
next_market, next_open_time = get_next_market_open(
|
||||||
except Exception as exc:
|
settings.enabled_market_list
|
||||||
logger.warning("Market open notification failed: %s", exc)
|
)
|
||||||
_market_states[market.code] = True
|
now = datetime.now(UTC)
|
||||||
|
wait_seconds = (next_open_time - now).total_seconds()
|
||||||
# Volatility Hunter: Scan market periodically to update watchlist
|
logger.info(
|
||||||
now_timestamp = asyncio.get_event_loop().time()
|
"No markets open. Next market: %s, opens in %.1f hours",
|
||||||
last_scan = last_scan_time.get(market.code, 0.0)
|
next_market.name,
|
||||||
if now_timestamp - last_scan >= SCAN_INTERVAL_SECONDS:
|
wait_seconds / 3600,
|
||||||
try:
|
)
|
||||||
# Scan all stocks in the universe
|
await asyncio.wait_for(shutdown.wait(), timeout=wait_seconds)
|
||||||
stock_universe = STOCK_UNIVERSE.get(market.code, [])
|
except TimeoutError:
|
||||||
if stock_universe:
|
continue # Market should be open now
|
||||||
logger.info("Volatility Hunter: Scanning %s market", market.name)
|
except ValueError as exc:
|
||||||
scan_result = await market_scanner.scan_market(
|
logger.error("Failed to find next market open: %s", exc)
|
||||||
market, stock_universe
|
await asyncio.sleep(TRADE_INTERVAL_SECONDS)
|
||||||
)
|
|
||||||
|
|
||||||
# Update watchlist with top movers
|
|
||||||
current_watchlist = WATCHLISTS.get(market.code, [])
|
|
||||||
updated_watchlist = market_scanner.get_updated_watchlist(
|
|
||||||
current_watchlist,
|
|
||||||
scan_result,
|
|
||||||
max_replacements=2,
|
|
||||||
)
|
|
||||||
WATCHLISTS[market.code] = updated_watchlist
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Volatility Hunter: Watchlist updated for %s (%d top movers, %d breakouts)",
|
|
||||||
market.name,
|
|
||||||
len(scan_result.top_movers),
|
|
||||||
len(scan_result.breakouts),
|
|
||||||
)
|
|
||||||
|
|
||||||
last_scan_time[market.code] = now_timestamp
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("Volatility Hunter scan failed for %s: %s", market.name, exc)
|
|
||||||
|
|
||||||
# Get watchlist for this market
|
|
||||||
watchlist = WATCHLISTS.get(market.code, [])
|
|
||||||
if not watchlist:
|
|
||||||
logger.debug("No watchlist for market %s", market.code)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.info("Processing market: %s (%d stocks)", market.name, len(watchlist))
|
# Process each open market
|
||||||
|
for market in open_markets:
|
||||||
# Process each stock in the watchlist
|
|
||||||
for stock_code in watchlist:
|
|
||||||
if shutdown.is_set():
|
if shutdown.is_set():
|
||||||
break
|
break
|
||||||
|
|
||||||
# Retry logic for connection errors
|
# Notify market open if it just opened
|
||||||
for attempt in range(1, MAX_CONNECTION_RETRIES + 1):
|
if not _market_states.get(market.code, False):
|
||||||
try:
|
try:
|
||||||
await trading_cycle(
|
await telegram.notify_market_open(market.name)
|
||||||
broker,
|
|
||||||
overseas_broker,
|
|
||||||
brain,
|
|
||||||
risk,
|
|
||||||
db_conn,
|
|
||||||
decision_logger,
|
|
||||||
context_store,
|
|
||||||
criticality_assessor,
|
|
||||||
telegram,
|
|
||||||
market,
|
|
||||||
stock_code,
|
|
||||||
)
|
|
||||||
break # Success — exit retry loop
|
|
||||||
except CircuitBreakerTripped as exc:
|
|
||||||
logger.critical("Circuit breaker tripped — shutting down")
|
|
||||||
try:
|
|
||||||
await telegram.notify_circuit_breaker(
|
|
||||||
pnl_pct=exc.pnl_pct,
|
|
||||||
threshold=exc.threshold,
|
|
||||||
)
|
|
||||||
except Exception as notify_exc:
|
|
||||||
logger.warning(
|
|
||||||
"Circuit breaker notification failed: %s", notify_exc
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
except ConnectionError as exc:
|
|
||||||
if attempt < MAX_CONNECTION_RETRIES:
|
|
||||||
logger.warning(
|
|
||||||
"Connection error for %s (attempt %d/%d): %s",
|
|
||||||
stock_code,
|
|
||||||
attempt,
|
|
||||||
MAX_CONNECTION_RETRIES,
|
|
||||||
exc,
|
|
||||||
)
|
|
||||||
await asyncio.sleep(2**attempt) # Exponential backoff
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
"Connection error for %s (all retries exhausted): %s",
|
|
||||||
stock_code,
|
|
||||||
exc,
|
|
||||||
)
|
|
||||||
break # Give up on this stock
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.exception("Unexpected error for %s: %s", stock_code, exc)
|
logger.warning("Market open notification failed: %s", exc)
|
||||||
break # Don't retry on unexpected errors
|
_market_states[market.code] = True
|
||||||
|
|
||||||
# Log priority queue metrics periodically
|
# Volatility Hunter: Scan market periodically to update watchlist
|
||||||
metrics = await priority_queue.get_metrics()
|
now_timestamp = asyncio.get_event_loop().time()
|
||||||
if metrics.total_enqueued > 0:
|
last_scan = last_scan_time.get(market.code, 0.0)
|
||||||
logger.info(
|
if now_timestamp - last_scan >= SCAN_INTERVAL_SECONDS:
|
||||||
"Priority queue metrics: enqueued=%d, dequeued=%d, size=%d, timeouts=%d, errors=%d",
|
try:
|
||||||
metrics.total_enqueued,
|
# Scan all stocks in the universe
|
||||||
metrics.total_dequeued,
|
stock_universe = STOCK_UNIVERSE.get(market.code, [])
|
||||||
metrics.current_size,
|
if stock_universe:
|
||||||
metrics.total_timeouts,
|
logger.info("Volatility Hunter: Scanning %s market", market.name)
|
||||||
metrics.total_errors,
|
scan_result = await market_scanner.scan_market(
|
||||||
)
|
market, stock_universe
|
||||||
|
)
|
||||||
|
|
||||||
# Wait for next cycle or shutdown
|
# Update watchlist with top movers
|
||||||
try:
|
current_watchlist = WATCHLISTS.get(market.code, [])
|
||||||
await asyncio.wait_for(shutdown.wait(), timeout=TRADE_INTERVAL_SECONDS)
|
updated_watchlist = market_scanner.get_updated_watchlist(
|
||||||
except TimeoutError:
|
current_watchlist,
|
||||||
pass # Normal — timeout means it's time for next cycle
|
scan_result,
|
||||||
|
max_replacements=2,
|
||||||
|
)
|
||||||
|
WATCHLISTS[market.code] = updated_watchlist
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Volatility Hunter: Watchlist updated for %s (%d top movers, %d breakouts)",
|
||||||
|
market.name,
|
||||||
|
len(scan_result.top_movers),
|
||||||
|
len(scan_result.breakouts),
|
||||||
|
)
|
||||||
|
|
||||||
|
last_scan_time[market.code] = now_timestamp
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Volatility Hunter scan failed for %s: %s", market.name, exc)
|
||||||
|
|
||||||
|
# Get watchlist for this market
|
||||||
|
watchlist = WATCHLISTS.get(market.code, [])
|
||||||
|
if not watchlist:
|
||||||
|
logger.debug("No watchlist for market %s", market.code)
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info("Processing market: %s (%d stocks)", market.name, len(watchlist))
|
||||||
|
|
||||||
|
# Process each stock in the watchlist
|
||||||
|
for stock_code in watchlist:
|
||||||
|
if shutdown.is_set():
|
||||||
|
break
|
||||||
|
|
||||||
|
# Retry logic for connection errors
|
||||||
|
for attempt in range(1, MAX_CONNECTION_RETRIES + 1):
|
||||||
|
try:
|
||||||
|
await trading_cycle(
|
||||||
|
broker,
|
||||||
|
overseas_broker,
|
||||||
|
brain,
|
||||||
|
risk,
|
||||||
|
db_conn,
|
||||||
|
decision_logger,
|
||||||
|
context_store,
|
||||||
|
criticality_assessor,
|
||||||
|
telegram,
|
||||||
|
market,
|
||||||
|
stock_code,
|
||||||
|
)
|
||||||
|
break # Success — exit retry loop
|
||||||
|
except CircuitBreakerTripped as exc:
|
||||||
|
logger.critical("Circuit breaker tripped — shutting down")
|
||||||
|
try:
|
||||||
|
await telegram.notify_circuit_breaker(
|
||||||
|
pnl_pct=exc.pnl_pct,
|
||||||
|
threshold=exc.threshold,
|
||||||
|
)
|
||||||
|
except Exception as notify_exc:
|
||||||
|
logger.warning(
|
||||||
|
"Circuit breaker notification failed: %s", notify_exc
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
except ConnectionError as exc:
|
||||||
|
if attempt < MAX_CONNECTION_RETRIES:
|
||||||
|
logger.warning(
|
||||||
|
"Connection error for %s (attempt %d/%d): %s",
|
||||||
|
stock_code,
|
||||||
|
attempt,
|
||||||
|
MAX_CONNECTION_RETRIES,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(2**attempt) # Exponential backoff
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
"Connection error for %s (all retries exhausted): %s",
|
||||||
|
stock_code,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
break # Give up on this stock
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Unexpected error for %s: %s", stock_code, exc)
|
||||||
|
break # Don't retry on unexpected errors
|
||||||
|
|
||||||
|
# Log priority queue metrics periodically
|
||||||
|
metrics = await priority_queue.get_metrics()
|
||||||
|
if metrics.total_enqueued > 0:
|
||||||
|
logger.info(
|
||||||
|
"Priority queue metrics: enqueued=%d, dequeued=%d, size=%d, timeouts=%d, errors=%d",
|
||||||
|
metrics.total_enqueued,
|
||||||
|
metrics.total_dequeued,
|
||||||
|
metrics.current_size,
|
||||||
|
metrics.total_timeouts,
|
||||||
|
metrics.total_errors,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for next cycle or shutdown
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(shutdown.wait(), timeout=TRADE_INTERVAL_SECONDS)
|
||||||
|
except TimeoutError:
|
||||||
|
pass # Normal — timeout means it's time for next cycle
|
||||||
finally:
|
finally:
|
||||||
|
# Clean up resources
|
||||||
await broker.close()
|
await broker.close()
|
||||||
|
await telegram.close()
|
||||||
db_conn.close()
|
db_conn.close()
|
||||||
logger.info("The Ouroboros rests.")
|
logger.info("The Ouroboros rests.")
|
||||||
|
|
||||||
|
|||||||
@@ -152,3 +152,121 @@ class TestPromptConstruction:
|
|||||||
assert "JSON" in prompt
|
assert "JSON" in prompt
|
||||||
assert "action" in prompt
|
assert "action" in prompt
|
||||||
assert "confidence" in prompt
|
assert "confidence" in prompt
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Batch Decision Making
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestBatchDecisionParsing:
|
||||||
|
"""Batch response parser must handle JSON arrays correctly."""
|
||||||
|
|
||||||
|
def test_parse_valid_batch_response(self, settings):
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
stocks_data = [
|
||||||
|
{"stock_code": "AAPL", "current_price": 185.5},
|
||||||
|
{"stock_code": "MSFT", "current_price": 420.0},
|
||||||
|
]
|
||||||
|
raw = """[
|
||||||
|
{"code": "AAPL", "action": "BUY", "confidence": 85, "rationale": "Strong momentum"},
|
||||||
|
{"code": "MSFT", "action": "HOLD", "confidence": 50, "rationale": "Wait for earnings"}
|
||||||
|
]"""
|
||||||
|
|
||||||
|
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||||
|
|
||||||
|
assert len(decisions) == 2
|
||||||
|
assert decisions["AAPL"].action == "BUY"
|
||||||
|
assert decisions["AAPL"].confidence == 85
|
||||||
|
assert decisions["MSFT"].action == "HOLD"
|
||||||
|
assert decisions["MSFT"].confidence == 50
|
||||||
|
|
||||||
|
def test_parse_batch_with_markdown_wrapper(self, settings):
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
||||||
|
raw = """```json
|
||||||
|
[{"code": "AAPL", "action": "BUY", "confidence": 90, "rationale": "Good"}]
|
||||||
|
```"""
|
||||||
|
|
||||||
|
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||||
|
|
||||||
|
assert decisions["AAPL"].action == "BUY"
|
||||||
|
assert decisions["AAPL"].confidence == 90
|
||||||
|
|
||||||
|
def test_parse_batch_empty_response_returns_hold_for_all(self, settings):
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
stocks_data = [
|
||||||
|
{"stock_code": "AAPL", "current_price": 185.5},
|
||||||
|
{"stock_code": "MSFT", "current_price": 420.0},
|
||||||
|
]
|
||||||
|
|
||||||
|
decisions = client._parse_batch_response("", stocks_data, token_count=100)
|
||||||
|
|
||||||
|
assert len(decisions) == 2
|
||||||
|
assert decisions["AAPL"].action == "HOLD"
|
||||||
|
assert decisions["AAPL"].confidence == 0
|
||||||
|
assert decisions["MSFT"].action == "HOLD"
|
||||||
|
|
||||||
|
def test_parse_batch_malformed_json_returns_hold_for_all(self, settings):
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
||||||
|
raw = "This is not JSON"
|
||||||
|
|
||||||
|
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||||
|
|
||||||
|
assert decisions["AAPL"].action == "HOLD"
|
||||||
|
assert decisions["AAPL"].confidence == 0
|
||||||
|
|
||||||
|
def test_parse_batch_not_array_returns_hold_for_all(self, settings):
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
||||||
|
raw = '{"code": "AAPL", "action": "BUY", "confidence": 90, "rationale": "Good"}'
|
||||||
|
|
||||||
|
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||||
|
|
||||||
|
assert decisions["AAPL"].action == "HOLD"
|
||||||
|
assert decisions["AAPL"].confidence == 0
|
||||||
|
|
||||||
|
def test_parse_batch_missing_stock_gets_hold(self, settings):
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
stocks_data = [
|
||||||
|
{"stock_code": "AAPL", "current_price": 185.5},
|
||||||
|
{"stock_code": "MSFT", "current_price": 420.0},
|
||||||
|
]
|
||||||
|
# Response only has AAPL, MSFT is missing
|
||||||
|
raw = '[{"code": "AAPL", "action": "BUY", "confidence": 85, "rationale": "Good"}]'
|
||||||
|
|
||||||
|
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||||
|
|
||||||
|
assert decisions["AAPL"].action == "BUY"
|
||||||
|
assert decisions["MSFT"].action == "HOLD"
|
||||||
|
assert decisions["MSFT"].confidence == 0
|
||||||
|
|
||||||
|
def test_parse_batch_invalid_action_becomes_hold(self, settings):
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
||||||
|
raw = '[{"code": "AAPL", "action": "YOLO", "confidence": 90, "rationale": "Moon"}]'
|
||||||
|
|
||||||
|
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||||
|
|
||||||
|
assert decisions["AAPL"].action == "HOLD"
|
||||||
|
|
||||||
|
def test_parse_batch_low_confidence_becomes_hold(self, settings):
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
||||||
|
raw = '[{"code": "AAPL", "action": "BUY", "confidence": 65, "rationale": "Weak"}]'
|
||||||
|
|
||||||
|
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||||
|
|
||||||
|
assert decisions["AAPL"].action == "HOLD"
|
||||||
|
assert decisions["AAPL"].confidence == 65
|
||||||
|
|
||||||
|
def test_parse_batch_missing_fields_gets_hold(self, settings):
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
||||||
|
raw = '[{"code": "AAPL", "action": "BUY"}]' # Missing confidence and rationale
|
||||||
|
|
||||||
|
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||||
|
|
||||||
|
assert decisions["AAPL"].action == "HOLD"
|
||||||
|
assert decisions["AAPL"].confidence == 0
|
||||||
|
|||||||
@@ -49,6 +49,110 @@ class TestTokenManagement:
|
|||||||
|
|
||||||
await broker.close()
|
await broker.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_token_refresh_calls_api_once(self, settings):
|
||||||
|
"""Multiple concurrent token requests should only call API once."""
|
||||||
|
broker = KISBroker(settings)
|
||||||
|
|
||||||
|
# Track how many times the mock API is called
|
||||||
|
call_count = [0]
|
||||||
|
|
||||||
|
def create_mock_resp():
|
||||||
|
call_count[0] += 1
|
||||||
|
mock_resp = AsyncMock()
|
||||||
|
mock_resp.status = 200
|
||||||
|
mock_resp.json = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"access_token": "tok_concurrent",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": 86400,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
return mock_resp
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession.post", return_value=create_mock_resp()):
|
||||||
|
# Launch 5 concurrent token requests
|
||||||
|
tokens = await asyncio.gather(
|
||||||
|
broker._ensure_token(),
|
||||||
|
broker._ensure_token(),
|
||||||
|
broker._ensure_token(),
|
||||||
|
broker._ensure_token(),
|
||||||
|
broker._ensure_token(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# All should get the same token
|
||||||
|
assert all(t == "tok_concurrent" for t in tokens)
|
||||||
|
# API should be called only once (due to lock)
|
||||||
|
assert call_count[0] == 1
|
||||||
|
|
||||||
|
await broker.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_token_refresh_cooldown_prevents_rapid_retries(self, settings):
|
||||||
|
"""Token refresh should enforce cooldown after failure (issue #54)."""
|
||||||
|
broker = KISBroker(settings)
|
||||||
|
broker._refresh_cooldown = 2.0 # Short cooldown for testing
|
||||||
|
|
||||||
|
# First refresh attempt fails with 403 (EGW00133)
|
||||||
|
mock_resp_403 = AsyncMock()
|
||||||
|
mock_resp_403.status = 403
|
||||||
|
mock_resp_403.text = AsyncMock(
|
||||||
|
return_value='{"error_code":"EGW00133","error_description":"접근토큰 발급 잠시 후 다시 시도하세요(1분당 1회)"}'
|
||||||
|
)
|
||||||
|
mock_resp_403.__aenter__ = AsyncMock(return_value=mock_resp_403)
|
||||||
|
mock_resp_403.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession.post", return_value=mock_resp_403):
|
||||||
|
# First attempt should fail with 403
|
||||||
|
with pytest.raises(ConnectionError, match="Token refresh failed"):
|
||||||
|
await broker._ensure_token()
|
||||||
|
|
||||||
|
# Second attempt within cooldown should fail with cooldown error
|
||||||
|
with pytest.raises(ConnectionError, match="Token refresh on cooldown"):
|
||||||
|
await broker._ensure_token()
|
||||||
|
|
||||||
|
await broker.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_token_refresh_allowed_after_cooldown(self, settings):
|
||||||
|
"""Token refresh should be allowed after cooldown period expires."""
|
||||||
|
broker = KISBroker(settings)
|
||||||
|
broker._refresh_cooldown = 0.1 # Very short cooldown for testing
|
||||||
|
|
||||||
|
# First attempt fails
|
||||||
|
mock_resp_403 = AsyncMock()
|
||||||
|
mock_resp_403.status = 403
|
||||||
|
mock_resp_403.text = AsyncMock(return_value='{"error_code":"EGW00133"}')
|
||||||
|
mock_resp_403.__aenter__ = AsyncMock(return_value=mock_resp_403)
|
||||||
|
mock_resp_403.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
# Second attempt succeeds
|
||||||
|
mock_resp_200 = AsyncMock()
|
||||||
|
mock_resp_200.status = 200
|
||||||
|
mock_resp_200.json = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"access_token": "tok_after_cooldown",
|
||||||
|
"expires_in": 86400,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_resp_200.__aenter__ = AsyncMock(return_value=mock_resp_200)
|
||||||
|
mock_resp_200.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession.post", return_value=mock_resp_403):
|
||||||
|
with pytest.raises(ConnectionError, match="Token refresh failed"):
|
||||||
|
await broker._ensure_token()
|
||||||
|
|
||||||
|
# Wait for cooldown to expire
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession.post", return_value=mock_resp_200):
|
||||||
|
token = await broker._ensure_token()
|
||||||
|
assert token == "tok_after_cooldown"
|
||||||
|
|
||||||
|
await broker.close()
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Network Error Handling
|
# Network Error Handling
|
||||||
@@ -107,6 +211,38 @@ class TestRateLimiter:
|
|||||||
await broker._rate_limiter.acquire()
|
await broker._rate_limiter.acquire()
|
||||||
await broker.close()
|
await broker.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_order_acquires_rate_limiter_twice(self, settings):
|
||||||
|
"""send_order must acquire rate limiter for both hash key and order call."""
|
||||||
|
broker = KISBroker(settings)
|
||||||
|
broker._access_token = "tok"
|
||||||
|
broker._token_expires_at = asyncio.get_event_loop().time() + 3600
|
||||||
|
|
||||||
|
# Mock hash key response
|
||||||
|
mock_hash_resp = AsyncMock()
|
||||||
|
mock_hash_resp.status = 200
|
||||||
|
mock_hash_resp.json = AsyncMock(return_value={"HASH": "abc123"})
|
||||||
|
mock_hash_resp.__aenter__ = AsyncMock(return_value=mock_hash_resp)
|
||||||
|
mock_hash_resp.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
# Mock order response
|
||||||
|
mock_order_resp = AsyncMock()
|
||||||
|
mock_order_resp.status = 200
|
||||||
|
mock_order_resp.json = AsyncMock(return_value={"rt_cd": "0"})
|
||||||
|
mock_order_resp.__aenter__ = AsyncMock(return_value=mock_order_resp)
|
||||||
|
mock_order_resp.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"aiohttp.ClientSession.post", side_effect=[mock_hash_resp, mock_order_resp]
|
||||||
|
):
|
||||||
|
with patch.object(
|
||||||
|
broker._rate_limiter, "acquire", new_callable=AsyncMock
|
||||||
|
) as mock_acquire:
|
||||||
|
await broker.send_order("005930", "BUY", 1, 50000)
|
||||||
|
assert mock_acquire.call_count == 2
|
||||||
|
|
||||||
|
await broker.close()
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Hash Key Generation
|
# Hash Key Generation
|
||||||
@@ -136,3 +272,27 @@ class TestHashKey:
|
|||||||
assert len(hash_key) > 0
|
assert len(hash_key) > 0
|
||||||
|
|
||||||
await broker.close()
|
await broker.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_hash_key_acquires_rate_limiter(self, settings):
|
||||||
|
"""_get_hash_key must go through the rate limiter to prevent burst."""
|
||||||
|
broker = KISBroker(settings)
|
||||||
|
broker._access_token = "tok"
|
||||||
|
broker._token_expires_at = asyncio.get_event_loop().time() + 3600
|
||||||
|
|
||||||
|
body = {"CANO": "12345678", "ACNT_PRDT_CD": "01"}
|
||||||
|
|
||||||
|
mock_resp = AsyncMock()
|
||||||
|
mock_resp.status = 200
|
||||||
|
mock_resp.json = AsyncMock(return_value={"HASH": "abc123hash"})
|
||||||
|
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
|
||||||
|
with patch.object(
|
||||||
|
broker._rate_limiter, "acquire", new_callable=AsyncMock
|
||||||
|
) as mock_acquire:
|
||||||
|
await broker._get_hash_key(body)
|
||||||
|
mock_acquire.assert_called_once()
|
||||||
|
|
||||||
|
await broker.close()
|
||||||
|
|||||||
@@ -6,7 +6,43 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from src.core.risk_manager import CircuitBreakerTripped, FatFingerRejected
|
from src.core.risk_manager import CircuitBreakerTripped, FatFingerRejected
|
||||||
from src.main import trading_cycle
|
from src.main import safe_float, trading_cycle
|
||||||
|
|
||||||
|
|
||||||
|
class TestSafeFloat:
|
||||||
|
"""Test safe_float() helper function."""
|
||||||
|
|
||||||
|
def test_converts_valid_string(self):
|
||||||
|
"""Test conversion of valid numeric string."""
|
||||||
|
assert safe_float("123.45") == 123.45
|
||||||
|
assert safe_float("0") == 0.0
|
||||||
|
assert safe_float("-99.9") == -99.9
|
||||||
|
|
||||||
|
def test_handles_empty_string(self):
|
||||||
|
"""Test empty string returns default."""
|
||||||
|
assert safe_float("") == 0.0
|
||||||
|
assert safe_float("", 99.0) == 99.0
|
||||||
|
|
||||||
|
def test_handles_none(self):
|
||||||
|
"""Test None returns default."""
|
||||||
|
assert safe_float(None) == 0.0
|
||||||
|
assert safe_float(None, 42.0) == 42.0
|
||||||
|
|
||||||
|
def test_handles_invalid_string(self):
|
||||||
|
"""Test invalid string returns default."""
|
||||||
|
assert safe_float("invalid") == 0.0
|
||||||
|
assert safe_float("not_a_number", 100.0) == 100.0
|
||||||
|
assert safe_float("12.34.56") == 0.0
|
||||||
|
|
||||||
|
def test_handles_float_input(self):
|
||||||
|
"""Test float input passes through."""
|
||||||
|
assert safe_float(123.45) == 123.45
|
||||||
|
assert safe_float(0.0) == 0.0
|
||||||
|
|
||||||
|
def test_custom_default(self):
|
||||||
|
"""Test custom default value."""
|
||||||
|
assert safe_float("", -1.0) == -1.0
|
||||||
|
assert safe_float(None, 999.0) == 999.0
|
||||||
|
|
||||||
|
|
||||||
class TestTradingCycleTelegramIntegration:
|
class TestTradingCycleTelegramIntegration:
|
||||||
@@ -394,6 +430,26 @@ class TestOverseasBalanceParsing:
|
|||||||
broker.get_overseas_balance = AsyncMock(return_value={"output2": []})
|
broker.get_overseas_balance = AsyncMock(return_value={"output2": []})
|
||||||
return broker
|
return broker
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_overseas_broker_with_empty_price(self) -> MagicMock:
|
||||||
|
"""Create mock overseas broker returning empty string for price."""
|
||||||
|
broker = MagicMock()
|
||||||
|
broker.get_overseas_price = AsyncMock(
|
||||||
|
return_value={"output": {"last": ""}} # Empty string
|
||||||
|
)
|
||||||
|
broker.get_overseas_balance = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"output2": [
|
||||||
|
{
|
||||||
|
"frcr_evlu_tota": "10000.00",
|
||||||
|
"frcr_dncl_amt_2": "5000.00",
|
||||||
|
"frcr_buy_amt_smtl": "4500.00",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return broker
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_domestic_broker(self) -> MagicMock:
|
def mock_domestic_broker(self) -> MagicMock:
|
||||||
"""Create minimal mock domestic broker."""
|
"""Create minimal mock domestic broker."""
|
||||||
@@ -559,3 +615,37 @@ class TestOverseasBalanceParsing:
|
|||||||
|
|
||||||
# Verify balance API was called
|
# Verify balance API was called
|
||||||
mock_overseas_broker_with_empty.get_overseas_balance.assert_called_once()
|
mock_overseas_broker_with_empty.get_overseas_balance.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_overseas_price_empty_string(
|
||||||
|
self,
|
||||||
|
mock_domestic_broker: MagicMock,
|
||||||
|
mock_overseas_broker_with_empty_price: MagicMock,
|
||||||
|
mock_brain_hold: MagicMock,
|
||||||
|
mock_risk: MagicMock,
|
||||||
|
mock_db: MagicMock,
|
||||||
|
mock_decision_logger: MagicMock,
|
||||||
|
mock_context_store: MagicMock,
|
||||||
|
mock_criticality_assessor: MagicMock,
|
||||||
|
mock_telegram: MagicMock,
|
||||||
|
mock_overseas_market: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test overseas price parsing with empty string (issue #49)."""
|
||||||
|
with patch("src.main.log_trade"):
|
||||||
|
# Should not raise ValueError, should default to 0.0
|
||||||
|
await trading_cycle(
|
||||||
|
broker=mock_domestic_broker,
|
||||||
|
overseas_broker=mock_overseas_broker_with_empty_price,
|
||||||
|
brain=mock_brain_hold,
|
||||||
|
risk=mock_risk,
|
||||||
|
db_conn=mock_db,
|
||||||
|
decision_logger=mock_decision_logger,
|
||||||
|
context_store=mock_context_store,
|
||||||
|
criticality_assessor=mock_criticality_assessor,
|
||||||
|
telegram=mock_telegram,
|
||||||
|
market=mock_overseas_market,
|
||||||
|
stock_code="AAPL",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify price API was called
|
||||||
|
mock_overseas_broker_with_empty_price.get_overseas_price.assert_called_once()
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
@@ -338,6 +339,28 @@ class TestMarketScanner:
|
|||||||
assert metrics.stock_code == "AAPL"
|
assert metrics.stock_code == "AAPL"
|
||||||
assert metrics.current_price == 150.50
|
assert metrics.current_price == 150.50
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scan_stock_overseas_empty_price(
|
||||||
|
self,
|
||||||
|
scanner: MarketScanner,
|
||||||
|
mock_overseas_broker: OverseasBroker,
|
||||||
|
context_store: ContextStore,
|
||||||
|
) -> None:
|
||||||
|
"""Test scanning overseas stock with empty price string (issue #49)."""
|
||||||
|
mock_overseas_broker.get_overseas_price.return_value = {
|
||||||
|
"output": {
|
||||||
|
"last": "", # Empty string
|
||||||
|
"tvol": "", # Empty string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
market = MARKETS["US_NASDAQ"]
|
||||||
|
metrics = await scanner.scan_stock("AAPL", market)
|
||||||
|
|
||||||
|
assert metrics is not None
|
||||||
|
assert metrics.stock_code == "AAPL"
|
||||||
|
assert metrics.current_price == 0.0 # Should default to 0.0
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_scan_stock_error_handling(
|
async def test_scan_stock_error_handling(
|
||||||
self,
|
self,
|
||||||
@@ -509,3 +532,45 @@ class TestMarketScanner:
|
|||||||
new_additions = [code for code in updated if code not in current_watchlist]
|
new_additions = [code for code in updated if code not in current_watchlist]
|
||||||
assert len(new_additions) <= 1
|
assert len(new_additions) <= 1
|
||||||
assert len(updated) == len(current_watchlist)
|
assert len(updated) == len(current_watchlist)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scan_market_respects_concurrency_limit(
|
||||||
|
self,
|
||||||
|
mock_broker: KISBroker,
|
||||||
|
mock_overseas_broker: OverseasBroker,
|
||||||
|
volatility_analyzer: VolatilityAnalyzer,
|
||||||
|
context_store: ContextStore,
|
||||||
|
) -> None:
|
||||||
|
"""scan_market should limit concurrent scans to max_concurrent_scans."""
|
||||||
|
max_concurrent = 2
|
||||||
|
scanner = MarketScanner(
|
||||||
|
broker=mock_broker,
|
||||||
|
overseas_broker=mock_overseas_broker,
|
||||||
|
volatility_analyzer=volatility_analyzer,
|
||||||
|
context_store=context_store,
|
||||||
|
top_n=5,
|
||||||
|
max_concurrent_scans=max_concurrent,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track peak concurrency
|
||||||
|
active_count = 0
|
||||||
|
peak_count = 0
|
||||||
|
|
||||||
|
original_scan = scanner.scan_stock
|
||||||
|
|
||||||
|
async def tracking_scan(code: str, market: Any) -> VolatilityMetrics:
|
||||||
|
nonlocal active_count, peak_count
|
||||||
|
active_count += 1
|
||||||
|
peak_count = max(peak_count, active_count)
|
||||||
|
await asyncio.sleep(0.05) # Simulate API call duration
|
||||||
|
active_count -= 1
|
||||||
|
return VolatilityMetrics(code, 50000, 500, 1.0, 1.0, 1.0, 1.0, 10.0, 50.0)
|
||||||
|
|
||||||
|
scanner.scan_stock = tracking_scan # type: ignore[method-assign]
|
||||||
|
|
||||||
|
market = MARKETS["KR"]
|
||||||
|
stock_codes = ["001", "002", "003", "004", "005", "006"]
|
||||||
|
|
||||||
|
await scanner.scan_market(market, stock_codes)
|
||||||
|
|
||||||
|
assert peak_count <= max_concurrent
|
||||||
|
|||||||
Reference in New Issue
Block a user