Compare commits
282 Commits
feature/is
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 990f9696ab | |||
|
|
9bf72c63ec | ||
|
|
1399fa4d09 | ||
| f63fb53289 | |||
|
|
5050a4cf84 | ||
|
|
4987b6393a | ||
| 8faf974522 | |||
|
|
d524159ad0 | ||
|
|
c7c740f446 | ||
|
|
1333c65455 | ||
| 9db7f903f8 | |||
|
|
4660310ee4 | ||
|
|
c383a411ff | ||
| 7b3ba27ef7 | |||
|
|
6ff887c047 | ||
| 219eef6388 | |||
|
|
9d7ca12275 | ||
|
|
ccb00ee77d | ||
| b1b728f62e | |||
|
|
df12be1305 | ||
| 6a6d3bd631 | |||
|
|
7aa5fedc12 | ||
|
|
3e777a5ab8 | ||
| 6f93258983 | |||
|
|
82167c5b8a | ||
| f87c4dc2f0 | |||
|
|
8af5f564c3 | ||
| 06e4fc5597 | |||
|
|
b697b6d515 | ||
| 42db5b3cc1 | |||
|
|
f252a84d65 | ||
| adc5211fd2 | |||
|
|
67e0e8df41 | ||
| ffdb99c6c7 | |||
|
|
ce5ea5abde | ||
| 5ae302b083 | |||
|
|
d31a61cd0b | ||
|
|
1c7a17320c | ||
| f58d42fdb0 | |||
|
|
0b20251de0 | ||
| bffe6e9288 | |||
|
|
0146d1bf8a | ||
| 497564e75c | |||
|
|
988a56c07c | ||
| c9f1345e3c | |||
|
|
8c492eae3a | ||
| 271c592a46 | |||
|
|
a063bd9d10 | ||
| 847456e0af | |||
|
|
a3a9fd1f24 | ||
|
|
f34117bc81 | ||
| 17e012cd04 | |||
|
|
a030dcc0dc | ||
|
|
d1698dee33 | ||
| 8a8ba3b0cb | |||
|
|
6b74e4cc77 | ||
| 1a1fe7e637 | |||
|
|
2e27000760 | ||
| 5a41f86112 | |||
|
|
ff9c4d6082 | ||
| 25ad4776c9 | |||
|
|
9339824e22 | ||
| e6eae6c6e0 | |||
| bb6bd0392e | |||
| a66181b7a7 | |||
| da585ee547 | |||
| c737d5009a | |||
|
|
f7d33e69d1 | ||
|
|
7d99d8ec4a | ||
|
|
0727f28f77 | ||
|
|
ac4fb00644 | ||
|
|
4fc4a57036 | ||
| 641f3e8811 | |||
|
|
ebd0a0297c | ||
| 02a72e0f7e | |||
| 478a659ac2 | |||
|
|
16b9b6832d | ||
|
|
48b87a79f6 | ||
|
|
ad79082dcc | ||
|
|
11dff9d3e5 | ||
|
|
3c5f1752e6 | ||
|
|
d6a389e0b7 | ||
| cd36d53a47 | |||
|
|
1242794fc4 | ||
| b45d136894 | |||
|
|
ce82121f04 | ||
| 0e2987e66d | |||
|
|
cdd5a218a7 | ||
|
|
f3491e94e4 | ||
|
|
342511a6ed | ||
| 2d5912dc08 | |||
|
|
40ea41cf3c | ||
| af5bfbac24 | |||
|
|
7e9a573390 | ||
| 7dbc48260c | |||
|
|
4b883a4fc4 | ||
|
|
98071a8ee3 | ||
|
|
f2ad270e8b | ||
| 04c73a1a06 | |||
|
|
4da22b10eb | ||
| c920b257b6 | |||
| 9927bfa13e | |||
|
|
aceba86186 | ||
|
|
b961c53a92 | ||
| 76a7ee7cdb | |||
|
|
77577f3f4d | ||
| 17112b864a | |||
|
|
28bcc7acd7 | ||
|
|
39b9f179f4 | ||
| bd2b3241b2 | |||
| 561faaaafa | |||
| a33d6a145f | |||
| 7e6c912214 | |||
|
|
d6edbc0fa2 | ||
|
|
c7640a30d7 | ||
|
|
60a22d6cd4 | ||
|
|
b1f48d859e | ||
| 03f8d220a4 | |||
|
|
305120f599 | ||
| faa23b3f1b | |||
|
|
5844ec5ad3 | ||
| ff5ff736d8 | |||
|
|
4a59d7e66d | ||
|
|
8dd625bfd1 | ||
| b50977aa76 | |||
|
|
fbcd016e1a | ||
| ce5773ba45 | |||
|
|
7834b89f10 | ||
| e0d6c9f81d | |||
|
|
2e550f8b58 | ||
| c76e2dfed5 | |||
|
|
24fa22e77b | ||
| cd1579058c | |||
| 45b48fa7cd | |||
|
|
3952a5337b | ||
|
|
ccc97ebaa9 | ||
|
|
3a54db8948 | ||
|
|
96e2ad4f1f | ||
| c5a8982122 | |||
|
|
f7289606fc | ||
| 0c5c90201f | |||
|
|
b484f0daff | ||
|
|
1288181e39 | ||
|
|
b625f41621 | ||
| 77d3ba967c | |||
|
|
aeed881d85 | ||
|
|
d0bbdb5dc1 | ||
| 44339c52d7 | |||
|
|
22ffdafacc | ||
|
|
c49765e951 | ||
| 64000b9967 | |||
|
|
733e6b36e9 | ||
|
|
0659cc0aca | ||
|
|
748b9b848e | ||
|
|
6a1ad230ee | ||
| 90bbc78867 | |||
|
|
1ef5dcb2b3 | ||
|
|
d105a3ff5e | ||
| 0424c78f6c | |||
|
|
3fdb7a29d4 | ||
| 31b4d0bf1e | |||
|
|
e2275a23b1 | ||
| 7522bb7e66 | |||
|
|
63fa6841a2 | ||
| ece3c5597b | |||
|
|
63f4e49d88 | ||
|
|
e0a6b307a2 | ||
| 75320eb587 | |||
|
|
afb31b7f4b | ||
| a429a9f4da | |||
|
|
d9763def85 | ||
| ab7f0444b2 | |||
|
|
6b3960a3a4 | ||
| 6cad8e74e1 | |||
|
|
86c94cff62 | ||
| 692cb61991 | |||
|
|
392422992b | ||
| cc637a9738 | |||
|
|
8c27473fed | ||
| bde54c7487 | |||
|
|
a14f944fcc | ||
| 56f7405baa | |||
|
|
e3b1ecc572 | ||
| 8acf72b22c | |||
|
|
c95102a0bd | ||
| 0685d62f9c | |||
|
|
78021d4695 | ||
| 3cdd10783b | |||
|
|
c4e31be27a | ||
| 9d9ade14eb | |||
|
|
9a8936ab34 | ||
| c5831966ed | |||
|
|
f03cc6039b | ||
| 9171e54652 | |||
|
|
d64e072f06 | ||
|
|
b2312fbe01 | ||
|
|
98c4a2413c | ||
| 6fba7c7ae8 | |||
|
|
be695a5d7c | ||
|
|
6471e66d89 | ||
|
|
149039a904 | ||
| 815d675529 | |||
|
|
e8634b93c3 | ||
| f20736fd2a | |||
|
|
7f2f96a819 | ||
| aaa74894dd | |||
|
|
e711d6702a | ||
|
|
d2fc829380 | ||
| de27b1af10 | |||
|
|
7370220497 | ||
| b01dacf328 | |||
|
|
1210c17989 | ||
|
|
9599b188e8 | ||
| c43660a58c | |||
|
|
7fd48c7764 | ||
| a105bb7c1a | |||
|
|
1a34a74232 | ||
| a82a167915 | |||
|
|
7725e7a8de | ||
|
|
f0ae25c533 | ||
| 27f581f17d | |||
|
|
18a098d9a6 | ||
| d2b07326ed | |||
|
|
1c5eadc23b | ||
| 10ff718045 | |||
|
|
0ca3fe9f5d | ||
| 462f8763ab | |||
|
|
57a45a24cb | ||
| a7696568cc | |||
|
|
70701bf73a | ||
| 20dbd94892 | |||
|
|
48a99962e3 | ||
| ee66ecc305 | |||
|
|
065c9daaad | ||
| c76b9d5c15 | |||
|
|
259f9d2e24 | ||
| 8e715c55cd | |||
|
|
0057de4d12 | ||
|
|
71ac59794e | ||
| be04820b00 | |||
| 10b6e34d44 | |||
| 58f1106dbd | |||
| cf5072cced | |||
|
|
702653e52e | ||
|
|
db0d966a6a | ||
|
|
a56adcd342 | ||
|
|
eaf509a895 | ||
|
|
854931bed2 | ||
| 33b5ff5e54 | |||
| 3923d03650 | |||
|
|
c57ccc4bca | ||
|
|
cb2e3fae57 | ||
| 5e4c68c9d8 | |||
|
|
95f540e5df | ||
| 0087a6b20a | |||
|
|
3dfd7c0935 | ||
| 4b2bb25d03 | |||
|
|
881bbb4240 | ||
| 5f7d61748b | |||
|
|
972e71a2f1 | ||
| 614b9939b1 | |||
|
|
6dbc2afbf4 | ||
| 6c96f9ac64 | |||
|
|
ed26915562 | ||
| 628a572c70 | |||
|
|
73e1d0a54e | ||
| b111157dc8 | |||
|
|
8c05448843 | ||
|
|
87556b145e | ||
| 645c761238 | |||
|
|
033d5fcadd | ||
| 128324427f | |||
|
|
61f5aaf4a3 | ||
|
|
4f61d5af8e | ||
|
|
62fd4ff5e1 | ||
| f40f19e735 | |||
|
|
ce952d97b2 | ||
| 53d3637b3e | |||
|
|
ae7195c829 | ||
| ad1f17bb56 | |||
|
|
62b1a1f37a | ||
| 2a80030ceb |
69
.env.example
69
.env.example
@@ -1,23 +1,82 @@
|
||||
# ============================================================
|
||||
# The Ouroboros — Environment Configuration
|
||||
# ============================================================
|
||||
# Copy this file to .env and fill in your values.
|
||||
# Lines starting with # are comments.
|
||||
|
||||
# ============================================================
|
||||
# Korea Investment Securities API
|
||||
# ============================================================
|
||||
KIS_APP_KEY=your_app_key_here
|
||||
KIS_APP_SECRET=your_app_secret_here
|
||||
KIS_ACCOUNT_NO=12345678-01
|
||||
KIS_BASE_URL=https://openapivts.koreainvestment.com:9443
|
||||
|
||||
# Paper trading (VTS): https://openapivts.koreainvestment.com:29443
|
||||
# Live trading: https://openapi.koreainvestment.com:9443
|
||||
KIS_BASE_URL=https://openapivts.koreainvestment.com:29443
|
||||
|
||||
# ============================================================
|
||||
# Trading Mode
|
||||
# ============================================================
|
||||
# paper = 모의투자 (safe for testing), live = 실전투자 (real money)
|
||||
MODE=paper
|
||||
|
||||
# daily = batch per session, realtime = per-stock continuous scan
|
||||
TRADE_MODE=daily
|
||||
|
||||
# Comma-separated market codes: KR, US, JP, HK, CN, VN
|
||||
ENABLED_MARKETS=KR,US
|
||||
|
||||
# Simulated USD cash for paper (VTS) overseas trading.
|
||||
# VTS overseas balance API often returns 0; this value is used as fallback.
|
||||
# Set to 0 to disable fallback (not used in live mode).
|
||||
PAPER_OVERSEAS_CASH=50000.0
|
||||
|
||||
# ============================================================
|
||||
# Google Gemini
|
||||
# ============================================================
|
||||
GEMINI_API_KEY=your_gemini_api_key_here
|
||||
GEMINI_MODEL=gemini-pro
|
||||
# Recommended: gemini-2.0-flash-exp or gemini-1.5-pro
|
||||
GEMINI_MODEL=gemini-2.0-flash-exp
|
||||
|
||||
# ============================================================
|
||||
# Risk Management
|
||||
# ============================================================
|
||||
CIRCUIT_BREAKER_PCT=-3.0
|
||||
FAT_FINGER_PCT=30.0
|
||||
CONFIDENCE_THRESHOLD=80
|
||||
|
||||
# ============================================================
|
||||
# Database
|
||||
# ============================================================
|
||||
DB_PATH=data/trade_logs.db
|
||||
|
||||
# ============================================================
|
||||
# Rate Limiting
|
||||
RATE_LIMIT_RPS=10.0
|
||||
# ============================================================
|
||||
# KIS API real limit is ~2 RPS. Keep at 2.0 for maximum safety.
|
||||
# Increasing this risks EGW00201 "초당 거래건수 초과" errors.
|
||||
RATE_LIMIT_RPS=2.0
|
||||
|
||||
# Trading Mode (paper / live)
|
||||
MODE=paper
|
||||
# ============================================================
|
||||
# External Data APIs (optional)
|
||||
# ============================================================
|
||||
# NEWS_API_KEY=your_news_api_key_here
|
||||
# NEWS_API_PROVIDER=alphavantage
|
||||
# MARKET_DATA_API_KEY=your_market_data_key_here
|
||||
|
||||
# ============================================================
|
||||
# Telegram Notifications (optional)
|
||||
# ============================================================
|
||||
# Get bot token from @BotFather on Telegram
|
||||
# Get chat ID from @userinfobot or your chat
|
||||
# TELEGRAM_BOT_TOKEN=1234567890:ABCdefGHIjklMNOpqrsTUVwxyz
|
||||
# TELEGRAM_CHAT_ID=123456789
|
||||
# TELEGRAM_ENABLED=true
|
||||
|
||||
# ============================================================
|
||||
# Dashboard (optional)
|
||||
# ============================================================
|
||||
# DASHBOARD_ENABLED=false
|
||||
# DASHBOARD_HOST=127.0.0.1
|
||||
# DASHBOARD_PORT=8080
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -174,4 +174,7 @@ cython_debug/
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
# Data files (trade logs, databases)
|
||||
# But NOT src/data/ which contains source code
|
||||
data/
|
||||
!src/data/
|
||||
|
||||
96
CLAUDE.md
96
CLAUDE.md
@@ -15,8 +15,76 @@ pytest -v --cov=src
|
||||
|
||||
# Run (paper trading)
|
||||
python -m src.main --mode=paper
|
||||
|
||||
# Run with dashboard
|
||||
python -m src.main --mode=paper --dashboard
|
||||
```
|
||||
|
||||
## Telegram Notifications (Optional)
|
||||
|
||||
Get real-time alerts for trades, circuit breakers, and system events via Telegram.
|
||||
|
||||
### Quick Setup
|
||||
|
||||
1. **Create bot**: Message [@BotFather](https://t.me/BotFather) on Telegram → `/newbot`
|
||||
2. **Get chat ID**: Message [@userinfobot](https://t.me/userinfobot) → `/start`
|
||||
3. **Configure**: Add to `.env`:
|
||||
```bash
|
||||
TELEGRAM_BOT_TOKEN=1234567890:ABCdefGHIjklMNOpqrsTUVwxyz
|
||||
TELEGRAM_CHAT_ID=123456789
|
||||
TELEGRAM_ENABLED=true
|
||||
```
|
||||
4. **Test**: Start bot conversation (`/start`), then run the agent
|
||||
|
||||
**Full documentation**: [src/notifications/README.md](src/notifications/README.md)
|
||||
|
||||
### What You'll Get
|
||||
|
||||
- 🟢 Trade execution alerts (BUY/SELL with confidence)
|
||||
- 🚨 Circuit breaker trips (automatic trading halt)
|
||||
- ⚠️ Fat-finger rejections (oversized orders blocked)
|
||||
- ℹ️ Market open/close notifications
|
||||
- 📝 System startup/shutdown status
|
||||
|
||||
### Interactive Commands
|
||||
|
||||
With `TELEGRAM_COMMANDS_ENABLED=true` (default), the bot supports 9 bidirectional commands: `/help`, `/status`, `/positions`, `/report`, `/scenarios`, `/review`, `/dashboard`, `/stop`, `/resume`.
|
||||
|
||||
**Fail-safe**: Notifications never crash the trading system. Missing credentials or API errors are logged but trading continues normally.
|
||||
|
||||
## Smart Volatility Scanner (Optional)
|
||||
|
||||
Python-first filtering pipeline that reduces Gemini API calls by pre-filtering stocks using technical indicators.
|
||||
|
||||
### How It Works
|
||||
|
||||
1. **Fetch Rankings** — KIS API volume surge rankings (top 30 stocks)
|
||||
2. **Python Filter** — RSI + volume ratio calculations (no AI)
|
||||
- Volume > 200% of previous day
|
||||
- RSI(14) < 30 (oversold) OR RSI(14) > 70 (momentum)
|
||||
3. **AI Judgment** — Only qualified candidates (1-3 stocks) sent to Gemini
|
||||
|
||||
### Configuration
|
||||
|
||||
Add to `.env` (optional, has sensible defaults):
|
||||
```bash
|
||||
RSI_OVERSOLD_THRESHOLD=30 # 0-50, default 30
|
||||
RSI_MOMENTUM_THRESHOLD=70 # 50-100, default 70
|
||||
VOL_MULTIPLIER=2.0 # Volume threshold (2.0 = 200%)
|
||||
SCANNER_TOP_N=3 # Max candidates per scan
|
||||
```
|
||||
|
||||
### Benefits
|
||||
|
||||
- **Reduces API costs** — Process 1-3 stocks instead of 20-30
|
||||
- **Python-based filtering** — Fast technical analysis before AI
|
||||
- **Evolution-ready** — Selection context logged for strategy optimization
|
||||
- **Fault-tolerant** — Falls back to static watchlist on API failure
|
||||
|
||||
### Realtime Mode Only
|
||||
|
||||
Smart Scanner runs in `TRADE_MODE=realtime` only. Daily mode uses static watchlists for batch efficiency.
|
||||
|
||||
## Documentation
|
||||
|
||||
- **[Workflow Guide](docs/workflow.md)** — Git workflow policy and agent-based development
|
||||
@@ -25,6 +93,8 @@ python -m src.main --mode=paper
|
||||
- **[Context Tree](docs/context-tree.md)** — L1-L7 hierarchical memory system
|
||||
- **[Testing](docs/testing.md)** — Test structure, coverage requirements, writing tests
|
||||
- **[Agent Policies](docs/agents.md)** — Prime directives, constraints, prohibited actions
|
||||
- **[Requirements Log](docs/requirements-log.md)** — User requirements and feedback tracking
|
||||
- **[Live Trading Checklist](docs/live-trading-checklist.md)** — 모의→실전 전환 체크리스트
|
||||
|
||||
## Core Principles
|
||||
|
||||
@@ -33,20 +103,37 @@ python -m src.main --mode=paper
|
||||
3. **Issue-Driven Development** — All work goes through Gitea issues → feature branches → PRs
|
||||
4. **Agent Specialization** — Use dedicated agents for design, coding, testing, docs, review
|
||||
|
||||
## Requirements Management
|
||||
|
||||
User requirements and feedback are tracked in [docs/requirements-log.md](docs/requirements-log.md):
|
||||
|
||||
- New requirements are added chronologically with dates
|
||||
- Code changes should reference related requirements
|
||||
- Helps maintain project evolution aligned with user needs
|
||||
- Preserves context across conversations and development cycles
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
src/
|
||||
├── analysis/ # Technical analysis (RSI, volatility, smart scanner)
|
||||
├── backup/ # Disaster recovery (scheduler, cloud storage, health)
|
||||
├── brain/ # Gemini AI decision engine (prompt optimizer, context selector)
|
||||
├── broker/ # KIS API client (domestic + overseas)
|
||||
├── brain/ # Gemini AI decision engine
|
||||
├── context/ # L1-L7 hierarchical memory system
|
||||
├── core/ # Risk manager (READ-ONLY)
|
||||
├── evolution/ # Self-improvement optimizer
|
||||
├── dashboard/ # FastAPI read-only monitoring (8 API endpoints)
|
||||
├── data/ # External data integration (news, market data, calendar)
|
||||
├── evolution/ # Self-improvement (optimizer, daily review, scorecard)
|
||||
├── logging/ # Decision logger (audit trail)
|
||||
├── markets/ # Market schedules and timezone handling
|
||||
├── notifications/ # Telegram alerts + bidirectional commands (9 commands)
|
||||
├── strategy/ # Pre-market planner, scenario engine, playbook store
|
||||
├── db.py # SQLite trade logging
|
||||
├── main.py # Trading loop orchestrator
|
||||
└── config.py # Settings (from .env)
|
||||
|
||||
tests/ # 54 tests across 4 files
|
||||
tests/ # 551 tests across 25 files
|
||||
docs/ # Extended documentation
|
||||
```
|
||||
|
||||
@@ -58,6 +145,7 @@ ruff check src/ tests/ # Lint
|
||||
mypy src/ --strict # Type check
|
||||
|
||||
python -m src.main --mode=paper # Paper trading
|
||||
python -m src.main --mode=paper --dashboard # With dashboard
|
||||
python -m src.main --mode=live # Live trading (⚠️ real money)
|
||||
|
||||
# Gitea workflow (requires tea CLI)
|
||||
@@ -83,7 +171,7 @@ Markets auto-detected based on timezone and enabled in `ENABLED_MARKETS` env var
|
||||
- `src/core/risk_manager.py` is **READ-ONLY** — changes require human approval
|
||||
- Circuit breaker at -3.0% P&L — may only be made **stricter**
|
||||
- Fat-finger protection: max 30% of cash per order — always enforced
|
||||
- Confidence < 80 → force HOLD — cannot be weakened
|
||||
- Confidence 임계값 (market_outlook별, 낮출 수 없음): BEARISH ≥ 90, NEUTRAL/기본 ≥ 80, BULLISH ≥ 75
|
||||
- All code changes → corresponding tests → coverage ≥ 80%
|
||||
|
||||
## Contributing
|
||||
|
||||
178
README.md
178
README.md
@@ -10,27 +10,41 @@ KIS(한국투자증권) API로 매매하고, Google Gemini로 판단하며, 자
|
||||
│ (매매 실행) │ │ (거래 루프) │ │ (의사결정) │
|
||||
└─────────────┘ └──────┬──────┘ └─────────────┘
|
||||
│
|
||||
┌──────┴──────┐
|
||||
│Risk Manager │
|
||||
│ (안전장치) │
|
||||
└──────┬──────┘
|
||||
┌────────────┼────────────┐
|
||||
│ │ │
|
||||
┌──────┴──────┐ ┌──┴───┐ ┌──────┴──────┐
|
||||
│Risk Manager │ │ DB │ │ Telegram │
|
||||
│ (안전장치) │ │ │ │ (알림+명령) │
|
||||
└──────┬──────┘ └──────┘ └─────────────┘
|
||||
│
|
||||
┌──────┴──────┐
|
||||
│ Evolution │
|
||||
│ (전략 진화) │
|
||||
└─────────────┘
|
||||
┌────────┼────────┐
|
||||
│ │ │
|
||||
┌────┴────┐┌──┴──┐┌────┴─────┐
|
||||
│Strategy ││Ctx ││Evolution │
|
||||
│(플레이북)││(메모리)││ (진화) │
|
||||
└─────────┘└─────┘└──────────┘
|
||||
```
|
||||
|
||||
**v2 핵심**: "Plan Once, Execute Locally" — 장 시작 전 AI가 시나리오 플레이북을 1회 생성하고, 거래 시간에는 로컬 시나리오 매칭만 수행하여 API 비용과 지연 시간을 대폭 절감.
|
||||
|
||||
## 핵심 모듈
|
||||
|
||||
| 모듈 | 파일 | 설명 |
|
||||
| 모듈 | 위치 | 설명 |
|
||||
|------|------|------|
|
||||
| 설정 | `src/config.py` | Pydantic 기반 환경변수 로딩 및 타입 검증 |
|
||||
| 브로커 | `src/broker/kis_api.py` | KIS API 비동기 래퍼 (토큰 갱신, 레이트 리미터, 해시키) |
|
||||
| 두뇌 | `src/brain/gemini_client.py` | Gemini 프롬프트 구성 및 JSON 응답 파싱 |
|
||||
| 방패 | `src/core/risk_manager.py` | 서킷 브레이커 + 팻 핑거 체크 |
|
||||
| 진화 | `src/evolution/optimizer.py` | 실패 패턴 분석 → 새 전략 생성 → 테스트 → PR |
|
||||
| DB | `src/db.py` | SQLite 거래 로그 기록 |
|
||||
| 설정 | `src/config.py` | Pydantic 기반 환경변수 로딩 및 타입 검증 (35+ 변수) |
|
||||
| 브로커 | `src/broker/` | KIS API 비동기 래퍼 (국내 + 해외 9개 시장) |
|
||||
| 두뇌 | `src/brain/` | Gemini 프롬프트 구성, JSON 파싱, 토큰 최적화 |
|
||||
| 방패 | `src/core/risk_manager.py` | 서킷 브레이커 + 팻 핑거 체크 (READ-ONLY) |
|
||||
| 전략 | `src/strategy/` | Pre-Market Planner, Scenario Engine, Playbook Store |
|
||||
| 컨텍스트 | `src/context/` | L1-L7 계층형 메모리 시스템 |
|
||||
| 분석 | `src/analysis/` | RSI, ATR, Smart Volatility Scanner |
|
||||
| 알림 | `src/notifications/` | 텔레그램 양방향 (알림 + 9개 명령어) |
|
||||
| 대시보드 | `src/dashboard/` | FastAPI 읽기 전용 모니터링 (8개 API) |
|
||||
| 진화 | `src/evolution/` | 전략 진화 + Daily Review + Scorecard |
|
||||
| 의사결정 로그 | `src/logging/` | 전체 거래 결정 감사 추적 |
|
||||
| 데이터 | `src/data/` | 뉴스, 시장 데이터, 경제 캘린더 연동 |
|
||||
| 백업 | `src/backup/` | 자동 백업, S3 클라우드, 무결성 검증 |
|
||||
| DB | `src/db.py` | SQLite 거래 로그 (5개 테이블) |
|
||||
|
||||
## 안전장치
|
||||
|
||||
@@ -41,6 +55,7 @@ KIS(한국투자증권) API로 매매하고, Google Gemini로 판단하며, 자
|
||||
| 신뢰도 임계값 | Gemini 신뢰도 80 미만이면 강제 HOLD |
|
||||
| 레이트 리미터 | Leaky Bucket 알고리즘으로 API 호출 제한 |
|
||||
| 토큰 자동 갱신 | 만료 1분 전 자동으로 Access Token 재발급 |
|
||||
| 손절 모니터링 | 플레이북 시나리오 기반 실시간 포지션 보호 |
|
||||
|
||||
## 빠른 시작
|
||||
|
||||
@@ -66,7 +81,11 @@ pytest -v --cov=src --cov-report=term-missing
|
||||
### 4. 실행 (모의투자)
|
||||
|
||||
```bash
|
||||
# 기본 실행
|
||||
python -m src.main --mode=paper
|
||||
|
||||
# 대시보드 활성화
|
||||
python -m src.main --mode=paper --dashboard
|
||||
```
|
||||
|
||||
### 5. Docker 실행
|
||||
@@ -75,23 +94,90 @@ python -m src.main --mode=paper
|
||||
docker compose up -d ouroboros
|
||||
```
|
||||
|
||||
## 지원 시장
|
||||
|
||||
| 국가 | 거래소 | 코드 |
|
||||
|------|--------|------|
|
||||
| 🇰🇷 한국 | KRX | KR |
|
||||
| 🇺🇸 미국 | NASDAQ, NYSE, AMEX | US_NASDAQ, US_NYSE, US_AMEX |
|
||||
| 🇯🇵 일본 | TSE | JP |
|
||||
| 🇭🇰 홍콩 | SEHK | HK |
|
||||
| 🇨🇳 중국 | 상하이, 선전 | CN_SHA, CN_SZA |
|
||||
| 🇻🇳 베트남 | 하노이, 호치민 | VN_HNX, VN_HSX |
|
||||
|
||||
`ENABLED_MARKETS` 환경변수로 활성 시장 선택 (기본: `KR,US`).
|
||||
|
||||
## 텔레그램 (선택사항)
|
||||
|
||||
거래 실행, 서킷 브레이커 발동, 시스템 상태 등을 텔레그램으로 실시간 알림 받을 수 있습니다.
|
||||
|
||||
### 빠른 설정
|
||||
|
||||
1. **봇 생성**: 텔레그램에서 [@BotFather](https://t.me/BotFather) 메시지 → `/newbot` 명령
|
||||
2. **채팅 ID 확인**: [@userinfobot](https://t.me/userinfobot) 메시지 → `/start` 명령
|
||||
3. **환경변수 설정**: `.env` 파일에 추가
|
||||
```bash
|
||||
TELEGRAM_BOT_TOKEN=1234567890:ABCdefGHIjklMNOpqrsTUVwxyz
|
||||
TELEGRAM_CHAT_ID=123456789
|
||||
TELEGRAM_ENABLED=true
|
||||
```
|
||||
4. **테스트**: 봇과 대화 시작 (`/start` 전송) 후 에이전트 실행
|
||||
|
||||
**상세 문서**: [src/notifications/README.md](src/notifications/README.md)
|
||||
|
||||
### 알림 종류
|
||||
|
||||
- 🟢 거래 체결 알림 (BUY/SELL + 신뢰도)
|
||||
- 🚨 서킷 브레이커 발동 (자동 거래 중단)
|
||||
- ⚠️ 팻 핑거 차단 (과도한 주문 차단)
|
||||
- ℹ️ 장 시작/종료 알림
|
||||
- 📝 시스템 시작/종료 상태
|
||||
|
||||
### 양방향 명령어
|
||||
|
||||
`TELEGRAM_COMMANDS_ENABLED=true` (기본값) 설정 시 9개 대화형 명령어 지원:
|
||||
|
||||
| 명령어 | 설명 |
|
||||
|--------|------|
|
||||
| `/help` | 사용 가능한 명령어 목록 |
|
||||
| `/status` | 거래 상태 (모드, 시장, P&L) |
|
||||
| `/positions` | 계좌 요약 (잔고, 현금, P&L) |
|
||||
| `/report` | 일일 요약 (거래 수, P&L, 승률) |
|
||||
| `/scenarios` | 오늘의 플레이북 시나리오 |
|
||||
| `/review` | 최근 스코어카드 (L6_DAILY) |
|
||||
| `/dashboard` | 대시보드 URL 표시 |
|
||||
| `/stop` | 거래 일시 정지 |
|
||||
| `/resume` | 거래 재개 |
|
||||
|
||||
**안전장치**: 알림 실패해도 거래는 계속 진행됩니다.
|
||||
|
||||
## 테스트
|
||||
|
||||
35개 테스트가 TDD 방식으로 구현 전에 먼저 작성되었습니다.
|
||||
551개 테스트가 25개 파일에 걸쳐 구현되어 있습니다. 최소 커버리지 80%.
|
||||
|
||||
```
|
||||
tests/test_risk.py — 서킷 브레이커, 팻 핑거, 통합 검증 (11개)
|
||||
tests/test_broker.py — 토큰 관리, 타임아웃, HTTP 에러, 해시키 (6개)
|
||||
tests/test_brain.py — JSON 파싱, 신뢰도 임계값, 비정상 응답 처리 (15개)
|
||||
tests/test_scenario_engine.py — 시나리오 매칭 (44개)
|
||||
tests/test_data_integration.py — 외부 데이터 연동 (38개)
|
||||
tests/test_pre_market_planner.py — 플레이북 생성 (37개)
|
||||
tests/test_main.py — 거래 루프 통합 (37개)
|
||||
tests/test_token_efficiency.py — 토큰 최적화 (34개)
|
||||
tests/test_strategy_models.py — 전략 모델 검증 (33개)
|
||||
tests/test_telegram_commands.py — 텔레그램 명령어 (31개)
|
||||
tests/test_latency_control.py — 지연시간 제어 (30개)
|
||||
tests/test_telegram.py — 텔레그램 알림 (25개)
|
||||
... 외 16개 파일
|
||||
```
|
||||
|
||||
**상세**: [docs/testing.md](docs/testing.md)
|
||||
|
||||
## 기술 스택
|
||||
|
||||
- **언어**: Python 3.11+ (asyncio 기반)
|
||||
- **브로커**: KIS Open API (REST)
|
||||
- **브로커**: KIS Open API (REST, 국내+해외)
|
||||
- **AI**: Google Gemini Pro
|
||||
- **DB**: SQLite
|
||||
- **검증**: pytest + coverage
|
||||
- **DB**: SQLite (5개 테이블: trades, contexts, decision_logs, playbooks, context_metadata)
|
||||
- **대시보드**: FastAPI + uvicorn
|
||||
- **검증**: pytest + coverage (551 tests)
|
||||
- **CI/CD**: GitHub Actions
|
||||
- **배포**: Docker + Docker Compose
|
||||
|
||||
@@ -99,26 +185,50 @@ tests/test_brain.py — JSON 파싱, 신뢰도 임계값, 비정상 응답 처
|
||||
|
||||
```
|
||||
The-Ouroboros/
|
||||
├── .github/workflows/ci.yml # CI 파이프라인
|
||||
├── docs/
|
||||
│ ├── agents.md # AI 에이전트 페르소나 정의
|
||||
│ └── skills.md # 사용 가능한 도구 목록
|
||||
│ ├── architecture.md # 시스템 아키텍처
|
||||
│ ├── testing.md # 테스트 가이드
|
||||
│ ├── commands.md # 명령어 레퍼런스
|
||||
│ ├── context-tree.md # L1-L7 메모리 시스템
|
||||
│ ├── workflow.md # Git 워크플로우
|
||||
│ ├── agents.md # 에이전트 정책
|
||||
│ ├── skills.md # 도구 목록
|
||||
│ ├── disaster_recovery.md # 백업/복구
|
||||
│ └── requirements-log.md # 요구사항 기록
|
||||
├── src/
|
||||
│ ├── analysis/ # 기술적 분석 (RSI, ATR, Smart Scanner)
|
||||
│ ├── backup/ # 백업 (스케줄러, S3, 무결성 검증)
|
||||
│ ├── brain/ # Gemini 의사결정 (프롬프트 최적화, 컨텍스트 선택)
|
||||
│ ├── broker/ # KIS API (국내 + 해외)
|
||||
│ ├── context/ # L1-L7 계층 메모리
|
||||
│ ├── core/ # 리스크 관리 (READ-ONLY)
|
||||
│ ├── dashboard/ # FastAPI 모니터링 대시보드
|
||||
│ ├── data/ # 외부 데이터 연동
|
||||
│ ├── evolution/ # 전략 진화 + Daily Review
|
||||
│ ├── logging/ # 의사결정 감사 추적
|
||||
│ ├── markets/ # 시장 스케줄 + 타임존
|
||||
│ ├── notifications/ # 텔레그램 알림 + 명령어
|
||||
│ ├── strategy/ # 플레이북 (Planner, Scenario Engine)
|
||||
│ ├── config.py # Pydantic 설정
|
||||
│ ├── logging_config.py # JSON 구조화 로깅
|
||||
│ ├── db.py # SQLite 거래 기록
|
||||
│ ├── main.py # 비동기 거래 루프
|
||||
│ ├── broker/kis_api.py # KIS API 클라이언트
|
||||
│ ├── brain/gemini_client.py # Gemini 의사결정 엔진
|
||||
│ ├── core/risk_manager.py # 리스크 관리
|
||||
│ ├── evolution/optimizer.py # 전략 진화 엔진
|
||||
│ └── strategies/base.py # 전략 베이스 클래스
|
||||
├── tests/ # TDD 테스트 스위트
|
||||
│ ├── db.py # SQLite 데이터베이스
|
||||
│ └── main.py # 비동기 거래 루프
|
||||
├── tests/ # 551개 테스트 (25개 파일)
|
||||
├── Dockerfile # 멀티스테이지 빌드
|
||||
├── docker-compose.yml # 서비스 오케스트레이션
|
||||
└── pyproject.toml # 의존성 및 도구 설정
|
||||
```
|
||||
|
||||
## 문서
|
||||
|
||||
- **[아키텍처](docs/architecture.md)** — 시스템 설계, 컴포넌트, 데이터 흐름
|
||||
- **[테스트](docs/testing.md)** — 테스트 구조, 커버리지, 작성 가이드
|
||||
- **[명령어](docs/commands.md)** — CLI, Dashboard, Telegram 명령어
|
||||
- **[컨텍스트 트리](docs/context-tree.md)** — L1-L7 계층 메모리
|
||||
- **[워크플로우](docs/workflow.md)** — Git 워크플로우 정책
|
||||
- **[에이전트 정책](docs/agents.md)** — 안전 제약, 금지 행위
|
||||
- **[백업/복구](docs/disaster_recovery.md)** — 재해 복구 절차
|
||||
- **[요구사항](docs/requirements-log.md)** — 사용자 요구사항 추적
|
||||
|
||||
## 라이선스
|
||||
|
||||
이 프로젝트의 라이선스는 [LICENSE](LICENSE) 파일을 참조하세요.
|
||||
|
||||
45
docs/agent-constraints.md
Normal file
45
docs/agent-constraints.md
Normal file
@@ -0,0 +1,45 @@
|
||||
# Agent Constraints
|
||||
|
||||
This document records **persistent behavioral constraints** for agents working on this repository.
|
||||
It is distinct from `docs/requirements-log.md`, which records **project/product requirements**.
|
||||
|
||||
## Scope
|
||||
|
||||
- Applies to all AI agents and automation that modify this repo.
|
||||
- Supplements (does not replace) `docs/agents.md` and `docs/workflow.md`.
|
||||
|
||||
## Persistent Rules
|
||||
|
||||
1. **Workflow enforcement**
|
||||
- Follow `docs/workflow.md` for all changes.
|
||||
- Create a Gitea issue before any code or documentation change.
|
||||
- Work on a feature branch `feature/issue-{N}-{short-description}` and open a PR.
|
||||
- Never commit directly to `main`.
|
||||
|
||||
2. **Document-first routing**
|
||||
- When performing work, consult relevant `docs/` files *before* making changes.
|
||||
- Route decisions to the documented policy whenever applicable.
|
||||
- If guidance conflicts, prefer the stricter/safety-first rule and note it in the PR.
|
||||
|
||||
3. **Docs with code**
|
||||
- Any code change must be accompanied by relevant documentation updates.
|
||||
- If no doc update is needed, state the reason explicitly in the PR.
|
||||
|
||||
4. **Session-persistent user constraints**
|
||||
- If the user requests that a behavior should persist across sessions, record it here
|
||||
(or in a dedicated policy doc) and reference it when working.
|
||||
- Keep entries short and concrete, with dates.
|
||||
|
||||
## Change Control
|
||||
|
||||
- Changes to this file follow the same workflow as code changes.
|
||||
- Keep the history chronological and minimize rewording of existing entries.
|
||||
|
||||
## History
|
||||
|
||||
### 2026-02-08
|
||||
|
||||
- Always enforce Gitea workflow: issue -> feature branch -> PR before changes.
|
||||
- When work requires guidance, consult the relevant `docs/` policies first.
|
||||
- Any code change must be accompanied by relevant documentation updates.
|
||||
- Persist user constraints across sessions by recording them in this document.
|
||||
@@ -2,7 +2,44 @@
|
||||
|
||||
## Overview
|
||||
|
||||
Self-evolving AI trading agent for global stock markets via KIS (Korea Investment & Securities) API. The main loop in `src/main.py` orchestrates four components in a 60-second cycle per stock across multiple markets.
|
||||
Self-evolving AI trading agent for global stock markets via KIS (Korea Investment & Securities) API. The main loop in `src/main.py` orchestrates components across multiple markets with two trading modes: daily (batch API calls) or realtime (per-stock decisions).
|
||||
|
||||
**v2 Proactive Playbook Architecture**: The system uses a "plan once, execute locally" approach. Pre-market, the AI generates a playbook of scenarios (one Gemini API call per market per day). During trading hours, a local scenario engine matches live market data against these pre-computed scenarios — no additional AI calls needed. This dramatically reduces API costs and latency.
|
||||
|
||||
## Trading Modes
|
||||
|
||||
The system supports two trading frequency modes controlled by the `TRADE_MODE` environment variable:
|
||||
|
||||
### Daily Mode (default)
|
||||
|
||||
Optimized for Gemini Free tier API limits (20 calls/day):
|
||||
|
||||
- **Batch decisions**: 1 API call per market per session
|
||||
- **Fixed schedule**: 4 sessions per day at 6-hour intervals (configurable)
|
||||
- **API efficiency**: Processes all stocks in a market simultaneously
|
||||
- **Use case**: Free tier users, cost-conscious deployments
|
||||
- **Configuration**:
|
||||
```bash
|
||||
TRADE_MODE=daily
|
||||
DAILY_SESSIONS=4 # Sessions per day (1-10)
|
||||
SESSION_INTERVAL_HOURS=6 # Hours between sessions (1-24)
|
||||
```
|
||||
|
||||
**Example**: With 2 markets (US, KR) and 4 sessions/day = 8 API calls/day (within 20 call limit)
|
||||
|
||||
### Realtime Mode
|
||||
|
||||
High-frequency trading with individual stock analysis:
|
||||
|
||||
- **Per-stock decisions**: 1 API call per stock per cycle
|
||||
- **60-second interval**: Continuous monitoring
|
||||
- **Use case**: Production deployments with Gemini paid tier
|
||||
- **Configuration**:
|
||||
```bash
|
||||
TRADE_MODE=realtime
|
||||
```
|
||||
|
||||
**Note**: Realtime mode requires Gemini API subscription due to high call volume.
|
||||
|
||||
## Core Components
|
||||
|
||||
@@ -11,9 +48,11 @@ Self-evolving AI trading agent for global stock markets via KIS (Korea Investmen
|
||||
**KISBroker** (`kis_api.py`) — Async KIS API client for domestic Korean market
|
||||
|
||||
- Automatic OAuth token refresh (valid for 24 hours)
|
||||
- Leaky-bucket rate limiter (10 requests per second)
|
||||
- Leaky-bucket rate limiter (configurable RPS, default 2.0)
|
||||
- POST body hash-key signing for order authentication
|
||||
- Custom SSL context with disabled hostname verification for VTS (virtual trading) endpoint due to known certificate mismatch
|
||||
- `fetch_market_rankings()` — Fetch volume surge rankings from KIS API
|
||||
- `get_daily_prices()` — Fetch OHLCV history for technical analysis
|
||||
|
||||
**OverseasBroker** (`overseas.py`) — KIS overseas stock API wrapper
|
||||
|
||||
@@ -28,10 +67,47 @@ Self-evolving AI trading agent for global stock markets via KIS (Korea Investmen
|
||||
- `is_market_open()` checks weekends, trading hours, lunch breaks
|
||||
- `get_open_markets()` returns currently active markets
|
||||
- `get_next_market_open()` finds next market to open and when
|
||||
- 10 global markets defined (KR, US_NASDAQ, US_NYSE, US_AMEX, JP, HK, CN_SHA, CN_SZA, VN_HNX, VN_HSX)
|
||||
|
||||
### 2. Brain (`src/brain/gemini_client.py`)
|
||||
**Overseas Ranking API Methods** (added in v0.10.x):
|
||||
- `fetch_overseas_rankings()` — Fetch overseas ranking universe (fluctuation / volume)
|
||||
- Ranking endpoint paths and TR_IDs are configurable via environment variables
|
||||
|
||||
**GeminiClient** — AI decision engine powered by Google Gemini
|
||||
### 2. Analysis (`src/analysis/`)
|
||||
|
||||
**VolatilityAnalyzer** (`volatility.py`) — Technical indicator calculations
|
||||
|
||||
- ATR (Average True Range) for volatility measurement
|
||||
- RSI (Relative Strength Index) using Wilder's smoothing method
|
||||
- Price change percentages across multiple timeframes
|
||||
- Volume surge ratios and price-volume divergence
|
||||
- Momentum scoring (0-100 scale)
|
||||
- Breakout/breakdown pattern detection
|
||||
|
||||
**SmartVolatilityScanner** (`smart_scanner.py`) — Python-first filtering pipeline
|
||||
|
||||
- **Domestic (KR)**:
|
||||
- **Step 1**: Fetch domestic fluctuation ranking as primary universe
|
||||
- **Step 2**: Fetch domestic volume ranking for liquidity bonus
|
||||
- **Step 3**: Compute volatility-first score (max of daily change% and intraday range%)
|
||||
- **Step 4**: Apply liquidity bonus and return top N candidates
|
||||
- **Overseas (US/JP/HK/CN/VN)**:
|
||||
- **Step 1**: Fetch overseas ranking universe (fluctuation rank + volume rank bonus)
|
||||
- **Step 2**: Compute volatility-first score (max of daily change% and intraday range%)
|
||||
- **Step 3**: Apply liquidity bonus from volume ranking
|
||||
- **Step 4**: Return top N candidates (default 3)
|
||||
- **Fallback (overseas only)**: If ranking API is unavailable, uses dynamic universe
|
||||
from runtime active symbols + recent traded symbols + current holdings (no static watchlist)
|
||||
- **Realtime mode only**: Daily mode uses batch processing for API efficiency
|
||||
|
||||
**Benefits:**
|
||||
- Reduces Gemini API calls from 20-30 stocks to 1-3 qualified candidates
|
||||
- Fast Python-based filtering before expensive AI judgment
|
||||
- Logs selection context (RSI-compatible proxy, volume_ratio, signal, score) for Evolution system
|
||||
|
||||
### 3. Brain (`src/brain/`)
|
||||
|
||||
**GeminiClient** (`gemini_client.py`) — AI decision engine powered by Google Gemini
|
||||
|
||||
- Constructs structured prompts from market data
|
||||
- Parses JSON responses into `TradeDecision` objects (`action`, `confidence`, `rationale`)
|
||||
@@ -39,11 +115,20 @@ Self-evolving AI trading agent for global stock markets via KIS (Korea Investmen
|
||||
- Falls back to safe HOLD on any parse/API error
|
||||
- Handles markdown-wrapped JSON, malformed responses, invalid actions
|
||||
|
||||
### 3. Risk Manager (`src/core/risk_manager.py`)
|
||||
**PromptOptimizer** (`prompt_optimizer.py`) — Token efficiency optimization
|
||||
|
||||
- Reduces prompt size while preserving decision quality
|
||||
- Caches optimized prompts
|
||||
|
||||
**ContextSelector** (`context_selector.py`) — Relevant context selection for prompts
|
||||
|
||||
- Selects appropriate context layers for current market conditions
|
||||
|
||||
### 4. Risk Manager (`src/core/risk_manager.py`)
|
||||
|
||||
**RiskManager** — Safety circuit breaker and order validation
|
||||
|
||||
⚠️ **READ-ONLY by policy** (see [`docs/agents.md`](./agents.md))
|
||||
> **READ-ONLY by policy** (see [`docs/agents.md`](./agents.md))
|
||||
|
||||
- **Circuit Breaker**: Halts all trading via `SystemExit` when daily P&L drops below -3.0%
|
||||
- Threshold may only be made stricter, never relaxed
|
||||
@@ -51,9 +136,106 @@ Self-evolving AI trading agent for global stock markets via KIS (Korea Investmen
|
||||
- **Fat-Finger Protection**: Rejects orders exceeding 30% of available cash
|
||||
- Must always be enforced, cannot be disabled
|
||||
|
||||
### 4. Evolution (`src/evolution/optimizer.py`)
|
||||
### 5. Strategy (`src/strategy/`)
|
||||
|
||||
**StrategyOptimizer** — Self-improvement loop
|
||||
**Pre-Market Planner** (`pre_market_planner.py`) — AI playbook generation
|
||||
|
||||
- Runs before market open (configurable `PRE_MARKET_MINUTES`, default 30)
|
||||
- Generates scenario-based playbooks via single Gemini API call per market
|
||||
- Handles timeout (`PLANNER_TIMEOUT_SECONDS`, default 60) with defensive playbook fallback
|
||||
- Persists playbooks to database for audit trail
|
||||
|
||||
**Scenario Engine** (`scenario_engine.py`) — Local scenario matching
|
||||
|
||||
- Matches live market data against pre-computed playbook scenarios
|
||||
- No AI calls during trading hours — pure Python matching logic
|
||||
- Returns matched scenarios with confidence scores
|
||||
- Configurable `MAX_SCENARIOS_PER_STOCK` (default 5)
|
||||
- Periodic rescan at `RESCAN_INTERVAL_SECONDS` (default 300)
|
||||
|
||||
**Playbook Store** (`playbook_store.py`) — Playbook persistence
|
||||
|
||||
- SQLite-backed storage for daily playbooks
|
||||
- Date and market-based retrieval
|
||||
- Status tracking (generated, active, expired)
|
||||
|
||||
**Models** (`models.py`) — Pydantic data models
|
||||
|
||||
- Scenario, Playbook, MatchResult, and related type definitions
|
||||
|
||||
### 6. Context System (`src/context/`)
|
||||
|
||||
**Context Store** (`store.py`) — L1-L7 hierarchical memory
|
||||
|
||||
- 7-layer context system (see [docs/context-tree.md](./context-tree.md)):
|
||||
- L1: Tick-level (real-time price)
|
||||
- L2: Intraday (session summary)
|
||||
- L3: Daily (end-of-day)
|
||||
- L4: Weekly (trend analysis)
|
||||
- L5: Monthly (strategy review)
|
||||
- L6: Daily Review (scorecard)
|
||||
- L7: Evolution (long-term learning)
|
||||
- Key-value storage with timeframe tagging
|
||||
- SQLite persistence in `contexts` table
|
||||
|
||||
**Context Scheduler** (`scheduler.py`) — Periodic aggregation
|
||||
|
||||
- Scheduled summarization from lower to higher layers
|
||||
- Configurable aggregation intervals
|
||||
|
||||
**Context Summarizer** (`summarizer.py`) — Layer summarization
|
||||
|
||||
- Aggregates lower-layer data into higher-layer summaries
|
||||
|
||||
### 7. Dashboard (`src/dashboard/`)
|
||||
|
||||
**FastAPI App** (`app.py`) — Read-only monitoring dashboard
|
||||
|
||||
- Runs as daemon thread when enabled (`--dashboard` CLI flag or `DASHBOARD_ENABLED=true`)
|
||||
- Configurable host/port (`DASHBOARD_HOST`, `DASHBOARD_PORT`, default `127.0.0.1:8080`)
|
||||
- Serves static HTML frontend
|
||||
|
||||
**8 API Endpoints:**
|
||||
|
||||
| Endpoint | Method | Description |
|
||||
|----------|--------|-------------|
|
||||
| `/` | GET | Static HTML dashboard |
|
||||
| `/api/status` | GET | Daily trading status by market |
|
||||
| `/api/playbook/{date}` | GET | Playbook for specific date and market |
|
||||
| `/api/scorecard/{date}` | GET | Daily scorecard from L6_DAILY context |
|
||||
| `/api/performance` | GET | Trading performance metrics (by market + combined) |
|
||||
| `/api/context/{layer}` | GET | Query context by layer (L1-L7) |
|
||||
| `/api/decisions` | GET | Decision log entries with outcomes |
|
||||
| `/api/scenarios/active` | GET | Today's matched scenarios |
|
||||
|
||||
### 8. Notifications (`src/notifications/telegram_client.py`)
|
||||
|
||||
**TelegramClient** — Real-time event notifications via Telegram Bot API
|
||||
|
||||
- Sends alerts for trades, circuit breakers, fat-finger rejections, system events
|
||||
- Non-blocking: failures are logged but never crash trading
|
||||
- Rate-limited: 1 message/second default to respect Telegram API limits
|
||||
- Auto-disabled when credentials missing
|
||||
|
||||
**TelegramCommandHandler** — Bidirectional command interface
|
||||
|
||||
- Long polling from Telegram API (configurable `TELEGRAM_POLLING_INTERVAL`)
|
||||
- 9 interactive commands: `/help`, `/status`, `/positions`, `/report`, `/scenarios`, `/review`, `/dashboard`, `/stop`, `/resume`
|
||||
- Authorization filtering by `TELEGRAM_CHAT_ID`
|
||||
- Enable/disable via `TELEGRAM_COMMANDS_ENABLED` (default: true)
|
||||
|
||||
**Notification Types:**
|
||||
- Trade execution (BUY/SELL with confidence)
|
||||
- Circuit breaker trips (critical alert)
|
||||
- Fat-finger protection triggers (order rejection)
|
||||
- Market open/close events
|
||||
- System startup/shutdown status
|
||||
- Playbook generation results
|
||||
- Stop-loss monitoring alerts
|
||||
|
||||
### 9. Evolution (`src/evolution/`)
|
||||
|
||||
**StrategyOptimizer** (`optimizer.py`) — Self-improvement loop
|
||||
|
||||
- Analyzes high-confidence losing trades from SQLite
|
||||
- Asks Gemini to generate new `BaseStrategy` subclasses
|
||||
@@ -61,11 +243,127 @@ Self-evolving AI trading agent for global stock markets via KIS (Korea Investmen
|
||||
- Simulates PR creation for human review
|
||||
- Only activates strategies that pass all tests
|
||||
|
||||
**DailyReview** (`daily_review.py`) — End-of-day review
|
||||
|
||||
- Generates comprehensive trade performance summary
|
||||
- Stores results in L6_DAILY context layer
|
||||
- Tracks win rate, P&L, confidence accuracy
|
||||
|
||||
**DailyScorecard** (`scorecard.py`) — Performance scoring
|
||||
|
||||
- Calculates daily metrics (trades, P&L, win rate, avg confidence)
|
||||
- Enables trend tracking across days
|
||||
|
||||
**Stop-Loss Monitoring** — Real-time position protection
|
||||
|
||||
- Monitors positions against stop-loss levels from playbook scenarios
|
||||
- Sends Telegram alerts when thresholds approached or breached
|
||||
|
||||
### 10. Decision Logger (`src/logging/decision_logger.py`)
|
||||
|
||||
**DecisionLogger** — Comprehensive audit trail
|
||||
|
||||
- Logs every trading decision with full context snapshot
|
||||
- Captures input data, rationale, confidence, and outcomes
|
||||
- Supports outcome tracking (P&L, accuracy) for post-analysis
|
||||
- Stored in `decision_logs` table with indexed queries
|
||||
- Review workflow support (reviewed flag, review notes)
|
||||
|
||||
### 11. Data Integration (`src/data/`)
|
||||
|
||||
**External Data Sources** (optional):
|
||||
|
||||
- `news_api.py` — News sentiment data
|
||||
- `market_data.py` — Extended market data
|
||||
- `economic_calendar.py` — Economic event calendar
|
||||
|
||||
### 12. Backup (`src/backup/`)
|
||||
|
||||
**Disaster Recovery** (see [docs/disaster_recovery.md](./disaster_recovery.md)):
|
||||
|
||||
- `scheduler.py` — Automated backup scheduling
|
||||
- `exporter.py` — Data export to various formats
|
||||
- `cloud_storage.py` — S3-compatible cloud backup
|
||||
- `health_monitor.py` — Backup integrity verification
|
||||
|
||||
## Data Flow
|
||||
|
||||
### Playbook Mode (Daily — Primary v2 Flow)
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Main Loop (60s cycle per stock, per market) │
|
||||
│ Pre-Market Phase (before market open) │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Pre-Market Planner │
|
||||
│ - 1 Gemini API call per market │
|
||||
│ - Generate scenario playbook │
|
||||
│ - Store in playbooks table │
|
||||
└──────────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Trading Hours (market open → close) │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Market Schedule Check │
|
||||
│ - Get open markets │
|
||||
│ - Filter by enabled markets │
|
||||
└──────────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Scenario Engine (local) │
|
||||
│ - Match live data vs playbook │
|
||||
│ - No AI calls needed │
|
||||
│ - Return matched scenarios │
|
||||
└──────────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Risk Manager: Validate Order │
|
||||
│ - Check circuit breaker │
|
||||
│ - Check fat-finger limit │
|
||||
└──────────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Broker: Execute Order │
|
||||
│ - Domestic: send_order() │
|
||||
│ - Overseas: send_overseas_order()│
|
||||
└──────────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Decision Logger + DB │
|
||||
│ - Full audit trail │
|
||||
│ - Context snapshot │
|
||||
│ - Telegram notification │
|
||||
└──────────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Post-Market Phase │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Daily Review + Scorecard │
|
||||
│ - Performance summary │
|
||||
│ - Store in L6_DAILY context │
|
||||
│ - Evolution learning │
|
||||
└──────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Realtime Mode (with Smart Scanner)
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Main Loop (60s cycle per market) │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
@@ -74,58 +372,69 @@ Self-evolving AI trading agent for global stock markets via KIS (Korea Investmen
|
||||
│ - Get open markets │
|
||||
│ - Filter by enabled markets │
|
||||
│ - Wait if all closed │
|
||||
└──────────────────┬────────────────┘
|
||||
└──────────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Smart Scanner (Python-first) │
|
||||
│ - Domestic: fluctuation rank │
|
||||
│ + volume rank bonus │
|
||||
│ + volatility-first scoring │
|
||||
│ - Overseas: ranking universe │
|
||||
│ + volatility-first scoring │
|
||||
│ - Fallback: dynamic universe │
|
||||
│ - Return top 3 qualified stocks │
|
||||
└──────────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ For Each Qualified Candidate │
|
||||
└──────────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Broker: Fetch Market Data │
|
||||
│ - Domestic: orderbook + balance │
|
||||
│ - Overseas: price + balance │
|
||||
└──────────────────┬────────────────┘
|
||||
└──────────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Calculate P&L │
|
||||
│ pnl_pct = (eval - cost) / cost │
|
||||
└──────────────────┬────────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Brain: Get Decision │
|
||||
│ Brain: Get Decision (AI) │
|
||||
│ - Build prompt with market data │
|
||||
│ - Call Gemini API │
|
||||
│ - Parse JSON response │
|
||||
│ - Return TradeDecision │
|
||||
└──────────────────┬────────────────┘
|
||||
└──────────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Risk Manager: Validate Order │
|
||||
│ - Check circuit breaker │
|
||||
│ - Check fat-finger limit │
|
||||
│ - Raise if validation fails │
|
||||
└──────────────────┬────────────────┘
|
||||
└──────────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Broker: Execute Order │
|
||||
│ - Domestic: send_order() │
|
||||
│ - Overseas: send_overseas_order() │
|
||||
└──────────────────┬────────────────┘
|
||||
│ - Overseas: send_overseas_order()│
|
||||
└──────────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Database: Log Trade │
|
||||
│ - SQLite (data/trades.db) │
|
||||
│ - Track: action, confidence, │
|
||||
│ rationale, market, exchange │
|
||||
└───────────────────────────────────┘
|
||||
│ Decision Logger + Notifications │
|
||||
│ - Log trade to SQLite │
|
||||
│ - selection_context (JSON) │
|
||||
│ - Telegram notification │
|
||||
└──────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Database Schema
|
||||
|
||||
**SQLite** (`src/db.py`)
|
||||
**SQLite** (`src/db.py`) — Database: `data/trades.db`
|
||||
|
||||
### trades
|
||||
```sql
|
||||
CREATE TABLE trades (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
@@ -137,12 +446,73 @@ CREATE TABLE trades (
|
||||
quantity INTEGER,
|
||||
price REAL,
|
||||
pnl REAL DEFAULT 0.0,
|
||||
market TEXT DEFAULT 'KR', -- KR | US_NASDAQ | JP | etc.
|
||||
exchange_code TEXT DEFAULT 'KRX' -- KRX | NASD | NYSE | etc.
|
||||
market TEXT DEFAULT 'KR',
|
||||
exchange_code TEXT DEFAULT 'KRX',
|
||||
selection_context TEXT, -- JSON: {rsi, volume_ratio, signal, score}
|
||||
decision_id TEXT -- Links to decision_logs
|
||||
);
|
||||
```
|
||||
|
||||
Auto-migration: Adds `market` and `exchange_code` columns if missing for backward compatibility.
|
||||
### contexts
|
||||
```sql
|
||||
CREATE TABLE contexts (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
layer TEXT NOT NULL, -- L1 through L7
|
||||
timeframe TEXT,
|
||||
key TEXT NOT NULL,
|
||||
value TEXT NOT NULL, -- JSON data
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
);
|
||||
-- Indices: idx_contexts_layer, idx_contexts_timeframe, idx_contexts_updated
|
||||
```
|
||||
|
||||
### decision_logs
|
||||
```sql
|
||||
CREATE TABLE decision_logs (
|
||||
decision_id TEXT PRIMARY KEY,
|
||||
timestamp TEXT NOT NULL,
|
||||
stock_code TEXT,
|
||||
market TEXT,
|
||||
exchange_code TEXT,
|
||||
action TEXT,
|
||||
confidence INTEGER,
|
||||
rationale TEXT,
|
||||
context_snapshot TEXT, -- JSON: full context at decision time
|
||||
input_data TEXT, -- JSON: market data used
|
||||
outcome_pnl REAL,
|
||||
outcome_accuracy REAL,
|
||||
reviewed INTEGER DEFAULT 0,
|
||||
review_notes TEXT
|
||||
);
|
||||
-- Indices: idx_decision_logs_timestamp, idx_decision_logs_reviewed, idx_decision_logs_confidence
|
||||
```
|
||||
|
||||
### playbooks
|
||||
```sql
|
||||
CREATE TABLE playbooks (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
date TEXT NOT NULL,
|
||||
market TEXT NOT NULL,
|
||||
status TEXT DEFAULT 'generated',
|
||||
playbook_json TEXT NOT NULL, -- Full playbook with scenarios
|
||||
generated_at TEXT NOT NULL,
|
||||
token_count INTEGER,
|
||||
scenario_count INTEGER,
|
||||
match_count INTEGER DEFAULT 0
|
||||
);
|
||||
-- Indices: idx_playbooks_date, idx_playbooks_market
|
||||
```
|
||||
|
||||
### context_metadata
|
||||
```sql
|
||||
CREATE TABLE context_metadata (
|
||||
layer TEXT PRIMARY KEY,
|
||||
description TEXT,
|
||||
retention_days INTEGER,
|
||||
aggregation_source TEXT
|
||||
);
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
@@ -157,13 +527,81 @@ KIS_APP_SECRET=your_app_secret
|
||||
KIS_ACCOUNT_NO=XXXXXXXX-XX
|
||||
GEMINI_API_KEY=your_gemini_key
|
||||
|
||||
# Optional
|
||||
# Optional — Trading Mode
|
||||
MODE=paper # paper | live
|
||||
TRADE_MODE=daily # daily | realtime
|
||||
DAILY_SESSIONS=4 # Sessions per day (daily mode only)
|
||||
SESSION_INTERVAL_HOURS=6 # Hours between sessions (daily mode only)
|
||||
|
||||
# Optional — Database
|
||||
DB_PATH=data/trades.db
|
||||
|
||||
# Optional — Risk
|
||||
CONFIDENCE_THRESHOLD=80
|
||||
MAX_LOSS_PCT=3.0
|
||||
MAX_ORDER_PCT=30.0
|
||||
ENABLED_MARKETS=KR,US_NASDAQ # Comma-separated market codes
|
||||
|
||||
# Optional — Markets
|
||||
ENABLED_MARKETS=KR,US # Comma-separated market codes
|
||||
RATE_LIMIT_RPS=2.0 # KIS API requests per second
|
||||
|
||||
# Optional — Pre-Market Planner (v2)
|
||||
PRE_MARKET_MINUTES=30 # Minutes before market open to generate playbook
|
||||
MAX_SCENARIOS_PER_STOCK=5 # Max scenarios per stock in playbook
|
||||
PLANNER_TIMEOUT_SECONDS=60 # Timeout for playbook generation
|
||||
DEFENSIVE_PLAYBOOK_ON_FAILURE=true # Fallback on AI failure
|
||||
RESCAN_INTERVAL_SECONDS=300 # Scenario rescan interval during trading
|
||||
|
||||
# Optional — Smart Scanner (realtime mode only)
|
||||
RSI_OVERSOLD_THRESHOLD=30 # 0-50, oversold threshold
|
||||
RSI_MOMENTUM_THRESHOLD=70 # 50-100, momentum threshold
|
||||
VOL_MULTIPLIER=2.0 # Minimum volume ratio (2.0 = 200%)
|
||||
SCANNER_TOP_N=3 # Max qualified candidates per scan
|
||||
|
||||
# Optional — Dashboard
|
||||
DASHBOARD_ENABLED=false # Enable FastAPI dashboard
|
||||
DASHBOARD_HOST=127.0.0.1 # Dashboard bind address
|
||||
DASHBOARD_PORT=8080 # Dashboard port (1-65535)
|
||||
|
||||
# Optional — Telegram
|
||||
TELEGRAM_BOT_TOKEN=1234567890:ABCdefGHIjklMNOpqrsTUVwxyz
|
||||
TELEGRAM_CHAT_ID=123456789
|
||||
TELEGRAM_ENABLED=true
|
||||
TELEGRAM_COMMANDS_ENABLED=true # Enable bidirectional commands
|
||||
TELEGRAM_POLLING_INTERVAL=1.0 # Command polling interval (seconds)
|
||||
|
||||
# Optional — Backup
|
||||
BACKUP_ENABLED=false
|
||||
BACKUP_DIR=data/backups
|
||||
S3_ENDPOINT_URL=...
|
||||
S3_ACCESS_KEY=...
|
||||
S3_SECRET_KEY=...
|
||||
S3_BUCKET_NAME=...
|
||||
S3_REGION=...
|
||||
|
||||
# Optional — External Data
|
||||
NEWS_API_KEY=...
|
||||
NEWS_API_PROVIDER=...
|
||||
MARKET_DATA_API_KEY=...
|
||||
|
||||
# Position Sizing (optional)
|
||||
POSITION_SIZING_ENABLED=true
|
||||
POSITION_BASE_ALLOCATION_PCT=5.0
|
||||
POSITION_MIN_ALLOCATION_PCT=1.0
|
||||
POSITION_MAX_ALLOCATION_PCT=10.0
|
||||
POSITION_VOLATILITY_TARGET_SCORE=50.0
|
||||
|
||||
# Legacy/compat scanner thresholds (kept for backward compatibility)
|
||||
RSI_OVERSOLD_THRESHOLD=30
|
||||
RSI_MOMENTUM_THRESHOLD=70
|
||||
VOL_MULTIPLIER=2.0
|
||||
|
||||
# Overseas Ranking API (optional override; account-dependent)
|
||||
OVERSEAS_RANKING_ENABLED=true
|
||||
OVERSEAS_RANKING_FLUCT_TR_ID=HHDFS76200100
|
||||
OVERSEAS_RANKING_VOLUME_TR_ID=HHDFS76200200
|
||||
OVERSEAS_RANKING_FLUCT_PATH=/uapi/overseas-price/v1/quotations/inquire-updown-rank
|
||||
OVERSEAS_RANKING_VOLUME_PATH=/uapi/overseas-price/v1/quotations/inquire-volume-rank
|
||||
```
|
||||
|
||||
Tests use in-memory SQLite (`DB_PATH=":memory:"`) and dummy credentials via `tests/conftest.py`.
|
||||
@@ -189,3 +627,17 @@ Tests use in-memory SQLite (`DB_PATH=":memory:"`) and dummy credentials via `tes
|
||||
- Wait until next market opens
|
||||
- Use `get_next_market_open()` to calculate wait time
|
||||
- Sleep until market open time
|
||||
|
||||
### Telegram API Errors
|
||||
- Log warning but continue trading
|
||||
- Missing credentials → auto-disable notifications
|
||||
- Network timeout → skip notification, no retry
|
||||
- Invalid token → log error, trading unaffected
|
||||
- Rate limit exceeded → queued via rate limiter
|
||||
|
||||
### Playbook Generation Failure
|
||||
- Timeout → fall back to defensive playbook (`DEFENSIVE_PLAYBOOK_ON_FAILURE`)
|
||||
- API error → use previous day's playbook if available
|
||||
- No playbook → skip pre-market phase, fall back to direct AI calls
|
||||
|
||||
**Guarantee**: Notification and dashboard failures never interrupt trading operations.
|
||||
|
||||
@@ -119,7 +119,7 @@ No decorator needed for async tests.
|
||||
# Install all dependencies (production + dev)
|
||||
pip install -e ".[dev]"
|
||||
|
||||
# Run full test suite with coverage
|
||||
# Run full test suite with coverage (551 tests across 25 files)
|
||||
pytest -v --cov=src --cov-report=term-missing
|
||||
|
||||
# Run a single test file
|
||||
@@ -137,11 +137,82 @@ mypy src/ --strict
|
||||
# Run the trading agent
|
||||
python -m src.main --mode=paper
|
||||
|
||||
# Run with dashboard enabled
|
||||
python -m src.main --mode=paper --dashboard
|
||||
|
||||
# Docker
|
||||
docker compose up -d ouroboros # Run agent
|
||||
docker compose --profile test up test # Run tests in container
|
||||
```
|
||||
|
||||
## Dashboard
|
||||
|
||||
The FastAPI dashboard provides read-only monitoring of the trading system.
|
||||
|
||||
### Starting the Dashboard
|
||||
|
||||
```bash
|
||||
# Via CLI flag
|
||||
python -m src.main --mode=paper --dashboard
|
||||
|
||||
# Via environment variable
|
||||
DASHBOARD_ENABLED=true python -m src.main --mode=paper
|
||||
```
|
||||
|
||||
Dashboard runs as a daemon thread on `DASHBOARD_HOST:DASHBOARD_PORT` (default: `127.0.0.1:8080`).
|
||||
|
||||
### API Endpoints
|
||||
|
||||
| Endpoint | Description |
|
||||
|----------|-------------|
|
||||
| `GET /` | HTML dashboard UI |
|
||||
| `GET /api/status` | Daily trading status by market |
|
||||
| `GET /api/playbook/{date}` | Playbook for specific date (query: `market`) |
|
||||
| `GET /api/scorecard/{date}` | Daily scorecard from L6_DAILY context |
|
||||
| `GET /api/performance` | Performance metrics by market and combined |
|
||||
| `GET /api/context/{layer}` | Context data by layer L1-L7 (query: `timeframe`) |
|
||||
| `GET /api/decisions` | Decision log entries (query: `limit`, `market`) |
|
||||
| `GET /api/scenarios/active` | Today's matched scenarios |
|
||||
|
||||
## Telegram Commands
|
||||
|
||||
When `TELEGRAM_COMMANDS_ENABLED=true` (default), the bot accepts these interactive commands:
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/help` | List available commands |
|
||||
| `/status` | Show trading status (mode, markets, P&L) |
|
||||
| `/positions` | Display account summary (balance, cash, P&L) |
|
||||
| `/report` | Daily summary metrics (trades, P&L, win rate) |
|
||||
| `/scenarios` | Show today's playbook scenarios |
|
||||
| `/review` | Display recent scorecards (L6_DAILY layer) |
|
||||
| `/dashboard` | Show dashboard URL if enabled |
|
||||
| `/stop` | Pause trading |
|
||||
| `/resume` | Resume trading |
|
||||
|
||||
Commands are only processed from the authorized `TELEGRAM_CHAT_ID`.
|
||||
|
||||
## KIS API TR_ID 참조 문서
|
||||
|
||||
**TR_ID를 추가하거나 수정할 때 반드시 공식 문서를 먼저 확인할 것.**
|
||||
|
||||
공식 문서: `docs/한국투자증권_오픈API_전체문서_20260221_030000.xlsx`
|
||||
|
||||
> ⚠️ 커뮤니티 블로그, GitHub 예제 등 비공식 자료의 TR_ID는 오래되거나 틀릴 수 있음.
|
||||
> 실제로 `VTTT1006U`(미국 매도 — 잘못됨)가 오랫동안 코드에 남아있던 사례가 있음 (Issue #189).
|
||||
|
||||
### 주요 TR_ID 목록
|
||||
|
||||
| 구분 | 모의투자 TR_ID | 실전투자 TR_ID | 시트명 |
|
||||
|------|---------------|---------------|--------|
|
||||
| 해외주식 매수 (미국) | `VTTT1002U` | `TTTT1002U` | 해외주식 주문 |
|
||||
| 해외주식 매도 (미국) | `VTTT1001U` | `TTTT1006U` | 해외주식 주문 |
|
||||
|
||||
새로운 TR_ID가 필요할 때:
|
||||
1. 위 xlsx 파일에서 해당 거래 유형의 시트를 찾는다.
|
||||
2. 모의투자(`VTTT`) / 실전투자(`TTTT`) 컬럼을 구분하여 정확한 값을 사용한다.
|
||||
3. 코드에 출처 주석을 남긴다: `# Source: 한국투자증권_오픈API_전체문서 — '<시트명>' 시트`
|
||||
|
||||
## Environment Setup
|
||||
|
||||
```bash
|
||||
|
||||
348
docs/disaster_recovery.md
Normal file
348
docs/disaster_recovery.md
Normal file
@@ -0,0 +1,348 @@
|
||||
# Disaster Recovery Guide
|
||||
|
||||
Complete guide for backing up and restoring The Ouroboros trading system.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Backup Strategy](#backup-strategy)
|
||||
- [Creating Backups](#creating-backups)
|
||||
- [Restoring from Backup](#restoring-from-backup)
|
||||
- [Health Monitoring](#health-monitoring)
|
||||
- [Export Formats](#export-formats)
|
||||
- [RTO/RPO](#rtorpo)
|
||||
- [Testing Recovery](#testing-recovery)
|
||||
|
||||
## Backup Strategy
|
||||
|
||||
The system implements a 3-tier backup retention policy:
|
||||
|
||||
| Policy | Frequency | Retention | Purpose |
|
||||
|--------|-----------|-----------|---------|
|
||||
| **Daily** | Every day | 30 days | Quick recovery from recent issues |
|
||||
| **Weekly** | Sunday | 1 year | Medium-term historical analysis |
|
||||
| **Monthly** | 1st of month | Forever | Long-term archival |
|
||||
|
||||
### Storage Structure
|
||||
|
||||
```
|
||||
data/backups/
|
||||
├── daily/ # Last 30 days
|
||||
├── weekly/ # Last 52 weeks
|
||||
└── monthly/ # Forever (cold storage)
|
||||
```
|
||||
|
||||
## Creating Backups
|
||||
|
||||
### Automated Backups (Recommended)
|
||||
|
||||
Set up a cron job to run daily:
|
||||
|
||||
```bash
|
||||
# Edit crontab
|
||||
crontab -e
|
||||
|
||||
# Run backup at 2 AM every day
|
||||
0 2 * * * cd /path/to/The-Ouroboros && ./scripts/backup.sh >> logs/backup.log 2>&1
|
||||
```
|
||||
|
||||
### Manual Backups
|
||||
|
||||
```bash
|
||||
# Run backup script
|
||||
./scripts/backup.sh
|
||||
|
||||
# Or use Python directly
|
||||
python3 -c "
|
||||
from pathlib import Path
|
||||
from src.backup.scheduler import BackupScheduler, BackupPolicy
|
||||
|
||||
scheduler = BackupScheduler('data/trade_logs.db', Path('data/backups'))
|
||||
metadata = scheduler.create_backup(BackupPolicy.DAILY, verify=True)
|
||||
print(f'Backup created: {metadata.file_path}')
|
||||
"
|
||||
```
|
||||
|
||||
### Export to Other Formats
|
||||
|
||||
```bash
|
||||
python3 -c "
|
||||
from pathlib import Path
|
||||
from src.backup.exporter import BackupExporter, ExportFormat
|
||||
|
||||
exporter = BackupExporter('data/trade_logs.db')
|
||||
results = exporter.export_all(
|
||||
Path('exports'),
|
||||
formats=[ExportFormat.JSON, ExportFormat.CSV],
|
||||
compress=True
|
||||
)
|
||||
"
|
||||
```
|
||||
|
||||
## Restoring from Backup
|
||||
|
||||
### Interactive Restoration
|
||||
|
||||
```bash
|
||||
./scripts/restore.sh
|
||||
```
|
||||
|
||||
The script will:
|
||||
1. List available backups
|
||||
2. Ask you to select one
|
||||
3. Create a safety backup of current database
|
||||
4. Restore the selected backup
|
||||
5. Verify database integrity
|
||||
|
||||
### Manual Restoration
|
||||
|
||||
```python
|
||||
from pathlib import Path
|
||||
from src.backup.scheduler import BackupScheduler
|
||||
|
||||
scheduler = BackupScheduler('data/trade_logs.db', Path('data/backups'))
|
||||
|
||||
# List backups
|
||||
backups = scheduler.list_backups()
|
||||
for backup in backups:
|
||||
print(f"{backup.timestamp}: {backup.file_path}")
|
||||
|
||||
# Restore specific backup
|
||||
scheduler.restore_backup(backups[0], verify=True)
|
||||
```
|
||||
|
||||
## Health Monitoring
|
||||
|
||||
### Check System Health
|
||||
|
||||
```python
|
||||
from pathlib import Path
|
||||
from src.backup.health_monitor import HealthMonitor
|
||||
|
||||
monitor = HealthMonitor('data/trade_logs.db', Path('data/backups'))
|
||||
|
||||
# Run all checks
|
||||
report = monitor.get_health_report()
|
||||
print(f"Overall status: {report['overall_status']}")
|
||||
|
||||
# Individual checks
|
||||
checks = monitor.run_all_checks()
|
||||
for name, result in checks.items():
|
||||
print(f"{name}: {result.status.value} - {result.message}")
|
||||
```
|
||||
|
||||
### Health Checks
|
||||
|
||||
The system monitors:
|
||||
|
||||
- **Database Health**: Accessibility, integrity, size
|
||||
- **Disk Space**: Available storage (alerts if < 10 GB)
|
||||
- **Backup Recency**: Ensures backups are < 25 hours old
|
||||
|
||||
### Health Status Levels
|
||||
|
||||
- **HEALTHY**: All systems operational
|
||||
- **DEGRADED**: Warning condition (e.g., low disk space)
|
||||
- **UNHEALTHY**: Critical issue (e.g., database corrupted, no backups)
|
||||
|
||||
## Export Formats
|
||||
|
||||
### JSON (Human-Readable)
|
||||
|
||||
```json
|
||||
{
|
||||
"export_timestamp": "2024-01-15T10:30:00Z",
|
||||
"record_count": 150,
|
||||
"trades": [
|
||||
{
|
||||
"timestamp": "2024-01-15T09:00:00Z",
|
||||
"stock_code": "005930",
|
||||
"action": "BUY",
|
||||
"quantity": 10,
|
||||
"price": 70000.0,
|
||||
"confidence": 85,
|
||||
"rationale": "Strong momentum",
|
||||
"pnl": 0.0
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### CSV (Analysis Tools)
|
||||
|
||||
Compatible with Excel, pandas, R:
|
||||
|
||||
```csv
|
||||
timestamp,stock_code,action,quantity,price,confidence,rationale,pnl
|
||||
2024-01-15T09:00:00Z,005930,BUY,10,70000.0,85,Strong momentum,0.0
|
||||
```
|
||||
|
||||
### Parquet (Big Data)
|
||||
|
||||
Columnar format for Spark, DuckDB:
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
df = pd.read_parquet('exports/trades_20240115.parquet')
|
||||
```
|
||||
|
||||
## RTO/RPO
|
||||
|
||||
### Recovery Time Objective (RTO)
|
||||
|
||||
**Target: < 5 minutes**
|
||||
|
||||
Time to restore trading operations:
|
||||
1. Identify backup to restore (1 min)
|
||||
2. Run restore script (2 min)
|
||||
3. Verify database integrity (1 min)
|
||||
4. Restart trading system (1 min)
|
||||
|
||||
### Recovery Point Objective (RPO)
|
||||
|
||||
**Target: < 24 hours**
|
||||
|
||||
Maximum acceptable data loss:
|
||||
- Daily backups ensure ≤ 24-hour data loss
|
||||
- For critical periods, run backups more frequently
|
||||
|
||||
## Testing Recovery
|
||||
|
||||
### Quarterly Recovery Test
|
||||
|
||||
Perform full disaster recovery test every quarter:
|
||||
|
||||
1. **Create test backup**
|
||||
```bash
|
||||
./scripts/backup.sh
|
||||
```
|
||||
|
||||
2. **Simulate disaster** (use test database)
|
||||
```bash
|
||||
cp data/trade_logs.db data/trade_logs_test.db
|
||||
rm data/trade_logs_test.db # Simulate data loss
|
||||
```
|
||||
|
||||
3. **Restore from backup**
|
||||
```bash
|
||||
DB_PATH=data/trade_logs_test.db ./scripts/restore.sh
|
||||
```
|
||||
|
||||
4. **Verify data integrity**
|
||||
```python
|
||||
import sqlite3
|
||||
conn = sqlite3.connect('data/trade_logs_test.db')
|
||||
cursor = conn.execute('SELECT COUNT(*) FROM trades')
|
||||
print(f"Restored {cursor.fetchone()[0]} trades")
|
||||
```
|
||||
|
||||
5. **Document results** in `logs/recovery_test_YYYYMMDD.md`
|
||||
|
||||
### Backup Verification
|
||||
|
||||
Always verify backups after creation:
|
||||
|
||||
```python
|
||||
from pathlib import Path
|
||||
from src.backup.scheduler import BackupScheduler
|
||||
|
||||
scheduler = BackupScheduler('data/trade_logs.db', Path('data/backups'))
|
||||
|
||||
# Create and verify
|
||||
metadata = scheduler.create_backup(BackupPolicy.DAILY, verify=True)
|
||||
print(f"Checksum: {metadata.checksum}") # Should not be None
|
||||
```
|
||||
|
||||
## Emergency Procedures
|
||||
|
||||
### Database Corrupted
|
||||
|
||||
1. Stop trading system immediately
|
||||
2. Check most recent backup age: `ls -lht data/backups/daily/`
|
||||
3. Restore: `./scripts/restore.sh`
|
||||
4. Verify: Run health check
|
||||
5. Resume trading
|
||||
|
||||
### Disk Full
|
||||
|
||||
1. Check disk space: `df -h`
|
||||
2. Clean old backups: Run cleanup manually
|
||||
```python
|
||||
from pathlib import Path
|
||||
from src.backup.scheduler import BackupScheduler
|
||||
scheduler = BackupScheduler('data/trade_logs.db', Path('data/backups'))
|
||||
scheduler.cleanup_old_backups()
|
||||
```
|
||||
3. Consider archiving old monthly backups to external storage
|
||||
4. Increase disk space if needed
|
||||
|
||||
### Lost All Backups
|
||||
|
||||
If local backups are lost:
|
||||
1. Check if exports exist in `exports/` directory
|
||||
2. Reconstruct database from CSV/JSON exports
|
||||
3. If no exports: Check broker API for trade history
|
||||
4. Manual reconstruction as last resort
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Test Restores Regularly**: Don't wait for disaster
|
||||
2. **Monitor Disk Space**: Set up alerts at 80% usage
|
||||
3. **Keep Multiple Generations**: Never delete all backups at once
|
||||
4. **Verify Checksums**: Always verify backup integrity
|
||||
5. **Document Changes**: Update this guide when backup strategy changes
|
||||
6. **Off-Site Storage**: Consider external backup for monthly archives
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Backup Script Fails
|
||||
|
||||
```bash
|
||||
# Check database file permissions
|
||||
ls -l data/trade_logs.db
|
||||
|
||||
# Check disk space
|
||||
df -h data/
|
||||
|
||||
# Run backup manually with debug
|
||||
python3 -c "
|
||||
import logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
from pathlib import Path
|
||||
from src.backup.scheduler import BackupScheduler, BackupPolicy
|
||||
scheduler = BackupScheduler('data/trade_logs.db', Path('data/backups'))
|
||||
scheduler.create_backup(BackupPolicy.DAILY, verify=True)
|
||||
"
|
||||
```
|
||||
|
||||
### Restore Fails Verification
|
||||
|
||||
```bash
|
||||
# Check backup file integrity
|
||||
python3 -c "
|
||||
import sqlite3
|
||||
conn = sqlite3.connect('data/backups/daily/trade_logs_daily_20240115.db')
|
||||
cursor = conn.execute('PRAGMA integrity_check')
|
||||
print(cursor.fetchone()[0])
|
||||
"
|
||||
```
|
||||
|
||||
### Health Check Fails
|
||||
|
||||
```python
|
||||
from pathlib import Path
|
||||
from src.backup.health_monitor import HealthMonitor
|
||||
|
||||
monitor = HealthMonitor('data/trade_logs.db', Path('data/backups'))
|
||||
|
||||
# Check each component individually
|
||||
print("Database:", monitor.check_database_health())
|
||||
print("Disk Space:", monitor.check_disk_space())
|
||||
print("Backup Recency:", monitor.check_backup_recency())
|
||||
```
|
||||
|
||||
## Contact
|
||||
|
||||
For backup/recovery issues:
|
||||
- Check logs: `logs/backup.log`
|
||||
- Review health status: Run health monitor
|
||||
- Raise issue on GitHub if automated recovery fails
|
||||
131
docs/live-trading-checklist.md
Normal file
131
docs/live-trading-checklist.md
Normal file
@@ -0,0 +1,131 @@
|
||||
# 실전 전환 체크리스트
|
||||
|
||||
모의 거래(paper)에서 실전(live)으로 전환하기 전에 아래 항목을 **순서대로** 모두 확인하세요.
|
||||
|
||||
---
|
||||
|
||||
## 1. 사전 조건
|
||||
|
||||
### 1-1. KIS OpenAPI 실전 계좌 준비
|
||||
- [ ] 한국투자증권 계좌 개설 완료 (일반 위탁 계좌)
|
||||
- [ ] OpenAPI 실전 사용 신청 (KIS 홈페이지 → Open API → 서비스 신청)
|
||||
- [ ] 실전용 APP_KEY / APP_SECRET 발급 완료
|
||||
- [ ] KIS_ACCOUNT_NO 형식 확인: `XXXXXXXX-XX` (8자리-2자리)
|
||||
|
||||
### 1-2. 리스크 파라미터 검토
|
||||
- [ ] `CIRCUIT_BREAKER_PCT` 확인: 기본값 -3.0% (더 엄격하게 조정 권장)
|
||||
- [ ] `FAT_FINGER_PCT` 확인: 기본값 30.0% (1회 주문 최대 잔고 대비 %)
|
||||
- [ ] `CONFIDENCE_THRESHOLD` 확인: BEARISH ≥ 90, NEUTRAL ≥ 80, BULLISH ≥ 75
|
||||
- [ ] 초기 투자금 결정 및 해외 주식 운용 한도 설정
|
||||
|
||||
### 1-3. 시스템 요건
|
||||
- [ ] 커버리지 80% 이상 유지 확인: `pytest --cov=src`
|
||||
- [ ] 타입 체크 통과: `mypy src/ --strict`
|
||||
- [ ] Lint 통과: `ruff check src/ tests/`
|
||||
|
||||
---
|
||||
|
||||
## 2. 환경 설정
|
||||
|
||||
### 2-1. `.env` 파일 수정
|
||||
|
||||
```bash
|
||||
# 1. KIS 실전 URL로 변경 (모의: openapivts 포트 29443)
|
||||
KIS_BASE_URL=https://openapi.koreainvestment.com:9443
|
||||
|
||||
# 2. 실전 APP_KEY / APP_SECRET으로 교체
|
||||
KIS_APP_KEY=<실전_APP_KEY>
|
||||
KIS_APP_SECRET=<실전_APP_SECRET>
|
||||
KIS_ACCOUNT_NO=<실전_계좌번호>
|
||||
|
||||
# 3. 모드를 live로 변경
|
||||
MODE=live
|
||||
|
||||
# 4. PAPER_OVERSEAS_CASH 비활성화 (live 모드에선 무시되지만 명시적으로 0 설정)
|
||||
PAPER_OVERSEAS_CASH=0
|
||||
```
|
||||
|
||||
> ⚠️ `KIS_BASE_URL` 포트 주의:
|
||||
> - **모의(VTS)**: `https://openapivts.koreainvestment.com:29443`
|
||||
> - **실전**: `https://openapi.koreainvestment.com:9443`
|
||||
|
||||
### 2-2. TR_ID 자동 분기 확인
|
||||
|
||||
아래 TR_ID는 `MODE` 값에 따라 코드에서 **자동으로 선택**됩니다.
|
||||
별도 설정 불필요하나, 문제 발생 시 아래 표를 참조하세요.
|
||||
|
||||
| 구분 | 모의 TR_ID | 실전 TR_ID |
|
||||
|------|-----------|-----------|
|
||||
| 국내 잔고 조회 | `VTTC8434R` | `TTTC8434R` |
|
||||
| 국내 현금 매수 | `VTTC0012U` | `TTTC0012U` |
|
||||
| 국내 현금 매도 | `VTTC0011U` | `TTTC0011U` |
|
||||
| 해외 잔고 조회 | `VTTS3012R` | `TTTS3012R` |
|
||||
| 해외 매수 | `VTTT1002U` | `TTTT1002U` |
|
||||
| 해외 매도 | `VTTT1001U` | `TTTT1006U` |
|
||||
|
||||
> **출처**: `docs/한국투자증권_오픈API_전체문서_20260221_030000.xlsx` (공식 문서 기준)
|
||||
|
||||
---
|
||||
|
||||
## 3. 최종 확인
|
||||
|
||||
### 3-1. 실전 시작 전 점검
|
||||
- [ ] DB 백업 완료: `data/trade_logs.db` → `data/backups/`
|
||||
- [ ] Telegram 알림 설정 확인 (실전에서는 알림이 더욱 중요)
|
||||
- [ ] 소액으로 첫 거래 진행 후 TR_ID/계좌 정상 동작 확인
|
||||
|
||||
### 3-2. 실행 명령
|
||||
|
||||
```bash
|
||||
# 실전 모드로 실행
|
||||
python -m src.main --mode=live
|
||||
|
||||
# 대시보드 함께 실행 (별도 터미널에서 모니터링)
|
||||
python -m src.main --mode=live --dashboard
|
||||
```
|
||||
|
||||
### 3-3. 실전 시작 직후 확인 사항
|
||||
- [ ] 로그에 `MODE=live` 출력 확인
|
||||
- [ ] 첫 잔고 조회 성공 (ConnectionError 없음)
|
||||
- [ ] Telegram 알림 수신 확인 ("System started")
|
||||
- [ ] 첫 주문 후 KIS 앱에서 체결 내역 확인
|
||||
|
||||
---
|
||||
|
||||
## 4. 비상 정지 방법
|
||||
|
||||
### 즉각 정지
|
||||
```bash
|
||||
# 터미널에서 Ctrl+C (정상 종료 트리거)
|
||||
# 또는 Telegram 봇 명령:
|
||||
/stop
|
||||
```
|
||||
|
||||
### Circuit Breaker 발동 시
|
||||
- CB가 발동되면 자동으로 거래 중단 및 Telegram 알림 전송
|
||||
- CB 임계값: `CIRCUIT_BREAKER_PCT` (기본 -3.0%)
|
||||
- **임계값은 엄격하게만 조정 가능** (더 낮은 음수 값으로만 변경)
|
||||
|
||||
---
|
||||
|
||||
## 5. 롤백 절차
|
||||
|
||||
실전 전환 후 문제 발생 시:
|
||||
|
||||
```bash
|
||||
# 1. 즉시 .env에서 MODE=paper로 복원
|
||||
# 2. 재시작
|
||||
python -m src.main --mode=paper
|
||||
|
||||
# 3. DB에서 최근 거래 확인
|
||||
sqlite3 data/trade_logs.db "SELECT * FROM trades ORDER BY id DESC LIMIT 20;"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 관련 문서
|
||||
|
||||
- [시스템 아키텍처](architecture.md)
|
||||
- [워크플로우 가이드](workflow.md)
|
||||
- [재해 복구](disaster_recovery.md)
|
||||
- [Agent 제약 조건](agents.md)
|
||||
56
docs/ouroboros/00_validation_system.md
Normal file
56
docs/ouroboros/00_validation_system.md
Normal file
@@ -0,0 +1,56 @@
|
||||
<!--
|
||||
Doc-ID: DOC-VAL-001
|
||||
Version: 1.0.0
|
||||
Status: active
|
||||
Owner: strategy
|
||||
Updated: 2026-02-26
|
||||
-->
|
||||
|
||||
# 문서 검증 시스템
|
||||
|
||||
본 문서는 문서 간 허위 내용, 수치 충돌, 구현 불가능 지시를 사전에 제거하기 위한 검증 규칙이다.
|
||||
|
||||
## 검증 목표
|
||||
|
||||
- 단일 진실원장 기준으로 모든 지시서의 수치/규칙 정합성 보장
|
||||
- 설계 문장과 코드 작업 지시 간 추적성 보장
|
||||
- 테스트 미정의 상태에서 구현 착수 금지
|
||||
|
||||
## 불일치 유형 정의
|
||||
|
||||
- `RULE-DOC-001`: 정의되지 않은 요구사항 ID 사용
|
||||
- `RULE-DOC-002`: 동일 요구사항 ID에 상충되는 값(예: 슬리피지 수치) 기술
|
||||
- `RULE-DOC-003`: 시간대 미표기 또는 KST/UTC 혼용 지시
|
||||
- `RULE-DOC-004`: 주문 정책과 리스크 정책 충돌(예: 저유동 세션 시장가 허용)
|
||||
- `RULE-DOC-005`: 구현 태스크에 테스트 ID 미연결
|
||||
- `RULE-DOC-006`: 문서 라우팅 링크 깨짐
|
||||
|
||||
## 검증 파이프라인
|
||||
|
||||
1. 정적 검사 (자동)
|
||||
- 대상: `docs/ouroboros/*.md`
|
||||
- 검사: 메타데이터, 링크 유효성, ID 정의/참조 일치, REQ-추적성 매핑
|
||||
- 도구: `scripts/validate_ouroboros_docs.py`
|
||||
|
||||
2. 추적성 검사 (자동 + 수동)
|
||||
- 자동: `REQ-*`가 최소 1개 `TASK-*`와 1개 `TEST-*`에 연결되었는지 확인
|
||||
- 수동: 정책 충돌 후보를 PR 체크리스트로 검토
|
||||
|
||||
3. 도메인 무결성 검사 (수동)
|
||||
- KIS 점검시간 회피, 주문 유형 강제, Kill Switch 순서, 환율 정책이 동시에 존재하는지 점검
|
||||
- 백테스트 체결가가 보수 가정인지 점검
|
||||
|
||||
## 변경 통제 규칙
|
||||
|
||||
- `REQ-*` 추가/수정 시 반드시 `01_requirements_registry.md` 먼저 변경
|
||||
- `TASK-*` 수정 시 반드시 `40_acceptance_and_test_plan.md`의 대응 테스트를 동시 수정
|
||||
- 충돌 발생 시 우선순위: `requirements_registry > phase execution > code work order`
|
||||
|
||||
적용 룰셋:
|
||||
- `RULE-DOC-001` `RULE-DOC-002` `RULE-DOC-003` `RULE-DOC-004` `RULE-DOC-005` `RULE-DOC-006`
|
||||
|
||||
## PR 게이트
|
||||
|
||||
- `python3 scripts/validate_ouroboros_docs.py` 성공
|
||||
- 신규/변경 `REQ-*`가 테스트 기준(`TEST-*`)과 연결됨
|
||||
- 원본 계획(v2/v3)과 모순 없음
|
||||
39
docs/ouroboros/01_requirements_registry.md
Normal file
39
docs/ouroboros/01_requirements_registry.md
Normal file
@@ -0,0 +1,39 @@
|
||||
<!--
|
||||
Doc-ID: DOC-REQ-001
|
||||
Version: 1.0.0
|
||||
Status: active
|
||||
Owner: strategy
|
||||
Updated: 2026-02-26
|
||||
-->
|
||||
|
||||
# 요구사항 원장 (Single Source of Truth)
|
||||
|
||||
이 문서의 ID가 계획/구현/테스트 전 문서에서 참조되는 유일한 요구사항 집합이다.
|
||||
|
||||
## v2 핵심 요구사항
|
||||
|
||||
- `REQ-V2-001`: 상태는 `HOLDING`, `BE_LOCK`, `ARMED`, `EXITED` 4단계여야 한다.
|
||||
- `REQ-V2-002`: 상태 전이는 매 틱/바 평가 시 최상위 상태로 즉시 승격되어야 한다.
|
||||
- `REQ-V2-003`: `EXITED` 조건은 모든 상태보다 우선 평가되어야 한다.
|
||||
- `REQ-V2-004`: 청산 로직은 Hard Stop, BE Lock, ATR Trailing, 모델 확률 보조 트리거를 포함해야 한다.
|
||||
- `REQ-V2-005`: 라벨링은 Triple Barrier(Upper/Lower/Time) 방식이어야 한다.
|
||||
- `REQ-V2-006`: 검증은 Walk-forward + Purge/Embargo를 강제한다.
|
||||
- `REQ-V2-007`: 백테스트는 비용/슬리피지/체결실패를 반영하지 않으면 채택 불가다.
|
||||
- `REQ-V2-008`: Kill Switch는 신규주문차단 -> 미체결취소 -> 재조회 -> 리스크축소 -> 스냅샷 순서다.
|
||||
|
||||
## v3 핵심 요구사항
|
||||
|
||||
- `REQ-V3-001`: 모든 신호/주문/로그는 `session_id`를 포함해야 한다.
|
||||
- `REQ-V3-002`: 세션 전환 시 리스크 파라미터 재로딩이 수행되어야 한다.
|
||||
- `REQ-V3-003`: 브로커 블랙아웃 시간대에는 신규 주문이 금지되어야 한다.
|
||||
- `REQ-V3-004`: 블랙아웃 중 신호는 Queue에 적재되고, 복구 후 유효성 재검증을 거친다.
|
||||
- `REQ-V3-005`: 저유동 세션(`NXT_AFTER`, `US_PRE`, `US_DAY`, `US_AFTER`)은 시장가 주문 금지다.
|
||||
- `REQ-V3-006`: 백테스트 체결가는 불리한 방향 체결 가정을 기본으로 한다.
|
||||
- `REQ-V3-007`: US 운용은 환율 손익 분리 추적과 통화 버퍼 정책을 포함해야 한다.
|
||||
- `REQ-V3-008`: 마감/오버나잇 규칙은 Kill Switch와 충돌 없이 연동되어야 한다.
|
||||
|
||||
## 공통 운영 요구사항
|
||||
|
||||
- `REQ-OPS-001`: 타임존은 모든 시간 필드에 명시(KST/UTC)되어야 한다.
|
||||
- `REQ-OPS-002`: 문서의 수치 정책은 원장에서만 변경한다.
|
||||
- `REQ-OPS-003`: 구현 태스크는 반드시 테스트 태스크를 동반한다.
|
||||
63
docs/ouroboros/10_phase_v2_execution.md
Normal file
63
docs/ouroboros/10_phase_v2_execution.md
Normal file
@@ -0,0 +1,63 @@
|
||||
<!--
|
||||
Doc-ID: DOC-PHASE-V2-001
|
||||
Version: 1.0.0
|
||||
Status: active
|
||||
Owner: strategy
|
||||
Updated: 2026-02-26
|
||||
-->
|
||||
|
||||
# v2 실행 지시서 (설계 -> 코드)
|
||||
|
||||
참조 요구사항: `REQ-V2-001` `REQ-V2-002` `REQ-V2-003` `REQ-V2-004` `REQ-V2-005` `REQ-V2-006` `REQ-V2-007` `REQ-V2-008` `REQ-OPS-001` `REQ-OPS-002` `REQ-OPS-003`
|
||||
|
||||
## 단계 1: 도메인 모델 확정
|
||||
|
||||
- `TASK-V2-001`: 상태머신 enum/전이 이벤트/전이 사유 스키마 설계
|
||||
- `TASK-V2-002`: `position_state` 스냅샷 구조(현재상태, peak, stops, last_reason) 정의
|
||||
- `TASK-V2-003`: 청산 판단 입력 DTO(가격, ATR, pred_prob, liquidity_signal) 정의
|
||||
|
||||
완료 기준:
|
||||
- 상태와 전이 사유가 로그/DB에서 재현 가능
|
||||
- `REQ-V2-001`~`003`을 코드 타입 수준에서 강제
|
||||
|
||||
## 단계 2: 청산 엔진 구현
|
||||
|
||||
- `TASK-V2-004`: 우선순위 기반 전이 함수 구현(`evaluate_exit_first` -> `promote_state`)
|
||||
- `TASK-V2-005`: Hard Stop/BE Lock/ATR Trailing 결합 로직 구현
|
||||
- `TASK-V2-006`: 모델 확률 신호를 보조 트리거로 결합(단독 청산 금지)
|
||||
|
||||
완료 기준:
|
||||
- 갭 상황에서 다중 조건 동시 충족 시 최상위 상태로 단번 전이
|
||||
- `REQ-V2-004` 준수
|
||||
|
||||
## 단계 3: 라벨링/학습 데이터 파이프라인
|
||||
|
||||
- `TASK-V2-007`: Triple Barrier 라벨러 구현(장벽 선터치 우선)
|
||||
- `TASK-V2-008`: 피처 구간/라벨 구간 분리 검증 유틸 구현
|
||||
- `TASK-V2-009`: 라벨 생성 로그(진입시각, 터치장벽, 만기장벽) 기록
|
||||
|
||||
완료 기준:
|
||||
- look-ahead 차단 증빙 로그 확보
|
||||
- `REQ-V2-005` 충족
|
||||
|
||||
## 단계 4: 검증 프레임워크
|
||||
|
||||
- `TASK-V2-010`: Walk-forward split + Purge/Embargo 분할기 구현
|
||||
- `TASK-V2-011`: 베이스라인(`B0`,`B1`,`M1`) 비교 리포트 포맷 구현
|
||||
- `TASK-V2-012`: 체결 비용/슬리피지/실패 반영 백테스트 옵션 강제
|
||||
|
||||
완료 기준:
|
||||
- `REQ-V2-006`, `REQ-V2-007` 충족
|
||||
|
||||
## 단계 5: Kill Switch 통합
|
||||
|
||||
- `TASK-V2-013`: Kill Switch 순차 실행 오케스트레이터 구현 (`src/core/risk_manager.py` 수정 금지)
|
||||
- `TASK-V2-014`: 주문 차단 플래그/미체결 취소/재조회 재시도 로직 구현
|
||||
- `TASK-V2-015`: 스냅샷/알림/복구 진입 절차 구현
|
||||
|
||||
완료 기준:
|
||||
- `REQ-V2-008` 순서 일치
|
||||
|
||||
라우팅:
|
||||
- 코드 지시 상세: [30_code_level_work_orders.md](./30_code_level_work_orders.md)
|
||||
- 테스트 상세: [40_acceptance_and_test_plan.md](./40_acceptance_and_test_plan.md)
|
||||
60
docs/ouroboros/20_phase_v3_execution.md
Normal file
60
docs/ouroboros/20_phase_v3_execution.md
Normal file
@@ -0,0 +1,60 @@
|
||||
<!--
|
||||
Doc-ID: DOC-PHASE-V3-001
|
||||
Version: 1.0.0
|
||||
Status: active
|
||||
Owner: strategy
|
||||
Updated: 2026-02-26
|
||||
-->
|
||||
|
||||
# v3 실행 지시서 (세션 확장)
|
||||
|
||||
참조 요구사항: `REQ-V3-001` `REQ-V3-002` `REQ-V3-003` `REQ-V3-004` `REQ-V3-005` `REQ-V3-006` `REQ-V3-007` `REQ-V3-008` `REQ-OPS-001` `REQ-OPS-002` `REQ-OPS-003`
|
||||
|
||||
## 단계 1: 세션 엔진
|
||||
|
||||
- `TASK-V3-001`: `session_id` 분류기 구현(KR/US 확장 세션)
|
||||
- `TASK-V3-002`: 세션 전환 훅에서 리스크 파라미터 재로딩 구현
|
||||
- `TASK-V3-003`: 로그/DB 스키마에 `session_id` 필드 강제
|
||||
|
||||
완료 기준:
|
||||
- `REQ-V3-001`, `REQ-V3-002` 충족
|
||||
|
||||
## 단계 2: 블랙아웃/복구 제어
|
||||
|
||||
- `TASK-V3-004`: 블랙아웃 윈도우 정책 로더 구현(설정 기반)
|
||||
- `TASK-V3-005`: 블랙아웃 중 신규 주문 차단 + 의도 큐 적재 구현
|
||||
- `TASK-V3-006`: 복구 시 동기화(잔고/미체결/체결) 후 큐 재검증 실행
|
||||
|
||||
완료 기준:
|
||||
- `REQ-V3-003`, `REQ-V3-004` 충족
|
||||
|
||||
## 단계 3: 주문 정책 강화
|
||||
|
||||
- `TASK-V3-007`: 세션별 주문 타입 매트릭스 구현
|
||||
- `TASK-V3-008`: 저유동 세션 시장가 주문 하드 차단
|
||||
- `TASK-V3-009`: 재호가 간격/횟수 제한 및 주문 철회 조건 구현
|
||||
|
||||
완료 기준:
|
||||
- `REQ-V3-005` 충족
|
||||
|
||||
## 단계 4: 비용/체결 모델 정교화
|
||||
|
||||
- `TASK-V3-010`: 세션별 슬리피지/비용 테이블 엔진 반영
|
||||
- `TASK-V3-011`: 불리한 체결 가정(상대 호가 방향) 체결기 구현
|
||||
- `TASK-V3-012`: 시나리오별 체결 실패/부분체결 모델 반영
|
||||
|
||||
완료 기준:
|
||||
- `REQ-V3-006` 충족
|
||||
|
||||
## 단계 5: 환율/오버나잇/Kill Switch 연동
|
||||
|
||||
- `TASK-V3-013`: 전략 PnL과 FX PnL 분리 회계 구현
|
||||
- `TASK-V3-014`: USD/KRW 버퍼 규칙 위반 시 신규 진입 제한 구현
|
||||
- `TASK-V3-015`: 오버나잇 예외와 Kill Switch 우선순위 통합
|
||||
|
||||
완료 기준:
|
||||
- `REQ-V3-007`, `REQ-V3-008` 충족
|
||||
|
||||
라우팅:
|
||||
- 코드 지시 상세: [30_code_level_work_orders.md](./30_code_level_work_orders.md)
|
||||
- 테스트 상세: [40_acceptance_and_test_plan.md](./40_acceptance_and_test_plan.md)
|
||||
59
docs/ouroboros/30_code_level_work_orders.md
Normal file
59
docs/ouroboros/30_code_level_work_orders.md
Normal file
@@ -0,0 +1,59 @@
|
||||
<!--
|
||||
Doc-ID: DOC-CODE-001
|
||||
Version: 1.0.0
|
||||
Status: active
|
||||
Owner: strategy
|
||||
Updated: 2026-02-26
|
||||
-->
|
||||
|
||||
# 코드 레벨 작업 지시서
|
||||
|
||||
본 문서는 파일 단위 구현 지시서다. 모든 작업은 요구사항 ID와 테스트 ID를 포함해야 한다.
|
||||
|
||||
제약:
|
||||
- `src/core/risk_manager.py`는 READ-ONLY로 간주하고 수정하지 않는다.
|
||||
- Kill Switch는 별도 모듈(예: `src/core/kill_switch.py`)로 추가하고 상위 실행 루프에서 연동한다.
|
||||
|
||||
## 구현 단위 A: 상태기계/청산
|
||||
|
||||
- `TASK-CODE-001` (`REQ-V2-001`,`REQ-V2-002`,`REQ-V2-003`): `src/strategy/`에 상태기계 모듈 추가
|
||||
- `TASK-CODE-002` (`REQ-V2-004`): ATR/BE/Hard Stop 결합 청산 함수 추가
|
||||
- `TASK-CODE-003` (`REQ-V2-008`): Kill Switch 오케스트레이터를 `src/core/kill_switch.py`에 추가
|
||||
- `TEST-CODE-001`: 갭 점프 시 최고상태 승격 테스트
|
||||
- `TEST-CODE-002`: EXIT 우선순위 테스트
|
||||
|
||||
## 구현 단위 B: 라벨링/검증
|
||||
|
||||
- `TASK-CODE-004` (`REQ-V2-005`): Triple Barrier 라벨러 모듈 추가(`src/analysis/` 또는 `src/strategy/`)
|
||||
- `TASK-CODE-005` (`REQ-V2-006`): Walk-forward + Purge/Embargo 분할 유틸 추가
|
||||
- `TASK-CODE-006` (`REQ-V2-007`): 백테스트 실행기에서 비용/슬리피지 옵션 필수화
|
||||
- `TEST-CODE-003`: 라벨 선터치 우선 테스트
|
||||
- `TEST-CODE-004`: 누수 차단 테스트
|
||||
|
||||
## 구현 단위 C: 세션/주문 정책
|
||||
|
||||
- `TASK-CODE-007` (`REQ-V3-001`,`REQ-V3-002`): 세션 분류/전환 훅을 `src/markets/schedule.py` 연동
|
||||
- `TASK-CODE-008` (`REQ-V3-003`,`REQ-V3-004`): 블랙아웃 큐 처리기를 `src/broker/`에 추가
|
||||
- `TASK-CODE-009` (`REQ-V3-005`): 세션별 주문 타입 검증기 추가
|
||||
- `TEST-CODE-005`: 블랙아웃 신규주문 차단 테스트
|
||||
- `TEST-CODE-006`: 저유동 세션 시장가 거부 테스트
|
||||
|
||||
## 구현 단위 D: 체결/환율/오버나잇
|
||||
|
||||
- `TASK-CODE-010` (`REQ-V3-006`): 불리한 체결가 모델을 백테스트 체결기로 구현
|
||||
- `TASK-CODE-011` (`REQ-V3-007`): FX PnL 분리 회계 테이블/컬럼 추가
|
||||
- `TASK-CODE-012` (`REQ-V3-008`): 오버나잇 예외와 Kill Switch 충돌 해소 로직 구현
|
||||
- `TEST-CODE-007`: 불리한 체결가 모델 테스트
|
||||
- `TEST-CODE-008`: FX 버퍼 위반 시 신규진입 제한 테스트
|
||||
|
||||
## 구현 단위 E: 운영/문서 거버넌스
|
||||
|
||||
- `TASK-OPS-001` (`REQ-OPS-001`): 시간 필드/로그 스키마의 타임존 표기 강제 규칙 구현
|
||||
- `TASK-OPS-002` (`REQ-OPS-002`): 정책 수치 변경 시 `01_requirements_registry.md` 선수정 CI 체크 추가
|
||||
- `TASK-OPS-003` (`REQ-OPS-003`): `TASK-*` 없는 `REQ-*` 또는 `TEST-*` 없는 `REQ-*`를 차단하는 문서 검증 게이트 유지
|
||||
|
||||
## 커밋 규칙
|
||||
|
||||
- 커밋 메시지에 `TASK-*` 포함
|
||||
- PR 본문에 `REQ-*`, `TEST-*` 매핑 표 포함
|
||||
- 변경 파일마다 최소 1개 테스트 연결
|
||||
57
docs/ouroboros/40_acceptance_and_test_plan.md
Normal file
57
docs/ouroboros/40_acceptance_and_test_plan.md
Normal file
@@ -0,0 +1,57 @@
|
||||
<!--
|
||||
Doc-ID: DOC-TEST-001
|
||||
Version: 1.0.0
|
||||
Status: active
|
||||
Owner: strategy
|
||||
Updated: 2026-02-26
|
||||
-->
|
||||
|
||||
# 수용 기준 및 테스트 계획
|
||||
|
||||
## 수용 기준
|
||||
|
||||
- `TEST-ACC-000` (`REQ-V2-001`): 상태 enum은 4개(`HOLDING`,`BE_LOCK`,`ARMED`,`EXITED`)만 허용한다.
|
||||
- `TEST-ACC-001` (`REQ-V2-002`): 상태 전이는 순차 if-else가 아닌 우선순위 승격으로 동작한다.
|
||||
- `TEST-ACC-010` (`REQ-V2-003`): `EXITED` 조건은 어떤 상태보다 먼저 평가된다.
|
||||
- `TEST-ACC-011` (`REQ-V2-004`): 청산 판단은 Hard Stop/BE Lock/ATR/모델보조 4요소를 모두 포함한다.
|
||||
- `TEST-ACC-012` (`REQ-V2-005`): Triple Barrier 라벨은 first-touch 규칙으로 결정된다.
|
||||
- `TEST-ACC-013` (`REQ-V2-006`): 학습/검증 분할은 Walk-forward + Purge/Embargo를 적용한다.
|
||||
- `TEST-ACC-014` (`REQ-V2-007`): 비용/슬리피지/체결실패 옵션 비활성 시 백테스트 실행을 거부한다.
|
||||
- `TEST-ACC-002` (`REQ-V2-008`): Kill Switch 실행 순서가 고정 순서를 위반하지 않는다.
|
||||
- `TEST-ACC-015` (`REQ-V3-001`): 모든 주문/로그 레코드에 `session_id`가 저장된다.
|
||||
- `TEST-ACC-016` (`REQ-V3-002`): 세션 전환 이벤트 시 리스크 파라미터가 재로딩된다.
|
||||
- `TEST-ACC-003` (`REQ-V3-003`): 블랙아웃 중 신규 주문 API 호출이 발생하지 않는다.
|
||||
- `TEST-ACC-017` (`REQ-V3-004`): 블랙아웃 큐는 복구 후 재검증을 통과한 주문만 실행한다.
|
||||
- `TEST-ACC-004` (`REQ-V3-005`): 저유동 세션 시장가 주문은 항상 거부된다.
|
||||
- `TEST-ACC-005` (`REQ-V3-006`): 백테스트 체결가가 단순 종가 체결보다 보수적 손익을 낸다.
|
||||
- `TEST-ACC-006` (`REQ-V3-007`): 전략 손익과 환율 손익이 별도 집계된다.
|
||||
- `TEST-ACC-018` (`REQ-V3-008`): 오버나잇 예외 상태에서도 Kill Switch 우선순위가 유지된다.
|
||||
- `TEST-ACC-007` (`REQ-OPS-001`): 시간 관련 필드는 타임존(KST/UTC)이 누락되면 검증 실패한다.
|
||||
- `TEST-ACC-008` (`REQ-OPS-002`): 정책 수치 변경이 원장 미반영이면 검증 실패한다.
|
||||
- `TEST-ACC-009` (`REQ-OPS-003`): `REQ-*`가 `TASK-*`/`TEST-*` 매핑 없이 존재하면 검증 실패한다.
|
||||
|
||||
## 테스트 계층
|
||||
|
||||
1. 단위 테스트
|
||||
- 상태 전이, 주문타입 검증, 큐 복구 로직, 체결가 모델
|
||||
|
||||
2. 통합 테스트
|
||||
- 세션 전환 -> 주문 정책 -> 리스크 엔진 연동
|
||||
- 블랙아웃 시작/해제 이벤트 연동
|
||||
|
||||
3. 회귀 테스트
|
||||
- 기존 `tests/` 스위트 전량 실행
|
||||
- 신규 기능 플래그 ON/OFF 비교
|
||||
|
||||
## 실행 명령
|
||||
|
||||
```bash
|
||||
pytest -q
|
||||
python3 scripts/validate_ouroboros_docs.py
|
||||
```
|
||||
|
||||
## 실패 처리 규칙
|
||||
|
||||
- 문서 검증 실패 시 구현 PR 병합 금지
|
||||
- `REQ-*` 변경 후 테스트 매핑 누락 시 병합 금지
|
||||
- 회귀 실패 시 원인 모듈 분리 후 재검증
|
||||
68
docs/ouroboros/50_scenario_matrix_and_issue_taxonomy.md
Normal file
68
docs/ouroboros/50_scenario_matrix_and_issue_taxonomy.md
Normal file
@@ -0,0 +1,68 @@
|
||||
<!--
|
||||
Doc-ID: DOC-PM-001
|
||||
Version: 1.0.0
|
||||
Status: active
|
||||
Owner: strategy
|
||||
Updated: 2026-02-26
|
||||
-->
|
||||
|
||||
# 실전 시나리오 매트릭스 + 이슈 분류 체계
|
||||
|
||||
목표: 운영에서 바로 사용할 수 있는 형태로 Happy Path / Failure Path / Ops Incident를 추적 가능한 ID 체계(`REQ-*`, `TASK-*`, `TEST-*`)에 매핑한다.
|
||||
|
||||
## 1) 시나리오 매트릭스
|
||||
|
||||
| Scenario ID | Type | Trigger | Expected System Behavior | Primary IDs (REQ/TASK/TEST) | Ticket Priority |
|
||||
|---|---|---|---|---|---|
|
||||
| `SCN-HAPPY-001` | Happy Path | KR 정규 세션에서 진입 신호 발생, 블랙아웃 아님 | 주문/로그에 `session_id` 저장 후 정책에 맞는 주문 전송 | `REQ-V3-001`, `TASK-V3-001`, `TASK-V3-003`, `TEST-ACC-015` | P1 |
|
||||
| `SCN-HAPPY-002` | Happy Path | 보유 포지션에서 BE/ATR/Hard Stop 조건 순차 도달 | 상태가 즉시 상위 단계로 승격, `EXITED` 우선 평가 보장 | `REQ-V2-002`, `REQ-V2-003`, `TASK-V2-004`, `TEST-ACC-001`, `TEST-ACC-010` | P0 |
|
||||
| `SCN-HAPPY-003` | Happy Path | 세션 전환(KR->US) 이벤트 발생 | 리스크 파라미터 자동 재로딩, 새 세션 정책으로 즉시 전환 | `REQ-V3-002`, `TASK-V3-002`, `TEST-ACC-016` | P0 |
|
||||
| `SCN-HAPPY-004` | Happy Path | 백테스트 실행 요청 | 비용/슬리피지/체결실패 옵션 누락 시 실행 거부, 포함 시 실행 | `REQ-V2-007`, `TASK-V2-012`, `TEST-ACC-014` | P1 |
|
||||
| `SCN-FAIL-001` | Failure Path | 블랙아웃 중 신규 주문 신호 발생 | 신규 주문 차단 + 주문 의도 큐 적재, API 직접 호출 금지 | `REQ-V3-003`, `REQ-V3-004`, `TASK-V3-005`, `TEST-ACC-003`, `TEST-ACC-017` | P0 |
|
||||
| `SCN-FAIL-002` | Failure Path | 저유동 세션에 시장가 주문 요청 | 시장가 하드 거부, 지정가 대체 또는 주문 취소 | `REQ-V3-005`, `TASK-V3-007`, `TASK-V3-008`, `TEST-ACC-004` | P0 |
|
||||
| `SCN-FAIL-003` | Failure Path | Kill Switch 트리거(손실/연결/리스크 한도) | 신규주문차단->미체결취소->재조회->리스크축소->스냅샷 순서 강제 | `REQ-V2-008`, `TASK-V2-013`, `TEST-ACC-002` | P0 |
|
||||
| `SCN-FAIL-004` | Failure Path | FX 버퍼 부족 상태에서 US 진입 신호 | 전략 PnL/FX PnL 분리 집계 유지, 신규 진입 제한 | `REQ-V3-007`, `TASK-V3-013`, `TASK-V3-014`, `TEST-ACC-006` | P1 |
|
||||
| `SCN-OPS-001` | Ops Incident | 브로커 점검/블랙아웃 종료 직후 | 잔고/미체결/체결 동기화 후 큐 재검증 통과 주문만 집행 | `REQ-V3-004`, `TASK-V3-006`, `TEST-ACC-017` | P0 |
|
||||
| `SCN-OPS-002` | Ops Incident | 정책 수치가 코드에만 반영되고 원장 미수정 | 문서 검증에서 실패 처리, PR 병합 차단 | `REQ-OPS-002`, `TASK-OPS-002`, `TEST-ACC-008` | P0 |
|
||||
| `SCN-OPS-003` | Ops Incident | 타임존 누락 로그/스케줄 데이터 유입 | KST/UTC 미표기 레코드 검증 실패 처리 | `REQ-OPS-001`, `TASK-OPS-001`, `TEST-ACC-007` | P1 |
|
||||
| `SCN-OPS-004` | Ops Incident | 신규 REQ 추가 후 TASK/TEST 누락 | 추적성 게이트 실패, 구현 PR 병합 차단 | `REQ-OPS-003`, `TASK-OPS-003`, `TEST-ACC-009` | P0 |
|
||||
| `SCN-OPS-005` | Ops Incident | 배포 후 런타임 이상 동작(주문오류/상태전이오류/정책위반) 탐지 | Runtime Verifier가 즉시 이슈 발행, Dev 수정 후 재관측으로 클로즈 판정 | `REQ-V2-008`, `REQ-V3-003`, `REQ-V3-005`, `TEST-ACC-002`, `TEST-ACC-003`, `TEST-ACC-004` | P0 |
|
||||
|
||||
## 2) 이슈 분류 체계 (Issue Taxonomy)
|
||||
|
||||
| Taxonomy | Definition | Typical Symptoms | Default Owner | Mapping Baseline |
|
||||
|---|---|---|---|---|
|
||||
| `EXEC-STATE` | 상태기계/청산 우선순위 위반 | EXIT 우선순위 깨짐, 상태 역행, 갭 대응 실패 | Strategy | `REQ-V2-001`~`REQ-V2-004`, `TASK-V2-004`~`TASK-V2-006`, `TEST-ACC-000`,`001`,`010`,`011` |
|
||||
| `EXEC-POLICY` | 세션/주문 정책 위반 | 블랙아웃 주문 전송, 저유동 시장가 허용 | Broker/Execution | `REQ-V3-003`~`REQ-V3-005`, `TASK-V3-004`~`TASK-V3-009`, `TEST-ACC-003`,`004`,`017` |
|
||||
| `BACKTEST-MODEL` | 백테스트 현실성/검증 무결성 위반 | 비용 옵션 off로 실행, 체결가 과낙관 | Research | `REQ-V2-006`,`REQ-V2-007`,`REQ-V3-006`, `TASK-V2-010`~`012`, `TASK-V3-010`~`012`, `TEST-ACC-013`,`014`,`005` |
|
||||
| `RISK-EMERGENCY` | Kill Switch/리스크 비상 대응 실패 | 순서 위반, 차단 누락, 복구 절차 누락 | Risk | `REQ-V2-008`,`REQ-V3-008`, `TASK-V2-013`~`015`, `TASK-V3-015`, `TEST-ACC-002`,`018` |
|
||||
| `FX-ACCOUNTING` | 환율/통화 버퍼 정책 위반 | 전략손익/환차손익 혼합 집계, 버퍼 미적용 | Risk + Data | `REQ-V3-007`, `TASK-V3-013`,`014`, `TEST-ACC-006` |
|
||||
| `OPS-GOVERNANCE` | 문서/추적성/타임존 거버넌스 위반 | 원장 미수정, TEST 누락, 타임존 미표기 | PM + QA | `REQ-OPS-001`~`003`, `TASK-OPS-001`~`003`, `TEST-ACC-007`~`009` |
|
||||
| `RUNTIME-VERIFY` | 실동작 모니터링 검증 | 배포 후 이상 현상, 간헐 오류, 테스트 미포착 회귀 | Runtime Verifier + TPM | 관련 `REQ/TASK/TEST`와 런타임 로그 증적 필수 |
|
||||
|
||||
## 3) 티켓 생성 규칙 (Implementable)
|
||||
|
||||
1. 모든 이슈는 `taxonomy + scenario_id`를 제목에 포함한다.
|
||||
예: `[EXEC-POLICY][SCN-FAIL-001] blackout 주문 차단 누락`
|
||||
2. 본문 필수 항목: 재현절차, 기대결과, 실제결과, 영향범위, 롤백/완화책.
|
||||
3. 본문에 최소 1개 `REQ-*`, 1개 `TASK-*`, 1개 `TEST-*`를 명시한다.
|
||||
4. 우선순위 기준:
|
||||
- P0: 실주문 위험, Kill Switch, 블랙아웃/시장가 정책, 추적성 게이트 실패
|
||||
- P1: 손익 왜곡 가능성(체결/FX/시간대), 운영 리스크 증가
|
||||
- P2: 보고서/관측성 품질 이슈(거래 안전성 영향 없음)
|
||||
5. Runtime Verifier가 발행한 `RUNTIME-VERIFY` 이슈는 Main Agent 확인 전 클로즈 금지.
|
||||
|
||||
## 4) 즉시 생성 권장 티켓 (초기 백로그)
|
||||
|
||||
- `TKT-P0-001`: `[EXEC-POLICY][SCN-FAIL-001]` 블랙아웃 차단 + 큐적재 + 복구 재검증 e2e 점검 (`REQ-V3-003`,`REQ-V3-004`)
|
||||
- `TKT-P0-002`: `[RISK-EMERGENCY][SCN-FAIL-003]` Kill Switch 순서 강제 검증 자동화 (`REQ-V2-008`)
|
||||
- `TKT-P0-003`: `[OPS-GOVERNANCE][SCN-OPS-004]` REQ/TASK/TEST 누락 시 PR 차단 게이트 상시 점검 (`REQ-OPS-003`)
|
||||
- `TKT-P1-001`: `[FX-ACCOUNTING][SCN-FAIL-004]` FX 버퍼 위반 시 진입 제한 회귀 케이스 보강 (`REQ-V3-007`)
|
||||
- `TKT-P1-002`: `[BACKTEST-MODEL][SCN-HAPPY-004]` 비용/슬리피지 미설정 백테스트 거부 UX 명확화 (`REQ-V2-007`)
|
||||
- `TKT-P0-004`: `[RUNTIME-VERIFY][SCN-OPS-005]` 배포 후 런타임 이상 탐지/재현/클로즈 판정 절차 자동화
|
||||
|
||||
## 5) 운영 체크포인트
|
||||
|
||||
- 스프린트 계획 시 `P0` 시나리오 100% 테스트 통과를 출발 조건으로 둔다.
|
||||
- 배포 승인 시 `SCN-FAIL-*`, `SCN-OPS-*` 관련 `TEST-ACC-*`를 우선 확인한다.
|
||||
- 정책 변경 PR은 반드시 원장(`01_requirements_registry.md`) 선수정 후 진행한다.
|
||||
192
docs/ouroboros/50_tpm_control_protocol.md
Normal file
192
docs/ouroboros/50_tpm_control_protocol.md
Normal file
@@ -0,0 +1,192 @@
|
||||
<!--
|
||||
Doc-ID: DOC-TPM-001
|
||||
Version: 1.0.0
|
||||
Status: active
|
||||
Owner: tpm
|
||||
Updated: 2026-02-26
|
||||
-->
|
||||
|
||||
# TPM Control Protocol (Main <-> PM <-> TPM <-> Dev <-> Verifier <-> Runtime Verifier)
|
||||
|
||||
목적:
|
||||
- PM 시나리오가 구현 가능한 단위로 분해되고, 개발/검증이 동일 ID 체계(`REQ-*`, `TASK-*`, `TEST-*`)로 닫히도록 강제한다.
|
||||
- 각 단계는 Entry/Exit gate를 통과해야 다음 단계로 이동 가능하다.
|
||||
- 주요 의사결정 포인트마다 Main Agent의 승인/의견 확인을 강제한다.
|
||||
|
||||
## Team Roles
|
||||
|
||||
- Main Agent: 최종 취합/우선순위/승인 게이트 오너
|
||||
- PM Agent: 시나리오/요구사항/티켓 관리
|
||||
- TPM Agent: PM-Dev-검증 간 구현 가능성/달성률 통제, 티켓 등록 및 구현 우선순위 지정 오너
|
||||
- Dev Agent: 구현 수행, 블로커 발생 시 재계획 요청
|
||||
- Verifier Agent: 문서/코드/테스트 산출물 검증
|
||||
- Runtime Verifier Agent: 실제 동작 모니터링, 이상 징후 이슈 발행, 수정 후 이슈 클로즈 판정
|
||||
|
||||
Main Agent 아이디에이션 책임:
|
||||
- 진행 중 신규 구현 아이디어를 별도 문서에 누적 기록한다.
|
||||
- 기록 위치: [70_main_agent_ideation.md](./70_main_agent_ideation.md)
|
||||
- 각 항목은 `IDEA-*` 식별자, 배경, 기대효과, 리스크, 후속 티켓 후보를 포함해야 한다.
|
||||
|
||||
## Main Decision Checkpoints (Mandatory)
|
||||
|
||||
- DCP-01 범위 확정: Phase 0 종료 전 Main Agent 승인 필수
|
||||
- DCP-02 요구사항 확정: Phase 1 종료 전 Main Agent 승인 필수
|
||||
- DCP-03 구현 착수: Phase 2 종료 전 Main Agent 승인 필수
|
||||
- DCP-04 배포 승인: Phase 4 종료 후 Main Agent 최종 승인 필수
|
||||
|
||||
## Phase Control Gates
|
||||
|
||||
### Phase 0: Scenario Intake and Scope Lock
|
||||
|
||||
Entry criteria:
|
||||
- PM 시나리오가 사용자 가치, 실패 모드, 우선순위를 포함해 제출됨
|
||||
- 영향 범위(모듈/세션/KR-US 시장)가 명시됨
|
||||
|
||||
Exit criteria:
|
||||
- 시나리오가 `REQ-*` 후보에 1:1 또는 1:N 매핑됨
|
||||
- 모호한 표현("개선", "최적화")은 측정 가능한 조건으로 치환됨
|
||||
- 비범위 항목(out-of-scope) 명시
|
||||
|
||||
Control checks:
|
||||
- PM/TPM 합의 완료
|
||||
- Main Agent 승인(DCP-01)
|
||||
- 산출물: 시나리오 카드, 초기 매핑 메모
|
||||
|
||||
### Phase 1: Requirement Registry Gate
|
||||
|
||||
Entry criteria:
|
||||
- Phase 0 산출물 승인
|
||||
- 변경 대상 요구사항 문서 식별 완료
|
||||
|
||||
Exit criteria:
|
||||
- [01_requirements_registry.md](./01_requirements_registry.md)에 `REQ-*` 정의/수정 반영
|
||||
- 각 `REQ-*`가 최소 1개 `TASK-*`, 1개 `TEST-*`와 연결 가능 상태
|
||||
- 시간/정책 수치는 원장 단일 소스로 확정(`REQ-OPS-001`,`REQ-OPS-002`)
|
||||
|
||||
Control checks:
|
||||
- `python3 scripts/validate_ouroboros_docs.py` 통과
|
||||
- Main Agent 승인(DCP-02)
|
||||
- 산출물: 업데이트된 요구사항 원장
|
||||
|
||||
### Phase 2: Design and Work-Order Gate
|
||||
|
||||
Entry criteria:
|
||||
- 요구사항 원장 갱신 완료
|
||||
- 영향 모듈 분석 완료(상태기계, 주문정책, 백테스트, 세션)
|
||||
|
||||
Exit criteria:
|
||||
- [10_phase_v2_execution.md](./10_phase_v2_execution.md), [20_phase_v3_execution.md](./20_phase_v3_execution.md), [30_code_level_work_orders.md](./30_code_level_work_orders.md)에 작업 분해 완료
|
||||
- 각 작업은 구현 위치/제약/완료 조건을 가짐
|
||||
- 위험 작업(Kill Switch, blackout, session transition)은 별도 롤백 절차 포함
|
||||
|
||||
Control checks:
|
||||
- TPM이 `REQ -> TASK` 누락 여부 검토
|
||||
- Main Agent 승인(DCP-03)
|
||||
- 산출물: 승인된 Work Order 세트
|
||||
|
||||
### Phase 3: Implementation Gate
|
||||
|
||||
Entry criteria:
|
||||
- 승인된 `TASK-*`가 브랜치 작업 단위로 분리됨
|
||||
- 변경 범위별 테스트 계획이 PR 본문에 링크됨
|
||||
|
||||
Exit criteria:
|
||||
- 코드 변경이 `TASK-*`에 대응되어 추적 가능
|
||||
- 제약 준수(`src/core/risk_manager.py` 직접 수정 금지 등) 확인
|
||||
- 신규 로직마다 최소 1개 테스트 추가 또는 기존 테스트 확장
|
||||
|
||||
Control checks:
|
||||
- PR 템플릿 내 `REQ-*`/`TASK-*`/`TEST-*` 매핑 확인
|
||||
- 산출물: 리뷰 가능한 PR
|
||||
|
||||
### Phase 4: Verification and Acceptance Gate
|
||||
|
||||
Entry criteria:
|
||||
- 구현 PR ready 상태
|
||||
- 테스트 케이스/픽스처 준비 완료
|
||||
|
||||
Exit criteria:
|
||||
- [40_acceptance_and_test_plan.md](./40_acceptance_and_test_plan.md)의 해당 `TEST-ACC-*` 전부 통과
|
||||
- 회귀 테스트 통과(`pytest -q`)
|
||||
- 문서 검증 통과(`python3 scripts/validate_ouroboros_docs.py`)
|
||||
|
||||
Control checks:
|
||||
- Verifier가 테스트 증적(로그/리포트/실행 커맨드) 첨부
|
||||
- Runtime Verifier가 스테이징/실운영 모니터링 계획 승인
|
||||
- 산출물: 수용 승인 레코드
|
||||
|
||||
### Phase 5: Release and Post-Release Control
|
||||
|
||||
Entry criteria:
|
||||
- Phase 4 승인
|
||||
- 운영 체크리스트 준비(세션 전환, 블랙아웃, Kill Switch)
|
||||
|
||||
Exit criteria:
|
||||
- 배포 후 초기 관찰 윈도우에서 치명 경보 없음
|
||||
- 신규 시나리오/회귀 이슈는 다음 Cycle의 Phase 0 입력으로 환류
|
||||
- 요구사항/테스트 문서 버전 동기화 완료
|
||||
|
||||
Control checks:
|
||||
- PM/TPM/Dev 3자 종료 확인
|
||||
- Runtime Verifier가 운영 모니터링 이슈 상태(신규/진행/해결)를 리포트
|
||||
- Main Agent 최종 승인(DCP-04)
|
||||
- 산출물: 릴리즈 노트 + 후속 액션 목록
|
||||
|
||||
## Replan Protocol (Dev -> TPM)
|
||||
|
||||
- 트리거:
|
||||
- 구현 불가능(기술적 제약/외부 API 제약)
|
||||
- 예상 대비 개발 리소스 과다(공수/인력/의존성 급증)
|
||||
- 절차:
|
||||
1) Dev Agent가 `REPLAN-REQUEST` 발행(영향 REQ/TASK, 원인, 대안, 추가 공수 포함)
|
||||
2) TPM Agent가 1차 심사(범위 축소/단계 분할/요구사항 조정안)
|
||||
3) Verifier/PM 의견 수렴 후 Main Agent 승인으로 재계획 확정
|
||||
- 규칙:
|
||||
- Main Agent 승인 없는 재계획은 실행 금지
|
||||
- 재계획 반영 시 문서(`REQ/TASK/TEST`) 동시 갱신 필수
|
||||
|
||||
TPM 티켓 운영 규칙:
|
||||
- TPM은 합의된 변경을 이슈로 등록하고 우선순위(`P0/P1/P2`)를 지정한다.
|
||||
- PR 본문에는 TPM이 지정한 우선순위와 범위가 그대로 반영되어야 한다.
|
||||
- 우선순위 변경은 TPM 제안 + Main Agent 승인으로만 가능하다.
|
||||
|
||||
## Runtime Verification Protocol
|
||||
|
||||
- Runtime Verifier는 테스트 통과 이후 실제 동작(스테이징/실운영)을 모니터링한다.
|
||||
- 이상 동작/현상 발견 시 즉시 이슈 발행:
|
||||
- 제목 규칙: `[RUNTIME-VERIFY][SCN-*] ...`
|
||||
- 본문 필수: 재현조건, 관측 로그, 영향 범위, 임시 완화책, 관련 `REQ/TASK/TEST`
|
||||
- 이슈 클로즈 규칙:
|
||||
- Dev 수정 완료 + Verifier 재검증 통과 + Runtime Verifier 재관측 정상
|
||||
- 최종 클로즈 승인자는 Main Agent
|
||||
|
||||
## Server Reflection Rule (No-Merge by Default)
|
||||
|
||||
- 서버 반영 기본 규칙은 `브랜치 푸시 + PR 생성/코멘트`까지로 제한한다.
|
||||
- 기본 흐름에서 검증 승인 후 자동/수동 머지 실행은 금지한다.
|
||||
- 예외는 사용자 명시 승인 시에만 허용되며, Main Agent가 예외 근거를 PR에 기록한다.
|
||||
|
||||
## Acceptance Matrix (PM Scenario -> Dev Tasks -> Verifier Checks)
|
||||
|
||||
| PM Scenario | Requirement Coverage | Dev Tasks (Primary) | Verifier Checks (Must Pass) |
|
||||
|---|---|---|---|
|
||||
| 갭 급락/급등에서 청산 우선 처리 필요 | `REQ-V2-001`,`REQ-V2-002`,`REQ-V2-003` | `TASK-V2-004`,`TASK-CODE-001` | `TEST-ACC-000`,`TEST-ACC-001`,`TEST-ACC-010`,`TEST-CODE-001`,`TEST-CODE-002` |
|
||||
| 하드스탑 + BE락 + ATR + 모델보조를 한 엔진으로 통합 | `REQ-V2-004` | `TASK-V2-005`,`TASK-V2-006`,`TASK-CODE-002` | `TEST-ACC-011` |
|
||||
| 라벨 누수 없는 학습데이터 생성 | `REQ-V2-005` | `TASK-V2-007`,`TASK-CODE-004` | `TEST-ACC-012`,`TEST-CODE-003` |
|
||||
| 검증 프레임워크를 시계열 누수 방지 구조로 강제 | `REQ-V2-006` | `TASK-V2-010`,`TASK-CODE-005` | `TEST-ACC-013`,`TEST-CODE-004` |
|
||||
| 과낙관 백테스트 방지(비용/슬리피지/실패 강제) | `REQ-V2-007` | `TASK-V2-012`,`TASK-CODE-006` | `TEST-ACC-014` |
|
||||
| 장애 시 Kill Switch 실행 순서 고정 | `REQ-V2-008` | `TASK-V2-013`,`TASK-V2-014`,`TASK-V2-015`,`TASK-CODE-003` | `TEST-ACC-002`,`TEST-ACC-018` |
|
||||
| 세션 전환 단위 리스크/로그 추적 일관화 | `REQ-V3-001`,`REQ-V3-002` | `TASK-V3-001`,`TASK-V3-002`,`TASK-V3-003`,`TASK-CODE-007` | `TEST-ACC-015`,`TEST-ACC-016` |
|
||||
| 블랙아웃 중 주문 차단 + 복구 후 재검증 실행 | `REQ-V3-003`,`REQ-V3-004` | `TASK-V3-004`,`TASK-V3-005`,`TASK-V3-006`,`TASK-CODE-008` | `TEST-ACC-003`,`TEST-ACC-017`,`TEST-CODE-005` |
|
||||
| 저유동 세션 시장가 주문 금지 | `REQ-V3-005` | `TASK-V3-007`,`TASK-V3-008`,`TASK-CODE-009` | `TEST-ACC-004`,`TEST-CODE-006` |
|
||||
| 보수적 체결 모델을 백테스트 기본으로 설정 | `REQ-V3-006` | `TASK-V3-010`,`TASK-V3-011`,`TASK-V3-012`,`TASK-CODE-010` | `TEST-ACC-005`,`TEST-CODE-007` |
|
||||
| 전략손익/환율손익 분리 + 통화 버퍼 통제 | `REQ-V3-007` | `TASK-V3-013`,`TASK-V3-014`,`TASK-CODE-011` | `TEST-ACC-006`,`TEST-CODE-008` |
|
||||
| 오버나잇 규칙과 Kill Switch 충돌 방지 | `REQ-V3-008` | `TASK-V3-015`,`TASK-CODE-012` | `TEST-ACC-018` |
|
||||
| 타임존/정책변경/추적성 문서 거버넌스 | `REQ-OPS-001`,`REQ-OPS-002`,`REQ-OPS-003` | `TASK-OPS-001`,`TASK-OPS-002`,`TASK-OPS-003` | `TEST-ACC-007`,`TEST-ACC-008`,`TEST-ACC-009` |
|
||||
|
||||
## 운영 규율 (TPM Enforcement Rules)
|
||||
|
||||
- 어떤 PM 시나리오도 `REQ-*` 없는 구현 착수 금지.
|
||||
- 어떤 `REQ-*`도 `TASK-*`,`TEST-*` 없는 승인 금지.
|
||||
- Verifier는 "코드 리뷰 통과"만으로 승인 불가, 반드시 `TEST-ACC-*` 증적 필요.
|
||||
- 배포 승인권자는 Phase 4 체크리스트 미충족 시 릴리즈 보류 권한을 행사해야 한다.
|
||||
94
docs/ouroboros/60_repo_enforcement_checklist.md
Normal file
94
docs/ouroboros/60_repo_enforcement_checklist.md
Normal file
@@ -0,0 +1,94 @@
|
||||
<!--
|
||||
Doc-ID: DOC-OPS-002
|
||||
Version: 1.0.0
|
||||
Status: active
|
||||
Owner: tpm
|
||||
Updated: 2026-02-26
|
||||
-->
|
||||
|
||||
# 저장소 강제 설정 체크리스트
|
||||
|
||||
목표: "엄격 검증 운영"을 문서가 아니라 저장소 설정으로 강제한다.
|
||||
|
||||
## 1) main 브랜치 보호 (필수)
|
||||
|
||||
적용 항목:
|
||||
- direct push 금지
|
||||
- force push 금지
|
||||
- branch 삭제 금지
|
||||
- merge는 PR 경로만 허용
|
||||
|
||||
검증:
|
||||
- `main`에 대해 직접 `git push origin main` 시 거부되는지 확인
|
||||
|
||||
## 2) 필수 상태 체크 (필수)
|
||||
|
||||
필수 CI 항목:
|
||||
- `validate_ouroboros_docs` (명령: `python3 scripts/validate_ouroboros_docs.py`)
|
||||
- `test` (명령: `pytest -q`)
|
||||
|
||||
설정 기준:
|
||||
- 위 2개 체크가 `success` 아니면 머지 금지
|
||||
- 체크 스킵/중립 상태 허용 금지
|
||||
|
||||
## 3) 필수 리뷰어 규칙 (권장 -> 필수)
|
||||
|
||||
역할 기반 승인:
|
||||
- Verifier 1명 승인 필수
|
||||
- TPM 또는 PM 1명 승인 필수
|
||||
- Runtime Verifier 관련 변경(PR 본문에 runtime 영향 있음) 시 Runtime Verifier 승인 필수
|
||||
|
||||
설정 기준:
|
||||
- 최소 승인 수: 2
|
||||
- 작성자 self-approval 불가
|
||||
- 새 커밋 푸시 시 기존 승인 재검토 요구
|
||||
|
||||
## 4) 워크플로우 게이트
|
||||
|
||||
병합 전 체크리스트:
|
||||
- 이슈 연결(`Closes #N`) 존재
|
||||
- PR 본문에 `REQ-*`, `TASK-*`, `TEST-*` 매핑 표 존재
|
||||
- `src/core/risk_manager.py` 변경 없음
|
||||
- 주요 의사결정 체크포인트(DCP-01~04) 중 해당 단계 Main Agent 확인 기록 존재
|
||||
|
||||
자동 점검:
|
||||
- 문서 검증 스크립트 통과
|
||||
- 테스트 통과
|
||||
|
||||
## 5) 감사 추적
|
||||
|
||||
필수 보존 증적:
|
||||
- CI 실행 로그 링크
|
||||
- 검증 실패/복구 기록
|
||||
- 머지 승인 코멘트(Verifier/TPM)
|
||||
|
||||
분기별 점검:
|
||||
- 브랜치 보호 규칙 drift 여부
|
||||
- 필수 CI 이름 변경/누락 여부
|
||||
|
||||
## 6) 적용 순서 (운영 절차)
|
||||
|
||||
1. 브랜치 보호 활성화
|
||||
2. 필수 CI 체크 연결
|
||||
3. 리뷰어 규칙 적용
|
||||
4. 샘플 PR로 거부 시나리오 테스트
|
||||
5. 정상 머지 시나리오 테스트
|
||||
|
||||
## 7) 실패 시 조치
|
||||
|
||||
- 브랜치 보호 미적용 발견 시: 즉시 릴리즈 중지
|
||||
- 필수 CI 우회 발견 시: 관리자 권한 점검 및 감사 이슈 발행
|
||||
- 리뷰 규칙 무효화 발견 시: 규칙 복구 후 재머지 정책 시행
|
||||
- Runtime 이상 이슈 미해결 상태에서 클로즈 시도 발견 시: 즉시 이슈 재오픈 + 릴리즈 중지
|
||||
|
||||
## 8) 재계획(Dev Replan) 운영 규칙
|
||||
|
||||
- Dev가 `REPLAN-REQUEST` 발행 시 TPM 심사 없이는 스코프/일정 변경 금지
|
||||
- `REPLAN-REQUEST`는 Main Agent 승인 전 \"제안\" 상태로 유지
|
||||
- 승인된 재계획은 `REQ/TASK/TEST` 문서를 동시 갱신해야 유효
|
||||
|
||||
## 9) 서버 반영 규칙 (No-Merge by Default)
|
||||
|
||||
- 서버 반영은 `브랜치 푸시 + PR 코멘트(리뷰/논의/검증승인)`까지를 기본으로 한다.
|
||||
- 기본 규칙에서 `tea pulls merge` 실행은 금지한다.
|
||||
- 사용자 명시 승인 시에만 예외적으로 머지를 허용한다(예외 근거를 PR 코멘트에 기록).
|
||||
48
docs/ouroboros/70_main_agent_ideation.md
Normal file
48
docs/ouroboros/70_main_agent_ideation.md
Normal file
@@ -0,0 +1,48 @@
|
||||
<!--
|
||||
Doc-ID: DOC-IDEA-001
|
||||
Version: 1.0.0
|
||||
Status: active
|
||||
Owner: main-agent
|
||||
Updated: 2026-02-26
|
||||
-->
|
||||
|
||||
# 메인 에이전트 아이디에이션 백로그
|
||||
|
||||
목적:
|
||||
- 구현 진행 중 떠오른 신규 구현 아이디어를 계획 반영 전 임시 저장한다.
|
||||
- 본 문서는 사용자 검토 후 다음 계획 포함 여부를 결정하기 위한 검토 큐다.
|
||||
|
||||
운영 규칙:
|
||||
- 각 아이디어는 `IDEA-*` 식별자를 사용한다.
|
||||
- 필수 필드: 배경, 기대효과, 리스크, 후속 티켓 후보.
|
||||
- 상태는 `proposed`, `under-review`, `accepted`, `rejected` 중 하나를 사용한다.
|
||||
|
||||
## 아이디어 목록
|
||||
|
||||
- `IDEA-001` (status: proposed)
|
||||
- 제목: Kill-Switch 전역 상태를 프로세스 단일 전역에서 시장/세션 단위 상태로 분리
|
||||
- 배경: 현재는 전역 block 플래그 기반이라 시장별 분리 제어가 제한될 수 있음
|
||||
- 기대효과: KR/US 병행 운용 시 한 시장 장애가 다른 시장 주문을 불필요하게 막는 리스크 축소
|
||||
- 리스크: 상태 동기화 복잡도 증가, 테스트 케이스 확장 필요
|
||||
- 후속 티켓 후보: `TKT-P1-KS-SCOPE-SPLIT`
|
||||
|
||||
- `IDEA-002` (status: proposed)
|
||||
- 제목: Exit Engine 입력 계약(ATR/peak/model_prob/liquidity) 표준 DTO를 데이터 파이프라인에 고정
|
||||
- 배경: 현재 ATR/모델확률 일부가 fallback 기반이라 운영 일관성이 약함
|
||||
- 기대효과: 백테스트-실거래 입력 동형성 강화, 회귀 분석 용이
|
||||
- 리스크: 기존 스캐너/시나리오 엔진 연동 작업량 증가
|
||||
- 후속 티켓 후보: `TKT-P1-EXIT-CONTRACT`
|
||||
|
||||
- `IDEA-003` (status: proposed)
|
||||
- 제목: Runtime Verifier 자동 이슈 생성기(로그 패턴 -> 이슈 템플릿 자동화)
|
||||
- 배경: 런타임 이상 리포트가 수동 작성 중심이라 누락 가능성 존재
|
||||
- 기대효과: 이상 탐지 후 이슈 등록 리드타임 단축, 증적 표준화
|
||||
- 리스크: 오탐 이슈 폭증 가능성, 필터링 룰 필요
|
||||
- 후속 티켓 후보: `TKT-P1-RUNTIME-AUTO-ISSUE`
|
||||
|
||||
- `IDEA-004` (status: proposed)
|
||||
- 제목: PR 코멘트 워크플로우 자동 점검(리뷰어->개발논의->검증승인 누락 차단)
|
||||
- 배경: 현재 절차는 강력하지만 수행 확인이 수동
|
||||
- 기대효과: 절차 누락 방지, 감사 추적 자동화
|
||||
- 리스크: CLI/API 연동 유지보수 비용
|
||||
- 후속 티켓 후보: `TKT-P0-WORKFLOW-GUARD`
|
||||
40
docs/ouroboros/README.md
Normal file
40
docs/ouroboros/README.md
Normal file
@@ -0,0 +1,40 @@
|
||||
<!--
|
||||
Doc-ID: DOC-ROOT-001
|
||||
Version: 1.0.0
|
||||
Status: active
|
||||
Owner: strategy
|
||||
Updated: 2026-02-26
|
||||
-->
|
||||
|
||||
# The Ouroboros 실행 문서 허브
|
||||
|
||||
이 폴더는 `ouroboros_plan_v2.txt`, `ouroboros_plan_v3.txt`를 구현 가능한 작업 지시서 수준으로 분해한 문서 허브다.
|
||||
|
||||
## 읽기 순서 (Routing)
|
||||
|
||||
1. 검증 체계부터 확정: [00_validation_system.md](./00_validation_system.md)
|
||||
2. 단일 진실원장(요구사항): [01_requirements_registry.md](./01_requirements_registry.md)
|
||||
3. v2 실행 지시서: [10_phase_v2_execution.md](./10_phase_v2_execution.md)
|
||||
4. v3 실행 지시서: [20_phase_v3_execution.md](./20_phase_v3_execution.md)
|
||||
5. 코드 레벨 작업 지시: [30_code_level_work_orders.md](./30_code_level_work_orders.md)
|
||||
6. 수용 기준/테스트 계획: [40_acceptance_and_test_plan.md](./40_acceptance_and_test_plan.md)
|
||||
7. PM 시나리오/이슈 분류: [50_scenario_matrix_and_issue_taxonomy.md](./50_scenario_matrix_and_issue_taxonomy.md)
|
||||
8. TPM 제어 프로토콜/수용 매트릭스: [50_tpm_control_protocol.md](./50_tpm_control_protocol.md)
|
||||
9. 저장소 강제 설정 체크리스트: [60_repo_enforcement_checklist.md](./60_repo_enforcement_checklist.md)
|
||||
10. 메인 에이전트 아이디에이션 백로그: [70_main_agent_ideation.md](./70_main_agent_ideation.md)
|
||||
|
||||
## 운영 규칙
|
||||
|
||||
- 계획 변경은 반드시 `01_requirements_registry.md`의 ID 정의부터 수정한다.
|
||||
- 구현 문서는 원장 ID만 참조하고 자체 숫자/정책을 새로 만들지 않는다.
|
||||
- 문서 품질 룰셋(`RULE-DOC-001` `RULE-DOC-002` `RULE-DOC-003` `RULE-DOC-004` `RULE-DOC-005` `RULE-DOC-006`)은 [00_validation_system.md](./00_validation_system.md)를 기준으로 적용한다.
|
||||
- 문서 병합 전 아래 검증을 통과해야 한다.
|
||||
|
||||
```bash
|
||||
python3 scripts/validate_ouroboros_docs.py
|
||||
```
|
||||
|
||||
## 원본 계획 문서
|
||||
|
||||
- [v2](/home/agentson/repos/The-Ouroboros/ouroboros_plan_v2.txt)
|
||||
- [v3](/home/agentson/repos/The-Ouroboros/ouroboros_plan_v3.txt)
|
||||
357
docs/requirements-log.md
Normal file
357
docs/requirements-log.md
Normal file
@@ -0,0 +1,357 @@
|
||||
# Requirements Log
|
||||
|
||||
프로젝트 진화를 위한 사용자 요구사항 기록.
|
||||
|
||||
이 문서는 시간순으로 사용자와의 대화에서 나온 요구사항과 피드백을 기록합니다.
|
||||
새로운 요구사항이 있으면 날짜와 함께 추가하세요.
|
||||
|
||||
---
|
||||
|
||||
## 2026-02-21
|
||||
|
||||
### 거래 상태 확인 중 발견된 버그 (#187)
|
||||
|
||||
- 거래 상태 점검 요청 → SELL 주문(손절/익절)이 Fat Finger에 막혀 전혀 실행 안 됨 발견
|
||||
- **#187 (Critical)**: SELL 주문에서 Fat Finger 오탐 — `order_amount/total_cash > 30%`가 SELL에도 적용되어 대형 포지션 매도 불가
|
||||
- JELD stop-loss -6.20% → 차단, RXT take-profit +46.13% → 차단
|
||||
- 수정: SELL은 `check_circuit_breaker`만 호출, `validate_order`(Fat Finger 포함) 미호출
|
||||
|
||||
---
|
||||
|
||||
## 2026-02-20
|
||||
|
||||
### 지속적 모니터링 및 개선점 도출 (이슈 #178~#182)
|
||||
|
||||
- Dashboard 포함해서 실행하며 간헐적 문제 모니터링 및 개선점 자동 도출 요청
|
||||
- 모니터링 결과 발견된 이슈 목록:
|
||||
- **#178**: uvicorn 미설치 → dashboard 미작동 + 오해의 소지 있는 시작 로그 → uvicorn 설치 완료
|
||||
- **#179 (Critical)**: 잔액 부족 주문 실패 후 매 사이클마다 무한 재시도 (MLECW 20분 이상 반복)
|
||||
- **#180**: 다중 인스턴스 실행 시 Telegram 409 충돌
|
||||
- **#181**: implied_rsi 공식 포화 문제 (change_rate≥12.5% → RSI=100)
|
||||
- **#182 (Critical)**: 보유 종목이 SmartScanner 변동성 필터에 걸려 SELL 신호 미생성 → SELL 체결 0건, 잔고 소진
|
||||
- 요구사항: 모니터링 자동화 및 주기적 개선점 리포트 도출
|
||||
|
||||
---
|
||||
|
||||
## 2026-02-05
|
||||
|
||||
### API 효율화
|
||||
- Gemini API는 귀중한 자원. 종목별 개별 호출 대신 배치 호출 필요
|
||||
- Free tier 한도(20 calls/day) 고려하여 일일 몇 차례 거래 모드로 전환
|
||||
- 배치 API 호출로 여러 종목을 한 번에 분석
|
||||
|
||||
### 거래 모드
|
||||
- **Daily Mode**: 하루 4회 거래 세션 (6시간 간격) - Free tier 호환
|
||||
- **Realtime Mode**: 60초 간격 실시간 거래 - 유료 구독 필요
|
||||
- `TRADE_MODE` 환경변수로 모드 선택
|
||||
|
||||
### 진화 시스템
|
||||
- 사용자 대화 내용을 문서로 기록하여 향후에도 의도 반영
|
||||
- 프롬프트 품질 검증은 별도 이슈로 다룰 예정
|
||||
|
||||
### 문서화
|
||||
- 시스템 구조, 기능별 설명 등 코드 문서화 항상 신경쓸 것
|
||||
- 새로운 기능 추가 시 관련 문서 업데이트 필수
|
||||
|
||||
---
|
||||
|
||||
## 2026-02-06
|
||||
|
||||
### Smart Volatility Scanner (Python-First, AI-Last 파이프라인)
|
||||
|
||||
**배경:**
|
||||
- 정적 종목 리스트를 순회하는 방식은 비효율적
|
||||
- KIS API 거래량 순위를 통해 시장 주도주를 자동 탐지해야 함
|
||||
- Gemini API 호출 전에 Python 기반 기술적 분석으로 필터링 필요
|
||||
|
||||
**요구사항:**
|
||||
1. KIS API 거래량 순위 API 통합 (`fetch_market_rankings`)
|
||||
2. 일별 가격 히스토리 API 추가 (`get_daily_prices`)
|
||||
3. RSI(14) 계산 기능 구현 (Wilder's smoothing method)
|
||||
4. 필터 조건:
|
||||
- 거래량 > 전일 대비 200% (VOL_MULTIPLIER)
|
||||
- RSI < 30 (과매도) OR RSI > 70 (모멘텀)
|
||||
5. 상위 1-3개 적격 종목만 Gemini에 전달
|
||||
6. 종목 선정 배경(RSI, volume_ratio, signal, score) 데이터베이스 기록
|
||||
|
||||
**구현 결과:**
|
||||
- `src/analysis/smart_scanner.py`: SmartVolatilityScanner 클래스
|
||||
- `src/analysis/volatility.py`: calculate_rsi() 메서드 추가
|
||||
- `src/broker/kis_api.py`: 2개 신규 API 메서드
|
||||
- `src/db.py`: selection_context 컬럼 추가
|
||||
- 설정 가능한 임계값: RSI_OVERSOLD_THRESHOLD, RSI_MOMENTUM_THRESHOLD, VOL_MULTIPLIER, SCANNER_TOP_N
|
||||
|
||||
**효과:**
|
||||
- Gemini API 호출 20-30개 → 1-3개로 감소
|
||||
- Python 기반 빠른 필터링 → 비용 절감
|
||||
- 선정 기준 추적 → Evolution 시스템 최적화 가능
|
||||
- API 장애 시 정적 watchlist로 자동 전환
|
||||
|
||||
**참고:** Realtime 모드 전용. Daily 모드는 배치 효율성을 위해 정적 watchlist 사용.
|
||||
|
||||
**이슈/PR:** #76, #77
|
||||
|
||||
---
|
||||
|
||||
## 2026-02-10
|
||||
|
||||
### 코드 리뷰 시 플랜-구현 일치 검증 규칙
|
||||
|
||||
**배경:**
|
||||
- 코드 리뷰 시 플랜(EnterPlanMode에서 승인된 계획)과 실제 구현이 일치하는지 확인하는 절차가 없었음
|
||||
- 플랜과 다른 구현이 리뷰 없이 통과될 위험
|
||||
|
||||
**요구사항:**
|
||||
1. 모든 PR 리뷰에서 플랜-구현 일치 여부를 필수 체크
|
||||
2. 플랜에 없는 변경은 정당한 사유 필요
|
||||
3. 플랜 항목이 누락되면 PR 설명에 사유 기록
|
||||
4. 스코프가 플랜과 일치하는지 확인
|
||||
|
||||
**구현 결과:**
|
||||
- `docs/workflow.md`에 Code Review Checklist 섹션 추가
|
||||
- Plan Consistency (필수), Safety & Constraints, Quality, Workflow 4개 카테고리
|
||||
|
||||
**이슈/PR:** #114
|
||||
|
||||
---
|
||||
|
||||
## 2026-02-16
|
||||
|
||||
### 문서 v2 동기화 (전체 문서 현행화)
|
||||
|
||||
**배경:**
|
||||
- v2 기능 구현 완료 후 문서가 실제 코드 상태와 크게 괴리
|
||||
- 문서에는 54 tests / 4 files로 기록되었으나 실제로는 551 tests / 25 files
|
||||
- v2 핵심 기능(Playbook, Scenario Engine, Dashboard, Telegram Commands, Daily Review, Context System, Backup) 문서화 누락
|
||||
|
||||
**요구사항:**
|
||||
1. `docs/testing.md` — 551 tests / 25 files 반영, 전체 테스트 파일 설명
|
||||
2. `docs/architecture.md` — v2 컴포넌트(Strategy, Context, Dashboard, Decision Logger 등) 추가, Playbook Mode 데이터 플로우, DB 스키마 5개 테이블, v2 환경변수
|
||||
3. `docs/commands.md` — Dashboard 실행 명령어, Telegram 명령어 9종 레퍼런스
|
||||
4. `CLAUDE.md` — Project Structure 트리 확장, 테스트 수 업데이트, `--dashboard` 플래그
|
||||
5. `docs/skills.md` — DB 파일명 `trades.db`로 통일, Dashboard 명령어 추가
|
||||
6. 기존에 유효한 트러블슈팅, 코드 예제 등은 유지
|
||||
|
||||
**구현 결과:**
|
||||
- 6개 문서 파일 업데이트
|
||||
- 이전 시도(2개 커밋)는 기존 내용을 과도하게 삭제하여 폐기, main 기준으로 재작업
|
||||
|
||||
**이슈/PR:** #131, PR #134
|
||||
|
||||
### 해외 스캐너 개선: 랭킹 연동 + 변동성 우선 선별
|
||||
|
||||
**배경:**
|
||||
- `run_overnight` 실운영에서 미국장 동안 거래가 0건 지속
|
||||
- 원인: 해외 시장에서도 국내 랭킹/일봉 API 경로를 사용하던 구조적 불일치
|
||||
|
||||
**요구사항:**
|
||||
1. 해외 시장도 랭킹 API 기반 유니버스 탐색 지원
|
||||
2. 단순 상승률/거래대금 상위가 아니라, **변동성이 큰 종목**을 우선 선별
|
||||
3. 고정 티커 fallback 금지
|
||||
|
||||
**구현 결과:**
|
||||
- `src/broker/overseas.py`
|
||||
- `fetch_overseas_rankings()` 추가 (fluctuation / volume)
|
||||
- 해외 랭킹 API 경로/TR_ID를 설정값으로 오버라이드 가능하게 구현
|
||||
- `src/analysis/smart_scanner.py`
|
||||
- market-aware 스캔(국내/해외 분리)
|
||||
- 해외: 랭킹 API 유니버스 + 변동성 우선 점수(일변동률 vs 장중 고저폭)
|
||||
- 거래대금/거래량 랭킹은 유동성 보정 점수로 활용
|
||||
- 랭킹 실패 시에는 동적 유니버스(active/recent/holdings)만 사용
|
||||
- `src/config.py`
|
||||
- `OVERSEAS_RANKING_*` 설정 추가
|
||||
|
||||
**효과:**
|
||||
- 해외 시장에서 스캐너 후보 0개로 정지되는 상황 완화
|
||||
- 종목 선정 기준이 단순 상승률 중심에서 변동성 중심으로 개선
|
||||
- 고정 티커 없이도 시장 주도 변동 종목 탐지 가능
|
||||
|
||||
### 국내 스캐너/주문수량 정렬: 변동성 우선 + 리스크 타기팅
|
||||
|
||||
**배경:**
|
||||
- 해외만 변동성 우선으로 동작하고, 국내는 RSI/거래량 필터 중심으로 동작해 시장 간 전략 일관성이 낮았음
|
||||
- 매수 수량이 고정 1주라서 변동성 구간별 익스포저 관리가 어려웠음
|
||||
|
||||
**요구사항:**
|
||||
1. 국내 스캐너도 변동성 우선 선별로 해외와 통일
|
||||
2. 고변동 종목일수록 포지션 크기를 줄이는 수량 산식 적용
|
||||
|
||||
**구현 결과:**
|
||||
- `src/analysis/smart_scanner.py`
|
||||
- 국내: `fluctuation ranking + volume ranking bonus` 기반 점수화로 전환
|
||||
- 점수는 `max(abs(change_rate), intraday_range_pct)` 중심으로 계산
|
||||
- 국내 랭킹 응답 스키마 키(`price`, `change_rate`, `volume`) 파싱 보강
|
||||
- `src/main.py`
|
||||
- `_determine_order_quantity()` 추가
|
||||
- BUY 시 변동성 점수 기반 동적 수량 산정 적용
|
||||
- `trading_cycle`, `run_daily_session` 경로 모두 동일 수량 로직 사용
|
||||
- `src/config.py`
|
||||
- `POSITION_SIZING_*` 설정 추가
|
||||
|
||||
**효과:**
|
||||
- 국내/해외 스캐너 기준이 변동성 중심으로 일관화
|
||||
- 고변동 구간에서 자동 익스포저 축소, 저변동 구간에서 과소진입 완화
|
||||
|
||||
## 2026-02-18
|
||||
|
||||
### KIS 해외 랭킹 API 404 에러 수정
|
||||
|
||||
**배경:**
|
||||
- KIS 해외주식 랭킹 API(`fetch_overseas_rankings`)가 모든 거래소에서 HTTP 404를 반환
|
||||
- Smart Scanner가 해외 시장 후보 종목을 찾지 못해 거래가 전혀 실행되지 않음
|
||||
|
||||
**근본 원인:**
|
||||
- TR_ID, API 경로, 거래소 코드가 모두 KIS 공식 문서와 불일치
|
||||
|
||||
**구현 결과:**
|
||||
- `src/config.py`: TR_ID/Path 기본값을 KIS 공식 스펙으로 수정
|
||||
- `src/broker/overseas.py`: 랭킹 API 전용 거래소 코드 매핑 추가 (NASD→NAS, NYSE→NYS, AMEX→AMS), 올바른 API 파라미터 사용
|
||||
- `tests/test_overseas_broker.py`: 19개 단위 테스트 추가
|
||||
|
||||
**효과:**
|
||||
- 해외 시장 랭킹 스캔이 정상 동작하여 Smart Scanner가 후보 종목 탐지 가능
|
||||
|
||||
### Gemini prompt_override 미적용 버그 수정
|
||||
|
||||
**배경:**
|
||||
- `run_overnight` 실행 시 모든 시장에서 Playbook 생성 실패 (`JSONDecodeError`)
|
||||
- defensive playbook으로 폴백되어 모든 종목이 HOLD 처리
|
||||
|
||||
**근본 원인:**
|
||||
- `pre_market_planner.py`가 `market_data["prompt_override"]`에 Playbook 전용 프롬프트를 넣어 `gemini.decide()` 호출
|
||||
- `gemini_client.py`의 `decide()` 메서드가 `prompt_override` 키를 전혀 확인하지 않고 항상 일반 트레이드 결정 프롬프트 생성
|
||||
- Gemini가 Playbook JSON 대신 일반 트레이드 결정을 반환하여 파싱 실패
|
||||
|
||||
**구현 결과:**
|
||||
- `src/brain/gemini_client.py`: `decide()` 메서드에서 `prompt_override` 우선 사용 로직 추가
|
||||
- `tests/test_brain.py`: 3개 테스트 추가 (override 전달, optimization 우회, 미지정 시 기존 동작 유지)
|
||||
|
||||
**이슈/PR:** #143
|
||||
|
||||
### 미국장 거래 미실행 근본 원인 분석 및 수정 (자율 실행 세션)
|
||||
|
||||
**배경:**
|
||||
- 사용자 요청: "미국장 열면 프로그램 돌려서 거래 한 번도 못 한 거 꼭 원인 찾아서 해결해줘"
|
||||
- 프로그램을 미국장 개장(9:30 AM EST) 전부터 실행하여 실시간 로그를 분석
|
||||
|
||||
**발견된 근본 원인 #1: Defensive Playbook — BUY 조건 없음**
|
||||
|
||||
- Gemini free tier (20 RPD) 소진 → `generate_playbook()` 실패 → `_defensive_playbook()` 폴백
|
||||
- Defensive playbook은 `price_change_pct_below: -3.0 → SELL` 조건만 존재, BUY 조건 없음
|
||||
- ScenarioEngine이 항상 HOLD 반환 → 거래 0건
|
||||
|
||||
**수정 #1 (PR #146, Issue #145):**
|
||||
- `src/strategy/pre_market_planner.py`: `_smart_fallback_playbook()` 메서드 추가
|
||||
- 스캐너 signal 기반 BUY 조건 생성: `momentum → volume_ratio_above`, `oversold → rsi_below`
|
||||
- 기존 defensive stop-loss SELL 조건 유지
|
||||
- Gemini 실패 시 defensive → smart fallback으로 전환
|
||||
- 테스트 10개 추가
|
||||
|
||||
**발견된 근본 원인 #2: 가격 API 거래소 코드 불일치 + VTS 잔고 API 오류**
|
||||
|
||||
실제 로그:
|
||||
```
|
||||
Scenario matched for MRNX: BUY (confidence=80) ✓
|
||||
Decision for EWUS (NYSE American): BUY (confidence=80) ✓
|
||||
Skip BUY APLZ (NYSE American): no affordable quantity (cash=0.00, price=0.00) ✗
|
||||
```
|
||||
|
||||
- `get_overseas_price()`: `NASD`/`NYSE`/`AMEX` 전송 → API가 `NAS`/`NYS`/`AMS` 기대 → 빈 응답 → `price=0`
|
||||
- `VTTS3012R` 잔고 API: "ERROR : INPUT INVALID_CHECK_ACNO" → `total_cash=0`
|
||||
- 결과: `_determine_order_quantity()` 가 0 반환 → 주문 건너뜀
|
||||
|
||||
**수정 #2 (PR #148, Issue #147):**
|
||||
- `src/broker/overseas.py`: `_PRICE_EXCHANGE_MAP = _RANKING_EXCHANGE_MAP` 추가, 가격 API에 매핑 적용
|
||||
- `src/config.py`: `PAPER_OVERSEAS_CASH: float = Field(default=50000.0)` — paper 모드 시뮬레이션 잔고
|
||||
- `src/main.py`: 잔고 0일 때 PAPER_OVERSEAS_CASH 폴백, 가격 0일 때 candidate.price 폴백
|
||||
- 테스트 8개 추가
|
||||
|
||||
**효과:**
|
||||
- BUY 결정 → 실제 주문 전송까지의 파이프라인이 완전히 동작
|
||||
- Paper 모드에서 KIS VTS 해외 잔고 API 오류에 관계없이 시뮬레이션 거래 가능
|
||||
|
||||
**이슈/PR:** #145, #146, #147, #148
|
||||
|
||||
### 해외주식 시장가 주문 거부 수정 (Fix #3, 연속 발견)
|
||||
|
||||
**배경:**
|
||||
- Fix #147 적용 후 주문 전송 시작 → KIS VTS가 거부: "지정가만 가능한 상품입니다"
|
||||
|
||||
**근본 원인:**
|
||||
- `trading_cycle()`, `run_daily_session()` 양쪽에서 `send_overseas_order(price=0.0)` 하드코딩
|
||||
- `price=0` → `ORD_DVSN="01"` (시장가) 전송 → KIS VTS 거부
|
||||
- Fix #147에서 이미 `current_price`를 올바르게 계산했으나 주문 시 미사용
|
||||
|
||||
**구현 결과:**
|
||||
- `src/main.py`: 두 곳에서 `price=0.0` → `price=current_price`/`price=stock_data["current_price"]`
|
||||
- `tests/test_main.py`: 회귀 테스트 `test_overseas_buy_order_uses_limit_price` 추가
|
||||
|
||||
**최종 확인 로그:**
|
||||
```
|
||||
Order result: 모의투자 매수주문이 완료 되었습니다. ✓
|
||||
```
|
||||
|
||||
**이슈/PR:** #149, #150
|
||||
|
||||
---
|
||||
|
||||
## 2026-02-23
|
||||
|
||||
### 국내주식 지정가 전환 및 미체결 처리 (#232)
|
||||
|
||||
**배경:**
|
||||
- 해외주식은 #211에서 지정가로 전환했으나 국내주식은 여전히 `price=0` (시장가)
|
||||
- KRX도 지정가 주문 사용 시 동일한 미체결 위험이 존재
|
||||
- 지정가 전환 + 미체결 처리를 함께 구현
|
||||
|
||||
**구현 내용:**
|
||||
|
||||
1. `src/broker/kis_api.py`
|
||||
- `get_domestic_pending_orders()`: 모의 즉시 `[]`, 실전 `TTTC0084R` GET
|
||||
- `cancel_domestic_order()`: 실전 `TTTC0013U` / 모의 `VTTC0013U`, hashkey 필수
|
||||
|
||||
2. `src/main.py`
|
||||
- import `kr_round_down` 추가
|
||||
- `trading_cycle`, `run_daily_session` 국내 주문 `price=0` → 지정가:
|
||||
BUY +0.2% / SELL -0.2%, `kr_round_down` KRX 틱 반올림 적용
|
||||
- `handle_domestic_pending_orders` 함수: BUY→취소+쿨다운, SELL→취소+재주문(-0.4%, 최대1회)
|
||||
- daily/realtime 두 모드에서 domestic pending 체크 호출 추가
|
||||
|
||||
3. 테스트 14개 추가:
|
||||
- `TestGetDomesticPendingOrders` (3), `TestCancelDomesticOrder` (5)
|
||||
- `TestHandleDomesticPendingOrders` (4), `TestDomesticLimitOrderPrice` (2)
|
||||
|
||||
**이슈/PR:** #232, PR #233
|
||||
|
||||
---
|
||||
|
||||
## 2026-02-24
|
||||
|
||||
### 해외잔고 ghost position 수정 — '모의투자 잔고내역이 없습니다' 반복 방지 (#235)
|
||||
|
||||
**배경:**
|
||||
- 모의투자 실행 시 MLECW, KNRX, NBY, SNSE 등 만료/정지된 종목에 대해
|
||||
`모의투자 잔고내역이 없습니다` 오류가 매 사이클 반복됨
|
||||
|
||||
**근본 원인:**
|
||||
1. `ovrs_cblc_qty` (해외잔고수량, 총 보유) vs `ord_psbl_qty` (주문가능수량, 실제 매도 가능)
|
||||
- 기존 코드: `ovrs_cblc_qty` 우선 사용 → 만료 Warrant가 `ovrs_cblc_qty=289456`이지만 실제 `ord_psbl_qty=0`
|
||||
- startup sync / build_overseas_symbol_universe가 이 종목들을 포지션으로 기록
|
||||
2. SELL 실패 시 DB 포지션이 닫히지 않아 다음 사이클에서도 재시도 (무한 반복)
|
||||
|
||||
**구현 내용:**
|
||||
|
||||
1. `src/main.py` — `_extract_held_codes_from_balance`, `_extract_held_qty_from_balance`
|
||||
- 해외 잔고 필드 우선순위 변경: `ord_psbl_qty` → `ovrs_cblc_qty` → `hldg_qty` (fallback 유지)
|
||||
- KIS 공식 문서(VTTS3012R) 기준: `ord_psbl_qty`가 실제 매도 가능 수량
|
||||
|
||||
2. `src/main.py` — `trading_cycle` ghost-close 처리
|
||||
- 해외 SELL이 `잔고내역이 없습니다`로 실패 시 DB 포지션을 `[ghost-close]` SELL로 종료
|
||||
- exchange code 불일치 등 예외 상황에서 무한 반복 방지
|
||||
|
||||
3. 테스트 7개 추가:
|
||||
- `TestExtractHeldQtyFromBalance` 3개: ord_psbl_qty 우선, 0이면 0 반환, fallback
|
||||
- `TestExtractHeldCodesFromBalance` 2개: ord_psbl_qty=0인 종목 제외, fallback
|
||||
- `TestOverseasGhostPositionClose` 2개: ghost-close 로그 확인, 일반 오류 무시
|
||||
|
||||
**이슈/PR:** #235, PR #236
|
||||
@@ -34,6 +34,12 @@ python -m src.main --mode=paper
|
||||
```
|
||||
Runs the agent in paper-trading mode (no real orders).
|
||||
|
||||
### Start Trading Agent with Dashboard
|
||||
```bash
|
||||
python -m src.main --mode=paper --dashboard
|
||||
```
|
||||
Runs the agent with FastAPI dashboard on `127.0.0.1:8080` (configurable via `DASHBOARD_HOST`/`DASHBOARD_PORT`).
|
||||
|
||||
### Start Trading Agent (Production)
|
||||
```bash
|
||||
docker compose up -d ouroboros
|
||||
@@ -59,7 +65,7 @@ Analyze the last 30 days of trade logs and generate performance metrics.
|
||||
python -m src.evolution.optimizer --evolve
|
||||
```
|
||||
Triggers the evolution engine to:
|
||||
1. Analyze `trade_logs.db` for failing patterns
|
||||
1. Analyze `trades.db` for failing patterns
|
||||
2. Ask Gemini to generate a new strategy
|
||||
3. Run tests on the new strategy
|
||||
4. Create a PR if tests pass
|
||||
@@ -91,12 +97,12 @@ curl http://localhost:8080/health
|
||||
|
||||
### View Trade Logs
|
||||
```bash
|
||||
sqlite3 data/trade_logs.db "SELECT * FROM trades ORDER BY timestamp DESC LIMIT 20;"
|
||||
sqlite3 data/trades.db "SELECT * FROM trades ORDER BY timestamp DESC LIMIT 20;"
|
||||
```
|
||||
|
||||
### Export Trade History
|
||||
```bash
|
||||
sqlite3 -header -csv data/trade_logs.db "SELECT * FROM trades;" > trades_export.csv
|
||||
sqlite3 -header -csv data/trades.db "SELECT * FROM trades;" > trades_export.csv
|
||||
```
|
||||
|
||||
## Safety Checklist (Pre-Deploy)
|
||||
|
||||
206
docs/testing.md
206
docs/testing.md
@@ -2,51 +2,29 @@
|
||||
|
||||
## Test Structure
|
||||
|
||||
**54 tests** across four files. `asyncio_mode = "auto"` in pyproject.toml — async tests need no special decorator.
|
||||
**551 tests** across **25 files**. `asyncio_mode = "auto"` in pyproject.toml — async tests need no special decorator.
|
||||
|
||||
The `settings` fixture in `conftest.py` provides safe defaults with test credentials and in-memory DB.
|
||||
|
||||
### Test Files
|
||||
|
||||
#### `tests/test_risk.py` (11 tests)
|
||||
- Circuit breaker boundaries
|
||||
- Fat-finger edge cases
|
||||
#### Core Components
|
||||
|
||||
##### `tests/test_risk.py` (14 tests)
|
||||
- Circuit breaker boundaries and exact threshold triggers
|
||||
- Fat-finger edge cases and percentage validation
|
||||
- P&L calculation edge cases
|
||||
- Order validation logic
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
def test_circuit_breaker_exact_threshold(risk_manager):
|
||||
"""Circuit breaker should trip at exactly -3.0%."""
|
||||
with pytest.raises(CircuitBreakerTripped):
|
||||
risk_manager.validate_order(
|
||||
current_pnl_pct=-3.0,
|
||||
order_amount=1000,
|
||||
total_cash=10000
|
||||
)
|
||||
```
|
||||
|
||||
#### `tests/test_broker.py` (6 tests)
|
||||
##### `tests/test_broker.py` (11 tests)
|
||||
- OAuth token lifecycle
|
||||
- Rate limiting enforcement
|
||||
- Hash key generation
|
||||
- Network error handling
|
||||
- SSL context configuration
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
async def test_rate_limiter(broker):
|
||||
"""Rate limiter should delay requests to stay under 10 RPS."""
|
||||
start = time.monotonic()
|
||||
for _ in range(15): # 15 requests
|
||||
await broker._rate_limiter.acquire()
|
||||
elapsed = time.monotonic() - start
|
||||
assert elapsed >= 1.0 # Should take at least 1 second
|
||||
```
|
||||
|
||||
#### `tests/test_brain.py` (18 tests)
|
||||
- Valid JSON parsing
|
||||
- Markdown-wrapped JSON handling
|
||||
##### `tests/test_brain.py` (24 tests)
|
||||
- Valid JSON parsing and markdown-wrapped JSON handling
|
||||
- Malformed JSON fallback
|
||||
- Missing fields handling
|
||||
- Invalid action validation
|
||||
@@ -54,33 +32,143 @@ async def test_rate_limiter(broker):
|
||||
- Empty response handling
|
||||
- Prompt construction for different markets
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
async def test_confidence_below_threshold_forces_hold(brain):
|
||||
"""Decisions below confidence threshold should force HOLD."""
|
||||
decision = brain.parse_response('{"action":"BUY","confidence":70,"rationale":"test"}')
|
||||
assert decision.action == "HOLD"
|
||||
assert decision.confidence == 70
|
||||
```
|
||||
|
||||
#### `tests/test_market_schedule.py` (19 tests)
|
||||
##### `tests/test_market_schedule.py` (24 tests)
|
||||
- Market open/close logic
|
||||
- Timezone handling (UTC, Asia/Seoul, America/New_York, etc.)
|
||||
- DST (Daylight Saving Time) transitions
|
||||
- Weekend handling
|
||||
- Lunch break logic
|
||||
- Weekend handling and lunch break logic
|
||||
- Multiple market filtering
|
||||
- Next market open calculation
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
def test_is_market_open_during_trading_hours():
|
||||
"""Market should be open during regular trading hours."""
|
||||
# KRX: 9:00-15:30 KST, no lunch break
|
||||
market = MARKETS["KR"]
|
||||
trading_time = datetime(2026, 2, 3, 10, 0, tzinfo=ZoneInfo("Asia/Seoul")) # Monday 10:00
|
||||
assert is_market_open(market, trading_time) is True
|
||||
```
|
||||
##### `tests/test_db.py` (3 tests)
|
||||
- Database initialization and table creation
|
||||
- Trade logging with all fields (market, exchange_code, decision_id)
|
||||
- Query and retrieval operations
|
||||
|
||||
##### `tests/test_main.py` (37 tests)
|
||||
- Trading loop orchestration
|
||||
- Market iteration and stock processing
|
||||
- Dashboard integration (`--dashboard` flag)
|
||||
- Telegram command handler wiring
|
||||
- Error handling and graceful shutdown
|
||||
|
||||
#### Strategy & Playbook (v2)
|
||||
|
||||
##### `tests/test_pre_market_planner.py` (37 tests)
|
||||
- Pre-market playbook generation
|
||||
- Gemini API integration for scenario creation
|
||||
- Timeout handling and defensive playbook fallback
|
||||
- Multi-market playbook generation
|
||||
|
||||
##### `tests/test_scenario_engine.py` (44 tests)
|
||||
- Scenario matching against live market data
|
||||
- Confidence scoring and threshold filtering
|
||||
- Multiple scenario type handling
|
||||
- Edge cases (no match, partial match, expired scenarios)
|
||||
|
||||
##### `tests/test_playbook_store.py` (23 tests)
|
||||
- Playbook persistence to SQLite
|
||||
- Date-based retrieval and market filtering
|
||||
- Playbook status management (generated, active, expired)
|
||||
- JSON serialization/deserialization
|
||||
|
||||
##### `tests/test_strategy_models.py` (33 tests)
|
||||
- Pydantic model validation for scenarios, playbooks, decisions
|
||||
- Field constraints and default values
|
||||
- Serialization round-trips
|
||||
|
||||
#### Analysis & Scanning
|
||||
|
||||
##### `tests/test_volatility.py` (24 tests)
|
||||
- ATR and RSI calculation accuracy
|
||||
- Volume surge ratio computation
|
||||
- Momentum scoring
|
||||
- Breakout/breakdown pattern detection
|
||||
- Market scanner watchlist management
|
||||
|
||||
##### `tests/test_smart_scanner.py` (13 tests)
|
||||
- Python-first filtering pipeline
|
||||
- RSI and volume ratio filter logic
|
||||
- Candidate scoring and ranking
|
||||
- Fallback to static watchlist
|
||||
|
||||
#### Context & Memory
|
||||
|
||||
##### `tests/test_context.py` (18 tests)
|
||||
- L1-L7 layer storage and retrieval
|
||||
- Context key-value CRUD operations
|
||||
- Timeframe-based queries
|
||||
- Layer metadata management
|
||||
|
||||
##### `tests/test_context_scheduler.py` (5 tests)
|
||||
- Periodic context aggregation scheduling
|
||||
- Layer summarization triggers
|
||||
|
||||
#### Evolution & Review
|
||||
|
||||
##### `tests/test_evolution.py` (24 tests)
|
||||
- Strategy optimization loop
|
||||
- High-confidence losing trade analysis
|
||||
- Generated strategy validation
|
||||
|
||||
##### `tests/test_daily_review.py` (10 tests)
|
||||
- End-of-day review generation
|
||||
- Trade performance summarization
|
||||
- Context layer (L6_DAILY) integration
|
||||
|
||||
##### `tests/test_scorecard.py` (3 tests)
|
||||
- Daily scorecard metrics calculation
|
||||
- Win rate, P&L, confidence tracking
|
||||
|
||||
#### Notifications & Commands
|
||||
|
||||
##### `tests/test_telegram.py` (25 tests)
|
||||
- Message sending and formatting
|
||||
- Rate limiting (leaky bucket)
|
||||
- Error handling (network timeout, invalid token)
|
||||
- Auto-disable on missing credentials
|
||||
- Notification types (trade, circuit breaker, fat-finger, market events)
|
||||
|
||||
##### `tests/test_telegram_commands.py` (31 tests)
|
||||
- 9 command handlers (/help, /status, /positions, /report, /scenarios, /review, /dashboard, /stop, /resume)
|
||||
- Long polling and command dispatch
|
||||
- Authorization filtering by chat_id
|
||||
- Command response formatting
|
||||
|
||||
#### Dashboard
|
||||
|
||||
##### `tests/test_dashboard.py` (14 tests)
|
||||
- FastAPI endpoint responses (8 API routes)
|
||||
- Status, playbook, scorecard, performance, context, decisions, scenarios
|
||||
- Query parameter handling (market, date, limit)
|
||||
|
||||
#### Performance & Quality
|
||||
|
||||
##### `tests/test_token_efficiency.py` (34 tests)
|
||||
- Gemini token usage optimization
|
||||
- Prompt size reduction verification
|
||||
- Cache effectiveness
|
||||
|
||||
##### `tests/test_latency_control.py` (30 tests)
|
||||
- API call latency measurement
|
||||
- Rate limiter timing accuracy
|
||||
- Async operation overhead
|
||||
|
||||
##### `tests/test_decision_logger.py` (9 tests)
|
||||
- Decision audit trail completeness
|
||||
- Context snapshot capture
|
||||
- Outcome tracking (P&L, accuracy)
|
||||
|
||||
##### `tests/test_data_integration.py` (38 tests)
|
||||
- External data source integration
|
||||
- News API, market data, economic calendar
|
||||
- Error handling for API failures
|
||||
|
||||
##### `tests/test_backup.py` (23 tests)
|
||||
- Backup scheduler and execution
|
||||
- Cloud storage (S3) upload
|
||||
- Health monitoring
|
||||
- Data export functionality
|
||||
|
||||
## Coverage Requirements
|
||||
|
||||
@@ -91,20 +179,6 @@ Check coverage:
|
||||
pytest -v --cov=src --cov-report=term-missing
|
||||
```
|
||||
|
||||
Expected output:
|
||||
```
|
||||
Name Stmts Miss Cover Missing
|
||||
-----------------------------------------------------------
|
||||
src/brain/gemini_client.py 85 5 94% 165-169
|
||||
src/broker/kis_api.py 120 12 90% ...
|
||||
src/core/risk_manager.py 35 2 94% ...
|
||||
src/db.py 25 1 96% ...
|
||||
src/main.py 150 80 47% (excluded from CI)
|
||||
src/markets/schedule.py 95 3 97% ...
|
||||
-----------------------------------------------------------
|
||||
TOTAL 510 103 80%
|
||||
```
|
||||
|
||||
**Note:** `main.py` has lower coverage as it contains the main loop which is tested via integration/manual testing.
|
||||
|
||||
## Test Configuration
|
||||
|
||||
@@ -6,12 +6,64 @@
|
||||
|
||||
1. **Create Gitea Issue First** — All features, bug fixes, and policy changes require a Gitea issue before any code is written
|
||||
2. **Create Feature Branch** — Branch from `main` using format `feature/issue-{N}-{short-description}`
|
||||
- After creating the branch, run `git pull origin main` and rebase to ensure the branch is up to date
|
||||
3. **Implement Changes** — Write code, tests, and documentation on the feature branch
|
||||
4. **Create Pull Request** — Submit PR to `main` branch referencing the issue number
|
||||
5. **Review & Merge** — After approval, merge via PR (squash or merge commit)
|
||||
|
||||
**Never commit directly to `main`.** This policy applies to all changes, no exceptions.
|
||||
|
||||
## Gitea CLI Formatting Troubleshooting
|
||||
|
||||
Issue/PR 본문 작성 시 줄바꿈(`\n`)이 문자열 그대로 저장되는 문제가 반복될 수 있다. 원인은 `-d "...\n..."` 형태에서 쉘/CLI가 이스케이프를 실제 개행으로 해석하지 않기 때문이다.
|
||||
|
||||
권장 패턴:
|
||||
|
||||
```bash
|
||||
ISSUE_BODY=$(cat <<'EOF'
|
||||
## Summary
|
||||
- 변경 내용 1
|
||||
- 변경 내용 2
|
||||
|
||||
## Why
|
||||
- 배경 1
|
||||
- 배경 2
|
||||
|
||||
## Scope
|
||||
- 포함 범위
|
||||
- 제외 범위
|
||||
EOF
|
||||
)
|
||||
|
||||
tea issues create \
|
||||
-t "docs: 제목" \
|
||||
-d "$ISSUE_BODY"
|
||||
```
|
||||
|
||||
PR도 동일하게 적용:
|
||||
|
||||
```bash
|
||||
PR_BODY=$(cat <<'EOF'
|
||||
## Summary
|
||||
- ...
|
||||
|
||||
## Validation
|
||||
- python3 scripts/validate_ouroboros_docs.py
|
||||
EOF
|
||||
)
|
||||
|
||||
tea pr create \
|
||||
--base main \
|
||||
--head feature/issue-N-something \
|
||||
--title "docs: ... (#N)" \
|
||||
--description "$PR_BODY"
|
||||
```
|
||||
|
||||
금지 패턴:
|
||||
|
||||
- `-d "line1\nline2"` (웹 UI에 `\n` 문자 그대로 노출될 수 있음)
|
||||
- 본문에 백틱/괄호를 인라인로 넣고 적절한 quoting 없이 즉시 실행
|
||||
|
||||
## Agent Workflow
|
||||
|
||||
**Modern AI development leverages specialized agents for concurrent, efficient task execution.**
|
||||
@@ -73,3 +125,37 @@ task_tool(
|
||||
```
|
||||
|
||||
Use `run_in_background=True` for independent tasks that don't block subsequent work.
|
||||
|
||||
## Code Review Checklist
|
||||
|
||||
**CRITICAL: Every PR review MUST verify plan-implementation consistency.**
|
||||
|
||||
Before approving any PR, the reviewer (human or agent) must check ALL of the following:
|
||||
|
||||
### 1. Plan Consistency (MANDATORY)
|
||||
|
||||
- [ ] **Implementation matches the approved plan** — Compare the actual code changes against the plan created during `EnterPlanMode`. Every item in the plan must be addressed.
|
||||
- [ ] **No unplanned changes** — If the implementation includes changes not in the plan, they must be explicitly justified.
|
||||
- [ ] **No plan items omitted** — If any planned item was skipped, the reason must be documented in the PR description.
|
||||
- [ ] **Scope matches** — The PR does not exceed or fall short of the planned scope.
|
||||
|
||||
### 2. Safety & Constraints
|
||||
|
||||
- [ ] `src/core/risk_manager.py` is unchanged (READ-ONLY)
|
||||
- [ ] Circuit breaker threshold not weakened (only stricter allowed)
|
||||
- [ ] Fat-finger protection (30% max order) still enforced
|
||||
- [ ] Confidence < 80 still forces HOLD
|
||||
- [ ] No hardcoded API keys or secrets
|
||||
|
||||
### 3. Quality
|
||||
|
||||
- [ ] All new/modified code has corresponding tests
|
||||
- [ ] Test coverage >= 80%
|
||||
- [ ] `ruff check src/ tests/` passes (no lint errors)
|
||||
- [ ] No `assert` statements removed from tests
|
||||
|
||||
### 4. Workflow
|
||||
|
||||
- [ ] PR references the Gitea issue number
|
||||
- [ ] Feature branch follows naming convention (`feature/issue-N-description`)
|
||||
- [ ] Commit messages are clear and descriptive
|
||||
|
||||
165
ouroboros_plan_v2.txt
Normal file
165
ouroboros_plan_v2.txt
Normal file
@@ -0,0 +1,165 @@
|
||||
[The Ouroboros] 운영/전략 계획서 v2
|
||||
작성일: 2026-02-26
|
||||
상태: 코드 구현 전 설계안(전략/검증 중심)
|
||||
|
||||
==================================================
|
||||
0) 목적
|
||||
==================================================
|
||||
고정 익절(+3%) 중심 로직에서 벗어나, 다음을 만족하는 실전형 청산 체계로 전환한다.
|
||||
- 수익 구간 보호 (손익 역전 방지)
|
||||
- 변동성 적응형 청산
|
||||
- 예측 모델의 확률 신호를 보조적으로 결합
|
||||
- 과적합 방지를 최우선으로 한 검증 프레임워크
|
||||
|
||||
==================================================
|
||||
1) 핵심 설계 원칙
|
||||
==================================================
|
||||
1. 예측 성능과 전략 성능을 분리 평가
|
||||
- 예측 성능: PR-AUC, Brier, Calibration
|
||||
- 전략 성능: Net PnL, Sharpe, MDD, Profit Factor, Turnover
|
||||
|
||||
2. 시계열 검증 규율 강제
|
||||
- Walk-forward 분할
|
||||
- Purge/Embargo 적용
|
||||
- Random split 금지
|
||||
|
||||
3. 실거래 리얼리즘 우선
|
||||
- 거래비용/슬리피지/체결실패 반영 없는 백테스트 결과는 채택 금지
|
||||
|
||||
==================================================
|
||||
2) 매도 상태기계 (State Machine)
|
||||
==================================================
|
||||
상태:
|
||||
- HOLDING
|
||||
- BE_LOCK
|
||||
- ARMED
|
||||
- EXITED
|
||||
|
||||
정의:
|
||||
- HOLDING: 일반 보유 상태
|
||||
- BE_LOCK: 일정 수익권 진입 시 손절선을 본전(또는 비용 반영 본전)으로 상향
|
||||
- ARMED: 추세 추적(피크 추적) 기반 청산 준비 상태
|
||||
- EXITED: 청산 완료
|
||||
|
||||
전이 규칙(개념):
|
||||
- HOLDING -> BE_LOCK: unrealized_pnl_pct >= be_arm_pct
|
||||
- BE_LOCK -> ARMED: unrealized_pnl_pct >= arm_pct
|
||||
- ARMED -> EXITED: 아래 조건 중 하나 충족
|
||||
1) hard stop 도달
|
||||
2) trailing stop 도달 (peak 대비 하락)
|
||||
3) 모델 하락확률 + 유동성 약화 조건 충족
|
||||
|
||||
상태 전이 구현 규칙(필수):
|
||||
- 매 틱/바 평가 시 "현재 조건이 허용하는 최상위 상태"로 즉시 승격
|
||||
- 순차 if-else로 인한 전이 누락 금지 (예: 갭으로 BE_LOCK/ARMED 동시 충족)
|
||||
- EXITED 조건은 모든 상태보다 우선 평가
|
||||
- 상태 전이 로그에 이전/이후 상태, 전이 사유, 기준 가격/수익률 기록
|
||||
|
||||
==================================================
|
||||
3) 청산 로직 구성 (4중 안전장치)
|
||||
==================================================
|
||||
A. Hard Stop
|
||||
- 계좌/포지션 보호용 절대 하한
|
||||
- 항상 활성화
|
||||
|
||||
B. Dynamic Stop (Break-even Lock)
|
||||
- BE_LOCK 진입 시 손절선을 본전 이상으로 상향
|
||||
- "수익 포지션이 손실로 반전"되는 구조적 리스크 차단
|
||||
|
||||
C. ATR 기반 Trailing Stop
|
||||
- 고정 trail_pct 대신 변동성 적응형 사용
|
||||
- 예시: ExitPrice = PeakPrice - (k * ATR)
|
||||
|
||||
D. 모델 확률 신호
|
||||
- 하락전환 확률(pred_prob)이 임계값 이상일 때 청산 가중
|
||||
- 단독 트리거가 아닌 trailing/리스크 룰 보조 트리거로 사용
|
||||
|
||||
==================================================
|
||||
4) 라벨링 체계 (Triple Barrier)
|
||||
==================================================
|
||||
목표:
|
||||
고정 H-window 라벨 편향을 줄이고, 금융 시계열의 경로 의존성을 반영한다.
|
||||
|
||||
라벨 정의:
|
||||
- Upper barrier (익절)
|
||||
- Lower barrier (손절)
|
||||
- Time barrier (만기)
|
||||
|
||||
규칙:
|
||||
- 세 장벽 중 "먼저 터치한 장벽"으로 라벨 확정
|
||||
- 라벨은 entry 시점 이후 데이터만 사용해 생성
|
||||
- 피처 생성 구간과 라벨 구간을 엄격 분리해 look-ahead bias 방지
|
||||
|
||||
==================================================
|
||||
5) 검증 프레임워크
|
||||
==================================================
|
||||
5.1 분할 방식
|
||||
- Fold 단위 Walk-forward
|
||||
- Purge/Embargo로 인접 샘플 누수 차단
|
||||
|
||||
5.2 비교군(Baseline) 구조
|
||||
- B0: 기존 고정 손절/익절
|
||||
- B1: 모델 없는 trailing only
|
||||
- M1: trailing + 모델 확률 결합
|
||||
|
||||
5.3 채택 기준
|
||||
- M1이 B0/B1 대비 OOS(Out-of-sample)에서 일관된 우위
|
||||
- 단일 구간 성과가 아닌 fold 분포 기준으로 판단
|
||||
|
||||
==================================================
|
||||
6) 실행 아키텍처 원칙
|
||||
==================================================
|
||||
1. 저지연 실행 경로
|
||||
- 실시간 청산 판단은 경량 엔진(룰/GBDT) 담당
|
||||
- LLM은 레짐 판단/비중 조절/상위 의사결정 보조
|
||||
|
||||
2. 체결 현실 반영
|
||||
- 세션 유동성에 따른 슬리피지 페널티 차등 적용
|
||||
- 미체결/재호가/재접수 시나리오를 백테스트에 반영
|
||||
|
||||
==================================================
|
||||
7) 운영 리스크 관리
|
||||
==================================================
|
||||
승격 단계:
|
||||
- Offline backtest -> Paper shadow -> Small-capital live
|
||||
|
||||
중단(Kill Switch):
|
||||
- rolling Sharpe 악화
|
||||
- MDD 한도 초과
|
||||
- 체결 실패율/슬리피지 급등
|
||||
|
||||
Kill Switch 실행 순서(원자적):
|
||||
1) 모든 신규 주문 차단 플래그 ON
|
||||
2) 모든 미체결 주문 취소 요청
|
||||
3) 취소 결과 재조회(실패 건 재시도)
|
||||
4) 포지션 리스크 재계산 후 강제 축소/청산 판단
|
||||
5) 상태/로그 스냅샷 저장 및 운영 경보 발송
|
||||
|
||||
원칙:
|
||||
- 모델이 실패해도 hard stop 기반 보수 모드로 즉시 디그레이드 가능해야 함
|
||||
|
||||
==================================================
|
||||
8) 고정 파라미터(초기안)
|
||||
==================================================
|
||||
(15분봉 단기 스윙 기준 제안)
|
||||
- KR: be_arm_pct=1.2, arm_pct=2.8, atr_period=14, atr_multiplier_k=2.2,
|
||||
time_barrier_bars=26, p_thresh=0.62
|
||||
- US: be_arm_pct=1.0, arm_pct=2.4, atr_period=14, atr_multiplier_k=2.0,
|
||||
time_barrier_bars=32, p_thresh=0.60
|
||||
|
||||
민감도 범위(초기 탐색):
|
||||
- be_arm_pct: KR 0.9~1.8 / US 0.7~1.5
|
||||
- arm_pct: KR 2.2~3.8 / US 1.8~3.2
|
||||
- atr_multiplier_k: KR 1.8~2.8 / US 1.6~2.4
|
||||
- time_barrier_bars: KR 20~36 / US 24~48
|
||||
- p_thresh: 0.55~0.70
|
||||
|
||||
==================================================
|
||||
9) 구현 전 체크리스트
|
||||
==================================================
|
||||
- 파라미터 튜닝 시 nested leakage 방지
|
||||
- 수수료/세금/슬리피지 전부 반영 여부 확인
|
||||
- 세션/타임존/DST 처리 일관성 확인
|
||||
- 모델 버전/설정 해시/실험 로그 재현성 확보
|
||||
|
||||
끝.
|
||||
185
ouroboros_plan_v3.txt
Normal file
185
ouroboros_plan_v3.txt
Normal file
@@ -0,0 +1,185 @@
|
||||
[The Ouroboros] 운영확장 v3
|
||||
작성일: 2026-02-26
|
||||
상태: v2 확장판 / 야간·프리마켓 포함 글로벌 세션 운영 설계안
|
||||
|
||||
==================================================
|
||||
0) 목적
|
||||
==================================================
|
||||
"24시간 무중단 자산 증식" 비전을 위해 거래 세션 범위를 KR 정규장 중심에서
|
||||
NXT/미국 확장 세션까지 확대한다. 핵심은 다음 3가지다.
|
||||
- 세션 인지형 의사결정
|
||||
- 세션별 리스크/비용 차등 적용
|
||||
- 시간장벽의 현실적 재정의
|
||||
|
||||
==================================================
|
||||
1) 세션 모델 (Session-aware Engine)
|
||||
==================================================
|
||||
KR 세션:
|
||||
- NXT_PRE : 08:00 ~ 08:50 (KST)
|
||||
- KRX_REG : 09:00 ~ 15:30 (KST)
|
||||
- NXT_AFTER : 15:30 ~ 20:00 (KST)
|
||||
|
||||
US 세션(KST 관점 운영):
|
||||
- US_DAY : 10:00 ~ 18:00
|
||||
- US_PRE : 18:00 ~ 23:30
|
||||
- US_REG : 23:30 ~ 06:00
|
||||
- US_AFTER : 06:00 ~ 07:00
|
||||
|
||||
원칙:
|
||||
- 모든 피처/신호/주문/로그에 session_id를 명시적으로 포함
|
||||
- 세션 전환 시 상태 업데이트 및 리스크 파라미터 재로딩
|
||||
|
||||
==================================================
|
||||
2) 캘린더/휴장/DST 고정 소스
|
||||
==================================================
|
||||
KR:
|
||||
- 기본: pykrx 또는 FinanceDataReader (KRX 기준)
|
||||
- 예외: 연휴/임시 휴장/NXT 특이 운영은 KIS 공지 기반 보완
|
||||
|
||||
US:
|
||||
- pandas_market_calendars (NYSE 기준)
|
||||
- 2026 DST:
|
||||
- 시작: 2026-03-08
|
||||
- 종료: 2026-11-01
|
||||
|
||||
정합성 규칙:
|
||||
- 스케줄 충돌 시 "거래소 캘린더 > 로컬 추정" 우선
|
||||
- 시장 상태(open/close/half-day)는 주문 엔진 진입 전 최종 검증
|
||||
|
||||
KIS 점검시간 회피 정책(필수):
|
||||
- 브로커 점검/장애 블랙아웃 윈도우는 운영 설정으로 별도 관리
|
||||
- 블랙아웃 구간에는 신규 주문 전송 금지, 취소/정정도 정책적으로 제한
|
||||
- 신호는 유지하되 주문 의도는 Queue에 적재, 복구 후 유효성 재검증 뒤 실행
|
||||
- 복구 직후에는 잔고/미체결/체결내역을 우선 동기화한 뒤 주문 엔진 재가동
|
||||
|
||||
==================================================
|
||||
3) 시간장벽 재정의
|
||||
==================================================
|
||||
v2의 time_barrier_bars 고정값을 v3에서 다음으로 확장:
|
||||
- max_holding_minutes (시장별 기본 만기)
|
||||
- 봉 개수는 세션 길이/간격으로 동적 계산
|
||||
|
||||
기본값:
|
||||
- KR: max_holding_minutes = 2160 (약 3거래일, NXT 포함 관점)
|
||||
- US: max_holding_minutes = 4320 (약 72시간)
|
||||
|
||||
운영 주의:
|
||||
- 고정 "일중 청산"보다 "포지션 유지 시간" 기준 만기 적용
|
||||
- 세션 종료 강제청산 규칙과 충돌 시 우선순위 명시 필요
|
||||
|
||||
==================================================
|
||||
4) 세션별 비용/슬리피지 모델 (보수적)
|
||||
==================================================
|
||||
KRX_REG:
|
||||
- 슬리피지: 2~3틱 (약 0.05%)
|
||||
- 수수료+세금: 0.20% ~ 0.23%
|
||||
|
||||
NXT_AFTER:
|
||||
- 슬리피지: 5~8틱 (약 0.15%)
|
||||
- 수수료+세금: 0.20% ~ 0.23%
|
||||
|
||||
US_REG:
|
||||
- 슬리피지: 2~3틱 (약 0.03%)
|
||||
- 수수료+기타 비용: 0.07% ~ 0.15%
|
||||
|
||||
US_PRE / US_DAY:
|
||||
- 슬리피지: 10틱+ (약 0.3% ~ 0.5%)
|
||||
- 수수료+기타 비용: 0.07% ~ 0.15%
|
||||
|
||||
원칙:
|
||||
- 백테스트 체결가는 세션별 보수 가정 적용
|
||||
- 저유동 세션은 자동 보수 모드(p_thresh 상향, atr_k 상향) 권장
|
||||
- 백테스트 체결가 기본은 "불리한 방향 체결" 가정 (단순 close 체결 금지)
|
||||
|
||||
세션별 주문 유형 강제(필수):
|
||||
- KRX_REG / US_REG: 지정가 우선, 시장가 제한적 허용
|
||||
- NXT_AFTER / US_PRE / US_DAY / US_AFTER: 시장가 금지
|
||||
- 저유동 세션은 최우선 지정가 또는 IOC/FOK(가격 보호 한도 포함)만 허용
|
||||
- 주문 실패 시 재호가 간격/횟수 상한을 두고, 초과 시 주문 철회
|
||||
|
||||
==================================================
|
||||
5) 포지션/잔고 통합 규칙 (KIS 특성 반영)
|
||||
==================================================
|
||||
문제:
|
||||
- KRX/NXT 잔고 조회가 venue 단위로 분리되거나 반영 지연 가능
|
||||
|
||||
규칙:
|
||||
- 종목 식별은 동일 종목코드(또는 ISIN) 기준 통합 포지션으로 관리
|
||||
- 다만 주문 가능 수량은 venue별 API 응답을 최종 기준으로 사용
|
||||
- 매도 가능 수량 검증은 주문 직전 재조회로 확정
|
||||
|
||||
==================================================
|
||||
6) 마감 강제청산/오버나잇 예외 규칙
|
||||
==================================================
|
||||
기본 원칙:
|
||||
- 모든 포지션에 대해 세션 종료 10분 전 REDUCE_ALL 검토
|
||||
|
||||
오버나잇 예외 허용 (모두 충족 시):
|
||||
1) ARMED 상태 (예: +2.8% 이상)
|
||||
2) 모델 하락확률 < 0.30
|
||||
3) 포트폴리오 현금 비중 >= 50%
|
||||
|
||||
갭 리스크 통제:
|
||||
- 다음 개장 시 hard stop를 시가 기준으로 재산정
|
||||
- 조건 위반 시 즉시 청산 우선
|
||||
|
||||
Kill Switch 연동:
|
||||
- MDD/실패율 임계치 초과 시 "미체결 전량 취소 -> 신규 주문 차단 -> 리스크 축소" 순서 강제
|
||||
|
||||
==================================================
|
||||
7) 데이터 저장/용량 정책
|
||||
==================================================
|
||||
핵심 테이블(계획):
|
||||
- feature_snapshots
|
||||
- position_states
|
||||
- model_predictions
|
||||
|
||||
저장 규칙:
|
||||
- feature_hash 기반 중복 제거
|
||||
- 가격 변화가 작아도 session_id 변경 시 강제 스냅샷
|
||||
- 월 단위 DB 로테이션 권장 (예: trading_YYYY_MM.db)
|
||||
|
||||
==================================================
|
||||
8) 환율/정산 리스크 정책 (US 필수)
|
||||
==================================================
|
||||
원칙:
|
||||
- USD 노출은 전략 손익과 별도로 환율 손익을 분리 추적
|
||||
- 원화 주문 서비스 사용 시 가환율 체결/익일 정산 리스크를 예수금 규칙에 반영
|
||||
|
||||
운영 규칙:
|
||||
- 환전 시점 정책(사전 환전/수시 환전)을 고정하고 로그에 기록
|
||||
- 최소 USD 버퍼와 KRW 버퍼를 각각 설정해 주문 가능금 부족 리스크 완화
|
||||
- 환율 급변 구간에는 포지션 한도 축소 또는 신규 진입 제한
|
||||
|
||||
==================================================
|
||||
9) v3 실험 매트릭스 (우선 3선)
|
||||
==================================================
|
||||
EXP-KR-01:
|
||||
- 시장: KR
|
||||
- 포커스: NXT 야간 특화
|
||||
- 제안: time barrier 확장(예: 48 bars 상당), p_thresh 상향(0.65)
|
||||
|
||||
EXP-US-01:
|
||||
- 시장: US
|
||||
- 포커스: 21h 준연속 운용
|
||||
- 제안: time barrier 확장(예: 80 bars 상당), atr_k 상향(2.5)
|
||||
|
||||
EXP-HYB-01:
|
||||
- 시장: Global
|
||||
- 포커스: KR 낮 + US 밤 연계
|
||||
- 제안: 레짐 기반 자산배분 자동조절
|
||||
|
||||
==================================================
|
||||
10) 코드 착수 전 최종 확정 체크
|
||||
==================================================
|
||||
1) 세션별 공식 캘린더 소스/우선순위
|
||||
2) 세션별 슬리피지/비용 테이블 수치
|
||||
3) 시장별 max_holding_minutes
|
||||
4) 마감 강제청산 예외 조건 임계값
|
||||
5) 블랙아웃(점검/장애) 시간대와 주문 큐 처리 규칙
|
||||
6) 세션별 허용 주문 유형(시장가 허용 범위 포함)
|
||||
7) 환전/정산 정책 및 통화 버퍼 임계값
|
||||
|
||||
모든 항목 확정 후 Step 1 구현(코드)로 이동.
|
||||
|
||||
끝.
|
||||
@@ -8,6 +8,9 @@ dependencies = [
|
||||
"pydantic>=2.5,<3",
|
||||
"pydantic-settings>=2.1,<3",
|
||||
"google-genai>=1.0,<2",
|
||||
"scipy>=1.11,<2",
|
||||
"fastapi>=0.110,<1",
|
||||
"uvicorn>=0.29,<1",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
96
scripts/backup.sh
Normal file
96
scripts/backup.sh
Normal file
@@ -0,0 +1,96 @@
|
||||
#!/usr/bin/env bash
|
||||
# Automated backup script for The Ouroboros trading system
|
||||
# Runs daily/weekly/monthly backups
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Configuration
|
||||
DB_PATH="${DB_PATH:-data/trade_logs.db}"
|
||||
BACKUP_DIR="${BACKUP_DIR:-data/backups}"
|
||||
PYTHON="${PYTHON:-python3}"
|
||||
|
||||
# Colors for output
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
log_info() {
|
||||
echo -e "${GREEN}[INFO]${NC} $1"
|
||||
}
|
||||
|
||||
log_warn() {
|
||||
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||
}
|
||||
|
||||
log_error() {
|
||||
echo -e "${RED}[ERROR]${NC} $1"
|
||||
}
|
||||
|
||||
# Check if database exists
|
||||
if [ ! -f "$DB_PATH" ]; then
|
||||
log_error "Database not found: $DB_PATH"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Create backup directory
|
||||
mkdir -p "$BACKUP_DIR"
|
||||
|
||||
log_info "Starting backup process..."
|
||||
log_info "Database: $DB_PATH"
|
||||
log_info "Backup directory: $BACKUP_DIR"
|
||||
|
||||
# Determine backup policy based on day of week and month
|
||||
DAY_OF_WEEK=$(date +%u) # 1-7 (Monday-Sunday)
|
||||
DAY_OF_MONTH=$(date +%d)
|
||||
|
||||
if [ "$DAY_OF_MONTH" == "01" ]; then
|
||||
POLICY="monthly"
|
||||
log_info "Running MONTHLY backup (first day of month)"
|
||||
elif [ "$DAY_OF_WEEK" == "7" ]; then
|
||||
POLICY="weekly"
|
||||
log_info "Running WEEKLY backup (Sunday)"
|
||||
else
|
||||
POLICY="daily"
|
||||
log_info "Running DAILY backup"
|
||||
fi
|
||||
|
||||
# Run Python backup script
|
||||
$PYTHON -c "
|
||||
from pathlib import Path
|
||||
from src.backup.scheduler import BackupScheduler, BackupPolicy
|
||||
from src.backup.health_monitor import HealthMonitor
|
||||
|
||||
# Create scheduler
|
||||
scheduler = BackupScheduler(
|
||||
db_path='$DB_PATH',
|
||||
backup_dir=Path('$BACKUP_DIR')
|
||||
)
|
||||
|
||||
# Create backup
|
||||
policy = BackupPolicy.$POLICY.upper()
|
||||
metadata = scheduler.create_backup(policy, verify=True)
|
||||
print(f'Backup created: {metadata.file_path}')
|
||||
print(f'Size: {metadata.size_bytes / 1024 / 1024:.2f} MB')
|
||||
print(f'Checksum: {metadata.checksum}')
|
||||
|
||||
# Cleanup old backups
|
||||
removed = scheduler.cleanup_old_backups()
|
||||
total_removed = sum(removed.values())
|
||||
if total_removed > 0:
|
||||
print(f'Removed {total_removed} old backup(s)')
|
||||
|
||||
# Health check
|
||||
monitor = HealthMonitor('$DB_PATH', Path('$BACKUP_DIR'))
|
||||
status = monitor.get_overall_status()
|
||||
print(f'System health: {status.value}')
|
||||
"
|
||||
|
||||
if [ $? -eq 0 ]; then
|
||||
log_info "Backup completed successfully"
|
||||
else
|
||||
log_error "Backup failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
log_info "Backup process finished"
|
||||
54
scripts/morning_report.sh
Executable file
54
scripts/morning_report.sh
Executable file
@@ -0,0 +1,54 @@
|
||||
#!/usr/bin/env bash
|
||||
# Morning summary for overnight run logs.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
LOG_DIR="${LOG_DIR:-data/overnight}"
|
||||
|
||||
if [ ! -d "$LOG_DIR" ]; then
|
||||
echo "로그 디렉터리가 없습니다: $LOG_DIR"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
latest_run="$(ls -1t "$LOG_DIR"/run_*.log 2>/dev/null | head -n 1 || true)"
|
||||
latest_watchdog="$(ls -1t "$LOG_DIR"/watchdog_*.log 2>/dev/null | head -n 1 || true)"
|
||||
|
||||
if [ -z "$latest_run" ]; then
|
||||
echo "run 로그가 없습니다: $LOG_DIR/run_*.log"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Overnight report"
|
||||
echo "- run log: $latest_run"
|
||||
if [ -n "$latest_watchdog" ]; then
|
||||
echo "- watchdog log: $latest_watchdog"
|
||||
fi
|
||||
|
||||
start_line="$(head -n 1 "$latest_run" || true)"
|
||||
end_line="$(tail -n 1 "$latest_run" || true)"
|
||||
|
||||
info_count="$(rg -c '"level": "INFO"' "$latest_run" || true)"
|
||||
warn_count="$(rg -c '"level": "WARNING"' "$latest_run" || true)"
|
||||
error_count="$(rg -c '"level": "ERROR"' "$latest_run" || true)"
|
||||
critical_count="$(rg -c '"level": "CRITICAL"' "$latest_run" || true)"
|
||||
traceback_count="$(rg -c 'Traceback' "$latest_run" || true)"
|
||||
|
||||
echo "- start: ${start_line:-N/A}"
|
||||
echo "- end: ${end_line:-N/A}"
|
||||
echo "- INFO: ${info_count:-0}"
|
||||
echo "- WARNING: ${warn_count:-0}"
|
||||
echo "- ERROR: ${error_count:-0}"
|
||||
echo "- CRITICAL: ${critical_count:-0}"
|
||||
echo "- Traceback: ${traceback_count:-0}"
|
||||
|
||||
if [ -n "$latest_watchdog" ]; then
|
||||
watchdog_errors="$(rg -c '\[ERROR\]' "$latest_watchdog" || true)"
|
||||
echo "- watchdog ERROR: ${watchdog_errors:-0}"
|
||||
echo ""
|
||||
echo "최근 watchdog 로그:"
|
||||
tail -n 5 "$latest_watchdog" || true
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "최근 앱 로그:"
|
||||
tail -n 20 "$latest_run" || true
|
||||
111
scripts/restore.sh
Normal file
111
scripts/restore.sh
Normal file
@@ -0,0 +1,111 @@
|
||||
#!/usr/bin/env bash
|
||||
# Restore script for The Ouroboros trading system
|
||||
# Restores database from a backup file
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Configuration
|
||||
DB_PATH="${DB_PATH:-data/trade_logs.db}"
|
||||
BACKUP_DIR="${BACKUP_DIR:-data/backups}"
|
||||
PYTHON="${PYTHON:-python3}"
|
||||
|
||||
# Colors for output
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
log_info() {
|
||||
echo -e "${GREEN}[INFO]${NC} $1"
|
||||
}
|
||||
|
||||
log_warn() {
|
||||
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||
}
|
||||
|
||||
log_error() {
|
||||
echo -e "${RED}[ERROR]${NC} $1"
|
||||
}
|
||||
|
||||
# Check if backup directory exists
|
||||
if [ ! -d "$BACKUP_DIR" ]; then
|
||||
log_error "Backup directory not found: $BACKUP_DIR"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
log_info "Available backups:"
|
||||
log_info "=================="
|
||||
|
||||
# List available backups
|
||||
$PYTHON -c "
|
||||
from pathlib import Path
|
||||
from src.backup.scheduler import BackupScheduler
|
||||
|
||||
scheduler = BackupScheduler(
|
||||
db_path='$DB_PATH',
|
||||
backup_dir=Path('$BACKUP_DIR')
|
||||
)
|
||||
|
||||
backups = scheduler.list_backups()
|
||||
|
||||
if not backups:
|
||||
print('No backups found.')
|
||||
exit(1)
|
||||
|
||||
for i, backup in enumerate(backups, 1):
|
||||
size_mb = backup.size_bytes / 1024 / 1024
|
||||
print(f'{i}. [{backup.policy.value.upper()}] {backup.file_path.name}')
|
||||
print(f' Date: {backup.timestamp.strftime(\"%Y-%m-%d %H:%M:%S UTC\")}')
|
||||
print(f' Size: {size_mb:.2f} MB')
|
||||
print()
|
||||
"
|
||||
|
||||
# Ask user to select backup
|
||||
echo ""
|
||||
read -p "Enter backup number to restore (or 'q' to quit): " BACKUP_NUM
|
||||
|
||||
if [ "$BACKUP_NUM" == "q" ]; then
|
||||
log_info "Restore cancelled"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Confirm restoration
|
||||
log_warn "WARNING: This will replace the current database!"
|
||||
log_warn "Current database will be backed up to: ${DB_PATH}.before_restore"
|
||||
read -p "Are you sure you want to continue? (yes/no): " CONFIRM
|
||||
|
||||
if [ "$CONFIRM" != "yes" ]; then
|
||||
log_info "Restore cancelled"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Perform restoration
|
||||
$PYTHON -c "
|
||||
from pathlib import Path
|
||||
from src.backup.scheduler import BackupScheduler
|
||||
|
||||
scheduler = BackupScheduler(
|
||||
db_path='$DB_PATH',
|
||||
backup_dir=Path('$BACKUP_DIR')
|
||||
)
|
||||
|
||||
backups = scheduler.list_backups()
|
||||
backup_index = int('$BACKUP_NUM') - 1
|
||||
|
||||
if backup_index < 0 or backup_index >= len(backups):
|
||||
print('Invalid backup number')
|
||||
exit(1)
|
||||
|
||||
selected = backups[backup_index]
|
||||
print(f'Restoring: {selected.file_path.name}')
|
||||
|
||||
scheduler.restore_backup(selected, verify=True)
|
||||
print('Restore completed successfully')
|
||||
"
|
||||
|
||||
if [ $? -eq 0 ]; then
|
||||
log_info "Database restored successfully"
|
||||
else
|
||||
log_error "Restore failed"
|
||||
exit 1
|
||||
fi
|
||||
87
scripts/run_overnight.sh
Executable file
87
scripts/run_overnight.sh
Executable file
@@ -0,0 +1,87 @@
|
||||
#!/usr/bin/env bash
|
||||
# Start The Ouroboros overnight with logs and watchdog.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
LOG_DIR="${LOG_DIR:-data/overnight}"
|
||||
CHECK_INTERVAL="${CHECK_INTERVAL:-30}"
|
||||
TMUX_AUTO="${TMUX_AUTO:-true}"
|
||||
TMUX_ATTACH="${TMUX_ATTACH:-true}"
|
||||
TMUX_SESSION_PREFIX="${TMUX_SESSION_PREFIX:-ouroboros_overnight}"
|
||||
|
||||
if [ -z "${APP_CMD:-}" ]; then
|
||||
if [ -x ".venv/bin/python" ]; then
|
||||
PYTHON_BIN=".venv/bin/python"
|
||||
elif command -v python3 >/dev/null 2>&1; then
|
||||
PYTHON_BIN="python3"
|
||||
elif command -v python >/dev/null 2>&1; then
|
||||
PYTHON_BIN="python"
|
||||
else
|
||||
echo ".venv/bin/python 또는 python3/python 실행 파일을 찾을 수 없습니다."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
dashboard_port="${DASHBOARD_PORT:-8080}"
|
||||
|
||||
APP_CMD="DASHBOARD_PORT=$dashboard_port $PYTHON_BIN -m src.main --mode=live --dashboard"
|
||||
fi
|
||||
|
||||
mkdir -p "$LOG_DIR"
|
||||
|
||||
timestamp="$(date +"%Y%m%d_%H%M%S")"
|
||||
RUN_LOG="$LOG_DIR/run_${timestamp}.log"
|
||||
WATCHDOG_LOG="$LOG_DIR/watchdog_${timestamp}.log"
|
||||
PID_FILE="$LOG_DIR/app.pid"
|
||||
WATCHDOG_PID_FILE="$LOG_DIR/watchdog.pid"
|
||||
|
||||
if [ -f "$PID_FILE" ]; then
|
||||
old_pid="$(cat "$PID_FILE" || true)"
|
||||
if [ -n "$old_pid" ] && kill -0 "$old_pid" 2>/dev/null; then
|
||||
echo "앱이 이미 실행 중입니다. pid=$old_pid"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "[$(date -u +"%Y-%m-%dT%H:%M:%SZ")] starting: $APP_CMD" | tee -a "$RUN_LOG"
|
||||
nohup bash -lc "$APP_CMD" >>"$RUN_LOG" 2>&1 &
|
||||
app_pid=$!
|
||||
echo "$app_pid" > "$PID_FILE"
|
||||
|
||||
echo "[$(date -u +"%Y-%m-%dT%H:%M:%SZ")] app pid=$app_pid" | tee -a "$RUN_LOG"
|
||||
|
||||
nohup env PID_FILE="$PID_FILE" LOG_FILE="$WATCHDOG_LOG" CHECK_INTERVAL="$CHECK_INTERVAL" \
|
||||
bash scripts/watchdog.sh >/dev/null 2>&1 &
|
||||
watchdog_pid=$!
|
||||
echo "$watchdog_pid" > "$WATCHDOG_PID_FILE"
|
||||
|
||||
cat <<EOF
|
||||
시작 완료
|
||||
- app pid: $app_pid
|
||||
- watchdog pid: $watchdog_pid
|
||||
- app log: $RUN_LOG
|
||||
- watchdog log: $WATCHDOG_LOG
|
||||
|
||||
실시간 확인:
|
||||
tail -f "$RUN_LOG"
|
||||
tail -f "$WATCHDOG_LOG"
|
||||
EOF
|
||||
|
||||
if [ "$TMUX_AUTO" = "true" ]; then
|
||||
if ! command -v tmux >/dev/null 2>&1; then
|
||||
echo "tmux를 찾지 못해 자동 세션 생성은 건너뜁니다."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
session_name="${TMUX_SESSION_PREFIX}_${timestamp}"
|
||||
window_name="overnight"
|
||||
tmux new-session -d -s "$session_name" -n "$window_name" "tail -f '$RUN_LOG'"
|
||||
tmux split-window -t "${session_name}:${window_name}" -v "tail -f '$WATCHDOG_LOG'"
|
||||
tmux select-layout -t "${session_name}:${window_name}" even-vertical
|
||||
|
||||
echo "tmux session 생성: $session_name"
|
||||
echo "수동 접속: tmux attach -t $session_name"
|
||||
|
||||
if [ -z "${TMUX:-}" ] && [ "$TMUX_ATTACH" = "true" ]; then
|
||||
tmux attach -t "$session_name"
|
||||
fi
|
||||
fi
|
||||
76
scripts/stop_overnight.sh
Executable file
76
scripts/stop_overnight.sh
Executable file
@@ -0,0 +1,76 @@
|
||||
#!/usr/bin/env bash
|
||||
# Stop The Ouroboros overnight app/watchdog/tmux session.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
LOG_DIR="${LOG_DIR:-data/overnight}"
|
||||
PID_FILE="$LOG_DIR/app.pid"
|
||||
WATCHDOG_PID_FILE="$LOG_DIR/watchdog.pid"
|
||||
TMUX_SESSION_PREFIX="${TMUX_SESSION_PREFIX:-ouroboros_overnight}"
|
||||
KILL_TIMEOUT="${KILL_TIMEOUT:-5}"
|
||||
|
||||
stop_pid() {
|
||||
local name="$1"
|
||||
local pid="$2"
|
||||
|
||||
if [ -z "$pid" ]; then
|
||||
echo "$name PID가 비어 있습니다."
|
||||
return 1
|
||||
fi
|
||||
|
||||
if ! kill -0 "$pid" 2>/dev/null; then
|
||||
echo "$name 프로세스가 이미 종료됨 (pid=$pid)"
|
||||
return 0
|
||||
fi
|
||||
|
||||
kill "$pid" 2>/dev/null || true
|
||||
for _ in $(seq 1 "$KILL_TIMEOUT"); do
|
||||
if ! kill -0 "$pid" 2>/dev/null; then
|
||||
echo "$name 종료됨 (pid=$pid)"
|
||||
return 0
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
kill -9 "$pid" 2>/dev/null || true
|
||||
if ! kill -0 "$pid" 2>/dev/null; then
|
||||
echo "$name 강제 종료됨 (pid=$pid)"
|
||||
return 0
|
||||
fi
|
||||
|
||||
echo "$name 종료 실패 (pid=$pid)"
|
||||
return 1
|
||||
}
|
||||
|
||||
status=0
|
||||
|
||||
if [ -f "$WATCHDOG_PID_FILE" ]; then
|
||||
watchdog_pid="$(cat "$WATCHDOG_PID_FILE" || true)"
|
||||
stop_pid "watchdog" "$watchdog_pid" || status=1
|
||||
rm -f "$WATCHDOG_PID_FILE"
|
||||
else
|
||||
echo "watchdog pid 파일 없음: $WATCHDOG_PID_FILE"
|
||||
fi
|
||||
|
||||
if [ -f "$PID_FILE" ]; then
|
||||
app_pid="$(cat "$PID_FILE" || true)"
|
||||
stop_pid "app" "$app_pid" || status=1
|
||||
rm -f "$PID_FILE"
|
||||
else
|
||||
echo "app pid 파일 없음: $PID_FILE"
|
||||
fi
|
||||
|
||||
if command -v tmux >/dev/null 2>&1; then
|
||||
sessions="$(tmux ls 2>/dev/null | awk -F: -v p="$TMUX_SESSION_PREFIX" '$1 ~ "^" p "_" {print $1}')"
|
||||
if [ -n "$sessions" ]; then
|
||||
while IFS= read -r s; do
|
||||
[ -z "$s" ] && continue
|
||||
tmux kill-session -t "$s" 2>/dev/null || true
|
||||
echo "tmux 세션 종료: $s"
|
||||
done <<< "$sessions"
|
||||
else
|
||||
echo "종료할 tmux 세션 없음 (prefix=${TMUX_SESSION_PREFIX}_)"
|
||||
fi
|
||||
fi
|
||||
|
||||
exit "$status"
|
||||
140
scripts/validate_ouroboros_docs.py
Executable file
140
scripts/validate_ouroboros_docs.py
Executable file
@@ -0,0 +1,140 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Validate Ouroboros planning docs for metadata, links, and ID consistency."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
DOC_DIR = Path("docs/ouroboros")
|
||||
META_PATTERN = re.compile(
|
||||
r"<!--\n"
|
||||
r"Doc-ID: (?P<doc_id>[^\n]+)\n"
|
||||
r"Version: (?P<version>[^\n]+)\n"
|
||||
r"Status: (?P<status>[^\n]+)\n"
|
||||
r"Owner: (?P<owner>[^\n]+)\n"
|
||||
r"Updated: (?P<updated>\d{4}-\d{2}-\d{2})\n"
|
||||
r"-->",
|
||||
re.MULTILINE,
|
||||
)
|
||||
ID_PATTERN = re.compile(r"\b(?:REQ|RULE|TASK|TEST|DOC)-[A-Z0-9-]+-\d{3}\b")
|
||||
DEF_PATTERN = re.compile(r"^-\s+`(?P<id>(?:REQ|RULE|TASK|TEST|DOC)-[A-Z0-9-]+-\d{3})`", re.MULTILINE)
|
||||
LINK_PATTERN = re.compile(r"\[[^\]]+\]\((?P<link>[^)]+)\)")
|
||||
LINE_DEF_PATTERN = re.compile(r"^-\s+`(?P<id>(?:REQ|RULE|TASK|TEST|DOC)-[A-Z0-9-]+-\d{3})`.*$", re.MULTILINE)
|
||||
|
||||
|
||||
def iter_docs() -> list[Path]:
|
||||
return sorted([p for p in DOC_DIR.glob("*.md") if p.is_file()])
|
||||
|
||||
|
||||
def validate_metadata(path: Path, text: str, errors: list[str], doc_ids: dict[str, Path]) -> None:
|
||||
match = META_PATTERN.search(text)
|
||||
if not match:
|
||||
errors.append(f"{path}: missing or malformed metadata block")
|
||||
return
|
||||
doc_id = match.group("doc_id").strip()
|
||||
if doc_id in doc_ids:
|
||||
errors.append(f"{path}: duplicate Doc-ID {doc_id} (already in {doc_ids[doc_id]})")
|
||||
else:
|
||||
doc_ids[doc_id] = path
|
||||
|
||||
|
||||
def validate_links(path: Path, text: str, errors: list[str]) -> None:
|
||||
for m in LINK_PATTERN.finditer(text):
|
||||
link = m.group("link").strip()
|
||||
if not link or link.startswith("http") or link.startswith("#"):
|
||||
continue
|
||||
if link.startswith("/"):
|
||||
target = Path(link)
|
||||
else:
|
||||
target = (path.parent / link).resolve()
|
||||
if not target.exists():
|
||||
errors.append(f"{path}: broken link -> {link}")
|
||||
|
||||
|
||||
def collect_ids(path: Path, text: str, defs: dict[str, Path], refs: dict[str, set[Path]]) -> None:
|
||||
for m in DEF_PATTERN.finditer(text):
|
||||
defs[m.group("id")] = path
|
||||
for m in ID_PATTERN.finditer(text):
|
||||
idv = m.group(0)
|
||||
refs.setdefault(idv, set()).add(path)
|
||||
|
||||
|
||||
def collect_req_traceability(text: str, req_to_task: dict[str, set[str]], req_to_test: dict[str, set[str]]) -> None:
|
||||
for m in LINE_DEF_PATTERN.finditer(text):
|
||||
line = m.group(0)
|
||||
item_id = m.group("id")
|
||||
req_ids = [rid for rid in ID_PATTERN.findall(line) if rid.startswith("REQ-")]
|
||||
if item_id.startswith("TASK-"):
|
||||
for req_id in req_ids:
|
||||
req_to_task.setdefault(req_id, set()).add(item_id)
|
||||
if item_id.startswith("TEST-"):
|
||||
for req_id in req_ids:
|
||||
req_to_test.setdefault(req_id, set()).add(item_id)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
if not DOC_DIR.exists():
|
||||
print(f"ERROR: missing directory {DOC_DIR}")
|
||||
return 1
|
||||
|
||||
docs = iter_docs()
|
||||
if not docs:
|
||||
print(f"ERROR: no markdown docs found in {DOC_DIR}")
|
||||
return 1
|
||||
|
||||
errors: list[str] = []
|
||||
doc_ids: dict[str, Path] = {}
|
||||
defs: dict[str, Path] = {}
|
||||
refs: dict[str, set[Path]] = {}
|
||||
req_to_task: dict[str, set[str]] = {}
|
||||
req_to_test: dict[str, set[str]] = {}
|
||||
|
||||
for path in docs:
|
||||
text = path.read_text(encoding="utf-8")
|
||||
validate_metadata(path, text, errors, doc_ids)
|
||||
validate_links(path, text, errors)
|
||||
collect_ids(path, text, defs, refs)
|
||||
collect_req_traceability(text, req_to_task, req_to_test)
|
||||
|
||||
for idv, where_used in sorted(refs.items()):
|
||||
if idv.startswith("DOC-"):
|
||||
continue
|
||||
if idv not in defs:
|
||||
files = ", ".join(str(p) for p in sorted(where_used))
|
||||
errors.append(f"undefined ID {idv}, used in: {files}")
|
||||
|
||||
for idv in sorted(defs):
|
||||
if not idv.startswith("REQ-"):
|
||||
continue
|
||||
if idv not in req_to_task:
|
||||
errors.append(f"REQ without TASK mapping: {idv}")
|
||||
if idv not in req_to_test:
|
||||
errors.append(f"REQ without TEST mapping: {idv}")
|
||||
|
||||
warnings: list[str] = []
|
||||
for idv, where_def in sorted(defs.items()):
|
||||
if len(refs.get(idv, set())) <= 1 and (idv.startswith("REQ-") or idv.startswith("RULE-")):
|
||||
warnings.append(f"orphan ID {idv} defined in {where_def} (not referenced elsewhere)")
|
||||
|
||||
if errors:
|
||||
print("[FAIL] Ouroboros docs validation failed")
|
||||
for err in errors:
|
||||
print(f"- {err}")
|
||||
return 1
|
||||
|
||||
print(f"[OK] validated {len(docs)} docs in {DOC_DIR}")
|
||||
print(f"[OK] unique Doc-ID: {len(doc_ids)}")
|
||||
print(f"[OK] definitions: {len(defs)}, references: {len(refs)}")
|
||||
print(f"[OK] req->task mappings: {len(req_to_task)}")
|
||||
print(f"[OK] req->test mappings: {len(req_to_test)}")
|
||||
if warnings:
|
||||
print(f"[WARN] orphan IDs: {len(warnings)}")
|
||||
for w in warnings:
|
||||
print(f"- {w}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
42
scripts/watchdog.sh
Executable file
42
scripts/watchdog.sh
Executable file
@@ -0,0 +1,42 @@
|
||||
#!/usr/bin/env bash
|
||||
# Simple watchdog for The Ouroboros process.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
PID_FILE="${PID_FILE:-data/overnight/app.pid}"
|
||||
LOG_FILE="${LOG_FILE:-data/overnight/watchdog.log}"
|
||||
CHECK_INTERVAL="${CHECK_INTERVAL:-30}"
|
||||
STATUS_EVERY="${STATUS_EVERY:-10}"
|
||||
|
||||
mkdir -p "$(dirname "$LOG_FILE")"
|
||||
|
||||
log() {
|
||||
printf '%s %s\n' "$(date -u +"%Y-%m-%dT%H:%M:%SZ")" "$1" | tee -a "$LOG_FILE"
|
||||
}
|
||||
|
||||
if [ ! -f "$PID_FILE" ]; then
|
||||
log "[ERROR] pid file not found: $PID_FILE"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PID="$(cat "$PID_FILE")"
|
||||
if [ -z "$PID" ]; then
|
||||
log "[ERROR] pid file is empty: $PID_FILE"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
log "[INFO] watchdog started (pid=$PID, interval=${CHECK_INTERVAL}s)"
|
||||
|
||||
count=0
|
||||
while true; do
|
||||
if kill -0 "$PID" 2>/dev/null; then
|
||||
count=$((count + 1))
|
||||
if [ $((count % STATUS_EVERY)) -eq 0 ]; then
|
||||
log "[INFO] process alive (pid=$PID)"
|
||||
fi
|
||||
else
|
||||
log "[ERROR] process stopped (pid=$PID)"
|
||||
exit 1
|
||||
fi
|
||||
sleep "$CHECK_INTERVAL"
|
||||
done
|
||||
9
src/analysis/__init__.py
Normal file
9
src/analysis/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""Technical analysis and market scanning modules."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from src.analysis.scanner import MarketScanner
|
||||
from src.analysis.smart_scanner import ScanCandidate, SmartVolatilityScanner
|
||||
from src.analysis.volatility import VolatilityAnalyzer
|
||||
|
||||
__all__ = ["VolatilityAnalyzer", "MarketScanner", "SmartVolatilityScanner", "ScanCandidate"]
|
||||
244
src/analysis/scanner.py
Normal file
244
src/analysis/scanner.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""Real-time market scanner for detecting high-momentum stocks.
|
||||
|
||||
Scans all available stocks in a market and ranks by volatility/momentum score.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from src.analysis.volatility import VolatilityAnalyzer, VolatilityMetrics
|
||||
from src.broker.kis_api import KISBroker
|
||||
from src.broker.overseas import OverseasBroker
|
||||
from src.context.layer import ContextLayer
|
||||
from src.context.store import ContextStore
|
||||
from src.markets.schedule import MarketInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScanResult:
|
||||
"""Result from a market scan."""
|
||||
|
||||
market_code: str
|
||||
timestamp: str
|
||||
total_scanned: int
|
||||
top_movers: list[VolatilityMetrics]
|
||||
breakouts: list[str] # Stock codes with breakout patterns
|
||||
breakdowns: list[str] # Stock codes with breakdown patterns
|
||||
|
||||
|
||||
class MarketScanner:
|
||||
"""Scans markets for high-volatility, high-momentum stocks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
broker: KISBroker,
|
||||
overseas_broker: OverseasBroker,
|
||||
volatility_analyzer: VolatilityAnalyzer,
|
||||
context_store: ContextStore,
|
||||
top_n: int = 5,
|
||||
max_concurrent_scans: int = 1,
|
||||
) -> None:
|
||||
"""Initialize the market scanner.
|
||||
|
||||
Args:
|
||||
broker: KIS broker instance for domestic market
|
||||
overseas_broker: Overseas broker instance
|
||||
volatility_analyzer: Volatility analyzer instance
|
||||
context_store: Context store for L7 real-time data
|
||||
top_n: Number of top movers to return per market (default 5)
|
||||
max_concurrent_scans: Max concurrent stock scans (default 1, fully serialized)
|
||||
"""
|
||||
self.broker = broker
|
||||
self.overseas_broker = overseas_broker
|
||||
self.analyzer = volatility_analyzer
|
||||
self.context_store = context_store
|
||||
self.top_n = top_n
|
||||
self._scan_semaphore = asyncio.Semaphore(max_concurrent_scans)
|
||||
|
||||
async def scan_stock(
|
||||
self,
|
||||
stock_code: str,
|
||||
market: MarketInfo,
|
||||
) -> VolatilityMetrics | None:
|
||||
"""Scan a single stock for volatility metrics.
|
||||
|
||||
Args:
|
||||
stock_code: Stock code to scan
|
||||
market: Market information
|
||||
|
||||
Returns:
|
||||
VolatilityMetrics if successful, None on error
|
||||
"""
|
||||
try:
|
||||
if market.is_domestic:
|
||||
orderbook = await self.broker.get_orderbook(stock_code)
|
||||
else:
|
||||
# For overseas, we need to adapt the price data structure
|
||||
price_data = await self.overseas_broker.get_overseas_price(
|
||||
market.exchange_code, stock_code
|
||||
)
|
||||
# Convert to orderbook-like structure
|
||||
orderbook = {
|
||||
"output1": {
|
||||
"stck_prpr": price_data.get("output", {}).get("last", "0") or "0",
|
||||
"acml_vol": price_data.get("output", {}).get("tvol", "0") or "0",
|
||||
}
|
||||
}
|
||||
|
||||
# For now, use empty price history (would need real historical data)
|
||||
# In production, this would fetch from a time-series database or API
|
||||
price_history: dict[str, Any] = {
|
||||
"high": [],
|
||||
"low": [],
|
||||
"close": [],
|
||||
"volume": [],
|
||||
}
|
||||
|
||||
metrics = self.analyzer.analyze(stock_code, orderbook, price_history)
|
||||
|
||||
# Store in L7 real-time layer
|
||||
from datetime import UTC, datetime
|
||||
timeframe = datetime.now(UTC).isoformat()
|
||||
self.context_store.set_context(
|
||||
ContextLayer.L7_REALTIME,
|
||||
timeframe,
|
||||
f"volatility_{market.code}_{stock_code}",
|
||||
{
|
||||
"price": metrics.current_price,
|
||||
"atr": metrics.atr,
|
||||
"price_change_1m": metrics.price_change_1m,
|
||||
"volume_surge": metrics.volume_surge,
|
||||
"momentum_score": metrics.momentum_score,
|
||||
},
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to scan %s (%s): %s", stock_code, market.code, exc)
|
||||
return None
|
||||
|
||||
async def scan_market(
|
||||
self,
|
||||
market: MarketInfo,
|
||||
stock_codes: list[str],
|
||||
) -> ScanResult:
|
||||
"""Scan all stocks in a market and rank by momentum.
|
||||
|
||||
Args:
|
||||
market: Market to scan
|
||||
stock_codes: List of stock codes to scan
|
||||
|
||||
Returns:
|
||||
ScanResult with ranked stocks
|
||||
"""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
logger.info("Scanning %s market (%d stocks)", market.name, len(stock_codes))
|
||||
|
||||
# Scan stocks with bounded concurrency to prevent API rate limit burst
|
||||
async def _bounded_scan(code: str) -> VolatilityMetrics | None:
|
||||
async with self._scan_semaphore:
|
||||
return await self.scan_stock(code, market)
|
||||
|
||||
tasks = [_bounded_scan(code) for code in stock_codes]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Filter out failures and sort by momentum score
|
||||
valid_metrics = [m for m in results if m is not None]
|
||||
valid_metrics.sort(key=lambda m: m.momentum_score, reverse=True)
|
||||
|
||||
# Get top N movers
|
||||
top_movers = valid_metrics[: self.top_n]
|
||||
|
||||
# Detect breakouts and breakdowns
|
||||
breakouts = [
|
||||
m.stock_code for m in valid_metrics if self.analyzer.is_breakout(m)
|
||||
]
|
||||
breakdowns = [
|
||||
m.stock_code for m in valid_metrics if self.analyzer.is_breakdown(m)
|
||||
]
|
||||
|
||||
logger.info(
|
||||
"%s scan complete: %d scanned, top momentum=%.1f, %d breakouts, %d breakdowns",
|
||||
market.name,
|
||||
len(valid_metrics),
|
||||
top_movers[0].momentum_score if top_movers else 0.0,
|
||||
len(breakouts),
|
||||
len(breakdowns),
|
||||
)
|
||||
|
||||
# Store scan results in L7
|
||||
timeframe = datetime.now(UTC).isoformat()
|
||||
self.context_store.set_context(
|
||||
ContextLayer.L7_REALTIME,
|
||||
timeframe,
|
||||
f"scan_result_{market.code}",
|
||||
{
|
||||
"total_scanned": len(valid_metrics),
|
||||
"top_movers": [m.stock_code for m in top_movers],
|
||||
"breakouts": breakouts,
|
||||
"breakdowns": breakdowns,
|
||||
},
|
||||
)
|
||||
|
||||
return ScanResult(
|
||||
market_code=market.code,
|
||||
timestamp=timeframe,
|
||||
total_scanned=len(valid_metrics),
|
||||
top_movers=top_movers,
|
||||
breakouts=breakouts,
|
||||
breakdowns=breakdowns,
|
||||
)
|
||||
|
||||
def get_updated_watchlist(
|
||||
self,
|
||||
current_watchlist: list[str],
|
||||
scan_result: ScanResult,
|
||||
max_replacements: int = 2,
|
||||
) -> list[str]:
|
||||
"""Update watchlist by replacing laggards with leaders.
|
||||
|
||||
Args:
|
||||
current_watchlist: Current watchlist
|
||||
scan_result: Recent scan result
|
||||
max_replacements: Maximum stocks to replace per scan
|
||||
|
||||
Returns:
|
||||
Updated watchlist with leaders
|
||||
"""
|
||||
# Keep stocks that are in top movers
|
||||
top_codes = [m.stock_code for m in scan_result.top_movers]
|
||||
keepers = [code for code in current_watchlist if code in top_codes]
|
||||
|
||||
# Add new leaders not in current watchlist
|
||||
new_leaders = [code for code in top_codes if code not in current_watchlist]
|
||||
|
||||
# Limit replacements
|
||||
new_leaders = new_leaders[:max_replacements]
|
||||
|
||||
# Create updated watchlist
|
||||
updated = keepers + new_leaders
|
||||
|
||||
# If we removed too many, backfill from current watchlist
|
||||
if len(updated) < len(current_watchlist):
|
||||
backfill = [
|
||||
code for code in current_watchlist
|
||||
if code not in updated
|
||||
][: len(current_watchlist) - len(updated)]
|
||||
updated.extend(backfill)
|
||||
|
||||
logger.info(
|
||||
"Watchlist updated: %d kept, %d new leaders, %d total",
|
||||
len(keepers),
|
||||
len(new_leaders),
|
||||
len(updated),
|
||||
)
|
||||
|
||||
return updated
|
||||
449
src/analysis/smart_scanner.py
Normal file
449
src/analysis/smart_scanner.py
Normal file
@@ -0,0 +1,449 @@
|
||||
"""Smart Volatility Scanner with volatility-first market ranking logic."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from src.analysis.volatility import VolatilityAnalyzer
|
||||
from src.broker.kis_api import KISBroker
|
||||
from src.broker.overseas import OverseasBroker
|
||||
from src.config import Settings
|
||||
from src.markets.schedule import MarketInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScanCandidate:
|
||||
"""A qualified candidate from the smart scanner."""
|
||||
|
||||
stock_code: str
|
||||
name: str
|
||||
price: float
|
||||
volume: float
|
||||
volume_ratio: float # Current volume / previous day volume
|
||||
rsi: float
|
||||
signal: str # "oversold" or "momentum"
|
||||
score: float # Composite score for ranking
|
||||
|
||||
|
||||
class SmartVolatilityScanner:
|
||||
"""Scans market rankings and applies volatility-first filters.
|
||||
|
||||
Flow:
|
||||
1. Fetch fluctuation rankings as primary universe
|
||||
2. Fetch volume rankings for liquidity bonus
|
||||
3. Score by volatility first, liquidity second
|
||||
4. Return top N qualified candidates
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
broker: KISBroker,
|
||||
overseas_broker: OverseasBroker | None,
|
||||
volatility_analyzer: VolatilityAnalyzer,
|
||||
settings: Settings,
|
||||
) -> None:
|
||||
"""Initialize the smart scanner.
|
||||
|
||||
Args:
|
||||
broker: KIS broker for API calls
|
||||
volatility_analyzer: Analyzer for RSI calculation
|
||||
settings: Application settings
|
||||
"""
|
||||
self.broker = broker
|
||||
self.overseas_broker = overseas_broker
|
||||
self.analyzer = volatility_analyzer
|
||||
self.settings = settings
|
||||
|
||||
# Extract scanner settings
|
||||
self.rsi_oversold = settings.RSI_OVERSOLD_THRESHOLD
|
||||
self.rsi_momentum = settings.RSI_MOMENTUM_THRESHOLD
|
||||
self.vol_multiplier = settings.VOL_MULTIPLIER
|
||||
self.top_n = settings.SCANNER_TOP_N
|
||||
|
||||
async def scan(
|
||||
self,
|
||||
market: MarketInfo | None = None,
|
||||
fallback_stocks: list[str] | None = None,
|
||||
) -> list[ScanCandidate]:
|
||||
"""Execute smart scan and return qualified candidates.
|
||||
|
||||
Args:
|
||||
market: Target market info (domestic vs overseas behavior)
|
||||
fallback_stocks: Stock codes to use if ranking API fails
|
||||
|
||||
Returns:
|
||||
List of ScanCandidate, sorted by score, up to top_n items
|
||||
"""
|
||||
if market and not market.is_domestic:
|
||||
return await self._scan_overseas(market, fallback_stocks)
|
||||
|
||||
return await self._scan_domestic(fallback_stocks)
|
||||
|
||||
async def _scan_domestic(
|
||||
self,
|
||||
fallback_stocks: list[str] | None = None,
|
||||
) -> list[ScanCandidate]:
|
||||
"""Scan domestic market using volatility-first ranking + liquidity bonus."""
|
||||
# 1) Primary universe from fluctuation ranking.
|
||||
try:
|
||||
fluct_rows = await self.broker.fetch_market_rankings(
|
||||
ranking_type="fluctuation",
|
||||
limit=50,
|
||||
)
|
||||
except ConnectionError as exc:
|
||||
logger.warning("Domestic fluctuation ranking failed: %s", exc)
|
||||
fluct_rows = []
|
||||
|
||||
# 2) Liquidity bonus from volume ranking.
|
||||
try:
|
||||
volume_rows = await self.broker.fetch_market_rankings(
|
||||
ranking_type="volume",
|
||||
limit=50,
|
||||
)
|
||||
except ConnectionError as exc:
|
||||
logger.warning("Domestic volume ranking failed: %s", exc)
|
||||
volume_rows = []
|
||||
|
||||
if not fluct_rows and fallback_stocks:
|
||||
logger.info(
|
||||
"Domestic ranking unavailable; using fallback symbols (%d)",
|
||||
len(fallback_stocks),
|
||||
)
|
||||
fluct_rows = [
|
||||
{
|
||||
"stock_code": code,
|
||||
"name": code,
|
||||
"price": 0.0,
|
||||
"volume": 0.0,
|
||||
"change_rate": 0.0,
|
||||
"volume_increase_rate": 0.0,
|
||||
}
|
||||
for code in fallback_stocks
|
||||
]
|
||||
|
||||
if not fluct_rows:
|
||||
return []
|
||||
|
||||
volume_rank_bonus: dict[str, float] = {}
|
||||
for idx, row in enumerate(volume_rows):
|
||||
code = _extract_stock_code(row)
|
||||
if not code:
|
||||
continue
|
||||
volume_rank_bonus[code] = max(0.0, 15.0 - idx * 0.3)
|
||||
|
||||
candidates: list[ScanCandidate] = []
|
||||
for stock in fluct_rows:
|
||||
stock_code = _extract_stock_code(stock)
|
||||
if not stock_code:
|
||||
continue
|
||||
|
||||
try:
|
||||
price = _extract_last_price(stock)
|
||||
change_rate = _extract_change_rate_pct(stock)
|
||||
volume = _extract_volume(stock)
|
||||
|
||||
intraday_range_pct = 0.0
|
||||
volume_ratio = _safe_float(stock.get("volume_increase_rate"), 0.0) / 100.0 + 1.0
|
||||
|
||||
# Use daily chart to refine range/volume when available.
|
||||
daily_prices = await self.broker.get_daily_prices(stock_code, days=2)
|
||||
if daily_prices:
|
||||
latest = daily_prices[-1]
|
||||
latest_close = _safe_float(latest.get("close"), default=price)
|
||||
if price <= 0:
|
||||
price = latest_close
|
||||
latest_high = _safe_float(latest.get("high"))
|
||||
latest_low = _safe_float(latest.get("low"))
|
||||
if latest_close > 0 and latest_high > 0 and latest_low > 0 and latest_high >= latest_low:
|
||||
intraday_range_pct = (latest_high - latest_low) / latest_close * 100.0
|
||||
if volume <= 0:
|
||||
volume = _safe_float(latest.get("volume"))
|
||||
if len(daily_prices) >= 2:
|
||||
prev_day_volume = _safe_float(daily_prices[-2].get("volume"))
|
||||
if prev_day_volume > 0:
|
||||
volume_ratio = max(volume_ratio, volume / prev_day_volume)
|
||||
|
||||
volatility_pct = max(abs(change_rate), intraday_range_pct)
|
||||
if price <= 0 or volatility_pct < 0.8:
|
||||
continue
|
||||
|
||||
volatility_score = min(volatility_pct / 10.0, 1.0) * 85.0
|
||||
liquidity_score = volume_rank_bonus.get(stock_code, 0.0)
|
||||
score = min(100.0, volatility_score + liquidity_score)
|
||||
signal = "momentum" if change_rate >= 0 else "oversold"
|
||||
implied_rsi = max(0.0, min(100.0, 50.0 + (change_rate * 2.0)))
|
||||
|
||||
candidates.append(
|
||||
ScanCandidate(
|
||||
stock_code=stock_code,
|
||||
name=stock.get("name", stock_code),
|
||||
price=price,
|
||||
volume=volume,
|
||||
volume_ratio=max(1.0, volume_ratio, volatility_pct / 2.0),
|
||||
rsi=implied_rsi,
|
||||
signal=signal,
|
||||
score=score,
|
||||
)
|
||||
)
|
||||
|
||||
except ConnectionError as exc:
|
||||
logger.warning("Failed to analyze %s: %s", stock_code, exc)
|
||||
continue
|
||||
except Exception as exc:
|
||||
logger.error("Unexpected error analyzing %s: %s", stock_code, exc)
|
||||
continue
|
||||
|
||||
logger.info("Domestic ranking scan found %d candidates", len(candidates))
|
||||
candidates.sort(key=lambda c: c.score, reverse=True)
|
||||
return candidates[: self.top_n]
|
||||
|
||||
async def _scan_overseas(
|
||||
self,
|
||||
market: MarketInfo,
|
||||
fallback_stocks: list[str] | None = None,
|
||||
) -> list[ScanCandidate]:
|
||||
"""Scan overseas symbols using ranking API first, then fallback universe."""
|
||||
if self.overseas_broker is None:
|
||||
logger.warning(
|
||||
"Overseas scanner unavailable for %s: overseas broker not configured",
|
||||
market.name,
|
||||
)
|
||||
return []
|
||||
|
||||
candidates = await self._scan_overseas_from_rankings(market)
|
||||
if not candidates:
|
||||
candidates = await self._scan_overseas_from_symbols(market, fallback_stocks)
|
||||
|
||||
candidates.sort(key=lambda c: c.score, reverse=True)
|
||||
return candidates[: self.top_n]
|
||||
|
||||
async def _scan_overseas_from_rankings(
|
||||
self,
|
||||
market: MarketInfo,
|
||||
) -> list[ScanCandidate]:
|
||||
"""Build overseas candidates from ranking APIs using volatility-first scoring."""
|
||||
assert self.overseas_broker is not None
|
||||
try:
|
||||
fluct_rows = await self.overseas_broker.fetch_overseas_rankings(
|
||||
exchange_code=market.exchange_code,
|
||||
ranking_type="fluctuation",
|
||||
limit=50,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Overseas fluctuation ranking failed for %s: %s", market.code, exc
|
||||
)
|
||||
fluct_rows = []
|
||||
|
||||
if not fluct_rows:
|
||||
return []
|
||||
|
||||
volume_rank_bonus: dict[str, float] = {}
|
||||
try:
|
||||
volume_rows = await self.overseas_broker.fetch_overseas_rankings(
|
||||
exchange_code=market.exchange_code,
|
||||
ranking_type="volume",
|
||||
limit=50,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Overseas volume ranking failed for %s: %s", market.code, exc
|
||||
)
|
||||
volume_rows = []
|
||||
|
||||
for idx, row in enumerate(volume_rows):
|
||||
code = _extract_stock_code(row)
|
||||
if not code:
|
||||
continue
|
||||
# Top-ranked by traded value/volume gets higher liquidity bonus.
|
||||
volume_rank_bonus[code] = max(0.0, 15.0 - idx * 0.3)
|
||||
|
||||
candidates: list[ScanCandidate] = []
|
||||
for row in fluct_rows:
|
||||
stock_code = _extract_stock_code(row)
|
||||
if not stock_code:
|
||||
continue
|
||||
|
||||
price = _extract_last_price(row)
|
||||
change_rate = _extract_change_rate_pct(row)
|
||||
volume = _extract_volume(row)
|
||||
intraday_range_pct = _extract_intraday_range_pct(row, price)
|
||||
volatility_pct = max(abs(change_rate), intraday_range_pct)
|
||||
|
||||
# Volatility-first filter (not simple gainers/value ranking).
|
||||
if price <= 0 or volatility_pct < 0.8:
|
||||
continue
|
||||
|
||||
volatility_score = min(volatility_pct / 10.0, 1.0) * 85.0
|
||||
liquidity_score = volume_rank_bonus.get(stock_code, 0.0)
|
||||
score = min(100.0, volatility_score + liquidity_score)
|
||||
signal = "momentum" if change_rate >= 0 else "oversold"
|
||||
implied_rsi = max(0.0, min(100.0, 50.0 + (change_rate * 2.0)))
|
||||
candidates.append(
|
||||
ScanCandidate(
|
||||
stock_code=stock_code,
|
||||
name=str(row.get("name") or row.get("ovrs_item_name") or stock_code),
|
||||
price=price,
|
||||
volume=volume,
|
||||
volume_ratio=max(1.0, volatility_pct / 2.0),
|
||||
rsi=implied_rsi,
|
||||
signal=signal,
|
||||
score=score,
|
||||
)
|
||||
)
|
||||
|
||||
if candidates:
|
||||
logger.info(
|
||||
"Overseas ranking scan found %d candidates for %s",
|
||||
len(candidates),
|
||||
market.name,
|
||||
)
|
||||
return candidates
|
||||
|
||||
async def _scan_overseas_from_symbols(
|
||||
self,
|
||||
market: MarketInfo,
|
||||
symbols: list[str] | None,
|
||||
) -> list[ScanCandidate]:
|
||||
"""Fallback overseas scan from dynamic symbol universe."""
|
||||
assert self.overseas_broker is not None
|
||||
if not symbols:
|
||||
logger.info("Overseas scanner: no symbol universe for %s", market.name)
|
||||
return []
|
||||
|
||||
logger.info(
|
||||
"Overseas scanner: scanning %d fallback symbols for %s",
|
||||
len(symbols),
|
||||
market.name,
|
||||
)
|
||||
candidates: list[ScanCandidate] = []
|
||||
for stock_code in symbols:
|
||||
try:
|
||||
price_data = await self.overseas_broker.get_overseas_price(
|
||||
market.exchange_code, stock_code
|
||||
)
|
||||
output = price_data.get("output", {})
|
||||
price = _extract_last_price(output)
|
||||
change_rate = _extract_change_rate_pct(output)
|
||||
volume = _extract_volume(output)
|
||||
intraday_range_pct = _extract_intraday_range_pct(output, price)
|
||||
volatility_pct = max(abs(change_rate), intraday_range_pct)
|
||||
|
||||
if price <= 0 or volatility_pct < 0.8:
|
||||
continue
|
||||
|
||||
score = min(volatility_pct / 10.0, 1.0) * 100.0
|
||||
signal = "momentum" if change_rate >= 0 else "oversold"
|
||||
implied_rsi = max(0.0, min(100.0, 50.0 + (change_rate * 2.0)))
|
||||
candidates.append(
|
||||
ScanCandidate(
|
||||
stock_code=stock_code,
|
||||
name=stock_code,
|
||||
price=price,
|
||||
volume=volume,
|
||||
volume_ratio=max(1.0, volatility_pct / 2.0),
|
||||
rsi=implied_rsi,
|
||||
signal=signal,
|
||||
score=score,
|
||||
)
|
||||
)
|
||||
except ConnectionError as exc:
|
||||
logger.warning("Failed to analyze overseas %s: %s", stock_code, exc)
|
||||
except Exception as exc:
|
||||
logger.error("Unexpected error analyzing overseas %s: %s", stock_code, exc)
|
||||
logger.info(
|
||||
"Overseas symbol fallback scan found %d candidates for %s",
|
||||
len(candidates),
|
||||
market.name,
|
||||
)
|
||||
return candidates
|
||||
|
||||
def get_stock_codes(self, candidates: list[ScanCandidate]) -> list[str]:
|
||||
"""Extract stock codes from candidates for watchlist update.
|
||||
|
||||
Args:
|
||||
candidates: List of scan candidates
|
||||
|
||||
Returns:
|
||||
List of stock codes
|
||||
"""
|
||||
return [c.stock_code for c in candidates]
|
||||
|
||||
|
||||
def _safe_float(value: Any, default: float = 0.0) -> float:
|
||||
"""Convert arbitrary values to float safely."""
|
||||
if value in (None, ""):
|
||||
return default
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
def _extract_stock_code(row: dict[str, Any]) -> str:
|
||||
"""Extract normalized stock code from various API schemas."""
|
||||
return (
|
||||
str(
|
||||
row.get("symb")
|
||||
or row.get("ovrs_pdno")
|
||||
or row.get("stock_code")
|
||||
or row.get("pdno")
|
||||
or ""
|
||||
)
|
||||
.strip()
|
||||
.upper()
|
||||
)
|
||||
|
||||
|
||||
def _extract_last_price(row: dict[str, Any]) -> float:
|
||||
"""Extract last/close-like price from API schema variants."""
|
||||
return _safe_float(
|
||||
row.get("last")
|
||||
or row.get("ovrs_nmix_prpr")
|
||||
or row.get("stck_prpr")
|
||||
or row.get("price")
|
||||
or row.get("close")
|
||||
)
|
||||
|
||||
|
||||
def _extract_change_rate_pct(row: dict[str, Any]) -> float:
|
||||
"""Extract daily change rate (%) from API schema variants."""
|
||||
return _safe_float(
|
||||
row.get("rate")
|
||||
or row.get("change_rate")
|
||||
or row.get("prdy_ctrt")
|
||||
or row.get("evlu_pfls_rt")
|
||||
or row.get("chg_rt")
|
||||
)
|
||||
|
||||
|
||||
def _extract_volume(row: dict[str, Any]) -> float:
|
||||
"""Extract volume/traded-amount proxy from schema variants."""
|
||||
return _safe_float(
|
||||
row.get("tvol") or row.get("acml_vol") or row.get("vol") or row.get("volume")
|
||||
)
|
||||
|
||||
|
||||
def _extract_intraday_range_pct(row: dict[str, Any], price: float) -> float:
|
||||
"""Estimate intraday range percentage from high/low fields."""
|
||||
if price <= 0:
|
||||
return 0.0
|
||||
high = _safe_float(
|
||||
row.get("high")
|
||||
or row.get("ovrs_hgpr")
|
||||
or row.get("stck_hgpr")
|
||||
or row.get("day_hgpr")
|
||||
)
|
||||
low = _safe_float(
|
||||
row.get("low")
|
||||
or row.get("ovrs_lwpr")
|
||||
or row.get("stck_lwpr")
|
||||
or row.get("day_lwpr")
|
||||
)
|
||||
if high <= 0 or low <= 0 or high < low:
|
||||
return 0.0
|
||||
return (high - low) / price * 100.0
|
||||
373
src/analysis/volatility.py
Normal file
373
src/analysis/volatility.py
Normal file
@@ -0,0 +1,373 @@
|
||||
"""Volatility and momentum analysis for stock selection.
|
||||
|
||||
Calculates ATR, price change percentages, volume surges, and price-volume divergence.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class VolatilityMetrics:
|
||||
"""Volatility and momentum metrics for a stock."""
|
||||
|
||||
stock_code: str
|
||||
current_price: float
|
||||
atr: float # Average True Range (14 periods)
|
||||
price_change_1m: float # 1-minute price change %
|
||||
price_change_5m: float # 5-minute price change %
|
||||
price_change_15m: float # 15-minute price change %
|
||||
volume_surge: float # Volume vs average (ratio)
|
||||
pv_divergence: float # Price-volume divergence score
|
||||
momentum_score: float # Combined momentum score (0-100)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"VolatilityMetrics({self.stock_code}: "
|
||||
f"price={self.current_price:.2f}, "
|
||||
f"atr={self.atr:.2f}, "
|
||||
f"1m={self.price_change_1m:.2f}%, "
|
||||
f"vol_surge={self.volume_surge:.2f}x, "
|
||||
f"momentum={self.momentum_score:.1f})"
|
||||
)
|
||||
|
||||
|
||||
class VolatilityAnalyzer:
|
||||
"""Analyzes stock volatility and momentum for leader detection."""
|
||||
|
||||
def __init__(self, min_volume_surge: float = 2.0, min_price_change: float = 1.0) -> None:
|
||||
"""Initialize the volatility analyzer.
|
||||
|
||||
Args:
|
||||
min_volume_surge: Minimum volume surge ratio (default 2x average)
|
||||
min_price_change: Minimum price change % for breakout (default 1%)
|
||||
"""
|
||||
self.min_volume_surge = min_volume_surge
|
||||
self.min_price_change = min_price_change
|
||||
|
||||
def calculate_atr(
|
||||
self,
|
||||
high_prices: list[float],
|
||||
low_prices: list[float],
|
||||
close_prices: list[float],
|
||||
period: int = 14,
|
||||
) -> float:
|
||||
"""Calculate Average True Range (ATR).
|
||||
|
||||
Args:
|
||||
high_prices: List of high prices (most recent last)
|
||||
low_prices: List of low prices (most recent last)
|
||||
close_prices: List of close prices (most recent last)
|
||||
period: ATR period (default 14)
|
||||
|
||||
Returns:
|
||||
ATR value
|
||||
"""
|
||||
if (
|
||||
len(high_prices) < period + 1
|
||||
or len(low_prices) < period + 1
|
||||
or len(close_prices) < period + 1
|
||||
):
|
||||
return 0.0
|
||||
|
||||
true_ranges: list[float] = []
|
||||
for i in range(1, len(high_prices)):
|
||||
high = high_prices[i]
|
||||
low = low_prices[i]
|
||||
prev_close = close_prices[i - 1]
|
||||
|
||||
tr = max(
|
||||
high - low,
|
||||
abs(high - prev_close),
|
||||
abs(low - prev_close),
|
||||
)
|
||||
true_ranges.append(tr)
|
||||
|
||||
if len(true_ranges) < period:
|
||||
return 0.0
|
||||
|
||||
# Simple Moving Average of True Range
|
||||
recent_tr = true_ranges[-period:]
|
||||
return sum(recent_tr) / len(recent_tr)
|
||||
|
||||
def calculate_price_change(
|
||||
self, current_price: float, past_price: float
|
||||
) -> float:
|
||||
"""Calculate price change percentage.
|
||||
|
||||
Args:
|
||||
current_price: Current price
|
||||
past_price: Past price to compare against
|
||||
|
||||
Returns:
|
||||
Price change percentage
|
||||
"""
|
||||
if past_price == 0:
|
||||
return 0.0
|
||||
return ((current_price - past_price) / past_price) * 100
|
||||
|
||||
def calculate_volume_surge(
|
||||
self, current_volume: float, avg_volume: float
|
||||
) -> float:
|
||||
"""Calculate volume surge ratio.
|
||||
|
||||
Args:
|
||||
current_volume: Current volume
|
||||
avg_volume: Average volume
|
||||
|
||||
Returns:
|
||||
Volume surge ratio (current / average)
|
||||
"""
|
||||
if avg_volume == 0:
|
||||
return 1.0
|
||||
return current_volume / avg_volume
|
||||
|
||||
def calculate_rsi(
|
||||
self,
|
||||
close_prices: list[float],
|
||||
period: int = 14,
|
||||
) -> float:
|
||||
"""Calculate Relative Strength Index (RSI) using Wilder's smoothing.
|
||||
|
||||
Args:
|
||||
close_prices: List of closing prices (oldest to newest, minimum period+1 values)
|
||||
period: RSI period (default 14)
|
||||
|
||||
Returns:
|
||||
RSI value between 0 and 100, or 50.0 (neutral) if insufficient data
|
||||
|
||||
Examples:
|
||||
>>> analyzer = VolatilityAnalyzer()
|
||||
>>> prices = [100 - i * 0.5 for i in range(20)] # Downtrend
|
||||
>>> rsi = analyzer.calculate_rsi(prices)
|
||||
>>> assert rsi < 50 # Oversold territory
|
||||
"""
|
||||
if len(close_prices) < period + 1:
|
||||
return 50.0 # Neutral RSI if insufficient data
|
||||
|
||||
# Calculate price changes
|
||||
changes = [close_prices[i] - close_prices[i - 1] for i in range(1, len(close_prices))]
|
||||
|
||||
# Separate gains and losses
|
||||
gains = [max(0.0, change) for change in changes]
|
||||
losses = [max(0.0, -change) for change in changes]
|
||||
|
||||
# Calculate initial average gain/loss (simple average for first period)
|
||||
avg_gain = sum(gains[:period]) / period
|
||||
avg_loss = sum(losses[:period]) / period
|
||||
|
||||
# Apply Wilder's smoothing for remaining periods
|
||||
for i in range(period, len(changes)):
|
||||
avg_gain = (avg_gain * (period - 1) + gains[i]) / period
|
||||
avg_loss = (avg_loss * (period - 1) + losses[i]) / period
|
||||
|
||||
# Calculate RS and RSI
|
||||
if avg_loss == 0:
|
||||
return 100.0 # All gains, maximum RSI
|
||||
|
||||
rs = avg_gain / avg_loss
|
||||
rsi = 100 - (100 / (1 + rs))
|
||||
|
||||
return rsi
|
||||
|
||||
def calculate_pv_divergence(
|
||||
self,
|
||||
price_change: float,
|
||||
volume_surge: float,
|
||||
) -> float:
|
||||
"""Calculate price-volume divergence score.
|
||||
|
||||
Positive divergence: Price up + Volume up = bullish
|
||||
Negative divergence: Price up + Volume down = bearish
|
||||
Neutral: Price/volume move together moderately
|
||||
|
||||
Args:
|
||||
price_change: Price change percentage
|
||||
volume_surge: Volume surge ratio
|
||||
|
||||
Returns:
|
||||
Divergence score (-100 to +100)
|
||||
"""
|
||||
# Normalize volume surge to -1 to +1 scale (1.0 = neutral)
|
||||
volume_signal = (volume_surge - 1.0) * 10 # Scale for sensitivity
|
||||
|
||||
# Calculate divergence
|
||||
# Positive: price and volume move in same direction
|
||||
# Negative: price and volume move in opposite directions
|
||||
if price_change > 0 and volume_surge > 1.0:
|
||||
# Bullish: price up, volume up
|
||||
return min(100.0, price_change * volume_signal)
|
||||
elif price_change < 0 and volume_surge < 1.0:
|
||||
# Bearish confirmation: price down, volume down
|
||||
return max(-100.0, price_change * volume_signal)
|
||||
elif price_change > 0 and volume_surge < 1.0:
|
||||
# Bearish divergence: price up but volume low (weak rally)
|
||||
return -abs(price_change) * 0.5
|
||||
elif price_change < 0 and volume_surge > 1.0:
|
||||
# Selling pressure: price down, volume up
|
||||
return price_change * volume_signal
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def calculate_momentum_score(
|
||||
self,
|
||||
price_change_1m: float,
|
||||
price_change_5m: float,
|
||||
price_change_15m: float,
|
||||
volume_surge: float,
|
||||
atr: float,
|
||||
current_price: float,
|
||||
) -> float:
|
||||
"""Calculate combined momentum score (0-100).
|
||||
|
||||
Weights:
|
||||
- 1m change: 40%
|
||||
- 5m change: 30%
|
||||
- 15m change: 20%
|
||||
- Volume surge: 10%
|
||||
|
||||
Args:
|
||||
price_change_1m: 1-minute price change %
|
||||
price_change_5m: 5-minute price change %
|
||||
price_change_15m: 15-minute price change %
|
||||
volume_surge: Volume surge ratio
|
||||
atr: Average True Range
|
||||
current_price: Current price
|
||||
|
||||
Returns:
|
||||
Momentum score (0-100)
|
||||
"""
|
||||
# Weight recent changes more heavily
|
||||
weighted_change = (
|
||||
price_change_1m * 0.4 +
|
||||
price_change_5m * 0.3 +
|
||||
price_change_15m * 0.2
|
||||
)
|
||||
|
||||
# Volume contribution (normalized to 0-10 scale)
|
||||
volume_contribution = min(10.0, (volume_surge - 1.0) * 5.0)
|
||||
|
||||
# Volatility bonus: higher ATR = higher potential (normalized)
|
||||
volatility_bonus = 0.0
|
||||
if current_price > 0:
|
||||
atr_pct = (atr / current_price) * 100
|
||||
volatility_bonus = min(10.0, atr_pct)
|
||||
|
||||
# Combine scores
|
||||
raw_score = weighted_change + volume_contribution + volatility_bonus
|
||||
|
||||
# Normalize to 0-100 scale
|
||||
# Assume typical momentum range is -10 to +30
|
||||
normalized = ((raw_score + 10) / 40) * 100
|
||||
|
||||
return max(0.0, min(100.0, normalized))
|
||||
|
||||
def analyze(
|
||||
self,
|
||||
stock_code: str,
|
||||
orderbook_data: dict[str, Any],
|
||||
price_history: dict[str, Any],
|
||||
) -> VolatilityMetrics:
|
||||
"""Analyze volatility and momentum for a stock.
|
||||
|
||||
Args:
|
||||
stock_code: Stock code
|
||||
orderbook_data: Current orderbook/quote data
|
||||
price_history: Historical price and volume data
|
||||
|
||||
Returns:
|
||||
VolatilityMetrics with calculated indicators
|
||||
"""
|
||||
# Extract current data from orderbook
|
||||
output1 = orderbook_data.get("output1", {})
|
||||
current_price = float(output1.get("stck_prpr", 0))
|
||||
current_volume = float(output1.get("acml_vol", 0))
|
||||
|
||||
# Extract historical data
|
||||
high_prices = price_history.get("high", [])
|
||||
low_prices = price_history.get("low", [])
|
||||
close_prices = price_history.get("close", [])
|
||||
volumes = price_history.get("volume", [])
|
||||
|
||||
# Calculate ATR
|
||||
atr = self.calculate_atr(high_prices, low_prices, close_prices)
|
||||
|
||||
# Calculate price changes (use historical data if available)
|
||||
price_change_1m = 0.0
|
||||
price_change_5m = 0.0
|
||||
price_change_15m = 0.0
|
||||
|
||||
if len(close_prices) > 0:
|
||||
if len(close_prices) >= 1:
|
||||
price_change_1m = self.calculate_price_change(
|
||||
current_price, close_prices[-1]
|
||||
)
|
||||
if len(close_prices) >= 5:
|
||||
price_change_5m = self.calculate_price_change(
|
||||
current_price, close_prices[-5]
|
||||
)
|
||||
if len(close_prices) >= 15:
|
||||
price_change_15m = self.calculate_price_change(
|
||||
current_price, close_prices[-15]
|
||||
)
|
||||
|
||||
# Calculate volume surge
|
||||
avg_volume = sum(volumes) / len(volumes) if volumes else current_volume
|
||||
volume_surge = self.calculate_volume_surge(current_volume, avg_volume)
|
||||
|
||||
# Calculate price-volume divergence
|
||||
pv_divergence = self.calculate_pv_divergence(price_change_1m, volume_surge)
|
||||
|
||||
# Calculate momentum score
|
||||
momentum_score = self.calculate_momentum_score(
|
||||
price_change_1m,
|
||||
price_change_5m,
|
||||
price_change_15m,
|
||||
volume_surge,
|
||||
atr,
|
||||
current_price,
|
||||
)
|
||||
|
||||
return VolatilityMetrics(
|
||||
stock_code=stock_code,
|
||||
current_price=current_price,
|
||||
atr=atr,
|
||||
price_change_1m=price_change_1m,
|
||||
price_change_5m=price_change_5m,
|
||||
price_change_15m=price_change_15m,
|
||||
volume_surge=volume_surge,
|
||||
pv_divergence=pv_divergence,
|
||||
momentum_score=momentum_score,
|
||||
)
|
||||
|
||||
def is_breakout(self, metrics: VolatilityMetrics) -> bool:
|
||||
"""Determine if a stock is experiencing a breakout.
|
||||
|
||||
Args:
|
||||
metrics: Volatility metrics for the stock
|
||||
|
||||
Returns:
|
||||
True if breakout conditions are met
|
||||
"""
|
||||
return (
|
||||
metrics.price_change_1m >= self.min_price_change
|
||||
and metrics.volume_surge >= self.min_volume_surge
|
||||
and metrics.pv_divergence > 0 # Bullish divergence
|
||||
)
|
||||
|
||||
def is_breakdown(self, metrics: VolatilityMetrics) -> bool:
|
||||
"""Determine if a stock is experiencing a breakdown.
|
||||
|
||||
Args:
|
||||
metrics: Volatility metrics for the stock
|
||||
|
||||
Returns:
|
||||
True if breakdown conditions are met
|
||||
"""
|
||||
return (
|
||||
metrics.price_change_1m <= -self.min_price_change
|
||||
and metrics.volume_surge >= self.min_volume_surge
|
||||
and metrics.pv_divergence < 0 # Bearish divergence
|
||||
)
|
||||
21
src/backup/__init__.py
Normal file
21
src/backup/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Backup and disaster recovery system for long-term sustainability.
|
||||
|
||||
This module provides:
|
||||
- Automated database backups (daily, weekly, monthly)
|
||||
- Multi-format exports (JSON, CSV, Parquet)
|
||||
- Cloud storage integration (S3-compatible)
|
||||
- Health monitoring and alerts
|
||||
"""
|
||||
|
||||
from src.backup.exporter import BackupExporter, ExportFormat
|
||||
from src.backup.scheduler import BackupScheduler, BackupPolicy
|
||||
from src.backup.cloud_storage import CloudStorage, S3Config
|
||||
|
||||
__all__ = [
|
||||
"BackupExporter",
|
||||
"ExportFormat",
|
||||
"BackupScheduler",
|
||||
"BackupPolicy",
|
||||
"CloudStorage",
|
||||
"S3Config",
|
||||
]
|
||||
274
src/backup/cloud_storage.py
Normal file
274
src/backup/cloud_storage.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""Cloud storage integration for off-site backups.
|
||||
|
||||
Supports S3-compatible storage providers:
|
||||
- AWS S3
|
||||
- MinIO
|
||||
- Backblaze B2
|
||||
- DigitalOcean Spaces
|
||||
- Cloudflare R2
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class S3Config:
|
||||
"""Configuration for S3-compatible storage."""
|
||||
|
||||
endpoint_url: str | None # None for AWS S3, custom URL for others
|
||||
access_key: str
|
||||
secret_key: str
|
||||
bucket_name: str
|
||||
region: str = "us-east-1"
|
||||
use_ssl: bool = True
|
||||
|
||||
|
||||
class CloudStorage:
|
||||
"""Upload backups to S3-compatible cloud storage."""
|
||||
|
||||
def __init__(self, config: S3Config) -> None:
|
||||
"""Initialize cloud storage client.
|
||||
|
||||
Args:
|
||||
config: S3 configuration
|
||||
|
||||
Raises:
|
||||
ImportError: If boto3 is not installed
|
||||
"""
|
||||
try:
|
||||
import boto3
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"boto3 is required for cloud storage. Install with: pip install boto3"
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.client = boto3.client(
|
||||
"s3",
|
||||
endpoint_url=config.endpoint_url,
|
||||
aws_access_key_id=config.access_key,
|
||||
aws_secret_access_key=config.secret_key,
|
||||
region_name=config.region,
|
||||
use_ssl=config.use_ssl,
|
||||
)
|
||||
|
||||
def upload_file(
|
||||
self,
|
||||
file_path: Path,
|
||||
object_key: str | None = None,
|
||||
metadata: dict[str, str] | None = None,
|
||||
) -> str:
|
||||
"""Upload a file to cloud storage.
|
||||
|
||||
Args:
|
||||
file_path: Local file to upload
|
||||
object_key: S3 object key (default: filename)
|
||||
metadata: Optional metadata to attach
|
||||
|
||||
Returns:
|
||||
S3 object key
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
Exception: If upload fails
|
||||
"""
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
if object_key is None:
|
||||
object_key = file_path.name
|
||||
|
||||
extra_args: dict[str, Any] = {}
|
||||
|
||||
# Add server-side encryption
|
||||
extra_args["ServerSideEncryption"] = "AES256"
|
||||
|
||||
# Add metadata if provided
|
||||
if metadata:
|
||||
extra_args["Metadata"] = metadata
|
||||
|
||||
logger.info("Uploading %s to s3://%s/%s", file_path.name, self.config.bucket_name, object_key)
|
||||
|
||||
try:
|
||||
self.client.upload_file(
|
||||
str(file_path),
|
||||
self.config.bucket_name,
|
||||
object_key,
|
||||
ExtraArgs=extra_args,
|
||||
)
|
||||
logger.info("Upload successful: %s", object_key)
|
||||
return object_key
|
||||
except Exception as exc:
|
||||
logger.error("Upload failed: %s", exc)
|
||||
raise
|
||||
|
||||
def download_file(self, object_key: str, local_path: Path) -> Path:
|
||||
"""Download a file from cloud storage.
|
||||
|
||||
Args:
|
||||
object_key: S3 object key
|
||||
local_path: Local destination path
|
||||
|
||||
Returns:
|
||||
Path to downloaded file
|
||||
|
||||
Raises:
|
||||
Exception: If download fails
|
||||
"""
|
||||
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info("Downloading s3://%s/%s to %s", self.config.bucket_name, object_key, local_path)
|
||||
|
||||
try:
|
||||
self.client.download_file(
|
||||
self.config.bucket_name,
|
||||
object_key,
|
||||
str(local_path),
|
||||
)
|
||||
logger.info("Download successful: %s", local_path)
|
||||
return local_path
|
||||
except Exception as exc:
|
||||
logger.error("Download failed: %s", exc)
|
||||
raise
|
||||
|
||||
def list_files(self, prefix: str = "") -> list[dict[str, Any]]:
|
||||
"""List files in cloud storage.
|
||||
|
||||
Args:
|
||||
prefix: Filter by object key prefix
|
||||
|
||||
Returns:
|
||||
List of file metadata dictionaries
|
||||
"""
|
||||
try:
|
||||
response = self.client.list_objects_v2(
|
||||
Bucket=self.config.bucket_name,
|
||||
Prefix=prefix,
|
||||
)
|
||||
|
||||
if "Contents" not in response:
|
||||
return []
|
||||
|
||||
files = []
|
||||
for obj in response["Contents"]:
|
||||
files.append(
|
||||
{
|
||||
"key": obj["Key"],
|
||||
"size_bytes": obj["Size"],
|
||||
"last_modified": obj["LastModified"],
|
||||
"etag": obj["ETag"],
|
||||
}
|
||||
)
|
||||
|
||||
return files
|
||||
except Exception as exc:
|
||||
logger.error("Failed to list files: %s", exc)
|
||||
raise
|
||||
|
||||
def delete_file(self, object_key: str) -> None:
|
||||
"""Delete a file from cloud storage.
|
||||
|
||||
Args:
|
||||
object_key: S3 object key
|
||||
|
||||
Raises:
|
||||
Exception: If deletion fails
|
||||
"""
|
||||
logger.info("Deleting s3://%s/%s", self.config.bucket_name, object_key)
|
||||
|
||||
try:
|
||||
self.client.delete_object(
|
||||
Bucket=self.config.bucket_name,
|
||||
Key=object_key,
|
||||
)
|
||||
logger.info("Deletion successful: %s", object_key)
|
||||
except Exception as exc:
|
||||
logger.error("Deletion failed: %s", exc)
|
||||
raise
|
||||
|
||||
def get_storage_stats(self) -> dict[str, Any]:
|
||||
"""Get cloud storage statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with storage stats
|
||||
"""
|
||||
try:
|
||||
files = self.list_files()
|
||||
|
||||
total_size = sum(f["size_bytes"] for f in files)
|
||||
total_count = len(files)
|
||||
|
||||
return {
|
||||
"total_files": total_count,
|
||||
"total_size_bytes": total_size,
|
||||
"total_size_mb": total_size / 1024 / 1024,
|
||||
"total_size_gb": total_size / 1024 / 1024 / 1024,
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.error("Failed to get storage stats: %s", exc)
|
||||
return {
|
||||
"error": str(exc),
|
||||
"total_files": 0,
|
||||
"total_size_bytes": 0,
|
||||
}
|
||||
|
||||
def verify_connection(self) -> bool:
|
||||
"""Verify connection to cloud storage.
|
||||
|
||||
Returns:
|
||||
True if connection is successful
|
||||
"""
|
||||
try:
|
||||
self.client.head_bucket(Bucket=self.config.bucket_name)
|
||||
logger.info("Cloud storage connection verified")
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.error("Cloud storage connection failed: %s", exc)
|
||||
return False
|
||||
|
||||
def create_bucket_if_not_exists(self) -> None:
|
||||
"""Create storage bucket if it doesn't exist.
|
||||
|
||||
Raises:
|
||||
Exception: If bucket creation fails
|
||||
"""
|
||||
try:
|
||||
self.client.head_bucket(Bucket=self.config.bucket_name)
|
||||
logger.info("Bucket already exists: %s", self.config.bucket_name)
|
||||
except self.client.exceptions.NoSuchBucket:
|
||||
logger.info("Creating bucket: %s", self.config.bucket_name)
|
||||
if self.config.region == "us-east-1":
|
||||
# us-east-1 requires special handling
|
||||
self.client.create_bucket(Bucket=self.config.bucket_name)
|
||||
else:
|
||||
self.client.create_bucket(
|
||||
Bucket=self.config.bucket_name,
|
||||
CreateBucketConfiguration={"LocationConstraint": self.config.region},
|
||||
)
|
||||
logger.info("Bucket created successfully")
|
||||
except Exception as exc:
|
||||
logger.error("Failed to verify/create bucket: %s", exc)
|
||||
raise
|
||||
|
||||
def enable_versioning(self) -> None:
|
||||
"""Enable versioning on the bucket.
|
||||
|
||||
Raises:
|
||||
Exception: If versioning enablement fails
|
||||
"""
|
||||
try:
|
||||
self.client.put_bucket_versioning(
|
||||
Bucket=self.config.bucket_name,
|
||||
VersioningConfiguration={"Status": "Enabled"},
|
||||
)
|
||||
logger.info("Versioning enabled for bucket: %s", self.config.bucket_name)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to enable versioning: %s", exc)
|
||||
raise
|
||||
326
src/backup/exporter.py
Normal file
326
src/backup/exporter.py
Normal file
@@ -0,0 +1,326 @@
|
||||
"""Multi-format database exporter for backups.
|
||||
|
||||
Supports JSON, CSV, and Parquet formats for different use cases:
|
||||
- JSON: Human-readable, easy to inspect
|
||||
- CSV: Analysis tools (Excel, pandas)
|
||||
- Parquet: Big data tools (Spark, DuckDB)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import gzip
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExportFormat(str, Enum):
|
||||
"""Supported export formats."""
|
||||
|
||||
JSON = "json"
|
||||
CSV = "csv"
|
||||
PARQUET = "parquet"
|
||||
|
||||
|
||||
class BackupExporter:
|
||||
"""Export database to multiple formats."""
|
||||
|
||||
def __init__(self, db_path: str) -> None:
|
||||
"""Initialize the exporter.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database
|
||||
"""
|
||||
self.db_path = db_path
|
||||
|
||||
def export_all(
|
||||
self,
|
||||
output_dir: Path,
|
||||
formats: list[ExportFormat] | None = None,
|
||||
compress: bool = True,
|
||||
incremental_since: datetime | None = None,
|
||||
) -> dict[ExportFormat, Path]:
|
||||
"""Export database to multiple formats.
|
||||
|
||||
Args:
|
||||
output_dir: Directory to write export files
|
||||
formats: List of formats to export (default: all)
|
||||
compress: Whether to gzip compress exports
|
||||
incremental_since: Only export records after this timestamp
|
||||
|
||||
Returns:
|
||||
Dictionary mapping format to output file path
|
||||
"""
|
||||
if formats is None:
|
||||
formats = [ExportFormat.JSON, ExportFormat.CSV, ExportFormat.PARQUET]
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
results: dict[ExportFormat, Path] = {}
|
||||
|
||||
for fmt in formats:
|
||||
try:
|
||||
output_file = self._export_format(
|
||||
fmt, output_dir, timestamp, compress, incremental_since
|
||||
)
|
||||
results[fmt] = output_file
|
||||
logger.info("Exported to %s: %s", fmt.value, output_file)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to export to %s: %s", fmt.value, exc)
|
||||
|
||||
return results
|
||||
|
||||
def _export_format(
|
||||
self,
|
||||
fmt: ExportFormat,
|
||||
output_dir: Path,
|
||||
timestamp: str,
|
||||
compress: bool,
|
||||
incremental_since: datetime | None,
|
||||
) -> Path:
|
||||
"""Export to a specific format.
|
||||
|
||||
Args:
|
||||
fmt: Export format
|
||||
output_dir: Output directory
|
||||
timestamp: Timestamp string for filename
|
||||
compress: Whether to compress
|
||||
incremental_since: Incremental export cutoff
|
||||
|
||||
Returns:
|
||||
Path to output file
|
||||
"""
|
||||
if fmt == ExportFormat.JSON:
|
||||
return self._export_json(output_dir, timestamp, compress, incremental_since)
|
||||
elif fmt == ExportFormat.CSV:
|
||||
return self._export_csv(output_dir, timestamp, compress, incremental_since)
|
||||
elif fmt == ExportFormat.PARQUET:
|
||||
return self._export_parquet(
|
||||
output_dir, timestamp, compress, incremental_since
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {fmt}")
|
||||
|
||||
def _get_trades(
|
||||
self, incremental_since: datetime | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch trades from database.
|
||||
|
||||
Args:
|
||||
incremental_since: Only fetch trades after this timestamp
|
||||
|
||||
Returns:
|
||||
List of trade records
|
||||
"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
if incremental_since:
|
||||
cursor = conn.execute(
|
||||
"SELECT * FROM trades WHERE timestamp > ?",
|
||||
(incremental_since.isoformat(),),
|
||||
)
|
||||
else:
|
||||
cursor = conn.execute("SELECT * FROM trades")
|
||||
|
||||
trades = [dict(row) for row in cursor.fetchall()]
|
||||
conn.close()
|
||||
|
||||
return trades
|
||||
|
||||
def _export_json(
|
||||
self,
|
||||
output_dir: Path,
|
||||
timestamp: str,
|
||||
compress: bool,
|
||||
incremental_since: datetime | None,
|
||||
) -> Path:
|
||||
"""Export to JSON format.
|
||||
|
||||
Args:
|
||||
output_dir: Output directory
|
||||
timestamp: Timestamp for filename
|
||||
compress: Whether to gzip
|
||||
incremental_since: Incremental cutoff
|
||||
|
||||
Returns:
|
||||
Path to output file
|
||||
"""
|
||||
trades = self._get_trades(incremental_since)
|
||||
|
||||
filename = f"trades_{timestamp}.json"
|
||||
if compress:
|
||||
filename += ".gz"
|
||||
|
||||
output_file = output_dir / filename
|
||||
|
||||
data = {
|
||||
"export_timestamp": datetime.now(UTC).isoformat(),
|
||||
"incremental_since": (
|
||||
incremental_since.isoformat() if incremental_since else None
|
||||
),
|
||||
"record_count": len(trades),
|
||||
"trades": trades,
|
||||
}
|
||||
|
||||
if compress:
|
||||
with gzip.open(output_file, "wt", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
else:
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
return output_file
|
||||
|
||||
def _export_csv(
|
||||
self,
|
||||
output_dir: Path,
|
||||
timestamp: str,
|
||||
compress: bool,
|
||||
incremental_since: datetime | None,
|
||||
) -> Path:
|
||||
"""Export to CSV format.
|
||||
|
||||
Args:
|
||||
output_dir: Output directory
|
||||
timestamp: Timestamp for filename
|
||||
compress: Whether to gzip
|
||||
incremental_since: Incremental cutoff
|
||||
|
||||
Returns:
|
||||
Path to output file
|
||||
"""
|
||||
trades = self._get_trades(incremental_since)
|
||||
|
||||
filename = f"trades_{timestamp}.csv"
|
||||
if compress:
|
||||
filename += ".gz"
|
||||
|
||||
output_file = output_dir / filename
|
||||
|
||||
if not trades:
|
||||
# Write empty CSV with headers
|
||||
if compress:
|
||||
with gzip.open(output_file, "wt", encoding="utf-8", newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(
|
||||
[
|
||||
"timestamp",
|
||||
"stock_code",
|
||||
"action",
|
||||
"quantity",
|
||||
"price",
|
||||
"confidence",
|
||||
"rationale",
|
||||
"pnl",
|
||||
]
|
||||
)
|
||||
else:
|
||||
with open(output_file, "w", encoding="utf-8", newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(
|
||||
[
|
||||
"timestamp",
|
||||
"stock_code",
|
||||
"action",
|
||||
"quantity",
|
||||
"price",
|
||||
"confidence",
|
||||
"rationale",
|
||||
"pnl",
|
||||
]
|
||||
)
|
||||
return output_file
|
||||
|
||||
# Get column names from first trade
|
||||
fieldnames = list(trades[0].keys())
|
||||
|
||||
if compress:
|
||||
with gzip.open(output_file, "wt", encoding="utf-8", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(trades)
|
||||
else:
|
||||
with open(output_file, "w", encoding="utf-8", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(trades)
|
||||
|
||||
return output_file
|
||||
|
||||
def _export_parquet(
|
||||
self,
|
||||
output_dir: Path,
|
||||
timestamp: str,
|
||||
compress: bool,
|
||||
incremental_since: datetime | None,
|
||||
) -> Path:
|
||||
"""Export to Parquet format.
|
||||
|
||||
Args:
|
||||
output_dir: Output directory
|
||||
timestamp: Timestamp for filename
|
||||
compress: Whether to compress (Parquet has built-in compression)
|
||||
incremental_since: Incremental cutoff
|
||||
|
||||
Returns:
|
||||
Path to output file
|
||||
"""
|
||||
trades = self._get_trades(incremental_since)
|
||||
|
||||
filename = f"trades_{timestamp}.parquet"
|
||||
output_file = output_dir / filename
|
||||
|
||||
try:
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"pyarrow is required for Parquet export. "
|
||||
"Install with: pip install pyarrow"
|
||||
)
|
||||
|
||||
# Convert to pyarrow table
|
||||
table = pa.Table.from_pylist(trades)
|
||||
|
||||
# Write with compression
|
||||
compression = "gzip" if compress else "none"
|
||||
pq.write_table(table, output_file, compression=compression)
|
||||
|
||||
return output_file
|
||||
|
||||
def get_export_stats(self) -> dict[str, Any]:
|
||||
"""Get statistics about exportable data.
|
||||
|
||||
Returns:
|
||||
Dictionary with data statistics
|
||||
"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
stats = {}
|
||||
|
||||
# Total trades
|
||||
cursor.execute("SELECT COUNT(*) FROM trades")
|
||||
stats["total_trades"] = cursor.fetchone()[0]
|
||||
|
||||
# Date range
|
||||
cursor.execute("SELECT MIN(timestamp), MAX(timestamp) FROM trades")
|
||||
min_date, max_date = cursor.fetchone()
|
||||
stats["date_range"] = {"earliest": min_date, "latest": max_date}
|
||||
|
||||
# Database size
|
||||
cursor.execute("SELECT page_count * page_size FROM pragma_page_count(), pragma_page_size()")
|
||||
stats["db_size_bytes"] = cursor.fetchone()[0]
|
||||
|
||||
conn.close()
|
||||
|
||||
return stats
|
||||
282
src/backup/health_monitor.py
Normal file
282
src/backup/health_monitor.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""Health monitoring for backup system.
|
||||
|
||||
Checks:
|
||||
- Database accessibility and integrity
|
||||
- Disk space availability
|
||||
- Backup success/failure tracking
|
||||
- Self-healing capabilities
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
import sqlite3
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HealthStatus(str, Enum):
|
||||
"""Health check status."""
|
||||
|
||||
HEALTHY = "healthy"
|
||||
DEGRADED = "degraded"
|
||||
UNHEALTHY = "unhealthy"
|
||||
|
||||
|
||||
@dataclass
|
||||
class HealthCheckResult:
|
||||
"""Result of a health check."""
|
||||
|
||||
status: HealthStatus
|
||||
message: str
|
||||
details: dict[str, Any] | None = None
|
||||
timestamp: datetime | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.timestamp is None:
|
||||
self.timestamp = datetime.now(UTC)
|
||||
|
||||
|
||||
class HealthMonitor:
|
||||
"""Monitor system health and backup status."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_path: str,
|
||||
backup_dir: Path,
|
||||
min_disk_space_gb: float = 10.0,
|
||||
max_backup_age_hours: int = 25, # Daily backups should be < 25 hours old
|
||||
) -> None:
|
||||
"""Initialize health monitor.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database
|
||||
backup_dir: Backup directory
|
||||
min_disk_space_gb: Minimum required disk space in GB
|
||||
max_backup_age_hours: Maximum acceptable backup age in hours
|
||||
"""
|
||||
self.db_path = Path(db_path)
|
||||
self.backup_dir = backup_dir
|
||||
self.min_disk_space_bytes = int(min_disk_space_gb * 1024 * 1024 * 1024)
|
||||
self.max_backup_age = timedelta(hours=max_backup_age_hours)
|
||||
|
||||
def check_database_health(self) -> HealthCheckResult:
|
||||
"""Check database accessibility and integrity.
|
||||
|
||||
Returns:
|
||||
HealthCheckResult
|
||||
"""
|
||||
# Check if database exists
|
||||
if not self.db_path.exists():
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.UNHEALTHY,
|
||||
message=f"Database not found: {self.db_path}",
|
||||
)
|
||||
|
||||
# Check if database is accessible
|
||||
try:
|
||||
conn = sqlite3.connect(str(self.db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Run integrity check
|
||||
cursor.execute("PRAGMA integrity_check")
|
||||
result = cursor.fetchone()[0]
|
||||
|
||||
if result != "ok":
|
||||
conn.close()
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.UNHEALTHY,
|
||||
message=f"Database integrity check failed: {result}",
|
||||
)
|
||||
|
||||
# Get database size
|
||||
cursor.execute(
|
||||
"SELECT page_count * page_size FROM pragma_page_count(), pragma_page_size()"
|
||||
)
|
||||
db_size = cursor.fetchone()[0]
|
||||
|
||||
# Get row counts
|
||||
cursor.execute("SELECT COUNT(*) FROM trades")
|
||||
trade_count = cursor.fetchone()[0]
|
||||
|
||||
conn.close()
|
||||
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.HEALTHY,
|
||||
message="Database is healthy",
|
||||
details={
|
||||
"size_bytes": db_size,
|
||||
"size_mb": db_size / 1024 / 1024,
|
||||
"trade_count": trade_count,
|
||||
},
|
||||
)
|
||||
|
||||
except sqlite3.Error as exc:
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.UNHEALTHY,
|
||||
message=f"Database access error: {exc}",
|
||||
)
|
||||
|
||||
def check_disk_space(self) -> HealthCheckResult:
|
||||
"""Check available disk space.
|
||||
|
||||
Returns:
|
||||
HealthCheckResult
|
||||
"""
|
||||
try:
|
||||
stat = shutil.disk_usage(self.backup_dir)
|
||||
|
||||
free_gb = stat.free / 1024 / 1024 / 1024
|
||||
total_gb = stat.total / 1024 / 1024 / 1024
|
||||
used_percent = (stat.used / stat.total) * 100
|
||||
|
||||
if stat.free < self.min_disk_space_bytes:
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.UNHEALTHY,
|
||||
message=f"Low disk space: {free_gb:.2f} GB free (minimum: {self.min_disk_space_bytes / 1024 / 1024 / 1024:.2f} GB)",
|
||||
details={
|
||||
"free_gb": free_gb,
|
||||
"total_gb": total_gb,
|
||||
"used_percent": used_percent,
|
||||
},
|
||||
)
|
||||
elif stat.free < self.min_disk_space_bytes * 2:
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.DEGRADED,
|
||||
message=f"Disk space low: {free_gb:.2f} GB free",
|
||||
details={
|
||||
"free_gb": free_gb,
|
||||
"total_gb": total_gb,
|
||||
"used_percent": used_percent,
|
||||
},
|
||||
)
|
||||
else:
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.HEALTHY,
|
||||
message=f"Disk space healthy: {free_gb:.2f} GB free",
|
||||
details={
|
||||
"free_gb": free_gb,
|
||||
"total_gb": total_gb,
|
||||
"used_percent": used_percent,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.UNHEALTHY,
|
||||
message=f"Failed to check disk space: {exc}",
|
||||
)
|
||||
|
||||
def check_backup_recency(self) -> HealthCheckResult:
|
||||
"""Check if backups are recent enough.
|
||||
|
||||
Returns:
|
||||
HealthCheckResult
|
||||
"""
|
||||
daily_dir = self.backup_dir / "daily"
|
||||
|
||||
if not daily_dir.exists():
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.DEGRADED,
|
||||
message="Daily backup directory not found",
|
||||
)
|
||||
|
||||
# Find most recent backup
|
||||
backups = sorted(daily_dir.glob("*.db"), key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
|
||||
if not backups:
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.UNHEALTHY,
|
||||
message="No daily backups found",
|
||||
)
|
||||
|
||||
most_recent = backups[0]
|
||||
mtime = datetime.fromtimestamp(most_recent.stat().st_mtime, tz=UTC)
|
||||
age = datetime.now(UTC) - mtime
|
||||
|
||||
if age > self.max_backup_age:
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.DEGRADED,
|
||||
message=f"Most recent backup is {age.total_seconds() / 3600:.1f} hours old",
|
||||
details={
|
||||
"backup_file": most_recent.name,
|
||||
"age_hours": age.total_seconds() / 3600,
|
||||
"threshold_hours": self.max_backup_age.total_seconds() / 3600,
|
||||
},
|
||||
)
|
||||
else:
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.HEALTHY,
|
||||
message=f"Recent backup found ({age.total_seconds() / 3600:.1f} hours old)",
|
||||
details={
|
||||
"backup_file": most_recent.name,
|
||||
"age_hours": age.total_seconds() / 3600,
|
||||
},
|
||||
)
|
||||
|
||||
def run_all_checks(self) -> dict[str, HealthCheckResult]:
|
||||
"""Run all health checks.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping check name to result
|
||||
"""
|
||||
checks = {
|
||||
"database": self.check_database_health(),
|
||||
"disk_space": self.check_disk_space(),
|
||||
"backup_recency": self.check_backup_recency(),
|
||||
}
|
||||
|
||||
# Log results
|
||||
for check_name, result in checks.items():
|
||||
if result.status == HealthStatus.UNHEALTHY:
|
||||
logger.error("[%s] %s: %s", check_name, result.status.value, result.message)
|
||||
elif result.status == HealthStatus.DEGRADED:
|
||||
logger.warning("[%s] %s: %s", check_name, result.status.value, result.message)
|
||||
else:
|
||||
logger.info("[%s] %s: %s", check_name, result.status.value, result.message)
|
||||
|
||||
return checks
|
||||
|
||||
def get_overall_status(self) -> HealthStatus:
|
||||
"""Get overall system health status.
|
||||
|
||||
Returns:
|
||||
HealthStatus (worst status from all checks)
|
||||
"""
|
||||
checks = self.run_all_checks()
|
||||
|
||||
# Return worst status
|
||||
if any(c.status == HealthStatus.UNHEALTHY for c in checks.values()):
|
||||
return HealthStatus.UNHEALTHY
|
||||
elif any(c.status == HealthStatus.DEGRADED for c in checks.values()):
|
||||
return HealthStatus.DEGRADED
|
||||
else:
|
||||
return HealthStatus.HEALTHY
|
||||
|
||||
def get_health_report(self) -> dict[str, Any]:
|
||||
"""Get comprehensive health report.
|
||||
|
||||
Returns:
|
||||
Dictionary with health report
|
||||
"""
|
||||
checks = self.run_all_checks()
|
||||
overall = self.get_overall_status()
|
||||
|
||||
return {
|
||||
"overall_status": overall.value,
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"checks": {
|
||||
name: {
|
||||
"status": result.status.value,
|
||||
"message": result.message,
|
||||
"details": result.details,
|
||||
}
|
||||
for name, result in checks.items()
|
||||
},
|
||||
}
|
||||
336
src/backup/scheduler.py
Normal file
336
src/backup/scheduler.py
Normal file
@@ -0,0 +1,336 @@
|
||||
"""Backup scheduler for automated database backups.
|
||||
|
||||
Implements backup policies:
|
||||
- Daily: Keep for 30 days (hot storage)
|
||||
- Weekly: Keep for 1 year (warm storage)
|
||||
- Monthly: Keep forever (cold storage)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BackupPolicy(str, Enum):
|
||||
"""Backup retention policies."""
|
||||
|
||||
DAILY = "daily"
|
||||
WEEKLY = "weekly"
|
||||
MONTHLY = "monthly"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackupMetadata:
|
||||
"""Metadata for a backup."""
|
||||
|
||||
timestamp: datetime
|
||||
policy: BackupPolicy
|
||||
file_path: Path
|
||||
size_bytes: int
|
||||
checksum: str | None = None
|
||||
|
||||
|
||||
class BackupScheduler:
|
||||
"""Manage automated database backups with retention policies."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_path: str,
|
||||
backup_dir: Path,
|
||||
daily_retention_days: int = 30,
|
||||
weekly_retention_days: int = 365,
|
||||
) -> None:
|
||||
"""Initialize the backup scheduler.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database
|
||||
backup_dir: Root directory for backups
|
||||
daily_retention_days: Days to keep daily backups
|
||||
weekly_retention_days: Days to keep weekly backups
|
||||
"""
|
||||
self.db_path = Path(db_path)
|
||||
self.backup_dir = backup_dir
|
||||
self.daily_retention = timedelta(days=daily_retention_days)
|
||||
self.weekly_retention = timedelta(days=weekly_retention_days)
|
||||
|
||||
# Create policy-specific directories
|
||||
self.daily_dir = backup_dir / "daily"
|
||||
self.weekly_dir = backup_dir / "weekly"
|
||||
self.monthly_dir = backup_dir / "monthly"
|
||||
|
||||
for d in [self.daily_dir, self.weekly_dir, self.monthly_dir]:
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def create_backup(
|
||||
self, policy: BackupPolicy, verify: bool = True
|
||||
) -> BackupMetadata:
|
||||
"""Create a database backup.
|
||||
|
||||
Args:
|
||||
policy: Backup policy (daily/weekly/monthly)
|
||||
verify: Whether to verify backup integrity
|
||||
|
||||
Returns:
|
||||
BackupMetadata object
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If database doesn't exist
|
||||
OSError: If backup fails
|
||||
"""
|
||||
if not self.db_path.exists():
|
||||
raise FileNotFoundError(f"Database not found: {self.db_path}")
|
||||
|
||||
timestamp = datetime.now(UTC)
|
||||
backup_filename = self._get_backup_filename(timestamp, policy)
|
||||
|
||||
# Determine output directory
|
||||
if policy == BackupPolicy.DAILY:
|
||||
output_dir = self.daily_dir
|
||||
elif policy == BackupPolicy.WEEKLY:
|
||||
output_dir = self.weekly_dir
|
||||
else: # MONTHLY
|
||||
output_dir = self.monthly_dir
|
||||
|
||||
backup_path = output_dir / backup_filename
|
||||
|
||||
# Create backup (copy database file)
|
||||
logger.info("Creating %s backup: %s", policy.value, backup_path)
|
||||
shutil.copy2(self.db_path, backup_path)
|
||||
|
||||
# Get file size
|
||||
size_bytes = backup_path.stat().st_size
|
||||
|
||||
# Verify backup if requested
|
||||
checksum = None
|
||||
if verify:
|
||||
checksum = self._verify_backup(backup_path)
|
||||
|
||||
metadata = BackupMetadata(
|
||||
timestamp=timestamp,
|
||||
policy=policy,
|
||||
file_path=backup_path,
|
||||
size_bytes=size_bytes,
|
||||
checksum=checksum,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Backup created: %s (%.2f MB)",
|
||||
backup_path.name,
|
||||
size_bytes / 1024 / 1024,
|
||||
)
|
||||
|
||||
return metadata
|
||||
|
||||
def _get_backup_filename(self, timestamp: datetime, policy: BackupPolicy) -> str:
|
||||
"""Generate backup filename.
|
||||
|
||||
Args:
|
||||
timestamp: Backup timestamp
|
||||
policy: Backup policy
|
||||
|
||||
Returns:
|
||||
Filename string
|
||||
"""
|
||||
ts_str = timestamp.strftime("%Y%m%d_%H%M%S")
|
||||
return f"trade_logs_{policy.value}_{ts_str}.db"
|
||||
|
||||
def _verify_backup(self, backup_path: Path) -> str:
|
||||
"""Verify backup integrity using SQLite integrity check.
|
||||
|
||||
Args:
|
||||
backup_path: Path to backup file
|
||||
|
||||
Returns:
|
||||
Checksum string (MD5 hash)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If integrity check fails
|
||||
"""
|
||||
import hashlib
|
||||
import sqlite3
|
||||
|
||||
# Integrity check
|
||||
try:
|
||||
conn = sqlite3.connect(str(backup_path))
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("PRAGMA integrity_check")
|
||||
result = cursor.fetchone()[0]
|
||||
conn.close()
|
||||
|
||||
if result != "ok":
|
||||
raise RuntimeError(f"Integrity check failed: {result}")
|
||||
except sqlite3.Error as exc:
|
||||
raise RuntimeError(f"Failed to verify backup: {exc}")
|
||||
|
||||
# Calculate MD5 checksum
|
||||
md5 = hashlib.md5()
|
||||
with open(backup_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(8192), b""):
|
||||
md5.update(chunk)
|
||||
|
||||
return md5.hexdigest()
|
||||
|
||||
def cleanup_old_backups(self) -> dict[BackupPolicy, int]:
|
||||
"""Remove backups older than retention policies.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping policy to number of backups removed
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
removed_counts: dict[BackupPolicy, int] = {}
|
||||
|
||||
# Daily backups: remove older than retention
|
||||
removed_counts[BackupPolicy.DAILY] = self._cleanup_directory(
|
||||
self.daily_dir, now - self.daily_retention
|
||||
)
|
||||
|
||||
# Weekly backups: remove older than retention
|
||||
removed_counts[BackupPolicy.WEEKLY] = self._cleanup_directory(
|
||||
self.weekly_dir, now - self.weekly_retention
|
||||
)
|
||||
|
||||
# Monthly backups: never remove (kept forever)
|
||||
removed_counts[BackupPolicy.MONTHLY] = 0
|
||||
|
||||
total = sum(removed_counts.values())
|
||||
if total > 0:
|
||||
logger.info("Cleaned up %d old backup(s)", total)
|
||||
|
||||
return removed_counts
|
||||
|
||||
def _cleanup_directory(self, directory: Path, cutoff: datetime) -> int:
|
||||
"""Remove backups older than cutoff date.
|
||||
|
||||
Args:
|
||||
directory: Directory to clean
|
||||
cutoff: Remove files older than this
|
||||
|
||||
Returns:
|
||||
Number of files removed
|
||||
"""
|
||||
removed = 0
|
||||
|
||||
for backup_file in directory.glob("*.db"):
|
||||
# Get file modification time
|
||||
mtime = datetime.fromtimestamp(backup_file.stat().st_mtime, tz=UTC)
|
||||
|
||||
if mtime < cutoff:
|
||||
logger.debug("Removing old backup: %s", backup_file.name)
|
||||
backup_file.unlink()
|
||||
removed += 1
|
||||
|
||||
return removed
|
||||
|
||||
def list_backups(
|
||||
self, policy: BackupPolicy | None = None
|
||||
) -> list[BackupMetadata]:
|
||||
"""List available backups.
|
||||
|
||||
Args:
|
||||
policy: Filter by policy (None for all)
|
||||
|
||||
Returns:
|
||||
List of BackupMetadata objects
|
||||
"""
|
||||
backups: list[BackupMetadata] = []
|
||||
|
||||
policies_to_check = (
|
||||
[policy] if policy else [BackupPolicy.DAILY, BackupPolicy.WEEKLY, BackupPolicy.MONTHLY]
|
||||
)
|
||||
|
||||
for pol in policies_to_check:
|
||||
if pol == BackupPolicy.DAILY:
|
||||
directory = self.daily_dir
|
||||
elif pol == BackupPolicy.WEEKLY:
|
||||
directory = self.weekly_dir
|
||||
else:
|
||||
directory = self.monthly_dir
|
||||
|
||||
for backup_file in sorted(directory.glob("*.db")):
|
||||
mtime = datetime.fromtimestamp(backup_file.stat().st_mtime, tz=UTC)
|
||||
size = backup_file.stat().st_size
|
||||
|
||||
backups.append(
|
||||
BackupMetadata(
|
||||
timestamp=mtime,
|
||||
policy=pol,
|
||||
file_path=backup_file,
|
||||
size_bytes=size,
|
||||
)
|
||||
)
|
||||
|
||||
# Sort by timestamp (newest first)
|
||||
backups.sort(key=lambda b: b.timestamp, reverse=True)
|
||||
|
||||
return backups
|
||||
|
||||
def get_backup_stats(self) -> dict[str, Any]:
|
||||
"""Get backup statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with backup stats
|
||||
"""
|
||||
stats: dict[str, Any] = {}
|
||||
|
||||
for policy in BackupPolicy:
|
||||
if policy == BackupPolicy.DAILY:
|
||||
directory = self.daily_dir
|
||||
elif policy == BackupPolicy.WEEKLY:
|
||||
directory = self.weekly_dir
|
||||
else:
|
||||
directory = self.monthly_dir
|
||||
|
||||
backups = list(directory.glob("*.db"))
|
||||
total_size = sum(b.stat().st_size for b in backups)
|
||||
|
||||
stats[policy.value] = {
|
||||
"count": len(backups),
|
||||
"total_size_bytes": total_size,
|
||||
"total_size_mb": total_size / 1024 / 1024,
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
def restore_backup(self, backup_metadata: BackupMetadata, verify: bool = True) -> None:
|
||||
"""Restore database from backup.
|
||||
|
||||
Args:
|
||||
backup_metadata: Backup to restore
|
||||
verify: Whether to verify restored database
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If backup file doesn't exist
|
||||
RuntimeError: If verification fails
|
||||
"""
|
||||
if not backup_metadata.file_path.exists():
|
||||
raise FileNotFoundError(f"Backup not found: {backup_metadata.file_path}")
|
||||
|
||||
# Create backup of current database
|
||||
if self.db_path.exists():
|
||||
backup_current = self.db_path.with_suffix(".db.before_restore")
|
||||
logger.info("Backing up current database to: %s", backup_current)
|
||||
shutil.copy2(self.db_path, backup_current)
|
||||
|
||||
# Restore backup
|
||||
logger.info("Restoring backup: %s", backup_metadata.file_path.name)
|
||||
shutil.copy2(backup_metadata.file_path, self.db_path)
|
||||
|
||||
# Verify restored database
|
||||
if verify:
|
||||
try:
|
||||
self._verify_backup(self.db_path)
|
||||
logger.info("Backup restored and verified successfully")
|
||||
except RuntimeError as exc:
|
||||
# Restore failed, revert to backup
|
||||
if backup_current.exists():
|
||||
logger.error("Restore verification failed, reverting: %s", exc)
|
||||
shutil.copy2(backup_current, self.db_path)
|
||||
raise
|
||||
293
src/brain/cache.py
Normal file
293
src/brain/cache.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""Response caching system for reducing redundant LLM calls.
|
||||
|
||||
This module provides caching for common trading scenarios:
|
||||
- TTL-based cache invalidation
|
||||
- Cache key based on market conditions
|
||||
- Cache hit rate monitoring
|
||||
- Special handling for HOLD decisions in quiet markets
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.brain.gemini_client import TradeDecision
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""Cached decision with metadata."""
|
||||
|
||||
decision: "TradeDecision"
|
||||
cached_at: float # Unix timestamp
|
||||
hit_count: int = 0
|
||||
market_data_hash: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheMetrics:
|
||||
"""Metrics for cache performance monitoring."""
|
||||
|
||||
total_requests: int = 0
|
||||
cache_hits: int = 0
|
||||
cache_misses: int = 0
|
||||
evictions: int = 0
|
||||
total_entries: int = 0
|
||||
|
||||
@property
|
||||
def hit_rate(self) -> float:
|
||||
"""Calculate cache hit rate."""
|
||||
if self.total_requests == 0:
|
||||
return 0.0
|
||||
return self.cache_hits / self.total_requests
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert metrics to dictionary."""
|
||||
return {
|
||||
"total_requests": self.total_requests,
|
||||
"cache_hits": self.cache_hits,
|
||||
"cache_misses": self.cache_misses,
|
||||
"hit_rate": self.hit_rate,
|
||||
"evictions": self.evictions,
|
||||
"total_entries": self.total_entries,
|
||||
}
|
||||
|
||||
|
||||
class DecisionCache:
|
||||
"""TTL-based cache for trade decisions."""
|
||||
|
||||
def __init__(self, ttl_seconds: int = 300, max_size: int = 1000) -> None:
|
||||
"""Initialize the decision cache.
|
||||
|
||||
Args:
|
||||
ttl_seconds: Time-to-live for cache entries in seconds (default: 5 minutes)
|
||||
max_size: Maximum number of cache entries
|
||||
"""
|
||||
self.ttl_seconds = ttl_seconds
|
||||
self.max_size = max_size
|
||||
self._cache: dict[str, CacheEntry] = {}
|
||||
self._metrics = CacheMetrics()
|
||||
|
||||
def _generate_cache_key(self, market_data: dict[str, Any]) -> str:
|
||||
"""Generate cache key from market data.
|
||||
|
||||
Key is based on:
|
||||
- Stock code
|
||||
- Current price (rounded to reduce sensitivity)
|
||||
- Market conditions (orderbook snapshot)
|
||||
|
||||
Args:
|
||||
market_data: Market data dictionary
|
||||
|
||||
Returns:
|
||||
Cache key string
|
||||
"""
|
||||
# Extract key components
|
||||
stock_code = market_data.get("stock_code", "UNKNOWN")
|
||||
current_price = market_data.get("current_price", 0)
|
||||
|
||||
# Round price to reduce sensitivity (cache hits for similar prices)
|
||||
# For prices > 1000, round to nearest 10
|
||||
# For prices < 1000, round to nearest 1
|
||||
if current_price > 1000:
|
||||
price_rounded = round(current_price / 10) * 10
|
||||
else:
|
||||
price_rounded = round(current_price)
|
||||
|
||||
# Include orderbook snapshot (if available)
|
||||
orderbook_key = ""
|
||||
if "orderbook" in market_data and market_data["orderbook"]:
|
||||
ob = market_data["orderbook"]
|
||||
# Just use bid/ask spread as indicator
|
||||
if "bid" in ob and "ask" in ob and ob["bid"] and ob["ask"]:
|
||||
bid_price = ob["bid"][0].get("price", 0) if ob["bid"] else 0
|
||||
ask_price = ob["ask"][0].get("price", 0) if ob["ask"] else 0
|
||||
spread = ask_price - bid_price
|
||||
orderbook_key = f"_spread{spread}"
|
||||
|
||||
# Generate cache key
|
||||
key_str = f"{stock_code}_{price_rounded}{orderbook_key}"
|
||||
|
||||
return key_str
|
||||
|
||||
def _generate_market_hash(self, market_data: dict[str, Any]) -> str:
|
||||
"""Generate hash of full market data for invalidation checks.
|
||||
|
||||
Args:
|
||||
market_data: Market data dictionary
|
||||
|
||||
Returns:
|
||||
Hash string
|
||||
"""
|
||||
# Create stable JSON representation
|
||||
stable_json = json.dumps(market_data, sort_keys=True, ensure_ascii=False)
|
||||
return hashlib.md5(stable_json.encode()).hexdigest()
|
||||
|
||||
def get(self, market_data: dict[str, Any]) -> TradeDecision | None:
|
||||
"""Retrieve cached decision if valid.
|
||||
|
||||
Args:
|
||||
market_data: Market data dictionary
|
||||
|
||||
Returns:
|
||||
Cached TradeDecision if valid, None otherwise
|
||||
"""
|
||||
self._metrics.total_requests += 1
|
||||
|
||||
cache_key = self._generate_cache_key(market_data)
|
||||
|
||||
if cache_key not in self._cache:
|
||||
self._metrics.cache_misses += 1
|
||||
return None
|
||||
|
||||
entry = self._cache[cache_key]
|
||||
current_time = time.time()
|
||||
|
||||
# Check TTL
|
||||
if current_time - entry.cached_at > self.ttl_seconds:
|
||||
# Expired
|
||||
del self._cache[cache_key]
|
||||
self._metrics.cache_misses += 1
|
||||
self._metrics.evictions += 1
|
||||
logger.debug("Cache expired for key: %s", cache_key)
|
||||
return None
|
||||
|
||||
# Cache hit
|
||||
entry.hit_count += 1
|
||||
self._metrics.cache_hits += 1
|
||||
logger.debug("Cache hit for key: %s (hits: %d)", cache_key, entry.hit_count)
|
||||
|
||||
return entry.decision
|
||||
|
||||
def set(
|
||||
self,
|
||||
market_data: dict[str, Any],
|
||||
decision: TradeDecision,
|
||||
) -> None:
|
||||
"""Store decision in cache.
|
||||
|
||||
Args:
|
||||
market_data: Market data dictionary
|
||||
decision: TradeDecision to cache
|
||||
"""
|
||||
cache_key = self._generate_cache_key(market_data)
|
||||
market_hash = self._generate_market_hash(market_data)
|
||||
|
||||
# Enforce max size (evict oldest if full)
|
||||
if len(self._cache) >= self.max_size:
|
||||
# Find oldest entry
|
||||
oldest_key = min(self._cache.keys(), key=lambda k: self._cache[k].cached_at)
|
||||
del self._cache[oldest_key]
|
||||
self._metrics.evictions += 1
|
||||
logger.debug("Cache full, evicted key: %s", oldest_key)
|
||||
|
||||
# Store entry
|
||||
entry = CacheEntry(
|
||||
decision=decision,
|
||||
cached_at=time.time(),
|
||||
market_data_hash=market_hash,
|
||||
)
|
||||
self._cache[cache_key] = entry
|
||||
self._metrics.total_entries = len(self._cache)
|
||||
|
||||
logger.debug("Cached decision for key: %s", cache_key)
|
||||
|
||||
def invalidate(self, stock_code: str | None = None) -> int:
|
||||
"""Invalidate cache entries.
|
||||
|
||||
Args:
|
||||
stock_code: Specific stock code to invalidate, or None for all
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
if stock_code is None:
|
||||
# Clear all
|
||||
count = len(self._cache)
|
||||
self._cache.clear()
|
||||
self._metrics.evictions += count
|
||||
self._metrics.total_entries = 0
|
||||
logger.info("Invalidated all cache entries (%d)", count)
|
||||
return count
|
||||
|
||||
# Invalidate specific stock
|
||||
keys_to_remove = [k for k in self._cache.keys() if k.startswith(f"{stock_code}_")]
|
||||
count = len(keys_to_remove)
|
||||
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
|
||||
self._metrics.evictions += count
|
||||
self._metrics.total_entries = len(self._cache)
|
||||
logger.info("Invalidated %d cache entries for stock: %s", count, stock_code)
|
||||
|
||||
return count
|
||||
|
||||
def cleanup_expired(self) -> int:
|
||||
"""Remove expired entries from cache.
|
||||
|
||||
Returns:
|
||||
Number of entries removed
|
||||
"""
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
k
|
||||
for k, v in self._cache.items()
|
||||
if current_time - v.cached_at > self.ttl_seconds
|
||||
]
|
||||
|
||||
count = len(expired_keys)
|
||||
for key in expired_keys:
|
||||
del self._cache[key]
|
||||
|
||||
self._metrics.evictions += count
|
||||
self._metrics.total_entries = len(self._cache)
|
||||
|
||||
if count > 0:
|
||||
logger.debug("Cleaned up %d expired cache entries", count)
|
||||
|
||||
return count
|
||||
|
||||
def get_metrics(self) -> CacheMetrics:
|
||||
"""Get current cache metrics.
|
||||
|
||||
Returns:
|
||||
CacheMetrics object with current statistics
|
||||
"""
|
||||
return self._metrics
|
||||
|
||||
def reset_metrics(self) -> None:
|
||||
"""Reset cache metrics."""
|
||||
self._metrics = CacheMetrics(total_entries=len(self._cache))
|
||||
logger.info("Cache metrics reset")
|
||||
|
||||
def should_cache_decision(self, decision: TradeDecision) -> bool:
|
||||
"""Determine if a decision should be cached.
|
||||
|
||||
HOLD decisions with low confidence are good candidates for caching,
|
||||
as they're likely to recur in quiet markets.
|
||||
|
||||
Args:
|
||||
decision: TradeDecision to evaluate
|
||||
|
||||
Returns:
|
||||
True if decision should be cached
|
||||
"""
|
||||
# Cache HOLD decisions (common in quiet markets)
|
||||
if decision.action == "HOLD":
|
||||
return True
|
||||
|
||||
# Cache high-confidence decisions (stable signals)
|
||||
if decision.confidence >= 90:
|
||||
return True
|
||||
|
||||
# Don't cache low-confidence BUY/SELL (volatile signals)
|
||||
return False
|
||||
296
src/brain/context_selector.py
Normal file
296
src/brain/context_selector.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""Smart context selection for optimizing token usage.
|
||||
|
||||
This module implements intelligent selection of context layers (L1-L7) based on
|
||||
decision type and market conditions:
|
||||
- L7 (real-time) for normal trading decisions
|
||||
- L6-L5 (daily/weekly) for strategic decisions
|
||||
- L4-L1 (monthly/legacy) only for major events or policy changes
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from src.context.layer import ContextLayer
|
||||
from src.context.store import ContextStore
|
||||
|
||||
|
||||
class DecisionType(str, Enum):
|
||||
"""Type of trading decision being made."""
|
||||
|
||||
NORMAL = "normal" # Regular trade decision
|
||||
STRATEGIC = "strategic" # Strategy adjustment
|
||||
MAJOR_EVENT = "major_event" # Portfolio rebalancing, policy change
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ContextSelection:
|
||||
"""Selected context layers and their relevance scores."""
|
||||
|
||||
layers: list[ContextLayer]
|
||||
relevance_scores: dict[ContextLayer, float]
|
||||
total_score: float
|
||||
|
||||
|
||||
class ContextSelector:
|
||||
"""Selects optimal context layers to minimize token usage."""
|
||||
|
||||
def __init__(self, store: ContextStore) -> None:
|
||||
"""Initialize the context selector.
|
||||
|
||||
Args:
|
||||
store: ContextStore instance for retrieving context data
|
||||
"""
|
||||
self.store = store
|
||||
|
||||
def select_layers(
|
||||
self,
|
||||
decision_type: DecisionType = DecisionType.NORMAL,
|
||||
include_realtime: bool = True,
|
||||
) -> list[ContextLayer]:
|
||||
"""Select context layers based on decision type.
|
||||
|
||||
Strategy:
|
||||
- NORMAL: L7 (real-time) only
|
||||
- STRATEGIC: L7 + L6 + L5 (real-time + daily + weekly)
|
||||
- MAJOR_EVENT: All layers L1-L7
|
||||
|
||||
Args:
|
||||
decision_type: Type of decision being made
|
||||
include_realtime: Whether to include L7 real-time data
|
||||
|
||||
Returns:
|
||||
List of context layers to use (ordered by priority)
|
||||
"""
|
||||
if decision_type == DecisionType.NORMAL:
|
||||
# Normal trading: only real-time data
|
||||
return [ContextLayer.L7_REALTIME] if include_realtime else []
|
||||
|
||||
elif decision_type == DecisionType.STRATEGIC:
|
||||
# Strategic decisions: real-time + recent history
|
||||
layers = []
|
||||
if include_realtime:
|
||||
layers.append(ContextLayer.L7_REALTIME)
|
||||
layers.extend([ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY])
|
||||
return layers
|
||||
|
||||
else: # MAJOR_EVENT
|
||||
# Major events: all layers for comprehensive context
|
||||
layers = []
|
||||
if include_realtime:
|
||||
layers.append(ContextLayer.L7_REALTIME)
|
||||
layers.extend(
|
||||
[
|
||||
ContextLayer.L6_DAILY,
|
||||
ContextLayer.L5_WEEKLY,
|
||||
ContextLayer.L4_MONTHLY,
|
||||
ContextLayer.L3_QUARTERLY,
|
||||
ContextLayer.L2_ANNUAL,
|
||||
ContextLayer.L1_LEGACY,
|
||||
]
|
||||
)
|
||||
return layers
|
||||
|
||||
def score_layer_relevance(
|
||||
self,
|
||||
layer: ContextLayer,
|
||||
decision_type: DecisionType,
|
||||
current_time: datetime | None = None,
|
||||
) -> float:
|
||||
"""Calculate relevance score for a context layer.
|
||||
|
||||
Relevance is based on:
|
||||
1. Decision type (normal, strategic, major event)
|
||||
2. Layer recency (L7 > L6 > ... > L1)
|
||||
3. Data availability
|
||||
|
||||
Args:
|
||||
layer: Context layer to score
|
||||
decision_type: Type of decision being made
|
||||
current_time: Current time (defaults to now)
|
||||
|
||||
Returns:
|
||||
Relevance score (0.0 to 1.0)
|
||||
"""
|
||||
if current_time is None:
|
||||
current_time = datetime.now(UTC)
|
||||
|
||||
# Base scores by decision type
|
||||
base_scores = {
|
||||
DecisionType.NORMAL: {
|
||||
ContextLayer.L7_REALTIME: 1.0,
|
||||
ContextLayer.L6_DAILY: 0.1,
|
||||
ContextLayer.L5_WEEKLY: 0.05,
|
||||
ContextLayer.L4_MONTHLY: 0.01,
|
||||
ContextLayer.L3_QUARTERLY: 0.0,
|
||||
ContextLayer.L2_ANNUAL: 0.0,
|
||||
ContextLayer.L1_LEGACY: 0.0,
|
||||
},
|
||||
DecisionType.STRATEGIC: {
|
||||
ContextLayer.L7_REALTIME: 0.9,
|
||||
ContextLayer.L6_DAILY: 0.8,
|
||||
ContextLayer.L5_WEEKLY: 0.7,
|
||||
ContextLayer.L4_MONTHLY: 0.3,
|
||||
ContextLayer.L3_QUARTERLY: 0.2,
|
||||
ContextLayer.L2_ANNUAL: 0.1,
|
||||
ContextLayer.L1_LEGACY: 0.05,
|
||||
},
|
||||
DecisionType.MAJOR_EVENT: {
|
||||
ContextLayer.L7_REALTIME: 0.7,
|
||||
ContextLayer.L6_DAILY: 0.7,
|
||||
ContextLayer.L5_WEEKLY: 0.7,
|
||||
ContextLayer.L4_MONTHLY: 0.8,
|
||||
ContextLayer.L3_QUARTERLY: 0.8,
|
||||
ContextLayer.L2_ANNUAL: 0.9,
|
||||
ContextLayer.L1_LEGACY: 1.0,
|
||||
},
|
||||
}
|
||||
|
||||
score = base_scores[decision_type].get(layer, 0.0)
|
||||
|
||||
# Check data availability
|
||||
latest_timeframe = self.store.get_latest_timeframe(layer)
|
||||
if latest_timeframe is None:
|
||||
# No data available - reduce score significantly
|
||||
score *= 0.1
|
||||
|
||||
return score
|
||||
|
||||
def select_with_scoring(
|
||||
self,
|
||||
decision_type: DecisionType = DecisionType.NORMAL,
|
||||
min_score: float = 0.5,
|
||||
) -> ContextSelection:
|
||||
"""Select context layers with relevance scoring.
|
||||
|
||||
Args:
|
||||
decision_type: Type of decision being made
|
||||
min_score: Minimum relevance score to include a layer
|
||||
|
||||
Returns:
|
||||
ContextSelection with selected layers and scores
|
||||
"""
|
||||
all_layers = [
|
||||
ContextLayer.L7_REALTIME,
|
||||
ContextLayer.L6_DAILY,
|
||||
ContextLayer.L5_WEEKLY,
|
||||
ContextLayer.L4_MONTHLY,
|
||||
ContextLayer.L3_QUARTERLY,
|
||||
ContextLayer.L2_ANNUAL,
|
||||
ContextLayer.L1_LEGACY,
|
||||
]
|
||||
|
||||
scores = {
|
||||
layer: self.score_layer_relevance(layer, decision_type) for layer in all_layers
|
||||
}
|
||||
|
||||
# Filter by minimum score
|
||||
selected_layers = [layer for layer, score in scores.items() if score >= min_score]
|
||||
|
||||
# Sort by score (descending)
|
||||
selected_layers.sort(key=lambda layer: scores[layer], reverse=True)
|
||||
|
||||
total_score = sum(scores[layer] for layer in selected_layers)
|
||||
|
||||
return ContextSelection(
|
||||
layers=selected_layers,
|
||||
relevance_scores=scores,
|
||||
total_score=total_score,
|
||||
)
|
||||
|
||||
def get_context_data(
|
||||
self,
|
||||
layers: list[ContextLayer],
|
||||
max_items_per_layer: int = 10,
|
||||
) -> dict[str, Any]:
|
||||
"""Retrieve context data for selected layers.
|
||||
|
||||
Args:
|
||||
layers: List of context layers to retrieve
|
||||
max_items_per_layer: Maximum number of items per layer
|
||||
|
||||
Returns:
|
||||
Dictionary with context data organized by layer
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
|
||||
for layer in layers:
|
||||
# Get latest timeframe for this layer
|
||||
latest_timeframe = self.store.get_latest_timeframe(layer)
|
||||
if latest_timeframe:
|
||||
# Get all contexts for latest timeframe
|
||||
contexts = self.store.get_all_contexts(layer, latest_timeframe)
|
||||
|
||||
# Limit number of items
|
||||
if len(contexts) > max_items_per_layer:
|
||||
# Keep only first N items
|
||||
contexts = dict(list(contexts.items())[:max_items_per_layer])
|
||||
|
||||
result[layer.value] = contexts
|
||||
|
||||
return result
|
||||
|
||||
def estimate_context_tokens(self, context_data: dict[str, Any]) -> int:
|
||||
"""Estimate total tokens for context data.
|
||||
|
||||
Args:
|
||||
context_data: Context data dictionary
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
import json
|
||||
|
||||
from src.brain.prompt_optimizer import PromptOptimizer
|
||||
|
||||
# Serialize to JSON and estimate tokens
|
||||
json_str = json.dumps(context_data, ensure_ascii=False)
|
||||
return PromptOptimizer.estimate_tokens(json_str)
|
||||
|
||||
def optimize_context_for_budget(
|
||||
self,
|
||||
decision_type: DecisionType,
|
||||
max_tokens: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Select and retrieve context data within a token budget.
|
||||
|
||||
Args:
|
||||
decision_type: Type of decision being made
|
||||
max_tokens: Maximum token budget for context
|
||||
|
||||
Returns:
|
||||
Optimized context data within budget
|
||||
"""
|
||||
# Start with minimal selection
|
||||
selection = self.select_with_scoring(decision_type, min_score=0.5)
|
||||
|
||||
# Retrieve data
|
||||
context_data = self.get_context_data(selection.layers)
|
||||
|
||||
# Check if within budget
|
||||
estimated_tokens = self.estimate_context_tokens(context_data)
|
||||
|
||||
if estimated_tokens <= max_tokens:
|
||||
return context_data
|
||||
|
||||
# If over budget, progressively reduce
|
||||
# 1. Reduce items per layer
|
||||
for max_items in [5, 3, 1]:
|
||||
context_data = self.get_context_data(selection.layers, max_items)
|
||||
estimated_tokens = self.estimate_context_tokens(context_data)
|
||||
if estimated_tokens <= max_tokens:
|
||||
return context_data
|
||||
|
||||
# 2. Remove lower-priority layers
|
||||
for min_score in [0.6, 0.7, 0.8, 0.9]:
|
||||
selection = self.select_with_scoring(decision_type, min_score=min_score)
|
||||
context_data = self.get_context_data(selection.layers, max_items_per_layer=1)
|
||||
estimated_tokens = self.estimate_context_tokens(context_data)
|
||||
if estimated_tokens <= max_tokens:
|
||||
return context_data
|
||||
|
||||
# Last resort: return only L7 with minimal data
|
||||
return self.get_context_data([ContextLayer.L7_REALTIME], max_items_per_layer=1)
|
||||
@@ -2,6 +2,17 @@
|
||||
|
||||
Constructs prompts from market data, calls Gemini, and parses structured
|
||||
JSON responses into validated TradeDecision objects.
|
||||
|
||||
Includes token efficiency optimizations:
|
||||
- Prompt compression and abbreviation
|
||||
- Response caching for common scenarios
|
||||
- Smart context selection
|
||||
- Token usage tracking and metrics
|
||||
|
||||
Includes external data integration:
|
||||
- News sentiment analysis
|
||||
- Economic calendar events
|
||||
- Market indicators
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -15,6 +26,11 @@ from typing import Any
|
||||
from google import genai
|
||||
|
||||
from src.config import Settings
|
||||
from src.data.news_api import NewsAPI, NewsSentiment
|
||||
from src.data.economic_calendar import EconomicCalendar
|
||||
from src.data.market_data import MarketData
|
||||
from src.brain.cache import DecisionCache
|
||||
from src.brain.prompt_optimizer import PromptOptimizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,23 +44,176 @@ class TradeDecision:
|
||||
action: str # "BUY" | "SELL" | "HOLD"
|
||||
confidence: int # 0-100
|
||||
rationale: str
|
||||
token_count: int = 0 # Estimated tokens used
|
||||
cached: bool = False # Whether decision came from cache
|
||||
|
||||
|
||||
class GeminiClient:
|
||||
"""Wraps the Gemini API for trade decision-making."""
|
||||
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
settings: Settings,
|
||||
news_api: NewsAPI | None = None,
|
||||
economic_calendar: EconomicCalendar | None = None,
|
||||
market_data: MarketData | None = None,
|
||||
enable_cache: bool = True,
|
||||
enable_optimization: bool = True,
|
||||
) -> None:
|
||||
self._settings = settings
|
||||
self._confidence_threshold = settings.CONFIDENCE_THRESHOLD
|
||||
self._client = genai.Client(api_key=settings.GEMINI_API_KEY)
|
||||
self._model_name = settings.GEMINI_MODEL
|
||||
|
||||
# External data sources (optional)
|
||||
self._news_api = news_api
|
||||
self._economic_calendar = economic_calendar
|
||||
self._market_data = market_data
|
||||
|
||||
# Token efficiency features
|
||||
self._enable_cache = enable_cache
|
||||
self._enable_optimization = enable_optimization
|
||||
self._cache = DecisionCache(ttl_seconds=300) if enable_cache else None
|
||||
self._optimizer = PromptOptimizer()
|
||||
|
||||
# Token usage metrics
|
||||
self._total_tokens_used = 0
|
||||
self._total_decisions = 0
|
||||
self._total_cached_decisions = 0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# External Data Integration
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _build_external_context(
|
||||
self, stock_code: str, news_sentiment: NewsSentiment | None = None
|
||||
) -> str:
|
||||
"""Build external data context for the prompt.
|
||||
|
||||
Args:
|
||||
stock_code: Stock ticker symbol
|
||||
news_sentiment: Optional pre-fetched news sentiment
|
||||
|
||||
Returns:
|
||||
Formatted string with external data context
|
||||
"""
|
||||
context_parts: list[str] = []
|
||||
|
||||
# News sentiment
|
||||
if news_sentiment is not None:
|
||||
sentiment_str = self._format_news_sentiment(news_sentiment)
|
||||
if sentiment_str:
|
||||
context_parts.append(sentiment_str)
|
||||
elif self._news_api is not None:
|
||||
# Fetch news sentiment if not provided
|
||||
try:
|
||||
sentiment = await self._news_api.get_news_sentiment(stock_code)
|
||||
if sentiment is not None:
|
||||
sentiment_str = self._format_news_sentiment(sentiment)
|
||||
if sentiment_str:
|
||||
context_parts.append(sentiment_str)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to fetch news sentiment: %s", exc)
|
||||
|
||||
# Economic events
|
||||
if self._economic_calendar is not None:
|
||||
events_str = self._format_economic_events(stock_code)
|
||||
if events_str:
|
||||
context_parts.append(events_str)
|
||||
|
||||
# Market indicators
|
||||
if self._market_data is not None:
|
||||
indicators_str = self._format_market_indicators()
|
||||
if indicators_str:
|
||||
context_parts.append(indicators_str)
|
||||
|
||||
if not context_parts:
|
||||
return ""
|
||||
|
||||
return "EXTERNAL DATA:\n" + "\n\n".join(context_parts)
|
||||
|
||||
def _format_news_sentiment(self, sentiment: NewsSentiment) -> str:
|
||||
"""Format news sentiment for prompt."""
|
||||
if sentiment.article_count == 0:
|
||||
return ""
|
||||
|
||||
# Select top 3 most relevant articles
|
||||
top_articles = sentiment.articles[:3]
|
||||
|
||||
lines = [
|
||||
f"News Sentiment: {sentiment.avg_sentiment:.2f} "
|
||||
f"(from {sentiment.article_count} articles)",
|
||||
]
|
||||
|
||||
for i, article in enumerate(top_articles, 1):
|
||||
lines.append(
|
||||
f" {i}. [{article.source}] {article.title} "
|
||||
f"(sentiment: {article.sentiment_score:.2f})"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _format_economic_events(self, stock_code: str) -> str:
|
||||
"""Format upcoming economic events for prompt."""
|
||||
if self._economic_calendar is None:
|
||||
return ""
|
||||
|
||||
# Check for upcoming high-impact events
|
||||
upcoming = self._economic_calendar.get_upcoming_events(
|
||||
days_ahead=7, min_impact="HIGH"
|
||||
)
|
||||
|
||||
if upcoming.high_impact_count == 0:
|
||||
return ""
|
||||
|
||||
lines = [
|
||||
f"Upcoming High-Impact Events: {upcoming.high_impact_count} in next 7 days"
|
||||
]
|
||||
|
||||
if upcoming.next_major_event is not None:
|
||||
event = upcoming.next_major_event
|
||||
lines.append(
|
||||
f" Next: {event.name} ({event.event_type}) "
|
||||
f"on {event.datetime.strftime('%Y-%m-%d')}"
|
||||
)
|
||||
|
||||
# Check for earnings
|
||||
earnings_date = self._economic_calendar.get_earnings_date(stock_code)
|
||||
if earnings_date is not None:
|
||||
lines.append(
|
||||
f" Earnings: {stock_code} on {earnings_date.strftime('%Y-%m-%d')}"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _format_market_indicators(self) -> str:
|
||||
"""Format market indicators for prompt."""
|
||||
if self._market_data is None:
|
||||
return ""
|
||||
|
||||
try:
|
||||
indicators = self._market_data.get_market_indicators()
|
||||
lines = [f"Market Sentiment: {indicators.sentiment.name}"]
|
||||
|
||||
# Add breadth if meaningful
|
||||
if indicators.breadth.advance_decline_ratio != 1.0:
|
||||
lines.append(
|
||||
f"Advance/Decline Ratio: {indicators.breadth.advance_decline_ratio:.2f}"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to get market indicators: %s", exc)
|
||||
return ""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Prompt Construction
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def build_prompt(self, market_data: dict[str, Any]) -> str:
|
||||
"""Build a structured prompt from market data.
|
||||
async def build_prompt(
|
||||
self, market_data: dict[str, Any], news_sentiment: NewsSentiment | None = None
|
||||
) -> str:
|
||||
"""Build a structured prompt from market data and external sources.
|
||||
|
||||
The prompt instructs Gemini to return valid JSON with action,
|
||||
confidence, and rationale fields.
|
||||
@@ -72,6 +241,60 @@ class GeminiClient:
|
||||
|
||||
market_info = "\n".join(market_info_lines)
|
||||
|
||||
# Add external data context if available
|
||||
external_context = await self._build_external_context(
|
||||
market_data["stock_code"], news_sentiment
|
||||
)
|
||||
if external_context:
|
||||
market_info += f"\n\n{external_context}"
|
||||
|
||||
json_format = (
|
||||
'{"action": "BUY"|"SELL"|"HOLD", '
|
||||
'"confidence": <int 0-100>, "rationale": "<string>"}'
|
||||
)
|
||||
return (
|
||||
f"You are a professional {market_name} trading analyst.\n"
|
||||
"Analyze the following market data and decide whether to "
|
||||
"BUY, SELL, or HOLD.\n\n"
|
||||
f"{market_info}\n\n"
|
||||
"You MUST respond with ONLY valid JSON in the following format:\n"
|
||||
f"{json_format}\n\n"
|
||||
"Rules:\n"
|
||||
"- action must be exactly one of: BUY, SELL, HOLD\n"
|
||||
"- confidence must be an integer from 0 to 100\n"
|
||||
"- rationale must explain your reasoning concisely\n"
|
||||
"- Do NOT wrap the JSON in markdown code blocks\n"
|
||||
)
|
||||
|
||||
def build_prompt_sync(self, market_data: dict[str, Any]) -> str:
|
||||
"""Synchronous version of build_prompt (for backward compatibility).
|
||||
|
||||
This version does NOT include external data integration.
|
||||
Use async build_prompt() for full functionality.
|
||||
"""
|
||||
market_name = market_data.get("market_name", "Korean stock market")
|
||||
|
||||
# Build market data section dynamically based on available fields
|
||||
market_info_lines = [
|
||||
f"Market: {market_name}",
|
||||
f"Stock Code: {market_data['stock_code']}",
|
||||
f"Current Price: {market_data['current_price']}",
|
||||
]
|
||||
|
||||
# Add orderbook if available (domestic markets)
|
||||
if "orderbook" in market_data:
|
||||
market_info_lines.append(
|
||||
f"Orderbook: {json.dumps(market_data['orderbook'], ensure_ascii=False)}"
|
||||
)
|
||||
|
||||
# Add foreigner net if non-zero
|
||||
if market_data.get("foreigner_net", 0) != 0:
|
||||
market_info_lines.append(
|
||||
f"Foreigner Net Buy/Sell: {market_data['foreigner_net']}"
|
||||
)
|
||||
|
||||
market_info = "\n".join(market_info_lines)
|
||||
|
||||
json_format = (
|
||||
'{"action": "BUY"|"SELL"|"HOLD", '
|
||||
'"confidence": <int 0-100>, "rationale": "<string>"}'
|
||||
@@ -123,8 +346,10 @@ class GeminiClient:
|
||||
# Validate required fields
|
||||
if not all(k in data for k in ("action", "confidence", "rationale")):
|
||||
logger.warning("Missing fields in Gemini response — defaulting to HOLD")
|
||||
# Preserve raw text in rationale so prompt_override callers (e.g. pre_market_planner)
|
||||
# can extract their own JSON format from decision.rationale (#245)
|
||||
return TradeDecision(
|
||||
action="HOLD", confidence=0, rationale="Missing required fields"
|
||||
action="HOLD", confidence=0, rationale=raw
|
||||
)
|
||||
|
||||
action = str(data["action"]).upper()
|
||||
@@ -152,28 +377,397 @@ class GeminiClient:
|
||||
# API Call
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def decide(self, market_data: dict[str, Any]) -> TradeDecision:
|
||||
"""Build prompt, call Gemini, and return a parsed decision."""
|
||||
prompt = self.build_prompt(market_data)
|
||||
logger.info("Requesting trade decision from Gemini")
|
||||
async def decide(
|
||||
self, market_data: dict[str, Any], news_sentiment: NewsSentiment | None = None
|
||||
) -> TradeDecision:
|
||||
"""Build prompt, call Gemini, and return a parsed decision.
|
||||
|
||||
Args:
|
||||
market_data: Market data dictionary with price, orderbook, etc.
|
||||
news_sentiment: Optional pre-fetched news sentiment
|
||||
|
||||
Returns:
|
||||
Parsed TradeDecision
|
||||
"""
|
||||
# Check cache first
|
||||
if self._cache:
|
||||
cached_decision = self._cache.get(market_data)
|
||||
if cached_decision:
|
||||
self._total_cached_decisions += 1
|
||||
self._total_decisions += 1
|
||||
logger.info(
|
||||
"Cache hit for decision",
|
||||
extra={
|
||||
"action": cached_decision.action,
|
||||
"confidence": cached_decision.confidence,
|
||||
"cache_hit_rate": self.get_cache_hit_rate(),
|
||||
},
|
||||
)
|
||||
# Return cached decision with cached flag
|
||||
return TradeDecision(
|
||||
action=cached_decision.action,
|
||||
confidence=cached_decision.confidence,
|
||||
rationale=cached_decision.rationale,
|
||||
token_count=0,
|
||||
cached=True,
|
||||
)
|
||||
|
||||
# Build prompt (prompt_override takes priority for callers like pre_market_planner)
|
||||
if "prompt_override" in market_data:
|
||||
prompt = market_data["prompt_override"]
|
||||
elif self._enable_optimization:
|
||||
prompt = self._optimizer.build_compressed_prompt(market_data)
|
||||
else:
|
||||
prompt = await self.build_prompt(market_data, news_sentiment)
|
||||
|
||||
# Estimate tokens
|
||||
token_count = self._optimizer.estimate_tokens(prompt)
|
||||
self._total_tokens_used += token_count
|
||||
|
||||
logger.info(
|
||||
"Requesting trade decision from Gemini",
|
||||
extra={"estimated_tokens": token_count, "optimized": self._enable_optimization},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self._client.aio.models.generate_content(
|
||||
model=self._model_name, contents=prompt,
|
||||
model=self._model_name,
|
||||
contents=prompt,
|
||||
)
|
||||
raw = response.text
|
||||
except Exception as exc:
|
||||
logger.error("Gemini API error: %s", exc)
|
||||
return TradeDecision(
|
||||
action="HOLD", confidence=0, rationale=f"API error: {exc}"
|
||||
action="HOLD", confidence=0, rationale=f"API error: {exc}", token_count=token_count
|
||||
)
|
||||
|
||||
# prompt_override callers (e.g. pre_market_planner) expect raw text back,
|
||||
# not a parsed TradeDecision. Skip parse_response to avoid spurious
|
||||
# "Missing fields" warnings and return the raw response directly. (#247)
|
||||
if "prompt_override" in market_data:
|
||||
logger.info(
|
||||
"Gemini raw response received (prompt_override, tokens=%d)", token_count
|
||||
)
|
||||
# Not a trade decision — don't inflate _total_decisions metrics
|
||||
return TradeDecision(
|
||||
action="HOLD", confidence=0, rationale=raw, token_count=token_count
|
||||
)
|
||||
|
||||
decision = self.parse_response(raw)
|
||||
self._total_decisions += 1
|
||||
|
||||
# Add token count to decision
|
||||
decision_with_tokens = TradeDecision(
|
||||
action=decision.action,
|
||||
confidence=decision.confidence,
|
||||
rationale=decision.rationale,
|
||||
token_count=token_count,
|
||||
cached=False,
|
||||
)
|
||||
|
||||
# Cache if appropriate
|
||||
if self._cache and self._cache.should_cache_decision(decision):
|
||||
self._cache.set(market_data, decision)
|
||||
|
||||
logger.info(
|
||||
"Gemini decision",
|
||||
extra={
|
||||
"action": decision.action,
|
||||
"confidence": decision.confidence,
|
||||
"tokens": token_count,
|
||||
"avg_tokens": self.get_avg_tokens_per_decision(),
|
||||
},
|
||||
)
|
||||
return decision
|
||||
|
||||
return decision_with_tokens
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Token Efficiency Metrics
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_token_metrics(self) -> dict[str, Any]:
|
||||
"""Get token usage metrics.
|
||||
|
||||
Returns:
|
||||
Dictionary with token usage statistics
|
||||
"""
|
||||
metrics = {
|
||||
"total_tokens_used": self._total_tokens_used,
|
||||
"total_decisions": self._total_decisions,
|
||||
"total_cached_decisions": self._total_cached_decisions,
|
||||
"avg_tokens_per_decision": self.get_avg_tokens_per_decision(),
|
||||
"cache_hit_rate": self.get_cache_hit_rate(),
|
||||
}
|
||||
|
||||
if self._cache:
|
||||
cache_metrics = self._cache.get_metrics()
|
||||
metrics["cache_metrics"] = cache_metrics.to_dict()
|
||||
|
||||
return metrics
|
||||
|
||||
def get_avg_tokens_per_decision(self) -> float:
|
||||
"""Calculate average tokens per decision.
|
||||
|
||||
Returns:
|
||||
Average tokens per decision
|
||||
"""
|
||||
if self._total_decisions == 0:
|
||||
return 0.0
|
||||
return self._total_tokens_used / self._total_decisions
|
||||
|
||||
def get_cache_hit_rate(self) -> float:
|
||||
"""Calculate cache hit rate.
|
||||
|
||||
Returns:
|
||||
Cache hit rate (0.0 to 1.0)
|
||||
"""
|
||||
if self._total_decisions == 0:
|
||||
return 0.0
|
||||
return self._total_cached_decisions / self._total_decisions
|
||||
|
||||
def reset_metrics(self) -> None:
|
||||
"""Reset token usage metrics."""
|
||||
self._total_tokens_used = 0
|
||||
self._total_decisions = 0
|
||||
self._total_cached_decisions = 0
|
||||
if self._cache:
|
||||
self._cache.reset_metrics()
|
||||
logger.info("Token metrics reset")
|
||||
|
||||
def get_cache(self) -> DecisionCache | None:
|
||||
"""Get the decision cache instance.
|
||||
|
||||
Returns:
|
||||
DecisionCache instance or None if caching disabled
|
||||
"""
|
||||
return self._cache
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Batch Decision Making (for daily trading mode)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def decide_batch(
|
||||
self, stocks_data: list[dict[str, Any]]
|
||||
) -> dict[str, TradeDecision]:
|
||||
"""Make decisions for multiple stocks in a single API call.
|
||||
|
||||
This is designed for daily trading mode to minimize API usage
|
||||
when working with Gemini Free tier (20 calls/day limit).
|
||||
|
||||
Args:
|
||||
stocks_data: List of market data dictionaries, each with:
|
||||
- stock_code: Stock ticker
|
||||
- current_price: Current price
|
||||
- market_name: Market name (optional)
|
||||
- foreigner_net: Foreigner net buy/sell (optional)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping stock_code to TradeDecision
|
||||
|
||||
Example:
|
||||
>>> stocks_data = [
|
||||
... {"stock_code": "AAPL", "current_price": 185.5},
|
||||
... {"stock_code": "MSFT", "current_price": 420.0},
|
||||
... ]
|
||||
>>> decisions = await client.decide_batch(stocks_data)
|
||||
>>> decisions["AAPL"].action
|
||||
'BUY'
|
||||
"""
|
||||
if not stocks_data:
|
||||
return {}
|
||||
|
||||
# Build compressed batch prompt
|
||||
market_name = stocks_data[0].get("market_name", "stock market")
|
||||
|
||||
# Format stock data as compact JSON array
|
||||
compact_stocks = []
|
||||
for stock in stocks_data:
|
||||
compact = {
|
||||
"code": stock["stock_code"],
|
||||
"price": stock["current_price"],
|
||||
}
|
||||
if stock.get("foreigner_net", 0) != 0:
|
||||
compact["frgn"] = stock["foreigner_net"]
|
||||
compact_stocks.append(compact)
|
||||
|
||||
data_str = json.dumps(compact_stocks, ensure_ascii=False)
|
||||
|
||||
prompt = (
|
||||
f"You are a professional {market_name} trading analyst.\n"
|
||||
"Analyze the following stocks and decide whether to BUY, SELL, or HOLD each one.\n\n"
|
||||
f"Stock Data: {data_str}\n\n"
|
||||
"You MUST respond with ONLY a valid JSON array in this format:\n"
|
||||
'[{"code": "AAPL", "action": "BUY", "confidence": 85, "rationale": "..."},\n'
|
||||
' {"code": "MSFT", "action": "HOLD", "confidence": 50, "rationale": "..."}, ...]\n\n'
|
||||
"Rules:\n"
|
||||
"- Return one decision object per stock\n"
|
||||
"- action must be exactly: BUY, SELL, or HOLD\n"
|
||||
"- confidence must be 0-100\n"
|
||||
"- rationale should be concise (1-2 sentences)\n"
|
||||
"- Do NOT wrap JSON in markdown code blocks\n"
|
||||
)
|
||||
|
||||
# Estimate tokens
|
||||
token_count = self._optimizer.estimate_tokens(prompt)
|
||||
self._total_tokens_used += token_count
|
||||
|
||||
logger.info(
|
||||
"Requesting batch decision for %d stocks from Gemini",
|
||||
len(stocks_data),
|
||||
extra={"estimated_tokens": token_count},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self._client.aio.models.generate_content(
|
||||
model=self._model_name,
|
||||
contents=prompt,
|
||||
)
|
||||
raw = response.text
|
||||
except Exception as exc:
|
||||
logger.error("Gemini API error in batch decision: %s", exc)
|
||||
# Return HOLD for all stocks on API error
|
||||
return {
|
||||
stock["stock_code"]: TradeDecision(
|
||||
action="HOLD",
|
||||
confidence=0,
|
||||
rationale=f"API error: {exc}",
|
||||
token_count=token_count,
|
||||
cached=False,
|
||||
)
|
||||
for stock in stocks_data
|
||||
}
|
||||
|
||||
# Parse batch response
|
||||
return self._parse_batch_response(raw, stocks_data, token_count)
|
||||
|
||||
def _parse_batch_response(
|
||||
self, raw: str, stocks_data: list[dict[str, Any]], token_count: int
|
||||
) -> dict[str, TradeDecision]:
|
||||
"""Parse batch response into a dictionary of decisions.
|
||||
|
||||
Args:
|
||||
raw: Raw response from Gemini
|
||||
stocks_data: Original stock data list
|
||||
token_count: Token count for the request
|
||||
|
||||
Returns:
|
||||
Dictionary mapping stock_code to TradeDecision
|
||||
"""
|
||||
if not raw or not raw.strip():
|
||||
logger.warning("Empty batch response from Gemini — defaulting all to HOLD")
|
||||
return {
|
||||
stock["stock_code"]: TradeDecision(
|
||||
action="HOLD",
|
||||
confidence=0,
|
||||
rationale="Empty response",
|
||||
token_count=0,
|
||||
cached=False,
|
||||
)
|
||||
for stock in stocks_data
|
||||
}
|
||||
|
||||
# Strip markdown code fences if present
|
||||
cleaned = raw.strip()
|
||||
match = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", cleaned, re.DOTALL)
|
||||
if match:
|
||||
cleaned = match.group(1).strip()
|
||||
|
||||
try:
|
||||
data = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Malformed JSON in batch response — defaulting all to HOLD")
|
||||
return {
|
||||
stock["stock_code"]: TradeDecision(
|
||||
action="HOLD",
|
||||
confidence=0,
|
||||
rationale="Malformed JSON response",
|
||||
token_count=0,
|
||||
cached=False,
|
||||
)
|
||||
for stock in stocks_data
|
||||
}
|
||||
|
||||
if not isinstance(data, list):
|
||||
logger.warning("Batch response is not a JSON array — defaulting all to HOLD")
|
||||
return {
|
||||
stock["stock_code"]: TradeDecision(
|
||||
action="HOLD",
|
||||
confidence=0,
|
||||
rationale="Invalid response format",
|
||||
token_count=0,
|
||||
cached=False,
|
||||
)
|
||||
for stock in stocks_data
|
||||
}
|
||||
|
||||
# Build decision map
|
||||
decisions: dict[str, TradeDecision] = {}
|
||||
stock_codes = {stock["stock_code"] for stock in stocks_data}
|
||||
|
||||
for item in data:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
code = item.get("code")
|
||||
if not code or code not in stock_codes:
|
||||
continue
|
||||
|
||||
# Validate required fields
|
||||
if not all(k in item for k in ("action", "confidence", "rationale")):
|
||||
logger.warning("Missing fields for %s — using HOLD", code)
|
||||
decisions[code] = TradeDecision(
|
||||
action="HOLD",
|
||||
confidence=0,
|
||||
rationale="Missing required fields",
|
||||
token_count=0,
|
||||
cached=False,
|
||||
)
|
||||
continue
|
||||
|
||||
action = str(item["action"]).upper()
|
||||
if action not in VALID_ACTIONS:
|
||||
logger.warning("Invalid action '%s' for %s — forcing HOLD", action, code)
|
||||
action = "HOLD"
|
||||
|
||||
confidence = int(item["confidence"])
|
||||
rationale = str(item["rationale"])
|
||||
|
||||
# Enforce confidence threshold
|
||||
if confidence < self._confidence_threshold:
|
||||
logger.info(
|
||||
"Confidence %d < threshold %d for %s — forcing HOLD",
|
||||
confidence,
|
||||
self._confidence_threshold,
|
||||
code,
|
||||
)
|
||||
action = "HOLD"
|
||||
|
||||
decisions[code] = TradeDecision(
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
rationale=rationale,
|
||||
token_count=token_count // len(stocks_data), # Split token cost
|
||||
cached=False,
|
||||
)
|
||||
self._total_decisions += 1
|
||||
|
||||
# Fill in missing stocks with HOLD
|
||||
for stock in stocks_data:
|
||||
code = stock["stock_code"]
|
||||
if code not in decisions:
|
||||
logger.warning("No decision for %s in batch response — using HOLD", code)
|
||||
decisions[code] = TradeDecision(
|
||||
action="HOLD",
|
||||
confidence=0,
|
||||
rationale="Not found in batch response",
|
||||
token_count=0,
|
||||
cached=False,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Batch decision completed for %d stocks",
|
||||
len(decisions),
|
||||
extra={"tokens": token_count},
|
||||
)
|
||||
|
||||
return decisions
|
||||
|
||||
267
src/brain/prompt_optimizer.py
Normal file
267
src/brain/prompt_optimizer.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""Prompt optimization utilities for reducing token usage.
|
||||
|
||||
This module provides tools to compress prompts while maintaining decision quality:
|
||||
- Token counting
|
||||
- Text compression and abbreviation
|
||||
- Template-based prompts with variable slots
|
||||
- Priority-based context truncation
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
# Abbreviation mapping for common terms
|
||||
ABBREVIATIONS = {
|
||||
"price": "P",
|
||||
"volume": "V",
|
||||
"current": "cur",
|
||||
"previous": "prev",
|
||||
"change": "chg",
|
||||
"percentage": "pct",
|
||||
"market": "mkt",
|
||||
"orderbook": "ob",
|
||||
"foreigner": "fgn",
|
||||
"buy": "B",
|
||||
"sell": "S",
|
||||
"hold": "H",
|
||||
"confidence": "conf",
|
||||
"rationale": "reason",
|
||||
"action": "act",
|
||||
"net": "net",
|
||||
}
|
||||
|
||||
# Reverse mapping for decompression
|
||||
REVERSE_ABBREVIATIONS = {v: k for k, v in ABBREVIATIONS.items()}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TokenMetrics:
|
||||
"""Metrics about token usage in a prompt."""
|
||||
|
||||
char_count: int
|
||||
word_count: int
|
||||
estimated_tokens: int # Rough estimate: ~4 chars per token
|
||||
compression_ratio: float = 1.0 # Original / Compressed
|
||||
|
||||
|
||||
class PromptOptimizer:
|
||||
"""Optimizes prompts to reduce token usage while maintaining quality."""
|
||||
|
||||
@staticmethod
|
||||
def estimate_tokens(text: str) -> int:
|
||||
"""Estimate token count for text.
|
||||
|
||||
Uses a simple heuristic: ~4 characters per token for English.
|
||||
This is approximate but sufficient for optimization purposes.
|
||||
|
||||
Args:
|
||||
text: Input text to estimate tokens for
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
# Simple estimate: 1 token ≈ 4 characters
|
||||
return max(1, len(text) // 4)
|
||||
|
||||
@staticmethod
|
||||
def count_tokens(text: str) -> TokenMetrics:
|
||||
"""Count various metrics for a text.
|
||||
|
||||
Args:
|
||||
text: Input text to analyze
|
||||
|
||||
Returns:
|
||||
TokenMetrics with character, word, and estimated token counts
|
||||
"""
|
||||
char_count = len(text)
|
||||
word_count = len(text.split())
|
||||
estimated_tokens = PromptOptimizer.estimate_tokens(text)
|
||||
|
||||
return TokenMetrics(
|
||||
char_count=char_count,
|
||||
word_count=word_count,
|
||||
estimated_tokens=estimated_tokens,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def compress_json(data: dict[str, Any]) -> str:
|
||||
"""Compress JSON by removing whitespace.
|
||||
|
||||
Args:
|
||||
data: Dictionary to serialize
|
||||
|
||||
Returns:
|
||||
Compact JSON string without whitespace
|
||||
"""
|
||||
return json.dumps(data, separators=(",", ":"), ensure_ascii=False)
|
||||
|
||||
@staticmethod
|
||||
def abbreviate_text(text: str, aggressive: bool = False) -> str:
|
||||
"""Apply abbreviations to reduce text length.
|
||||
|
||||
Args:
|
||||
text: Input text to abbreviate
|
||||
aggressive: If True, apply more aggressive compression
|
||||
|
||||
Returns:
|
||||
Abbreviated text
|
||||
"""
|
||||
result = text
|
||||
|
||||
# Apply word-level abbreviations (case-insensitive)
|
||||
for full, abbr in ABBREVIATIONS.items():
|
||||
# Word boundaries to avoid partial replacements
|
||||
pattern = r"\b" + re.escape(full) + r"\b"
|
||||
result = re.sub(pattern, abbr, result, flags=re.IGNORECASE)
|
||||
|
||||
if aggressive:
|
||||
# Remove articles and filler words
|
||||
result = re.sub(r"\b(a|an|the)\b", "", result, flags=re.IGNORECASE)
|
||||
result = re.sub(r"\b(is|are|was|were)\b", "", result, flags=re.IGNORECASE)
|
||||
# Collapse multiple spaces
|
||||
result = re.sub(r"\s+", " ", result)
|
||||
|
||||
return result.strip()
|
||||
|
||||
@staticmethod
|
||||
def build_compressed_prompt(
|
||||
market_data: dict[str, Any],
|
||||
include_instructions: bool = True,
|
||||
max_length: int | None = None,
|
||||
) -> str:
|
||||
"""Build a compressed prompt from market data.
|
||||
|
||||
Args:
|
||||
market_data: Market data dictionary with stock info
|
||||
include_instructions: Whether to include full instructions
|
||||
max_length: Maximum character length (truncates if needed)
|
||||
|
||||
Returns:
|
||||
Compressed prompt string
|
||||
"""
|
||||
# Abbreviated market name
|
||||
market_name = market_data.get("market_name", "KR")
|
||||
if "Korea" in market_name:
|
||||
market_name = "KR"
|
||||
elif "United States" in market_name or "US" in market_name:
|
||||
market_name = "US"
|
||||
|
||||
# Core data - always included
|
||||
core_info = {
|
||||
"mkt": market_name,
|
||||
"code": market_data["stock_code"],
|
||||
"P": market_data["current_price"],
|
||||
}
|
||||
|
||||
# Optional fields
|
||||
if "orderbook" in market_data and market_data["orderbook"]:
|
||||
ob = market_data["orderbook"]
|
||||
# Compress orderbook: keep only top 3 levels
|
||||
compressed_ob = {
|
||||
"bid": ob.get("bid", [])[:3],
|
||||
"ask": ob.get("ask", [])[:3],
|
||||
}
|
||||
core_info["ob"] = compressed_ob
|
||||
|
||||
if market_data.get("foreigner_net", 0) != 0:
|
||||
core_info["fgn_net"] = market_data["foreigner_net"]
|
||||
|
||||
# Compress to JSON
|
||||
data_str = PromptOptimizer.compress_json(core_info)
|
||||
|
||||
if include_instructions:
|
||||
# Minimal instructions
|
||||
prompt = (
|
||||
f"{market_name} trader. Analyze:\n{data_str}\n\n"
|
||||
'Return JSON: {"action":"BUY"|"SELL"|"HOLD","confidence":<0-100>,"rationale":"<text>"}\n'
|
||||
"Rules: action=BUY/SELL/HOLD, confidence=0-100, rationale=concise. No markdown."
|
||||
)
|
||||
else:
|
||||
# Data only (for cached contexts where instructions are known)
|
||||
prompt = data_str
|
||||
|
||||
# Truncate if needed
|
||||
if max_length and len(prompt) > max_length:
|
||||
prompt = prompt[:max_length] + "..."
|
||||
|
||||
return prompt
|
||||
|
||||
@staticmethod
|
||||
def truncate_context(
|
||||
context: dict[str, Any],
|
||||
max_tokens: int,
|
||||
priority_keys: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Truncate context data to fit within token budget.
|
||||
|
||||
Keeps high-priority keys first, then truncates less important data.
|
||||
|
||||
Args:
|
||||
context: Context dictionary to truncate
|
||||
max_tokens: Maximum token budget
|
||||
priority_keys: List of keys to keep (in order of priority)
|
||||
|
||||
Returns:
|
||||
Truncated context dictionary
|
||||
"""
|
||||
if not context:
|
||||
return {}
|
||||
|
||||
if priority_keys is None:
|
||||
priority_keys = []
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
current_tokens = 0
|
||||
|
||||
# Add priority keys first
|
||||
for key in priority_keys:
|
||||
if key in context:
|
||||
value_str = json.dumps(context[key])
|
||||
tokens = PromptOptimizer.estimate_tokens(value_str)
|
||||
|
||||
if current_tokens + tokens <= max_tokens:
|
||||
result[key] = context[key]
|
||||
current_tokens += tokens
|
||||
else:
|
||||
break
|
||||
|
||||
# Add remaining keys if space available
|
||||
for key, value in context.items():
|
||||
if key in result:
|
||||
continue
|
||||
|
||||
value_str = json.dumps(value)
|
||||
tokens = PromptOptimizer.estimate_tokens(value_str)
|
||||
|
||||
if current_tokens + tokens <= max_tokens:
|
||||
result[key] = value
|
||||
current_tokens += tokens
|
||||
else:
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def calculate_compression_ratio(original: str, compressed: str) -> float:
|
||||
"""Calculate compression ratio between original and compressed text.
|
||||
|
||||
Args:
|
||||
original: Original text
|
||||
compressed: Compressed text
|
||||
|
||||
Returns:
|
||||
Compression ratio (original_tokens / compressed_tokens)
|
||||
"""
|
||||
original_tokens = PromptOptimizer.estimate_tokens(original)
|
||||
compressed_tokens = PromptOptimizer.estimate_tokens(compressed)
|
||||
|
||||
if compressed_tokens == 0:
|
||||
return 1.0
|
||||
|
||||
return original_tokens / compressed_tokens
|
||||
@@ -8,7 +8,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import ssl
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import aiohttp
|
||||
|
||||
@@ -20,6 +20,39 @@ _KIS_VTS_HOST = "openapivts.koreainvestment.com"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def kr_tick_unit(price: float) -> int:
|
||||
"""Return KRX tick size for the given price level.
|
||||
|
||||
KRX price tick rules (domestic stocks):
|
||||
price < 2,000 → 1원
|
||||
2,000 ≤ price < 5,000 → 5원
|
||||
5,000 ≤ price < 20,000 → 10원
|
||||
20,000 ≤ price < 50,000 → 50원
|
||||
50,000 ≤ price < 200,000 → 100원
|
||||
200,000 ≤ price < 500,000 → 500원
|
||||
500,000 ≤ price → 1,000원
|
||||
"""
|
||||
if price < 2_000:
|
||||
return 1
|
||||
if price < 5_000:
|
||||
return 5
|
||||
if price < 20_000:
|
||||
return 10
|
||||
if price < 50_000:
|
||||
return 50
|
||||
if price < 200_000:
|
||||
return 100
|
||||
if price < 500_000:
|
||||
return 500
|
||||
return 1_000
|
||||
|
||||
|
||||
def kr_round_down(price: float) -> int:
|
||||
"""Round *down* price to the nearest KRX tick unit."""
|
||||
tick = kr_tick_unit(price)
|
||||
return int(price // tick * tick)
|
||||
|
||||
|
||||
class LeakyBucket:
|
||||
"""Simple leaky-bucket rate limiter for async code."""
|
||||
|
||||
@@ -55,6 +88,9 @@ class KISBroker:
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._access_token: str | None = None
|
||||
self._token_expires_at: float = 0.0
|
||||
self._token_lock = asyncio.Lock()
|
||||
self._last_refresh_attempt: float = 0.0
|
||||
self._refresh_cooldown: float = 60.0 # Seconds (matches KIS 1/minute limit)
|
||||
self._rate_limiter = LeakyBucket(settings.RATE_LIMIT_RPS)
|
||||
|
||||
def _get_session(self) -> aiohttp.ClientSession:
|
||||
@@ -80,12 +116,38 @@ class KISBroker:
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _ensure_token(self) -> str:
|
||||
"""Return a valid access token, refreshing if expired."""
|
||||
"""Return a valid access token, refreshing if expired.
|
||||
|
||||
Uses a lock to prevent concurrent token refresh attempts that would
|
||||
hit the API's 1-per-minute rate limit (EGW00133).
|
||||
"""
|
||||
# Fast path: check without lock
|
||||
now = asyncio.get_event_loop().time()
|
||||
if self._access_token and now < self._token_expires_at:
|
||||
return self._access_token
|
||||
|
||||
# Slow path: acquire lock and refresh
|
||||
async with self._token_lock:
|
||||
# Re-check after acquiring lock (another coroutine may have refreshed)
|
||||
now = asyncio.get_event_loop().time()
|
||||
if self._access_token and now < self._token_expires_at:
|
||||
return self._access_token
|
||||
|
||||
# Check cooldown period (prevents hitting EGW00133: 1/minute limit)
|
||||
time_since_last_attempt = now - self._last_refresh_attempt
|
||||
if time_since_last_attempt < self._refresh_cooldown:
|
||||
remaining = self._refresh_cooldown - time_since_last_attempt
|
||||
# Do not fail fast here. If token is unavailable, upstream calls
|
||||
# will all fail for up to a minute and scanning returns no trades.
|
||||
logger.warning(
|
||||
"Token refresh on cooldown. Waiting %.1fs before retry (KIS allows 1/minute)",
|
||||
remaining,
|
||||
)
|
||||
await asyncio.sleep(remaining)
|
||||
now = asyncio.get_event_loop().time()
|
||||
|
||||
logger.info("Refreshing KIS access token")
|
||||
self._last_refresh_attempt = now
|
||||
session = self._get_session()
|
||||
url = f"{self._base_url}/oauth2/tokenP"
|
||||
body = {
|
||||
@@ -111,6 +173,7 @@ class KISBroker:
|
||||
|
||||
async def _get_hash_key(self, body: dict[str, Any]) -> str:
|
||||
"""Request a hash key from KIS for POST request body signing."""
|
||||
await self._rate_limiter.acquire()
|
||||
session = self._get_session()
|
||||
url = f"{self._base_url}/uapi/hashkey"
|
||||
headers = {
|
||||
@@ -168,12 +231,64 @@ class KISBroker:
|
||||
except (TimeoutError, aiohttp.ClientError) as exc:
|
||||
raise ConnectionError(f"Network error fetching orderbook: {exc}") from exc
|
||||
|
||||
async def get_current_price(
|
||||
self, stock_code: str
|
||||
) -> tuple[float, float, float]:
|
||||
"""Fetch current price data for a domestic stock.
|
||||
|
||||
Uses the ``inquire-price`` API (FHKST01010100), which works in both
|
||||
real and VTS environments and returns the actual last-traded price.
|
||||
|
||||
Returns:
|
||||
(current_price, prdy_ctrt, frgn_ntby_qty)
|
||||
- current_price: Last traded price in KRW.
|
||||
- prdy_ctrt: Day change rate (%).
|
||||
- frgn_ntby_qty: Foreigner net buy quantity.
|
||||
"""
|
||||
await self._rate_limiter.acquire()
|
||||
session = self._get_session()
|
||||
|
||||
headers = await self._auth_headers("FHKST01010100")
|
||||
params = {
|
||||
"FID_COND_MRKT_DIV_CODE": "J",
|
||||
"FID_INPUT_ISCD": stock_code,
|
||||
}
|
||||
url = f"{self._base_url}/uapi/domestic-stock/v1/quotations/inquire-price"
|
||||
|
||||
def _f(val: str | None) -> float:
|
||||
try:
|
||||
return float(val or "0")
|
||||
except ValueError:
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
async with session.get(url, headers=headers, params=params) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ConnectionError(
|
||||
f"get_current_price failed ({resp.status}): {text}"
|
||||
)
|
||||
data = await resp.json()
|
||||
out = data.get("output", {})
|
||||
return (
|
||||
_f(out.get("stck_prpr")),
|
||||
_f(out.get("prdy_ctrt")),
|
||||
_f(out.get("frgn_ntby_qty")),
|
||||
)
|
||||
except (TimeoutError, aiohttp.ClientError) as exc:
|
||||
raise ConnectionError(
|
||||
f"Network error fetching current price: {exc}"
|
||||
) from exc
|
||||
|
||||
async def get_balance(self) -> dict[str, Any]:
|
||||
"""Fetch current account balance and holdings."""
|
||||
await self._rate_limiter.acquire()
|
||||
session = self._get_session()
|
||||
|
||||
headers = await self._auth_headers("VTTC8434R") # 모의투자 잔고조회
|
||||
# TR_ID: 실전 TTTC8434R, 모의 VTTC8434R
|
||||
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '국내주식 잔고조회' 시트
|
||||
tr_id = "TTTC8434R" if self._settings.MODE == "live" else "VTTC8434R"
|
||||
headers = await self._auth_headers(tr_id)
|
||||
params = {
|
||||
"CANO": self._account_no,
|
||||
"ACNT_PRDT_CD": self._product_cd,
|
||||
@@ -218,14 +333,30 @@ class KISBroker:
|
||||
await self._rate_limiter.acquire()
|
||||
session = self._get_session()
|
||||
|
||||
tr_id = "VTTC0802U" if order_type == "BUY" else "VTTC0801U"
|
||||
# TR_ID: 실전 BUY=TTTC0012U SELL=TTTC0011U, 모의 BUY=VTTC0012U SELL=VTTC0011U
|
||||
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '주식주문(현금)' 시트
|
||||
# ※ TTTC0802U/VTTC0802U는 미수매수(증거금40% 계좌 전용) — 현금주문에 사용 금지
|
||||
if self._settings.MODE == "live":
|
||||
tr_id = "TTTC0012U" if order_type == "BUY" else "TTTC0011U"
|
||||
else:
|
||||
tr_id = "VTTC0012U" if order_type == "BUY" else "VTTC0011U"
|
||||
|
||||
# KRX requires limit orders to be rounded down to the tick unit.
|
||||
# ORD_DVSN: "00"=지정가, "01"=시장가
|
||||
if price > 0:
|
||||
ord_dvsn = "00" # 지정가
|
||||
ord_price = kr_round_down(price)
|
||||
else:
|
||||
ord_dvsn = "01" # 시장가
|
||||
ord_price = 0
|
||||
|
||||
body = {
|
||||
"CANO": self._account_no,
|
||||
"ACNT_PRDT_CD": self._product_cd,
|
||||
"PDNO": stock_code,
|
||||
"ORD_DVSN": "01" if price > 0 else "06", # 01=지정가, 06=시장가
|
||||
"ORD_DVSN": ord_dvsn,
|
||||
"ORD_QTY": str(quantity),
|
||||
"ORD_UNPR": str(price),
|
||||
"ORD_UNPR": str(ord_price),
|
||||
}
|
||||
|
||||
hash_key = await self._get_hash_key(body)
|
||||
@@ -252,3 +383,279 @@ class KISBroker:
|
||||
return data
|
||||
except (TimeoutError, aiohttp.ClientError) as exc:
|
||||
raise ConnectionError(f"Network error sending order: {exc}") from exc
|
||||
|
||||
async def fetch_market_rankings(
|
||||
self,
|
||||
ranking_type: str = "volume",
|
||||
limit: int = 30,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch market rankings from KIS API.
|
||||
|
||||
Args:
|
||||
ranking_type: Type of ranking ("volume" or "fluctuation")
|
||||
limit: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List of stock data dicts with keys: stock_code, name, price, volume,
|
||||
change_rate, volume_increase_rate
|
||||
|
||||
Raises:
|
||||
ConnectionError: If API request fails
|
||||
"""
|
||||
await self._rate_limiter.acquire()
|
||||
session = self._get_session()
|
||||
|
||||
if ranking_type == "volume":
|
||||
# 거래량순위: FHPST01710000 / /quotations/volume-rank
|
||||
tr_id = "FHPST01710000"
|
||||
url = f"{self._base_url}/uapi/domestic-stock/v1/quotations/volume-rank"
|
||||
params: dict[str, str] = {
|
||||
"FID_COND_MRKT_DIV_CODE": "J",
|
||||
"FID_COND_SCR_DIV_CODE": "20171",
|
||||
"FID_INPUT_ISCD": "0000",
|
||||
"FID_DIV_CLS_CODE": "0",
|
||||
"FID_BLNG_CLS_CODE": "0",
|
||||
"FID_TRGT_CLS_CODE": "111111111",
|
||||
"FID_TRGT_EXLS_CLS_CODE": "0000000000",
|
||||
"FID_INPUT_PRICE_1": "0",
|
||||
"FID_INPUT_PRICE_2": "0",
|
||||
"FID_VOL_CNT": "0",
|
||||
"FID_INPUT_DATE_1": "",
|
||||
}
|
||||
else:
|
||||
# 등락률순위: FHPST01700000 / /ranking/fluctuation (소문자 파라미터)
|
||||
tr_id = "FHPST01700000"
|
||||
url = f"{self._base_url}/uapi/domestic-stock/v1/ranking/fluctuation"
|
||||
params = {
|
||||
"fid_cond_mrkt_div_code": "J",
|
||||
"fid_cond_scr_div_code": "20170",
|
||||
"fid_input_iscd": "0000",
|
||||
"fid_rank_sort_cls_code": "0",
|
||||
"fid_input_cnt_1": str(limit),
|
||||
"fid_prc_cls_code": "0",
|
||||
"fid_input_price_1": "0",
|
||||
"fid_input_price_2": "0",
|
||||
"fid_vol_cnt": "0",
|
||||
"fid_trgt_cls_code": "0",
|
||||
"fid_trgt_exls_cls_code": "0",
|
||||
"fid_div_cls_code": "0",
|
||||
"fid_rsfl_rate1": "0",
|
||||
"fid_rsfl_rate2": "0",
|
||||
}
|
||||
|
||||
headers = await self._auth_headers(tr_id)
|
||||
|
||||
try:
|
||||
async with session.get(url, headers=headers, params=params) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ConnectionError(
|
||||
f"fetch_market_rankings failed ({resp.status}): {text}"
|
||||
)
|
||||
data = await resp.json()
|
||||
|
||||
# Parse response - output is a list of ranked stocks
|
||||
def _safe_float(value: str | float | None, default: float = 0.0) -> float:
|
||||
if value is None or value == "":
|
||||
return default
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
return default
|
||||
|
||||
rankings = []
|
||||
for item in data.get("output", [])[:limit]:
|
||||
rankings.append({
|
||||
"stock_code": item.get("stck_shrn_iscd") or item.get("mksc_shrn_iscd", ""),
|
||||
"name": item.get("hts_kor_isnm", ""),
|
||||
"price": _safe_float(item.get("stck_prpr", "0")),
|
||||
"volume": _safe_float(item.get("acml_vol", "0")),
|
||||
"change_rate": _safe_float(item.get("prdy_ctrt", "0")),
|
||||
"volume_increase_rate": _safe_float(item.get("vol_inrt", "0")),
|
||||
})
|
||||
return rankings
|
||||
|
||||
except (TimeoutError, aiohttp.ClientError) as exc:
|
||||
raise ConnectionError(f"Network error fetching rankings: {exc}") from exc
|
||||
|
||||
async def get_domestic_pending_orders(self) -> list[dict[str, Any]]:
|
||||
"""Fetch unfilled (pending) domestic limit orders.
|
||||
|
||||
The KIS pending-orders API (TTTC0084R) is unsupported in paper (VTS)
|
||||
mode, so this method returns an empty list immediately when MODE is
|
||||
not "live".
|
||||
|
||||
Returns:
|
||||
List of pending order dicts from the KIS ``output`` field.
|
||||
Each dict includes keys such as ``odno``, ``orgn_odno``,
|
||||
``ord_gno_brno``, ``psbl_qty``, ``sll_buy_dvsn_cd``, ``pdno``.
|
||||
"""
|
||||
if self._settings.MODE != "live":
|
||||
logger.debug(
|
||||
"get_domestic_pending_orders: paper mode — TTTC0084R unsupported, returning []"
|
||||
)
|
||||
return []
|
||||
|
||||
await self._rate_limiter.acquire()
|
||||
session = self._get_session()
|
||||
|
||||
# TR_ID: 실전 TTTC0084R (모의 미지원)
|
||||
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '주식 미체결조회' 시트
|
||||
headers = await self._auth_headers("TTTC0084R")
|
||||
params = {
|
||||
"CANO": self._account_no,
|
||||
"ACNT_PRDT_CD": self._product_cd,
|
||||
"INQR_DVSN_1": "0",
|
||||
"INQR_DVSN_2": "0",
|
||||
"CTX_AREA_FK100": "",
|
||||
"CTX_AREA_NK100": "",
|
||||
}
|
||||
url = f"{self._base_url}/uapi/domestic-stock/v1/trading/inquire-psbl-rvsecncl"
|
||||
|
||||
try:
|
||||
async with session.get(url, headers=headers, params=params) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ConnectionError(
|
||||
f"get_domestic_pending_orders failed ({resp.status}): {text}"
|
||||
)
|
||||
data = await resp.json()
|
||||
return data.get("output", []) or []
|
||||
except (TimeoutError, aiohttp.ClientError) as exc:
|
||||
raise ConnectionError(
|
||||
f"Network error fetching domestic pending orders: {exc}"
|
||||
) from exc
|
||||
|
||||
async def cancel_domestic_order(
|
||||
self,
|
||||
stock_code: str,
|
||||
orgn_odno: str,
|
||||
krx_fwdg_ord_orgno: str,
|
||||
qty: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Cancel an unfilled domestic limit order.
|
||||
|
||||
Args:
|
||||
stock_code: 6-digit domestic stock code (``pdno``).
|
||||
orgn_odno: Original order number from pending-orders response
|
||||
(``orgn_odno`` field).
|
||||
krx_fwdg_ord_orgno: KRX forwarding order branch number from
|
||||
pending-orders response (``ord_gno_brno`` field).
|
||||
qty: Quantity to cancel (use ``psbl_qty`` from pending order).
|
||||
|
||||
Returns:
|
||||
Raw KIS API response dict (check ``rt_cd == "0"`` for success).
|
||||
"""
|
||||
await self._rate_limiter.acquire()
|
||||
session = self._get_session()
|
||||
|
||||
# TR_ID: 실전 TTTC0013U, 모의 VTTC0013U
|
||||
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '주식주문(정정취소)' 시트
|
||||
tr_id = "TTTC0013U" if self._settings.MODE == "live" else "VTTC0013U"
|
||||
|
||||
body = {
|
||||
"CANO": self._account_no,
|
||||
"ACNT_PRDT_CD": self._product_cd,
|
||||
"KRX_FWDG_ORD_ORGNO": krx_fwdg_ord_orgno,
|
||||
"ORGN_ODNO": orgn_odno,
|
||||
"ORD_DVSN": "00",
|
||||
"ORD_QTY": str(qty),
|
||||
"ORD_UNPR": "0",
|
||||
"RVSE_CNCL_DVSN_CD": "02",
|
||||
"QTY_ALL_ORD_YN": "Y",
|
||||
}
|
||||
|
||||
hash_key = await self._get_hash_key(body)
|
||||
headers = await self._auth_headers(tr_id)
|
||||
headers["hashkey"] = hash_key
|
||||
|
||||
url = f"{self._base_url}/uapi/domestic-stock/v1/trading/order-rvsecncl"
|
||||
|
||||
try:
|
||||
async with session.post(url, headers=headers, json=body) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ConnectionError(
|
||||
f"cancel_domestic_order failed ({resp.status}): {text}"
|
||||
)
|
||||
return cast(dict[str, Any], await resp.json())
|
||||
except (TimeoutError, aiohttp.ClientError) as exc:
|
||||
raise ConnectionError(
|
||||
f"Network error cancelling domestic order: {exc}"
|
||||
) from exc
|
||||
|
||||
async def get_daily_prices(
|
||||
self,
|
||||
stock_code: str,
|
||||
days: int = 20,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch daily OHLCV price history for a stock.
|
||||
|
||||
Args:
|
||||
stock_code: 6-digit stock code
|
||||
days: Number of trading days to fetch (default 20 for RSI calculation)
|
||||
|
||||
Returns:
|
||||
List of daily price dicts with keys: date, open, high, low, close, volume
|
||||
Sorted oldest to newest
|
||||
|
||||
Raises:
|
||||
ConnectionError: If API request fails
|
||||
"""
|
||||
await self._rate_limiter.acquire()
|
||||
session = self._get_session()
|
||||
|
||||
headers = await self._auth_headers("FHKST03010100")
|
||||
|
||||
# Calculate date range (today and N days ago)
|
||||
from datetime import datetime, timedelta
|
||||
end_date = datetime.now().strftime("%Y%m%d")
|
||||
start_date = (datetime.now() - timedelta(days=days + 10)).strftime("%Y%m%d")
|
||||
|
||||
params = {
|
||||
"FID_COND_MRKT_DIV_CODE": "J",
|
||||
"FID_INPUT_ISCD": stock_code,
|
||||
"FID_INPUT_DATE_1": start_date,
|
||||
"FID_INPUT_DATE_2": end_date,
|
||||
"FID_PERIOD_DIV_CODE": "D", # Daily
|
||||
"FID_ORG_ADJ_PRC": "0", # Adjusted price
|
||||
}
|
||||
|
||||
url = f"{self._base_url}/uapi/domestic-stock/v1/quotations/inquire-daily-itemchartprice"
|
||||
|
||||
try:
|
||||
async with session.get(url, headers=headers, params=params) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ConnectionError(
|
||||
f"get_daily_prices failed ({resp.status}): {text}"
|
||||
)
|
||||
data = await resp.json()
|
||||
|
||||
# Parse response
|
||||
def _safe_float(value: str | float | None, default: float = 0.0) -> float:
|
||||
if value is None or value == "":
|
||||
return default
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
return default
|
||||
|
||||
prices = []
|
||||
for item in data.get("output2", []):
|
||||
prices.append({
|
||||
"date": item.get("stck_bsop_date", ""),
|
||||
"open": _safe_float(item.get("stck_oprc", "0")),
|
||||
"high": _safe_float(item.get("stck_hgpr", "0")),
|
||||
"low": _safe_float(item.get("stck_lwpr", "0")),
|
||||
"close": _safe_float(item.get("stck_clpr", "0")),
|
||||
"volume": _safe_float(item.get("acml_vol", "0")),
|
||||
})
|
||||
|
||||
# Sort oldest to newest (KIS returns newest first)
|
||||
prices.reverse()
|
||||
|
||||
return prices[:days] # Return only requested number of days
|
||||
|
||||
except (TimeoutError, aiohttp.ClientError) as exc:
|
||||
raise ConnectionError(f"Network error fetching daily prices: {exc}") from exc
|
||||
|
||||
@@ -12,6 +12,38 @@ from src.broker.kis_api import KISBroker
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Ranking API uses different exchange codes than order/quote APIs.
|
||||
_RANKING_EXCHANGE_MAP: dict[str, str] = {
|
||||
"NASD": "NAS",
|
||||
"NYSE": "NYS",
|
||||
"AMEX": "AMS",
|
||||
"SEHK": "HKS",
|
||||
"SHAA": "SHS",
|
||||
"SZAA": "SZS",
|
||||
"HSX": "HSX",
|
||||
"HNX": "HNX",
|
||||
"TSE": "TSE",
|
||||
}
|
||||
|
||||
# Price inquiry API (HHDFS00000300) uses the same short exchange codes as rankings.
|
||||
# NASD → NAS, NYSE → NYS, AMEX → AMS (confirmed: AMEX returns empty, AMS returns price).
|
||||
_PRICE_EXCHANGE_MAP: dict[str, str] = _RANKING_EXCHANGE_MAP
|
||||
|
||||
# Cancel order TR_IDs per exchange code — (live_tr_id, paper_tr_id).
|
||||
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '해외주식 주문취소' 시트
|
||||
_CANCEL_TR_ID_MAP: dict[str, tuple[str, str]] = {
|
||||
"NASD": ("TTTT1004U", "VTTT1004U"),
|
||||
"NYSE": ("TTTT1004U", "VTTT1004U"),
|
||||
"AMEX": ("TTTT1004U", "VTTT1004U"),
|
||||
"SEHK": ("TTTS1003U", "VTTS1003U"),
|
||||
"TSE": ("TTTS0309U", "VTTS0309U"),
|
||||
"SHAA": ("TTTS0302U", "VTTS0302U"),
|
||||
"SZAA": ("TTTS0306U", "VTTS0306U"),
|
||||
"HNX": ("TTTS0312U", "VTTS0312U"),
|
||||
"HSX": ("TTTS0312U", "VTTS0312U"),
|
||||
}
|
||||
|
||||
|
||||
class OverseasBroker:
|
||||
"""KIS Overseas Stock API wrapper that reuses KISBroker infrastructure."""
|
||||
|
||||
@@ -44,9 +76,11 @@ class OverseasBroker:
|
||||
session = self._broker._get_session()
|
||||
|
||||
headers = await self._broker._auth_headers("HHDFS00000300")
|
||||
# Map internal exchange codes to the short form expected by the price API.
|
||||
price_excd = _PRICE_EXCHANGE_MAP.get(exchange_code, exchange_code)
|
||||
params = {
|
||||
"AUTH": "",
|
||||
"EXCD": exchange_code,
|
||||
"EXCD": price_excd,
|
||||
"SYMB": stock_code,
|
||||
}
|
||||
url = f"{self._broker._base_url}/uapi/overseas-price/v1/quotations/price"
|
||||
@@ -64,6 +98,83 @@ class OverseasBroker:
|
||||
f"Network error fetching overseas price: {exc}"
|
||||
) from exc
|
||||
|
||||
async def fetch_overseas_rankings(
|
||||
self,
|
||||
exchange_code: str,
|
||||
ranking_type: str = "fluctuation",
|
||||
limit: int = 30,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch overseas rankings (price change or volume surge).
|
||||
|
||||
Ranking API specs may differ by account/product. Endpoint paths and
|
||||
TR_IDs are configurable via settings and can be overridden in .env.
|
||||
"""
|
||||
if not self._broker._settings.OVERSEAS_RANKING_ENABLED:
|
||||
return []
|
||||
|
||||
await self._broker._rate_limiter.acquire()
|
||||
session = self._broker._get_session()
|
||||
|
||||
ranking_excd = _RANKING_EXCHANGE_MAP.get(exchange_code, exchange_code)
|
||||
|
||||
if ranking_type == "volume":
|
||||
tr_id = self._broker._settings.OVERSEAS_RANKING_VOLUME_TR_ID
|
||||
path = self._broker._settings.OVERSEAS_RANKING_VOLUME_PATH
|
||||
params: dict[str, str] = {
|
||||
"KEYB": "", # NEXT KEY BUFF — Required, 공백
|
||||
"AUTH": "",
|
||||
"EXCD": ranking_excd,
|
||||
"MIXN": "0",
|
||||
"VOL_RANG": "0",
|
||||
}
|
||||
else:
|
||||
tr_id = self._broker._settings.OVERSEAS_RANKING_FLUCT_TR_ID
|
||||
path = self._broker._settings.OVERSEAS_RANKING_FLUCT_PATH
|
||||
params = {
|
||||
"KEYB": "", # NEXT KEY BUFF — Required, 공백
|
||||
"AUTH": "",
|
||||
"EXCD": ranking_excd,
|
||||
"NDAY": "0",
|
||||
"GUBN": "1", # 0=하락율, 1=상승율 — 변동성 스캐너는 급등 종목 우선
|
||||
"VOL_RANG": "0",
|
||||
}
|
||||
|
||||
headers = await self._broker._auth_headers(tr_id)
|
||||
url = f"{self._broker._base_url}{path}"
|
||||
|
||||
try:
|
||||
async with session.get(url, headers=headers, params=params) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
if resp.status == 404:
|
||||
logger.warning(
|
||||
"Overseas ranking endpoint unavailable (404) for %s/%s; "
|
||||
"using symbol fallback scan",
|
||||
exchange_code,
|
||||
ranking_type,
|
||||
)
|
||||
return []
|
||||
raise ConnectionError(
|
||||
f"fetch_overseas_rankings failed ({resp.status}): {text}"
|
||||
)
|
||||
|
||||
data = await resp.json()
|
||||
rows = self._extract_ranking_rows(data)
|
||||
if rows:
|
||||
return rows[:limit]
|
||||
|
||||
logger.debug(
|
||||
"Overseas ranking returned empty for %s/%s (keys=%s)",
|
||||
exchange_code,
|
||||
ranking_type,
|
||||
list(data.keys()),
|
||||
)
|
||||
return []
|
||||
except (TimeoutError, aiohttp.ClientError) as exc:
|
||||
raise ConnectionError(
|
||||
f"Network error fetching overseas rankings: {exc}"
|
||||
) from exc
|
||||
|
||||
async def get_overseas_balance(self, exchange_code: str) -> dict[str, Any]:
|
||||
"""
|
||||
Fetch overseas account balance.
|
||||
@@ -80,8 +191,12 @@ class OverseasBroker:
|
||||
await self._broker._rate_limiter.acquire()
|
||||
session = self._broker._get_session()
|
||||
|
||||
# Virtual trading TR_ID for overseas balance inquiry
|
||||
headers = await self._broker._auth_headers("VTTS3012R")
|
||||
# TR_ID: 실전 TTTS3012R, 모의 VTTS3012R
|
||||
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '해외주식 잔고조회' 시트
|
||||
balance_tr_id = (
|
||||
"TTTS3012R" if self._broker._settings.MODE == "live" else "VTTS3012R"
|
||||
)
|
||||
headers = await self._broker._auth_headers(balance_tr_id)
|
||||
params = {
|
||||
"CANO": self._broker._account_no,
|
||||
"ACNT_PRDT_CD": self._broker._product_cd,
|
||||
@@ -107,6 +222,59 @@ class OverseasBroker:
|
||||
f"Network error fetching overseas balance: {exc}"
|
||||
) from exc
|
||||
|
||||
async def get_overseas_buying_power(
|
||||
self,
|
||||
exchange_code: str,
|
||||
stock_code: str,
|
||||
price: float,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Fetch overseas buying power for a specific stock and price.
|
||||
|
||||
Args:
|
||||
exchange_code: Exchange code (e.g., "NASD", "NYSE")
|
||||
stock_code: Stock ticker symbol
|
||||
price: Current stock price (used for quantity calculation)
|
||||
|
||||
Returns:
|
||||
API response; key field: output.ord_psbl_frcr_amt (주문가능외화금액)
|
||||
|
||||
Raises:
|
||||
ConnectionError: On network or API errors
|
||||
"""
|
||||
await self._broker._rate_limiter.acquire()
|
||||
session = self._broker._get_session()
|
||||
|
||||
# TR_ID: 실전 TTTS3007R, 모의 VTTS3007R
|
||||
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '해외주식 매수가능금액조회' 시트
|
||||
ps_tr_id = (
|
||||
"TTTS3007R" if self._broker._settings.MODE == "live" else "VTTS3007R"
|
||||
)
|
||||
headers = await self._broker._auth_headers(ps_tr_id)
|
||||
params = {
|
||||
"CANO": self._broker._account_no,
|
||||
"ACNT_PRDT_CD": self._broker._product_cd,
|
||||
"OVRS_EXCG_CD": exchange_code,
|
||||
"OVRS_ORD_UNPR": f"{price:.2f}",
|
||||
"ITEM_CD": stock_code,
|
||||
}
|
||||
url = (
|
||||
f"{self._broker._base_url}/uapi/overseas-stock/v1/trading/inquire-psamount"
|
||||
)
|
||||
|
||||
try:
|
||||
async with session.get(url, headers=headers, params=params) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ConnectionError(
|
||||
f"get_overseas_buying_power failed ({resp.status}): {text}"
|
||||
)
|
||||
return await resp.json()
|
||||
except (TimeoutError, aiohttp.ClientError) as exc:
|
||||
raise ConnectionError(
|
||||
f"Network error fetching overseas buying power: {exc}"
|
||||
) from exc
|
||||
|
||||
async def send_overseas_order(
|
||||
self,
|
||||
exchange_code: str,
|
||||
@@ -134,8 +302,12 @@ class OverseasBroker:
|
||||
await self._broker._rate_limiter.acquire()
|
||||
session = self._broker._get_session()
|
||||
|
||||
# Virtual trading TR_IDs for overseas orders
|
||||
tr_id = "VTTT1002U" if order_type == "BUY" else "VTTT1006U"
|
||||
# TR_ID: 실전 BUY=TTTT1002U SELL=TTTT1006U, 모의 BUY=VTTT1002U SELL=VTTT1001U
|
||||
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '해외주식 주문' 시트
|
||||
if self._broker._settings.MODE == "live":
|
||||
tr_id = "TTTT1002U" if order_type == "BUY" else "TTTT1006U"
|
||||
else:
|
||||
tr_id = "VTTT1002U" if order_type == "BUY" else "VTTT1001U"
|
||||
|
||||
body = {
|
||||
"CANO": self._broker._account_no,
|
||||
@@ -162,6 +334,9 @@ class OverseasBroker:
|
||||
f"send_overseas_order failed ({resp.status}): {text}"
|
||||
)
|
||||
data = await resp.json()
|
||||
rt_cd = data.get("rt_cd", "")
|
||||
msg1 = data.get("msg1", "")
|
||||
if rt_cd == "0":
|
||||
logger.info(
|
||||
"Overseas order submitted",
|
||||
extra={
|
||||
@@ -170,12 +345,147 @@ class OverseasBroker:
|
||||
"action": order_type,
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Overseas order rejected (rt_cd=%s): %s [%s %s %s qty=%d]",
|
||||
rt_cd,
|
||||
msg1,
|
||||
order_type,
|
||||
stock_code,
|
||||
exchange_code,
|
||||
quantity,
|
||||
)
|
||||
return data
|
||||
except (TimeoutError, aiohttp.ClientError) as exc:
|
||||
raise ConnectionError(
|
||||
f"Network error sending overseas order: {exc}"
|
||||
) from exc
|
||||
|
||||
async def get_overseas_pending_orders(
|
||||
self, exchange_code: str
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch unfilled (pending) overseas orders for a given exchange.
|
||||
|
||||
Args:
|
||||
exchange_code: Exchange code (e.g., "NASD", "SEHK").
|
||||
For US markets, NASD returns all US pending orders (NASD/NYSE/AMEX).
|
||||
|
||||
Returns:
|
||||
List of pending order dicts with fields: odno, pdno, sll_buy_dvsn_cd,
|
||||
ft_ord_qty, nccs_qty, ft_ord_unpr3, ovrs_excg_cd.
|
||||
Always returns [] in paper mode (TTTS3018R is live-only).
|
||||
|
||||
Raises:
|
||||
ConnectionError: On network or API errors (live mode only).
|
||||
"""
|
||||
if self._broker._settings.MODE != "live":
|
||||
logger.debug(
|
||||
"Pending orders API (TTTS3018R) not supported in paper mode; returning []"
|
||||
)
|
||||
return []
|
||||
|
||||
await self._broker._rate_limiter.acquire()
|
||||
session = self._broker._get_session()
|
||||
|
||||
# TTTS3018R: 해외주식 미체결내역조회 (실전 전용)
|
||||
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '해외주식 미체결조회' 시트
|
||||
headers = await self._broker._auth_headers("TTTS3018R")
|
||||
params = {
|
||||
"CANO": self._broker._account_no,
|
||||
"ACNT_PRDT_CD": self._broker._product_cd,
|
||||
"OVRS_EXCG_CD": exchange_code,
|
||||
"SORT_SQN": "DS",
|
||||
"CTX_AREA_FK200": "",
|
||||
"CTX_AREA_NK200": "",
|
||||
}
|
||||
url = (
|
||||
f"{self._broker._base_url}/uapi/overseas-stock/v1/trading/inquire-nccs"
|
||||
)
|
||||
|
||||
try:
|
||||
async with session.get(url, headers=headers, params=params) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ConnectionError(
|
||||
f"get_overseas_pending_orders failed ({resp.status}): {text}"
|
||||
)
|
||||
data = await resp.json()
|
||||
output = data.get("output", [])
|
||||
if isinstance(output, list):
|
||||
return output
|
||||
return []
|
||||
except (TimeoutError, aiohttp.ClientError) as exc:
|
||||
raise ConnectionError(
|
||||
f"Network error fetching pending orders: {exc}"
|
||||
) from exc
|
||||
|
||||
async def cancel_overseas_order(
|
||||
self,
|
||||
exchange_code: str,
|
||||
stock_code: str,
|
||||
odno: str,
|
||||
qty: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Cancel an overseas limit order.
|
||||
|
||||
Args:
|
||||
exchange_code: Exchange code (e.g., "NASD", "SEHK").
|
||||
stock_code: Stock ticker symbol.
|
||||
odno: Original order number to cancel.
|
||||
qty: Unfilled quantity to cancel.
|
||||
|
||||
Returns:
|
||||
API response dict containing rt_cd and msg1.
|
||||
|
||||
Raises:
|
||||
ValueError: If exchange_code has no cancel TR_ID mapping.
|
||||
ConnectionError: On network or API errors.
|
||||
"""
|
||||
tr_ids = _CANCEL_TR_ID_MAP.get(exchange_code)
|
||||
if tr_ids is None:
|
||||
raise ValueError(f"No cancel TR_ID mapping for exchange: {exchange_code}")
|
||||
live_tr_id, paper_tr_id = tr_ids
|
||||
tr_id = live_tr_id if self._broker._settings.MODE == "live" else paper_tr_id
|
||||
|
||||
await self._broker._rate_limiter.acquire()
|
||||
session = self._broker._get_session()
|
||||
|
||||
# RVSE_CNCL_DVSN_CD="02" means cancel (not revision).
|
||||
# OVRS_ORD_UNPR must be "0" for cancellations.
|
||||
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '해외주식 정정취소주문' 시트
|
||||
body = {
|
||||
"CANO": self._broker._account_no,
|
||||
"ACNT_PRDT_CD": self._broker._product_cd,
|
||||
"OVRS_EXCG_CD": exchange_code,
|
||||
"PDNO": stock_code,
|
||||
"ORGN_ODNO": odno,
|
||||
"RVSE_CNCL_DVSN_CD": "02",
|
||||
"ORD_QTY": str(qty),
|
||||
"OVRS_ORD_UNPR": "0",
|
||||
"ORD_SVR_DVSN_CD": "0",
|
||||
}
|
||||
|
||||
hash_key = await self._broker._get_hash_key(body)
|
||||
headers = await self._broker._auth_headers(tr_id)
|
||||
headers["hashkey"] = hash_key
|
||||
|
||||
url = (
|
||||
f"{self._broker._base_url}/uapi/overseas-stock/v1/trading/order-rvsecncl"
|
||||
)
|
||||
|
||||
try:
|
||||
async with session.post(url, headers=headers, json=body) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ConnectionError(
|
||||
f"cancel_overseas_order failed ({resp.status}): {text}"
|
||||
)
|
||||
return await resp.json()
|
||||
except (TimeoutError, aiohttp.ClientError) as exc:
|
||||
raise ConnectionError(
|
||||
f"Network error cancelling overseas order: {exc}"
|
||||
) from exc
|
||||
|
||||
def _get_currency_code(self, exchange_code: str) -> str:
|
||||
"""
|
||||
Map exchange code to currency code.
|
||||
@@ -198,3 +508,11 @@ class OverseasBroker:
|
||||
"HSX": "VND",
|
||||
}
|
||||
return currency_map.get(exchange_code, "USD")
|
||||
|
||||
def _extract_ranking_rows(self, data: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""Extract list rows from ranking response across schema variants."""
|
||||
candidates = [data.get("output"), data.get("output1"), data.get("output2")]
|
||||
for value in candidates:
|
||||
if isinstance(value, list):
|
||||
return [row for row in value if isinstance(row, dict)]
|
||||
return []
|
||||
|
||||
@@ -13,28 +13,112 @@ class Settings(BaseSettings):
|
||||
KIS_APP_KEY: str
|
||||
KIS_APP_SECRET: str
|
||||
KIS_ACCOUNT_NO: str # format: "XXXXXXXX-XX"
|
||||
KIS_BASE_URL: str = "https://openapivts.koreainvestment.com:9443"
|
||||
KIS_BASE_URL: str = "https://openapivts.koreainvestment.com:29443"
|
||||
|
||||
# Google Gemini
|
||||
GEMINI_API_KEY: str
|
||||
GEMINI_MODEL: str = "gemini-pro"
|
||||
GEMINI_MODEL: str = "gemini-2.0-flash"
|
||||
|
||||
# External Data APIs (optional — for data-driven decisions)
|
||||
NEWS_API_KEY: str | None = None
|
||||
NEWS_API_PROVIDER: str = "alphavantage" # "alphavantage" or "newsapi"
|
||||
MARKET_DATA_API_KEY: str | None = None
|
||||
|
||||
# Legacy field names (for backward compatibility)
|
||||
ALPHA_VANTAGE_API_KEY: str | None = None
|
||||
NEWSAPI_KEY: str | None = None
|
||||
|
||||
# Risk Management
|
||||
CIRCUIT_BREAKER_PCT: float = Field(default=-3.0, le=0.0)
|
||||
FAT_FINGER_PCT: float = Field(default=30.0, gt=0.0, le=100.0)
|
||||
CONFIDENCE_THRESHOLD: int = Field(default=80, ge=0, le=100)
|
||||
|
||||
# Smart Scanner Configuration
|
||||
RSI_OVERSOLD_THRESHOLD: int = Field(default=30, ge=0, le=50)
|
||||
RSI_MOMENTUM_THRESHOLD: int = Field(default=70, ge=50, le=100)
|
||||
VOL_MULTIPLIER: float = Field(default=2.0, gt=1.0, le=10.0)
|
||||
SCANNER_TOP_N: int = Field(default=3, ge=1, le=10)
|
||||
POSITION_SIZING_ENABLED: bool = True
|
||||
POSITION_BASE_ALLOCATION_PCT: float = Field(default=5.0, gt=0.0, le=30.0)
|
||||
POSITION_MIN_ALLOCATION_PCT: float = Field(default=1.0, gt=0.0, le=20.0)
|
||||
POSITION_MAX_ALLOCATION_PCT: float = Field(default=10.0, gt=0.0, le=50.0)
|
||||
POSITION_VOLATILITY_TARGET_SCORE: float = Field(default=50.0, gt=0.0, le=100.0)
|
||||
|
||||
# Database
|
||||
DB_PATH: str = "data/trade_logs.db"
|
||||
|
||||
# Rate Limiting (requests per second for KIS API)
|
||||
RATE_LIMIT_RPS: float = 10.0
|
||||
# Conservative limit to avoid EGW00201 "초당 거래건수 초과" errors.
|
||||
# KIS API real limit is ~2 RPS; 2.0 provides maximum safety.
|
||||
RATE_LIMIT_RPS: float = 2.0
|
||||
|
||||
# Trading mode
|
||||
MODE: str = Field(default="paper", pattern="^(paper|live)$")
|
||||
|
||||
# Simulated USD cash for VTS (paper) overseas trading.
|
||||
# KIS VTS overseas balance API returns errors for most accounts.
|
||||
# This value is used as a fallback when the balance API returns 0 in paper mode.
|
||||
PAPER_OVERSEAS_CASH: float = Field(default=50000.0, ge=0.0)
|
||||
|
||||
# Trading frequency mode (daily = batch API calls, realtime = per-stock calls)
|
||||
TRADE_MODE: str = Field(default="daily", pattern="^(daily|realtime)$")
|
||||
DAILY_SESSIONS: int = Field(default=4, ge=1, le=10)
|
||||
SESSION_INTERVAL_HOURS: int = Field(default=6, ge=1, le=24)
|
||||
|
||||
# Pre-Market Planner
|
||||
PRE_MARKET_MINUTES: int = Field(default=30, ge=10, le=120)
|
||||
MAX_SCENARIOS_PER_STOCK: int = Field(default=5, ge=1, le=10)
|
||||
PLANNER_TIMEOUT_SECONDS: int = Field(default=60, ge=10, le=300)
|
||||
DEFENSIVE_PLAYBOOK_ON_FAILURE: bool = True
|
||||
RESCAN_INTERVAL_SECONDS: int = Field(default=300, ge=60, le=900)
|
||||
|
||||
# Market selection (comma-separated market codes)
|
||||
ENABLED_MARKETS: str = "KR"
|
||||
ENABLED_MARKETS: str = "KR,US"
|
||||
|
||||
# Backup and Disaster Recovery (optional)
|
||||
BACKUP_ENABLED: bool = True
|
||||
BACKUP_DIR: str = "data/backups"
|
||||
S3_ENDPOINT_URL: str | None = None # For MinIO, Backblaze B2, etc.
|
||||
S3_ACCESS_KEY: str | None = None
|
||||
S3_SECRET_KEY: str | None = None
|
||||
S3_BUCKET_NAME: str | None = None
|
||||
S3_REGION: str = "us-east-1"
|
||||
|
||||
# Telegram Notifications (optional)
|
||||
TELEGRAM_BOT_TOKEN: str | None = None
|
||||
TELEGRAM_CHAT_ID: str | None = None
|
||||
TELEGRAM_ENABLED: bool = True
|
||||
|
||||
# Telegram Commands (optional)
|
||||
TELEGRAM_COMMANDS_ENABLED: bool = True
|
||||
TELEGRAM_POLLING_INTERVAL: float = 1.0 # seconds
|
||||
|
||||
# Telegram notification type filters (granular control)
|
||||
# circuit_breaker is always sent regardless — safety-critical
|
||||
TELEGRAM_NOTIFY_TRADES: bool = True # BUY/SELL execution alerts
|
||||
TELEGRAM_NOTIFY_MARKET_OPEN_CLOSE: bool = True # Market open/close alerts
|
||||
TELEGRAM_NOTIFY_FAT_FINGER: bool = True # Fat-finger rejection alerts
|
||||
TELEGRAM_NOTIFY_SYSTEM_EVENTS: bool = True # System start/shutdown alerts
|
||||
TELEGRAM_NOTIFY_PLAYBOOK: bool = True # Playbook generated/failed alerts
|
||||
TELEGRAM_NOTIFY_SCENARIO_MATCH: bool = True # Scenario matched alerts (most frequent)
|
||||
TELEGRAM_NOTIFY_ERRORS: bool = True # Error alerts
|
||||
|
||||
# Overseas ranking API (KIS endpoint/TR_ID may vary by account/product)
|
||||
# Override these from .env if your account uses different specs.
|
||||
OVERSEAS_RANKING_ENABLED: bool = True
|
||||
OVERSEAS_RANKING_FLUCT_TR_ID: str = "HHDFS76290000"
|
||||
OVERSEAS_RANKING_VOLUME_TR_ID: str = "HHDFS76270000"
|
||||
OVERSEAS_RANKING_FLUCT_PATH: str = (
|
||||
"/uapi/overseas-stock/v1/ranking/updown-rate"
|
||||
)
|
||||
OVERSEAS_RANKING_VOLUME_PATH: str = (
|
||||
"/uapi/overseas-stock/v1/ranking/volume-surge"
|
||||
)
|
||||
|
||||
# Dashboard (optional)
|
||||
DASHBOARD_ENABLED: bool = False
|
||||
DASHBOARD_HOST: str = "127.0.0.1"
|
||||
DASHBOARD_PORT: int = Field(default=8080, ge=1, le=65535)
|
||||
|
||||
model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}
|
||||
|
||||
@@ -49,4 +133,7 @@ class Settings(BaseSettings):
|
||||
@property
|
||||
def enabled_market_list(self) -> list[str]:
|
||||
"""Parse ENABLED_MARKETS into list of market codes."""
|
||||
return [m.strip() for m in self.ENABLED_MARKETS.split(",") if m.strip()]
|
||||
from src.markets.schedule import expand_market_codes
|
||||
|
||||
raw = [m.strip() for m in self.ENABLED_MARKETS.split(",") if m.strip()]
|
||||
return expand_market_codes(raw)
|
||||
|
||||
@@ -5,6 +5,7 @@ The context tree implements Pillar 2: hierarchical memory management across
|
||||
"""
|
||||
|
||||
from src.context.layer import ContextLayer
|
||||
from src.context.scheduler import ContextScheduler
|
||||
from src.context.store import ContextStore
|
||||
|
||||
__all__ = ["ContextLayer", "ContextStore"]
|
||||
__all__ = ["ContextLayer", "ContextScheduler", "ContextStore"]
|
||||
|
||||
@@ -18,16 +18,33 @@ class ContextAggregator:
|
||||
self.conn = conn
|
||||
self.store = ContextStore(conn)
|
||||
|
||||
def aggregate_daily_from_trades(self, date: str | None = None) -> None:
|
||||
def aggregate_daily_from_trades(
|
||||
self, date: str | None = None, market: str | None = None
|
||||
) -> None:
|
||||
"""Aggregate L6 (daily) context from trades table.
|
||||
|
||||
Args:
|
||||
date: Date in YYYY-MM-DD format. If None, uses today.
|
||||
market: Market code filter (e.g., "KR", "US"). If None, aggregates all markets.
|
||||
"""
|
||||
if date is None:
|
||||
date = datetime.now(UTC).date().isoformat()
|
||||
|
||||
# Calculate daily metrics from trades
|
||||
if market is None:
|
||||
cursor = self.conn.execute(
|
||||
"""
|
||||
SELECT DISTINCT market
|
||||
FROM trades
|
||||
WHERE DATE(timestamp) = ?
|
||||
""",
|
||||
(date,),
|
||||
)
|
||||
markets = [row[0] for row in cursor.fetchall() if row[0]]
|
||||
else:
|
||||
markets = [market]
|
||||
|
||||
for market_code in markets:
|
||||
# Calculate daily metrics from trades for the market
|
||||
cursor = self.conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
@@ -41,29 +58,43 @@ class ContextAggregator:
|
||||
SUM(CASE WHEN pnl > 0 THEN 1 ELSE 0 END) as wins,
|
||||
SUM(CASE WHEN pnl < 0 THEN 1 ELSE 0 END) as losses
|
||||
FROM trades
|
||||
WHERE DATE(timestamp) = ?
|
||||
WHERE DATE(timestamp) = ? AND market = ?
|
||||
""",
|
||||
(date,),
|
||||
(date, market_code),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
|
||||
if row and row[0] > 0: # At least one trade
|
||||
trade_count, buys, sells, holds, avg_conf, total_pnl, stocks, wins, losses = row
|
||||
|
||||
# Store daily metrics in L6
|
||||
self.store.set_context(ContextLayer.L6_DAILY, date, "trade_count", trade_count)
|
||||
self.store.set_context(ContextLayer.L6_DAILY, date, "buys", buys)
|
||||
self.store.set_context(ContextLayer.L6_DAILY, date, "sells", sells)
|
||||
self.store.set_context(ContextLayer.L6_DAILY, date, "holds", holds)
|
||||
key_suffix = f"_{market_code}"
|
||||
|
||||
# Store daily metrics in L6 with market suffix
|
||||
self.store.set_context(
|
||||
ContextLayer.L6_DAILY, date, "avg_confidence", round(avg_conf, 2)
|
||||
ContextLayer.L6_DAILY, date, f"trade_count{key_suffix}", trade_count
|
||||
)
|
||||
self.store.set_context(ContextLayer.L6_DAILY, date, f"buys{key_suffix}", buys)
|
||||
self.store.set_context(ContextLayer.L6_DAILY, date, f"sells{key_suffix}", sells)
|
||||
self.store.set_context(ContextLayer.L6_DAILY, date, f"holds{key_suffix}", holds)
|
||||
self.store.set_context(
|
||||
ContextLayer.L6_DAILY,
|
||||
date,
|
||||
f"avg_confidence{key_suffix}",
|
||||
round(avg_conf, 2),
|
||||
)
|
||||
self.store.set_context(
|
||||
ContextLayer.L6_DAILY, date, "total_pnl", round(total_pnl, 2)
|
||||
ContextLayer.L6_DAILY,
|
||||
date,
|
||||
f"total_pnl{key_suffix}",
|
||||
round(total_pnl, 2),
|
||||
)
|
||||
self.store.set_context(
|
||||
ContextLayer.L6_DAILY, date, f"unique_stocks{key_suffix}", stocks
|
||||
)
|
||||
self.store.set_context(ContextLayer.L6_DAILY, date, "unique_stocks", stocks)
|
||||
win_rate = round(wins / max(wins + losses, 1) * 100, 2)
|
||||
self.store.set_context(ContextLayer.L6_DAILY, date, "win_rate", win_rate)
|
||||
self.store.set_context(
|
||||
ContextLayer.L6_DAILY, date, f"win_rate{key_suffix}", win_rate
|
||||
)
|
||||
|
||||
def aggregate_weekly_from_daily(self, week: str | None = None) -> None:
|
||||
"""Aggregate L5 (weekly) context from L6 (daily).
|
||||
@@ -92,14 +123,25 @@ class ContextAggregator:
|
||||
daily_data[row[0]].append(json.loads(row[1]))
|
||||
|
||||
if daily_data:
|
||||
# Sum all PnL values
|
||||
# Sum all PnL values (market-specific if suffixed)
|
||||
if "total_pnl" in daily_data:
|
||||
total_pnl = sum(daily_data["total_pnl"])
|
||||
self.store.set_context(
|
||||
ContextLayer.L5_WEEKLY, week, "weekly_pnl", round(total_pnl, 2)
|
||||
)
|
||||
|
||||
# Average all confidence values
|
||||
for key, values in daily_data.items():
|
||||
if key.startswith("total_pnl_"):
|
||||
market_code = key.split("total_pnl_", 1)[1]
|
||||
total_pnl = sum(values)
|
||||
self.store.set_context(
|
||||
ContextLayer.L5_WEEKLY,
|
||||
week,
|
||||
f"weekly_pnl_{market_code}",
|
||||
round(total_pnl, 2),
|
||||
)
|
||||
|
||||
# Average all confidence values (market-specific if suffixed)
|
||||
if "avg_confidence" in daily_data:
|
||||
conf_values = daily_data["avg_confidence"]
|
||||
avg_conf = sum(conf_values) / len(conf_values)
|
||||
@@ -107,6 +149,17 @@ class ContextAggregator:
|
||||
ContextLayer.L5_WEEKLY, week, "avg_confidence", round(avg_conf, 2)
|
||||
)
|
||||
|
||||
for key, values in daily_data.items():
|
||||
if key.startswith("avg_confidence_"):
|
||||
market_code = key.split("avg_confidence_", 1)[1]
|
||||
avg_conf = sum(values) / len(values)
|
||||
self.store.set_context(
|
||||
ContextLayer.L5_WEEKLY,
|
||||
week,
|
||||
f"avg_confidence_{market_code}",
|
||||
round(avg_conf, 2),
|
||||
)
|
||||
|
||||
def aggregate_monthly_from_weekly(self, month: str | None = None) -> None:
|
||||
"""Aggregate L4 (monthly) context from L5 (weekly).
|
||||
|
||||
@@ -135,8 +188,16 @@ class ContextAggregator:
|
||||
|
||||
if weekly_data:
|
||||
# Sum all weekly PnL values
|
||||
total_pnl_values: list[float] = []
|
||||
if "weekly_pnl" in weekly_data:
|
||||
total_pnl = sum(weekly_data["weekly_pnl"])
|
||||
total_pnl_values.extend(weekly_data["weekly_pnl"])
|
||||
|
||||
for key, values in weekly_data.items():
|
||||
if key.startswith("weekly_pnl_"):
|
||||
total_pnl_values.extend(values)
|
||||
|
||||
if total_pnl_values:
|
||||
total_pnl = sum(total_pnl_values)
|
||||
self.store.set_context(
|
||||
ContextLayer.L4_MONTHLY, month, "monthly_pnl", round(total_pnl, 2)
|
||||
)
|
||||
@@ -230,21 +291,44 @@ class ContextAggregator:
|
||||
)
|
||||
|
||||
def run_all_aggregations(self) -> None:
|
||||
"""Run all aggregations from L7 to L1 (bottom-up)."""
|
||||
"""Run all aggregations from L7 to L1 (bottom-up).
|
||||
|
||||
All timeframes are derived from the latest trade timestamp so that
|
||||
past data re-aggregation produces consistent results across layers.
|
||||
"""
|
||||
cursor = self.conn.execute("SELECT MAX(timestamp) FROM trades")
|
||||
row = cursor.fetchone()
|
||||
if not row or row[0] is None:
|
||||
return
|
||||
|
||||
ts_raw = row[0]
|
||||
if ts_raw.endswith("Z"):
|
||||
ts_raw = ts_raw.replace("Z", "+00:00")
|
||||
latest_ts = datetime.fromisoformat(ts_raw)
|
||||
trade_date = latest_ts.date()
|
||||
date_str = trade_date.isoformat()
|
||||
|
||||
iso_year, iso_week, _ = trade_date.isocalendar()
|
||||
week_str = f"{iso_year}-W{iso_week:02d}"
|
||||
month_str = f"{trade_date.year}-{trade_date.month:02d}"
|
||||
quarter = (trade_date.month - 1) // 3 + 1
|
||||
quarter_str = f"{trade_date.year}-Q{quarter}"
|
||||
year_str = str(trade_date.year)
|
||||
|
||||
# L7 (trades) → L6 (daily)
|
||||
self.aggregate_daily_from_trades()
|
||||
self.aggregate_daily_from_trades(date_str)
|
||||
|
||||
# L6 (daily) → L5 (weekly)
|
||||
self.aggregate_weekly_from_daily()
|
||||
self.aggregate_weekly_from_daily(week_str)
|
||||
|
||||
# L5 (weekly) → L4 (monthly)
|
||||
self.aggregate_monthly_from_weekly()
|
||||
self.aggregate_monthly_from_weekly(month_str)
|
||||
|
||||
# L4 (monthly) → L3 (quarterly)
|
||||
self.aggregate_quarterly_from_monthly()
|
||||
self.aggregate_quarterly_from_monthly(quarter_str)
|
||||
|
||||
# L3 (quarterly) → L2 (annual)
|
||||
self.aggregate_annual_from_quarterly()
|
||||
self.aggregate_annual_from_quarterly(year_str)
|
||||
|
||||
# L2 (annual) → L1 (legacy)
|
||||
self.aggregate_legacy_from_annual()
|
||||
|
||||
135
src/context/scheduler.py
Normal file
135
src/context/scheduler.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Context aggregation scheduler for periodic rollups and cleanup."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from calendar import monthrange
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from src.context.aggregator import ContextAggregator
|
||||
from src.context.store import ContextStore
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleResult:
|
||||
"""Represents which scheduled tasks ran."""
|
||||
|
||||
weekly: bool = False
|
||||
monthly: bool = False
|
||||
quarterly: bool = False
|
||||
annual: bool = False
|
||||
legacy: bool = False
|
||||
cleanup: bool = False
|
||||
|
||||
|
||||
class ContextScheduler:
|
||||
"""Run periodic context aggregations and cleanup when due."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conn: sqlite3.Connection | None = None,
|
||||
aggregator: ContextAggregator | None = None,
|
||||
store: ContextStore | None = None,
|
||||
) -> None:
|
||||
if aggregator is None:
|
||||
if conn is None:
|
||||
raise ValueError("conn is required when aggregator is not provided")
|
||||
aggregator = ContextAggregator(conn)
|
||||
self.aggregator = aggregator
|
||||
|
||||
if store is None:
|
||||
store = getattr(aggregator, "store", None)
|
||||
if store is None:
|
||||
if conn is None:
|
||||
raise ValueError("conn is required when store is not provided")
|
||||
store = ContextStore(conn)
|
||||
self.store = store
|
||||
|
||||
self._last_run: dict[str, str] = {}
|
||||
|
||||
def run_if_due(self, now: datetime | None = None) -> ScheduleResult:
|
||||
"""Run scheduled aggregations if their schedule is due.
|
||||
|
||||
Args:
|
||||
now: Current datetime (UTC). If None, uses current time.
|
||||
|
||||
Returns:
|
||||
ScheduleResult indicating which tasks ran.
|
||||
"""
|
||||
if now is None:
|
||||
now = datetime.now(UTC)
|
||||
|
||||
today = now.date().isoformat()
|
||||
result = ScheduleResult()
|
||||
|
||||
if self._should_run("cleanup", today):
|
||||
self.store.cleanup_expired_contexts()
|
||||
result = self._with(result, cleanup=True)
|
||||
|
||||
if self._is_sunday(now) and self._should_run("weekly", today):
|
||||
week = now.strftime("%Y-W%V")
|
||||
self.aggregator.aggregate_weekly_from_daily(week)
|
||||
result = self._with(result, weekly=True)
|
||||
|
||||
if self._is_last_day_of_month(now) and self._should_run("monthly", today):
|
||||
month = now.strftime("%Y-%m")
|
||||
self.aggregator.aggregate_monthly_from_weekly(month)
|
||||
result = self._with(result, monthly=True)
|
||||
|
||||
if self._is_last_day_of_quarter(now) and self._should_run("quarterly", today):
|
||||
quarter = self._current_quarter(now)
|
||||
self.aggregator.aggregate_quarterly_from_monthly(quarter)
|
||||
result = self._with(result, quarterly=True)
|
||||
|
||||
if self._is_last_day_of_year(now) and self._should_run("annual", today):
|
||||
year = str(now.year)
|
||||
self.aggregator.aggregate_annual_from_quarterly(year)
|
||||
result = self._with(result, annual=True)
|
||||
|
||||
# Legacy rollup runs after annual aggregation.
|
||||
self.aggregator.aggregate_legacy_from_annual()
|
||||
result = self._with(result, legacy=True)
|
||||
|
||||
return result
|
||||
|
||||
def _should_run(self, key: str, date_str: str) -> bool:
|
||||
if self._last_run.get(key) == date_str:
|
||||
return False
|
||||
self._last_run[key] = date_str
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _is_sunday(now: datetime) -> bool:
|
||||
return now.weekday() == 6
|
||||
|
||||
@staticmethod
|
||||
def _is_last_day_of_month(now: datetime) -> bool:
|
||||
last_day = monthrange(now.year, now.month)[1]
|
||||
return now.day == last_day
|
||||
|
||||
@classmethod
|
||||
def _is_last_day_of_quarter(cls, now: datetime) -> bool:
|
||||
if now.month not in (3, 6, 9, 12):
|
||||
return False
|
||||
return cls._is_last_day_of_month(now)
|
||||
|
||||
@staticmethod
|
||||
def _is_last_day_of_year(now: datetime) -> bool:
|
||||
return now.month == 12 and now.day == 31
|
||||
|
||||
@staticmethod
|
||||
def _current_quarter(now: datetime) -> str:
|
||||
quarter = (now.month - 1) // 3 + 1
|
||||
return f"{now.year}-Q{quarter}"
|
||||
|
||||
@staticmethod
|
||||
def _with(result: ScheduleResult, **kwargs: bool) -> ScheduleResult:
|
||||
return ScheduleResult(
|
||||
weekly=kwargs.get("weekly", result.weekly),
|
||||
monthly=kwargs.get("monthly", result.monthly),
|
||||
quarterly=kwargs.get("quarterly", result.quarterly),
|
||||
annual=kwargs.get("annual", result.annual),
|
||||
legacy=kwargs.get("legacy", result.legacy),
|
||||
cleanup=kwargs.get("cleanup", result.cleanup),
|
||||
)
|
||||
328
src/context/summarizer.py
Normal file
328
src/context/summarizer.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""Context summarization for efficient historical data representation.
|
||||
|
||||
This module summarizes old context data instead of including raw details:
|
||||
- Key metrics only (averages, trends, not details)
|
||||
- Rolling window (keep last N days detailed, summarize older)
|
||||
- Aggregate historical data efficiently
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from src.context.layer import ContextLayer
|
||||
from src.context.store import ContextStore
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SummaryStats:
|
||||
"""Statistical summary of historical data."""
|
||||
|
||||
count: int
|
||||
mean: float | None = None
|
||||
min: float | None = None
|
||||
max: float | None = None
|
||||
std: float | None = None
|
||||
trend: str | None = None # "up", "down", "flat"
|
||||
|
||||
|
||||
class ContextSummarizer:
|
||||
"""Summarizes historical context data to reduce token usage."""
|
||||
|
||||
def __init__(self, store: ContextStore) -> None:
|
||||
"""Initialize the context summarizer.
|
||||
|
||||
Args:
|
||||
store: ContextStore instance for retrieving context data
|
||||
"""
|
||||
self.store = store
|
||||
|
||||
def summarize_numeric_values(self, values: list[float]) -> SummaryStats:
|
||||
"""Summarize a list of numeric values.
|
||||
|
||||
Args:
|
||||
values: List of numeric values to summarize
|
||||
|
||||
Returns:
|
||||
SummaryStats with mean, min, max, std, and trend
|
||||
"""
|
||||
if not values:
|
||||
return SummaryStats(count=0)
|
||||
|
||||
count = len(values)
|
||||
mean = sum(values) / count
|
||||
min_val = min(values)
|
||||
max_val = max(values)
|
||||
|
||||
# Calculate standard deviation
|
||||
if count > 1:
|
||||
variance = sum((x - mean) ** 2 for x in values) / (count - 1)
|
||||
std = variance**0.5
|
||||
else:
|
||||
std = 0.0
|
||||
|
||||
# Determine trend
|
||||
trend = "flat"
|
||||
if count >= 3:
|
||||
# Simple trend: compare first third vs last third
|
||||
first_third = values[: count // 3]
|
||||
last_third = values[-(count // 3) :]
|
||||
first_avg = sum(first_third) / len(first_third)
|
||||
last_avg = sum(last_third) / len(last_third)
|
||||
|
||||
# Trend threshold: 5% change
|
||||
threshold = 0.05 * abs(first_avg) if first_avg != 0 else 0.01
|
||||
|
||||
if last_avg > first_avg + threshold:
|
||||
trend = "up"
|
||||
elif last_avg < first_avg - threshold:
|
||||
trend = "down"
|
||||
|
||||
return SummaryStats(
|
||||
count=count,
|
||||
mean=round(mean, 4),
|
||||
min=round(min_val, 4),
|
||||
max=round(max_val, 4),
|
||||
std=round(std, 4),
|
||||
trend=trend,
|
||||
)
|
||||
|
||||
def summarize_layer(
|
||||
self,
|
||||
layer: ContextLayer,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Summarize all context data for a layer within a date range.
|
||||
|
||||
Args:
|
||||
layer: Context layer to summarize
|
||||
start_date: Start date (inclusive), None for all
|
||||
end_date: End date (inclusive), None for now
|
||||
|
||||
Returns:
|
||||
Dictionary with summarized metrics
|
||||
"""
|
||||
if end_date is None:
|
||||
end_date = datetime.now(UTC)
|
||||
|
||||
# Get all contexts for this layer
|
||||
all_contexts = self.store.get_all_contexts(layer)
|
||||
|
||||
if not all_contexts:
|
||||
return {"summary": "No data available", "count": 0}
|
||||
|
||||
# Group numeric values by key
|
||||
numeric_data: dict[str, list[float]] = {}
|
||||
text_data: dict[str, list[str]] = {}
|
||||
|
||||
for key, value in all_contexts.items():
|
||||
# Try to extract numeric values
|
||||
if isinstance(value, (int, float)):
|
||||
if key not in numeric_data:
|
||||
numeric_data[key] = []
|
||||
numeric_data[key].append(float(value))
|
||||
elif isinstance(value, dict):
|
||||
# Extract numeric fields from dict
|
||||
for subkey, subvalue in value.items():
|
||||
if isinstance(subvalue, (int, float)):
|
||||
full_key = f"{key}.{subkey}"
|
||||
if full_key not in numeric_data:
|
||||
numeric_data[full_key] = []
|
||||
numeric_data[full_key].append(float(subvalue))
|
||||
elif isinstance(value, str):
|
||||
if key not in text_data:
|
||||
text_data[key] = []
|
||||
text_data[key].append(value)
|
||||
|
||||
# Summarize numeric data
|
||||
summary: dict[str, Any] = {}
|
||||
|
||||
for key, values in numeric_data.items():
|
||||
stats = self.summarize_numeric_values(values)
|
||||
summary[key] = {
|
||||
"count": stats.count,
|
||||
"avg": stats.mean,
|
||||
"range": [stats.min, stats.max],
|
||||
"trend": stats.trend,
|
||||
}
|
||||
|
||||
# Summarize text data (just counts)
|
||||
for key, values in text_data.items():
|
||||
summary[f"{key}_count"] = len(values)
|
||||
|
||||
summary["total_entries"] = len(all_contexts)
|
||||
|
||||
return summary
|
||||
|
||||
def rolling_window_summary(
|
||||
self,
|
||||
layer: ContextLayer,
|
||||
window_days: int = 30,
|
||||
summarize_older: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a rolling window summary.
|
||||
|
||||
Recent data (within window) is kept detailed.
|
||||
Older data is summarized to key metrics.
|
||||
|
||||
Args:
|
||||
layer: Context layer to summarize
|
||||
window_days: Number of days to keep detailed
|
||||
summarize_older: Whether to summarize data older than window
|
||||
|
||||
Returns:
|
||||
Dictionary with recent (detailed) and historical (summary) data
|
||||
"""
|
||||
result: dict[str, Any] = {
|
||||
"window_days": window_days,
|
||||
"recent_data": {},
|
||||
"historical_summary": {},
|
||||
}
|
||||
|
||||
# Get all contexts
|
||||
all_contexts = self.store.get_all_contexts(layer)
|
||||
|
||||
recent_values: dict[str, list[float]] = {}
|
||||
historical_values: dict[str, list[float]] = {}
|
||||
|
||||
for key, value in all_contexts.items():
|
||||
# For simplicity, treat all numeric values
|
||||
if isinstance(value, (int, float)):
|
||||
# Note: We don't have timestamps in context keys
|
||||
# This is a simplified implementation
|
||||
# In practice, would need to check timeframe field
|
||||
|
||||
# For now, put recent data in window
|
||||
if key not in recent_values:
|
||||
recent_values[key] = []
|
||||
recent_values[key].append(float(value))
|
||||
|
||||
# Detailed recent data
|
||||
result["recent_data"] = {key: values[-10:] for key, values in recent_values.items()}
|
||||
|
||||
# Summarized historical data
|
||||
if summarize_older:
|
||||
for key, values in historical_values.items():
|
||||
stats = self.summarize_numeric_values(values)
|
||||
result["historical_summary"][key] = {
|
||||
"count": stats.count,
|
||||
"avg": stats.mean,
|
||||
"trend": stats.trend,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def aggregate_to_higher_layer(
|
||||
self,
|
||||
source_layer: ContextLayer,
|
||||
target_layer: ContextLayer,
|
||||
metric_key: str,
|
||||
aggregation_func: str = "mean",
|
||||
) -> float | None:
|
||||
"""Aggregate data from source layer to target layer.
|
||||
|
||||
Args:
|
||||
source_layer: Source context layer (more granular)
|
||||
target_layer: Target context layer (less granular)
|
||||
metric_key: Key of metric to aggregate
|
||||
aggregation_func: Aggregation function ("mean", "sum", "max", "min")
|
||||
|
||||
Returns:
|
||||
Aggregated value, or None if no data available
|
||||
"""
|
||||
# Get all contexts from source layer
|
||||
source_contexts = self.store.get_all_contexts(source_layer)
|
||||
|
||||
# Extract values for metric_key
|
||||
values = []
|
||||
for key, value in source_contexts.items():
|
||||
if key == metric_key and isinstance(value, (int, float)):
|
||||
values.append(float(value))
|
||||
elif isinstance(value, dict) and metric_key in value:
|
||||
subvalue = value[metric_key]
|
||||
if isinstance(subvalue, (int, float)):
|
||||
values.append(float(subvalue))
|
||||
|
||||
if not values:
|
||||
return None
|
||||
|
||||
# Apply aggregation function
|
||||
if aggregation_func == "mean":
|
||||
return sum(values) / len(values)
|
||||
elif aggregation_func == "sum":
|
||||
return sum(values)
|
||||
elif aggregation_func == "max":
|
||||
return max(values)
|
||||
elif aggregation_func == "min":
|
||||
return min(values)
|
||||
else:
|
||||
return sum(values) / len(values) # Default to mean
|
||||
|
||||
def create_compact_summary(
|
||||
self,
|
||||
layers: list[ContextLayer],
|
||||
top_n_metrics: int = 5,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a compact summary across multiple layers.
|
||||
|
||||
Args:
|
||||
layers: List of context layers to summarize
|
||||
top_n_metrics: Number of top metrics to include per layer
|
||||
|
||||
Returns:
|
||||
Compact summary dictionary
|
||||
"""
|
||||
summary: dict[str, Any] = {}
|
||||
|
||||
for layer in layers:
|
||||
layer_summary = self.summarize_layer(layer)
|
||||
|
||||
# Keep only top N metrics (by count/relevance)
|
||||
metrics = []
|
||||
for key, value in layer_summary.items():
|
||||
if isinstance(value, dict) and "count" in value:
|
||||
metrics.append((key, value, value["count"]))
|
||||
|
||||
# Sort by count (descending)
|
||||
metrics.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
# Keep top N
|
||||
top_metrics = {m[0]: m[1] for m in metrics[:top_n_metrics]}
|
||||
|
||||
summary[layer.value] = top_metrics
|
||||
|
||||
return summary
|
||||
|
||||
def format_summary_for_prompt(self, summary: dict[str, Any]) -> str:
|
||||
"""Format summary for inclusion in a prompt.
|
||||
|
||||
Args:
|
||||
summary: Summary dictionary
|
||||
|
||||
Returns:
|
||||
Formatted string for prompt
|
||||
"""
|
||||
lines = []
|
||||
|
||||
for layer, metrics in summary.items():
|
||||
if not metrics:
|
||||
continue
|
||||
|
||||
lines.append(f"{layer}:")
|
||||
for key, value in metrics.items():
|
||||
if isinstance(value, dict):
|
||||
# Format as: key: avg=X, trend=Y
|
||||
parts = []
|
||||
if "avg" in value and value["avg"] is not None:
|
||||
parts.append(f"avg={value['avg']:.2f}")
|
||||
if "trend" in value and value["trend"]:
|
||||
parts.append(f"trend={value['trend']}")
|
||||
if parts:
|
||||
lines.append(f" {key}: {', '.join(parts)}")
|
||||
else:
|
||||
lines.append(f" {key}: {value}")
|
||||
|
||||
return "\n".join(lines)
|
||||
110
src/core/criticality.py
Normal file
110
src/core/criticality.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""Criticality assessment for urgency-based response system.
|
||||
|
||||
Evaluates market conditions to determine response urgency and enable
|
||||
faster reactions in critical situations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class CriticalityLevel(StrEnum):
|
||||
"""Urgency levels for market conditions and trading decisions."""
|
||||
|
||||
CRITICAL = "CRITICAL" # <5s timeout - Emergency response required
|
||||
HIGH = "HIGH" # <30s timeout - Elevated priority
|
||||
NORMAL = "NORMAL" # <60s timeout - Standard processing
|
||||
LOW = "LOW" # No timeout - Batch processing
|
||||
|
||||
|
||||
class CriticalityAssessor:
|
||||
"""Assesses market conditions to determine response criticality level."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
critical_pnl_threshold: float = -2.5,
|
||||
critical_price_change_threshold: float = 5.0,
|
||||
critical_volume_surge_threshold: float = 10.0,
|
||||
high_volatility_threshold: float = 70.0,
|
||||
low_volatility_threshold: float = 30.0,
|
||||
) -> None:
|
||||
"""Initialize the criticality assessor.
|
||||
|
||||
Args:
|
||||
critical_pnl_threshold: P&L % that triggers CRITICAL (default -2.5%)
|
||||
critical_price_change_threshold: Price change % that triggers CRITICAL
|
||||
(default 5.0% in 1 minute)
|
||||
critical_volume_surge_threshold: Volume surge ratio that triggers CRITICAL
|
||||
(default 10x average)
|
||||
high_volatility_threshold: Volatility score that triggers HIGH
|
||||
(default 70.0)
|
||||
low_volatility_threshold: Volatility score below which is LOW
|
||||
(default 30.0)
|
||||
"""
|
||||
self.critical_pnl_threshold = critical_pnl_threshold
|
||||
self.critical_price_change_threshold = critical_price_change_threshold
|
||||
self.critical_volume_surge_threshold = critical_volume_surge_threshold
|
||||
self.high_volatility_threshold = high_volatility_threshold
|
||||
self.low_volatility_threshold = low_volatility_threshold
|
||||
|
||||
def assess_market_conditions(
|
||||
self,
|
||||
pnl_pct: float,
|
||||
volatility_score: float,
|
||||
volume_surge: float,
|
||||
price_change_1m: float = 0.0,
|
||||
is_market_open: bool = True,
|
||||
) -> CriticalityLevel:
|
||||
"""Assess criticality level based on market conditions.
|
||||
|
||||
Args:
|
||||
pnl_pct: Current P&L percentage
|
||||
volatility_score: Momentum score from VolatilityAnalyzer (0-100)
|
||||
volume_surge: Volume surge ratio (current / average)
|
||||
price_change_1m: 1-minute price change percentage
|
||||
is_market_open: Whether the market is currently open
|
||||
|
||||
Returns:
|
||||
CriticalityLevel indicating required response urgency
|
||||
"""
|
||||
# Market closed or very quiet → LOW priority (batch processing)
|
||||
if not is_market_open or volatility_score < self.low_volatility_threshold:
|
||||
return CriticalityLevel.LOW
|
||||
|
||||
# CRITICAL conditions: immediate action required
|
||||
# 1. P&L near circuit breaker (-2.5% is close to -3.0% breaker)
|
||||
if pnl_pct <= self.critical_pnl_threshold:
|
||||
return CriticalityLevel.CRITICAL
|
||||
|
||||
# 2. Large sudden price movement (>5% in 1 minute)
|
||||
if abs(price_change_1m) >= self.critical_price_change_threshold:
|
||||
return CriticalityLevel.CRITICAL
|
||||
|
||||
# 3. Extreme volume surge (>10x average) indicates major event
|
||||
if volume_surge >= self.critical_volume_surge_threshold:
|
||||
return CriticalityLevel.CRITICAL
|
||||
|
||||
# HIGH priority: elevated volatility requires faster response
|
||||
if volatility_score >= self.high_volatility_threshold:
|
||||
return CriticalityLevel.HIGH
|
||||
|
||||
# NORMAL: standard trading conditions
|
||||
return CriticalityLevel.NORMAL
|
||||
|
||||
def get_timeout(self, level: CriticalityLevel) -> float | None:
|
||||
"""Get timeout in seconds for a given criticality level.
|
||||
|
||||
Args:
|
||||
level: Criticality level
|
||||
|
||||
Returns:
|
||||
Timeout in seconds, or None for no timeout (LOW priority)
|
||||
"""
|
||||
timeout_map = {
|
||||
CriticalityLevel.CRITICAL: 5.0,
|
||||
CriticalityLevel.HIGH: 30.0,
|
||||
CriticalityLevel.NORMAL: 60.0,
|
||||
CriticalityLevel.LOW: None,
|
||||
}
|
||||
return timeout_map[level]
|
||||
71
src/core/kill_switch.py
Normal file
71
src/core/kill_switch.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Kill switch orchestration for emergency risk actions.
|
||||
|
||||
Order is fixed:
|
||||
1) block new orders
|
||||
2) cancel pending orders
|
||||
3) refresh order state
|
||||
4) reduce risk
|
||||
5) snapshot and notify
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
StepCallable = Callable[[], Any | Awaitable[Any]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class KillSwitchReport:
|
||||
reason: str
|
||||
steps: list[str] = field(default_factory=list)
|
||||
errors: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class KillSwitchOrchestrator:
|
||||
def __init__(self) -> None:
|
||||
self.new_orders_blocked = False
|
||||
|
||||
async def _run_step(
|
||||
self,
|
||||
report: KillSwitchReport,
|
||||
name: str,
|
||||
fn: StepCallable | None,
|
||||
) -> None:
|
||||
report.steps.append(name)
|
||||
if fn is None:
|
||||
return
|
||||
try:
|
||||
result = fn()
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
except Exception as exc: # pragma: no cover - intentionally resilient
|
||||
report.errors.append(f"{name}: {exc}")
|
||||
|
||||
async def trigger(
|
||||
self,
|
||||
*,
|
||||
reason: str,
|
||||
cancel_pending_orders: StepCallable | None = None,
|
||||
refresh_order_state: StepCallable | None = None,
|
||||
reduce_risk: StepCallable | None = None,
|
||||
snapshot_state: StepCallable | None = None,
|
||||
notify: StepCallable | None = None,
|
||||
) -> KillSwitchReport:
|
||||
report = KillSwitchReport(reason=reason)
|
||||
|
||||
self.new_orders_blocked = True
|
||||
report.steps.append("block_new_orders")
|
||||
|
||||
await self._run_step(report, "cancel_pending_orders", cancel_pending_orders)
|
||||
await self._run_step(report, "refresh_order_state", refresh_order_state)
|
||||
await self._run_step(report, "reduce_risk", reduce_risk)
|
||||
await self._run_step(report, "snapshot_state", snapshot_state)
|
||||
await self._run_step(report, "notify", notify)
|
||||
|
||||
return report
|
||||
|
||||
def clear_block(self) -> None:
|
||||
self.new_orders_blocked = False
|
||||
291
src/core/priority_queue.py
Normal file
291
src/core/priority_queue.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""Priority-based task queue for latency control.
|
||||
|
||||
Implements a thread-safe priority queue with timeout enforcement and metrics tracking.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import heapq
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable, Coroutine
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from src.core.criticality import CriticalityLevel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class PriorityTask:
|
||||
"""Task with priority and timestamp for queue ordering."""
|
||||
|
||||
# Lower priority value = higher urgency (CRITICAL=0, HIGH=1, NORMAL=2, LOW=3)
|
||||
priority: int
|
||||
timestamp: float
|
||||
# Task data not used in comparison
|
||||
task_id: str = field(compare=False)
|
||||
task_data: dict[str, Any] = field(compare=False, default_factory=dict)
|
||||
callback: Callable[[], Coroutine[Any, Any, Any]] | None = field(
|
||||
compare=False, default=None
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueueMetrics:
|
||||
"""Metrics for priority queue performance monitoring."""
|
||||
|
||||
total_enqueued: int = 0
|
||||
total_dequeued: int = 0
|
||||
total_timeouts: int = 0
|
||||
total_errors: int = 0
|
||||
current_size: int = 0
|
||||
# Average wait time per criticality level (in seconds)
|
||||
avg_wait_time: dict[CriticalityLevel, float] = field(default_factory=dict)
|
||||
# P95 wait time per criticality level
|
||||
p95_wait_time: dict[CriticalityLevel, float] = field(default_factory=dict)
|
||||
|
||||
|
||||
class PriorityTaskQueue:
|
||||
"""Thread-safe priority queue with timeout enforcement."""
|
||||
|
||||
# Priority mapping for criticality levels
|
||||
PRIORITY_MAP = {
|
||||
CriticalityLevel.CRITICAL: 0,
|
||||
CriticalityLevel.HIGH: 1,
|
||||
CriticalityLevel.NORMAL: 2,
|
||||
CriticalityLevel.LOW: 3,
|
||||
}
|
||||
|
||||
def __init__(self, max_size: int = 1000) -> None:
|
||||
"""Initialize the priority task queue.
|
||||
|
||||
Args:
|
||||
max_size: Maximum queue size (default 1000)
|
||||
"""
|
||||
self._queue: list[PriorityTask] = []
|
||||
self._lock = asyncio.Lock()
|
||||
self._max_size = max_size
|
||||
self._metrics = QueueMetrics()
|
||||
# Track wait times for metrics
|
||||
self._wait_times: dict[CriticalityLevel, list[float]] = {
|
||||
level: [] for level in CriticalityLevel
|
||||
}
|
||||
|
||||
async def enqueue(
|
||||
self,
|
||||
task_id: str,
|
||||
criticality: CriticalityLevel,
|
||||
task_data: dict[str, Any],
|
||||
callback: Callable[[], Coroutine[Any, Any, Any]] | None = None,
|
||||
) -> bool:
|
||||
"""Add a task to the priority queue.
|
||||
|
||||
Args:
|
||||
task_id: Unique identifier for the task
|
||||
criticality: Criticality level determining priority
|
||||
task_data: Data associated with the task
|
||||
callback: Optional async callback to execute
|
||||
|
||||
Returns:
|
||||
True if enqueued successfully, False if queue is full
|
||||
"""
|
||||
async with self._lock:
|
||||
if len(self._queue) >= self._max_size:
|
||||
logger.warning(
|
||||
"Priority queue full (size=%d), rejecting task %s",
|
||||
len(self._queue),
|
||||
task_id,
|
||||
)
|
||||
return False
|
||||
|
||||
priority = self.PRIORITY_MAP[criticality]
|
||||
timestamp = time.time()
|
||||
|
||||
task = PriorityTask(
|
||||
priority=priority,
|
||||
timestamp=timestamp,
|
||||
task_id=task_id,
|
||||
task_data=task_data,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
heapq.heappush(self._queue, task)
|
||||
self._metrics.total_enqueued += 1
|
||||
self._metrics.current_size = len(self._queue)
|
||||
|
||||
logger.debug(
|
||||
"Enqueued task %s with criticality %s (priority=%d, queue_size=%d)",
|
||||
task_id,
|
||||
criticality.value,
|
||||
priority,
|
||||
len(self._queue),
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
async def dequeue(self, timeout: float | None = None) -> PriorityTask | None:
|
||||
"""Remove and return the highest priority task from the queue.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for a task (seconds)
|
||||
|
||||
Returns:
|
||||
PriorityTask if available, None if queue is empty or timeout
|
||||
"""
|
||||
start_time = time.time()
|
||||
deadline = start_time + timeout if timeout else None
|
||||
|
||||
while True:
|
||||
async with self._lock:
|
||||
if self._queue:
|
||||
task = heapq.heappop(self._queue)
|
||||
self._metrics.total_dequeued += 1
|
||||
self._metrics.current_size = len(self._queue)
|
||||
|
||||
# Calculate wait time
|
||||
wait_time = time.time() - task.timestamp
|
||||
criticality = self._get_criticality_from_priority(task.priority)
|
||||
self._wait_times[criticality].append(wait_time)
|
||||
self._update_wait_time_metrics()
|
||||
|
||||
logger.debug(
|
||||
"Dequeued task %s (priority=%d, wait_time=%.2fs, queue_size=%d)",
|
||||
task.task_id,
|
||||
task.priority,
|
||||
wait_time,
|
||||
len(self._queue),
|
||||
)
|
||||
|
||||
return task
|
||||
|
||||
# Queue is empty
|
||||
if deadline and time.time() >= deadline:
|
||||
return None
|
||||
|
||||
# Wait a bit before checking again
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def execute_with_timeout(
|
||||
self,
|
||||
task: PriorityTask,
|
||||
timeout: float | None,
|
||||
) -> Any:
|
||||
"""Execute a task with timeout enforcement.
|
||||
|
||||
Args:
|
||||
task: Task to execute
|
||||
timeout: Timeout in seconds (None = no timeout)
|
||||
|
||||
Returns:
|
||||
Result from task callback
|
||||
|
||||
Raises:
|
||||
asyncio.TimeoutError: If task exceeds timeout
|
||||
Exception: Any exception raised by the task callback
|
||||
"""
|
||||
if not task.callback:
|
||||
logger.warning("Task %s has no callback, skipping execution", task.task_id)
|
||||
return None
|
||||
|
||||
criticality = self._get_criticality_from_priority(task.priority)
|
||||
|
||||
try:
|
||||
if timeout:
|
||||
result = await asyncio.wait_for(task.callback(), timeout=timeout)
|
||||
else:
|
||||
result = await task.callback()
|
||||
|
||||
logger.debug(
|
||||
"Task %s completed successfully (criticality=%s)",
|
||||
task.task_id,
|
||||
criticality.value,
|
||||
)
|
||||
return result
|
||||
|
||||
except TimeoutError:
|
||||
self._metrics.total_timeouts += 1
|
||||
logger.error(
|
||||
"Task %s timed out after %.2fs (criticality=%s)",
|
||||
task.task_id,
|
||||
timeout or 0.0,
|
||||
criticality.value,
|
||||
)
|
||||
raise
|
||||
|
||||
except Exception as exc:
|
||||
self._metrics.total_errors += 1
|
||||
logger.exception(
|
||||
"Task %s failed with error (criticality=%s): %s",
|
||||
task.task_id,
|
||||
criticality.value,
|
||||
exc,
|
||||
)
|
||||
raise
|
||||
|
||||
def _get_criticality_from_priority(self, priority: int) -> CriticalityLevel:
|
||||
"""Convert priority back to criticality level."""
|
||||
for level, prio in self.PRIORITY_MAP.items():
|
||||
if prio == priority:
|
||||
return level
|
||||
return CriticalityLevel.NORMAL
|
||||
|
||||
def _update_wait_time_metrics(self) -> None:
|
||||
"""Update average and p95 wait time metrics."""
|
||||
for level, times in self._wait_times.items():
|
||||
if not times:
|
||||
continue
|
||||
|
||||
# Keep only last 1000 measurements to avoid memory bloat
|
||||
if len(times) > 1000:
|
||||
self._wait_times[level] = times[-1000:]
|
||||
times = self._wait_times[level]
|
||||
|
||||
# Calculate average
|
||||
self._metrics.avg_wait_time[level] = sum(times) / len(times)
|
||||
|
||||
# Calculate P95
|
||||
sorted_times = sorted(times)
|
||||
p95_idx = int(len(sorted_times) * 0.95)
|
||||
self._metrics.p95_wait_time[level] = sorted_times[p95_idx]
|
||||
|
||||
async def get_metrics(self) -> QueueMetrics:
|
||||
"""Get current queue metrics.
|
||||
|
||||
Returns:
|
||||
QueueMetrics with current statistics
|
||||
"""
|
||||
async with self._lock:
|
||||
return QueueMetrics(
|
||||
total_enqueued=self._metrics.total_enqueued,
|
||||
total_dequeued=self._metrics.total_dequeued,
|
||||
total_timeouts=self._metrics.total_timeouts,
|
||||
total_errors=self._metrics.total_errors,
|
||||
current_size=self._metrics.current_size,
|
||||
avg_wait_time=dict(self._metrics.avg_wait_time),
|
||||
p95_wait_time=dict(self._metrics.p95_wait_time),
|
||||
)
|
||||
|
||||
async def size(self) -> int:
|
||||
"""Get current queue size.
|
||||
|
||||
Returns:
|
||||
Number of tasks in queue
|
||||
"""
|
||||
async with self._lock:
|
||||
return len(self._queue)
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Clear all tasks from the queue.
|
||||
|
||||
Returns:
|
||||
Number of tasks cleared
|
||||
"""
|
||||
async with self._lock:
|
||||
count = len(self._queue)
|
||||
self._queue.clear()
|
||||
self._metrics.current_size = 0
|
||||
logger.info("Cleared %d tasks from priority queue", count)
|
||||
return count
|
||||
5
src/dashboard/__init__.py
Normal file
5
src/dashboard/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""FastAPI dashboard package for observability APIs."""
|
||||
|
||||
from src.dashboard.app import create_dashboard_app
|
||||
|
||||
__all__ = ["create_dashboard_app"]
|
||||
498
src/dashboard/app.py
Normal file
498
src/dashboard/app.py
Normal file
@@ -0,0 +1,498 @@
|
||||
"""FastAPI application for observability dashboard endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
from datetime import UTC, datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Query
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
|
||||
def create_dashboard_app(db_path: str, mode: str = "paper") -> FastAPI:
|
||||
"""Create dashboard FastAPI app bound to a SQLite database path."""
|
||||
app = FastAPI(title="The Ouroboros Dashboard", version="1.0.0")
|
||||
app.state.db_path = db_path
|
||||
app.state.mode = mode
|
||||
|
||||
@app.get("/")
|
||||
def index() -> FileResponse:
|
||||
index_path = Path(__file__).parent / "static" / "index.html"
|
||||
return FileResponse(index_path)
|
||||
|
||||
@app.get("/api/status")
|
||||
def get_status() -> dict[str, Any]:
|
||||
today = datetime.now(UTC).date().isoformat()
|
||||
with _connect(db_path) as conn:
|
||||
market_rows = conn.execute(
|
||||
"""
|
||||
SELECT DISTINCT market FROM (
|
||||
SELECT market FROM trades WHERE DATE(timestamp) = ?
|
||||
UNION
|
||||
SELECT market FROM decision_logs WHERE DATE(timestamp) = ?
|
||||
UNION
|
||||
SELECT market FROM playbooks WHERE date = ?
|
||||
) ORDER BY market
|
||||
""",
|
||||
(today, today, today),
|
||||
).fetchall()
|
||||
markets = [row[0] for row in market_rows] if market_rows else []
|
||||
market_status: dict[str, Any] = {}
|
||||
total_trades = 0
|
||||
total_pnl = 0.0
|
||||
total_decisions = 0
|
||||
for market in markets:
|
||||
trade_row = conn.execute(
|
||||
"""
|
||||
SELECT COUNT(*) AS c, COALESCE(SUM(pnl), 0.0) AS p
|
||||
FROM trades
|
||||
WHERE DATE(timestamp) = ? AND market = ?
|
||||
""",
|
||||
(today, market),
|
||||
).fetchone()
|
||||
decision_row = conn.execute(
|
||||
"""
|
||||
SELECT COUNT(*) AS c
|
||||
FROM decision_logs
|
||||
WHERE DATE(timestamp) = ? AND market = ?
|
||||
""",
|
||||
(today, market),
|
||||
).fetchone()
|
||||
playbook_row = conn.execute(
|
||||
"""
|
||||
SELECT status
|
||||
FROM playbooks
|
||||
WHERE date = ? AND market = ?
|
||||
LIMIT 1
|
||||
""",
|
||||
(today, market),
|
||||
).fetchone()
|
||||
market_status[market] = {
|
||||
"trade_count": int(trade_row["c"] if trade_row else 0),
|
||||
"total_pnl": float(trade_row["p"] if trade_row else 0.0),
|
||||
"decision_count": int(decision_row["c"] if decision_row else 0),
|
||||
"playbook_status": playbook_row["status"] if playbook_row else None,
|
||||
}
|
||||
total_trades += market_status[market]["trade_count"]
|
||||
total_pnl += market_status[market]["total_pnl"]
|
||||
total_decisions += market_status[market]["decision_count"]
|
||||
|
||||
cb_threshold = float(os.getenv("CIRCUIT_BREAKER_PCT", "-3.0"))
|
||||
pnl_pct_rows = conn.execute(
|
||||
"""
|
||||
SELECT key, value
|
||||
FROM system_metrics
|
||||
WHERE key LIKE 'portfolio_pnl_pct_%'
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT 20
|
||||
"""
|
||||
).fetchall()
|
||||
current_pnl_pct: float | None = None
|
||||
if pnl_pct_rows:
|
||||
values = [
|
||||
json.loads(row["value"]).get("pnl_pct")
|
||||
for row in pnl_pct_rows
|
||||
if json.loads(row["value"]).get("pnl_pct") is not None
|
||||
]
|
||||
if values:
|
||||
current_pnl_pct = round(min(values), 4)
|
||||
|
||||
if current_pnl_pct is None:
|
||||
cb_status = "unknown"
|
||||
elif current_pnl_pct <= cb_threshold:
|
||||
cb_status = "tripped"
|
||||
elif current_pnl_pct <= cb_threshold + 1.0:
|
||||
cb_status = "warning"
|
||||
else:
|
||||
cb_status = "ok"
|
||||
|
||||
return {
|
||||
"date": today,
|
||||
"mode": mode,
|
||||
"markets": market_status,
|
||||
"totals": {
|
||||
"trade_count": total_trades,
|
||||
"total_pnl": round(total_pnl, 2),
|
||||
"decision_count": total_decisions,
|
||||
},
|
||||
"circuit_breaker": {
|
||||
"threshold_pct": cb_threshold,
|
||||
"current_pnl_pct": current_pnl_pct,
|
||||
"status": cb_status,
|
||||
},
|
||||
}
|
||||
|
||||
@app.get("/api/playbook/{date_str}")
|
||||
def get_playbook(date_str: str, market: str = Query("KR")) -> dict[str, Any]:
|
||||
with _connect(db_path) as conn:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT date, market, status, playbook_json, generated_at,
|
||||
token_count, scenario_count, match_count
|
||||
FROM playbooks
|
||||
WHERE date = ? AND market = ?
|
||||
""",
|
||||
(date_str, market),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(status_code=404, detail="playbook not found")
|
||||
return {
|
||||
"date": row["date"],
|
||||
"market": row["market"],
|
||||
"status": row["status"],
|
||||
"playbook": json.loads(row["playbook_json"]),
|
||||
"generated_at": row["generated_at"],
|
||||
"token_count": row["token_count"],
|
||||
"scenario_count": row["scenario_count"],
|
||||
"match_count": row["match_count"],
|
||||
}
|
||||
|
||||
@app.get("/api/scorecard/{date_str}")
|
||||
def get_scorecard(date_str: str, market: str = Query("KR")) -> dict[str, Any]:
|
||||
key = f"scorecard_{market}"
|
||||
with _connect(db_path) as conn:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT value
|
||||
FROM contexts
|
||||
WHERE layer = 'L6_DAILY' AND timeframe = ? AND key = ?
|
||||
""",
|
||||
(date_str, key),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(status_code=404, detail="scorecard not found")
|
||||
return {"date": date_str, "market": market, "scorecard": json.loads(row["value"])}
|
||||
|
||||
@app.get("/api/performance")
|
||||
def get_performance(market: str = Query("all")) -> dict[str, Any]:
|
||||
with _connect(db_path) as conn:
|
||||
if market == "all":
|
||||
by_market_rows = conn.execute(
|
||||
"""
|
||||
SELECT market,
|
||||
COUNT(*) AS total_trades,
|
||||
SUM(CASE WHEN pnl > 0 THEN 1 ELSE 0 END) AS wins,
|
||||
SUM(CASE WHEN pnl < 0 THEN 1 ELSE 0 END) AS losses,
|
||||
COALESCE(SUM(pnl), 0.0) AS total_pnl,
|
||||
COALESCE(AVG(confidence), 0.0) AS avg_confidence
|
||||
FROM trades
|
||||
GROUP BY market
|
||||
ORDER BY market
|
||||
"""
|
||||
).fetchall()
|
||||
combined = _performance_from_rows(by_market_rows)
|
||||
return {
|
||||
"market": "all",
|
||||
"combined": combined,
|
||||
"by_market": [
|
||||
_row_to_performance(row)
|
||||
for row in by_market_rows
|
||||
],
|
||||
}
|
||||
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT market,
|
||||
COUNT(*) AS total_trades,
|
||||
SUM(CASE WHEN pnl > 0 THEN 1 ELSE 0 END) AS wins,
|
||||
SUM(CASE WHEN pnl < 0 THEN 1 ELSE 0 END) AS losses,
|
||||
COALESCE(SUM(pnl), 0.0) AS total_pnl,
|
||||
COALESCE(AVG(confidence), 0.0) AS avg_confidence
|
||||
FROM trades
|
||||
WHERE market = ?
|
||||
GROUP BY market
|
||||
""",
|
||||
(market,),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return {"market": market, "metrics": _empty_performance(market)}
|
||||
return {"market": market, "metrics": _row_to_performance(row)}
|
||||
|
||||
@app.get("/api/context/{layer}")
|
||||
def get_context_layer(
|
||||
layer: str,
|
||||
timeframe: str | None = Query(default=None),
|
||||
limit: int = Query(default=100, ge=1, le=1000),
|
||||
) -> dict[str, Any]:
|
||||
with _connect(db_path) as conn:
|
||||
if timeframe is None:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT timeframe, key, value, updated_at
|
||||
FROM contexts
|
||||
WHERE layer = ?
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(layer, limit),
|
||||
).fetchall()
|
||||
else:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT timeframe, key, value, updated_at
|
||||
FROM contexts
|
||||
WHERE layer = ? AND timeframe = ?
|
||||
ORDER BY key
|
||||
LIMIT ?
|
||||
""",
|
||||
(layer, timeframe, limit),
|
||||
).fetchall()
|
||||
|
||||
entries = [
|
||||
{
|
||||
"timeframe": row["timeframe"],
|
||||
"key": row["key"],
|
||||
"value": json.loads(row["value"]),
|
||||
"updated_at": row["updated_at"],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
return {
|
||||
"layer": layer,
|
||||
"timeframe": timeframe,
|
||||
"count": len(entries),
|
||||
"entries": entries,
|
||||
}
|
||||
|
||||
@app.get("/api/decisions")
|
||||
def get_decisions(
|
||||
market: str = Query("KR"),
|
||||
limit: int = Query(default=50, ge=1, le=500),
|
||||
) -> dict[str, Any]:
|
||||
with _connect(db_path) as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT decision_id, timestamp, stock_code, market, exchange_code,
|
||||
action, confidence, rationale, context_snapshot, input_data,
|
||||
outcome_pnl, outcome_accuracy
|
||||
FROM decision_logs
|
||||
WHERE market = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(market, limit),
|
||||
).fetchall()
|
||||
decisions = []
|
||||
for row in rows:
|
||||
decisions.append(
|
||||
{
|
||||
"decision_id": row["decision_id"],
|
||||
"timestamp": row["timestamp"],
|
||||
"stock_code": row["stock_code"],
|
||||
"market": row["market"],
|
||||
"exchange_code": row["exchange_code"],
|
||||
"action": row["action"],
|
||||
"confidence": row["confidence"],
|
||||
"rationale": row["rationale"],
|
||||
"context_snapshot": json.loads(row["context_snapshot"]),
|
||||
"input_data": json.loads(row["input_data"]),
|
||||
"outcome_pnl": row["outcome_pnl"],
|
||||
"outcome_accuracy": row["outcome_accuracy"],
|
||||
}
|
||||
)
|
||||
return {"market": market, "count": len(decisions), "decisions": decisions}
|
||||
|
||||
@app.get("/api/pnl/history")
|
||||
def get_pnl_history(
|
||||
days: int = Query(default=30, ge=1, le=365),
|
||||
market: str = Query("all"),
|
||||
) -> dict[str, Any]:
|
||||
"""Return daily P&L history for charting."""
|
||||
with _connect(db_path) as conn:
|
||||
if market == "all":
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT DATE(timestamp) AS date,
|
||||
SUM(pnl) AS daily_pnl,
|
||||
COUNT(*) AS trade_count
|
||||
FROM trades
|
||||
WHERE pnl IS NOT NULL
|
||||
AND DATE(timestamp) >= DATE('now', ?)
|
||||
GROUP BY DATE(timestamp)
|
||||
ORDER BY DATE(timestamp)
|
||||
""",
|
||||
(f"-{days} days",),
|
||||
).fetchall()
|
||||
else:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT DATE(timestamp) AS date,
|
||||
SUM(pnl) AS daily_pnl,
|
||||
COUNT(*) AS trade_count
|
||||
FROM trades
|
||||
WHERE pnl IS NOT NULL
|
||||
AND market = ?
|
||||
AND DATE(timestamp) >= DATE('now', ?)
|
||||
GROUP BY DATE(timestamp)
|
||||
ORDER BY DATE(timestamp)
|
||||
""",
|
||||
(market, f"-{days} days"),
|
||||
).fetchall()
|
||||
return {
|
||||
"days": days,
|
||||
"market": market,
|
||||
"labels": [row["date"] for row in rows],
|
||||
"pnl": [round(float(row["daily_pnl"]), 2) for row in rows],
|
||||
"trades": [int(row["trade_count"]) for row in rows],
|
||||
}
|
||||
|
||||
@app.get("/api/scenarios/active")
|
||||
def get_active_scenarios(
|
||||
market: str = Query("US"),
|
||||
date_str: str | None = Query(default=None),
|
||||
limit: int = Query(default=50, ge=1, le=500),
|
||||
) -> dict[str, Any]:
|
||||
if date_str is None:
|
||||
date_str = datetime.now(UTC).date().isoformat()
|
||||
|
||||
with _connect(db_path) as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT timestamp, stock_code, action, confidence, rationale, context_snapshot
|
||||
FROM decision_logs
|
||||
WHERE market = ? AND DATE(timestamp) = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(market, date_str, limit),
|
||||
).fetchall()
|
||||
matches: list[dict[str, Any]] = []
|
||||
for row in rows:
|
||||
snapshot = json.loads(row["context_snapshot"])
|
||||
scenario_match = snapshot.get("scenario_match", {})
|
||||
if not isinstance(scenario_match, dict) or not scenario_match:
|
||||
continue
|
||||
matches.append(
|
||||
{
|
||||
"timestamp": row["timestamp"],
|
||||
"stock_code": row["stock_code"],
|
||||
"action": row["action"],
|
||||
"confidence": row["confidence"],
|
||||
"rationale": row["rationale"],
|
||||
"scenario_match": scenario_match,
|
||||
}
|
||||
)
|
||||
return {"market": market, "date": date_str, "count": len(matches), "matches": matches}
|
||||
|
||||
@app.get("/api/positions")
|
||||
def get_positions() -> dict[str, Any]:
|
||||
"""Return all currently open positions (last trade per symbol is BUY)."""
|
||||
with _connect(db_path) as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT stock_code, market, exchange_code,
|
||||
price AS entry_price, quantity, timestamp AS entry_time,
|
||||
decision_id
|
||||
FROM (
|
||||
SELECT stock_code, market, exchange_code, price, quantity,
|
||||
timestamp, decision_id, action,
|
||||
ROW_NUMBER() OVER (
|
||||
PARTITION BY stock_code, market
|
||||
ORDER BY timestamp DESC
|
||||
) AS rn
|
||||
FROM trades
|
||||
)
|
||||
WHERE rn = 1 AND action = 'BUY'
|
||||
ORDER BY entry_time DESC
|
||||
"""
|
||||
).fetchall()
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
positions = []
|
||||
for row in rows:
|
||||
entry_time_str = row["entry_time"]
|
||||
try:
|
||||
entry_dt = datetime.fromisoformat(entry_time_str.replace("Z", "+00:00"))
|
||||
held_seconds = int((now - entry_dt).total_seconds())
|
||||
held_hours = held_seconds // 3600
|
||||
held_minutes = (held_seconds % 3600) // 60
|
||||
if held_hours >= 1:
|
||||
held_display = f"{held_hours}h {held_minutes}m"
|
||||
else:
|
||||
held_display = f"{held_minutes}m"
|
||||
except (ValueError, TypeError):
|
||||
held_display = "--"
|
||||
|
||||
positions.append(
|
||||
{
|
||||
"stock_code": row["stock_code"],
|
||||
"market": row["market"],
|
||||
"exchange_code": row["exchange_code"],
|
||||
"entry_price": row["entry_price"],
|
||||
"quantity": row["quantity"],
|
||||
"entry_time": entry_time_str,
|
||||
"held": held_display,
|
||||
"decision_id": row["decision_id"],
|
||||
}
|
||||
)
|
||||
|
||||
return {"count": len(positions), "positions": positions}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def _connect(db_path: str) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA busy_timeout=8000")
|
||||
return conn
|
||||
|
||||
|
||||
def _row_to_performance(row: sqlite3.Row) -> dict[str, Any]:
|
||||
wins = int(row["wins"] or 0)
|
||||
losses = int(row["losses"] or 0)
|
||||
total = int(row["total_trades"] or 0)
|
||||
win_rate = round((wins / (wins + losses) * 100), 2) if (wins + losses) > 0 else 0.0
|
||||
return {
|
||||
"market": row["market"],
|
||||
"total_trades": total,
|
||||
"wins": wins,
|
||||
"losses": losses,
|
||||
"win_rate": win_rate,
|
||||
"total_pnl": round(float(row["total_pnl"] or 0.0), 2),
|
||||
"avg_confidence": round(float(row["avg_confidence"] or 0.0), 2),
|
||||
}
|
||||
|
||||
|
||||
def _performance_from_rows(rows: list[sqlite3.Row]) -> dict[str, Any]:
|
||||
total_trades = 0
|
||||
wins = 0
|
||||
losses = 0
|
||||
total_pnl = 0.0
|
||||
confidence_weighted = 0.0
|
||||
for row in rows:
|
||||
market_total = int(row["total_trades"] or 0)
|
||||
market_conf = float(row["avg_confidence"] or 0.0)
|
||||
total_trades += market_total
|
||||
wins += int(row["wins"] or 0)
|
||||
losses += int(row["losses"] or 0)
|
||||
total_pnl += float(row["total_pnl"] or 0.0)
|
||||
confidence_weighted += market_total * market_conf
|
||||
win_rate = round((wins / (wins + losses) * 100), 2) if (wins + losses) > 0 else 0.0
|
||||
avg_confidence = round(confidence_weighted / total_trades, 2) if total_trades > 0 else 0.0
|
||||
return {
|
||||
"market": "all",
|
||||
"total_trades": total_trades,
|
||||
"wins": wins,
|
||||
"losses": losses,
|
||||
"win_rate": win_rate,
|
||||
"total_pnl": round(total_pnl, 2),
|
||||
"avg_confidence": avg_confidence,
|
||||
}
|
||||
|
||||
|
||||
def _empty_performance(market: str) -> dict[str, Any]:
|
||||
return {
|
||||
"market": market,
|
||||
"total_trades": 0,
|
||||
"wins": 0,
|
||||
"losses": 0,
|
||||
"win_rate": 0.0,
|
||||
"total_pnl": 0.0,
|
||||
"avg_confidence": 0.0,
|
||||
}
|
||||
798
src/dashboard/static/index.html
Normal file
798
src/dashboard/static/index.html
Normal file
@@ -0,0 +1,798 @@
|
||||
<!doctype html>
|
||||
<html lang="ko">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>The Ouroboros Dashboard</title>
|
||||
<script src="https://cdn.jsdelivr.net/npm/chart.js@4.4.0/dist/chart.umd.min.js"></script>
|
||||
<style>
|
||||
:root {
|
||||
--bg: #0b1724;
|
||||
--panel: #12263a;
|
||||
--fg: #e6eef7;
|
||||
--muted: #9fb3c8;
|
||||
--accent: #3cb371;
|
||||
--red: #e05555;
|
||||
--warn: #e8a040;
|
||||
--border: #28455f;
|
||||
}
|
||||
* { box-sizing: border-box; margin: 0; padding: 0; }
|
||||
body {
|
||||
font-family: ui-monospace, SFMono-Regular, Menlo, monospace;
|
||||
background: radial-gradient(circle at top left, #173b58, var(--bg));
|
||||
color: var(--fg);
|
||||
min-height: 100vh;
|
||||
font-size: 13px;
|
||||
}
|
||||
.wrap { max-width: 1100px; margin: 0 auto; padding: 20px 16px; }
|
||||
|
||||
/* Header */
|
||||
header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
margin-bottom: 20px;
|
||||
padding-bottom: 12px;
|
||||
border-bottom: 1px solid var(--border);
|
||||
}
|
||||
header h1 { font-size: 18px; color: var(--accent); letter-spacing: 0.5px; }
|
||||
.header-right { display: flex; align-items: center; gap: 12px; color: var(--muted); font-size: 12px; }
|
||||
.refresh-btn {
|
||||
background: none; border: 1px solid var(--border); color: var(--muted);
|
||||
padding: 4px 10px; border-radius: 6px; cursor: pointer; font-family: inherit;
|
||||
font-size: 12px; transition: border-color 0.2s;
|
||||
}
|
||||
.refresh-btn:hover { border-color: var(--accent); color: var(--accent); }
|
||||
.mode-badge {
|
||||
padding: 3px 10px; border-radius: 5px; font-size: 12px; font-weight: 700;
|
||||
letter-spacing: 0.5px;
|
||||
}
|
||||
.mode-badge.live {
|
||||
background: rgba(224, 85, 85, 0.15); color: var(--red);
|
||||
border: 1px solid rgba(224, 85, 85, 0.4);
|
||||
animation: pulse-warn 2s ease-in-out infinite;
|
||||
}
|
||||
.mode-badge.paper {
|
||||
background: rgba(232, 160, 64, 0.15); color: var(--warn);
|
||||
border: 1px solid rgba(232, 160, 64, 0.4);
|
||||
}
|
||||
|
||||
/* CB Gauge */
|
||||
.cb-gauge-wrap {
|
||||
display: flex; align-items: center; gap: 8px;
|
||||
font-size: 11px; color: var(--muted);
|
||||
}
|
||||
.cb-dot {
|
||||
width: 8px; height: 8px; border-radius: 50%; flex-shrink: 0;
|
||||
}
|
||||
.cb-dot.ok { background: var(--accent); }
|
||||
.cb-dot.warning { background: var(--warn); animation: pulse-warn 1.2s ease-in-out infinite; }
|
||||
.cb-dot.tripped { background: var(--red); animation: pulse-warn 0.6s ease-in-out infinite; }
|
||||
.cb-dot.unknown { background: var(--border); }
|
||||
@keyframes pulse-warn {
|
||||
0%, 100% { opacity: 1; }
|
||||
50% { opacity: 0.35; }
|
||||
}
|
||||
.cb-bar-wrap { width: 64px; height: 5px; background: rgba(255,255,255,0.08); border-radius: 3px; overflow: hidden; }
|
||||
.cb-bar-fill { height: 100%; border-radius: 3px; transition: width 0.4s, background 0.4s; }
|
||||
|
||||
/* Summary cards */
|
||||
.cards { display: grid; grid-template-columns: repeat(4, 1fr); gap: 12px; margin-bottom: 20px; }
|
||||
@media (max-width: 700px) { .cards { grid-template-columns: repeat(2, 1fr); } }
|
||||
.card {
|
||||
background: var(--panel);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 10px;
|
||||
padding: 16px;
|
||||
}
|
||||
.card-label { color: var(--muted); font-size: 11px; margin-bottom: 6px; text-transform: uppercase; letter-spacing: 0.5px; }
|
||||
.card-value { font-size: 22px; font-weight: 700; }
|
||||
.card-sub { color: var(--muted); font-size: 11px; margin-top: 4px; }
|
||||
.positive { color: var(--accent); }
|
||||
.negative { color: var(--red); }
|
||||
.neutral { color: var(--fg); }
|
||||
|
||||
/* Chart panel */
|
||||
.chart-panel {
|
||||
background: var(--panel);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 10px;
|
||||
padding: 16px;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
.panel-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
.panel-title { font-size: 13px; color: var(--muted); font-weight: 600; }
|
||||
.chart-container { position: relative; height: 180px; }
|
||||
.chart-error { color: var(--muted); text-align: center; padding: 40px 0; font-size: 12px; }
|
||||
|
||||
/* Days selector */
|
||||
.days-selector { display: flex; gap: 4px; }
|
||||
.day-btn {
|
||||
background: none; border: 1px solid var(--border); color: var(--muted);
|
||||
padding: 3px 8px; border-radius: 4px; cursor: pointer; font-family: inherit; font-size: 11px;
|
||||
}
|
||||
.day-btn.active { border-color: var(--accent); color: var(--accent); background: rgba(60, 179, 113, 0.08); }
|
||||
|
||||
/* Decisions panel */
|
||||
.decisions-panel {
|
||||
background: var(--panel);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 10px;
|
||||
padding: 16px;
|
||||
}
|
||||
.market-tabs { display: flex; gap: 6px; flex-wrap: wrap; }
|
||||
.tab-btn {
|
||||
background: none; border: 1px solid var(--border); color: var(--muted);
|
||||
padding: 4px 10px; border-radius: 6px; cursor: pointer; font-family: inherit; font-size: 11px;
|
||||
}
|
||||
.tab-btn.active { border-color: var(--accent); color: var(--accent); background: rgba(60, 179, 113, 0.08); }
|
||||
.decisions-table { width: 100%; border-collapse: collapse; margin-top: 14px; }
|
||||
.decisions-table th {
|
||||
text-align: left; color: var(--muted); font-size: 11px; font-weight: 600;
|
||||
padding: 6px 8px; border-bottom: 1px solid var(--border); white-space: nowrap;
|
||||
}
|
||||
.decisions-table td {
|
||||
padding: 8px 8px; border-bottom: 1px solid rgba(40, 69, 95, 0.5);
|
||||
vertical-align: middle; white-space: nowrap;
|
||||
}
|
||||
.decisions-table tr:last-child td { border-bottom: none; }
|
||||
.decisions-table tr:hover td { background: rgba(255,255,255,0.02); }
|
||||
.badge {
|
||||
display: inline-block; padding: 2px 7px; border-radius: 4px;
|
||||
font-size: 11px; font-weight: 700; letter-spacing: 0.5px;
|
||||
}
|
||||
.badge-buy { background: rgba(60, 179, 113, 0.15); color: var(--accent); }
|
||||
.badge-sell { background: rgba(224, 85, 85, 0.15); color: var(--red); }
|
||||
.badge-hold { background: rgba(159, 179, 200, 0.12); color: var(--muted); }
|
||||
.conf-bar-wrap { display: flex; align-items: center; gap: 6px; min-width: 90px; }
|
||||
.conf-bar { flex: 1; height: 6px; background: rgba(255,255,255,0.08); border-radius: 3px; overflow: hidden; }
|
||||
.conf-fill { height: 100%; border-radius: 3px; background: var(--accent); transition: width 0.3s; }
|
||||
.conf-val { color: var(--muted); font-size: 11px; min-width: 26px; text-align: right; }
|
||||
.rationale-cell { max-width: 200px; overflow: hidden; text-overflow: ellipsis; color: var(--muted); }
|
||||
.empty-row td { text-align: center; color: var(--muted); padding: 24px; }
|
||||
|
||||
/* Positions panel */
|
||||
.positions-panel {
|
||||
background: var(--panel);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 10px;
|
||||
padding: 16px;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
.positions-table { width: 100%; border-collapse: collapse; margin-top: 14px; }
|
||||
.positions-table th {
|
||||
text-align: left; color: var(--muted); font-size: 11px; font-weight: 600;
|
||||
padding: 6px 8px; border-bottom: 1px solid var(--border); white-space: nowrap;
|
||||
}
|
||||
.positions-table td {
|
||||
padding: 8px 8px; border-bottom: 1px solid rgba(40, 69, 95, 0.5);
|
||||
vertical-align: middle; white-space: nowrap;
|
||||
}
|
||||
.positions-table tr:last-child td { border-bottom: none; }
|
||||
.positions-table tr:hover td { background: rgba(255,255,255,0.02); }
|
||||
.pos-empty { color: var(--muted); text-align: center; padding: 20px 0; font-size: 12px; }
|
||||
.pos-count {
|
||||
display: inline-block; background: rgba(60, 179, 113, 0.12);
|
||||
color: var(--accent); font-size: 11px; font-weight: 700;
|
||||
padding: 2px 8px; border-radius: 10px; margin-left: 8px;
|
||||
}
|
||||
|
||||
/* Spinner */
|
||||
.spinner { display: inline-block; width: 12px; height: 12px; border: 2px solid var(--border); border-top-color: var(--accent); border-radius: 50%; animation: spin 0.8s linear infinite; }
|
||||
@keyframes spin { to { transform: rotate(360deg); } }
|
||||
|
||||
/* Generic panel */
|
||||
.panel {
|
||||
background: var(--panel);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 10px;
|
||||
padding: 16px;
|
||||
margin-top: 20px;
|
||||
}
|
||||
|
||||
/* Playbook panel - details/summary accordion */
|
||||
.playbook-panel details { border: 1px solid var(--border); border-radius: 4px; margin-bottom: 6px; }
|
||||
.playbook-panel summary { padding: 8px 12px; cursor: pointer; font-weight: 600; background: var(--bg); color: var(--fg); }
|
||||
.playbook-panel summary:hover { color: var(--accent); }
|
||||
.playbook-panel pre { margin: 0; padding: 12px; background: var(--bg); overflow-x: auto;
|
||||
font-size: 11px; color: #a0c4ff; white-space: pre-wrap; }
|
||||
|
||||
/* Scorecard KPI card grid */
|
||||
.scorecard-grid { display: grid; grid-template-columns: repeat(auto-fill, minmax(140px, 1fr)); gap: 10px; }
|
||||
.kpi-card { background: var(--bg); border: 1px solid var(--border); border-radius: 6px; padding: 12px; text-align: center; }
|
||||
.kpi-card .kpi-label { font-size: 11px; color: var(--muted); margin-bottom: 4px; }
|
||||
.kpi-card .kpi-value { font-size: 20px; font-weight: 700; color: var(--fg); }
|
||||
|
||||
/* Scenarios table */
|
||||
.scenarios-table { width: 100%; border-collapse: collapse; font-size: 13px; }
|
||||
.scenarios-table th { background: var(--bg); padding: 8px; text-align: left; border-bottom: 1px solid var(--border);
|
||||
color: var(--muted); font-size: 11px; font-weight: 600; white-space: nowrap; }
|
||||
.scenarios-table td { padding: 7px 8px; border-bottom: 1px solid rgba(40,69,95,0.5); }
|
||||
.scenarios-table tr:hover td { background: rgba(255,255,255,0.02); }
|
||||
|
||||
/* Context table */
|
||||
.context-table { width: 100%; border-collapse: collapse; font-size: 12px; }
|
||||
.context-table th { background: var(--bg); padding: 8px; text-align: left; border-bottom: 1px solid var(--border);
|
||||
color: var(--muted); font-size: 11px; font-weight: 600; white-space: nowrap; }
|
||||
.context-table td { padding: 6px 8px; border-bottom: 1px solid rgba(40,69,95,0.5); vertical-align: top; }
|
||||
.context-value { max-height: 60px; overflow-y: auto; color: #a0c4ff; word-break: break-all; }
|
||||
|
||||
/* Common panel select controls */
|
||||
.panel-controls { display: flex; gap: 8px; align-items: center; flex-wrap: wrap; }
|
||||
.panel-controls select, .panel-controls input[type="number"] {
|
||||
background: var(--bg); color: var(--fg); border: 1px solid var(--border);
|
||||
border-radius: 4px; padding: 4px 8px; font-size: 13px; font-family: inherit;
|
||||
}
|
||||
.panel-date { color: var(--muted); font-size: 12px; }
|
||||
.empty-msg { color: var(--muted); text-align: center; padding: 20px 0; font-size: 12px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="wrap">
|
||||
<!-- Header -->
|
||||
<header>
|
||||
<h1>🐍 The Ouroboros</h1>
|
||||
<div class="header-right">
|
||||
<span class="mode-badge" id="mode-badge">--</span>
|
||||
<div class="cb-gauge-wrap" id="cb-gauge" title="Circuit Breaker">
|
||||
<span class="cb-dot unknown" id="cb-dot"></span>
|
||||
<span id="cb-label">CB --</span>
|
||||
<div class="cb-bar-wrap">
|
||||
<div class="cb-bar-fill" id="cb-bar" style="width:0%;background:var(--accent)"></div>
|
||||
</div>
|
||||
</div>
|
||||
<span id="last-updated">--</span>
|
||||
<button class="refresh-btn" onclick="refreshAll()">↺ 새로고침</button>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<!-- Summary cards -->
|
||||
<div class="cards">
|
||||
<div class="card">
|
||||
<div class="card-label">오늘 거래</div>
|
||||
<div class="card-value neutral" id="card-trades">--</div>
|
||||
<div class="card-sub" id="card-trades-sub">거래 건수</div>
|
||||
</div>
|
||||
<div class="card">
|
||||
<div class="card-label">오늘 P&L</div>
|
||||
<div class="card-value" id="card-pnl">--</div>
|
||||
<div class="card-sub" id="card-pnl-sub">실현 손익</div>
|
||||
</div>
|
||||
<div class="card">
|
||||
<div class="card-label">승률</div>
|
||||
<div class="card-value neutral" id="card-winrate">--</div>
|
||||
<div class="card-sub">전체 누적</div>
|
||||
</div>
|
||||
<div class="card">
|
||||
<div class="card-label">누적 거래</div>
|
||||
<div class="card-value neutral" id="card-total">--</div>
|
||||
<div class="card-sub">전체 기간</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Open Positions -->
|
||||
<div class="positions-panel">
|
||||
<div class="panel-header">
|
||||
<span class="panel-title">
|
||||
현재 보유 포지션
|
||||
<span class="pos-count" id="positions-count">0</span>
|
||||
</span>
|
||||
</div>
|
||||
<table class="positions-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>종목</th>
|
||||
<th>시장</th>
|
||||
<th>수량</th>
|
||||
<th>진입가</th>
|
||||
<th>보유 시간</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody id="positions-body">
|
||||
<tr><td colspan="5" class="pos-empty"><span class="spinner"></span></td></tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<!-- P&L Chart -->
|
||||
<div class="chart-panel">
|
||||
<div class="panel-header">
|
||||
<span class="panel-title">P&L 추이</span>
|
||||
<div class="days-selector">
|
||||
<button class="day-btn active" data-days="7" onclick="selectDays(this)">7일</button>
|
||||
<button class="day-btn" data-days="30" onclick="selectDays(this)">30일</button>
|
||||
<button class="day-btn" data-days="90" onclick="selectDays(this)">90일</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="chart-container">
|
||||
<canvas id="pnl-chart"></canvas>
|
||||
<div class="chart-error" id="chart-error" style="display:none">데이터 없음</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Decisions log -->
|
||||
<div class="decisions-panel">
|
||||
<div class="panel-header">
|
||||
<span class="panel-title">최근 결정 로그</span>
|
||||
<div class="market-tabs" id="market-tabs">
|
||||
<button class="tab-btn active" data-market="KR" onclick="selectMarket(this)">KR</button>
|
||||
<button class="tab-btn" data-market="US_NASDAQ" onclick="selectMarket(this)">US_NASDAQ</button>
|
||||
<button class="tab-btn" data-market="US_NYSE" onclick="selectMarket(this)">US_NYSE</button>
|
||||
<button class="tab-btn" data-market="JP" onclick="selectMarket(this)">JP</button>
|
||||
<button class="tab-btn" data-market="HK" onclick="selectMarket(this)">HK</button>
|
||||
</div>
|
||||
</div>
|
||||
<table class="decisions-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>시각</th>
|
||||
<th>종목</th>
|
||||
<th>액션</th>
|
||||
<th>신뢰도</th>
|
||||
<th>사유</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody id="decisions-body">
|
||||
<tr class="empty-row"><td colspan="5"><span class="spinner"></span></td></tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<!-- playbook panel -->
|
||||
<div class="panel playbook-panel">
|
||||
<div class="panel-header">
|
||||
<span class="panel-title">📋 프리마켓 플레이북</span>
|
||||
<div class="panel-controls">
|
||||
<select id="pb-market-select" onchange="fetchPlaybook()">
|
||||
<option value="KR">KR</option>
|
||||
<option value="US_NASDAQ">US_NASDAQ</option>
|
||||
<option value="US_NYSE">US_NYSE</option>
|
||||
</select>
|
||||
<span id="pb-date" class="panel-date"></span>
|
||||
</div>
|
||||
</div>
|
||||
<div id="playbook-content"><p class="empty-msg">데이터 없음</p></div>
|
||||
</div>
|
||||
|
||||
<!-- scorecard panel -->
|
||||
<div class="panel">
|
||||
<div class="panel-header">
|
||||
<span class="panel-title">📊 일간 스코어카드</span>
|
||||
<div class="panel-controls">
|
||||
<select id="sc-market-select" onchange="fetchScorecard()">
|
||||
<option value="KR">KR</option>
|
||||
<option value="US_NASDAQ">US_NASDAQ</option>
|
||||
</select>
|
||||
<span id="sc-date" class="panel-date"></span>
|
||||
</div>
|
||||
</div>
|
||||
<div id="scorecard-grid" class="scorecard-grid"><p class="empty-msg">데이터 없음</p></div>
|
||||
</div>
|
||||
|
||||
<!-- scenarios panel -->
|
||||
<div class="panel">
|
||||
<div class="panel-header">
|
||||
<span class="panel-title">🎯 활성 시나리오 매칭</span>
|
||||
<div class="panel-controls">
|
||||
<select id="scen-market-select" onchange="fetchScenarios()">
|
||||
<option value="KR">KR</option>
|
||||
<option value="US_NASDAQ">US_NASDAQ</option>
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
<div id="scenarios-content"><p class="empty-msg">데이터 없음</p></div>
|
||||
</div>
|
||||
|
||||
<!-- context layer panel -->
|
||||
<div class="panel">
|
||||
<div class="panel-header">
|
||||
<span class="panel-title">🧠 컨텍스트 트리</span>
|
||||
<div class="panel-controls">
|
||||
<select id="ctx-layer-select" onchange="fetchContext()">
|
||||
<option value="L7_REALTIME">L7_REALTIME</option>
|
||||
<option value="L6_DAILY">L6_DAILY</option>
|
||||
<option value="L5_WEEKLY">L5_WEEKLY</option>
|
||||
<option value="L4_MONTHLY">L4_MONTHLY</option>
|
||||
<option value="L3_QUARTERLY">L3_QUARTERLY</option>
|
||||
<option value="L2_YEARLY">L2_YEARLY</option>
|
||||
<option value="L1_LIFETIME">L1_LIFETIME</option>
|
||||
</select>
|
||||
<input id="ctx-limit" type="number" value="20" min="1" max="200"
|
||||
style="width:60px;" onchange="fetchContext()">
|
||||
</div>
|
||||
</div>
|
||||
<div id="context-content"><p class="empty-msg">데이터 없음</p></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let pnlChart = null;
|
||||
let currentDays = 7;
|
||||
let currentMarket = 'KR';
|
||||
|
||||
function fmt(dt) {
|
||||
try {
|
||||
const d = new Date(dt);
|
||||
return d.toLocaleTimeString('ko-KR', { hour: '2-digit', minute: '2-digit', hour12: false });
|
||||
} catch { return dt || '--'; }
|
||||
}
|
||||
|
||||
function fmtPnl(v) {
|
||||
if (v === null || v === undefined) return '--';
|
||||
const n = parseFloat(v);
|
||||
const cls = n > 0 ? 'positive' : n < 0 ? 'negative' : 'neutral';
|
||||
const sign = n > 0 ? '+' : '';
|
||||
return `<span class="${cls}">${sign}${n.toFixed(2)}</span>`;
|
||||
}
|
||||
|
||||
function badge(action) {
|
||||
const a = (action || '').toUpperCase();
|
||||
const cls = a === 'BUY' ? 'badge-buy' : a === 'SELL' ? 'badge-sell' : 'badge-hold';
|
||||
return `<span class="badge ${cls}">${a}</span>`;
|
||||
}
|
||||
|
||||
function confBar(conf) {
|
||||
const pct = Math.min(Math.max(conf || 0, 0), 100);
|
||||
return `<div class="conf-bar-wrap">
|
||||
<div class="conf-bar"><div class="conf-fill" style="width:${pct}%"></div></div>
|
||||
<span class="conf-val">${pct}</span>
|
||||
</div>`;
|
||||
}
|
||||
|
||||
function fmtPrice(v, market) {
|
||||
if (v === null || v === undefined) return '--';
|
||||
const n = parseFloat(v);
|
||||
const sym = market === 'KR' ? '₩' : market === 'JP' ? '¥' : market === 'HK' ? 'HK$' : '$';
|
||||
return sym + n.toLocaleString('en-US', { minimumFractionDigits: 0, maximumFractionDigits: 4 });
|
||||
}
|
||||
|
||||
async function fetchPositions() {
|
||||
const tbody = document.getElementById('positions-body');
|
||||
const countEl = document.getElementById('positions-count');
|
||||
try {
|
||||
const r = await fetch('/api/positions');
|
||||
if (!r.ok) throw new Error('fetch failed');
|
||||
const d = await r.json();
|
||||
countEl.textContent = d.count ?? 0;
|
||||
if (!d.positions || d.positions.length === 0) {
|
||||
tbody.innerHTML = '<tr><td colspan="5" class="pos-empty">현재 보유 중인 포지션 없음</td></tr>';
|
||||
return;
|
||||
}
|
||||
tbody.innerHTML = d.positions.map(p => `
|
||||
<tr>
|
||||
<td><strong>${p.stock_code || '--'}</strong></td>
|
||||
<td><span style="color:var(--muted);font-size:11px">${p.market || '--'}</span></td>
|
||||
<td>${p.quantity ?? '--'}</td>
|
||||
<td>${fmtPrice(p.entry_price, p.market)}</td>
|
||||
<td style="color:var(--muted);font-size:11px">${p.held || '--'}</td>
|
||||
</tr>
|
||||
`).join('');
|
||||
} catch {
|
||||
tbody.innerHTML = '<tr><td colspan="5" class="pos-empty">데이터 로드 실패</td></tr>';
|
||||
}
|
||||
}
|
||||
|
||||
function renderCbGauge(cb) {
|
||||
if (!cb) return;
|
||||
const dot = document.getElementById('cb-dot');
|
||||
const label = document.getElementById('cb-label');
|
||||
const bar = document.getElementById('cb-bar');
|
||||
|
||||
const status = cb.status || 'unknown';
|
||||
const threshold = cb.threshold_pct ?? -3.0;
|
||||
const current = cb.current_pnl_pct;
|
||||
|
||||
// dot color
|
||||
dot.className = `cb-dot ${status}`;
|
||||
|
||||
// label
|
||||
if (current !== null && current !== undefined) {
|
||||
const sign = current > 0 ? '+' : '';
|
||||
label.textContent = `CB ${sign}${current.toFixed(2)}%`;
|
||||
} else {
|
||||
label.textContent = 'CB --';
|
||||
}
|
||||
|
||||
// bar: fill = how much of the threshold has been consumed (0%=safe, 100%=tripped)
|
||||
const colorMap = { ok: 'var(--accent)', warning: 'var(--warn)', tripped: 'var(--red)', unknown: 'var(--border)' };
|
||||
bar.style.background = colorMap[status] || 'var(--border)';
|
||||
if (current !== null && current !== undefined && threshold < 0) {
|
||||
const fillPct = Math.min(Math.max((current / threshold) * 100, 0), 100);
|
||||
bar.style.width = `${fillPct}%`;
|
||||
} else {
|
||||
bar.style.width = '0%';
|
||||
}
|
||||
}
|
||||
|
||||
async function fetchStatus() {
|
||||
try {
|
||||
const r = await fetch('/api/status');
|
||||
if (!r.ok) return;
|
||||
const d = await r.json();
|
||||
const t = d.totals || {};
|
||||
document.getElementById('card-trades').textContent = t.trade_count ?? '--';
|
||||
const pnlEl = document.getElementById('card-pnl');
|
||||
const pnlV = t.total_pnl;
|
||||
if (pnlV !== undefined) {
|
||||
const n = parseFloat(pnlV);
|
||||
const sign = n > 0 ? '+' : '';
|
||||
pnlEl.textContent = `${sign}${n.toFixed(2)}`;
|
||||
pnlEl.className = `card-value ${n > 0 ? 'positive' : n < 0 ? 'negative' : 'neutral'}`;
|
||||
}
|
||||
document.getElementById('card-pnl-sub').textContent = `결정 ${t.decision_count ?? 0}건`;
|
||||
renderCbGauge(d.circuit_breaker);
|
||||
renderModeBadge(d.mode);
|
||||
} catch {}
|
||||
}
|
||||
|
||||
function renderModeBadge(mode) {
|
||||
const el = document.getElementById('mode-badge');
|
||||
if (!el) return;
|
||||
if (mode === 'live') {
|
||||
el.textContent = '🔴 실전투자';
|
||||
el.className = 'mode-badge live';
|
||||
} else {
|
||||
el.textContent = '🟡 모의투자';
|
||||
el.className = 'mode-badge paper';
|
||||
}
|
||||
}
|
||||
|
||||
async function fetchPerformance() {
|
||||
try {
|
||||
const r = await fetch('/api/performance?market=all');
|
||||
if (!r.ok) return;
|
||||
const d = await r.json();
|
||||
const c = d.combined || {};
|
||||
document.getElementById('card-winrate').textContent = c.win_rate !== undefined ? `${c.win_rate}%` : '--';
|
||||
document.getElementById('card-total').textContent = c.total_trades ?? '--';
|
||||
} catch {}
|
||||
}
|
||||
|
||||
async function fetchPnlHistory(days) {
|
||||
try {
|
||||
const r = await fetch(`/api/pnl/history?days=${days}`);
|
||||
if (!r.ok) throw new Error('fetch failed');
|
||||
const d = await r.json();
|
||||
renderChart(d);
|
||||
} catch {
|
||||
document.getElementById('chart-error').style.display = 'block';
|
||||
}
|
||||
}
|
||||
|
||||
function renderChart(data) {
|
||||
const errEl = document.getElementById('chart-error');
|
||||
if (!data.labels || data.labels.length === 0) {
|
||||
errEl.style.display = 'block';
|
||||
return;
|
||||
}
|
||||
errEl.style.display = 'none';
|
||||
|
||||
const colors = data.pnl.map(v => v >= 0 ? 'rgba(60,179,113,0.75)' : 'rgba(224,85,85,0.75)');
|
||||
const borderColors = data.pnl.map(v => v >= 0 ? '#3cb371' : '#e05555');
|
||||
|
||||
if (pnlChart) { pnlChart.destroy(); pnlChart = null; }
|
||||
const ctx = document.getElementById('pnl-chart').getContext('2d');
|
||||
pnlChart = new Chart(ctx, {
|
||||
type: 'bar',
|
||||
data: {
|
||||
labels: data.labels,
|
||||
datasets: [{
|
||||
label: 'Daily P&L',
|
||||
data: data.pnl,
|
||||
backgroundColor: colors,
|
||||
borderColor: borderColors,
|
||||
borderWidth: 1,
|
||||
borderRadius: 3,
|
||||
}]
|
||||
},
|
||||
options: {
|
||||
responsive: true,
|
||||
maintainAspectRatio: false,
|
||||
plugins: {
|
||||
legend: { display: false },
|
||||
tooltip: {
|
||||
callbacks: {
|
||||
label: ctx => {
|
||||
const v = ctx.parsed.y;
|
||||
const sign = v >= 0 ? '+' : '';
|
||||
const trades = data.trades[ctx.dataIndex];
|
||||
return [`P&L: ${sign}${v.toFixed(2)}`, `거래: ${trades}건`];
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
scales: {
|
||||
x: {
|
||||
ticks: { color: '#9fb3c8', font: { size: 10 }, maxRotation: 0 },
|
||||
grid: { color: 'rgba(40,69,95,0.4)' }
|
||||
},
|
||||
y: {
|
||||
ticks: { color: '#9fb3c8', font: { size: 10 } },
|
||||
grid: { color: 'rgba(40,69,95,0.4)' }
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async function fetchDecisions(market) {
|
||||
const tbody = document.getElementById('decisions-body');
|
||||
tbody.innerHTML = '<tr class="empty-row"><td colspan="5"><span class="spinner"></span></td></tr>';
|
||||
try {
|
||||
const r = await fetch(`/api/decisions?market=${market}&limit=50`);
|
||||
if (!r.ok) throw new Error('fetch failed');
|
||||
const d = await r.json();
|
||||
if (!d.decisions || d.decisions.length === 0) {
|
||||
tbody.innerHTML = '<tr class="empty-row"><td colspan="5">결정 로그 없음</td></tr>';
|
||||
return;
|
||||
}
|
||||
tbody.innerHTML = d.decisions.map(dec => `
|
||||
<tr>
|
||||
<td>${fmt(dec.timestamp)}</td>
|
||||
<td>${dec.stock_code || '--'}</td>
|
||||
<td>${badge(dec.action)}</td>
|
||||
<td>${confBar(dec.confidence)}</td>
|
||||
<td class="rationale-cell" title="${(dec.rationale || '').replace(/"/g, '"')}">${dec.rationale || '--'}</td>
|
||||
</tr>
|
||||
`).join('');
|
||||
} catch {
|
||||
tbody.innerHTML = '<tr class="empty-row"><td colspan="5">데이터 로드 실패</td></tr>';
|
||||
}
|
||||
}
|
||||
|
||||
function selectDays(btn) {
|
||||
document.querySelectorAll('.day-btn').forEach(b => b.classList.remove('active'));
|
||||
btn.classList.add('active');
|
||||
currentDays = parseInt(btn.dataset.days, 10);
|
||||
fetchPnlHistory(currentDays);
|
||||
}
|
||||
|
||||
function selectMarket(btn) {
|
||||
document.querySelectorAll('.tab-btn').forEach(b => b.classList.remove('active'));
|
||||
btn.classList.add('active');
|
||||
currentMarket = btn.dataset.market;
|
||||
fetchDecisions(currentMarket);
|
||||
}
|
||||
|
||||
function todayStr() {
|
||||
return new Date().toISOString().slice(0, 10);
|
||||
}
|
||||
|
||||
function esc(s) {
|
||||
return String(s ?? '').replace(/&/g, '&').replace(/</g, '<').replace(/>/g, '>').replace(/"/g, '"');
|
||||
}
|
||||
|
||||
async function fetchJSON(url) {
|
||||
const r = await fetch(url);
|
||||
if (!r.ok) throw new Error(`HTTP ${r.status}`);
|
||||
return r.json();
|
||||
}
|
||||
|
||||
async function fetchPlaybook() {
|
||||
const market = document.getElementById('pb-market-select').value;
|
||||
const date = todayStr();
|
||||
document.getElementById('pb-date').textContent = date;
|
||||
const el = document.getElementById('playbook-content');
|
||||
try {
|
||||
const data = await fetchJSON(`/api/playbook/${date}?market=${market}`);
|
||||
const stocks = data.stock_playbooks ?? [];
|
||||
if (stocks.length === 0) {
|
||||
el.innerHTML = '<p class="empty-msg">오늘 플레이북 없음</p>';
|
||||
return;
|
||||
}
|
||||
el.innerHTML = stocks.map(sp =>
|
||||
`<details><summary>${esc(sp.stock_code ?? '?')} — ${esc(sp.signal ?? '')}</summary>` +
|
||||
`<pre>${esc(JSON.stringify(sp, null, 2))}</pre></details>`
|
||||
).join('');
|
||||
} catch {
|
||||
el.innerHTML = '<p class="empty-msg">플레이북 없음 (오늘 미생성 또는 API 오류)</p>';
|
||||
}
|
||||
}
|
||||
|
||||
async function fetchScorecard() {
|
||||
const market = document.getElementById('sc-market-select').value;
|
||||
const date = todayStr();
|
||||
document.getElementById('sc-date').textContent = date;
|
||||
const el = document.getElementById('scorecard-grid');
|
||||
try {
|
||||
const data = await fetchJSON(`/api/scorecard/${date}?market=${market}`);
|
||||
const sc = data.scorecard ?? {};
|
||||
const entries = Object.entries(sc);
|
||||
if (entries.length === 0) {
|
||||
el.innerHTML = '<p class="empty-msg">스코어카드 없음</p>';
|
||||
return;
|
||||
}
|
||||
el.className = 'scorecard-grid';
|
||||
el.innerHTML = entries.map(([k, v]) => `
|
||||
<div class="kpi-card">
|
||||
<div class="kpi-label">${esc(k)}</div>
|
||||
<div class="kpi-value">${typeof v === 'number' ? v.toFixed(2) : esc(String(v))}</div>
|
||||
</div>`).join('');
|
||||
} catch {
|
||||
el.innerHTML = '<p class="empty-msg">스코어카드 없음 (오늘 미생성 또는 API 오류)</p>';
|
||||
}
|
||||
}
|
||||
|
||||
async function fetchScenarios() {
|
||||
const market = document.getElementById('scen-market-select').value;
|
||||
const date = todayStr();
|
||||
const el = document.getElementById('scenarios-content');
|
||||
try {
|
||||
const data = await fetchJSON(`/api/scenarios/active?market=${market}&date_str=${date}&limit=50`);
|
||||
const matches = data.matches ?? [];
|
||||
if (matches.length === 0) {
|
||||
el.innerHTML = '<p class="empty-msg">활성 시나리오 없음</p>';
|
||||
return;
|
||||
}
|
||||
el.innerHTML = `<table class="scenarios-table">
|
||||
<thead><tr><th>종목</th><th>신호</th><th>신뢰도</th><th>매칭 조건</th></tr></thead>
|
||||
<tbody>${matches.map(m => `
|
||||
<tr>
|
||||
<td>${esc(m.stock_code)}</td>
|
||||
<td>${esc(m.signal ?? '-')}</td>
|
||||
<td>${esc(m.confidence ?? '-')}</td>
|
||||
<td><code style="font-size:11px">${esc(JSON.stringify(m.scenario_match ?? {}))}</code></td>
|
||||
</tr>`).join('')}
|
||||
</tbody></table>`;
|
||||
} catch {
|
||||
el.innerHTML = '<p class="empty-msg">데이터 없음</p>';
|
||||
}
|
||||
}
|
||||
|
||||
async function fetchContext() {
|
||||
const layer = document.getElementById('ctx-layer-select').value;
|
||||
const limit = Math.min(Math.max(parseInt(document.getElementById('ctx-limit').value, 10) || 20, 1), 200);
|
||||
const el = document.getElementById('context-content');
|
||||
try {
|
||||
const data = await fetchJSON(`/api/context/${layer}?limit=${limit}`);
|
||||
const entries = data.entries ?? [];
|
||||
if (entries.length === 0) {
|
||||
el.innerHTML = '<p class="empty-msg">컨텍스트 없음</p>';
|
||||
return;
|
||||
}
|
||||
el.innerHTML = `<table class="context-table">
|
||||
<thead><tr><th>timeframe</th><th>key</th><th>value</th><th>updated</th></tr></thead>
|
||||
<tbody>${entries.map(e => `
|
||||
<tr>
|
||||
<td>${esc(e.timeframe)}</td>
|
||||
<td>${esc(e.key)}</td>
|
||||
<td><div class="context-value">${esc(JSON.stringify(e.value ?? e.raw_value))}</div></td>
|
||||
<td style="font-size:11px;color:var(--muted)">${esc((e.updated_at ?? '').slice(0, 16))}</td>
|
||||
</tr>`).join('')}
|
||||
</tbody></table>`;
|
||||
} catch {
|
||||
el.innerHTML = '<p class="empty-msg">데이터 없음</p>';
|
||||
}
|
||||
}
|
||||
|
||||
async function refreshAll() {
|
||||
document.getElementById('last-updated').textContent = '업데이트 중...';
|
||||
await Promise.all([
|
||||
fetchStatus(),
|
||||
fetchPerformance(),
|
||||
fetchPositions(),
|
||||
fetchPnlHistory(currentDays),
|
||||
fetchDecisions(currentMarket),
|
||||
fetchPlaybook(),
|
||||
fetchScorecard(),
|
||||
fetchScenarios(),
|
||||
fetchContext(),
|
||||
]);
|
||||
const now = new Date();
|
||||
const timeStr = now.toLocaleTimeString('ko-KR', { hour: '2-digit', minute: '2-digit', second: '2-digit', hour12: false });
|
||||
document.getElementById('last-updated').textContent = `마지막 업데이트: ${timeStr}`;
|
||||
}
|
||||
|
||||
// Initial load
|
||||
refreshAll();
|
||||
|
||||
// Auto-refresh every 30 seconds
|
||||
setInterval(refreshAll, 30000);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
205
src/data/README.md
Normal file
205
src/data/README.md
Normal file
@@ -0,0 +1,205 @@
|
||||
# External Data Integration
|
||||
|
||||
This module provides objective external data sources to enhance trading decisions beyond just market prices and user input.
|
||||
|
||||
## Modules
|
||||
|
||||
### `news_api.py` - News Sentiment Analysis
|
||||
|
||||
Fetches real-time news for stocks with sentiment scoring.
|
||||
|
||||
**Features:**
|
||||
- Alpha Vantage and NewsAPI.org support
|
||||
- Sentiment scoring (-1.0 to +1.0)
|
||||
- 5-minute caching to minimize API quota usage
|
||||
- Graceful fallback when API unavailable
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from src.data.news_api import NewsAPI
|
||||
|
||||
# Initialize with API key
|
||||
news_api = NewsAPI(api_key="your_key", provider="alphavantage")
|
||||
|
||||
# Fetch news sentiment
|
||||
sentiment = await news_api.get_news_sentiment("AAPL")
|
||||
if sentiment:
|
||||
print(f"Average sentiment: {sentiment.avg_sentiment}")
|
||||
for article in sentiment.articles[:3]:
|
||||
print(f"{article.title} ({article.sentiment_score})")
|
||||
```
|
||||
|
||||
### `economic_calendar.py` - Major Economic Events
|
||||
|
||||
Tracks FOMC meetings, GDP releases, CPI, earnings calendars, and other market-moving events.
|
||||
|
||||
**Features:**
|
||||
- High-impact event tracking (FOMC, GDP, CPI)
|
||||
- Earnings calendar per stock
|
||||
- Event proximity checking
|
||||
- Hardcoded major events for 2026 (no API required)
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from src.data.economic_calendar import EconomicCalendar
|
||||
|
||||
calendar = EconomicCalendar()
|
||||
calendar.load_hardcoded_events()
|
||||
|
||||
# Get upcoming high-impact events
|
||||
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="HIGH")
|
||||
print(f"High-impact events: {upcoming.high_impact_count}")
|
||||
|
||||
# Check if near earnings
|
||||
earnings_date = calendar.get_earnings_date("AAPL")
|
||||
if earnings_date:
|
||||
print(f"Next earnings: {earnings_date}")
|
||||
|
||||
# Check for high volatility period
|
||||
if calendar.is_high_volatility_period(hours_ahead=24):
|
||||
print("High-impact event imminent!")
|
||||
```
|
||||
|
||||
### `market_data.py` - Market Indicators
|
||||
|
||||
Provides market breadth, sector performance, and sentiment indicators.
|
||||
|
||||
**Features:**
|
||||
- Market sentiment levels (Fear & Greed equivalent)
|
||||
- Market breadth (advancing/declining stocks)
|
||||
- Sector performance tracking
|
||||
- Fear/Greed score calculation
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from src.data.market_data import MarketData
|
||||
|
||||
market_data = MarketData(api_key="your_key")
|
||||
|
||||
# Get market sentiment
|
||||
sentiment = market_data.get_market_sentiment()
|
||||
print(f"Market sentiment: {sentiment.name}")
|
||||
|
||||
# Get full indicators
|
||||
indicators = market_data.get_market_indicators("US")
|
||||
print(f"Sentiment: {indicators.sentiment.name}")
|
||||
print(f"A/D Ratio: {indicators.breadth.advance_decline_ratio}")
|
||||
```
|
||||
|
||||
## Integration with GeminiClient
|
||||
|
||||
The external data sources are seamlessly integrated into the AI decision engine:
|
||||
|
||||
```python
|
||||
from src.brain.gemini_client import GeminiClient
|
||||
from src.data.news_api import NewsAPI
|
||||
from src.data.economic_calendar import EconomicCalendar
|
||||
from src.data.market_data import MarketData
|
||||
from src.config import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# Initialize data sources
|
||||
news_api = NewsAPI(api_key=settings.NEWS_API_KEY, provider=settings.NEWS_API_PROVIDER)
|
||||
calendar = EconomicCalendar()
|
||||
calendar.load_hardcoded_events()
|
||||
market_data = MarketData(api_key=settings.MARKET_DATA_API_KEY)
|
||||
|
||||
# Create enhanced client
|
||||
client = GeminiClient(
|
||||
settings,
|
||||
news_api=news_api,
|
||||
economic_calendar=calendar,
|
||||
market_data=market_data
|
||||
)
|
||||
|
||||
# Make decision with external context
|
||||
market_data_dict = {
|
||||
"stock_code": "AAPL",
|
||||
"current_price": 180.0,
|
||||
"market_name": "US stock market"
|
||||
}
|
||||
|
||||
decision = await client.decide(market_data_dict)
|
||||
```
|
||||
|
||||
The external data is automatically included in the prompt sent to Gemini:
|
||||
|
||||
```
|
||||
Market: US stock market
|
||||
Stock Code: AAPL
|
||||
Current Price: 180.0
|
||||
|
||||
EXTERNAL DATA:
|
||||
News Sentiment: 0.85 (from 10 articles)
|
||||
1. [Reuters] Apple hits record high (sentiment: 0.92)
|
||||
2. [Bloomberg] Strong iPhone sales (sentiment: 0.78)
|
||||
3. [CNBC] Tech sector rallying (sentiment: 0.85)
|
||||
|
||||
Upcoming High-Impact Events: 2 in next 7 days
|
||||
Next: FOMC Meeting (FOMC) on 2026-03-18
|
||||
Earnings: AAPL on 2026-02-10
|
||||
|
||||
Market Sentiment: GREED
|
||||
Advance/Decline Ratio: 2.35
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
Add these to your `.env` file:
|
||||
|
||||
```bash
|
||||
# External Data APIs (optional)
|
||||
NEWS_API_KEY=your_alpha_vantage_key
|
||||
NEWS_API_PROVIDER=alphavantage # or "newsapi"
|
||||
MARKET_DATA_API_KEY=your_market_data_key
|
||||
```
|
||||
|
||||
## API Recommendations
|
||||
|
||||
### Alpha Vantage (News)
|
||||
- **Free tier:** 5 calls/min, 500 calls/day
|
||||
- **Pros:** Provides sentiment scores, no credit card required
|
||||
- **URL:** https://www.alphavantage.co/
|
||||
|
||||
### NewsAPI.org
|
||||
- **Free tier:** 100 requests/day
|
||||
- **Pros:** Large news coverage, easy to use
|
||||
- **Cons:** No sentiment scores (we use keyword heuristics)
|
||||
- **URL:** https://newsapi.org/
|
||||
|
||||
## Caching Strategy
|
||||
|
||||
To minimize API quota usage:
|
||||
|
||||
1. **News:** 5-minute TTL cache per stock
|
||||
2. **Economic Calendar:** Loaded once at startup (hardcoded events)
|
||||
3. **Market Data:** Fetched per decision (lightweight)
|
||||
|
||||
## Graceful Degradation
|
||||
|
||||
The system works gracefully without external data:
|
||||
|
||||
- If no API keys provided → decisions work with just market prices
|
||||
- If API fails → decision continues without external context
|
||||
- If cache expired → attempts refetch, falls back to no data
|
||||
- Errors are logged but never block trading decisions
|
||||
|
||||
## Testing
|
||||
|
||||
All modules have comprehensive test coverage (81%+):
|
||||
|
||||
```bash
|
||||
pytest tests/test_data_integration.py -v --cov=src/data
|
||||
```
|
||||
|
||||
Tests use mocks to avoid requiring real API keys.
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- Twitter/X sentiment analysis
|
||||
- Reddit WallStreetBets sentiment
|
||||
- Options flow data
|
||||
- Insider trading activity
|
||||
- Analyst upgrades/downgrades
|
||||
- Real-time economic data APIs
|
||||
5
src/data/__init__.py
Normal file
5
src/data/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""External data integration for objective decision-making."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ["NewsAPI", "EconomicCalendar", "MarketData"]
|
||||
219
src/data/economic_calendar.py
Normal file
219
src/data/economic_calendar.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""Economic calendar integration for major market events.
|
||||
|
||||
Tracks FOMC meetings, GDP releases, CPI, earnings calendars, and other
|
||||
market-moving events.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EconomicEvent:
|
||||
"""Single economic event."""
|
||||
|
||||
name: str
|
||||
event_type: str # "FOMC", "GDP", "CPI", "EARNINGS", etc.
|
||||
datetime: datetime
|
||||
impact: str # "HIGH", "MEDIUM", "LOW"
|
||||
country: str
|
||||
description: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpcomingEvents:
|
||||
"""Collection of upcoming economic events."""
|
||||
|
||||
events: list[EconomicEvent]
|
||||
high_impact_count: int
|
||||
next_major_event: EconomicEvent | None
|
||||
|
||||
|
||||
class EconomicCalendar:
|
||||
"""Economic calendar with event tracking and impact scoring."""
|
||||
|
||||
def __init__(self, api_key: str | None = None) -> None:
|
||||
"""Initialize economic calendar.
|
||||
|
||||
Args:
|
||||
api_key: API key for calendar provider (None for testing/hardcoded)
|
||||
"""
|
||||
self._api_key = api_key
|
||||
# For now, use hardcoded major events (can be extended with API)
|
||||
self._events: list[EconomicEvent] = []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_upcoming_events(
|
||||
self, days_ahead: int = 7, min_impact: str = "MEDIUM"
|
||||
) -> UpcomingEvents:
|
||||
"""Get upcoming economic events within specified timeframe.
|
||||
|
||||
Args:
|
||||
days_ahead: Number of days to look ahead
|
||||
min_impact: Minimum impact level ("LOW", "MEDIUM", "HIGH")
|
||||
|
||||
Returns:
|
||||
UpcomingEvents with filtered events
|
||||
"""
|
||||
now = datetime.now()
|
||||
end_date = now + timedelta(days=days_ahead)
|
||||
|
||||
# Filter events by timeframe and impact
|
||||
upcoming = [
|
||||
event
|
||||
for event in self._events
|
||||
if now <= event.datetime <= end_date
|
||||
and self._impact_level(event.impact) >= self._impact_level(min_impact)
|
||||
]
|
||||
|
||||
# Sort by datetime
|
||||
upcoming.sort(key=lambda e: e.datetime)
|
||||
|
||||
# Count high-impact events
|
||||
high_impact_count = sum(1 for e in upcoming if e.impact == "HIGH")
|
||||
|
||||
# Get next major event
|
||||
next_major = None
|
||||
for event in upcoming:
|
||||
if event.impact == "HIGH":
|
||||
next_major = event
|
||||
break
|
||||
|
||||
return UpcomingEvents(
|
||||
events=upcoming,
|
||||
high_impact_count=high_impact_count,
|
||||
next_major_event=next_major,
|
||||
)
|
||||
|
||||
def add_event(self, event: EconomicEvent) -> None:
|
||||
"""Add an economic event to the calendar."""
|
||||
self._events.append(event)
|
||||
|
||||
def clear_events(self) -> None:
|
||||
"""Clear all events (useful for testing)."""
|
||||
self._events.clear()
|
||||
|
||||
def get_earnings_date(self, stock_code: str) -> datetime | None:
|
||||
"""Get next earnings date for a stock.
|
||||
|
||||
Args:
|
||||
stock_code: Stock ticker symbol
|
||||
|
||||
Returns:
|
||||
Next earnings datetime or None if not found
|
||||
"""
|
||||
now = datetime.now()
|
||||
earnings_events = [
|
||||
event
|
||||
for event in self._events
|
||||
if event.event_type == "EARNINGS"
|
||||
and stock_code.upper() in event.name.upper()
|
||||
and event.datetime > now
|
||||
]
|
||||
|
||||
if not earnings_events:
|
||||
return None
|
||||
|
||||
# Return earliest upcoming earnings
|
||||
earnings_events.sort(key=lambda e: e.datetime)
|
||||
return earnings_events[0].datetime
|
||||
|
||||
def load_hardcoded_events(self) -> None:
|
||||
"""Load hardcoded major economic events for 2026.
|
||||
|
||||
This is a fallback when no API is available.
|
||||
"""
|
||||
# Major FOMC meetings in 2026 (estimated)
|
||||
fomc_dates = [
|
||||
datetime(2026, 3, 18),
|
||||
datetime(2026, 5, 6),
|
||||
datetime(2026, 6, 17),
|
||||
datetime(2026, 7, 29),
|
||||
datetime(2026, 9, 16),
|
||||
datetime(2026, 11, 4),
|
||||
datetime(2026, 12, 16),
|
||||
]
|
||||
|
||||
for date in fomc_dates:
|
||||
self.add_event(
|
||||
EconomicEvent(
|
||||
name="FOMC Meeting",
|
||||
event_type="FOMC",
|
||||
datetime=date,
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Federal Reserve interest rate decision",
|
||||
)
|
||||
)
|
||||
|
||||
# Quarterly GDP releases (estimated)
|
||||
gdp_dates = [
|
||||
datetime(2026, 4, 28),
|
||||
datetime(2026, 7, 30),
|
||||
datetime(2026, 10, 29),
|
||||
]
|
||||
|
||||
for date in gdp_dates:
|
||||
self.add_event(
|
||||
EconomicEvent(
|
||||
name="US GDP Release",
|
||||
event_type="GDP",
|
||||
datetime=date,
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Quarterly GDP growth rate",
|
||||
)
|
||||
)
|
||||
|
||||
# Monthly CPI releases (12th of each month, estimated)
|
||||
for month in range(1, 13):
|
||||
try:
|
||||
cpi_date = datetime(2026, month, 12)
|
||||
self.add_event(
|
||||
EconomicEvent(
|
||||
name="US CPI Release",
|
||||
event_type="CPI",
|
||||
datetime=cpi_date,
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Consumer Price Index inflation data",
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _impact_level(self, impact: str) -> int:
|
||||
"""Convert impact string to numeric level."""
|
||||
levels = {"LOW": 1, "MEDIUM": 2, "HIGH": 3}
|
||||
return levels.get(impact.upper(), 0)
|
||||
|
||||
def is_high_volatility_period(self, hours_ahead: int = 24) -> bool:
|
||||
"""Check if we're near a high-impact event.
|
||||
|
||||
Args:
|
||||
hours_ahead: Number of hours to look ahead
|
||||
|
||||
Returns:
|
||||
True if high-impact event is imminent
|
||||
"""
|
||||
now = datetime.now()
|
||||
threshold = now + timedelta(hours=hours_ahead)
|
||||
|
||||
for event in self._events:
|
||||
if event.impact == "HIGH" and now <= event.datetime <= threshold:
|
||||
return True
|
||||
|
||||
return False
|
||||
198
src/data/market_data.py
Normal file
198
src/data/market_data.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""Additional market data indicators beyond basic price data.
|
||||
|
||||
Provides market breadth, sector performance, and market sentiment indicators.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MarketSentiment(Enum):
|
||||
"""Overall market sentiment levels."""
|
||||
|
||||
EXTREME_FEAR = 1
|
||||
FEAR = 2
|
||||
NEUTRAL = 3
|
||||
GREED = 4
|
||||
EXTREME_GREED = 5
|
||||
|
||||
|
||||
@dataclass
|
||||
class SectorPerformance:
|
||||
"""Performance metrics for a market sector."""
|
||||
|
||||
sector_name: str
|
||||
daily_change_pct: float
|
||||
weekly_change_pct: float
|
||||
leader_stock: str # Best performing stock in sector
|
||||
laggard_stock: str # Worst performing stock in sector
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarketBreadth:
|
||||
"""Market breadth indicators."""
|
||||
|
||||
advancing_stocks: int
|
||||
declining_stocks: int
|
||||
unchanged_stocks: int
|
||||
new_highs: int
|
||||
new_lows: int
|
||||
advance_decline_ratio: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarketIndicators:
|
||||
"""Aggregated market indicators."""
|
||||
|
||||
sentiment: MarketSentiment
|
||||
breadth: MarketBreadth
|
||||
sector_performance: list[SectorPerformance]
|
||||
vix_level: float | None # Volatility index if available
|
||||
|
||||
|
||||
class MarketData:
|
||||
"""Market data provider for additional indicators."""
|
||||
|
||||
def __init__(self, api_key: str | None = None) -> None:
|
||||
"""Initialize market data provider.
|
||||
|
||||
Args:
|
||||
api_key: API key for data provider (None for testing)
|
||||
"""
|
||||
self._api_key = api_key
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_market_sentiment(self) -> MarketSentiment:
|
||||
"""Get current market sentiment level.
|
||||
|
||||
This is a simplified version. In production, this would integrate
|
||||
with Fear & Greed Index or similar sentiment indicators.
|
||||
|
||||
Returns:
|
||||
MarketSentiment enum value
|
||||
"""
|
||||
# Default to neutral when API not available
|
||||
if self._api_key is None:
|
||||
logger.debug("No market data API key — returning NEUTRAL sentiment")
|
||||
return MarketSentiment.NEUTRAL
|
||||
|
||||
# TODO: Integrate with actual sentiment API
|
||||
return MarketSentiment.NEUTRAL
|
||||
|
||||
def get_market_breadth(self, market: str = "US") -> MarketBreadth | None:
|
||||
"""Get market breadth indicators.
|
||||
|
||||
Args:
|
||||
market: Market code ("US", "KR", etc.)
|
||||
|
||||
Returns:
|
||||
MarketBreadth object or None if unavailable
|
||||
"""
|
||||
if self._api_key is None:
|
||||
logger.debug("No market data API key — returning None for breadth")
|
||||
return None
|
||||
|
||||
# TODO: Integrate with actual market breadth API
|
||||
return None
|
||||
|
||||
def get_sector_performance(
|
||||
self, market: str = "US"
|
||||
) -> list[SectorPerformance]:
|
||||
"""Get sector performance rankings.
|
||||
|
||||
Args:
|
||||
market: Market code ("US", "KR", etc.)
|
||||
|
||||
Returns:
|
||||
List of SectorPerformance objects, sorted by daily change
|
||||
"""
|
||||
if self._api_key is None:
|
||||
logger.debug("No market data API key — returning empty sector list")
|
||||
return []
|
||||
|
||||
# TODO: Integrate with actual sector performance API
|
||||
return []
|
||||
|
||||
def get_market_indicators(self, market: str = "US") -> MarketIndicators:
|
||||
"""Get aggregated market indicators.
|
||||
|
||||
Args:
|
||||
market: Market code ("US", "KR", etc.)
|
||||
|
||||
Returns:
|
||||
MarketIndicators with all available data
|
||||
"""
|
||||
sentiment = self.get_market_sentiment()
|
||||
breadth = self.get_market_breadth(market)
|
||||
sectors = self.get_sector_performance(market)
|
||||
|
||||
# Default breadth if unavailable
|
||||
if breadth is None:
|
||||
breadth = MarketBreadth(
|
||||
advancing_stocks=0,
|
||||
declining_stocks=0,
|
||||
unchanged_stocks=0,
|
||||
new_highs=0,
|
||||
new_lows=0,
|
||||
advance_decline_ratio=1.0,
|
||||
)
|
||||
|
||||
return MarketIndicators(
|
||||
sentiment=sentiment,
|
||||
breadth=breadth,
|
||||
sector_performance=sectors,
|
||||
vix_level=None, # TODO: Add VIX integration
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helper Methods
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def calculate_fear_greed_score(
|
||||
self, breadth: MarketBreadth, vix: float | None = None
|
||||
) -> int:
|
||||
"""Calculate a simple fear/greed score (0-100).
|
||||
|
||||
Args:
|
||||
breadth: Market breadth data
|
||||
vix: VIX level (optional)
|
||||
|
||||
Returns:
|
||||
Score from 0 (extreme fear) to 100 (extreme greed)
|
||||
"""
|
||||
# Start at neutral
|
||||
score = 50
|
||||
|
||||
# Adjust based on advance/decline ratio
|
||||
if breadth.advance_decline_ratio > 1.5:
|
||||
score += 20
|
||||
elif breadth.advance_decline_ratio > 1.0:
|
||||
score += 10
|
||||
elif breadth.advance_decline_ratio < 0.5:
|
||||
score -= 20
|
||||
elif breadth.advance_decline_ratio < 1.0:
|
||||
score -= 10
|
||||
|
||||
# Adjust based on new highs/lows
|
||||
if breadth.new_highs > breadth.new_lows * 2:
|
||||
score += 15
|
||||
elif breadth.new_lows > breadth.new_highs * 2:
|
||||
score -= 15
|
||||
|
||||
# Adjust based on VIX if available
|
||||
if vix is not None:
|
||||
if vix > 30: # High volatility = fear
|
||||
score -= 15
|
||||
elif vix < 15: # Low volatility = complacency/greed
|
||||
score += 10
|
||||
|
||||
# Clamp to 0-100
|
||||
return max(0, min(100, score))
|
||||
316
src/data/news_api.py
Normal file
316
src/data/news_api.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""News API integration with sentiment analysis and caching.
|
||||
|
||||
Fetches real-time news for stocks using free-tier APIs (Alpha Vantage or NewsAPI).
|
||||
Includes 5-minute caching to minimize API quota usage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache entries expire after 5 minutes
|
||||
CACHE_TTL_SECONDS = 300
|
||||
|
||||
|
||||
@dataclass
|
||||
class NewsArticle:
|
||||
"""Single news article with sentiment."""
|
||||
|
||||
title: str
|
||||
summary: str
|
||||
source: str
|
||||
published_at: str
|
||||
sentiment_score: float # -1.0 (negative) to +1.0 (positive)
|
||||
url: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class NewsSentiment:
|
||||
"""Aggregated news sentiment for a stock."""
|
||||
|
||||
stock_code: str
|
||||
articles: list[NewsArticle]
|
||||
avg_sentiment: float # Average sentiment across all articles
|
||||
article_count: int
|
||||
fetched_at: float # Unix timestamp
|
||||
|
||||
|
||||
class NewsAPI:
|
||||
"""News API client with sentiment analysis and caching."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
provider: str = "alphavantage",
|
||||
cache_ttl: int = CACHE_TTL_SECONDS,
|
||||
) -> None:
|
||||
"""Initialize NewsAPI client.
|
||||
|
||||
Args:
|
||||
api_key: API key for the news provider (None for testing)
|
||||
provider: News provider ("alphavantage" or "newsapi")
|
||||
cache_ttl: Cache time-to-live in seconds
|
||||
"""
|
||||
self._api_key = api_key
|
||||
self._provider = provider
|
||||
self._cache_ttl = cache_ttl
|
||||
self._cache: dict[str, NewsSentiment] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_news_sentiment(self, stock_code: str) -> NewsSentiment | None:
|
||||
"""Fetch news sentiment for a stock with caching.
|
||||
|
||||
Args:
|
||||
stock_code: Stock ticker symbol (e.g., "AAPL", "005930")
|
||||
|
||||
Returns:
|
||||
NewsSentiment object or None if fetch fails or API unavailable
|
||||
"""
|
||||
# Check cache first
|
||||
cached = self._get_from_cache(stock_code)
|
||||
if cached is not None:
|
||||
logger.debug("News cache hit for %s", stock_code)
|
||||
return cached
|
||||
|
||||
# API key required for real requests
|
||||
if self._api_key is None:
|
||||
logger.warning("No news API key provided — returning None")
|
||||
return None
|
||||
|
||||
# Fetch from API
|
||||
try:
|
||||
sentiment = await self._fetch_news(stock_code)
|
||||
if sentiment is not None:
|
||||
self._cache[stock_code] = sentiment
|
||||
return sentiment
|
||||
except Exception as exc:
|
||||
logger.error("Failed to fetch news for %s: %s", stock_code, exc)
|
||||
return None
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the news cache (useful for testing)."""
|
||||
self._cache.clear()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Cache Management
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_from_cache(self, stock_code: str) -> NewsSentiment | None:
|
||||
"""Retrieve cached sentiment if not expired."""
|
||||
if stock_code not in self._cache:
|
||||
return None
|
||||
|
||||
cached = self._cache[stock_code]
|
||||
age = time.time() - cached.fetched_at
|
||||
|
||||
if age > self._cache_ttl:
|
||||
logger.debug("News cache expired for %s (age: %.1fs)", stock_code, age)
|
||||
del self._cache[stock_code]
|
||||
return None
|
||||
|
||||
return cached
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# API Fetching
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _fetch_news(self, stock_code: str) -> NewsSentiment | None:
|
||||
"""Fetch news from the provider API."""
|
||||
if self._provider == "alphavantage":
|
||||
return await self._fetch_alphavantage(stock_code)
|
||||
elif self._provider == "newsapi":
|
||||
return await self._fetch_newsapi(stock_code)
|
||||
else:
|
||||
logger.error("Unknown news provider: %s", self._provider)
|
||||
return None
|
||||
|
||||
async def _fetch_alphavantage(self, stock_code: str) -> NewsSentiment | None:
|
||||
"""Fetch news from Alpha Vantage News Sentiment API."""
|
||||
url = "https://www.alphavantage.co/query"
|
||||
params = {
|
||||
"function": "NEWS_SENTIMENT",
|
||||
"tickers": stock_code,
|
||||
"apikey": self._api_key,
|
||||
"limit": 10, # Fetch top 10 articles
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, params=params, timeout=10) as resp:
|
||||
if resp.status != 200:
|
||||
logger.error(
|
||||
"Alpha Vantage API error: HTTP %d", resp.status
|
||||
)
|
||||
return None
|
||||
|
||||
data = await resp.json()
|
||||
return self._parse_alphavantage_response(stock_code, data)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Alpha Vantage request failed: %s", exc)
|
||||
return None
|
||||
|
||||
async def _fetch_newsapi(self, stock_code: str) -> NewsSentiment | None:
|
||||
"""Fetch news from NewsAPI.org."""
|
||||
url = "https://newsapi.org/v2/everything"
|
||||
params = {
|
||||
"q": stock_code,
|
||||
"apiKey": self._api_key,
|
||||
"pageSize": 10,
|
||||
"sortBy": "publishedAt",
|
||||
"language": "en",
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, params=params, timeout=10) as resp:
|
||||
if resp.status != 200:
|
||||
logger.error("NewsAPI error: HTTP %d", resp.status)
|
||||
return None
|
||||
|
||||
data = await resp.json()
|
||||
return self._parse_newsapi_response(stock_code, data)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("NewsAPI request failed: %s", exc)
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Response Parsing
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _parse_alphavantage_response(
|
||||
self, stock_code: str, data: dict[str, Any]
|
||||
) -> NewsSentiment | None:
|
||||
"""Parse Alpha Vantage API response."""
|
||||
if "feed" not in data:
|
||||
logger.warning("No 'feed' key in Alpha Vantage response")
|
||||
return None
|
||||
|
||||
articles: list[NewsArticle] = []
|
||||
for item in data["feed"]:
|
||||
# Extract sentiment for this specific ticker
|
||||
ticker_sentiment = self._extract_ticker_sentiment(item, stock_code)
|
||||
|
||||
article = NewsArticle(
|
||||
title=item.get("title", ""),
|
||||
summary=item.get("summary", "")[:200], # Truncate long summaries
|
||||
source=item.get("source", "Unknown"),
|
||||
published_at=item.get("time_published", ""),
|
||||
sentiment_score=ticker_sentiment,
|
||||
url=item.get("url", ""),
|
||||
)
|
||||
articles.append(article)
|
||||
|
||||
if not articles:
|
||||
return None
|
||||
|
||||
avg_sentiment = sum(a.sentiment_score for a in articles) / len(articles)
|
||||
|
||||
return NewsSentiment(
|
||||
stock_code=stock_code,
|
||||
articles=articles,
|
||||
avg_sentiment=avg_sentiment,
|
||||
article_count=len(articles),
|
||||
fetched_at=time.time(),
|
||||
)
|
||||
|
||||
def _extract_ticker_sentiment(
|
||||
self, item: dict[str, Any], stock_code: str
|
||||
) -> float:
|
||||
"""Extract sentiment score for specific ticker from article."""
|
||||
ticker_sentiments = item.get("ticker_sentiment", [])
|
||||
for ts in ticker_sentiments:
|
||||
if ts.get("ticker", "").upper() == stock_code.upper():
|
||||
# Alpha Vantage provides sentiment_score as string
|
||||
score_str = ts.get("ticker_sentiment_score", "0")
|
||||
try:
|
||||
return float(score_str)
|
||||
except ValueError:
|
||||
return 0.0
|
||||
|
||||
# Fallback to overall sentiment if ticker-specific not found
|
||||
overall_sentiment = item.get("overall_sentiment_score", "0")
|
||||
try:
|
||||
return float(overall_sentiment)
|
||||
except ValueError:
|
||||
return 0.0
|
||||
|
||||
def _parse_newsapi_response(
|
||||
self, stock_code: str, data: dict[str, Any]
|
||||
) -> NewsSentiment | None:
|
||||
"""Parse NewsAPI.org response.
|
||||
|
||||
Note: NewsAPI doesn't provide sentiment scores, so we use a
|
||||
simple heuristic based on title keywords.
|
||||
"""
|
||||
if data.get("status") != "ok" or "articles" not in data:
|
||||
logger.warning("Invalid NewsAPI response")
|
||||
return None
|
||||
|
||||
articles: list[NewsArticle] = []
|
||||
for item in data["articles"]:
|
||||
# Simple sentiment heuristic based on keywords
|
||||
sentiment = self._estimate_sentiment_from_text(
|
||||
item.get("title", "") + " " + item.get("description", "")
|
||||
)
|
||||
|
||||
article = NewsArticle(
|
||||
title=item.get("title", ""),
|
||||
summary=item.get("description", "")[:200],
|
||||
source=item.get("source", {}).get("name", "Unknown"),
|
||||
published_at=item.get("publishedAt", ""),
|
||||
sentiment_score=sentiment,
|
||||
url=item.get("url", ""),
|
||||
)
|
||||
articles.append(article)
|
||||
|
||||
if not articles:
|
||||
return None
|
||||
|
||||
avg_sentiment = sum(a.sentiment_score for a in articles) / len(articles)
|
||||
|
||||
return NewsSentiment(
|
||||
stock_code=stock_code,
|
||||
articles=articles,
|
||||
avg_sentiment=avg_sentiment,
|
||||
article_count=len(articles),
|
||||
fetched_at=time.time(),
|
||||
)
|
||||
|
||||
def _estimate_sentiment_from_text(self, text: str) -> float:
|
||||
"""Simple keyword-based sentiment estimation.
|
||||
|
||||
This is a fallback for APIs that don't provide sentiment scores.
|
||||
Returns a score between -1.0 and +1.0.
|
||||
"""
|
||||
text_lower = text.lower()
|
||||
|
||||
positive_keywords = [
|
||||
"surge", "jump", "gain", "rise", "soar", "rally", "profit",
|
||||
"growth", "upgrade", "beat", "strong", "bullish", "breakthrough",
|
||||
]
|
||||
negative_keywords = [
|
||||
"plunge", "fall", "drop", "decline", "crash", "loss", "weak",
|
||||
"downgrade", "miss", "bearish", "concern", "risk", "warning",
|
||||
]
|
||||
|
||||
positive_count = sum(1 for kw in positive_keywords if kw in text_lower)
|
||||
negative_count = sum(1 for kw in negative_keywords if kw in text_lower)
|
||||
|
||||
total = positive_count + negative_count
|
||||
if total == 0:
|
||||
return 0.0
|
||||
|
||||
# Normalize to -1.0 to +1.0 range
|
||||
return (positive_count - negative_count) / total
|
||||
154
src/db.py
154
src/db.py
@@ -2,9 +2,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
def init_db(db_path: str) -> sqlite3.Connection:
|
||||
@@ -12,6 +14,11 @@ def init_db(db_path: str) -> sqlite3.Connection:
|
||||
if db_path != ":memory:":
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(db_path)
|
||||
# Enable WAL mode for concurrent read/write (dashboard + trading loop).
|
||||
# WAL does not apply to in-memory databases.
|
||||
if db_path != ":memory:":
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA busy_timeout=5000")
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS trades (
|
||||
@@ -25,12 +32,14 @@ def init_db(db_path: str) -> sqlite3.Connection:
|
||||
price REAL,
|
||||
pnl REAL DEFAULT 0.0,
|
||||
market TEXT DEFAULT 'KR',
|
||||
exchange_code TEXT DEFAULT 'KRX'
|
||||
exchange_code TEXT DEFAULT 'KRX',
|
||||
decision_id TEXT,
|
||||
mode TEXT DEFAULT 'paper'
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Migration: Add market and exchange_code columns if they don't exist
|
||||
# Migration: Add columns if they don't exist (backward-compatible schema upgrades)
|
||||
cursor = conn.execute("PRAGMA table_info(trades)")
|
||||
columns = {row[1] for row in cursor.fetchall()}
|
||||
|
||||
@@ -38,6 +47,12 @@ def init_db(db_path: str) -> sqlite3.Connection:
|
||||
conn.execute("ALTER TABLE trades ADD COLUMN market TEXT DEFAULT 'KR'")
|
||||
if "exchange_code" not in columns:
|
||||
conn.execute("ALTER TABLE trades ADD COLUMN exchange_code TEXT DEFAULT 'KRX'")
|
||||
if "selection_context" not in columns:
|
||||
conn.execute("ALTER TABLE trades ADD COLUMN selection_context TEXT")
|
||||
if "decision_id" not in columns:
|
||||
conn.execute("ALTER TABLE trades ADD COLUMN decision_id TEXT")
|
||||
if "mode" not in columns:
|
||||
conn.execute("ALTER TABLE trades ADD COLUMN mode TEXT DEFAULT 'paper'")
|
||||
|
||||
# Context tree tables for multi-layered memory management
|
||||
conn.execute(
|
||||
@@ -88,6 +103,27 @@ def init_db(db_path: str) -> sqlite3.Connection:
|
||||
"""
|
||||
)
|
||||
|
||||
# Playbook storage for pre-market strategy persistence
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS playbooks (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
date TEXT NOT NULL,
|
||||
market TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
playbook_json TEXT NOT NULL,
|
||||
generated_at TEXT NOT NULL,
|
||||
token_count INTEGER DEFAULT 0,
|
||||
scenario_count INTEGER DEFAULT 0,
|
||||
match_count INTEGER DEFAULT 0,
|
||||
UNIQUE(date, market)
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_playbooks_date ON playbooks(date)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_playbooks_market ON playbooks(market)")
|
||||
|
||||
# Create indices for efficient context queries
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_contexts_layer ON contexts(layer)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_contexts_timeframe ON contexts(timeframe)")
|
||||
@@ -103,6 +139,25 @@ def init_db(db_path: str) -> sqlite3.Connection:
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_decision_logs_confidence ON decision_logs(confidence)"
|
||||
)
|
||||
|
||||
# Index for open-position queries (partition by stock_code, market, ordered by timestamp)
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_trades_stock_market_ts"
|
||||
" ON trades (stock_code, market, timestamp DESC)"
|
||||
)
|
||||
|
||||
# Lightweight key-value store for trading system runtime metrics (dashboard use only)
|
||||
# Intentionally separate from the AI context tree to preserve separation of concerns.
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS system_metrics (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
return conn
|
||||
|
||||
@@ -118,15 +173,38 @@ def log_trade(
|
||||
pnl: float = 0.0,
|
||||
market: str = "KR",
|
||||
exchange_code: str = "KRX",
|
||||
selection_context: dict[str, any] | None = None,
|
||||
decision_id: str | None = None,
|
||||
mode: str = "paper",
|
||||
) -> None:
|
||||
"""Insert a trade record into the database."""
|
||||
"""Insert a trade record into the database.
|
||||
|
||||
Args:
|
||||
conn: Database connection
|
||||
stock_code: Stock code
|
||||
action: Trade action (BUY/SELL/HOLD)
|
||||
confidence: Confidence level (0-100)
|
||||
rationale: AI decision rationale
|
||||
quantity: Number of shares
|
||||
price: Trade price
|
||||
pnl: Profit/loss
|
||||
market: Market code
|
||||
exchange_code: Exchange code
|
||||
selection_context: Scanner selection data (RSI, volume_ratio, signal, score)
|
||||
decision_id: Unique decision identifier for audit linking
|
||||
mode: Trading mode ('paper' or 'live') for data separation
|
||||
"""
|
||||
# Serialize selection context to JSON
|
||||
context_json = json.dumps(selection_context) if selection_context else None
|
||||
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO trades (
|
||||
timestamp, stock_code, action, confidence, rationale,
|
||||
quantity, price, pnl, market, exchange_code
|
||||
quantity, price, pnl, market, exchange_code, selection_context, decision_id,
|
||||
mode
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
datetime.now(UTC).isoformat(),
|
||||
@@ -139,6 +217,72 @@ def log_trade(
|
||||
pnl,
|
||||
market,
|
||||
exchange_code,
|
||||
context_json,
|
||||
decision_id,
|
||||
mode,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def get_latest_buy_trade(
|
||||
conn: sqlite3.Connection, stock_code: str, market: str
|
||||
) -> dict[str, Any] | None:
|
||||
"""Fetch the most recent BUY trade for a stock and market."""
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
SELECT decision_id, price, quantity
|
||||
FROM trades
|
||||
WHERE stock_code = ?
|
||||
AND market = ?
|
||||
AND action = 'BUY'
|
||||
AND decision_id IS NOT NULL
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 1
|
||||
""",
|
||||
(stock_code, market),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
return {"decision_id": row[0], "price": row[1], "quantity": row[2]}
|
||||
|
||||
|
||||
def get_open_position(
|
||||
conn: sqlite3.Connection, stock_code: str, market: str
|
||||
) -> dict[str, Any] | None:
|
||||
"""Return open position if latest trade is BUY, else None."""
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
SELECT action, decision_id, price, quantity, timestamp
|
||||
FROM trades
|
||||
WHERE stock_code = ?
|
||||
AND market = ?
|
||||
AND action IN ('BUY', 'SELL')
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 1
|
||||
""",
|
||||
(stock_code, market),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if not row or row[0] != "BUY":
|
||||
return None
|
||||
return {"decision_id": row[1], "price": row[2], "quantity": row[3], "timestamp": row[4]}
|
||||
|
||||
|
||||
def get_recent_symbols(
|
||||
conn: sqlite3.Connection, market: str, limit: int = 30
|
||||
) -> list[str]:
|
||||
"""Return recent unique symbols for a market, newest first."""
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
SELECT stock_code, MAX(timestamp) AS last_ts
|
||||
FROM trades
|
||||
WHERE market = ?
|
||||
GROUP BY stock_code
|
||||
ORDER BY last_ts DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(market, limit),
|
||||
)
|
||||
return [row[0] for row in cursor.fetchall() if row and row[0]]
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
"""Evolution engine for self-improving trading strategies."""
|
||||
|
||||
from src.evolution.ab_test import ABTester, ABTestResult, StrategyPerformance
|
||||
from src.evolution.daily_review import DailyReviewer
|
||||
from src.evolution.optimizer import EvolutionOptimizer
|
||||
from src.evolution.performance_tracker import (
|
||||
PerformanceDashboard,
|
||||
PerformanceTracker,
|
||||
StrategyMetrics,
|
||||
)
|
||||
from src.evolution.scorecard import DailyScorecard
|
||||
|
||||
__all__ = [
|
||||
"EvolutionOptimizer",
|
||||
"ABTester",
|
||||
"ABTestResult",
|
||||
"StrategyPerformance",
|
||||
"PerformanceTracker",
|
||||
"PerformanceDashboard",
|
||||
"StrategyMetrics",
|
||||
"DailyScorecard",
|
||||
"DailyReviewer",
|
||||
]
|
||||
|
||||
220
src/evolution/ab_test.py
Normal file
220
src/evolution/ab_test.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""A/B Testing framework for strategy comparison.
|
||||
|
||||
Runs multiple strategies in parallel, tracks their performance,
|
||||
and uses statistical significance testing to determine winners.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import scipy.stats as stats
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrategyPerformance:
|
||||
"""Performance metrics for a single strategy."""
|
||||
|
||||
strategy_name: str
|
||||
total_trades: int
|
||||
wins: int
|
||||
losses: int
|
||||
total_pnl: float
|
||||
avg_pnl: float
|
||||
win_rate: float
|
||||
sharpe_ratio: float | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ABTestResult:
|
||||
"""Result of an A/B test between two strategies."""
|
||||
|
||||
strategy_a: str
|
||||
strategy_b: str
|
||||
winner: str | None
|
||||
p_value: float
|
||||
confidence_level: float
|
||||
is_significant: bool
|
||||
performance_a: StrategyPerformance
|
||||
performance_b: StrategyPerformance
|
||||
|
||||
|
||||
class ABTester:
|
||||
"""A/B testing framework for comparing trading strategies."""
|
||||
|
||||
def __init__(self, significance_level: float = 0.05) -> None:
|
||||
"""Initialize A/B tester.
|
||||
|
||||
Args:
|
||||
significance_level: P-value threshold for statistical significance (default 0.05)
|
||||
"""
|
||||
self._significance_level = significance_level
|
||||
|
||||
def calculate_performance(
|
||||
self, trades: list[dict[str, Any]], strategy_name: str
|
||||
) -> StrategyPerformance:
|
||||
"""Calculate performance metrics for a strategy.
|
||||
|
||||
Args:
|
||||
trades: List of trade records with pnl values
|
||||
strategy_name: Name of the strategy
|
||||
|
||||
Returns:
|
||||
StrategyPerformance object with calculated metrics
|
||||
"""
|
||||
if not trades:
|
||||
return StrategyPerformance(
|
||||
strategy_name=strategy_name,
|
||||
total_trades=0,
|
||||
wins=0,
|
||||
losses=0,
|
||||
total_pnl=0.0,
|
||||
avg_pnl=0.0,
|
||||
win_rate=0.0,
|
||||
sharpe_ratio=None,
|
||||
)
|
||||
|
||||
total_trades = len(trades)
|
||||
wins = sum(1 for t in trades if t.get("pnl", 0) > 0)
|
||||
losses = sum(1 for t in trades if t.get("pnl", 0) < 0)
|
||||
pnls = [t.get("pnl", 0.0) for t in trades]
|
||||
total_pnl = sum(pnls)
|
||||
avg_pnl = total_pnl / total_trades if total_trades > 0 else 0.0
|
||||
win_rate = (wins / total_trades * 100) if total_trades > 0 else 0.0
|
||||
|
||||
# Calculate Sharpe ratio (risk-adjusted return)
|
||||
sharpe_ratio = None
|
||||
if len(pnls) > 1:
|
||||
mean_return = avg_pnl
|
||||
std_return = (
|
||||
sum((p - mean_return) ** 2 for p in pnls) / (len(pnls) - 1)
|
||||
) ** 0.5
|
||||
if std_return > 0:
|
||||
sharpe_ratio = mean_return / std_return
|
||||
|
||||
return StrategyPerformance(
|
||||
strategy_name=strategy_name,
|
||||
total_trades=total_trades,
|
||||
wins=wins,
|
||||
losses=losses,
|
||||
total_pnl=round(total_pnl, 2),
|
||||
avg_pnl=round(avg_pnl, 2),
|
||||
win_rate=round(win_rate, 2),
|
||||
sharpe_ratio=round(sharpe_ratio, 4) if sharpe_ratio else None,
|
||||
)
|
||||
|
||||
def compare_strategies(
|
||||
self,
|
||||
trades_a: list[dict[str, Any]],
|
||||
trades_b: list[dict[str, Any]],
|
||||
strategy_a_name: str = "Strategy A",
|
||||
strategy_b_name: str = "Strategy B",
|
||||
) -> ABTestResult:
|
||||
"""Compare two strategies using statistical testing.
|
||||
|
||||
Uses a two-sample t-test to determine if performance difference is significant.
|
||||
|
||||
Args:
|
||||
trades_a: List of trades from strategy A
|
||||
trades_b: List of trades from strategy B
|
||||
strategy_a_name: Name of strategy A
|
||||
strategy_b_name: Name of strategy B
|
||||
|
||||
Returns:
|
||||
ABTestResult with comparison details
|
||||
"""
|
||||
perf_a = self.calculate_performance(trades_a, strategy_a_name)
|
||||
perf_b = self.calculate_performance(trades_b, strategy_b_name)
|
||||
|
||||
# Extract PnL arrays for statistical testing
|
||||
pnls_a = [t.get("pnl", 0.0) for t in trades_a]
|
||||
pnls_b = [t.get("pnl", 0.0) for t in trades_b]
|
||||
|
||||
# Perform two-sample t-test
|
||||
if len(pnls_a) > 1 and len(pnls_b) > 1:
|
||||
t_stat, p_value = stats.ttest_ind(pnls_a, pnls_b, equal_var=False)
|
||||
is_significant = p_value < self._significance_level
|
||||
confidence_level = (1 - p_value) * 100
|
||||
else:
|
||||
# Not enough data for statistical test
|
||||
p_value = 1.0
|
||||
is_significant = False
|
||||
confidence_level = 0.0
|
||||
|
||||
# Determine winner based on average PnL
|
||||
winner = None
|
||||
if is_significant:
|
||||
if perf_a.avg_pnl > perf_b.avg_pnl:
|
||||
winner = strategy_a_name
|
||||
elif perf_b.avg_pnl > perf_a.avg_pnl:
|
||||
winner = strategy_b_name
|
||||
|
||||
return ABTestResult(
|
||||
strategy_a=strategy_a_name,
|
||||
strategy_b=strategy_b_name,
|
||||
winner=winner,
|
||||
p_value=round(p_value, 4),
|
||||
confidence_level=round(confidence_level, 2),
|
||||
is_significant=is_significant,
|
||||
performance_a=perf_a,
|
||||
performance_b=perf_b,
|
||||
)
|
||||
|
||||
def should_deploy(
|
||||
self,
|
||||
result: ABTestResult,
|
||||
min_win_rate: float = 60.0,
|
||||
min_trades: int = 20,
|
||||
) -> bool:
|
||||
"""Determine if a winning strategy should be deployed.
|
||||
|
||||
Args:
|
||||
result: A/B test result
|
||||
min_win_rate: Minimum win rate percentage for deployment (default 60%)
|
||||
min_trades: Minimum number of trades required (default 20)
|
||||
|
||||
Returns:
|
||||
True if the winning strategy meets deployment criteria
|
||||
"""
|
||||
if not result.is_significant or result.winner is None:
|
||||
return False
|
||||
|
||||
# Get performance of winning strategy
|
||||
if result.winner == result.strategy_a:
|
||||
winning_perf = result.performance_a
|
||||
else:
|
||||
winning_perf = result.performance_b
|
||||
|
||||
# Check deployment criteria
|
||||
has_enough_trades = winning_perf.total_trades >= min_trades
|
||||
has_good_win_rate = winning_perf.win_rate >= min_win_rate
|
||||
is_profitable = winning_perf.avg_pnl > 0
|
||||
|
||||
meets_criteria = has_enough_trades and has_good_win_rate and is_profitable
|
||||
|
||||
if meets_criteria:
|
||||
logger.info(
|
||||
"Strategy '%s' meets deployment criteria: "
|
||||
"win_rate=%.2f%%, trades=%d, avg_pnl=%.2f",
|
||||
result.winner,
|
||||
winning_perf.win_rate,
|
||||
winning_perf.total_trades,
|
||||
winning_perf.avg_pnl,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Strategy '%s' does NOT meet deployment criteria: "
|
||||
"win_rate=%.2f%% (min %.2f%%), trades=%d (min %d), avg_pnl=%.2f",
|
||||
result.winner if result.winner else "unknown",
|
||||
winning_perf.win_rate if result.winner else 0.0,
|
||||
min_win_rate,
|
||||
winning_perf.total_trades if result.winner else 0,
|
||||
min_trades,
|
||||
winning_perf.avg_pnl if result.winner else 0.0,
|
||||
)
|
||||
|
||||
return meets_criteria
|
||||
196
src/evolution/daily_review.py
Normal file
196
src/evolution/daily_review.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""Daily review generator for market-scoped end-of-day scorecards."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import sqlite3
|
||||
from dataclasses import asdict
|
||||
|
||||
from src.brain.gemini_client import GeminiClient
|
||||
from src.context.layer import ContextLayer
|
||||
from src.context.store import ContextStore
|
||||
from src.evolution.scorecard import DailyScorecard
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DailyReviewer:
|
||||
"""Builds daily scorecards and optional AI-generated lessons."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conn: sqlite3.Connection,
|
||||
context_store: ContextStore,
|
||||
gemini_client: GeminiClient | None = None,
|
||||
) -> None:
|
||||
self._conn = conn
|
||||
self._context_store = context_store
|
||||
self._gemini = gemini_client
|
||||
|
||||
def generate_scorecard(self, date: str, market: str) -> DailyScorecard:
|
||||
"""Generate a market-scoped scorecard from decision logs and trades."""
|
||||
decision_rows = self._conn.execute(
|
||||
"""
|
||||
SELECT action, confidence, context_snapshot
|
||||
FROM decision_logs
|
||||
WHERE DATE(timestamp) = ? AND market = ?
|
||||
""",
|
||||
(date, market),
|
||||
).fetchall()
|
||||
|
||||
total_decisions = len(decision_rows)
|
||||
buys = sum(1 for row in decision_rows if row[0] == "BUY")
|
||||
sells = sum(1 for row in decision_rows if row[0] == "SELL")
|
||||
holds = sum(1 for row in decision_rows if row[0] == "HOLD")
|
||||
avg_confidence = (
|
||||
round(sum(int(row[1]) for row in decision_rows) / total_decisions, 2)
|
||||
if total_decisions > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
matched = 0
|
||||
for row in decision_rows:
|
||||
try:
|
||||
snapshot = json.loads(row[2]) if row[2] else {}
|
||||
except json.JSONDecodeError:
|
||||
snapshot = {}
|
||||
scenario_match = snapshot.get("scenario_match", {})
|
||||
if isinstance(scenario_match, dict) and scenario_match:
|
||||
matched += 1
|
||||
scenario_match_rate = (
|
||||
round((matched / total_decisions) * 100, 2)
|
||||
if total_decisions
|
||||
else 0.0
|
||||
)
|
||||
|
||||
trade_stats = self._conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
COALESCE(SUM(pnl), 0.0),
|
||||
SUM(CASE WHEN pnl > 0 THEN 1 ELSE 0 END),
|
||||
SUM(CASE WHEN pnl < 0 THEN 1 ELSE 0 END)
|
||||
FROM trades
|
||||
WHERE DATE(timestamp) = ? AND market = ?
|
||||
""",
|
||||
(date, market),
|
||||
).fetchone()
|
||||
total_pnl = round(float(trade_stats[0] or 0.0), 2) if trade_stats else 0.0
|
||||
wins = int(trade_stats[1] or 0) if trade_stats else 0
|
||||
losses = int(trade_stats[2] or 0) if trade_stats else 0
|
||||
win_rate = round((wins / (wins + losses)) * 100, 2) if (wins + losses) > 0 else 0.0
|
||||
|
||||
top_winners = [
|
||||
row[0]
|
||||
for row in self._conn.execute(
|
||||
"""
|
||||
SELECT stock_code, SUM(pnl) AS stock_pnl
|
||||
FROM trades
|
||||
WHERE DATE(timestamp) = ? AND market = ?
|
||||
GROUP BY stock_code
|
||||
HAVING stock_pnl > 0
|
||||
ORDER BY stock_pnl DESC
|
||||
LIMIT 3
|
||||
""",
|
||||
(date, market),
|
||||
).fetchall()
|
||||
]
|
||||
|
||||
top_losers = [
|
||||
row[0]
|
||||
for row in self._conn.execute(
|
||||
"""
|
||||
SELECT stock_code, SUM(pnl) AS stock_pnl
|
||||
FROM trades
|
||||
WHERE DATE(timestamp) = ? AND market = ?
|
||||
GROUP BY stock_code
|
||||
HAVING stock_pnl < 0
|
||||
ORDER BY stock_pnl ASC
|
||||
LIMIT 3
|
||||
""",
|
||||
(date, market),
|
||||
).fetchall()
|
||||
]
|
||||
|
||||
return DailyScorecard(
|
||||
date=date,
|
||||
market=market,
|
||||
total_decisions=total_decisions,
|
||||
buys=buys,
|
||||
sells=sells,
|
||||
holds=holds,
|
||||
total_pnl=total_pnl,
|
||||
win_rate=win_rate,
|
||||
avg_confidence=avg_confidence,
|
||||
scenario_match_rate=scenario_match_rate,
|
||||
top_winners=top_winners,
|
||||
top_losers=top_losers,
|
||||
lessons=[],
|
||||
cross_market_note="",
|
||||
)
|
||||
|
||||
async def generate_lessons(self, scorecard: DailyScorecard) -> list[str]:
|
||||
"""Generate concise lessons from scorecard metrics using Gemini."""
|
||||
if self._gemini is None:
|
||||
return []
|
||||
|
||||
prompt = (
|
||||
"You are a trading performance reviewer.\n"
|
||||
"Return ONLY a JSON array of 1-3 short lessons in English.\n"
|
||||
f"Market: {scorecard.market}\n"
|
||||
f"Date: {scorecard.date}\n"
|
||||
f"Total decisions: {scorecard.total_decisions}\n"
|
||||
f"Buys/Sells/Holds: {scorecard.buys}/{scorecard.sells}/{scorecard.holds}\n"
|
||||
f"Total PnL: {scorecard.total_pnl}\n"
|
||||
f"Win rate: {scorecard.win_rate}%\n"
|
||||
f"Average confidence: {scorecard.avg_confidence}\n"
|
||||
f"Scenario match rate: {scorecard.scenario_match_rate}%\n"
|
||||
f"Top winners: {', '.join(scorecard.top_winners) or 'N/A'}\n"
|
||||
f"Top losers: {', '.join(scorecard.top_losers) or 'N/A'}\n"
|
||||
)
|
||||
|
||||
try:
|
||||
decision = await self._gemini.decide(
|
||||
{
|
||||
"stock_code": "REVIEW",
|
||||
"market_name": scorecard.market,
|
||||
"current_price": 0,
|
||||
"prompt_override": prompt,
|
||||
}
|
||||
)
|
||||
return self._parse_lessons(decision.rationale)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to generate daily lessons: %s", exc)
|
||||
return []
|
||||
|
||||
def store_scorecard_in_context(self, scorecard: DailyScorecard) -> None:
|
||||
"""Store scorecard in L6 using market-scoped key."""
|
||||
self._context_store.set_context(
|
||||
ContextLayer.L6_DAILY,
|
||||
scorecard.date,
|
||||
f"scorecard_{scorecard.market}",
|
||||
asdict(scorecard),
|
||||
)
|
||||
|
||||
def _parse_lessons(self, raw_text: str) -> list[str]:
|
||||
"""Parse lessons from JSON array response or fallback text."""
|
||||
raw_text = raw_text.strip()
|
||||
try:
|
||||
parsed = json.loads(raw_text)
|
||||
if isinstance(parsed, list):
|
||||
return [str(item).strip() for item in parsed if str(item).strip()][:3]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
match = re.search(r"\[.*\]", raw_text, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
parsed = json.loads(match.group(0))
|
||||
if isinstance(parsed, list):
|
||||
return [str(item).strip() for item in parsed if str(item).strip()][:3]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
lines = [line.strip("-* \t") for line in raw_text.splitlines() if line.strip()]
|
||||
return lines[:3]
|
||||
@@ -1,10 +1,10 @@
|
||||
"""Evolution Engine — analyzes trade logs and generates new strategies.
|
||||
|
||||
This module:
|
||||
1. Reads trade_logs.db to identify failing patterns
|
||||
2. Asks Gemini to generate a new strategy class
|
||||
3. Runs pytest on the generated file
|
||||
4. Creates a simulated PR if tests pass
|
||||
1. Uses DecisionLogger.get_losing_decisions() to identify failing patterns
|
||||
2. Analyzes failure patterns by time, market conditions, stock characteristics
|
||||
3. Asks Gemini to generate improved strategy recommendations
|
||||
4. Generates new strategy classes with enhanced decision-making logic
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -14,6 +14,7 @@ import logging
|
||||
import sqlite3
|
||||
import subprocess
|
||||
import textwrap
|
||||
from collections import Counter
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -21,6 +22,8 @@ from typing import Any
|
||||
from google import genai
|
||||
|
||||
from src.config import Settings
|
||||
from src.db import init_db
|
||||
from src.logging.decision_logger import DecisionLogger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -53,29 +56,105 @@ class EvolutionOptimizer:
|
||||
self._db_path = settings.DB_PATH
|
||||
self._client = genai.Client(api_key=settings.GEMINI_API_KEY)
|
||||
self._model_name = settings.GEMINI_MODEL
|
||||
self._conn = init_db(self._db_path)
|
||||
self._decision_logger = DecisionLogger(self._conn)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Analysis
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def analyze_failures(self, limit: int = 50) -> list[dict[str, Any]]:
|
||||
"""Find trades where high confidence led to losses."""
|
||||
conn = sqlite3.connect(self._db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
rows = conn.execute(
|
||||
"""Find high-confidence decisions that resulted in losses.
|
||||
|
||||
Uses DecisionLogger.get_losing_decisions() to retrieve failures.
|
||||
"""
|
||||
SELECT stock_code, action, confidence, pnl, rationale, timestamp
|
||||
FROM trades
|
||||
WHERE confidence >= 80 AND pnl < 0
|
||||
ORDER BY pnl ASC
|
||||
LIMIT ?
|
||||
""",
|
||||
(limit,),
|
||||
).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
finally:
|
||||
conn.close()
|
||||
losing_decisions = self._decision_logger.get_losing_decisions(
|
||||
min_confidence=80, min_loss=-100.0
|
||||
)
|
||||
|
||||
# Limit results
|
||||
if len(losing_decisions) > limit:
|
||||
losing_decisions = losing_decisions[:limit]
|
||||
|
||||
# Convert to dict format for analysis
|
||||
failures = []
|
||||
for decision in losing_decisions:
|
||||
failures.append({
|
||||
"decision_id": decision.decision_id,
|
||||
"timestamp": decision.timestamp,
|
||||
"stock_code": decision.stock_code,
|
||||
"market": decision.market,
|
||||
"exchange_code": decision.exchange_code,
|
||||
"action": decision.action,
|
||||
"confidence": decision.confidence,
|
||||
"rationale": decision.rationale,
|
||||
"outcome_pnl": decision.outcome_pnl,
|
||||
"outcome_accuracy": decision.outcome_accuracy,
|
||||
"context_snapshot": decision.context_snapshot,
|
||||
"input_data": decision.input_data,
|
||||
})
|
||||
|
||||
return failures
|
||||
|
||||
def identify_failure_patterns(
|
||||
self, failures: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
"""Identify patterns in losing decisions.
|
||||
|
||||
Analyzes:
|
||||
- Time patterns (hour of day, day of week)
|
||||
- Market conditions (volatility, volume)
|
||||
- Stock characteristics (price range, market)
|
||||
- Common failure modes in rationale
|
||||
"""
|
||||
if not failures:
|
||||
return {"pattern_count": 0, "patterns": {}}
|
||||
|
||||
patterns = {
|
||||
"markets": Counter(),
|
||||
"actions": Counter(),
|
||||
"hours": Counter(),
|
||||
"avg_confidence": 0.0,
|
||||
"avg_loss": 0.0,
|
||||
"total_failures": len(failures),
|
||||
}
|
||||
|
||||
total_confidence = 0
|
||||
total_loss = 0.0
|
||||
|
||||
for failure in failures:
|
||||
# Market distribution
|
||||
patterns["markets"][failure.get("market", "UNKNOWN")] += 1
|
||||
|
||||
# Action distribution
|
||||
patterns["actions"][failure.get("action", "UNKNOWN")] += 1
|
||||
|
||||
# Time pattern (extract hour from ISO timestamp)
|
||||
timestamp = failure.get("timestamp", "")
|
||||
if timestamp:
|
||||
try:
|
||||
dt = datetime.fromisoformat(timestamp)
|
||||
patterns["hours"][dt.hour] += 1
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
# Aggregate metrics
|
||||
total_confidence += failure.get("confidence", 0)
|
||||
total_loss += failure.get("outcome_pnl", 0.0)
|
||||
|
||||
patterns["avg_confidence"] = (
|
||||
round(total_confidence / len(failures), 2) if failures else 0.0
|
||||
)
|
||||
patterns["avg_loss"] = (
|
||||
round(total_loss / len(failures), 2) if failures else 0.0
|
||||
)
|
||||
|
||||
# Convert Counters to regular dicts for JSON serialization
|
||||
patterns["markets"] = dict(patterns["markets"])
|
||||
patterns["actions"] = dict(patterns["actions"])
|
||||
patterns["hours"] = dict(patterns["hours"])
|
||||
|
||||
return patterns
|
||||
|
||||
def get_performance_summary(self) -> dict[str, Any]:
|
||||
"""Return aggregate performance metrics from trade logs."""
|
||||
@@ -109,14 +188,25 @@ class EvolutionOptimizer:
|
||||
async def generate_strategy(self, failures: list[dict[str, Any]]) -> Path | None:
|
||||
"""Ask Gemini to generate a new strategy based on failure analysis.
|
||||
|
||||
Integrates failure patterns and market conditions to create improved strategies.
|
||||
Returns the path to the generated strategy file, or None on failure.
|
||||
"""
|
||||
# Identify failure patterns first
|
||||
patterns = self.identify_failure_patterns(failures)
|
||||
|
||||
prompt = (
|
||||
"You are a quantitative trading strategy developer.\n"
|
||||
"Analyze these failed trades and generate an improved strategy.\n\n"
|
||||
f"Failed trades:\n{json.dumps(failures, indent=2, default=str)}\n\n"
|
||||
"Generate a Python class that inherits from BaseStrategy.\n"
|
||||
"The class must have an `evaluate(self, market_data: dict) -> dict` method.\n"
|
||||
"Analyze these failed trades and their patterns, then generate an improved strategy.\n\n"
|
||||
f"Failure Patterns:\n{json.dumps(patterns, indent=2)}\n\n"
|
||||
f"Sample Failed Trades (first 5):\n"
|
||||
f"{json.dumps(failures[:5], indent=2, default=str)}\n\n"
|
||||
"Based on these patterns, generate an improved trading strategy.\n"
|
||||
"The strategy should:\n"
|
||||
"1. Avoid the identified failure patterns\n"
|
||||
"2. Consider market-specific conditions\n"
|
||||
"3. Adjust confidence based on historical performance\n\n"
|
||||
"Generate a Python method body that inherits from BaseStrategy.\n"
|
||||
"The method signature is: evaluate(self, market_data: dict) -> dict\n"
|
||||
"The method must return a dict with keys: action, confidence, rationale.\n"
|
||||
"Respond with ONLY the method body (Python code), no class definition.\n"
|
||||
)
|
||||
@@ -147,10 +237,15 @@ class EvolutionOptimizer:
|
||||
# Indent the body for the class method
|
||||
indented_body = textwrap.indent(body, " ")
|
||||
|
||||
# Generate rationale from patterns
|
||||
rationale = f"Auto-evolved from {len(failures)} failures. "
|
||||
rationale += f"Primary failure markets: {list(patterns.get('markets', {}).keys())}. "
|
||||
rationale += f"Average loss: {patterns.get('avg_loss', 0.0)}"
|
||||
|
||||
content = STRATEGY_TEMPLATE.format(
|
||||
name=version,
|
||||
timestamp=datetime.now(UTC).isoformat(),
|
||||
rationale="Auto-evolved from failure analysis",
|
||||
rationale=rationale,
|
||||
class_name=class_name,
|
||||
body=indented_body.strip(),
|
||||
)
|
||||
|
||||
303
src/evolution/performance_tracker.py
Normal file
303
src/evolution/performance_tracker.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""Performance tracking system for strategy monitoring.
|
||||
|
||||
Tracks win rates, monitors improvement over time,
|
||||
and provides performance metrics dashboard.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrategyMetrics:
|
||||
"""Performance metrics for a strategy over a time period."""
|
||||
|
||||
strategy_name: str
|
||||
period_start: str
|
||||
period_end: str
|
||||
total_trades: int
|
||||
wins: int
|
||||
losses: int
|
||||
holds: int
|
||||
win_rate: float
|
||||
avg_pnl: float
|
||||
total_pnl: float
|
||||
best_trade: float
|
||||
worst_trade: float
|
||||
avg_confidence: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerformanceDashboard:
|
||||
"""Comprehensive performance dashboard."""
|
||||
|
||||
generated_at: str
|
||||
overall_metrics: StrategyMetrics
|
||||
daily_metrics: list[StrategyMetrics]
|
||||
weekly_metrics: list[StrategyMetrics]
|
||||
improvement_trend: dict[str, Any]
|
||||
|
||||
|
||||
class PerformanceTracker:
|
||||
"""Tracks and monitors strategy performance over time."""
|
||||
|
||||
def __init__(self, db_path: str) -> None:
|
||||
"""Initialize performance tracker.
|
||||
|
||||
Args:
|
||||
db_path: Path to the trade logs database
|
||||
"""
|
||||
self._db_path = db_path
|
||||
|
||||
def get_strategy_metrics(
|
||||
self,
|
||||
strategy_name: str | None = None,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
) -> StrategyMetrics:
|
||||
"""Get performance metrics for a strategy over a time period.
|
||||
|
||||
Args:
|
||||
strategy_name: Name of the strategy (None = all strategies)
|
||||
start_date: Start date in ISO format (None = beginning of time)
|
||||
end_date: End date in ISO format (None = now)
|
||||
|
||||
Returns:
|
||||
StrategyMetrics object with performance data
|
||||
"""
|
||||
conn = sqlite3.connect(self._db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
try:
|
||||
# Build query with optional filters
|
||||
query = """
|
||||
SELECT
|
||||
COUNT(*) as total_trades,
|
||||
SUM(CASE WHEN pnl > 0 THEN 1 ELSE 0 END) as wins,
|
||||
SUM(CASE WHEN pnl < 0 THEN 1 ELSE 0 END) as losses,
|
||||
SUM(CASE WHEN action = 'HOLD' THEN 1 ELSE 0 END) as holds,
|
||||
COALESCE(AVG(CASE WHEN pnl IS NOT NULL THEN pnl END), 0) as avg_pnl,
|
||||
COALESCE(SUM(CASE WHEN pnl IS NOT NULL THEN pnl ELSE 0 END), 0) as total_pnl,
|
||||
COALESCE(MAX(pnl), 0) as best_trade,
|
||||
COALESCE(MIN(pnl), 0) as worst_trade,
|
||||
COALESCE(AVG(confidence), 0) as avg_confidence,
|
||||
MIN(timestamp) as period_start,
|
||||
MAX(timestamp) as period_end
|
||||
FROM trades
|
||||
WHERE 1=1
|
||||
"""
|
||||
params: list[Any] = []
|
||||
|
||||
if start_date:
|
||||
query += " AND timestamp >= ?"
|
||||
params.append(start_date)
|
||||
|
||||
if end_date:
|
||||
query += " AND timestamp <= ?"
|
||||
params.append(end_date)
|
||||
|
||||
# Note: Currently trades table doesn't have strategy_name column
|
||||
# This is a placeholder for future extension
|
||||
|
||||
row = conn.execute(query, params).fetchone()
|
||||
|
||||
total_trades = row["total_trades"] or 0
|
||||
wins = row["wins"] or 0
|
||||
win_rate = (wins / total_trades * 100) if total_trades > 0 else 0.0
|
||||
|
||||
return StrategyMetrics(
|
||||
strategy_name=strategy_name or "default",
|
||||
period_start=row["period_start"] or "",
|
||||
period_end=row["period_end"] or "",
|
||||
total_trades=total_trades,
|
||||
wins=wins,
|
||||
losses=row["losses"] or 0,
|
||||
holds=row["holds"] or 0,
|
||||
win_rate=round(win_rate, 2),
|
||||
avg_pnl=round(row["avg_pnl"], 2),
|
||||
total_pnl=round(row["total_pnl"], 2),
|
||||
best_trade=round(row["best_trade"], 2),
|
||||
worst_trade=round(row["worst_trade"], 2),
|
||||
avg_confidence=round(row["avg_confidence"], 2),
|
||||
)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_daily_metrics(
|
||||
self, days: int = 7, strategy_name: str | None = None
|
||||
) -> list[StrategyMetrics]:
|
||||
"""Get daily performance metrics for the last N days.
|
||||
|
||||
Args:
|
||||
days: Number of days to retrieve (default 7)
|
||||
strategy_name: Name of the strategy (None = all strategies)
|
||||
|
||||
Returns:
|
||||
List of StrategyMetrics, one per day
|
||||
"""
|
||||
metrics = []
|
||||
end_date = datetime.now(UTC)
|
||||
|
||||
for i in range(days):
|
||||
day_end = end_date - timedelta(days=i)
|
||||
day_start = day_end - timedelta(days=1)
|
||||
|
||||
day_metrics = self.get_strategy_metrics(
|
||||
strategy_name=strategy_name,
|
||||
start_date=day_start.isoformat(),
|
||||
end_date=day_end.isoformat(),
|
||||
)
|
||||
metrics.append(day_metrics)
|
||||
|
||||
return metrics
|
||||
|
||||
def get_weekly_metrics(
|
||||
self, weeks: int = 4, strategy_name: str | None = None
|
||||
) -> list[StrategyMetrics]:
|
||||
"""Get weekly performance metrics for the last N weeks.
|
||||
|
||||
Args:
|
||||
weeks: Number of weeks to retrieve (default 4)
|
||||
strategy_name: Name of the strategy (None = all strategies)
|
||||
|
||||
Returns:
|
||||
List of StrategyMetrics, one per week
|
||||
"""
|
||||
metrics = []
|
||||
end_date = datetime.now(UTC)
|
||||
|
||||
for i in range(weeks):
|
||||
week_end = end_date - timedelta(weeks=i)
|
||||
week_start = week_end - timedelta(weeks=1)
|
||||
|
||||
week_metrics = self.get_strategy_metrics(
|
||||
strategy_name=strategy_name,
|
||||
start_date=week_start.isoformat(),
|
||||
end_date=week_end.isoformat(),
|
||||
)
|
||||
metrics.append(week_metrics)
|
||||
|
||||
return metrics
|
||||
|
||||
def calculate_improvement_trend(
|
||||
self, metrics_history: list[StrategyMetrics]
|
||||
) -> dict[str, Any]:
|
||||
"""Calculate improvement trend from historical metrics.
|
||||
|
||||
Args:
|
||||
metrics_history: List of StrategyMetrics ordered from oldest to newest
|
||||
|
||||
Returns:
|
||||
Dictionary with trend analysis
|
||||
"""
|
||||
if len(metrics_history) < 2:
|
||||
return {
|
||||
"trend": "insufficient_data",
|
||||
"win_rate_change": 0.0,
|
||||
"pnl_change": 0.0,
|
||||
"confidence_change": 0.0,
|
||||
}
|
||||
|
||||
oldest = metrics_history[0]
|
||||
newest = metrics_history[-1]
|
||||
|
||||
win_rate_change = newest.win_rate - oldest.win_rate
|
||||
pnl_change = newest.avg_pnl - oldest.avg_pnl
|
||||
confidence_change = newest.avg_confidence - oldest.avg_confidence
|
||||
|
||||
# Determine overall trend
|
||||
if win_rate_change > 5.0 and pnl_change > 0:
|
||||
trend = "improving"
|
||||
elif win_rate_change < -5.0 or pnl_change < 0:
|
||||
trend = "declining"
|
||||
else:
|
||||
trend = "stable"
|
||||
|
||||
return {
|
||||
"trend": trend,
|
||||
"win_rate_change": round(win_rate_change, 2),
|
||||
"pnl_change": round(pnl_change, 2),
|
||||
"confidence_change": round(confidence_change, 2),
|
||||
"period_count": len(metrics_history),
|
||||
}
|
||||
|
||||
def generate_dashboard(
|
||||
self, strategy_name: str | None = None
|
||||
) -> PerformanceDashboard:
|
||||
"""Generate a comprehensive performance dashboard.
|
||||
|
||||
Args:
|
||||
strategy_name: Name of the strategy (None = all strategies)
|
||||
|
||||
Returns:
|
||||
PerformanceDashboard with all metrics
|
||||
"""
|
||||
# Get overall metrics
|
||||
overall_metrics = self.get_strategy_metrics(strategy_name=strategy_name)
|
||||
|
||||
# Get daily metrics (last 7 days)
|
||||
daily_metrics = self.get_daily_metrics(days=7, strategy_name=strategy_name)
|
||||
|
||||
# Get weekly metrics (last 4 weeks)
|
||||
weekly_metrics = self.get_weekly_metrics(weeks=4, strategy_name=strategy_name)
|
||||
|
||||
# Calculate improvement trend
|
||||
improvement_trend = self.calculate_improvement_trend(weekly_metrics[::-1])
|
||||
|
||||
return PerformanceDashboard(
|
||||
generated_at=datetime.now(UTC).isoformat(),
|
||||
overall_metrics=overall_metrics,
|
||||
daily_metrics=daily_metrics,
|
||||
weekly_metrics=weekly_metrics,
|
||||
improvement_trend=improvement_trend,
|
||||
)
|
||||
|
||||
def export_dashboard_json(
|
||||
self, dashboard: PerformanceDashboard
|
||||
) -> str:
|
||||
"""Export dashboard as JSON string.
|
||||
|
||||
Args:
|
||||
dashboard: PerformanceDashboard object
|
||||
|
||||
Returns:
|
||||
JSON string representation
|
||||
"""
|
||||
data = {
|
||||
"generated_at": dashboard.generated_at,
|
||||
"overall_metrics": asdict(dashboard.overall_metrics),
|
||||
"daily_metrics": [asdict(m) for m in dashboard.daily_metrics],
|
||||
"weekly_metrics": [asdict(m) for m in dashboard.weekly_metrics],
|
||||
"improvement_trend": dashboard.improvement_trend,
|
||||
}
|
||||
return json.dumps(data, indent=2)
|
||||
|
||||
def log_dashboard(self, dashboard: PerformanceDashboard) -> None:
|
||||
"""Log dashboard summary to logger.
|
||||
|
||||
Args:
|
||||
dashboard: PerformanceDashboard object
|
||||
"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("PERFORMANCE DASHBOARD")
|
||||
logger.info("=" * 60)
|
||||
logger.info("Generated: %s", dashboard.generated_at)
|
||||
logger.info("")
|
||||
logger.info("Overall Performance:")
|
||||
logger.info(" Total Trades: %d", dashboard.overall_metrics.total_trades)
|
||||
logger.info(" Win Rate: %.2f%%", dashboard.overall_metrics.win_rate)
|
||||
logger.info(" Average P&L: %.2f", dashboard.overall_metrics.avg_pnl)
|
||||
logger.info(" Total P&L: %.2f", dashboard.overall_metrics.total_pnl)
|
||||
logger.info("")
|
||||
logger.info("Improvement Trend (%s):", dashboard.improvement_trend["trend"])
|
||||
logger.info(" Win Rate Change: %+.2f%%", dashboard.improvement_trend["win_rate_change"])
|
||||
logger.info(" P&L Change: %+.2f", dashboard.improvement_trend["pnl_change"])
|
||||
logger.info("=" * 60)
|
||||
25
src/evolution/scorecard.py
Normal file
25
src/evolution/scorecard.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""Daily scorecard model for end-of-day performance review."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class DailyScorecard:
|
||||
"""Structured daily performance snapshot for a single market."""
|
||||
|
||||
date: str
|
||||
market: str
|
||||
total_decisions: int
|
||||
buys: int
|
||||
sells: int
|
||||
holds: int
|
||||
total_pnl: float
|
||||
win_rate: float
|
||||
avg_confidence: float
|
||||
scenario_match_rate: float
|
||||
top_winners: list[str] = field(default_factory=list)
|
||||
top_losers: list[str] = field(default_factory=list)
|
||||
lessons: list[str] = field(default_factory=list)
|
||||
cross_market_note: str = ""
|
||||
2943
src/main.py
2943
src/main.py
File diff suppressed because it is too large
Load Diff
@@ -123,6 +123,23 @@ MARKETS: dict[str, MarketInfo] = {
|
||||
),
|
||||
}
|
||||
|
||||
MARKET_SHORTHAND: dict[str, list[str]] = {
|
||||
"US": ["US_NASDAQ", "US_NYSE", "US_AMEX"],
|
||||
"CN": ["CN_SHA", "CN_SZA"],
|
||||
"VN": ["VN_HAN", "VN_HCM"],
|
||||
}
|
||||
|
||||
|
||||
def expand_market_codes(codes: list[str]) -> list[str]:
|
||||
"""Expand shorthand market codes into concrete exchange market codes."""
|
||||
expanded: list[str] = []
|
||||
for code in codes:
|
||||
if code in MARKET_SHORTHAND:
|
||||
expanded.extend(MARKET_SHORTHAND[code])
|
||||
else:
|
||||
expanded.append(code)
|
||||
return expanded
|
||||
|
||||
|
||||
def is_market_open(market: MarketInfo, now: datetime | None = None) -> bool:
|
||||
"""
|
||||
|
||||
350
src/notifications/README.md
Normal file
350
src/notifications/README.md
Normal file
@@ -0,0 +1,350 @@
|
||||
# Telegram Notifications
|
||||
|
||||
Real-time trading event notifications via Telegram Bot API.
|
||||
|
||||
## Setup
|
||||
|
||||
### 1. Create a Telegram Bot
|
||||
|
||||
1. Open Telegram and message [@BotFather](https://t.me/BotFather)
|
||||
2. Send `/newbot` command
|
||||
3. Follow prompts to name your bot
|
||||
4. Save the **bot token** (looks like `1234567890:ABCdefGHIjklMNOpqrsTUVwxyz`)
|
||||
|
||||
### 2. Get Your Chat ID
|
||||
|
||||
**Option A: Using @userinfobot**
|
||||
1. Message [@userinfobot](https://t.me/userinfobot) on Telegram
|
||||
2. Send `/start`
|
||||
3. Save your numeric **chat ID** (e.g., `123456789`)
|
||||
|
||||
**Option B: Using @RawDataBot**
|
||||
1. Message [@RawDataBot](https://t.me/rawdatabot) on Telegram
|
||||
2. Look for `"id":` in the JSON response
|
||||
3. Save your numeric **chat ID**
|
||||
|
||||
### 3. Configure Environment
|
||||
|
||||
Add to your `.env` file:
|
||||
|
||||
```bash
|
||||
TELEGRAM_BOT_TOKEN=1234567890:ABCdefGHIjklMNOpqrsTUVwxyz
|
||||
TELEGRAM_CHAT_ID=123456789
|
||||
TELEGRAM_ENABLED=true
|
||||
```
|
||||
|
||||
### 4. Test the Bot
|
||||
|
||||
Start a conversation with your bot on Telegram first (send `/start`), then run:
|
||||
|
||||
```bash
|
||||
python -m src.main --mode=paper
|
||||
```
|
||||
|
||||
You should receive a startup notification.
|
||||
|
||||
## Message Examples
|
||||
|
||||
### Trade Execution
|
||||
```
|
||||
🟢 BUY
|
||||
Symbol: AAPL (United States)
|
||||
Quantity: 10 shares
|
||||
Price: 150.25
|
||||
Confidence: 85%
|
||||
```
|
||||
|
||||
### Circuit Breaker
|
||||
```
|
||||
🚨 CIRCUIT BREAKER TRIPPED
|
||||
P&L: -3.15% (threshold: -3.0%)
|
||||
Trading halted for safety
|
||||
```
|
||||
|
||||
### Fat-Finger Protection
|
||||
```
|
||||
⚠️ Fat-Finger Protection
|
||||
Order rejected: TSLA
|
||||
Attempted: 45.0% of cash
|
||||
Max allowed: 30%
|
||||
Amount: 45,000 / 100,000
|
||||
```
|
||||
|
||||
### Market Open/Close
|
||||
```
|
||||
ℹ️ Market Open
|
||||
Korea trading session started
|
||||
|
||||
ℹ️ Market Close
|
||||
Korea trading session ended
|
||||
📈 P&L: +1.25%
|
||||
```
|
||||
|
||||
### System Status
|
||||
```
|
||||
📝 System Started
|
||||
Mode: PAPER
|
||||
Markets: KRX, NASDAQ
|
||||
|
||||
System Shutdown
|
||||
Normal shutdown
|
||||
```
|
||||
|
||||
## Notification Priorities
|
||||
|
||||
| Priority | Emoji | Use Case |
|
||||
|----------|-------|----------|
|
||||
| LOW | ℹ️ | Market open/close |
|
||||
| MEDIUM | 📊 | Trade execution, system start/stop |
|
||||
| HIGH | ⚠️ | Fat-finger protection, errors |
|
||||
| CRITICAL | 🚨 | Circuit breaker trips |
|
||||
|
||||
## Rate Limiting
|
||||
|
||||
- Default: 1 message per second
|
||||
- Prevents hitting Telegram's global rate limits
|
||||
- Configurable via `rate_limit` parameter
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### No notifications received
|
||||
|
||||
1. **Check bot configuration**
|
||||
```bash
|
||||
# Verify env variables are set
|
||||
grep TELEGRAM .env
|
||||
```
|
||||
|
||||
2. **Start conversation with bot**
|
||||
- Open bot in Telegram
|
||||
- Send `/start` command
|
||||
- Bot cannot message users who haven't started a conversation
|
||||
|
||||
3. **Check logs**
|
||||
```bash
|
||||
# Look for Telegram-related errors
|
||||
python -m src.main --mode=paper 2>&1 | grep -i telegram
|
||||
```
|
||||
|
||||
4. **Verify bot token**
|
||||
```bash
|
||||
curl https://api.telegram.org/bot<YOUR_TOKEN>/getMe
|
||||
# Should return bot info (not 401 error)
|
||||
```
|
||||
|
||||
5. **Verify chat ID**
|
||||
```bash
|
||||
curl -X POST https://api.telegram.org/bot<YOUR_TOKEN>/sendMessage \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"chat_id": "<YOUR_CHAT_ID>", "text": "Test"}'
|
||||
# Should send a test message
|
||||
```
|
||||
|
||||
### Notifications delayed
|
||||
|
||||
- Check rate limiter settings
|
||||
- Verify network connection
|
||||
- Look for timeout errors in logs
|
||||
|
||||
### "Chat not found" error
|
||||
|
||||
- Incorrect chat ID
|
||||
- Bot blocked by user
|
||||
- Need to send `/start` to bot first
|
||||
|
||||
### "Unauthorized" error
|
||||
|
||||
- Invalid bot token
|
||||
- Token revoked (regenerate with @BotFather)
|
||||
|
||||
## Graceful Degradation
|
||||
|
||||
The system works without Telegram notifications:
|
||||
|
||||
- Missing credentials → notifications disabled automatically
|
||||
- API errors → logged but trading continues
|
||||
- Network timeouts → trading loop unaffected
|
||||
- Rate limiting → messages queued, trading proceeds
|
||||
|
||||
**Notifications never crash the trading system.**
|
||||
|
||||
## Security Notes
|
||||
|
||||
- Never commit `.env` file with credentials
|
||||
- Bot token grants full bot control
|
||||
- Chat ID is not sensitive (just a number)
|
||||
- Messages are sent over HTTPS
|
||||
- No trading credentials in notifications
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Group Notifications
|
||||
|
||||
1. Add bot to Telegram group
|
||||
2. Get group chat ID (negative number like `-123456789`)
|
||||
3. Use group chat ID in `TELEGRAM_CHAT_ID`
|
||||
|
||||
### Multiple Recipients
|
||||
|
||||
Create multiple bots or use a broadcast group with multiple members.
|
||||
|
||||
### Custom Rate Limits
|
||||
|
||||
Not currently exposed in config, but can be modified in code:
|
||||
|
||||
```python
|
||||
telegram = TelegramClient(
|
||||
bot_token=settings.TELEGRAM_BOT_TOKEN,
|
||||
chat_id=settings.TELEGRAM_CHAT_ID,
|
||||
rate_limit=2.0, # 2 messages per second
|
||||
)
|
||||
```
|
||||
|
||||
## Bidirectional Commands
|
||||
|
||||
Control your trading bot remotely via Telegram commands. The bot not only sends notifications but also accepts commands for real-time control.
|
||||
|
||||
### Available Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/start` | Welcome message with quick start guide |
|
||||
| `/help` | List all available commands |
|
||||
| `/status` | Current trading status (mode, markets, P&L, circuit breaker) |
|
||||
| `/positions` | View current holdings grouped by market |
|
||||
| `/stop` | Pause all trading operations |
|
||||
| `/resume` | Resume trading operations |
|
||||
|
||||
### Command Examples
|
||||
|
||||
**Check Trading Status**
|
||||
```
|
||||
You: /status
|
||||
|
||||
Bot:
|
||||
📊 Trading Status
|
||||
|
||||
Mode: PAPER
|
||||
Markets: Korea, United States
|
||||
Trading: Active
|
||||
|
||||
Current P&L: +2.50%
|
||||
Circuit Breaker: -3.0%
|
||||
```
|
||||
|
||||
**View Holdings**
|
||||
```
|
||||
You: /positions
|
||||
|
||||
Bot:
|
||||
💼 Current Holdings
|
||||
|
||||
🇰🇷 Korea
|
||||
• 005930: 10 shares @ 70,000
|
||||
• 035420: 5 shares @ 200,000
|
||||
|
||||
🇺🇸 Overseas
|
||||
• AAPL: 15 shares @ 175
|
||||
• TSLA: 8 shares @ 245
|
||||
|
||||
Cash: ₩5,000,000
|
||||
```
|
||||
|
||||
**Pause Trading**
|
||||
```
|
||||
You: /stop
|
||||
|
||||
Bot:
|
||||
⏸️ Trading Paused
|
||||
|
||||
All trading operations have been suspended.
|
||||
Use /resume to restart trading.
|
||||
```
|
||||
|
||||
**Resume Trading**
|
||||
```
|
||||
You: /resume
|
||||
|
||||
Bot:
|
||||
▶️ Trading Resumed
|
||||
|
||||
Trading operations have been restarted.
|
||||
```
|
||||
|
||||
### Security
|
||||
|
||||
**Chat ID Verification**
|
||||
- Commands are only accepted from the configured `TELEGRAM_CHAT_ID`
|
||||
- Unauthorized users receive no response
|
||||
- Command attempts from wrong chat IDs are logged
|
||||
|
||||
**Authorization Required**
|
||||
- Only the bot owner (chat ID in `.env`) can control trading
|
||||
- No way for unauthorized users to discover or use commands
|
||||
- All command executions are logged for audit
|
||||
|
||||
### Configuration
|
||||
|
||||
Add to your `.env` file:
|
||||
|
||||
```bash
|
||||
# Commands are enabled by default
|
||||
TELEGRAM_COMMANDS_ENABLED=true
|
||||
|
||||
# Polling interval (seconds) - how often to check for commands
|
||||
TELEGRAM_POLLING_INTERVAL=1.0
|
||||
```
|
||||
|
||||
To disable commands but keep notifications:
|
||||
```bash
|
||||
TELEGRAM_COMMANDS_ENABLED=false
|
||||
```
|
||||
|
||||
### How It Works
|
||||
|
||||
1. **Long Polling**: Bot checks Telegram API every second for new messages
|
||||
2. **Command Parsing**: Messages starting with `/` are parsed as commands
|
||||
3. **Authentication**: Chat ID is verified before executing any command
|
||||
4. **Execution**: Command handler is called with current bot state
|
||||
5. **Response**: Result is sent back via Telegram
|
||||
|
||||
### Error Handling
|
||||
|
||||
- Command parsing errors → "Unknown command" response
|
||||
- API failures → Graceful degradation, error logged
|
||||
- Invalid state → Appropriate message (e.g., "Trading is already paused")
|
||||
- Trading loop isolation → Command errors never crash trading
|
||||
|
||||
### Troubleshooting Commands
|
||||
|
||||
**Commands not responding**
|
||||
1. Check `TELEGRAM_COMMANDS_ENABLED=true` in `.env`
|
||||
2. Verify you started conversation with `/start`
|
||||
3. Check logs for command handler errors
|
||||
4. Confirm chat ID matches `.env` configuration
|
||||
|
||||
**Wrong chat ID**
|
||||
- Commands from unauthorized chats are silently ignored
|
||||
- Check logs for "unauthorized chat_id" warnings
|
||||
|
||||
**Delayed responses**
|
||||
- Polling interval is 1 second by default
|
||||
- Network latency may add delay
|
||||
- Check `TELEGRAM_POLLING_INTERVAL` setting
|
||||
|
||||
## API Reference
|
||||
|
||||
See `telegram_client.py` for full API documentation.
|
||||
|
||||
### Notification Methods
|
||||
- `notify_trade_execution()` - Trade alerts
|
||||
- `notify_circuit_breaker()` - Emergency stops
|
||||
- `notify_fat_finger()` - Order rejections
|
||||
- `notify_market_open/close()` - Session tracking
|
||||
- `notify_system_start/shutdown()` - Lifecycle events
|
||||
- `notify_error()` - Error alerts
|
||||
|
||||
### Command Handler
|
||||
- `TelegramCommandHandler` - Bidirectional command processing
|
||||
- `register_command()` - Register custom command handlers
|
||||
- `start_polling()` / `stop_polling()` - Lifecycle management
|
||||
5
src/notifications/__init__.py
Normal file
5
src/notifications/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Real-time notifications for trading events."""
|
||||
|
||||
from src.notifications.telegram_client import TelegramClient
|
||||
|
||||
__all__ = ["TelegramClient"]
|
||||
734
src/notifications/telegram_client.py
Normal file
734
src/notifications/telegram_client.py
Normal file
@@ -0,0 +1,734 @@
|
||||
"""Telegram notification client for real-time trading alerts."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, fields
|
||||
from enum import Enum
|
||||
from typing import ClassVar
|
||||
|
||||
import aiohttp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NotificationPriority(Enum):
|
||||
"""Priority levels for notifications with emoji indicators."""
|
||||
|
||||
LOW = ("ℹ️", "info")
|
||||
MEDIUM = ("📊", "medium")
|
||||
HIGH = ("⚠️", "warning")
|
||||
CRITICAL = ("🚨", "critical")
|
||||
|
||||
def __init__(self, emoji: str, label: str) -> None:
|
||||
self.emoji = emoji
|
||||
self.label = label
|
||||
|
||||
|
||||
class LeakyBucket:
|
||||
"""Rate limiter using leaky bucket algorithm."""
|
||||
|
||||
def __init__(self, rate: float, capacity: int = 1) -> None:
|
||||
"""
|
||||
Initialize rate limiter.
|
||||
|
||||
Args:
|
||||
rate: Maximum requests per second
|
||||
capacity: Bucket capacity (burst size)
|
||||
"""
|
||||
self._rate = rate
|
||||
self._capacity = capacity
|
||||
self._tokens = float(capacity)
|
||||
self._last_update = time.monotonic()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def acquire(self) -> None:
|
||||
"""Wait until a token is available, then consume it."""
|
||||
async with self._lock:
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_update
|
||||
self._tokens = min(self._capacity, self._tokens + elapsed * self._rate)
|
||||
self._last_update = now
|
||||
|
||||
if self._tokens < 1.0:
|
||||
wait_time = (1.0 - self._tokens) / self._rate
|
||||
await asyncio.sleep(wait_time)
|
||||
self._tokens = 0.0
|
||||
else:
|
||||
self._tokens -= 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class NotificationFilter:
|
||||
"""Granular on/off flags for each notification type.
|
||||
|
||||
circuit_breaker is intentionally omitted — it is always sent regardless.
|
||||
"""
|
||||
|
||||
# Maps user-facing command keys to dataclass field names
|
||||
KEYS: ClassVar[dict[str, str]] = {
|
||||
"trades": "trades",
|
||||
"market": "market_open_close",
|
||||
"fatfinger": "fat_finger",
|
||||
"system": "system_events",
|
||||
"playbook": "playbook",
|
||||
"scenario": "scenario_match",
|
||||
"errors": "errors",
|
||||
}
|
||||
|
||||
trades: bool = True
|
||||
market_open_close: bool = True
|
||||
fat_finger: bool = True
|
||||
system_events: bool = True
|
||||
playbook: bool = True
|
||||
scenario_match: bool = True
|
||||
errors: bool = True
|
||||
|
||||
def set_flag(self, key: str, value: bool) -> bool:
|
||||
"""Set a filter flag by user-facing key. Returns False if key is unknown."""
|
||||
field = self.KEYS.get(key.lower())
|
||||
if field is None:
|
||||
return False
|
||||
setattr(self, field, value)
|
||||
return True
|
||||
|
||||
def as_dict(self) -> dict[str, bool]:
|
||||
"""Return {user_key: current_value} for display."""
|
||||
return {k: getattr(self, field) for k, field in self.KEYS.items()}
|
||||
|
||||
|
||||
@dataclass
|
||||
class NotificationMessage:
|
||||
"""Internal notification message structure."""
|
||||
|
||||
priority: NotificationPriority
|
||||
message: str
|
||||
|
||||
|
||||
class TelegramClient:
|
||||
"""Telegram Bot API client for sending trading notifications."""
|
||||
|
||||
API_BASE = "https://api.telegram.org/bot{token}"
|
||||
DEFAULT_TIMEOUT = 5.0 # seconds
|
||||
DEFAULT_RATE = 1.0 # messages per second
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bot_token: str | None = None,
|
||||
chat_id: str | None = None,
|
||||
enabled: bool = True,
|
||||
rate_limit: float = DEFAULT_RATE,
|
||||
notification_filter: NotificationFilter | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize Telegram client.
|
||||
|
||||
Args:
|
||||
bot_token: Telegram bot token from @BotFather
|
||||
chat_id: Target chat ID (user or group)
|
||||
enabled: Enable/disable notifications globally
|
||||
rate_limit: Maximum messages per second
|
||||
notification_filter: Granular per-type on/off flags
|
||||
"""
|
||||
self._bot_token = bot_token
|
||||
self._chat_id = chat_id
|
||||
self._enabled = enabled
|
||||
self._rate_limiter = LeakyBucket(rate=rate_limit)
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._filter = notification_filter if notification_filter is not None else NotificationFilter()
|
||||
|
||||
if not enabled:
|
||||
logger.info("Telegram notifications disabled via configuration")
|
||||
elif bot_token is None or chat_id is None:
|
||||
logger.warning(
|
||||
"Telegram notifications disabled (missing bot_token or chat_id)"
|
||||
)
|
||||
self._enabled = False
|
||||
else:
|
||||
logger.info("Telegram notifications enabled for chat_id=%s", chat_id)
|
||||
|
||||
def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create aiohttp session."""
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=self.DEFAULT_TIMEOUT)
|
||||
)
|
||||
return self._session
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close HTTP session."""
|
||||
if self._session is not None and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
def set_notification(self, key: str, value: bool) -> bool:
|
||||
"""Toggle a notification type by user-facing key at runtime.
|
||||
|
||||
Args:
|
||||
key: User-facing key (e.g. "scenario", "market", "all")
|
||||
value: True to enable, False to disable
|
||||
|
||||
Returns:
|
||||
True if key was valid, False if unknown.
|
||||
"""
|
||||
if key == "all":
|
||||
for k in NotificationFilter.KEYS:
|
||||
self._filter.set_flag(k, value)
|
||||
return True
|
||||
return self._filter.set_flag(key, value)
|
||||
|
||||
def filter_status(self) -> dict[str, bool]:
|
||||
"""Return current per-type filter state keyed by user-facing names."""
|
||||
return self._filter.as_dict()
|
||||
|
||||
async def send_message(self, text: str, parse_mode: str = "HTML") -> bool:
|
||||
"""
|
||||
Send a generic text message to Telegram.
|
||||
|
||||
Args:
|
||||
text: Message text to send
|
||||
parse_mode: Parse mode for formatting (HTML or Markdown)
|
||||
|
||||
Returns:
|
||||
True if message was sent successfully, False otherwise
|
||||
"""
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
try:
|
||||
await self._rate_limiter.acquire()
|
||||
|
||||
url = f"{self.API_BASE.format(token=self._bot_token)}/sendMessage"
|
||||
payload = {
|
||||
"chat_id": self._chat_id,
|
||||
"text": text,
|
||||
"parse_mode": parse_mode,
|
||||
}
|
||||
|
||||
session = self._get_session()
|
||||
async with session.post(url, json=payload) as resp:
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
logger.error(
|
||||
"Telegram API error (status=%d): %s", resp.status, error_text
|
||||
)
|
||||
return False
|
||||
logger.debug("Telegram message sent: %s", text[:50])
|
||||
return True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Telegram message timeout")
|
||||
return False
|
||||
except aiohttp.ClientError as exc:
|
||||
logger.error("Telegram message failed: %s", exc)
|
||||
return False
|
||||
except Exception as exc:
|
||||
logger.error("Unexpected error sending message: %s", exc)
|
||||
return False
|
||||
|
||||
async def _send_notification(self, msg: NotificationMessage) -> None:
|
||||
"""
|
||||
Send notification to Telegram with graceful degradation.
|
||||
|
||||
Args:
|
||||
msg: Notification message to send
|
||||
"""
|
||||
formatted_message = f"{msg.priority.emoji} {msg.message}"
|
||||
await self.send_message(formatted_message)
|
||||
|
||||
async def notify_trade_execution(
|
||||
self,
|
||||
stock_code: str,
|
||||
market: str,
|
||||
action: str,
|
||||
quantity: int,
|
||||
price: float,
|
||||
confidence: float,
|
||||
) -> None:
|
||||
"""
|
||||
Notify trade execution.
|
||||
|
||||
Args:
|
||||
stock_code: Stock ticker symbol
|
||||
market: Market name (e.g., "Korea", "United States")
|
||||
action: "BUY" or "SELL"
|
||||
quantity: Number of shares
|
||||
price: Execution price
|
||||
confidence: AI confidence level (0-100)
|
||||
"""
|
||||
if not self._filter.trades:
|
||||
return
|
||||
emoji = "🟢" if action == "BUY" else "🔴"
|
||||
message = (
|
||||
f"<b>{emoji} {action}</b>\n"
|
||||
f"Symbol: <code>{stock_code}</code> ({market})\n"
|
||||
f"Quantity: {quantity:,} shares\n"
|
||||
f"Price: {price:,.2f}\n"
|
||||
f"Confidence: {confidence:.0f}%"
|
||||
)
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=NotificationPriority.MEDIUM, message=message)
|
||||
)
|
||||
|
||||
async def notify_market_open(self, market_name: str) -> None:
|
||||
"""
|
||||
Notify market opening.
|
||||
|
||||
Args:
|
||||
market_name: Name of the market (e.g., "Korea", "United States")
|
||||
"""
|
||||
if not self._filter.market_open_close:
|
||||
return
|
||||
message = f"<b>Market Open</b>\n{market_name} trading session started"
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=NotificationPriority.LOW, message=message)
|
||||
)
|
||||
|
||||
async def notify_market_close(self, market_name: str, pnl_pct: float) -> None:
|
||||
"""
|
||||
Notify market closing.
|
||||
|
||||
Args:
|
||||
market_name: Name of the market
|
||||
pnl_pct: Final P&L percentage for the session
|
||||
"""
|
||||
if not self._filter.market_open_close:
|
||||
return
|
||||
pnl_sign = "+" if pnl_pct >= 0 else ""
|
||||
pnl_emoji = "📈" if pnl_pct >= 0 else "📉"
|
||||
message = (
|
||||
f"<b>Market Close</b>\n"
|
||||
f"{market_name} trading session ended\n"
|
||||
f"{pnl_emoji} P&L: {pnl_sign}{pnl_pct:.2f}%"
|
||||
)
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=NotificationPriority.LOW, message=message)
|
||||
)
|
||||
|
||||
async def notify_circuit_breaker(
|
||||
self, pnl_pct: float, threshold: float
|
||||
) -> None:
|
||||
"""
|
||||
Notify circuit breaker activation.
|
||||
|
||||
Args:
|
||||
pnl_pct: Current P&L percentage
|
||||
threshold: Circuit breaker threshold
|
||||
"""
|
||||
message = (
|
||||
f"<b>CIRCUIT BREAKER TRIPPED</b>\n"
|
||||
f"P&L: {pnl_pct:.2f}% (threshold: {threshold:.1f}%)\n"
|
||||
f"Trading halted for safety"
|
||||
)
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=NotificationPriority.CRITICAL, message=message)
|
||||
)
|
||||
|
||||
async def notify_fat_finger(
|
||||
self,
|
||||
stock_code: str,
|
||||
order_amount: float,
|
||||
total_cash: float,
|
||||
max_pct: float,
|
||||
) -> None:
|
||||
"""
|
||||
Notify fat-finger protection rejection.
|
||||
|
||||
Args:
|
||||
stock_code: Stock ticker symbol
|
||||
order_amount: Attempted order amount
|
||||
total_cash: Total available cash
|
||||
max_pct: Maximum allowed percentage
|
||||
"""
|
||||
if not self._filter.fat_finger:
|
||||
return
|
||||
attempted_pct = (order_amount / total_cash) * 100 if total_cash > 0 else 0
|
||||
message = (
|
||||
f"<b>Fat-Finger Protection</b>\n"
|
||||
f"Order rejected: <code>{stock_code}</code>\n"
|
||||
f"Attempted: {attempted_pct:.1f}% of cash\n"
|
||||
f"Max allowed: {max_pct:.0f}%\n"
|
||||
f"Amount: {order_amount:,.0f} / {total_cash:,.0f}"
|
||||
)
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=NotificationPriority.HIGH, message=message)
|
||||
)
|
||||
|
||||
async def notify_system_start(
|
||||
self, mode: str, enabled_markets: list[str]
|
||||
) -> None:
|
||||
"""
|
||||
Notify system startup.
|
||||
|
||||
Args:
|
||||
mode: Trading mode ("paper" or "live")
|
||||
enabled_markets: List of enabled market codes
|
||||
"""
|
||||
if not self._filter.system_events:
|
||||
return
|
||||
mode_emoji = "📝" if mode == "paper" else "💰"
|
||||
markets_str = ", ".join(enabled_markets)
|
||||
message = (
|
||||
f"<b>{mode_emoji} System Started</b>\n"
|
||||
f"Mode: {mode.upper()}\n"
|
||||
f"Markets: {markets_str}"
|
||||
)
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=NotificationPriority.MEDIUM, message=message)
|
||||
)
|
||||
|
||||
async def notify_playbook_generated(
|
||||
self,
|
||||
market: str,
|
||||
stock_count: int,
|
||||
scenario_count: int,
|
||||
token_count: int,
|
||||
) -> None:
|
||||
"""
|
||||
Notify that a daily playbook was generated.
|
||||
|
||||
Args:
|
||||
market: Market code (e.g., "KR", "US")
|
||||
stock_count: Number of stocks in the playbook
|
||||
scenario_count: Total number of scenarios
|
||||
token_count: Gemini token usage for the playbook
|
||||
"""
|
||||
if not self._filter.playbook:
|
||||
return
|
||||
message = (
|
||||
f"<b>Playbook Generated</b>\n"
|
||||
f"Market: {market}\n"
|
||||
f"Stocks: {stock_count}\n"
|
||||
f"Scenarios: {scenario_count}\n"
|
||||
f"Tokens: {token_count}"
|
||||
)
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=NotificationPriority.MEDIUM, message=message)
|
||||
)
|
||||
|
||||
async def notify_scenario_matched(
|
||||
self,
|
||||
stock_code: str,
|
||||
action: str,
|
||||
condition_summary: str,
|
||||
confidence: float,
|
||||
) -> None:
|
||||
"""
|
||||
Notify that a scenario matched for a stock.
|
||||
|
||||
Args:
|
||||
stock_code: Stock ticker symbol
|
||||
action: Scenario action (BUY/SELL/HOLD/REDUCE_ALL)
|
||||
condition_summary: Short summary of the matched condition
|
||||
confidence: Scenario confidence (0-100)
|
||||
"""
|
||||
if not self._filter.scenario_match:
|
||||
return
|
||||
message = (
|
||||
f"<b>Scenario Matched</b>\n"
|
||||
f"Symbol: <code>{stock_code}</code>\n"
|
||||
f"Action: {action}\n"
|
||||
f"Condition: {condition_summary}\n"
|
||||
f"Confidence: {confidence:.0f}%"
|
||||
)
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=NotificationPriority.HIGH, message=message)
|
||||
)
|
||||
|
||||
async def notify_playbook_failed(self, market: str, reason: str) -> None:
|
||||
"""
|
||||
Notify that playbook generation failed.
|
||||
|
||||
Args:
|
||||
market: Market code (e.g., "KR", "US")
|
||||
reason: Failure reason summary
|
||||
"""
|
||||
if not self._filter.playbook:
|
||||
return
|
||||
message = (
|
||||
f"<b>Playbook Failed</b>\n"
|
||||
f"Market: {market}\n"
|
||||
f"Reason: {reason[:200]}"
|
||||
)
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=NotificationPriority.HIGH, message=message)
|
||||
)
|
||||
|
||||
async def notify_system_shutdown(self, reason: str) -> None:
|
||||
"""
|
||||
Notify system shutdown.
|
||||
|
||||
Args:
|
||||
reason: Reason for shutdown (e.g., "Normal shutdown", "Circuit breaker")
|
||||
"""
|
||||
if not self._filter.system_events:
|
||||
return
|
||||
message = f"<b>System Shutdown</b>\n{reason}"
|
||||
priority = (
|
||||
NotificationPriority.CRITICAL
|
||||
if "circuit breaker" in reason.lower()
|
||||
else NotificationPriority.MEDIUM
|
||||
)
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=priority, message=message)
|
||||
)
|
||||
|
||||
async def notify_unfilled_order(
|
||||
self,
|
||||
stock_code: str,
|
||||
market: str,
|
||||
action: str,
|
||||
quantity: int,
|
||||
outcome: str,
|
||||
new_price: float | None = None,
|
||||
) -> None:
|
||||
"""Notify about an unfilled overseas order that was cancelled or resubmitted.
|
||||
|
||||
Args:
|
||||
stock_code: Stock ticker symbol.
|
||||
market: Exchange/market code (e.g., "NASD", "SEHK").
|
||||
action: "BUY" or "SELL".
|
||||
quantity: Unfilled quantity.
|
||||
outcome: "cancelled" or "resubmitted".
|
||||
new_price: New order price if resubmitted (None if only cancelled).
|
||||
"""
|
||||
if not self._filter.trades:
|
||||
return
|
||||
# SELL resubmit is high priority — position liquidation at risk.
|
||||
# BUY cancel is medium priority — only cash is freed.
|
||||
priority = (
|
||||
NotificationPriority.HIGH
|
||||
if action == "SELL"
|
||||
else NotificationPriority.MEDIUM
|
||||
)
|
||||
outcome_emoji = "🔄" if outcome == "resubmitted" else "❌"
|
||||
outcome_label = "재주문" if outcome == "resubmitted" else "취소됨"
|
||||
action_emoji = "🔴" if action == "SELL" else "🟢"
|
||||
lines = [
|
||||
f"<b>{outcome_emoji} 미체결 주문 {outcome_label}</b>",
|
||||
f"Symbol: <code>{stock_code}</code> ({market})",
|
||||
f"Action: {action_emoji} {action}",
|
||||
f"Quantity: {quantity:,} shares",
|
||||
]
|
||||
if new_price is not None:
|
||||
lines.append(f"New Price: {new_price:.4f}")
|
||||
message = "\n".join(lines)
|
||||
await self._send_notification(NotificationMessage(priority=priority, message=message))
|
||||
|
||||
async def notify_error(
|
||||
self, error_type: str, error_msg: str, context: str
|
||||
) -> None:
|
||||
"""
|
||||
Notify system error.
|
||||
|
||||
Args:
|
||||
error_type: Type of error (e.g., "Connection Error")
|
||||
error_msg: Error message
|
||||
context: Error context (e.g., stock code, market)
|
||||
"""
|
||||
if not self._filter.errors:
|
||||
return
|
||||
message = (
|
||||
f"<b>Error: {error_type}</b>\n"
|
||||
f"Context: {context}\n"
|
||||
f"Message: {error_msg[:200]}" # Truncate long errors
|
||||
)
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=NotificationPriority.HIGH, message=message)
|
||||
)
|
||||
|
||||
|
||||
class TelegramCommandHandler:
|
||||
"""Handles incoming Telegram commands via long polling."""
|
||||
|
||||
def __init__(
|
||||
self, client: TelegramClient, polling_interval: float = 1.0
|
||||
) -> None:
|
||||
"""
|
||||
Initialize command handler.
|
||||
|
||||
Args:
|
||||
client: TelegramClient instance for sending responses
|
||||
polling_interval: Polling interval in seconds
|
||||
"""
|
||||
self._client = client
|
||||
self._polling_interval = polling_interval
|
||||
self._commands: dict[str, Callable[[], Awaitable[None]]] = {}
|
||||
self._commands_with_args: dict[str, Callable[[list[str]], Awaitable[None]]] = {}
|
||||
self._last_update_id = 0
|
||||
self._polling_task: asyncio.Task[None] | None = None
|
||||
self._running = False
|
||||
|
||||
def register_command(
|
||||
self, command: str, handler: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
"""
|
||||
Register a command handler (no arguments).
|
||||
|
||||
Args:
|
||||
command: Command name (without leading slash, e.g., "start")
|
||||
handler: Async function to handle the command
|
||||
"""
|
||||
self._commands[command] = handler
|
||||
logger.debug("Registered command handler: /%s", command)
|
||||
|
||||
def register_command_with_args(
|
||||
self, command: str, handler: Callable[[list[str]], Awaitable[None]]
|
||||
) -> None:
|
||||
"""
|
||||
Register a command handler that receives trailing arguments.
|
||||
|
||||
Args:
|
||||
command: Command name (without leading slash, e.g., "notify")
|
||||
handler: Async function receiving list of argument tokens
|
||||
"""
|
||||
self._commands_with_args[command] = handler
|
||||
logger.debug("Registered command handler (with args): /%s", command)
|
||||
|
||||
async def start_polling(self) -> None:
|
||||
"""Start long polling for commands."""
|
||||
if self._running:
|
||||
logger.warning("Command handler already running")
|
||||
return
|
||||
|
||||
if not self._client._enabled:
|
||||
logger.info("Command handler disabled (TelegramClient disabled)")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._polling_task = asyncio.create_task(self._poll_loop())
|
||||
logger.info("Started Telegram command polling")
|
||||
|
||||
async def stop_polling(self) -> None:
|
||||
"""Stop polling and cancel pending tasks."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
if self._polling_task:
|
||||
self._polling_task.cancel()
|
||||
try:
|
||||
await self._polling_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("Stopped Telegram command polling")
|
||||
|
||||
async def _poll_loop(self) -> None:
|
||||
"""Main polling loop that fetches updates."""
|
||||
while self._running:
|
||||
try:
|
||||
updates = await self._get_updates()
|
||||
for update in updates:
|
||||
await self._handle_update(update)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as exc:
|
||||
logger.error("Error in polling loop: %s", exc)
|
||||
|
||||
await asyncio.sleep(self._polling_interval)
|
||||
|
||||
async def _get_updates(self) -> list[dict]:
|
||||
"""
|
||||
Fetch updates from Telegram API.
|
||||
|
||||
Returns:
|
||||
List of update objects
|
||||
"""
|
||||
try:
|
||||
url = f"{self._client.API_BASE.format(token=self._client._bot_token)}/getUpdates"
|
||||
payload = {
|
||||
"offset": self._last_update_id + 1,
|
||||
"timeout": int(self._polling_interval),
|
||||
"allowed_updates": ["message"],
|
||||
}
|
||||
|
||||
session = self._client._get_session()
|
||||
async with session.post(url, json=payload) as resp:
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
if resp.status == 409:
|
||||
# Another bot instance is already polling — stop this poller entirely.
|
||||
# Retrying would keep conflicting with the other instance.
|
||||
self._running = False
|
||||
logger.warning(
|
||||
"Telegram conflict (409): another instance is already polling. "
|
||||
"Disabling Telegram commands for this process. "
|
||||
"Ensure only one instance of The Ouroboros is running at a time.",
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
"getUpdates API error (status=%d): %s", resp.status, error_text
|
||||
)
|
||||
return []
|
||||
|
||||
data = await resp.json()
|
||||
if not data.get("ok"):
|
||||
logger.error("getUpdates returned ok=false: %s", data)
|
||||
return []
|
||||
|
||||
updates = data.get("result", [])
|
||||
if updates:
|
||||
self._last_update_id = updates[-1]["update_id"]
|
||||
|
||||
return updates
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug("getUpdates timeout (normal)")
|
||||
return []
|
||||
except aiohttp.ClientError as exc:
|
||||
logger.error("getUpdates failed: %s", exc)
|
||||
return []
|
||||
except Exception as exc:
|
||||
logger.error("Unexpected error in _get_updates: %s", exc)
|
||||
return []
|
||||
|
||||
async def _handle_update(self, update: dict) -> None:
|
||||
"""
|
||||
Parse and handle a single update.
|
||||
|
||||
Args:
|
||||
update: Update object from Telegram API
|
||||
"""
|
||||
try:
|
||||
message = update.get("message")
|
||||
if not message:
|
||||
return
|
||||
|
||||
# Verify chat_id matches configured chat
|
||||
chat_id = str(message.get("chat", {}).get("id", ""))
|
||||
if chat_id != self._client._chat_id:
|
||||
logger.warning(
|
||||
"Ignoring command from unauthorized chat_id: %s", chat_id
|
||||
)
|
||||
return
|
||||
|
||||
# Extract command text
|
||||
text = message.get("text", "").strip()
|
||||
if not text.startswith("/"):
|
||||
return
|
||||
|
||||
# Parse command (remove leading slash and extract command name)
|
||||
command_parts = text[1:].split()
|
||||
if not command_parts:
|
||||
return
|
||||
|
||||
# Remove @botname suffix if present (for group chats)
|
||||
command_name = command_parts[0].split("@")[0]
|
||||
|
||||
# Execute handler (args-aware handlers take priority)
|
||||
args_handler = self._commands_with_args.get(command_name)
|
||||
if args_handler:
|
||||
logger.info("Executing command: /%s %s", command_name, command_parts[1:])
|
||||
await args_handler(command_parts[1:])
|
||||
elif command_name in self._commands:
|
||||
logger.info("Executing command: /%s", command_name)
|
||||
await self._commands[command_name]()
|
||||
else:
|
||||
logger.debug("Unknown command: /%s", command_name)
|
||||
await self._client.send_message(
|
||||
f"Unknown command: /{command_name}\nUse /help to see available commands."
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error handling update: %s", exc)
|
||||
# Don't crash the polling loop on handler errors
|
||||
0
src/strategy/__init__.py
Normal file
0
src/strategy/__init__.py
Normal file
104
src/strategy/exit_rules.py
Normal file
104
src/strategy/exit_rules.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Composite exit rules: hard stop, break-even lock, ATR trailing, model assist."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from src.strategy.position_state_machine import PositionState, StateTransitionInput, promote_state
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExitRuleConfig:
|
||||
hard_stop_pct: float = -2.0
|
||||
be_arm_pct: float = 1.2
|
||||
arm_pct: float = 3.0
|
||||
atr_multiplier_k: float = 2.2
|
||||
model_prob_threshold: float = 0.62
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExitRuleInput:
|
||||
current_price: float
|
||||
entry_price: float
|
||||
peak_price: float
|
||||
atr_value: float = 0.0
|
||||
pred_down_prob: float = 0.0
|
||||
liquidity_weak: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExitEvaluation:
|
||||
state: PositionState
|
||||
should_exit: bool
|
||||
reason: str
|
||||
unrealized_pnl_pct: float
|
||||
trailing_stop_price: float | None
|
||||
|
||||
|
||||
def evaluate_exit(
|
||||
*,
|
||||
current_state: PositionState,
|
||||
config: ExitRuleConfig,
|
||||
inp: ExitRuleInput,
|
||||
) -> ExitEvaluation:
|
||||
"""Evaluate composite exit logic and return updated state."""
|
||||
if inp.entry_price <= 0 or inp.current_price <= 0:
|
||||
return ExitEvaluation(
|
||||
state=current_state,
|
||||
should_exit=False,
|
||||
reason="invalid_price",
|
||||
unrealized_pnl_pct=0.0,
|
||||
trailing_stop_price=None,
|
||||
)
|
||||
|
||||
unrealized = (inp.current_price - inp.entry_price) / inp.entry_price * 100.0
|
||||
hard_stop_hit = unrealized <= config.hard_stop_pct
|
||||
take_profit_hit = unrealized >= config.arm_pct
|
||||
|
||||
trailing_stop_price: float | None = None
|
||||
trailing_stop_hit = False
|
||||
if inp.atr_value > 0 and inp.peak_price > 0:
|
||||
trailing_stop_price = inp.peak_price - (config.atr_multiplier_k * inp.atr_value)
|
||||
trailing_stop_hit = inp.current_price <= trailing_stop_price
|
||||
|
||||
be_lock_threat = current_state in (PositionState.BE_LOCK, PositionState.ARMED) and (
|
||||
inp.current_price <= inp.entry_price
|
||||
)
|
||||
model_exit_signal = inp.pred_down_prob >= config.model_prob_threshold and inp.liquidity_weak
|
||||
|
||||
next_state = promote_state(
|
||||
current=current_state,
|
||||
inp=StateTransitionInput(
|
||||
unrealized_pnl_pct=unrealized,
|
||||
be_arm_pct=config.be_arm_pct,
|
||||
arm_pct=config.arm_pct,
|
||||
hard_stop_hit=hard_stop_hit,
|
||||
trailing_stop_hit=trailing_stop_hit,
|
||||
model_exit_signal=model_exit_signal,
|
||||
be_lock_threat=be_lock_threat,
|
||||
),
|
||||
)
|
||||
|
||||
if hard_stop_hit:
|
||||
reason = "hard_stop"
|
||||
elif trailing_stop_hit:
|
||||
reason = "atr_trailing_stop"
|
||||
elif be_lock_threat:
|
||||
reason = "be_lock_threat"
|
||||
elif model_exit_signal:
|
||||
reason = "model_liquidity_exit"
|
||||
elif take_profit_hit:
|
||||
# Backward-compatible immediate profit-taking path.
|
||||
reason = "arm_take_profit"
|
||||
else:
|
||||
reason = "hold"
|
||||
|
||||
should_exit = next_state == PositionState.EXITED or take_profit_hit
|
||||
|
||||
return ExitEvaluation(
|
||||
state=next_state,
|
||||
should_exit=should_exit,
|
||||
reason=reason,
|
||||
unrealized_pnl_pct=unrealized,
|
||||
trailing_stop_price=trailing_stop_price,
|
||||
)
|
||||
184
src/strategy/models.py
Normal file
184
src/strategy/models.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""Pydantic models for pre-market scenario planning.
|
||||
|
||||
Defines the data contracts for the proactive strategy system:
|
||||
- AI generates DayPlaybook before market open (structured JSON scenarios)
|
||||
- Local ScenarioEngine matches conditions during market hours (no API calls)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, date, datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class ScenarioAction(str, Enum):
|
||||
"""Actions that can be taken by scenarios."""
|
||||
|
||||
BUY = "BUY"
|
||||
SELL = "SELL"
|
||||
HOLD = "HOLD"
|
||||
REDUCE_ALL = "REDUCE_ALL"
|
||||
|
||||
|
||||
class MarketOutlook(str, Enum):
|
||||
"""AI's assessment of market direction."""
|
||||
|
||||
BULLISH = "bullish"
|
||||
NEUTRAL_TO_BULLISH = "neutral_to_bullish"
|
||||
NEUTRAL = "neutral"
|
||||
NEUTRAL_TO_BEARISH = "neutral_to_bearish"
|
||||
BEARISH = "bearish"
|
||||
|
||||
|
||||
class PlaybookStatus(str, Enum):
|
||||
"""Lifecycle status of a playbook."""
|
||||
|
||||
PENDING = "pending"
|
||||
READY = "ready"
|
||||
FAILED = "failed"
|
||||
EXPIRED = "expired"
|
||||
|
||||
|
||||
class StockCondition(BaseModel):
|
||||
"""Condition fields for scenario matching (all optional, AND-combined).
|
||||
|
||||
The ScenarioEngine evaluates all non-None fields as AND conditions.
|
||||
A condition matches only if ALL specified fields are satisfied.
|
||||
|
||||
Technical indicator fields:
|
||||
rsi_below / rsi_above — RSI threshold
|
||||
volume_ratio_above / volume_ratio_below — volume vs previous day
|
||||
price_above / price_below — absolute price level
|
||||
price_change_pct_above / price_change_pct_below — intraday % change
|
||||
|
||||
Position-aware fields (require market_data enrichment from open position):
|
||||
unrealized_pnl_pct_above — matches if unrealized P&L > threshold (e.g. 3.0 → +3%)
|
||||
unrealized_pnl_pct_below — matches if unrealized P&L < threshold (e.g. -2.0 → -2%)
|
||||
holding_days_above — matches if position held for more than N days
|
||||
holding_days_below — matches if position held for fewer than N days
|
||||
"""
|
||||
|
||||
rsi_below: float | None = None
|
||||
rsi_above: float | None = None
|
||||
volume_ratio_above: float | None = None
|
||||
volume_ratio_below: float | None = None
|
||||
price_above: float | None = None
|
||||
price_below: float | None = None
|
||||
price_change_pct_above: float | None = None
|
||||
price_change_pct_below: float | None = None
|
||||
unrealized_pnl_pct_above: float | None = None
|
||||
unrealized_pnl_pct_below: float | None = None
|
||||
holding_days_above: int | None = None
|
||||
holding_days_below: int | None = None
|
||||
|
||||
def has_any_condition(self) -> bool:
|
||||
"""Check if at least one condition field is set."""
|
||||
return any(
|
||||
v is not None
|
||||
for v in (
|
||||
self.rsi_below,
|
||||
self.rsi_above,
|
||||
self.volume_ratio_above,
|
||||
self.volume_ratio_below,
|
||||
self.price_above,
|
||||
self.price_below,
|
||||
self.price_change_pct_above,
|
||||
self.price_change_pct_below,
|
||||
self.unrealized_pnl_pct_above,
|
||||
self.unrealized_pnl_pct_below,
|
||||
self.holding_days_above,
|
||||
self.holding_days_below,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class StockScenario(BaseModel):
|
||||
"""A single condition-action rule for one stock."""
|
||||
|
||||
condition: StockCondition
|
||||
action: ScenarioAction
|
||||
confidence: int = Field(ge=0, le=100)
|
||||
allocation_pct: float = Field(ge=0, le=100, default=10.0)
|
||||
stop_loss_pct: float = Field(le=0, default=-2.0)
|
||||
take_profit_pct: float = Field(ge=0, default=3.0)
|
||||
rationale: str = ""
|
||||
|
||||
|
||||
class StockPlaybook(BaseModel):
|
||||
"""All scenarios for a single stock (ordered by priority)."""
|
||||
|
||||
stock_code: str
|
||||
stock_name: str = ""
|
||||
scenarios: list[StockScenario] = Field(min_length=1)
|
||||
|
||||
|
||||
class GlobalRule(BaseModel):
|
||||
"""Portfolio-level rule (checked before stock-level scenarios)."""
|
||||
|
||||
condition: str # e.g. "portfolio_pnl_pct < -2.0"
|
||||
action: ScenarioAction
|
||||
rationale: str = ""
|
||||
|
||||
|
||||
class CrossMarketContext(BaseModel):
|
||||
"""Summary of another market's state for cross-market awareness."""
|
||||
|
||||
market: str # e.g. "US" or "KR"
|
||||
date: str
|
||||
total_pnl: float = 0.0
|
||||
win_rate: float = 0.0
|
||||
index_change_pct: float = 0.0 # e.g. KOSPI or S&P500 change
|
||||
key_events: list[str] = Field(default_factory=list)
|
||||
lessons: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DayPlaybook(BaseModel):
|
||||
"""Complete playbook for a single trading day in a single market.
|
||||
|
||||
Generated by PreMarketPlanner (1 Gemini call per market per day).
|
||||
Consumed by ScenarioEngine during market hours (0 API calls).
|
||||
"""
|
||||
|
||||
date: date
|
||||
market: str # "KR" or "US"
|
||||
market_outlook: MarketOutlook = MarketOutlook.NEUTRAL
|
||||
generated_at: str = "" # ISO timestamp
|
||||
gemini_model: str = ""
|
||||
token_count: int = 0
|
||||
global_rules: list[GlobalRule] = Field(default_factory=list)
|
||||
stock_playbooks: list[StockPlaybook] = Field(default_factory=list)
|
||||
default_action: ScenarioAction = ScenarioAction.HOLD
|
||||
context_summary: dict = Field(default_factory=dict)
|
||||
cross_market: CrossMarketContext | None = None
|
||||
|
||||
@field_validator("stock_playbooks")
|
||||
@classmethod
|
||||
def validate_unique_stocks(cls, v: list[StockPlaybook]) -> list[StockPlaybook]:
|
||||
codes = [pb.stock_code for pb in v]
|
||||
if len(codes) != len(set(codes)):
|
||||
raise ValueError("Duplicate stock codes in playbook")
|
||||
return v
|
||||
|
||||
def get_stock_playbook(self, stock_code: str) -> StockPlaybook | None:
|
||||
"""Find the playbook for a specific stock."""
|
||||
for pb in self.stock_playbooks:
|
||||
if pb.stock_code == stock_code:
|
||||
return pb
|
||||
return None
|
||||
|
||||
@property
|
||||
def scenario_count(self) -> int:
|
||||
"""Total number of scenarios across all stocks."""
|
||||
return sum(len(pb.scenarios) for pb in self.stock_playbooks)
|
||||
|
||||
@property
|
||||
def stock_count(self) -> int:
|
||||
"""Number of stocks with scenarios."""
|
||||
return len(self.stock_playbooks)
|
||||
|
||||
def model_post_init(self, __context: object) -> None:
|
||||
"""Set generated_at if not provided."""
|
||||
if not self.generated_at:
|
||||
self.generated_at = datetime.now(UTC).isoformat()
|
||||
184
src/strategy/playbook_store.py
Normal file
184
src/strategy/playbook_store.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""Playbook persistence layer — CRUD for DayPlaybook in SQLite.
|
||||
|
||||
Stores and retrieves market-specific daily playbooks with JSON serialization.
|
||||
Designed for the pre-market strategy system (one playbook per market per day).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
from datetime import date
|
||||
|
||||
from src.strategy.models import DayPlaybook, PlaybookStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PlaybookStore:
|
||||
"""CRUD operations for DayPlaybook persistence."""
|
||||
|
||||
def __init__(self, conn: sqlite3.Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def save(self, playbook: DayPlaybook) -> int:
|
||||
"""Save or replace a playbook for a given date+market.
|
||||
|
||||
Uses INSERT OR REPLACE to enforce UNIQUE(date, market).
|
||||
|
||||
Returns:
|
||||
The row id of the inserted/replaced record.
|
||||
"""
|
||||
playbook_json = playbook.model_dump_json()
|
||||
cursor = self._conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO playbooks
|
||||
(date, market, status, playbook_json, generated_at,
|
||||
token_count, scenario_count, match_count)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
playbook.date.isoformat(),
|
||||
playbook.market,
|
||||
PlaybookStatus.READY.value,
|
||||
playbook_json,
|
||||
playbook.generated_at,
|
||||
playbook.token_count,
|
||||
playbook.scenario_count,
|
||||
0,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
row_id = cursor.lastrowid or 0
|
||||
logger.info(
|
||||
"Saved playbook for %s/%s (%d stocks, %d scenarios)",
|
||||
playbook.date, playbook.market,
|
||||
playbook.stock_count, playbook.scenario_count,
|
||||
)
|
||||
return row_id
|
||||
|
||||
def load(self, target_date: date, market: str) -> DayPlaybook | None:
|
||||
"""Load a playbook for a specific date and market.
|
||||
|
||||
Returns:
|
||||
DayPlaybook if found, None otherwise.
|
||||
"""
|
||||
row = self._conn.execute(
|
||||
"SELECT playbook_json FROM playbooks WHERE date = ? AND market = ?",
|
||||
(target_date.isoformat(), market),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return DayPlaybook.model_validate_json(row[0])
|
||||
|
||||
def get_status(self, target_date: date, market: str) -> PlaybookStatus | None:
|
||||
"""Get the status of a playbook without deserializing the full JSON."""
|
||||
row = self._conn.execute(
|
||||
"SELECT status FROM playbooks WHERE date = ? AND market = ?",
|
||||
(target_date.isoformat(), market),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return PlaybookStatus(row[0])
|
||||
|
||||
def update_status(self, target_date: date, market: str, status: PlaybookStatus) -> bool:
|
||||
"""Update the status of a playbook.
|
||||
|
||||
Returns:
|
||||
True if a row was updated, False if not found.
|
||||
"""
|
||||
cursor = self._conn.execute(
|
||||
"UPDATE playbooks SET status = ? WHERE date = ? AND market = ?",
|
||||
(status.value, target_date.isoformat(), market),
|
||||
)
|
||||
self._conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def increment_match_count(self, target_date: date, market: str) -> bool:
|
||||
"""Increment the match_count for tracking scenario hits during the day.
|
||||
|
||||
Returns:
|
||||
True if a row was updated, False if not found.
|
||||
"""
|
||||
cursor = self._conn.execute(
|
||||
"UPDATE playbooks SET match_count = match_count + 1 WHERE date = ? AND market = ?",
|
||||
(target_date.isoformat(), market),
|
||||
)
|
||||
self._conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def get_stats(self, target_date: date, market: str) -> dict | None:
|
||||
"""Get playbook stats without full deserialization.
|
||||
|
||||
Returns:
|
||||
Dict with status, token_count, scenario_count, match_count, or None.
|
||||
"""
|
||||
row = self._conn.execute(
|
||||
"""
|
||||
SELECT status, token_count, scenario_count, match_count, generated_at
|
||||
FROM playbooks WHERE date = ? AND market = ?
|
||||
""",
|
||||
(target_date.isoformat(), market),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return {
|
||||
"status": row[0],
|
||||
"token_count": row[1],
|
||||
"scenario_count": row[2],
|
||||
"match_count": row[3],
|
||||
"generated_at": row[4],
|
||||
}
|
||||
|
||||
def list_recent(self, market: str | None = None, limit: int = 7) -> list[dict]:
|
||||
"""List recent playbooks with summary info.
|
||||
|
||||
Args:
|
||||
market: Filter by market code. None for all markets.
|
||||
limit: Max number of results.
|
||||
|
||||
Returns:
|
||||
List of dicts with date, market, status, scenario_count, match_count.
|
||||
"""
|
||||
if market is not None:
|
||||
rows = self._conn.execute(
|
||||
"""
|
||||
SELECT date, market, status, scenario_count, match_count
|
||||
FROM playbooks WHERE market = ?
|
||||
ORDER BY date DESC LIMIT ?
|
||||
""",
|
||||
(market, limit),
|
||||
).fetchall()
|
||||
else:
|
||||
rows = self._conn.execute(
|
||||
"""
|
||||
SELECT date, market, status, scenario_count, match_count
|
||||
FROM playbooks
|
||||
ORDER BY date DESC LIMIT ?
|
||||
""",
|
||||
(limit,),
|
||||
).fetchall()
|
||||
return [
|
||||
{
|
||||
"date": row[0],
|
||||
"market": row[1],
|
||||
"status": row[2],
|
||||
"scenario_count": row[3],
|
||||
"match_count": row[4],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def delete(self, target_date: date, market: str) -> bool:
|
||||
"""Delete a playbook.
|
||||
|
||||
Returns:
|
||||
True if a row was deleted, False if not found.
|
||||
"""
|
||||
cursor = self._conn.execute(
|
||||
"DELETE FROM playbooks WHERE date = ? AND market = ?",
|
||||
(target_date.isoformat(), market),
|
||||
)
|
||||
self._conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
70
src/strategy/position_state_machine.py
Normal file
70
src/strategy/position_state_machine.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Position state machine for staged exit control.
|
||||
|
||||
State progression is monotonic (promotion-only) except terminal EXITED.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class PositionState(str, Enum):
|
||||
HOLDING = "HOLDING"
|
||||
BE_LOCK = "BE_LOCK"
|
||||
ARMED = "ARMED"
|
||||
EXITED = "EXITED"
|
||||
|
||||
|
||||
_STATE_RANK: dict[PositionState, int] = {
|
||||
PositionState.HOLDING: 0,
|
||||
PositionState.BE_LOCK: 1,
|
||||
PositionState.ARMED: 2,
|
||||
PositionState.EXITED: 3,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StateTransitionInput:
|
||||
unrealized_pnl_pct: float
|
||||
be_arm_pct: float
|
||||
arm_pct: float
|
||||
hard_stop_hit: bool = False
|
||||
trailing_stop_hit: bool = False
|
||||
model_exit_signal: bool = False
|
||||
be_lock_threat: bool = False
|
||||
|
||||
|
||||
def evaluate_exit_first(inp: StateTransitionInput) -> bool:
|
||||
"""Return True when terminal exit conditions are met.
|
||||
|
||||
EXITED must be evaluated before any promotion.
|
||||
"""
|
||||
return (
|
||||
inp.hard_stop_hit
|
||||
or inp.trailing_stop_hit
|
||||
or inp.model_exit_signal
|
||||
or inp.be_lock_threat
|
||||
)
|
||||
|
||||
|
||||
def promote_state(current: PositionState, inp: StateTransitionInput) -> PositionState:
|
||||
"""Promote to highest admissible state for current tick/bar.
|
||||
|
||||
Rules:
|
||||
- EXITED has highest precedence and is terminal.
|
||||
- Promotions are monotonic (no downgrade).
|
||||
"""
|
||||
if current == PositionState.EXITED:
|
||||
return PositionState.EXITED
|
||||
|
||||
if evaluate_exit_first(inp):
|
||||
return PositionState.EXITED
|
||||
|
||||
target = PositionState.HOLDING
|
||||
if inp.unrealized_pnl_pct >= inp.arm_pct:
|
||||
target = PositionState.ARMED
|
||||
elif inp.unrealized_pnl_pct >= inp.be_arm_pct:
|
||||
target = PositionState.BE_LOCK
|
||||
|
||||
return target if _STATE_RANK[target] > _STATE_RANK[current] else current
|
||||
620
src/strategy/pre_market_planner.py
Normal file
620
src/strategy/pre_market_planner.py
Normal file
@@ -0,0 +1,620 @@
|
||||
"""Pre-market planner — generates DayPlaybook via Gemini before market open.
|
||||
|
||||
One Gemini API call per market per day. Candidates come from SmartVolatilityScanner.
|
||||
On failure, returns a smart rule-based fallback playbook that uses scanner signals
|
||||
(momentum/oversold) to generate BUY conditions, avoiding the all-HOLD problem.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import date, timedelta
|
||||
from typing import Any
|
||||
|
||||
from src.analysis.smart_scanner import ScanCandidate
|
||||
from src.brain.context_selector import ContextSelector, DecisionType
|
||||
from src.brain.gemini_client import GeminiClient
|
||||
from src.config import Settings
|
||||
from src.context.store import ContextLayer, ContextStore
|
||||
from src.strategy.models import (
|
||||
CrossMarketContext,
|
||||
DayPlaybook,
|
||||
GlobalRule,
|
||||
MarketOutlook,
|
||||
ScenarioAction,
|
||||
StockCondition,
|
||||
StockPlaybook,
|
||||
StockScenario,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Mapping from string to MarketOutlook enum
|
||||
_OUTLOOK_MAP: dict[str, MarketOutlook] = {
|
||||
"bullish": MarketOutlook.BULLISH,
|
||||
"neutral_to_bullish": MarketOutlook.NEUTRAL_TO_BULLISH,
|
||||
"neutral": MarketOutlook.NEUTRAL,
|
||||
"neutral_to_bearish": MarketOutlook.NEUTRAL_TO_BEARISH,
|
||||
"bearish": MarketOutlook.BEARISH,
|
||||
}
|
||||
|
||||
_ACTION_MAP: dict[str, ScenarioAction] = {
|
||||
"BUY": ScenarioAction.BUY,
|
||||
"SELL": ScenarioAction.SELL,
|
||||
"HOLD": ScenarioAction.HOLD,
|
||||
"REDUCE_ALL": ScenarioAction.REDUCE_ALL,
|
||||
}
|
||||
|
||||
|
||||
class PreMarketPlanner:
|
||||
"""Generates a DayPlaybook by calling Gemini once before market open.
|
||||
|
||||
Flow:
|
||||
1. Collect strategic context (L5-L7) + cross-market context
|
||||
2. Build a structured prompt with scan candidates
|
||||
3. Call Gemini for JSON scenario generation
|
||||
4. Parse and validate response into DayPlaybook
|
||||
5. On failure → defensive playbook (HOLD everything)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gemini_client: GeminiClient,
|
||||
context_store: ContextStore,
|
||||
context_selector: ContextSelector,
|
||||
settings: Settings,
|
||||
) -> None:
|
||||
self._gemini = gemini_client
|
||||
self._context_store = context_store
|
||||
self._context_selector = context_selector
|
||||
self._settings = settings
|
||||
|
||||
async def generate_playbook(
|
||||
self,
|
||||
market: str,
|
||||
candidates: list[ScanCandidate],
|
||||
today: date | None = None,
|
||||
current_holdings: list[dict] | None = None,
|
||||
) -> DayPlaybook:
|
||||
"""Generate a DayPlaybook for a market using Gemini.
|
||||
|
||||
Args:
|
||||
market: Market code ("KR" or "US")
|
||||
candidates: Stock candidates from SmartVolatilityScanner
|
||||
today: Override date (defaults to date.today()). Use market-local date.
|
||||
current_holdings: Currently held positions with entry_price and unrealized_pnl_pct.
|
||||
Each dict: {"stock_code": str, "name": str, "qty": int,
|
||||
"entry_price": float, "unrealized_pnl_pct": float,
|
||||
"holding_days": int}
|
||||
|
||||
Returns:
|
||||
DayPlaybook with scenarios. Empty/defensive if no candidates or failure.
|
||||
"""
|
||||
if today is None:
|
||||
today = date.today()
|
||||
|
||||
if not candidates:
|
||||
logger.info("No candidates for %s — returning empty playbook", market)
|
||||
return self._empty_playbook(today, market)
|
||||
|
||||
try:
|
||||
# 1. Gather context
|
||||
context_data = self._gather_context()
|
||||
self_market_scorecard = self.build_self_market_scorecard(market, today)
|
||||
cross_market = self.build_cross_market_context(market, today)
|
||||
|
||||
# 2. Build prompt
|
||||
prompt = self._build_prompt(
|
||||
market,
|
||||
candidates,
|
||||
context_data,
|
||||
self_market_scorecard,
|
||||
cross_market,
|
||||
current_holdings=current_holdings,
|
||||
)
|
||||
|
||||
# 3. Call Gemini
|
||||
market_data = {
|
||||
"stock_code": "PLANNER",
|
||||
"current_price": 0,
|
||||
"prompt_override": prompt,
|
||||
}
|
||||
decision = await self._gemini.decide(market_data)
|
||||
|
||||
# 4. Parse response
|
||||
playbook = self._parse_response(
|
||||
decision.rationale, today, market, candidates, cross_market,
|
||||
current_holdings=current_holdings,
|
||||
)
|
||||
playbook_with_tokens = playbook.model_copy(
|
||||
update={"token_count": decision.token_count}
|
||||
)
|
||||
logger.info(
|
||||
"Generated playbook for %s: %d stocks, %d scenarios, %d tokens",
|
||||
market,
|
||||
playbook_with_tokens.stock_count,
|
||||
playbook_with_tokens.scenario_count,
|
||||
playbook_with_tokens.token_count,
|
||||
)
|
||||
return playbook_with_tokens
|
||||
|
||||
except Exception:
|
||||
logger.exception("Playbook generation failed for %s", market)
|
||||
if self._settings.DEFENSIVE_PLAYBOOK_ON_FAILURE:
|
||||
return self._smart_fallback_playbook(today, market, candidates, self._settings)
|
||||
return self._empty_playbook(today, market)
|
||||
|
||||
def build_cross_market_context(
|
||||
self, target_market: str, today: date | None = None,
|
||||
) -> CrossMarketContext | None:
|
||||
"""Build cross-market context from the other market's L6 data.
|
||||
|
||||
KR planner → reads US scorecard from previous night.
|
||||
US planner → reads KR scorecard from today.
|
||||
|
||||
Args:
|
||||
target_market: The market being planned ("KR" or "US")
|
||||
today: Override date (defaults to date.today()). Use market-local date.
|
||||
"""
|
||||
other_market = "US" if target_market == "KR" else "KR"
|
||||
if today is None:
|
||||
today = date.today()
|
||||
timeframe_date = today - timedelta(days=1) if target_market == "KR" else today
|
||||
timeframe = timeframe_date.isoformat()
|
||||
|
||||
scorecard_key = f"scorecard_{other_market}"
|
||||
scorecard_data = self._context_store.get_context(
|
||||
ContextLayer.L6_DAILY, timeframe, scorecard_key
|
||||
)
|
||||
|
||||
if scorecard_data is None:
|
||||
logger.debug("No cross-market scorecard found for %s", other_market)
|
||||
return None
|
||||
|
||||
if isinstance(scorecard_data, str):
|
||||
try:
|
||||
scorecard_data = json.loads(scorecard_data)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None
|
||||
|
||||
if not isinstance(scorecard_data, dict):
|
||||
return None
|
||||
|
||||
return CrossMarketContext(
|
||||
market=other_market,
|
||||
date=timeframe,
|
||||
total_pnl=float(scorecard_data.get("total_pnl", 0.0)),
|
||||
win_rate=float(scorecard_data.get("win_rate", 0.0)),
|
||||
index_change_pct=float(scorecard_data.get("index_change_pct", 0.0)),
|
||||
key_events=scorecard_data.get("key_events", []),
|
||||
lessons=scorecard_data.get("lessons", []),
|
||||
)
|
||||
|
||||
def build_self_market_scorecard(
|
||||
self, market: str, today: date | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Build previous-day scorecard for the same market."""
|
||||
if today is None:
|
||||
today = date.today()
|
||||
timeframe = (today - timedelta(days=1)).isoformat()
|
||||
scorecard_key = f"scorecard_{market}"
|
||||
scorecard_data = self._context_store.get_context(
|
||||
ContextLayer.L6_DAILY, timeframe, scorecard_key
|
||||
)
|
||||
|
||||
if scorecard_data is None:
|
||||
return None
|
||||
|
||||
if isinstance(scorecard_data, str):
|
||||
try:
|
||||
scorecard_data = json.loads(scorecard_data)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None
|
||||
|
||||
if not isinstance(scorecard_data, dict):
|
||||
return None
|
||||
|
||||
return {
|
||||
"date": timeframe,
|
||||
"total_pnl": float(scorecard_data.get("total_pnl", 0.0)),
|
||||
"win_rate": float(scorecard_data.get("win_rate", 0.0)),
|
||||
"lessons": scorecard_data.get("lessons", []),
|
||||
}
|
||||
|
||||
def _gather_context(self) -> dict[str, Any]:
|
||||
"""Gather strategic context using ContextSelector."""
|
||||
layers = self._context_selector.select_layers(
|
||||
decision_type=DecisionType.STRATEGIC,
|
||||
include_realtime=True,
|
||||
)
|
||||
return self._context_selector.get_context_data(layers, max_items_per_layer=10)
|
||||
|
||||
def _build_prompt(
|
||||
self,
|
||||
market: str,
|
||||
candidates: list[ScanCandidate],
|
||||
context_data: dict[str, Any],
|
||||
self_market_scorecard: dict[str, Any] | None,
|
||||
cross_market: CrossMarketContext | None,
|
||||
current_holdings: list[dict] | None = None,
|
||||
) -> str:
|
||||
"""Build a structured prompt for Gemini to generate scenario JSON."""
|
||||
max_scenarios = self._settings.MAX_SCENARIOS_PER_STOCK
|
||||
|
||||
candidates_text = "\n".join(
|
||||
f" - {c.stock_code} ({c.name}): price={c.price}, "
|
||||
f"RSI={c.rsi:.1f}, volume_ratio={c.volume_ratio:.1f}, "
|
||||
f"signal={c.signal}, score={c.score:.1f}"
|
||||
for c in candidates
|
||||
)
|
||||
|
||||
holdings_text = ""
|
||||
if current_holdings:
|
||||
lines = []
|
||||
for h in current_holdings:
|
||||
code = h.get("stock_code", "")
|
||||
name = h.get("name", "")
|
||||
qty = h.get("qty", 0)
|
||||
entry_price = h.get("entry_price", 0.0)
|
||||
pnl_pct = h.get("unrealized_pnl_pct", 0.0)
|
||||
holding_days = h.get("holding_days", 0)
|
||||
lines.append(
|
||||
f" - {code} ({name}): {qty}주 @ {entry_price:,.0f}, "
|
||||
f"미실현손익 {pnl_pct:+.2f}%, 보유 {holding_days}일"
|
||||
)
|
||||
holdings_text = (
|
||||
"\n## Current Holdings (보유 중 — SELL/HOLD 전략 고려 필요)\n"
|
||||
+ "\n".join(lines)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
cross_market_text = ""
|
||||
if cross_market:
|
||||
cross_market_text = (
|
||||
f"\n## Other Market ({cross_market.market}) Summary\n"
|
||||
f"- P&L: {cross_market.total_pnl:+.2f}%\n"
|
||||
f"- Win Rate: {cross_market.win_rate:.0f}%\n"
|
||||
f"- Index Change: {cross_market.index_change_pct:+.2f}%\n"
|
||||
)
|
||||
if cross_market.lessons:
|
||||
cross_market_text += f"- Lessons: {'; '.join(cross_market.lessons[:3])}\n"
|
||||
|
||||
self_market_text = ""
|
||||
if self_market_scorecard:
|
||||
self_market_text = (
|
||||
f"\n## My Market Previous Day ({market})\n"
|
||||
f"- Date: {self_market_scorecard['date']}\n"
|
||||
f"- P&L: {self_market_scorecard['total_pnl']:+.2f}%\n"
|
||||
f"- Win Rate: {self_market_scorecard['win_rate']:.0f}%\n"
|
||||
)
|
||||
lessons = self_market_scorecard.get("lessons", [])
|
||||
if lessons:
|
||||
self_market_text += f"- Lessons: {'; '.join(lessons[:3])}\n"
|
||||
|
||||
context_text = ""
|
||||
if context_data:
|
||||
context_text = "\n## Strategic Context\n"
|
||||
for layer_name, layer_data in context_data.items():
|
||||
if layer_data:
|
||||
context_text += f"### {layer_name}\n"
|
||||
for key, value in list(layer_data.items())[:5]:
|
||||
context_text += f" - {key}: {value}\n"
|
||||
|
||||
holdings_instruction = ""
|
||||
if current_holdings:
|
||||
holding_codes = [h.get("stock_code", "") for h in current_holdings]
|
||||
holdings_instruction = (
|
||||
f"- Also include SELL/HOLD scenarios for held stocks: "
|
||||
f"{', '.join(holding_codes)} "
|
||||
f"(even if not in candidates list)\n"
|
||||
)
|
||||
|
||||
return (
|
||||
f"You are a pre-market trading strategist for the {market} market.\n"
|
||||
f"Generate structured trading scenarios for today.\n\n"
|
||||
f"## Candidates (from volatility scanner)\n{candidates_text}\n"
|
||||
f"{holdings_text}"
|
||||
f"{self_market_text}"
|
||||
f"{cross_market_text}"
|
||||
f"{context_text}\n"
|
||||
f"## Instructions\n"
|
||||
f"Return a JSON object with this exact structure:\n"
|
||||
f'{{\n'
|
||||
f' "market_outlook": "bullish|neutral_to_bullish|neutral'
|
||||
f'|neutral_to_bearish|bearish",\n'
|
||||
f' "global_rules": [\n'
|
||||
f' {{"condition": "portfolio_pnl_pct < -2.0",'
|
||||
f' "action": "REDUCE_ALL", "rationale": "..."}}\n'
|
||||
f' ],\n'
|
||||
f' "stocks": [\n'
|
||||
f' {{\n'
|
||||
f' "stock_code": "...",\n'
|
||||
f' "scenarios": [\n'
|
||||
f' {{\n'
|
||||
f' "condition": {{"rsi_below": 30, "volume_ratio_above": 2.0,'
|
||||
f' "unrealized_pnl_pct_above": 3.0, "holding_days_above": 5}},\n'
|
||||
f' "action": "BUY|SELL|HOLD",\n'
|
||||
f' "confidence": 85,\n'
|
||||
f' "allocation_pct": 10.0,\n'
|
||||
f' "stop_loss_pct": -2.0,\n'
|
||||
f' "take_profit_pct": 3.0,\n'
|
||||
f' "rationale": "..."\n'
|
||||
f' }}\n'
|
||||
f' ]\n'
|
||||
f' }}\n'
|
||||
f' ]\n'
|
||||
f'}}\n\n'
|
||||
f"Rules:\n"
|
||||
f"- Max {max_scenarios} scenarios per stock\n"
|
||||
f"- Candidates list is the primary source for BUY candidates\n"
|
||||
f"{holdings_instruction}"
|
||||
f"- Confidence 0-100 (80+ for actionable trades)\n"
|
||||
f"- stop_loss_pct must be <= 0, take_profit_pct must be >= 0\n"
|
||||
f"- Return ONLY the JSON, no markdown fences or explanation\n"
|
||||
)
|
||||
|
||||
def _parse_response(
|
||||
self,
|
||||
response_text: str,
|
||||
today: date,
|
||||
market: str,
|
||||
candidates: list[ScanCandidate],
|
||||
cross_market: CrossMarketContext | None,
|
||||
current_holdings: list[dict] | None = None,
|
||||
) -> DayPlaybook:
|
||||
"""Parse Gemini's JSON response into a validated DayPlaybook."""
|
||||
cleaned = self._extract_json(response_text)
|
||||
data = json.loads(cleaned)
|
||||
|
||||
valid_codes = {c.stock_code for c in candidates}
|
||||
# Holdings are also valid — AI may generate SELL/HOLD scenarios for them
|
||||
if current_holdings:
|
||||
for h in current_holdings:
|
||||
code = h.get("stock_code", "")
|
||||
if code:
|
||||
valid_codes.add(code)
|
||||
|
||||
# Parse market outlook
|
||||
outlook_str = data.get("market_outlook", "neutral")
|
||||
market_outlook = _OUTLOOK_MAP.get(outlook_str, MarketOutlook.NEUTRAL)
|
||||
|
||||
# Parse global rules
|
||||
global_rules = []
|
||||
for rule_data in data.get("global_rules", []):
|
||||
action_str = rule_data.get("action", "HOLD")
|
||||
action = _ACTION_MAP.get(action_str, ScenarioAction.HOLD)
|
||||
global_rules.append(
|
||||
GlobalRule(
|
||||
condition=rule_data.get("condition", ""),
|
||||
action=action,
|
||||
rationale=rule_data.get("rationale", ""),
|
||||
)
|
||||
)
|
||||
|
||||
# Parse stock playbooks
|
||||
stock_playbooks = []
|
||||
max_scenarios = self._settings.MAX_SCENARIOS_PER_STOCK
|
||||
for stock_data in data.get("stocks", []):
|
||||
code = stock_data.get("stock_code", "")
|
||||
if code not in valid_codes:
|
||||
logger.warning("Gemini returned unknown stock %s — skipping", code)
|
||||
continue
|
||||
|
||||
scenarios = []
|
||||
for sc_data in stock_data.get("scenarios", [])[:max_scenarios]:
|
||||
scenario = self._parse_scenario(sc_data)
|
||||
if scenario:
|
||||
scenarios.append(scenario)
|
||||
|
||||
if scenarios:
|
||||
stock_playbooks.append(
|
||||
StockPlaybook(
|
||||
stock_code=code,
|
||||
scenarios=scenarios,
|
||||
)
|
||||
)
|
||||
|
||||
return DayPlaybook(
|
||||
date=today,
|
||||
market=market,
|
||||
market_outlook=market_outlook,
|
||||
global_rules=global_rules,
|
||||
stock_playbooks=stock_playbooks,
|
||||
cross_market=cross_market,
|
||||
)
|
||||
|
||||
def _parse_scenario(self, sc_data: dict) -> StockScenario | None:
|
||||
"""Parse a single scenario from JSON data. Returns None if invalid."""
|
||||
try:
|
||||
cond_data = sc_data.get("condition", {})
|
||||
condition = StockCondition(
|
||||
rsi_below=cond_data.get("rsi_below"),
|
||||
rsi_above=cond_data.get("rsi_above"),
|
||||
volume_ratio_above=cond_data.get("volume_ratio_above"),
|
||||
volume_ratio_below=cond_data.get("volume_ratio_below"),
|
||||
price_above=cond_data.get("price_above"),
|
||||
price_below=cond_data.get("price_below"),
|
||||
price_change_pct_above=cond_data.get("price_change_pct_above"),
|
||||
price_change_pct_below=cond_data.get("price_change_pct_below"),
|
||||
unrealized_pnl_pct_above=cond_data.get("unrealized_pnl_pct_above"),
|
||||
unrealized_pnl_pct_below=cond_data.get("unrealized_pnl_pct_below"),
|
||||
holding_days_above=cond_data.get("holding_days_above"),
|
||||
holding_days_below=cond_data.get("holding_days_below"),
|
||||
)
|
||||
|
||||
if not condition.has_any_condition():
|
||||
logger.warning("Scenario has no conditions — skipping")
|
||||
return None
|
||||
|
||||
action_str = sc_data.get("action", "HOLD")
|
||||
action = _ACTION_MAP.get(action_str, ScenarioAction.HOLD)
|
||||
|
||||
return StockScenario(
|
||||
condition=condition,
|
||||
action=action,
|
||||
confidence=int(sc_data.get("confidence", 50)),
|
||||
allocation_pct=float(sc_data.get("allocation_pct", 10.0)),
|
||||
stop_loss_pct=float(sc_data.get("stop_loss_pct", -2.0)),
|
||||
take_profit_pct=float(sc_data.get("take_profit_pct", 3.0)),
|
||||
rationale=sc_data.get("rationale", ""),
|
||||
)
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning("Failed to parse scenario: %s", e)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_json(text: str) -> str:
|
||||
"""Extract JSON from response, stripping markdown fences if present."""
|
||||
stripped = text.strip()
|
||||
if stripped.startswith("```"):
|
||||
# Remove first line (```json or ```) and last line (```)
|
||||
lines = stripped.split("\n")
|
||||
lines = lines[1:] # Remove opening fence
|
||||
if lines and lines[-1].strip() == "```":
|
||||
lines = lines[:-1]
|
||||
stripped = "\n".join(lines)
|
||||
return stripped.strip()
|
||||
|
||||
@staticmethod
|
||||
def _empty_playbook(today: date, market: str) -> DayPlaybook:
|
||||
"""Return an empty playbook (no stocks, no scenarios)."""
|
||||
return DayPlaybook(
|
||||
date=today,
|
||||
market=market,
|
||||
market_outlook=MarketOutlook.NEUTRAL,
|
||||
stock_playbooks=[],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _defensive_playbook(
|
||||
today: date,
|
||||
market: str,
|
||||
candidates: list[ScanCandidate],
|
||||
) -> DayPlaybook:
|
||||
"""Return a defensive playbook — HOLD everything with stop-loss ready."""
|
||||
stock_playbooks = [
|
||||
StockPlaybook(
|
||||
stock_code=c.stock_code,
|
||||
scenarios=[
|
||||
StockScenario(
|
||||
condition=StockCondition(price_change_pct_below=-3.0),
|
||||
action=ScenarioAction.SELL,
|
||||
confidence=90,
|
||||
stop_loss_pct=-3.0,
|
||||
rationale="Defensive stop-loss (planner failure)",
|
||||
),
|
||||
],
|
||||
)
|
||||
for c in candidates
|
||||
]
|
||||
return DayPlaybook(
|
||||
date=today,
|
||||
market=market,
|
||||
market_outlook=MarketOutlook.NEUTRAL_TO_BEARISH,
|
||||
default_action=ScenarioAction.HOLD,
|
||||
stock_playbooks=stock_playbooks,
|
||||
global_rules=[
|
||||
GlobalRule(
|
||||
condition="portfolio_pnl_pct < -2.0",
|
||||
action=ScenarioAction.REDUCE_ALL,
|
||||
rationale="Defensive: reduce on loss threshold",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _smart_fallback_playbook(
|
||||
today: date,
|
||||
market: str,
|
||||
candidates: list[ScanCandidate],
|
||||
settings: Settings,
|
||||
) -> DayPlaybook:
|
||||
"""Rule-based fallback playbook when Gemini is unavailable.
|
||||
|
||||
Uses scanner signals (RSI, volume_ratio) to generate meaningful BUY
|
||||
conditions instead of the all-SELL defensive playbook. Candidates are
|
||||
already pre-qualified by SmartVolatilityScanner, so we trust their
|
||||
signals and build actionable scenarios from them.
|
||||
|
||||
Scenario logic per candidate:
|
||||
- momentum signal: BUY when volume_ratio exceeds scanner threshold
|
||||
- oversold signal: BUY when RSI is below oversold threshold
|
||||
- always: SELL stop-loss at -3.0% as guard
|
||||
"""
|
||||
stock_playbooks = []
|
||||
for c in candidates:
|
||||
scenarios: list[StockScenario] = []
|
||||
|
||||
if c.signal == "momentum":
|
||||
scenarios.append(
|
||||
StockScenario(
|
||||
condition=StockCondition(
|
||||
volume_ratio_above=settings.VOL_MULTIPLIER,
|
||||
),
|
||||
action=ScenarioAction.BUY,
|
||||
confidence=80,
|
||||
allocation_pct=10.0,
|
||||
stop_loss_pct=-3.0,
|
||||
take_profit_pct=5.0,
|
||||
rationale=(
|
||||
f"Rule-based BUY: momentum signal, "
|
||||
f"volume={c.volume_ratio:.1f}x (fallback planner)"
|
||||
),
|
||||
)
|
||||
)
|
||||
elif c.signal == "oversold":
|
||||
scenarios.append(
|
||||
StockScenario(
|
||||
condition=StockCondition(
|
||||
rsi_below=settings.RSI_OVERSOLD_THRESHOLD,
|
||||
),
|
||||
action=ScenarioAction.BUY,
|
||||
confidence=80,
|
||||
allocation_pct=10.0,
|
||||
stop_loss_pct=-3.0,
|
||||
take_profit_pct=5.0,
|
||||
rationale=(
|
||||
f"Rule-based BUY: oversold signal, "
|
||||
f"RSI={c.rsi:.0f} (fallback planner)"
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Always add stop-loss guard
|
||||
scenarios.append(
|
||||
StockScenario(
|
||||
condition=StockCondition(price_change_pct_below=-3.0),
|
||||
action=ScenarioAction.SELL,
|
||||
confidence=90,
|
||||
stop_loss_pct=-3.0,
|
||||
rationale="Rule-based stop-loss (fallback planner)",
|
||||
)
|
||||
)
|
||||
|
||||
stock_playbooks.append(
|
||||
StockPlaybook(
|
||||
stock_code=c.stock_code,
|
||||
scenarios=scenarios,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Smart fallback playbook for %s: %d stocks with rule-based BUY/SELL conditions",
|
||||
market,
|
||||
len(stock_playbooks),
|
||||
)
|
||||
return DayPlaybook(
|
||||
date=today,
|
||||
market=market,
|
||||
market_outlook=MarketOutlook.NEUTRAL,
|
||||
default_action=ScenarioAction.HOLD,
|
||||
stock_playbooks=stock_playbooks,
|
||||
global_rules=[
|
||||
GlobalRule(
|
||||
condition="portfolio_pnl_pct < -2.0",
|
||||
action=ScenarioAction.REDUCE_ALL,
|
||||
rationale="Defensive: reduce on loss threshold",
|
||||
),
|
||||
],
|
||||
)
|
||||
305
src/strategy/scenario_engine.py
Normal file
305
src/strategy/scenario_engine.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""Local scenario engine for playbook execution.
|
||||
|
||||
Matches real-time market conditions against pre-defined scenarios
|
||||
without any API calls. Designed for sub-100ms execution.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from src.strategy.models import (
|
||||
DayPlaybook,
|
||||
GlobalRule,
|
||||
ScenarioAction,
|
||||
StockCondition,
|
||||
StockScenario,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScenarioMatch:
|
||||
"""Result of matching market conditions against scenarios."""
|
||||
|
||||
stock_code: str
|
||||
matched_scenario: StockScenario | None
|
||||
action: ScenarioAction
|
||||
confidence: int
|
||||
rationale: str
|
||||
global_rule_triggered: GlobalRule | None = None
|
||||
match_details: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ScenarioEngine:
|
||||
"""Evaluates playbook scenarios against real-time market data.
|
||||
|
||||
No API calls — pure Python condition matching.
|
||||
|
||||
Expected market_data keys: "rsi", "volume_ratio", "current_price", "price_change_pct".
|
||||
Callers must normalize data source keys to match this contract.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._warned_keys: set[str] = set()
|
||||
|
||||
@staticmethod
|
||||
def _safe_float(value: Any) -> float | None:
|
||||
"""Safely cast a value to float. Returns None on failure."""
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
def _warn_missing_key(self, key: str) -> None:
|
||||
"""Log a missing-key warning once per key per engine instance."""
|
||||
if key not in self._warned_keys:
|
||||
self._warned_keys.add(key)
|
||||
logger.warning("Condition requires '%s' but key missing from market_data", key)
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
playbook: DayPlaybook,
|
||||
stock_code: str,
|
||||
market_data: dict[str, Any],
|
||||
portfolio_data: dict[str, Any],
|
||||
) -> ScenarioMatch:
|
||||
"""Match market conditions to scenarios and return a decision.
|
||||
|
||||
Algorithm:
|
||||
1. Check global rules first (portfolio-level circuit breakers)
|
||||
2. Find the StockPlaybook for the given stock_code
|
||||
3. Iterate scenarios in order (first match wins)
|
||||
4. If no match, return playbook.default_action (HOLD)
|
||||
|
||||
Args:
|
||||
playbook: Today's DayPlaybook for this market
|
||||
stock_code: Stock ticker to evaluate
|
||||
market_data: Real-time market data (price, rsi, volume_ratio, etc.)
|
||||
portfolio_data: Portfolio state (pnl_pct, total_cash, etc.)
|
||||
|
||||
Returns:
|
||||
ScenarioMatch with the decision
|
||||
"""
|
||||
# 1. Check global rules
|
||||
triggered_rule = self.check_global_rules(playbook, portfolio_data)
|
||||
if triggered_rule is not None:
|
||||
logger.info(
|
||||
"Global rule triggered for %s: %s -> %s",
|
||||
stock_code,
|
||||
triggered_rule.condition,
|
||||
triggered_rule.action.value,
|
||||
)
|
||||
return ScenarioMatch(
|
||||
stock_code=stock_code,
|
||||
matched_scenario=None,
|
||||
action=triggered_rule.action,
|
||||
confidence=100,
|
||||
rationale=f"Global rule: {triggered_rule.rationale or triggered_rule.condition}",
|
||||
global_rule_triggered=triggered_rule,
|
||||
)
|
||||
|
||||
# 2. Find stock playbook
|
||||
stock_pb = playbook.get_stock_playbook(stock_code)
|
||||
if stock_pb is None:
|
||||
logger.debug("No playbook for %s — defaulting to %s", stock_code, playbook.default_action)
|
||||
return ScenarioMatch(
|
||||
stock_code=stock_code,
|
||||
matched_scenario=None,
|
||||
action=playbook.default_action,
|
||||
confidence=0,
|
||||
rationale=f"No scenarios defined for {stock_code}",
|
||||
)
|
||||
|
||||
# 3. Iterate scenarios (first match wins)
|
||||
for scenario in stock_pb.scenarios:
|
||||
if self.evaluate_condition(scenario.condition, market_data):
|
||||
logger.info(
|
||||
"Scenario matched for %s: %s (confidence=%d)",
|
||||
stock_code,
|
||||
scenario.action.value,
|
||||
scenario.confidence,
|
||||
)
|
||||
return ScenarioMatch(
|
||||
stock_code=stock_code,
|
||||
matched_scenario=scenario,
|
||||
action=scenario.action,
|
||||
confidence=scenario.confidence,
|
||||
rationale=scenario.rationale,
|
||||
match_details=self._build_match_details(scenario.condition, market_data),
|
||||
)
|
||||
|
||||
# 4. No match — default action
|
||||
logger.debug("No scenario matched for %s — defaulting to %s", stock_code, playbook.default_action)
|
||||
return ScenarioMatch(
|
||||
stock_code=stock_code,
|
||||
matched_scenario=None,
|
||||
action=playbook.default_action,
|
||||
confidence=0,
|
||||
rationale="No scenario conditions met — holding position",
|
||||
)
|
||||
|
||||
def check_global_rules(
|
||||
self,
|
||||
playbook: DayPlaybook,
|
||||
portfolio_data: dict[str, Any],
|
||||
) -> GlobalRule | None:
|
||||
"""Check portfolio-level rules. Returns first triggered rule or None."""
|
||||
for rule in playbook.global_rules:
|
||||
if self._evaluate_global_condition(rule.condition, portfolio_data):
|
||||
return rule
|
||||
return None
|
||||
|
||||
def evaluate_condition(
|
||||
self,
|
||||
condition: StockCondition,
|
||||
market_data: dict[str, Any],
|
||||
) -> bool:
|
||||
"""Evaluate all non-None fields in condition as AND.
|
||||
|
||||
Returns True only if ALL specified conditions are met.
|
||||
Empty condition (no fields set) returns False for safety.
|
||||
"""
|
||||
if not condition.has_any_condition():
|
||||
return False
|
||||
|
||||
checks: list[bool] = []
|
||||
|
||||
rsi = self._safe_float(market_data.get("rsi"))
|
||||
if condition.rsi_below is not None or condition.rsi_above is not None:
|
||||
if "rsi" not in market_data:
|
||||
self._warn_missing_key("rsi")
|
||||
if condition.rsi_below is not None:
|
||||
checks.append(rsi is not None and rsi < condition.rsi_below)
|
||||
if condition.rsi_above is not None:
|
||||
checks.append(rsi is not None and rsi > condition.rsi_above)
|
||||
|
||||
volume_ratio = self._safe_float(market_data.get("volume_ratio"))
|
||||
if condition.volume_ratio_above is not None or condition.volume_ratio_below is not None:
|
||||
if "volume_ratio" not in market_data:
|
||||
self._warn_missing_key("volume_ratio")
|
||||
if condition.volume_ratio_above is not None:
|
||||
checks.append(volume_ratio is not None and volume_ratio > condition.volume_ratio_above)
|
||||
if condition.volume_ratio_below is not None:
|
||||
checks.append(volume_ratio is not None and volume_ratio < condition.volume_ratio_below)
|
||||
|
||||
price = self._safe_float(market_data.get("current_price"))
|
||||
if condition.price_above is not None or condition.price_below is not None:
|
||||
if "current_price" not in market_data:
|
||||
self._warn_missing_key("current_price")
|
||||
if condition.price_above is not None:
|
||||
checks.append(price is not None and price > condition.price_above)
|
||||
if condition.price_below is not None:
|
||||
checks.append(price is not None and price < condition.price_below)
|
||||
|
||||
price_change_pct = self._safe_float(market_data.get("price_change_pct"))
|
||||
if condition.price_change_pct_above is not None or condition.price_change_pct_below is not None:
|
||||
if "price_change_pct" not in market_data:
|
||||
self._warn_missing_key("price_change_pct")
|
||||
if condition.price_change_pct_above is not None:
|
||||
checks.append(price_change_pct is not None and price_change_pct > condition.price_change_pct_above)
|
||||
if condition.price_change_pct_below is not None:
|
||||
checks.append(price_change_pct is not None and price_change_pct < condition.price_change_pct_below)
|
||||
|
||||
# Position-aware conditions
|
||||
unrealized_pnl_pct = self._safe_float(market_data.get("unrealized_pnl_pct"))
|
||||
if condition.unrealized_pnl_pct_above is not None or condition.unrealized_pnl_pct_below is not None:
|
||||
if "unrealized_pnl_pct" not in market_data:
|
||||
self._warn_missing_key("unrealized_pnl_pct")
|
||||
if condition.unrealized_pnl_pct_above is not None:
|
||||
checks.append(
|
||||
unrealized_pnl_pct is not None
|
||||
and unrealized_pnl_pct > condition.unrealized_pnl_pct_above
|
||||
)
|
||||
if condition.unrealized_pnl_pct_below is not None:
|
||||
checks.append(
|
||||
unrealized_pnl_pct is not None
|
||||
and unrealized_pnl_pct < condition.unrealized_pnl_pct_below
|
||||
)
|
||||
|
||||
holding_days = self._safe_float(market_data.get("holding_days"))
|
||||
if condition.holding_days_above is not None or condition.holding_days_below is not None:
|
||||
if "holding_days" not in market_data:
|
||||
self._warn_missing_key("holding_days")
|
||||
if condition.holding_days_above is not None:
|
||||
checks.append(
|
||||
holding_days is not None
|
||||
and holding_days > condition.holding_days_above
|
||||
)
|
||||
if condition.holding_days_below is not None:
|
||||
checks.append(
|
||||
holding_days is not None
|
||||
and holding_days < condition.holding_days_below
|
||||
)
|
||||
|
||||
return len(checks) > 0 and all(checks)
|
||||
|
||||
def _evaluate_global_condition(
|
||||
self,
|
||||
condition_str: str,
|
||||
portfolio_data: dict[str, Any],
|
||||
) -> bool:
|
||||
"""Evaluate a simple global condition string against portfolio data.
|
||||
|
||||
Supports: "field < value", "field > value", "field <= value", "field >= value"
|
||||
"""
|
||||
parts = condition_str.strip().split()
|
||||
if len(parts) != 3:
|
||||
logger.warning("Invalid global condition format: %s", condition_str)
|
||||
return False
|
||||
|
||||
field_name, operator, value_str = parts
|
||||
try:
|
||||
threshold = float(value_str)
|
||||
except ValueError:
|
||||
logger.warning("Invalid threshold in condition: %s", condition_str)
|
||||
return False
|
||||
|
||||
actual = portfolio_data.get(field_name)
|
||||
if actual is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
actual_val = float(actual)
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
if operator == "<":
|
||||
return actual_val < threshold
|
||||
elif operator == ">":
|
||||
return actual_val > threshold
|
||||
elif operator == "<=":
|
||||
return actual_val <= threshold
|
||||
elif operator == ">=":
|
||||
return actual_val >= threshold
|
||||
else:
|
||||
logger.warning("Unknown operator in condition: %s", operator)
|
||||
return False
|
||||
|
||||
def _build_match_details(
|
||||
self,
|
||||
condition: StockCondition,
|
||||
market_data: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Build a summary of which conditions matched and their normalized values."""
|
||||
details: dict[str, Any] = {}
|
||||
|
||||
if condition.rsi_below is not None or condition.rsi_above is not None:
|
||||
details["rsi"] = self._safe_float(market_data.get("rsi"))
|
||||
if condition.volume_ratio_above is not None or condition.volume_ratio_below is not None:
|
||||
details["volume_ratio"] = self._safe_float(market_data.get("volume_ratio"))
|
||||
if condition.price_above is not None or condition.price_below is not None:
|
||||
details["current_price"] = self._safe_float(market_data.get("current_price"))
|
||||
if condition.price_change_pct_above is not None or condition.price_change_pct_below is not None:
|
||||
details["price_change_pct"] = self._safe_float(market_data.get("price_change_pct"))
|
||||
if condition.unrealized_pnl_pct_above is not None or condition.unrealized_pnl_pct_below is not None:
|
||||
details["unrealized_pnl_pct"] = self._safe_float(market_data.get("unrealized_pnl_pct"))
|
||||
if condition.holding_days_above is not None or condition.holding_days_below is not None:
|
||||
details["holding_days"] = self._safe_float(market_data.get("holding_days"))
|
||||
|
||||
return details
|
||||
799
tests/test_backup.py
Normal file
799
tests/test_backup.py
Normal file
@@ -0,0 +1,799 @@
|
||||
"""Tests for backup and disaster recovery system."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import sys
|
||||
import tempfile
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.backup.exporter import BackupExporter, ExportFormat
|
||||
from src.backup.health_monitor import HealthMonitor, HealthStatus
|
||||
from src.backup.scheduler import BackupPolicy, BackupScheduler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db(tmp_path: Path) -> Path:
|
||||
"""Create a temporary test database."""
|
||||
db_path = tmp_path / "test_trades.db"
|
||||
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create trades table
|
||||
cursor.execute("""
|
||||
CREATE TABLE trades (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
timestamp TEXT NOT NULL,
|
||||
stock_code TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
quantity INTEGER NOT NULL,
|
||||
price REAL NOT NULL,
|
||||
confidence INTEGER NOT NULL,
|
||||
rationale TEXT,
|
||||
pnl REAL DEFAULT 0.0
|
||||
)
|
||||
""")
|
||||
|
||||
# Insert test data
|
||||
test_trades = [
|
||||
("2024-01-01T10:00:00Z", "005930", "BUY", 10, 70000.0, 85, "Test buy", 0.0),
|
||||
("2024-01-01T11:00:00Z", "005930", "SELL", 10, 71000.0, 90, "Test sell", 10000.0),
|
||||
("2024-01-02T10:00:00Z", "AAPL", "BUY", 5, 180.0, 88, "Tech buy", 0.0),
|
||||
]
|
||||
|
||||
cursor.executemany(
|
||||
"""
|
||||
INSERT INTO trades (timestamp, stock_code, action, quantity, price, confidence, rationale, pnl)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
test_trades,
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
return db_path
|
||||
|
||||
|
||||
class TestBackupExporter:
|
||||
"""Test BackupExporter functionality."""
|
||||
|
||||
def test_exporter_init(self, temp_db: Path) -> None:
|
||||
"""Test exporter initialization."""
|
||||
exporter = BackupExporter(str(temp_db))
|
||||
assert exporter.db_path == str(temp_db)
|
||||
|
||||
def test_export_json(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test JSON export."""
|
||||
exporter = BackupExporter(str(temp_db))
|
||||
output_dir = tmp_path / "exports"
|
||||
|
||||
results = exporter.export_all(
|
||||
output_dir, formats=[ExportFormat.JSON], compress=False
|
||||
)
|
||||
|
||||
assert ExportFormat.JSON in results
|
||||
assert results[ExportFormat.JSON].exists()
|
||||
assert results[ExportFormat.JSON].suffix == ".json"
|
||||
|
||||
def test_export_json_compressed(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test compressed JSON export."""
|
||||
exporter = BackupExporter(str(temp_db))
|
||||
output_dir = tmp_path / "exports"
|
||||
|
||||
results = exporter.export_all(
|
||||
output_dir, formats=[ExportFormat.JSON], compress=True
|
||||
)
|
||||
|
||||
assert ExportFormat.JSON in results
|
||||
assert results[ExportFormat.JSON].suffix == ".gz"
|
||||
|
||||
def test_export_csv(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test CSV export."""
|
||||
exporter = BackupExporter(str(temp_db))
|
||||
output_dir = tmp_path / "exports"
|
||||
|
||||
results = exporter.export_all(
|
||||
output_dir, formats=[ExportFormat.CSV], compress=False
|
||||
)
|
||||
|
||||
assert ExportFormat.CSV in results
|
||||
assert results[ExportFormat.CSV].exists()
|
||||
|
||||
# Verify CSV content
|
||||
with open(results[ExportFormat.CSV], "r") as f:
|
||||
lines = f.readlines()
|
||||
assert len(lines) == 4 # Header + 3 rows
|
||||
|
||||
def test_export_all_formats(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test exporting all formats."""
|
||||
exporter = BackupExporter(str(temp_db))
|
||||
output_dir = tmp_path / "exports"
|
||||
|
||||
# Skip Parquet if pyarrow not available
|
||||
try:
|
||||
import pyarrow # noqa: F401
|
||||
|
||||
formats = [ExportFormat.JSON, ExportFormat.CSV, ExportFormat.PARQUET]
|
||||
except ImportError:
|
||||
formats = [ExportFormat.JSON, ExportFormat.CSV]
|
||||
|
||||
results = exporter.export_all(output_dir, formats=formats, compress=False)
|
||||
|
||||
for fmt in formats:
|
||||
assert fmt in results
|
||||
assert results[fmt].exists()
|
||||
|
||||
def test_incremental_export(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test incremental export."""
|
||||
exporter = BackupExporter(str(temp_db))
|
||||
output_dir = tmp_path / "exports"
|
||||
|
||||
# Export only trades after Jan 2
|
||||
cutoff = datetime(2024, 1, 2, tzinfo=UTC)
|
||||
results = exporter.export_all(
|
||||
output_dir,
|
||||
formats=[ExportFormat.JSON],
|
||||
compress=False,
|
||||
incremental_since=cutoff,
|
||||
)
|
||||
|
||||
# Should only have 1 trade (AAPL on Jan 2)
|
||||
import json
|
||||
|
||||
with open(results[ExportFormat.JSON], "r") as f:
|
||||
data = json.load(f)
|
||||
assert data["record_count"] == 1
|
||||
assert data["trades"][0]["stock_code"] == "AAPL"
|
||||
|
||||
def test_get_export_stats(self, temp_db: Path) -> None:
|
||||
"""Test export statistics."""
|
||||
exporter = BackupExporter(str(temp_db))
|
||||
stats = exporter.get_export_stats()
|
||||
|
||||
assert stats["total_trades"] == 3
|
||||
assert "date_range" in stats
|
||||
assert "db_size_bytes" in stats
|
||||
|
||||
|
||||
class TestBackupScheduler:
|
||||
"""Test BackupScheduler functionality."""
|
||||
|
||||
def test_scheduler_init(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test scheduler initialization."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
|
||||
assert scheduler.db_path == temp_db
|
||||
assert (backup_dir / "daily").exists()
|
||||
assert (backup_dir / "weekly").exists()
|
||||
assert (backup_dir / "monthly").exists()
|
||||
|
||||
def test_create_daily_backup(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test daily backup creation."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
|
||||
metadata = scheduler.create_backup(BackupPolicy.DAILY, verify=True)
|
||||
|
||||
assert metadata.policy == BackupPolicy.DAILY
|
||||
assert metadata.file_path.exists()
|
||||
assert metadata.size_bytes > 0
|
||||
assert metadata.checksum is not None
|
||||
|
||||
def test_create_weekly_backup(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test weekly backup creation."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
|
||||
metadata = scheduler.create_backup(BackupPolicy.WEEKLY, verify=False)
|
||||
|
||||
assert metadata.policy == BackupPolicy.WEEKLY
|
||||
assert metadata.file_path.exists()
|
||||
assert metadata.checksum is None # verify=False
|
||||
|
||||
def test_list_backups(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test listing backups."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
|
||||
scheduler.create_backup(BackupPolicy.DAILY)
|
||||
scheduler.create_backup(BackupPolicy.WEEKLY)
|
||||
|
||||
backups = scheduler.list_backups()
|
||||
assert len(backups) == 2
|
||||
|
||||
daily_backups = scheduler.list_backups(BackupPolicy.DAILY)
|
||||
assert len(daily_backups) == 1
|
||||
assert daily_backups[0].policy == BackupPolicy.DAILY
|
||||
|
||||
def test_cleanup_old_backups(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test cleanup of old backups."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir, daily_retention_days=0)
|
||||
|
||||
# Create a backup
|
||||
scheduler.create_backup(BackupPolicy.DAILY)
|
||||
|
||||
# Cleanup should remove it (0 day retention)
|
||||
removed = scheduler.cleanup_old_backups()
|
||||
assert removed[BackupPolicy.DAILY] >= 1
|
||||
|
||||
def test_backup_stats(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test backup statistics."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
|
||||
scheduler.create_backup(BackupPolicy.DAILY)
|
||||
scheduler.create_backup(BackupPolicy.MONTHLY)
|
||||
|
||||
stats = scheduler.get_backup_stats()
|
||||
|
||||
assert stats["daily"]["count"] == 1
|
||||
assert stats["monthly"]["count"] == 1
|
||||
assert stats["daily"]["total_size_bytes"] > 0
|
||||
|
||||
def test_restore_backup(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test backup restoration."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
|
||||
# Create backup
|
||||
metadata = scheduler.create_backup(BackupPolicy.DAILY)
|
||||
|
||||
# Modify database
|
||||
conn = sqlite3.connect(str(temp_db))
|
||||
conn.execute("DELETE FROM trades")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# Restore
|
||||
scheduler.restore_backup(metadata, verify=True)
|
||||
|
||||
# Verify restoration
|
||||
conn = sqlite3.connect(str(temp_db))
|
||||
cursor = conn.execute("SELECT COUNT(*) FROM trades")
|
||||
count = cursor.fetchone()[0]
|
||||
conn.close()
|
||||
|
||||
assert count == 3 # Original 3 trades restored
|
||||
|
||||
|
||||
class TestHealthMonitor:
|
||||
"""Test HealthMonitor functionality."""
|
||||
|
||||
def test_monitor_init(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test monitor initialization."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
monitor = HealthMonitor(str(temp_db), backup_dir)
|
||||
|
||||
assert monitor.db_path == temp_db
|
||||
|
||||
def test_check_database_health_ok(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test database health check (healthy)."""
|
||||
monitor = HealthMonitor(str(temp_db), tmp_path / "backups")
|
||||
result = monitor.check_database_health()
|
||||
|
||||
assert result.status == HealthStatus.HEALTHY
|
||||
assert "healthy" in result.message.lower()
|
||||
assert result.details is not None
|
||||
assert result.details["trade_count"] == 3
|
||||
|
||||
def test_check_database_health_missing(self, tmp_path: Path) -> None:
|
||||
"""Test database health check (missing file)."""
|
||||
non_existent = tmp_path / "missing.db"
|
||||
monitor = HealthMonitor(str(non_existent), tmp_path / "backups")
|
||||
result = monitor.check_database_health()
|
||||
|
||||
assert result.status == HealthStatus.UNHEALTHY
|
||||
assert "not found" in result.message.lower()
|
||||
|
||||
def test_check_disk_space(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test disk space check."""
|
||||
monitor = HealthMonitor(str(temp_db), tmp_path, min_disk_space_gb=0.001)
|
||||
result = monitor.check_disk_space()
|
||||
|
||||
# Should be healthy with minimal requirement
|
||||
assert result.status in [HealthStatus.HEALTHY, HealthStatus.DEGRADED]
|
||||
assert result.details is not None
|
||||
assert "free_gb" in result.details
|
||||
|
||||
def test_check_backup_recency_no_backups(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test backup recency check (no backups)."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
backup_dir.mkdir()
|
||||
(backup_dir / "daily").mkdir()
|
||||
|
||||
monitor = HealthMonitor(str(temp_db), backup_dir)
|
||||
result = monitor.check_backup_recency()
|
||||
|
||||
assert result.status == HealthStatus.UNHEALTHY
|
||||
assert "no" in result.message.lower()
|
||||
|
||||
def test_check_backup_recency_recent(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test backup recency check (recent backup)."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
scheduler.create_backup(BackupPolicy.DAILY)
|
||||
|
||||
monitor = HealthMonitor(str(temp_db), backup_dir)
|
||||
result = monitor.check_backup_recency()
|
||||
|
||||
assert result.status == HealthStatus.HEALTHY
|
||||
assert "recent" in result.message.lower()
|
||||
|
||||
def test_run_all_checks(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test running all health checks."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
scheduler.create_backup(BackupPolicy.DAILY)
|
||||
|
||||
monitor = HealthMonitor(str(temp_db), backup_dir, min_disk_space_gb=0.001)
|
||||
checks = monitor.run_all_checks()
|
||||
|
||||
assert "database" in checks
|
||||
assert "disk_space" in checks
|
||||
assert "backup_recency" in checks
|
||||
assert checks["database"].status == HealthStatus.HEALTHY
|
||||
|
||||
def test_get_overall_status(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test overall health status."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
scheduler.create_backup(BackupPolicy.DAILY)
|
||||
|
||||
monitor = HealthMonitor(str(temp_db), backup_dir, min_disk_space_gb=0.001)
|
||||
status = monitor.get_overall_status()
|
||||
|
||||
assert status in [HealthStatus.HEALTHY, HealthStatus.DEGRADED]
|
||||
|
||||
def test_get_health_report(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test health report generation."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
scheduler.create_backup(BackupPolicy.DAILY)
|
||||
|
||||
monitor = HealthMonitor(str(temp_db), backup_dir, min_disk_space_gb=0.001)
|
||||
report = monitor.get_health_report()
|
||||
|
||||
assert "overall_status" in report
|
||||
assert "timestamp" in report
|
||||
assert "checks" in report
|
||||
assert len(report["checks"]) == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BackupExporter — additional coverage for previously uncovered branches
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_db(tmp_path: Path) -> Path:
|
||||
"""Create a temporary database with NO trade records."""
|
||||
db_path = tmp_path / "empty_trades.db"
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.execute(
|
||||
"""CREATE TABLE trades (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
timestamp TEXT NOT NULL,
|
||||
stock_code TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
quantity INTEGER NOT NULL,
|
||||
price REAL NOT NULL,
|
||||
confidence INTEGER NOT NULL,
|
||||
rationale TEXT,
|
||||
pnl REAL DEFAULT 0.0
|
||||
)"""
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return db_path
|
||||
|
||||
|
||||
class TestBackupExporterAdditional:
|
||||
"""Cover branches missed in the original TestBackupExporter suite."""
|
||||
|
||||
def test_export_all_default_formats(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""export_all with formats=None must default to JSON+CSV+Parquet path."""
|
||||
exporter = BackupExporter(str(temp_db))
|
||||
# formats=None triggers the default list assignment (line 62)
|
||||
results = exporter.export_all(tmp_path / "out", formats=None, compress=False)
|
||||
# JSON and CSV must always succeed; Parquet needs pyarrow
|
||||
assert ExportFormat.JSON in results
|
||||
assert ExportFormat.CSV in results
|
||||
|
||||
def test_export_all_logs_error_on_failure(
|
||||
self, temp_db: Path, tmp_path: Path
|
||||
) -> None:
|
||||
"""export_all must log an error and continue when one format fails."""
|
||||
exporter = BackupExporter(str(temp_db))
|
||||
# Patch _export_format to raise on JSON, succeed on CSV
|
||||
original = exporter._export_format
|
||||
|
||||
def failing_export(fmt, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
if fmt == ExportFormat.JSON:
|
||||
raise RuntimeError("simulated failure")
|
||||
return original(fmt, *args, **kwargs)
|
||||
|
||||
exporter._export_format = failing_export # type: ignore[method-assign]
|
||||
results = exporter.export_all(
|
||||
tmp_path / "out",
|
||||
formats=[ExportFormat.JSON, ExportFormat.CSV],
|
||||
compress=False,
|
||||
)
|
||||
# JSON failed → not in results; CSV succeeded → in results
|
||||
assert ExportFormat.JSON not in results
|
||||
assert ExportFormat.CSV in results
|
||||
|
||||
def test_export_csv_empty_trades_no_compress(
|
||||
self, empty_db: Path, tmp_path: Path
|
||||
) -> None:
|
||||
"""CSV export with no trades and compress=False must write header row only."""
|
||||
exporter = BackupExporter(str(empty_db))
|
||||
results = exporter.export_all(
|
||||
tmp_path / "out",
|
||||
formats=[ExportFormat.CSV],
|
||||
compress=False,
|
||||
)
|
||||
assert ExportFormat.CSV in results
|
||||
out = results[ExportFormat.CSV]
|
||||
assert out.exists()
|
||||
content = out.read_text()
|
||||
assert "timestamp" in content
|
||||
|
||||
def test_export_csv_empty_trades_compressed(
|
||||
self, empty_db: Path, tmp_path: Path
|
||||
) -> None:
|
||||
"""CSV export with no trades and compress=True must write gzipped header."""
|
||||
import gzip
|
||||
|
||||
exporter = BackupExporter(str(empty_db))
|
||||
results = exporter.export_all(
|
||||
tmp_path / "out",
|
||||
formats=[ExportFormat.CSV],
|
||||
compress=True,
|
||||
)
|
||||
assert ExportFormat.CSV in results
|
||||
out = results[ExportFormat.CSV]
|
||||
assert out.suffix == ".gz"
|
||||
with gzip.open(out, "rt", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
assert "timestamp" in content
|
||||
|
||||
def test_export_csv_with_data_compressed(
|
||||
self, temp_db: Path, tmp_path: Path
|
||||
) -> None:
|
||||
"""CSV export with data and compress=True must write gzipped rows."""
|
||||
import gzip
|
||||
|
||||
exporter = BackupExporter(str(temp_db))
|
||||
results = exporter.export_all(
|
||||
tmp_path / "out",
|
||||
formats=[ExportFormat.CSV],
|
||||
compress=True,
|
||||
)
|
||||
assert ExportFormat.CSV in results
|
||||
out = results[ExportFormat.CSV]
|
||||
with gzip.open(out, "rt", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
# Header + 3 data rows
|
||||
assert len(lines) == 4
|
||||
|
||||
def test_export_parquet_raises_import_error_without_pyarrow(
|
||||
self, temp_db: Path, tmp_path: Path
|
||||
) -> None:
|
||||
"""Parquet export must raise ImportError when pyarrow is not installed."""
|
||||
exporter = BackupExporter(str(temp_db))
|
||||
with patch.dict(sys.modules, {"pyarrow": None, "pyarrow.parquet": None}):
|
||||
try:
|
||||
import pyarrow # noqa: F401
|
||||
pytest.skip("pyarrow is installed; cannot test ImportError path")
|
||||
except ImportError:
|
||||
pass
|
||||
results = exporter.export_all(
|
||||
tmp_path / "out",
|
||||
formats=[ExportFormat.PARQUET],
|
||||
compress=False,
|
||||
)
|
||||
# Parquet export fails gracefully; result dict should not contain it
|
||||
assert ExportFormat.PARQUET not in results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CloudStorage — mocked boto3 tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_boto3_module():
|
||||
"""Inject a fake boto3 into sys.modules for the duration of the test."""
|
||||
mock = MagicMock()
|
||||
with patch.dict(sys.modules, {"boto3": mock}):
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_config():
|
||||
"""Minimal S3Config for tests."""
|
||||
from src.backup.cloud_storage import S3Config
|
||||
|
||||
return S3Config(
|
||||
endpoint_url="http://localhost:9000",
|
||||
access_key="minioadmin",
|
||||
secret_key="minioadmin",
|
||||
bucket_name="test-bucket",
|
||||
region="us-east-1",
|
||||
)
|
||||
|
||||
|
||||
class TestCloudStorage:
|
||||
"""Test CloudStorage using mocked boto3."""
|
||||
|
||||
def test_init_creates_s3_client(self, mock_boto3_module, s3_config) -> None:
|
||||
"""CloudStorage.__init__ must call boto3.client with the correct args."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
mock_boto3_module.client.assert_called_once()
|
||||
call_kwargs = mock_boto3_module.client.call_args[1]
|
||||
assert call_kwargs["aws_access_key_id"] == "minioadmin"
|
||||
assert call_kwargs["aws_secret_access_key"] == "minioadmin"
|
||||
assert storage.config == s3_config
|
||||
|
||||
def test_init_raises_if_boto3_missing(self, s3_config) -> None:
|
||||
"""CloudStorage.__init__ must raise ImportError when boto3 is absent."""
|
||||
with patch.dict(sys.modules, {"boto3": None}): # type: ignore[dict-item]
|
||||
with pytest.raises((ImportError, TypeError)):
|
||||
# Re-import to trigger the try/except inside __init__
|
||||
import importlib
|
||||
|
||||
import src.backup.cloud_storage as m
|
||||
|
||||
importlib.reload(m)
|
||||
m.CloudStorage(s3_config)
|
||||
|
||||
def test_upload_file_success(
|
||||
self, mock_boto3_module, s3_config, tmp_path: Path
|
||||
) -> None:
|
||||
"""upload_file must call client.upload_file and return the object key."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
test_file = tmp_path / "backup.json.gz"
|
||||
test_file.write_bytes(b"data")
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
key = storage.upload_file(test_file, object_key="backups/backup.json.gz")
|
||||
|
||||
assert key == "backups/backup.json.gz"
|
||||
storage.client.upload_file.assert_called_once()
|
||||
|
||||
def test_upload_file_default_key(
|
||||
self, mock_boto3_module, s3_config, tmp_path: Path
|
||||
) -> None:
|
||||
"""upload_file without object_key must use the filename as key."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
test_file = tmp_path / "myfile.gz"
|
||||
test_file.write_bytes(b"data")
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
key = storage.upload_file(test_file)
|
||||
|
||||
assert key == "myfile.gz"
|
||||
|
||||
def test_upload_file_not_found(
|
||||
self, mock_boto3_module, s3_config, tmp_path: Path
|
||||
) -> None:
|
||||
"""upload_file must raise FileNotFoundError for missing files."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
with pytest.raises(FileNotFoundError):
|
||||
storage.upload_file(tmp_path / "nonexistent.gz")
|
||||
|
||||
def test_upload_file_propagates_client_error(
|
||||
self, mock_boto3_module, s3_config, tmp_path: Path
|
||||
) -> None:
|
||||
"""upload_file must re-raise exceptions from the boto3 client."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
test_file = tmp_path / "backup.gz"
|
||||
test_file.write_bytes(b"data")
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
storage.client.upload_file.side_effect = RuntimeError("network error")
|
||||
|
||||
with pytest.raises(RuntimeError, match="network error"):
|
||||
storage.upload_file(test_file)
|
||||
|
||||
def test_download_file_success(
|
||||
self, mock_boto3_module, s3_config, tmp_path: Path
|
||||
) -> None:
|
||||
"""download_file must call client.download_file and return local path."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
dest = tmp_path / "downloads" / "backup.gz"
|
||||
|
||||
result = storage.download_file("backups/backup.gz", dest)
|
||||
|
||||
assert result == dest
|
||||
storage.client.download_file.assert_called_once()
|
||||
|
||||
def test_download_file_propagates_error(
|
||||
self, mock_boto3_module, s3_config, tmp_path: Path
|
||||
) -> None:
|
||||
"""download_file must re-raise exceptions from the boto3 client."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
storage.client.download_file.side_effect = RuntimeError("timeout")
|
||||
|
||||
with pytest.raises(RuntimeError, match="timeout"):
|
||||
storage.download_file("key", tmp_path / "dest.gz")
|
||||
|
||||
def test_list_files_returns_objects(
|
||||
self, mock_boto3_module, s3_config
|
||||
) -> None:
|
||||
"""list_files must return parsed file metadata from S3 response."""
|
||||
from datetime import timezone
|
||||
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
storage.client.list_objects_v2.return_value = {
|
||||
"Contents": [
|
||||
{
|
||||
"Key": "backups/a.gz",
|
||||
"Size": 1024,
|
||||
"LastModified": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
"ETag": '"abc123"',
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
files = storage.list_files(prefix="backups/")
|
||||
assert len(files) == 1
|
||||
assert files[0]["key"] == "backups/a.gz"
|
||||
assert files[0]["size_bytes"] == 1024
|
||||
|
||||
def test_list_files_empty_bucket(
|
||||
self, mock_boto3_module, s3_config
|
||||
) -> None:
|
||||
"""list_files must return empty list when bucket has no objects."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
storage.client.list_objects_v2.return_value = {}
|
||||
|
||||
files = storage.list_files()
|
||||
assert files == []
|
||||
|
||||
def test_list_files_propagates_error(
|
||||
self, mock_boto3_module, s3_config
|
||||
) -> None:
|
||||
"""list_files must re-raise exceptions from the boto3 client."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
storage.client.list_objects_v2.side_effect = RuntimeError("auth error")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
storage.list_files()
|
||||
|
||||
def test_delete_file_success(
|
||||
self, mock_boto3_module, s3_config
|
||||
) -> None:
|
||||
"""delete_file must call client.delete_object with the correct key."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
storage.delete_file("backups/old.gz")
|
||||
storage.client.delete_object.assert_called_once_with(
|
||||
Bucket="test-bucket", Key="backups/old.gz"
|
||||
)
|
||||
|
||||
def test_delete_file_propagates_error(
|
||||
self, mock_boto3_module, s3_config
|
||||
) -> None:
|
||||
"""delete_file must re-raise exceptions from the boto3 client."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
storage.client.delete_object.side_effect = RuntimeError("permission denied")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
storage.delete_file("backups/old.gz")
|
||||
|
||||
def test_get_storage_stats_success(
|
||||
self, mock_boto3_module, s3_config
|
||||
) -> None:
|
||||
"""get_storage_stats must aggregate file sizes correctly."""
|
||||
from datetime import timezone
|
||||
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
storage.client.list_objects_v2.return_value = {
|
||||
"Contents": [
|
||||
{
|
||||
"Key": "a.gz",
|
||||
"Size": 1024 * 1024,
|
||||
"LastModified": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
"ETag": '"x"',
|
||||
},
|
||||
{
|
||||
"Key": "b.gz",
|
||||
"Size": 1024 * 1024,
|
||||
"LastModified": datetime(2026, 1, 2, tzinfo=timezone.utc),
|
||||
"ETag": '"y"',
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
stats = storage.get_storage_stats()
|
||||
assert stats["total_files"] == 2
|
||||
assert stats["total_size_bytes"] == 2 * 1024 * 1024
|
||||
assert stats["total_size_mb"] == pytest.approx(2.0)
|
||||
|
||||
def test_get_storage_stats_on_error(
|
||||
self, mock_boto3_module, s3_config
|
||||
) -> None:
|
||||
"""get_storage_stats must return error dict without raising on failure."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
storage.client.list_objects_v2.side_effect = RuntimeError("no connection")
|
||||
|
||||
stats = storage.get_storage_stats()
|
||||
assert "error" in stats
|
||||
assert stats["total_files"] == 0
|
||||
|
||||
def test_verify_connection_success(
|
||||
self, mock_boto3_module, s3_config
|
||||
) -> None:
|
||||
"""verify_connection must return True when head_bucket succeeds."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
result = storage.verify_connection()
|
||||
assert result is True
|
||||
|
||||
def test_verify_connection_failure(
|
||||
self, mock_boto3_module, s3_config
|
||||
) -> None:
|
||||
"""verify_connection must return False when head_bucket raises."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
storage.client.head_bucket.side_effect = RuntimeError("no such bucket")
|
||||
|
||||
result = storage.verify_connection()
|
||||
assert result is False
|
||||
|
||||
def test_enable_versioning(
|
||||
self, mock_boto3_module, s3_config
|
||||
) -> None:
|
||||
"""enable_versioning must call put_bucket_versioning."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
storage.enable_versioning()
|
||||
storage.client.put_bucket_versioning.assert_called_once()
|
||||
|
||||
def test_enable_versioning_propagates_error(
|
||||
self, mock_boto3_module, s3_config
|
||||
) -> None:
|
||||
"""enable_versioning must re-raise exceptions from the boto3 client."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
storage.client.put_bucket_versioning.side_effect = RuntimeError("denied")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
storage.enable_versioning()
|
||||
@@ -2,6 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.brain.gemini_client import GeminiClient
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -89,9 +93,21 @@ class TestMalformedJsonHandling:
|
||||
|
||||
def test_json_with_missing_fields_returns_hold(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
decision = client.parse_response('{"action": "BUY"}')
|
||||
raw = '{"action": "BUY"}'
|
||||
decision = client.parse_response(raw)
|
||||
assert decision.action == "HOLD"
|
||||
assert decision.confidence == 0
|
||||
# rationale preserves raw so prompt_override callers (e.g. pre_market_planner)
|
||||
# can extract non-TradeDecision JSON from decision.rationale (#245)
|
||||
assert decision.rationale == raw
|
||||
|
||||
def test_non_trade_decision_json_preserves_raw_in_rationale(self, settings):
|
||||
"""Playbook JSON (no action/confidence/rationale) must be preserved for planner."""
|
||||
client = GeminiClient(settings)
|
||||
playbook_json = '{"market_outlook": "neutral", "stocks": []}'
|
||||
decision = client.parse_response(playbook_json)
|
||||
assert decision.action == "HOLD"
|
||||
assert decision.rationale == playbook_json
|
||||
|
||||
def test_json_with_invalid_action_returns_hold(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
@@ -126,7 +142,7 @@ class TestPromptConstruction:
|
||||
"orderbook": {"asks": [], "bids": []},
|
||||
"foreigner_net": -50000,
|
||||
}
|
||||
prompt = client.build_prompt(market_data)
|
||||
prompt = client.build_prompt_sync(market_data)
|
||||
assert "005930" in prompt
|
||||
|
||||
def test_prompt_contains_price(self, settings):
|
||||
@@ -137,7 +153,7 @@ class TestPromptConstruction:
|
||||
"orderbook": {"asks": [], "bids": []},
|
||||
"foreigner_net": -50000,
|
||||
}
|
||||
prompt = client.build_prompt(market_data)
|
||||
prompt = client.build_prompt_sync(market_data)
|
||||
assert "72000" in prompt
|
||||
|
||||
def test_prompt_enforces_json_output_format(self, settings):
|
||||
@@ -148,7 +164,254 @@ class TestPromptConstruction:
|
||||
"orderbook": {"asks": [], "bids": []},
|
||||
"foreigner_net": 0,
|
||||
}
|
||||
prompt = client.build_prompt(market_data)
|
||||
prompt = client.build_prompt_sync(market_data)
|
||||
assert "JSON" in prompt
|
||||
assert "action" in prompt
|
||||
assert "confidence" in prompt
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Batch Decision Making
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBatchDecisionParsing:
|
||||
"""Batch response parser must handle JSON arrays correctly."""
|
||||
|
||||
def test_parse_valid_batch_response(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
stocks_data = [
|
||||
{"stock_code": "AAPL", "current_price": 185.5},
|
||||
{"stock_code": "MSFT", "current_price": 420.0},
|
||||
]
|
||||
raw = """[
|
||||
{"code": "AAPL", "action": "BUY", "confidence": 85, "rationale": "Strong momentum"},
|
||||
{"code": "MSFT", "action": "HOLD", "confidence": 50, "rationale": "Wait for earnings"}
|
||||
]"""
|
||||
|
||||
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||
|
||||
assert len(decisions) == 2
|
||||
assert decisions["AAPL"].action == "BUY"
|
||||
assert decisions["AAPL"].confidence == 85
|
||||
assert decisions["MSFT"].action == "HOLD"
|
||||
assert decisions["MSFT"].confidence == 50
|
||||
|
||||
def test_parse_batch_with_markdown_wrapper(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
||||
raw = """```json
|
||||
[{"code": "AAPL", "action": "BUY", "confidence": 90, "rationale": "Good"}]
|
||||
```"""
|
||||
|
||||
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||
|
||||
assert decisions["AAPL"].action == "BUY"
|
||||
assert decisions["AAPL"].confidence == 90
|
||||
|
||||
def test_parse_batch_empty_response_returns_hold_for_all(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
stocks_data = [
|
||||
{"stock_code": "AAPL", "current_price": 185.5},
|
||||
{"stock_code": "MSFT", "current_price": 420.0},
|
||||
]
|
||||
|
||||
decisions = client._parse_batch_response("", stocks_data, token_count=100)
|
||||
|
||||
assert len(decisions) == 2
|
||||
assert decisions["AAPL"].action == "HOLD"
|
||||
assert decisions["AAPL"].confidence == 0
|
||||
assert decisions["MSFT"].action == "HOLD"
|
||||
|
||||
def test_parse_batch_malformed_json_returns_hold_for_all(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
||||
raw = "This is not JSON"
|
||||
|
||||
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||
|
||||
assert decisions["AAPL"].action == "HOLD"
|
||||
assert decisions["AAPL"].confidence == 0
|
||||
|
||||
def test_parse_batch_not_array_returns_hold_for_all(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
||||
raw = '{"code": "AAPL", "action": "BUY", "confidence": 90, "rationale": "Good"}'
|
||||
|
||||
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||
|
||||
assert decisions["AAPL"].action == "HOLD"
|
||||
assert decisions["AAPL"].confidence == 0
|
||||
|
||||
def test_parse_batch_missing_stock_gets_hold(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
stocks_data = [
|
||||
{"stock_code": "AAPL", "current_price": 185.5},
|
||||
{"stock_code": "MSFT", "current_price": 420.0},
|
||||
]
|
||||
# Response only has AAPL, MSFT is missing
|
||||
raw = '[{"code": "AAPL", "action": "BUY", "confidence": 85, "rationale": "Good"}]'
|
||||
|
||||
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||
|
||||
assert decisions["AAPL"].action == "BUY"
|
||||
assert decisions["MSFT"].action == "HOLD"
|
||||
assert decisions["MSFT"].confidence == 0
|
||||
|
||||
def test_parse_batch_invalid_action_becomes_hold(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
||||
raw = '[{"code": "AAPL", "action": "YOLO", "confidence": 90, "rationale": "Moon"}]'
|
||||
|
||||
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||
|
||||
assert decisions["AAPL"].action == "HOLD"
|
||||
|
||||
def test_parse_batch_low_confidence_becomes_hold(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
||||
raw = '[{"code": "AAPL", "action": "BUY", "confidence": 65, "rationale": "Weak"}]'
|
||||
|
||||
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||
|
||||
assert decisions["AAPL"].action == "HOLD"
|
||||
assert decisions["AAPL"].confidence == 65
|
||||
|
||||
def test_parse_batch_missing_fields_gets_hold(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
||||
raw = '[{"code": "AAPL", "action": "BUY"}]' # Missing confidence and rationale
|
||||
|
||||
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||
|
||||
assert decisions["AAPL"].action == "HOLD"
|
||||
assert decisions["AAPL"].confidence == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompt Override (used by pre_market_planner)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPromptOverride:
|
||||
"""decide() must use prompt_override when present in market_data."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_override_is_sent_to_gemini(self, settings):
|
||||
"""When prompt_override is in market_data, it should be used as the prompt."""
|
||||
client = GeminiClient(settings)
|
||||
|
||||
custom_prompt = "You are a playbook generator. Return JSON with scenarios."
|
||||
playbook_json = '{"market_outlook": "neutral", "stocks": []}'
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = playbook_json
|
||||
|
||||
with patch.object(
|
||||
client._client.aio.models,
|
||||
"generate_content",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_generate:
|
||||
market_data = {
|
||||
"stock_code": "PLANNER",
|
||||
"current_price": 0,
|
||||
"prompt_override": custom_prompt,
|
||||
}
|
||||
decision = await client.decide(market_data)
|
||||
|
||||
# Verify the custom prompt was sent, not a built prompt
|
||||
mock_generate.assert_called_once()
|
||||
actual_prompt = mock_generate.call_args[1].get(
|
||||
"contents", mock_generate.call_args[0][1] if len(mock_generate.call_args[0]) > 1 else None
|
||||
)
|
||||
assert actual_prompt == custom_prompt
|
||||
# Raw response preserved in rationale without parse_response (#247)
|
||||
assert decision.rationale == playbook_json
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_override_skips_parse_response(self, settings):
|
||||
"""prompt_override bypasses parse_response — no Missing fields warning, raw preserved."""
|
||||
client = GeminiClient(settings)
|
||||
client._enable_optimization = True
|
||||
|
||||
custom_prompt = "Custom playbook prompt"
|
||||
playbook_json = '{"market_outlook": "bullish", "stocks": [{"stock_code": "AAPL"}]}'
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = playbook_json
|
||||
|
||||
with patch.object(
|
||||
client._client.aio.models,
|
||||
"generate_content",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
with patch.object(client, "parse_response") as mock_parse:
|
||||
market_data = {
|
||||
"stock_code": "PLANNER",
|
||||
"current_price": 0,
|
||||
"prompt_override": custom_prompt,
|
||||
}
|
||||
decision = await client.decide(market_data)
|
||||
|
||||
# parse_response must NOT be called for prompt_override
|
||||
mock_parse.assert_not_called()
|
||||
# Raw playbook JSON preserved in rationale
|
||||
assert decision.rationale == playbook_json
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_override_takes_priority_over_optimization(self, settings):
|
||||
"""prompt_override must win over enable_optimization=True."""
|
||||
client = GeminiClient(settings)
|
||||
client._enable_optimization = True
|
||||
|
||||
custom_prompt = "Explicit playbook prompt"
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = '{"market_outlook": "neutral", "stocks": []}'
|
||||
|
||||
with patch.object(
|
||||
client._client.aio.models,
|
||||
"generate_content",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_generate:
|
||||
market_data = {
|
||||
"stock_code": "PLANNER",
|
||||
"current_price": 0,
|
||||
"prompt_override": custom_prompt,
|
||||
}
|
||||
await client.decide(market_data)
|
||||
|
||||
actual_prompt = mock_generate.call_args[1].get(
|
||||
"contents", mock_generate.call_args[0][1] if len(mock_generate.call_args[0]) > 1 else None
|
||||
)
|
||||
# The custom prompt must be used, not the compressed prompt
|
||||
assert actual_prompt == custom_prompt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_without_prompt_override_uses_build_prompt(self, settings):
|
||||
"""Without prompt_override, decide() should use build_prompt as before."""
|
||||
client = GeminiClient(settings)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = '{"action": "HOLD", "confidence": 50, "rationale": "ok"}'
|
||||
|
||||
with patch.object(
|
||||
client._client.aio.models,
|
||||
"generate_content",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_generate:
|
||||
market_data = {
|
||||
"stock_code": "005930",
|
||||
"current_price": 72000,
|
||||
}
|
||||
await client.decide(market_data)
|
||||
|
||||
actual_prompt = mock_generate.call_args[1].get(
|
||||
"contents", mock_generate.call_args[0][1] if len(mock_generate.call_args[0]) > 1 else None
|
||||
)
|
||||
# Should contain stock code from build_prompt, not be a custom override
|
||||
assert "005930" in actual_prompt
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -49,6 +49,110 @@ class TestTokenManagement:
|
||||
|
||||
await broker.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_token_refresh_calls_api_once(self, settings):
|
||||
"""Multiple concurrent token requests should only call API once."""
|
||||
broker = KISBroker(settings)
|
||||
|
||||
# Track how many times the mock API is called
|
||||
call_count = [0]
|
||||
|
||||
def create_mock_resp():
|
||||
call_count[0] += 1
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(
|
||||
return_value={
|
||||
"access_token": "tok_concurrent",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 86400,
|
||||
}
|
||||
)
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
return mock_resp
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=create_mock_resp()):
|
||||
# Launch 5 concurrent token requests
|
||||
tokens = await asyncio.gather(
|
||||
broker._ensure_token(),
|
||||
broker._ensure_token(),
|
||||
broker._ensure_token(),
|
||||
broker._ensure_token(),
|
||||
broker._ensure_token(),
|
||||
)
|
||||
|
||||
# All should get the same token
|
||||
assert all(t == "tok_concurrent" for t in tokens)
|
||||
# API should be called only once (due to lock)
|
||||
assert call_count[0] == 1
|
||||
|
||||
await broker.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_refresh_cooldown_waits_then_retries(self, settings):
|
||||
"""Token refresh should wait out cooldown then retry (issue #54)."""
|
||||
broker = KISBroker(settings)
|
||||
broker._refresh_cooldown = 0.1 # Short cooldown for testing
|
||||
|
||||
# All attempts fail with 403 (EGW00133)
|
||||
mock_resp_403 = AsyncMock()
|
||||
mock_resp_403.status = 403
|
||||
mock_resp_403.text = AsyncMock(
|
||||
return_value='{"error_code":"EGW00133","error_description":"접근토큰 발급 잠시 후 다시 시도하세요(1분당 1회)"}'
|
||||
)
|
||||
mock_resp_403.__aenter__ = AsyncMock(return_value=mock_resp_403)
|
||||
mock_resp_403.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp_403):
|
||||
# First attempt should fail with 403
|
||||
with pytest.raises(ConnectionError, match="Token refresh failed"):
|
||||
await broker._ensure_token()
|
||||
|
||||
# Second attempt within cooldown should wait then retry (and still get 403)
|
||||
with pytest.raises(ConnectionError, match="Token refresh failed"):
|
||||
await broker._ensure_token()
|
||||
|
||||
await broker.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_refresh_allowed_after_cooldown(self, settings):
|
||||
"""Token refresh should be allowed after cooldown period expires."""
|
||||
broker = KISBroker(settings)
|
||||
broker._refresh_cooldown = 0.1 # Very short cooldown for testing
|
||||
|
||||
# First attempt fails
|
||||
mock_resp_403 = AsyncMock()
|
||||
mock_resp_403.status = 403
|
||||
mock_resp_403.text = AsyncMock(return_value='{"error_code":"EGW00133"}')
|
||||
mock_resp_403.__aenter__ = AsyncMock(return_value=mock_resp_403)
|
||||
mock_resp_403.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
# Second attempt succeeds
|
||||
mock_resp_200 = AsyncMock()
|
||||
mock_resp_200.status = 200
|
||||
mock_resp_200.json = AsyncMock(
|
||||
return_value={
|
||||
"access_token": "tok_after_cooldown",
|
||||
"expires_in": 86400,
|
||||
}
|
||||
)
|
||||
mock_resp_200.__aenter__ = AsyncMock(return_value=mock_resp_200)
|
||||
mock_resp_200.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp_403):
|
||||
with pytest.raises(ConnectionError, match="Token refresh failed"):
|
||||
await broker._ensure_token()
|
||||
|
||||
# Wait for cooldown to expire
|
||||
await asyncio.sleep(0.15)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp_200):
|
||||
token = await broker._ensure_token()
|
||||
assert token == "tok_after_cooldown"
|
||||
|
||||
await broker.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Network Error Handling
|
||||
@@ -107,6 +211,38 @@ class TestRateLimiter:
|
||||
await broker._rate_limiter.acquire()
|
||||
await broker.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_order_acquires_rate_limiter_twice(self, settings):
|
||||
"""send_order must acquire rate limiter for both hash key and order call."""
|
||||
broker = KISBroker(settings)
|
||||
broker._access_token = "tok"
|
||||
broker._token_expires_at = asyncio.get_event_loop().time() + 3600
|
||||
|
||||
# Mock hash key response
|
||||
mock_hash_resp = AsyncMock()
|
||||
mock_hash_resp.status = 200
|
||||
mock_hash_resp.json = AsyncMock(return_value={"HASH": "abc123"})
|
||||
mock_hash_resp.__aenter__ = AsyncMock(return_value=mock_hash_resp)
|
||||
mock_hash_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
# Mock order response
|
||||
mock_order_resp = AsyncMock()
|
||||
mock_order_resp.status = 200
|
||||
mock_order_resp.json = AsyncMock(return_value={"rt_cd": "0"})
|
||||
mock_order_resp.__aenter__ = AsyncMock(return_value=mock_order_resp)
|
||||
mock_order_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post", side_effect=[mock_hash_resp, mock_order_resp]
|
||||
):
|
||||
with patch.object(
|
||||
broker._rate_limiter, "acquire", new_callable=AsyncMock
|
||||
) as mock_acquire:
|
||||
await broker.send_order("005930", "BUY", 1, 50000)
|
||||
assert mock_acquire.call_count == 2
|
||||
|
||||
await broker.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hash Key Generation
|
||||
@@ -136,3 +272,671 @@ class TestHashKey:
|
||||
assert len(hash_key) > 0
|
||||
|
||||
await broker.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hash_key_acquires_rate_limiter(self, settings):
|
||||
"""_get_hash_key must go through the rate limiter to prevent burst."""
|
||||
broker = KISBroker(settings)
|
||||
broker._access_token = "tok"
|
||||
broker._token_expires_at = asyncio.get_event_loop().time() + 3600
|
||||
|
||||
body = {"CANO": "12345678", "ACNT_PRDT_CD": "01"}
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(return_value={"HASH": "abc123hash"})
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
|
||||
with patch.object(
|
||||
broker._rate_limiter, "acquire", new_callable=AsyncMock
|
||||
) as mock_acquire:
|
||||
await broker._get_hash_key(body)
|
||||
mock_acquire.assert_called_once()
|
||||
|
||||
await broker.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# fetch_market_rankings — TR_ID, path, params (issue #155)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_ranking_mock(items: list[dict]) -> AsyncMock:
|
||||
"""Build a mock HTTP response returning ranking items."""
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(return_value={"output": items})
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
return mock_resp
|
||||
|
||||
|
||||
class TestFetchMarketRankings:
|
||||
"""Verify correct TR_ID, API path, and params per ranking_type (issue #155)."""
|
||||
|
||||
@pytest.fixture
|
||||
def broker(self, settings) -> KISBroker:
|
||||
b = KISBroker(settings)
|
||||
b._access_token = "tok"
|
||||
b._token_expires_at = float("inf")
|
||||
b._rate_limiter.acquire = AsyncMock()
|
||||
return b
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_volume_uses_correct_tr_id_and_path(self, broker: KISBroker) -> None:
|
||||
mock_resp = _make_ranking_mock([])
|
||||
with patch("aiohttp.ClientSession.get", return_value=mock_resp) as mock_get:
|
||||
await broker.fetch_market_rankings(ranking_type="volume")
|
||||
|
||||
call_kwargs = mock_get.call_args
|
||||
url = call_kwargs[0][0] if call_kwargs[0] else call_kwargs[1].get("url", "")
|
||||
headers = call_kwargs[1].get("headers", {})
|
||||
params = call_kwargs[1].get("params", {})
|
||||
|
||||
assert "volume-rank" in url
|
||||
assert headers.get("tr_id") == "FHPST01710000"
|
||||
assert params.get("FID_COND_SCR_DIV_CODE") == "20171"
|
||||
assert params.get("FID_TRGT_EXLS_CLS_CODE") == "0000000000"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fluctuation_uses_correct_tr_id_and_path(self, broker: KISBroker) -> None:
|
||||
mock_resp = _make_ranking_mock([])
|
||||
with patch("aiohttp.ClientSession.get", return_value=mock_resp) as mock_get:
|
||||
await broker.fetch_market_rankings(ranking_type="fluctuation")
|
||||
|
||||
call_kwargs = mock_get.call_args
|
||||
url = call_kwargs[0][0] if call_kwargs[0] else call_kwargs[1].get("url", "")
|
||||
headers = call_kwargs[1].get("headers", {})
|
||||
params = call_kwargs[1].get("params", {})
|
||||
|
||||
assert "ranking/fluctuation" in url
|
||||
assert headers.get("tr_id") == "FHPST01700000"
|
||||
assert params.get("fid_cond_scr_div_code") == "20170"
|
||||
# 실전 API는 4자리("0000") 거부 — 1자리("0")여야 한다 (#240)
|
||||
assert params.get("fid_rank_sort_cls_code") == "0"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_volume_returns_parsed_rows(self, broker: KISBroker) -> None:
|
||||
items = [
|
||||
{
|
||||
"mksc_shrn_iscd": "005930",
|
||||
"hts_kor_isnm": "삼성전자",
|
||||
"stck_prpr": "75000",
|
||||
"acml_vol": "10000000",
|
||||
"prdy_ctrt": "2.5",
|
||||
"vol_inrt": "150",
|
||||
}
|
||||
]
|
||||
mock_resp = _make_ranking_mock(items)
|
||||
with patch("aiohttp.ClientSession.get", return_value=mock_resp):
|
||||
result = await broker.fetch_market_rankings(ranking_type="volume")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["stock_code"] == "005930"
|
||||
assert result[0]["price"] == 75000.0
|
||||
assert result[0]["change_rate"] == 2.5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fluctuation_parses_stck_shrn_iscd(self, broker: KISBroker) -> None:
|
||||
"""실전 API는 mksc_shrn_iscd 대신 stck_shrn_iscd를 반환한다 (#240)."""
|
||||
items = [
|
||||
{
|
||||
"stck_shrn_iscd": "015260",
|
||||
"hts_kor_isnm": "에이엔피",
|
||||
"stck_prpr": "794",
|
||||
"acml_vol": "4896196",
|
||||
"prdy_ctrt": "29.74",
|
||||
"vol_inrt": "0",
|
||||
}
|
||||
]
|
||||
mock_resp = _make_ranking_mock(items)
|
||||
with patch("aiohttp.ClientSession.get", return_value=mock_resp):
|
||||
result = await broker.fetch_market_rankings(ranking_type="fluctuation")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["stock_code"] == "015260"
|
||||
assert result[0]["change_rate"] == 29.74
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# KRX tick unit / round-down helpers (issue #157)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
from src.broker.kis_api import kr_tick_unit, kr_round_down # noqa: E402
|
||||
|
||||
|
||||
class TestKrTickUnit:
|
||||
"""kr_tick_unit and kr_round_down must implement KRX price tick rules."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"price, expected_tick",
|
||||
[
|
||||
(1999, 1),
|
||||
(2000, 5),
|
||||
(4999, 5),
|
||||
(5000, 10),
|
||||
(19999, 10),
|
||||
(20000, 50),
|
||||
(49999, 50),
|
||||
(50000, 100),
|
||||
(199999, 100),
|
||||
(200000, 500),
|
||||
(499999, 500),
|
||||
(500000, 1000),
|
||||
(1000000, 1000),
|
||||
],
|
||||
)
|
||||
def test_tick_unit_boundaries(self, price: int, expected_tick: int) -> None:
|
||||
assert kr_tick_unit(price) == expected_tick
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"price, expected_rounded",
|
||||
[
|
||||
(188150, 188100), # 100원 단위, 50원 잔여 → 내림
|
||||
(188100, 188100), # 이미 정렬됨
|
||||
(75050, 75000), # 100원 단위, 50원 잔여 → 내림
|
||||
(49950, 49950), # 50원 단위 정렬됨
|
||||
(49960, 49950), # 50원 단위, 10원 잔여 → 내림
|
||||
(1999, 1999), # 1원 단위 → 그대로
|
||||
(5003, 5000), # 10원 단위, 3원 잔여 → 내림
|
||||
],
|
||||
)
|
||||
def test_round_down_to_tick(self, price: int, expected_rounded: int) -> None:
|
||||
assert kr_round_down(price) == expected_rounded
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_current_price (issue #157)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetCurrentPrice:
|
||||
"""get_current_price must use inquire-price API and return (price, change, foreigner)."""
|
||||
|
||||
@pytest.fixture
|
||||
def broker(self, settings) -> KISBroker:
|
||||
b = KISBroker(settings)
|
||||
b._access_token = "tok"
|
||||
b._token_expires_at = float("inf")
|
||||
b._rate_limiter.acquire = AsyncMock()
|
||||
return b
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_correct_fields(self, broker: KISBroker) -> None:
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(
|
||||
return_value={
|
||||
"rt_cd": "0",
|
||||
"output": {
|
||||
"stck_prpr": "188600",
|
||||
"prdy_ctrt": "3.97",
|
||||
"frgn_ntby_qty": "12345",
|
||||
},
|
||||
}
|
||||
)
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.get", return_value=mock_resp) as mock_get:
|
||||
price, change_pct, foreigner = await broker.get_current_price("005930")
|
||||
|
||||
assert price == 188600.0
|
||||
assert change_pct == 3.97
|
||||
assert foreigner == 12345.0
|
||||
|
||||
call_kwargs = mock_get.call_args
|
||||
url = call_kwargs[0][0] if call_kwargs[0] else call_kwargs[1].get("url", "")
|
||||
headers = call_kwargs[1].get("headers", {})
|
||||
assert "inquire-price" in url
|
||||
assert headers.get("tr_id") == "FHKST01010100"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_error_raises_connection_error(self, broker: KISBroker) -> None:
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 500
|
||||
mock_resp.text = AsyncMock(return_value="Internal Server Error")
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.get", return_value=mock_resp):
|
||||
with pytest.raises(ConnectionError, match="get_current_price failed"):
|
||||
await broker.get_current_price("005930")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# send_order tick rounding and ORD_DVSN (issue #157)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSendOrderTickRounding:
|
||||
"""send_order must apply KRX tick rounding and correct ORD_DVSN codes."""
|
||||
|
||||
@pytest.fixture
|
||||
def broker(self, settings) -> KISBroker:
|
||||
b = KISBroker(settings)
|
||||
b._access_token = "tok"
|
||||
b._token_expires_at = float("inf")
|
||||
b._rate_limiter.acquire = AsyncMock()
|
||||
return b
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_limit_order_rounds_down_to_tick(self, broker: KISBroker) -> None:
|
||||
"""Price 188150 (not on 100-won tick) must be rounded to 188100."""
|
||||
mock_hash = AsyncMock()
|
||||
mock_hash.status = 200
|
||||
mock_hash.json = AsyncMock(return_value={"HASH": "h"})
|
||||
mock_hash.__aenter__ = AsyncMock(return_value=mock_hash)
|
||||
mock_hash.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_order = AsyncMock()
|
||||
mock_order.status = 200
|
||||
mock_order.json = AsyncMock(return_value={"rt_cd": "0"})
|
||||
mock_order.__aenter__ = AsyncMock(return_value=mock_order)
|
||||
mock_order.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
|
||||
) as mock_post:
|
||||
await broker.send_order("005930", "BUY", 1, price=188150)
|
||||
|
||||
order_call = mock_post.call_args_list[1]
|
||||
body = order_call[1].get("json", {})
|
||||
assert body["ORD_UNPR"] == "188100" # rounded down
|
||||
assert body["ORD_DVSN"] == "00" # 지정가
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_limit_order_ord_dvsn_is_00(self, broker: KISBroker) -> None:
|
||||
"""send_order with price>0 must use ORD_DVSN='00' (지정가)."""
|
||||
mock_hash = AsyncMock()
|
||||
mock_hash.status = 200
|
||||
mock_hash.json = AsyncMock(return_value={"HASH": "h"})
|
||||
mock_hash.__aenter__ = AsyncMock(return_value=mock_hash)
|
||||
mock_hash.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_order = AsyncMock()
|
||||
mock_order.status = 200
|
||||
mock_order.json = AsyncMock(return_value={"rt_cd": "0"})
|
||||
mock_order.__aenter__ = AsyncMock(return_value=mock_order)
|
||||
mock_order.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
|
||||
) as mock_post:
|
||||
await broker.send_order("005930", "BUY", 1, price=50000)
|
||||
|
||||
order_call = mock_post.call_args_list[1]
|
||||
body = order_call[1].get("json", {})
|
||||
assert body["ORD_DVSN"] == "00"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_market_order_ord_dvsn_is_01(self, broker: KISBroker) -> None:
|
||||
"""send_order with price=0 must use ORD_DVSN='01' (시장가)."""
|
||||
mock_hash = AsyncMock()
|
||||
mock_hash.status = 200
|
||||
mock_hash.json = AsyncMock(return_value={"HASH": "h"})
|
||||
mock_hash.__aenter__ = AsyncMock(return_value=mock_hash)
|
||||
mock_hash.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_order = AsyncMock()
|
||||
mock_order.status = 200
|
||||
mock_order.json = AsyncMock(return_value={"rt_cd": "0"})
|
||||
mock_order.__aenter__ = AsyncMock(return_value=mock_order)
|
||||
mock_order.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
|
||||
) as mock_post:
|
||||
await broker.send_order("005930", "SELL", 1, price=0)
|
||||
|
||||
order_call = mock_post.call_args_list[1]
|
||||
body = order_call[1].get("json", {})
|
||||
assert body["ORD_DVSN"] == "01"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TR_ID live/paper branching (issues #201, #202, #203)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTRIDBranchingDomestic:
|
||||
"""get_balance and send_order must use correct TR_ID for live vs paper mode."""
|
||||
|
||||
def _make_broker(self, settings, mode: str) -> KISBroker:
|
||||
from src.config import Settings
|
||||
|
||||
s = Settings(
|
||||
KIS_APP_KEY=settings.KIS_APP_KEY,
|
||||
KIS_APP_SECRET=settings.KIS_APP_SECRET,
|
||||
KIS_ACCOUNT_NO=settings.KIS_ACCOUNT_NO,
|
||||
GEMINI_API_KEY=settings.GEMINI_API_KEY,
|
||||
DB_PATH=":memory:",
|
||||
ENABLED_MARKETS="KR",
|
||||
MODE=mode,
|
||||
)
|
||||
b = KISBroker(s)
|
||||
b._access_token = "tok"
|
||||
b._token_expires_at = float("inf")
|
||||
b._rate_limiter.acquire = AsyncMock()
|
||||
return b
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_balance_paper_uses_vttc8434r(self, settings) -> None:
|
||||
broker = self._make_broker(settings, "paper")
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(
|
||||
return_value={"output1": [], "output2": {}}
|
||||
)
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.get", return_value=mock_resp) as mock_get:
|
||||
await broker.get_balance()
|
||||
|
||||
headers = mock_get.call_args[1].get("headers", {})
|
||||
assert headers["tr_id"] == "VTTC8434R"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_balance_live_uses_tttc8434r(self, settings) -> None:
|
||||
broker = self._make_broker(settings, "live")
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(
|
||||
return_value={"output1": [], "output2": {}}
|
||||
)
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.get", return_value=mock_resp) as mock_get:
|
||||
await broker.get_balance()
|
||||
|
||||
headers = mock_get.call_args[1].get("headers", {})
|
||||
assert headers["tr_id"] == "TTTC8434R"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_order_buy_paper_uses_vttc0012u(self, settings) -> None:
|
||||
broker = self._make_broker(settings, "paper")
|
||||
mock_hash = AsyncMock()
|
||||
mock_hash.status = 200
|
||||
mock_hash.json = AsyncMock(return_value={"HASH": "h"})
|
||||
mock_hash.__aenter__ = AsyncMock(return_value=mock_hash)
|
||||
mock_hash.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_order = AsyncMock()
|
||||
mock_order.status = 200
|
||||
mock_order.json = AsyncMock(return_value={"rt_cd": "0"})
|
||||
mock_order.__aenter__ = AsyncMock(return_value=mock_order)
|
||||
mock_order.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
|
||||
) as mock_post:
|
||||
await broker.send_order("005930", "BUY", 1)
|
||||
|
||||
order_headers = mock_post.call_args_list[1][1].get("headers", {})
|
||||
assert order_headers["tr_id"] == "VTTC0012U"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_order_buy_live_uses_tttc0012u(self, settings) -> None:
|
||||
broker = self._make_broker(settings, "live")
|
||||
mock_hash = AsyncMock()
|
||||
mock_hash.status = 200
|
||||
mock_hash.json = AsyncMock(return_value={"HASH": "h"})
|
||||
mock_hash.__aenter__ = AsyncMock(return_value=mock_hash)
|
||||
mock_hash.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_order = AsyncMock()
|
||||
mock_order.status = 200
|
||||
mock_order.json = AsyncMock(return_value={"rt_cd": "0"})
|
||||
mock_order.__aenter__ = AsyncMock(return_value=mock_order)
|
||||
mock_order.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
|
||||
) as mock_post:
|
||||
await broker.send_order("005930", "BUY", 1)
|
||||
|
||||
order_headers = mock_post.call_args_list[1][1].get("headers", {})
|
||||
assert order_headers["tr_id"] == "TTTC0012U"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_order_sell_paper_uses_vttc0011u(self, settings) -> None:
|
||||
broker = self._make_broker(settings, "paper")
|
||||
mock_hash = AsyncMock()
|
||||
mock_hash.status = 200
|
||||
mock_hash.json = AsyncMock(return_value={"HASH": "h"})
|
||||
mock_hash.__aenter__ = AsyncMock(return_value=mock_hash)
|
||||
mock_hash.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_order = AsyncMock()
|
||||
mock_order.status = 200
|
||||
mock_order.json = AsyncMock(return_value={"rt_cd": "0"})
|
||||
mock_order.__aenter__ = AsyncMock(return_value=mock_order)
|
||||
mock_order.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
|
||||
) as mock_post:
|
||||
await broker.send_order("005930", "SELL", 1)
|
||||
|
||||
order_headers = mock_post.call_args_list[1][1].get("headers", {})
|
||||
assert order_headers["tr_id"] == "VTTC0011U"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_order_sell_live_uses_tttc0011u(self, settings) -> None:
|
||||
broker = self._make_broker(settings, "live")
|
||||
mock_hash = AsyncMock()
|
||||
mock_hash.status = 200
|
||||
mock_hash.json = AsyncMock(return_value={"HASH": "h"})
|
||||
mock_hash.__aenter__ = AsyncMock(return_value=mock_hash)
|
||||
mock_hash.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_order = AsyncMock()
|
||||
mock_order.status = 200
|
||||
mock_order.json = AsyncMock(return_value={"rt_cd": "0"})
|
||||
mock_order.__aenter__ = AsyncMock(return_value=mock_order)
|
||||
mock_order.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
|
||||
) as mock_post:
|
||||
await broker.send_order("005930", "SELL", 1)
|
||||
|
||||
order_headers = mock_post.call_args_list[1][1].get("headers", {})
|
||||
assert order_headers["tr_id"] == "TTTC0011U"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Domestic Pending Orders (get_domestic_pending_orders)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetDomesticPendingOrders:
|
||||
"""get_domestic_pending_orders must return [] in paper mode and call TTTC0084R in live."""
|
||||
|
||||
def _make_broker(self, settings, mode: str) -> KISBroker:
|
||||
from src.config import Settings
|
||||
|
||||
s = Settings(
|
||||
KIS_APP_KEY=settings.KIS_APP_KEY,
|
||||
KIS_APP_SECRET=settings.KIS_APP_SECRET,
|
||||
KIS_ACCOUNT_NO=settings.KIS_ACCOUNT_NO,
|
||||
GEMINI_API_KEY=settings.GEMINI_API_KEY,
|
||||
DB_PATH=":memory:",
|
||||
ENABLED_MARKETS="KR",
|
||||
MODE=mode,
|
||||
)
|
||||
b = KISBroker(s)
|
||||
b._access_token = "tok"
|
||||
b._token_expires_at = float("inf")
|
||||
b._rate_limiter.acquire = AsyncMock()
|
||||
return b
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_paper_mode_returns_empty(self, settings) -> None:
|
||||
"""Paper mode must return [] immediately without any API call."""
|
||||
broker = self._make_broker(settings, "paper")
|
||||
|
||||
with patch("aiohttp.ClientSession.get") as mock_get:
|
||||
result = await broker.get_domestic_pending_orders()
|
||||
|
||||
assert result == []
|
||||
mock_get.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_mode_calls_tttc0084r_with_correct_params(
|
||||
self, settings
|
||||
) -> None:
|
||||
"""Live mode must call TTTC0084R with INQR_DVSN_1/2 and paging params."""
|
||||
broker = self._make_broker(settings, "live")
|
||||
pending = [{"odno": "001", "pdno": "005930", "psbl_qty": "10"}]
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(return_value={"output": pending})
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.get", return_value=mock_resp) as mock_get:
|
||||
result = await broker.get_domestic_pending_orders()
|
||||
|
||||
assert result == pending
|
||||
headers = mock_get.call_args[1].get("headers", {})
|
||||
assert headers["tr_id"] == "TTTC0084R"
|
||||
params = mock_get.call_args[1].get("params", {})
|
||||
assert params["INQR_DVSN_1"] == "0"
|
||||
assert params["INQR_DVSN_2"] == "0"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_mode_connection_error(self, settings) -> None:
|
||||
"""Network error must raise ConnectionError."""
|
||||
import aiohttp as _aiohttp
|
||||
|
||||
broker = self._make_broker(settings, "live")
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.get",
|
||||
side_effect=_aiohttp.ClientError("timeout"),
|
||||
):
|
||||
with pytest.raises(ConnectionError):
|
||||
await broker.get_domestic_pending_orders()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Domestic Order Cancellation (cancel_domestic_order)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCancelDomesticOrder:
|
||||
"""cancel_domestic_order must use correct TR_ID and build body correctly."""
|
||||
|
||||
def _make_broker(self, settings, mode: str) -> KISBroker:
|
||||
from src.config import Settings
|
||||
|
||||
s = Settings(
|
||||
KIS_APP_KEY=settings.KIS_APP_KEY,
|
||||
KIS_APP_SECRET=settings.KIS_APP_SECRET,
|
||||
KIS_ACCOUNT_NO=settings.KIS_ACCOUNT_NO,
|
||||
GEMINI_API_KEY=settings.GEMINI_API_KEY,
|
||||
DB_PATH=":memory:",
|
||||
ENABLED_MARKETS="KR",
|
||||
MODE=mode,
|
||||
)
|
||||
b = KISBroker(s)
|
||||
b._access_token = "tok"
|
||||
b._token_expires_at = float("inf")
|
||||
b._rate_limiter.acquire = AsyncMock()
|
||||
return b
|
||||
|
||||
def _make_post_mocks(self, order_payload: dict) -> tuple:
|
||||
mock_hash = AsyncMock()
|
||||
mock_hash.status = 200
|
||||
mock_hash.json = AsyncMock(return_value={"HASH": "h"})
|
||||
mock_hash.__aenter__ = AsyncMock(return_value=mock_hash)
|
||||
mock_hash.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_order = AsyncMock()
|
||||
mock_order.status = 200
|
||||
mock_order.json = AsyncMock(return_value=order_payload)
|
||||
mock_order.__aenter__ = AsyncMock(return_value=mock_order)
|
||||
mock_order.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
return mock_hash, mock_order
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_uses_tttc0013u(self, settings) -> None:
|
||||
"""Live mode must use TR_ID TTTC0013U."""
|
||||
broker = self._make_broker(settings, "live")
|
||||
mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"})
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
|
||||
) as mock_post:
|
||||
await broker.cancel_domestic_order("005930", "ORD001", "BRNO01", 5)
|
||||
|
||||
order_headers = mock_post.call_args_list[1][1].get("headers", {})
|
||||
assert order_headers["tr_id"] == "TTTC0013U"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_paper_uses_vttc0013u(self, settings) -> None:
|
||||
"""Paper mode must use TR_ID VTTC0013U."""
|
||||
broker = self._make_broker(settings, "paper")
|
||||
mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"})
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
|
||||
) as mock_post:
|
||||
await broker.cancel_domestic_order("005930", "ORD001", "BRNO01", 5)
|
||||
|
||||
order_headers = mock_post.call_args_list[1][1].get("headers", {})
|
||||
assert order_headers["tr_id"] == "VTTC0013U"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_sets_rvse_cncl_dvsn_cd_02(self, settings) -> None:
|
||||
"""Body must have RVSE_CNCL_DVSN_CD='02' (취소) and QTY_ALL_ORD_YN='Y'."""
|
||||
broker = self._make_broker(settings, "live")
|
||||
mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"})
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
|
||||
) as mock_post:
|
||||
await broker.cancel_domestic_order("005930", "ORD001", "BRNO01", 5)
|
||||
|
||||
body = mock_post.call_args_list[1][1].get("json", {})
|
||||
assert body["RVSE_CNCL_DVSN_CD"] == "02"
|
||||
assert body["QTY_ALL_ORD_YN"] == "Y"
|
||||
assert body["ORD_UNPR"] == "0"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_sets_krx_fwdg_ord_orgno_in_body(self, settings) -> None:
|
||||
"""Body must include KRX_FWDG_ORD_ORGNO and ORGN_ODNO from arguments."""
|
||||
broker = self._make_broker(settings, "live")
|
||||
mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"})
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
|
||||
) as mock_post:
|
||||
await broker.cancel_domestic_order("005930", "ORD123", "BRN456", 3)
|
||||
|
||||
body = mock_post.call_args_list[1][1].get("json", {})
|
||||
assert body["KRX_FWDG_ORD_ORGNO"] == "BRN456"
|
||||
assert body["ORGN_ODNO"] == "ORD123"
|
||||
assert body["ORD_QTY"] == "3"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_sets_hashkey_header(self, settings) -> None:
|
||||
"""Request must include hashkey header (same pattern as send_order)."""
|
||||
broker = self._make_broker(settings, "live")
|
||||
mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"})
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
|
||||
) as mock_post:
|
||||
await broker.cancel_domestic_order("005930", "ORD001", "BRNO01", 2)
|
||||
|
||||
order_headers = mock_post.call_args_list[1][1].get("headers", {})
|
||||
assert "hashkey" in order_headers
|
||||
assert order_headers["hashkey"] == "h"
|
||||
|
||||
@@ -10,6 +10,7 @@ import pytest
|
||||
from src.context.aggregator import ContextAggregator
|
||||
from src.context.layer import LAYER_CONFIG, ContextLayer
|
||||
from src.context.store import ContextStore
|
||||
from src.context.summarizer import ContextSummarizer
|
||||
from src.db import init_db, log_trade
|
||||
|
||||
|
||||
@@ -161,7 +162,7 @@ class TestContextAggregator:
|
||||
self, aggregator: ContextAggregator, db_conn: sqlite3.Connection
|
||||
) -> None:
|
||||
"""Test aggregating daily metrics from trades."""
|
||||
date = "2026-02-04"
|
||||
date = datetime.now(UTC).date().isoformat()
|
||||
|
||||
# Create sample trades
|
||||
log_trade(db_conn, "005930", "BUY", 85, "Good signal", quantity=10, price=70000, pnl=500)
|
||||
@@ -175,36 +176,44 @@ class TestContextAggregator:
|
||||
db_conn.commit()
|
||||
|
||||
# Aggregate
|
||||
aggregator.aggregate_daily_from_trades(date)
|
||||
aggregator.aggregate_daily_from_trades(date, market="KR")
|
||||
|
||||
# Verify L6 contexts
|
||||
store = aggregator.store
|
||||
assert store.get_context(ContextLayer.L6_DAILY, date, "trade_count") == 3
|
||||
assert store.get_context(ContextLayer.L6_DAILY, date, "buys") == 1
|
||||
assert store.get_context(ContextLayer.L6_DAILY, date, "sells") == 1
|
||||
assert store.get_context(ContextLayer.L6_DAILY, date, "holds") == 1
|
||||
assert store.get_context(ContextLayer.L6_DAILY, date, "total_pnl") == 2000.0
|
||||
assert store.get_context(ContextLayer.L6_DAILY, date, "unique_stocks") == 3
|
||||
assert store.get_context(ContextLayer.L6_DAILY, date, "trade_count_KR") == 3
|
||||
assert store.get_context(ContextLayer.L6_DAILY, date, "buys_KR") == 1
|
||||
assert store.get_context(ContextLayer.L6_DAILY, date, "sells_KR") == 1
|
||||
assert store.get_context(ContextLayer.L6_DAILY, date, "holds_KR") == 1
|
||||
assert store.get_context(ContextLayer.L6_DAILY, date, "total_pnl_KR") == 2000.0
|
||||
assert store.get_context(ContextLayer.L6_DAILY, date, "unique_stocks_KR") == 3
|
||||
# 2 wins, 0 losses
|
||||
assert store.get_context(ContextLayer.L6_DAILY, date, "win_rate") == 100.0
|
||||
assert store.get_context(ContextLayer.L6_DAILY, date, "win_rate_KR") == 100.0
|
||||
|
||||
def test_aggregate_weekly_from_daily(self, aggregator: ContextAggregator) -> None:
|
||||
"""Test aggregating weekly metrics from daily."""
|
||||
week = "2026-W06"
|
||||
|
||||
# Set daily contexts
|
||||
aggregator.store.set_context(ContextLayer.L6_DAILY, "2026-02-02", "total_pnl", 100.0)
|
||||
aggregator.store.set_context(ContextLayer.L6_DAILY, "2026-02-03", "total_pnl", 200.0)
|
||||
aggregator.store.set_context(ContextLayer.L6_DAILY, "2026-02-02", "avg_confidence", 80.0)
|
||||
aggregator.store.set_context(ContextLayer.L6_DAILY, "2026-02-03", "avg_confidence", 85.0)
|
||||
aggregator.store.set_context(
|
||||
ContextLayer.L6_DAILY, "2026-02-02", "total_pnl_KR", 100.0
|
||||
)
|
||||
aggregator.store.set_context(
|
||||
ContextLayer.L6_DAILY, "2026-02-03", "total_pnl_KR", 200.0
|
||||
)
|
||||
aggregator.store.set_context(
|
||||
ContextLayer.L6_DAILY, "2026-02-02", "avg_confidence_KR", 80.0
|
||||
)
|
||||
aggregator.store.set_context(
|
||||
ContextLayer.L6_DAILY, "2026-02-03", "avg_confidence_KR", 85.0
|
||||
)
|
||||
|
||||
# Aggregate
|
||||
aggregator.aggregate_weekly_from_daily(week)
|
||||
|
||||
# Verify L5 contexts
|
||||
store = aggregator.store
|
||||
weekly_pnl = store.get_context(ContextLayer.L5_WEEKLY, week, "weekly_pnl")
|
||||
avg_conf = store.get_context(ContextLayer.L5_WEEKLY, week, "avg_confidence")
|
||||
weekly_pnl = store.get_context(ContextLayer.L5_WEEKLY, week, "weekly_pnl_KR")
|
||||
avg_conf = store.get_context(ContextLayer.L5_WEEKLY, week, "avg_confidence_KR")
|
||||
|
||||
assert weekly_pnl == 300.0
|
||||
assert avg_conf == 82.5
|
||||
@@ -214,9 +223,15 @@ class TestContextAggregator:
|
||||
month = "2026-02"
|
||||
|
||||
# Set weekly contexts
|
||||
aggregator.store.set_context(ContextLayer.L5_WEEKLY, "2026-W05", "weekly_pnl", 100.0)
|
||||
aggregator.store.set_context(ContextLayer.L5_WEEKLY, "2026-W06", "weekly_pnl", 200.0)
|
||||
aggregator.store.set_context(ContextLayer.L5_WEEKLY, "2026-W07", "weekly_pnl", 150.0)
|
||||
aggregator.store.set_context(
|
||||
ContextLayer.L5_WEEKLY, "2026-W05", "weekly_pnl_KR", 100.0
|
||||
)
|
||||
aggregator.store.set_context(
|
||||
ContextLayer.L5_WEEKLY, "2026-W06", "weekly_pnl_KR", 200.0
|
||||
)
|
||||
aggregator.store.set_context(
|
||||
ContextLayer.L5_WEEKLY, "2026-W07", "weekly_pnl_KR", 150.0
|
||||
)
|
||||
|
||||
# Aggregate
|
||||
aggregator.aggregate_monthly_from_weekly(month)
|
||||
@@ -285,7 +300,7 @@ class TestContextAggregator:
|
||||
self, aggregator: ContextAggregator, db_conn: sqlite3.Connection
|
||||
) -> None:
|
||||
"""Test running all aggregations from L7 to L1."""
|
||||
date = "2026-02-04"
|
||||
date = datetime.now(UTC).date().isoformat()
|
||||
|
||||
# Create sample trades
|
||||
log_trade(db_conn, "005930", "BUY", 85, "Good signal", quantity=10, price=70000, pnl=1000)
|
||||
@@ -299,10 +314,18 @@ class TestContextAggregator:
|
||||
|
||||
# Verify data exists in each layer
|
||||
store = aggregator.store
|
||||
assert store.get_context(ContextLayer.L6_DAILY, date, "total_pnl") == 1000.0
|
||||
current_week = datetime.now(UTC).strftime("%Y-W%V")
|
||||
assert store.get_context(ContextLayer.L5_WEEKLY, current_week, "weekly_pnl") is not None
|
||||
# Further layers depend on time alignment, just verify no crashes
|
||||
assert store.get_context(ContextLayer.L6_DAILY, date, "total_pnl_KR") == 1000.0
|
||||
from datetime import date as date_cls
|
||||
trade_date = date_cls.fromisoformat(date)
|
||||
iso_year, iso_week, _ = trade_date.isocalendar()
|
||||
trade_week = f"{iso_year}-W{iso_week:02d}"
|
||||
assert store.get_context(ContextLayer.L5_WEEKLY, trade_week, "weekly_pnl_KR") is not None
|
||||
trade_month = f"{trade_date.year}-{trade_date.month:02d}"
|
||||
trade_quarter = f"{trade_date.year}-Q{(trade_date.month - 1) // 3 + 1}"
|
||||
trade_year = str(trade_date.year)
|
||||
assert store.get_context(ContextLayer.L4_MONTHLY, trade_month, "monthly_pnl") == 1000.0
|
||||
assert store.get_context(ContextLayer.L3_QUARTERLY, trade_quarter, "quarterly_pnl") == 1000.0
|
||||
assert store.get_context(ContextLayer.L2_ANNUAL, trade_year, "annual_pnl") == 1000.0
|
||||
|
||||
|
||||
class TestLayerMetadata:
|
||||
@@ -348,3 +371,259 @@ class TestLayerMetadata:
|
||||
|
||||
# L1 aggregates from L2
|
||||
assert LAYER_CONFIG[ContextLayer.L1_LEGACY].aggregation_source == ContextLayer.L2_ANNUAL
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ContextSummarizer tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def summarizer(db_conn: sqlite3.Connection) -> ContextSummarizer:
|
||||
"""Provide a ContextSummarizer backed by an in-memory store."""
|
||||
return ContextSummarizer(ContextStore(db_conn))
|
||||
|
||||
|
||||
class TestContextSummarizer:
|
||||
"""Test suite for ContextSummarizer."""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# summarize_numeric_values
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_summarize_empty_values(self, summarizer: ContextSummarizer) -> None:
|
||||
"""Empty list must return SummaryStats with count=0 and no other fields."""
|
||||
stats = summarizer.summarize_numeric_values([])
|
||||
assert stats.count == 0
|
||||
assert stats.mean is None
|
||||
assert stats.min is None
|
||||
assert stats.max is None
|
||||
|
||||
def test_summarize_single_value(self, summarizer: ContextSummarizer) -> None:
|
||||
"""Single-element list must return correct stats with std=0 and trend=flat."""
|
||||
stats = summarizer.summarize_numeric_values([42.0])
|
||||
assert stats.count == 1
|
||||
assert stats.mean == 42.0
|
||||
assert stats.std == 0.0
|
||||
assert stats.trend == "flat"
|
||||
|
||||
def test_summarize_upward_trend(self, summarizer: ContextSummarizer) -> None:
|
||||
"""Increasing values must produce trend='up'."""
|
||||
values = [1.0, 2.0, 3.0, 10.0, 20.0, 30.0]
|
||||
stats = summarizer.summarize_numeric_values(values)
|
||||
assert stats.trend == "up"
|
||||
|
||||
def test_summarize_downward_trend(self, summarizer: ContextSummarizer) -> None:
|
||||
"""Decreasing values must produce trend='down'."""
|
||||
values = [30.0, 20.0, 10.0, 3.0, 2.0, 1.0]
|
||||
stats = summarizer.summarize_numeric_values(values)
|
||||
assert stats.trend == "down"
|
||||
|
||||
def test_summarize_flat_trend(self, summarizer: ContextSummarizer) -> None:
|
||||
"""Stable values must produce trend='flat'."""
|
||||
values = [100.0, 100.1, 99.9, 100.0, 100.2, 99.8]
|
||||
stats = summarizer.summarize_numeric_values(values)
|
||||
assert stats.trend == "flat"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# summarize_layer
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_summarize_layer_no_data(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""summarize_layer with no data must return the 'No data' sentinel."""
|
||||
result = summarizer.summarize_layer(ContextLayer.L6_DAILY)
|
||||
assert result["count"] == 0
|
||||
assert "No data" in result["summary"]
|
||||
|
||||
def test_summarize_layer_numeric(
|
||||
self, summarizer: ContextSummarizer, db_conn: sqlite3.Connection
|
||||
) -> None:
|
||||
"""summarize_layer must collect numeric values and produce stats."""
|
||||
store = summarizer.store
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "total_pnl", 100.0)
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-02", "total_pnl", 200.0)
|
||||
|
||||
result = summarizer.summarize_layer(ContextLayer.L6_DAILY)
|
||||
assert "total_entries" in result
|
||||
|
||||
def test_summarize_layer_with_dict_values(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""summarize_layer must handle dict values by extracting numeric subkeys."""
|
||||
store = summarizer.store
|
||||
# set_context serialises the value as JSON, so passing a dict works
|
||||
store.set_context(
|
||||
ContextLayer.L6_DAILY, "2026-02-01", "metrics",
|
||||
{"win_rate": 65.0, "label": "good"}
|
||||
)
|
||||
|
||||
result = summarizer.summarize_layer(ContextLayer.L6_DAILY)
|
||||
assert "total_entries" in result
|
||||
# numeric subkey "win_rate" should appear as "metrics.win_rate"
|
||||
assert "metrics.win_rate" in result
|
||||
|
||||
def test_summarize_layer_with_string_values(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""summarize_layer must count string values separately."""
|
||||
store = summarizer.store
|
||||
# set_context stores string values as JSON-encoded strings
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "outlook", "BULLISH")
|
||||
|
||||
result = summarizer.summarize_layer(ContextLayer.L6_DAILY)
|
||||
# String fields contribute a `<key>_count` entry
|
||||
assert "outlook_count" in result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# rolling_window_summary
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_rolling_window_summary_basic(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""rolling_window_summary must return the expected structure."""
|
||||
store = summarizer.store
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "pnl", 500.0)
|
||||
|
||||
result = summarizer.rolling_window_summary(ContextLayer.L6_DAILY)
|
||||
assert "window_days" in result
|
||||
assert "recent_data" in result
|
||||
assert "historical_summary" in result
|
||||
|
||||
def test_rolling_window_summary_no_older_data(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""rolling_window_summary with summarize_older=False skips history."""
|
||||
result = summarizer.rolling_window_summary(
|
||||
ContextLayer.L6_DAILY, summarize_older=False
|
||||
)
|
||||
assert result["historical_summary"] == {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# aggregate_to_higher_layer
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_aggregate_to_higher_layer_mean(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""aggregate_to_higher_layer with 'mean' via dict subkeys returns average."""
|
||||
store = summarizer.store
|
||||
# Use different outer keys but same inner metric key so get_all_contexts
|
||||
# returns multiple rows with the target subkey.
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day1", {"pnl": 100.0})
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day2", {"pnl": 200.0})
|
||||
|
||||
result = summarizer.aggregate_to_higher_layer(
|
||||
ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY, "pnl", "mean"
|
||||
)
|
||||
assert result == pytest.approx(150.0)
|
||||
|
||||
def test_aggregate_to_higher_layer_sum(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""aggregate_to_higher_layer with 'sum' must return the total."""
|
||||
store = summarizer.store
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day1", {"pnl": 100.0})
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day2", {"pnl": 200.0})
|
||||
|
||||
result = summarizer.aggregate_to_higher_layer(
|
||||
ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY, "pnl", "sum"
|
||||
)
|
||||
assert result == pytest.approx(300.0)
|
||||
|
||||
def test_aggregate_to_higher_layer_max(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""aggregate_to_higher_layer with 'max' must return the maximum."""
|
||||
store = summarizer.store
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day1", {"pnl": 100.0})
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day2", {"pnl": 200.0})
|
||||
|
||||
result = summarizer.aggregate_to_higher_layer(
|
||||
ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY, "pnl", "max"
|
||||
)
|
||||
assert result == pytest.approx(200.0)
|
||||
|
||||
def test_aggregate_to_higher_layer_min(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""aggregate_to_higher_layer with 'min' must return the minimum."""
|
||||
store = summarizer.store
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day1", {"pnl": 100.0})
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day2", {"pnl": 200.0})
|
||||
|
||||
result = summarizer.aggregate_to_higher_layer(
|
||||
ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY, "pnl", "min"
|
||||
)
|
||||
assert result == pytest.approx(100.0)
|
||||
|
||||
def test_aggregate_to_higher_layer_no_data(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""aggregate_to_higher_layer with no matching key must return None."""
|
||||
result = summarizer.aggregate_to_higher_layer(
|
||||
ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY, "nonexistent", "mean"
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_aggregate_to_higher_layer_unknown_func_defaults_to_mean(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""Unknown aggregation function must fall back to mean."""
|
||||
store = summarizer.store
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day1", {"pnl": 100.0})
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day2", {"pnl": 200.0})
|
||||
|
||||
result = summarizer.aggregate_to_higher_layer(
|
||||
ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY, "pnl", "unknown_func"
|
||||
)
|
||||
assert result == pytest.approx(150.0)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# create_compact_summary + format_summary_for_prompt
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_create_compact_summary(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""create_compact_summary must produce a dict keyed by layer value."""
|
||||
store = summarizer.store
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "pnl", 100.0)
|
||||
|
||||
result = summarizer.create_compact_summary([ContextLayer.L6_DAILY])
|
||||
assert ContextLayer.L6_DAILY.value in result
|
||||
|
||||
def test_format_summary_for_prompt_with_numeric_metrics(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""format_summary_for_prompt must render avg/trend fields."""
|
||||
store = summarizer.store
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "pnl", 100.0)
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-02", "pnl", 200.0)
|
||||
|
||||
compact = summarizer.create_compact_summary([ContextLayer.L6_DAILY])
|
||||
text = summarizer.format_summary_for_prompt(compact)
|
||||
assert isinstance(text, str)
|
||||
|
||||
def test_format_summary_for_prompt_skips_empty_layers(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""format_summary_for_prompt must skip layers with no metrics."""
|
||||
summary = {ContextLayer.L6_DAILY.value: {}}
|
||||
text = summarizer.format_summary_for_prompt(summary)
|
||||
assert text == ""
|
||||
|
||||
def test_format_summary_non_dict_value(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""format_summary_for_prompt must render non-dict values as plain text."""
|
||||
summary = {
|
||||
"daily": {
|
||||
"plain_count": 42,
|
||||
}
|
||||
}
|
||||
text = summarizer.format_summary_for_prompt(summary)
|
||||
assert "plain_count" in text
|
||||
assert "42" in text
|
||||
|
||||
104
tests/test_context_scheduler.py
Normal file
104
tests/test_context_scheduler.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Tests for ContextScheduler."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from src.context.scheduler import ContextScheduler
|
||||
|
||||
|
||||
@dataclass
|
||||
class StubAggregator:
|
||||
"""Stub aggregator that records calls."""
|
||||
|
||||
weekly_calls: list[str]
|
||||
monthly_calls: list[str]
|
||||
quarterly_calls: list[str]
|
||||
annual_calls: list[str]
|
||||
legacy_calls: int
|
||||
|
||||
def aggregate_weekly_from_daily(self, week: str) -> None:
|
||||
self.weekly_calls.append(week)
|
||||
|
||||
def aggregate_monthly_from_weekly(self, month: str) -> None:
|
||||
self.monthly_calls.append(month)
|
||||
|
||||
def aggregate_quarterly_from_monthly(self, quarter: str) -> None:
|
||||
self.quarterly_calls.append(quarter)
|
||||
|
||||
def aggregate_annual_from_quarterly(self, year: str) -> None:
|
||||
self.annual_calls.append(year)
|
||||
|
||||
def aggregate_legacy_from_annual(self) -> None:
|
||||
self.legacy_calls += 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class StubStore:
|
||||
"""Stub store that records cleanup calls."""
|
||||
|
||||
cleanup_calls: int = 0
|
||||
|
||||
def cleanup_expired_contexts(self) -> None:
|
||||
self.cleanup_calls += 1
|
||||
|
||||
|
||||
def make_scheduler() -> tuple[ContextScheduler, StubAggregator, StubStore]:
|
||||
aggregator = StubAggregator([], [], [], [], 0)
|
||||
store = StubStore()
|
||||
scheduler = ContextScheduler(aggregator=aggregator, store=store)
|
||||
return scheduler, aggregator, store
|
||||
|
||||
|
||||
def test_run_if_due_weekly() -> None:
|
||||
scheduler, aggregator, store = make_scheduler()
|
||||
now = datetime(2026, 2, 8, 10, 0, tzinfo=UTC) # Sunday
|
||||
|
||||
result = scheduler.run_if_due(now)
|
||||
|
||||
assert result.weekly is True
|
||||
assert aggregator.weekly_calls == ["2026-W06"]
|
||||
assert store.cleanup_calls == 1
|
||||
|
||||
|
||||
def test_run_if_due_monthly() -> None:
|
||||
scheduler, aggregator, _store = make_scheduler()
|
||||
now = datetime(2026, 2, 28, 12, 0, tzinfo=UTC) # Last day of month
|
||||
|
||||
result = scheduler.run_if_due(now)
|
||||
|
||||
assert result.monthly is True
|
||||
assert aggregator.monthly_calls == ["2026-02"]
|
||||
|
||||
|
||||
def test_run_if_due_quarterly() -> None:
|
||||
scheduler, aggregator, _store = make_scheduler()
|
||||
now = datetime(2026, 3, 31, 12, 0, tzinfo=UTC) # Last day of Q1
|
||||
|
||||
result = scheduler.run_if_due(now)
|
||||
|
||||
assert result.quarterly is True
|
||||
assert aggregator.quarterly_calls == ["2026-Q1"]
|
||||
|
||||
|
||||
def test_run_if_due_annual_and_legacy() -> None:
|
||||
scheduler, aggregator, _store = make_scheduler()
|
||||
now = datetime(2026, 12, 31, 12, 0, tzinfo=UTC)
|
||||
|
||||
result = scheduler.run_if_due(now)
|
||||
|
||||
assert result.annual is True
|
||||
assert result.legacy is True
|
||||
assert aggregator.annual_calls == ["2026"]
|
||||
assert aggregator.legacy_calls == 1
|
||||
|
||||
|
||||
def test_cleanup_runs_once_per_day() -> None:
|
||||
scheduler, _aggregator, store = make_scheduler()
|
||||
now = datetime(2026, 2, 9, 9, 0, tzinfo=UTC)
|
||||
|
||||
scheduler.run_if_due(now)
|
||||
scheduler.run_if_due(now)
|
||||
|
||||
assert store.cleanup_calls == 1
|
||||
387
tests/test_daily_review.py
Normal file
387
tests/test_daily_review.py
Normal file
@@ -0,0 +1,387 @@
|
||||
"""Tests for DailyReviewer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from src.context.layer import ContextLayer
|
||||
from src.context.store import ContextStore
|
||||
from src.db import init_db, log_trade
|
||||
from src.evolution.daily_review import DailyReviewer
|
||||
from src.evolution.scorecard import DailyScorecard
|
||||
from src.logging.decision_logger import DecisionLogger
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
TODAY = datetime.now(UTC).strftime("%Y-%m-%d")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_conn() -> sqlite3.Connection:
|
||||
return init_db(":memory:")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def context_store(db_conn: sqlite3.Connection) -> ContextStore:
|
||||
return ContextStore(db_conn)
|
||||
|
||||
|
||||
def _log_decision(
|
||||
logger: DecisionLogger,
|
||||
*,
|
||||
stock_code: str,
|
||||
market: str,
|
||||
action: str,
|
||||
confidence: int,
|
||||
scenario_match: dict[str, float] | None = None,
|
||||
) -> str:
|
||||
return logger.log_decision(
|
||||
stock_code=stock_code,
|
||||
market=market,
|
||||
exchange_code="KRX" if market == "KR" else "NASDAQ",
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
rationale="test",
|
||||
context_snapshot={"scenario_match": scenario_match or {}},
|
||||
input_data={"stock_code": stock_code},
|
||||
)
|
||||
|
||||
|
||||
def test_generate_scorecard_market_scoped(
|
||||
db_conn: sqlite3.Connection, context_store: ContextStore,
|
||||
) -> None:
|
||||
reviewer = DailyReviewer(db_conn, context_store)
|
||||
logger = DecisionLogger(db_conn)
|
||||
|
||||
buy_id = _log_decision(
|
||||
logger,
|
||||
stock_code="005930",
|
||||
market="KR",
|
||||
action="BUY",
|
||||
confidence=90,
|
||||
scenario_match={"rsi": 29.0},
|
||||
)
|
||||
_log_decision(
|
||||
logger,
|
||||
stock_code="000660",
|
||||
market="KR",
|
||||
action="HOLD",
|
||||
confidence=60,
|
||||
)
|
||||
_log_decision(
|
||||
logger,
|
||||
stock_code="AAPL",
|
||||
market="US",
|
||||
action="SELL",
|
||||
confidence=80,
|
||||
scenario_match={"volume_ratio": 2.1},
|
||||
)
|
||||
|
||||
log_trade(
|
||||
db_conn,
|
||||
"005930",
|
||||
"BUY",
|
||||
90,
|
||||
"buy",
|
||||
quantity=1,
|
||||
price=100.0,
|
||||
pnl=10.0,
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
decision_id=buy_id,
|
||||
)
|
||||
log_trade(
|
||||
db_conn,
|
||||
"000660",
|
||||
"HOLD",
|
||||
60,
|
||||
"hold",
|
||||
quantity=0,
|
||||
price=0.0,
|
||||
pnl=0.0,
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
)
|
||||
log_trade(
|
||||
db_conn,
|
||||
"AAPL",
|
||||
"SELL",
|
||||
80,
|
||||
"sell",
|
||||
quantity=1,
|
||||
price=200.0,
|
||||
pnl=-5.0,
|
||||
market="US",
|
||||
exchange_code="NASDAQ",
|
||||
)
|
||||
|
||||
scorecard = reviewer.generate_scorecard(TODAY, "KR")
|
||||
|
||||
assert scorecard.market == "KR"
|
||||
assert scorecard.total_decisions == 2
|
||||
assert scorecard.buys == 1
|
||||
assert scorecard.sells == 0
|
||||
assert scorecard.holds == 1
|
||||
assert scorecard.total_pnl == 10.0
|
||||
assert scorecard.win_rate == 100.0
|
||||
assert scorecard.avg_confidence == 75.0
|
||||
assert scorecard.scenario_match_rate == 50.0
|
||||
|
||||
|
||||
def test_generate_scorecard_top_winners_and_losers(
|
||||
db_conn: sqlite3.Connection, context_store: ContextStore,
|
||||
) -> None:
|
||||
reviewer = DailyReviewer(db_conn, context_store)
|
||||
logger = DecisionLogger(db_conn)
|
||||
|
||||
for code, pnl in [("005930", 30.0), ("000660", 10.0), ("035420", -15.0), ("051910", -5.0)]:
|
||||
decision_id = _log_decision(
|
||||
logger,
|
||||
stock_code=code,
|
||||
market="KR",
|
||||
action="BUY" if pnl >= 0 else "SELL",
|
||||
confidence=80,
|
||||
scenario_match={"rsi": 30.0},
|
||||
)
|
||||
log_trade(
|
||||
db_conn,
|
||||
code,
|
||||
"BUY" if pnl >= 0 else "SELL",
|
||||
80,
|
||||
"test",
|
||||
quantity=1,
|
||||
price=100.0,
|
||||
pnl=pnl,
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
decision_id=decision_id,
|
||||
)
|
||||
|
||||
scorecard = reviewer.generate_scorecard(TODAY, "KR")
|
||||
assert scorecard.top_winners == ["005930", "000660"]
|
||||
assert scorecard.top_losers == ["035420", "051910"]
|
||||
|
||||
|
||||
def test_generate_scorecard_empty_day(
|
||||
db_conn: sqlite3.Connection, context_store: ContextStore,
|
||||
) -> None:
|
||||
reviewer = DailyReviewer(db_conn, context_store)
|
||||
scorecard = reviewer.generate_scorecard(TODAY, "KR")
|
||||
|
||||
assert scorecard.total_decisions == 0
|
||||
assert scorecard.total_pnl == 0.0
|
||||
assert scorecard.win_rate == 0.0
|
||||
assert scorecard.avg_confidence == 0.0
|
||||
assert scorecard.scenario_match_rate == 0.0
|
||||
assert scorecard.top_winners == []
|
||||
assert scorecard.top_losers == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_lessons_without_gemini_returns_empty(
|
||||
db_conn: sqlite3.Connection, context_store: ContextStore,
|
||||
) -> None:
|
||||
reviewer = DailyReviewer(db_conn, context_store, gemini_client=None)
|
||||
lessons = await reviewer.generate_lessons(
|
||||
DailyScorecard(
|
||||
date="2026-02-14",
|
||||
market="KR",
|
||||
total_decisions=1,
|
||||
buys=1,
|
||||
sells=0,
|
||||
holds=0,
|
||||
total_pnl=5.0,
|
||||
win_rate=100.0,
|
||||
avg_confidence=90.0,
|
||||
scenario_match_rate=100.0,
|
||||
)
|
||||
)
|
||||
assert lessons == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_lessons_parses_json_array(
|
||||
db_conn: sqlite3.Connection, context_store: ContextStore,
|
||||
) -> None:
|
||||
mock_gemini = MagicMock()
|
||||
mock_gemini.decide = AsyncMock(
|
||||
return_value=SimpleNamespace(rationale='["Cut losers earlier", "Reduce midday churn"]')
|
||||
)
|
||||
reviewer = DailyReviewer(db_conn, context_store, gemini_client=mock_gemini)
|
||||
|
||||
lessons = await reviewer.generate_lessons(
|
||||
DailyScorecard(
|
||||
date="2026-02-14",
|
||||
market="KR",
|
||||
total_decisions=3,
|
||||
buys=1,
|
||||
sells=1,
|
||||
holds=1,
|
||||
total_pnl=-2.5,
|
||||
win_rate=50.0,
|
||||
avg_confidence=70.0,
|
||||
scenario_match_rate=66.7,
|
||||
)
|
||||
)
|
||||
assert lessons == ["Cut losers earlier", "Reduce midday churn"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_lessons_fallback_to_lines(
|
||||
db_conn: sqlite3.Connection, context_store: ContextStore,
|
||||
) -> None:
|
||||
mock_gemini = MagicMock()
|
||||
mock_gemini.decide = AsyncMock(
|
||||
return_value=SimpleNamespace(rationale="- Keep risk tighter\n- Increase selectivity")
|
||||
)
|
||||
reviewer = DailyReviewer(db_conn, context_store, gemini_client=mock_gemini)
|
||||
|
||||
lessons = await reviewer.generate_lessons(
|
||||
DailyScorecard(
|
||||
date="2026-02-14",
|
||||
market="US",
|
||||
total_decisions=2,
|
||||
buys=1,
|
||||
sells=1,
|
||||
holds=0,
|
||||
total_pnl=1.0,
|
||||
win_rate=50.0,
|
||||
avg_confidence=75.0,
|
||||
scenario_match_rate=100.0,
|
||||
)
|
||||
)
|
||||
assert lessons == ["Keep risk tighter", "Increase selectivity"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_lessons_handles_gemini_error(
|
||||
db_conn: sqlite3.Connection, context_store: ContextStore,
|
||||
) -> None:
|
||||
mock_gemini = MagicMock()
|
||||
mock_gemini.decide = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
reviewer = DailyReviewer(db_conn, context_store, gemini_client=mock_gemini)
|
||||
|
||||
lessons = await reviewer.generate_lessons(
|
||||
DailyScorecard(
|
||||
date="2026-02-14",
|
||||
market="US",
|
||||
total_decisions=0,
|
||||
buys=0,
|
||||
sells=0,
|
||||
holds=0,
|
||||
total_pnl=0.0,
|
||||
win_rate=0.0,
|
||||
avg_confidence=0.0,
|
||||
scenario_match_rate=0.0,
|
||||
)
|
||||
)
|
||||
assert lessons == []
|
||||
|
||||
|
||||
def test_store_scorecard_in_context(
|
||||
db_conn: sqlite3.Connection, context_store: ContextStore,
|
||||
) -> None:
|
||||
reviewer = DailyReviewer(db_conn, context_store)
|
||||
scorecard = DailyScorecard(
|
||||
date="2026-02-14",
|
||||
market="KR",
|
||||
total_decisions=5,
|
||||
buys=2,
|
||||
sells=1,
|
||||
holds=2,
|
||||
total_pnl=15.0,
|
||||
win_rate=66.67,
|
||||
avg_confidence=82.0,
|
||||
scenario_match_rate=80.0,
|
||||
lessons=["Keep position sizing stable"],
|
||||
cross_market_note="US risk-off",
|
||||
)
|
||||
|
||||
reviewer.store_scorecard_in_context(scorecard)
|
||||
|
||||
stored = context_store.get_context(
|
||||
ContextLayer.L6_DAILY,
|
||||
"2026-02-14",
|
||||
"scorecard_KR",
|
||||
)
|
||||
assert stored is not None
|
||||
assert stored["market"] == "KR"
|
||||
assert stored["total_pnl"] == 15.0
|
||||
assert stored["lessons"] == ["Keep position sizing stable"]
|
||||
|
||||
|
||||
def test_store_scorecard_key_is_market_scoped(
|
||||
db_conn: sqlite3.Connection, context_store: ContextStore,
|
||||
) -> None:
|
||||
reviewer = DailyReviewer(db_conn, context_store)
|
||||
kr = DailyScorecard(
|
||||
date="2026-02-14",
|
||||
market="KR",
|
||||
total_decisions=1,
|
||||
buys=1,
|
||||
sells=0,
|
||||
holds=0,
|
||||
total_pnl=1.0,
|
||||
win_rate=100.0,
|
||||
avg_confidence=90.0,
|
||||
scenario_match_rate=100.0,
|
||||
)
|
||||
us = DailyScorecard(
|
||||
date="2026-02-14",
|
||||
market="US",
|
||||
total_decisions=1,
|
||||
buys=0,
|
||||
sells=1,
|
||||
holds=0,
|
||||
total_pnl=-1.0,
|
||||
win_rate=0.0,
|
||||
avg_confidence=70.0,
|
||||
scenario_match_rate=100.0,
|
||||
)
|
||||
|
||||
reviewer.store_scorecard_in_context(kr)
|
||||
reviewer.store_scorecard_in_context(us)
|
||||
|
||||
kr_ctx = context_store.get_context(ContextLayer.L6_DAILY, "2026-02-14", "scorecard_KR")
|
||||
us_ctx = context_store.get_context(ContextLayer.L6_DAILY, "2026-02-14", "scorecard_US")
|
||||
|
||||
assert kr_ctx["market"] == "KR"
|
||||
assert us_ctx["market"] == "US"
|
||||
assert kr_ctx["total_pnl"] == 1.0
|
||||
assert us_ctx["total_pnl"] == -1.0
|
||||
|
||||
|
||||
def test_generate_scorecard_handles_invalid_context_snapshot(
|
||||
db_conn: sqlite3.Connection, context_store: ContextStore,
|
||||
) -> None:
|
||||
reviewer = DailyReviewer(db_conn, context_store)
|
||||
db_conn.execute(
|
||||
"""
|
||||
INSERT INTO decision_logs (
|
||||
decision_id, timestamp, stock_code, market, exchange_code,
|
||||
action, confidence, rationale, context_snapshot, input_data
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
"d1",
|
||||
"2026-02-14T09:00:00+00:00",
|
||||
"005930",
|
||||
"KR",
|
||||
"KRX",
|
||||
"HOLD",
|
||||
50,
|
||||
"test",
|
||||
"{invalid_json",
|
||||
json.dumps({}),
|
||||
),
|
||||
)
|
||||
db_conn.commit()
|
||||
|
||||
scorecard = reviewer.generate_scorecard("2026-02-14", "KR")
|
||||
assert scorecard.total_decisions == 1
|
||||
assert scorecard.scenario_match_rate == 0.0
|
||||
451
tests/test_dashboard.py
Normal file
451
tests/test_dashboard.py
Normal file
@@ -0,0 +1,451 @@
|
||||
"""Tests for dashboard endpoint handlers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from collections.abc import Callable
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from src.dashboard.app import create_dashboard_app
|
||||
from src.db import init_db
|
||||
|
||||
|
||||
def _seed_db(conn: sqlite3.Connection) -> None:
|
||||
today = datetime.now(UTC).date().isoformat()
|
||||
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO playbooks (
|
||||
date, market, status, playbook_json, generated_at,
|
||||
token_count, scenario_count, match_count
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
"2026-02-14",
|
||||
"KR",
|
||||
"ready",
|
||||
json.dumps({"market": "KR", "stock_playbooks": []}),
|
||||
"2026-02-14T08:30:00+00:00",
|
||||
123,
|
||||
2,
|
||||
1,
|
||||
),
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO playbooks (
|
||||
date, market, status, playbook_json, generated_at,
|
||||
token_count, scenario_count, match_count
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
today,
|
||||
"US_NASDAQ",
|
||||
"ready",
|
||||
json.dumps({"market": "US_NASDAQ", "stock_playbooks": []}),
|
||||
f"{today}T08:30:00+00:00",
|
||||
100,
|
||||
1,
|
||||
0,
|
||||
),
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO contexts (layer, timeframe, key, value, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
"L6_DAILY",
|
||||
"2026-02-14",
|
||||
"scorecard_KR",
|
||||
json.dumps({"market": "KR", "total_pnl": 1.5, "win_rate": 60.0}),
|
||||
"2026-02-14T15:30:00+00:00",
|
||||
"2026-02-14T15:30:00+00:00",
|
||||
),
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO contexts (layer, timeframe, key, value, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
"L7_REALTIME",
|
||||
"2026-02-14T10:00:00+00:00",
|
||||
"volatility_KR_005930",
|
||||
json.dumps({"momentum_score": 70.0}),
|
||||
"2026-02-14T10:00:00+00:00",
|
||||
"2026-02-14T10:00:00+00:00",
|
||||
),
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO decision_logs (
|
||||
decision_id, timestamp, stock_code, market, exchange_code,
|
||||
action, confidence, rationale, context_snapshot, input_data
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
"d-kr-1",
|
||||
f"{today}T09:10:00+00:00",
|
||||
"005930",
|
||||
"KR",
|
||||
"KRX",
|
||||
"BUY",
|
||||
85,
|
||||
"signal matched",
|
||||
json.dumps({"scenario_match": {"rsi": 28.0}}),
|
||||
json.dumps({"current_price": 70000}),
|
||||
),
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO decision_logs (
|
||||
decision_id, timestamp, stock_code, market, exchange_code,
|
||||
action, confidence, rationale, context_snapshot, input_data
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
"d-us-1",
|
||||
f"{today}T21:10:00+00:00",
|
||||
"AAPL",
|
||||
"US_NASDAQ",
|
||||
"NASDAQ",
|
||||
"SELL",
|
||||
80,
|
||||
"no match",
|
||||
json.dumps({"scenario_match": {}}),
|
||||
json.dumps({"current_price": 200}),
|
||||
),
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO trades (
|
||||
timestamp, stock_code, action, confidence, rationale,
|
||||
quantity, price, pnl, market, exchange_code, selection_context, decision_id
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
f"{today}T09:11:00+00:00",
|
||||
"005930",
|
||||
"BUY",
|
||||
85,
|
||||
"buy",
|
||||
1,
|
||||
70000,
|
||||
2.0,
|
||||
"KR",
|
||||
"KRX",
|
||||
None,
|
||||
"d-kr-1",
|
||||
),
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO trades (
|
||||
timestamp, stock_code, action, confidence, rationale,
|
||||
quantity, price, pnl, market, exchange_code, selection_context, decision_id
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
f"{today}T21:11:00+00:00",
|
||||
"AAPL",
|
||||
"SELL",
|
||||
80,
|
||||
"sell",
|
||||
1,
|
||||
200,
|
||||
-1.0,
|
||||
"US_NASDAQ",
|
||||
"NASDAQ",
|
||||
None,
|
||||
"d-us-1",
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def _app(tmp_path: Path) -> Any:
|
||||
db_path = tmp_path / "dashboard_test.db"
|
||||
conn = init_db(str(db_path))
|
||||
_seed_db(conn)
|
||||
conn.close()
|
||||
return create_dashboard_app(str(db_path))
|
||||
|
||||
|
||||
def _endpoint(app: Any, path: str) -> Callable[..., Any]:
|
||||
for route in app.routes:
|
||||
if getattr(route, "path", None) == path:
|
||||
return route.endpoint
|
||||
raise AssertionError(f"route not found: {path}")
|
||||
|
||||
|
||||
def test_index_serves_html(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
index = _endpoint(app, "/")
|
||||
resp = index()
|
||||
assert isinstance(resp, FileResponse)
|
||||
assert "index.html" in str(resp.path)
|
||||
|
||||
|
||||
def test_status_endpoint(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_status = _endpoint(app, "/api/status")
|
||||
body = get_status()
|
||||
assert "KR" in body["markets"]
|
||||
assert "US_NASDAQ" in body["markets"]
|
||||
assert "totals" in body
|
||||
|
||||
|
||||
def test_playbook_found(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_playbook = _endpoint(app, "/api/playbook/{date_str}")
|
||||
body = get_playbook("2026-02-14", market="KR")
|
||||
assert body["market"] == "KR"
|
||||
|
||||
|
||||
def test_playbook_not_found(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_playbook = _endpoint(app, "/api/playbook/{date_str}")
|
||||
with pytest.raises(HTTPException, match="playbook not found"):
|
||||
get_playbook("2026-02-15", market="KR")
|
||||
|
||||
|
||||
def test_scorecard_found(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_scorecard = _endpoint(app, "/api/scorecard/{date_str}")
|
||||
body = get_scorecard("2026-02-14", market="KR")
|
||||
assert body["scorecard"]["total_pnl"] == 1.5
|
||||
|
||||
|
||||
def test_scorecard_not_found(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_scorecard = _endpoint(app, "/api/scorecard/{date_str}")
|
||||
with pytest.raises(HTTPException, match="scorecard not found"):
|
||||
get_scorecard("2026-02-15", market="KR")
|
||||
|
||||
|
||||
def test_performance_all(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_performance = _endpoint(app, "/api/performance")
|
||||
body = get_performance(market="all")
|
||||
assert body["market"] == "all"
|
||||
assert body["combined"]["total_trades"] == 2
|
||||
assert len(body["by_market"]) == 2
|
||||
|
||||
|
||||
def test_performance_market_filter(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_performance = _endpoint(app, "/api/performance")
|
||||
body = get_performance(market="KR")
|
||||
assert body["market"] == "KR"
|
||||
assert body["metrics"]["total_trades"] == 1
|
||||
|
||||
|
||||
def test_performance_empty_market(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_performance = _endpoint(app, "/api/performance")
|
||||
body = get_performance(market="JP")
|
||||
assert body["metrics"]["total_trades"] == 0
|
||||
|
||||
|
||||
def test_context_layer_all(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_context_layer = _endpoint(app, "/api/context/{layer}")
|
||||
body = get_context_layer("L7_REALTIME", timeframe=None, limit=100)
|
||||
assert body["layer"] == "L7_REALTIME"
|
||||
assert body["count"] == 1
|
||||
|
||||
|
||||
def test_context_layer_timeframe_filter(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_context_layer = _endpoint(app, "/api/context/{layer}")
|
||||
body = get_context_layer("L6_DAILY", timeframe="2026-02-14", limit=100)
|
||||
assert body["count"] == 1
|
||||
assert body["entries"][0]["key"] == "scorecard_KR"
|
||||
|
||||
|
||||
def test_decisions_endpoint(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_decisions = _endpoint(app, "/api/decisions")
|
||||
body = get_decisions(market="KR", limit=50)
|
||||
assert body["count"] == 1
|
||||
assert body["decisions"][0]["decision_id"] == "d-kr-1"
|
||||
|
||||
|
||||
def test_scenarios_active_filters_non_matched(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_active_scenarios = _endpoint(app, "/api/scenarios/active")
|
||||
body = get_active_scenarios(
|
||||
market="KR",
|
||||
date_str=datetime.now(UTC).date().isoformat(),
|
||||
limit=50,
|
||||
)
|
||||
assert body["count"] == 1
|
||||
assert body["matches"][0]["stock_code"] == "005930"
|
||||
|
||||
|
||||
def test_scenarios_active_empty_when_no_matches(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_active_scenarios = _endpoint(app, "/api/scenarios/active")
|
||||
body = get_active_scenarios(market="US", date_str="2026-02-14", limit=50)
|
||||
assert body["count"] == 0
|
||||
|
||||
|
||||
def test_pnl_history_all_markets(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_pnl_history = _endpoint(app, "/api/pnl/history")
|
||||
body = get_pnl_history(days=30, market="all")
|
||||
assert body["market"] == "all"
|
||||
assert isinstance(body["labels"], list)
|
||||
assert isinstance(body["pnl"], list)
|
||||
assert len(body["labels"]) == len(body["pnl"])
|
||||
|
||||
|
||||
def test_pnl_history_market_filter(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_pnl_history = _endpoint(app, "/api/pnl/history")
|
||||
body = get_pnl_history(days=30, market="KR")
|
||||
assert body["market"] == "KR"
|
||||
# KR has 1 trade with pnl=2.0
|
||||
assert len(body["labels"]) >= 1
|
||||
assert body["pnl"][0] == 2.0
|
||||
|
||||
|
||||
def test_positions_returns_open_buy(tmp_path: Path) -> None:
|
||||
"""BUY가 마지막 거래인 종목은 포지션으로 반환되어야 한다."""
|
||||
app = _app(tmp_path)
|
||||
get_positions = _endpoint(app, "/api/positions")
|
||||
body = get_positions()
|
||||
# seed_db: 005930은 BUY (오픈), AAPL은 SELL (마지막)
|
||||
assert body["count"] == 1
|
||||
pos = body["positions"][0]
|
||||
assert pos["stock_code"] == "005930"
|
||||
assert pos["market"] == "KR"
|
||||
assert pos["quantity"] == 1
|
||||
assert pos["entry_price"] == 70000
|
||||
|
||||
|
||||
def test_positions_excludes_closed_sell(tmp_path: Path) -> None:
|
||||
"""마지막 거래가 SELL인 종목은 포지션에 나타나지 않아야 한다."""
|
||||
app = _app(tmp_path)
|
||||
get_positions = _endpoint(app, "/api/positions")
|
||||
body = get_positions()
|
||||
codes = [p["stock_code"] for p in body["positions"]]
|
||||
assert "AAPL" not in codes
|
||||
|
||||
|
||||
def test_positions_empty_when_no_trades(tmp_path: Path) -> None:
|
||||
"""거래 내역이 없으면 빈 포지션 목록을 반환해야 한다."""
|
||||
db_path = tmp_path / "empty.db"
|
||||
conn = init_db(str(db_path))
|
||||
conn.close()
|
||||
app = create_dashboard_app(str(db_path))
|
||||
get_positions = _endpoint(app, "/api/positions")
|
||||
body = get_positions()
|
||||
assert body["count"] == 0
|
||||
assert body["positions"] == []
|
||||
|
||||
|
||||
def _seed_cb_context(conn: sqlite3.Connection, pnl_pct: float, market: str = "KR") -> None:
|
||||
import json as _json
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO system_metrics (key, value, updated_at) VALUES (?, ?, ?)",
|
||||
(
|
||||
f"portfolio_pnl_pct_{market}",
|
||||
_json.dumps({"pnl_pct": pnl_pct}),
|
||||
"2026-02-22T10:00:00+00:00",
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def test_status_circuit_breaker_ok(tmp_path: Path) -> None:
|
||||
"""pnl_pct가 -2.0%보다 높으면 status=ok를 반환해야 한다."""
|
||||
db_path = tmp_path / "cb_ok.db"
|
||||
conn = init_db(str(db_path))
|
||||
_seed_cb_context(conn, -1.0)
|
||||
conn.close()
|
||||
app = create_dashboard_app(str(db_path))
|
||||
get_status = _endpoint(app, "/api/status")
|
||||
body = get_status()
|
||||
cb = body["circuit_breaker"]
|
||||
assert cb["status"] == "ok"
|
||||
assert cb["current_pnl_pct"] == -1.0
|
||||
assert cb["threshold_pct"] == -3.0
|
||||
|
||||
|
||||
def test_status_circuit_breaker_warning(tmp_path: Path) -> None:
|
||||
"""pnl_pct가 -2.0% 이하이면 status=warning을 반환해야 한다."""
|
||||
db_path = tmp_path / "cb_warn.db"
|
||||
conn = init_db(str(db_path))
|
||||
_seed_cb_context(conn, -2.5)
|
||||
conn.close()
|
||||
app = create_dashboard_app(str(db_path))
|
||||
get_status = _endpoint(app, "/api/status")
|
||||
body = get_status()
|
||||
assert body["circuit_breaker"]["status"] == "warning"
|
||||
|
||||
|
||||
def test_status_circuit_breaker_tripped(tmp_path: Path) -> None:
|
||||
"""pnl_pct가 임계값(-3.0%) 이하이면 status=tripped를 반환해야 한다."""
|
||||
db_path = tmp_path / "cb_tripped.db"
|
||||
conn = init_db(str(db_path))
|
||||
_seed_cb_context(conn, -3.5)
|
||||
conn.close()
|
||||
app = create_dashboard_app(str(db_path))
|
||||
get_status = _endpoint(app, "/api/status")
|
||||
body = get_status()
|
||||
assert body["circuit_breaker"]["status"] == "tripped"
|
||||
|
||||
|
||||
def test_status_circuit_breaker_unknown_when_no_data(tmp_path: Path) -> None:
|
||||
"""L7 context에 pnl_pct 데이터가 없으면 status=unknown을 반환해야 한다."""
|
||||
app = _app(tmp_path) # seed_db에는 portfolio_pnl_pct 없음
|
||||
get_status = _endpoint(app, "/api/status")
|
||||
body = get_status()
|
||||
cb = body["circuit_breaker"]
|
||||
assert cb["status"] == "unknown"
|
||||
assert cb["current_pnl_pct"] is None
|
||||
|
||||
|
||||
def test_status_mode_paper(tmp_path: Path) -> None:
|
||||
"""mode=paper로 생성하면 status 응답에 mode=paper가 포함돼야 한다."""
|
||||
db_path = tmp_path / "dashboard_test.db"
|
||||
conn = init_db(str(db_path))
|
||||
_seed_db(conn)
|
||||
conn.close()
|
||||
app = create_dashboard_app(str(db_path), mode="paper")
|
||||
get_status = _endpoint(app, "/api/status")
|
||||
body = get_status()
|
||||
assert body["mode"] == "paper"
|
||||
|
||||
|
||||
def test_status_mode_live(tmp_path: Path) -> None:
|
||||
"""mode=live로 생성하면 status 응답에 mode=live가 포함돼야 한다."""
|
||||
db_path = tmp_path / "dashboard_test.db"
|
||||
conn = init_db(str(db_path))
|
||||
_seed_db(conn)
|
||||
conn.close()
|
||||
app = create_dashboard_app(str(db_path), mode="live")
|
||||
get_status = _endpoint(app, "/api/status")
|
||||
body = get_status()
|
||||
assert body["mode"] == "live"
|
||||
|
||||
|
||||
def test_status_mode_default_paper(tmp_path: Path) -> None:
|
||||
"""mode 파라미터 미전달 시 기본값은 paper여야 한다."""
|
||||
db_path = tmp_path / "dashboard_test.db"
|
||||
conn = init_db(str(db_path))
|
||||
_seed_db(conn)
|
||||
conn.close()
|
||||
app = create_dashboard_app(str(db_path))
|
||||
get_status = _endpoint(app, "/api/status")
|
||||
body = get_status()
|
||||
assert body["mode"] == "paper"
|
||||
673
tests/test_data_integration.py
Normal file
673
tests/test_data_integration.py
Normal file
@@ -0,0 +1,673 @@
|
||||
"""Tests for external data integration (news, economic calendar, market data)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.brain.gemini_client import GeminiClient
|
||||
from src.data.economic_calendar import EconomicCalendar, EconomicEvent
|
||||
from src.data.market_data import MarketBreadth, MarketData, MarketSentiment
|
||||
from src.data.news_api import NewsAPI, NewsArticle, NewsSentiment
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# NewsAPI Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNewsAPI:
|
||||
"""Test news API integration with caching."""
|
||||
|
||||
def test_news_api_init_without_key(self):
|
||||
"""NewsAPI should initialize without API key for testing."""
|
||||
api = NewsAPI(api_key=None)
|
||||
assert api._api_key is None
|
||||
assert api._provider == "alphavantage"
|
||||
assert api._cache_ttl == 300
|
||||
|
||||
def test_news_api_init_with_custom_settings(self):
|
||||
"""NewsAPI should accept custom provider and cache TTL."""
|
||||
api = NewsAPI(api_key="test_key", provider="newsapi", cache_ttl=600)
|
||||
assert api._api_key == "test_key"
|
||||
assert api._provider == "newsapi"
|
||||
assert api._cache_ttl == 600
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_news_sentiment_without_api_key_returns_none(self):
|
||||
"""Without API key, get_news_sentiment should return None."""
|
||||
api = NewsAPI(api_key=None)
|
||||
result = await api.get_news_sentiment("AAPL")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_hit_returns_cached_sentiment(self):
|
||||
"""Cache hit should return cached sentiment without API call."""
|
||||
api = NewsAPI(api_key="test_key")
|
||||
|
||||
# Manually populate cache
|
||||
cached_sentiment = NewsSentiment(
|
||||
stock_code="AAPL",
|
||||
articles=[],
|
||||
avg_sentiment=0.5,
|
||||
article_count=0,
|
||||
fetched_at=time.time(),
|
||||
)
|
||||
api._cache["AAPL"] = cached_sentiment
|
||||
|
||||
result = await api.get_news_sentiment("AAPL")
|
||||
assert result is cached_sentiment
|
||||
assert result.stock_code == "AAPL"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_expiry_triggers_refetch(self):
|
||||
"""Expired cache entry should trigger refetch."""
|
||||
api = NewsAPI(api_key="test_key", cache_ttl=1)
|
||||
|
||||
# Add expired cache entry
|
||||
expired_sentiment = NewsSentiment(
|
||||
stock_code="AAPL",
|
||||
articles=[],
|
||||
avg_sentiment=0.5,
|
||||
article_count=0,
|
||||
fetched_at=time.time() - 10, # 10 seconds ago
|
||||
)
|
||||
api._cache["AAPL"] = expired_sentiment
|
||||
|
||||
# Mock the fetch to avoid real API call
|
||||
with patch.object(api, "_fetch_news", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = None
|
||||
result = await api.get_news_sentiment("AAPL")
|
||||
|
||||
# Should have attempted refetch since cache expired
|
||||
mock_fetch.assert_called_once_with("AAPL")
|
||||
|
||||
def test_clear_cache(self):
|
||||
"""clear_cache should empty the cache."""
|
||||
api = NewsAPI(api_key="test_key")
|
||||
api._cache["AAPL"] = NewsSentiment(
|
||||
stock_code="AAPL",
|
||||
articles=[],
|
||||
avg_sentiment=0.0,
|
||||
article_count=0,
|
||||
fetched_at=time.time(),
|
||||
)
|
||||
assert len(api._cache) == 1
|
||||
|
||||
api.clear_cache()
|
||||
assert len(api._cache) == 0
|
||||
|
||||
def test_parse_alphavantage_response_with_valid_data(self):
|
||||
"""Should parse Alpha Vantage response correctly."""
|
||||
api = NewsAPI(api_key="test_key", provider="alphavantage")
|
||||
|
||||
mock_response = {
|
||||
"feed": [
|
||||
{
|
||||
"title": "Apple hits new high",
|
||||
"summary": "Apple stock surges to record levels",
|
||||
"source": "Reuters",
|
||||
"time_published": "2026-02-04T10:00:00",
|
||||
"url": "https://example.com/1",
|
||||
"ticker_sentiment": [
|
||||
{"ticker": "AAPL", "ticker_sentiment_score": "0.85"}
|
||||
],
|
||||
"overall_sentiment_score": "0.75",
|
||||
},
|
||||
{
|
||||
"title": "Market volatility rises",
|
||||
"summary": "Tech stocks face headwinds",
|
||||
"source": "Bloomberg",
|
||||
"time_published": "2026-02-04T09:00:00",
|
||||
"url": "https://example.com/2",
|
||||
"ticker_sentiment": [
|
||||
{"ticker": "AAPL", "ticker_sentiment_score": "-0.3"}
|
||||
],
|
||||
"overall_sentiment_score": "-0.2",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
result = api._parse_alphavantage_response("AAPL", mock_response)
|
||||
|
||||
assert result is not None
|
||||
assert result.stock_code == "AAPL"
|
||||
assert result.article_count == 2
|
||||
assert len(result.articles) == 2
|
||||
assert result.articles[0].title == "Apple hits new high"
|
||||
assert result.articles[0].sentiment_score == 0.85
|
||||
assert result.articles[1].sentiment_score == -0.3
|
||||
# Average: (0.85 - 0.3) / 2 = 0.275
|
||||
assert abs(result.avg_sentiment - 0.275) < 0.01
|
||||
|
||||
def test_parse_alphavantage_response_without_feed_returns_none(self):
|
||||
"""Should return None if 'feed' key is missing."""
|
||||
api = NewsAPI(api_key="test_key", provider="alphavantage")
|
||||
result = api._parse_alphavantage_response("AAPL", {})
|
||||
assert result is None
|
||||
|
||||
def test_parse_newsapi_response_with_valid_data(self):
|
||||
"""Should parse NewsAPI.org response correctly."""
|
||||
api = NewsAPI(api_key="test_key", provider="newsapi")
|
||||
|
||||
mock_response = {
|
||||
"status": "ok",
|
||||
"articles": [
|
||||
{
|
||||
"title": "Apple stock surges",
|
||||
"description": "Strong earnings beat expectations",
|
||||
"source": {"name": "TechCrunch"},
|
||||
"publishedAt": "2026-02-04T10:00:00Z",
|
||||
"url": "https://example.com/1",
|
||||
},
|
||||
{
|
||||
"title": "Tech sector faces risks",
|
||||
"description": "Concerns over market downturn",
|
||||
"source": {"name": "CNBC"},
|
||||
"publishedAt": "2026-02-04T09:00:00Z",
|
||||
"url": "https://example.com/2",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
result = api._parse_newsapi_response("AAPL", mock_response)
|
||||
|
||||
assert result is not None
|
||||
assert result.stock_code == "AAPL"
|
||||
assert result.article_count == 2
|
||||
assert len(result.articles) == 2
|
||||
assert result.articles[0].title == "Apple stock surges"
|
||||
assert result.articles[0].source == "TechCrunch"
|
||||
|
||||
def test_estimate_sentiment_from_text_positive(self):
|
||||
"""Should detect positive sentiment from keywords."""
|
||||
api = NewsAPI()
|
||||
text = "Stock price surges with strong profit growth and upgrade"
|
||||
sentiment = api._estimate_sentiment_from_text(text)
|
||||
assert sentiment > 0.5
|
||||
|
||||
def test_estimate_sentiment_from_text_negative(self):
|
||||
"""Should detect negative sentiment from keywords."""
|
||||
api = NewsAPI()
|
||||
text = "Stock plunges on weak earnings, downgrade warning"
|
||||
sentiment = api._estimate_sentiment_from_text(text)
|
||||
assert sentiment < -0.5
|
||||
|
||||
def test_estimate_sentiment_from_text_neutral(self):
|
||||
"""Should return neutral sentiment without keywords."""
|
||||
api = NewsAPI()
|
||||
text = "Company announces quarterly report"
|
||||
sentiment = api._estimate_sentiment_from_text(text)
|
||||
assert abs(sentiment) < 0.1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EconomicCalendar Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEconomicCalendar:
|
||||
"""Test economic calendar functionality."""
|
||||
|
||||
def test_economic_calendar_init(self):
|
||||
"""EconomicCalendar should initialize correctly."""
|
||||
calendar = EconomicCalendar(api_key="test_key")
|
||||
assert calendar._api_key == "test_key"
|
||||
assert len(calendar._events) == 0
|
||||
|
||||
def test_add_event(self):
|
||||
"""Should be able to add events to calendar."""
|
||||
calendar = EconomicCalendar()
|
||||
event = EconomicEvent(
|
||||
name="FOMC Meeting",
|
||||
event_type="FOMC",
|
||||
datetime=datetime(2026, 3, 18),
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Interest rate decision",
|
||||
)
|
||||
calendar.add_event(event)
|
||||
assert len(calendar._events) == 1
|
||||
assert calendar._events[0].name == "FOMC Meeting"
|
||||
|
||||
def test_get_upcoming_events_filters_by_timeframe(self):
|
||||
"""Should only return events within specified timeframe."""
|
||||
calendar = EconomicCalendar()
|
||||
|
||||
# Add events at different times
|
||||
now = datetime.now()
|
||||
calendar.add_event(
|
||||
EconomicEvent(
|
||||
name="Event Tomorrow",
|
||||
event_type="GDP",
|
||||
datetime=now + timedelta(days=1),
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Test event",
|
||||
)
|
||||
)
|
||||
calendar.add_event(
|
||||
EconomicEvent(
|
||||
name="Event Next Month",
|
||||
event_type="CPI",
|
||||
datetime=now + timedelta(days=30),
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Test event",
|
||||
)
|
||||
)
|
||||
|
||||
# Get events for next 7 days
|
||||
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="HIGH")
|
||||
assert upcoming.high_impact_count == 1
|
||||
assert upcoming.events[0].name == "Event Tomorrow"
|
||||
|
||||
def test_get_upcoming_events_filters_by_impact(self):
|
||||
"""Should filter events by minimum impact level."""
|
||||
calendar = EconomicCalendar()
|
||||
|
||||
now = datetime.now()
|
||||
calendar.add_event(
|
||||
EconomicEvent(
|
||||
name="High Impact Event",
|
||||
event_type="FOMC",
|
||||
datetime=now + timedelta(days=1),
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Test",
|
||||
)
|
||||
)
|
||||
calendar.add_event(
|
||||
EconomicEvent(
|
||||
name="Low Impact Event",
|
||||
event_type="OTHER",
|
||||
datetime=now + timedelta(days=1),
|
||||
impact="LOW",
|
||||
country="US",
|
||||
description="Test",
|
||||
)
|
||||
)
|
||||
|
||||
# Filter for HIGH impact only
|
||||
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="HIGH")
|
||||
assert upcoming.high_impact_count == 1
|
||||
assert upcoming.events[0].name == "High Impact Event"
|
||||
|
||||
# Filter for MEDIUM and above (should still get HIGH)
|
||||
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="MEDIUM")
|
||||
assert len(upcoming.events) == 1
|
||||
|
||||
# Filter for LOW and above (should get both)
|
||||
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="LOW")
|
||||
assert len(upcoming.events) == 2
|
||||
|
||||
def test_get_earnings_date_returns_next_earnings(self):
|
||||
"""Should return next earnings date for a stock."""
|
||||
calendar = EconomicCalendar()
|
||||
|
||||
now = datetime.now()
|
||||
earnings_date = now + timedelta(days=5)
|
||||
|
||||
calendar.add_event(
|
||||
EconomicEvent(
|
||||
name="AAPL Earnings",
|
||||
event_type="EARNINGS",
|
||||
datetime=earnings_date,
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Apple quarterly earnings",
|
||||
)
|
||||
)
|
||||
|
||||
result = calendar.get_earnings_date("AAPL")
|
||||
assert result == earnings_date
|
||||
|
||||
def test_get_earnings_date_returns_none_if_not_found(self):
|
||||
"""Should return None if no earnings found for stock."""
|
||||
calendar = EconomicCalendar()
|
||||
result = calendar.get_earnings_date("UNKNOWN")
|
||||
assert result is None
|
||||
|
||||
def test_load_hardcoded_events(self):
|
||||
"""Should load hardcoded major economic events."""
|
||||
calendar = EconomicCalendar()
|
||||
calendar.load_hardcoded_events()
|
||||
|
||||
# Should have multiple events (FOMC, GDP, CPI)
|
||||
assert len(calendar._events) > 10
|
||||
|
||||
# Check for FOMC events
|
||||
fomc_events = [e for e in calendar._events if e.event_type == "FOMC"]
|
||||
assert len(fomc_events) > 0
|
||||
|
||||
# Check for GDP events
|
||||
gdp_events = [e for e in calendar._events if e.event_type == "GDP"]
|
||||
assert len(gdp_events) > 0
|
||||
|
||||
# Check for CPI events
|
||||
cpi_events = [e for e in calendar._events if e.event_type == "CPI"]
|
||||
assert len(cpi_events) == 12 # Monthly CPI releases
|
||||
|
||||
def test_is_high_volatility_period_returns_true_near_high_impact(self):
|
||||
"""Should return True if high-impact event is within threshold."""
|
||||
calendar = EconomicCalendar()
|
||||
|
||||
now = datetime.now()
|
||||
calendar.add_event(
|
||||
EconomicEvent(
|
||||
name="FOMC Meeting",
|
||||
event_type="FOMC",
|
||||
datetime=now + timedelta(hours=12),
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Test",
|
||||
)
|
||||
)
|
||||
|
||||
assert calendar.is_high_volatility_period(hours_ahead=24) is True
|
||||
|
||||
def test_is_high_volatility_period_returns_false_when_no_events(self):
|
||||
"""Should return False if no high-impact events nearby."""
|
||||
calendar = EconomicCalendar()
|
||||
assert calendar.is_high_volatility_period(hours_ahead=24) is False
|
||||
|
||||
def test_clear_events(self):
|
||||
"""Should clear all events."""
|
||||
calendar = EconomicCalendar()
|
||||
calendar.add_event(
|
||||
EconomicEvent(
|
||||
name="Test",
|
||||
event_type="TEST",
|
||||
datetime=datetime.now(),
|
||||
impact="LOW",
|
||||
country="US",
|
||||
description="Test",
|
||||
)
|
||||
)
|
||||
assert len(calendar._events) == 1
|
||||
|
||||
calendar.clear_events()
|
||||
assert len(calendar._events) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MarketData Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMarketData:
|
||||
"""Test market data indicators."""
|
||||
|
||||
def test_market_data_init(self):
|
||||
"""MarketData should initialize correctly."""
|
||||
data = MarketData(api_key="test_key")
|
||||
assert data._api_key == "test_key"
|
||||
|
||||
def test_get_market_sentiment_without_api_key_returns_neutral(self):
|
||||
"""Without API key, should return NEUTRAL sentiment."""
|
||||
data = MarketData(api_key=None)
|
||||
sentiment = data.get_market_sentiment()
|
||||
assert sentiment == MarketSentiment.NEUTRAL
|
||||
|
||||
def test_get_market_breadth_without_api_key_returns_none(self):
|
||||
"""Without API key, should return None for breadth."""
|
||||
data = MarketData(api_key=None)
|
||||
breadth = data.get_market_breadth()
|
||||
assert breadth is None
|
||||
|
||||
def test_get_sector_performance_without_api_key_returns_empty(self):
|
||||
"""Without API key, should return empty list."""
|
||||
data = MarketData(api_key=None)
|
||||
sectors = data.get_sector_performance()
|
||||
assert sectors == []
|
||||
|
||||
def test_get_market_indicators_returns_defaults_without_api(self):
|
||||
"""Should return default indicators without API key."""
|
||||
data = MarketData(api_key=None)
|
||||
indicators = data.get_market_indicators()
|
||||
|
||||
assert indicators.sentiment == MarketSentiment.NEUTRAL
|
||||
assert indicators.breadth.advance_decline_ratio == 1.0
|
||||
assert indicators.sector_performance == []
|
||||
assert indicators.vix_level is None
|
||||
|
||||
def test_calculate_fear_greed_score_neutral_baseline(self):
|
||||
"""Should return neutral score (50) for balanced market."""
|
||||
data = MarketData()
|
||||
breadth = MarketBreadth(
|
||||
advancing_stocks=500,
|
||||
declining_stocks=500,
|
||||
unchanged_stocks=100,
|
||||
new_highs=50,
|
||||
new_lows=50,
|
||||
advance_decline_ratio=1.0,
|
||||
)
|
||||
|
||||
score = data.calculate_fear_greed_score(breadth)
|
||||
assert score == 50
|
||||
|
||||
def test_calculate_fear_greed_score_greedy_market(self):
|
||||
"""Should return high score for greedy market conditions."""
|
||||
data = MarketData()
|
||||
breadth = MarketBreadth(
|
||||
advancing_stocks=800,
|
||||
declining_stocks=200,
|
||||
unchanged_stocks=100,
|
||||
new_highs=100,
|
||||
new_lows=10,
|
||||
advance_decline_ratio=4.0,
|
||||
)
|
||||
|
||||
score = data.calculate_fear_greed_score(breadth, vix=12.0)
|
||||
assert score > 70
|
||||
|
||||
def test_calculate_fear_greed_score_fearful_market(self):
|
||||
"""Should return low score for fearful market conditions."""
|
||||
data = MarketData()
|
||||
breadth = MarketBreadth(
|
||||
advancing_stocks=200,
|
||||
declining_stocks=800,
|
||||
unchanged_stocks=100,
|
||||
new_highs=10,
|
||||
new_lows=100,
|
||||
advance_decline_ratio=0.25,
|
||||
)
|
||||
|
||||
score = data.calculate_fear_greed_score(breadth, vix=35.0)
|
||||
assert score < 30
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GeminiClient Integration Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGeminiClientWithExternalData:
|
||||
"""Test GeminiClient integration with external data sources."""
|
||||
|
||||
def test_gemini_client_accepts_optional_data_sources(self, settings):
|
||||
"""GeminiClient should accept optional external data sources."""
|
||||
news_api = NewsAPI(api_key="test_key")
|
||||
calendar = EconomicCalendar()
|
||||
market_data = MarketData()
|
||||
|
||||
client = GeminiClient(
|
||||
settings,
|
||||
news_api=news_api,
|
||||
economic_calendar=calendar,
|
||||
market_data=market_data,
|
||||
)
|
||||
|
||||
assert client._news_api is news_api
|
||||
assert client._economic_calendar is calendar
|
||||
assert client._market_data is market_data
|
||||
|
||||
def test_gemini_client_works_without_external_data(self, settings):
|
||||
"""GeminiClient should work without external data sources."""
|
||||
client = GeminiClient(settings)
|
||||
assert client._news_api is None
|
||||
assert client._economic_calendar is None
|
||||
assert client._market_data is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_prompt_includes_news_sentiment(self, settings):
|
||||
"""build_prompt should include news sentiment when available."""
|
||||
client = GeminiClient(settings)
|
||||
|
||||
market_data = {
|
||||
"stock_code": "AAPL",
|
||||
"current_price": 180.0,
|
||||
"market_name": "US stock market",
|
||||
}
|
||||
|
||||
sentiment = NewsSentiment(
|
||||
stock_code="AAPL",
|
||||
articles=[
|
||||
NewsArticle(
|
||||
title="Apple hits record high",
|
||||
summary="Strong earnings",
|
||||
source="Reuters",
|
||||
published_at="2026-02-04",
|
||||
sentiment_score=0.85,
|
||||
url="https://example.com",
|
||||
)
|
||||
],
|
||||
avg_sentiment=0.85,
|
||||
article_count=1,
|
||||
fetched_at=time.time(),
|
||||
)
|
||||
|
||||
prompt = await client.build_prompt(market_data, news_sentiment=sentiment)
|
||||
|
||||
assert "AAPL" in prompt
|
||||
assert "180.0" in prompt
|
||||
assert "EXTERNAL DATA" in prompt
|
||||
assert "News Sentiment" in prompt
|
||||
assert "0.85" in prompt
|
||||
assert "Apple hits record high" in prompt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_prompt_with_economic_events(self, settings):
|
||||
"""build_prompt should include upcoming economic events."""
|
||||
calendar = EconomicCalendar()
|
||||
now = datetime.now()
|
||||
calendar.add_event(
|
||||
EconomicEvent(
|
||||
name="FOMC Meeting",
|
||||
event_type="FOMC",
|
||||
datetime=now + timedelta(days=2),
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Interest rate decision",
|
||||
)
|
||||
)
|
||||
|
||||
client = GeminiClient(settings, economic_calendar=calendar)
|
||||
|
||||
market_data = {
|
||||
"stock_code": "AAPL",
|
||||
"current_price": 180.0,
|
||||
"market_name": "US stock market",
|
||||
}
|
||||
|
||||
prompt = await client.build_prompt(market_data)
|
||||
|
||||
assert "EXTERNAL DATA" in prompt
|
||||
assert "High-Impact Events" in prompt
|
||||
assert "FOMC Meeting" in prompt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_prompt_with_market_indicators(self, settings):
|
||||
"""build_prompt should include market sentiment indicators."""
|
||||
market_data_provider = MarketData(api_key="test_key")
|
||||
|
||||
# Mock the get_market_indicators to return test data
|
||||
with patch.object(market_data_provider, "get_market_indicators") as mock:
|
||||
mock.return_value = MagicMock(
|
||||
sentiment=MarketSentiment.EXTREME_GREED,
|
||||
breadth=MagicMock(advance_decline_ratio=2.5),
|
||||
)
|
||||
|
||||
client = GeminiClient(settings, market_data=market_data_provider)
|
||||
|
||||
market_data = {
|
||||
"stock_code": "AAPL",
|
||||
"current_price": 180.0,
|
||||
"market_name": "US stock market",
|
||||
}
|
||||
|
||||
prompt = await client.build_prompt(market_data)
|
||||
|
||||
assert "EXTERNAL DATA" in prompt
|
||||
assert "Market Sentiment" in prompt
|
||||
assert "EXTREME_GREED" in prompt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_prompt_graceful_when_no_external_data(self, settings):
|
||||
"""build_prompt should work gracefully without external data."""
|
||||
client = GeminiClient(settings)
|
||||
|
||||
market_data = {
|
||||
"stock_code": "AAPL",
|
||||
"current_price": 180.0,
|
||||
"market_name": "US stock market",
|
||||
}
|
||||
|
||||
prompt = await client.build_prompt(market_data)
|
||||
|
||||
assert "AAPL" in prompt
|
||||
assert "180.0" in prompt
|
||||
# Should NOT have external data section
|
||||
assert "EXTERNAL DATA" not in prompt
|
||||
|
||||
def test_build_prompt_sync_backward_compatibility(self, settings):
|
||||
"""build_prompt_sync should maintain backward compatibility."""
|
||||
client = GeminiClient(settings)
|
||||
|
||||
market_data = {
|
||||
"stock_code": "005930",
|
||||
"current_price": 72000,
|
||||
"orderbook": {"asks": [], "bids": []},
|
||||
"foreigner_net": -50000,
|
||||
}
|
||||
|
||||
prompt = client.build_prompt_sync(market_data)
|
||||
|
||||
assert "005930" in prompt
|
||||
assert "72000" in prompt
|
||||
assert "JSON" in prompt
|
||||
# Sync version should NOT have external data
|
||||
assert "EXTERNAL DATA" not in prompt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decide_with_news_sentiment_parameter(self, settings):
|
||||
"""decide should accept optional news_sentiment parameter."""
|
||||
client = GeminiClient(settings)
|
||||
|
||||
market_data = {
|
||||
"stock_code": "AAPL",
|
||||
"current_price": 180.0,
|
||||
"market_name": "US stock market",
|
||||
}
|
||||
|
||||
sentiment = NewsSentiment(
|
||||
stock_code="AAPL",
|
||||
articles=[],
|
||||
avg_sentiment=0.5,
|
||||
article_count=1,
|
||||
fetched_at=time.time(),
|
||||
)
|
||||
|
||||
# Mock the Gemini API call
|
||||
with patch.object(client._client.aio.models, "generate_content", new_callable=AsyncMock) as mock_gen:
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = '{"action": "BUY", "confidence": 85, "rationale": "Good news"}'
|
||||
mock_gen.return_value = mock_response
|
||||
|
||||
decision = await client.decide(market_data, news_sentiment=sentiment)
|
||||
|
||||
assert decision.action == "BUY"
|
||||
assert decision.confidence == 85
|
||||
mock_gen.assert_called_once()
|
||||
195
tests/test_db.py
Normal file
195
tests/test_db.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""Tests for database helper functions."""
|
||||
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
from src.db import get_open_position, init_db, log_trade
|
||||
|
||||
|
||||
def test_get_open_position_returns_latest_buy() -> None:
|
||||
conn = init_db(":memory:")
|
||||
log_trade(
|
||||
conn=conn,
|
||||
stock_code="005930",
|
||||
action="BUY",
|
||||
confidence=90,
|
||||
rationale="entry",
|
||||
quantity=2,
|
||||
price=70000.0,
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
decision_id="d-buy-1",
|
||||
)
|
||||
|
||||
position = get_open_position(conn, "005930", "KR")
|
||||
assert position is not None
|
||||
assert position["decision_id"] == "d-buy-1"
|
||||
assert position["price"] == 70000.0
|
||||
assert position["quantity"] == 2
|
||||
|
||||
|
||||
def test_get_open_position_returns_none_when_latest_is_sell() -> None:
|
||||
conn = init_db(":memory:")
|
||||
log_trade(
|
||||
conn=conn,
|
||||
stock_code="005930",
|
||||
action="BUY",
|
||||
confidence=90,
|
||||
rationale="entry",
|
||||
quantity=1,
|
||||
price=70000.0,
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
decision_id="d-buy-1",
|
||||
)
|
||||
log_trade(
|
||||
conn=conn,
|
||||
stock_code="005930",
|
||||
action="SELL",
|
||||
confidence=95,
|
||||
rationale="exit",
|
||||
quantity=1,
|
||||
price=71000.0,
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
decision_id="d-sell-1",
|
||||
)
|
||||
|
||||
assert get_open_position(conn, "005930", "KR") is None
|
||||
|
||||
|
||||
def test_get_open_position_returns_none_when_no_trades() -> None:
|
||||
conn = init_db(":memory:")
|
||||
assert get_open_position(conn, "AAPL", "US_NASDAQ") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WAL mode tests (issue #210)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_wal_mode_applied_to_file_db() -> None:
|
||||
"""File-based DB must use WAL journal mode for dashboard concurrent reads."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
db_path = f.name
|
||||
try:
|
||||
conn = init_db(db_path)
|
||||
cursor = conn.execute("PRAGMA journal_mode")
|
||||
mode = cursor.fetchone()[0]
|
||||
assert mode == "wal", f"Expected WAL mode, got {mode}"
|
||||
conn.close()
|
||||
finally:
|
||||
os.unlink(db_path)
|
||||
# Clean up WAL auxiliary files if they exist
|
||||
for ext in ("-wal", "-shm"):
|
||||
path = db_path + ext
|
||||
if os.path.exists(path):
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
def test_wal_mode_not_applied_to_memory_db() -> None:
|
||||
""":memory: DB must not apply WAL (SQLite does not support WAL for in-memory)."""
|
||||
conn = init_db(":memory:")
|
||||
cursor = conn.execute("PRAGMA journal_mode")
|
||||
mode = cursor.fetchone()[0]
|
||||
# In-memory DBs default to 'memory' journal mode
|
||||
assert mode != "wal", "WAL should not be set on in-memory database"
|
||||
conn.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# mode column tests (issue #212)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_log_trade_stores_mode_paper() -> None:
|
||||
"""log_trade must persist mode='paper' in the trades table."""
|
||||
conn = init_db(":memory:")
|
||||
log_trade(
|
||||
conn=conn,
|
||||
stock_code="005930",
|
||||
action="BUY",
|
||||
confidence=85,
|
||||
rationale="test",
|
||||
mode="paper",
|
||||
)
|
||||
row = conn.execute("SELECT mode FROM trades ORDER BY id DESC LIMIT 1").fetchone()
|
||||
assert row is not None
|
||||
assert row[0] == "paper"
|
||||
|
||||
|
||||
def test_log_trade_stores_mode_live() -> None:
|
||||
"""log_trade must persist mode='live' in the trades table."""
|
||||
conn = init_db(":memory:")
|
||||
log_trade(
|
||||
conn=conn,
|
||||
stock_code="005930",
|
||||
action="BUY",
|
||||
confidence=85,
|
||||
rationale="test",
|
||||
mode="live",
|
||||
)
|
||||
row = conn.execute("SELECT mode FROM trades ORDER BY id DESC LIMIT 1").fetchone()
|
||||
assert row is not None
|
||||
assert row[0] == "live"
|
||||
|
||||
|
||||
def test_log_trade_default_mode_is_paper() -> None:
|
||||
"""log_trade without explicit mode must default to 'paper'."""
|
||||
conn = init_db(":memory:")
|
||||
log_trade(
|
||||
conn=conn,
|
||||
stock_code="005930",
|
||||
action="HOLD",
|
||||
confidence=50,
|
||||
rationale="test",
|
||||
)
|
||||
row = conn.execute("SELECT mode FROM trades ORDER BY id DESC LIMIT 1").fetchone()
|
||||
assert row is not None
|
||||
assert row[0] == "paper"
|
||||
|
||||
|
||||
def test_mode_column_exists_in_schema() -> None:
|
||||
"""trades table must have a mode column after init_db."""
|
||||
conn = init_db(":memory:")
|
||||
cursor = conn.execute("PRAGMA table_info(trades)")
|
||||
columns = {row[1] for row in cursor.fetchall()}
|
||||
assert "mode" in columns
|
||||
|
||||
|
||||
def test_mode_migration_adds_column_to_existing_db() -> None:
|
||||
"""init_db must add mode column to existing DBs that lack it (migration)."""
|
||||
import sqlite3
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
db_path = f.name
|
||||
try:
|
||||
# Create DB without mode column (simulate old schema)
|
||||
old_conn = sqlite3.connect(db_path)
|
||||
old_conn.execute(
|
||||
"""CREATE TABLE trades (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
timestamp TEXT NOT NULL,
|
||||
stock_code TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
confidence INTEGER NOT NULL,
|
||||
rationale TEXT,
|
||||
quantity INTEGER,
|
||||
price REAL,
|
||||
pnl REAL DEFAULT 0.0,
|
||||
market TEXT DEFAULT 'KR',
|
||||
exchange_code TEXT DEFAULT 'KRX',
|
||||
decision_id TEXT
|
||||
)"""
|
||||
)
|
||||
old_conn.commit()
|
||||
old_conn.close()
|
||||
|
||||
# Run init_db — should add mode column via migration
|
||||
conn = init_db(db_path)
|
||||
cursor = conn.execute("PRAGMA table_info(trades)")
|
||||
columns = {row[1] for row in cursor.fetchall()}
|
||||
assert "mode" in columns
|
||||
conn.close()
|
||||
finally:
|
||||
os.unlink(db_path)
|
||||
685
tests/test_evolution.py
Normal file
685
tests/test_evolution.py
Normal file
@@ -0,0 +1,685 @@
|
||||
"""Tests for the Evolution Engine components.
|
||||
|
||||
Tests cover:
|
||||
- EvolutionOptimizer: failure analysis and strategy generation
|
||||
- ABTester: A/B testing and statistical comparison
|
||||
- PerformanceTracker: metrics tracking and dashboard
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.config import Settings
|
||||
from src.db import init_db, log_trade
|
||||
from src.evolution.ab_test import ABTester
|
||||
from src.evolution.optimizer import EvolutionOptimizer
|
||||
from src.evolution.performance_tracker import (
|
||||
PerformanceDashboard,
|
||||
PerformanceTracker,
|
||||
StrategyMetrics,
|
||||
)
|
||||
from src.logging.decision_logger import DecisionLogger
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_conn() -> sqlite3.Connection:
|
||||
"""Provide an in-memory database with initialized schema."""
|
||||
return init_db(":memory:")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def settings() -> Settings:
|
||||
"""Provide test settings."""
|
||||
return Settings(
|
||||
KIS_APP_KEY="test_key",
|
||||
KIS_APP_SECRET="test_secret",
|
||||
KIS_ACCOUNT_NO="12345678-01",
|
||||
GEMINI_API_KEY="test_gemini_key",
|
||||
GEMINI_MODEL="gemini-pro",
|
||||
DB_PATH=":memory:",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def optimizer(settings: Settings) -> EvolutionOptimizer:
|
||||
"""Provide an EvolutionOptimizer instance."""
|
||||
return EvolutionOptimizer(settings)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def decision_logger(db_conn: sqlite3.Connection) -> DecisionLogger:
|
||||
"""Provide a DecisionLogger instance."""
|
||||
return DecisionLogger(db_conn)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ab_tester() -> ABTester:
|
||||
"""Provide an ABTester instance."""
|
||||
return ABTester(significance_level=0.05)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def performance_tracker(settings: Settings) -> PerformanceTracker:
|
||||
"""Provide a PerformanceTracker instance."""
|
||||
return PerformanceTracker(db_path=":memory:")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# EvolutionOptimizer Tests
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_analyze_failures_uses_decision_logger(optimizer: EvolutionOptimizer) -> None:
|
||||
"""Test that analyze_failures uses DecisionLogger.get_losing_decisions()."""
|
||||
# Add some losing decisions to the database
|
||||
logger = optimizer._decision_logger
|
||||
|
||||
# High-confidence loss
|
||||
id1 = logger.log_decision(
|
||||
stock_code="005930",
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
action="BUY",
|
||||
confidence=85,
|
||||
rationale="Expected growth",
|
||||
context_snapshot={"L1": {"price": 70000}},
|
||||
input_data={"price": 70000, "volume": 1000},
|
||||
)
|
||||
logger.update_outcome(id1, pnl=-2000.0, accuracy=0)
|
||||
|
||||
# Another high-confidence loss
|
||||
id2 = logger.log_decision(
|
||||
stock_code="000660",
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
action="SELL",
|
||||
confidence=90,
|
||||
rationale="Expected drop",
|
||||
context_snapshot={"L1": {"price": 100000}},
|
||||
input_data={"price": 100000, "volume": 500},
|
||||
)
|
||||
logger.update_outcome(id2, pnl=-1500.0, accuracy=0)
|
||||
|
||||
# Low-confidence loss (should be ignored)
|
||||
id3 = logger.log_decision(
|
||||
stock_code="035420",
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
action="HOLD",
|
||||
confidence=70,
|
||||
rationale="Uncertain",
|
||||
context_snapshot={},
|
||||
input_data={},
|
||||
)
|
||||
logger.update_outcome(id3, pnl=-500.0, accuracy=0)
|
||||
|
||||
# Analyze failures
|
||||
failures = optimizer.analyze_failures(limit=10)
|
||||
|
||||
# Should get 2 failures (confidence >= 80)
|
||||
assert len(failures) == 2
|
||||
assert all(f["confidence"] >= 80 for f in failures)
|
||||
assert all(f["outcome_pnl"] <= -100.0 for f in failures)
|
||||
|
||||
|
||||
def test_analyze_failures_empty_database(optimizer: EvolutionOptimizer) -> None:
|
||||
"""Test analyze_failures with no losing decisions."""
|
||||
failures = optimizer.analyze_failures()
|
||||
assert failures == []
|
||||
|
||||
|
||||
def test_identify_failure_patterns(optimizer: EvolutionOptimizer) -> None:
|
||||
"""Test identification of failure patterns."""
|
||||
failures = [
|
||||
{
|
||||
"decision_id": "1",
|
||||
"timestamp": "2024-01-15T09:30:00+00:00",
|
||||
"stock_code": "005930",
|
||||
"market": "KR",
|
||||
"exchange_code": "KRX",
|
||||
"action": "BUY",
|
||||
"confidence": 85,
|
||||
"rationale": "Test",
|
||||
"outcome_pnl": -1000.0,
|
||||
"outcome_accuracy": 0,
|
||||
"context_snapshot": {},
|
||||
"input_data": {},
|
||||
},
|
||||
{
|
||||
"decision_id": "2",
|
||||
"timestamp": "2024-01-15T14:30:00+00:00",
|
||||
"stock_code": "000660",
|
||||
"market": "KR",
|
||||
"exchange_code": "KRX",
|
||||
"action": "SELL",
|
||||
"confidence": 90,
|
||||
"rationale": "Test",
|
||||
"outcome_pnl": -2000.0,
|
||||
"outcome_accuracy": 0,
|
||||
"context_snapshot": {},
|
||||
"input_data": {},
|
||||
},
|
||||
{
|
||||
"decision_id": "3",
|
||||
"timestamp": "2024-01-15T09:45:00+00:00",
|
||||
"stock_code": "035420",
|
||||
"market": "US_NASDAQ",
|
||||
"exchange_code": "NASDAQ",
|
||||
"action": "BUY",
|
||||
"confidence": 80,
|
||||
"rationale": "Test",
|
||||
"outcome_pnl": -500.0,
|
||||
"outcome_accuracy": 0,
|
||||
"context_snapshot": {},
|
||||
"input_data": {},
|
||||
},
|
||||
]
|
||||
|
||||
patterns = optimizer.identify_failure_patterns(failures)
|
||||
|
||||
assert patterns["total_failures"] == 3
|
||||
assert patterns["markets"]["KR"] == 2
|
||||
assert patterns["markets"]["US_NASDAQ"] == 1
|
||||
assert patterns["actions"]["BUY"] == 2
|
||||
assert patterns["actions"]["SELL"] == 1
|
||||
assert 9 in patterns["hours"] # 09:30 and 09:45
|
||||
assert 14 in patterns["hours"] # 14:30
|
||||
assert patterns["avg_confidence"] == 85.0
|
||||
assert patterns["avg_loss"] == -1166.67
|
||||
|
||||
|
||||
def test_identify_failure_patterns_empty(optimizer: EvolutionOptimizer) -> None:
|
||||
"""Test pattern identification with no failures."""
|
||||
patterns = optimizer.identify_failure_patterns([])
|
||||
assert patterns["pattern_count"] == 0
|
||||
assert patterns["patterns"] == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_strategy_creates_file(optimizer: EvolutionOptimizer, tmp_path: Path) -> None:
|
||||
"""Test that generate_strategy creates a strategy file."""
|
||||
failures = [
|
||||
{
|
||||
"decision_id": "1",
|
||||
"timestamp": "2024-01-15T09:30:00+00:00",
|
||||
"stock_code": "005930",
|
||||
"market": "KR",
|
||||
"action": "BUY",
|
||||
"confidence": 85,
|
||||
"outcome_pnl": -1000.0,
|
||||
"context_snapshot": {},
|
||||
"input_data": {},
|
||||
}
|
||||
]
|
||||
|
||||
# Mock Gemini response
|
||||
mock_response = Mock()
|
||||
mock_response.text = """
|
||||
# Simple strategy
|
||||
price = market_data.get("current_price", 0)
|
||||
if price > 50000:
|
||||
return {"action": "BUY", "confidence": 70, "rationale": "Price above threshold"}
|
||||
return {"action": "HOLD", "confidence": 50, "rationale": "Waiting"}
|
||||
"""
|
||||
|
||||
with patch.object(optimizer._client.aio.models, "generate_content", new=AsyncMock(return_value=mock_response)):
|
||||
with patch("src.evolution.optimizer.STRATEGIES_DIR", tmp_path):
|
||||
strategy_path = await optimizer.generate_strategy(failures)
|
||||
|
||||
assert strategy_path is not None
|
||||
assert strategy_path.exists()
|
||||
assert strategy_path.suffix == ".py"
|
||||
assert "class Strategy_" in strategy_path.read_text()
|
||||
assert "def evaluate" in strategy_path.read_text()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_strategy_handles_api_error(optimizer: EvolutionOptimizer) -> None:
|
||||
"""Test that generate_strategy handles Gemini API errors gracefully."""
|
||||
failures = [{"decision_id": "1", "timestamp": "2024-01-15T09:30:00+00:00"}]
|
||||
|
||||
with patch.object(
|
||||
optimizer._client.aio.models,
|
||||
"generate_content",
|
||||
side_effect=Exception("API Error"),
|
||||
):
|
||||
strategy_path = await optimizer.generate_strategy(failures)
|
||||
|
||||
assert strategy_path is None
|
||||
|
||||
|
||||
def test_get_performance_summary() -> None:
|
||||
"""Test getting performance summary from trades table."""
|
||||
# Create a temporary database with trades
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
|
||||
tmp_path = tmp.name
|
||||
|
||||
conn = init_db(tmp_path)
|
||||
log_trade(conn, "005930", "BUY", 85, "Test win", quantity=10, price=70000, pnl=1000.0)
|
||||
log_trade(conn, "000660", "SELL", 90, "Test loss", quantity=5, price=100000, pnl=-500.0)
|
||||
log_trade(conn, "035420", "BUY", 80, "Test win", quantity=8, price=50000, pnl=800.0)
|
||||
conn.close()
|
||||
|
||||
# Create settings with temp database path
|
||||
settings = Settings(
|
||||
KIS_APP_KEY="test_key",
|
||||
KIS_APP_SECRET="test_secret",
|
||||
KIS_ACCOUNT_NO="12345678-01",
|
||||
GEMINI_API_KEY="test_gemini_key",
|
||||
GEMINI_MODEL="gemini-pro",
|
||||
DB_PATH=tmp_path,
|
||||
)
|
||||
|
||||
optimizer = EvolutionOptimizer(settings)
|
||||
summary = optimizer.get_performance_summary()
|
||||
|
||||
assert summary["total_trades"] == 3
|
||||
assert summary["wins"] == 2
|
||||
assert summary["losses"] == 1
|
||||
assert summary["total_pnl"] == 1300.0
|
||||
assert summary["avg_pnl"] == 433.33
|
||||
|
||||
# Clean up
|
||||
Path(tmp_path).unlink()
|
||||
|
||||
|
||||
def test_validate_strategy_success(optimizer: EvolutionOptimizer, tmp_path: Path) -> None:
|
||||
"""Test strategy validation when tests pass."""
|
||||
strategy_file = tmp_path / "test_strategy.py"
|
||||
strategy_file.write_text("# Valid strategy file")
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = Mock(returncode=0, stdout="", stderr="")
|
||||
result = optimizer.validate_strategy(strategy_file)
|
||||
|
||||
assert result is True
|
||||
assert strategy_file.exists()
|
||||
|
||||
|
||||
def test_validate_strategy_failure(optimizer: EvolutionOptimizer, tmp_path: Path) -> None:
|
||||
"""Test strategy validation when tests fail."""
|
||||
strategy_file = tmp_path / "test_strategy.py"
|
||||
strategy_file.write_text("# Invalid strategy file")
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = Mock(returncode=1, stdout="FAILED", stderr="")
|
||||
result = optimizer.validate_strategy(strategy_file)
|
||||
|
||||
assert result is False
|
||||
# File should be deleted on failure
|
||||
assert not strategy_file.exists()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# ABTester Tests
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_calculate_performance_basic(ab_tester: ABTester) -> None:
|
||||
"""Test basic performance calculation."""
|
||||
trades = [
|
||||
{"pnl": 1000.0},
|
||||
{"pnl": -500.0},
|
||||
{"pnl": 800.0},
|
||||
{"pnl": 200.0},
|
||||
]
|
||||
|
||||
perf = ab_tester.calculate_performance(trades, "TestStrategy")
|
||||
|
||||
assert perf.strategy_name == "TestStrategy"
|
||||
assert perf.total_trades == 4
|
||||
assert perf.wins == 3
|
||||
assert perf.losses == 1
|
||||
assert perf.total_pnl == 1500.0
|
||||
assert perf.avg_pnl == 375.0
|
||||
assert perf.win_rate == 75.0
|
||||
assert perf.sharpe_ratio is not None
|
||||
|
||||
|
||||
def test_calculate_performance_empty(ab_tester: ABTester) -> None:
|
||||
"""Test performance calculation with no trades."""
|
||||
perf = ab_tester.calculate_performance([], "EmptyStrategy")
|
||||
|
||||
assert perf.total_trades == 0
|
||||
assert perf.wins == 0
|
||||
assert perf.losses == 0
|
||||
assert perf.total_pnl == 0.0
|
||||
assert perf.avg_pnl == 0.0
|
||||
assert perf.win_rate == 0.0
|
||||
assert perf.sharpe_ratio is None
|
||||
|
||||
|
||||
def test_compare_strategies_significant_difference(ab_tester: ABTester) -> None:
|
||||
"""Test strategy comparison with significant performance difference."""
|
||||
# Strategy A: consistently profitable
|
||||
trades_a = [{"pnl": 1000.0} for _ in range(30)]
|
||||
|
||||
# Strategy B: consistently losing
|
||||
trades_b = [{"pnl": -500.0} for _ in range(30)]
|
||||
|
||||
result = ab_tester.compare_strategies(trades_a, trades_b, "Strategy A", "Strategy B")
|
||||
|
||||
# scipy returns np.True_ instead of Python bool
|
||||
assert bool(result.is_significant) is True
|
||||
assert result.winner == "Strategy A"
|
||||
assert result.p_value < 0.05
|
||||
assert result.performance_a.avg_pnl > result.performance_b.avg_pnl
|
||||
|
||||
|
||||
def test_compare_strategies_no_difference(ab_tester: ABTester) -> None:
|
||||
"""Test strategy comparison with no significant difference."""
|
||||
# Both strategies have similar performance
|
||||
trades_a = [{"pnl": 100.0}, {"pnl": -50.0}, {"pnl": 80.0}]
|
||||
trades_b = [{"pnl": 90.0}, {"pnl": -60.0}, {"pnl": 85.0}]
|
||||
|
||||
result = ab_tester.compare_strategies(trades_a, trades_b, "Strategy A", "Strategy B")
|
||||
|
||||
# With small samples and similar performance, likely not significant
|
||||
assert result.winner is None or not result.is_significant
|
||||
|
||||
|
||||
def test_should_deploy_meets_criteria(ab_tester: ABTester) -> None:
|
||||
"""Test deployment decision when criteria are met."""
|
||||
# Create a winning result that meets criteria
|
||||
trades_a = [{"pnl": 1000.0} for _ in range(25)] # 100% win rate
|
||||
trades_b = [{"pnl": -500.0} for _ in range(25)]
|
||||
|
||||
result = ab_tester.compare_strategies(trades_a, trades_b, "Winner", "Loser")
|
||||
|
||||
should_deploy = ab_tester.should_deploy(result, min_win_rate=60.0, min_trades=20)
|
||||
|
||||
assert should_deploy is True
|
||||
|
||||
|
||||
def test_should_deploy_insufficient_trades(ab_tester: ABTester) -> None:
|
||||
"""Test deployment decision with insufficient trades."""
|
||||
trades_a = [{"pnl": 1000.0} for _ in range(10)] # Only 10 trades
|
||||
trades_b = [{"pnl": -500.0} for _ in range(10)]
|
||||
|
||||
result = ab_tester.compare_strategies(trades_a, trades_b, "Winner", "Loser")
|
||||
|
||||
should_deploy = ab_tester.should_deploy(result, min_win_rate=60.0, min_trades=20)
|
||||
|
||||
assert should_deploy is False
|
||||
|
||||
|
||||
def test_should_deploy_low_win_rate(ab_tester: ABTester) -> None:
|
||||
"""Test deployment decision with low win rate."""
|
||||
# Mix of wins and losses, below 60% win rate
|
||||
trades_a = [{"pnl": 100.0}] * 10 + [{"pnl": -100.0}] * 15 # 40% win rate
|
||||
trades_b = [{"pnl": -500.0} for _ in range(25)]
|
||||
|
||||
result = ab_tester.compare_strategies(trades_a, trades_b, "LowWinner", "Loser")
|
||||
|
||||
should_deploy = ab_tester.should_deploy(result, min_win_rate=60.0, min_trades=20)
|
||||
|
||||
assert should_deploy is False
|
||||
|
||||
|
||||
def test_should_deploy_not_significant(ab_tester: ABTester) -> None:
|
||||
"""Test deployment decision when difference is not significant."""
|
||||
# Use more varied data to ensure statistical insignificance
|
||||
trades_a = [{"pnl": 100.0}, {"pnl": -50.0}] * 12 + [{"pnl": 100.0}]
|
||||
trades_b = [{"pnl": 95.0}, {"pnl": -45.0}] * 12 + [{"pnl": 95.0}]
|
||||
|
||||
result = ab_tester.compare_strategies(trades_a, trades_b, "A", "B")
|
||||
|
||||
should_deploy = ab_tester.should_deploy(result, min_win_rate=60.0, min_trades=20)
|
||||
|
||||
# Not significant or not profitable enough
|
||||
# Even if significant, win rate is 50% which is below 60% threshold
|
||||
assert should_deploy is False
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# PerformanceTracker Tests
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_strategy_metrics(db_conn: sqlite3.Connection) -> None:
|
||||
"""Test getting strategy metrics."""
|
||||
# Add some trades
|
||||
log_trade(db_conn, "005930", "BUY", 85, "Win 1", quantity=10, price=70000, pnl=1000.0)
|
||||
log_trade(db_conn, "000660", "SELL", 90, "Loss 1", quantity=5, price=100000, pnl=-500.0)
|
||||
log_trade(db_conn, "035420", "BUY", 80, "Win 2", quantity=8, price=50000, pnl=800.0)
|
||||
log_trade(db_conn, "005930", "HOLD", 75, "Hold", quantity=0, price=70000, pnl=0.0)
|
||||
|
||||
tracker = PerformanceTracker(db_path=":memory:")
|
||||
# Manually set connection for testing
|
||||
tracker._db_path = db_conn
|
||||
|
||||
# Need to use the same connection
|
||||
with patch("sqlite3.connect", return_value=db_conn):
|
||||
metrics = tracker.get_strategy_metrics()
|
||||
|
||||
assert metrics.total_trades == 4
|
||||
assert metrics.wins == 2
|
||||
assert metrics.losses == 1
|
||||
assert metrics.holds == 1
|
||||
assert metrics.win_rate == 50.0
|
||||
assert metrics.total_pnl == 1300.0
|
||||
|
||||
|
||||
def test_calculate_improvement_trend_improving(performance_tracker: PerformanceTracker) -> None:
|
||||
"""Test improvement trend calculation for improving strategy."""
|
||||
metrics = [
|
||||
StrategyMetrics(
|
||||
strategy_name="test",
|
||||
period_start="2024-01-01",
|
||||
period_end="2024-01-07",
|
||||
total_trades=10,
|
||||
wins=5,
|
||||
losses=5,
|
||||
holds=0,
|
||||
win_rate=50.0,
|
||||
avg_pnl=100.0,
|
||||
total_pnl=1000.0,
|
||||
best_trade=500.0,
|
||||
worst_trade=-300.0,
|
||||
avg_confidence=75.0,
|
||||
),
|
||||
StrategyMetrics(
|
||||
strategy_name="test",
|
||||
period_start="2024-01-08",
|
||||
period_end="2024-01-14",
|
||||
total_trades=10,
|
||||
wins=7,
|
||||
losses=3,
|
||||
holds=0,
|
||||
win_rate=70.0,
|
||||
avg_pnl=200.0,
|
||||
total_pnl=2000.0,
|
||||
best_trade=600.0,
|
||||
worst_trade=-200.0,
|
||||
avg_confidence=80.0,
|
||||
),
|
||||
]
|
||||
|
||||
trend = performance_tracker.calculate_improvement_trend(metrics)
|
||||
|
||||
assert trend["trend"] == "improving"
|
||||
assert trend["win_rate_change"] == 20.0
|
||||
assert trend["pnl_change"] == 100.0
|
||||
assert trend["confidence_change"] == 5.0
|
||||
|
||||
|
||||
def test_calculate_improvement_trend_declining(performance_tracker: PerformanceTracker) -> None:
|
||||
"""Test improvement trend calculation for declining strategy."""
|
||||
metrics = [
|
||||
StrategyMetrics(
|
||||
strategy_name="test",
|
||||
period_start="2024-01-01",
|
||||
period_end="2024-01-07",
|
||||
total_trades=10,
|
||||
wins=7,
|
||||
losses=3,
|
||||
holds=0,
|
||||
win_rate=70.0,
|
||||
avg_pnl=200.0,
|
||||
total_pnl=2000.0,
|
||||
best_trade=600.0,
|
||||
worst_trade=-200.0,
|
||||
avg_confidence=80.0,
|
||||
),
|
||||
StrategyMetrics(
|
||||
strategy_name="test",
|
||||
period_start="2024-01-08",
|
||||
period_end="2024-01-14",
|
||||
total_trades=10,
|
||||
wins=4,
|
||||
losses=6,
|
||||
holds=0,
|
||||
win_rate=40.0,
|
||||
avg_pnl=-50.0,
|
||||
total_pnl=-500.0,
|
||||
best_trade=300.0,
|
||||
worst_trade=-400.0,
|
||||
avg_confidence=70.0,
|
||||
),
|
||||
]
|
||||
|
||||
trend = performance_tracker.calculate_improvement_trend(metrics)
|
||||
|
||||
assert trend["trend"] == "declining"
|
||||
assert trend["win_rate_change"] == -30.0
|
||||
assert trend["pnl_change"] == -250.0
|
||||
|
||||
|
||||
def test_calculate_improvement_trend_insufficient_data(performance_tracker: PerformanceTracker) -> None:
|
||||
"""Test improvement trend with insufficient data."""
|
||||
metrics = [
|
||||
StrategyMetrics(
|
||||
strategy_name="test",
|
||||
period_start="2024-01-01",
|
||||
period_end="2024-01-07",
|
||||
total_trades=10,
|
||||
wins=5,
|
||||
losses=5,
|
||||
holds=0,
|
||||
win_rate=50.0,
|
||||
avg_pnl=100.0,
|
||||
total_pnl=1000.0,
|
||||
best_trade=500.0,
|
||||
worst_trade=-300.0,
|
||||
avg_confidence=75.0,
|
||||
)
|
||||
]
|
||||
|
||||
trend = performance_tracker.calculate_improvement_trend(metrics)
|
||||
|
||||
assert trend["trend"] == "insufficient_data"
|
||||
assert trend["win_rate_change"] == 0.0
|
||||
assert trend["pnl_change"] == 0.0
|
||||
|
||||
|
||||
def test_export_dashboard_json(performance_tracker: PerformanceTracker) -> None:
|
||||
"""Test exporting dashboard as JSON."""
|
||||
overall_metrics = StrategyMetrics(
|
||||
strategy_name="test",
|
||||
period_start="2024-01-01",
|
||||
period_end="2024-01-31",
|
||||
total_trades=100,
|
||||
wins=60,
|
||||
losses=40,
|
||||
holds=10,
|
||||
win_rate=60.0,
|
||||
avg_pnl=150.0,
|
||||
total_pnl=15000.0,
|
||||
best_trade=1000.0,
|
||||
worst_trade=-500.0,
|
||||
avg_confidence=80.0,
|
||||
)
|
||||
|
||||
dashboard = PerformanceDashboard(
|
||||
generated_at=datetime.now(UTC).isoformat(),
|
||||
overall_metrics=overall_metrics,
|
||||
daily_metrics=[],
|
||||
weekly_metrics=[],
|
||||
improvement_trend={"trend": "improving", "win_rate_change": 10.0},
|
||||
)
|
||||
|
||||
json_output = performance_tracker.export_dashboard_json(dashboard)
|
||||
|
||||
# Verify it's valid JSON
|
||||
data = json.loads(json_output)
|
||||
assert "generated_at" in data
|
||||
assert "overall_metrics" in data
|
||||
assert data["overall_metrics"]["total_trades"] == 100
|
||||
assert data["overall_metrics"]["win_rate"] == 60.0
|
||||
|
||||
|
||||
def test_generate_dashboard() -> None:
|
||||
"""Test generating a complete dashboard."""
|
||||
# Create tracker with temp database
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
|
||||
tmp_path = tmp.name
|
||||
|
||||
# Initialize with data
|
||||
conn = init_db(tmp_path)
|
||||
log_trade(conn, "005930", "BUY", 85, "Win", quantity=10, price=70000, pnl=1000.0)
|
||||
log_trade(conn, "000660", "SELL", 90, "Loss", quantity=5, price=100000, pnl=-500.0)
|
||||
conn.close()
|
||||
|
||||
tracker = PerformanceTracker(db_path=tmp_path)
|
||||
dashboard = tracker.generate_dashboard()
|
||||
|
||||
assert isinstance(dashboard, PerformanceDashboard)
|
||||
assert dashboard.overall_metrics.total_trades == 2
|
||||
assert len(dashboard.daily_metrics) == 7
|
||||
assert len(dashboard.weekly_metrics) == 4
|
||||
assert "trend" in dashboard.improvement_trend
|
||||
|
||||
# Clean up
|
||||
Path(tmp_path).unlink()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Integration Tests
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_evolution_pipeline(optimizer: EvolutionOptimizer, tmp_path: Path) -> None:
|
||||
"""Test the complete evolution pipeline."""
|
||||
# Add losing decisions
|
||||
logger = optimizer._decision_logger
|
||||
id1 = logger.log_decision(
|
||||
stock_code="005930",
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
action="BUY",
|
||||
confidence=85,
|
||||
rationale="Expected growth",
|
||||
context_snapshot={},
|
||||
input_data={},
|
||||
)
|
||||
logger.update_outcome(id1, pnl=-2000.0, accuracy=0)
|
||||
|
||||
# Mock Gemini and subprocess
|
||||
mock_response = Mock()
|
||||
mock_response.text = 'return {"action": "HOLD", "confidence": 50, "rationale": "Test"}'
|
||||
|
||||
with patch.object(optimizer._client.aio.models, "generate_content", new=AsyncMock(return_value=mock_response)):
|
||||
with patch("src.evolution.optimizer.STRATEGIES_DIR", tmp_path):
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = Mock(returncode=0, stdout="", stderr="")
|
||||
|
||||
result = await optimizer.evolve()
|
||||
|
||||
assert result is not None
|
||||
assert "title" in result
|
||||
assert "branch" in result
|
||||
assert "status" in result
|
||||
55
tests/test_kill_switch.py
Normal file
55
tests/test_kill_switch.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import pytest
|
||||
|
||||
from src.core.kill_switch import KillSwitchOrchestrator
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kill_switch_executes_steps_in_order() -> None:
|
||||
ks = KillSwitchOrchestrator()
|
||||
calls: list[str] = []
|
||||
|
||||
async def _cancel() -> None:
|
||||
calls.append("cancel")
|
||||
|
||||
def _refresh() -> None:
|
||||
calls.append("refresh")
|
||||
|
||||
def _reduce() -> None:
|
||||
calls.append("reduce")
|
||||
|
||||
def _snapshot() -> None:
|
||||
calls.append("snapshot")
|
||||
|
||||
def _notify() -> None:
|
||||
calls.append("notify")
|
||||
|
||||
report = await ks.trigger(
|
||||
reason="test",
|
||||
cancel_pending_orders=_cancel,
|
||||
refresh_order_state=_refresh,
|
||||
reduce_risk=_reduce,
|
||||
snapshot_state=_snapshot,
|
||||
notify=_notify,
|
||||
)
|
||||
|
||||
assert report.steps == [
|
||||
"block_new_orders",
|
||||
"cancel_pending_orders",
|
||||
"refresh_order_state",
|
||||
"reduce_risk",
|
||||
"snapshot_state",
|
||||
"notify",
|
||||
]
|
||||
assert calls == ["cancel", "refresh", "reduce", "snapshot", "notify"]
|
||||
assert report.errors == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kill_switch_collects_step_errors() -> None:
|
||||
ks = KillSwitchOrchestrator()
|
||||
|
||||
def _boom() -> None:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
report = await ks.trigger(reason="test", cancel_pending_orders=_boom)
|
||||
assert any(err.startswith("cancel_pending_orders:") for err in report.errors)
|
||||
558
tests/test_latency_control.py
Normal file
558
tests/test_latency_control.py
Normal file
@@ -0,0 +1,558 @@
|
||||
"""Tests for latency control system (criticality assessment and priority queue)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from src.core.criticality import CriticalityAssessor, CriticalityLevel
|
||||
from src.core.priority_queue import PriorityTask, PriorityTaskQueue
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CriticalityAssessor Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCriticalityAssessor:
|
||||
"""Test suite for criticality assessment logic."""
|
||||
|
||||
def test_market_closed_returns_low(self) -> None:
|
||||
"""Market closed should return LOW priority."""
|
||||
assessor = CriticalityAssessor()
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=0.0,
|
||||
volatility_score=50.0,
|
||||
volume_surge=1.0,
|
||||
is_market_open=False,
|
||||
)
|
||||
assert level == CriticalityLevel.LOW
|
||||
|
||||
def test_very_low_volatility_returns_low(self) -> None:
|
||||
"""Very low volatility should return LOW priority."""
|
||||
assessor = CriticalityAssessor()
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=0.0,
|
||||
volatility_score=20.0, # Below 30.0 threshold
|
||||
volume_surge=1.0,
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.LOW
|
||||
|
||||
def test_critical_pnl_threshold_triggered(self) -> None:
|
||||
"""P&L below -2.5% should trigger CRITICAL."""
|
||||
assessor = CriticalityAssessor()
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=-2.6, # Below -2.5% threshold
|
||||
volatility_score=50.0,
|
||||
volume_surge=1.0,
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.CRITICAL
|
||||
|
||||
def test_critical_pnl_at_circuit_breaker_proximity(self) -> None:
|
||||
"""P&L at exactly -2.5% (near -3.0% breaker) should be CRITICAL."""
|
||||
assessor = CriticalityAssessor()
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=-2.5,
|
||||
volatility_score=50.0,
|
||||
volume_surge=1.0,
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.CRITICAL
|
||||
|
||||
def test_critical_price_change_positive(self) -> None:
|
||||
"""Large positive price change (>5%) should trigger CRITICAL."""
|
||||
assessor = CriticalityAssessor()
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=0.0,
|
||||
volatility_score=50.0,
|
||||
volume_surge=1.0,
|
||||
price_change_1m=5.5, # Above 5.0% threshold
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.CRITICAL
|
||||
|
||||
def test_critical_price_change_negative(self) -> None:
|
||||
"""Large negative price change (<-5%) should trigger CRITICAL."""
|
||||
assessor = CriticalityAssessor()
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=0.0,
|
||||
volatility_score=50.0,
|
||||
volume_surge=1.0,
|
||||
price_change_1m=-6.0, # Below -5.0% threshold
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.CRITICAL
|
||||
|
||||
def test_critical_volume_surge(self) -> None:
|
||||
"""Extreme volume surge (>10x) should trigger CRITICAL."""
|
||||
assessor = CriticalityAssessor()
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=0.0,
|
||||
volatility_score=50.0,
|
||||
volume_surge=12.0, # Above 10.0x threshold
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.CRITICAL
|
||||
|
||||
def test_high_volatility_returns_high(self) -> None:
|
||||
"""High volatility score should return HIGH priority."""
|
||||
assessor = CriticalityAssessor()
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=0.0,
|
||||
volatility_score=75.0, # Above 70.0 threshold
|
||||
volume_surge=1.0,
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.HIGH
|
||||
|
||||
def test_normal_conditions_return_normal(self) -> None:
|
||||
"""Normal market conditions should return NORMAL priority."""
|
||||
assessor = CriticalityAssessor()
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=0.5,
|
||||
volatility_score=50.0, # Between 30-70
|
||||
volume_surge=1.5,
|
||||
price_change_1m=1.0,
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.NORMAL
|
||||
|
||||
def test_custom_thresholds(self) -> None:
|
||||
"""Custom thresholds should be respected."""
|
||||
assessor = CriticalityAssessor(
|
||||
critical_pnl_threshold=-1.0,
|
||||
critical_price_change_threshold=3.0,
|
||||
critical_volume_surge_threshold=5.0,
|
||||
high_volatility_threshold=60.0,
|
||||
low_volatility_threshold=20.0,
|
||||
)
|
||||
|
||||
# Test custom P&L threshold
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=-1.1,
|
||||
volatility_score=50.0,
|
||||
volume_surge=1.0,
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.CRITICAL
|
||||
|
||||
# Test custom price change threshold
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=0.0,
|
||||
volatility_score=50.0,
|
||||
volume_surge=1.0,
|
||||
price_change_1m=3.5,
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.CRITICAL
|
||||
|
||||
def test_get_timeout_returns_correct_values(self) -> None:
|
||||
"""Timeout values should match specification."""
|
||||
assessor = CriticalityAssessor()
|
||||
|
||||
assert assessor.get_timeout(CriticalityLevel.CRITICAL) == 5.0
|
||||
assert assessor.get_timeout(CriticalityLevel.HIGH) == 30.0
|
||||
assert assessor.get_timeout(CriticalityLevel.NORMAL) == 60.0
|
||||
assert assessor.get_timeout(CriticalityLevel.LOW) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PriorityTaskQueue Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPriorityTaskQueue:
|
||||
"""Test suite for priority queue implementation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_task(self) -> None:
|
||||
"""Tasks should be enqueued successfully."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
success = await queue.enqueue(
|
||||
task_id="test-1",
|
||||
criticality=CriticalityLevel.NORMAL,
|
||||
task_data={"action": "test"},
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert await queue.size() == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_rejects_when_full(self) -> None:
|
||||
"""Queue should reject tasks when full."""
|
||||
queue = PriorityTaskQueue(max_size=2)
|
||||
|
||||
# Fill the queue
|
||||
await queue.enqueue("task-1", CriticalityLevel.NORMAL, {})
|
||||
await queue.enqueue("task-2", CriticalityLevel.NORMAL, {})
|
||||
|
||||
# Third task should be rejected
|
||||
success = await queue.enqueue("task-3", CriticalityLevel.NORMAL, {})
|
||||
assert success is False
|
||||
assert await queue.size() == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dequeue_returns_highest_priority(self) -> None:
|
||||
"""Dequeue should return highest priority task first."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Enqueue tasks in reverse priority order
|
||||
await queue.enqueue("low", CriticalityLevel.LOW, {"priority": 3})
|
||||
await queue.enqueue("normal", CriticalityLevel.NORMAL, {"priority": 2})
|
||||
await queue.enqueue("high", CriticalityLevel.HIGH, {"priority": 1})
|
||||
await queue.enqueue("critical", CriticalityLevel.CRITICAL, {"priority": 0})
|
||||
|
||||
# Dequeue should return CRITICAL first
|
||||
task = await queue.dequeue(timeout=1.0)
|
||||
assert task is not None
|
||||
assert task.task_id == "critical"
|
||||
assert task.priority == 0
|
||||
|
||||
# Then HIGH
|
||||
task = await queue.dequeue(timeout=1.0)
|
||||
assert task is not None
|
||||
assert task.task_id == "high"
|
||||
assert task.priority == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dequeue_fifo_within_same_priority(self) -> None:
|
||||
"""Tasks with same priority should be FIFO."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Enqueue multiple tasks with same priority
|
||||
await queue.enqueue("task-1", CriticalityLevel.NORMAL, {})
|
||||
await asyncio.sleep(0.01) # Small delay to ensure different timestamps
|
||||
await queue.enqueue("task-2", CriticalityLevel.NORMAL, {})
|
||||
await asyncio.sleep(0.01)
|
||||
await queue.enqueue("task-3", CriticalityLevel.NORMAL, {})
|
||||
|
||||
# Should dequeue in FIFO order
|
||||
task1 = await queue.dequeue(timeout=1.0)
|
||||
task2 = await queue.dequeue(timeout=1.0)
|
||||
task3 = await queue.dequeue(timeout=1.0)
|
||||
|
||||
assert task1 is not None and task1.task_id == "task-1"
|
||||
assert task2 is not None and task2.task_id == "task-2"
|
||||
assert task3 is not None and task3.task_id == "task-3"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dequeue_returns_none_when_empty(self) -> None:
|
||||
"""Dequeue should return None when queue is empty after timeout."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
task = await queue.dequeue(timeout=0.1)
|
||||
assert task is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_with_timeout_success(self) -> None:
|
||||
"""Task execution should succeed within timeout."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Create a simple async callback
|
||||
async def test_callback() -> str:
|
||||
await asyncio.sleep(0.01)
|
||||
return "success"
|
||||
|
||||
task = PriorityTask(
|
||||
priority=0,
|
||||
timestamp=0.0,
|
||||
task_id="test",
|
||||
task_data={},
|
||||
callback=test_callback,
|
||||
)
|
||||
|
||||
result = await queue.execute_with_timeout(task, timeout=1.0)
|
||||
assert result == "success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_with_timeout_raises_timeout_error(self) -> None:
|
||||
"""Task execution should raise TimeoutError if exceeds timeout."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Create a slow async callback
|
||||
async def slow_callback() -> str:
|
||||
await asyncio.sleep(1.0)
|
||||
return "too slow"
|
||||
|
||||
task = PriorityTask(
|
||||
priority=0,
|
||||
timestamp=0.0,
|
||||
task_id="test",
|
||||
task_data={},
|
||||
callback=slow_callback,
|
||||
)
|
||||
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await queue.execute_with_timeout(task, timeout=0.1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_with_timeout_propagates_exceptions(self) -> None:
|
||||
"""Task execution should propagate exceptions from callback."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Create a failing async callback
|
||||
async def failing_callback() -> None:
|
||||
raise ValueError("Test error")
|
||||
|
||||
task = PriorityTask(
|
||||
priority=0,
|
||||
timestamp=0.0,
|
||||
task_id="test",
|
||||
task_data={},
|
||||
callback=failing_callback,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Test error"):
|
||||
await queue.execute_with_timeout(task, timeout=1.0)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_without_timeout(self) -> None:
|
||||
"""Task execution should work without timeout (LOW priority)."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
async def test_callback() -> str:
|
||||
await asyncio.sleep(0.01)
|
||||
return "success"
|
||||
|
||||
task = PriorityTask(
|
||||
priority=3,
|
||||
timestamp=0.0,
|
||||
task_id="test",
|
||||
task_data={},
|
||||
callback=test_callback,
|
||||
)
|
||||
|
||||
result = await queue.execute_with_timeout(task, timeout=None)
|
||||
assert result == "success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_metrics(self) -> None:
|
||||
"""Queue should track metrics correctly."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Enqueue and dequeue some tasks
|
||||
await queue.enqueue("task-1", CriticalityLevel.CRITICAL, {})
|
||||
await queue.enqueue("task-2", CriticalityLevel.HIGH, {})
|
||||
await queue.enqueue("task-3", CriticalityLevel.NORMAL, {})
|
||||
|
||||
await queue.dequeue(timeout=1.0)
|
||||
await queue.dequeue(timeout=1.0)
|
||||
|
||||
metrics = await queue.get_metrics()
|
||||
|
||||
assert metrics.total_enqueued == 3
|
||||
assert metrics.total_dequeued == 2
|
||||
assert metrics.current_size == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_time_metrics(self) -> None:
|
||||
"""Queue should track wait times per criticality level."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Enqueue tasks with different criticality
|
||||
await queue.enqueue("critical-1", CriticalityLevel.CRITICAL, {})
|
||||
await asyncio.sleep(0.05) # Add some wait time
|
||||
|
||||
await queue.dequeue(timeout=1.0)
|
||||
|
||||
metrics = await queue.get_metrics()
|
||||
|
||||
# Should have wait time metrics for CRITICAL
|
||||
assert CriticalityLevel.CRITICAL in metrics.avg_wait_time
|
||||
assert metrics.avg_wait_time[CriticalityLevel.CRITICAL] > 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_queue(self) -> None:
|
||||
"""Clear should remove all tasks from queue."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
await queue.enqueue("task-1", CriticalityLevel.NORMAL, {})
|
||||
await queue.enqueue("task-2", CriticalityLevel.NORMAL, {})
|
||||
await queue.enqueue("task-3", CriticalityLevel.NORMAL, {})
|
||||
|
||||
cleared = await queue.clear()
|
||||
|
||||
assert cleared == 3
|
||||
assert await queue.size() == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_enqueue_dequeue(self) -> None:
|
||||
"""Queue should handle concurrent operations safely."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Concurrent enqueue operations
|
||||
async def enqueue_tasks() -> None:
|
||||
for i in range(10):
|
||||
await queue.enqueue(
|
||||
f"task-{i}",
|
||||
CriticalityLevel.NORMAL,
|
||||
{"index": i},
|
||||
)
|
||||
|
||||
# Concurrent dequeue operations
|
||||
async def dequeue_tasks() -> list[str]:
|
||||
tasks = []
|
||||
for _ in range(10):
|
||||
task = await queue.dequeue(timeout=1.0)
|
||||
if task:
|
||||
tasks.append(task.task_id)
|
||||
await asyncio.sleep(0.01)
|
||||
return tasks
|
||||
|
||||
# Run both concurrently
|
||||
enqueue_task = asyncio.create_task(enqueue_tasks())
|
||||
dequeue_task = asyncio.create_task(dequeue_tasks())
|
||||
|
||||
await enqueue_task
|
||||
dequeued_ids = await dequeue_task
|
||||
|
||||
# All tasks should be processed
|
||||
assert len(dequeued_ids) == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_metric_tracking(self) -> None:
|
||||
"""Queue should track timeout occurrences."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
async def slow_callback() -> str:
|
||||
await asyncio.sleep(1.0)
|
||||
return "too slow"
|
||||
|
||||
task = PriorityTask(
|
||||
priority=0,
|
||||
timestamp=0.0,
|
||||
task_id="test",
|
||||
task_data={},
|
||||
callback=slow_callback,
|
||||
)
|
||||
|
||||
try:
|
||||
await queue.execute_with_timeout(task, timeout=0.1)
|
||||
except TimeoutError:
|
||||
pass
|
||||
|
||||
metrics = await queue.get_metrics()
|
||||
assert metrics.total_timeouts == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_metric_tracking(self) -> None:
|
||||
"""Queue should track execution errors."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
async def failing_callback() -> None:
|
||||
raise ValueError("Test error")
|
||||
|
||||
task = PriorityTask(
|
||||
priority=0,
|
||||
timestamp=0.0,
|
||||
task_id="test",
|
||||
task_data={},
|
||||
callback=failing_callback,
|
||||
)
|
||||
|
||||
try:
|
||||
await queue.execute_with_timeout(task, timeout=1.0)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
metrics = await queue.get_metrics()
|
||||
assert metrics.total_errors == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLatencyControlIntegration:
|
||||
"""Integration tests for criticality assessment and priority queue."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_critical_task_bypass_queue(self) -> None:
|
||||
"""CRITICAL tasks should bypass lower priority tasks."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Add normal priority tasks
|
||||
await queue.enqueue("normal-1", CriticalityLevel.NORMAL, {})
|
||||
await queue.enqueue("normal-2", CriticalityLevel.NORMAL, {})
|
||||
|
||||
# Add critical task (should jump to front)
|
||||
await queue.enqueue("critical", CriticalityLevel.CRITICAL, {})
|
||||
|
||||
# Dequeue should return critical first
|
||||
task = await queue.dequeue(timeout=1.0)
|
||||
assert task is not None
|
||||
assert task.task_id == "critical"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_enforcement_by_criticality(self) -> None:
|
||||
"""Timeout enforcement should match criticality level."""
|
||||
assessor = CriticalityAssessor()
|
||||
|
||||
# CRITICAL should have 5s timeout
|
||||
critical_timeout = assessor.get_timeout(CriticalityLevel.CRITICAL)
|
||||
assert critical_timeout == 5.0
|
||||
|
||||
# HIGH should have 30s timeout
|
||||
high_timeout = assessor.get_timeout(CriticalityLevel.HIGH)
|
||||
assert high_timeout == 30.0
|
||||
|
||||
# NORMAL should have 60s timeout
|
||||
normal_timeout = assessor.get_timeout(CriticalityLevel.NORMAL)
|
||||
assert normal_timeout == 60.0
|
||||
|
||||
# LOW should have no timeout
|
||||
low_timeout = assessor.get_timeout(CriticalityLevel.LOW)
|
||||
assert low_timeout is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fast_path_execution_for_critical(self) -> None:
|
||||
"""CRITICAL tasks should complete quickly."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Create a fast callback simulating fast-path execution
|
||||
async def fast_path_callback() -> str:
|
||||
# Simulate simplified decision flow
|
||||
await asyncio.sleep(0.01) # Very fast execution
|
||||
return "fast_path_complete"
|
||||
|
||||
task = PriorityTask(
|
||||
priority=0, # CRITICAL
|
||||
timestamp=0.0,
|
||||
task_id="critical-fast",
|
||||
task_data={},
|
||||
callback=fast_path_callback,
|
||||
)
|
||||
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
result = await queue.execute_with_timeout(task, timeout=5.0)
|
||||
elapsed = time.time() - start
|
||||
|
||||
assert result == "fast_path_complete"
|
||||
assert elapsed < 5.0 # Should complete well under CRITICAL timeout
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graceful_degradation_when_queue_full(self) -> None:
|
||||
"""System should gracefully handle full queue."""
|
||||
queue = PriorityTaskQueue(max_size=2)
|
||||
|
||||
# Fill the queue
|
||||
await queue.enqueue("task-1", CriticalityLevel.NORMAL, {})
|
||||
await queue.enqueue("task-2", CriticalityLevel.NORMAL, {})
|
||||
|
||||
# Try to add more tasks
|
||||
success = await queue.enqueue("task-3", CriticalityLevel.NORMAL, {})
|
||||
assert success is False
|
||||
|
||||
# Queue should still function
|
||||
task = await queue.dequeue(timeout=1.0)
|
||||
assert task is not None
|
||||
|
||||
# Now we can add another task
|
||||
success = await queue.enqueue("task-4", CriticalityLevel.NORMAL, {})
|
||||
assert success is True
|
||||
117
tests/test_logging_config.py
Normal file
117
tests/test_logging_config.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""Tests for JSON structured logging configuration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from src.logging_config import JSONFormatter, setup_logging
|
||||
|
||||
|
||||
class TestJSONFormatter:
|
||||
"""Test JSONFormatter output."""
|
||||
|
||||
def test_basic_log_record(self) -> None:
|
||||
"""JSONFormatter must emit valid JSON with required fields."""
|
||||
formatter = JSONFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="test.logger",
|
||||
level=logging.INFO,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg="Hello %s",
|
||||
args=("world",),
|
||||
exc_info=None,
|
||||
)
|
||||
output = formatter.format(record)
|
||||
data = json.loads(output)
|
||||
assert data["level"] == "INFO"
|
||||
assert data["logger"] == "test.logger"
|
||||
assert data["message"] == "Hello world"
|
||||
assert "timestamp" in data
|
||||
|
||||
def test_includes_exception_info(self) -> None:
|
||||
"""JSONFormatter must include exception info when present."""
|
||||
formatter = JSONFormatter()
|
||||
try:
|
||||
raise ValueError("test error")
|
||||
except ValueError:
|
||||
exc_info = sys.exc_info()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.ERROR,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg="oops",
|
||||
args=(),
|
||||
exc_info=exc_info,
|
||||
)
|
||||
output = formatter.format(record)
|
||||
data = json.loads(output)
|
||||
assert "exception" in data
|
||||
assert "ValueError" in data["exception"]
|
||||
|
||||
def test_extra_trading_fields_included(self) -> None:
|
||||
"""Extra trading fields attached to the record must appear in JSON."""
|
||||
formatter = JSONFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg="trade",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
record.stock_code = "005930" # type: ignore[attr-defined]
|
||||
record.action = "BUY" # type: ignore[attr-defined]
|
||||
record.confidence = 85 # type: ignore[attr-defined]
|
||||
record.pnl_pct = -1.5 # type: ignore[attr-defined]
|
||||
record.order_amount = 1_000_000 # type: ignore[attr-defined]
|
||||
output = formatter.format(record)
|
||||
data = json.loads(output)
|
||||
assert data["stock_code"] == "005930"
|
||||
assert data["action"] == "BUY"
|
||||
assert data["confidence"] == 85
|
||||
assert data["pnl_pct"] == -1.5
|
||||
assert data["order_amount"] == 1_000_000
|
||||
|
||||
def test_none_extra_fields_excluded(self) -> None:
|
||||
"""Extra fields that are None must not appear in JSON output."""
|
||||
formatter = JSONFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg="no extras",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
output = formatter.format(record)
|
||||
data = json.loads(output)
|
||||
assert "stock_code" not in data
|
||||
assert "action" not in data
|
||||
assert "confidence" not in data
|
||||
|
||||
|
||||
class TestSetupLogging:
|
||||
"""Test setup_logging function."""
|
||||
|
||||
def test_configures_root_logger(self) -> None:
|
||||
"""setup_logging must attach a JSON handler to the root logger."""
|
||||
setup_logging(level=logging.DEBUG)
|
||||
root = logging.getLogger()
|
||||
json_handlers = [
|
||||
h for h in root.handlers if isinstance(h.formatter, JSONFormatter)
|
||||
]
|
||||
assert len(json_handlers) == 1
|
||||
assert root.level == logging.DEBUG
|
||||
|
||||
def test_avoids_duplicate_handlers(self) -> None:
|
||||
"""Calling setup_logging twice must not add duplicate handlers."""
|
||||
setup_logging()
|
||||
setup_logging()
|
||||
root = logging.getLogger()
|
||||
assert len(root.handlers) == 1
|
||||
5118
tests/test_main.py
Normal file
5118
tests/test_main.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -7,6 +7,7 @@ import pytest
|
||||
|
||||
from src.markets.schedule import (
|
||||
MARKETS,
|
||||
expand_market_codes,
|
||||
get_next_market_open,
|
||||
get_open_markets,
|
||||
is_market_open,
|
||||
@@ -199,3 +200,28 @@ class TestGetNextMarketOpen:
|
||||
enabled_markets=["INVALID", "KR"], now=test_time
|
||||
)
|
||||
assert market.code == "KR"
|
||||
|
||||
|
||||
class TestExpandMarketCodes:
|
||||
"""Test shorthand market expansion."""
|
||||
|
||||
def test_expand_us_shorthand(self) -> None:
|
||||
assert expand_market_codes(["US"]) == ["US_NASDAQ", "US_NYSE", "US_AMEX"]
|
||||
|
||||
def test_expand_cn_shorthand(self) -> None:
|
||||
assert expand_market_codes(["CN"]) == ["CN_SHA", "CN_SZA"]
|
||||
|
||||
def test_expand_vn_shorthand(self) -> None:
|
||||
assert expand_market_codes(["VN"]) == ["VN_HAN", "VN_HCM"]
|
||||
|
||||
def test_expand_mixed_codes(self) -> None:
|
||||
assert expand_market_codes(["KR", "US", "JP"]) == [
|
||||
"KR",
|
||||
"US_NASDAQ",
|
||||
"US_NYSE",
|
||||
"US_AMEX",
|
||||
"JP",
|
||||
]
|
||||
|
||||
def test_expand_preserves_unknown_code(self) -> None:
|
||||
assert expand_market_codes(["KR", "UNKNOWN"]) == ["KR", "UNKNOWN"]
|
||||
|
||||
1036
tests/test_overseas_broker.py
Normal file
1036
tests/test_overseas_broker.py
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user