Compare commits
241 Commits
feature/is
...
feature/is
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a063bd9d10 | ||
| 847456e0af | |||
|
|
a3a9fd1f24 | ||
|
|
f34117bc81 | ||
| 17e012cd04 | |||
|
|
a030dcc0dc | ||
|
|
d1698dee33 | ||
| 8a8ba3b0cb | |||
|
|
6b74e4cc77 | ||
| 1a1fe7e637 | |||
|
|
2e27000760 | ||
| 5a41f86112 | |||
|
|
ff9c4d6082 | ||
| 25ad4776c9 | |||
|
|
9339824e22 | ||
| e6eae6c6e0 | |||
| bb6bd0392e | |||
| a66181b7a7 | |||
| da585ee547 | |||
| c737d5009a | |||
|
|
f7d33e69d1 | ||
|
|
7d99d8ec4a | ||
|
|
0727f28f77 | ||
|
|
ac4fb00644 | ||
|
|
4fc4a57036 | ||
| 641f3e8811 | |||
|
|
ebd0a0297c | ||
| 02a72e0f7e | |||
| 478a659ac2 | |||
|
|
16b9b6832d | ||
|
|
48b87a79f6 | ||
|
|
ad79082dcc | ||
|
|
11dff9d3e5 | ||
|
|
3c5f1752e6 | ||
|
|
d6a389e0b7 | ||
| cd36d53a47 | |||
|
|
1242794fc4 | ||
| b45d136894 | |||
|
|
ce82121f04 | ||
| 0e2987e66d | |||
|
|
cdd5a218a7 | ||
|
|
f3491e94e4 | ||
|
|
342511a6ed | ||
| 2d5912dc08 | |||
|
|
40ea41cf3c | ||
| af5bfbac24 | |||
|
|
7e9a573390 | ||
| 7dbc48260c | |||
|
|
4b883a4fc4 | ||
|
|
98071a8ee3 | ||
|
|
f2ad270e8b | ||
| 04c73a1a06 | |||
|
|
4da22b10eb | ||
| c920b257b6 | |||
| 9927bfa13e | |||
|
|
aceba86186 | ||
|
|
b961c53a92 | ||
| 76a7ee7cdb | |||
|
|
77577f3f4d | ||
| 17112b864a | |||
|
|
28bcc7acd7 | ||
|
|
39b9f179f4 | ||
| bd2b3241b2 | |||
| 561faaaafa | |||
| a33d6a145f | |||
| 7e6c912214 | |||
|
|
d6edbc0fa2 | ||
|
|
c7640a30d7 | ||
|
|
60a22d6cd4 | ||
|
|
b1f48d859e | ||
| 03f8d220a4 | |||
|
|
305120f599 | ||
| faa23b3f1b | |||
|
|
5844ec5ad3 | ||
| ff5ff736d8 | |||
|
|
4a59d7e66d | ||
|
|
8dd625bfd1 | ||
| b50977aa76 | |||
|
|
fbcd016e1a | ||
| ce5773ba45 | |||
|
|
7834b89f10 | ||
| e0d6c9f81d | |||
|
|
2e550f8b58 | ||
| c76e2dfed5 | |||
|
|
24fa22e77b | ||
| cd1579058c | |||
| 45b48fa7cd | |||
|
|
3952a5337b | ||
|
|
ccc97ebaa9 | ||
|
|
3a54db8948 | ||
|
|
96e2ad4f1f | ||
| c5a8982122 | |||
|
|
f7289606fc | ||
| 0c5c90201f | |||
|
|
b484f0daff | ||
|
|
1288181e39 | ||
|
|
b625f41621 | ||
| 77d3ba967c | |||
|
|
aeed881d85 | ||
|
|
d0bbdb5dc1 | ||
| 44339c52d7 | |||
|
|
22ffdafacc | ||
|
|
c49765e951 | ||
| 64000b9967 | |||
|
|
733e6b36e9 | ||
|
|
0659cc0aca | ||
|
|
748b9b848e | ||
|
|
6a1ad230ee | ||
| 90bbc78867 | |||
|
|
1ef5dcb2b3 | ||
|
|
d105a3ff5e | ||
| 0424c78f6c | |||
|
|
3fdb7a29d4 | ||
| 31b4d0bf1e | |||
|
|
e2275a23b1 | ||
| 7522bb7e66 | |||
|
|
63fa6841a2 | ||
| ece3c5597b | |||
|
|
63f4e49d88 | ||
|
|
e0a6b307a2 | ||
| 75320eb587 | |||
|
|
afb31b7f4b | ||
| a429a9f4da | |||
|
|
d9763def85 | ||
| ab7f0444b2 | |||
|
|
6b3960a3a4 | ||
| 6cad8e74e1 | |||
|
|
86c94cff62 | ||
| 692cb61991 | |||
|
|
392422992b | ||
| cc637a9738 | |||
|
|
8c27473fed | ||
| bde54c7487 | |||
|
|
a14f944fcc | ||
| 56f7405baa | |||
|
|
e3b1ecc572 | ||
| 8acf72b22c | |||
|
|
c95102a0bd | ||
| 0685d62f9c | |||
|
|
78021d4695 | ||
| 3cdd10783b | |||
|
|
c4e31be27a | ||
| 9d9ade14eb | |||
|
|
9a8936ab34 | ||
| c5831966ed | |||
|
|
f03cc6039b | ||
| 9171e54652 | |||
|
|
d64e072f06 | ||
|
|
b2312fbe01 | ||
|
|
98c4a2413c | ||
| 6fba7c7ae8 | |||
|
|
be695a5d7c | ||
|
|
6471e66d89 | ||
|
|
149039a904 | ||
| 815d675529 | |||
|
|
e8634b93c3 | ||
| f20736fd2a | |||
|
|
7f2f96a819 | ||
| aaa74894dd | |||
|
|
e711d6702a | ||
|
|
d2fc829380 | ||
| de27b1af10 | |||
|
|
7370220497 | ||
| b01dacf328 | |||
|
|
1210c17989 | ||
|
|
9599b188e8 | ||
| c43660a58c | |||
|
|
7fd48c7764 | ||
| a105bb7c1a | |||
|
|
1a34a74232 | ||
| a82a167915 | |||
|
|
7725e7a8de | ||
|
|
f0ae25c533 | ||
| 27f581f17d | |||
|
|
18a098d9a6 | ||
| d2b07326ed | |||
|
|
1c5eadc23b | ||
| 10ff718045 | |||
|
|
0ca3fe9f5d | ||
| 462f8763ab | |||
|
|
57a45a24cb | ||
| a7696568cc | |||
|
|
70701bf73a | ||
| 20dbd94892 | |||
|
|
48a99962e3 | ||
| ee66ecc305 | |||
|
|
065c9daaad | ||
| c76b9d5c15 | |||
|
|
259f9d2e24 | ||
| 8e715c55cd | |||
|
|
0057de4d12 | ||
|
|
71ac59794e | ||
| be04820b00 | |||
| 10b6e34d44 | |||
| 58f1106dbd | |||
| cf5072cced | |||
|
|
702653e52e | ||
|
|
db0d966a6a | ||
|
|
a56adcd342 | ||
|
|
eaf509a895 | ||
|
|
854931bed2 | ||
| 33b5ff5e54 | |||
| 3923d03650 | |||
|
|
c57ccc4bca | ||
|
|
cb2e3fae57 | ||
| 5e4c68c9d8 | |||
|
|
95f540e5df | ||
| 0087a6b20a | |||
|
|
3dfd7c0935 | ||
| 4b2bb25d03 | |||
|
|
881bbb4240 | ||
| 5f7d61748b | |||
|
|
972e71a2f1 | ||
| 614b9939b1 | |||
|
|
6dbc2afbf4 | ||
| 6c96f9ac64 | |||
|
|
ed26915562 | ||
| 628a572c70 | |||
|
|
73e1d0a54e | ||
| b111157dc8 | |||
|
|
8c05448843 | ||
|
|
87556b145e | ||
| 645c761238 | |||
|
|
033d5fcadd | ||
| 128324427f | |||
|
|
61f5aaf4a3 | ||
|
|
4f61d5af8e | ||
|
|
62fd4ff5e1 | ||
| f40f19e735 | |||
|
|
ce952d97b2 | ||
| 53d3637b3e | |||
|
|
ae7195c829 | ||
| ad1f17bb56 | |||
|
|
62b1a1f37a | ||
| 2a80030ceb | |||
|
|
2f9efdad64 | ||
|
|
6551d7af79 | ||
| 7515a5a314 | |||
|
|
254b543c89 | ||
|
|
917b68eb81 | ||
| 2becbddb4a |
69
.env.example
69
.env.example
@@ -1,23 +1,82 @@
|
||||
# ============================================================
|
||||
# The Ouroboros — Environment Configuration
|
||||
# ============================================================
|
||||
# Copy this file to .env and fill in your values.
|
||||
# Lines starting with # are comments.
|
||||
|
||||
# ============================================================
|
||||
# Korea Investment Securities API
|
||||
# ============================================================
|
||||
KIS_APP_KEY=your_app_key_here
|
||||
KIS_APP_SECRET=your_app_secret_here
|
||||
KIS_ACCOUNT_NO=12345678-01
|
||||
KIS_BASE_URL=https://openapivts.koreainvestment.com:9443
|
||||
|
||||
# Paper trading (VTS): https://openapivts.koreainvestment.com:29443
|
||||
# Live trading: https://openapi.koreainvestment.com:9443
|
||||
KIS_BASE_URL=https://openapivts.koreainvestment.com:29443
|
||||
|
||||
# ============================================================
|
||||
# Trading Mode
|
||||
# ============================================================
|
||||
# paper = 모의투자 (safe for testing), live = 실전투자 (real money)
|
||||
MODE=paper
|
||||
|
||||
# daily = batch per session, realtime = per-stock continuous scan
|
||||
TRADE_MODE=daily
|
||||
|
||||
# Comma-separated market codes: KR, US, JP, HK, CN, VN
|
||||
ENABLED_MARKETS=KR,US
|
||||
|
||||
# Simulated USD cash for paper (VTS) overseas trading.
|
||||
# VTS overseas balance API often returns 0; this value is used as fallback.
|
||||
# Set to 0 to disable fallback (not used in live mode).
|
||||
PAPER_OVERSEAS_CASH=50000.0
|
||||
|
||||
# ============================================================
|
||||
# Google Gemini
|
||||
# ============================================================
|
||||
GEMINI_API_KEY=your_gemini_api_key_here
|
||||
GEMINI_MODEL=gemini-pro
|
||||
# Recommended: gemini-2.0-flash-exp or gemini-1.5-pro
|
||||
GEMINI_MODEL=gemini-2.0-flash-exp
|
||||
|
||||
# ============================================================
|
||||
# Risk Management
|
||||
# ============================================================
|
||||
CIRCUIT_BREAKER_PCT=-3.0
|
||||
FAT_FINGER_PCT=30.0
|
||||
CONFIDENCE_THRESHOLD=80
|
||||
|
||||
# ============================================================
|
||||
# Database
|
||||
# ============================================================
|
||||
DB_PATH=data/trade_logs.db
|
||||
|
||||
# ============================================================
|
||||
# Rate Limiting
|
||||
RATE_LIMIT_RPS=10.0
|
||||
# ============================================================
|
||||
# KIS API real limit is ~2 RPS. Keep at 2.0 for maximum safety.
|
||||
# Increasing this risks EGW00201 "초당 거래건수 초과" errors.
|
||||
RATE_LIMIT_RPS=2.0
|
||||
|
||||
# Trading Mode (paper / live)
|
||||
MODE=paper
|
||||
# ============================================================
|
||||
# External Data APIs (optional)
|
||||
# ============================================================
|
||||
# NEWS_API_KEY=your_news_api_key_here
|
||||
# NEWS_API_PROVIDER=alphavantage
|
||||
# MARKET_DATA_API_KEY=your_market_data_key_here
|
||||
|
||||
# ============================================================
|
||||
# Telegram Notifications (optional)
|
||||
# ============================================================
|
||||
# Get bot token from @BotFather on Telegram
|
||||
# Get chat ID from @userinfobot or your chat
|
||||
# TELEGRAM_BOT_TOKEN=1234567890:ABCdefGHIjklMNOpqrsTUVwxyz
|
||||
# TELEGRAM_CHAT_ID=123456789
|
||||
# TELEGRAM_ENABLED=true
|
||||
|
||||
# ============================================================
|
||||
# Dashboard (optional)
|
||||
# ============================================================
|
||||
# DASHBOARD_ENABLED=false
|
||||
# DASHBOARD_HOST=127.0.0.1
|
||||
# DASHBOARD_PORT=8080
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -174,4 +174,7 @@ cython_debug/
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
# Data files (trade logs, databases)
|
||||
# But NOT src/data/ which contains source code
|
||||
data/
|
||||
!src/data/
|
||||
|
||||
97
CLAUDE.md
97
CLAUDE.md
@@ -15,15 +15,86 @@ pytest -v --cov=src
|
||||
|
||||
# Run (paper trading)
|
||||
python -m src.main --mode=paper
|
||||
|
||||
# Run with dashboard
|
||||
python -m src.main --mode=paper --dashboard
|
||||
```
|
||||
|
||||
## Telegram Notifications (Optional)
|
||||
|
||||
Get real-time alerts for trades, circuit breakers, and system events via Telegram.
|
||||
|
||||
### Quick Setup
|
||||
|
||||
1. **Create bot**: Message [@BotFather](https://t.me/BotFather) on Telegram → `/newbot`
|
||||
2. **Get chat ID**: Message [@userinfobot](https://t.me/userinfobot) → `/start`
|
||||
3. **Configure**: Add to `.env`:
|
||||
```bash
|
||||
TELEGRAM_BOT_TOKEN=1234567890:ABCdefGHIjklMNOpqrsTUVwxyz
|
||||
TELEGRAM_CHAT_ID=123456789
|
||||
TELEGRAM_ENABLED=true
|
||||
```
|
||||
4. **Test**: Start bot conversation (`/start`), then run the agent
|
||||
|
||||
**Full documentation**: [src/notifications/README.md](src/notifications/README.md)
|
||||
|
||||
### What You'll Get
|
||||
|
||||
- 🟢 Trade execution alerts (BUY/SELL with confidence)
|
||||
- 🚨 Circuit breaker trips (automatic trading halt)
|
||||
- ⚠️ Fat-finger rejections (oversized orders blocked)
|
||||
- ℹ️ Market open/close notifications
|
||||
- 📝 System startup/shutdown status
|
||||
|
||||
### Interactive Commands
|
||||
|
||||
With `TELEGRAM_COMMANDS_ENABLED=true` (default), the bot supports 9 bidirectional commands: `/help`, `/status`, `/positions`, `/report`, `/scenarios`, `/review`, `/dashboard`, `/stop`, `/resume`.
|
||||
|
||||
**Fail-safe**: Notifications never crash the trading system. Missing credentials or API errors are logged but trading continues normally.
|
||||
|
||||
## Smart Volatility Scanner (Optional)
|
||||
|
||||
Python-first filtering pipeline that reduces Gemini API calls by pre-filtering stocks using technical indicators.
|
||||
|
||||
### How It Works
|
||||
|
||||
1. **Fetch Rankings** — KIS API volume surge rankings (top 30 stocks)
|
||||
2. **Python Filter** — RSI + volume ratio calculations (no AI)
|
||||
- Volume > 200% of previous day
|
||||
- RSI(14) < 30 (oversold) OR RSI(14) > 70 (momentum)
|
||||
3. **AI Judgment** — Only qualified candidates (1-3 stocks) sent to Gemini
|
||||
|
||||
### Configuration
|
||||
|
||||
Add to `.env` (optional, has sensible defaults):
|
||||
```bash
|
||||
RSI_OVERSOLD_THRESHOLD=30 # 0-50, default 30
|
||||
RSI_MOMENTUM_THRESHOLD=70 # 50-100, default 70
|
||||
VOL_MULTIPLIER=2.0 # Volume threshold (2.0 = 200%)
|
||||
SCANNER_TOP_N=3 # Max candidates per scan
|
||||
```
|
||||
|
||||
### Benefits
|
||||
|
||||
- **Reduces API costs** — Process 1-3 stocks instead of 20-30
|
||||
- **Python-based filtering** — Fast technical analysis before AI
|
||||
- **Evolution-ready** — Selection context logged for strategy optimization
|
||||
- **Fault-tolerant** — Falls back to static watchlist on API failure
|
||||
|
||||
### Realtime Mode Only
|
||||
|
||||
Smart Scanner runs in `TRADE_MODE=realtime` only. Daily mode uses static watchlists for batch efficiency.
|
||||
|
||||
## Documentation
|
||||
|
||||
- **[Workflow Guide](docs/workflow.md)** — Git workflow policy and agent-based development
|
||||
- **[Command Reference](docs/commands.md)** — Common failures, build commands, troubleshooting
|
||||
- **[Architecture](docs/architecture.md)** — System design, components, data flow
|
||||
- **[Context Tree](docs/context-tree.md)** — L1-L7 hierarchical memory system
|
||||
- **[Testing](docs/testing.md)** — Test structure, coverage requirements, writing tests
|
||||
- **[Agent Policies](docs/agents.md)** — Prime directives, constraints, prohibited actions
|
||||
- **[Requirements Log](docs/requirements-log.md)** — User requirements and feedback tracking
|
||||
- **[Live Trading Checklist](docs/live-trading-checklist.md)** — 모의→실전 전환 체크리스트
|
||||
|
||||
## Core Principles
|
||||
|
||||
@@ -32,20 +103,37 @@ python -m src.main --mode=paper
|
||||
3. **Issue-Driven Development** — All work goes through Gitea issues → feature branches → PRs
|
||||
4. **Agent Specialization** — Use dedicated agents for design, coding, testing, docs, review
|
||||
|
||||
## Requirements Management
|
||||
|
||||
User requirements and feedback are tracked in [docs/requirements-log.md](docs/requirements-log.md):
|
||||
|
||||
- New requirements are added chronologically with dates
|
||||
- Code changes should reference related requirements
|
||||
- Helps maintain project evolution aligned with user needs
|
||||
- Preserves context across conversations and development cycles
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
src/
|
||||
├── analysis/ # Technical analysis (RSI, volatility, smart scanner)
|
||||
├── backup/ # Disaster recovery (scheduler, cloud storage, health)
|
||||
├── brain/ # Gemini AI decision engine (prompt optimizer, context selector)
|
||||
├── broker/ # KIS API client (domestic + overseas)
|
||||
├── brain/ # Gemini AI decision engine
|
||||
├── context/ # L1-L7 hierarchical memory system
|
||||
├── core/ # Risk manager (READ-ONLY)
|
||||
├── evolution/ # Self-improvement optimizer
|
||||
├── dashboard/ # FastAPI read-only monitoring (8 API endpoints)
|
||||
├── data/ # External data integration (news, market data, calendar)
|
||||
├── evolution/ # Self-improvement (optimizer, daily review, scorecard)
|
||||
├── logging/ # Decision logger (audit trail)
|
||||
├── markets/ # Market schedules and timezone handling
|
||||
├── notifications/ # Telegram alerts + bidirectional commands (9 commands)
|
||||
├── strategy/ # Pre-market planner, scenario engine, playbook store
|
||||
├── db.py # SQLite trade logging
|
||||
├── main.py # Trading loop orchestrator
|
||||
└── config.py # Settings (from .env)
|
||||
|
||||
tests/ # 54 tests across 4 files
|
||||
tests/ # 551 tests across 25 files
|
||||
docs/ # Extended documentation
|
||||
```
|
||||
|
||||
@@ -57,6 +145,7 @@ ruff check src/ tests/ # Lint
|
||||
mypy src/ --strict # Type check
|
||||
|
||||
python -m src.main --mode=paper # Paper trading
|
||||
python -m src.main --mode=paper --dashboard # With dashboard
|
||||
python -m src.main --mode=live # Live trading (⚠️ real money)
|
||||
|
||||
# Gitea workflow (requires tea CLI)
|
||||
@@ -82,7 +171,7 @@ Markets auto-detected based on timezone and enabled in `ENABLED_MARKETS` env var
|
||||
- `src/core/risk_manager.py` is **READ-ONLY** — changes require human approval
|
||||
- Circuit breaker at -3.0% P&L — may only be made **stricter**
|
||||
- Fat-finger protection: max 30% of cash per order — always enforced
|
||||
- Confidence < 80 → force HOLD — cannot be weakened
|
||||
- Confidence 임계값 (market_outlook별, 낮출 수 없음): BEARISH ≥ 90, NEUTRAL/기본 ≥ 80, BULLISH ≥ 75
|
||||
- All code changes → corresponding tests → coverage ≥ 80%
|
||||
|
||||
## Contributing
|
||||
|
||||
178
README.md
178
README.md
@@ -10,27 +10,41 @@ KIS(한국투자증권) API로 매매하고, Google Gemini로 판단하며, 자
|
||||
│ (매매 실행) │ │ (거래 루프) │ │ (의사결정) │
|
||||
└─────────────┘ └──────┬──────┘ └─────────────┘
|
||||
│
|
||||
┌──────┴──────┐
|
||||
│Risk Manager │
|
||||
│ (안전장치) │
|
||||
└──────┬──────┘
|
||||
┌────────────┼────────────┐
|
||||
│ │ │
|
||||
┌──────┴──────┐ ┌──┴───┐ ┌──────┴──────┐
|
||||
│Risk Manager │ │ DB │ │ Telegram │
|
||||
│ (안전장치) │ │ │ │ (알림+명령) │
|
||||
└──────┬──────┘ └──────┘ └─────────────┘
|
||||
│
|
||||
┌──────┴──────┐
|
||||
│ Evolution │
|
||||
│ (전략 진화) │
|
||||
└─────────────┘
|
||||
┌────────┼────────┐
|
||||
│ │ │
|
||||
┌────┴────┐┌──┴──┐┌────┴─────┐
|
||||
│Strategy ││Ctx ││Evolution │
|
||||
│(플레이북)││(메모리)││ (진화) │
|
||||
└─────────┘└─────┘└──────────┘
|
||||
```
|
||||
|
||||
**v2 핵심**: "Plan Once, Execute Locally" — 장 시작 전 AI가 시나리오 플레이북을 1회 생성하고, 거래 시간에는 로컬 시나리오 매칭만 수행하여 API 비용과 지연 시간을 대폭 절감.
|
||||
|
||||
## 핵심 모듈
|
||||
|
||||
| 모듈 | 파일 | 설명 |
|
||||
| 모듈 | 위치 | 설명 |
|
||||
|------|------|------|
|
||||
| 설정 | `src/config.py` | Pydantic 기반 환경변수 로딩 및 타입 검증 |
|
||||
| 브로커 | `src/broker/kis_api.py` | KIS API 비동기 래퍼 (토큰 갱신, 레이트 리미터, 해시키) |
|
||||
| 두뇌 | `src/brain/gemini_client.py` | Gemini 프롬프트 구성 및 JSON 응답 파싱 |
|
||||
| 방패 | `src/core/risk_manager.py` | 서킷 브레이커 + 팻 핑거 체크 |
|
||||
| 진화 | `src/evolution/optimizer.py` | 실패 패턴 분석 → 새 전략 생성 → 테스트 → PR |
|
||||
| DB | `src/db.py` | SQLite 거래 로그 기록 |
|
||||
| 설정 | `src/config.py` | Pydantic 기반 환경변수 로딩 및 타입 검증 (35+ 변수) |
|
||||
| 브로커 | `src/broker/` | KIS API 비동기 래퍼 (국내 + 해외 9개 시장) |
|
||||
| 두뇌 | `src/brain/` | Gemini 프롬프트 구성, JSON 파싱, 토큰 최적화 |
|
||||
| 방패 | `src/core/risk_manager.py` | 서킷 브레이커 + 팻 핑거 체크 (READ-ONLY) |
|
||||
| 전략 | `src/strategy/` | Pre-Market Planner, Scenario Engine, Playbook Store |
|
||||
| 컨텍스트 | `src/context/` | L1-L7 계층형 메모리 시스템 |
|
||||
| 분석 | `src/analysis/` | RSI, ATR, Smart Volatility Scanner |
|
||||
| 알림 | `src/notifications/` | 텔레그램 양방향 (알림 + 9개 명령어) |
|
||||
| 대시보드 | `src/dashboard/` | FastAPI 읽기 전용 모니터링 (8개 API) |
|
||||
| 진화 | `src/evolution/` | 전략 진화 + Daily Review + Scorecard |
|
||||
| 의사결정 로그 | `src/logging/` | 전체 거래 결정 감사 추적 |
|
||||
| 데이터 | `src/data/` | 뉴스, 시장 데이터, 경제 캘린더 연동 |
|
||||
| 백업 | `src/backup/` | 자동 백업, S3 클라우드, 무결성 검증 |
|
||||
| DB | `src/db.py` | SQLite 거래 로그 (5개 테이블) |
|
||||
|
||||
## 안전장치
|
||||
|
||||
@@ -41,6 +55,7 @@ KIS(한국투자증권) API로 매매하고, Google Gemini로 판단하며, 자
|
||||
| 신뢰도 임계값 | Gemini 신뢰도 80 미만이면 강제 HOLD |
|
||||
| 레이트 리미터 | Leaky Bucket 알고리즘으로 API 호출 제한 |
|
||||
| 토큰 자동 갱신 | 만료 1분 전 자동으로 Access Token 재발급 |
|
||||
| 손절 모니터링 | 플레이북 시나리오 기반 실시간 포지션 보호 |
|
||||
|
||||
## 빠른 시작
|
||||
|
||||
@@ -66,7 +81,11 @@ pytest -v --cov=src --cov-report=term-missing
|
||||
### 4. 실행 (모의투자)
|
||||
|
||||
```bash
|
||||
# 기본 실행
|
||||
python -m src.main --mode=paper
|
||||
|
||||
# 대시보드 활성화
|
||||
python -m src.main --mode=paper --dashboard
|
||||
```
|
||||
|
||||
### 5. Docker 실행
|
||||
@@ -75,23 +94,90 @@ python -m src.main --mode=paper
|
||||
docker compose up -d ouroboros
|
||||
```
|
||||
|
||||
## 지원 시장
|
||||
|
||||
| 국가 | 거래소 | 코드 |
|
||||
|------|--------|------|
|
||||
| 🇰🇷 한국 | KRX | KR |
|
||||
| 🇺🇸 미국 | NASDAQ, NYSE, AMEX | US_NASDAQ, US_NYSE, US_AMEX |
|
||||
| 🇯🇵 일본 | TSE | JP |
|
||||
| 🇭🇰 홍콩 | SEHK | HK |
|
||||
| 🇨🇳 중국 | 상하이, 선전 | CN_SHA, CN_SZA |
|
||||
| 🇻🇳 베트남 | 하노이, 호치민 | VN_HNX, VN_HSX |
|
||||
|
||||
`ENABLED_MARKETS` 환경변수로 활성 시장 선택 (기본: `KR,US`).
|
||||
|
||||
## 텔레그램 (선택사항)
|
||||
|
||||
거래 실행, 서킷 브레이커 발동, 시스템 상태 등을 텔레그램으로 실시간 알림 받을 수 있습니다.
|
||||
|
||||
### 빠른 설정
|
||||
|
||||
1. **봇 생성**: 텔레그램에서 [@BotFather](https://t.me/BotFather) 메시지 → `/newbot` 명령
|
||||
2. **채팅 ID 확인**: [@userinfobot](https://t.me/userinfobot) 메시지 → `/start` 명령
|
||||
3. **환경변수 설정**: `.env` 파일에 추가
|
||||
```bash
|
||||
TELEGRAM_BOT_TOKEN=1234567890:ABCdefGHIjklMNOpqrsTUVwxyz
|
||||
TELEGRAM_CHAT_ID=123456789
|
||||
TELEGRAM_ENABLED=true
|
||||
```
|
||||
4. **테스트**: 봇과 대화 시작 (`/start` 전송) 후 에이전트 실행
|
||||
|
||||
**상세 문서**: [src/notifications/README.md](src/notifications/README.md)
|
||||
|
||||
### 알림 종류
|
||||
|
||||
- 🟢 거래 체결 알림 (BUY/SELL + 신뢰도)
|
||||
- 🚨 서킷 브레이커 발동 (자동 거래 중단)
|
||||
- ⚠️ 팻 핑거 차단 (과도한 주문 차단)
|
||||
- ℹ️ 장 시작/종료 알림
|
||||
- 📝 시스템 시작/종료 상태
|
||||
|
||||
### 양방향 명령어
|
||||
|
||||
`TELEGRAM_COMMANDS_ENABLED=true` (기본값) 설정 시 9개 대화형 명령어 지원:
|
||||
|
||||
| 명령어 | 설명 |
|
||||
|--------|------|
|
||||
| `/help` | 사용 가능한 명령어 목록 |
|
||||
| `/status` | 거래 상태 (모드, 시장, P&L) |
|
||||
| `/positions` | 계좌 요약 (잔고, 현금, P&L) |
|
||||
| `/report` | 일일 요약 (거래 수, P&L, 승률) |
|
||||
| `/scenarios` | 오늘의 플레이북 시나리오 |
|
||||
| `/review` | 최근 스코어카드 (L6_DAILY) |
|
||||
| `/dashboard` | 대시보드 URL 표시 |
|
||||
| `/stop` | 거래 일시 정지 |
|
||||
| `/resume` | 거래 재개 |
|
||||
|
||||
**안전장치**: 알림 실패해도 거래는 계속 진행됩니다.
|
||||
|
||||
## 테스트
|
||||
|
||||
35개 테스트가 TDD 방식으로 구현 전에 먼저 작성되었습니다.
|
||||
551개 테스트가 25개 파일에 걸쳐 구현되어 있습니다. 최소 커버리지 80%.
|
||||
|
||||
```
|
||||
tests/test_risk.py — 서킷 브레이커, 팻 핑거, 통합 검증 (11개)
|
||||
tests/test_broker.py — 토큰 관리, 타임아웃, HTTP 에러, 해시키 (6개)
|
||||
tests/test_brain.py — JSON 파싱, 신뢰도 임계값, 비정상 응답 처리 (15개)
|
||||
tests/test_scenario_engine.py — 시나리오 매칭 (44개)
|
||||
tests/test_data_integration.py — 외부 데이터 연동 (38개)
|
||||
tests/test_pre_market_planner.py — 플레이북 생성 (37개)
|
||||
tests/test_main.py — 거래 루프 통합 (37개)
|
||||
tests/test_token_efficiency.py — 토큰 최적화 (34개)
|
||||
tests/test_strategy_models.py — 전략 모델 검증 (33개)
|
||||
tests/test_telegram_commands.py — 텔레그램 명령어 (31개)
|
||||
tests/test_latency_control.py — 지연시간 제어 (30개)
|
||||
tests/test_telegram.py — 텔레그램 알림 (25개)
|
||||
... 외 16개 파일
|
||||
```
|
||||
|
||||
**상세**: [docs/testing.md](docs/testing.md)
|
||||
|
||||
## 기술 스택
|
||||
|
||||
- **언어**: Python 3.11+ (asyncio 기반)
|
||||
- **브로커**: KIS Open API (REST)
|
||||
- **브로커**: KIS Open API (REST, 국내+해외)
|
||||
- **AI**: Google Gemini Pro
|
||||
- **DB**: SQLite
|
||||
- **검증**: pytest + coverage
|
||||
- **DB**: SQLite (5개 테이블: trades, contexts, decision_logs, playbooks, context_metadata)
|
||||
- **대시보드**: FastAPI + uvicorn
|
||||
- **검증**: pytest + coverage (551 tests)
|
||||
- **CI/CD**: GitHub Actions
|
||||
- **배포**: Docker + Docker Compose
|
||||
|
||||
@@ -99,26 +185,50 @@ tests/test_brain.py — JSON 파싱, 신뢰도 임계값, 비정상 응답 처
|
||||
|
||||
```
|
||||
The-Ouroboros/
|
||||
├── .github/workflows/ci.yml # CI 파이프라인
|
||||
├── docs/
|
||||
│ ├── agents.md # AI 에이전트 페르소나 정의
|
||||
│ └── skills.md # 사용 가능한 도구 목록
|
||||
│ ├── architecture.md # 시스템 아키텍처
|
||||
│ ├── testing.md # 테스트 가이드
|
||||
│ ├── commands.md # 명령어 레퍼런스
|
||||
│ ├── context-tree.md # L1-L7 메모리 시스템
|
||||
│ ├── workflow.md # Git 워크플로우
|
||||
│ ├── agents.md # 에이전트 정책
|
||||
│ ├── skills.md # 도구 목록
|
||||
│ ├── disaster_recovery.md # 백업/복구
|
||||
│ └── requirements-log.md # 요구사항 기록
|
||||
├── src/
|
||||
│ ├── analysis/ # 기술적 분석 (RSI, ATR, Smart Scanner)
|
||||
│ ├── backup/ # 백업 (스케줄러, S3, 무결성 검증)
|
||||
│ ├── brain/ # Gemini 의사결정 (프롬프트 최적화, 컨텍스트 선택)
|
||||
│ ├── broker/ # KIS API (국내 + 해외)
|
||||
│ ├── context/ # L1-L7 계층 메모리
|
||||
│ ├── core/ # 리스크 관리 (READ-ONLY)
|
||||
│ ├── dashboard/ # FastAPI 모니터링 대시보드
|
||||
│ ├── data/ # 외부 데이터 연동
|
||||
│ ├── evolution/ # 전략 진화 + Daily Review
|
||||
│ ├── logging/ # 의사결정 감사 추적
|
||||
│ ├── markets/ # 시장 스케줄 + 타임존
|
||||
│ ├── notifications/ # 텔레그램 알림 + 명령어
|
||||
│ ├── strategy/ # 플레이북 (Planner, Scenario Engine)
|
||||
│ ├── config.py # Pydantic 설정
|
||||
│ ├── logging_config.py # JSON 구조화 로깅
|
||||
│ ├── db.py # SQLite 거래 기록
|
||||
│ ├── main.py # 비동기 거래 루프
|
||||
│ ├── broker/kis_api.py # KIS API 클라이언트
|
||||
│ ├── brain/gemini_client.py # Gemini 의사결정 엔진
|
||||
│ ├── core/risk_manager.py # 리스크 관리
|
||||
│ ├── evolution/optimizer.py # 전략 진화 엔진
|
||||
│ └── strategies/base.py # 전략 베이스 클래스
|
||||
├── tests/ # TDD 테스트 스위트
|
||||
│ ├── db.py # SQLite 데이터베이스
|
||||
│ └── main.py # 비동기 거래 루프
|
||||
├── tests/ # 551개 테스트 (25개 파일)
|
||||
├── Dockerfile # 멀티스테이지 빌드
|
||||
├── docker-compose.yml # 서비스 오케스트레이션
|
||||
└── pyproject.toml # 의존성 및 도구 설정
|
||||
```
|
||||
|
||||
## 문서
|
||||
|
||||
- **[아키텍처](docs/architecture.md)** — 시스템 설계, 컴포넌트, 데이터 흐름
|
||||
- **[테스트](docs/testing.md)** — 테스트 구조, 커버리지, 작성 가이드
|
||||
- **[명령어](docs/commands.md)** — CLI, Dashboard, Telegram 명령어
|
||||
- **[컨텍스트 트리](docs/context-tree.md)** — L1-L7 계층 메모리
|
||||
- **[워크플로우](docs/workflow.md)** — Git 워크플로우 정책
|
||||
- **[에이전트 정책](docs/agents.md)** — 안전 제약, 금지 행위
|
||||
- **[백업/복구](docs/disaster_recovery.md)** — 재해 복구 절차
|
||||
- **[요구사항](docs/requirements-log.md)** — 사용자 요구사항 추적
|
||||
|
||||
## 라이선스
|
||||
|
||||
이 프로젝트의 라이선스는 [LICENSE](LICENSE) 파일을 참조하세요.
|
||||
|
||||
45
docs/agent-constraints.md
Normal file
45
docs/agent-constraints.md
Normal file
@@ -0,0 +1,45 @@
|
||||
# Agent Constraints
|
||||
|
||||
This document records **persistent behavioral constraints** for agents working on this repository.
|
||||
It is distinct from `docs/requirements-log.md`, which records **project/product requirements**.
|
||||
|
||||
## Scope
|
||||
|
||||
- Applies to all AI agents and automation that modify this repo.
|
||||
- Supplements (does not replace) `docs/agents.md` and `docs/workflow.md`.
|
||||
|
||||
## Persistent Rules
|
||||
|
||||
1. **Workflow enforcement**
|
||||
- Follow `docs/workflow.md` for all changes.
|
||||
- Create a Gitea issue before any code or documentation change.
|
||||
- Work on a feature branch `feature/issue-{N}-{short-description}` and open a PR.
|
||||
- Never commit directly to `main`.
|
||||
|
||||
2. **Document-first routing**
|
||||
- When performing work, consult relevant `docs/` files *before* making changes.
|
||||
- Route decisions to the documented policy whenever applicable.
|
||||
- If guidance conflicts, prefer the stricter/safety-first rule and note it in the PR.
|
||||
|
||||
3. **Docs with code**
|
||||
- Any code change must be accompanied by relevant documentation updates.
|
||||
- If no doc update is needed, state the reason explicitly in the PR.
|
||||
|
||||
4. **Session-persistent user constraints**
|
||||
- If the user requests that a behavior should persist across sessions, record it here
|
||||
(or in a dedicated policy doc) and reference it when working.
|
||||
- Keep entries short and concrete, with dates.
|
||||
|
||||
## Change Control
|
||||
|
||||
- Changes to this file follow the same workflow as code changes.
|
||||
- Keep the history chronological and minimize rewording of existing entries.
|
||||
|
||||
## History
|
||||
|
||||
### 2026-02-08
|
||||
|
||||
- Always enforce Gitea workflow: issue -> feature branch -> PR before changes.
|
||||
- When work requires guidance, consult the relevant `docs/` policies first.
|
||||
- Any code change must be accompanied by relevant documentation updates.
|
||||
- Persist user constraints across sessions by recording them in this document.
|
||||
@@ -2,7 +2,44 @@
|
||||
|
||||
## Overview
|
||||
|
||||
Self-evolving AI trading agent for global stock markets via KIS (Korea Investment & Securities) API. The main loop in `src/main.py` orchestrates four components in a 60-second cycle per stock across multiple markets.
|
||||
Self-evolving AI trading agent for global stock markets via KIS (Korea Investment & Securities) API. The main loop in `src/main.py` orchestrates components across multiple markets with two trading modes: daily (batch API calls) or realtime (per-stock decisions).
|
||||
|
||||
**v2 Proactive Playbook Architecture**: The system uses a "plan once, execute locally" approach. Pre-market, the AI generates a playbook of scenarios (one Gemini API call per market per day). During trading hours, a local scenario engine matches live market data against these pre-computed scenarios — no additional AI calls needed. This dramatically reduces API costs and latency.
|
||||
|
||||
## Trading Modes
|
||||
|
||||
The system supports two trading frequency modes controlled by the `TRADE_MODE` environment variable:
|
||||
|
||||
### Daily Mode (default)
|
||||
|
||||
Optimized for Gemini Free tier API limits (20 calls/day):
|
||||
|
||||
- **Batch decisions**: 1 API call per market per session
|
||||
- **Fixed schedule**: 4 sessions per day at 6-hour intervals (configurable)
|
||||
- **API efficiency**: Processes all stocks in a market simultaneously
|
||||
- **Use case**: Free tier users, cost-conscious deployments
|
||||
- **Configuration**:
|
||||
```bash
|
||||
TRADE_MODE=daily
|
||||
DAILY_SESSIONS=4 # Sessions per day (1-10)
|
||||
SESSION_INTERVAL_HOURS=6 # Hours between sessions (1-24)
|
||||
```
|
||||
|
||||
**Example**: With 2 markets (US, KR) and 4 sessions/day = 8 API calls/day (within 20 call limit)
|
||||
|
||||
### Realtime Mode
|
||||
|
||||
High-frequency trading with individual stock analysis:
|
||||
|
||||
- **Per-stock decisions**: 1 API call per stock per cycle
|
||||
- **60-second interval**: Continuous monitoring
|
||||
- **Use case**: Production deployments with Gemini paid tier
|
||||
- **Configuration**:
|
||||
```bash
|
||||
TRADE_MODE=realtime
|
||||
```
|
||||
|
||||
**Note**: Realtime mode requires Gemini API subscription due to high call volume.
|
||||
|
||||
## Core Components
|
||||
|
||||
@@ -11,9 +48,11 @@ Self-evolving AI trading agent for global stock markets via KIS (Korea Investmen
|
||||
**KISBroker** (`kis_api.py`) — Async KIS API client for domestic Korean market
|
||||
|
||||
- Automatic OAuth token refresh (valid for 24 hours)
|
||||
- Leaky-bucket rate limiter (10 requests per second)
|
||||
- Leaky-bucket rate limiter (configurable RPS, default 2.0)
|
||||
- POST body hash-key signing for order authentication
|
||||
- Custom SSL context with disabled hostname verification for VTS (virtual trading) endpoint due to known certificate mismatch
|
||||
- `fetch_market_rankings()` — Fetch volume surge rankings from KIS API
|
||||
- `get_daily_prices()` — Fetch OHLCV history for technical analysis
|
||||
|
||||
**OverseasBroker** (`overseas.py`) — KIS overseas stock API wrapper
|
||||
|
||||
@@ -28,10 +67,47 @@ Self-evolving AI trading agent for global stock markets via KIS (Korea Investmen
|
||||
- `is_market_open()` checks weekends, trading hours, lunch breaks
|
||||
- `get_open_markets()` returns currently active markets
|
||||
- `get_next_market_open()` finds next market to open and when
|
||||
- 10 global markets defined (KR, US_NASDAQ, US_NYSE, US_AMEX, JP, HK, CN_SHA, CN_SZA, VN_HNX, VN_HSX)
|
||||
|
||||
### 2. Brain (`src/brain/gemini_client.py`)
|
||||
**Overseas Ranking API Methods** (added in v0.10.x):
|
||||
- `fetch_overseas_rankings()` — Fetch overseas ranking universe (fluctuation / volume)
|
||||
- Ranking endpoint paths and TR_IDs are configurable via environment variables
|
||||
|
||||
**GeminiClient** — AI decision engine powered by Google Gemini
|
||||
### 2. Analysis (`src/analysis/`)
|
||||
|
||||
**VolatilityAnalyzer** (`volatility.py`) — Technical indicator calculations
|
||||
|
||||
- ATR (Average True Range) for volatility measurement
|
||||
- RSI (Relative Strength Index) using Wilder's smoothing method
|
||||
- Price change percentages across multiple timeframes
|
||||
- Volume surge ratios and price-volume divergence
|
||||
- Momentum scoring (0-100 scale)
|
||||
- Breakout/breakdown pattern detection
|
||||
|
||||
**SmartVolatilityScanner** (`smart_scanner.py`) — Python-first filtering pipeline
|
||||
|
||||
- **Domestic (KR)**:
|
||||
- **Step 1**: Fetch domestic fluctuation ranking as primary universe
|
||||
- **Step 2**: Fetch domestic volume ranking for liquidity bonus
|
||||
- **Step 3**: Compute volatility-first score (max of daily change% and intraday range%)
|
||||
- **Step 4**: Apply liquidity bonus and return top N candidates
|
||||
- **Overseas (US/JP/HK/CN/VN)**:
|
||||
- **Step 1**: Fetch overseas ranking universe (fluctuation rank + volume rank bonus)
|
||||
- **Step 2**: Compute volatility-first score (max of daily change% and intraday range%)
|
||||
- **Step 3**: Apply liquidity bonus from volume ranking
|
||||
- **Step 4**: Return top N candidates (default 3)
|
||||
- **Fallback (overseas only)**: If ranking API is unavailable, uses dynamic universe
|
||||
from runtime active symbols + recent traded symbols + current holdings (no static watchlist)
|
||||
- **Realtime mode only**: Daily mode uses batch processing for API efficiency
|
||||
|
||||
**Benefits:**
|
||||
- Reduces Gemini API calls from 20-30 stocks to 1-3 qualified candidates
|
||||
- Fast Python-based filtering before expensive AI judgment
|
||||
- Logs selection context (RSI-compatible proxy, volume_ratio, signal, score) for Evolution system
|
||||
|
||||
### 3. Brain (`src/brain/`)
|
||||
|
||||
**GeminiClient** (`gemini_client.py`) — AI decision engine powered by Google Gemini
|
||||
|
||||
- Constructs structured prompts from market data
|
||||
- Parses JSON responses into `TradeDecision` objects (`action`, `confidence`, `rationale`)
|
||||
@@ -39,11 +115,20 @@ Self-evolving AI trading agent for global stock markets via KIS (Korea Investmen
|
||||
- Falls back to safe HOLD on any parse/API error
|
||||
- Handles markdown-wrapped JSON, malformed responses, invalid actions
|
||||
|
||||
### 3. Risk Manager (`src/core/risk_manager.py`)
|
||||
**PromptOptimizer** (`prompt_optimizer.py`) — Token efficiency optimization
|
||||
|
||||
- Reduces prompt size while preserving decision quality
|
||||
- Caches optimized prompts
|
||||
|
||||
**ContextSelector** (`context_selector.py`) — Relevant context selection for prompts
|
||||
|
||||
- Selects appropriate context layers for current market conditions
|
||||
|
||||
### 4. Risk Manager (`src/core/risk_manager.py`)
|
||||
|
||||
**RiskManager** — Safety circuit breaker and order validation
|
||||
|
||||
⚠️ **READ-ONLY by policy** (see [`docs/agents.md`](./agents.md))
|
||||
> **READ-ONLY by policy** (see [`docs/agents.md`](./agents.md))
|
||||
|
||||
- **Circuit Breaker**: Halts all trading via `SystemExit` when daily P&L drops below -3.0%
|
||||
- Threshold may only be made stricter, never relaxed
|
||||
@@ -51,9 +136,106 @@ Self-evolving AI trading agent for global stock markets via KIS (Korea Investmen
|
||||
- **Fat-Finger Protection**: Rejects orders exceeding 30% of available cash
|
||||
- Must always be enforced, cannot be disabled
|
||||
|
||||
### 4. Evolution (`src/evolution/optimizer.py`)
|
||||
### 5. Strategy (`src/strategy/`)
|
||||
|
||||
**StrategyOptimizer** — Self-improvement loop
|
||||
**Pre-Market Planner** (`pre_market_planner.py`) — AI playbook generation
|
||||
|
||||
- Runs before market open (configurable `PRE_MARKET_MINUTES`, default 30)
|
||||
- Generates scenario-based playbooks via single Gemini API call per market
|
||||
- Handles timeout (`PLANNER_TIMEOUT_SECONDS`, default 60) with defensive playbook fallback
|
||||
- Persists playbooks to database for audit trail
|
||||
|
||||
**Scenario Engine** (`scenario_engine.py`) — Local scenario matching
|
||||
|
||||
- Matches live market data against pre-computed playbook scenarios
|
||||
- No AI calls during trading hours — pure Python matching logic
|
||||
- Returns matched scenarios with confidence scores
|
||||
- Configurable `MAX_SCENARIOS_PER_STOCK` (default 5)
|
||||
- Periodic rescan at `RESCAN_INTERVAL_SECONDS` (default 300)
|
||||
|
||||
**Playbook Store** (`playbook_store.py`) — Playbook persistence
|
||||
|
||||
- SQLite-backed storage for daily playbooks
|
||||
- Date and market-based retrieval
|
||||
- Status tracking (generated, active, expired)
|
||||
|
||||
**Models** (`models.py`) — Pydantic data models
|
||||
|
||||
- Scenario, Playbook, MatchResult, and related type definitions
|
||||
|
||||
### 6. Context System (`src/context/`)
|
||||
|
||||
**Context Store** (`store.py`) — L1-L7 hierarchical memory
|
||||
|
||||
- 7-layer context system (see [docs/context-tree.md](./context-tree.md)):
|
||||
- L1: Tick-level (real-time price)
|
||||
- L2: Intraday (session summary)
|
||||
- L3: Daily (end-of-day)
|
||||
- L4: Weekly (trend analysis)
|
||||
- L5: Monthly (strategy review)
|
||||
- L6: Daily Review (scorecard)
|
||||
- L7: Evolution (long-term learning)
|
||||
- Key-value storage with timeframe tagging
|
||||
- SQLite persistence in `contexts` table
|
||||
|
||||
**Context Scheduler** (`scheduler.py`) — Periodic aggregation
|
||||
|
||||
- Scheduled summarization from lower to higher layers
|
||||
- Configurable aggregation intervals
|
||||
|
||||
**Context Summarizer** (`summarizer.py`) — Layer summarization
|
||||
|
||||
- Aggregates lower-layer data into higher-layer summaries
|
||||
|
||||
### 7. Dashboard (`src/dashboard/`)
|
||||
|
||||
**FastAPI App** (`app.py`) — Read-only monitoring dashboard
|
||||
|
||||
- Runs as daemon thread when enabled (`--dashboard` CLI flag or `DASHBOARD_ENABLED=true`)
|
||||
- Configurable host/port (`DASHBOARD_HOST`, `DASHBOARD_PORT`, default `127.0.0.1:8080`)
|
||||
- Serves static HTML frontend
|
||||
|
||||
**8 API Endpoints:**
|
||||
|
||||
| Endpoint | Method | Description |
|
||||
|----------|--------|-------------|
|
||||
| `/` | GET | Static HTML dashboard |
|
||||
| `/api/status` | GET | Daily trading status by market |
|
||||
| `/api/playbook/{date}` | GET | Playbook for specific date and market |
|
||||
| `/api/scorecard/{date}` | GET | Daily scorecard from L6_DAILY context |
|
||||
| `/api/performance` | GET | Trading performance metrics (by market + combined) |
|
||||
| `/api/context/{layer}` | GET | Query context by layer (L1-L7) |
|
||||
| `/api/decisions` | GET | Decision log entries with outcomes |
|
||||
| `/api/scenarios/active` | GET | Today's matched scenarios |
|
||||
|
||||
### 8. Notifications (`src/notifications/telegram_client.py`)
|
||||
|
||||
**TelegramClient** — Real-time event notifications via Telegram Bot API
|
||||
|
||||
- Sends alerts for trades, circuit breakers, fat-finger rejections, system events
|
||||
- Non-blocking: failures are logged but never crash trading
|
||||
- Rate-limited: 1 message/second default to respect Telegram API limits
|
||||
- Auto-disabled when credentials missing
|
||||
|
||||
**TelegramCommandHandler** — Bidirectional command interface
|
||||
|
||||
- Long polling from Telegram API (configurable `TELEGRAM_POLLING_INTERVAL`)
|
||||
- 9 interactive commands: `/help`, `/status`, `/positions`, `/report`, `/scenarios`, `/review`, `/dashboard`, `/stop`, `/resume`
|
||||
- Authorization filtering by `TELEGRAM_CHAT_ID`
|
||||
- Enable/disable via `TELEGRAM_COMMANDS_ENABLED` (default: true)
|
||||
|
||||
**Notification Types:**
|
||||
- Trade execution (BUY/SELL with confidence)
|
||||
- Circuit breaker trips (critical alert)
|
||||
- Fat-finger protection triggers (order rejection)
|
||||
- Market open/close events
|
||||
- System startup/shutdown status
|
||||
- Playbook generation results
|
||||
- Stop-loss monitoring alerts
|
||||
|
||||
### 9. Evolution (`src/evolution/`)
|
||||
|
||||
**StrategyOptimizer** (`optimizer.py`) — Self-improvement loop
|
||||
|
||||
- Analyzes high-confidence losing trades from SQLite
|
||||
- Asks Gemini to generate new `BaseStrategy` subclasses
|
||||
@@ -61,11 +243,127 @@ Self-evolving AI trading agent for global stock markets via KIS (Korea Investmen
|
||||
- Simulates PR creation for human review
|
||||
- Only activates strategies that pass all tests
|
||||
|
||||
**DailyReview** (`daily_review.py`) — End-of-day review
|
||||
|
||||
- Generates comprehensive trade performance summary
|
||||
- Stores results in L6_DAILY context layer
|
||||
- Tracks win rate, P&L, confidence accuracy
|
||||
|
||||
**DailyScorecard** (`scorecard.py`) — Performance scoring
|
||||
|
||||
- Calculates daily metrics (trades, P&L, win rate, avg confidence)
|
||||
- Enables trend tracking across days
|
||||
|
||||
**Stop-Loss Monitoring** — Real-time position protection
|
||||
|
||||
- Monitors positions against stop-loss levels from playbook scenarios
|
||||
- Sends Telegram alerts when thresholds approached or breached
|
||||
|
||||
### 10. Decision Logger (`src/logging/decision_logger.py`)
|
||||
|
||||
**DecisionLogger** — Comprehensive audit trail
|
||||
|
||||
- Logs every trading decision with full context snapshot
|
||||
- Captures input data, rationale, confidence, and outcomes
|
||||
- Supports outcome tracking (P&L, accuracy) for post-analysis
|
||||
- Stored in `decision_logs` table with indexed queries
|
||||
- Review workflow support (reviewed flag, review notes)
|
||||
|
||||
### 11. Data Integration (`src/data/`)
|
||||
|
||||
**External Data Sources** (optional):
|
||||
|
||||
- `news_api.py` — News sentiment data
|
||||
- `market_data.py` — Extended market data
|
||||
- `economic_calendar.py` — Economic event calendar
|
||||
|
||||
### 12. Backup (`src/backup/`)
|
||||
|
||||
**Disaster Recovery** (see [docs/disaster_recovery.md](./disaster_recovery.md)):
|
||||
|
||||
- `scheduler.py` — Automated backup scheduling
|
||||
- `exporter.py` — Data export to various formats
|
||||
- `cloud_storage.py` — S3-compatible cloud backup
|
||||
- `health_monitor.py` — Backup integrity verification
|
||||
|
||||
## Data Flow
|
||||
|
||||
### Playbook Mode (Daily — Primary v2 Flow)
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Main Loop (60s cycle per stock, per market) │
|
||||
│ Pre-Market Phase (before market open) │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Pre-Market Planner │
|
||||
│ - 1 Gemini API call per market │
|
||||
│ - Generate scenario playbook │
|
||||
│ - Store in playbooks table │
|
||||
└──────────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Trading Hours (market open → close) │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Market Schedule Check │
|
||||
│ - Get open markets │
|
||||
│ - Filter by enabled markets │
|
||||
└──────────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Scenario Engine (local) │
|
||||
│ - Match live data vs playbook │
|
||||
│ - No AI calls needed │
|
||||
│ - Return matched scenarios │
|
||||
└──────────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Risk Manager: Validate Order │
|
||||
│ - Check circuit breaker │
|
||||
│ - Check fat-finger limit │
|
||||
└──────────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Broker: Execute Order │
|
||||
│ - Domestic: send_order() │
|
||||
│ - Overseas: send_overseas_order()│
|
||||
└──────────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Decision Logger + DB │
|
||||
│ - Full audit trail │
|
||||
│ - Context snapshot │
|
||||
│ - Telegram notification │
|
||||
└──────────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Post-Market Phase │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Daily Review + Scorecard │
|
||||
│ - Performance summary │
|
||||
│ - Store in L6_DAILY context │
|
||||
│ - Evolution learning │
|
||||
└──────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Realtime Mode (with Smart Scanner)
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Main Loop (60s cycle per market) │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
@@ -74,58 +372,69 @@ Self-evolving AI trading agent for global stock markets via KIS (Korea Investmen
|
||||
│ - Get open markets │
|
||||
│ - Filter by enabled markets │
|
||||
│ - Wait if all closed │
|
||||
└──────────────────┬────────────────┘
|
||||
└──────────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Smart Scanner (Python-first) │
|
||||
│ - Domestic: fluctuation rank │
|
||||
│ + volume rank bonus │
|
||||
│ + volatility-first scoring │
|
||||
│ - Overseas: ranking universe │
|
||||
│ + volatility-first scoring │
|
||||
│ - Fallback: dynamic universe │
|
||||
│ - Return top 3 qualified stocks │
|
||||
└──────────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ For Each Qualified Candidate │
|
||||
└──────────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Broker: Fetch Market Data │
|
||||
│ - Domestic: orderbook + balance │
|
||||
│ - Overseas: price + balance │
|
||||
└──────────────────┬────────────────┘
|
||||
└──────────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Calculate P&L │
|
||||
│ pnl_pct = (eval - cost) / cost │
|
||||
└──────────────────┬────────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Brain: Get Decision │
|
||||
│ Brain: Get Decision (AI) │
|
||||
│ - Build prompt with market data │
|
||||
│ - Call Gemini API │
|
||||
│ - Parse JSON response │
|
||||
│ - Return TradeDecision │
|
||||
└──────────────────┬────────────────┘
|
||||
└──────────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Risk Manager: Validate Order │
|
||||
│ - Check circuit breaker │
|
||||
│ - Check fat-finger limit │
|
||||
│ - Raise if validation fails │
|
||||
└──────────────────┬────────────────┘
|
||||
└──────────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Broker: Execute Order │
|
||||
│ - Domestic: send_order() │
|
||||
│ - Overseas: send_overseas_order()│
|
||||
└──────────────────┬────────────────┘
|
||||
└──────────────────┬───────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Database: Log Trade │
|
||||
│ - SQLite (data/trades.db) │
|
||||
│ - Track: action, confidence, │
|
||||
│ rationale, market, exchange │
|
||||
└───────────────────────────────────┘
|
||||
│ Decision Logger + Notifications │
|
||||
│ - Log trade to SQLite │
|
||||
│ - selection_context (JSON) │
|
||||
│ - Telegram notification │
|
||||
└──────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Database Schema
|
||||
|
||||
**SQLite** (`src/db.py`)
|
||||
**SQLite** (`src/db.py`) — Database: `data/trades.db`
|
||||
|
||||
### trades
|
||||
```sql
|
||||
CREATE TABLE trades (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
@@ -137,12 +446,73 @@ CREATE TABLE trades (
|
||||
quantity INTEGER,
|
||||
price REAL,
|
||||
pnl REAL DEFAULT 0.0,
|
||||
market TEXT DEFAULT 'KR', -- KR | US_NASDAQ | JP | etc.
|
||||
exchange_code TEXT DEFAULT 'KRX' -- KRX | NASD | NYSE | etc.
|
||||
market TEXT DEFAULT 'KR',
|
||||
exchange_code TEXT DEFAULT 'KRX',
|
||||
selection_context TEXT, -- JSON: {rsi, volume_ratio, signal, score}
|
||||
decision_id TEXT -- Links to decision_logs
|
||||
);
|
||||
```
|
||||
|
||||
Auto-migration: Adds `market` and `exchange_code` columns if missing for backward compatibility.
|
||||
### contexts
|
||||
```sql
|
||||
CREATE TABLE contexts (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
layer TEXT NOT NULL, -- L1 through L7
|
||||
timeframe TEXT,
|
||||
key TEXT NOT NULL,
|
||||
value TEXT NOT NULL, -- JSON data
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
);
|
||||
-- Indices: idx_contexts_layer, idx_contexts_timeframe, idx_contexts_updated
|
||||
```
|
||||
|
||||
### decision_logs
|
||||
```sql
|
||||
CREATE TABLE decision_logs (
|
||||
decision_id TEXT PRIMARY KEY,
|
||||
timestamp TEXT NOT NULL,
|
||||
stock_code TEXT,
|
||||
market TEXT,
|
||||
exchange_code TEXT,
|
||||
action TEXT,
|
||||
confidence INTEGER,
|
||||
rationale TEXT,
|
||||
context_snapshot TEXT, -- JSON: full context at decision time
|
||||
input_data TEXT, -- JSON: market data used
|
||||
outcome_pnl REAL,
|
||||
outcome_accuracy REAL,
|
||||
reviewed INTEGER DEFAULT 0,
|
||||
review_notes TEXT
|
||||
);
|
||||
-- Indices: idx_decision_logs_timestamp, idx_decision_logs_reviewed, idx_decision_logs_confidence
|
||||
```
|
||||
|
||||
### playbooks
|
||||
```sql
|
||||
CREATE TABLE playbooks (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
date TEXT NOT NULL,
|
||||
market TEXT NOT NULL,
|
||||
status TEXT DEFAULT 'generated',
|
||||
playbook_json TEXT NOT NULL, -- Full playbook with scenarios
|
||||
generated_at TEXT NOT NULL,
|
||||
token_count INTEGER,
|
||||
scenario_count INTEGER,
|
||||
match_count INTEGER DEFAULT 0
|
||||
);
|
||||
-- Indices: idx_playbooks_date, idx_playbooks_market
|
||||
```
|
||||
|
||||
### context_metadata
|
||||
```sql
|
||||
CREATE TABLE context_metadata (
|
||||
layer TEXT PRIMARY KEY,
|
||||
description TEXT,
|
||||
retention_days INTEGER,
|
||||
aggregation_source TEXT
|
||||
);
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
@@ -157,13 +527,81 @@ KIS_APP_SECRET=your_app_secret
|
||||
KIS_ACCOUNT_NO=XXXXXXXX-XX
|
||||
GEMINI_API_KEY=your_gemini_key
|
||||
|
||||
# Optional
|
||||
# Optional — Trading Mode
|
||||
MODE=paper # paper | live
|
||||
TRADE_MODE=daily # daily | realtime
|
||||
DAILY_SESSIONS=4 # Sessions per day (daily mode only)
|
||||
SESSION_INTERVAL_HOURS=6 # Hours between sessions (daily mode only)
|
||||
|
||||
# Optional — Database
|
||||
DB_PATH=data/trades.db
|
||||
|
||||
# Optional — Risk
|
||||
CONFIDENCE_THRESHOLD=80
|
||||
MAX_LOSS_PCT=3.0
|
||||
MAX_ORDER_PCT=30.0
|
||||
ENABLED_MARKETS=KR,US_NASDAQ # Comma-separated market codes
|
||||
|
||||
# Optional — Markets
|
||||
ENABLED_MARKETS=KR,US # Comma-separated market codes
|
||||
RATE_LIMIT_RPS=2.0 # KIS API requests per second
|
||||
|
||||
# Optional — Pre-Market Planner (v2)
|
||||
PRE_MARKET_MINUTES=30 # Minutes before market open to generate playbook
|
||||
MAX_SCENARIOS_PER_STOCK=5 # Max scenarios per stock in playbook
|
||||
PLANNER_TIMEOUT_SECONDS=60 # Timeout for playbook generation
|
||||
DEFENSIVE_PLAYBOOK_ON_FAILURE=true # Fallback on AI failure
|
||||
RESCAN_INTERVAL_SECONDS=300 # Scenario rescan interval during trading
|
||||
|
||||
# Optional — Smart Scanner (realtime mode only)
|
||||
RSI_OVERSOLD_THRESHOLD=30 # 0-50, oversold threshold
|
||||
RSI_MOMENTUM_THRESHOLD=70 # 50-100, momentum threshold
|
||||
VOL_MULTIPLIER=2.0 # Minimum volume ratio (2.0 = 200%)
|
||||
SCANNER_TOP_N=3 # Max qualified candidates per scan
|
||||
|
||||
# Optional — Dashboard
|
||||
DASHBOARD_ENABLED=false # Enable FastAPI dashboard
|
||||
DASHBOARD_HOST=127.0.0.1 # Dashboard bind address
|
||||
DASHBOARD_PORT=8080 # Dashboard port (1-65535)
|
||||
|
||||
# Optional — Telegram
|
||||
TELEGRAM_BOT_TOKEN=1234567890:ABCdefGHIjklMNOpqrsTUVwxyz
|
||||
TELEGRAM_CHAT_ID=123456789
|
||||
TELEGRAM_ENABLED=true
|
||||
TELEGRAM_COMMANDS_ENABLED=true # Enable bidirectional commands
|
||||
TELEGRAM_POLLING_INTERVAL=1.0 # Command polling interval (seconds)
|
||||
|
||||
# Optional — Backup
|
||||
BACKUP_ENABLED=false
|
||||
BACKUP_DIR=data/backups
|
||||
S3_ENDPOINT_URL=...
|
||||
S3_ACCESS_KEY=...
|
||||
S3_SECRET_KEY=...
|
||||
S3_BUCKET_NAME=...
|
||||
S3_REGION=...
|
||||
|
||||
# Optional — External Data
|
||||
NEWS_API_KEY=...
|
||||
NEWS_API_PROVIDER=...
|
||||
MARKET_DATA_API_KEY=...
|
||||
|
||||
# Position Sizing (optional)
|
||||
POSITION_SIZING_ENABLED=true
|
||||
POSITION_BASE_ALLOCATION_PCT=5.0
|
||||
POSITION_MIN_ALLOCATION_PCT=1.0
|
||||
POSITION_MAX_ALLOCATION_PCT=10.0
|
||||
POSITION_VOLATILITY_TARGET_SCORE=50.0
|
||||
|
||||
# Legacy/compat scanner thresholds (kept for backward compatibility)
|
||||
RSI_OVERSOLD_THRESHOLD=30
|
||||
RSI_MOMENTUM_THRESHOLD=70
|
||||
VOL_MULTIPLIER=2.0
|
||||
|
||||
# Overseas Ranking API (optional override; account-dependent)
|
||||
OVERSEAS_RANKING_ENABLED=true
|
||||
OVERSEAS_RANKING_FLUCT_TR_ID=HHDFS76200100
|
||||
OVERSEAS_RANKING_VOLUME_TR_ID=HHDFS76200200
|
||||
OVERSEAS_RANKING_FLUCT_PATH=/uapi/overseas-price/v1/quotations/inquire-updown-rank
|
||||
OVERSEAS_RANKING_VOLUME_PATH=/uapi/overseas-price/v1/quotations/inquire-volume-rank
|
||||
```
|
||||
|
||||
Tests use in-memory SQLite (`DB_PATH=":memory:"`) and dummy credentials via `tests/conftest.py`.
|
||||
@@ -189,3 +627,17 @@ Tests use in-memory SQLite (`DB_PATH=":memory:"`) and dummy credentials via `tes
|
||||
- Wait until next market opens
|
||||
- Use `get_next_market_open()` to calculate wait time
|
||||
- Sleep until market open time
|
||||
|
||||
### Telegram API Errors
|
||||
- Log warning but continue trading
|
||||
- Missing credentials → auto-disable notifications
|
||||
- Network timeout → skip notification, no retry
|
||||
- Invalid token → log error, trading unaffected
|
||||
- Rate limit exceeded → queued via rate limiter
|
||||
|
||||
### Playbook Generation Failure
|
||||
- Timeout → fall back to defensive playbook (`DEFENSIVE_PLAYBOOK_ON_FAILURE`)
|
||||
- API error → use previous day's playbook if available
|
||||
- No playbook → skip pre-market phase, fall back to direct AI calls
|
||||
|
||||
**Guarantee**: Notification and dashboard failures never interrupt trading operations.
|
||||
|
||||
@@ -119,7 +119,7 @@ No decorator needed for async tests.
|
||||
# Install all dependencies (production + dev)
|
||||
pip install -e ".[dev]"
|
||||
|
||||
# Run full test suite with coverage
|
||||
# Run full test suite with coverage (551 tests across 25 files)
|
||||
pytest -v --cov=src --cov-report=term-missing
|
||||
|
||||
# Run a single test file
|
||||
@@ -137,11 +137,82 @@ mypy src/ --strict
|
||||
# Run the trading agent
|
||||
python -m src.main --mode=paper
|
||||
|
||||
# Run with dashboard enabled
|
||||
python -m src.main --mode=paper --dashboard
|
||||
|
||||
# Docker
|
||||
docker compose up -d ouroboros # Run agent
|
||||
docker compose --profile test up test # Run tests in container
|
||||
```
|
||||
|
||||
## Dashboard
|
||||
|
||||
The FastAPI dashboard provides read-only monitoring of the trading system.
|
||||
|
||||
### Starting the Dashboard
|
||||
|
||||
```bash
|
||||
# Via CLI flag
|
||||
python -m src.main --mode=paper --dashboard
|
||||
|
||||
# Via environment variable
|
||||
DASHBOARD_ENABLED=true python -m src.main --mode=paper
|
||||
```
|
||||
|
||||
Dashboard runs as a daemon thread on `DASHBOARD_HOST:DASHBOARD_PORT` (default: `127.0.0.1:8080`).
|
||||
|
||||
### API Endpoints
|
||||
|
||||
| Endpoint | Description |
|
||||
|----------|-------------|
|
||||
| `GET /` | HTML dashboard UI |
|
||||
| `GET /api/status` | Daily trading status by market |
|
||||
| `GET /api/playbook/{date}` | Playbook for specific date (query: `market`) |
|
||||
| `GET /api/scorecard/{date}` | Daily scorecard from L6_DAILY context |
|
||||
| `GET /api/performance` | Performance metrics by market and combined |
|
||||
| `GET /api/context/{layer}` | Context data by layer L1-L7 (query: `timeframe`) |
|
||||
| `GET /api/decisions` | Decision log entries (query: `limit`, `market`) |
|
||||
| `GET /api/scenarios/active` | Today's matched scenarios |
|
||||
|
||||
## Telegram Commands
|
||||
|
||||
When `TELEGRAM_COMMANDS_ENABLED=true` (default), the bot accepts these interactive commands:
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/help` | List available commands |
|
||||
| `/status` | Show trading status (mode, markets, P&L) |
|
||||
| `/positions` | Display account summary (balance, cash, P&L) |
|
||||
| `/report` | Daily summary metrics (trades, P&L, win rate) |
|
||||
| `/scenarios` | Show today's playbook scenarios |
|
||||
| `/review` | Display recent scorecards (L6_DAILY layer) |
|
||||
| `/dashboard` | Show dashboard URL if enabled |
|
||||
| `/stop` | Pause trading |
|
||||
| `/resume` | Resume trading |
|
||||
|
||||
Commands are only processed from the authorized `TELEGRAM_CHAT_ID`.
|
||||
|
||||
## KIS API TR_ID 참조 문서
|
||||
|
||||
**TR_ID를 추가하거나 수정할 때 반드시 공식 문서를 먼저 확인할 것.**
|
||||
|
||||
공식 문서: `docs/한국투자증권_오픈API_전체문서_20260221_030000.xlsx`
|
||||
|
||||
> ⚠️ 커뮤니티 블로그, GitHub 예제 등 비공식 자료의 TR_ID는 오래되거나 틀릴 수 있음.
|
||||
> 실제로 `VTTT1006U`(미국 매도 — 잘못됨)가 오랫동안 코드에 남아있던 사례가 있음 (Issue #189).
|
||||
|
||||
### 주요 TR_ID 목록
|
||||
|
||||
| 구분 | 모의투자 TR_ID | 실전투자 TR_ID | 시트명 |
|
||||
|------|---------------|---------------|--------|
|
||||
| 해외주식 매수 (미국) | `VTTT1002U` | `TTTT1002U` | 해외주식 주문 |
|
||||
| 해외주식 매도 (미국) | `VTTT1001U` | `TTTT1006U` | 해외주식 주문 |
|
||||
|
||||
새로운 TR_ID가 필요할 때:
|
||||
1. 위 xlsx 파일에서 해당 거래 유형의 시트를 찾는다.
|
||||
2. 모의투자(`VTTT`) / 실전투자(`TTTT`) 컬럼을 구분하여 정확한 값을 사용한다.
|
||||
3. 코드에 출처 주석을 남긴다: `# Source: 한국투자증권_오픈API_전체문서 — '<시트명>' 시트`
|
||||
|
||||
## Environment Setup
|
||||
|
||||
```bash
|
||||
|
||||
338
docs/context-tree.md
Normal file
338
docs/context-tree.md
Normal file
@@ -0,0 +1,338 @@
|
||||
# Context Tree: Multi-Layered Memory Management
|
||||
|
||||
The context tree implements **Pillar 2** of The Ouroboros: hierarchical memory management across 7 time horizons, from real-time market data to generational trading wisdom.
|
||||
|
||||
## Overview
|
||||
|
||||
Instead of a flat memory structure, The Ouroboros maintains a **7-tier context tree** where each layer represents a different time horizon and level of abstraction:
|
||||
|
||||
```
|
||||
L1 (Legacy) ← Cumulative wisdom across generations
|
||||
↑
|
||||
L2 (Annual) ← Yearly performance metrics
|
||||
↑
|
||||
L3 (Quarterly) ← Quarterly strategy adjustments
|
||||
↑
|
||||
L4 (Monthly) ← Monthly portfolio rebalancing
|
||||
↑
|
||||
L5 (Weekly) ← Weekly stock selection
|
||||
↑
|
||||
L6 (Daily) ← Daily trade logs
|
||||
↑
|
||||
L7 (Real-time) ← Live market data
|
||||
```
|
||||
|
||||
Data flows **bottom-up**: real-time trades aggregate into daily summaries, which roll up to weekly, then monthly, quarterly, annual, and finally into permanent legacy knowledge.
|
||||
|
||||
## The 7 Layers
|
||||
|
||||
### L7: Real-time
|
||||
**Retention**: 7 days
|
||||
**Timeframe format**: `YYYY-MM-DD` (same-day)
|
||||
**Content**: Current positions, live quotes, orderbook snapshots, tick-by-tick volatility
|
||||
|
||||
**Use cases**:
|
||||
- Immediate execution decisions
|
||||
- Stop-loss triggers
|
||||
- Real-time P&L tracking
|
||||
|
||||
**Example keys**:
|
||||
- `current_position_{stock_code}`: Current holdings
|
||||
- `live_price_{stock_code}`: Latest quote
|
||||
- `volatility_5m_{stock_code}`: 5-minute rolling volatility
|
||||
|
||||
### L6: Daily
|
||||
**Retention**: 90 days
|
||||
**Timeframe format**: `YYYY-MM-DD`
|
||||
**Content**: Daily trade logs, end-of-day P&L, market summaries, decision accuracy
|
||||
|
||||
**Use cases**:
|
||||
- Daily performance review
|
||||
- Identify patterns in recent trading
|
||||
- Backtest strategy adjustments
|
||||
|
||||
**Example keys**:
|
||||
- `total_pnl`: Daily profit/loss
|
||||
- `trade_count`: Number of trades
|
||||
- `win_rate`: Percentage of profitable trades
|
||||
- `avg_confidence`: Average Gemini confidence
|
||||
|
||||
### L5: Weekly
|
||||
**Retention**: 1 year
|
||||
**Timeframe format**: `YYYY-Www` (ISO week, e.g., `2026-W06`)
|
||||
**Content**: Weekly stock selection, sector rotation, volatility regime classification
|
||||
|
||||
**Use cases**:
|
||||
- Weekly strategy adjustment
|
||||
- Sector momentum tracking
|
||||
- Identify hot/cold markets
|
||||
|
||||
**Example keys**:
|
||||
- `weekly_pnl`: Week's total P&L
|
||||
- `top_performers`: Best-performing stocks
|
||||
- `sector_focus`: Dominant sectors
|
||||
- `avg_confidence`: Weekly average confidence
|
||||
|
||||
### L4: Monthly
|
||||
**Retention**: 2 years
|
||||
**Timeframe format**: `YYYY-MM`
|
||||
**Content**: Monthly portfolio rebalancing, risk exposure analysis, drawdown recovery
|
||||
|
||||
**Use cases**:
|
||||
- Monthly performance reporting
|
||||
- Risk exposure adjustment
|
||||
- Correlation analysis
|
||||
|
||||
**Example keys**:
|
||||
- `monthly_pnl`: Month's total P&L
|
||||
- `sharpe_ratio`: Risk-adjusted return
|
||||
- `max_drawdown`: Largest peak-to-trough decline
|
||||
- `rebalancing_notes`: Manual insights
|
||||
|
||||
### L3: Quarterly
|
||||
**Retention**: 3 years
|
||||
**Timeframe format**: `YYYY-Qn` (e.g., `2026-Q1`)
|
||||
**Content**: Quarterly strategy pivots, market phase detection (bull/bear/sideways), macro regime changes
|
||||
|
||||
**Use cases**:
|
||||
- Strategic pivots (e.g., growth → value)
|
||||
- Macro regime classification
|
||||
- Long-term pattern recognition
|
||||
|
||||
**Example keys**:
|
||||
- `quarterly_pnl`: Quarter's total P&L
|
||||
- `market_phase`: Bull/Bear/Sideways
|
||||
- `strategy_adjustments`: Major changes made
|
||||
- `lessons_learned`: Key insights
|
||||
|
||||
### L2: Annual
|
||||
**Retention**: 10 years
|
||||
**Timeframe format**: `YYYY`
|
||||
**Content**: Yearly returns, Sharpe ratio, max drawdown, win rate, strategy effectiveness
|
||||
|
||||
**Use cases**:
|
||||
- Annual performance review
|
||||
- Multi-year trend analysis
|
||||
- Strategy benchmarking
|
||||
|
||||
**Example keys**:
|
||||
- `annual_pnl`: Year's total P&L
|
||||
- `sharpe_ratio`: Annual risk-adjusted return
|
||||
- `win_rate`: Yearly win percentage
|
||||
- `best_strategy`: Most successful strategy
|
||||
- `worst_mistake`: Biggest lesson learned
|
||||
|
||||
### L1: Legacy
|
||||
**Retention**: Forever
|
||||
**Timeframe format**: `LEGACY` (single timeframe)
|
||||
**Content**: Cumulative trading history, core principles, generational wisdom
|
||||
|
||||
**Use cases**:
|
||||
- Long-term philosophy
|
||||
- Foundational rules
|
||||
- Lessons that transcend market cycles
|
||||
|
||||
**Example keys**:
|
||||
- `total_pnl`: All-time profit/loss
|
||||
- `years_traded`: Trading longevity
|
||||
- `avg_annual_pnl`: Long-term average return
|
||||
- `core_principles`: Immutable trading rules
|
||||
- `greatest_trades`: Hall of fame
|
||||
- `never_again`: Permanent warnings
|
||||
|
||||
## Usage
|
||||
|
||||
### Setting Context
|
||||
|
||||
```python
|
||||
from src.context import ContextLayer, ContextStore
|
||||
from src.db import init_db
|
||||
|
||||
conn = init_db("data/ouroboros.db")
|
||||
store = ContextStore(conn)
|
||||
|
||||
# Store daily P&L
|
||||
store.set_context(
|
||||
layer=ContextLayer.L6_DAILY,
|
||||
timeframe="2026-02-04",
|
||||
key="total_pnl",
|
||||
value=1234.56
|
||||
)
|
||||
|
||||
# Store weekly insight
|
||||
store.set_context(
|
||||
layer=ContextLayer.L5_WEEKLY,
|
||||
timeframe="2026-W06",
|
||||
key="top_performers",
|
||||
value=["005930", "000660", "035720"] # JSON-serializable
|
||||
)
|
||||
|
||||
# Store legacy wisdom
|
||||
store.set_context(
|
||||
layer=ContextLayer.L1_LEGACY,
|
||||
timeframe="LEGACY",
|
||||
key="core_principles",
|
||||
value=[
|
||||
"Cut losses fast",
|
||||
"Let winners run",
|
||||
"Never average down on losing positions"
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
### Retrieving Context
|
||||
|
||||
```python
|
||||
# Get a specific value
|
||||
pnl = store.get_context(ContextLayer.L6_DAILY, "2026-02-04", "total_pnl")
|
||||
# Returns: 1234.56
|
||||
|
||||
# Get all keys for a timeframe
|
||||
daily_summary = store.get_all_contexts(ContextLayer.L6_DAILY, "2026-02-04")
|
||||
# Returns: {"total_pnl": 1234.56, "trade_count": 10, "win_rate": 60.0, ...}
|
||||
|
||||
# Get all data for a layer (any timeframe)
|
||||
all_daily = store.get_all_contexts(ContextLayer.L6_DAILY)
|
||||
# Returns: {"total_pnl": 1234.56, "trade_count": 10, ...} (latest timeframes first)
|
||||
|
||||
# Get the latest timeframe
|
||||
latest = store.get_latest_timeframe(ContextLayer.L6_DAILY)
|
||||
# Returns: "2026-02-04"
|
||||
```
|
||||
|
||||
### Automatic Aggregation
|
||||
|
||||
The `ContextAggregator` rolls up data from lower to higher layers:
|
||||
|
||||
```python
|
||||
from src.context.aggregator import ContextAggregator
|
||||
|
||||
aggregator = ContextAggregator(conn)
|
||||
|
||||
# Aggregate daily metrics from trades
|
||||
aggregator.aggregate_daily_from_trades("2026-02-04")
|
||||
|
||||
# Roll up weekly from daily
|
||||
aggregator.aggregate_weekly_from_daily("2026-W06")
|
||||
|
||||
# Roll up all layers at once (bottom-up)
|
||||
aggregator.run_all_aggregations()
|
||||
```
|
||||
|
||||
**Aggregation schedule** (recommended):
|
||||
- **L7 → L6**: Every midnight (daily rollup)
|
||||
- **L6 → L5**: Every Sunday (weekly rollup)
|
||||
- **L5 → L4**: First day of each month (monthly rollup)
|
||||
- **L4 → L3**: First day of quarter (quarterly rollup)
|
||||
- **L3 → L2**: January 1st (annual rollup)
|
||||
- **L2 → L1**: On demand (major milestones)
|
||||
|
||||
### Context Cleanup
|
||||
|
||||
Expired contexts are automatically deleted based on retention policies:
|
||||
|
||||
```python
|
||||
# Manual cleanup
|
||||
deleted = store.cleanup_expired_contexts()
|
||||
# Returns: {ContextLayer.L7_REALTIME: 42, ContextLayer.L6_DAILY: 15, ...}
|
||||
```
|
||||
|
||||
**Retention policies** (defined in `src/context/layer.py`):
|
||||
- L1: Forever
|
||||
- L2: 10 years
|
||||
- L3: 3 years
|
||||
- L4: 2 years
|
||||
- L5: 1 year
|
||||
- L6: 90 days
|
||||
- L7: 7 days
|
||||
|
||||
## Integration with Gemini Brain
|
||||
|
||||
The context tree provides hierarchical memory for decision-making:
|
||||
|
||||
```python
|
||||
from src.brain.gemini_client import GeminiClient
|
||||
|
||||
# Build prompt with multi-layer context
|
||||
def build_enhanced_prompt(stock_code: str, store: ContextStore) -> str:
|
||||
# L7: Real-time data
|
||||
current_price = store.get_context(ContextLayer.L7_REALTIME, "2026-02-04", f"live_price_{stock_code}")
|
||||
|
||||
# L6: Recent daily performance
|
||||
yesterday_pnl = store.get_context(ContextLayer.L6_DAILY, "2026-02-03", "total_pnl")
|
||||
|
||||
# L5: Weekly trend
|
||||
weekly_data = store.get_all_contexts(ContextLayer.L5_WEEKLY, "2026-W06")
|
||||
|
||||
# L1: Core principles
|
||||
principles = store.get_context(ContextLayer.L1_LEGACY, "LEGACY", "core_principles")
|
||||
|
||||
return f"""
|
||||
Analyze {stock_code} for trading decision.
|
||||
|
||||
Current price: {current_price}
|
||||
Yesterday's P&L: {yesterday_pnl}
|
||||
This week: {weekly_data}
|
||||
|
||||
Core principles:
|
||||
{chr(10).join(f'- {p}' for p in principles)}
|
||||
|
||||
Decision (BUY/SELL/HOLD):
|
||||
"""
|
||||
```
|
||||
|
||||
## Database Schema
|
||||
|
||||
```sql
|
||||
-- Context storage
|
||||
CREATE TABLE contexts (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
layer TEXT NOT NULL, -- L1_LEGACY, L2_ANNUAL, ..., L7_REALTIME
|
||||
timeframe TEXT NOT NULL, -- "LEGACY", "2026", "2026-Q1", "2026-02", "2026-W06", "2026-02-04"
|
||||
key TEXT NOT NULL, -- "total_pnl", "win_rate", "core_principles", etc.
|
||||
value TEXT NOT NULL, -- JSON-serialized value
|
||||
created_at TEXT NOT NULL, -- ISO 8601 timestamp
|
||||
updated_at TEXT NOT NULL, -- ISO 8601 timestamp
|
||||
UNIQUE(layer, timeframe, key)
|
||||
);
|
||||
|
||||
-- Layer metadata
|
||||
CREATE TABLE context_metadata (
|
||||
layer TEXT PRIMARY KEY,
|
||||
description TEXT NOT NULL,
|
||||
retention_days INTEGER, -- NULL = keep forever
|
||||
aggregation_source TEXT -- Parent layer for rollup
|
||||
);
|
||||
|
||||
-- Indices for fast queries
|
||||
CREATE INDEX idx_contexts_layer ON contexts(layer);
|
||||
CREATE INDEX idx_contexts_timeframe ON contexts(timeframe);
|
||||
CREATE INDEX idx_contexts_updated ON contexts(updated_at);
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Write to leaf layers only** — Never manually write to L1-L5; let aggregation populate them
|
||||
2. **Aggregate regularly** — Schedule aggregation jobs to keep higher layers fresh
|
||||
3. **Query specific timeframes** — Use `get_context(layer, timeframe, key)` for precise retrieval
|
||||
4. **Clean up periodically** — Run `cleanup_expired_contexts()` weekly to free space
|
||||
5. **Preserve L1 forever** — Legacy wisdom should never expire
|
||||
6. **Use JSON-serializable values** — Store dicts, lists, strings, numbers (not custom objects)
|
||||
|
||||
## Testing
|
||||
|
||||
See `tests/test_context.py` for comprehensive test coverage (18 tests, 100% coverage on context modules).
|
||||
|
||||
```bash
|
||||
pytest tests/test_context.py -v
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- **Implementation**: `src/context/`
|
||||
- `layer.py`: Layer definitions and metadata
|
||||
- `store.py`: CRUD operations
|
||||
- `aggregator.py`: Bottom-up aggregation logic
|
||||
- **Database**: `src/db.py` (table initialization)
|
||||
- **Tests**: `tests/test_context.py`
|
||||
- **Related**: Pillar 2 (Multi-layered Context Management)
|
||||
348
docs/disaster_recovery.md
Normal file
348
docs/disaster_recovery.md
Normal file
@@ -0,0 +1,348 @@
|
||||
# Disaster Recovery Guide
|
||||
|
||||
Complete guide for backing up and restoring The Ouroboros trading system.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Backup Strategy](#backup-strategy)
|
||||
- [Creating Backups](#creating-backups)
|
||||
- [Restoring from Backup](#restoring-from-backup)
|
||||
- [Health Monitoring](#health-monitoring)
|
||||
- [Export Formats](#export-formats)
|
||||
- [RTO/RPO](#rtorpo)
|
||||
- [Testing Recovery](#testing-recovery)
|
||||
|
||||
## Backup Strategy
|
||||
|
||||
The system implements a 3-tier backup retention policy:
|
||||
|
||||
| Policy | Frequency | Retention | Purpose |
|
||||
|--------|-----------|-----------|---------|
|
||||
| **Daily** | Every day | 30 days | Quick recovery from recent issues |
|
||||
| **Weekly** | Sunday | 1 year | Medium-term historical analysis |
|
||||
| **Monthly** | 1st of month | Forever | Long-term archival |
|
||||
|
||||
### Storage Structure
|
||||
|
||||
```
|
||||
data/backups/
|
||||
├── daily/ # Last 30 days
|
||||
├── weekly/ # Last 52 weeks
|
||||
└── monthly/ # Forever (cold storage)
|
||||
```
|
||||
|
||||
## Creating Backups
|
||||
|
||||
### Automated Backups (Recommended)
|
||||
|
||||
Set up a cron job to run daily:
|
||||
|
||||
```bash
|
||||
# Edit crontab
|
||||
crontab -e
|
||||
|
||||
# Run backup at 2 AM every day
|
||||
0 2 * * * cd /path/to/The-Ouroboros && ./scripts/backup.sh >> logs/backup.log 2>&1
|
||||
```
|
||||
|
||||
### Manual Backups
|
||||
|
||||
```bash
|
||||
# Run backup script
|
||||
./scripts/backup.sh
|
||||
|
||||
# Or use Python directly
|
||||
python3 -c "
|
||||
from pathlib import Path
|
||||
from src.backup.scheduler import BackupScheduler, BackupPolicy
|
||||
|
||||
scheduler = BackupScheduler('data/trade_logs.db', Path('data/backups'))
|
||||
metadata = scheduler.create_backup(BackupPolicy.DAILY, verify=True)
|
||||
print(f'Backup created: {metadata.file_path}')
|
||||
"
|
||||
```
|
||||
|
||||
### Export to Other Formats
|
||||
|
||||
```bash
|
||||
python3 -c "
|
||||
from pathlib import Path
|
||||
from src.backup.exporter import BackupExporter, ExportFormat
|
||||
|
||||
exporter = BackupExporter('data/trade_logs.db')
|
||||
results = exporter.export_all(
|
||||
Path('exports'),
|
||||
formats=[ExportFormat.JSON, ExportFormat.CSV],
|
||||
compress=True
|
||||
)
|
||||
"
|
||||
```
|
||||
|
||||
## Restoring from Backup
|
||||
|
||||
### Interactive Restoration
|
||||
|
||||
```bash
|
||||
./scripts/restore.sh
|
||||
```
|
||||
|
||||
The script will:
|
||||
1. List available backups
|
||||
2. Ask you to select one
|
||||
3. Create a safety backup of current database
|
||||
4. Restore the selected backup
|
||||
5. Verify database integrity
|
||||
|
||||
### Manual Restoration
|
||||
|
||||
```python
|
||||
from pathlib import Path
|
||||
from src.backup.scheduler import BackupScheduler
|
||||
|
||||
scheduler = BackupScheduler('data/trade_logs.db', Path('data/backups'))
|
||||
|
||||
# List backups
|
||||
backups = scheduler.list_backups()
|
||||
for backup in backups:
|
||||
print(f"{backup.timestamp}: {backup.file_path}")
|
||||
|
||||
# Restore specific backup
|
||||
scheduler.restore_backup(backups[0], verify=True)
|
||||
```
|
||||
|
||||
## Health Monitoring
|
||||
|
||||
### Check System Health
|
||||
|
||||
```python
|
||||
from pathlib import Path
|
||||
from src.backup.health_monitor import HealthMonitor
|
||||
|
||||
monitor = HealthMonitor('data/trade_logs.db', Path('data/backups'))
|
||||
|
||||
# Run all checks
|
||||
report = monitor.get_health_report()
|
||||
print(f"Overall status: {report['overall_status']}")
|
||||
|
||||
# Individual checks
|
||||
checks = monitor.run_all_checks()
|
||||
for name, result in checks.items():
|
||||
print(f"{name}: {result.status.value} - {result.message}")
|
||||
```
|
||||
|
||||
### Health Checks
|
||||
|
||||
The system monitors:
|
||||
|
||||
- **Database Health**: Accessibility, integrity, size
|
||||
- **Disk Space**: Available storage (alerts if < 10 GB)
|
||||
- **Backup Recency**: Ensures backups are < 25 hours old
|
||||
|
||||
### Health Status Levels
|
||||
|
||||
- **HEALTHY**: All systems operational
|
||||
- **DEGRADED**: Warning condition (e.g., low disk space)
|
||||
- **UNHEALTHY**: Critical issue (e.g., database corrupted, no backups)
|
||||
|
||||
## Export Formats
|
||||
|
||||
### JSON (Human-Readable)
|
||||
|
||||
```json
|
||||
{
|
||||
"export_timestamp": "2024-01-15T10:30:00Z",
|
||||
"record_count": 150,
|
||||
"trades": [
|
||||
{
|
||||
"timestamp": "2024-01-15T09:00:00Z",
|
||||
"stock_code": "005930",
|
||||
"action": "BUY",
|
||||
"quantity": 10,
|
||||
"price": 70000.0,
|
||||
"confidence": 85,
|
||||
"rationale": "Strong momentum",
|
||||
"pnl": 0.0
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### CSV (Analysis Tools)
|
||||
|
||||
Compatible with Excel, pandas, R:
|
||||
|
||||
```csv
|
||||
timestamp,stock_code,action,quantity,price,confidence,rationale,pnl
|
||||
2024-01-15T09:00:00Z,005930,BUY,10,70000.0,85,Strong momentum,0.0
|
||||
```
|
||||
|
||||
### Parquet (Big Data)
|
||||
|
||||
Columnar format for Spark, DuckDB:
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
df = pd.read_parquet('exports/trades_20240115.parquet')
|
||||
```
|
||||
|
||||
## RTO/RPO
|
||||
|
||||
### Recovery Time Objective (RTO)
|
||||
|
||||
**Target: < 5 minutes**
|
||||
|
||||
Time to restore trading operations:
|
||||
1. Identify backup to restore (1 min)
|
||||
2. Run restore script (2 min)
|
||||
3. Verify database integrity (1 min)
|
||||
4. Restart trading system (1 min)
|
||||
|
||||
### Recovery Point Objective (RPO)
|
||||
|
||||
**Target: < 24 hours**
|
||||
|
||||
Maximum acceptable data loss:
|
||||
- Daily backups ensure ≤ 24-hour data loss
|
||||
- For critical periods, run backups more frequently
|
||||
|
||||
## Testing Recovery
|
||||
|
||||
### Quarterly Recovery Test
|
||||
|
||||
Perform full disaster recovery test every quarter:
|
||||
|
||||
1. **Create test backup**
|
||||
```bash
|
||||
./scripts/backup.sh
|
||||
```
|
||||
|
||||
2. **Simulate disaster** (use test database)
|
||||
```bash
|
||||
cp data/trade_logs.db data/trade_logs_test.db
|
||||
rm data/trade_logs_test.db # Simulate data loss
|
||||
```
|
||||
|
||||
3. **Restore from backup**
|
||||
```bash
|
||||
DB_PATH=data/trade_logs_test.db ./scripts/restore.sh
|
||||
```
|
||||
|
||||
4. **Verify data integrity**
|
||||
```python
|
||||
import sqlite3
|
||||
conn = sqlite3.connect('data/trade_logs_test.db')
|
||||
cursor = conn.execute('SELECT COUNT(*) FROM trades')
|
||||
print(f"Restored {cursor.fetchone()[0]} trades")
|
||||
```
|
||||
|
||||
5. **Document results** in `logs/recovery_test_YYYYMMDD.md`
|
||||
|
||||
### Backup Verification
|
||||
|
||||
Always verify backups after creation:
|
||||
|
||||
```python
|
||||
from pathlib import Path
|
||||
from src.backup.scheduler import BackupScheduler
|
||||
|
||||
scheduler = BackupScheduler('data/trade_logs.db', Path('data/backups'))
|
||||
|
||||
# Create and verify
|
||||
metadata = scheduler.create_backup(BackupPolicy.DAILY, verify=True)
|
||||
print(f"Checksum: {metadata.checksum}") # Should not be None
|
||||
```
|
||||
|
||||
## Emergency Procedures
|
||||
|
||||
### Database Corrupted
|
||||
|
||||
1. Stop trading system immediately
|
||||
2. Check most recent backup age: `ls -lht data/backups/daily/`
|
||||
3. Restore: `./scripts/restore.sh`
|
||||
4. Verify: Run health check
|
||||
5. Resume trading
|
||||
|
||||
### Disk Full
|
||||
|
||||
1. Check disk space: `df -h`
|
||||
2. Clean old backups: Run cleanup manually
|
||||
```python
|
||||
from pathlib import Path
|
||||
from src.backup.scheduler import BackupScheduler
|
||||
scheduler = BackupScheduler('data/trade_logs.db', Path('data/backups'))
|
||||
scheduler.cleanup_old_backups()
|
||||
```
|
||||
3. Consider archiving old monthly backups to external storage
|
||||
4. Increase disk space if needed
|
||||
|
||||
### Lost All Backups
|
||||
|
||||
If local backups are lost:
|
||||
1. Check if exports exist in `exports/` directory
|
||||
2. Reconstruct database from CSV/JSON exports
|
||||
3. If no exports: Check broker API for trade history
|
||||
4. Manual reconstruction as last resort
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Test Restores Regularly**: Don't wait for disaster
|
||||
2. **Monitor Disk Space**: Set up alerts at 80% usage
|
||||
3. **Keep Multiple Generations**: Never delete all backups at once
|
||||
4. **Verify Checksums**: Always verify backup integrity
|
||||
5. **Document Changes**: Update this guide when backup strategy changes
|
||||
6. **Off-Site Storage**: Consider external backup for monthly archives
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Backup Script Fails
|
||||
|
||||
```bash
|
||||
# Check database file permissions
|
||||
ls -l data/trade_logs.db
|
||||
|
||||
# Check disk space
|
||||
df -h data/
|
||||
|
||||
# Run backup manually with debug
|
||||
python3 -c "
|
||||
import logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
from pathlib import Path
|
||||
from src.backup.scheduler import BackupScheduler, BackupPolicy
|
||||
scheduler = BackupScheduler('data/trade_logs.db', Path('data/backups'))
|
||||
scheduler.create_backup(BackupPolicy.DAILY, verify=True)
|
||||
"
|
||||
```
|
||||
|
||||
### Restore Fails Verification
|
||||
|
||||
```bash
|
||||
# Check backup file integrity
|
||||
python3 -c "
|
||||
import sqlite3
|
||||
conn = sqlite3.connect('data/backups/daily/trade_logs_daily_20240115.db')
|
||||
cursor = conn.execute('PRAGMA integrity_check')
|
||||
print(cursor.fetchone()[0])
|
||||
"
|
||||
```
|
||||
|
||||
### Health Check Fails
|
||||
|
||||
```python
|
||||
from pathlib import Path
|
||||
from src.backup.health_monitor import HealthMonitor
|
||||
|
||||
monitor = HealthMonitor('data/trade_logs.db', Path('data/backups'))
|
||||
|
||||
# Check each component individually
|
||||
print("Database:", monitor.check_database_health())
|
||||
print("Disk Space:", monitor.check_disk_space())
|
||||
print("Backup Recency:", monitor.check_backup_recency())
|
||||
```
|
||||
|
||||
## Contact
|
||||
|
||||
For backup/recovery issues:
|
||||
- Check logs: `logs/backup.log`
|
||||
- Review health status: Run health monitor
|
||||
- Raise issue on GitHub if automated recovery fails
|
||||
131
docs/live-trading-checklist.md
Normal file
131
docs/live-trading-checklist.md
Normal file
@@ -0,0 +1,131 @@
|
||||
# 실전 전환 체크리스트
|
||||
|
||||
모의 거래(paper)에서 실전(live)으로 전환하기 전에 아래 항목을 **순서대로** 모두 확인하세요.
|
||||
|
||||
---
|
||||
|
||||
## 1. 사전 조건
|
||||
|
||||
### 1-1. KIS OpenAPI 실전 계좌 준비
|
||||
- [ ] 한국투자증권 계좌 개설 완료 (일반 위탁 계좌)
|
||||
- [ ] OpenAPI 실전 사용 신청 (KIS 홈페이지 → Open API → 서비스 신청)
|
||||
- [ ] 실전용 APP_KEY / APP_SECRET 발급 완료
|
||||
- [ ] KIS_ACCOUNT_NO 형식 확인: `XXXXXXXX-XX` (8자리-2자리)
|
||||
|
||||
### 1-2. 리스크 파라미터 검토
|
||||
- [ ] `CIRCUIT_BREAKER_PCT` 확인: 기본값 -3.0% (더 엄격하게 조정 권장)
|
||||
- [ ] `FAT_FINGER_PCT` 확인: 기본값 30.0% (1회 주문 최대 잔고 대비 %)
|
||||
- [ ] `CONFIDENCE_THRESHOLD` 확인: BEARISH ≥ 90, NEUTRAL ≥ 80, BULLISH ≥ 75
|
||||
- [ ] 초기 투자금 결정 및 해외 주식 운용 한도 설정
|
||||
|
||||
### 1-3. 시스템 요건
|
||||
- [ ] 커버리지 80% 이상 유지 확인: `pytest --cov=src`
|
||||
- [ ] 타입 체크 통과: `mypy src/ --strict`
|
||||
- [ ] Lint 통과: `ruff check src/ tests/`
|
||||
|
||||
---
|
||||
|
||||
## 2. 환경 설정
|
||||
|
||||
### 2-1. `.env` 파일 수정
|
||||
|
||||
```bash
|
||||
# 1. KIS 실전 URL로 변경 (모의: openapivts 포트 29443)
|
||||
KIS_BASE_URL=https://openapi.koreainvestment.com:9443
|
||||
|
||||
# 2. 실전 APP_KEY / APP_SECRET으로 교체
|
||||
KIS_APP_KEY=<실전_APP_KEY>
|
||||
KIS_APP_SECRET=<실전_APP_SECRET>
|
||||
KIS_ACCOUNT_NO=<실전_계좌번호>
|
||||
|
||||
# 3. 모드를 live로 변경
|
||||
MODE=live
|
||||
|
||||
# 4. PAPER_OVERSEAS_CASH 비활성화 (live 모드에선 무시되지만 명시적으로 0 설정)
|
||||
PAPER_OVERSEAS_CASH=0
|
||||
```
|
||||
|
||||
> ⚠️ `KIS_BASE_URL` 포트 주의:
|
||||
> - **모의(VTS)**: `https://openapivts.koreainvestment.com:29443`
|
||||
> - **실전**: `https://openapi.koreainvestment.com:9443`
|
||||
|
||||
### 2-2. TR_ID 자동 분기 확인
|
||||
|
||||
아래 TR_ID는 `MODE` 값에 따라 코드에서 **자동으로 선택**됩니다.
|
||||
별도 설정 불필요하나, 문제 발생 시 아래 표를 참조하세요.
|
||||
|
||||
| 구분 | 모의 TR_ID | 실전 TR_ID |
|
||||
|------|-----------|-----------|
|
||||
| 국내 잔고 조회 | `VTTC8434R` | `TTTC8434R` |
|
||||
| 국내 현금 매수 | `VTTC0012U` | `TTTC0012U` |
|
||||
| 국내 현금 매도 | `VTTC0011U` | `TTTC0011U` |
|
||||
| 해외 잔고 조회 | `VTTS3012R` | `TTTS3012R` |
|
||||
| 해외 매수 | `VTTT1002U` | `TTTT1002U` |
|
||||
| 해외 매도 | `VTTT1001U` | `TTTT1006U` |
|
||||
|
||||
> **출처**: `docs/한국투자증권_오픈API_전체문서_20260221_030000.xlsx` (공식 문서 기준)
|
||||
|
||||
---
|
||||
|
||||
## 3. 최종 확인
|
||||
|
||||
### 3-1. 실전 시작 전 점검
|
||||
- [ ] DB 백업 완료: `data/trade_logs.db` → `data/backups/`
|
||||
- [ ] Telegram 알림 설정 확인 (실전에서는 알림이 더욱 중요)
|
||||
- [ ] 소액으로 첫 거래 진행 후 TR_ID/계좌 정상 동작 확인
|
||||
|
||||
### 3-2. 실행 명령
|
||||
|
||||
```bash
|
||||
# 실전 모드로 실행
|
||||
python -m src.main --mode=live
|
||||
|
||||
# 대시보드 함께 실행 (별도 터미널에서 모니터링)
|
||||
python -m src.main --mode=live --dashboard
|
||||
```
|
||||
|
||||
### 3-3. 실전 시작 직후 확인 사항
|
||||
- [ ] 로그에 `MODE=live` 출력 확인
|
||||
- [ ] 첫 잔고 조회 성공 (ConnectionError 없음)
|
||||
- [ ] Telegram 알림 수신 확인 ("System started")
|
||||
- [ ] 첫 주문 후 KIS 앱에서 체결 내역 확인
|
||||
|
||||
---
|
||||
|
||||
## 4. 비상 정지 방법
|
||||
|
||||
### 즉각 정지
|
||||
```bash
|
||||
# 터미널에서 Ctrl+C (정상 종료 트리거)
|
||||
# 또는 Telegram 봇 명령:
|
||||
/stop
|
||||
```
|
||||
|
||||
### Circuit Breaker 발동 시
|
||||
- CB가 발동되면 자동으로 거래 중단 및 Telegram 알림 전송
|
||||
- CB 임계값: `CIRCUIT_BREAKER_PCT` (기본 -3.0%)
|
||||
- **임계값은 엄격하게만 조정 가능** (더 낮은 음수 값으로만 변경)
|
||||
|
||||
---
|
||||
|
||||
## 5. 롤백 절차
|
||||
|
||||
실전 전환 후 문제 발생 시:
|
||||
|
||||
```bash
|
||||
# 1. 즉시 .env에서 MODE=paper로 복원
|
||||
# 2. 재시작
|
||||
python -m src.main --mode=paper
|
||||
|
||||
# 3. DB에서 최근 거래 확인
|
||||
sqlite3 data/trade_logs.db "SELECT * FROM trades ORDER BY id DESC LIMIT 20;"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 관련 문서
|
||||
|
||||
- [시스템 아키텍처](architecture.md)
|
||||
- [워크플로우 가이드](workflow.md)
|
||||
- [재해 복구](disaster_recovery.md)
|
||||
- [Agent 제약 조건](agents.md)
|
||||
357
docs/requirements-log.md
Normal file
357
docs/requirements-log.md
Normal file
@@ -0,0 +1,357 @@
|
||||
# Requirements Log
|
||||
|
||||
프로젝트 진화를 위한 사용자 요구사항 기록.
|
||||
|
||||
이 문서는 시간순으로 사용자와의 대화에서 나온 요구사항과 피드백을 기록합니다.
|
||||
새로운 요구사항이 있으면 날짜와 함께 추가하세요.
|
||||
|
||||
---
|
||||
|
||||
## 2026-02-21
|
||||
|
||||
### 거래 상태 확인 중 발견된 버그 (#187)
|
||||
|
||||
- 거래 상태 점검 요청 → SELL 주문(손절/익절)이 Fat Finger에 막혀 전혀 실행 안 됨 발견
|
||||
- **#187 (Critical)**: SELL 주문에서 Fat Finger 오탐 — `order_amount/total_cash > 30%`가 SELL에도 적용되어 대형 포지션 매도 불가
|
||||
- JELD stop-loss -6.20% → 차단, RXT take-profit +46.13% → 차단
|
||||
- 수정: SELL은 `check_circuit_breaker`만 호출, `validate_order`(Fat Finger 포함) 미호출
|
||||
|
||||
---
|
||||
|
||||
## 2026-02-20
|
||||
|
||||
### 지속적 모니터링 및 개선점 도출 (이슈 #178~#182)
|
||||
|
||||
- Dashboard 포함해서 실행하며 간헐적 문제 모니터링 및 개선점 자동 도출 요청
|
||||
- 모니터링 결과 발견된 이슈 목록:
|
||||
- **#178**: uvicorn 미설치 → dashboard 미작동 + 오해의 소지 있는 시작 로그 → uvicorn 설치 완료
|
||||
- **#179 (Critical)**: 잔액 부족 주문 실패 후 매 사이클마다 무한 재시도 (MLECW 20분 이상 반복)
|
||||
- **#180**: 다중 인스턴스 실행 시 Telegram 409 충돌
|
||||
- **#181**: implied_rsi 공식 포화 문제 (change_rate≥12.5% → RSI=100)
|
||||
- **#182 (Critical)**: 보유 종목이 SmartScanner 변동성 필터에 걸려 SELL 신호 미생성 → SELL 체결 0건, 잔고 소진
|
||||
- 요구사항: 모니터링 자동화 및 주기적 개선점 리포트 도출
|
||||
|
||||
---
|
||||
|
||||
## 2026-02-05
|
||||
|
||||
### API 효율화
|
||||
- Gemini API는 귀중한 자원. 종목별 개별 호출 대신 배치 호출 필요
|
||||
- Free tier 한도(20 calls/day) 고려하여 일일 몇 차례 거래 모드로 전환
|
||||
- 배치 API 호출로 여러 종목을 한 번에 분석
|
||||
|
||||
### 거래 모드
|
||||
- **Daily Mode**: 하루 4회 거래 세션 (6시간 간격) - Free tier 호환
|
||||
- **Realtime Mode**: 60초 간격 실시간 거래 - 유료 구독 필요
|
||||
- `TRADE_MODE` 환경변수로 모드 선택
|
||||
|
||||
### 진화 시스템
|
||||
- 사용자 대화 내용을 문서로 기록하여 향후에도 의도 반영
|
||||
- 프롬프트 품질 검증은 별도 이슈로 다룰 예정
|
||||
|
||||
### 문서화
|
||||
- 시스템 구조, 기능별 설명 등 코드 문서화 항상 신경쓸 것
|
||||
- 새로운 기능 추가 시 관련 문서 업데이트 필수
|
||||
|
||||
---
|
||||
|
||||
## 2026-02-06
|
||||
|
||||
### Smart Volatility Scanner (Python-First, AI-Last 파이프라인)
|
||||
|
||||
**배경:**
|
||||
- 정적 종목 리스트를 순회하는 방식은 비효율적
|
||||
- KIS API 거래량 순위를 통해 시장 주도주를 자동 탐지해야 함
|
||||
- Gemini API 호출 전에 Python 기반 기술적 분석으로 필터링 필요
|
||||
|
||||
**요구사항:**
|
||||
1. KIS API 거래량 순위 API 통합 (`fetch_market_rankings`)
|
||||
2. 일별 가격 히스토리 API 추가 (`get_daily_prices`)
|
||||
3. RSI(14) 계산 기능 구현 (Wilder's smoothing method)
|
||||
4. 필터 조건:
|
||||
- 거래량 > 전일 대비 200% (VOL_MULTIPLIER)
|
||||
- RSI < 30 (과매도) OR RSI > 70 (모멘텀)
|
||||
5. 상위 1-3개 적격 종목만 Gemini에 전달
|
||||
6. 종목 선정 배경(RSI, volume_ratio, signal, score) 데이터베이스 기록
|
||||
|
||||
**구현 결과:**
|
||||
- `src/analysis/smart_scanner.py`: SmartVolatilityScanner 클래스
|
||||
- `src/analysis/volatility.py`: calculate_rsi() 메서드 추가
|
||||
- `src/broker/kis_api.py`: 2개 신규 API 메서드
|
||||
- `src/db.py`: selection_context 컬럼 추가
|
||||
- 설정 가능한 임계값: RSI_OVERSOLD_THRESHOLD, RSI_MOMENTUM_THRESHOLD, VOL_MULTIPLIER, SCANNER_TOP_N
|
||||
|
||||
**효과:**
|
||||
- Gemini API 호출 20-30개 → 1-3개로 감소
|
||||
- Python 기반 빠른 필터링 → 비용 절감
|
||||
- 선정 기준 추적 → Evolution 시스템 최적화 가능
|
||||
- API 장애 시 정적 watchlist로 자동 전환
|
||||
|
||||
**참고:** Realtime 모드 전용. Daily 모드는 배치 효율성을 위해 정적 watchlist 사용.
|
||||
|
||||
**이슈/PR:** #76, #77
|
||||
|
||||
---
|
||||
|
||||
## 2026-02-10
|
||||
|
||||
### 코드 리뷰 시 플랜-구현 일치 검증 규칙
|
||||
|
||||
**배경:**
|
||||
- 코드 리뷰 시 플랜(EnterPlanMode에서 승인된 계획)과 실제 구현이 일치하는지 확인하는 절차가 없었음
|
||||
- 플랜과 다른 구현이 리뷰 없이 통과될 위험
|
||||
|
||||
**요구사항:**
|
||||
1. 모든 PR 리뷰에서 플랜-구현 일치 여부를 필수 체크
|
||||
2. 플랜에 없는 변경은 정당한 사유 필요
|
||||
3. 플랜 항목이 누락되면 PR 설명에 사유 기록
|
||||
4. 스코프가 플랜과 일치하는지 확인
|
||||
|
||||
**구현 결과:**
|
||||
- `docs/workflow.md`에 Code Review Checklist 섹션 추가
|
||||
- Plan Consistency (필수), Safety & Constraints, Quality, Workflow 4개 카테고리
|
||||
|
||||
**이슈/PR:** #114
|
||||
|
||||
---
|
||||
|
||||
## 2026-02-16
|
||||
|
||||
### 문서 v2 동기화 (전체 문서 현행화)
|
||||
|
||||
**배경:**
|
||||
- v2 기능 구현 완료 후 문서가 실제 코드 상태와 크게 괴리
|
||||
- 문서에는 54 tests / 4 files로 기록되었으나 실제로는 551 tests / 25 files
|
||||
- v2 핵심 기능(Playbook, Scenario Engine, Dashboard, Telegram Commands, Daily Review, Context System, Backup) 문서화 누락
|
||||
|
||||
**요구사항:**
|
||||
1. `docs/testing.md` — 551 tests / 25 files 반영, 전체 테스트 파일 설명
|
||||
2. `docs/architecture.md` — v2 컴포넌트(Strategy, Context, Dashboard, Decision Logger 등) 추가, Playbook Mode 데이터 플로우, DB 스키마 5개 테이블, v2 환경변수
|
||||
3. `docs/commands.md` — Dashboard 실행 명령어, Telegram 명령어 9종 레퍼런스
|
||||
4. `CLAUDE.md` — Project Structure 트리 확장, 테스트 수 업데이트, `--dashboard` 플래그
|
||||
5. `docs/skills.md` — DB 파일명 `trades.db`로 통일, Dashboard 명령어 추가
|
||||
6. 기존에 유효한 트러블슈팅, 코드 예제 등은 유지
|
||||
|
||||
**구현 결과:**
|
||||
- 6개 문서 파일 업데이트
|
||||
- 이전 시도(2개 커밋)는 기존 내용을 과도하게 삭제하여 폐기, main 기준으로 재작업
|
||||
|
||||
**이슈/PR:** #131, PR #134
|
||||
|
||||
### 해외 스캐너 개선: 랭킹 연동 + 변동성 우선 선별
|
||||
|
||||
**배경:**
|
||||
- `run_overnight` 실운영에서 미국장 동안 거래가 0건 지속
|
||||
- 원인: 해외 시장에서도 국내 랭킹/일봉 API 경로를 사용하던 구조적 불일치
|
||||
|
||||
**요구사항:**
|
||||
1. 해외 시장도 랭킹 API 기반 유니버스 탐색 지원
|
||||
2. 단순 상승률/거래대금 상위가 아니라, **변동성이 큰 종목**을 우선 선별
|
||||
3. 고정 티커 fallback 금지
|
||||
|
||||
**구현 결과:**
|
||||
- `src/broker/overseas.py`
|
||||
- `fetch_overseas_rankings()` 추가 (fluctuation / volume)
|
||||
- 해외 랭킹 API 경로/TR_ID를 설정값으로 오버라이드 가능하게 구현
|
||||
- `src/analysis/smart_scanner.py`
|
||||
- market-aware 스캔(국내/해외 분리)
|
||||
- 해외: 랭킹 API 유니버스 + 변동성 우선 점수(일변동률 vs 장중 고저폭)
|
||||
- 거래대금/거래량 랭킹은 유동성 보정 점수로 활용
|
||||
- 랭킹 실패 시에는 동적 유니버스(active/recent/holdings)만 사용
|
||||
- `src/config.py`
|
||||
- `OVERSEAS_RANKING_*` 설정 추가
|
||||
|
||||
**효과:**
|
||||
- 해외 시장에서 스캐너 후보 0개로 정지되는 상황 완화
|
||||
- 종목 선정 기준이 단순 상승률 중심에서 변동성 중심으로 개선
|
||||
- 고정 티커 없이도 시장 주도 변동 종목 탐지 가능
|
||||
|
||||
### 국내 스캐너/주문수량 정렬: 변동성 우선 + 리스크 타기팅
|
||||
|
||||
**배경:**
|
||||
- 해외만 변동성 우선으로 동작하고, 국내는 RSI/거래량 필터 중심으로 동작해 시장 간 전략 일관성이 낮았음
|
||||
- 매수 수량이 고정 1주라서 변동성 구간별 익스포저 관리가 어려웠음
|
||||
|
||||
**요구사항:**
|
||||
1. 국내 스캐너도 변동성 우선 선별로 해외와 통일
|
||||
2. 고변동 종목일수록 포지션 크기를 줄이는 수량 산식 적용
|
||||
|
||||
**구현 결과:**
|
||||
- `src/analysis/smart_scanner.py`
|
||||
- 국내: `fluctuation ranking + volume ranking bonus` 기반 점수화로 전환
|
||||
- 점수는 `max(abs(change_rate), intraday_range_pct)` 중심으로 계산
|
||||
- 국내 랭킹 응답 스키마 키(`price`, `change_rate`, `volume`) 파싱 보강
|
||||
- `src/main.py`
|
||||
- `_determine_order_quantity()` 추가
|
||||
- BUY 시 변동성 점수 기반 동적 수량 산정 적용
|
||||
- `trading_cycle`, `run_daily_session` 경로 모두 동일 수량 로직 사용
|
||||
- `src/config.py`
|
||||
- `POSITION_SIZING_*` 설정 추가
|
||||
|
||||
**효과:**
|
||||
- 국내/해외 스캐너 기준이 변동성 중심으로 일관화
|
||||
- 고변동 구간에서 자동 익스포저 축소, 저변동 구간에서 과소진입 완화
|
||||
|
||||
## 2026-02-18
|
||||
|
||||
### KIS 해외 랭킹 API 404 에러 수정
|
||||
|
||||
**배경:**
|
||||
- KIS 해외주식 랭킹 API(`fetch_overseas_rankings`)가 모든 거래소에서 HTTP 404를 반환
|
||||
- Smart Scanner가 해외 시장 후보 종목을 찾지 못해 거래가 전혀 실행되지 않음
|
||||
|
||||
**근본 원인:**
|
||||
- TR_ID, API 경로, 거래소 코드가 모두 KIS 공식 문서와 불일치
|
||||
|
||||
**구현 결과:**
|
||||
- `src/config.py`: TR_ID/Path 기본값을 KIS 공식 스펙으로 수정
|
||||
- `src/broker/overseas.py`: 랭킹 API 전용 거래소 코드 매핑 추가 (NASD→NAS, NYSE→NYS, AMEX→AMS), 올바른 API 파라미터 사용
|
||||
- `tests/test_overseas_broker.py`: 19개 단위 테스트 추가
|
||||
|
||||
**효과:**
|
||||
- 해외 시장 랭킹 스캔이 정상 동작하여 Smart Scanner가 후보 종목 탐지 가능
|
||||
|
||||
### Gemini prompt_override 미적용 버그 수정
|
||||
|
||||
**배경:**
|
||||
- `run_overnight` 실행 시 모든 시장에서 Playbook 생성 실패 (`JSONDecodeError`)
|
||||
- defensive playbook으로 폴백되어 모든 종목이 HOLD 처리
|
||||
|
||||
**근본 원인:**
|
||||
- `pre_market_planner.py`가 `market_data["prompt_override"]`에 Playbook 전용 프롬프트를 넣어 `gemini.decide()` 호출
|
||||
- `gemini_client.py`의 `decide()` 메서드가 `prompt_override` 키를 전혀 확인하지 않고 항상 일반 트레이드 결정 프롬프트 생성
|
||||
- Gemini가 Playbook JSON 대신 일반 트레이드 결정을 반환하여 파싱 실패
|
||||
|
||||
**구현 결과:**
|
||||
- `src/brain/gemini_client.py`: `decide()` 메서드에서 `prompt_override` 우선 사용 로직 추가
|
||||
- `tests/test_brain.py`: 3개 테스트 추가 (override 전달, optimization 우회, 미지정 시 기존 동작 유지)
|
||||
|
||||
**이슈/PR:** #143
|
||||
|
||||
### 미국장 거래 미실행 근본 원인 분석 및 수정 (자율 실행 세션)
|
||||
|
||||
**배경:**
|
||||
- 사용자 요청: "미국장 열면 프로그램 돌려서 거래 한 번도 못 한 거 꼭 원인 찾아서 해결해줘"
|
||||
- 프로그램을 미국장 개장(9:30 AM EST) 전부터 실행하여 실시간 로그를 분석
|
||||
|
||||
**발견된 근본 원인 #1: Defensive Playbook — BUY 조건 없음**
|
||||
|
||||
- Gemini free tier (20 RPD) 소진 → `generate_playbook()` 실패 → `_defensive_playbook()` 폴백
|
||||
- Defensive playbook은 `price_change_pct_below: -3.0 → SELL` 조건만 존재, BUY 조건 없음
|
||||
- ScenarioEngine이 항상 HOLD 반환 → 거래 0건
|
||||
|
||||
**수정 #1 (PR #146, Issue #145):**
|
||||
- `src/strategy/pre_market_planner.py`: `_smart_fallback_playbook()` 메서드 추가
|
||||
- 스캐너 signal 기반 BUY 조건 생성: `momentum → volume_ratio_above`, `oversold → rsi_below`
|
||||
- 기존 defensive stop-loss SELL 조건 유지
|
||||
- Gemini 실패 시 defensive → smart fallback으로 전환
|
||||
- 테스트 10개 추가
|
||||
|
||||
**발견된 근본 원인 #2: 가격 API 거래소 코드 불일치 + VTS 잔고 API 오류**
|
||||
|
||||
실제 로그:
|
||||
```
|
||||
Scenario matched for MRNX: BUY (confidence=80) ✓
|
||||
Decision for EWUS (NYSE American): BUY (confidence=80) ✓
|
||||
Skip BUY APLZ (NYSE American): no affordable quantity (cash=0.00, price=0.00) ✗
|
||||
```
|
||||
|
||||
- `get_overseas_price()`: `NASD`/`NYSE`/`AMEX` 전송 → API가 `NAS`/`NYS`/`AMS` 기대 → 빈 응답 → `price=0`
|
||||
- `VTTS3012R` 잔고 API: "ERROR : INPUT INVALID_CHECK_ACNO" → `total_cash=0`
|
||||
- 결과: `_determine_order_quantity()` 가 0 반환 → 주문 건너뜀
|
||||
|
||||
**수정 #2 (PR #148, Issue #147):**
|
||||
- `src/broker/overseas.py`: `_PRICE_EXCHANGE_MAP = _RANKING_EXCHANGE_MAP` 추가, 가격 API에 매핑 적용
|
||||
- `src/config.py`: `PAPER_OVERSEAS_CASH: float = Field(default=50000.0)` — paper 모드 시뮬레이션 잔고
|
||||
- `src/main.py`: 잔고 0일 때 PAPER_OVERSEAS_CASH 폴백, 가격 0일 때 candidate.price 폴백
|
||||
- 테스트 8개 추가
|
||||
|
||||
**효과:**
|
||||
- BUY 결정 → 실제 주문 전송까지의 파이프라인이 완전히 동작
|
||||
- Paper 모드에서 KIS VTS 해외 잔고 API 오류에 관계없이 시뮬레이션 거래 가능
|
||||
|
||||
**이슈/PR:** #145, #146, #147, #148
|
||||
|
||||
### 해외주식 시장가 주문 거부 수정 (Fix #3, 연속 발견)
|
||||
|
||||
**배경:**
|
||||
- Fix #147 적용 후 주문 전송 시작 → KIS VTS가 거부: "지정가만 가능한 상품입니다"
|
||||
|
||||
**근본 원인:**
|
||||
- `trading_cycle()`, `run_daily_session()` 양쪽에서 `send_overseas_order(price=0.0)` 하드코딩
|
||||
- `price=0` → `ORD_DVSN="01"` (시장가) 전송 → KIS VTS 거부
|
||||
- Fix #147에서 이미 `current_price`를 올바르게 계산했으나 주문 시 미사용
|
||||
|
||||
**구현 결과:**
|
||||
- `src/main.py`: 두 곳에서 `price=0.0` → `price=current_price`/`price=stock_data["current_price"]`
|
||||
- `tests/test_main.py`: 회귀 테스트 `test_overseas_buy_order_uses_limit_price` 추가
|
||||
|
||||
**최종 확인 로그:**
|
||||
```
|
||||
Order result: 모의투자 매수주문이 완료 되었습니다. ✓
|
||||
```
|
||||
|
||||
**이슈/PR:** #149, #150
|
||||
|
||||
---
|
||||
|
||||
## 2026-02-23
|
||||
|
||||
### 국내주식 지정가 전환 및 미체결 처리 (#232)
|
||||
|
||||
**배경:**
|
||||
- 해외주식은 #211에서 지정가로 전환했으나 국내주식은 여전히 `price=0` (시장가)
|
||||
- KRX도 지정가 주문 사용 시 동일한 미체결 위험이 존재
|
||||
- 지정가 전환 + 미체결 처리를 함께 구현
|
||||
|
||||
**구현 내용:**
|
||||
|
||||
1. `src/broker/kis_api.py`
|
||||
- `get_domestic_pending_orders()`: 모의 즉시 `[]`, 실전 `TTTC0084R` GET
|
||||
- `cancel_domestic_order()`: 실전 `TTTC0013U` / 모의 `VTTC0013U`, hashkey 필수
|
||||
|
||||
2. `src/main.py`
|
||||
- import `kr_round_down` 추가
|
||||
- `trading_cycle`, `run_daily_session` 국내 주문 `price=0` → 지정가:
|
||||
BUY +0.2% / SELL -0.2%, `kr_round_down` KRX 틱 반올림 적용
|
||||
- `handle_domestic_pending_orders` 함수: BUY→취소+쿨다운, SELL→취소+재주문(-0.4%, 최대1회)
|
||||
- daily/realtime 두 모드에서 domestic pending 체크 호출 추가
|
||||
|
||||
3. 테스트 14개 추가:
|
||||
- `TestGetDomesticPendingOrders` (3), `TestCancelDomesticOrder` (5)
|
||||
- `TestHandleDomesticPendingOrders` (4), `TestDomesticLimitOrderPrice` (2)
|
||||
|
||||
**이슈/PR:** #232, PR #233
|
||||
|
||||
---
|
||||
|
||||
## 2026-02-24
|
||||
|
||||
### 해외잔고 ghost position 수정 — '모의투자 잔고내역이 없습니다' 반복 방지 (#235)
|
||||
|
||||
**배경:**
|
||||
- 모의투자 실행 시 MLECW, KNRX, NBY, SNSE 등 만료/정지된 종목에 대해
|
||||
`모의투자 잔고내역이 없습니다` 오류가 매 사이클 반복됨
|
||||
|
||||
**근본 원인:**
|
||||
1. `ovrs_cblc_qty` (해외잔고수량, 총 보유) vs `ord_psbl_qty` (주문가능수량, 실제 매도 가능)
|
||||
- 기존 코드: `ovrs_cblc_qty` 우선 사용 → 만료 Warrant가 `ovrs_cblc_qty=289456`이지만 실제 `ord_psbl_qty=0`
|
||||
- startup sync / build_overseas_symbol_universe가 이 종목들을 포지션으로 기록
|
||||
2. SELL 실패 시 DB 포지션이 닫히지 않아 다음 사이클에서도 재시도 (무한 반복)
|
||||
|
||||
**구현 내용:**
|
||||
|
||||
1. `src/main.py` — `_extract_held_codes_from_balance`, `_extract_held_qty_from_balance`
|
||||
- 해외 잔고 필드 우선순위 변경: `ord_psbl_qty` → `ovrs_cblc_qty` → `hldg_qty` (fallback 유지)
|
||||
- KIS 공식 문서(VTTS3012R) 기준: `ord_psbl_qty`가 실제 매도 가능 수량
|
||||
|
||||
2. `src/main.py` — `trading_cycle` ghost-close 처리
|
||||
- 해외 SELL이 `잔고내역이 없습니다`로 실패 시 DB 포지션을 `[ghost-close]` SELL로 종료
|
||||
- exchange code 불일치 등 예외 상황에서 무한 반복 방지
|
||||
|
||||
3. 테스트 7개 추가:
|
||||
- `TestExtractHeldQtyFromBalance` 3개: ord_psbl_qty 우선, 0이면 0 반환, fallback
|
||||
- `TestExtractHeldCodesFromBalance` 2개: ord_psbl_qty=0인 종목 제외, fallback
|
||||
- `TestOverseasGhostPositionClose` 2개: ghost-close 로그 확인, 일반 오류 무시
|
||||
|
||||
**이슈/PR:** #235, PR #236
|
||||
@@ -34,6 +34,12 @@ python -m src.main --mode=paper
|
||||
```
|
||||
Runs the agent in paper-trading mode (no real orders).
|
||||
|
||||
### Start Trading Agent with Dashboard
|
||||
```bash
|
||||
python -m src.main --mode=paper --dashboard
|
||||
```
|
||||
Runs the agent with FastAPI dashboard on `127.0.0.1:8080` (configurable via `DASHBOARD_HOST`/`DASHBOARD_PORT`).
|
||||
|
||||
### Start Trading Agent (Production)
|
||||
```bash
|
||||
docker compose up -d ouroboros
|
||||
@@ -59,7 +65,7 @@ Analyze the last 30 days of trade logs and generate performance metrics.
|
||||
python -m src.evolution.optimizer --evolve
|
||||
```
|
||||
Triggers the evolution engine to:
|
||||
1. Analyze `trade_logs.db` for failing patterns
|
||||
1. Analyze `trades.db` for failing patterns
|
||||
2. Ask Gemini to generate a new strategy
|
||||
3. Run tests on the new strategy
|
||||
4. Create a PR if tests pass
|
||||
@@ -91,12 +97,12 @@ curl http://localhost:8080/health
|
||||
|
||||
### View Trade Logs
|
||||
```bash
|
||||
sqlite3 data/trade_logs.db "SELECT * FROM trades ORDER BY timestamp DESC LIMIT 20;"
|
||||
sqlite3 data/trades.db "SELECT * FROM trades ORDER BY timestamp DESC LIMIT 20;"
|
||||
```
|
||||
|
||||
### Export Trade History
|
||||
```bash
|
||||
sqlite3 -header -csv data/trade_logs.db "SELECT * FROM trades;" > trades_export.csv
|
||||
sqlite3 -header -csv data/trades.db "SELECT * FROM trades;" > trades_export.csv
|
||||
```
|
||||
|
||||
## Safety Checklist (Pre-Deploy)
|
||||
|
||||
206
docs/testing.md
206
docs/testing.md
@@ -2,51 +2,29 @@
|
||||
|
||||
## Test Structure
|
||||
|
||||
**54 tests** across four files. `asyncio_mode = "auto"` in pyproject.toml — async tests need no special decorator.
|
||||
**551 tests** across **25 files**. `asyncio_mode = "auto"` in pyproject.toml — async tests need no special decorator.
|
||||
|
||||
The `settings` fixture in `conftest.py` provides safe defaults with test credentials and in-memory DB.
|
||||
|
||||
### Test Files
|
||||
|
||||
#### `tests/test_risk.py` (11 tests)
|
||||
- Circuit breaker boundaries
|
||||
- Fat-finger edge cases
|
||||
#### Core Components
|
||||
|
||||
##### `tests/test_risk.py` (14 tests)
|
||||
- Circuit breaker boundaries and exact threshold triggers
|
||||
- Fat-finger edge cases and percentage validation
|
||||
- P&L calculation edge cases
|
||||
- Order validation logic
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
def test_circuit_breaker_exact_threshold(risk_manager):
|
||||
"""Circuit breaker should trip at exactly -3.0%."""
|
||||
with pytest.raises(CircuitBreakerTripped):
|
||||
risk_manager.validate_order(
|
||||
current_pnl_pct=-3.0,
|
||||
order_amount=1000,
|
||||
total_cash=10000
|
||||
)
|
||||
```
|
||||
|
||||
#### `tests/test_broker.py` (6 tests)
|
||||
##### `tests/test_broker.py` (11 tests)
|
||||
- OAuth token lifecycle
|
||||
- Rate limiting enforcement
|
||||
- Hash key generation
|
||||
- Network error handling
|
||||
- SSL context configuration
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
async def test_rate_limiter(broker):
|
||||
"""Rate limiter should delay requests to stay under 10 RPS."""
|
||||
start = time.monotonic()
|
||||
for _ in range(15): # 15 requests
|
||||
await broker._rate_limiter.acquire()
|
||||
elapsed = time.monotonic() - start
|
||||
assert elapsed >= 1.0 # Should take at least 1 second
|
||||
```
|
||||
|
||||
#### `tests/test_brain.py` (18 tests)
|
||||
- Valid JSON parsing
|
||||
- Markdown-wrapped JSON handling
|
||||
##### `tests/test_brain.py` (24 tests)
|
||||
- Valid JSON parsing and markdown-wrapped JSON handling
|
||||
- Malformed JSON fallback
|
||||
- Missing fields handling
|
||||
- Invalid action validation
|
||||
@@ -54,33 +32,143 @@ async def test_rate_limiter(broker):
|
||||
- Empty response handling
|
||||
- Prompt construction for different markets
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
async def test_confidence_below_threshold_forces_hold(brain):
|
||||
"""Decisions below confidence threshold should force HOLD."""
|
||||
decision = brain.parse_response('{"action":"BUY","confidence":70,"rationale":"test"}')
|
||||
assert decision.action == "HOLD"
|
||||
assert decision.confidence == 70
|
||||
```
|
||||
|
||||
#### `tests/test_market_schedule.py` (19 tests)
|
||||
##### `tests/test_market_schedule.py` (24 tests)
|
||||
- Market open/close logic
|
||||
- Timezone handling (UTC, Asia/Seoul, America/New_York, etc.)
|
||||
- DST (Daylight Saving Time) transitions
|
||||
- Weekend handling
|
||||
- Lunch break logic
|
||||
- Weekend handling and lunch break logic
|
||||
- Multiple market filtering
|
||||
- Next market open calculation
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
def test_is_market_open_during_trading_hours():
|
||||
"""Market should be open during regular trading hours."""
|
||||
# KRX: 9:00-15:30 KST, no lunch break
|
||||
market = MARKETS["KR"]
|
||||
trading_time = datetime(2026, 2, 3, 10, 0, tzinfo=ZoneInfo("Asia/Seoul")) # Monday 10:00
|
||||
assert is_market_open(market, trading_time) is True
|
||||
```
|
||||
##### `tests/test_db.py` (3 tests)
|
||||
- Database initialization and table creation
|
||||
- Trade logging with all fields (market, exchange_code, decision_id)
|
||||
- Query and retrieval operations
|
||||
|
||||
##### `tests/test_main.py` (37 tests)
|
||||
- Trading loop orchestration
|
||||
- Market iteration and stock processing
|
||||
- Dashboard integration (`--dashboard` flag)
|
||||
- Telegram command handler wiring
|
||||
- Error handling and graceful shutdown
|
||||
|
||||
#### Strategy & Playbook (v2)
|
||||
|
||||
##### `tests/test_pre_market_planner.py` (37 tests)
|
||||
- Pre-market playbook generation
|
||||
- Gemini API integration for scenario creation
|
||||
- Timeout handling and defensive playbook fallback
|
||||
- Multi-market playbook generation
|
||||
|
||||
##### `tests/test_scenario_engine.py` (44 tests)
|
||||
- Scenario matching against live market data
|
||||
- Confidence scoring and threshold filtering
|
||||
- Multiple scenario type handling
|
||||
- Edge cases (no match, partial match, expired scenarios)
|
||||
|
||||
##### `tests/test_playbook_store.py` (23 tests)
|
||||
- Playbook persistence to SQLite
|
||||
- Date-based retrieval and market filtering
|
||||
- Playbook status management (generated, active, expired)
|
||||
- JSON serialization/deserialization
|
||||
|
||||
##### `tests/test_strategy_models.py` (33 tests)
|
||||
- Pydantic model validation for scenarios, playbooks, decisions
|
||||
- Field constraints and default values
|
||||
- Serialization round-trips
|
||||
|
||||
#### Analysis & Scanning
|
||||
|
||||
##### `tests/test_volatility.py` (24 tests)
|
||||
- ATR and RSI calculation accuracy
|
||||
- Volume surge ratio computation
|
||||
- Momentum scoring
|
||||
- Breakout/breakdown pattern detection
|
||||
- Market scanner watchlist management
|
||||
|
||||
##### `tests/test_smart_scanner.py` (13 tests)
|
||||
- Python-first filtering pipeline
|
||||
- RSI and volume ratio filter logic
|
||||
- Candidate scoring and ranking
|
||||
- Fallback to static watchlist
|
||||
|
||||
#### Context & Memory
|
||||
|
||||
##### `tests/test_context.py` (18 tests)
|
||||
- L1-L7 layer storage and retrieval
|
||||
- Context key-value CRUD operations
|
||||
- Timeframe-based queries
|
||||
- Layer metadata management
|
||||
|
||||
##### `tests/test_context_scheduler.py` (5 tests)
|
||||
- Periodic context aggregation scheduling
|
||||
- Layer summarization triggers
|
||||
|
||||
#### Evolution & Review
|
||||
|
||||
##### `tests/test_evolution.py` (24 tests)
|
||||
- Strategy optimization loop
|
||||
- High-confidence losing trade analysis
|
||||
- Generated strategy validation
|
||||
|
||||
##### `tests/test_daily_review.py` (10 tests)
|
||||
- End-of-day review generation
|
||||
- Trade performance summarization
|
||||
- Context layer (L6_DAILY) integration
|
||||
|
||||
##### `tests/test_scorecard.py` (3 tests)
|
||||
- Daily scorecard metrics calculation
|
||||
- Win rate, P&L, confidence tracking
|
||||
|
||||
#### Notifications & Commands
|
||||
|
||||
##### `tests/test_telegram.py` (25 tests)
|
||||
- Message sending and formatting
|
||||
- Rate limiting (leaky bucket)
|
||||
- Error handling (network timeout, invalid token)
|
||||
- Auto-disable on missing credentials
|
||||
- Notification types (trade, circuit breaker, fat-finger, market events)
|
||||
|
||||
##### `tests/test_telegram_commands.py` (31 tests)
|
||||
- 9 command handlers (/help, /status, /positions, /report, /scenarios, /review, /dashboard, /stop, /resume)
|
||||
- Long polling and command dispatch
|
||||
- Authorization filtering by chat_id
|
||||
- Command response formatting
|
||||
|
||||
#### Dashboard
|
||||
|
||||
##### `tests/test_dashboard.py` (14 tests)
|
||||
- FastAPI endpoint responses (8 API routes)
|
||||
- Status, playbook, scorecard, performance, context, decisions, scenarios
|
||||
- Query parameter handling (market, date, limit)
|
||||
|
||||
#### Performance & Quality
|
||||
|
||||
##### `tests/test_token_efficiency.py` (34 tests)
|
||||
- Gemini token usage optimization
|
||||
- Prompt size reduction verification
|
||||
- Cache effectiveness
|
||||
|
||||
##### `tests/test_latency_control.py` (30 tests)
|
||||
- API call latency measurement
|
||||
- Rate limiter timing accuracy
|
||||
- Async operation overhead
|
||||
|
||||
##### `tests/test_decision_logger.py` (9 tests)
|
||||
- Decision audit trail completeness
|
||||
- Context snapshot capture
|
||||
- Outcome tracking (P&L, accuracy)
|
||||
|
||||
##### `tests/test_data_integration.py` (38 tests)
|
||||
- External data source integration
|
||||
- News API, market data, economic calendar
|
||||
- Error handling for API failures
|
||||
|
||||
##### `tests/test_backup.py` (23 tests)
|
||||
- Backup scheduler and execution
|
||||
- Cloud storage (S3) upload
|
||||
- Health monitoring
|
||||
- Data export functionality
|
||||
|
||||
## Coverage Requirements
|
||||
|
||||
@@ -91,20 +179,6 @@ Check coverage:
|
||||
pytest -v --cov=src --cov-report=term-missing
|
||||
```
|
||||
|
||||
Expected output:
|
||||
```
|
||||
Name Stmts Miss Cover Missing
|
||||
-----------------------------------------------------------
|
||||
src/brain/gemini_client.py 85 5 94% 165-169
|
||||
src/broker/kis_api.py 120 12 90% ...
|
||||
src/core/risk_manager.py 35 2 94% ...
|
||||
src/db.py 25 1 96% ...
|
||||
src/main.py 150 80 47% (excluded from CI)
|
||||
src/markets/schedule.py 95 3 97% ...
|
||||
-----------------------------------------------------------
|
||||
TOTAL 510 103 80%
|
||||
```
|
||||
|
||||
**Note:** `main.py` has lower coverage as it contains the main loop which is tested via integration/manual testing.
|
||||
|
||||
## Test Configuration
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
1. **Create Gitea Issue First** — All features, bug fixes, and policy changes require a Gitea issue before any code is written
|
||||
2. **Create Feature Branch** — Branch from `main` using format `feature/issue-{N}-{short-description}`
|
||||
- After creating the branch, run `git pull origin main` and rebase to ensure the branch is up to date
|
||||
3. **Implement Changes** — Write code, tests, and documentation on the feature branch
|
||||
4. **Create Pull Request** — Submit PR to `main` branch referencing the issue number
|
||||
5. **Review & Merge** — After approval, merge via PR (squash or merge commit)
|
||||
@@ -73,3 +74,37 @@ task_tool(
|
||||
```
|
||||
|
||||
Use `run_in_background=True` for independent tasks that don't block subsequent work.
|
||||
|
||||
## Code Review Checklist
|
||||
|
||||
**CRITICAL: Every PR review MUST verify plan-implementation consistency.**
|
||||
|
||||
Before approving any PR, the reviewer (human or agent) must check ALL of the following:
|
||||
|
||||
### 1. Plan Consistency (MANDATORY)
|
||||
|
||||
- [ ] **Implementation matches the approved plan** — Compare the actual code changes against the plan created during `EnterPlanMode`. Every item in the plan must be addressed.
|
||||
- [ ] **No unplanned changes** — If the implementation includes changes not in the plan, they must be explicitly justified.
|
||||
- [ ] **No plan items omitted** — If any planned item was skipped, the reason must be documented in the PR description.
|
||||
- [ ] **Scope matches** — The PR does not exceed or fall short of the planned scope.
|
||||
|
||||
### 2. Safety & Constraints
|
||||
|
||||
- [ ] `src/core/risk_manager.py` is unchanged (READ-ONLY)
|
||||
- [ ] Circuit breaker threshold not weakened (only stricter allowed)
|
||||
- [ ] Fat-finger protection (30% max order) still enforced
|
||||
- [ ] Confidence < 80 still forces HOLD
|
||||
- [ ] No hardcoded API keys or secrets
|
||||
|
||||
### 3. Quality
|
||||
|
||||
- [ ] All new/modified code has corresponding tests
|
||||
- [ ] Test coverage >= 80%
|
||||
- [ ] `ruff check src/ tests/` passes (no lint errors)
|
||||
- [ ] No `assert` statements removed from tests
|
||||
|
||||
### 4. Workflow
|
||||
|
||||
- [ ] PR references the Gitea issue number
|
||||
- [ ] Feature branch follows naming convention (`feature/issue-N-description`)
|
||||
- [ ] Commit messages are clear and descriptive
|
||||
|
||||
@@ -8,6 +8,9 @@ dependencies = [
|
||||
"pydantic>=2.5,<3",
|
||||
"pydantic-settings>=2.1,<3",
|
||||
"google-genai>=1.0,<2",
|
||||
"scipy>=1.11,<2",
|
||||
"fastapi>=0.110,<1",
|
||||
"uvicorn>=0.29,<1",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
96
scripts/backup.sh
Normal file
96
scripts/backup.sh
Normal file
@@ -0,0 +1,96 @@
|
||||
#!/usr/bin/env bash
|
||||
# Automated backup script for The Ouroboros trading system
|
||||
# Runs daily/weekly/monthly backups
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Configuration
|
||||
DB_PATH="${DB_PATH:-data/trade_logs.db}"
|
||||
BACKUP_DIR="${BACKUP_DIR:-data/backups}"
|
||||
PYTHON="${PYTHON:-python3}"
|
||||
|
||||
# Colors for output
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
log_info() {
|
||||
echo -e "${GREEN}[INFO]${NC} $1"
|
||||
}
|
||||
|
||||
log_warn() {
|
||||
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||
}
|
||||
|
||||
log_error() {
|
||||
echo -e "${RED}[ERROR]${NC} $1"
|
||||
}
|
||||
|
||||
# Check if database exists
|
||||
if [ ! -f "$DB_PATH" ]; then
|
||||
log_error "Database not found: $DB_PATH"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Create backup directory
|
||||
mkdir -p "$BACKUP_DIR"
|
||||
|
||||
log_info "Starting backup process..."
|
||||
log_info "Database: $DB_PATH"
|
||||
log_info "Backup directory: $BACKUP_DIR"
|
||||
|
||||
# Determine backup policy based on day of week and month
|
||||
DAY_OF_WEEK=$(date +%u) # 1-7 (Monday-Sunday)
|
||||
DAY_OF_MONTH=$(date +%d)
|
||||
|
||||
if [ "$DAY_OF_MONTH" == "01" ]; then
|
||||
POLICY="monthly"
|
||||
log_info "Running MONTHLY backup (first day of month)"
|
||||
elif [ "$DAY_OF_WEEK" == "7" ]; then
|
||||
POLICY="weekly"
|
||||
log_info "Running WEEKLY backup (Sunday)"
|
||||
else
|
||||
POLICY="daily"
|
||||
log_info "Running DAILY backup"
|
||||
fi
|
||||
|
||||
# Run Python backup script
|
||||
$PYTHON -c "
|
||||
from pathlib import Path
|
||||
from src.backup.scheduler import BackupScheduler, BackupPolicy
|
||||
from src.backup.health_monitor import HealthMonitor
|
||||
|
||||
# Create scheduler
|
||||
scheduler = BackupScheduler(
|
||||
db_path='$DB_PATH',
|
||||
backup_dir=Path('$BACKUP_DIR')
|
||||
)
|
||||
|
||||
# Create backup
|
||||
policy = BackupPolicy.$POLICY.upper()
|
||||
metadata = scheduler.create_backup(policy, verify=True)
|
||||
print(f'Backup created: {metadata.file_path}')
|
||||
print(f'Size: {metadata.size_bytes / 1024 / 1024:.2f} MB')
|
||||
print(f'Checksum: {metadata.checksum}')
|
||||
|
||||
# Cleanup old backups
|
||||
removed = scheduler.cleanup_old_backups()
|
||||
total_removed = sum(removed.values())
|
||||
if total_removed > 0:
|
||||
print(f'Removed {total_removed} old backup(s)')
|
||||
|
||||
# Health check
|
||||
monitor = HealthMonitor('$DB_PATH', Path('$BACKUP_DIR'))
|
||||
status = monitor.get_overall_status()
|
||||
print(f'System health: {status.value}')
|
||||
"
|
||||
|
||||
if [ $? -eq 0 ]; then
|
||||
log_info "Backup completed successfully"
|
||||
else
|
||||
log_error "Backup failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
log_info "Backup process finished"
|
||||
54
scripts/morning_report.sh
Executable file
54
scripts/morning_report.sh
Executable file
@@ -0,0 +1,54 @@
|
||||
#!/usr/bin/env bash
|
||||
# Morning summary for overnight run logs.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
LOG_DIR="${LOG_DIR:-data/overnight}"
|
||||
|
||||
if [ ! -d "$LOG_DIR" ]; then
|
||||
echo "로그 디렉터리가 없습니다: $LOG_DIR"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
latest_run="$(ls -1t "$LOG_DIR"/run_*.log 2>/dev/null | head -n 1 || true)"
|
||||
latest_watchdog="$(ls -1t "$LOG_DIR"/watchdog_*.log 2>/dev/null | head -n 1 || true)"
|
||||
|
||||
if [ -z "$latest_run" ]; then
|
||||
echo "run 로그가 없습니다: $LOG_DIR/run_*.log"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Overnight report"
|
||||
echo "- run log: $latest_run"
|
||||
if [ -n "$latest_watchdog" ]; then
|
||||
echo "- watchdog log: $latest_watchdog"
|
||||
fi
|
||||
|
||||
start_line="$(head -n 1 "$latest_run" || true)"
|
||||
end_line="$(tail -n 1 "$latest_run" || true)"
|
||||
|
||||
info_count="$(rg -c '"level": "INFO"' "$latest_run" || true)"
|
||||
warn_count="$(rg -c '"level": "WARNING"' "$latest_run" || true)"
|
||||
error_count="$(rg -c '"level": "ERROR"' "$latest_run" || true)"
|
||||
critical_count="$(rg -c '"level": "CRITICAL"' "$latest_run" || true)"
|
||||
traceback_count="$(rg -c 'Traceback' "$latest_run" || true)"
|
||||
|
||||
echo "- start: ${start_line:-N/A}"
|
||||
echo "- end: ${end_line:-N/A}"
|
||||
echo "- INFO: ${info_count:-0}"
|
||||
echo "- WARNING: ${warn_count:-0}"
|
||||
echo "- ERROR: ${error_count:-0}"
|
||||
echo "- CRITICAL: ${critical_count:-0}"
|
||||
echo "- Traceback: ${traceback_count:-0}"
|
||||
|
||||
if [ -n "$latest_watchdog" ]; then
|
||||
watchdog_errors="$(rg -c '\[ERROR\]' "$latest_watchdog" || true)"
|
||||
echo "- watchdog ERROR: ${watchdog_errors:-0}"
|
||||
echo ""
|
||||
echo "최근 watchdog 로그:"
|
||||
tail -n 5 "$latest_watchdog" || true
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "최근 앱 로그:"
|
||||
tail -n 20 "$latest_run" || true
|
||||
111
scripts/restore.sh
Normal file
111
scripts/restore.sh
Normal file
@@ -0,0 +1,111 @@
|
||||
#!/usr/bin/env bash
|
||||
# Restore script for The Ouroboros trading system
|
||||
# Restores database from a backup file
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Configuration
|
||||
DB_PATH="${DB_PATH:-data/trade_logs.db}"
|
||||
BACKUP_DIR="${BACKUP_DIR:-data/backups}"
|
||||
PYTHON="${PYTHON:-python3}"
|
||||
|
||||
# Colors for output
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
log_info() {
|
||||
echo -e "${GREEN}[INFO]${NC} $1"
|
||||
}
|
||||
|
||||
log_warn() {
|
||||
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||
}
|
||||
|
||||
log_error() {
|
||||
echo -e "${RED}[ERROR]${NC} $1"
|
||||
}
|
||||
|
||||
# Check if backup directory exists
|
||||
if [ ! -d "$BACKUP_DIR" ]; then
|
||||
log_error "Backup directory not found: $BACKUP_DIR"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
log_info "Available backups:"
|
||||
log_info "=================="
|
||||
|
||||
# List available backups
|
||||
$PYTHON -c "
|
||||
from pathlib import Path
|
||||
from src.backup.scheduler import BackupScheduler
|
||||
|
||||
scheduler = BackupScheduler(
|
||||
db_path='$DB_PATH',
|
||||
backup_dir=Path('$BACKUP_DIR')
|
||||
)
|
||||
|
||||
backups = scheduler.list_backups()
|
||||
|
||||
if not backups:
|
||||
print('No backups found.')
|
||||
exit(1)
|
||||
|
||||
for i, backup in enumerate(backups, 1):
|
||||
size_mb = backup.size_bytes / 1024 / 1024
|
||||
print(f'{i}. [{backup.policy.value.upper()}] {backup.file_path.name}')
|
||||
print(f' Date: {backup.timestamp.strftime(\"%Y-%m-%d %H:%M:%S UTC\")}')
|
||||
print(f' Size: {size_mb:.2f} MB')
|
||||
print()
|
||||
"
|
||||
|
||||
# Ask user to select backup
|
||||
echo ""
|
||||
read -p "Enter backup number to restore (or 'q' to quit): " BACKUP_NUM
|
||||
|
||||
if [ "$BACKUP_NUM" == "q" ]; then
|
||||
log_info "Restore cancelled"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Confirm restoration
|
||||
log_warn "WARNING: This will replace the current database!"
|
||||
log_warn "Current database will be backed up to: ${DB_PATH}.before_restore"
|
||||
read -p "Are you sure you want to continue? (yes/no): " CONFIRM
|
||||
|
||||
if [ "$CONFIRM" != "yes" ]; then
|
||||
log_info "Restore cancelled"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Perform restoration
|
||||
$PYTHON -c "
|
||||
from pathlib import Path
|
||||
from src.backup.scheduler import BackupScheduler
|
||||
|
||||
scheduler = BackupScheduler(
|
||||
db_path='$DB_PATH',
|
||||
backup_dir=Path('$BACKUP_DIR')
|
||||
)
|
||||
|
||||
backups = scheduler.list_backups()
|
||||
backup_index = int('$BACKUP_NUM') - 1
|
||||
|
||||
if backup_index < 0 or backup_index >= len(backups):
|
||||
print('Invalid backup number')
|
||||
exit(1)
|
||||
|
||||
selected = backups[backup_index]
|
||||
print(f'Restoring: {selected.file_path.name}')
|
||||
|
||||
scheduler.restore_backup(selected, verify=True)
|
||||
print('Restore completed successfully')
|
||||
"
|
||||
|
||||
if [ $? -eq 0 ]; then
|
||||
log_info "Database restored successfully"
|
||||
else
|
||||
log_error "Restore failed"
|
||||
exit 1
|
||||
fi
|
||||
87
scripts/run_overnight.sh
Executable file
87
scripts/run_overnight.sh
Executable file
@@ -0,0 +1,87 @@
|
||||
#!/usr/bin/env bash
|
||||
# Start The Ouroboros overnight with logs and watchdog.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
LOG_DIR="${LOG_DIR:-data/overnight}"
|
||||
CHECK_INTERVAL="${CHECK_INTERVAL:-30}"
|
||||
TMUX_AUTO="${TMUX_AUTO:-true}"
|
||||
TMUX_ATTACH="${TMUX_ATTACH:-true}"
|
||||
TMUX_SESSION_PREFIX="${TMUX_SESSION_PREFIX:-ouroboros_overnight}"
|
||||
|
||||
if [ -z "${APP_CMD:-}" ]; then
|
||||
if [ -x ".venv/bin/python" ]; then
|
||||
PYTHON_BIN=".venv/bin/python"
|
||||
elif command -v python3 >/dev/null 2>&1; then
|
||||
PYTHON_BIN="python3"
|
||||
elif command -v python >/dev/null 2>&1; then
|
||||
PYTHON_BIN="python"
|
||||
else
|
||||
echo ".venv/bin/python 또는 python3/python 실행 파일을 찾을 수 없습니다."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
dashboard_port="${DASHBOARD_PORT:-8080}"
|
||||
|
||||
APP_CMD="DASHBOARD_PORT=$dashboard_port $PYTHON_BIN -m src.main --mode=paper --dashboard"
|
||||
fi
|
||||
|
||||
mkdir -p "$LOG_DIR"
|
||||
|
||||
timestamp="$(date +"%Y%m%d_%H%M%S")"
|
||||
RUN_LOG="$LOG_DIR/run_${timestamp}.log"
|
||||
WATCHDOG_LOG="$LOG_DIR/watchdog_${timestamp}.log"
|
||||
PID_FILE="$LOG_DIR/app.pid"
|
||||
WATCHDOG_PID_FILE="$LOG_DIR/watchdog.pid"
|
||||
|
||||
if [ -f "$PID_FILE" ]; then
|
||||
old_pid="$(cat "$PID_FILE" || true)"
|
||||
if [ -n "$old_pid" ] && kill -0 "$old_pid" 2>/dev/null; then
|
||||
echo "앱이 이미 실행 중입니다. pid=$old_pid"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "[$(date -u +"%Y-%m-%dT%H:%M:%SZ")] starting: $APP_CMD" | tee -a "$RUN_LOG"
|
||||
nohup bash -lc "$APP_CMD" >>"$RUN_LOG" 2>&1 &
|
||||
app_pid=$!
|
||||
echo "$app_pid" > "$PID_FILE"
|
||||
|
||||
echo "[$(date -u +"%Y-%m-%dT%H:%M:%SZ")] app pid=$app_pid" | tee -a "$RUN_LOG"
|
||||
|
||||
nohup env PID_FILE="$PID_FILE" LOG_FILE="$WATCHDOG_LOG" CHECK_INTERVAL="$CHECK_INTERVAL" \
|
||||
bash scripts/watchdog.sh >/dev/null 2>&1 &
|
||||
watchdog_pid=$!
|
||||
echo "$watchdog_pid" > "$WATCHDOG_PID_FILE"
|
||||
|
||||
cat <<EOF
|
||||
시작 완료
|
||||
- app pid: $app_pid
|
||||
- watchdog pid: $watchdog_pid
|
||||
- app log: $RUN_LOG
|
||||
- watchdog log: $WATCHDOG_LOG
|
||||
|
||||
실시간 확인:
|
||||
tail -f "$RUN_LOG"
|
||||
tail -f "$WATCHDOG_LOG"
|
||||
EOF
|
||||
|
||||
if [ "$TMUX_AUTO" = "true" ]; then
|
||||
if ! command -v tmux >/dev/null 2>&1; then
|
||||
echo "tmux를 찾지 못해 자동 세션 생성은 건너뜁니다."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
session_name="${TMUX_SESSION_PREFIX}_${timestamp}"
|
||||
window_name="overnight"
|
||||
tmux new-session -d -s "$session_name" -n "$window_name" "tail -f '$RUN_LOG'"
|
||||
tmux split-window -t "${session_name}:${window_name}" -v "tail -f '$WATCHDOG_LOG'"
|
||||
tmux select-layout -t "${session_name}:${window_name}" even-vertical
|
||||
|
||||
echo "tmux session 생성: $session_name"
|
||||
echo "수동 접속: tmux attach -t $session_name"
|
||||
|
||||
if [ -z "${TMUX:-}" ] && [ "$TMUX_ATTACH" = "true" ]; then
|
||||
tmux attach -t "$session_name"
|
||||
fi
|
||||
fi
|
||||
76
scripts/stop_overnight.sh
Executable file
76
scripts/stop_overnight.sh
Executable file
@@ -0,0 +1,76 @@
|
||||
#!/usr/bin/env bash
|
||||
# Stop The Ouroboros overnight app/watchdog/tmux session.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
LOG_DIR="${LOG_DIR:-data/overnight}"
|
||||
PID_FILE="$LOG_DIR/app.pid"
|
||||
WATCHDOG_PID_FILE="$LOG_DIR/watchdog.pid"
|
||||
TMUX_SESSION_PREFIX="${TMUX_SESSION_PREFIX:-ouroboros_overnight}"
|
||||
KILL_TIMEOUT="${KILL_TIMEOUT:-5}"
|
||||
|
||||
stop_pid() {
|
||||
local name="$1"
|
||||
local pid="$2"
|
||||
|
||||
if [ -z "$pid" ]; then
|
||||
echo "$name PID가 비어 있습니다."
|
||||
return 1
|
||||
fi
|
||||
|
||||
if ! kill -0 "$pid" 2>/dev/null; then
|
||||
echo "$name 프로세스가 이미 종료됨 (pid=$pid)"
|
||||
return 0
|
||||
fi
|
||||
|
||||
kill "$pid" 2>/dev/null || true
|
||||
for _ in $(seq 1 "$KILL_TIMEOUT"); do
|
||||
if ! kill -0 "$pid" 2>/dev/null; then
|
||||
echo "$name 종료됨 (pid=$pid)"
|
||||
return 0
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
kill -9 "$pid" 2>/dev/null || true
|
||||
if ! kill -0 "$pid" 2>/dev/null; then
|
||||
echo "$name 강제 종료됨 (pid=$pid)"
|
||||
return 0
|
||||
fi
|
||||
|
||||
echo "$name 종료 실패 (pid=$pid)"
|
||||
return 1
|
||||
}
|
||||
|
||||
status=0
|
||||
|
||||
if [ -f "$WATCHDOG_PID_FILE" ]; then
|
||||
watchdog_pid="$(cat "$WATCHDOG_PID_FILE" || true)"
|
||||
stop_pid "watchdog" "$watchdog_pid" || status=1
|
||||
rm -f "$WATCHDOG_PID_FILE"
|
||||
else
|
||||
echo "watchdog pid 파일 없음: $WATCHDOG_PID_FILE"
|
||||
fi
|
||||
|
||||
if [ -f "$PID_FILE" ]; then
|
||||
app_pid="$(cat "$PID_FILE" || true)"
|
||||
stop_pid "app" "$app_pid" || status=1
|
||||
rm -f "$PID_FILE"
|
||||
else
|
||||
echo "app pid 파일 없음: $PID_FILE"
|
||||
fi
|
||||
|
||||
if command -v tmux >/dev/null 2>&1; then
|
||||
sessions="$(tmux ls 2>/dev/null | awk -F: -v p="$TMUX_SESSION_PREFIX" '$1 ~ "^" p "_" {print $1}')"
|
||||
if [ -n "$sessions" ]; then
|
||||
while IFS= read -r s; do
|
||||
[ -z "$s" ] && continue
|
||||
tmux kill-session -t "$s" 2>/dev/null || true
|
||||
echo "tmux 세션 종료: $s"
|
||||
done <<< "$sessions"
|
||||
else
|
||||
echo "종료할 tmux 세션 없음 (prefix=${TMUX_SESSION_PREFIX}_)"
|
||||
fi
|
||||
fi
|
||||
|
||||
exit "$status"
|
||||
42
scripts/watchdog.sh
Executable file
42
scripts/watchdog.sh
Executable file
@@ -0,0 +1,42 @@
|
||||
#!/usr/bin/env bash
|
||||
# Simple watchdog for The Ouroboros process.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
PID_FILE="${PID_FILE:-data/overnight/app.pid}"
|
||||
LOG_FILE="${LOG_FILE:-data/overnight/watchdog.log}"
|
||||
CHECK_INTERVAL="${CHECK_INTERVAL:-30}"
|
||||
STATUS_EVERY="${STATUS_EVERY:-10}"
|
||||
|
||||
mkdir -p "$(dirname "$LOG_FILE")"
|
||||
|
||||
log() {
|
||||
printf '%s %s\n' "$(date -u +"%Y-%m-%dT%H:%M:%SZ")" "$1" | tee -a "$LOG_FILE"
|
||||
}
|
||||
|
||||
if [ ! -f "$PID_FILE" ]; then
|
||||
log "[ERROR] pid file not found: $PID_FILE"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PID="$(cat "$PID_FILE")"
|
||||
if [ -z "$PID" ]; then
|
||||
log "[ERROR] pid file is empty: $PID_FILE"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
log "[INFO] watchdog started (pid=$PID, interval=${CHECK_INTERVAL}s)"
|
||||
|
||||
count=0
|
||||
while true; do
|
||||
if kill -0 "$PID" 2>/dev/null; then
|
||||
count=$((count + 1))
|
||||
if [ $((count % STATUS_EVERY)) -eq 0 ]; then
|
||||
log "[INFO] process alive (pid=$PID)"
|
||||
fi
|
||||
else
|
||||
log "[ERROR] process stopped (pid=$PID)"
|
||||
exit 1
|
||||
fi
|
||||
sleep "$CHECK_INTERVAL"
|
||||
done
|
||||
9
src/analysis/__init__.py
Normal file
9
src/analysis/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""Technical analysis and market scanning modules."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from src.analysis.scanner import MarketScanner
|
||||
from src.analysis.smart_scanner import ScanCandidate, SmartVolatilityScanner
|
||||
from src.analysis.volatility import VolatilityAnalyzer
|
||||
|
||||
__all__ = ["VolatilityAnalyzer", "MarketScanner", "SmartVolatilityScanner", "ScanCandidate"]
|
||||
244
src/analysis/scanner.py
Normal file
244
src/analysis/scanner.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""Real-time market scanner for detecting high-momentum stocks.
|
||||
|
||||
Scans all available stocks in a market and ranks by volatility/momentum score.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from src.analysis.volatility import VolatilityAnalyzer, VolatilityMetrics
|
||||
from src.broker.kis_api import KISBroker
|
||||
from src.broker.overseas import OverseasBroker
|
||||
from src.context.layer import ContextLayer
|
||||
from src.context.store import ContextStore
|
||||
from src.markets.schedule import MarketInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScanResult:
|
||||
"""Result from a market scan."""
|
||||
|
||||
market_code: str
|
||||
timestamp: str
|
||||
total_scanned: int
|
||||
top_movers: list[VolatilityMetrics]
|
||||
breakouts: list[str] # Stock codes with breakout patterns
|
||||
breakdowns: list[str] # Stock codes with breakdown patterns
|
||||
|
||||
|
||||
class MarketScanner:
|
||||
"""Scans markets for high-volatility, high-momentum stocks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
broker: KISBroker,
|
||||
overseas_broker: OverseasBroker,
|
||||
volatility_analyzer: VolatilityAnalyzer,
|
||||
context_store: ContextStore,
|
||||
top_n: int = 5,
|
||||
max_concurrent_scans: int = 1,
|
||||
) -> None:
|
||||
"""Initialize the market scanner.
|
||||
|
||||
Args:
|
||||
broker: KIS broker instance for domestic market
|
||||
overseas_broker: Overseas broker instance
|
||||
volatility_analyzer: Volatility analyzer instance
|
||||
context_store: Context store for L7 real-time data
|
||||
top_n: Number of top movers to return per market (default 5)
|
||||
max_concurrent_scans: Max concurrent stock scans (default 1, fully serialized)
|
||||
"""
|
||||
self.broker = broker
|
||||
self.overseas_broker = overseas_broker
|
||||
self.analyzer = volatility_analyzer
|
||||
self.context_store = context_store
|
||||
self.top_n = top_n
|
||||
self._scan_semaphore = asyncio.Semaphore(max_concurrent_scans)
|
||||
|
||||
async def scan_stock(
|
||||
self,
|
||||
stock_code: str,
|
||||
market: MarketInfo,
|
||||
) -> VolatilityMetrics | None:
|
||||
"""Scan a single stock for volatility metrics.
|
||||
|
||||
Args:
|
||||
stock_code: Stock code to scan
|
||||
market: Market information
|
||||
|
||||
Returns:
|
||||
VolatilityMetrics if successful, None on error
|
||||
"""
|
||||
try:
|
||||
if market.is_domestic:
|
||||
orderbook = await self.broker.get_orderbook(stock_code)
|
||||
else:
|
||||
# For overseas, we need to adapt the price data structure
|
||||
price_data = await self.overseas_broker.get_overseas_price(
|
||||
market.exchange_code, stock_code
|
||||
)
|
||||
# Convert to orderbook-like structure
|
||||
orderbook = {
|
||||
"output1": {
|
||||
"stck_prpr": price_data.get("output", {}).get("last", "0") or "0",
|
||||
"acml_vol": price_data.get("output", {}).get("tvol", "0") or "0",
|
||||
}
|
||||
}
|
||||
|
||||
# For now, use empty price history (would need real historical data)
|
||||
# In production, this would fetch from a time-series database or API
|
||||
price_history: dict[str, Any] = {
|
||||
"high": [],
|
||||
"low": [],
|
||||
"close": [],
|
||||
"volume": [],
|
||||
}
|
||||
|
||||
metrics = self.analyzer.analyze(stock_code, orderbook, price_history)
|
||||
|
||||
# Store in L7 real-time layer
|
||||
from datetime import UTC, datetime
|
||||
timeframe = datetime.now(UTC).isoformat()
|
||||
self.context_store.set_context(
|
||||
ContextLayer.L7_REALTIME,
|
||||
timeframe,
|
||||
f"volatility_{market.code}_{stock_code}",
|
||||
{
|
||||
"price": metrics.current_price,
|
||||
"atr": metrics.atr,
|
||||
"price_change_1m": metrics.price_change_1m,
|
||||
"volume_surge": metrics.volume_surge,
|
||||
"momentum_score": metrics.momentum_score,
|
||||
},
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to scan %s (%s): %s", stock_code, market.code, exc)
|
||||
return None
|
||||
|
||||
async def scan_market(
|
||||
self,
|
||||
market: MarketInfo,
|
||||
stock_codes: list[str],
|
||||
) -> ScanResult:
|
||||
"""Scan all stocks in a market and rank by momentum.
|
||||
|
||||
Args:
|
||||
market: Market to scan
|
||||
stock_codes: List of stock codes to scan
|
||||
|
||||
Returns:
|
||||
ScanResult with ranked stocks
|
||||
"""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
logger.info("Scanning %s market (%d stocks)", market.name, len(stock_codes))
|
||||
|
||||
# Scan stocks with bounded concurrency to prevent API rate limit burst
|
||||
async def _bounded_scan(code: str) -> VolatilityMetrics | None:
|
||||
async with self._scan_semaphore:
|
||||
return await self.scan_stock(code, market)
|
||||
|
||||
tasks = [_bounded_scan(code) for code in stock_codes]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Filter out failures and sort by momentum score
|
||||
valid_metrics = [m for m in results if m is not None]
|
||||
valid_metrics.sort(key=lambda m: m.momentum_score, reverse=True)
|
||||
|
||||
# Get top N movers
|
||||
top_movers = valid_metrics[: self.top_n]
|
||||
|
||||
# Detect breakouts and breakdowns
|
||||
breakouts = [
|
||||
m.stock_code for m in valid_metrics if self.analyzer.is_breakout(m)
|
||||
]
|
||||
breakdowns = [
|
||||
m.stock_code for m in valid_metrics if self.analyzer.is_breakdown(m)
|
||||
]
|
||||
|
||||
logger.info(
|
||||
"%s scan complete: %d scanned, top momentum=%.1f, %d breakouts, %d breakdowns",
|
||||
market.name,
|
||||
len(valid_metrics),
|
||||
top_movers[0].momentum_score if top_movers else 0.0,
|
||||
len(breakouts),
|
||||
len(breakdowns),
|
||||
)
|
||||
|
||||
# Store scan results in L7
|
||||
timeframe = datetime.now(UTC).isoformat()
|
||||
self.context_store.set_context(
|
||||
ContextLayer.L7_REALTIME,
|
||||
timeframe,
|
||||
f"scan_result_{market.code}",
|
||||
{
|
||||
"total_scanned": len(valid_metrics),
|
||||
"top_movers": [m.stock_code for m in top_movers],
|
||||
"breakouts": breakouts,
|
||||
"breakdowns": breakdowns,
|
||||
},
|
||||
)
|
||||
|
||||
return ScanResult(
|
||||
market_code=market.code,
|
||||
timestamp=timeframe,
|
||||
total_scanned=len(valid_metrics),
|
||||
top_movers=top_movers,
|
||||
breakouts=breakouts,
|
||||
breakdowns=breakdowns,
|
||||
)
|
||||
|
||||
def get_updated_watchlist(
|
||||
self,
|
||||
current_watchlist: list[str],
|
||||
scan_result: ScanResult,
|
||||
max_replacements: int = 2,
|
||||
) -> list[str]:
|
||||
"""Update watchlist by replacing laggards with leaders.
|
||||
|
||||
Args:
|
||||
current_watchlist: Current watchlist
|
||||
scan_result: Recent scan result
|
||||
max_replacements: Maximum stocks to replace per scan
|
||||
|
||||
Returns:
|
||||
Updated watchlist with leaders
|
||||
"""
|
||||
# Keep stocks that are in top movers
|
||||
top_codes = [m.stock_code for m in scan_result.top_movers]
|
||||
keepers = [code for code in current_watchlist if code in top_codes]
|
||||
|
||||
# Add new leaders not in current watchlist
|
||||
new_leaders = [code for code in top_codes if code not in current_watchlist]
|
||||
|
||||
# Limit replacements
|
||||
new_leaders = new_leaders[:max_replacements]
|
||||
|
||||
# Create updated watchlist
|
||||
updated = keepers + new_leaders
|
||||
|
||||
# If we removed too many, backfill from current watchlist
|
||||
if len(updated) < len(current_watchlist):
|
||||
backfill = [
|
||||
code for code in current_watchlist
|
||||
if code not in updated
|
||||
][: len(current_watchlist) - len(updated)]
|
||||
updated.extend(backfill)
|
||||
|
||||
logger.info(
|
||||
"Watchlist updated: %d kept, %d new leaders, %d total",
|
||||
len(keepers),
|
||||
len(new_leaders),
|
||||
len(updated),
|
||||
)
|
||||
|
||||
return updated
|
||||
449
src/analysis/smart_scanner.py
Normal file
449
src/analysis/smart_scanner.py
Normal file
@@ -0,0 +1,449 @@
|
||||
"""Smart Volatility Scanner with volatility-first market ranking logic."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from src.analysis.volatility import VolatilityAnalyzer
|
||||
from src.broker.kis_api import KISBroker
|
||||
from src.broker.overseas import OverseasBroker
|
||||
from src.config import Settings
|
||||
from src.markets.schedule import MarketInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScanCandidate:
|
||||
"""A qualified candidate from the smart scanner."""
|
||||
|
||||
stock_code: str
|
||||
name: str
|
||||
price: float
|
||||
volume: float
|
||||
volume_ratio: float # Current volume / previous day volume
|
||||
rsi: float
|
||||
signal: str # "oversold" or "momentum"
|
||||
score: float # Composite score for ranking
|
||||
|
||||
|
||||
class SmartVolatilityScanner:
|
||||
"""Scans market rankings and applies volatility-first filters.
|
||||
|
||||
Flow:
|
||||
1. Fetch fluctuation rankings as primary universe
|
||||
2. Fetch volume rankings for liquidity bonus
|
||||
3. Score by volatility first, liquidity second
|
||||
4. Return top N qualified candidates
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
broker: KISBroker,
|
||||
overseas_broker: OverseasBroker | None,
|
||||
volatility_analyzer: VolatilityAnalyzer,
|
||||
settings: Settings,
|
||||
) -> None:
|
||||
"""Initialize the smart scanner.
|
||||
|
||||
Args:
|
||||
broker: KIS broker for API calls
|
||||
volatility_analyzer: Analyzer for RSI calculation
|
||||
settings: Application settings
|
||||
"""
|
||||
self.broker = broker
|
||||
self.overseas_broker = overseas_broker
|
||||
self.analyzer = volatility_analyzer
|
||||
self.settings = settings
|
||||
|
||||
# Extract scanner settings
|
||||
self.rsi_oversold = settings.RSI_OVERSOLD_THRESHOLD
|
||||
self.rsi_momentum = settings.RSI_MOMENTUM_THRESHOLD
|
||||
self.vol_multiplier = settings.VOL_MULTIPLIER
|
||||
self.top_n = settings.SCANNER_TOP_N
|
||||
|
||||
async def scan(
|
||||
self,
|
||||
market: MarketInfo | None = None,
|
||||
fallback_stocks: list[str] | None = None,
|
||||
) -> list[ScanCandidate]:
|
||||
"""Execute smart scan and return qualified candidates.
|
||||
|
||||
Args:
|
||||
market: Target market info (domestic vs overseas behavior)
|
||||
fallback_stocks: Stock codes to use if ranking API fails
|
||||
|
||||
Returns:
|
||||
List of ScanCandidate, sorted by score, up to top_n items
|
||||
"""
|
||||
if market and not market.is_domestic:
|
||||
return await self._scan_overseas(market, fallback_stocks)
|
||||
|
||||
return await self._scan_domestic(fallback_stocks)
|
||||
|
||||
async def _scan_domestic(
|
||||
self,
|
||||
fallback_stocks: list[str] | None = None,
|
||||
) -> list[ScanCandidate]:
|
||||
"""Scan domestic market using volatility-first ranking + liquidity bonus."""
|
||||
# 1) Primary universe from fluctuation ranking.
|
||||
try:
|
||||
fluct_rows = await self.broker.fetch_market_rankings(
|
||||
ranking_type="fluctuation",
|
||||
limit=50,
|
||||
)
|
||||
except ConnectionError as exc:
|
||||
logger.warning("Domestic fluctuation ranking failed: %s", exc)
|
||||
fluct_rows = []
|
||||
|
||||
# 2) Liquidity bonus from volume ranking.
|
||||
try:
|
||||
volume_rows = await self.broker.fetch_market_rankings(
|
||||
ranking_type="volume",
|
||||
limit=50,
|
||||
)
|
||||
except ConnectionError as exc:
|
||||
logger.warning("Domestic volume ranking failed: %s", exc)
|
||||
volume_rows = []
|
||||
|
||||
if not fluct_rows and fallback_stocks:
|
||||
logger.info(
|
||||
"Domestic ranking unavailable; using fallback symbols (%d)",
|
||||
len(fallback_stocks),
|
||||
)
|
||||
fluct_rows = [
|
||||
{
|
||||
"stock_code": code,
|
||||
"name": code,
|
||||
"price": 0.0,
|
||||
"volume": 0.0,
|
||||
"change_rate": 0.0,
|
||||
"volume_increase_rate": 0.0,
|
||||
}
|
||||
for code in fallback_stocks
|
||||
]
|
||||
|
||||
if not fluct_rows:
|
||||
return []
|
||||
|
||||
volume_rank_bonus: dict[str, float] = {}
|
||||
for idx, row in enumerate(volume_rows):
|
||||
code = _extract_stock_code(row)
|
||||
if not code:
|
||||
continue
|
||||
volume_rank_bonus[code] = max(0.0, 15.0 - idx * 0.3)
|
||||
|
||||
candidates: list[ScanCandidate] = []
|
||||
for stock in fluct_rows:
|
||||
stock_code = _extract_stock_code(stock)
|
||||
if not stock_code:
|
||||
continue
|
||||
|
||||
try:
|
||||
price = _extract_last_price(stock)
|
||||
change_rate = _extract_change_rate_pct(stock)
|
||||
volume = _extract_volume(stock)
|
||||
|
||||
intraday_range_pct = 0.0
|
||||
volume_ratio = _safe_float(stock.get("volume_increase_rate"), 0.0) / 100.0 + 1.0
|
||||
|
||||
# Use daily chart to refine range/volume when available.
|
||||
daily_prices = await self.broker.get_daily_prices(stock_code, days=2)
|
||||
if daily_prices:
|
||||
latest = daily_prices[-1]
|
||||
latest_close = _safe_float(latest.get("close"), default=price)
|
||||
if price <= 0:
|
||||
price = latest_close
|
||||
latest_high = _safe_float(latest.get("high"))
|
||||
latest_low = _safe_float(latest.get("low"))
|
||||
if latest_close > 0 and latest_high > 0 and latest_low > 0 and latest_high >= latest_low:
|
||||
intraday_range_pct = (latest_high - latest_low) / latest_close * 100.0
|
||||
if volume <= 0:
|
||||
volume = _safe_float(latest.get("volume"))
|
||||
if len(daily_prices) >= 2:
|
||||
prev_day_volume = _safe_float(daily_prices[-2].get("volume"))
|
||||
if prev_day_volume > 0:
|
||||
volume_ratio = max(volume_ratio, volume / prev_day_volume)
|
||||
|
||||
volatility_pct = max(abs(change_rate), intraday_range_pct)
|
||||
if price <= 0 or volatility_pct < 0.8:
|
||||
continue
|
||||
|
||||
volatility_score = min(volatility_pct / 10.0, 1.0) * 85.0
|
||||
liquidity_score = volume_rank_bonus.get(stock_code, 0.0)
|
||||
score = min(100.0, volatility_score + liquidity_score)
|
||||
signal = "momentum" if change_rate >= 0 else "oversold"
|
||||
implied_rsi = max(0.0, min(100.0, 50.0 + (change_rate * 2.0)))
|
||||
|
||||
candidates.append(
|
||||
ScanCandidate(
|
||||
stock_code=stock_code,
|
||||
name=stock.get("name", stock_code),
|
||||
price=price,
|
||||
volume=volume,
|
||||
volume_ratio=max(1.0, volume_ratio, volatility_pct / 2.0),
|
||||
rsi=implied_rsi,
|
||||
signal=signal,
|
||||
score=score,
|
||||
)
|
||||
)
|
||||
|
||||
except ConnectionError as exc:
|
||||
logger.warning("Failed to analyze %s: %s", stock_code, exc)
|
||||
continue
|
||||
except Exception as exc:
|
||||
logger.error("Unexpected error analyzing %s: %s", stock_code, exc)
|
||||
continue
|
||||
|
||||
logger.info("Domestic ranking scan found %d candidates", len(candidates))
|
||||
candidates.sort(key=lambda c: c.score, reverse=True)
|
||||
return candidates[: self.top_n]
|
||||
|
||||
async def _scan_overseas(
|
||||
self,
|
||||
market: MarketInfo,
|
||||
fallback_stocks: list[str] | None = None,
|
||||
) -> list[ScanCandidate]:
|
||||
"""Scan overseas symbols using ranking API first, then fallback universe."""
|
||||
if self.overseas_broker is None:
|
||||
logger.warning(
|
||||
"Overseas scanner unavailable for %s: overseas broker not configured",
|
||||
market.name,
|
||||
)
|
||||
return []
|
||||
|
||||
candidates = await self._scan_overseas_from_rankings(market)
|
||||
if not candidates:
|
||||
candidates = await self._scan_overseas_from_symbols(market, fallback_stocks)
|
||||
|
||||
candidates.sort(key=lambda c: c.score, reverse=True)
|
||||
return candidates[: self.top_n]
|
||||
|
||||
async def _scan_overseas_from_rankings(
|
||||
self,
|
||||
market: MarketInfo,
|
||||
) -> list[ScanCandidate]:
|
||||
"""Build overseas candidates from ranking APIs using volatility-first scoring."""
|
||||
assert self.overseas_broker is not None
|
||||
try:
|
||||
fluct_rows = await self.overseas_broker.fetch_overseas_rankings(
|
||||
exchange_code=market.exchange_code,
|
||||
ranking_type="fluctuation",
|
||||
limit=50,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Overseas fluctuation ranking failed for %s: %s", market.code, exc
|
||||
)
|
||||
fluct_rows = []
|
||||
|
||||
if not fluct_rows:
|
||||
return []
|
||||
|
||||
volume_rank_bonus: dict[str, float] = {}
|
||||
try:
|
||||
volume_rows = await self.overseas_broker.fetch_overseas_rankings(
|
||||
exchange_code=market.exchange_code,
|
||||
ranking_type="volume",
|
||||
limit=50,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Overseas volume ranking failed for %s: %s", market.code, exc
|
||||
)
|
||||
volume_rows = []
|
||||
|
||||
for idx, row in enumerate(volume_rows):
|
||||
code = _extract_stock_code(row)
|
||||
if not code:
|
||||
continue
|
||||
# Top-ranked by traded value/volume gets higher liquidity bonus.
|
||||
volume_rank_bonus[code] = max(0.0, 15.0 - idx * 0.3)
|
||||
|
||||
candidates: list[ScanCandidate] = []
|
||||
for row in fluct_rows:
|
||||
stock_code = _extract_stock_code(row)
|
||||
if not stock_code:
|
||||
continue
|
||||
|
||||
price = _extract_last_price(row)
|
||||
change_rate = _extract_change_rate_pct(row)
|
||||
volume = _extract_volume(row)
|
||||
intraday_range_pct = _extract_intraday_range_pct(row, price)
|
||||
volatility_pct = max(abs(change_rate), intraday_range_pct)
|
||||
|
||||
# Volatility-first filter (not simple gainers/value ranking).
|
||||
if price <= 0 or volatility_pct < 0.8:
|
||||
continue
|
||||
|
||||
volatility_score = min(volatility_pct / 10.0, 1.0) * 85.0
|
||||
liquidity_score = volume_rank_bonus.get(stock_code, 0.0)
|
||||
score = min(100.0, volatility_score + liquidity_score)
|
||||
signal = "momentum" if change_rate >= 0 else "oversold"
|
||||
implied_rsi = max(0.0, min(100.0, 50.0 + (change_rate * 2.0)))
|
||||
candidates.append(
|
||||
ScanCandidate(
|
||||
stock_code=stock_code,
|
||||
name=str(row.get("name") or row.get("ovrs_item_name") or stock_code),
|
||||
price=price,
|
||||
volume=volume,
|
||||
volume_ratio=max(1.0, volatility_pct / 2.0),
|
||||
rsi=implied_rsi,
|
||||
signal=signal,
|
||||
score=score,
|
||||
)
|
||||
)
|
||||
|
||||
if candidates:
|
||||
logger.info(
|
||||
"Overseas ranking scan found %d candidates for %s",
|
||||
len(candidates),
|
||||
market.name,
|
||||
)
|
||||
return candidates
|
||||
|
||||
async def _scan_overseas_from_symbols(
|
||||
self,
|
||||
market: MarketInfo,
|
||||
symbols: list[str] | None,
|
||||
) -> list[ScanCandidate]:
|
||||
"""Fallback overseas scan from dynamic symbol universe."""
|
||||
assert self.overseas_broker is not None
|
||||
if not symbols:
|
||||
logger.info("Overseas scanner: no symbol universe for %s", market.name)
|
||||
return []
|
||||
|
||||
logger.info(
|
||||
"Overseas scanner: scanning %d fallback symbols for %s",
|
||||
len(symbols),
|
||||
market.name,
|
||||
)
|
||||
candidates: list[ScanCandidate] = []
|
||||
for stock_code in symbols:
|
||||
try:
|
||||
price_data = await self.overseas_broker.get_overseas_price(
|
||||
market.exchange_code, stock_code
|
||||
)
|
||||
output = price_data.get("output", {})
|
||||
price = _extract_last_price(output)
|
||||
change_rate = _extract_change_rate_pct(output)
|
||||
volume = _extract_volume(output)
|
||||
intraday_range_pct = _extract_intraday_range_pct(output, price)
|
||||
volatility_pct = max(abs(change_rate), intraday_range_pct)
|
||||
|
||||
if price <= 0 or volatility_pct < 0.8:
|
||||
continue
|
||||
|
||||
score = min(volatility_pct / 10.0, 1.0) * 100.0
|
||||
signal = "momentum" if change_rate >= 0 else "oversold"
|
||||
implied_rsi = max(0.0, min(100.0, 50.0 + (change_rate * 2.0)))
|
||||
candidates.append(
|
||||
ScanCandidate(
|
||||
stock_code=stock_code,
|
||||
name=stock_code,
|
||||
price=price,
|
||||
volume=volume,
|
||||
volume_ratio=max(1.0, volatility_pct / 2.0),
|
||||
rsi=implied_rsi,
|
||||
signal=signal,
|
||||
score=score,
|
||||
)
|
||||
)
|
||||
except ConnectionError as exc:
|
||||
logger.warning("Failed to analyze overseas %s: %s", stock_code, exc)
|
||||
except Exception as exc:
|
||||
logger.error("Unexpected error analyzing overseas %s: %s", stock_code, exc)
|
||||
logger.info(
|
||||
"Overseas symbol fallback scan found %d candidates for %s",
|
||||
len(candidates),
|
||||
market.name,
|
||||
)
|
||||
return candidates
|
||||
|
||||
def get_stock_codes(self, candidates: list[ScanCandidate]) -> list[str]:
|
||||
"""Extract stock codes from candidates for watchlist update.
|
||||
|
||||
Args:
|
||||
candidates: List of scan candidates
|
||||
|
||||
Returns:
|
||||
List of stock codes
|
||||
"""
|
||||
return [c.stock_code for c in candidates]
|
||||
|
||||
|
||||
def _safe_float(value: Any, default: float = 0.0) -> float:
|
||||
"""Convert arbitrary values to float safely."""
|
||||
if value in (None, ""):
|
||||
return default
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
def _extract_stock_code(row: dict[str, Any]) -> str:
|
||||
"""Extract normalized stock code from various API schemas."""
|
||||
return (
|
||||
str(
|
||||
row.get("symb")
|
||||
or row.get("ovrs_pdno")
|
||||
or row.get("stock_code")
|
||||
or row.get("pdno")
|
||||
or ""
|
||||
)
|
||||
.strip()
|
||||
.upper()
|
||||
)
|
||||
|
||||
|
||||
def _extract_last_price(row: dict[str, Any]) -> float:
|
||||
"""Extract last/close-like price from API schema variants."""
|
||||
return _safe_float(
|
||||
row.get("last")
|
||||
or row.get("ovrs_nmix_prpr")
|
||||
or row.get("stck_prpr")
|
||||
or row.get("price")
|
||||
or row.get("close")
|
||||
)
|
||||
|
||||
|
||||
def _extract_change_rate_pct(row: dict[str, Any]) -> float:
|
||||
"""Extract daily change rate (%) from API schema variants."""
|
||||
return _safe_float(
|
||||
row.get("rate")
|
||||
or row.get("change_rate")
|
||||
or row.get("prdy_ctrt")
|
||||
or row.get("evlu_pfls_rt")
|
||||
or row.get("chg_rt")
|
||||
)
|
||||
|
||||
|
||||
def _extract_volume(row: dict[str, Any]) -> float:
|
||||
"""Extract volume/traded-amount proxy from schema variants."""
|
||||
return _safe_float(
|
||||
row.get("tvol") or row.get("acml_vol") or row.get("vol") or row.get("volume")
|
||||
)
|
||||
|
||||
|
||||
def _extract_intraday_range_pct(row: dict[str, Any], price: float) -> float:
|
||||
"""Estimate intraday range percentage from high/low fields."""
|
||||
if price <= 0:
|
||||
return 0.0
|
||||
high = _safe_float(
|
||||
row.get("high")
|
||||
or row.get("ovrs_hgpr")
|
||||
or row.get("stck_hgpr")
|
||||
or row.get("day_hgpr")
|
||||
)
|
||||
low = _safe_float(
|
||||
row.get("low")
|
||||
or row.get("ovrs_lwpr")
|
||||
or row.get("stck_lwpr")
|
||||
or row.get("day_lwpr")
|
||||
)
|
||||
if high <= 0 or low <= 0 or high < low:
|
||||
return 0.0
|
||||
return (high - low) / price * 100.0
|
||||
373
src/analysis/volatility.py
Normal file
373
src/analysis/volatility.py
Normal file
@@ -0,0 +1,373 @@
|
||||
"""Volatility and momentum analysis for stock selection.
|
||||
|
||||
Calculates ATR, price change percentages, volume surges, and price-volume divergence.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class VolatilityMetrics:
|
||||
"""Volatility and momentum metrics for a stock."""
|
||||
|
||||
stock_code: str
|
||||
current_price: float
|
||||
atr: float # Average True Range (14 periods)
|
||||
price_change_1m: float # 1-minute price change %
|
||||
price_change_5m: float # 5-minute price change %
|
||||
price_change_15m: float # 15-minute price change %
|
||||
volume_surge: float # Volume vs average (ratio)
|
||||
pv_divergence: float # Price-volume divergence score
|
||||
momentum_score: float # Combined momentum score (0-100)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"VolatilityMetrics({self.stock_code}: "
|
||||
f"price={self.current_price:.2f}, "
|
||||
f"atr={self.atr:.2f}, "
|
||||
f"1m={self.price_change_1m:.2f}%, "
|
||||
f"vol_surge={self.volume_surge:.2f}x, "
|
||||
f"momentum={self.momentum_score:.1f})"
|
||||
)
|
||||
|
||||
|
||||
class VolatilityAnalyzer:
|
||||
"""Analyzes stock volatility and momentum for leader detection."""
|
||||
|
||||
def __init__(self, min_volume_surge: float = 2.0, min_price_change: float = 1.0) -> None:
|
||||
"""Initialize the volatility analyzer.
|
||||
|
||||
Args:
|
||||
min_volume_surge: Minimum volume surge ratio (default 2x average)
|
||||
min_price_change: Minimum price change % for breakout (default 1%)
|
||||
"""
|
||||
self.min_volume_surge = min_volume_surge
|
||||
self.min_price_change = min_price_change
|
||||
|
||||
def calculate_atr(
|
||||
self,
|
||||
high_prices: list[float],
|
||||
low_prices: list[float],
|
||||
close_prices: list[float],
|
||||
period: int = 14,
|
||||
) -> float:
|
||||
"""Calculate Average True Range (ATR).
|
||||
|
||||
Args:
|
||||
high_prices: List of high prices (most recent last)
|
||||
low_prices: List of low prices (most recent last)
|
||||
close_prices: List of close prices (most recent last)
|
||||
period: ATR period (default 14)
|
||||
|
||||
Returns:
|
||||
ATR value
|
||||
"""
|
||||
if (
|
||||
len(high_prices) < period + 1
|
||||
or len(low_prices) < period + 1
|
||||
or len(close_prices) < period + 1
|
||||
):
|
||||
return 0.0
|
||||
|
||||
true_ranges: list[float] = []
|
||||
for i in range(1, len(high_prices)):
|
||||
high = high_prices[i]
|
||||
low = low_prices[i]
|
||||
prev_close = close_prices[i - 1]
|
||||
|
||||
tr = max(
|
||||
high - low,
|
||||
abs(high - prev_close),
|
||||
abs(low - prev_close),
|
||||
)
|
||||
true_ranges.append(tr)
|
||||
|
||||
if len(true_ranges) < period:
|
||||
return 0.0
|
||||
|
||||
# Simple Moving Average of True Range
|
||||
recent_tr = true_ranges[-period:]
|
||||
return sum(recent_tr) / len(recent_tr)
|
||||
|
||||
def calculate_price_change(
|
||||
self, current_price: float, past_price: float
|
||||
) -> float:
|
||||
"""Calculate price change percentage.
|
||||
|
||||
Args:
|
||||
current_price: Current price
|
||||
past_price: Past price to compare against
|
||||
|
||||
Returns:
|
||||
Price change percentage
|
||||
"""
|
||||
if past_price == 0:
|
||||
return 0.0
|
||||
return ((current_price - past_price) / past_price) * 100
|
||||
|
||||
def calculate_volume_surge(
|
||||
self, current_volume: float, avg_volume: float
|
||||
) -> float:
|
||||
"""Calculate volume surge ratio.
|
||||
|
||||
Args:
|
||||
current_volume: Current volume
|
||||
avg_volume: Average volume
|
||||
|
||||
Returns:
|
||||
Volume surge ratio (current / average)
|
||||
"""
|
||||
if avg_volume == 0:
|
||||
return 1.0
|
||||
return current_volume / avg_volume
|
||||
|
||||
def calculate_rsi(
|
||||
self,
|
||||
close_prices: list[float],
|
||||
period: int = 14,
|
||||
) -> float:
|
||||
"""Calculate Relative Strength Index (RSI) using Wilder's smoothing.
|
||||
|
||||
Args:
|
||||
close_prices: List of closing prices (oldest to newest, minimum period+1 values)
|
||||
period: RSI period (default 14)
|
||||
|
||||
Returns:
|
||||
RSI value between 0 and 100, or 50.0 (neutral) if insufficient data
|
||||
|
||||
Examples:
|
||||
>>> analyzer = VolatilityAnalyzer()
|
||||
>>> prices = [100 - i * 0.5 for i in range(20)] # Downtrend
|
||||
>>> rsi = analyzer.calculate_rsi(prices)
|
||||
>>> assert rsi < 50 # Oversold territory
|
||||
"""
|
||||
if len(close_prices) < period + 1:
|
||||
return 50.0 # Neutral RSI if insufficient data
|
||||
|
||||
# Calculate price changes
|
||||
changes = [close_prices[i] - close_prices[i - 1] for i in range(1, len(close_prices))]
|
||||
|
||||
# Separate gains and losses
|
||||
gains = [max(0.0, change) for change in changes]
|
||||
losses = [max(0.0, -change) for change in changes]
|
||||
|
||||
# Calculate initial average gain/loss (simple average for first period)
|
||||
avg_gain = sum(gains[:period]) / period
|
||||
avg_loss = sum(losses[:period]) / period
|
||||
|
||||
# Apply Wilder's smoothing for remaining periods
|
||||
for i in range(period, len(changes)):
|
||||
avg_gain = (avg_gain * (period - 1) + gains[i]) / period
|
||||
avg_loss = (avg_loss * (period - 1) + losses[i]) / period
|
||||
|
||||
# Calculate RS and RSI
|
||||
if avg_loss == 0:
|
||||
return 100.0 # All gains, maximum RSI
|
||||
|
||||
rs = avg_gain / avg_loss
|
||||
rsi = 100 - (100 / (1 + rs))
|
||||
|
||||
return rsi
|
||||
|
||||
def calculate_pv_divergence(
|
||||
self,
|
||||
price_change: float,
|
||||
volume_surge: float,
|
||||
) -> float:
|
||||
"""Calculate price-volume divergence score.
|
||||
|
||||
Positive divergence: Price up + Volume up = bullish
|
||||
Negative divergence: Price up + Volume down = bearish
|
||||
Neutral: Price/volume move together moderately
|
||||
|
||||
Args:
|
||||
price_change: Price change percentage
|
||||
volume_surge: Volume surge ratio
|
||||
|
||||
Returns:
|
||||
Divergence score (-100 to +100)
|
||||
"""
|
||||
# Normalize volume surge to -1 to +1 scale (1.0 = neutral)
|
||||
volume_signal = (volume_surge - 1.0) * 10 # Scale for sensitivity
|
||||
|
||||
# Calculate divergence
|
||||
# Positive: price and volume move in same direction
|
||||
# Negative: price and volume move in opposite directions
|
||||
if price_change > 0 and volume_surge > 1.0:
|
||||
# Bullish: price up, volume up
|
||||
return min(100.0, price_change * volume_signal)
|
||||
elif price_change < 0 and volume_surge < 1.0:
|
||||
# Bearish confirmation: price down, volume down
|
||||
return max(-100.0, price_change * volume_signal)
|
||||
elif price_change > 0 and volume_surge < 1.0:
|
||||
# Bearish divergence: price up but volume low (weak rally)
|
||||
return -abs(price_change) * 0.5
|
||||
elif price_change < 0 and volume_surge > 1.0:
|
||||
# Selling pressure: price down, volume up
|
||||
return price_change * volume_signal
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def calculate_momentum_score(
|
||||
self,
|
||||
price_change_1m: float,
|
||||
price_change_5m: float,
|
||||
price_change_15m: float,
|
||||
volume_surge: float,
|
||||
atr: float,
|
||||
current_price: float,
|
||||
) -> float:
|
||||
"""Calculate combined momentum score (0-100).
|
||||
|
||||
Weights:
|
||||
- 1m change: 40%
|
||||
- 5m change: 30%
|
||||
- 15m change: 20%
|
||||
- Volume surge: 10%
|
||||
|
||||
Args:
|
||||
price_change_1m: 1-minute price change %
|
||||
price_change_5m: 5-minute price change %
|
||||
price_change_15m: 15-minute price change %
|
||||
volume_surge: Volume surge ratio
|
||||
atr: Average True Range
|
||||
current_price: Current price
|
||||
|
||||
Returns:
|
||||
Momentum score (0-100)
|
||||
"""
|
||||
# Weight recent changes more heavily
|
||||
weighted_change = (
|
||||
price_change_1m * 0.4 +
|
||||
price_change_5m * 0.3 +
|
||||
price_change_15m * 0.2
|
||||
)
|
||||
|
||||
# Volume contribution (normalized to 0-10 scale)
|
||||
volume_contribution = min(10.0, (volume_surge - 1.0) * 5.0)
|
||||
|
||||
# Volatility bonus: higher ATR = higher potential (normalized)
|
||||
volatility_bonus = 0.0
|
||||
if current_price > 0:
|
||||
atr_pct = (atr / current_price) * 100
|
||||
volatility_bonus = min(10.0, atr_pct)
|
||||
|
||||
# Combine scores
|
||||
raw_score = weighted_change + volume_contribution + volatility_bonus
|
||||
|
||||
# Normalize to 0-100 scale
|
||||
# Assume typical momentum range is -10 to +30
|
||||
normalized = ((raw_score + 10) / 40) * 100
|
||||
|
||||
return max(0.0, min(100.0, normalized))
|
||||
|
||||
def analyze(
|
||||
self,
|
||||
stock_code: str,
|
||||
orderbook_data: dict[str, Any],
|
||||
price_history: dict[str, Any],
|
||||
) -> VolatilityMetrics:
|
||||
"""Analyze volatility and momentum for a stock.
|
||||
|
||||
Args:
|
||||
stock_code: Stock code
|
||||
orderbook_data: Current orderbook/quote data
|
||||
price_history: Historical price and volume data
|
||||
|
||||
Returns:
|
||||
VolatilityMetrics with calculated indicators
|
||||
"""
|
||||
# Extract current data from orderbook
|
||||
output1 = orderbook_data.get("output1", {})
|
||||
current_price = float(output1.get("stck_prpr", 0))
|
||||
current_volume = float(output1.get("acml_vol", 0))
|
||||
|
||||
# Extract historical data
|
||||
high_prices = price_history.get("high", [])
|
||||
low_prices = price_history.get("low", [])
|
||||
close_prices = price_history.get("close", [])
|
||||
volumes = price_history.get("volume", [])
|
||||
|
||||
# Calculate ATR
|
||||
atr = self.calculate_atr(high_prices, low_prices, close_prices)
|
||||
|
||||
# Calculate price changes (use historical data if available)
|
||||
price_change_1m = 0.0
|
||||
price_change_5m = 0.0
|
||||
price_change_15m = 0.0
|
||||
|
||||
if len(close_prices) > 0:
|
||||
if len(close_prices) >= 1:
|
||||
price_change_1m = self.calculate_price_change(
|
||||
current_price, close_prices[-1]
|
||||
)
|
||||
if len(close_prices) >= 5:
|
||||
price_change_5m = self.calculate_price_change(
|
||||
current_price, close_prices[-5]
|
||||
)
|
||||
if len(close_prices) >= 15:
|
||||
price_change_15m = self.calculate_price_change(
|
||||
current_price, close_prices[-15]
|
||||
)
|
||||
|
||||
# Calculate volume surge
|
||||
avg_volume = sum(volumes) / len(volumes) if volumes else current_volume
|
||||
volume_surge = self.calculate_volume_surge(current_volume, avg_volume)
|
||||
|
||||
# Calculate price-volume divergence
|
||||
pv_divergence = self.calculate_pv_divergence(price_change_1m, volume_surge)
|
||||
|
||||
# Calculate momentum score
|
||||
momentum_score = self.calculate_momentum_score(
|
||||
price_change_1m,
|
||||
price_change_5m,
|
||||
price_change_15m,
|
||||
volume_surge,
|
||||
atr,
|
||||
current_price,
|
||||
)
|
||||
|
||||
return VolatilityMetrics(
|
||||
stock_code=stock_code,
|
||||
current_price=current_price,
|
||||
atr=atr,
|
||||
price_change_1m=price_change_1m,
|
||||
price_change_5m=price_change_5m,
|
||||
price_change_15m=price_change_15m,
|
||||
volume_surge=volume_surge,
|
||||
pv_divergence=pv_divergence,
|
||||
momentum_score=momentum_score,
|
||||
)
|
||||
|
||||
def is_breakout(self, metrics: VolatilityMetrics) -> bool:
|
||||
"""Determine if a stock is experiencing a breakout.
|
||||
|
||||
Args:
|
||||
metrics: Volatility metrics for the stock
|
||||
|
||||
Returns:
|
||||
True if breakout conditions are met
|
||||
"""
|
||||
return (
|
||||
metrics.price_change_1m >= self.min_price_change
|
||||
and metrics.volume_surge >= self.min_volume_surge
|
||||
and metrics.pv_divergence > 0 # Bullish divergence
|
||||
)
|
||||
|
||||
def is_breakdown(self, metrics: VolatilityMetrics) -> bool:
|
||||
"""Determine if a stock is experiencing a breakdown.
|
||||
|
||||
Args:
|
||||
metrics: Volatility metrics for the stock
|
||||
|
||||
Returns:
|
||||
True if breakdown conditions are met
|
||||
"""
|
||||
return (
|
||||
metrics.price_change_1m <= -self.min_price_change
|
||||
and metrics.volume_surge >= self.min_volume_surge
|
||||
and metrics.pv_divergence < 0 # Bearish divergence
|
||||
)
|
||||
21
src/backup/__init__.py
Normal file
21
src/backup/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Backup and disaster recovery system for long-term sustainability.
|
||||
|
||||
This module provides:
|
||||
- Automated database backups (daily, weekly, monthly)
|
||||
- Multi-format exports (JSON, CSV, Parquet)
|
||||
- Cloud storage integration (S3-compatible)
|
||||
- Health monitoring and alerts
|
||||
"""
|
||||
|
||||
from src.backup.exporter import BackupExporter, ExportFormat
|
||||
from src.backup.scheduler import BackupScheduler, BackupPolicy
|
||||
from src.backup.cloud_storage import CloudStorage, S3Config
|
||||
|
||||
__all__ = [
|
||||
"BackupExporter",
|
||||
"ExportFormat",
|
||||
"BackupScheduler",
|
||||
"BackupPolicy",
|
||||
"CloudStorage",
|
||||
"S3Config",
|
||||
]
|
||||
274
src/backup/cloud_storage.py
Normal file
274
src/backup/cloud_storage.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""Cloud storage integration for off-site backups.
|
||||
|
||||
Supports S3-compatible storage providers:
|
||||
- AWS S3
|
||||
- MinIO
|
||||
- Backblaze B2
|
||||
- DigitalOcean Spaces
|
||||
- Cloudflare R2
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class S3Config:
|
||||
"""Configuration for S3-compatible storage."""
|
||||
|
||||
endpoint_url: str | None # None for AWS S3, custom URL for others
|
||||
access_key: str
|
||||
secret_key: str
|
||||
bucket_name: str
|
||||
region: str = "us-east-1"
|
||||
use_ssl: bool = True
|
||||
|
||||
|
||||
class CloudStorage:
|
||||
"""Upload backups to S3-compatible cloud storage."""
|
||||
|
||||
def __init__(self, config: S3Config) -> None:
|
||||
"""Initialize cloud storage client.
|
||||
|
||||
Args:
|
||||
config: S3 configuration
|
||||
|
||||
Raises:
|
||||
ImportError: If boto3 is not installed
|
||||
"""
|
||||
try:
|
||||
import boto3
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"boto3 is required for cloud storage. Install with: pip install boto3"
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.client = boto3.client(
|
||||
"s3",
|
||||
endpoint_url=config.endpoint_url,
|
||||
aws_access_key_id=config.access_key,
|
||||
aws_secret_access_key=config.secret_key,
|
||||
region_name=config.region,
|
||||
use_ssl=config.use_ssl,
|
||||
)
|
||||
|
||||
def upload_file(
|
||||
self,
|
||||
file_path: Path,
|
||||
object_key: str | None = None,
|
||||
metadata: dict[str, str] | None = None,
|
||||
) -> str:
|
||||
"""Upload a file to cloud storage.
|
||||
|
||||
Args:
|
||||
file_path: Local file to upload
|
||||
object_key: S3 object key (default: filename)
|
||||
metadata: Optional metadata to attach
|
||||
|
||||
Returns:
|
||||
S3 object key
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
Exception: If upload fails
|
||||
"""
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
if object_key is None:
|
||||
object_key = file_path.name
|
||||
|
||||
extra_args: dict[str, Any] = {}
|
||||
|
||||
# Add server-side encryption
|
||||
extra_args["ServerSideEncryption"] = "AES256"
|
||||
|
||||
# Add metadata if provided
|
||||
if metadata:
|
||||
extra_args["Metadata"] = metadata
|
||||
|
||||
logger.info("Uploading %s to s3://%s/%s", file_path.name, self.config.bucket_name, object_key)
|
||||
|
||||
try:
|
||||
self.client.upload_file(
|
||||
str(file_path),
|
||||
self.config.bucket_name,
|
||||
object_key,
|
||||
ExtraArgs=extra_args,
|
||||
)
|
||||
logger.info("Upload successful: %s", object_key)
|
||||
return object_key
|
||||
except Exception as exc:
|
||||
logger.error("Upload failed: %s", exc)
|
||||
raise
|
||||
|
||||
def download_file(self, object_key: str, local_path: Path) -> Path:
|
||||
"""Download a file from cloud storage.
|
||||
|
||||
Args:
|
||||
object_key: S3 object key
|
||||
local_path: Local destination path
|
||||
|
||||
Returns:
|
||||
Path to downloaded file
|
||||
|
||||
Raises:
|
||||
Exception: If download fails
|
||||
"""
|
||||
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info("Downloading s3://%s/%s to %s", self.config.bucket_name, object_key, local_path)
|
||||
|
||||
try:
|
||||
self.client.download_file(
|
||||
self.config.bucket_name,
|
||||
object_key,
|
||||
str(local_path),
|
||||
)
|
||||
logger.info("Download successful: %s", local_path)
|
||||
return local_path
|
||||
except Exception as exc:
|
||||
logger.error("Download failed: %s", exc)
|
||||
raise
|
||||
|
||||
def list_files(self, prefix: str = "") -> list[dict[str, Any]]:
|
||||
"""List files in cloud storage.
|
||||
|
||||
Args:
|
||||
prefix: Filter by object key prefix
|
||||
|
||||
Returns:
|
||||
List of file metadata dictionaries
|
||||
"""
|
||||
try:
|
||||
response = self.client.list_objects_v2(
|
||||
Bucket=self.config.bucket_name,
|
||||
Prefix=prefix,
|
||||
)
|
||||
|
||||
if "Contents" not in response:
|
||||
return []
|
||||
|
||||
files = []
|
||||
for obj in response["Contents"]:
|
||||
files.append(
|
||||
{
|
||||
"key": obj["Key"],
|
||||
"size_bytes": obj["Size"],
|
||||
"last_modified": obj["LastModified"],
|
||||
"etag": obj["ETag"],
|
||||
}
|
||||
)
|
||||
|
||||
return files
|
||||
except Exception as exc:
|
||||
logger.error("Failed to list files: %s", exc)
|
||||
raise
|
||||
|
||||
def delete_file(self, object_key: str) -> None:
|
||||
"""Delete a file from cloud storage.
|
||||
|
||||
Args:
|
||||
object_key: S3 object key
|
||||
|
||||
Raises:
|
||||
Exception: If deletion fails
|
||||
"""
|
||||
logger.info("Deleting s3://%s/%s", self.config.bucket_name, object_key)
|
||||
|
||||
try:
|
||||
self.client.delete_object(
|
||||
Bucket=self.config.bucket_name,
|
||||
Key=object_key,
|
||||
)
|
||||
logger.info("Deletion successful: %s", object_key)
|
||||
except Exception as exc:
|
||||
logger.error("Deletion failed: %s", exc)
|
||||
raise
|
||||
|
||||
def get_storage_stats(self) -> dict[str, Any]:
|
||||
"""Get cloud storage statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with storage stats
|
||||
"""
|
||||
try:
|
||||
files = self.list_files()
|
||||
|
||||
total_size = sum(f["size_bytes"] for f in files)
|
||||
total_count = len(files)
|
||||
|
||||
return {
|
||||
"total_files": total_count,
|
||||
"total_size_bytes": total_size,
|
||||
"total_size_mb": total_size / 1024 / 1024,
|
||||
"total_size_gb": total_size / 1024 / 1024 / 1024,
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.error("Failed to get storage stats: %s", exc)
|
||||
return {
|
||||
"error": str(exc),
|
||||
"total_files": 0,
|
||||
"total_size_bytes": 0,
|
||||
}
|
||||
|
||||
def verify_connection(self) -> bool:
|
||||
"""Verify connection to cloud storage.
|
||||
|
||||
Returns:
|
||||
True if connection is successful
|
||||
"""
|
||||
try:
|
||||
self.client.head_bucket(Bucket=self.config.bucket_name)
|
||||
logger.info("Cloud storage connection verified")
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.error("Cloud storage connection failed: %s", exc)
|
||||
return False
|
||||
|
||||
def create_bucket_if_not_exists(self) -> None:
|
||||
"""Create storage bucket if it doesn't exist.
|
||||
|
||||
Raises:
|
||||
Exception: If bucket creation fails
|
||||
"""
|
||||
try:
|
||||
self.client.head_bucket(Bucket=self.config.bucket_name)
|
||||
logger.info("Bucket already exists: %s", self.config.bucket_name)
|
||||
except self.client.exceptions.NoSuchBucket:
|
||||
logger.info("Creating bucket: %s", self.config.bucket_name)
|
||||
if self.config.region == "us-east-1":
|
||||
# us-east-1 requires special handling
|
||||
self.client.create_bucket(Bucket=self.config.bucket_name)
|
||||
else:
|
||||
self.client.create_bucket(
|
||||
Bucket=self.config.bucket_name,
|
||||
CreateBucketConfiguration={"LocationConstraint": self.config.region},
|
||||
)
|
||||
logger.info("Bucket created successfully")
|
||||
except Exception as exc:
|
||||
logger.error("Failed to verify/create bucket: %s", exc)
|
||||
raise
|
||||
|
||||
def enable_versioning(self) -> None:
|
||||
"""Enable versioning on the bucket.
|
||||
|
||||
Raises:
|
||||
Exception: If versioning enablement fails
|
||||
"""
|
||||
try:
|
||||
self.client.put_bucket_versioning(
|
||||
Bucket=self.config.bucket_name,
|
||||
VersioningConfiguration={"Status": "Enabled"},
|
||||
)
|
||||
logger.info("Versioning enabled for bucket: %s", self.config.bucket_name)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to enable versioning: %s", exc)
|
||||
raise
|
||||
326
src/backup/exporter.py
Normal file
326
src/backup/exporter.py
Normal file
@@ -0,0 +1,326 @@
|
||||
"""Multi-format database exporter for backups.
|
||||
|
||||
Supports JSON, CSV, and Parquet formats for different use cases:
|
||||
- JSON: Human-readable, easy to inspect
|
||||
- CSV: Analysis tools (Excel, pandas)
|
||||
- Parquet: Big data tools (Spark, DuckDB)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import gzip
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExportFormat(str, Enum):
|
||||
"""Supported export formats."""
|
||||
|
||||
JSON = "json"
|
||||
CSV = "csv"
|
||||
PARQUET = "parquet"
|
||||
|
||||
|
||||
class BackupExporter:
|
||||
"""Export database to multiple formats."""
|
||||
|
||||
def __init__(self, db_path: str) -> None:
|
||||
"""Initialize the exporter.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database
|
||||
"""
|
||||
self.db_path = db_path
|
||||
|
||||
def export_all(
|
||||
self,
|
||||
output_dir: Path,
|
||||
formats: list[ExportFormat] | None = None,
|
||||
compress: bool = True,
|
||||
incremental_since: datetime | None = None,
|
||||
) -> dict[ExportFormat, Path]:
|
||||
"""Export database to multiple formats.
|
||||
|
||||
Args:
|
||||
output_dir: Directory to write export files
|
||||
formats: List of formats to export (default: all)
|
||||
compress: Whether to gzip compress exports
|
||||
incremental_since: Only export records after this timestamp
|
||||
|
||||
Returns:
|
||||
Dictionary mapping format to output file path
|
||||
"""
|
||||
if formats is None:
|
||||
formats = [ExportFormat.JSON, ExportFormat.CSV, ExportFormat.PARQUET]
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
results: dict[ExportFormat, Path] = {}
|
||||
|
||||
for fmt in formats:
|
||||
try:
|
||||
output_file = self._export_format(
|
||||
fmt, output_dir, timestamp, compress, incremental_since
|
||||
)
|
||||
results[fmt] = output_file
|
||||
logger.info("Exported to %s: %s", fmt.value, output_file)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to export to %s: %s", fmt.value, exc)
|
||||
|
||||
return results
|
||||
|
||||
def _export_format(
|
||||
self,
|
||||
fmt: ExportFormat,
|
||||
output_dir: Path,
|
||||
timestamp: str,
|
||||
compress: bool,
|
||||
incremental_since: datetime | None,
|
||||
) -> Path:
|
||||
"""Export to a specific format.
|
||||
|
||||
Args:
|
||||
fmt: Export format
|
||||
output_dir: Output directory
|
||||
timestamp: Timestamp string for filename
|
||||
compress: Whether to compress
|
||||
incremental_since: Incremental export cutoff
|
||||
|
||||
Returns:
|
||||
Path to output file
|
||||
"""
|
||||
if fmt == ExportFormat.JSON:
|
||||
return self._export_json(output_dir, timestamp, compress, incremental_since)
|
||||
elif fmt == ExportFormat.CSV:
|
||||
return self._export_csv(output_dir, timestamp, compress, incremental_since)
|
||||
elif fmt == ExportFormat.PARQUET:
|
||||
return self._export_parquet(
|
||||
output_dir, timestamp, compress, incremental_since
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {fmt}")
|
||||
|
||||
def _get_trades(
|
||||
self, incremental_since: datetime | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch trades from database.
|
||||
|
||||
Args:
|
||||
incremental_since: Only fetch trades after this timestamp
|
||||
|
||||
Returns:
|
||||
List of trade records
|
||||
"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
if incremental_since:
|
||||
cursor = conn.execute(
|
||||
"SELECT * FROM trades WHERE timestamp > ?",
|
||||
(incremental_since.isoformat(),),
|
||||
)
|
||||
else:
|
||||
cursor = conn.execute("SELECT * FROM trades")
|
||||
|
||||
trades = [dict(row) for row in cursor.fetchall()]
|
||||
conn.close()
|
||||
|
||||
return trades
|
||||
|
||||
def _export_json(
|
||||
self,
|
||||
output_dir: Path,
|
||||
timestamp: str,
|
||||
compress: bool,
|
||||
incremental_since: datetime | None,
|
||||
) -> Path:
|
||||
"""Export to JSON format.
|
||||
|
||||
Args:
|
||||
output_dir: Output directory
|
||||
timestamp: Timestamp for filename
|
||||
compress: Whether to gzip
|
||||
incremental_since: Incremental cutoff
|
||||
|
||||
Returns:
|
||||
Path to output file
|
||||
"""
|
||||
trades = self._get_trades(incremental_since)
|
||||
|
||||
filename = f"trades_{timestamp}.json"
|
||||
if compress:
|
||||
filename += ".gz"
|
||||
|
||||
output_file = output_dir / filename
|
||||
|
||||
data = {
|
||||
"export_timestamp": datetime.now(UTC).isoformat(),
|
||||
"incremental_since": (
|
||||
incremental_since.isoformat() if incremental_since else None
|
||||
),
|
||||
"record_count": len(trades),
|
||||
"trades": trades,
|
||||
}
|
||||
|
||||
if compress:
|
||||
with gzip.open(output_file, "wt", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
else:
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
return output_file
|
||||
|
||||
def _export_csv(
|
||||
self,
|
||||
output_dir: Path,
|
||||
timestamp: str,
|
||||
compress: bool,
|
||||
incremental_since: datetime | None,
|
||||
) -> Path:
|
||||
"""Export to CSV format.
|
||||
|
||||
Args:
|
||||
output_dir: Output directory
|
||||
timestamp: Timestamp for filename
|
||||
compress: Whether to gzip
|
||||
incremental_since: Incremental cutoff
|
||||
|
||||
Returns:
|
||||
Path to output file
|
||||
"""
|
||||
trades = self._get_trades(incremental_since)
|
||||
|
||||
filename = f"trades_{timestamp}.csv"
|
||||
if compress:
|
||||
filename += ".gz"
|
||||
|
||||
output_file = output_dir / filename
|
||||
|
||||
if not trades:
|
||||
# Write empty CSV with headers
|
||||
if compress:
|
||||
with gzip.open(output_file, "wt", encoding="utf-8", newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(
|
||||
[
|
||||
"timestamp",
|
||||
"stock_code",
|
||||
"action",
|
||||
"quantity",
|
||||
"price",
|
||||
"confidence",
|
||||
"rationale",
|
||||
"pnl",
|
||||
]
|
||||
)
|
||||
else:
|
||||
with open(output_file, "w", encoding="utf-8", newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(
|
||||
[
|
||||
"timestamp",
|
||||
"stock_code",
|
||||
"action",
|
||||
"quantity",
|
||||
"price",
|
||||
"confidence",
|
||||
"rationale",
|
||||
"pnl",
|
||||
]
|
||||
)
|
||||
return output_file
|
||||
|
||||
# Get column names from first trade
|
||||
fieldnames = list(trades[0].keys())
|
||||
|
||||
if compress:
|
||||
with gzip.open(output_file, "wt", encoding="utf-8", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(trades)
|
||||
else:
|
||||
with open(output_file, "w", encoding="utf-8", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(trades)
|
||||
|
||||
return output_file
|
||||
|
||||
def _export_parquet(
|
||||
self,
|
||||
output_dir: Path,
|
||||
timestamp: str,
|
||||
compress: bool,
|
||||
incremental_since: datetime | None,
|
||||
) -> Path:
|
||||
"""Export to Parquet format.
|
||||
|
||||
Args:
|
||||
output_dir: Output directory
|
||||
timestamp: Timestamp for filename
|
||||
compress: Whether to compress (Parquet has built-in compression)
|
||||
incremental_since: Incremental cutoff
|
||||
|
||||
Returns:
|
||||
Path to output file
|
||||
"""
|
||||
trades = self._get_trades(incremental_since)
|
||||
|
||||
filename = f"trades_{timestamp}.parquet"
|
||||
output_file = output_dir / filename
|
||||
|
||||
try:
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"pyarrow is required for Parquet export. "
|
||||
"Install with: pip install pyarrow"
|
||||
)
|
||||
|
||||
# Convert to pyarrow table
|
||||
table = pa.Table.from_pylist(trades)
|
||||
|
||||
# Write with compression
|
||||
compression = "gzip" if compress else "none"
|
||||
pq.write_table(table, output_file, compression=compression)
|
||||
|
||||
return output_file
|
||||
|
||||
def get_export_stats(self) -> dict[str, Any]:
|
||||
"""Get statistics about exportable data.
|
||||
|
||||
Returns:
|
||||
Dictionary with data statistics
|
||||
"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
stats = {}
|
||||
|
||||
# Total trades
|
||||
cursor.execute("SELECT COUNT(*) FROM trades")
|
||||
stats["total_trades"] = cursor.fetchone()[0]
|
||||
|
||||
# Date range
|
||||
cursor.execute("SELECT MIN(timestamp), MAX(timestamp) FROM trades")
|
||||
min_date, max_date = cursor.fetchone()
|
||||
stats["date_range"] = {"earliest": min_date, "latest": max_date}
|
||||
|
||||
# Database size
|
||||
cursor.execute("SELECT page_count * page_size FROM pragma_page_count(), pragma_page_size()")
|
||||
stats["db_size_bytes"] = cursor.fetchone()[0]
|
||||
|
||||
conn.close()
|
||||
|
||||
return stats
|
||||
282
src/backup/health_monitor.py
Normal file
282
src/backup/health_monitor.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""Health monitoring for backup system.
|
||||
|
||||
Checks:
|
||||
- Database accessibility and integrity
|
||||
- Disk space availability
|
||||
- Backup success/failure tracking
|
||||
- Self-healing capabilities
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
import sqlite3
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HealthStatus(str, Enum):
|
||||
"""Health check status."""
|
||||
|
||||
HEALTHY = "healthy"
|
||||
DEGRADED = "degraded"
|
||||
UNHEALTHY = "unhealthy"
|
||||
|
||||
|
||||
@dataclass
|
||||
class HealthCheckResult:
|
||||
"""Result of a health check."""
|
||||
|
||||
status: HealthStatus
|
||||
message: str
|
||||
details: dict[str, Any] | None = None
|
||||
timestamp: datetime | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.timestamp is None:
|
||||
self.timestamp = datetime.now(UTC)
|
||||
|
||||
|
||||
class HealthMonitor:
|
||||
"""Monitor system health and backup status."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_path: str,
|
||||
backup_dir: Path,
|
||||
min_disk_space_gb: float = 10.0,
|
||||
max_backup_age_hours: int = 25, # Daily backups should be < 25 hours old
|
||||
) -> None:
|
||||
"""Initialize health monitor.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database
|
||||
backup_dir: Backup directory
|
||||
min_disk_space_gb: Minimum required disk space in GB
|
||||
max_backup_age_hours: Maximum acceptable backup age in hours
|
||||
"""
|
||||
self.db_path = Path(db_path)
|
||||
self.backup_dir = backup_dir
|
||||
self.min_disk_space_bytes = int(min_disk_space_gb * 1024 * 1024 * 1024)
|
||||
self.max_backup_age = timedelta(hours=max_backup_age_hours)
|
||||
|
||||
def check_database_health(self) -> HealthCheckResult:
|
||||
"""Check database accessibility and integrity.
|
||||
|
||||
Returns:
|
||||
HealthCheckResult
|
||||
"""
|
||||
# Check if database exists
|
||||
if not self.db_path.exists():
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.UNHEALTHY,
|
||||
message=f"Database not found: {self.db_path}",
|
||||
)
|
||||
|
||||
# Check if database is accessible
|
||||
try:
|
||||
conn = sqlite3.connect(str(self.db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Run integrity check
|
||||
cursor.execute("PRAGMA integrity_check")
|
||||
result = cursor.fetchone()[0]
|
||||
|
||||
if result != "ok":
|
||||
conn.close()
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.UNHEALTHY,
|
||||
message=f"Database integrity check failed: {result}",
|
||||
)
|
||||
|
||||
# Get database size
|
||||
cursor.execute(
|
||||
"SELECT page_count * page_size FROM pragma_page_count(), pragma_page_size()"
|
||||
)
|
||||
db_size = cursor.fetchone()[0]
|
||||
|
||||
# Get row counts
|
||||
cursor.execute("SELECT COUNT(*) FROM trades")
|
||||
trade_count = cursor.fetchone()[0]
|
||||
|
||||
conn.close()
|
||||
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.HEALTHY,
|
||||
message="Database is healthy",
|
||||
details={
|
||||
"size_bytes": db_size,
|
||||
"size_mb": db_size / 1024 / 1024,
|
||||
"trade_count": trade_count,
|
||||
},
|
||||
)
|
||||
|
||||
except sqlite3.Error as exc:
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.UNHEALTHY,
|
||||
message=f"Database access error: {exc}",
|
||||
)
|
||||
|
||||
def check_disk_space(self) -> HealthCheckResult:
|
||||
"""Check available disk space.
|
||||
|
||||
Returns:
|
||||
HealthCheckResult
|
||||
"""
|
||||
try:
|
||||
stat = shutil.disk_usage(self.backup_dir)
|
||||
|
||||
free_gb = stat.free / 1024 / 1024 / 1024
|
||||
total_gb = stat.total / 1024 / 1024 / 1024
|
||||
used_percent = (stat.used / stat.total) * 100
|
||||
|
||||
if stat.free < self.min_disk_space_bytes:
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.UNHEALTHY,
|
||||
message=f"Low disk space: {free_gb:.2f} GB free (minimum: {self.min_disk_space_bytes / 1024 / 1024 / 1024:.2f} GB)",
|
||||
details={
|
||||
"free_gb": free_gb,
|
||||
"total_gb": total_gb,
|
||||
"used_percent": used_percent,
|
||||
},
|
||||
)
|
||||
elif stat.free < self.min_disk_space_bytes * 2:
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.DEGRADED,
|
||||
message=f"Disk space low: {free_gb:.2f} GB free",
|
||||
details={
|
||||
"free_gb": free_gb,
|
||||
"total_gb": total_gb,
|
||||
"used_percent": used_percent,
|
||||
},
|
||||
)
|
||||
else:
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.HEALTHY,
|
||||
message=f"Disk space healthy: {free_gb:.2f} GB free",
|
||||
details={
|
||||
"free_gb": free_gb,
|
||||
"total_gb": total_gb,
|
||||
"used_percent": used_percent,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.UNHEALTHY,
|
||||
message=f"Failed to check disk space: {exc}",
|
||||
)
|
||||
|
||||
def check_backup_recency(self) -> HealthCheckResult:
|
||||
"""Check if backups are recent enough.
|
||||
|
||||
Returns:
|
||||
HealthCheckResult
|
||||
"""
|
||||
daily_dir = self.backup_dir / "daily"
|
||||
|
||||
if not daily_dir.exists():
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.DEGRADED,
|
||||
message="Daily backup directory not found",
|
||||
)
|
||||
|
||||
# Find most recent backup
|
||||
backups = sorted(daily_dir.glob("*.db"), key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
|
||||
if not backups:
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.UNHEALTHY,
|
||||
message="No daily backups found",
|
||||
)
|
||||
|
||||
most_recent = backups[0]
|
||||
mtime = datetime.fromtimestamp(most_recent.stat().st_mtime, tz=UTC)
|
||||
age = datetime.now(UTC) - mtime
|
||||
|
||||
if age > self.max_backup_age:
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.DEGRADED,
|
||||
message=f"Most recent backup is {age.total_seconds() / 3600:.1f} hours old",
|
||||
details={
|
||||
"backup_file": most_recent.name,
|
||||
"age_hours": age.total_seconds() / 3600,
|
||||
"threshold_hours": self.max_backup_age.total_seconds() / 3600,
|
||||
},
|
||||
)
|
||||
else:
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.HEALTHY,
|
||||
message=f"Recent backup found ({age.total_seconds() / 3600:.1f} hours old)",
|
||||
details={
|
||||
"backup_file": most_recent.name,
|
||||
"age_hours": age.total_seconds() / 3600,
|
||||
},
|
||||
)
|
||||
|
||||
def run_all_checks(self) -> dict[str, HealthCheckResult]:
|
||||
"""Run all health checks.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping check name to result
|
||||
"""
|
||||
checks = {
|
||||
"database": self.check_database_health(),
|
||||
"disk_space": self.check_disk_space(),
|
||||
"backup_recency": self.check_backup_recency(),
|
||||
}
|
||||
|
||||
# Log results
|
||||
for check_name, result in checks.items():
|
||||
if result.status == HealthStatus.UNHEALTHY:
|
||||
logger.error("[%s] %s: %s", check_name, result.status.value, result.message)
|
||||
elif result.status == HealthStatus.DEGRADED:
|
||||
logger.warning("[%s] %s: %s", check_name, result.status.value, result.message)
|
||||
else:
|
||||
logger.info("[%s] %s: %s", check_name, result.status.value, result.message)
|
||||
|
||||
return checks
|
||||
|
||||
def get_overall_status(self) -> HealthStatus:
|
||||
"""Get overall system health status.
|
||||
|
||||
Returns:
|
||||
HealthStatus (worst status from all checks)
|
||||
"""
|
||||
checks = self.run_all_checks()
|
||||
|
||||
# Return worst status
|
||||
if any(c.status == HealthStatus.UNHEALTHY for c in checks.values()):
|
||||
return HealthStatus.UNHEALTHY
|
||||
elif any(c.status == HealthStatus.DEGRADED for c in checks.values()):
|
||||
return HealthStatus.DEGRADED
|
||||
else:
|
||||
return HealthStatus.HEALTHY
|
||||
|
||||
def get_health_report(self) -> dict[str, Any]:
|
||||
"""Get comprehensive health report.
|
||||
|
||||
Returns:
|
||||
Dictionary with health report
|
||||
"""
|
||||
checks = self.run_all_checks()
|
||||
overall = self.get_overall_status()
|
||||
|
||||
return {
|
||||
"overall_status": overall.value,
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"checks": {
|
||||
name: {
|
||||
"status": result.status.value,
|
||||
"message": result.message,
|
||||
"details": result.details,
|
||||
}
|
||||
for name, result in checks.items()
|
||||
},
|
||||
}
|
||||
336
src/backup/scheduler.py
Normal file
336
src/backup/scheduler.py
Normal file
@@ -0,0 +1,336 @@
|
||||
"""Backup scheduler for automated database backups.
|
||||
|
||||
Implements backup policies:
|
||||
- Daily: Keep for 30 days (hot storage)
|
||||
- Weekly: Keep for 1 year (warm storage)
|
||||
- Monthly: Keep forever (cold storage)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BackupPolicy(str, Enum):
|
||||
"""Backup retention policies."""
|
||||
|
||||
DAILY = "daily"
|
||||
WEEKLY = "weekly"
|
||||
MONTHLY = "monthly"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackupMetadata:
|
||||
"""Metadata for a backup."""
|
||||
|
||||
timestamp: datetime
|
||||
policy: BackupPolicy
|
||||
file_path: Path
|
||||
size_bytes: int
|
||||
checksum: str | None = None
|
||||
|
||||
|
||||
class BackupScheduler:
|
||||
"""Manage automated database backups with retention policies."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_path: str,
|
||||
backup_dir: Path,
|
||||
daily_retention_days: int = 30,
|
||||
weekly_retention_days: int = 365,
|
||||
) -> None:
|
||||
"""Initialize the backup scheduler.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database
|
||||
backup_dir: Root directory for backups
|
||||
daily_retention_days: Days to keep daily backups
|
||||
weekly_retention_days: Days to keep weekly backups
|
||||
"""
|
||||
self.db_path = Path(db_path)
|
||||
self.backup_dir = backup_dir
|
||||
self.daily_retention = timedelta(days=daily_retention_days)
|
||||
self.weekly_retention = timedelta(days=weekly_retention_days)
|
||||
|
||||
# Create policy-specific directories
|
||||
self.daily_dir = backup_dir / "daily"
|
||||
self.weekly_dir = backup_dir / "weekly"
|
||||
self.monthly_dir = backup_dir / "monthly"
|
||||
|
||||
for d in [self.daily_dir, self.weekly_dir, self.monthly_dir]:
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def create_backup(
|
||||
self, policy: BackupPolicy, verify: bool = True
|
||||
) -> BackupMetadata:
|
||||
"""Create a database backup.
|
||||
|
||||
Args:
|
||||
policy: Backup policy (daily/weekly/monthly)
|
||||
verify: Whether to verify backup integrity
|
||||
|
||||
Returns:
|
||||
BackupMetadata object
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If database doesn't exist
|
||||
OSError: If backup fails
|
||||
"""
|
||||
if not self.db_path.exists():
|
||||
raise FileNotFoundError(f"Database not found: {self.db_path}")
|
||||
|
||||
timestamp = datetime.now(UTC)
|
||||
backup_filename = self._get_backup_filename(timestamp, policy)
|
||||
|
||||
# Determine output directory
|
||||
if policy == BackupPolicy.DAILY:
|
||||
output_dir = self.daily_dir
|
||||
elif policy == BackupPolicy.WEEKLY:
|
||||
output_dir = self.weekly_dir
|
||||
else: # MONTHLY
|
||||
output_dir = self.monthly_dir
|
||||
|
||||
backup_path = output_dir / backup_filename
|
||||
|
||||
# Create backup (copy database file)
|
||||
logger.info("Creating %s backup: %s", policy.value, backup_path)
|
||||
shutil.copy2(self.db_path, backup_path)
|
||||
|
||||
# Get file size
|
||||
size_bytes = backup_path.stat().st_size
|
||||
|
||||
# Verify backup if requested
|
||||
checksum = None
|
||||
if verify:
|
||||
checksum = self._verify_backup(backup_path)
|
||||
|
||||
metadata = BackupMetadata(
|
||||
timestamp=timestamp,
|
||||
policy=policy,
|
||||
file_path=backup_path,
|
||||
size_bytes=size_bytes,
|
||||
checksum=checksum,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Backup created: %s (%.2f MB)",
|
||||
backup_path.name,
|
||||
size_bytes / 1024 / 1024,
|
||||
)
|
||||
|
||||
return metadata
|
||||
|
||||
def _get_backup_filename(self, timestamp: datetime, policy: BackupPolicy) -> str:
|
||||
"""Generate backup filename.
|
||||
|
||||
Args:
|
||||
timestamp: Backup timestamp
|
||||
policy: Backup policy
|
||||
|
||||
Returns:
|
||||
Filename string
|
||||
"""
|
||||
ts_str = timestamp.strftime("%Y%m%d_%H%M%S")
|
||||
return f"trade_logs_{policy.value}_{ts_str}.db"
|
||||
|
||||
def _verify_backup(self, backup_path: Path) -> str:
|
||||
"""Verify backup integrity using SQLite integrity check.
|
||||
|
||||
Args:
|
||||
backup_path: Path to backup file
|
||||
|
||||
Returns:
|
||||
Checksum string (MD5 hash)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If integrity check fails
|
||||
"""
|
||||
import hashlib
|
||||
import sqlite3
|
||||
|
||||
# Integrity check
|
||||
try:
|
||||
conn = sqlite3.connect(str(backup_path))
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("PRAGMA integrity_check")
|
||||
result = cursor.fetchone()[0]
|
||||
conn.close()
|
||||
|
||||
if result != "ok":
|
||||
raise RuntimeError(f"Integrity check failed: {result}")
|
||||
except sqlite3.Error as exc:
|
||||
raise RuntimeError(f"Failed to verify backup: {exc}")
|
||||
|
||||
# Calculate MD5 checksum
|
||||
md5 = hashlib.md5()
|
||||
with open(backup_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(8192), b""):
|
||||
md5.update(chunk)
|
||||
|
||||
return md5.hexdigest()
|
||||
|
||||
def cleanup_old_backups(self) -> dict[BackupPolicy, int]:
|
||||
"""Remove backups older than retention policies.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping policy to number of backups removed
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
removed_counts: dict[BackupPolicy, int] = {}
|
||||
|
||||
# Daily backups: remove older than retention
|
||||
removed_counts[BackupPolicy.DAILY] = self._cleanup_directory(
|
||||
self.daily_dir, now - self.daily_retention
|
||||
)
|
||||
|
||||
# Weekly backups: remove older than retention
|
||||
removed_counts[BackupPolicy.WEEKLY] = self._cleanup_directory(
|
||||
self.weekly_dir, now - self.weekly_retention
|
||||
)
|
||||
|
||||
# Monthly backups: never remove (kept forever)
|
||||
removed_counts[BackupPolicy.MONTHLY] = 0
|
||||
|
||||
total = sum(removed_counts.values())
|
||||
if total > 0:
|
||||
logger.info("Cleaned up %d old backup(s)", total)
|
||||
|
||||
return removed_counts
|
||||
|
||||
def _cleanup_directory(self, directory: Path, cutoff: datetime) -> int:
|
||||
"""Remove backups older than cutoff date.
|
||||
|
||||
Args:
|
||||
directory: Directory to clean
|
||||
cutoff: Remove files older than this
|
||||
|
||||
Returns:
|
||||
Number of files removed
|
||||
"""
|
||||
removed = 0
|
||||
|
||||
for backup_file in directory.glob("*.db"):
|
||||
# Get file modification time
|
||||
mtime = datetime.fromtimestamp(backup_file.stat().st_mtime, tz=UTC)
|
||||
|
||||
if mtime < cutoff:
|
||||
logger.debug("Removing old backup: %s", backup_file.name)
|
||||
backup_file.unlink()
|
||||
removed += 1
|
||||
|
||||
return removed
|
||||
|
||||
def list_backups(
|
||||
self, policy: BackupPolicy | None = None
|
||||
) -> list[BackupMetadata]:
|
||||
"""List available backups.
|
||||
|
||||
Args:
|
||||
policy: Filter by policy (None for all)
|
||||
|
||||
Returns:
|
||||
List of BackupMetadata objects
|
||||
"""
|
||||
backups: list[BackupMetadata] = []
|
||||
|
||||
policies_to_check = (
|
||||
[policy] if policy else [BackupPolicy.DAILY, BackupPolicy.WEEKLY, BackupPolicy.MONTHLY]
|
||||
)
|
||||
|
||||
for pol in policies_to_check:
|
||||
if pol == BackupPolicy.DAILY:
|
||||
directory = self.daily_dir
|
||||
elif pol == BackupPolicy.WEEKLY:
|
||||
directory = self.weekly_dir
|
||||
else:
|
||||
directory = self.monthly_dir
|
||||
|
||||
for backup_file in sorted(directory.glob("*.db")):
|
||||
mtime = datetime.fromtimestamp(backup_file.stat().st_mtime, tz=UTC)
|
||||
size = backup_file.stat().st_size
|
||||
|
||||
backups.append(
|
||||
BackupMetadata(
|
||||
timestamp=mtime,
|
||||
policy=pol,
|
||||
file_path=backup_file,
|
||||
size_bytes=size,
|
||||
)
|
||||
)
|
||||
|
||||
# Sort by timestamp (newest first)
|
||||
backups.sort(key=lambda b: b.timestamp, reverse=True)
|
||||
|
||||
return backups
|
||||
|
||||
def get_backup_stats(self) -> dict[str, Any]:
|
||||
"""Get backup statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with backup stats
|
||||
"""
|
||||
stats: dict[str, Any] = {}
|
||||
|
||||
for policy in BackupPolicy:
|
||||
if policy == BackupPolicy.DAILY:
|
||||
directory = self.daily_dir
|
||||
elif policy == BackupPolicy.WEEKLY:
|
||||
directory = self.weekly_dir
|
||||
else:
|
||||
directory = self.monthly_dir
|
||||
|
||||
backups = list(directory.glob("*.db"))
|
||||
total_size = sum(b.stat().st_size for b in backups)
|
||||
|
||||
stats[policy.value] = {
|
||||
"count": len(backups),
|
||||
"total_size_bytes": total_size,
|
||||
"total_size_mb": total_size / 1024 / 1024,
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
def restore_backup(self, backup_metadata: BackupMetadata, verify: bool = True) -> None:
|
||||
"""Restore database from backup.
|
||||
|
||||
Args:
|
||||
backup_metadata: Backup to restore
|
||||
verify: Whether to verify restored database
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If backup file doesn't exist
|
||||
RuntimeError: If verification fails
|
||||
"""
|
||||
if not backup_metadata.file_path.exists():
|
||||
raise FileNotFoundError(f"Backup not found: {backup_metadata.file_path}")
|
||||
|
||||
# Create backup of current database
|
||||
if self.db_path.exists():
|
||||
backup_current = self.db_path.with_suffix(".db.before_restore")
|
||||
logger.info("Backing up current database to: %s", backup_current)
|
||||
shutil.copy2(self.db_path, backup_current)
|
||||
|
||||
# Restore backup
|
||||
logger.info("Restoring backup: %s", backup_metadata.file_path.name)
|
||||
shutil.copy2(backup_metadata.file_path, self.db_path)
|
||||
|
||||
# Verify restored database
|
||||
if verify:
|
||||
try:
|
||||
self._verify_backup(self.db_path)
|
||||
logger.info("Backup restored and verified successfully")
|
||||
except RuntimeError as exc:
|
||||
# Restore failed, revert to backup
|
||||
if backup_current.exists():
|
||||
logger.error("Restore verification failed, reverting: %s", exc)
|
||||
shutil.copy2(backup_current, self.db_path)
|
||||
raise
|
||||
293
src/brain/cache.py
Normal file
293
src/brain/cache.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""Response caching system for reducing redundant LLM calls.
|
||||
|
||||
This module provides caching for common trading scenarios:
|
||||
- TTL-based cache invalidation
|
||||
- Cache key based on market conditions
|
||||
- Cache hit rate monitoring
|
||||
- Special handling for HOLD decisions in quiet markets
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.brain.gemini_client import TradeDecision
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""Cached decision with metadata."""
|
||||
|
||||
decision: "TradeDecision"
|
||||
cached_at: float # Unix timestamp
|
||||
hit_count: int = 0
|
||||
market_data_hash: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheMetrics:
|
||||
"""Metrics for cache performance monitoring."""
|
||||
|
||||
total_requests: int = 0
|
||||
cache_hits: int = 0
|
||||
cache_misses: int = 0
|
||||
evictions: int = 0
|
||||
total_entries: int = 0
|
||||
|
||||
@property
|
||||
def hit_rate(self) -> float:
|
||||
"""Calculate cache hit rate."""
|
||||
if self.total_requests == 0:
|
||||
return 0.0
|
||||
return self.cache_hits / self.total_requests
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert metrics to dictionary."""
|
||||
return {
|
||||
"total_requests": self.total_requests,
|
||||
"cache_hits": self.cache_hits,
|
||||
"cache_misses": self.cache_misses,
|
||||
"hit_rate": self.hit_rate,
|
||||
"evictions": self.evictions,
|
||||
"total_entries": self.total_entries,
|
||||
}
|
||||
|
||||
|
||||
class DecisionCache:
|
||||
"""TTL-based cache for trade decisions."""
|
||||
|
||||
def __init__(self, ttl_seconds: int = 300, max_size: int = 1000) -> None:
|
||||
"""Initialize the decision cache.
|
||||
|
||||
Args:
|
||||
ttl_seconds: Time-to-live for cache entries in seconds (default: 5 minutes)
|
||||
max_size: Maximum number of cache entries
|
||||
"""
|
||||
self.ttl_seconds = ttl_seconds
|
||||
self.max_size = max_size
|
||||
self._cache: dict[str, CacheEntry] = {}
|
||||
self._metrics = CacheMetrics()
|
||||
|
||||
def _generate_cache_key(self, market_data: dict[str, Any]) -> str:
|
||||
"""Generate cache key from market data.
|
||||
|
||||
Key is based on:
|
||||
- Stock code
|
||||
- Current price (rounded to reduce sensitivity)
|
||||
- Market conditions (orderbook snapshot)
|
||||
|
||||
Args:
|
||||
market_data: Market data dictionary
|
||||
|
||||
Returns:
|
||||
Cache key string
|
||||
"""
|
||||
# Extract key components
|
||||
stock_code = market_data.get("stock_code", "UNKNOWN")
|
||||
current_price = market_data.get("current_price", 0)
|
||||
|
||||
# Round price to reduce sensitivity (cache hits for similar prices)
|
||||
# For prices > 1000, round to nearest 10
|
||||
# For prices < 1000, round to nearest 1
|
||||
if current_price > 1000:
|
||||
price_rounded = round(current_price / 10) * 10
|
||||
else:
|
||||
price_rounded = round(current_price)
|
||||
|
||||
# Include orderbook snapshot (if available)
|
||||
orderbook_key = ""
|
||||
if "orderbook" in market_data and market_data["orderbook"]:
|
||||
ob = market_data["orderbook"]
|
||||
# Just use bid/ask spread as indicator
|
||||
if "bid" in ob and "ask" in ob and ob["bid"] and ob["ask"]:
|
||||
bid_price = ob["bid"][0].get("price", 0) if ob["bid"] else 0
|
||||
ask_price = ob["ask"][0].get("price", 0) if ob["ask"] else 0
|
||||
spread = ask_price - bid_price
|
||||
orderbook_key = f"_spread{spread}"
|
||||
|
||||
# Generate cache key
|
||||
key_str = f"{stock_code}_{price_rounded}{orderbook_key}"
|
||||
|
||||
return key_str
|
||||
|
||||
def _generate_market_hash(self, market_data: dict[str, Any]) -> str:
|
||||
"""Generate hash of full market data for invalidation checks.
|
||||
|
||||
Args:
|
||||
market_data: Market data dictionary
|
||||
|
||||
Returns:
|
||||
Hash string
|
||||
"""
|
||||
# Create stable JSON representation
|
||||
stable_json = json.dumps(market_data, sort_keys=True, ensure_ascii=False)
|
||||
return hashlib.md5(stable_json.encode()).hexdigest()
|
||||
|
||||
def get(self, market_data: dict[str, Any]) -> TradeDecision | None:
|
||||
"""Retrieve cached decision if valid.
|
||||
|
||||
Args:
|
||||
market_data: Market data dictionary
|
||||
|
||||
Returns:
|
||||
Cached TradeDecision if valid, None otherwise
|
||||
"""
|
||||
self._metrics.total_requests += 1
|
||||
|
||||
cache_key = self._generate_cache_key(market_data)
|
||||
|
||||
if cache_key not in self._cache:
|
||||
self._metrics.cache_misses += 1
|
||||
return None
|
||||
|
||||
entry = self._cache[cache_key]
|
||||
current_time = time.time()
|
||||
|
||||
# Check TTL
|
||||
if current_time - entry.cached_at > self.ttl_seconds:
|
||||
# Expired
|
||||
del self._cache[cache_key]
|
||||
self._metrics.cache_misses += 1
|
||||
self._metrics.evictions += 1
|
||||
logger.debug("Cache expired for key: %s", cache_key)
|
||||
return None
|
||||
|
||||
# Cache hit
|
||||
entry.hit_count += 1
|
||||
self._metrics.cache_hits += 1
|
||||
logger.debug("Cache hit for key: %s (hits: %d)", cache_key, entry.hit_count)
|
||||
|
||||
return entry.decision
|
||||
|
||||
def set(
|
||||
self,
|
||||
market_data: dict[str, Any],
|
||||
decision: TradeDecision,
|
||||
) -> None:
|
||||
"""Store decision in cache.
|
||||
|
||||
Args:
|
||||
market_data: Market data dictionary
|
||||
decision: TradeDecision to cache
|
||||
"""
|
||||
cache_key = self._generate_cache_key(market_data)
|
||||
market_hash = self._generate_market_hash(market_data)
|
||||
|
||||
# Enforce max size (evict oldest if full)
|
||||
if len(self._cache) >= self.max_size:
|
||||
# Find oldest entry
|
||||
oldest_key = min(self._cache.keys(), key=lambda k: self._cache[k].cached_at)
|
||||
del self._cache[oldest_key]
|
||||
self._metrics.evictions += 1
|
||||
logger.debug("Cache full, evicted key: %s", oldest_key)
|
||||
|
||||
# Store entry
|
||||
entry = CacheEntry(
|
||||
decision=decision,
|
||||
cached_at=time.time(),
|
||||
market_data_hash=market_hash,
|
||||
)
|
||||
self._cache[cache_key] = entry
|
||||
self._metrics.total_entries = len(self._cache)
|
||||
|
||||
logger.debug("Cached decision for key: %s", cache_key)
|
||||
|
||||
def invalidate(self, stock_code: str | None = None) -> int:
|
||||
"""Invalidate cache entries.
|
||||
|
||||
Args:
|
||||
stock_code: Specific stock code to invalidate, or None for all
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
if stock_code is None:
|
||||
# Clear all
|
||||
count = len(self._cache)
|
||||
self._cache.clear()
|
||||
self._metrics.evictions += count
|
||||
self._metrics.total_entries = 0
|
||||
logger.info("Invalidated all cache entries (%d)", count)
|
||||
return count
|
||||
|
||||
# Invalidate specific stock
|
||||
keys_to_remove = [k for k in self._cache.keys() if k.startswith(f"{stock_code}_")]
|
||||
count = len(keys_to_remove)
|
||||
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
|
||||
self._metrics.evictions += count
|
||||
self._metrics.total_entries = len(self._cache)
|
||||
logger.info("Invalidated %d cache entries for stock: %s", count, stock_code)
|
||||
|
||||
return count
|
||||
|
||||
def cleanup_expired(self) -> int:
|
||||
"""Remove expired entries from cache.
|
||||
|
||||
Returns:
|
||||
Number of entries removed
|
||||
"""
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
k
|
||||
for k, v in self._cache.items()
|
||||
if current_time - v.cached_at > self.ttl_seconds
|
||||
]
|
||||
|
||||
count = len(expired_keys)
|
||||
for key in expired_keys:
|
||||
del self._cache[key]
|
||||
|
||||
self._metrics.evictions += count
|
||||
self._metrics.total_entries = len(self._cache)
|
||||
|
||||
if count > 0:
|
||||
logger.debug("Cleaned up %d expired cache entries", count)
|
||||
|
||||
return count
|
||||
|
||||
def get_metrics(self) -> CacheMetrics:
|
||||
"""Get current cache metrics.
|
||||
|
||||
Returns:
|
||||
CacheMetrics object with current statistics
|
||||
"""
|
||||
return self._metrics
|
||||
|
||||
def reset_metrics(self) -> None:
|
||||
"""Reset cache metrics."""
|
||||
self._metrics = CacheMetrics(total_entries=len(self._cache))
|
||||
logger.info("Cache metrics reset")
|
||||
|
||||
def should_cache_decision(self, decision: TradeDecision) -> bool:
|
||||
"""Determine if a decision should be cached.
|
||||
|
||||
HOLD decisions with low confidence are good candidates for caching,
|
||||
as they're likely to recur in quiet markets.
|
||||
|
||||
Args:
|
||||
decision: TradeDecision to evaluate
|
||||
|
||||
Returns:
|
||||
True if decision should be cached
|
||||
"""
|
||||
# Cache HOLD decisions (common in quiet markets)
|
||||
if decision.action == "HOLD":
|
||||
return True
|
||||
|
||||
# Cache high-confidence decisions (stable signals)
|
||||
if decision.confidence >= 90:
|
||||
return True
|
||||
|
||||
# Don't cache low-confidence BUY/SELL (volatile signals)
|
||||
return False
|
||||
296
src/brain/context_selector.py
Normal file
296
src/brain/context_selector.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""Smart context selection for optimizing token usage.
|
||||
|
||||
This module implements intelligent selection of context layers (L1-L7) based on
|
||||
decision type and market conditions:
|
||||
- L7 (real-time) for normal trading decisions
|
||||
- L6-L5 (daily/weekly) for strategic decisions
|
||||
- L4-L1 (monthly/legacy) only for major events or policy changes
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from src.context.layer import ContextLayer
|
||||
from src.context.store import ContextStore
|
||||
|
||||
|
||||
class DecisionType(str, Enum):
|
||||
"""Type of trading decision being made."""
|
||||
|
||||
NORMAL = "normal" # Regular trade decision
|
||||
STRATEGIC = "strategic" # Strategy adjustment
|
||||
MAJOR_EVENT = "major_event" # Portfolio rebalancing, policy change
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ContextSelection:
|
||||
"""Selected context layers and their relevance scores."""
|
||||
|
||||
layers: list[ContextLayer]
|
||||
relevance_scores: dict[ContextLayer, float]
|
||||
total_score: float
|
||||
|
||||
|
||||
class ContextSelector:
|
||||
"""Selects optimal context layers to minimize token usage."""
|
||||
|
||||
def __init__(self, store: ContextStore) -> None:
|
||||
"""Initialize the context selector.
|
||||
|
||||
Args:
|
||||
store: ContextStore instance for retrieving context data
|
||||
"""
|
||||
self.store = store
|
||||
|
||||
def select_layers(
|
||||
self,
|
||||
decision_type: DecisionType = DecisionType.NORMAL,
|
||||
include_realtime: bool = True,
|
||||
) -> list[ContextLayer]:
|
||||
"""Select context layers based on decision type.
|
||||
|
||||
Strategy:
|
||||
- NORMAL: L7 (real-time) only
|
||||
- STRATEGIC: L7 + L6 + L5 (real-time + daily + weekly)
|
||||
- MAJOR_EVENT: All layers L1-L7
|
||||
|
||||
Args:
|
||||
decision_type: Type of decision being made
|
||||
include_realtime: Whether to include L7 real-time data
|
||||
|
||||
Returns:
|
||||
List of context layers to use (ordered by priority)
|
||||
"""
|
||||
if decision_type == DecisionType.NORMAL:
|
||||
# Normal trading: only real-time data
|
||||
return [ContextLayer.L7_REALTIME] if include_realtime else []
|
||||
|
||||
elif decision_type == DecisionType.STRATEGIC:
|
||||
# Strategic decisions: real-time + recent history
|
||||
layers = []
|
||||
if include_realtime:
|
||||
layers.append(ContextLayer.L7_REALTIME)
|
||||
layers.extend([ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY])
|
||||
return layers
|
||||
|
||||
else: # MAJOR_EVENT
|
||||
# Major events: all layers for comprehensive context
|
||||
layers = []
|
||||
if include_realtime:
|
||||
layers.append(ContextLayer.L7_REALTIME)
|
||||
layers.extend(
|
||||
[
|
||||
ContextLayer.L6_DAILY,
|
||||
ContextLayer.L5_WEEKLY,
|
||||
ContextLayer.L4_MONTHLY,
|
||||
ContextLayer.L3_QUARTERLY,
|
||||
ContextLayer.L2_ANNUAL,
|
||||
ContextLayer.L1_LEGACY,
|
||||
]
|
||||
)
|
||||
return layers
|
||||
|
||||
def score_layer_relevance(
|
||||
self,
|
||||
layer: ContextLayer,
|
||||
decision_type: DecisionType,
|
||||
current_time: datetime | None = None,
|
||||
) -> float:
|
||||
"""Calculate relevance score for a context layer.
|
||||
|
||||
Relevance is based on:
|
||||
1. Decision type (normal, strategic, major event)
|
||||
2. Layer recency (L7 > L6 > ... > L1)
|
||||
3. Data availability
|
||||
|
||||
Args:
|
||||
layer: Context layer to score
|
||||
decision_type: Type of decision being made
|
||||
current_time: Current time (defaults to now)
|
||||
|
||||
Returns:
|
||||
Relevance score (0.0 to 1.0)
|
||||
"""
|
||||
if current_time is None:
|
||||
current_time = datetime.now(UTC)
|
||||
|
||||
# Base scores by decision type
|
||||
base_scores = {
|
||||
DecisionType.NORMAL: {
|
||||
ContextLayer.L7_REALTIME: 1.0,
|
||||
ContextLayer.L6_DAILY: 0.1,
|
||||
ContextLayer.L5_WEEKLY: 0.05,
|
||||
ContextLayer.L4_MONTHLY: 0.01,
|
||||
ContextLayer.L3_QUARTERLY: 0.0,
|
||||
ContextLayer.L2_ANNUAL: 0.0,
|
||||
ContextLayer.L1_LEGACY: 0.0,
|
||||
},
|
||||
DecisionType.STRATEGIC: {
|
||||
ContextLayer.L7_REALTIME: 0.9,
|
||||
ContextLayer.L6_DAILY: 0.8,
|
||||
ContextLayer.L5_WEEKLY: 0.7,
|
||||
ContextLayer.L4_MONTHLY: 0.3,
|
||||
ContextLayer.L3_QUARTERLY: 0.2,
|
||||
ContextLayer.L2_ANNUAL: 0.1,
|
||||
ContextLayer.L1_LEGACY: 0.05,
|
||||
},
|
||||
DecisionType.MAJOR_EVENT: {
|
||||
ContextLayer.L7_REALTIME: 0.7,
|
||||
ContextLayer.L6_DAILY: 0.7,
|
||||
ContextLayer.L5_WEEKLY: 0.7,
|
||||
ContextLayer.L4_MONTHLY: 0.8,
|
||||
ContextLayer.L3_QUARTERLY: 0.8,
|
||||
ContextLayer.L2_ANNUAL: 0.9,
|
||||
ContextLayer.L1_LEGACY: 1.0,
|
||||
},
|
||||
}
|
||||
|
||||
score = base_scores[decision_type].get(layer, 0.0)
|
||||
|
||||
# Check data availability
|
||||
latest_timeframe = self.store.get_latest_timeframe(layer)
|
||||
if latest_timeframe is None:
|
||||
# No data available - reduce score significantly
|
||||
score *= 0.1
|
||||
|
||||
return score
|
||||
|
||||
def select_with_scoring(
|
||||
self,
|
||||
decision_type: DecisionType = DecisionType.NORMAL,
|
||||
min_score: float = 0.5,
|
||||
) -> ContextSelection:
|
||||
"""Select context layers with relevance scoring.
|
||||
|
||||
Args:
|
||||
decision_type: Type of decision being made
|
||||
min_score: Minimum relevance score to include a layer
|
||||
|
||||
Returns:
|
||||
ContextSelection with selected layers and scores
|
||||
"""
|
||||
all_layers = [
|
||||
ContextLayer.L7_REALTIME,
|
||||
ContextLayer.L6_DAILY,
|
||||
ContextLayer.L5_WEEKLY,
|
||||
ContextLayer.L4_MONTHLY,
|
||||
ContextLayer.L3_QUARTERLY,
|
||||
ContextLayer.L2_ANNUAL,
|
||||
ContextLayer.L1_LEGACY,
|
||||
]
|
||||
|
||||
scores = {
|
||||
layer: self.score_layer_relevance(layer, decision_type) for layer in all_layers
|
||||
}
|
||||
|
||||
# Filter by minimum score
|
||||
selected_layers = [layer for layer, score in scores.items() if score >= min_score]
|
||||
|
||||
# Sort by score (descending)
|
||||
selected_layers.sort(key=lambda layer: scores[layer], reverse=True)
|
||||
|
||||
total_score = sum(scores[layer] for layer in selected_layers)
|
||||
|
||||
return ContextSelection(
|
||||
layers=selected_layers,
|
||||
relevance_scores=scores,
|
||||
total_score=total_score,
|
||||
)
|
||||
|
||||
def get_context_data(
|
||||
self,
|
||||
layers: list[ContextLayer],
|
||||
max_items_per_layer: int = 10,
|
||||
) -> dict[str, Any]:
|
||||
"""Retrieve context data for selected layers.
|
||||
|
||||
Args:
|
||||
layers: List of context layers to retrieve
|
||||
max_items_per_layer: Maximum number of items per layer
|
||||
|
||||
Returns:
|
||||
Dictionary with context data organized by layer
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
|
||||
for layer in layers:
|
||||
# Get latest timeframe for this layer
|
||||
latest_timeframe = self.store.get_latest_timeframe(layer)
|
||||
if latest_timeframe:
|
||||
# Get all contexts for latest timeframe
|
||||
contexts = self.store.get_all_contexts(layer, latest_timeframe)
|
||||
|
||||
# Limit number of items
|
||||
if len(contexts) > max_items_per_layer:
|
||||
# Keep only first N items
|
||||
contexts = dict(list(contexts.items())[:max_items_per_layer])
|
||||
|
||||
result[layer.value] = contexts
|
||||
|
||||
return result
|
||||
|
||||
def estimate_context_tokens(self, context_data: dict[str, Any]) -> int:
|
||||
"""Estimate total tokens for context data.
|
||||
|
||||
Args:
|
||||
context_data: Context data dictionary
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
import json
|
||||
|
||||
from src.brain.prompt_optimizer import PromptOptimizer
|
||||
|
||||
# Serialize to JSON and estimate tokens
|
||||
json_str = json.dumps(context_data, ensure_ascii=False)
|
||||
return PromptOptimizer.estimate_tokens(json_str)
|
||||
|
||||
def optimize_context_for_budget(
|
||||
self,
|
||||
decision_type: DecisionType,
|
||||
max_tokens: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Select and retrieve context data within a token budget.
|
||||
|
||||
Args:
|
||||
decision_type: Type of decision being made
|
||||
max_tokens: Maximum token budget for context
|
||||
|
||||
Returns:
|
||||
Optimized context data within budget
|
||||
"""
|
||||
# Start with minimal selection
|
||||
selection = self.select_with_scoring(decision_type, min_score=0.5)
|
||||
|
||||
# Retrieve data
|
||||
context_data = self.get_context_data(selection.layers)
|
||||
|
||||
# Check if within budget
|
||||
estimated_tokens = self.estimate_context_tokens(context_data)
|
||||
|
||||
if estimated_tokens <= max_tokens:
|
||||
return context_data
|
||||
|
||||
# If over budget, progressively reduce
|
||||
# 1. Reduce items per layer
|
||||
for max_items in [5, 3, 1]:
|
||||
context_data = self.get_context_data(selection.layers, max_items)
|
||||
estimated_tokens = self.estimate_context_tokens(context_data)
|
||||
if estimated_tokens <= max_tokens:
|
||||
return context_data
|
||||
|
||||
# 2. Remove lower-priority layers
|
||||
for min_score in [0.6, 0.7, 0.8, 0.9]:
|
||||
selection = self.select_with_scoring(decision_type, min_score=min_score)
|
||||
context_data = self.get_context_data(selection.layers, max_items_per_layer=1)
|
||||
estimated_tokens = self.estimate_context_tokens(context_data)
|
||||
if estimated_tokens <= max_tokens:
|
||||
return context_data
|
||||
|
||||
# Last resort: return only L7 with minimal data
|
||||
return self.get_context_data([ContextLayer.L7_REALTIME], max_items_per_layer=1)
|
||||
@@ -2,6 +2,17 @@
|
||||
|
||||
Constructs prompts from market data, calls Gemini, and parses structured
|
||||
JSON responses into validated TradeDecision objects.
|
||||
|
||||
Includes token efficiency optimizations:
|
||||
- Prompt compression and abbreviation
|
||||
- Response caching for common scenarios
|
||||
- Smart context selection
|
||||
- Token usage tracking and metrics
|
||||
|
||||
Includes external data integration:
|
||||
- News sentiment analysis
|
||||
- Economic calendar events
|
||||
- Market indicators
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -15,6 +26,11 @@ from typing import Any
|
||||
from google import genai
|
||||
|
||||
from src.config import Settings
|
||||
from src.data.news_api import NewsAPI, NewsSentiment
|
||||
from src.data.economic_calendar import EconomicCalendar
|
||||
from src.data.market_data import MarketData
|
||||
from src.brain.cache import DecisionCache
|
||||
from src.brain.prompt_optimizer import PromptOptimizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,23 +44,176 @@ class TradeDecision:
|
||||
action: str # "BUY" | "SELL" | "HOLD"
|
||||
confidence: int # 0-100
|
||||
rationale: str
|
||||
token_count: int = 0 # Estimated tokens used
|
||||
cached: bool = False # Whether decision came from cache
|
||||
|
||||
|
||||
class GeminiClient:
|
||||
"""Wraps the Gemini API for trade decision-making."""
|
||||
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
settings: Settings,
|
||||
news_api: NewsAPI | None = None,
|
||||
economic_calendar: EconomicCalendar | None = None,
|
||||
market_data: MarketData | None = None,
|
||||
enable_cache: bool = True,
|
||||
enable_optimization: bool = True,
|
||||
) -> None:
|
||||
self._settings = settings
|
||||
self._confidence_threshold = settings.CONFIDENCE_THRESHOLD
|
||||
self._client = genai.Client(api_key=settings.GEMINI_API_KEY)
|
||||
self._model_name = settings.GEMINI_MODEL
|
||||
|
||||
# External data sources (optional)
|
||||
self._news_api = news_api
|
||||
self._economic_calendar = economic_calendar
|
||||
self._market_data = market_data
|
||||
|
||||
# Token efficiency features
|
||||
self._enable_cache = enable_cache
|
||||
self._enable_optimization = enable_optimization
|
||||
self._cache = DecisionCache(ttl_seconds=300) if enable_cache else None
|
||||
self._optimizer = PromptOptimizer()
|
||||
|
||||
# Token usage metrics
|
||||
self._total_tokens_used = 0
|
||||
self._total_decisions = 0
|
||||
self._total_cached_decisions = 0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# External Data Integration
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _build_external_context(
|
||||
self, stock_code: str, news_sentiment: NewsSentiment | None = None
|
||||
) -> str:
|
||||
"""Build external data context for the prompt.
|
||||
|
||||
Args:
|
||||
stock_code: Stock ticker symbol
|
||||
news_sentiment: Optional pre-fetched news sentiment
|
||||
|
||||
Returns:
|
||||
Formatted string with external data context
|
||||
"""
|
||||
context_parts: list[str] = []
|
||||
|
||||
# News sentiment
|
||||
if news_sentiment is not None:
|
||||
sentiment_str = self._format_news_sentiment(news_sentiment)
|
||||
if sentiment_str:
|
||||
context_parts.append(sentiment_str)
|
||||
elif self._news_api is not None:
|
||||
# Fetch news sentiment if not provided
|
||||
try:
|
||||
sentiment = await self._news_api.get_news_sentiment(stock_code)
|
||||
if sentiment is not None:
|
||||
sentiment_str = self._format_news_sentiment(sentiment)
|
||||
if sentiment_str:
|
||||
context_parts.append(sentiment_str)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to fetch news sentiment: %s", exc)
|
||||
|
||||
# Economic events
|
||||
if self._economic_calendar is not None:
|
||||
events_str = self._format_economic_events(stock_code)
|
||||
if events_str:
|
||||
context_parts.append(events_str)
|
||||
|
||||
# Market indicators
|
||||
if self._market_data is not None:
|
||||
indicators_str = self._format_market_indicators()
|
||||
if indicators_str:
|
||||
context_parts.append(indicators_str)
|
||||
|
||||
if not context_parts:
|
||||
return ""
|
||||
|
||||
return "EXTERNAL DATA:\n" + "\n\n".join(context_parts)
|
||||
|
||||
def _format_news_sentiment(self, sentiment: NewsSentiment) -> str:
|
||||
"""Format news sentiment for prompt."""
|
||||
if sentiment.article_count == 0:
|
||||
return ""
|
||||
|
||||
# Select top 3 most relevant articles
|
||||
top_articles = sentiment.articles[:3]
|
||||
|
||||
lines = [
|
||||
f"News Sentiment: {sentiment.avg_sentiment:.2f} "
|
||||
f"(from {sentiment.article_count} articles)",
|
||||
]
|
||||
|
||||
for i, article in enumerate(top_articles, 1):
|
||||
lines.append(
|
||||
f" {i}. [{article.source}] {article.title} "
|
||||
f"(sentiment: {article.sentiment_score:.2f})"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _format_economic_events(self, stock_code: str) -> str:
|
||||
"""Format upcoming economic events for prompt."""
|
||||
if self._economic_calendar is None:
|
||||
return ""
|
||||
|
||||
# Check for upcoming high-impact events
|
||||
upcoming = self._economic_calendar.get_upcoming_events(
|
||||
days_ahead=7, min_impact="HIGH"
|
||||
)
|
||||
|
||||
if upcoming.high_impact_count == 0:
|
||||
return ""
|
||||
|
||||
lines = [
|
||||
f"Upcoming High-Impact Events: {upcoming.high_impact_count} in next 7 days"
|
||||
]
|
||||
|
||||
if upcoming.next_major_event is not None:
|
||||
event = upcoming.next_major_event
|
||||
lines.append(
|
||||
f" Next: {event.name} ({event.event_type}) "
|
||||
f"on {event.datetime.strftime('%Y-%m-%d')}"
|
||||
)
|
||||
|
||||
# Check for earnings
|
||||
earnings_date = self._economic_calendar.get_earnings_date(stock_code)
|
||||
if earnings_date is not None:
|
||||
lines.append(
|
||||
f" Earnings: {stock_code} on {earnings_date.strftime('%Y-%m-%d')}"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _format_market_indicators(self) -> str:
|
||||
"""Format market indicators for prompt."""
|
||||
if self._market_data is None:
|
||||
return ""
|
||||
|
||||
try:
|
||||
indicators = self._market_data.get_market_indicators()
|
||||
lines = [f"Market Sentiment: {indicators.sentiment.name}"]
|
||||
|
||||
# Add breadth if meaningful
|
||||
if indicators.breadth.advance_decline_ratio != 1.0:
|
||||
lines.append(
|
||||
f"Advance/Decline Ratio: {indicators.breadth.advance_decline_ratio:.2f}"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to get market indicators: %s", exc)
|
||||
return ""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Prompt Construction
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def build_prompt(self, market_data: dict[str, Any]) -> str:
|
||||
"""Build a structured prompt from market data.
|
||||
async def build_prompt(
|
||||
self, market_data: dict[str, Any], news_sentiment: NewsSentiment | None = None
|
||||
) -> str:
|
||||
"""Build a structured prompt from market data and external sources.
|
||||
|
||||
The prompt instructs Gemini to return valid JSON with action,
|
||||
confidence, and rationale fields.
|
||||
@@ -72,6 +241,60 @@ class GeminiClient:
|
||||
|
||||
market_info = "\n".join(market_info_lines)
|
||||
|
||||
# Add external data context if available
|
||||
external_context = await self._build_external_context(
|
||||
market_data["stock_code"], news_sentiment
|
||||
)
|
||||
if external_context:
|
||||
market_info += f"\n\n{external_context}"
|
||||
|
||||
json_format = (
|
||||
'{"action": "BUY"|"SELL"|"HOLD", '
|
||||
'"confidence": <int 0-100>, "rationale": "<string>"}'
|
||||
)
|
||||
return (
|
||||
f"You are a professional {market_name} trading analyst.\n"
|
||||
"Analyze the following market data and decide whether to "
|
||||
"BUY, SELL, or HOLD.\n\n"
|
||||
f"{market_info}\n\n"
|
||||
"You MUST respond with ONLY valid JSON in the following format:\n"
|
||||
f"{json_format}\n\n"
|
||||
"Rules:\n"
|
||||
"- action must be exactly one of: BUY, SELL, HOLD\n"
|
||||
"- confidence must be an integer from 0 to 100\n"
|
||||
"- rationale must explain your reasoning concisely\n"
|
||||
"- Do NOT wrap the JSON in markdown code blocks\n"
|
||||
)
|
||||
|
||||
def build_prompt_sync(self, market_data: dict[str, Any]) -> str:
|
||||
"""Synchronous version of build_prompt (for backward compatibility).
|
||||
|
||||
This version does NOT include external data integration.
|
||||
Use async build_prompt() for full functionality.
|
||||
"""
|
||||
market_name = market_data.get("market_name", "Korean stock market")
|
||||
|
||||
# Build market data section dynamically based on available fields
|
||||
market_info_lines = [
|
||||
f"Market: {market_name}",
|
||||
f"Stock Code: {market_data['stock_code']}",
|
||||
f"Current Price: {market_data['current_price']}",
|
||||
]
|
||||
|
||||
# Add orderbook if available (domestic markets)
|
||||
if "orderbook" in market_data:
|
||||
market_info_lines.append(
|
||||
f"Orderbook: {json.dumps(market_data['orderbook'], ensure_ascii=False)}"
|
||||
)
|
||||
|
||||
# Add foreigner net if non-zero
|
||||
if market_data.get("foreigner_net", 0) != 0:
|
||||
market_info_lines.append(
|
||||
f"Foreigner Net Buy/Sell: {market_data['foreigner_net']}"
|
||||
)
|
||||
|
||||
market_info = "\n".join(market_info_lines)
|
||||
|
||||
json_format = (
|
||||
'{"action": "BUY"|"SELL"|"HOLD", '
|
||||
'"confidence": <int 0-100>, "rationale": "<string>"}'
|
||||
@@ -152,28 +375,385 @@ class GeminiClient:
|
||||
# API Call
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def decide(self, market_data: dict[str, Any]) -> TradeDecision:
|
||||
"""Build prompt, call Gemini, and return a parsed decision."""
|
||||
prompt = self.build_prompt(market_data)
|
||||
logger.info("Requesting trade decision from Gemini")
|
||||
async def decide(
|
||||
self, market_data: dict[str, Any], news_sentiment: NewsSentiment | None = None
|
||||
) -> TradeDecision:
|
||||
"""Build prompt, call Gemini, and return a parsed decision.
|
||||
|
||||
Args:
|
||||
market_data: Market data dictionary with price, orderbook, etc.
|
||||
news_sentiment: Optional pre-fetched news sentiment
|
||||
|
||||
Returns:
|
||||
Parsed TradeDecision
|
||||
"""
|
||||
# Check cache first
|
||||
if self._cache:
|
||||
cached_decision = self._cache.get(market_data)
|
||||
if cached_decision:
|
||||
self._total_cached_decisions += 1
|
||||
self._total_decisions += 1
|
||||
logger.info(
|
||||
"Cache hit for decision",
|
||||
extra={
|
||||
"action": cached_decision.action,
|
||||
"confidence": cached_decision.confidence,
|
||||
"cache_hit_rate": self.get_cache_hit_rate(),
|
||||
},
|
||||
)
|
||||
# Return cached decision with cached flag
|
||||
return TradeDecision(
|
||||
action=cached_decision.action,
|
||||
confidence=cached_decision.confidence,
|
||||
rationale=cached_decision.rationale,
|
||||
token_count=0,
|
||||
cached=True,
|
||||
)
|
||||
|
||||
# Build prompt (prompt_override takes priority for callers like pre_market_planner)
|
||||
if "prompt_override" in market_data:
|
||||
prompt = market_data["prompt_override"]
|
||||
elif self._enable_optimization:
|
||||
prompt = self._optimizer.build_compressed_prompt(market_data)
|
||||
else:
|
||||
prompt = await self.build_prompt(market_data, news_sentiment)
|
||||
|
||||
# Estimate tokens
|
||||
token_count = self._optimizer.estimate_tokens(prompt)
|
||||
self._total_tokens_used += token_count
|
||||
|
||||
logger.info(
|
||||
"Requesting trade decision from Gemini",
|
||||
extra={"estimated_tokens": token_count, "optimized": self._enable_optimization},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self._client.aio.models.generate_content(
|
||||
model=self._model_name, contents=prompt,
|
||||
model=self._model_name,
|
||||
contents=prompt,
|
||||
)
|
||||
raw = response.text
|
||||
except Exception as exc:
|
||||
logger.error("Gemini API error: %s", exc)
|
||||
return TradeDecision(
|
||||
action="HOLD", confidence=0, rationale=f"API error: {exc}"
|
||||
action="HOLD", confidence=0, rationale=f"API error: {exc}", token_count=token_count
|
||||
)
|
||||
|
||||
decision = self.parse_response(raw)
|
||||
self._total_decisions += 1
|
||||
|
||||
# Add token count to decision
|
||||
decision_with_tokens = TradeDecision(
|
||||
action=decision.action,
|
||||
confidence=decision.confidence,
|
||||
rationale=decision.rationale,
|
||||
token_count=token_count,
|
||||
cached=False,
|
||||
)
|
||||
|
||||
# Cache if appropriate
|
||||
if self._cache and self._cache.should_cache_decision(decision):
|
||||
self._cache.set(market_data, decision)
|
||||
|
||||
logger.info(
|
||||
"Gemini decision",
|
||||
extra={
|
||||
"action": decision.action,
|
||||
"confidence": decision.confidence,
|
||||
"tokens": token_count,
|
||||
"avg_tokens": self.get_avg_tokens_per_decision(),
|
||||
},
|
||||
)
|
||||
return decision
|
||||
|
||||
return decision_with_tokens
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Token Efficiency Metrics
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_token_metrics(self) -> dict[str, Any]:
|
||||
"""Get token usage metrics.
|
||||
|
||||
Returns:
|
||||
Dictionary with token usage statistics
|
||||
"""
|
||||
metrics = {
|
||||
"total_tokens_used": self._total_tokens_used,
|
||||
"total_decisions": self._total_decisions,
|
||||
"total_cached_decisions": self._total_cached_decisions,
|
||||
"avg_tokens_per_decision": self.get_avg_tokens_per_decision(),
|
||||
"cache_hit_rate": self.get_cache_hit_rate(),
|
||||
}
|
||||
|
||||
if self._cache:
|
||||
cache_metrics = self._cache.get_metrics()
|
||||
metrics["cache_metrics"] = cache_metrics.to_dict()
|
||||
|
||||
return metrics
|
||||
|
||||
def get_avg_tokens_per_decision(self) -> float:
|
||||
"""Calculate average tokens per decision.
|
||||
|
||||
Returns:
|
||||
Average tokens per decision
|
||||
"""
|
||||
if self._total_decisions == 0:
|
||||
return 0.0
|
||||
return self._total_tokens_used / self._total_decisions
|
||||
|
||||
def get_cache_hit_rate(self) -> float:
|
||||
"""Calculate cache hit rate.
|
||||
|
||||
Returns:
|
||||
Cache hit rate (0.0 to 1.0)
|
||||
"""
|
||||
if self._total_decisions == 0:
|
||||
return 0.0
|
||||
return self._total_cached_decisions / self._total_decisions
|
||||
|
||||
def reset_metrics(self) -> None:
|
||||
"""Reset token usage metrics."""
|
||||
self._total_tokens_used = 0
|
||||
self._total_decisions = 0
|
||||
self._total_cached_decisions = 0
|
||||
if self._cache:
|
||||
self._cache.reset_metrics()
|
||||
logger.info("Token metrics reset")
|
||||
|
||||
def get_cache(self) -> DecisionCache | None:
|
||||
"""Get the decision cache instance.
|
||||
|
||||
Returns:
|
||||
DecisionCache instance or None if caching disabled
|
||||
"""
|
||||
return self._cache
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Batch Decision Making (for daily trading mode)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def decide_batch(
|
||||
self, stocks_data: list[dict[str, Any]]
|
||||
) -> dict[str, TradeDecision]:
|
||||
"""Make decisions for multiple stocks in a single API call.
|
||||
|
||||
This is designed for daily trading mode to minimize API usage
|
||||
when working with Gemini Free tier (20 calls/day limit).
|
||||
|
||||
Args:
|
||||
stocks_data: List of market data dictionaries, each with:
|
||||
- stock_code: Stock ticker
|
||||
- current_price: Current price
|
||||
- market_name: Market name (optional)
|
||||
- foreigner_net: Foreigner net buy/sell (optional)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping stock_code to TradeDecision
|
||||
|
||||
Example:
|
||||
>>> stocks_data = [
|
||||
... {"stock_code": "AAPL", "current_price": 185.5},
|
||||
... {"stock_code": "MSFT", "current_price": 420.0},
|
||||
... ]
|
||||
>>> decisions = await client.decide_batch(stocks_data)
|
||||
>>> decisions["AAPL"].action
|
||||
'BUY'
|
||||
"""
|
||||
if not stocks_data:
|
||||
return {}
|
||||
|
||||
# Build compressed batch prompt
|
||||
market_name = stocks_data[0].get("market_name", "stock market")
|
||||
|
||||
# Format stock data as compact JSON array
|
||||
compact_stocks = []
|
||||
for stock in stocks_data:
|
||||
compact = {
|
||||
"code": stock["stock_code"],
|
||||
"price": stock["current_price"],
|
||||
}
|
||||
if stock.get("foreigner_net", 0) != 0:
|
||||
compact["frgn"] = stock["foreigner_net"]
|
||||
compact_stocks.append(compact)
|
||||
|
||||
data_str = json.dumps(compact_stocks, ensure_ascii=False)
|
||||
|
||||
prompt = (
|
||||
f"You are a professional {market_name} trading analyst.\n"
|
||||
"Analyze the following stocks and decide whether to BUY, SELL, or HOLD each one.\n\n"
|
||||
f"Stock Data: {data_str}\n\n"
|
||||
"You MUST respond with ONLY a valid JSON array in this format:\n"
|
||||
'[{"code": "AAPL", "action": "BUY", "confidence": 85, "rationale": "..."},\n'
|
||||
' {"code": "MSFT", "action": "HOLD", "confidence": 50, "rationale": "..."}, ...]\n\n'
|
||||
"Rules:\n"
|
||||
"- Return one decision object per stock\n"
|
||||
"- action must be exactly: BUY, SELL, or HOLD\n"
|
||||
"- confidence must be 0-100\n"
|
||||
"- rationale should be concise (1-2 sentences)\n"
|
||||
"- Do NOT wrap JSON in markdown code blocks\n"
|
||||
)
|
||||
|
||||
# Estimate tokens
|
||||
token_count = self._optimizer.estimate_tokens(prompt)
|
||||
self._total_tokens_used += token_count
|
||||
|
||||
logger.info(
|
||||
"Requesting batch decision for %d stocks from Gemini",
|
||||
len(stocks_data),
|
||||
extra={"estimated_tokens": token_count},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self._client.aio.models.generate_content(
|
||||
model=self._model_name,
|
||||
contents=prompt,
|
||||
)
|
||||
raw = response.text
|
||||
except Exception as exc:
|
||||
logger.error("Gemini API error in batch decision: %s", exc)
|
||||
# Return HOLD for all stocks on API error
|
||||
return {
|
||||
stock["stock_code"]: TradeDecision(
|
||||
action="HOLD",
|
||||
confidence=0,
|
||||
rationale=f"API error: {exc}",
|
||||
token_count=token_count,
|
||||
cached=False,
|
||||
)
|
||||
for stock in stocks_data
|
||||
}
|
||||
|
||||
# Parse batch response
|
||||
return self._parse_batch_response(raw, stocks_data, token_count)
|
||||
|
||||
def _parse_batch_response(
|
||||
self, raw: str, stocks_data: list[dict[str, Any]], token_count: int
|
||||
) -> dict[str, TradeDecision]:
|
||||
"""Parse batch response into a dictionary of decisions.
|
||||
|
||||
Args:
|
||||
raw: Raw response from Gemini
|
||||
stocks_data: Original stock data list
|
||||
token_count: Token count for the request
|
||||
|
||||
Returns:
|
||||
Dictionary mapping stock_code to TradeDecision
|
||||
"""
|
||||
if not raw or not raw.strip():
|
||||
logger.warning("Empty batch response from Gemini — defaulting all to HOLD")
|
||||
return {
|
||||
stock["stock_code"]: TradeDecision(
|
||||
action="HOLD",
|
||||
confidence=0,
|
||||
rationale="Empty response",
|
||||
token_count=0,
|
||||
cached=False,
|
||||
)
|
||||
for stock in stocks_data
|
||||
}
|
||||
|
||||
# Strip markdown code fences if present
|
||||
cleaned = raw.strip()
|
||||
match = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", cleaned, re.DOTALL)
|
||||
if match:
|
||||
cleaned = match.group(1).strip()
|
||||
|
||||
try:
|
||||
data = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Malformed JSON in batch response — defaulting all to HOLD")
|
||||
return {
|
||||
stock["stock_code"]: TradeDecision(
|
||||
action="HOLD",
|
||||
confidence=0,
|
||||
rationale="Malformed JSON response",
|
||||
token_count=0,
|
||||
cached=False,
|
||||
)
|
||||
for stock in stocks_data
|
||||
}
|
||||
|
||||
if not isinstance(data, list):
|
||||
logger.warning("Batch response is not a JSON array — defaulting all to HOLD")
|
||||
return {
|
||||
stock["stock_code"]: TradeDecision(
|
||||
action="HOLD",
|
||||
confidence=0,
|
||||
rationale="Invalid response format",
|
||||
token_count=0,
|
||||
cached=False,
|
||||
)
|
||||
for stock in stocks_data
|
||||
}
|
||||
|
||||
# Build decision map
|
||||
decisions: dict[str, TradeDecision] = {}
|
||||
stock_codes = {stock["stock_code"] for stock in stocks_data}
|
||||
|
||||
for item in data:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
code = item.get("code")
|
||||
if not code or code not in stock_codes:
|
||||
continue
|
||||
|
||||
# Validate required fields
|
||||
if not all(k in item for k in ("action", "confidence", "rationale")):
|
||||
logger.warning("Missing fields for %s — using HOLD", code)
|
||||
decisions[code] = TradeDecision(
|
||||
action="HOLD",
|
||||
confidence=0,
|
||||
rationale="Missing required fields",
|
||||
token_count=0,
|
||||
cached=False,
|
||||
)
|
||||
continue
|
||||
|
||||
action = str(item["action"]).upper()
|
||||
if action not in VALID_ACTIONS:
|
||||
logger.warning("Invalid action '%s' for %s — forcing HOLD", action, code)
|
||||
action = "HOLD"
|
||||
|
||||
confidence = int(item["confidence"])
|
||||
rationale = str(item["rationale"])
|
||||
|
||||
# Enforce confidence threshold
|
||||
if confidence < self._confidence_threshold:
|
||||
logger.info(
|
||||
"Confidence %d < threshold %d for %s — forcing HOLD",
|
||||
confidence,
|
||||
self._confidence_threshold,
|
||||
code,
|
||||
)
|
||||
action = "HOLD"
|
||||
|
||||
decisions[code] = TradeDecision(
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
rationale=rationale,
|
||||
token_count=token_count // len(stocks_data), # Split token cost
|
||||
cached=False,
|
||||
)
|
||||
self._total_decisions += 1
|
||||
|
||||
# Fill in missing stocks with HOLD
|
||||
for stock in stocks_data:
|
||||
code = stock["stock_code"]
|
||||
if code not in decisions:
|
||||
logger.warning("No decision for %s in batch response — using HOLD", code)
|
||||
decisions[code] = TradeDecision(
|
||||
action="HOLD",
|
||||
confidence=0,
|
||||
rationale="Not found in batch response",
|
||||
token_count=0,
|
||||
cached=False,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Batch decision completed for %d stocks",
|
||||
len(decisions),
|
||||
extra={"tokens": token_count},
|
||||
)
|
||||
|
||||
return decisions
|
||||
|
||||
267
src/brain/prompt_optimizer.py
Normal file
267
src/brain/prompt_optimizer.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""Prompt optimization utilities for reducing token usage.
|
||||
|
||||
This module provides tools to compress prompts while maintaining decision quality:
|
||||
- Token counting
|
||||
- Text compression and abbreviation
|
||||
- Template-based prompts with variable slots
|
||||
- Priority-based context truncation
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
# Abbreviation mapping for common terms
|
||||
ABBREVIATIONS = {
|
||||
"price": "P",
|
||||
"volume": "V",
|
||||
"current": "cur",
|
||||
"previous": "prev",
|
||||
"change": "chg",
|
||||
"percentage": "pct",
|
||||
"market": "mkt",
|
||||
"orderbook": "ob",
|
||||
"foreigner": "fgn",
|
||||
"buy": "B",
|
||||
"sell": "S",
|
||||
"hold": "H",
|
||||
"confidence": "conf",
|
||||
"rationale": "reason",
|
||||
"action": "act",
|
||||
"net": "net",
|
||||
}
|
||||
|
||||
# Reverse mapping for decompression
|
||||
REVERSE_ABBREVIATIONS = {v: k for k, v in ABBREVIATIONS.items()}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TokenMetrics:
|
||||
"""Metrics about token usage in a prompt."""
|
||||
|
||||
char_count: int
|
||||
word_count: int
|
||||
estimated_tokens: int # Rough estimate: ~4 chars per token
|
||||
compression_ratio: float = 1.0 # Original / Compressed
|
||||
|
||||
|
||||
class PromptOptimizer:
|
||||
"""Optimizes prompts to reduce token usage while maintaining quality."""
|
||||
|
||||
@staticmethod
|
||||
def estimate_tokens(text: str) -> int:
|
||||
"""Estimate token count for text.
|
||||
|
||||
Uses a simple heuristic: ~4 characters per token for English.
|
||||
This is approximate but sufficient for optimization purposes.
|
||||
|
||||
Args:
|
||||
text: Input text to estimate tokens for
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
# Simple estimate: 1 token ≈ 4 characters
|
||||
return max(1, len(text) // 4)
|
||||
|
||||
@staticmethod
|
||||
def count_tokens(text: str) -> TokenMetrics:
|
||||
"""Count various metrics for a text.
|
||||
|
||||
Args:
|
||||
text: Input text to analyze
|
||||
|
||||
Returns:
|
||||
TokenMetrics with character, word, and estimated token counts
|
||||
"""
|
||||
char_count = len(text)
|
||||
word_count = len(text.split())
|
||||
estimated_tokens = PromptOptimizer.estimate_tokens(text)
|
||||
|
||||
return TokenMetrics(
|
||||
char_count=char_count,
|
||||
word_count=word_count,
|
||||
estimated_tokens=estimated_tokens,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def compress_json(data: dict[str, Any]) -> str:
|
||||
"""Compress JSON by removing whitespace.
|
||||
|
||||
Args:
|
||||
data: Dictionary to serialize
|
||||
|
||||
Returns:
|
||||
Compact JSON string without whitespace
|
||||
"""
|
||||
return json.dumps(data, separators=(",", ":"), ensure_ascii=False)
|
||||
|
||||
@staticmethod
|
||||
def abbreviate_text(text: str, aggressive: bool = False) -> str:
|
||||
"""Apply abbreviations to reduce text length.
|
||||
|
||||
Args:
|
||||
text: Input text to abbreviate
|
||||
aggressive: If True, apply more aggressive compression
|
||||
|
||||
Returns:
|
||||
Abbreviated text
|
||||
"""
|
||||
result = text
|
||||
|
||||
# Apply word-level abbreviations (case-insensitive)
|
||||
for full, abbr in ABBREVIATIONS.items():
|
||||
# Word boundaries to avoid partial replacements
|
||||
pattern = r"\b" + re.escape(full) + r"\b"
|
||||
result = re.sub(pattern, abbr, result, flags=re.IGNORECASE)
|
||||
|
||||
if aggressive:
|
||||
# Remove articles and filler words
|
||||
result = re.sub(r"\b(a|an|the)\b", "", result, flags=re.IGNORECASE)
|
||||
result = re.sub(r"\b(is|are|was|were)\b", "", result, flags=re.IGNORECASE)
|
||||
# Collapse multiple spaces
|
||||
result = re.sub(r"\s+", " ", result)
|
||||
|
||||
return result.strip()
|
||||
|
||||
@staticmethod
|
||||
def build_compressed_prompt(
|
||||
market_data: dict[str, Any],
|
||||
include_instructions: bool = True,
|
||||
max_length: int | None = None,
|
||||
) -> str:
|
||||
"""Build a compressed prompt from market data.
|
||||
|
||||
Args:
|
||||
market_data: Market data dictionary with stock info
|
||||
include_instructions: Whether to include full instructions
|
||||
max_length: Maximum character length (truncates if needed)
|
||||
|
||||
Returns:
|
||||
Compressed prompt string
|
||||
"""
|
||||
# Abbreviated market name
|
||||
market_name = market_data.get("market_name", "KR")
|
||||
if "Korea" in market_name:
|
||||
market_name = "KR"
|
||||
elif "United States" in market_name or "US" in market_name:
|
||||
market_name = "US"
|
||||
|
||||
# Core data - always included
|
||||
core_info = {
|
||||
"mkt": market_name,
|
||||
"code": market_data["stock_code"],
|
||||
"P": market_data["current_price"],
|
||||
}
|
||||
|
||||
# Optional fields
|
||||
if "orderbook" in market_data and market_data["orderbook"]:
|
||||
ob = market_data["orderbook"]
|
||||
# Compress orderbook: keep only top 3 levels
|
||||
compressed_ob = {
|
||||
"bid": ob.get("bid", [])[:3],
|
||||
"ask": ob.get("ask", [])[:3],
|
||||
}
|
||||
core_info["ob"] = compressed_ob
|
||||
|
||||
if market_data.get("foreigner_net", 0) != 0:
|
||||
core_info["fgn_net"] = market_data["foreigner_net"]
|
||||
|
||||
# Compress to JSON
|
||||
data_str = PromptOptimizer.compress_json(core_info)
|
||||
|
||||
if include_instructions:
|
||||
# Minimal instructions
|
||||
prompt = (
|
||||
f"{market_name} trader. Analyze:\n{data_str}\n\n"
|
||||
'Return JSON: {"act":"BUY"|"SELL"|"HOLD","conf":<0-100>,"reason":"<text>"}\n'
|
||||
"Rules: act=BUY/SELL/HOLD, conf=0-100, reason=concise. No markdown."
|
||||
)
|
||||
else:
|
||||
# Data only (for cached contexts where instructions are known)
|
||||
prompt = data_str
|
||||
|
||||
# Truncate if needed
|
||||
if max_length and len(prompt) > max_length:
|
||||
prompt = prompt[:max_length] + "..."
|
||||
|
||||
return prompt
|
||||
|
||||
@staticmethod
|
||||
def truncate_context(
|
||||
context: dict[str, Any],
|
||||
max_tokens: int,
|
||||
priority_keys: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Truncate context data to fit within token budget.
|
||||
|
||||
Keeps high-priority keys first, then truncates less important data.
|
||||
|
||||
Args:
|
||||
context: Context dictionary to truncate
|
||||
max_tokens: Maximum token budget
|
||||
priority_keys: List of keys to keep (in order of priority)
|
||||
|
||||
Returns:
|
||||
Truncated context dictionary
|
||||
"""
|
||||
if not context:
|
||||
return {}
|
||||
|
||||
if priority_keys is None:
|
||||
priority_keys = []
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
current_tokens = 0
|
||||
|
||||
# Add priority keys first
|
||||
for key in priority_keys:
|
||||
if key in context:
|
||||
value_str = json.dumps(context[key])
|
||||
tokens = PromptOptimizer.estimate_tokens(value_str)
|
||||
|
||||
if current_tokens + tokens <= max_tokens:
|
||||
result[key] = context[key]
|
||||
current_tokens += tokens
|
||||
else:
|
||||
break
|
||||
|
||||
# Add remaining keys if space available
|
||||
for key, value in context.items():
|
||||
if key in result:
|
||||
continue
|
||||
|
||||
value_str = json.dumps(value)
|
||||
tokens = PromptOptimizer.estimate_tokens(value_str)
|
||||
|
||||
if current_tokens + tokens <= max_tokens:
|
||||
result[key] = value
|
||||
current_tokens += tokens
|
||||
else:
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def calculate_compression_ratio(original: str, compressed: str) -> float:
|
||||
"""Calculate compression ratio between original and compressed text.
|
||||
|
||||
Args:
|
||||
original: Original text
|
||||
compressed: Compressed text
|
||||
|
||||
Returns:
|
||||
Compression ratio (original_tokens / compressed_tokens)
|
||||
"""
|
||||
original_tokens = PromptOptimizer.estimate_tokens(original)
|
||||
compressed_tokens = PromptOptimizer.estimate_tokens(compressed)
|
||||
|
||||
if compressed_tokens == 0:
|
||||
return 1.0
|
||||
|
||||
return original_tokens / compressed_tokens
|
||||
@@ -8,7 +8,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import ssl
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import aiohttp
|
||||
|
||||
@@ -20,6 +20,39 @@ _KIS_VTS_HOST = "openapivts.koreainvestment.com"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def kr_tick_unit(price: float) -> int:
|
||||
"""Return KRX tick size for the given price level.
|
||||
|
||||
KRX price tick rules (domestic stocks):
|
||||
price < 2,000 → 1원
|
||||
2,000 ≤ price < 5,000 → 5원
|
||||
5,000 ≤ price < 20,000 → 10원
|
||||
20,000 ≤ price < 50,000 → 50원
|
||||
50,000 ≤ price < 200,000 → 100원
|
||||
200,000 ≤ price < 500,000 → 500원
|
||||
500,000 ≤ price → 1,000원
|
||||
"""
|
||||
if price < 2_000:
|
||||
return 1
|
||||
if price < 5_000:
|
||||
return 5
|
||||
if price < 20_000:
|
||||
return 10
|
||||
if price < 50_000:
|
||||
return 50
|
||||
if price < 200_000:
|
||||
return 100
|
||||
if price < 500_000:
|
||||
return 500
|
||||
return 1_000
|
||||
|
||||
|
||||
def kr_round_down(price: float) -> int:
|
||||
"""Round *down* price to the nearest KRX tick unit."""
|
||||
tick = kr_tick_unit(price)
|
||||
return int(price // tick * tick)
|
||||
|
||||
|
||||
class LeakyBucket:
|
||||
"""Simple leaky-bucket rate limiter for async code."""
|
||||
|
||||
@@ -55,6 +88,9 @@ class KISBroker:
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._access_token: str | None = None
|
||||
self._token_expires_at: float = 0.0
|
||||
self._token_lock = asyncio.Lock()
|
||||
self._last_refresh_attempt: float = 0.0
|
||||
self._refresh_cooldown: float = 60.0 # Seconds (matches KIS 1/minute limit)
|
||||
self._rate_limiter = LeakyBucket(settings.RATE_LIMIT_RPS)
|
||||
|
||||
def _get_session(self) -> aiohttp.ClientSession:
|
||||
@@ -80,12 +116,38 @@ class KISBroker:
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _ensure_token(self) -> str:
|
||||
"""Return a valid access token, refreshing if expired."""
|
||||
"""Return a valid access token, refreshing if expired.
|
||||
|
||||
Uses a lock to prevent concurrent token refresh attempts that would
|
||||
hit the API's 1-per-minute rate limit (EGW00133).
|
||||
"""
|
||||
# Fast path: check without lock
|
||||
now = asyncio.get_event_loop().time()
|
||||
if self._access_token and now < self._token_expires_at:
|
||||
return self._access_token
|
||||
|
||||
# Slow path: acquire lock and refresh
|
||||
async with self._token_lock:
|
||||
# Re-check after acquiring lock (another coroutine may have refreshed)
|
||||
now = asyncio.get_event_loop().time()
|
||||
if self._access_token and now < self._token_expires_at:
|
||||
return self._access_token
|
||||
|
||||
# Check cooldown period (prevents hitting EGW00133: 1/minute limit)
|
||||
time_since_last_attempt = now - self._last_refresh_attempt
|
||||
if time_since_last_attempt < self._refresh_cooldown:
|
||||
remaining = self._refresh_cooldown - time_since_last_attempt
|
||||
# Do not fail fast here. If token is unavailable, upstream calls
|
||||
# will all fail for up to a minute and scanning returns no trades.
|
||||
logger.warning(
|
||||
"Token refresh on cooldown. Waiting %.1fs before retry (KIS allows 1/minute)",
|
||||
remaining,
|
||||
)
|
||||
await asyncio.sleep(remaining)
|
||||
now = asyncio.get_event_loop().time()
|
||||
|
||||
logger.info("Refreshing KIS access token")
|
||||
self._last_refresh_attempt = now
|
||||
session = self._get_session()
|
||||
url = f"{self._base_url}/oauth2/tokenP"
|
||||
body = {
|
||||
@@ -111,6 +173,7 @@ class KISBroker:
|
||||
|
||||
async def _get_hash_key(self, body: dict[str, Any]) -> str:
|
||||
"""Request a hash key from KIS for POST request body signing."""
|
||||
await self._rate_limiter.acquire()
|
||||
session = self._get_session()
|
||||
url = f"{self._base_url}/uapi/hashkey"
|
||||
headers = {
|
||||
@@ -168,12 +231,64 @@ class KISBroker:
|
||||
except (TimeoutError, aiohttp.ClientError) as exc:
|
||||
raise ConnectionError(f"Network error fetching orderbook: {exc}") from exc
|
||||
|
||||
async def get_current_price(
|
||||
self, stock_code: str
|
||||
) -> tuple[float, float, float]:
|
||||
"""Fetch current price data for a domestic stock.
|
||||
|
||||
Uses the ``inquire-price`` API (FHKST01010100), which works in both
|
||||
real and VTS environments and returns the actual last-traded price.
|
||||
|
||||
Returns:
|
||||
(current_price, prdy_ctrt, frgn_ntby_qty)
|
||||
- current_price: Last traded price in KRW.
|
||||
- prdy_ctrt: Day change rate (%).
|
||||
- frgn_ntby_qty: Foreigner net buy quantity.
|
||||
"""
|
||||
await self._rate_limiter.acquire()
|
||||
session = self._get_session()
|
||||
|
||||
headers = await self._auth_headers("FHKST01010100")
|
||||
params = {
|
||||
"FID_COND_MRKT_DIV_CODE": "J",
|
||||
"FID_INPUT_ISCD": stock_code,
|
||||
}
|
||||
url = f"{self._base_url}/uapi/domestic-stock/v1/quotations/inquire-price"
|
||||
|
||||
def _f(val: str | None) -> float:
|
||||
try:
|
||||
return float(val or "0")
|
||||
except ValueError:
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
async with session.get(url, headers=headers, params=params) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ConnectionError(
|
||||
f"get_current_price failed ({resp.status}): {text}"
|
||||
)
|
||||
data = await resp.json()
|
||||
out = data.get("output", {})
|
||||
return (
|
||||
_f(out.get("stck_prpr")),
|
||||
_f(out.get("prdy_ctrt")),
|
||||
_f(out.get("frgn_ntby_qty")),
|
||||
)
|
||||
except (TimeoutError, aiohttp.ClientError) as exc:
|
||||
raise ConnectionError(
|
||||
f"Network error fetching current price: {exc}"
|
||||
) from exc
|
||||
|
||||
async def get_balance(self) -> dict[str, Any]:
|
||||
"""Fetch current account balance and holdings."""
|
||||
await self._rate_limiter.acquire()
|
||||
session = self._get_session()
|
||||
|
||||
headers = await self._auth_headers("VTTC8434R") # 모의투자 잔고조회
|
||||
# TR_ID: 실전 TTTC8434R, 모의 VTTC8434R
|
||||
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '국내주식 잔고조회' 시트
|
||||
tr_id = "TTTC8434R" if self._settings.MODE == "live" else "VTTC8434R"
|
||||
headers = await self._auth_headers(tr_id)
|
||||
params = {
|
||||
"CANO": self._account_no,
|
||||
"ACNT_PRDT_CD": self._product_cd,
|
||||
@@ -218,14 +333,30 @@ class KISBroker:
|
||||
await self._rate_limiter.acquire()
|
||||
session = self._get_session()
|
||||
|
||||
tr_id = "VTTC0802U" if order_type == "BUY" else "VTTC0801U"
|
||||
# TR_ID: 실전 BUY=TTTC0012U SELL=TTTC0011U, 모의 BUY=VTTC0012U SELL=VTTC0011U
|
||||
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '주식주문(현금)' 시트
|
||||
# ※ TTTC0802U/VTTC0802U는 미수매수(증거금40% 계좌 전용) — 현금주문에 사용 금지
|
||||
if self._settings.MODE == "live":
|
||||
tr_id = "TTTC0012U" if order_type == "BUY" else "TTTC0011U"
|
||||
else:
|
||||
tr_id = "VTTC0012U" if order_type == "BUY" else "VTTC0011U"
|
||||
|
||||
# KRX requires limit orders to be rounded down to the tick unit.
|
||||
# ORD_DVSN: "00"=지정가, "01"=시장가
|
||||
if price > 0:
|
||||
ord_dvsn = "00" # 지정가
|
||||
ord_price = kr_round_down(price)
|
||||
else:
|
||||
ord_dvsn = "01" # 시장가
|
||||
ord_price = 0
|
||||
|
||||
body = {
|
||||
"CANO": self._account_no,
|
||||
"ACNT_PRDT_CD": self._product_cd,
|
||||
"PDNO": stock_code,
|
||||
"ORD_DVSN": "01" if price > 0 else "06", # 01=지정가, 06=시장가
|
||||
"ORD_DVSN": ord_dvsn,
|
||||
"ORD_QTY": str(quantity),
|
||||
"ORD_UNPR": str(price),
|
||||
"ORD_UNPR": str(ord_price),
|
||||
}
|
||||
|
||||
hash_key = await self._get_hash_key(body)
|
||||
@@ -252,3 +383,279 @@ class KISBroker:
|
||||
return data
|
||||
except (TimeoutError, aiohttp.ClientError) as exc:
|
||||
raise ConnectionError(f"Network error sending order: {exc}") from exc
|
||||
|
||||
async def fetch_market_rankings(
|
||||
self,
|
||||
ranking_type: str = "volume",
|
||||
limit: int = 30,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch market rankings from KIS API.
|
||||
|
||||
Args:
|
||||
ranking_type: Type of ranking ("volume" or "fluctuation")
|
||||
limit: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List of stock data dicts with keys: stock_code, name, price, volume,
|
||||
change_rate, volume_increase_rate
|
||||
|
||||
Raises:
|
||||
ConnectionError: If API request fails
|
||||
"""
|
||||
await self._rate_limiter.acquire()
|
||||
session = self._get_session()
|
||||
|
||||
if ranking_type == "volume":
|
||||
# 거래량순위: FHPST01710000 / /quotations/volume-rank
|
||||
tr_id = "FHPST01710000"
|
||||
url = f"{self._base_url}/uapi/domestic-stock/v1/quotations/volume-rank"
|
||||
params: dict[str, str] = {
|
||||
"FID_COND_MRKT_DIV_CODE": "J",
|
||||
"FID_COND_SCR_DIV_CODE": "20171",
|
||||
"FID_INPUT_ISCD": "0000",
|
||||
"FID_DIV_CLS_CODE": "0",
|
||||
"FID_BLNG_CLS_CODE": "0",
|
||||
"FID_TRGT_CLS_CODE": "111111111",
|
||||
"FID_TRGT_EXLS_CLS_CODE": "0000000000",
|
||||
"FID_INPUT_PRICE_1": "0",
|
||||
"FID_INPUT_PRICE_2": "0",
|
||||
"FID_VOL_CNT": "0",
|
||||
"FID_INPUT_DATE_1": "",
|
||||
}
|
||||
else:
|
||||
# 등락률순위: FHPST01700000 / /ranking/fluctuation (소문자 파라미터)
|
||||
tr_id = "FHPST01700000"
|
||||
url = f"{self._base_url}/uapi/domestic-stock/v1/ranking/fluctuation"
|
||||
params = {
|
||||
"fid_cond_mrkt_div_code": "J",
|
||||
"fid_cond_scr_div_code": "20170",
|
||||
"fid_input_iscd": "0000",
|
||||
"fid_rank_sort_cls_code": "0000",
|
||||
"fid_input_cnt_1": str(limit),
|
||||
"fid_prc_cls_code": "0",
|
||||
"fid_input_price_1": "0",
|
||||
"fid_input_price_2": "0",
|
||||
"fid_vol_cnt": "0",
|
||||
"fid_trgt_cls_code": "0",
|
||||
"fid_trgt_exls_cls_code": "0",
|
||||
"fid_div_cls_code": "0",
|
||||
"fid_rsfl_rate1": "0",
|
||||
"fid_rsfl_rate2": "0",
|
||||
}
|
||||
|
||||
headers = await self._auth_headers(tr_id)
|
||||
|
||||
try:
|
||||
async with session.get(url, headers=headers, params=params) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ConnectionError(
|
||||
f"fetch_market_rankings failed ({resp.status}): {text}"
|
||||
)
|
||||
data = await resp.json()
|
||||
|
||||
# Parse response - output is a list of ranked stocks
|
||||
def _safe_float(value: str | float | None, default: float = 0.0) -> float:
|
||||
if value is None or value == "":
|
||||
return default
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
return default
|
||||
|
||||
rankings = []
|
||||
for item in data.get("output", [])[:limit]:
|
||||
rankings.append({
|
||||
"stock_code": item.get("mksc_shrn_iscd", ""),
|
||||
"name": item.get("hts_kor_isnm", ""),
|
||||
"price": _safe_float(item.get("stck_prpr", "0")),
|
||||
"volume": _safe_float(item.get("acml_vol", "0")),
|
||||
"change_rate": _safe_float(item.get("prdy_ctrt", "0")),
|
||||
"volume_increase_rate": _safe_float(item.get("vol_inrt", "0")),
|
||||
})
|
||||
return rankings
|
||||
|
||||
except (TimeoutError, aiohttp.ClientError) as exc:
|
||||
raise ConnectionError(f"Network error fetching rankings: {exc}") from exc
|
||||
|
||||
async def get_domestic_pending_orders(self) -> list[dict[str, Any]]:
|
||||
"""Fetch unfilled (pending) domestic limit orders.
|
||||
|
||||
The KIS pending-orders API (TTTC0084R) is unsupported in paper (VTS)
|
||||
mode, so this method returns an empty list immediately when MODE is
|
||||
not "live".
|
||||
|
||||
Returns:
|
||||
List of pending order dicts from the KIS ``output`` field.
|
||||
Each dict includes keys such as ``odno``, ``orgn_odno``,
|
||||
``ord_gno_brno``, ``psbl_qty``, ``sll_buy_dvsn_cd``, ``pdno``.
|
||||
"""
|
||||
if self._settings.MODE != "live":
|
||||
logger.debug(
|
||||
"get_domestic_pending_orders: paper mode — TTTC0084R unsupported, returning []"
|
||||
)
|
||||
return []
|
||||
|
||||
await self._rate_limiter.acquire()
|
||||
session = self._get_session()
|
||||
|
||||
# TR_ID: 실전 TTTC0084R (모의 미지원)
|
||||
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '주식 미체결조회' 시트
|
||||
headers = await self._auth_headers("TTTC0084R")
|
||||
params = {
|
||||
"CANO": self._account_no,
|
||||
"ACNT_PRDT_CD": self._product_cd,
|
||||
"INQR_DVSN_1": "0",
|
||||
"INQR_DVSN_2": "0",
|
||||
"CTX_AREA_FK100": "",
|
||||
"CTX_AREA_NK100": "",
|
||||
}
|
||||
url = f"{self._base_url}/uapi/domestic-stock/v1/trading/inquire-psbl-rvsecncl"
|
||||
|
||||
try:
|
||||
async with session.get(url, headers=headers, params=params) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ConnectionError(
|
||||
f"get_domestic_pending_orders failed ({resp.status}): {text}"
|
||||
)
|
||||
data = await resp.json()
|
||||
return data.get("output", []) or []
|
||||
except (TimeoutError, aiohttp.ClientError) as exc:
|
||||
raise ConnectionError(
|
||||
f"Network error fetching domestic pending orders: {exc}"
|
||||
) from exc
|
||||
|
||||
async def cancel_domestic_order(
|
||||
self,
|
||||
stock_code: str,
|
||||
orgn_odno: str,
|
||||
krx_fwdg_ord_orgno: str,
|
||||
qty: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Cancel an unfilled domestic limit order.
|
||||
|
||||
Args:
|
||||
stock_code: 6-digit domestic stock code (``pdno``).
|
||||
orgn_odno: Original order number from pending-orders response
|
||||
(``orgn_odno`` field).
|
||||
krx_fwdg_ord_orgno: KRX forwarding order branch number from
|
||||
pending-orders response (``ord_gno_brno`` field).
|
||||
qty: Quantity to cancel (use ``psbl_qty`` from pending order).
|
||||
|
||||
Returns:
|
||||
Raw KIS API response dict (check ``rt_cd == "0"`` for success).
|
||||
"""
|
||||
await self._rate_limiter.acquire()
|
||||
session = self._get_session()
|
||||
|
||||
# TR_ID: 실전 TTTC0013U, 모의 VTTC0013U
|
||||
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '주식주문(정정취소)' 시트
|
||||
tr_id = "TTTC0013U" if self._settings.MODE == "live" else "VTTC0013U"
|
||||
|
||||
body = {
|
||||
"CANO": self._account_no,
|
||||
"ACNT_PRDT_CD": self._product_cd,
|
||||
"KRX_FWDG_ORD_ORGNO": krx_fwdg_ord_orgno,
|
||||
"ORGN_ODNO": orgn_odno,
|
||||
"ORD_DVSN": "00",
|
||||
"ORD_QTY": str(qty),
|
||||
"ORD_UNPR": "0",
|
||||
"RVSE_CNCL_DVSN_CD": "02",
|
||||
"QTY_ALL_ORD_YN": "Y",
|
||||
}
|
||||
|
||||
hash_key = await self._get_hash_key(body)
|
||||
headers = await self._auth_headers(tr_id)
|
||||
headers["hashkey"] = hash_key
|
||||
|
||||
url = f"{self._base_url}/uapi/domestic-stock/v1/trading/order-rvsecncl"
|
||||
|
||||
try:
|
||||
async with session.post(url, headers=headers, json=body) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ConnectionError(
|
||||
f"cancel_domestic_order failed ({resp.status}): {text}"
|
||||
)
|
||||
return cast(dict[str, Any], await resp.json())
|
||||
except (TimeoutError, aiohttp.ClientError) as exc:
|
||||
raise ConnectionError(
|
||||
f"Network error cancelling domestic order: {exc}"
|
||||
) from exc
|
||||
|
||||
async def get_daily_prices(
|
||||
self,
|
||||
stock_code: str,
|
||||
days: int = 20,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch daily OHLCV price history for a stock.
|
||||
|
||||
Args:
|
||||
stock_code: 6-digit stock code
|
||||
days: Number of trading days to fetch (default 20 for RSI calculation)
|
||||
|
||||
Returns:
|
||||
List of daily price dicts with keys: date, open, high, low, close, volume
|
||||
Sorted oldest to newest
|
||||
|
||||
Raises:
|
||||
ConnectionError: If API request fails
|
||||
"""
|
||||
await self._rate_limiter.acquire()
|
||||
session = self._get_session()
|
||||
|
||||
headers = await self._auth_headers("FHKST03010100")
|
||||
|
||||
# Calculate date range (today and N days ago)
|
||||
from datetime import datetime, timedelta
|
||||
end_date = datetime.now().strftime("%Y%m%d")
|
||||
start_date = (datetime.now() - timedelta(days=days + 10)).strftime("%Y%m%d")
|
||||
|
||||
params = {
|
||||
"FID_COND_MRKT_DIV_CODE": "J",
|
||||
"FID_INPUT_ISCD": stock_code,
|
||||
"FID_INPUT_DATE_1": start_date,
|
||||
"FID_INPUT_DATE_2": end_date,
|
||||
"FID_PERIOD_DIV_CODE": "D", # Daily
|
||||
"FID_ORG_ADJ_PRC": "0", # Adjusted price
|
||||
}
|
||||
|
||||
url = f"{self._base_url}/uapi/domestic-stock/v1/quotations/inquire-daily-itemchartprice"
|
||||
|
||||
try:
|
||||
async with session.get(url, headers=headers, params=params) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ConnectionError(
|
||||
f"get_daily_prices failed ({resp.status}): {text}"
|
||||
)
|
||||
data = await resp.json()
|
||||
|
||||
# Parse response
|
||||
def _safe_float(value: str | float | None, default: float = 0.0) -> float:
|
||||
if value is None or value == "":
|
||||
return default
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
return default
|
||||
|
||||
prices = []
|
||||
for item in data.get("output2", []):
|
||||
prices.append({
|
||||
"date": item.get("stck_bsop_date", ""),
|
||||
"open": _safe_float(item.get("stck_oprc", "0")),
|
||||
"high": _safe_float(item.get("stck_hgpr", "0")),
|
||||
"low": _safe_float(item.get("stck_lwpr", "0")),
|
||||
"close": _safe_float(item.get("stck_clpr", "0")),
|
||||
"volume": _safe_float(item.get("acml_vol", "0")),
|
||||
})
|
||||
|
||||
# Sort oldest to newest (KIS returns newest first)
|
||||
prices.reverse()
|
||||
|
||||
return prices[:days] # Return only requested number of days
|
||||
|
||||
except (TimeoutError, aiohttp.ClientError) as exc:
|
||||
raise ConnectionError(f"Network error fetching daily prices: {exc}") from exc
|
||||
|
||||
@@ -12,6 +12,38 @@ from src.broker.kis_api import KISBroker
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Ranking API uses different exchange codes than order/quote APIs.
|
||||
_RANKING_EXCHANGE_MAP: dict[str, str] = {
|
||||
"NASD": "NAS",
|
||||
"NYSE": "NYS",
|
||||
"AMEX": "AMS",
|
||||
"SEHK": "HKS",
|
||||
"SHAA": "SHS",
|
||||
"SZAA": "SZS",
|
||||
"HSX": "HSX",
|
||||
"HNX": "HNX",
|
||||
"TSE": "TSE",
|
||||
}
|
||||
|
||||
# Price inquiry API (HHDFS00000300) uses the same short exchange codes as rankings.
|
||||
# NASD → NAS, NYSE → NYS, AMEX → AMS (confirmed: AMEX returns empty, AMS returns price).
|
||||
_PRICE_EXCHANGE_MAP: dict[str, str] = _RANKING_EXCHANGE_MAP
|
||||
|
||||
# Cancel order TR_IDs per exchange code — (live_tr_id, paper_tr_id).
|
||||
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '해외주식 주문취소' 시트
|
||||
_CANCEL_TR_ID_MAP: dict[str, tuple[str, str]] = {
|
||||
"NASD": ("TTTT1004U", "VTTT1004U"),
|
||||
"NYSE": ("TTTT1004U", "VTTT1004U"),
|
||||
"AMEX": ("TTTT1004U", "VTTT1004U"),
|
||||
"SEHK": ("TTTS1003U", "VTTS1003U"),
|
||||
"TSE": ("TTTS0309U", "VTTS0309U"),
|
||||
"SHAA": ("TTTS0302U", "VTTS0302U"),
|
||||
"SZAA": ("TTTS0306U", "VTTS0306U"),
|
||||
"HNX": ("TTTS0312U", "VTTS0312U"),
|
||||
"HSX": ("TTTS0312U", "VTTS0312U"),
|
||||
}
|
||||
|
||||
|
||||
class OverseasBroker:
|
||||
"""KIS Overseas Stock API wrapper that reuses KISBroker infrastructure."""
|
||||
|
||||
@@ -44,9 +76,11 @@ class OverseasBroker:
|
||||
session = self._broker._get_session()
|
||||
|
||||
headers = await self._broker._auth_headers("HHDFS00000300")
|
||||
# Map internal exchange codes to the short form expected by the price API.
|
||||
price_excd = _PRICE_EXCHANGE_MAP.get(exchange_code, exchange_code)
|
||||
params = {
|
||||
"AUTH": "",
|
||||
"EXCD": exchange_code,
|
||||
"EXCD": price_excd,
|
||||
"SYMB": stock_code,
|
||||
}
|
||||
url = f"{self._broker._base_url}/uapi/overseas-price/v1/quotations/price"
|
||||
@@ -64,6 +98,81 @@ class OverseasBroker:
|
||||
f"Network error fetching overseas price: {exc}"
|
||||
) from exc
|
||||
|
||||
async def fetch_overseas_rankings(
|
||||
self,
|
||||
exchange_code: str,
|
||||
ranking_type: str = "fluctuation",
|
||||
limit: int = 30,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch overseas rankings (price change or volume surge).
|
||||
|
||||
Ranking API specs may differ by account/product. Endpoint paths and
|
||||
TR_IDs are configurable via settings and can be overridden in .env.
|
||||
"""
|
||||
if not self._broker._settings.OVERSEAS_RANKING_ENABLED:
|
||||
return []
|
||||
|
||||
await self._broker._rate_limiter.acquire()
|
||||
session = self._broker._get_session()
|
||||
|
||||
ranking_excd = _RANKING_EXCHANGE_MAP.get(exchange_code, exchange_code)
|
||||
|
||||
if ranking_type == "volume":
|
||||
tr_id = self._broker._settings.OVERSEAS_RANKING_VOLUME_TR_ID
|
||||
path = self._broker._settings.OVERSEAS_RANKING_VOLUME_PATH
|
||||
params: dict[str, str] = {
|
||||
"AUTH": "",
|
||||
"EXCD": ranking_excd,
|
||||
"MIXN": "0",
|
||||
"VOL_RANG": "0",
|
||||
}
|
||||
else:
|
||||
tr_id = self._broker._settings.OVERSEAS_RANKING_FLUCT_TR_ID
|
||||
path = self._broker._settings.OVERSEAS_RANKING_FLUCT_PATH
|
||||
params = {
|
||||
"AUTH": "",
|
||||
"EXCD": ranking_excd,
|
||||
"NDAY": "0",
|
||||
"GUBN": "1",
|
||||
"VOL_RANG": "0",
|
||||
}
|
||||
|
||||
headers = await self._broker._auth_headers(tr_id)
|
||||
url = f"{self._broker._base_url}{path}"
|
||||
|
||||
try:
|
||||
async with session.get(url, headers=headers, params=params) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
if resp.status == 404:
|
||||
logger.warning(
|
||||
"Overseas ranking endpoint unavailable (404) for %s/%s; "
|
||||
"using symbol fallback scan",
|
||||
exchange_code,
|
||||
ranking_type,
|
||||
)
|
||||
return []
|
||||
raise ConnectionError(
|
||||
f"fetch_overseas_rankings failed ({resp.status}): {text}"
|
||||
)
|
||||
|
||||
data = await resp.json()
|
||||
rows = self._extract_ranking_rows(data)
|
||||
if rows:
|
||||
return rows[:limit]
|
||||
|
||||
logger.debug(
|
||||
"Overseas ranking returned empty for %s/%s (keys=%s)",
|
||||
exchange_code,
|
||||
ranking_type,
|
||||
list(data.keys()),
|
||||
)
|
||||
return []
|
||||
except (TimeoutError, aiohttp.ClientError) as exc:
|
||||
raise ConnectionError(
|
||||
f"Network error fetching overseas rankings: {exc}"
|
||||
) from exc
|
||||
|
||||
async def get_overseas_balance(self, exchange_code: str) -> dict[str, Any]:
|
||||
"""
|
||||
Fetch overseas account balance.
|
||||
@@ -80,8 +189,12 @@ class OverseasBroker:
|
||||
await self._broker._rate_limiter.acquire()
|
||||
session = self._broker._get_session()
|
||||
|
||||
# Virtual trading TR_ID for overseas balance inquiry
|
||||
headers = await self._broker._auth_headers("VTTS3012R")
|
||||
# TR_ID: 실전 TTTS3012R, 모의 VTTS3012R
|
||||
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '해외주식 잔고조회' 시트
|
||||
balance_tr_id = (
|
||||
"TTTS3012R" if self._broker._settings.MODE == "live" else "VTTS3012R"
|
||||
)
|
||||
headers = await self._broker._auth_headers(balance_tr_id)
|
||||
params = {
|
||||
"CANO": self._broker._account_no,
|
||||
"ACNT_PRDT_CD": self._broker._product_cd,
|
||||
@@ -134,8 +247,12 @@ class OverseasBroker:
|
||||
await self._broker._rate_limiter.acquire()
|
||||
session = self._broker._get_session()
|
||||
|
||||
# Virtual trading TR_IDs for overseas orders
|
||||
tr_id = "VTTT1002U" if order_type == "BUY" else "VTTT1006U"
|
||||
# TR_ID: 실전 BUY=TTTT1002U SELL=TTTT1006U, 모의 BUY=VTTT1002U SELL=VTTT1001U
|
||||
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '해외주식 주문' 시트
|
||||
if self._broker._settings.MODE == "live":
|
||||
tr_id = "TTTT1002U" if order_type == "BUY" else "TTTT1006U"
|
||||
else:
|
||||
tr_id = "VTTT1002U" if order_type == "BUY" else "VTTT1001U"
|
||||
|
||||
body = {
|
||||
"CANO": self._broker._account_no,
|
||||
@@ -162,6 +279,9 @@ class OverseasBroker:
|
||||
f"send_overseas_order failed ({resp.status}): {text}"
|
||||
)
|
||||
data = await resp.json()
|
||||
rt_cd = data.get("rt_cd", "")
|
||||
msg1 = data.get("msg1", "")
|
||||
if rt_cd == "0":
|
||||
logger.info(
|
||||
"Overseas order submitted",
|
||||
extra={
|
||||
@@ -170,12 +290,147 @@ class OverseasBroker:
|
||||
"action": order_type,
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Overseas order rejected (rt_cd=%s): %s [%s %s %s qty=%d]",
|
||||
rt_cd,
|
||||
msg1,
|
||||
order_type,
|
||||
stock_code,
|
||||
exchange_code,
|
||||
quantity,
|
||||
)
|
||||
return data
|
||||
except (TimeoutError, aiohttp.ClientError) as exc:
|
||||
raise ConnectionError(
|
||||
f"Network error sending overseas order: {exc}"
|
||||
) from exc
|
||||
|
||||
async def get_overseas_pending_orders(
|
||||
self, exchange_code: str
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch unfilled (pending) overseas orders for a given exchange.
|
||||
|
||||
Args:
|
||||
exchange_code: Exchange code (e.g., "NASD", "SEHK").
|
||||
For US markets, NASD returns all US pending orders (NASD/NYSE/AMEX).
|
||||
|
||||
Returns:
|
||||
List of pending order dicts with fields: odno, pdno, sll_buy_dvsn_cd,
|
||||
ft_ord_qty, nccs_qty, ft_ord_unpr3, ovrs_excg_cd.
|
||||
Always returns [] in paper mode (TTTS3018R is live-only).
|
||||
|
||||
Raises:
|
||||
ConnectionError: On network or API errors (live mode only).
|
||||
"""
|
||||
if self._broker._settings.MODE != "live":
|
||||
logger.debug(
|
||||
"Pending orders API (TTTS3018R) not supported in paper mode; returning []"
|
||||
)
|
||||
return []
|
||||
|
||||
await self._broker._rate_limiter.acquire()
|
||||
session = self._broker._get_session()
|
||||
|
||||
# TTTS3018R: 해외주식 미체결내역조회 (실전 전용)
|
||||
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '해외주식 미체결조회' 시트
|
||||
headers = await self._broker._auth_headers("TTTS3018R")
|
||||
params = {
|
||||
"CANO": self._broker._account_no,
|
||||
"ACNT_PRDT_CD": self._broker._product_cd,
|
||||
"OVRS_EXCG_CD": exchange_code,
|
||||
"SORT_SQN": "DS",
|
||||
"CTX_AREA_FK200": "",
|
||||
"CTX_AREA_NK200": "",
|
||||
}
|
||||
url = (
|
||||
f"{self._broker._base_url}/uapi/overseas-stock/v1/trading/inquire-nccs"
|
||||
)
|
||||
|
||||
try:
|
||||
async with session.get(url, headers=headers, params=params) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ConnectionError(
|
||||
f"get_overseas_pending_orders failed ({resp.status}): {text}"
|
||||
)
|
||||
data = await resp.json()
|
||||
output = data.get("output", [])
|
||||
if isinstance(output, list):
|
||||
return output
|
||||
return []
|
||||
except (TimeoutError, aiohttp.ClientError) as exc:
|
||||
raise ConnectionError(
|
||||
f"Network error fetching pending orders: {exc}"
|
||||
) from exc
|
||||
|
||||
async def cancel_overseas_order(
|
||||
self,
|
||||
exchange_code: str,
|
||||
stock_code: str,
|
||||
odno: str,
|
||||
qty: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Cancel an overseas limit order.
|
||||
|
||||
Args:
|
||||
exchange_code: Exchange code (e.g., "NASD", "SEHK").
|
||||
stock_code: Stock ticker symbol.
|
||||
odno: Original order number to cancel.
|
||||
qty: Unfilled quantity to cancel.
|
||||
|
||||
Returns:
|
||||
API response dict containing rt_cd and msg1.
|
||||
|
||||
Raises:
|
||||
ValueError: If exchange_code has no cancel TR_ID mapping.
|
||||
ConnectionError: On network or API errors.
|
||||
"""
|
||||
tr_ids = _CANCEL_TR_ID_MAP.get(exchange_code)
|
||||
if tr_ids is None:
|
||||
raise ValueError(f"No cancel TR_ID mapping for exchange: {exchange_code}")
|
||||
live_tr_id, paper_tr_id = tr_ids
|
||||
tr_id = live_tr_id if self._broker._settings.MODE == "live" else paper_tr_id
|
||||
|
||||
await self._broker._rate_limiter.acquire()
|
||||
session = self._broker._get_session()
|
||||
|
||||
# RVSE_CNCL_DVSN_CD="02" means cancel (not revision).
|
||||
# OVRS_ORD_UNPR must be "0" for cancellations.
|
||||
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '해외주식 정정취소주문' 시트
|
||||
body = {
|
||||
"CANO": self._broker._account_no,
|
||||
"ACNT_PRDT_CD": self._broker._product_cd,
|
||||
"OVRS_EXCG_CD": exchange_code,
|
||||
"PDNO": stock_code,
|
||||
"ORGN_ODNO": odno,
|
||||
"RVSE_CNCL_DVSN_CD": "02",
|
||||
"ORD_QTY": str(qty),
|
||||
"OVRS_ORD_UNPR": "0",
|
||||
"ORD_SVR_DVSN_CD": "0",
|
||||
}
|
||||
|
||||
hash_key = await self._broker._get_hash_key(body)
|
||||
headers = await self._broker._auth_headers(tr_id)
|
||||
headers["hashkey"] = hash_key
|
||||
|
||||
url = (
|
||||
f"{self._broker._base_url}/uapi/overseas-stock/v1/trading/order-rvsecncl"
|
||||
)
|
||||
|
||||
try:
|
||||
async with session.post(url, headers=headers, json=body) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ConnectionError(
|
||||
f"cancel_overseas_order failed ({resp.status}): {text}"
|
||||
)
|
||||
return await resp.json()
|
||||
except (TimeoutError, aiohttp.ClientError) as exc:
|
||||
raise ConnectionError(
|
||||
f"Network error cancelling overseas order: {exc}"
|
||||
) from exc
|
||||
|
||||
def _get_currency_code(self, exchange_code: str) -> str:
|
||||
"""
|
||||
Map exchange code to currency code.
|
||||
@@ -198,3 +453,11 @@ class OverseasBroker:
|
||||
"HSX": "VND",
|
||||
}
|
||||
return currency_map.get(exchange_code, "USD")
|
||||
|
||||
def _extract_ranking_rows(self, data: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""Extract list rows from ranking response across schema variants."""
|
||||
candidates = [data.get("output"), data.get("output1"), data.get("output2")]
|
||||
for value in candidates:
|
||||
if isinstance(value, list):
|
||||
return [row for row in value if isinstance(row, dict)]
|
||||
return []
|
||||
|
||||
@@ -13,28 +13,112 @@ class Settings(BaseSettings):
|
||||
KIS_APP_KEY: str
|
||||
KIS_APP_SECRET: str
|
||||
KIS_ACCOUNT_NO: str # format: "XXXXXXXX-XX"
|
||||
KIS_BASE_URL: str = "https://openapivts.koreainvestment.com:9443"
|
||||
KIS_BASE_URL: str = "https://openapivts.koreainvestment.com:29443"
|
||||
|
||||
# Google Gemini
|
||||
GEMINI_API_KEY: str
|
||||
GEMINI_MODEL: str = "gemini-pro"
|
||||
GEMINI_MODEL: str = "gemini-2.0-flash"
|
||||
|
||||
# External Data APIs (optional — for data-driven decisions)
|
||||
NEWS_API_KEY: str | None = None
|
||||
NEWS_API_PROVIDER: str = "alphavantage" # "alphavantage" or "newsapi"
|
||||
MARKET_DATA_API_KEY: str | None = None
|
||||
|
||||
# Legacy field names (for backward compatibility)
|
||||
ALPHA_VANTAGE_API_KEY: str | None = None
|
||||
NEWSAPI_KEY: str | None = None
|
||||
|
||||
# Risk Management
|
||||
CIRCUIT_BREAKER_PCT: float = Field(default=-3.0, le=0.0)
|
||||
FAT_FINGER_PCT: float = Field(default=30.0, gt=0.0, le=100.0)
|
||||
CONFIDENCE_THRESHOLD: int = Field(default=80, ge=0, le=100)
|
||||
|
||||
# Smart Scanner Configuration
|
||||
RSI_OVERSOLD_THRESHOLD: int = Field(default=30, ge=0, le=50)
|
||||
RSI_MOMENTUM_THRESHOLD: int = Field(default=70, ge=50, le=100)
|
||||
VOL_MULTIPLIER: float = Field(default=2.0, gt=1.0, le=10.0)
|
||||
SCANNER_TOP_N: int = Field(default=3, ge=1, le=10)
|
||||
POSITION_SIZING_ENABLED: bool = True
|
||||
POSITION_BASE_ALLOCATION_PCT: float = Field(default=5.0, gt=0.0, le=30.0)
|
||||
POSITION_MIN_ALLOCATION_PCT: float = Field(default=1.0, gt=0.0, le=20.0)
|
||||
POSITION_MAX_ALLOCATION_PCT: float = Field(default=10.0, gt=0.0, le=50.0)
|
||||
POSITION_VOLATILITY_TARGET_SCORE: float = Field(default=50.0, gt=0.0, le=100.0)
|
||||
|
||||
# Database
|
||||
DB_PATH: str = "data/trade_logs.db"
|
||||
|
||||
# Rate Limiting (requests per second for KIS API)
|
||||
RATE_LIMIT_RPS: float = 10.0
|
||||
# Conservative limit to avoid EGW00201 "초당 거래건수 초과" errors.
|
||||
# KIS API real limit is ~2 RPS; 2.0 provides maximum safety.
|
||||
RATE_LIMIT_RPS: float = 2.0
|
||||
|
||||
# Trading mode
|
||||
MODE: str = Field(default="paper", pattern="^(paper|live)$")
|
||||
|
||||
# Simulated USD cash for VTS (paper) overseas trading.
|
||||
# KIS VTS overseas balance API returns errors for most accounts.
|
||||
# This value is used as a fallback when the balance API returns 0 in paper mode.
|
||||
PAPER_OVERSEAS_CASH: float = Field(default=50000.0, ge=0.0)
|
||||
|
||||
# Trading frequency mode (daily = batch API calls, realtime = per-stock calls)
|
||||
TRADE_MODE: str = Field(default="daily", pattern="^(daily|realtime)$")
|
||||
DAILY_SESSIONS: int = Field(default=4, ge=1, le=10)
|
||||
SESSION_INTERVAL_HOURS: int = Field(default=6, ge=1, le=24)
|
||||
|
||||
# Pre-Market Planner
|
||||
PRE_MARKET_MINUTES: int = Field(default=30, ge=10, le=120)
|
||||
MAX_SCENARIOS_PER_STOCK: int = Field(default=5, ge=1, le=10)
|
||||
PLANNER_TIMEOUT_SECONDS: int = Field(default=60, ge=10, le=300)
|
||||
DEFENSIVE_PLAYBOOK_ON_FAILURE: bool = True
|
||||
RESCAN_INTERVAL_SECONDS: int = Field(default=300, ge=60, le=900)
|
||||
|
||||
# Market selection (comma-separated market codes)
|
||||
ENABLED_MARKETS: str = "KR"
|
||||
ENABLED_MARKETS: str = "KR,US"
|
||||
|
||||
# Backup and Disaster Recovery (optional)
|
||||
BACKUP_ENABLED: bool = True
|
||||
BACKUP_DIR: str = "data/backups"
|
||||
S3_ENDPOINT_URL: str | None = None # For MinIO, Backblaze B2, etc.
|
||||
S3_ACCESS_KEY: str | None = None
|
||||
S3_SECRET_KEY: str | None = None
|
||||
S3_BUCKET_NAME: str | None = None
|
||||
S3_REGION: str = "us-east-1"
|
||||
|
||||
# Telegram Notifications (optional)
|
||||
TELEGRAM_BOT_TOKEN: str | None = None
|
||||
TELEGRAM_CHAT_ID: str | None = None
|
||||
TELEGRAM_ENABLED: bool = True
|
||||
|
||||
# Telegram Commands (optional)
|
||||
TELEGRAM_COMMANDS_ENABLED: bool = True
|
||||
TELEGRAM_POLLING_INTERVAL: float = 1.0 # seconds
|
||||
|
||||
# Telegram notification type filters (granular control)
|
||||
# circuit_breaker is always sent regardless — safety-critical
|
||||
TELEGRAM_NOTIFY_TRADES: bool = True # BUY/SELL execution alerts
|
||||
TELEGRAM_NOTIFY_MARKET_OPEN_CLOSE: bool = True # Market open/close alerts
|
||||
TELEGRAM_NOTIFY_FAT_FINGER: bool = True # Fat-finger rejection alerts
|
||||
TELEGRAM_NOTIFY_SYSTEM_EVENTS: bool = True # System start/shutdown alerts
|
||||
TELEGRAM_NOTIFY_PLAYBOOK: bool = True # Playbook generated/failed alerts
|
||||
TELEGRAM_NOTIFY_SCENARIO_MATCH: bool = True # Scenario matched alerts (most frequent)
|
||||
TELEGRAM_NOTIFY_ERRORS: bool = True # Error alerts
|
||||
|
||||
# Overseas ranking API (KIS endpoint/TR_ID may vary by account/product)
|
||||
# Override these from .env if your account uses different specs.
|
||||
OVERSEAS_RANKING_ENABLED: bool = True
|
||||
OVERSEAS_RANKING_FLUCT_TR_ID: str = "HHDFS76290000"
|
||||
OVERSEAS_RANKING_VOLUME_TR_ID: str = "HHDFS76270000"
|
||||
OVERSEAS_RANKING_FLUCT_PATH: str = (
|
||||
"/uapi/overseas-stock/v1/ranking/updown-rate"
|
||||
)
|
||||
OVERSEAS_RANKING_VOLUME_PATH: str = (
|
||||
"/uapi/overseas-stock/v1/ranking/volume-surge"
|
||||
)
|
||||
|
||||
# Dashboard (optional)
|
||||
DASHBOARD_ENABLED: bool = False
|
||||
DASHBOARD_HOST: str = "127.0.0.1"
|
||||
DASHBOARD_PORT: int = Field(default=8080, ge=1, le=65535)
|
||||
|
||||
model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}
|
||||
|
||||
@@ -49,4 +133,7 @@ class Settings(BaseSettings):
|
||||
@property
|
||||
def enabled_market_list(self) -> list[str]:
|
||||
"""Parse ENABLED_MARKETS into list of market codes."""
|
||||
return [m.strip() for m in self.ENABLED_MARKETS.split(",") if m.strip()]
|
||||
from src.markets.schedule import expand_market_codes
|
||||
|
||||
raw = [m.strip() for m in self.ENABLED_MARKETS.split(",") if m.strip()]
|
||||
return expand_market_codes(raw)
|
||||
|
||||
11
src/context/__init__.py
Normal file
11
src/context/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Multi-layered context management system for trading decisions.
|
||||
|
||||
The context tree implements Pillar 2: hierarchical memory management across
|
||||
7 time horizons, from real-time quotes to generational wisdom.
|
||||
"""
|
||||
|
||||
from src.context.layer import ContextLayer
|
||||
from src.context.scheduler import ContextScheduler
|
||||
from src.context.store import ContextStore
|
||||
|
||||
__all__ = ["ContextLayer", "ContextScheduler", "ContextStore"]
|
||||
334
src/context/aggregator.py
Normal file
334
src/context/aggregator.py
Normal file
@@ -0,0 +1,334 @@
|
||||
"""Context aggregation logic for rolling up data from lower to higher layers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from src.context.layer import ContextLayer
|
||||
from src.context.store import ContextStore
|
||||
|
||||
|
||||
class ContextAggregator:
|
||||
"""Aggregates context data from lower (finer) to higher (coarser) layers."""
|
||||
|
||||
def __init__(self, conn: sqlite3.Connection) -> None:
|
||||
"""Initialize the aggregator with a database connection."""
|
||||
self.conn = conn
|
||||
self.store = ContextStore(conn)
|
||||
|
||||
def aggregate_daily_from_trades(
|
||||
self, date: str | None = None, market: str | None = None
|
||||
) -> None:
|
||||
"""Aggregate L6 (daily) context from trades table.
|
||||
|
||||
Args:
|
||||
date: Date in YYYY-MM-DD format. If None, uses today.
|
||||
market: Market code filter (e.g., "KR", "US"). If None, aggregates all markets.
|
||||
"""
|
||||
if date is None:
|
||||
date = datetime.now(UTC).date().isoformat()
|
||||
|
||||
if market is None:
|
||||
cursor = self.conn.execute(
|
||||
"""
|
||||
SELECT DISTINCT market
|
||||
FROM trades
|
||||
WHERE DATE(timestamp) = ?
|
||||
""",
|
||||
(date,),
|
||||
)
|
||||
markets = [row[0] for row in cursor.fetchall() if row[0]]
|
||||
else:
|
||||
markets = [market]
|
||||
|
||||
for market_code in markets:
|
||||
# Calculate daily metrics from trades for the market
|
||||
cursor = self.conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
COUNT(*) as trade_count,
|
||||
SUM(CASE WHEN action = 'BUY' THEN 1 ELSE 0 END) as buys,
|
||||
SUM(CASE WHEN action = 'SELL' THEN 1 ELSE 0 END) as sells,
|
||||
SUM(CASE WHEN action = 'HOLD' THEN 1 ELSE 0 END) as holds,
|
||||
AVG(confidence) as avg_confidence,
|
||||
SUM(pnl) as total_pnl,
|
||||
COUNT(DISTINCT stock_code) as unique_stocks,
|
||||
SUM(CASE WHEN pnl > 0 THEN 1 ELSE 0 END) as wins,
|
||||
SUM(CASE WHEN pnl < 0 THEN 1 ELSE 0 END) as losses
|
||||
FROM trades
|
||||
WHERE DATE(timestamp) = ? AND market = ?
|
||||
""",
|
||||
(date, market_code),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
|
||||
if row and row[0] > 0: # At least one trade
|
||||
trade_count, buys, sells, holds, avg_conf, total_pnl, stocks, wins, losses = row
|
||||
|
||||
key_suffix = f"_{market_code}"
|
||||
|
||||
# Store daily metrics in L6 with market suffix
|
||||
self.store.set_context(
|
||||
ContextLayer.L6_DAILY, date, f"trade_count{key_suffix}", trade_count
|
||||
)
|
||||
self.store.set_context(ContextLayer.L6_DAILY, date, f"buys{key_suffix}", buys)
|
||||
self.store.set_context(ContextLayer.L6_DAILY, date, f"sells{key_suffix}", sells)
|
||||
self.store.set_context(ContextLayer.L6_DAILY, date, f"holds{key_suffix}", holds)
|
||||
self.store.set_context(
|
||||
ContextLayer.L6_DAILY,
|
||||
date,
|
||||
f"avg_confidence{key_suffix}",
|
||||
round(avg_conf, 2),
|
||||
)
|
||||
self.store.set_context(
|
||||
ContextLayer.L6_DAILY,
|
||||
date,
|
||||
f"total_pnl{key_suffix}",
|
||||
round(total_pnl, 2),
|
||||
)
|
||||
self.store.set_context(
|
||||
ContextLayer.L6_DAILY, date, f"unique_stocks{key_suffix}", stocks
|
||||
)
|
||||
win_rate = round(wins / max(wins + losses, 1) * 100, 2)
|
||||
self.store.set_context(
|
||||
ContextLayer.L6_DAILY, date, f"win_rate{key_suffix}", win_rate
|
||||
)
|
||||
|
||||
def aggregate_weekly_from_daily(self, week: str | None = None) -> None:
|
||||
"""Aggregate L5 (weekly) context from L6 (daily).
|
||||
|
||||
Args:
|
||||
week: Week in YYYY-Www format (ISO week). If None, uses current week.
|
||||
"""
|
||||
if week is None:
|
||||
week = datetime.now(UTC).strftime("%Y-W%V")
|
||||
|
||||
# Get all daily contexts for this week
|
||||
cursor = self.conn.execute(
|
||||
"""
|
||||
SELECT key, value FROM contexts
|
||||
WHERE layer = ? AND timeframe LIKE ?
|
||||
""",
|
||||
(ContextLayer.L6_DAILY.value, f"{week[:4]}-%"), # All days in the year
|
||||
)
|
||||
|
||||
# Group by key and collect all values
|
||||
import json
|
||||
from collections import defaultdict
|
||||
|
||||
daily_data: dict[str, list[Any]] = defaultdict(list)
|
||||
for row in cursor.fetchall():
|
||||
daily_data[row[0]].append(json.loads(row[1]))
|
||||
|
||||
if daily_data:
|
||||
# Sum all PnL values (market-specific if suffixed)
|
||||
if "total_pnl" in daily_data:
|
||||
total_pnl = sum(daily_data["total_pnl"])
|
||||
self.store.set_context(
|
||||
ContextLayer.L5_WEEKLY, week, "weekly_pnl", round(total_pnl, 2)
|
||||
)
|
||||
|
||||
for key, values in daily_data.items():
|
||||
if key.startswith("total_pnl_"):
|
||||
market_code = key.split("total_pnl_", 1)[1]
|
||||
total_pnl = sum(values)
|
||||
self.store.set_context(
|
||||
ContextLayer.L5_WEEKLY,
|
||||
week,
|
||||
f"weekly_pnl_{market_code}",
|
||||
round(total_pnl, 2),
|
||||
)
|
||||
|
||||
# Average all confidence values (market-specific if suffixed)
|
||||
if "avg_confidence" in daily_data:
|
||||
conf_values = daily_data["avg_confidence"]
|
||||
avg_conf = sum(conf_values) / len(conf_values)
|
||||
self.store.set_context(
|
||||
ContextLayer.L5_WEEKLY, week, "avg_confidence", round(avg_conf, 2)
|
||||
)
|
||||
|
||||
for key, values in daily_data.items():
|
||||
if key.startswith("avg_confidence_"):
|
||||
market_code = key.split("avg_confidence_", 1)[1]
|
||||
avg_conf = sum(values) / len(values)
|
||||
self.store.set_context(
|
||||
ContextLayer.L5_WEEKLY,
|
||||
week,
|
||||
f"avg_confidence_{market_code}",
|
||||
round(avg_conf, 2),
|
||||
)
|
||||
|
||||
def aggregate_monthly_from_weekly(self, month: str | None = None) -> None:
|
||||
"""Aggregate L4 (monthly) context from L5 (weekly).
|
||||
|
||||
Args:
|
||||
month: Month in YYYY-MM format. If None, uses current month.
|
||||
"""
|
||||
if month is None:
|
||||
month = datetime.now(UTC).strftime("%Y-%m")
|
||||
|
||||
# Get all weekly contexts for this month
|
||||
cursor = self.conn.execute(
|
||||
"""
|
||||
SELECT key, value FROM contexts
|
||||
WHERE layer = ? AND timeframe LIKE ?
|
||||
""",
|
||||
(ContextLayer.L5_WEEKLY.value, f"{month[:4]}-W%"),
|
||||
)
|
||||
|
||||
# Group by key and collect all values
|
||||
import json
|
||||
from collections import defaultdict
|
||||
|
||||
weekly_data: dict[str, list[Any]] = defaultdict(list)
|
||||
for row in cursor.fetchall():
|
||||
weekly_data[row[0]].append(json.loads(row[1]))
|
||||
|
||||
if weekly_data:
|
||||
# Sum all weekly PnL values
|
||||
total_pnl_values: list[float] = []
|
||||
if "weekly_pnl" in weekly_data:
|
||||
total_pnl_values.extend(weekly_data["weekly_pnl"])
|
||||
|
||||
for key, values in weekly_data.items():
|
||||
if key.startswith("weekly_pnl_"):
|
||||
total_pnl_values.extend(values)
|
||||
|
||||
if total_pnl_values:
|
||||
total_pnl = sum(total_pnl_values)
|
||||
self.store.set_context(
|
||||
ContextLayer.L4_MONTHLY, month, "monthly_pnl", round(total_pnl, 2)
|
||||
)
|
||||
|
||||
def aggregate_quarterly_from_monthly(self, quarter: str | None = None) -> None:
|
||||
"""Aggregate L3 (quarterly) context from L4 (monthly).
|
||||
|
||||
Args:
|
||||
quarter: Quarter in YYYY-Qn format. If None, uses current quarter.
|
||||
"""
|
||||
if quarter is None:
|
||||
from datetime import datetime
|
||||
|
||||
now = datetime.now(UTC)
|
||||
q = (now.month - 1) // 3 + 1
|
||||
quarter = f"{now.year}-Q{q}"
|
||||
|
||||
# Get all monthly contexts for this quarter
|
||||
# Q1: 01-03, Q2: 04-06, Q3: 07-09, Q4: 10-12
|
||||
q_num = int(quarter.split("-Q")[1])
|
||||
months = [f"{quarter[:4]}-{m:02d}" for m in range((q_num - 1) * 3 + 1, q_num * 3 + 1)]
|
||||
|
||||
total_pnl = 0.0
|
||||
for month in months:
|
||||
monthly_pnl = self.store.get_context(
|
||||
ContextLayer.L4_MONTHLY, month, "monthly_pnl"
|
||||
)
|
||||
if monthly_pnl is not None:
|
||||
total_pnl += monthly_pnl
|
||||
|
||||
self.store.set_context(
|
||||
ContextLayer.L3_QUARTERLY, quarter, "quarterly_pnl", round(total_pnl, 2)
|
||||
)
|
||||
|
||||
def aggregate_annual_from_quarterly(self, year: str | None = None) -> None:
|
||||
"""Aggregate L2 (annual) context from L3 (quarterly).
|
||||
|
||||
Args:
|
||||
year: Year in YYYY format. If None, uses current year.
|
||||
"""
|
||||
if year is None:
|
||||
year = str(datetime.now(UTC).year)
|
||||
|
||||
# Get all quarterly contexts for this year
|
||||
total_pnl = 0.0
|
||||
for q in range(1, 5):
|
||||
quarter = f"{year}-Q{q}"
|
||||
quarterly_pnl = self.store.get_context(
|
||||
ContextLayer.L3_QUARTERLY, quarter, "quarterly_pnl"
|
||||
)
|
||||
if quarterly_pnl is not None:
|
||||
total_pnl += quarterly_pnl
|
||||
|
||||
self.store.set_context(
|
||||
ContextLayer.L2_ANNUAL, year, "annual_pnl", round(total_pnl, 2)
|
||||
)
|
||||
|
||||
def aggregate_legacy_from_annual(self) -> None:
|
||||
"""Aggregate L1 (legacy) context from all L2 (annual) data."""
|
||||
# Get all annual PnL
|
||||
cursor = self.conn.execute(
|
||||
"""
|
||||
SELECT timeframe, value FROM contexts
|
||||
WHERE layer = ? AND key = ?
|
||||
ORDER BY timeframe
|
||||
""",
|
||||
(ContextLayer.L2_ANNUAL.value, "annual_pnl"),
|
||||
)
|
||||
|
||||
import json
|
||||
|
||||
annual_data = [(row[0], json.loads(row[1])) for row in cursor.fetchall()]
|
||||
|
||||
if annual_data:
|
||||
total_pnl = sum(pnl for _, pnl in annual_data)
|
||||
years_traded = len(annual_data)
|
||||
avg_annual_pnl = total_pnl / years_traded
|
||||
|
||||
# Store in L1 (single "LEGACY" timeframe)
|
||||
self.store.set_context(
|
||||
ContextLayer.L1_LEGACY, "LEGACY", "total_pnl", round(total_pnl, 2)
|
||||
)
|
||||
self.store.set_context(
|
||||
ContextLayer.L1_LEGACY, "LEGACY", "years_traded", years_traded
|
||||
)
|
||||
self.store.set_context(
|
||||
ContextLayer.L1_LEGACY,
|
||||
"LEGACY",
|
||||
"avg_annual_pnl",
|
||||
round(avg_annual_pnl, 2),
|
||||
)
|
||||
|
||||
def run_all_aggregations(self) -> None:
|
||||
"""Run all aggregations from L7 to L1 (bottom-up).
|
||||
|
||||
All timeframes are derived from the latest trade timestamp so that
|
||||
past data re-aggregation produces consistent results across layers.
|
||||
"""
|
||||
cursor = self.conn.execute("SELECT MAX(timestamp) FROM trades")
|
||||
row = cursor.fetchone()
|
||||
if not row or row[0] is None:
|
||||
return
|
||||
|
||||
ts_raw = row[0]
|
||||
if ts_raw.endswith("Z"):
|
||||
ts_raw = ts_raw.replace("Z", "+00:00")
|
||||
latest_ts = datetime.fromisoformat(ts_raw)
|
||||
trade_date = latest_ts.date()
|
||||
date_str = trade_date.isoformat()
|
||||
|
||||
iso_year, iso_week, _ = trade_date.isocalendar()
|
||||
week_str = f"{iso_year}-W{iso_week:02d}"
|
||||
month_str = f"{trade_date.year}-{trade_date.month:02d}"
|
||||
quarter = (trade_date.month - 1) // 3 + 1
|
||||
quarter_str = f"{trade_date.year}-Q{quarter}"
|
||||
year_str = str(trade_date.year)
|
||||
|
||||
# L7 (trades) → L6 (daily)
|
||||
self.aggregate_daily_from_trades(date_str)
|
||||
|
||||
# L6 (daily) → L5 (weekly)
|
||||
self.aggregate_weekly_from_daily(week_str)
|
||||
|
||||
# L5 (weekly) → L4 (monthly)
|
||||
self.aggregate_monthly_from_weekly(month_str)
|
||||
|
||||
# L4 (monthly) → L3 (quarterly)
|
||||
self.aggregate_quarterly_from_monthly(quarter_str)
|
||||
|
||||
# L3 (quarterly) → L2 (annual)
|
||||
self.aggregate_annual_from_quarterly(year_str)
|
||||
|
||||
# L2 (annual) → L1 (legacy)
|
||||
self.aggregate_legacy_from_annual()
|
||||
75
src/context/layer.py
Normal file
75
src/context/layer.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Context layer definitions for multi-tier memory management."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ContextLayer(str, Enum):
|
||||
"""7-tier context hierarchy from real-time to generational."""
|
||||
|
||||
L1_LEGACY = "L1_LEGACY" # Cumulative/generational wisdom
|
||||
L2_ANNUAL = "L2_ANNUAL" # Yearly performance
|
||||
L3_QUARTERLY = "L3_QUARTERLY" # Quarterly strategy adjustments
|
||||
L4_MONTHLY = "L4_MONTHLY" # Monthly rebalancing
|
||||
L5_WEEKLY = "L5_WEEKLY" # Weekly stock selection
|
||||
L6_DAILY = "L6_DAILY" # Daily trade logs
|
||||
L7_REALTIME = "L7_REALTIME" # Real-time market data
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LayerMetadata:
|
||||
"""Metadata for each context layer."""
|
||||
|
||||
layer: ContextLayer
|
||||
description: str
|
||||
retention_days: int | None # None = keep forever
|
||||
aggregation_source: ContextLayer | None # Parent layer for aggregation
|
||||
|
||||
|
||||
# Layer configuration
|
||||
LAYER_CONFIG: dict[ContextLayer, LayerMetadata] = {
|
||||
ContextLayer.L1_LEGACY: LayerMetadata(
|
||||
layer=ContextLayer.L1_LEGACY,
|
||||
description="Cumulative trading history and core lessons learned across generations",
|
||||
retention_days=None, # Keep forever
|
||||
aggregation_source=ContextLayer.L2_ANNUAL,
|
||||
),
|
||||
ContextLayer.L2_ANNUAL: LayerMetadata(
|
||||
layer=ContextLayer.L2_ANNUAL,
|
||||
description="Yearly returns, Sharpe ratio, max drawdown, win rate",
|
||||
retention_days=365 * 10, # 10 years
|
||||
aggregation_source=ContextLayer.L3_QUARTERLY,
|
||||
),
|
||||
ContextLayer.L3_QUARTERLY: LayerMetadata(
|
||||
layer=ContextLayer.L3_QUARTERLY,
|
||||
description="Quarterly strategy adjustments, market phase detection, sector rotation",
|
||||
retention_days=365 * 3, # 3 years
|
||||
aggregation_source=ContextLayer.L4_MONTHLY,
|
||||
),
|
||||
ContextLayer.L4_MONTHLY: LayerMetadata(
|
||||
layer=ContextLayer.L4_MONTHLY,
|
||||
description="Monthly portfolio rebalancing, risk exposure, drawdown recovery",
|
||||
retention_days=365 * 2, # 2 years
|
||||
aggregation_source=ContextLayer.L5_WEEKLY,
|
||||
),
|
||||
ContextLayer.L5_WEEKLY: LayerMetadata(
|
||||
layer=ContextLayer.L5_WEEKLY,
|
||||
description="Weekly stock selection, sector focus, volatility regime",
|
||||
retention_days=365, # 1 year
|
||||
aggregation_source=ContextLayer.L6_DAILY,
|
||||
),
|
||||
ContextLayer.L6_DAILY: LayerMetadata(
|
||||
layer=ContextLayer.L6_DAILY,
|
||||
description="Daily trade logs, P&L, market summaries, decision accuracy",
|
||||
retention_days=90, # 90 days
|
||||
aggregation_source=ContextLayer.L7_REALTIME,
|
||||
),
|
||||
ContextLayer.L7_REALTIME: LayerMetadata(
|
||||
layer=ContextLayer.L7_REALTIME,
|
||||
description="Real-time positions, quotes, orderbook, volatility, live P&L",
|
||||
retention_days=7, # 7 days (real-time data is ephemeral)
|
||||
aggregation_source=None, # No aggregation source (leaf layer)
|
||||
),
|
||||
}
|
||||
135
src/context/scheduler.py
Normal file
135
src/context/scheduler.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Context aggregation scheduler for periodic rollups and cleanup."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from calendar import monthrange
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from src.context.aggregator import ContextAggregator
|
||||
from src.context.store import ContextStore
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleResult:
|
||||
"""Represents which scheduled tasks ran."""
|
||||
|
||||
weekly: bool = False
|
||||
monthly: bool = False
|
||||
quarterly: bool = False
|
||||
annual: bool = False
|
||||
legacy: bool = False
|
||||
cleanup: bool = False
|
||||
|
||||
|
||||
class ContextScheduler:
|
||||
"""Run periodic context aggregations and cleanup when due."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conn: sqlite3.Connection | None = None,
|
||||
aggregator: ContextAggregator | None = None,
|
||||
store: ContextStore | None = None,
|
||||
) -> None:
|
||||
if aggregator is None:
|
||||
if conn is None:
|
||||
raise ValueError("conn is required when aggregator is not provided")
|
||||
aggregator = ContextAggregator(conn)
|
||||
self.aggregator = aggregator
|
||||
|
||||
if store is None:
|
||||
store = getattr(aggregator, "store", None)
|
||||
if store is None:
|
||||
if conn is None:
|
||||
raise ValueError("conn is required when store is not provided")
|
||||
store = ContextStore(conn)
|
||||
self.store = store
|
||||
|
||||
self._last_run: dict[str, str] = {}
|
||||
|
||||
def run_if_due(self, now: datetime | None = None) -> ScheduleResult:
|
||||
"""Run scheduled aggregations if their schedule is due.
|
||||
|
||||
Args:
|
||||
now: Current datetime (UTC). If None, uses current time.
|
||||
|
||||
Returns:
|
||||
ScheduleResult indicating which tasks ran.
|
||||
"""
|
||||
if now is None:
|
||||
now = datetime.now(UTC)
|
||||
|
||||
today = now.date().isoformat()
|
||||
result = ScheduleResult()
|
||||
|
||||
if self._should_run("cleanup", today):
|
||||
self.store.cleanup_expired_contexts()
|
||||
result = self._with(result, cleanup=True)
|
||||
|
||||
if self._is_sunday(now) and self._should_run("weekly", today):
|
||||
week = now.strftime("%Y-W%V")
|
||||
self.aggregator.aggregate_weekly_from_daily(week)
|
||||
result = self._with(result, weekly=True)
|
||||
|
||||
if self._is_last_day_of_month(now) and self._should_run("monthly", today):
|
||||
month = now.strftime("%Y-%m")
|
||||
self.aggregator.aggregate_monthly_from_weekly(month)
|
||||
result = self._with(result, monthly=True)
|
||||
|
||||
if self._is_last_day_of_quarter(now) and self._should_run("quarterly", today):
|
||||
quarter = self._current_quarter(now)
|
||||
self.aggregator.aggregate_quarterly_from_monthly(quarter)
|
||||
result = self._with(result, quarterly=True)
|
||||
|
||||
if self._is_last_day_of_year(now) and self._should_run("annual", today):
|
||||
year = str(now.year)
|
||||
self.aggregator.aggregate_annual_from_quarterly(year)
|
||||
result = self._with(result, annual=True)
|
||||
|
||||
# Legacy rollup runs after annual aggregation.
|
||||
self.aggregator.aggregate_legacy_from_annual()
|
||||
result = self._with(result, legacy=True)
|
||||
|
||||
return result
|
||||
|
||||
def _should_run(self, key: str, date_str: str) -> bool:
|
||||
if self._last_run.get(key) == date_str:
|
||||
return False
|
||||
self._last_run[key] = date_str
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _is_sunday(now: datetime) -> bool:
|
||||
return now.weekday() == 6
|
||||
|
||||
@staticmethod
|
||||
def _is_last_day_of_month(now: datetime) -> bool:
|
||||
last_day = monthrange(now.year, now.month)[1]
|
||||
return now.day == last_day
|
||||
|
||||
@classmethod
|
||||
def _is_last_day_of_quarter(cls, now: datetime) -> bool:
|
||||
if now.month not in (3, 6, 9, 12):
|
||||
return False
|
||||
return cls._is_last_day_of_month(now)
|
||||
|
||||
@staticmethod
|
||||
def _is_last_day_of_year(now: datetime) -> bool:
|
||||
return now.month == 12 and now.day == 31
|
||||
|
||||
@staticmethod
|
||||
def _current_quarter(now: datetime) -> str:
|
||||
quarter = (now.month - 1) // 3 + 1
|
||||
return f"{now.year}-Q{quarter}"
|
||||
|
||||
@staticmethod
|
||||
def _with(result: ScheduleResult, **kwargs: bool) -> ScheduleResult:
|
||||
return ScheduleResult(
|
||||
weekly=kwargs.get("weekly", result.weekly),
|
||||
monthly=kwargs.get("monthly", result.monthly),
|
||||
quarterly=kwargs.get("quarterly", result.quarterly),
|
||||
annual=kwargs.get("annual", result.annual),
|
||||
legacy=kwargs.get("legacy", result.legacy),
|
||||
cleanup=kwargs.get("cleanup", result.cleanup),
|
||||
)
|
||||
193
src/context/store.py
Normal file
193
src/context/store.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""Context storage and retrieval for the 7-tier memory system."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from src.context.layer import LAYER_CONFIG, ContextLayer
|
||||
|
||||
|
||||
class ContextStore:
|
||||
"""Manages context data across the 7-tier hierarchy."""
|
||||
|
||||
def __init__(self, conn: sqlite3.Connection) -> None:
|
||||
"""Initialize the context store with a database connection."""
|
||||
self.conn = conn
|
||||
self._init_metadata()
|
||||
|
||||
def _init_metadata(self) -> None:
|
||||
"""Initialize context_metadata table with layer configurations."""
|
||||
for config in LAYER_CONFIG.values():
|
||||
self.conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO context_metadata
|
||||
(layer, description, retention_days, aggregation_source)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
config.layer.value,
|
||||
config.description,
|
||||
config.retention_days,
|
||||
config.aggregation_source.value if config.aggregation_source else None,
|
||||
),
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def set_context(
|
||||
self,
|
||||
layer: ContextLayer,
|
||||
timeframe: str,
|
||||
key: str,
|
||||
value: Any,
|
||||
) -> None:
|
||||
"""Set a context value for a given layer and timeframe.
|
||||
|
||||
Args:
|
||||
layer: The context layer (L1-L7)
|
||||
timeframe: Time identifier (e.g., "2026", "2026-Q1", "2026-01",
|
||||
"2026-W05", "2026-02-04")
|
||||
key: Context key (e.g., "sharpe_ratio", "win_rate", "lesson_learned")
|
||||
value: Context value (will be JSON-serialized)
|
||||
"""
|
||||
now = datetime.now(UTC).isoformat()
|
||||
value_json = json.dumps(value)
|
||||
|
||||
self.conn.execute(
|
||||
"""
|
||||
INSERT INTO contexts (layer, timeframe, key, value, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(layer, timeframe, key)
|
||||
DO UPDATE SET value = excluded.value, updated_at = excluded.updated_at
|
||||
""",
|
||||
(layer.value, timeframe, key, value_json, now, now),
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def get_context(
|
||||
self,
|
||||
layer: ContextLayer,
|
||||
timeframe: str,
|
||||
key: str,
|
||||
) -> Any | None:
|
||||
"""Get a context value for a given layer and timeframe.
|
||||
|
||||
Args:
|
||||
layer: The context layer (L1-L7)
|
||||
timeframe: Time identifier
|
||||
key: Context key
|
||||
|
||||
Returns:
|
||||
The context value (deserialized from JSON), or None if not found
|
||||
"""
|
||||
cursor = self.conn.execute(
|
||||
"""
|
||||
SELECT value FROM contexts
|
||||
WHERE layer = ? AND timeframe = ? AND key = ?
|
||||
""",
|
||||
(layer.value, timeframe, key),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return json.loads(row[0])
|
||||
return None
|
||||
|
||||
def get_all_contexts(
|
||||
self,
|
||||
layer: ContextLayer,
|
||||
timeframe: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Get all context values for a given layer and optional timeframe.
|
||||
|
||||
Args:
|
||||
layer: The context layer (L1-L7)
|
||||
timeframe: Optional time identifier filter
|
||||
|
||||
Returns:
|
||||
Dictionary of key-value pairs for the specified layer/timeframe
|
||||
"""
|
||||
if timeframe:
|
||||
cursor = self.conn.execute(
|
||||
"""
|
||||
SELECT key, value FROM contexts
|
||||
WHERE layer = ? AND timeframe = ?
|
||||
ORDER BY key
|
||||
""",
|
||||
(layer.value, timeframe),
|
||||
)
|
||||
else:
|
||||
cursor = self.conn.execute(
|
||||
"""
|
||||
SELECT key, value FROM contexts
|
||||
WHERE layer = ?
|
||||
ORDER BY timeframe DESC, key
|
||||
""",
|
||||
(layer.value,),
|
||||
)
|
||||
|
||||
return {row[0]: json.loads(row[1]) for row in cursor.fetchall()}
|
||||
|
||||
def get_latest_timeframe(self, layer: ContextLayer) -> str | None:
|
||||
"""Get the most recent timeframe for a given layer.
|
||||
|
||||
Args:
|
||||
layer: The context layer (L1-L7)
|
||||
|
||||
Returns:
|
||||
The latest timeframe string, or None if no data exists
|
||||
"""
|
||||
cursor = self.conn.execute(
|
||||
"""
|
||||
SELECT timeframe FROM contexts
|
||||
WHERE layer = ?
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT 1
|
||||
""",
|
||||
(layer.value,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return row[0] if row else None
|
||||
|
||||
def delete_old_contexts(self, layer: ContextLayer, cutoff_date: str) -> int:
|
||||
"""Delete contexts older than the cutoff date for a given layer.
|
||||
|
||||
Args:
|
||||
layer: The context layer (L1-L7)
|
||||
cutoff_date: ISO format date string (contexts before this will be deleted)
|
||||
|
||||
Returns:
|
||||
Number of rows deleted
|
||||
"""
|
||||
cursor = self.conn.execute(
|
||||
"""
|
||||
DELETE FROM contexts
|
||||
WHERE layer = ? AND updated_at < ?
|
||||
""",
|
||||
(layer.value, cutoff_date),
|
||||
)
|
||||
self.conn.commit()
|
||||
return cursor.rowcount
|
||||
|
||||
def cleanup_expired_contexts(self) -> dict[ContextLayer, int]:
|
||||
"""Delete expired contexts based on retention policies.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping layer to number of deleted rows
|
||||
"""
|
||||
deleted_counts: dict[ContextLayer, int] = {}
|
||||
|
||||
for layer, config in LAYER_CONFIG.items():
|
||||
if config.retention_days is None:
|
||||
# Keep forever (e.g., L1_LEGACY)
|
||||
deleted_counts[layer] = 0
|
||||
continue
|
||||
|
||||
# Calculate cutoff date
|
||||
from datetime import timedelta
|
||||
|
||||
cutoff = datetime.now(UTC) - timedelta(days=config.retention_days)
|
||||
deleted_counts[layer] = self.delete_old_contexts(layer, cutoff.isoformat())
|
||||
|
||||
return deleted_counts
|
||||
328
src/context/summarizer.py
Normal file
328
src/context/summarizer.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""Context summarization for efficient historical data representation.
|
||||
|
||||
This module summarizes old context data instead of including raw details:
|
||||
- Key metrics only (averages, trends, not details)
|
||||
- Rolling window (keep last N days detailed, summarize older)
|
||||
- Aggregate historical data efficiently
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from src.context.layer import ContextLayer
|
||||
from src.context.store import ContextStore
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SummaryStats:
|
||||
"""Statistical summary of historical data."""
|
||||
|
||||
count: int
|
||||
mean: float | None = None
|
||||
min: float | None = None
|
||||
max: float | None = None
|
||||
std: float | None = None
|
||||
trend: str | None = None # "up", "down", "flat"
|
||||
|
||||
|
||||
class ContextSummarizer:
|
||||
"""Summarizes historical context data to reduce token usage."""
|
||||
|
||||
def __init__(self, store: ContextStore) -> None:
|
||||
"""Initialize the context summarizer.
|
||||
|
||||
Args:
|
||||
store: ContextStore instance for retrieving context data
|
||||
"""
|
||||
self.store = store
|
||||
|
||||
def summarize_numeric_values(self, values: list[float]) -> SummaryStats:
|
||||
"""Summarize a list of numeric values.
|
||||
|
||||
Args:
|
||||
values: List of numeric values to summarize
|
||||
|
||||
Returns:
|
||||
SummaryStats with mean, min, max, std, and trend
|
||||
"""
|
||||
if not values:
|
||||
return SummaryStats(count=0)
|
||||
|
||||
count = len(values)
|
||||
mean = sum(values) / count
|
||||
min_val = min(values)
|
||||
max_val = max(values)
|
||||
|
||||
# Calculate standard deviation
|
||||
if count > 1:
|
||||
variance = sum((x - mean) ** 2 for x in values) / (count - 1)
|
||||
std = variance**0.5
|
||||
else:
|
||||
std = 0.0
|
||||
|
||||
# Determine trend
|
||||
trend = "flat"
|
||||
if count >= 3:
|
||||
# Simple trend: compare first third vs last third
|
||||
first_third = values[: count // 3]
|
||||
last_third = values[-(count // 3) :]
|
||||
first_avg = sum(first_third) / len(first_third)
|
||||
last_avg = sum(last_third) / len(last_third)
|
||||
|
||||
# Trend threshold: 5% change
|
||||
threshold = 0.05 * abs(first_avg) if first_avg != 0 else 0.01
|
||||
|
||||
if last_avg > first_avg + threshold:
|
||||
trend = "up"
|
||||
elif last_avg < first_avg - threshold:
|
||||
trend = "down"
|
||||
|
||||
return SummaryStats(
|
||||
count=count,
|
||||
mean=round(mean, 4),
|
||||
min=round(min_val, 4),
|
||||
max=round(max_val, 4),
|
||||
std=round(std, 4),
|
||||
trend=trend,
|
||||
)
|
||||
|
||||
def summarize_layer(
|
||||
self,
|
||||
layer: ContextLayer,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Summarize all context data for a layer within a date range.
|
||||
|
||||
Args:
|
||||
layer: Context layer to summarize
|
||||
start_date: Start date (inclusive), None for all
|
||||
end_date: End date (inclusive), None for now
|
||||
|
||||
Returns:
|
||||
Dictionary with summarized metrics
|
||||
"""
|
||||
if end_date is None:
|
||||
end_date = datetime.now(UTC)
|
||||
|
||||
# Get all contexts for this layer
|
||||
all_contexts = self.store.get_all_contexts(layer)
|
||||
|
||||
if not all_contexts:
|
||||
return {"summary": "No data available", "count": 0}
|
||||
|
||||
# Group numeric values by key
|
||||
numeric_data: dict[str, list[float]] = {}
|
||||
text_data: dict[str, list[str]] = {}
|
||||
|
||||
for key, value in all_contexts.items():
|
||||
# Try to extract numeric values
|
||||
if isinstance(value, (int, float)):
|
||||
if key not in numeric_data:
|
||||
numeric_data[key] = []
|
||||
numeric_data[key].append(float(value))
|
||||
elif isinstance(value, dict):
|
||||
# Extract numeric fields from dict
|
||||
for subkey, subvalue in value.items():
|
||||
if isinstance(subvalue, (int, float)):
|
||||
full_key = f"{key}.{subkey}"
|
||||
if full_key not in numeric_data:
|
||||
numeric_data[full_key] = []
|
||||
numeric_data[full_key].append(float(subvalue))
|
||||
elif isinstance(value, str):
|
||||
if key not in text_data:
|
||||
text_data[key] = []
|
||||
text_data[key].append(value)
|
||||
|
||||
# Summarize numeric data
|
||||
summary: dict[str, Any] = {}
|
||||
|
||||
for key, values in numeric_data.items():
|
||||
stats = self.summarize_numeric_values(values)
|
||||
summary[key] = {
|
||||
"count": stats.count,
|
||||
"avg": stats.mean,
|
||||
"range": [stats.min, stats.max],
|
||||
"trend": stats.trend,
|
||||
}
|
||||
|
||||
# Summarize text data (just counts)
|
||||
for key, values in text_data.items():
|
||||
summary[f"{key}_count"] = len(values)
|
||||
|
||||
summary["total_entries"] = len(all_contexts)
|
||||
|
||||
return summary
|
||||
|
||||
def rolling_window_summary(
|
||||
self,
|
||||
layer: ContextLayer,
|
||||
window_days: int = 30,
|
||||
summarize_older: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a rolling window summary.
|
||||
|
||||
Recent data (within window) is kept detailed.
|
||||
Older data is summarized to key metrics.
|
||||
|
||||
Args:
|
||||
layer: Context layer to summarize
|
||||
window_days: Number of days to keep detailed
|
||||
summarize_older: Whether to summarize data older than window
|
||||
|
||||
Returns:
|
||||
Dictionary with recent (detailed) and historical (summary) data
|
||||
"""
|
||||
result: dict[str, Any] = {
|
||||
"window_days": window_days,
|
||||
"recent_data": {},
|
||||
"historical_summary": {},
|
||||
}
|
||||
|
||||
# Get all contexts
|
||||
all_contexts = self.store.get_all_contexts(layer)
|
||||
|
||||
recent_values: dict[str, list[float]] = {}
|
||||
historical_values: dict[str, list[float]] = {}
|
||||
|
||||
for key, value in all_contexts.items():
|
||||
# For simplicity, treat all numeric values
|
||||
if isinstance(value, (int, float)):
|
||||
# Note: We don't have timestamps in context keys
|
||||
# This is a simplified implementation
|
||||
# In practice, would need to check timeframe field
|
||||
|
||||
# For now, put recent data in window
|
||||
if key not in recent_values:
|
||||
recent_values[key] = []
|
||||
recent_values[key].append(float(value))
|
||||
|
||||
# Detailed recent data
|
||||
result["recent_data"] = {key: values[-10:] for key, values in recent_values.items()}
|
||||
|
||||
# Summarized historical data
|
||||
if summarize_older:
|
||||
for key, values in historical_values.items():
|
||||
stats = self.summarize_numeric_values(values)
|
||||
result["historical_summary"][key] = {
|
||||
"count": stats.count,
|
||||
"avg": stats.mean,
|
||||
"trend": stats.trend,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def aggregate_to_higher_layer(
|
||||
self,
|
||||
source_layer: ContextLayer,
|
||||
target_layer: ContextLayer,
|
||||
metric_key: str,
|
||||
aggregation_func: str = "mean",
|
||||
) -> float | None:
|
||||
"""Aggregate data from source layer to target layer.
|
||||
|
||||
Args:
|
||||
source_layer: Source context layer (more granular)
|
||||
target_layer: Target context layer (less granular)
|
||||
metric_key: Key of metric to aggregate
|
||||
aggregation_func: Aggregation function ("mean", "sum", "max", "min")
|
||||
|
||||
Returns:
|
||||
Aggregated value, or None if no data available
|
||||
"""
|
||||
# Get all contexts from source layer
|
||||
source_contexts = self.store.get_all_contexts(source_layer)
|
||||
|
||||
# Extract values for metric_key
|
||||
values = []
|
||||
for key, value in source_contexts.items():
|
||||
if key == metric_key and isinstance(value, (int, float)):
|
||||
values.append(float(value))
|
||||
elif isinstance(value, dict) and metric_key in value:
|
||||
subvalue = value[metric_key]
|
||||
if isinstance(subvalue, (int, float)):
|
||||
values.append(float(subvalue))
|
||||
|
||||
if not values:
|
||||
return None
|
||||
|
||||
# Apply aggregation function
|
||||
if aggregation_func == "mean":
|
||||
return sum(values) / len(values)
|
||||
elif aggregation_func == "sum":
|
||||
return sum(values)
|
||||
elif aggregation_func == "max":
|
||||
return max(values)
|
||||
elif aggregation_func == "min":
|
||||
return min(values)
|
||||
else:
|
||||
return sum(values) / len(values) # Default to mean
|
||||
|
||||
def create_compact_summary(
|
||||
self,
|
||||
layers: list[ContextLayer],
|
||||
top_n_metrics: int = 5,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a compact summary across multiple layers.
|
||||
|
||||
Args:
|
||||
layers: List of context layers to summarize
|
||||
top_n_metrics: Number of top metrics to include per layer
|
||||
|
||||
Returns:
|
||||
Compact summary dictionary
|
||||
"""
|
||||
summary: dict[str, Any] = {}
|
||||
|
||||
for layer in layers:
|
||||
layer_summary = self.summarize_layer(layer)
|
||||
|
||||
# Keep only top N metrics (by count/relevance)
|
||||
metrics = []
|
||||
for key, value in layer_summary.items():
|
||||
if isinstance(value, dict) and "count" in value:
|
||||
metrics.append((key, value, value["count"]))
|
||||
|
||||
# Sort by count (descending)
|
||||
metrics.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
# Keep top N
|
||||
top_metrics = {m[0]: m[1] for m in metrics[:top_n_metrics]}
|
||||
|
||||
summary[layer.value] = top_metrics
|
||||
|
||||
return summary
|
||||
|
||||
def format_summary_for_prompt(self, summary: dict[str, Any]) -> str:
|
||||
"""Format summary for inclusion in a prompt.
|
||||
|
||||
Args:
|
||||
summary: Summary dictionary
|
||||
|
||||
Returns:
|
||||
Formatted string for prompt
|
||||
"""
|
||||
lines = []
|
||||
|
||||
for layer, metrics in summary.items():
|
||||
if not metrics:
|
||||
continue
|
||||
|
||||
lines.append(f"{layer}:")
|
||||
for key, value in metrics.items():
|
||||
if isinstance(value, dict):
|
||||
# Format as: key: avg=X, trend=Y
|
||||
parts = []
|
||||
if "avg" in value and value["avg"] is not None:
|
||||
parts.append(f"avg={value['avg']:.2f}")
|
||||
if "trend" in value and value["trend"]:
|
||||
parts.append(f"trend={value['trend']}")
|
||||
if parts:
|
||||
lines.append(f" {key}: {', '.join(parts)}")
|
||||
else:
|
||||
lines.append(f" {key}: {value}")
|
||||
|
||||
return "\n".join(lines)
|
||||
110
src/core/criticality.py
Normal file
110
src/core/criticality.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""Criticality assessment for urgency-based response system.
|
||||
|
||||
Evaluates market conditions to determine response urgency and enable
|
||||
faster reactions in critical situations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class CriticalityLevel(StrEnum):
|
||||
"""Urgency levels for market conditions and trading decisions."""
|
||||
|
||||
CRITICAL = "CRITICAL" # <5s timeout - Emergency response required
|
||||
HIGH = "HIGH" # <30s timeout - Elevated priority
|
||||
NORMAL = "NORMAL" # <60s timeout - Standard processing
|
||||
LOW = "LOW" # No timeout - Batch processing
|
||||
|
||||
|
||||
class CriticalityAssessor:
|
||||
"""Assesses market conditions to determine response criticality level."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
critical_pnl_threshold: float = -2.5,
|
||||
critical_price_change_threshold: float = 5.0,
|
||||
critical_volume_surge_threshold: float = 10.0,
|
||||
high_volatility_threshold: float = 70.0,
|
||||
low_volatility_threshold: float = 30.0,
|
||||
) -> None:
|
||||
"""Initialize the criticality assessor.
|
||||
|
||||
Args:
|
||||
critical_pnl_threshold: P&L % that triggers CRITICAL (default -2.5%)
|
||||
critical_price_change_threshold: Price change % that triggers CRITICAL
|
||||
(default 5.0% in 1 minute)
|
||||
critical_volume_surge_threshold: Volume surge ratio that triggers CRITICAL
|
||||
(default 10x average)
|
||||
high_volatility_threshold: Volatility score that triggers HIGH
|
||||
(default 70.0)
|
||||
low_volatility_threshold: Volatility score below which is LOW
|
||||
(default 30.0)
|
||||
"""
|
||||
self.critical_pnl_threshold = critical_pnl_threshold
|
||||
self.critical_price_change_threshold = critical_price_change_threshold
|
||||
self.critical_volume_surge_threshold = critical_volume_surge_threshold
|
||||
self.high_volatility_threshold = high_volatility_threshold
|
||||
self.low_volatility_threshold = low_volatility_threshold
|
||||
|
||||
def assess_market_conditions(
|
||||
self,
|
||||
pnl_pct: float,
|
||||
volatility_score: float,
|
||||
volume_surge: float,
|
||||
price_change_1m: float = 0.0,
|
||||
is_market_open: bool = True,
|
||||
) -> CriticalityLevel:
|
||||
"""Assess criticality level based on market conditions.
|
||||
|
||||
Args:
|
||||
pnl_pct: Current P&L percentage
|
||||
volatility_score: Momentum score from VolatilityAnalyzer (0-100)
|
||||
volume_surge: Volume surge ratio (current / average)
|
||||
price_change_1m: 1-minute price change percentage
|
||||
is_market_open: Whether the market is currently open
|
||||
|
||||
Returns:
|
||||
CriticalityLevel indicating required response urgency
|
||||
"""
|
||||
# Market closed or very quiet → LOW priority (batch processing)
|
||||
if not is_market_open or volatility_score < self.low_volatility_threshold:
|
||||
return CriticalityLevel.LOW
|
||||
|
||||
# CRITICAL conditions: immediate action required
|
||||
# 1. P&L near circuit breaker (-2.5% is close to -3.0% breaker)
|
||||
if pnl_pct <= self.critical_pnl_threshold:
|
||||
return CriticalityLevel.CRITICAL
|
||||
|
||||
# 2. Large sudden price movement (>5% in 1 minute)
|
||||
if abs(price_change_1m) >= self.critical_price_change_threshold:
|
||||
return CriticalityLevel.CRITICAL
|
||||
|
||||
# 3. Extreme volume surge (>10x average) indicates major event
|
||||
if volume_surge >= self.critical_volume_surge_threshold:
|
||||
return CriticalityLevel.CRITICAL
|
||||
|
||||
# HIGH priority: elevated volatility requires faster response
|
||||
if volatility_score >= self.high_volatility_threshold:
|
||||
return CriticalityLevel.HIGH
|
||||
|
||||
# NORMAL: standard trading conditions
|
||||
return CriticalityLevel.NORMAL
|
||||
|
||||
def get_timeout(self, level: CriticalityLevel) -> float | None:
|
||||
"""Get timeout in seconds for a given criticality level.
|
||||
|
||||
Args:
|
||||
level: Criticality level
|
||||
|
||||
Returns:
|
||||
Timeout in seconds, or None for no timeout (LOW priority)
|
||||
"""
|
||||
timeout_map = {
|
||||
CriticalityLevel.CRITICAL: 5.0,
|
||||
CriticalityLevel.HIGH: 30.0,
|
||||
CriticalityLevel.NORMAL: 60.0,
|
||||
CriticalityLevel.LOW: None,
|
||||
}
|
||||
return timeout_map[level]
|
||||
291
src/core/priority_queue.py
Normal file
291
src/core/priority_queue.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""Priority-based task queue for latency control.
|
||||
|
||||
Implements a thread-safe priority queue with timeout enforcement and metrics tracking.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import heapq
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable, Coroutine
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from src.core.criticality import CriticalityLevel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class PriorityTask:
|
||||
"""Task with priority and timestamp for queue ordering."""
|
||||
|
||||
# Lower priority value = higher urgency (CRITICAL=0, HIGH=1, NORMAL=2, LOW=3)
|
||||
priority: int
|
||||
timestamp: float
|
||||
# Task data not used in comparison
|
||||
task_id: str = field(compare=False)
|
||||
task_data: dict[str, Any] = field(compare=False, default_factory=dict)
|
||||
callback: Callable[[], Coroutine[Any, Any, Any]] | None = field(
|
||||
compare=False, default=None
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueueMetrics:
|
||||
"""Metrics for priority queue performance monitoring."""
|
||||
|
||||
total_enqueued: int = 0
|
||||
total_dequeued: int = 0
|
||||
total_timeouts: int = 0
|
||||
total_errors: int = 0
|
||||
current_size: int = 0
|
||||
# Average wait time per criticality level (in seconds)
|
||||
avg_wait_time: dict[CriticalityLevel, float] = field(default_factory=dict)
|
||||
# P95 wait time per criticality level
|
||||
p95_wait_time: dict[CriticalityLevel, float] = field(default_factory=dict)
|
||||
|
||||
|
||||
class PriorityTaskQueue:
|
||||
"""Thread-safe priority queue with timeout enforcement."""
|
||||
|
||||
# Priority mapping for criticality levels
|
||||
PRIORITY_MAP = {
|
||||
CriticalityLevel.CRITICAL: 0,
|
||||
CriticalityLevel.HIGH: 1,
|
||||
CriticalityLevel.NORMAL: 2,
|
||||
CriticalityLevel.LOW: 3,
|
||||
}
|
||||
|
||||
def __init__(self, max_size: int = 1000) -> None:
|
||||
"""Initialize the priority task queue.
|
||||
|
||||
Args:
|
||||
max_size: Maximum queue size (default 1000)
|
||||
"""
|
||||
self._queue: list[PriorityTask] = []
|
||||
self._lock = asyncio.Lock()
|
||||
self._max_size = max_size
|
||||
self._metrics = QueueMetrics()
|
||||
# Track wait times for metrics
|
||||
self._wait_times: dict[CriticalityLevel, list[float]] = {
|
||||
level: [] for level in CriticalityLevel
|
||||
}
|
||||
|
||||
async def enqueue(
|
||||
self,
|
||||
task_id: str,
|
||||
criticality: CriticalityLevel,
|
||||
task_data: dict[str, Any],
|
||||
callback: Callable[[], Coroutine[Any, Any, Any]] | None = None,
|
||||
) -> bool:
|
||||
"""Add a task to the priority queue.
|
||||
|
||||
Args:
|
||||
task_id: Unique identifier for the task
|
||||
criticality: Criticality level determining priority
|
||||
task_data: Data associated with the task
|
||||
callback: Optional async callback to execute
|
||||
|
||||
Returns:
|
||||
True if enqueued successfully, False if queue is full
|
||||
"""
|
||||
async with self._lock:
|
||||
if len(self._queue) >= self._max_size:
|
||||
logger.warning(
|
||||
"Priority queue full (size=%d), rejecting task %s",
|
||||
len(self._queue),
|
||||
task_id,
|
||||
)
|
||||
return False
|
||||
|
||||
priority = self.PRIORITY_MAP[criticality]
|
||||
timestamp = time.time()
|
||||
|
||||
task = PriorityTask(
|
||||
priority=priority,
|
||||
timestamp=timestamp,
|
||||
task_id=task_id,
|
||||
task_data=task_data,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
heapq.heappush(self._queue, task)
|
||||
self._metrics.total_enqueued += 1
|
||||
self._metrics.current_size = len(self._queue)
|
||||
|
||||
logger.debug(
|
||||
"Enqueued task %s with criticality %s (priority=%d, queue_size=%d)",
|
||||
task_id,
|
||||
criticality.value,
|
||||
priority,
|
||||
len(self._queue),
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
async def dequeue(self, timeout: float | None = None) -> PriorityTask | None:
|
||||
"""Remove and return the highest priority task from the queue.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for a task (seconds)
|
||||
|
||||
Returns:
|
||||
PriorityTask if available, None if queue is empty or timeout
|
||||
"""
|
||||
start_time = time.time()
|
||||
deadline = start_time + timeout if timeout else None
|
||||
|
||||
while True:
|
||||
async with self._lock:
|
||||
if self._queue:
|
||||
task = heapq.heappop(self._queue)
|
||||
self._metrics.total_dequeued += 1
|
||||
self._metrics.current_size = len(self._queue)
|
||||
|
||||
# Calculate wait time
|
||||
wait_time = time.time() - task.timestamp
|
||||
criticality = self._get_criticality_from_priority(task.priority)
|
||||
self._wait_times[criticality].append(wait_time)
|
||||
self._update_wait_time_metrics()
|
||||
|
||||
logger.debug(
|
||||
"Dequeued task %s (priority=%d, wait_time=%.2fs, queue_size=%d)",
|
||||
task.task_id,
|
||||
task.priority,
|
||||
wait_time,
|
||||
len(self._queue),
|
||||
)
|
||||
|
||||
return task
|
||||
|
||||
# Queue is empty
|
||||
if deadline and time.time() >= deadline:
|
||||
return None
|
||||
|
||||
# Wait a bit before checking again
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def execute_with_timeout(
|
||||
self,
|
||||
task: PriorityTask,
|
||||
timeout: float | None,
|
||||
) -> Any:
|
||||
"""Execute a task with timeout enforcement.
|
||||
|
||||
Args:
|
||||
task: Task to execute
|
||||
timeout: Timeout in seconds (None = no timeout)
|
||||
|
||||
Returns:
|
||||
Result from task callback
|
||||
|
||||
Raises:
|
||||
asyncio.TimeoutError: If task exceeds timeout
|
||||
Exception: Any exception raised by the task callback
|
||||
"""
|
||||
if not task.callback:
|
||||
logger.warning("Task %s has no callback, skipping execution", task.task_id)
|
||||
return None
|
||||
|
||||
criticality = self._get_criticality_from_priority(task.priority)
|
||||
|
||||
try:
|
||||
if timeout:
|
||||
result = await asyncio.wait_for(task.callback(), timeout=timeout)
|
||||
else:
|
||||
result = await task.callback()
|
||||
|
||||
logger.debug(
|
||||
"Task %s completed successfully (criticality=%s)",
|
||||
task.task_id,
|
||||
criticality.value,
|
||||
)
|
||||
return result
|
||||
|
||||
except TimeoutError:
|
||||
self._metrics.total_timeouts += 1
|
||||
logger.error(
|
||||
"Task %s timed out after %.2fs (criticality=%s)",
|
||||
task.task_id,
|
||||
timeout or 0.0,
|
||||
criticality.value,
|
||||
)
|
||||
raise
|
||||
|
||||
except Exception as exc:
|
||||
self._metrics.total_errors += 1
|
||||
logger.exception(
|
||||
"Task %s failed with error (criticality=%s): %s",
|
||||
task.task_id,
|
||||
criticality.value,
|
||||
exc,
|
||||
)
|
||||
raise
|
||||
|
||||
def _get_criticality_from_priority(self, priority: int) -> CriticalityLevel:
|
||||
"""Convert priority back to criticality level."""
|
||||
for level, prio in self.PRIORITY_MAP.items():
|
||||
if prio == priority:
|
||||
return level
|
||||
return CriticalityLevel.NORMAL
|
||||
|
||||
def _update_wait_time_metrics(self) -> None:
|
||||
"""Update average and p95 wait time metrics."""
|
||||
for level, times in self._wait_times.items():
|
||||
if not times:
|
||||
continue
|
||||
|
||||
# Keep only last 1000 measurements to avoid memory bloat
|
||||
if len(times) > 1000:
|
||||
self._wait_times[level] = times[-1000:]
|
||||
times = self._wait_times[level]
|
||||
|
||||
# Calculate average
|
||||
self._metrics.avg_wait_time[level] = sum(times) / len(times)
|
||||
|
||||
# Calculate P95
|
||||
sorted_times = sorted(times)
|
||||
p95_idx = int(len(sorted_times) * 0.95)
|
||||
self._metrics.p95_wait_time[level] = sorted_times[p95_idx]
|
||||
|
||||
async def get_metrics(self) -> QueueMetrics:
|
||||
"""Get current queue metrics.
|
||||
|
||||
Returns:
|
||||
QueueMetrics with current statistics
|
||||
"""
|
||||
async with self._lock:
|
||||
return QueueMetrics(
|
||||
total_enqueued=self._metrics.total_enqueued,
|
||||
total_dequeued=self._metrics.total_dequeued,
|
||||
total_timeouts=self._metrics.total_timeouts,
|
||||
total_errors=self._metrics.total_errors,
|
||||
current_size=self._metrics.current_size,
|
||||
avg_wait_time=dict(self._metrics.avg_wait_time),
|
||||
p95_wait_time=dict(self._metrics.p95_wait_time),
|
||||
)
|
||||
|
||||
async def size(self) -> int:
|
||||
"""Get current queue size.
|
||||
|
||||
Returns:
|
||||
Number of tasks in queue
|
||||
"""
|
||||
async with self._lock:
|
||||
return len(self._queue)
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Clear all tasks from the queue.
|
||||
|
||||
Returns:
|
||||
Number of tasks cleared
|
||||
"""
|
||||
async with self._lock:
|
||||
count = len(self._queue)
|
||||
self._queue.clear()
|
||||
self._metrics.current_size = 0
|
||||
logger.info("Cleared %d tasks from priority queue", count)
|
||||
return count
|
||||
5
src/dashboard/__init__.py
Normal file
5
src/dashboard/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""FastAPI dashboard package for observability APIs."""
|
||||
|
||||
from src.dashboard.app import create_dashboard_app
|
||||
|
||||
__all__ = ["create_dashboard_app"]
|
||||
497
src/dashboard/app.py
Normal file
497
src/dashboard/app.py
Normal file
@@ -0,0 +1,497 @@
|
||||
"""FastAPI application for observability dashboard endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
from datetime import UTC, datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Query
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
|
||||
def create_dashboard_app(db_path: str) -> FastAPI:
|
||||
"""Create dashboard FastAPI app bound to a SQLite database path."""
|
||||
app = FastAPI(title="The Ouroboros Dashboard", version="1.0.0")
|
||||
app.state.db_path = db_path
|
||||
|
||||
@app.get("/")
|
||||
def index() -> FileResponse:
|
||||
index_path = Path(__file__).parent / "static" / "index.html"
|
||||
return FileResponse(index_path)
|
||||
|
||||
@app.get("/api/status")
|
||||
def get_status() -> dict[str, Any]:
|
||||
today = datetime.now(UTC).date().isoformat()
|
||||
with _connect(db_path) as conn:
|
||||
market_rows = conn.execute(
|
||||
"""
|
||||
SELECT DISTINCT market FROM (
|
||||
SELECT market FROM trades WHERE DATE(timestamp) = ?
|
||||
UNION
|
||||
SELECT market FROM decision_logs WHERE DATE(timestamp) = ?
|
||||
UNION
|
||||
SELECT market FROM playbooks WHERE date = ?
|
||||
) ORDER BY market
|
||||
""",
|
||||
(today, today, today),
|
||||
).fetchall()
|
||||
markets = [row[0] for row in market_rows] if market_rows else []
|
||||
market_status: dict[str, Any] = {}
|
||||
total_trades = 0
|
||||
total_pnl = 0.0
|
||||
total_decisions = 0
|
||||
for market in markets:
|
||||
trade_row = conn.execute(
|
||||
"""
|
||||
SELECT COUNT(*) AS c, COALESCE(SUM(pnl), 0.0) AS p
|
||||
FROM trades
|
||||
WHERE DATE(timestamp) = ? AND market = ?
|
||||
""",
|
||||
(today, market),
|
||||
).fetchone()
|
||||
decision_row = conn.execute(
|
||||
"""
|
||||
SELECT COUNT(*) AS c
|
||||
FROM decision_logs
|
||||
WHERE DATE(timestamp) = ? AND market = ?
|
||||
""",
|
||||
(today, market),
|
||||
).fetchone()
|
||||
playbook_row = conn.execute(
|
||||
"""
|
||||
SELECT status
|
||||
FROM playbooks
|
||||
WHERE date = ? AND market = ?
|
||||
LIMIT 1
|
||||
""",
|
||||
(today, market),
|
||||
).fetchone()
|
||||
market_status[market] = {
|
||||
"trade_count": int(trade_row["c"] if trade_row else 0),
|
||||
"total_pnl": float(trade_row["p"] if trade_row else 0.0),
|
||||
"decision_count": int(decision_row["c"] if decision_row else 0),
|
||||
"playbook_status": playbook_row["status"] if playbook_row else None,
|
||||
}
|
||||
total_trades += market_status[market]["trade_count"]
|
||||
total_pnl += market_status[market]["total_pnl"]
|
||||
total_decisions += market_status[market]["decision_count"]
|
||||
|
||||
cb_threshold = float(os.getenv("CIRCUIT_BREAKER_PCT", "-3.0"))
|
||||
pnl_pct_rows = conn.execute(
|
||||
"""
|
||||
SELECT key, value
|
||||
FROM system_metrics
|
||||
WHERE key LIKE 'portfolio_pnl_pct_%'
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT 20
|
||||
"""
|
||||
).fetchall()
|
||||
current_pnl_pct: float | None = None
|
||||
if pnl_pct_rows:
|
||||
values = [
|
||||
json.loads(row["value"]).get("pnl_pct")
|
||||
for row in pnl_pct_rows
|
||||
if json.loads(row["value"]).get("pnl_pct") is not None
|
||||
]
|
||||
if values:
|
||||
current_pnl_pct = round(min(values), 4)
|
||||
|
||||
if current_pnl_pct is None:
|
||||
cb_status = "unknown"
|
||||
elif current_pnl_pct <= cb_threshold:
|
||||
cb_status = "tripped"
|
||||
elif current_pnl_pct <= cb_threshold + 1.0:
|
||||
cb_status = "warning"
|
||||
else:
|
||||
cb_status = "ok"
|
||||
|
||||
return {
|
||||
"date": today,
|
||||
"mode": os.getenv("MODE", "paper"),
|
||||
"markets": market_status,
|
||||
"totals": {
|
||||
"trade_count": total_trades,
|
||||
"total_pnl": round(total_pnl, 2),
|
||||
"decision_count": total_decisions,
|
||||
},
|
||||
"circuit_breaker": {
|
||||
"threshold_pct": cb_threshold,
|
||||
"current_pnl_pct": current_pnl_pct,
|
||||
"status": cb_status,
|
||||
},
|
||||
}
|
||||
|
||||
@app.get("/api/playbook/{date_str}")
|
||||
def get_playbook(date_str: str, market: str = Query("KR")) -> dict[str, Any]:
|
||||
with _connect(db_path) as conn:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT date, market, status, playbook_json, generated_at,
|
||||
token_count, scenario_count, match_count
|
||||
FROM playbooks
|
||||
WHERE date = ? AND market = ?
|
||||
""",
|
||||
(date_str, market),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(status_code=404, detail="playbook not found")
|
||||
return {
|
||||
"date": row["date"],
|
||||
"market": row["market"],
|
||||
"status": row["status"],
|
||||
"playbook": json.loads(row["playbook_json"]),
|
||||
"generated_at": row["generated_at"],
|
||||
"token_count": row["token_count"],
|
||||
"scenario_count": row["scenario_count"],
|
||||
"match_count": row["match_count"],
|
||||
}
|
||||
|
||||
@app.get("/api/scorecard/{date_str}")
|
||||
def get_scorecard(date_str: str, market: str = Query("KR")) -> dict[str, Any]:
|
||||
key = f"scorecard_{market}"
|
||||
with _connect(db_path) as conn:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT value
|
||||
FROM contexts
|
||||
WHERE layer = 'L6_DAILY' AND timeframe = ? AND key = ?
|
||||
""",
|
||||
(date_str, key),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(status_code=404, detail="scorecard not found")
|
||||
return {"date": date_str, "market": market, "scorecard": json.loads(row["value"])}
|
||||
|
||||
@app.get("/api/performance")
|
||||
def get_performance(market: str = Query("all")) -> dict[str, Any]:
|
||||
with _connect(db_path) as conn:
|
||||
if market == "all":
|
||||
by_market_rows = conn.execute(
|
||||
"""
|
||||
SELECT market,
|
||||
COUNT(*) AS total_trades,
|
||||
SUM(CASE WHEN pnl > 0 THEN 1 ELSE 0 END) AS wins,
|
||||
SUM(CASE WHEN pnl < 0 THEN 1 ELSE 0 END) AS losses,
|
||||
COALESCE(SUM(pnl), 0.0) AS total_pnl,
|
||||
COALESCE(AVG(confidence), 0.0) AS avg_confidence
|
||||
FROM trades
|
||||
GROUP BY market
|
||||
ORDER BY market
|
||||
"""
|
||||
).fetchall()
|
||||
combined = _performance_from_rows(by_market_rows)
|
||||
return {
|
||||
"market": "all",
|
||||
"combined": combined,
|
||||
"by_market": [
|
||||
_row_to_performance(row)
|
||||
for row in by_market_rows
|
||||
],
|
||||
}
|
||||
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT market,
|
||||
COUNT(*) AS total_trades,
|
||||
SUM(CASE WHEN pnl > 0 THEN 1 ELSE 0 END) AS wins,
|
||||
SUM(CASE WHEN pnl < 0 THEN 1 ELSE 0 END) AS losses,
|
||||
COALESCE(SUM(pnl), 0.0) AS total_pnl,
|
||||
COALESCE(AVG(confidence), 0.0) AS avg_confidence
|
||||
FROM trades
|
||||
WHERE market = ?
|
||||
GROUP BY market
|
||||
""",
|
||||
(market,),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return {"market": market, "metrics": _empty_performance(market)}
|
||||
return {"market": market, "metrics": _row_to_performance(row)}
|
||||
|
||||
@app.get("/api/context/{layer}")
|
||||
def get_context_layer(
|
||||
layer: str,
|
||||
timeframe: str | None = Query(default=None),
|
||||
limit: int = Query(default=100, ge=1, le=1000),
|
||||
) -> dict[str, Any]:
|
||||
with _connect(db_path) as conn:
|
||||
if timeframe is None:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT timeframe, key, value, updated_at
|
||||
FROM contexts
|
||||
WHERE layer = ?
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(layer, limit),
|
||||
).fetchall()
|
||||
else:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT timeframe, key, value, updated_at
|
||||
FROM contexts
|
||||
WHERE layer = ? AND timeframe = ?
|
||||
ORDER BY key
|
||||
LIMIT ?
|
||||
""",
|
||||
(layer, timeframe, limit),
|
||||
).fetchall()
|
||||
|
||||
entries = [
|
||||
{
|
||||
"timeframe": row["timeframe"],
|
||||
"key": row["key"],
|
||||
"value": json.loads(row["value"]),
|
||||
"updated_at": row["updated_at"],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
return {
|
||||
"layer": layer,
|
||||
"timeframe": timeframe,
|
||||
"count": len(entries),
|
||||
"entries": entries,
|
||||
}
|
||||
|
||||
@app.get("/api/decisions")
|
||||
def get_decisions(
|
||||
market: str = Query("KR"),
|
||||
limit: int = Query(default=50, ge=1, le=500),
|
||||
) -> dict[str, Any]:
|
||||
with _connect(db_path) as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT decision_id, timestamp, stock_code, market, exchange_code,
|
||||
action, confidence, rationale, context_snapshot, input_data,
|
||||
outcome_pnl, outcome_accuracy
|
||||
FROM decision_logs
|
||||
WHERE market = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(market, limit),
|
||||
).fetchall()
|
||||
decisions = []
|
||||
for row in rows:
|
||||
decisions.append(
|
||||
{
|
||||
"decision_id": row["decision_id"],
|
||||
"timestamp": row["timestamp"],
|
||||
"stock_code": row["stock_code"],
|
||||
"market": row["market"],
|
||||
"exchange_code": row["exchange_code"],
|
||||
"action": row["action"],
|
||||
"confidence": row["confidence"],
|
||||
"rationale": row["rationale"],
|
||||
"context_snapshot": json.loads(row["context_snapshot"]),
|
||||
"input_data": json.loads(row["input_data"]),
|
||||
"outcome_pnl": row["outcome_pnl"],
|
||||
"outcome_accuracy": row["outcome_accuracy"],
|
||||
}
|
||||
)
|
||||
return {"market": market, "count": len(decisions), "decisions": decisions}
|
||||
|
||||
@app.get("/api/pnl/history")
|
||||
def get_pnl_history(
|
||||
days: int = Query(default=30, ge=1, le=365),
|
||||
market: str = Query("all"),
|
||||
) -> dict[str, Any]:
|
||||
"""Return daily P&L history for charting."""
|
||||
with _connect(db_path) as conn:
|
||||
if market == "all":
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT DATE(timestamp) AS date,
|
||||
SUM(pnl) AS daily_pnl,
|
||||
COUNT(*) AS trade_count
|
||||
FROM trades
|
||||
WHERE pnl IS NOT NULL
|
||||
AND DATE(timestamp) >= DATE('now', ?)
|
||||
GROUP BY DATE(timestamp)
|
||||
ORDER BY DATE(timestamp)
|
||||
""",
|
||||
(f"-{days} days",),
|
||||
).fetchall()
|
||||
else:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT DATE(timestamp) AS date,
|
||||
SUM(pnl) AS daily_pnl,
|
||||
COUNT(*) AS trade_count
|
||||
FROM trades
|
||||
WHERE pnl IS NOT NULL
|
||||
AND market = ?
|
||||
AND DATE(timestamp) >= DATE('now', ?)
|
||||
GROUP BY DATE(timestamp)
|
||||
ORDER BY DATE(timestamp)
|
||||
""",
|
||||
(market, f"-{days} days"),
|
||||
).fetchall()
|
||||
return {
|
||||
"days": days,
|
||||
"market": market,
|
||||
"labels": [row["date"] for row in rows],
|
||||
"pnl": [round(float(row["daily_pnl"]), 2) for row in rows],
|
||||
"trades": [int(row["trade_count"]) for row in rows],
|
||||
}
|
||||
|
||||
@app.get("/api/scenarios/active")
|
||||
def get_active_scenarios(
|
||||
market: str = Query("US"),
|
||||
date_str: str | None = Query(default=None),
|
||||
limit: int = Query(default=50, ge=1, le=500),
|
||||
) -> dict[str, Any]:
|
||||
if date_str is None:
|
||||
date_str = datetime.now(UTC).date().isoformat()
|
||||
|
||||
with _connect(db_path) as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT timestamp, stock_code, action, confidence, rationale, context_snapshot
|
||||
FROM decision_logs
|
||||
WHERE market = ? AND DATE(timestamp) = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(market, date_str, limit),
|
||||
).fetchall()
|
||||
matches: list[dict[str, Any]] = []
|
||||
for row in rows:
|
||||
snapshot = json.loads(row["context_snapshot"])
|
||||
scenario_match = snapshot.get("scenario_match", {})
|
||||
if not isinstance(scenario_match, dict) or not scenario_match:
|
||||
continue
|
||||
matches.append(
|
||||
{
|
||||
"timestamp": row["timestamp"],
|
||||
"stock_code": row["stock_code"],
|
||||
"action": row["action"],
|
||||
"confidence": row["confidence"],
|
||||
"rationale": row["rationale"],
|
||||
"scenario_match": scenario_match,
|
||||
}
|
||||
)
|
||||
return {"market": market, "date": date_str, "count": len(matches), "matches": matches}
|
||||
|
||||
@app.get("/api/positions")
|
||||
def get_positions() -> dict[str, Any]:
|
||||
"""Return all currently open positions (last trade per symbol is BUY)."""
|
||||
with _connect(db_path) as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT stock_code, market, exchange_code,
|
||||
price AS entry_price, quantity, timestamp AS entry_time,
|
||||
decision_id
|
||||
FROM (
|
||||
SELECT stock_code, market, exchange_code, price, quantity,
|
||||
timestamp, decision_id, action,
|
||||
ROW_NUMBER() OVER (
|
||||
PARTITION BY stock_code, market
|
||||
ORDER BY timestamp DESC
|
||||
) AS rn
|
||||
FROM trades
|
||||
)
|
||||
WHERE rn = 1 AND action = 'BUY'
|
||||
ORDER BY entry_time DESC
|
||||
"""
|
||||
).fetchall()
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
positions = []
|
||||
for row in rows:
|
||||
entry_time_str = row["entry_time"]
|
||||
try:
|
||||
entry_dt = datetime.fromisoformat(entry_time_str.replace("Z", "+00:00"))
|
||||
held_seconds = int((now - entry_dt).total_seconds())
|
||||
held_hours = held_seconds // 3600
|
||||
held_minutes = (held_seconds % 3600) // 60
|
||||
if held_hours >= 1:
|
||||
held_display = f"{held_hours}h {held_minutes}m"
|
||||
else:
|
||||
held_display = f"{held_minutes}m"
|
||||
except (ValueError, TypeError):
|
||||
held_display = "--"
|
||||
|
||||
positions.append(
|
||||
{
|
||||
"stock_code": row["stock_code"],
|
||||
"market": row["market"],
|
||||
"exchange_code": row["exchange_code"],
|
||||
"entry_price": row["entry_price"],
|
||||
"quantity": row["quantity"],
|
||||
"entry_time": entry_time_str,
|
||||
"held": held_display,
|
||||
"decision_id": row["decision_id"],
|
||||
}
|
||||
)
|
||||
|
||||
return {"count": len(positions), "positions": positions}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def _connect(db_path: str) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA busy_timeout=8000")
|
||||
return conn
|
||||
|
||||
|
||||
def _row_to_performance(row: sqlite3.Row) -> dict[str, Any]:
|
||||
wins = int(row["wins"] or 0)
|
||||
losses = int(row["losses"] or 0)
|
||||
total = int(row["total_trades"] or 0)
|
||||
win_rate = round((wins / (wins + losses) * 100), 2) if (wins + losses) > 0 else 0.0
|
||||
return {
|
||||
"market": row["market"],
|
||||
"total_trades": total,
|
||||
"wins": wins,
|
||||
"losses": losses,
|
||||
"win_rate": win_rate,
|
||||
"total_pnl": round(float(row["total_pnl"] or 0.0), 2),
|
||||
"avg_confidence": round(float(row["avg_confidence"] or 0.0), 2),
|
||||
}
|
||||
|
||||
|
||||
def _performance_from_rows(rows: list[sqlite3.Row]) -> dict[str, Any]:
|
||||
total_trades = 0
|
||||
wins = 0
|
||||
losses = 0
|
||||
total_pnl = 0.0
|
||||
confidence_weighted = 0.0
|
||||
for row in rows:
|
||||
market_total = int(row["total_trades"] or 0)
|
||||
market_conf = float(row["avg_confidence"] or 0.0)
|
||||
total_trades += market_total
|
||||
wins += int(row["wins"] or 0)
|
||||
losses += int(row["losses"] or 0)
|
||||
total_pnl += float(row["total_pnl"] or 0.0)
|
||||
confidence_weighted += market_total * market_conf
|
||||
win_rate = round((wins / (wins + losses) * 100), 2) if (wins + losses) > 0 else 0.0
|
||||
avg_confidence = round(confidence_weighted / total_trades, 2) if total_trades > 0 else 0.0
|
||||
return {
|
||||
"market": "all",
|
||||
"total_trades": total_trades,
|
||||
"wins": wins,
|
||||
"losses": losses,
|
||||
"win_rate": win_rate,
|
||||
"total_pnl": round(total_pnl, 2),
|
||||
"avg_confidence": avg_confidence,
|
||||
}
|
||||
|
||||
|
||||
def _empty_performance(market: str) -> dict[str, Any]:
|
||||
return {
|
||||
"market": market,
|
||||
"total_trades": 0,
|
||||
"wins": 0,
|
||||
"losses": 0,
|
||||
"win_rate": 0.0,
|
||||
"total_pnl": 0.0,
|
||||
"avg_confidence": 0.0,
|
||||
}
|
||||
798
src/dashboard/static/index.html
Normal file
798
src/dashboard/static/index.html
Normal file
@@ -0,0 +1,798 @@
|
||||
<!doctype html>
|
||||
<html lang="ko">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>The Ouroboros Dashboard</title>
|
||||
<script src="https://cdn.jsdelivr.net/npm/chart.js@4.4.0/dist/chart.umd.min.js"></script>
|
||||
<style>
|
||||
:root {
|
||||
--bg: #0b1724;
|
||||
--panel: #12263a;
|
||||
--fg: #e6eef7;
|
||||
--muted: #9fb3c8;
|
||||
--accent: #3cb371;
|
||||
--red: #e05555;
|
||||
--warn: #e8a040;
|
||||
--border: #28455f;
|
||||
}
|
||||
* { box-sizing: border-box; margin: 0; padding: 0; }
|
||||
body {
|
||||
font-family: ui-monospace, SFMono-Regular, Menlo, monospace;
|
||||
background: radial-gradient(circle at top left, #173b58, var(--bg));
|
||||
color: var(--fg);
|
||||
min-height: 100vh;
|
||||
font-size: 13px;
|
||||
}
|
||||
.wrap { max-width: 1100px; margin: 0 auto; padding: 20px 16px; }
|
||||
|
||||
/* Header */
|
||||
header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
margin-bottom: 20px;
|
||||
padding-bottom: 12px;
|
||||
border-bottom: 1px solid var(--border);
|
||||
}
|
||||
header h1 { font-size: 18px; color: var(--accent); letter-spacing: 0.5px; }
|
||||
.header-right { display: flex; align-items: center; gap: 12px; color: var(--muted); font-size: 12px; }
|
||||
.refresh-btn {
|
||||
background: none; border: 1px solid var(--border); color: var(--muted);
|
||||
padding: 4px 10px; border-radius: 6px; cursor: pointer; font-family: inherit;
|
||||
font-size: 12px; transition: border-color 0.2s;
|
||||
}
|
||||
.refresh-btn:hover { border-color: var(--accent); color: var(--accent); }
|
||||
.mode-badge {
|
||||
padding: 3px 10px; border-radius: 5px; font-size: 12px; font-weight: 700;
|
||||
letter-spacing: 0.5px;
|
||||
}
|
||||
.mode-badge.live {
|
||||
background: rgba(224, 85, 85, 0.15); color: var(--red);
|
||||
border: 1px solid rgba(224, 85, 85, 0.4);
|
||||
animation: pulse-warn 2s ease-in-out infinite;
|
||||
}
|
||||
.mode-badge.paper {
|
||||
background: rgba(232, 160, 64, 0.15); color: var(--warn);
|
||||
border: 1px solid rgba(232, 160, 64, 0.4);
|
||||
}
|
||||
|
||||
/* CB Gauge */
|
||||
.cb-gauge-wrap {
|
||||
display: flex; align-items: center; gap: 8px;
|
||||
font-size: 11px; color: var(--muted);
|
||||
}
|
||||
.cb-dot {
|
||||
width: 8px; height: 8px; border-radius: 50%; flex-shrink: 0;
|
||||
}
|
||||
.cb-dot.ok { background: var(--accent); }
|
||||
.cb-dot.warning { background: var(--warn); animation: pulse-warn 1.2s ease-in-out infinite; }
|
||||
.cb-dot.tripped { background: var(--red); animation: pulse-warn 0.6s ease-in-out infinite; }
|
||||
.cb-dot.unknown { background: var(--border); }
|
||||
@keyframes pulse-warn {
|
||||
0%, 100% { opacity: 1; }
|
||||
50% { opacity: 0.35; }
|
||||
}
|
||||
.cb-bar-wrap { width: 64px; height: 5px; background: rgba(255,255,255,0.08); border-radius: 3px; overflow: hidden; }
|
||||
.cb-bar-fill { height: 100%; border-radius: 3px; transition: width 0.4s, background 0.4s; }
|
||||
|
||||
/* Summary cards */
|
||||
.cards { display: grid; grid-template-columns: repeat(4, 1fr); gap: 12px; margin-bottom: 20px; }
|
||||
@media (max-width: 700px) { .cards { grid-template-columns: repeat(2, 1fr); } }
|
||||
.card {
|
||||
background: var(--panel);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 10px;
|
||||
padding: 16px;
|
||||
}
|
||||
.card-label { color: var(--muted); font-size: 11px; margin-bottom: 6px; text-transform: uppercase; letter-spacing: 0.5px; }
|
||||
.card-value { font-size: 22px; font-weight: 700; }
|
||||
.card-sub { color: var(--muted); font-size: 11px; margin-top: 4px; }
|
||||
.positive { color: var(--accent); }
|
||||
.negative { color: var(--red); }
|
||||
.neutral { color: var(--fg); }
|
||||
|
||||
/* Chart panel */
|
||||
.chart-panel {
|
||||
background: var(--panel);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 10px;
|
||||
padding: 16px;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
.panel-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
.panel-title { font-size: 13px; color: var(--muted); font-weight: 600; }
|
||||
.chart-container { position: relative; height: 180px; }
|
||||
.chart-error { color: var(--muted); text-align: center; padding: 40px 0; font-size: 12px; }
|
||||
|
||||
/* Days selector */
|
||||
.days-selector { display: flex; gap: 4px; }
|
||||
.day-btn {
|
||||
background: none; border: 1px solid var(--border); color: var(--muted);
|
||||
padding: 3px 8px; border-radius: 4px; cursor: pointer; font-family: inherit; font-size: 11px;
|
||||
}
|
||||
.day-btn.active { border-color: var(--accent); color: var(--accent); background: rgba(60, 179, 113, 0.08); }
|
||||
|
||||
/* Decisions panel */
|
||||
.decisions-panel {
|
||||
background: var(--panel);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 10px;
|
||||
padding: 16px;
|
||||
}
|
||||
.market-tabs { display: flex; gap: 6px; flex-wrap: wrap; }
|
||||
.tab-btn {
|
||||
background: none; border: 1px solid var(--border); color: var(--muted);
|
||||
padding: 4px 10px; border-radius: 6px; cursor: pointer; font-family: inherit; font-size: 11px;
|
||||
}
|
||||
.tab-btn.active { border-color: var(--accent); color: var(--accent); background: rgba(60, 179, 113, 0.08); }
|
||||
.decisions-table { width: 100%; border-collapse: collapse; margin-top: 14px; }
|
||||
.decisions-table th {
|
||||
text-align: left; color: var(--muted); font-size: 11px; font-weight: 600;
|
||||
padding: 6px 8px; border-bottom: 1px solid var(--border); white-space: nowrap;
|
||||
}
|
||||
.decisions-table td {
|
||||
padding: 8px 8px; border-bottom: 1px solid rgba(40, 69, 95, 0.5);
|
||||
vertical-align: middle; white-space: nowrap;
|
||||
}
|
||||
.decisions-table tr:last-child td { border-bottom: none; }
|
||||
.decisions-table tr:hover td { background: rgba(255,255,255,0.02); }
|
||||
.badge {
|
||||
display: inline-block; padding: 2px 7px; border-radius: 4px;
|
||||
font-size: 11px; font-weight: 700; letter-spacing: 0.5px;
|
||||
}
|
||||
.badge-buy { background: rgba(60, 179, 113, 0.15); color: var(--accent); }
|
||||
.badge-sell { background: rgba(224, 85, 85, 0.15); color: var(--red); }
|
||||
.badge-hold { background: rgba(159, 179, 200, 0.12); color: var(--muted); }
|
||||
.conf-bar-wrap { display: flex; align-items: center; gap: 6px; min-width: 90px; }
|
||||
.conf-bar { flex: 1; height: 6px; background: rgba(255,255,255,0.08); border-radius: 3px; overflow: hidden; }
|
||||
.conf-fill { height: 100%; border-radius: 3px; background: var(--accent); transition: width 0.3s; }
|
||||
.conf-val { color: var(--muted); font-size: 11px; min-width: 26px; text-align: right; }
|
||||
.rationale-cell { max-width: 200px; overflow: hidden; text-overflow: ellipsis; color: var(--muted); }
|
||||
.empty-row td { text-align: center; color: var(--muted); padding: 24px; }
|
||||
|
||||
/* Positions panel */
|
||||
.positions-panel {
|
||||
background: var(--panel);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 10px;
|
||||
padding: 16px;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
.positions-table { width: 100%; border-collapse: collapse; margin-top: 14px; }
|
||||
.positions-table th {
|
||||
text-align: left; color: var(--muted); font-size: 11px; font-weight: 600;
|
||||
padding: 6px 8px; border-bottom: 1px solid var(--border); white-space: nowrap;
|
||||
}
|
||||
.positions-table td {
|
||||
padding: 8px 8px; border-bottom: 1px solid rgba(40, 69, 95, 0.5);
|
||||
vertical-align: middle; white-space: nowrap;
|
||||
}
|
||||
.positions-table tr:last-child td { border-bottom: none; }
|
||||
.positions-table tr:hover td { background: rgba(255,255,255,0.02); }
|
||||
.pos-empty { color: var(--muted); text-align: center; padding: 20px 0; font-size: 12px; }
|
||||
.pos-count {
|
||||
display: inline-block; background: rgba(60, 179, 113, 0.12);
|
||||
color: var(--accent); font-size: 11px; font-weight: 700;
|
||||
padding: 2px 8px; border-radius: 10px; margin-left: 8px;
|
||||
}
|
||||
|
||||
/* Spinner */
|
||||
.spinner { display: inline-block; width: 12px; height: 12px; border: 2px solid var(--border); border-top-color: var(--accent); border-radius: 50%; animation: spin 0.8s linear infinite; }
|
||||
@keyframes spin { to { transform: rotate(360deg); } }
|
||||
|
||||
/* Generic panel */
|
||||
.panel {
|
||||
background: var(--panel);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 10px;
|
||||
padding: 16px;
|
||||
margin-top: 20px;
|
||||
}
|
||||
|
||||
/* Playbook panel - details/summary accordion */
|
||||
.playbook-panel details { border: 1px solid var(--border); border-radius: 4px; margin-bottom: 6px; }
|
||||
.playbook-panel summary { padding: 8px 12px; cursor: pointer; font-weight: 600; background: var(--bg); color: var(--fg); }
|
||||
.playbook-panel summary:hover { color: var(--accent); }
|
||||
.playbook-panel pre { margin: 0; padding: 12px; background: var(--bg); overflow-x: auto;
|
||||
font-size: 11px; color: #a0c4ff; white-space: pre-wrap; }
|
||||
|
||||
/* Scorecard KPI card grid */
|
||||
.scorecard-grid { display: grid; grid-template-columns: repeat(auto-fill, minmax(140px, 1fr)); gap: 10px; }
|
||||
.kpi-card { background: var(--bg); border: 1px solid var(--border); border-radius: 6px; padding: 12px; text-align: center; }
|
||||
.kpi-card .kpi-label { font-size: 11px; color: var(--muted); margin-bottom: 4px; }
|
||||
.kpi-card .kpi-value { font-size: 20px; font-weight: 700; color: var(--fg); }
|
||||
|
||||
/* Scenarios table */
|
||||
.scenarios-table { width: 100%; border-collapse: collapse; font-size: 13px; }
|
||||
.scenarios-table th { background: var(--bg); padding: 8px; text-align: left; border-bottom: 1px solid var(--border);
|
||||
color: var(--muted); font-size: 11px; font-weight: 600; white-space: nowrap; }
|
||||
.scenarios-table td { padding: 7px 8px; border-bottom: 1px solid rgba(40,69,95,0.5); }
|
||||
.scenarios-table tr:hover td { background: rgba(255,255,255,0.02); }
|
||||
|
||||
/* Context table */
|
||||
.context-table { width: 100%; border-collapse: collapse; font-size: 12px; }
|
||||
.context-table th { background: var(--bg); padding: 8px; text-align: left; border-bottom: 1px solid var(--border);
|
||||
color: var(--muted); font-size: 11px; font-weight: 600; white-space: nowrap; }
|
||||
.context-table td { padding: 6px 8px; border-bottom: 1px solid rgba(40,69,95,0.5); vertical-align: top; }
|
||||
.context-value { max-height: 60px; overflow-y: auto; color: #a0c4ff; word-break: break-all; }
|
||||
|
||||
/* Common panel select controls */
|
||||
.panel-controls { display: flex; gap: 8px; align-items: center; flex-wrap: wrap; }
|
||||
.panel-controls select, .panel-controls input[type="number"] {
|
||||
background: var(--bg); color: var(--fg); border: 1px solid var(--border);
|
||||
border-radius: 4px; padding: 4px 8px; font-size: 13px; font-family: inherit;
|
||||
}
|
||||
.panel-date { color: var(--muted); font-size: 12px; }
|
||||
.empty-msg { color: var(--muted); text-align: center; padding: 20px 0; font-size: 12px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="wrap">
|
||||
<!-- Header -->
|
||||
<header>
|
||||
<h1>🐍 The Ouroboros</h1>
|
||||
<div class="header-right">
|
||||
<span class="mode-badge" id="mode-badge">--</span>
|
||||
<div class="cb-gauge-wrap" id="cb-gauge" title="Circuit Breaker">
|
||||
<span class="cb-dot unknown" id="cb-dot"></span>
|
||||
<span id="cb-label">CB --</span>
|
||||
<div class="cb-bar-wrap">
|
||||
<div class="cb-bar-fill" id="cb-bar" style="width:0%;background:var(--accent)"></div>
|
||||
</div>
|
||||
</div>
|
||||
<span id="last-updated">--</span>
|
||||
<button class="refresh-btn" onclick="refreshAll()">↺ 새로고침</button>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<!-- Summary cards -->
|
||||
<div class="cards">
|
||||
<div class="card">
|
||||
<div class="card-label">오늘 거래</div>
|
||||
<div class="card-value neutral" id="card-trades">--</div>
|
||||
<div class="card-sub" id="card-trades-sub">거래 건수</div>
|
||||
</div>
|
||||
<div class="card">
|
||||
<div class="card-label">오늘 P&L</div>
|
||||
<div class="card-value" id="card-pnl">--</div>
|
||||
<div class="card-sub" id="card-pnl-sub">실현 손익</div>
|
||||
</div>
|
||||
<div class="card">
|
||||
<div class="card-label">승률</div>
|
||||
<div class="card-value neutral" id="card-winrate">--</div>
|
||||
<div class="card-sub">전체 누적</div>
|
||||
</div>
|
||||
<div class="card">
|
||||
<div class="card-label">누적 거래</div>
|
||||
<div class="card-value neutral" id="card-total">--</div>
|
||||
<div class="card-sub">전체 기간</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Open Positions -->
|
||||
<div class="positions-panel">
|
||||
<div class="panel-header">
|
||||
<span class="panel-title">
|
||||
현재 보유 포지션
|
||||
<span class="pos-count" id="positions-count">0</span>
|
||||
</span>
|
||||
</div>
|
||||
<table class="positions-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>종목</th>
|
||||
<th>시장</th>
|
||||
<th>수량</th>
|
||||
<th>진입가</th>
|
||||
<th>보유 시간</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody id="positions-body">
|
||||
<tr><td colspan="5" class="pos-empty"><span class="spinner"></span></td></tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<!-- P&L Chart -->
|
||||
<div class="chart-panel">
|
||||
<div class="panel-header">
|
||||
<span class="panel-title">P&L 추이</span>
|
||||
<div class="days-selector">
|
||||
<button class="day-btn active" data-days="7" onclick="selectDays(this)">7일</button>
|
||||
<button class="day-btn" data-days="30" onclick="selectDays(this)">30일</button>
|
||||
<button class="day-btn" data-days="90" onclick="selectDays(this)">90일</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="chart-container">
|
||||
<canvas id="pnl-chart"></canvas>
|
||||
<div class="chart-error" id="chart-error" style="display:none">데이터 없음</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Decisions log -->
|
||||
<div class="decisions-panel">
|
||||
<div class="panel-header">
|
||||
<span class="panel-title">최근 결정 로그</span>
|
||||
<div class="market-tabs" id="market-tabs">
|
||||
<button class="tab-btn active" data-market="KR" onclick="selectMarket(this)">KR</button>
|
||||
<button class="tab-btn" data-market="US_NASDAQ" onclick="selectMarket(this)">US_NASDAQ</button>
|
||||
<button class="tab-btn" data-market="US_NYSE" onclick="selectMarket(this)">US_NYSE</button>
|
||||
<button class="tab-btn" data-market="JP" onclick="selectMarket(this)">JP</button>
|
||||
<button class="tab-btn" data-market="HK" onclick="selectMarket(this)">HK</button>
|
||||
</div>
|
||||
</div>
|
||||
<table class="decisions-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>시각</th>
|
||||
<th>종목</th>
|
||||
<th>액션</th>
|
||||
<th>신뢰도</th>
|
||||
<th>사유</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody id="decisions-body">
|
||||
<tr class="empty-row"><td colspan="5"><span class="spinner"></span></td></tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<!-- playbook panel -->
|
||||
<div class="panel playbook-panel">
|
||||
<div class="panel-header">
|
||||
<span class="panel-title">📋 프리마켓 플레이북</span>
|
||||
<div class="panel-controls">
|
||||
<select id="pb-market-select" onchange="fetchPlaybook()">
|
||||
<option value="KR">KR</option>
|
||||
<option value="US_NASDAQ">US_NASDAQ</option>
|
||||
<option value="US_NYSE">US_NYSE</option>
|
||||
</select>
|
||||
<span id="pb-date" class="panel-date"></span>
|
||||
</div>
|
||||
</div>
|
||||
<div id="playbook-content"><p class="empty-msg">데이터 없음</p></div>
|
||||
</div>
|
||||
|
||||
<!-- scorecard panel -->
|
||||
<div class="panel">
|
||||
<div class="panel-header">
|
||||
<span class="panel-title">📊 일간 스코어카드</span>
|
||||
<div class="panel-controls">
|
||||
<select id="sc-market-select" onchange="fetchScorecard()">
|
||||
<option value="KR">KR</option>
|
||||
<option value="US_NASDAQ">US_NASDAQ</option>
|
||||
</select>
|
||||
<span id="sc-date" class="panel-date"></span>
|
||||
</div>
|
||||
</div>
|
||||
<div id="scorecard-grid" class="scorecard-grid"><p class="empty-msg">데이터 없음</p></div>
|
||||
</div>
|
||||
|
||||
<!-- scenarios panel -->
|
||||
<div class="panel">
|
||||
<div class="panel-header">
|
||||
<span class="panel-title">🎯 활성 시나리오 매칭</span>
|
||||
<div class="panel-controls">
|
||||
<select id="scen-market-select" onchange="fetchScenarios()">
|
||||
<option value="KR">KR</option>
|
||||
<option value="US_NASDAQ">US_NASDAQ</option>
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
<div id="scenarios-content"><p class="empty-msg">데이터 없음</p></div>
|
||||
</div>
|
||||
|
||||
<!-- context layer panel -->
|
||||
<div class="panel">
|
||||
<div class="panel-header">
|
||||
<span class="panel-title">🧠 컨텍스트 트리</span>
|
||||
<div class="panel-controls">
|
||||
<select id="ctx-layer-select" onchange="fetchContext()">
|
||||
<option value="L7_REALTIME">L7_REALTIME</option>
|
||||
<option value="L6_DAILY">L6_DAILY</option>
|
||||
<option value="L5_WEEKLY">L5_WEEKLY</option>
|
||||
<option value="L4_MONTHLY">L4_MONTHLY</option>
|
||||
<option value="L3_QUARTERLY">L3_QUARTERLY</option>
|
||||
<option value="L2_YEARLY">L2_YEARLY</option>
|
||||
<option value="L1_LIFETIME">L1_LIFETIME</option>
|
||||
</select>
|
||||
<input id="ctx-limit" type="number" value="20" min="1" max="200"
|
||||
style="width:60px;" onchange="fetchContext()">
|
||||
</div>
|
||||
</div>
|
||||
<div id="context-content"><p class="empty-msg">데이터 없음</p></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let pnlChart = null;
|
||||
let currentDays = 7;
|
||||
let currentMarket = 'KR';
|
||||
|
||||
function fmt(dt) {
|
||||
try {
|
||||
const d = new Date(dt);
|
||||
return d.toLocaleTimeString('ko-KR', { hour: '2-digit', minute: '2-digit', hour12: false });
|
||||
} catch { return dt || '--'; }
|
||||
}
|
||||
|
||||
function fmtPnl(v) {
|
||||
if (v === null || v === undefined) return '--';
|
||||
const n = parseFloat(v);
|
||||
const cls = n > 0 ? 'positive' : n < 0 ? 'negative' : 'neutral';
|
||||
const sign = n > 0 ? '+' : '';
|
||||
return `<span class="${cls}">${sign}${n.toFixed(2)}</span>`;
|
||||
}
|
||||
|
||||
function badge(action) {
|
||||
const a = (action || '').toUpperCase();
|
||||
const cls = a === 'BUY' ? 'badge-buy' : a === 'SELL' ? 'badge-sell' : 'badge-hold';
|
||||
return `<span class="badge ${cls}">${a}</span>`;
|
||||
}
|
||||
|
||||
function confBar(conf) {
|
||||
const pct = Math.min(Math.max(conf || 0, 0), 100);
|
||||
return `<div class="conf-bar-wrap">
|
||||
<div class="conf-bar"><div class="conf-fill" style="width:${pct}%"></div></div>
|
||||
<span class="conf-val">${pct}</span>
|
||||
</div>`;
|
||||
}
|
||||
|
||||
function fmtPrice(v, market) {
|
||||
if (v === null || v === undefined) return '--';
|
||||
const n = parseFloat(v);
|
||||
const sym = market === 'KR' ? '₩' : market === 'JP' ? '¥' : market === 'HK' ? 'HK$' : '$';
|
||||
return sym + n.toLocaleString('en-US', { minimumFractionDigits: 0, maximumFractionDigits: 4 });
|
||||
}
|
||||
|
||||
async function fetchPositions() {
|
||||
const tbody = document.getElementById('positions-body');
|
||||
const countEl = document.getElementById('positions-count');
|
||||
try {
|
||||
const r = await fetch('/api/positions');
|
||||
if (!r.ok) throw new Error('fetch failed');
|
||||
const d = await r.json();
|
||||
countEl.textContent = d.count ?? 0;
|
||||
if (!d.positions || d.positions.length === 0) {
|
||||
tbody.innerHTML = '<tr><td colspan="5" class="pos-empty">현재 보유 중인 포지션 없음</td></tr>';
|
||||
return;
|
||||
}
|
||||
tbody.innerHTML = d.positions.map(p => `
|
||||
<tr>
|
||||
<td><strong>${p.stock_code || '--'}</strong></td>
|
||||
<td><span style="color:var(--muted);font-size:11px">${p.market || '--'}</span></td>
|
||||
<td>${p.quantity ?? '--'}</td>
|
||||
<td>${fmtPrice(p.entry_price, p.market)}</td>
|
||||
<td style="color:var(--muted);font-size:11px">${p.held || '--'}</td>
|
||||
</tr>
|
||||
`).join('');
|
||||
} catch {
|
||||
tbody.innerHTML = '<tr><td colspan="5" class="pos-empty">데이터 로드 실패</td></tr>';
|
||||
}
|
||||
}
|
||||
|
||||
function renderCbGauge(cb) {
|
||||
if (!cb) return;
|
||||
const dot = document.getElementById('cb-dot');
|
||||
const label = document.getElementById('cb-label');
|
||||
const bar = document.getElementById('cb-bar');
|
||||
|
||||
const status = cb.status || 'unknown';
|
||||
const threshold = cb.threshold_pct ?? -3.0;
|
||||
const current = cb.current_pnl_pct;
|
||||
|
||||
// dot color
|
||||
dot.className = `cb-dot ${status}`;
|
||||
|
||||
// label
|
||||
if (current !== null && current !== undefined) {
|
||||
const sign = current > 0 ? '+' : '';
|
||||
label.textContent = `CB ${sign}${current.toFixed(2)}%`;
|
||||
} else {
|
||||
label.textContent = 'CB --';
|
||||
}
|
||||
|
||||
// bar: fill = how much of the threshold has been consumed (0%=safe, 100%=tripped)
|
||||
const colorMap = { ok: 'var(--accent)', warning: 'var(--warn)', tripped: 'var(--red)', unknown: 'var(--border)' };
|
||||
bar.style.background = colorMap[status] || 'var(--border)';
|
||||
if (current !== null && current !== undefined && threshold < 0) {
|
||||
const fillPct = Math.min(Math.max((current / threshold) * 100, 0), 100);
|
||||
bar.style.width = `${fillPct}%`;
|
||||
} else {
|
||||
bar.style.width = '0%';
|
||||
}
|
||||
}
|
||||
|
||||
async function fetchStatus() {
|
||||
try {
|
||||
const r = await fetch('/api/status');
|
||||
if (!r.ok) return;
|
||||
const d = await r.json();
|
||||
const t = d.totals || {};
|
||||
document.getElementById('card-trades').textContent = t.trade_count ?? '--';
|
||||
const pnlEl = document.getElementById('card-pnl');
|
||||
const pnlV = t.total_pnl;
|
||||
if (pnlV !== undefined) {
|
||||
const n = parseFloat(pnlV);
|
||||
const sign = n > 0 ? '+' : '';
|
||||
pnlEl.textContent = `${sign}${n.toFixed(2)}`;
|
||||
pnlEl.className = `card-value ${n > 0 ? 'positive' : n < 0 ? 'negative' : 'neutral'}`;
|
||||
}
|
||||
document.getElementById('card-pnl-sub').textContent = `결정 ${t.decision_count ?? 0}건`;
|
||||
renderCbGauge(d.circuit_breaker);
|
||||
renderModeBadge(d.mode);
|
||||
} catch {}
|
||||
}
|
||||
|
||||
function renderModeBadge(mode) {
|
||||
const el = document.getElementById('mode-badge');
|
||||
if (!el) return;
|
||||
if (mode === 'live') {
|
||||
el.textContent = '🔴 실전투자';
|
||||
el.className = 'mode-badge live';
|
||||
} else {
|
||||
el.textContent = '🟡 모의투자';
|
||||
el.className = 'mode-badge paper';
|
||||
}
|
||||
}
|
||||
|
||||
async function fetchPerformance() {
|
||||
try {
|
||||
const r = await fetch('/api/performance?market=all');
|
||||
if (!r.ok) return;
|
||||
const d = await r.json();
|
||||
const c = d.combined || {};
|
||||
document.getElementById('card-winrate').textContent = c.win_rate !== undefined ? `${c.win_rate}%` : '--';
|
||||
document.getElementById('card-total').textContent = c.total_trades ?? '--';
|
||||
} catch {}
|
||||
}
|
||||
|
||||
async function fetchPnlHistory(days) {
|
||||
try {
|
||||
const r = await fetch(`/api/pnl/history?days=${days}`);
|
||||
if (!r.ok) throw new Error('fetch failed');
|
||||
const d = await r.json();
|
||||
renderChart(d);
|
||||
} catch {
|
||||
document.getElementById('chart-error').style.display = 'block';
|
||||
}
|
||||
}
|
||||
|
||||
function renderChart(data) {
|
||||
const errEl = document.getElementById('chart-error');
|
||||
if (!data.labels || data.labels.length === 0) {
|
||||
errEl.style.display = 'block';
|
||||
return;
|
||||
}
|
||||
errEl.style.display = 'none';
|
||||
|
||||
const colors = data.pnl.map(v => v >= 0 ? 'rgba(60,179,113,0.75)' : 'rgba(224,85,85,0.75)');
|
||||
const borderColors = data.pnl.map(v => v >= 0 ? '#3cb371' : '#e05555');
|
||||
|
||||
if (pnlChart) { pnlChart.destroy(); pnlChart = null; }
|
||||
const ctx = document.getElementById('pnl-chart').getContext('2d');
|
||||
pnlChart = new Chart(ctx, {
|
||||
type: 'bar',
|
||||
data: {
|
||||
labels: data.labels,
|
||||
datasets: [{
|
||||
label: 'Daily P&L',
|
||||
data: data.pnl,
|
||||
backgroundColor: colors,
|
||||
borderColor: borderColors,
|
||||
borderWidth: 1,
|
||||
borderRadius: 3,
|
||||
}]
|
||||
},
|
||||
options: {
|
||||
responsive: true,
|
||||
maintainAspectRatio: false,
|
||||
plugins: {
|
||||
legend: { display: false },
|
||||
tooltip: {
|
||||
callbacks: {
|
||||
label: ctx => {
|
||||
const v = ctx.parsed.y;
|
||||
const sign = v >= 0 ? '+' : '';
|
||||
const trades = data.trades[ctx.dataIndex];
|
||||
return [`P&L: ${sign}${v.toFixed(2)}`, `거래: ${trades}건`];
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
scales: {
|
||||
x: {
|
||||
ticks: { color: '#9fb3c8', font: { size: 10 }, maxRotation: 0 },
|
||||
grid: { color: 'rgba(40,69,95,0.4)' }
|
||||
},
|
||||
y: {
|
||||
ticks: { color: '#9fb3c8', font: { size: 10 } },
|
||||
grid: { color: 'rgba(40,69,95,0.4)' }
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async function fetchDecisions(market) {
|
||||
const tbody = document.getElementById('decisions-body');
|
||||
tbody.innerHTML = '<tr class="empty-row"><td colspan="5"><span class="spinner"></span></td></tr>';
|
||||
try {
|
||||
const r = await fetch(`/api/decisions?market=${market}&limit=50`);
|
||||
if (!r.ok) throw new Error('fetch failed');
|
||||
const d = await r.json();
|
||||
if (!d.decisions || d.decisions.length === 0) {
|
||||
tbody.innerHTML = '<tr class="empty-row"><td colspan="5">결정 로그 없음</td></tr>';
|
||||
return;
|
||||
}
|
||||
tbody.innerHTML = d.decisions.map(dec => `
|
||||
<tr>
|
||||
<td>${fmt(dec.timestamp)}</td>
|
||||
<td>${dec.stock_code || '--'}</td>
|
||||
<td>${badge(dec.action)}</td>
|
||||
<td>${confBar(dec.confidence)}</td>
|
||||
<td class="rationale-cell" title="${(dec.rationale || '').replace(/"/g, '"')}">${dec.rationale || '--'}</td>
|
||||
</tr>
|
||||
`).join('');
|
||||
} catch {
|
||||
tbody.innerHTML = '<tr class="empty-row"><td colspan="5">데이터 로드 실패</td></tr>';
|
||||
}
|
||||
}
|
||||
|
||||
function selectDays(btn) {
|
||||
document.querySelectorAll('.day-btn').forEach(b => b.classList.remove('active'));
|
||||
btn.classList.add('active');
|
||||
currentDays = parseInt(btn.dataset.days, 10);
|
||||
fetchPnlHistory(currentDays);
|
||||
}
|
||||
|
||||
function selectMarket(btn) {
|
||||
document.querySelectorAll('.tab-btn').forEach(b => b.classList.remove('active'));
|
||||
btn.classList.add('active');
|
||||
currentMarket = btn.dataset.market;
|
||||
fetchDecisions(currentMarket);
|
||||
}
|
||||
|
||||
function todayStr() {
|
||||
return new Date().toISOString().slice(0, 10);
|
||||
}
|
||||
|
||||
function esc(s) {
|
||||
return String(s ?? '').replace(/&/g, '&').replace(/</g, '<').replace(/>/g, '>').replace(/"/g, '"');
|
||||
}
|
||||
|
||||
async function fetchJSON(url) {
|
||||
const r = await fetch(url);
|
||||
if (!r.ok) throw new Error(`HTTP ${r.status}`);
|
||||
return r.json();
|
||||
}
|
||||
|
||||
async function fetchPlaybook() {
|
||||
const market = document.getElementById('pb-market-select').value;
|
||||
const date = todayStr();
|
||||
document.getElementById('pb-date').textContent = date;
|
||||
const el = document.getElementById('playbook-content');
|
||||
try {
|
||||
const data = await fetchJSON(`/api/playbook/${date}?market=${market}`);
|
||||
const stocks = data.stock_playbooks ?? [];
|
||||
if (stocks.length === 0) {
|
||||
el.innerHTML = '<p class="empty-msg">오늘 플레이북 없음</p>';
|
||||
return;
|
||||
}
|
||||
el.innerHTML = stocks.map(sp =>
|
||||
`<details><summary>${esc(sp.stock_code ?? '?')} — ${esc(sp.signal ?? '')}</summary>` +
|
||||
`<pre>${esc(JSON.stringify(sp, null, 2))}</pre></details>`
|
||||
).join('');
|
||||
} catch {
|
||||
el.innerHTML = '<p class="empty-msg">플레이북 없음 (오늘 미생성 또는 API 오류)</p>';
|
||||
}
|
||||
}
|
||||
|
||||
async function fetchScorecard() {
|
||||
const market = document.getElementById('sc-market-select').value;
|
||||
const date = todayStr();
|
||||
document.getElementById('sc-date').textContent = date;
|
||||
const el = document.getElementById('scorecard-grid');
|
||||
try {
|
||||
const data = await fetchJSON(`/api/scorecard/${date}?market=${market}`);
|
||||
const sc = data.scorecard ?? {};
|
||||
const entries = Object.entries(sc);
|
||||
if (entries.length === 0) {
|
||||
el.innerHTML = '<p class="empty-msg">스코어카드 없음</p>';
|
||||
return;
|
||||
}
|
||||
el.className = 'scorecard-grid';
|
||||
el.innerHTML = entries.map(([k, v]) => `
|
||||
<div class="kpi-card">
|
||||
<div class="kpi-label">${esc(k)}</div>
|
||||
<div class="kpi-value">${typeof v === 'number' ? v.toFixed(2) : esc(String(v))}</div>
|
||||
</div>`).join('');
|
||||
} catch {
|
||||
el.innerHTML = '<p class="empty-msg">스코어카드 없음 (오늘 미생성 또는 API 오류)</p>';
|
||||
}
|
||||
}
|
||||
|
||||
async function fetchScenarios() {
|
||||
const market = document.getElementById('scen-market-select').value;
|
||||
const date = todayStr();
|
||||
const el = document.getElementById('scenarios-content');
|
||||
try {
|
||||
const data = await fetchJSON(`/api/scenarios/active?market=${market}&date_str=${date}&limit=50`);
|
||||
const matches = data.matches ?? [];
|
||||
if (matches.length === 0) {
|
||||
el.innerHTML = '<p class="empty-msg">활성 시나리오 없음</p>';
|
||||
return;
|
||||
}
|
||||
el.innerHTML = `<table class="scenarios-table">
|
||||
<thead><tr><th>종목</th><th>신호</th><th>신뢰도</th><th>매칭 조건</th></tr></thead>
|
||||
<tbody>${matches.map(m => `
|
||||
<tr>
|
||||
<td>${esc(m.stock_code)}</td>
|
||||
<td>${esc(m.signal ?? '-')}</td>
|
||||
<td>${esc(m.confidence ?? '-')}</td>
|
||||
<td><code style="font-size:11px">${esc(JSON.stringify(m.scenario_match ?? {}))}</code></td>
|
||||
</tr>`).join('')}
|
||||
</tbody></table>`;
|
||||
} catch {
|
||||
el.innerHTML = '<p class="empty-msg">데이터 없음</p>';
|
||||
}
|
||||
}
|
||||
|
||||
async function fetchContext() {
|
||||
const layer = document.getElementById('ctx-layer-select').value;
|
||||
const limit = Math.min(Math.max(parseInt(document.getElementById('ctx-limit').value, 10) || 20, 1), 200);
|
||||
const el = document.getElementById('context-content');
|
||||
try {
|
||||
const data = await fetchJSON(`/api/context/${layer}?limit=${limit}`);
|
||||
const entries = data.entries ?? [];
|
||||
if (entries.length === 0) {
|
||||
el.innerHTML = '<p class="empty-msg">컨텍스트 없음</p>';
|
||||
return;
|
||||
}
|
||||
el.innerHTML = `<table class="context-table">
|
||||
<thead><tr><th>timeframe</th><th>key</th><th>value</th><th>updated</th></tr></thead>
|
||||
<tbody>${entries.map(e => `
|
||||
<tr>
|
||||
<td>${esc(e.timeframe)}</td>
|
||||
<td>${esc(e.key)}</td>
|
||||
<td><div class="context-value">${esc(JSON.stringify(e.value ?? e.raw_value))}</div></td>
|
||||
<td style="font-size:11px;color:var(--muted)">${esc((e.updated_at ?? '').slice(0, 16))}</td>
|
||||
</tr>`).join('')}
|
||||
</tbody></table>`;
|
||||
} catch {
|
||||
el.innerHTML = '<p class="empty-msg">데이터 없음</p>';
|
||||
}
|
||||
}
|
||||
|
||||
async function refreshAll() {
|
||||
document.getElementById('last-updated').textContent = '업데이트 중...';
|
||||
await Promise.all([
|
||||
fetchStatus(),
|
||||
fetchPerformance(),
|
||||
fetchPositions(),
|
||||
fetchPnlHistory(currentDays),
|
||||
fetchDecisions(currentMarket),
|
||||
fetchPlaybook(),
|
||||
fetchScorecard(),
|
||||
fetchScenarios(),
|
||||
fetchContext(),
|
||||
]);
|
||||
const now = new Date();
|
||||
const timeStr = now.toLocaleTimeString('ko-KR', { hour: '2-digit', minute: '2-digit', second: '2-digit', hour12: false });
|
||||
document.getElementById('last-updated').textContent = `마지막 업데이트: ${timeStr}`;
|
||||
}
|
||||
|
||||
// Initial load
|
||||
refreshAll();
|
||||
|
||||
// Auto-refresh every 30 seconds
|
||||
setInterval(refreshAll, 30000);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
205
src/data/README.md
Normal file
205
src/data/README.md
Normal file
@@ -0,0 +1,205 @@
|
||||
# External Data Integration
|
||||
|
||||
This module provides objective external data sources to enhance trading decisions beyond just market prices and user input.
|
||||
|
||||
## Modules
|
||||
|
||||
### `news_api.py` - News Sentiment Analysis
|
||||
|
||||
Fetches real-time news for stocks with sentiment scoring.
|
||||
|
||||
**Features:**
|
||||
- Alpha Vantage and NewsAPI.org support
|
||||
- Sentiment scoring (-1.0 to +1.0)
|
||||
- 5-minute caching to minimize API quota usage
|
||||
- Graceful fallback when API unavailable
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from src.data.news_api import NewsAPI
|
||||
|
||||
# Initialize with API key
|
||||
news_api = NewsAPI(api_key="your_key", provider="alphavantage")
|
||||
|
||||
# Fetch news sentiment
|
||||
sentiment = await news_api.get_news_sentiment("AAPL")
|
||||
if sentiment:
|
||||
print(f"Average sentiment: {sentiment.avg_sentiment}")
|
||||
for article in sentiment.articles[:3]:
|
||||
print(f"{article.title} ({article.sentiment_score})")
|
||||
```
|
||||
|
||||
### `economic_calendar.py` - Major Economic Events
|
||||
|
||||
Tracks FOMC meetings, GDP releases, CPI, earnings calendars, and other market-moving events.
|
||||
|
||||
**Features:**
|
||||
- High-impact event tracking (FOMC, GDP, CPI)
|
||||
- Earnings calendar per stock
|
||||
- Event proximity checking
|
||||
- Hardcoded major events for 2026 (no API required)
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from src.data.economic_calendar import EconomicCalendar
|
||||
|
||||
calendar = EconomicCalendar()
|
||||
calendar.load_hardcoded_events()
|
||||
|
||||
# Get upcoming high-impact events
|
||||
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="HIGH")
|
||||
print(f"High-impact events: {upcoming.high_impact_count}")
|
||||
|
||||
# Check if near earnings
|
||||
earnings_date = calendar.get_earnings_date("AAPL")
|
||||
if earnings_date:
|
||||
print(f"Next earnings: {earnings_date}")
|
||||
|
||||
# Check for high volatility period
|
||||
if calendar.is_high_volatility_period(hours_ahead=24):
|
||||
print("High-impact event imminent!")
|
||||
```
|
||||
|
||||
### `market_data.py` - Market Indicators
|
||||
|
||||
Provides market breadth, sector performance, and sentiment indicators.
|
||||
|
||||
**Features:**
|
||||
- Market sentiment levels (Fear & Greed equivalent)
|
||||
- Market breadth (advancing/declining stocks)
|
||||
- Sector performance tracking
|
||||
- Fear/Greed score calculation
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from src.data.market_data import MarketData
|
||||
|
||||
market_data = MarketData(api_key="your_key")
|
||||
|
||||
# Get market sentiment
|
||||
sentiment = market_data.get_market_sentiment()
|
||||
print(f"Market sentiment: {sentiment.name}")
|
||||
|
||||
# Get full indicators
|
||||
indicators = market_data.get_market_indicators("US")
|
||||
print(f"Sentiment: {indicators.sentiment.name}")
|
||||
print(f"A/D Ratio: {indicators.breadth.advance_decline_ratio}")
|
||||
```
|
||||
|
||||
## Integration with GeminiClient
|
||||
|
||||
The external data sources are seamlessly integrated into the AI decision engine:
|
||||
|
||||
```python
|
||||
from src.brain.gemini_client import GeminiClient
|
||||
from src.data.news_api import NewsAPI
|
||||
from src.data.economic_calendar import EconomicCalendar
|
||||
from src.data.market_data import MarketData
|
||||
from src.config import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# Initialize data sources
|
||||
news_api = NewsAPI(api_key=settings.NEWS_API_KEY, provider=settings.NEWS_API_PROVIDER)
|
||||
calendar = EconomicCalendar()
|
||||
calendar.load_hardcoded_events()
|
||||
market_data = MarketData(api_key=settings.MARKET_DATA_API_KEY)
|
||||
|
||||
# Create enhanced client
|
||||
client = GeminiClient(
|
||||
settings,
|
||||
news_api=news_api,
|
||||
economic_calendar=calendar,
|
||||
market_data=market_data
|
||||
)
|
||||
|
||||
# Make decision with external context
|
||||
market_data_dict = {
|
||||
"stock_code": "AAPL",
|
||||
"current_price": 180.0,
|
||||
"market_name": "US stock market"
|
||||
}
|
||||
|
||||
decision = await client.decide(market_data_dict)
|
||||
```
|
||||
|
||||
The external data is automatically included in the prompt sent to Gemini:
|
||||
|
||||
```
|
||||
Market: US stock market
|
||||
Stock Code: AAPL
|
||||
Current Price: 180.0
|
||||
|
||||
EXTERNAL DATA:
|
||||
News Sentiment: 0.85 (from 10 articles)
|
||||
1. [Reuters] Apple hits record high (sentiment: 0.92)
|
||||
2. [Bloomberg] Strong iPhone sales (sentiment: 0.78)
|
||||
3. [CNBC] Tech sector rallying (sentiment: 0.85)
|
||||
|
||||
Upcoming High-Impact Events: 2 in next 7 days
|
||||
Next: FOMC Meeting (FOMC) on 2026-03-18
|
||||
Earnings: AAPL on 2026-02-10
|
||||
|
||||
Market Sentiment: GREED
|
||||
Advance/Decline Ratio: 2.35
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
Add these to your `.env` file:
|
||||
|
||||
```bash
|
||||
# External Data APIs (optional)
|
||||
NEWS_API_KEY=your_alpha_vantage_key
|
||||
NEWS_API_PROVIDER=alphavantage # or "newsapi"
|
||||
MARKET_DATA_API_KEY=your_market_data_key
|
||||
```
|
||||
|
||||
## API Recommendations
|
||||
|
||||
### Alpha Vantage (News)
|
||||
- **Free tier:** 5 calls/min, 500 calls/day
|
||||
- **Pros:** Provides sentiment scores, no credit card required
|
||||
- **URL:** https://www.alphavantage.co/
|
||||
|
||||
### NewsAPI.org
|
||||
- **Free tier:** 100 requests/day
|
||||
- **Pros:** Large news coverage, easy to use
|
||||
- **Cons:** No sentiment scores (we use keyword heuristics)
|
||||
- **URL:** https://newsapi.org/
|
||||
|
||||
## Caching Strategy
|
||||
|
||||
To minimize API quota usage:
|
||||
|
||||
1. **News:** 5-minute TTL cache per stock
|
||||
2. **Economic Calendar:** Loaded once at startup (hardcoded events)
|
||||
3. **Market Data:** Fetched per decision (lightweight)
|
||||
|
||||
## Graceful Degradation
|
||||
|
||||
The system works gracefully without external data:
|
||||
|
||||
- If no API keys provided → decisions work with just market prices
|
||||
- If API fails → decision continues without external context
|
||||
- If cache expired → attempts refetch, falls back to no data
|
||||
- Errors are logged but never block trading decisions
|
||||
|
||||
## Testing
|
||||
|
||||
All modules have comprehensive test coverage (81%+):
|
||||
|
||||
```bash
|
||||
pytest tests/test_data_integration.py -v --cov=src/data
|
||||
```
|
||||
|
||||
Tests use mocks to avoid requiring real API keys.
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- Twitter/X sentiment analysis
|
||||
- Reddit WallStreetBets sentiment
|
||||
- Options flow data
|
||||
- Insider trading activity
|
||||
- Analyst upgrades/downgrades
|
||||
- Real-time economic data APIs
|
||||
5
src/data/__init__.py
Normal file
5
src/data/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""External data integration for objective decision-making."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ["NewsAPI", "EconomicCalendar", "MarketData"]
|
||||
219
src/data/economic_calendar.py
Normal file
219
src/data/economic_calendar.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""Economic calendar integration for major market events.
|
||||
|
||||
Tracks FOMC meetings, GDP releases, CPI, earnings calendars, and other
|
||||
market-moving events.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EconomicEvent:
|
||||
"""Single economic event."""
|
||||
|
||||
name: str
|
||||
event_type: str # "FOMC", "GDP", "CPI", "EARNINGS", etc.
|
||||
datetime: datetime
|
||||
impact: str # "HIGH", "MEDIUM", "LOW"
|
||||
country: str
|
||||
description: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpcomingEvents:
|
||||
"""Collection of upcoming economic events."""
|
||||
|
||||
events: list[EconomicEvent]
|
||||
high_impact_count: int
|
||||
next_major_event: EconomicEvent | None
|
||||
|
||||
|
||||
class EconomicCalendar:
|
||||
"""Economic calendar with event tracking and impact scoring."""
|
||||
|
||||
def __init__(self, api_key: str | None = None) -> None:
|
||||
"""Initialize economic calendar.
|
||||
|
||||
Args:
|
||||
api_key: API key for calendar provider (None for testing/hardcoded)
|
||||
"""
|
||||
self._api_key = api_key
|
||||
# For now, use hardcoded major events (can be extended with API)
|
||||
self._events: list[EconomicEvent] = []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_upcoming_events(
|
||||
self, days_ahead: int = 7, min_impact: str = "MEDIUM"
|
||||
) -> UpcomingEvents:
|
||||
"""Get upcoming economic events within specified timeframe.
|
||||
|
||||
Args:
|
||||
days_ahead: Number of days to look ahead
|
||||
min_impact: Minimum impact level ("LOW", "MEDIUM", "HIGH")
|
||||
|
||||
Returns:
|
||||
UpcomingEvents with filtered events
|
||||
"""
|
||||
now = datetime.now()
|
||||
end_date = now + timedelta(days=days_ahead)
|
||||
|
||||
# Filter events by timeframe and impact
|
||||
upcoming = [
|
||||
event
|
||||
for event in self._events
|
||||
if now <= event.datetime <= end_date
|
||||
and self._impact_level(event.impact) >= self._impact_level(min_impact)
|
||||
]
|
||||
|
||||
# Sort by datetime
|
||||
upcoming.sort(key=lambda e: e.datetime)
|
||||
|
||||
# Count high-impact events
|
||||
high_impact_count = sum(1 for e in upcoming if e.impact == "HIGH")
|
||||
|
||||
# Get next major event
|
||||
next_major = None
|
||||
for event in upcoming:
|
||||
if event.impact == "HIGH":
|
||||
next_major = event
|
||||
break
|
||||
|
||||
return UpcomingEvents(
|
||||
events=upcoming,
|
||||
high_impact_count=high_impact_count,
|
||||
next_major_event=next_major,
|
||||
)
|
||||
|
||||
def add_event(self, event: EconomicEvent) -> None:
|
||||
"""Add an economic event to the calendar."""
|
||||
self._events.append(event)
|
||||
|
||||
def clear_events(self) -> None:
|
||||
"""Clear all events (useful for testing)."""
|
||||
self._events.clear()
|
||||
|
||||
def get_earnings_date(self, stock_code: str) -> datetime | None:
|
||||
"""Get next earnings date for a stock.
|
||||
|
||||
Args:
|
||||
stock_code: Stock ticker symbol
|
||||
|
||||
Returns:
|
||||
Next earnings datetime or None if not found
|
||||
"""
|
||||
now = datetime.now()
|
||||
earnings_events = [
|
||||
event
|
||||
for event in self._events
|
||||
if event.event_type == "EARNINGS"
|
||||
and stock_code.upper() in event.name.upper()
|
||||
and event.datetime > now
|
||||
]
|
||||
|
||||
if not earnings_events:
|
||||
return None
|
||||
|
||||
# Return earliest upcoming earnings
|
||||
earnings_events.sort(key=lambda e: e.datetime)
|
||||
return earnings_events[0].datetime
|
||||
|
||||
def load_hardcoded_events(self) -> None:
|
||||
"""Load hardcoded major economic events for 2026.
|
||||
|
||||
This is a fallback when no API is available.
|
||||
"""
|
||||
# Major FOMC meetings in 2026 (estimated)
|
||||
fomc_dates = [
|
||||
datetime(2026, 3, 18),
|
||||
datetime(2026, 5, 6),
|
||||
datetime(2026, 6, 17),
|
||||
datetime(2026, 7, 29),
|
||||
datetime(2026, 9, 16),
|
||||
datetime(2026, 11, 4),
|
||||
datetime(2026, 12, 16),
|
||||
]
|
||||
|
||||
for date in fomc_dates:
|
||||
self.add_event(
|
||||
EconomicEvent(
|
||||
name="FOMC Meeting",
|
||||
event_type="FOMC",
|
||||
datetime=date,
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Federal Reserve interest rate decision",
|
||||
)
|
||||
)
|
||||
|
||||
# Quarterly GDP releases (estimated)
|
||||
gdp_dates = [
|
||||
datetime(2026, 4, 28),
|
||||
datetime(2026, 7, 30),
|
||||
datetime(2026, 10, 29),
|
||||
]
|
||||
|
||||
for date in gdp_dates:
|
||||
self.add_event(
|
||||
EconomicEvent(
|
||||
name="US GDP Release",
|
||||
event_type="GDP",
|
||||
datetime=date,
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Quarterly GDP growth rate",
|
||||
)
|
||||
)
|
||||
|
||||
# Monthly CPI releases (12th of each month, estimated)
|
||||
for month in range(1, 13):
|
||||
try:
|
||||
cpi_date = datetime(2026, month, 12)
|
||||
self.add_event(
|
||||
EconomicEvent(
|
||||
name="US CPI Release",
|
||||
event_type="CPI",
|
||||
datetime=cpi_date,
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Consumer Price Index inflation data",
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _impact_level(self, impact: str) -> int:
|
||||
"""Convert impact string to numeric level."""
|
||||
levels = {"LOW": 1, "MEDIUM": 2, "HIGH": 3}
|
||||
return levels.get(impact.upper(), 0)
|
||||
|
||||
def is_high_volatility_period(self, hours_ahead: int = 24) -> bool:
|
||||
"""Check if we're near a high-impact event.
|
||||
|
||||
Args:
|
||||
hours_ahead: Number of hours to look ahead
|
||||
|
||||
Returns:
|
||||
True if high-impact event is imminent
|
||||
"""
|
||||
now = datetime.now()
|
||||
threshold = now + timedelta(hours=hours_ahead)
|
||||
|
||||
for event in self._events:
|
||||
if event.impact == "HIGH" and now <= event.datetime <= threshold:
|
||||
return True
|
||||
|
||||
return False
|
||||
198
src/data/market_data.py
Normal file
198
src/data/market_data.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""Additional market data indicators beyond basic price data.
|
||||
|
||||
Provides market breadth, sector performance, and market sentiment indicators.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MarketSentiment(Enum):
|
||||
"""Overall market sentiment levels."""
|
||||
|
||||
EXTREME_FEAR = 1
|
||||
FEAR = 2
|
||||
NEUTRAL = 3
|
||||
GREED = 4
|
||||
EXTREME_GREED = 5
|
||||
|
||||
|
||||
@dataclass
|
||||
class SectorPerformance:
|
||||
"""Performance metrics for a market sector."""
|
||||
|
||||
sector_name: str
|
||||
daily_change_pct: float
|
||||
weekly_change_pct: float
|
||||
leader_stock: str # Best performing stock in sector
|
||||
laggard_stock: str # Worst performing stock in sector
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarketBreadth:
|
||||
"""Market breadth indicators."""
|
||||
|
||||
advancing_stocks: int
|
||||
declining_stocks: int
|
||||
unchanged_stocks: int
|
||||
new_highs: int
|
||||
new_lows: int
|
||||
advance_decline_ratio: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarketIndicators:
|
||||
"""Aggregated market indicators."""
|
||||
|
||||
sentiment: MarketSentiment
|
||||
breadth: MarketBreadth
|
||||
sector_performance: list[SectorPerformance]
|
||||
vix_level: float | None # Volatility index if available
|
||||
|
||||
|
||||
class MarketData:
|
||||
"""Market data provider for additional indicators."""
|
||||
|
||||
def __init__(self, api_key: str | None = None) -> None:
|
||||
"""Initialize market data provider.
|
||||
|
||||
Args:
|
||||
api_key: API key for data provider (None for testing)
|
||||
"""
|
||||
self._api_key = api_key
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_market_sentiment(self) -> MarketSentiment:
|
||||
"""Get current market sentiment level.
|
||||
|
||||
This is a simplified version. In production, this would integrate
|
||||
with Fear & Greed Index or similar sentiment indicators.
|
||||
|
||||
Returns:
|
||||
MarketSentiment enum value
|
||||
"""
|
||||
# Default to neutral when API not available
|
||||
if self._api_key is None:
|
||||
logger.debug("No market data API key — returning NEUTRAL sentiment")
|
||||
return MarketSentiment.NEUTRAL
|
||||
|
||||
# TODO: Integrate with actual sentiment API
|
||||
return MarketSentiment.NEUTRAL
|
||||
|
||||
def get_market_breadth(self, market: str = "US") -> MarketBreadth | None:
|
||||
"""Get market breadth indicators.
|
||||
|
||||
Args:
|
||||
market: Market code ("US", "KR", etc.)
|
||||
|
||||
Returns:
|
||||
MarketBreadth object or None if unavailable
|
||||
"""
|
||||
if self._api_key is None:
|
||||
logger.debug("No market data API key — returning None for breadth")
|
||||
return None
|
||||
|
||||
# TODO: Integrate with actual market breadth API
|
||||
return None
|
||||
|
||||
def get_sector_performance(
|
||||
self, market: str = "US"
|
||||
) -> list[SectorPerformance]:
|
||||
"""Get sector performance rankings.
|
||||
|
||||
Args:
|
||||
market: Market code ("US", "KR", etc.)
|
||||
|
||||
Returns:
|
||||
List of SectorPerformance objects, sorted by daily change
|
||||
"""
|
||||
if self._api_key is None:
|
||||
logger.debug("No market data API key — returning empty sector list")
|
||||
return []
|
||||
|
||||
# TODO: Integrate with actual sector performance API
|
||||
return []
|
||||
|
||||
def get_market_indicators(self, market: str = "US") -> MarketIndicators:
|
||||
"""Get aggregated market indicators.
|
||||
|
||||
Args:
|
||||
market: Market code ("US", "KR", etc.)
|
||||
|
||||
Returns:
|
||||
MarketIndicators with all available data
|
||||
"""
|
||||
sentiment = self.get_market_sentiment()
|
||||
breadth = self.get_market_breadth(market)
|
||||
sectors = self.get_sector_performance(market)
|
||||
|
||||
# Default breadth if unavailable
|
||||
if breadth is None:
|
||||
breadth = MarketBreadth(
|
||||
advancing_stocks=0,
|
||||
declining_stocks=0,
|
||||
unchanged_stocks=0,
|
||||
new_highs=0,
|
||||
new_lows=0,
|
||||
advance_decline_ratio=1.0,
|
||||
)
|
||||
|
||||
return MarketIndicators(
|
||||
sentiment=sentiment,
|
||||
breadth=breadth,
|
||||
sector_performance=sectors,
|
||||
vix_level=None, # TODO: Add VIX integration
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helper Methods
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def calculate_fear_greed_score(
|
||||
self, breadth: MarketBreadth, vix: float | None = None
|
||||
) -> int:
|
||||
"""Calculate a simple fear/greed score (0-100).
|
||||
|
||||
Args:
|
||||
breadth: Market breadth data
|
||||
vix: VIX level (optional)
|
||||
|
||||
Returns:
|
||||
Score from 0 (extreme fear) to 100 (extreme greed)
|
||||
"""
|
||||
# Start at neutral
|
||||
score = 50
|
||||
|
||||
# Adjust based on advance/decline ratio
|
||||
if breadth.advance_decline_ratio > 1.5:
|
||||
score += 20
|
||||
elif breadth.advance_decline_ratio > 1.0:
|
||||
score += 10
|
||||
elif breadth.advance_decline_ratio < 0.5:
|
||||
score -= 20
|
||||
elif breadth.advance_decline_ratio < 1.0:
|
||||
score -= 10
|
||||
|
||||
# Adjust based on new highs/lows
|
||||
if breadth.new_highs > breadth.new_lows * 2:
|
||||
score += 15
|
||||
elif breadth.new_lows > breadth.new_highs * 2:
|
||||
score -= 15
|
||||
|
||||
# Adjust based on VIX if available
|
||||
if vix is not None:
|
||||
if vix > 30: # High volatility = fear
|
||||
score -= 15
|
||||
elif vix < 15: # Low volatility = complacency/greed
|
||||
score += 10
|
||||
|
||||
# Clamp to 0-100
|
||||
return max(0, min(100, score))
|
||||
316
src/data/news_api.py
Normal file
316
src/data/news_api.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""News API integration with sentiment analysis and caching.
|
||||
|
||||
Fetches real-time news for stocks using free-tier APIs (Alpha Vantage or NewsAPI).
|
||||
Includes 5-minute caching to minimize API quota usage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache entries expire after 5 minutes
|
||||
CACHE_TTL_SECONDS = 300
|
||||
|
||||
|
||||
@dataclass
|
||||
class NewsArticle:
|
||||
"""Single news article with sentiment."""
|
||||
|
||||
title: str
|
||||
summary: str
|
||||
source: str
|
||||
published_at: str
|
||||
sentiment_score: float # -1.0 (negative) to +1.0 (positive)
|
||||
url: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class NewsSentiment:
|
||||
"""Aggregated news sentiment for a stock."""
|
||||
|
||||
stock_code: str
|
||||
articles: list[NewsArticle]
|
||||
avg_sentiment: float # Average sentiment across all articles
|
||||
article_count: int
|
||||
fetched_at: float # Unix timestamp
|
||||
|
||||
|
||||
class NewsAPI:
|
||||
"""News API client with sentiment analysis and caching."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
provider: str = "alphavantage",
|
||||
cache_ttl: int = CACHE_TTL_SECONDS,
|
||||
) -> None:
|
||||
"""Initialize NewsAPI client.
|
||||
|
||||
Args:
|
||||
api_key: API key for the news provider (None for testing)
|
||||
provider: News provider ("alphavantage" or "newsapi")
|
||||
cache_ttl: Cache time-to-live in seconds
|
||||
"""
|
||||
self._api_key = api_key
|
||||
self._provider = provider
|
||||
self._cache_ttl = cache_ttl
|
||||
self._cache: dict[str, NewsSentiment] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_news_sentiment(self, stock_code: str) -> NewsSentiment | None:
|
||||
"""Fetch news sentiment for a stock with caching.
|
||||
|
||||
Args:
|
||||
stock_code: Stock ticker symbol (e.g., "AAPL", "005930")
|
||||
|
||||
Returns:
|
||||
NewsSentiment object or None if fetch fails or API unavailable
|
||||
"""
|
||||
# Check cache first
|
||||
cached = self._get_from_cache(stock_code)
|
||||
if cached is not None:
|
||||
logger.debug("News cache hit for %s", stock_code)
|
||||
return cached
|
||||
|
||||
# API key required for real requests
|
||||
if self._api_key is None:
|
||||
logger.warning("No news API key provided — returning None")
|
||||
return None
|
||||
|
||||
# Fetch from API
|
||||
try:
|
||||
sentiment = await self._fetch_news(stock_code)
|
||||
if sentiment is not None:
|
||||
self._cache[stock_code] = sentiment
|
||||
return sentiment
|
||||
except Exception as exc:
|
||||
logger.error("Failed to fetch news for %s: %s", stock_code, exc)
|
||||
return None
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the news cache (useful for testing)."""
|
||||
self._cache.clear()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Cache Management
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_from_cache(self, stock_code: str) -> NewsSentiment | None:
|
||||
"""Retrieve cached sentiment if not expired."""
|
||||
if stock_code not in self._cache:
|
||||
return None
|
||||
|
||||
cached = self._cache[stock_code]
|
||||
age = time.time() - cached.fetched_at
|
||||
|
||||
if age > self._cache_ttl:
|
||||
logger.debug("News cache expired for %s (age: %.1fs)", stock_code, age)
|
||||
del self._cache[stock_code]
|
||||
return None
|
||||
|
||||
return cached
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# API Fetching
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _fetch_news(self, stock_code: str) -> NewsSentiment | None:
|
||||
"""Fetch news from the provider API."""
|
||||
if self._provider == "alphavantage":
|
||||
return await self._fetch_alphavantage(stock_code)
|
||||
elif self._provider == "newsapi":
|
||||
return await self._fetch_newsapi(stock_code)
|
||||
else:
|
||||
logger.error("Unknown news provider: %s", self._provider)
|
||||
return None
|
||||
|
||||
async def _fetch_alphavantage(self, stock_code: str) -> NewsSentiment | None:
|
||||
"""Fetch news from Alpha Vantage News Sentiment API."""
|
||||
url = "https://www.alphavantage.co/query"
|
||||
params = {
|
||||
"function": "NEWS_SENTIMENT",
|
||||
"tickers": stock_code,
|
||||
"apikey": self._api_key,
|
||||
"limit": 10, # Fetch top 10 articles
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, params=params, timeout=10) as resp:
|
||||
if resp.status != 200:
|
||||
logger.error(
|
||||
"Alpha Vantage API error: HTTP %d", resp.status
|
||||
)
|
||||
return None
|
||||
|
||||
data = await resp.json()
|
||||
return self._parse_alphavantage_response(stock_code, data)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Alpha Vantage request failed: %s", exc)
|
||||
return None
|
||||
|
||||
async def _fetch_newsapi(self, stock_code: str) -> NewsSentiment | None:
|
||||
"""Fetch news from NewsAPI.org."""
|
||||
url = "https://newsapi.org/v2/everything"
|
||||
params = {
|
||||
"q": stock_code,
|
||||
"apiKey": self._api_key,
|
||||
"pageSize": 10,
|
||||
"sortBy": "publishedAt",
|
||||
"language": "en",
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, params=params, timeout=10) as resp:
|
||||
if resp.status != 200:
|
||||
logger.error("NewsAPI error: HTTP %d", resp.status)
|
||||
return None
|
||||
|
||||
data = await resp.json()
|
||||
return self._parse_newsapi_response(stock_code, data)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("NewsAPI request failed: %s", exc)
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Response Parsing
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _parse_alphavantage_response(
|
||||
self, stock_code: str, data: dict[str, Any]
|
||||
) -> NewsSentiment | None:
|
||||
"""Parse Alpha Vantage API response."""
|
||||
if "feed" not in data:
|
||||
logger.warning("No 'feed' key in Alpha Vantage response")
|
||||
return None
|
||||
|
||||
articles: list[NewsArticle] = []
|
||||
for item in data["feed"]:
|
||||
# Extract sentiment for this specific ticker
|
||||
ticker_sentiment = self._extract_ticker_sentiment(item, stock_code)
|
||||
|
||||
article = NewsArticle(
|
||||
title=item.get("title", ""),
|
||||
summary=item.get("summary", "")[:200], # Truncate long summaries
|
||||
source=item.get("source", "Unknown"),
|
||||
published_at=item.get("time_published", ""),
|
||||
sentiment_score=ticker_sentiment,
|
||||
url=item.get("url", ""),
|
||||
)
|
||||
articles.append(article)
|
||||
|
||||
if not articles:
|
||||
return None
|
||||
|
||||
avg_sentiment = sum(a.sentiment_score for a in articles) / len(articles)
|
||||
|
||||
return NewsSentiment(
|
||||
stock_code=stock_code,
|
||||
articles=articles,
|
||||
avg_sentiment=avg_sentiment,
|
||||
article_count=len(articles),
|
||||
fetched_at=time.time(),
|
||||
)
|
||||
|
||||
def _extract_ticker_sentiment(
|
||||
self, item: dict[str, Any], stock_code: str
|
||||
) -> float:
|
||||
"""Extract sentiment score for specific ticker from article."""
|
||||
ticker_sentiments = item.get("ticker_sentiment", [])
|
||||
for ts in ticker_sentiments:
|
||||
if ts.get("ticker", "").upper() == stock_code.upper():
|
||||
# Alpha Vantage provides sentiment_score as string
|
||||
score_str = ts.get("ticker_sentiment_score", "0")
|
||||
try:
|
||||
return float(score_str)
|
||||
except ValueError:
|
||||
return 0.0
|
||||
|
||||
# Fallback to overall sentiment if ticker-specific not found
|
||||
overall_sentiment = item.get("overall_sentiment_score", "0")
|
||||
try:
|
||||
return float(overall_sentiment)
|
||||
except ValueError:
|
||||
return 0.0
|
||||
|
||||
def _parse_newsapi_response(
|
||||
self, stock_code: str, data: dict[str, Any]
|
||||
) -> NewsSentiment | None:
|
||||
"""Parse NewsAPI.org response.
|
||||
|
||||
Note: NewsAPI doesn't provide sentiment scores, so we use a
|
||||
simple heuristic based on title keywords.
|
||||
"""
|
||||
if data.get("status") != "ok" or "articles" not in data:
|
||||
logger.warning("Invalid NewsAPI response")
|
||||
return None
|
||||
|
||||
articles: list[NewsArticle] = []
|
||||
for item in data["articles"]:
|
||||
# Simple sentiment heuristic based on keywords
|
||||
sentiment = self._estimate_sentiment_from_text(
|
||||
item.get("title", "") + " " + item.get("description", "")
|
||||
)
|
||||
|
||||
article = NewsArticle(
|
||||
title=item.get("title", ""),
|
||||
summary=item.get("description", "")[:200],
|
||||
source=item.get("source", {}).get("name", "Unknown"),
|
||||
published_at=item.get("publishedAt", ""),
|
||||
sentiment_score=sentiment,
|
||||
url=item.get("url", ""),
|
||||
)
|
||||
articles.append(article)
|
||||
|
||||
if not articles:
|
||||
return None
|
||||
|
||||
avg_sentiment = sum(a.sentiment_score for a in articles) / len(articles)
|
||||
|
||||
return NewsSentiment(
|
||||
stock_code=stock_code,
|
||||
articles=articles,
|
||||
avg_sentiment=avg_sentiment,
|
||||
article_count=len(articles),
|
||||
fetched_at=time.time(),
|
||||
)
|
||||
|
||||
def _estimate_sentiment_from_text(self, text: str) -> float:
|
||||
"""Simple keyword-based sentiment estimation.
|
||||
|
||||
This is a fallback for APIs that don't provide sentiment scores.
|
||||
Returns a score between -1.0 and +1.0.
|
||||
"""
|
||||
text_lower = text.lower()
|
||||
|
||||
positive_keywords = [
|
||||
"surge", "jump", "gain", "rise", "soar", "rally", "profit",
|
||||
"growth", "upgrade", "beat", "strong", "bullish", "breakthrough",
|
||||
]
|
||||
negative_keywords = [
|
||||
"plunge", "fall", "drop", "decline", "crash", "loss", "weak",
|
||||
"downgrade", "miss", "bearish", "concern", "risk", "warning",
|
||||
]
|
||||
|
||||
positive_count = sum(1 for kw in positive_keywords if kw in text_lower)
|
||||
negative_count = sum(1 for kw in negative_keywords if kw in text_lower)
|
||||
|
||||
total = positive_count + negative_count
|
||||
if total == 0:
|
||||
return 0.0
|
||||
|
||||
# Normalize to -1.0 to +1.0 range
|
||||
return (positive_count - negative_count) / total
|
||||
217
src/db.py
217
src/db.py
@@ -2,9 +2,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
def init_db(db_path: str) -> sqlite3.Connection:
|
||||
@@ -12,6 +14,11 @@ def init_db(db_path: str) -> sqlite3.Connection:
|
||||
if db_path != ":memory:":
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(db_path)
|
||||
# Enable WAL mode for concurrent read/write (dashboard + trading loop).
|
||||
# WAL does not apply to in-memory databases.
|
||||
if db_path != ":memory:":
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA busy_timeout=5000")
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS trades (
|
||||
@@ -25,12 +32,14 @@ def init_db(db_path: str) -> sqlite3.Connection:
|
||||
price REAL,
|
||||
pnl REAL DEFAULT 0.0,
|
||||
market TEXT DEFAULT 'KR',
|
||||
exchange_code TEXT DEFAULT 'KRX'
|
||||
exchange_code TEXT DEFAULT 'KRX',
|
||||
decision_id TEXT,
|
||||
mode TEXT DEFAULT 'paper'
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Migration: Add market and exchange_code columns if they don't exist
|
||||
# Migration: Add columns if they don't exist (backward-compatible schema upgrades)
|
||||
cursor = conn.execute("PRAGMA table_info(trades)")
|
||||
columns = {row[1] for row in cursor.fetchall()}
|
||||
|
||||
@@ -38,6 +47,116 @@ def init_db(db_path: str) -> sqlite3.Connection:
|
||||
conn.execute("ALTER TABLE trades ADD COLUMN market TEXT DEFAULT 'KR'")
|
||||
if "exchange_code" not in columns:
|
||||
conn.execute("ALTER TABLE trades ADD COLUMN exchange_code TEXT DEFAULT 'KRX'")
|
||||
if "selection_context" not in columns:
|
||||
conn.execute("ALTER TABLE trades ADD COLUMN selection_context TEXT")
|
||||
if "decision_id" not in columns:
|
||||
conn.execute("ALTER TABLE trades ADD COLUMN decision_id TEXT")
|
||||
if "mode" not in columns:
|
||||
conn.execute("ALTER TABLE trades ADD COLUMN mode TEXT DEFAULT 'paper'")
|
||||
|
||||
# Context tree tables for multi-layered memory management
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS contexts (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
layer TEXT NOT NULL,
|
||||
timeframe TEXT NOT NULL,
|
||||
key TEXT NOT NULL,
|
||||
value TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
UNIQUE(layer, timeframe, key)
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Decision logging table for comprehensive audit trail
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS decision_logs (
|
||||
decision_id TEXT PRIMARY KEY,
|
||||
timestamp TEXT NOT NULL,
|
||||
stock_code TEXT NOT NULL,
|
||||
market TEXT NOT NULL,
|
||||
exchange_code TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
confidence INTEGER NOT NULL,
|
||||
rationale TEXT NOT NULL,
|
||||
context_snapshot TEXT NOT NULL,
|
||||
input_data TEXT NOT NULL,
|
||||
outcome_pnl REAL,
|
||||
outcome_accuracy INTEGER,
|
||||
reviewed INTEGER DEFAULT 0,
|
||||
review_notes TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS context_metadata (
|
||||
layer TEXT PRIMARY KEY,
|
||||
description TEXT NOT NULL,
|
||||
retention_days INTEGER,
|
||||
aggregation_source TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Playbook storage for pre-market strategy persistence
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS playbooks (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
date TEXT NOT NULL,
|
||||
market TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
playbook_json TEXT NOT NULL,
|
||||
generated_at TEXT NOT NULL,
|
||||
token_count INTEGER DEFAULT 0,
|
||||
scenario_count INTEGER DEFAULT 0,
|
||||
match_count INTEGER DEFAULT 0,
|
||||
UNIQUE(date, market)
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_playbooks_date ON playbooks(date)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_playbooks_market ON playbooks(market)")
|
||||
|
||||
# Create indices for efficient context queries
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_contexts_layer ON contexts(layer)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_contexts_timeframe ON contexts(timeframe)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_contexts_updated ON contexts(updated_at)")
|
||||
|
||||
# Create indices for efficient decision log queries
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_decision_logs_timestamp ON decision_logs(timestamp)"
|
||||
)
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_decision_logs_reviewed ON decision_logs(reviewed)"
|
||||
)
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_decision_logs_confidence ON decision_logs(confidence)"
|
||||
)
|
||||
|
||||
# Index for open-position queries (partition by stock_code, market, ordered by timestamp)
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_trades_stock_market_ts"
|
||||
" ON trades (stock_code, market, timestamp DESC)"
|
||||
)
|
||||
|
||||
# Lightweight key-value store for trading system runtime metrics (dashboard use only)
|
||||
# Intentionally separate from the AI context tree to preserve separation of concerns.
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS system_metrics (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
return conn
|
||||
@@ -54,15 +173,38 @@ def log_trade(
|
||||
pnl: float = 0.0,
|
||||
market: str = "KR",
|
||||
exchange_code: str = "KRX",
|
||||
selection_context: dict[str, any] | None = None,
|
||||
decision_id: str | None = None,
|
||||
mode: str = "paper",
|
||||
) -> None:
|
||||
"""Insert a trade record into the database."""
|
||||
"""Insert a trade record into the database.
|
||||
|
||||
Args:
|
||||
conn: Database connection
|
||||
stock_code: Stock code
|
||||
action: Trade action (BUY/SELL/HOLD)
|
||||
confidence: Confidence level (0-100)
|
||||
rationale: AI decision rationale
|
||||
quantity: Number of shares
|
||||
price: Trade price
|
||||
pnl: Profit/loss
|
||||
market: Market code
|
||||
exchange_code: Exchange code
|
||||
selection_context: Scanner selection data (RSI, volume_ratio, signal, score)
|
||||
decision_id: Unique decision identifier for audit linking
|
||||
mode: Trading mode ('paper' or 'live') for data separation
|
||||
"""
|
||||
# Serialize selection context to JSON
|
||||
context_json = json.dumps(selection_context) if selection_context else None
|
||||
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO trades (
|
||||
timestamp, stock_code, action, confidence, rationale,
|
||||
quantity, price, pnl, market, exchange_code
|
||||
quantity, price, pnl, market, exchange_code, selection_context, decision_id,
|
||||
mode
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
datetime.now(UTC).isoformat(),
|
||||
@@ -75,6 +217,71 @@ def log_trade(
|
||||
pnl,
|
||||
market,
|
||||
exchange_code,
|
||||
context_json,
|
||||
decision_id,
|
||||
mode,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def get_latest_buy_trade(
|
||||
conn: sqlite3.Connection, stock_code: str, market: str
|
||||
) -> dict[str, Any] | None:
|
||||
"""Fetch the most recent BUY trade for a stock and market."""
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
SELECT decision_id, price, quantity
|
||||
FROM trades
|
||||
WHERE stock_code = ?
|
||||
AND market = ?
|
||||
AND action = 'BUY'
|
||||
AND decision_id IS NOT NULL
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 1
|
||||
""",
|
||||
(stock_code, market),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
return {"decision_id": row[0], "price": row[1], "quantity": row[2]}
|
||||
|
||||
|
||||
def get_open_position(
|
||||
conn: sqlite3.Connection, stock_code: str, market: str
|
||||
) -> dict[str, Any] | None:
|
||||
"""Return open position if latest trade is BUY, else None."""
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
SELECT action, decision_id, price, quantity
|
||||
FROM trades
|
||||
WHERE stock_code = ?
|
||||
AND market = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 1
|
||||
""",
|
||||
(stock_code, market),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if not row or row[0] != "BUY":
|
||||
return None
|
||||
return {"decision_id": row[1], "price": row[2], "quantity": row[3]}
|
||||
|
||||
|
||||
def get_recent_symbols(
|
||||
conn: sqlite3.Connection, market: str, limit: int = 30
|
||||
) -> list[str]:
|
||||
"""Return recent unique symbols for a market, newest first."""
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
SELECT stock_code, MAX(timestamp) AS last_ts
|
||||
FROM trades
|
||||
WHERE market = ?
|
||||
GROUP BY stock_code
|
||||
ORDER BY last_ts DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(market, limit),
|
||||
)
|
||||
return [row[0] for row in cursor.fetchall() if row and row[0]]
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
"""Evolution engine for self-improving trading strategies."""
|
||||
|
||||
from src.evolution.ab_test import ABTester, ABTestResult, StrategyPerformance
|
||||
from src.evolution.daily_review import DailyReviewer
|
||||
from src.evolution.optimizer import EvolutionOptimizer
|
||||
from src.evolution.performance_tracker import (
|
||||
PerformanceDashboard,
|
||||
PerformanceTracker,
|
||||
StrategyMetrics,
|
||||
)
|
||||
from src.evolution.scorecard import DailyScorecard
|
||||
|
||||
__all__ = [
|
||||
"EvolutionOptimizer",
|
||||
"ABTester",
|
||||
"ABTestResult",
|
||||
"StrategyPerformance",
|
||||
"PerformanceTracker",
|
||||
"PerformanceDashboard",
|
||||
"StrategyMetrics",
|
||||
"DailyScorecard",
|
||||
"DailyReviewer",
|
||||
]
|
||||
|
||||
220
src/evolution/ab_test.py
Normal file
220
src/evolution/ab_test.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""A/B Testing framework for strategy comparison.
|
||||
|
||||
Runs multiple strategies in parallel, tracks their performance,
|
||||
and uses statistical significance testing to determine winners.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import scipy.stats as stats
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrategyPerformance:
|
||||
"""Performance metrics for a single strategy."""
|
||||
|
||||
strategy_name: str
|
||||
total_trades: int
|
||||
wins: int
|
||||
losses: int
|
||||
total_pnl: float
|
||||
avg_pnl: float
|
||||
win_rate: float
|
||||
sharpe_ratio: float | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ABTestResult:
|
||||
"""Result of an A/B test between two strategies."""
|
||||
|
||||
strategy_a: str
|
||||
strategy_b: str
|
||||
winner: str | None
|
||||
p_value: float
|
||||
confidence_level: float
|
||||
is_significant: bool
|
||||
performance_a: StrategyPerformance
|
||||
performance_b: StrategyPerformance
|
||||
|
||||
|
||||
class ABTester:
|
||||
"""A/B testing framework for comparing trading strategies."""
|
||||
|
||||
def __init__(self, significance_level: float = 0.05) -> None:
|
||||
"""Initialize A/B tester.
|
||||
|
||||
Args:
|
||||
significance_level: P-value threshold for statistical significance (default 0.05)
|
||||
"""
|
||||
self._significance_level = significance_level
|
||||
|
||||
def calculate_performance(
|
||||
self, trades: list[dict[str, Any]], strategy_name: str
|
||||
) -> StrategyPerformance:
|
||||
"""Calculate performance metrics for a strategy.
|
||||
|
||||
Args:
|
||||
trades: List of trade records with pnl values
|
||||
strategy_name: Name of the strategy
|
||||
|
||||
Returns:
|
||||
StrategyPerformance object with calculated metrics
|
||||
"""
|
||||
if not trades:
|
||||
return StrategyPerformance(
|
||||
strategy_name=strategy_name,
|
||||
total_trades=0,
|
||||
wins=0,
|
||||
losses=0,
|
||||
total_pnl=0.0,
|
||||
avg_pnl=0.0,
|
||||
win_rate=0.0,
|
||||
sharpe_ratio=None,
|
||||
)
|
||||
|
||||
total_trades = len(trades)
|
||||
wins = sum(1 for t in trades if t.get("pnl", 0) > 0)
|
||||
losses = sum(1 for t in trades if t.get("pnl", 0) < 0)
|
||||
pnls = [t.get("pnl", 0.0) for t in trades]
|
||||
total_pnl = sum(pnls)
|
||||
avg_pnl = total_pnl / total_trades if total_trades > 0 else 0.0
|
||||
win_rate = (wins / total_trades * 100) if total_trades > 0 else 0.0
|
||||
|
||||
# Calculate Sharpe ratio (risk-adjusted return)
|
||||
sharpe_ratio = None
|
||||
if len(pnls) > 1:
|
||||
mean_return = avg_pnl
|
||||
std_return = (
|
||||
sum((p - mean_return) ** 2 for p in pnls) / (len(pnls) - 1)
|
||||
) ** 0.5
|
||||
if std_return > 0:
|
||||
sharpe_ratio = mean_return / std_return
|
||||
|
||||
return StrategyPerformance(
|
||||
strategy_name=strategy_name,
|
||||
total_trades=total_trades,
|
||||
wins=wins,
|
||||
losses=losses,
|
||||
total_pnl=round(total_pnl, 2),
|
||||
avg_pnl=round(avg_pnl, 2),
|
||||
win_rate=round(win_rate, 2),
|
||||
sharpe_ratio=round(sharpe_ratio, 4) if sharpe_ratio else None,
|
||||
)
|
||||
|
||||
def compare_strategies(
|
||||
self,
|
||||
trades_a: list[dict[str, Any]],
|
||||
trades_b: list[dict[str, Any]],
|
||||
strategy_a_name: str = "Strategy A",
|
||||
strategy_b_name: str = "Strategy B",
|
||||
) -> ABTestResult:
|
||||
"""Compare two strategies using statistical testing.
|
||||
|
||||
Uses a two-sample t-test to determine if performance difference is significant.
|
||||
|
||||
Args:
|
||||
trades_a: List of trades from strategy A
|
||||
trades_b: List of trades from strategy B
|
||||
strategy_a_name: Name of strategy A
|
||||
strategy_b_name: Name of strategy B
|
||||
|
||||
Returns:
|
||||
ABTestResult with comparison details
|
||||
"""
|
||||
perf_a = self.calculate_performance(trades_a, strategy_a_name)
|
||||
perf_b = self.calculate_performance(trades_b, strategy_b_name)
|
||||
|
||||
# Extract PnL arrays for statistical testing
|
||||
pnls_a = [t.get("pnl", 0.0) for t in trades_a]
|
||||
pnls_b = [t.get("pnl", 0.0) for t in trades_b]
|
||||
|
||||
# Perform two-sample t-test
|
||||
if len(pnls_a) > 1 and len(pnls_b) > 1:
|
||||
t_stat, p_value = stats.ttest_ind(pnls_a, pnls_b, equal_var=False)
|
||||
is_significant = p_value < self._significance_level
|
||||
confidence_level = (1 - p_value) * 100
|
||||
else:
|
||||
# Not enough data for statistical test
|
||||
p_value = 1.0
|
||||
is_significant = False
|
||||
confidence_level = 0.0
|
||||
|
||||
# Determine winner based on average PnL
|
||||
winner = None
|
||||
if is_significant:
|
||||
if perf_a.avg_pnl > perf_b.avg_pnl:
|
||||
winner = strategy_a_name
|
||||
elif perf_b.avg_pnl > perf_a.avg_pnl:
|
||||
winner = strategy_b_name
|
||||
|
||||
return ABTestResult(
|
||||
strategy_a=strategy_a_name,
|
||||
strategy_b=strategy_b_name,
|
||||
winner=winner,
|
||||
p_value=round(p_value, 4),
|
||||
confidence_level=round(confidence_level, 2),
|
||||
is_significant=is_significant,
|
||||
performance_a=perf_a,
|
||||
performance_b=perf_b,
|
||||
)
|
||||
|
||||
def should_deploy(
|
||||
self,
|
||||
result: ABTestResult,
|
||||
min_win_rate: float = 60.0,
|
||||
min_trades: int = 20,
|
||||
) -> bool:
|
||||
"""Determine if a winning strategy should be deployed.
|
||||
|
||||
Args:
|
||||
result: A/B test result
|
||||
min_win_rate: Minimum win rate percentage for deployment (default 60%)
|
||||
min_trades: Minimum number of trades required (default 20)
|
||||
|
||||
Returns:
|
||||
True if the winning strategy meets deployment criteria
|
||||
"""
|
||||
if not result.is_significant or result.winner is None:
|
||||
return False
|
||||
|
||||
# Get performance of winning strategy
|
||||
if result.winner == result.strategy_a:
|
||||
winning_perf = result.performance_a
|
||||
else:
|
||||
winning_perf = result.performance_b
|
||||
|
||||
# Check deployment criteria
|
||||
has_enough_trades = winning_perf.total_trades >= min_trades
|
||||
has_good_win_rate = winning_perf.win_rate >= min_win_rate
|
||||
is_profitable = winning_perf.avg_pnl > 0
|
||||
|
||||
meets_criteria = has_enough_trades and has_good_win_rate and is_profitable
|
||||
|
||||
if meets_criteria:
|
||||
logger.info(
|
||||
"Strategy '%s' meets deployment criteria: "
|
||||
"win_rate=%.2f%%, trades=%d, avg_pnl=%.2f",
|
||||
result.winner,
|
||||
winning_perf.win_rate,
|
||||
winning_perf.total_trades,
|
||||
winning_perf.avg_pnl,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Strategy '%s' does NOT meet deployment criteria: "
|
||||
"win_rate=%.2f%% (min %.2f%%), trades=%d (min %d), avg_pnl=%.2f",
|
||||
result.winner if result.winner else "unknown",
|
||||
winning_perf.win_rate if result.winner else 0.0,
|
||||
min_win_rate,
|
||||
winning_perf.total_trades if result.winner else 0,
|
||||
min_trades,
|
||||
winning_perf.avg_pnl if result.winner else 0.0,
|
||||
)
|
||||
|
||||
return meets_criteria
|
||||
196
src/evolution/daily_review.py
Normal file
196
src/evolution/daily_review.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""Daily review generator for market-scoped end-of-day scorecards."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import sqlite3
|
||||
from dataclasses import asdict
|
||||
|
||||
from src.brain.gemini_client import GeminiClient
|
||||
from src.context.layer import ContextLayer
|
||||
from src.context.store import ContextStore
|
||||
from src.evolution.scorecard import DailyScorecard
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DailyReviewer:
|
||||
"""Builds daily scorecards and optional AI-generated lessons."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conn: sqlite3.Connection,
|
||||
context_store: ContextStore,
|
||||
gemini_client: GeminiClient | None = None,
|
||||
) -> None:
|
||||
self._conn = conn
|
||||
self._context_store = context_store
|
||||
self._gemini = gemini_client
|
||||
|
||||
def generate_scorecard(self, date: str, market: str) -> DailyScorecard:
|
||||
"""Generate a market-scoped scorecard from decision logs and trades."""
|
||||
decision_rows = self._conn.execute(
|
||||
"""
|
||||
SELECT action, confidence, context_snapshot
|
||||
FROM decision_logs
|
||||
WHERE DATE(timestamp) = ? AND market = ?
|
||||
""",
|
||||
(date, market),
|
||||
).fetchall()
|
||||
|
||||
total_decisions = len(decision_rows)
|
||||
buys = sum(1 for row in decision_rows if row[0] == "BUY")
|
||||
sells = sum(1 for row in decision_rows if row[0] == "SELL")
|
||||
holds = sum(1 for row in decision_rows if row[0] == "HOLD")
|
||||
avg_confidence = (
|
||||
round(sum(int(row[1]) for row in decision_rows) / total_decisions, 2)
|
||||
if total_decisions > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
matched = 0
|
||||
for row in decision_rows:
|
||||
try:
|
||||
snapshot = json.loads(row[2]) if row[2] else {}
|
||||
except json.JSONDecodeError:
|
||||
snapshot = {}
|
||||
scenario_match = snapshot.get("scenario_match", {})
|
||||
if isinstance(scenario_match, dict) and scenario_match:
|
||||
matched += 1
|
||||
scenario_match_rate = (
|
||||
round((matched / total_decisions) * 100, 2)
|
||||
if total_decisions
|
||||
else 0.0
|
||||
)
|
||||
|
||||
trade_stats = self._conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
COALESCE(SUM(pnl), 0.0),
|
||||
SUM(CASE WHEN pnl > 0 THEN 1 ELSE 0 END),
|
||||
SUM(CASE WHEN pnl < 0 THEN 1 ELSE 0 END)
|
||||
FROM trades
|
||||
WHERE DATE(timestamp) = ? AND market = ?
|
||||
""",
|
||||
(date, market),
|
||||
).fetchone()
|
||||
total_pnl = round(float(trade_stats[0] or 0.0), 2) if trade_stats else 0.0
|
||||
wins = int(trade_stats[1] or 0) if trade_stats else 0
|
||||
losses = int(trade_stats[2] or 0) if trade_stats else 0
|
||||
win_rate = round((wins / (wins + losses)) * 100, 2) if (wins + losses) > 0 else 0.0
|
||||
|
||||
top_winners = [
|
||||
row[0]
|
||||
for row in self._conn.execute(
|
||||
"""
|
||||
SELECT stock_code, SUM(pnl) AS stock_pnl
|
||||
FROM trades
|
||||
WHERE DATE(timestamp) = ? AND market = ?
|
||||
GROUP BY stock_code
|
||||
HAVING stock_pnl > 0
|
||||
ORDER BY stock_pnl DESC
|
||||
LIMIT 3
|
||||
""",
|
||||
(date, market),
|
||||
).fetchall()
|
||||
]
|
||||
|
||||
top_losers = [
|
||||
row[0]
|
||||
for row in self._conn.execute(
|
||||
"""
|
||||
SELECT stock_code, SUM(pnl) AS stock_pnl
|
||||
FROM trades
|
||||
WHERE DATE(timestamp) = ? AND market = ?
|
||||
GROUP BY stock_code
|
||||
HAVING stock_pnl < 0
|
||||
ORDER BY stock_pnl ASC
|
||||
LIMIT 3
|
||||
""",
|
||||
(date, market),
|
||||
).fetchall()
|
||||
]
|
||||
|
||||
return DailyScorecard(
|
||||
date=date,
|
||||
market=market,
|
||||
total_decisions=total_decisions,
|
||||
buys=buys,
|
||||
sells=sells,
|
||||
holds=holds,
|
||||
total_pnl=total_pnl,
|
||||
win_rate=win_rate,
|
||||
avg_confidence=avg_confidence,
|
||||
scenario_match_rate=scenario_match_rate,
|
||||
top_winners=top_winners,
|
||||
top_losers=top_losers,
|
||||
lessons=[],
|
||||
cross_market_note="",
|
||||
)
|
||||
|
||||
async def generate_lessons(self, scorecard: DailyScorecard) -> list[str]:
|
||||
"""Generate concise lessons from scorecard metrics using Gemini."""
|
||||
if self._gemini is None:
|
||||
return []
|
||||
|
||||
prompt = (
|
||||
"You are a trading performance reviewer.\n"
|
||||
"Return ONLY a JSON array of 1-3 short lessons in English.\n"
|
||||
f"Market: {scorecard.market}\n"
|
||||
f"Date: {scorecard.date}\n"
|
||||
f"Total decisions: {scorecard.total_decisions}\n"
|
||||
f"Buys/Sells/Holds: {scorecard.buys}/{scorecard.sells}/{scorecard.holds}\n"
|
||||
f"Total PnL: {scorecard.total_pnl}\n"
|
||||
f"Win rate: {scorecard.win_rate}%\n"
|
||||
f"Average confidence: {scorecard.avg_confidence}\n"
|
||||
f"Scenario match rate: {scorecard.scenario_match_rate}%\n"
|
||||
f"Top winners: {', '.join(scorecard.top_winners) or 'N/A'}\n"
|
||||
f"Top losers: {', '.join(scorecard.top_losers) or 'N/A'}\n"
|
||||
)
|
||||
|
||||
try:
|
||||
decision = await self._gemini.decide(
|
||||
{
|
||||
"stock_code": "REVIEW",
|
||||
"market_name": scorecard.market,
|
||||
"current_price": 0,
|
||||
"prompt_override": prompt,
|
||||
}
|
||||
)
|
||||
return self._parse_lessons(decision.rationale)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to generate daily lessons: %s", exc)
|
||||
return []
|
||||
|
||||
def store_scorecard_in_context(self, scorecard: DailyScorecard) -> None:
|
||||
"""Store scorecard in L6 using market-scoped key."""
|
||||
self._context_store.set_context(
|
||||
ContextLayer.L6_DAILY,
|
||||
scorecard.date,
|
||||
f"scorecard_{scorecard.market}",
|
||||
asdict(scorecard),
|
||||
)
|
||||
|
||||
def _parse_lessons(self, raw_text: str) -> list[str]:
|
||||
"""Parse lessons from JSON array response or fallback text."""
|
||||
raw_text = raw_text.strip()
|
||||
try:
|
||||
parsed = json.loads(raw_text)
|
||||
if isinstance(parsed, list):
|
||||
return [str(item).strip() for item in parsed if str(item).strip()][:3]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
match = re.search(r"\[.*\]", raw_text, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
parsed = json.loads(match.group(0))
|
||||
if isinstance(parsed, list):
|
||||
return [str(item).strip() for item in parsed if str(item).strip()][:3]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
lines = [line.strip("-* \t") for line in raw_text.splitlines() if line.strip()]
|
||||
return lines[:3]
|
||||
@@ -1,10 +1,10 @@
|
||||
"""Evolution Engine — analyzes trade logs and generates new strategies.
|
||||
|
||||
This module:
|
||||
1. Reads trade_logs.db to identify failing patterns
|
||||
2. Asks Gemini to generate a new strategy class
|
||||
3. Runs pytest on the generated file
|
||||
4. Creates a simulated PR if tests pass
|
||||
1. Uses DecisionLogger.get_losing_decisions() to identify failing patterns
|
||||
2. Analyzes failure patterns by time, market conditions, stock characteristics
|
||||
3. Asks Gemini to generate improved strategy recommendations
|
||||
4. Generates new strategy classes with enhanced decision-making logic
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -14,6 +14,7 @@ import logging
|
||||
import sqlite3
|
||||
import subprocess
|
||||
import textwrap
|
||||
from collections import Counter
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -21,6 +22,8 @@ from typing import Any
|
||||
from google import genai
|
||||
|
||||
from src.config import Settings
|
||||
from src.db import init_db
|
||||
from src.logging.decision_logger import DecisionLogger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -53,29 +56,105 @@ class EvolutionOptimizer:
|
||||
self._db_path = settings.DB_PATH
|
||||
self._client = genai.Client(api_key=settings.GEMINI_API_KEY)
|
||||
self._model_name = settings.GEMINI_MODEL
|
||||
self._conn = init_db(self._db_path)
|
||||
self._decision_logger = DecisionLogger(self._conn)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Analysis
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def analyze_failures(self, limit: int = 50) -> list[dict[str, Any]]:
|
||||
"""Find trades where high confidence led to losses."""
|
||||
conn = sqlite3.connect(self._db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
rows = conn.execute(
|
||||
"""Find high-confidence decisions that resulted in losses.
|
||||
|
||||
Uses DecisionLogger.get_losing_decisions() to retrieve failures.
|
||||
"""
|
||||
SELECT stock_code, action, confidence, pnl, rationale, timestamp
|
||||
FROM trades
|
||||
WHERE confidence >= 80 AND pnl < 0
|
||||
ORDER BY pnl ASC
|
||||
LIMIT ?
|
||||
""",
|
||||
(limit,),
|
||||
).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
finally:
|
||||
conn.close()
|
||||
losing_decisions = self._decision_logger.get_losing_decisions(
|
||||
min_confidence=80, min_loss=-100.0
|
||||
)
|
||||
|
||||
# Limit results
|
||||
if len(losing_decisions) > limit:
|
||||
losing_decisions = losing_decisions[:limit]
|
||||
|
||||
# Convert to dict format for analysis
|
||||
failures = []
|
||||
for decision in losing_decisions:
|
||||
failures.append({
|
||||
"decision_id": decision.decision_id,
|
||||
"timestamp": decision.timestamp,
|
||||
"stock_code": decision.stock_code,
|
||||
"market": decision.market,
|
||||
"exchange_code": decision.exchange_code,
|
||||
"action": decision.action,
|
||||
"confidence": decision.confidence,
|
||||
"rationale": decision.rationale,
|
||||
"outcome_pnl": decision.outcome_pnl,
|
||||
"outcome_accuracy": decision.outcome_accuracy,
|
||||
"context_snapshot": decision.context_snapshot,
|
||||
"input_data": decision.input_data,
|
||||
})
|
||||
|
||||
return failures
|
||||
|
||||
def identify_failure_patterns(
|
||||
self, failures: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
"""Identify patterns in losing decisions.
|
||||
|
||||
Analyzes:
|
||||
- Time patterns (hour of day, day of week)
|
||||
- Market conditions (volatility, volume)
|
||||
- Stock characteristics (price range, market)
|
||||
- Common failure modes in rationale
|
||||
"""
|
||||
if not failures:
|
||||
return {"pattern_count": 0, "patterns": {}}
|
||||
|
||||
patterns = {
|
||||
"markets": Counter(),
|
||||
"actions": Counter(),
|
||||
"hours": Counter(),
|
||||
"avg_confidence": 0.0,
|
||||
"avg_loss": 0.0,
|
||||
"total_failures": len(failures),
|
||||
}
|
||||
|
||||
total_confidence = 0
|
||||
total_loss = 0.0
|
||||
|
||||
for failure in failures:
|
||||
# Market distribution
|
||||
patterns["markets"][failure.get("market", "UNKNOWN")] += 1
|
||||
|
||||
# Action distribution
|
||||
patterns["actions"][failure.get("action", "UNKNOWN")] += 1
|
||||
|
||||
# Time pattern (extract hour from ISO timestamp)
|
||||
timestamp = failure.get("timestamp", "")
|
||||
if timestamp:
|
||||
try:
|
||||
dt = datetime.fromisoformat(timestamp)
|
||||
patterns["hours"][dt.hour] += 1
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
# Aggregate metrics
|
||||
total_confidence += failure.get("confidence", 0)
|
||||
total_loss += failure.get("outcome_pnl", 0.0)
|
||||
|
||||
patterns["avg_confidence"] = (
|
||||
round(total_confidence / len(failures), 2) if failures else 0.0
|
||||
)
|
||||
patterns["avg_loss"] = (
|
||||
round(total_loss / len(failures), 2) if failures else 0.0
|
||||
)
|
||||
|
||||
# Convert Counters to regular dicts for JSON serialization
|
||||
patterns["markets"] = dict(patterns["markets"])
|
||||
patterns["actions"] = dict(patterns["actions"])
|
||||
patterns["hours"] = dict(patterns["hours"])
|
||||
|
||||
return patterns
|
||||
|
||||
def get_performance_summary(self) -> dict[str, Any]:
|
||||
"""Return aggregate performance metrics from trade logs."""
|
||||
@@ -109,14 +188,25 @@ class EvolutionOptimizer:
|
||||
async def generate_strategy(self, failures: list[dict[str, Any]]) -> Path | None:
|
||||
"""Ask Gemini to generate a new strategy based on failure analysis.
|
||||
|
||||
Integrates failure patterns and market conditions to create improved strategies.
|
||||
Returns the path to the generated strategy file, or None on failure.
|
||||
"""
|
||||
# Identify failure patterns first
|
||||
patterns = self.identify_failure_patterns(failures)
|
||||
|
||||
prompt = (
|
||||
"You are a quantitative trading strategy developer.\n"
|
||||
"Analyze these failed trades and generate an improved strategy.\n\n"
|
||||
f"Failed trades:\n{json.dumps(failures, indent=2, default=str)}\n\n"
|
||||
"Generate a Python class that inherits from BaseStrategy.\n"
|
||||
"The class must have an `evaluate(self, market_data: dict) -> dict` method.\n"
|
||||
"Analyze these failed trades and their patterns, then generate an improved strategy.\n\n"
|
||||
f"Failure Patterns:\n{json.dumps(patterns, indent=2)}\n\n"
|
||||
f"Sample Failed Trades (first 5):\n"
|
||||
f"{json.dumps(failures[:5], indent=2, default=str)}\n\n"
|
||||
"Based on these patterns, generate an improved trading strategy.\n"
|
||||
"The strategy should:\n"
|
||||
"1. Avoid the identified failure patterns\n"
|
||||
"2. Consider market-specific conditions\n"
|
||||
"3. Adjust confidence based on historical performance\n\n"
|
||||
"Generate a Python method body that inherits from BaseStrategy.\n"
|
||||
"The method signature is: evaluate(self, market_data: dict) -> dict\n"
|
||||
"The method must return a dict with keys: action, confidence, rationale.\n"
|
||||
"Respond with ONLY the method body (Python code), no class definition.\n"
|
||||
)
|
||||
@@ -147,10 +237,15 @@ class EvolutionOptimizer:
|
||||
# Indent the body for the class method
|
||||
indented_body = textwrap.indent(body, " ")
|
||||
|
||||
# Generate rationale from patterns
|
||||
rationale = f"Auto-evolved from {len(failures)} failures. "
|
||||
rationale += f"Primary failure markets: {list(patterns.get('markets', {}).keys())}. "
|
||||
rationale += f"Average loss: {patterns.get('avg_loss', 0.0)}"
|
||||
|
||||
content = STRATEGY_TEMPLATE.format(
|
||||
name=version,
|
||||
timestamp=datetime.now(UTC).isoformat(),
|
||||
rationale="Auto-evolved from failure analysis",
|
||||
rationale=rationale,
|
||||
class_name=class_name,
|
||||
body=indented_body.strip(),
|
||||
)
|
||||
|
||||
303
src/evolution/performance_tracker.py
Normal file
303
src/evolution/performance_tracker.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""Performance tracking system for strategy monitoring.
|
||||
|
||||
Tracks win rates, monitors improvement over time,
|
||||
and provides performance metrics dashboard.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrategyMetrics:
|
||||
"""Performance metrics for a strategy over a time period."""
|
||||
|
||||
strategy_name: str
|
||||
period_start: str
|
||||
period_end: str
|
||||
total_trades: int
|
||||
wins: int
|
||||
losses: int
|
||||
holds: int
|
||||
win_rate: float
|
||||
avg_pnl: float
|
||||
total_pnl: float
|
||||
best_trade: float
|
||||
worst_trade: float
|
||||
avg_confidence: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerformanceDashboard:
|
||||
"""Comprehensive performance dashboard."""
|
||||
|
||||
generated_at: str
|
||||
overall_metrics: StrategyMetrics
|
||||
daily_metrics: list[StrategyMetrics]
|
||||
weekly_metrics: list[StrategyMetrics]
|
||||
improvement_trend: dict[str, Any]
|
||||
|
||||
|
||||
class PerformanceTracker:
|
||||
"""Tracks and monitors strategy performance over time."""
|
||||
|
||||
def __init__(self, db_path: str) -> None:
|
||||
"""Initialize performance tracker.
|
||||
|
||||
Args:
|
||||
db_path: Path to the trade logs database
|
||||
"""
|
||||
self._db_path = db_path
|
||||
|
||||
def get_strategy_metrics(
|
||||
self,
|
||||
strategy_name: str | None = None,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
) -> StrategyMetrics:
|
||||
"""Get performance metrics for a strategy over a time period.
|
||||
|
||||
Args:
|
||||
strategy_name: Name of the strategy (None = all strategies)
|
||||
start_date: Start date in ISO format (None = beginning of time)
|
||||
end_date: End date in ISO format (None = now)
|
||||
|
||||
Returns:
|
||||
StrategyMetrics object with performance data
|
||||
"""
|
||||
conn = sqlite3.connect(self._db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
try:
|
||||
# Build query with optional filters
|
||||
query = """
|
||||
SELECT
|
||||
COUNT(*) as total_trades,
|
||||
SUM(CASE WHEN pnl > 0 THEN 1 ELSE 0 END) as wins,
|
||||
SUM(CASE WHEN pnl < 0 THEN 1 ELSE 0 END) as losses,
|
||||
SUM(CASE WHEN action = 'HOLD' THEN 1 ELSE 0 END) as holds,
|
||||
COALESCE(AVG(CASE WHEN pnl IS NOT NULL THEN pnl END), 0) as avg_pnl,
|
||||
COALESCE(SUM(CASE WHEN pnl IS NOT NULL THEN pnl ELSE 0 END), 0) as total_pnl,
|
||||
COALESCE(MAX(pnl), 0) as best_trade,
|
||||
COALESCE(MIN(pnl), 0) as worst_trade,
|
||||
COALESCE(AVG(confidence), 0) as avg_confidence,
|
||||
MIN(timestamp) as period_start,
|
||||
MAX(timestamp) as period_end
|
||||
FROM trades
|
||||
WHERE 1=1
|
||||
"""
|
||||
params: list[Any] = []
|
||||
|
||||
if start_date:
|
||||
query += " AND timestamp >= ?"
|
||||
params.append(start_date)
|
||||
|
||||
if end_date:
|
||||
query += " AND timestamp <= ?"
|
||||
params.append(end_date)
|
||||
|
||||
# Note: Currently trades table doesn't have strategy_name column
|
||||
# This is a placeholder for future extension
|
||||
|
||||
row = conn.execute(query, params).fetchone()
|
||||
|
||||
total_trades = row["total_trades"] or 0
|
||||
wins = row["wins"] or 0
|
||||
win_rate = (wins / total_trades * 100) if total_trades > 0 else 0.0
|
||||
|
||||
return StrategyMetrics(
|
||||
strategy_name=strategy_name or "default",
|
||||
period_start=row["period_start"] or "",
|
||||
period_end=row["period_end"] or "",
|
||||
total_trades=total_trades,
|
||||
wins=wins,
|
||||
losses=row["losses"] or 0,
|
||||
holds=row["holds"] or 0,
|
||||
win_rate=round(win_rate, 2),
|
||||
avg_pnl=round(row["avg_pnl"], 2),
|
||||
total_pnl=round(row["total_pnl"], 2),
|
||||
best_trade=round(row["best_trade"], 2),
|
||||
worst_trade=round(row["worst_trade"], 2),
|
||||
avg_confidence=round(row["avg_confidence"], 2),
|
||||
)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_daily_metrics(
|
||||
self, days: int = 7, strategy_name: str | None = None
|
||||
) -> list[StrategyMetrics]:
|
||||
"""Get daily performance metrics for the last N days.
|
||||
|
||||
Args:
|
||||
days: Number of days to retrieve (default 7)
|
||||
strategy_name: Name of the strategy (None = all strategies)
|
||||
|
||||
Returns:
|
||||
List of StrategyMetrics, one per day
|
||||
"""
|
||||
metrics = []
|
||||
end_date = datetime.now(UTC)
|
||||
|
||||
for i in range(days):
|
||||
day_end = end_date - timedelta(days=i)
|
||||
day_start = day_end - timedelta(days=1)
|
||||
|
||||
day_metrics = self.get_strategy_metrics(
|
||||
strategy_name=strategy_name,
|
||||
start_date=day_start.isoformat(),
|
||||
end_date=day_end.isoformat(),
|
||||
)
|
||||
metrics.append(day_metrics)
|
||||
|
||||
return metrics
|
||||
|
||||
def get_weekly_metrics(
|
||||
self, weeks: int = 4, strategy_name: str | None = None
|
||||
) -> list[StrategyMetrics]:
|
||||
"""Get weekly performance metrics for the last N weeks.
|
||||
|
||||
Args:
|
||||
weeks: Number of weeks to retrieve (default 4)
|
||||
strategy_name: Name of the strategy (None = all strategies)
|
||||
|
||||
Returns:
|
||||
List of StrategyMetrics, one per week
|
||||
"""
|
||||
metrics = []
|
||||
end_date = datetime.now(UTC)
|
||||
|
||||
for i in range(weeks):
|
||||
week_end = end_date - timedelta(weeks=i)
|
||||
week_start = week_end - timedelta(weeks=1)
|
||||
|
||||
week_metrics = self.get_strategy_metrics(
|
||||
strategy_name=strategy_name,
|
||||
start_date=week_start.isoformat(),
|
||||
end_date=week_end.isoformat(),
|
||||
)
|
||||
metrics.append(week_metrics)
|
||||
|
||||
return metrics
|
||||
|
||||
def calculate_improvement_trend(
|
||||
self, metrics_history: list[StrategyMetrics]
|
||||
) -> dict[str, Any]:
|
||||
"""Calculate improvement trend from historical metrics.
|
||||
|
||||
Args:
|
||||
metrics_history: List of StrategyMetrics ordered from oldest to newest
|
||||
|
||||
Returns:
|
||||
Dictionary with trend analysis
|
||||
"""
|
||||
if len(metrics_history) < 2:
|
||||
return {
|
||||
"trend": "insufficient_data",
|
||||
"win_rate_change": 0.0,
|
||||
"pnl_change": 0.0,
|
||||
"confidence_change": 0.0,
|
||||
}
|
||||
|
||||
oldest = metrics_history[0]
|
||||
newest = metrics_history[-1]
|
||||
|
||||
win_rate_change = newest.win_rate - oldest.win_rate
|
||||
pnl_change = newest.avg_pnl - oldest.avg_pnl
|
||||
confidence_change = newest.avg_confidence - oldest.avg_confidence
|
||||
|
||||
# Determine overall trend
|
||||
if win_rate_change > 5.0 and pnl_change > 0:
|
||||
trend = "improving"
|
||||
elif win_rate_change < -5.0 or pnl_change < 0:
|
||||
trend = "declining"
|
||||
else:
|
||||
trend = "stable"
|
||||
|
||||
return {
|
||||
"trend": trend,
|
||||
"win_rate_change": round(win_rate_change, 2),
|
||||
"pnl_change": round(pnl_change, 2),
|
||||
"confidence_change": round(confidence_change, 2),
|
||||
"period_count": len(metrics_history),
|
||||
}
|
||||
|
||||
def generate_dashboard(
|
||||
self, strategy_name: str | None = None
|
||||
) -> PerformanceDashboard:
|
||||
"""Generate a comprehensive performance dashboard.
|
||||
|
||||
Args:
|
||||
strategy_name: Name of the strategy (None = all strategies)
|
||||
|
||||
Returns:
|
||||
PerformanceDashboard with all metrics
|
||||
"""
|
||||
# Get overall metrics
|
||||
overall_metrics = self.get_strategy_metrics(strategy_name=strategy_name)
|
||||
|
||||
# Get daily metrics (last 7 days)
|
||||
daily_metrics = self.get_daily_metrics(days=7, strategy_name=strategy_name)
|
||||
|
||||
# Get weekly metrics (last 4 weeks)
|
||||
weekly_metrics = self.get_weekly_metrics(weeks=4, strategy_name=strategy_name)
|
||||
|
||||
# Calculate improvement trend
|
||||
improvement_trend = self.calculate_improvement_trend(weekly_metrics[::-1])
|
||||
|
||||
return PerformanceDashboard(
|
||||
generated_at=datetime.now(UTC).isoformat(),
|
||||
overall_metrics=overall_metrics,
|
||||
daily_metrics=daily_metrics,
|
||||
weekly_metrics=weekly_metrics,
|
||||
improvement_trend=improvement_trend,
|
||||
)
|
||||
|
||||
def export_dashboard_json(
|
||||
self, dashboard: PerformanceDashboard
|
||||
) -> str:
|
||||
"""Export dashboard as JSON string.
|
||||
|
||||
Args:
|
||||
dashboard: PerformanceDashboard object
|
||||
|
||||
Returns:
|
||||
JSON string representation
|
||||
"""
|
||||
data = {
|
||||
"generated_at": dashboard.generated_at,
|
||||
"overall_metrics": asdict(dashboard.overall_metrics),
|
||||
"daily_metrics": [asdict(m) for m in dashboard.daily_metrics],
|
||||
"weekly_metrics": [asdict(m) for m in dashboard.weekly_metrics],
|
||||
"improvement_trend": dashboard.improvement_trend,
|
||||
}
|
||||
return json.dumps(data, indent=2)
|
||||
|
||||
def log_dashboard(self, dashboard: PerformanceDashboard) -> None:
|
||||
"""Log dashboard summary to logger.
|
||||
|
||||
Args:
|
||||
dashboard: PerformanceDashboard object
|
||||
"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("PERFORMANCE DASHBOARD")
|
||||
logger.info("=" * 60)
|
||||
logger.info("Generated: %s", dashboard.generated_at)
|
||||
logger.info("")
|
||||
logger.info("Overall Performance:")
|
||||
logger.info(" Total Trades: %d", dashboard.overall_metrics.total_trades)
|
||||
logger.info(" Win Rate: %.2f%%", dashboard.overall_metrics.win_rate)
|
||||
logger.info(" Average P&L: %.2f", dashboard.overall_metrics.avg_pnl)
|
||||
logger.info(" Total P&L: %.2f", dashboard.overall_metrics.total_pnl)
|
||||
logger.info("")
|
||||
logger.info("Improvement Trend (%s):", dashboard.improvement_trend["trend"])
|
||||
logger.info(" Win Rate Change: %+.2f%%", dashboard.improvement_trend["win_rate_change"])
|
||||
logger.info(" P&L Change: %+.2f", dashboard.improvement_trend["pnl_change"])
|
||||
logger.info("=" * 60)
|
||||
25
src/evolution/scorecard.py
Normal file
25
src/evolution/scorecard.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""Daily scorecard model for end-of-day performance review."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class DailyScorecard:
|
||||
"""Structured daily performance snapshot for a single market."""
|
||||
|
||||
date: str
|
||||
market: str
|
||||
total_decisions: int
|
||||
buys: int
|
||||
sells: int
|
||||
holds: int
|
||||
total_pnl: float
|
||||
win_rate: float
|
||||
avg_confidence: float
|
||||
scenario_match_rate: float
|
||||
top_winners: list[str] = field(default_factory=list)
|
||||
top_losers: list[str] = field(default_factory=list)
|
||||
lessons: list[str] = field(default_factory=list)
|
||||
cross_market_note: str = ""
|
||||
5
src/logging/__init__.py
Normal file
5
src/logging/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Decision logging and audit trail for trade decisions."""
|
||||
|
||||
from src.logging.decision_logger import DecisionLog, DecisionLogger
|
||||
|
||||
__all__ = ["DecisionLog", "DecisionLogger"]
|
||||
235
src/logging/decision_logger.py
Normal file
235
src/logging/decision_logger.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""Decision logging system with context snapshots for comprehensive audit trail."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecisionLog:
|
||||
"""A logged trading decision with context and outcome."""
|
||||
|
||||
decision_id: str
|
||||
timestamp: str
|
||||
stock_code: str
|
||||
market: str
|
||||
exchange_code: str
|
||||
action: str
|
||||
confidence: int
|
||||
rationale: str
|
||||
context_snapshot: dict[str, Any]
|
||||
input_data: dict[str, Any]
|
||||
outcome_pnl: float | None = None
|
||||
outcome_accuracy: int | None = None
|
||||
reviewed: bool = False
|
||||
review_notes: str | None = None
|
||||
|
||||
|
||||
class DecisionLogger:
|
||||
"""Logs trading decisions with full context for review and evolution."""
|
||||
|
||||
def __init__(self, conn: sqlite3.Connection) -> None:
|
||||
"""Initialize the decision logger with a database connection."""
|
||||
self.conn = conn
|
||||
|
||||
def log_decision(
|
||||
self,
|
||||
stock_code: str,
|
||||
market: str,
|
||||
exchange_code: str,
|
||||
action: str,
|
||||
confidence: int,
|
||||
rationale: str,
|
||||
context_snapshot: dict[str, Any],
|
||||
input_data: dict[str, Any],
|
||||
) -> str:
|
||||
"""Log a trading decision with full context.
|
||||
|
||||
Args:
|
||||
stock_code: Stock symbol
|
||||
market: Market code (e.g., "KR", "US_NASDAQ")
|
||||
exchange_code: Exchange code (e.g., "KRX", "NASDAQ")
|
||||
action: Trading action (BUY/SELL/HOLD)
|
||||
confidence: Confidence level (0-100)
|
||||
rationale: Reasoning for the decision
|
||||
context_snapshot: L1-L7 context snapshot at decision time
|
||||
input_data: Market data inputs (price, volume, orderbook, etc.)
|
||||
|
||||
Returns:
|
||||
decision_id: Unique identifier for this decision
|
||||
"""
|
||||
decision_id = str(uuid.uuid4())
|
||||
timestamp = datetime.now(UTC).isoformat()
|
||||
|
||||
self.conn.execute(
|
||||
"""
|
||||
INSERT INTO decision_logs (
|
||||
decision_id, timestamp, stock_code, market, exchange_code,
|
||||
action, confidence, rationale, context_snapshot, input_data
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
decision_id,
|
||||
timestamp,
|
||||
stock_code,
|
||||
market,
|
||||
exchange_code,
|
||||
action,
|
||||
confidence,
|
||||
rationale,
|
||||
json.dumps(context_snapshot),
|
||||
json.dumps(input_data),
|
||||
),
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
return decision_id
|
||||
|
||||
def get_unreviewed_decisions(
|
||||
self, min_confidence: int = 80, limit: int | None = None
|
||||
) -> list[DecisionLog]:
|
||||
"""Get unreviewed decisions with high confidence.
|
||||
|
||||
Args:
|
||||
min_confidence: Minimum confidence threshold (default 80)
|
||||
limit: Maximum number of results (None = unlimited)
|
||||
|
||||
Returns:
|
||||
List of unreviewed DecisionLog objects
|
||||
"""
|
||||
query = """
|
||||
SELECT
|
||||
decision_id, timestamp, stock_code, market, exchange_code,
|
||||
action, confidence, rationale, context_snapshot, input_data,
|
||||
outcome_pnl, outcome_accuracy, reviewed, review_notes
|
||||
FROM decision_logs
|
||||
WHERE reviewed = 0 AND confidence >= ?
|
||||
ORDER BY timestamp DESC
|
||||
"""
|
||||
if limit is not None:
|
||||
query += f" LIMIT {limit}"
|
||||
|
||||
cursor = self.conn.execute(query, (min_confidence,))
|
||||
return [self._row_to_decision_log(row) for row in cursor.fetchall()]
|
||||
|
||||
def mark_reviewed(self, decision_id: str, notes: str) -> None:
|
||||
"""Mark a decision as reviewed with notes.
|
||||
|
||||
Args:
|
||||
decision_id: Decision identifier
|
||||
notes: Review notes and insights
|
||||
"""
|
||||
self.conn.execute(
|
||||
"""
|
||||
UPDATE decision_logs
|
||||
SET reviewed = 1, review_notes = ?
|
||||
WHERE decision_id = ?
|
||||
""",
|
||||
(notes, decision_id),
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def update_outcome(
|
||||
self, decision_id: str, pnl: float, accuracy: int
|
||||
) -> None:
|
||||
"""Update the outcome of a decision after trade execution.
|
||||
|
||||
Args:
|
||||
decision_id: Decision identifier
|
||||
pnl: Actual profit/loss realized
|
||||
accuracy: 1 if decision was correct, 0 if wrong
|
||||
"""
|
||||
self.conn.execute(
|
||||
"""
|
||||
UPDATE decision_logs
|
||||
SET outcome_pnl = ?, outcome_accuracy = ?
|
||||
WHERE decision_id = ?
|
||||
""",
|
||||
(pnl, accuracy, decision_id),
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def get_decision_by_id(self, decision_id: str) -> DecisionLog | None:
|
||||
"""Get a specific decision by ID.
|
||||
|
||||
Args:
|
||||
decision_id: Decision identifier
|
||||
|
||||
Returns:
|
||||
DecisionLog object or None if not found
|
||||
"""
|
||||
cursor = self.conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
decision_id, timestamp, stock_code, market, exchange_code,
|
||||
action, confidence, rationale, context_snapshot, input_data,
|
||||
outcome_pnl, outcome_accuracy, reviewed, review_notes
|
||||
FROM decision_logs
|
||||
WHERE decision_id = ?
|
||||
""",
|
||||
(decision_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return self._row_to_decision_log(row) if row else None
|
||||
|
||||
def get_losing_decisions(
|
||||
self, min_confidence: int = 80, min_loss: float = -100.0
|
||||
) -> list[DecisionLog]:
|
||||
"""Get high-confidence decisions that resulted in losses.
|
||||
|
||||
Useful for identifying patterns in failed predictions.
|
||||
|
||||
Args:
|
||||
min_confidence: Minimum confidence threshold (default 80)
|
||||
min_loss: Minimum loss amount (default -100.0, i.e., loss >= 100)
|
||||
|
||||
Returns:
|
||||
List of losing DecisionLog objects
|
||||
"""
|
||||
cursor = self.conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
decision_id, timestamp, stock_code, market, exchange_code,
|
||||
action, confidence, rationale, context_snapshot, input_data,
|
||||
outcome_pnl, outcome_accuracy, reviewed, review_notes
|
||||
FROM decision_logs
|
||||
WHERE confidence >= ?
|
||||
AND outcome_pnl IS NOT NULL
|
||||
AND outcome_pnl <= ?
|
||||
ORDER BY outcome_pnl ASC
|
||||
""",
|
||||
(min_confidence, min_loss),
|
||||
)
|
||||
return [self._row_to_decision_log(row) for row in cursor.fetchall()]
|
||||
|
||||
def _row_to_decision_log(self, row: tuple[Any, ...]) -> DecisionLog:
|
||||
"""Convert a database row to a DecisionLog object.
|
||||
|
||||
Args:
|
||||
row: Database row tuple
|
||||
|
||||
Returns:
|
||||
DecisionLog object
|
||||
"""
|
||||
return DecisionLog(
|
||||
decision_id=row[0],
|
||||
timestamp=row[1],
|
||||
stock_code=row[2],
|
||||
market=row[3],
|
||||
exchange_code=row[4],
|
||||
action=row[5],
|
||||
confidence=row[6],
|
||||
rationale=row[7],
|
||||
context_snapshot=json.loads(row[8]),
|
||||
input_data=json.loads(row[9]),
|
||||
outcome_pnl=row[10],
|
||||
outcome_accuracy=row[11],
|
||||
reviewed=bool(row[12]),
|
||||
review_notes=row[13],
|
||||
)
|
||||
2792
src/main.py
2792
src/main.py
File diff suppressed because it is too large
Load Diff
@@ -123,6 +123,23 @@ MARKETS: dict[str, MarketInfo] = {
|
||||
),
|
||||
}
|
||||
|
||||
MARKET_SHORTHAND: dict[str, list[str]] = {
|
||||
"US": ["US_NASDAQ", "US_NYSE", "US_AMEX"],
|
||||
"CN": ["CN_SHA", "CN_SZA"],
|
||||
"VN": ["VN_HAN", "VN_HCM"],
|
||||
}
|
||||
|
||||
|
||||
def expand_market_codes(codes: list[str]) -> list[str]:
|
||||
"""Expand shorthand market codes into concrete exchange market codes."""
|
||||
expanded: list[str] = []
|
||||
for code in codes:
|
||||
if code in MARKET_SHORTHAND:
|
||||
expanded.extend(MARKET_SHORTHAND[code])
|
||||
else:
|
||||
expanded.append(code)
|
||||
return expanded
|
||||
|
||||
|
||||
def is_market_open(market: MarketInfo, now: datetime | None = None) -> bool:
|
||||
"""
|
||||
|
||||
350
src/notifications/README.md
Normal file
350
src/notifications/README.md
Normal file
@@ -0,0 +1,350 @@
|
||||
# Telegram Notifications
|
||||
|
||||
Real-time trading event notifications via Telegram Bot API.
|
||||
|
||||
## Setup
|
||||
|
||||
### 1. Create a Telegram Bot
|
||||
|
||||
1. Open Telegram and message [@BotFather](https://t.me/BotFather)
|
||||
2. Send `/newbot` command
|
||||
3. Follow prompts to name your bot
|
||||
4. Save the **bot token** (looks like `1234567890:ABCdefGHIjklMNOpqrsTUVwxyz`)
|
||||
|
||||
### 2. Get Your Chat ID
|
||||
|
||||
**Option A: Using @userinfobot**
|
||||
1. Message [@userinfobot](https://t.me/userinfobot) on Telegram
|
||||
2. Send `/start`
|
||||
3. Save your numeric **chat ID** (e.g., `123456789`)
|
||||
|
||||
**Option B: Using @RawDataBot**
|
||||
1. Message [@RawDataBot](https://t.me/rawdatabot) on Telegram
|
||||
2. Look for `"id":` in the JSON response
|
||||
3. Save your numeric **chat ID**
|
||||
|
||||
### 3. Configure Environment
|
||||
|
||||
Add to your `.env` file:
|
||||
|
||||
```bash
|
||||
TELEGRAM_BOT_TOKEN=1234567890:ABCdefGHIjklMNOpqrsTUVwxyz
|
||||
TELEGRAM_CHAT_ID=123456789
|
||||
TELEGRAM_ENABLED=true
|
||||
```
|
||||
|
||||
### 4. Test the Bot
|
||||
|
||||
Start a conversation with your bot on Telegram first (send `/start`), then run:
|
||||
|
||||
```bash
|
||||
python -m src.main --mode=paper
|
||||
```
|
||||
|
||||
You should receive a startup notification.
|
||||
|
||||
## Message Examples
|
||||
|
||||
### Trade Execution
|
||||
```
|
||||
🟢 BUY
|
||||
Symbol: AAPL (United States)
|
||||
Quantity: 10 shares
|
||||
Price: 150.25
|
||||
Confidence: 85%
|
||||
```
|
||||
|
||||
### Circuit Breaker
|
||||
```
|
||||
🚨 CIRCUIT BREAKER TRIPPED
|
||||
P&L: -3.15% (threshold: -3.0%)
|
||||
Trading halted for safety
|
||||
```
|
||||
|
||||
### Fat-Finger Protection
|
||||
```
|
||||
⚠️ Fat-Finger Protection
|
||||
Order rejected: TSLA
|
||||
Attempted: 45.0% of cash
|
||||
Max allowed: 30%
|
||||
Amount: 45,000 / 100,000
|
||||
```
|
||||
|
||||
### Market Open/Close
|
||||
```
|
||||
ℹ️ Market Open
|
||||
Korea trading session started
|
||||
|
||||
ℹ️ Market Close
|
||||
Korea trading session ended
|
||||
📈 P&L: +1.25%
|
||||
```
|
||||
|
||||
### System Status
|
||||
```
|
||||
📝 System Started
|
||||
Mode: PAPER
|
||||
Markets: KRX, NASDAQ
|
||||
|
||||
System Shutdown
|
||||
Normal shutdown
|
||||
```
|
||||
|
||||
## Notification Priorities
|
||||
|
||||
| Priority | Emoji | Use Case |
|
||||
|----------|-------|----------|
|
||||
| LOW | ℹ️ | Market open/close |
|
||||
| MEDIUM | 📊 | Trade execution, system start/stop |
|
||||
| HIGH | ⚠️ | Fat-finger protection, errors |
|
||||
| CRITICAL | 🚨 | Circuit breaker trips |
|
||||
|
||||
## Rate Limiting
|
||||
|
||||
- Default: 1 message per second
|
||||
- Prevents hitting Telegram's global rate limits
|
||||
- Configurable via `rate_limit` parameter
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### No notifications received
|
||||
|
||||
1. **Check bot configuration**
|
||||
```bash
|
||||
# Verify env variables are set
|
||||
grep TELEGRAM .env
|
||||
```
|
||||
|
||||
2. **Start conversation with bot**
|
||||
- Open bot in Telegram
|
||||
- Send `/start` command
|
||||
- Bot cannot message users who haven't started a conversation
|
||||
|
||||
3. **Check logs**
|
||||
```bash
|
||||
# Look for Telegram-related errors
|
||||
python -m src.main --mode=paper 2>&1 | grep -i telegram
|
||||
```
|
||||
|
||||
4. **Verify bot token**
|
||||
```bash
|
||||
curl https://api.telegram.org/bot<YOUR_TOKEN>/getMe
|
||||
# Should return bot info (not 401 error)
|
||||
```
|
||||
|
||||
5. **Verify chat ID**
|
||||
```bash
|
||||
curl -X POST https://api.telegram.org/bot<YOUR_TOKEN>/sendMessage \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"chat_id": "<YOUR_CHAT_ID>", "text": "Test"}'
|
||||
# Should send a test message
|
||||
```
|
||||
|
||||
### Notifications delayed
|
||||
|
||||
- Check rate limiter settings
|
||||
- Verify network connection
|
||||
- Look for timeout errors in logs
|
||||
|
||||
### "Chat not found" error
|
||||
|
||||
- Incorrect chat ID
|
||||
- Bot blocked by user
|
||||
- Need to send `/start` to bot first
|
||||
|
||||
### "Unauthorized" error
|
||||
|
||||
- Invalid bot token
|
||||
- Token revoked (regenerate with @BotFather)
|
||||
|
||||
## Graceful Degradation
|
||||
|
||||
The system works without Telegram notifications:
|
||||
|
||||
- Missing credentials → notifications disabled automatically
|
||||
- API errors → logged but trading continues
|
||||
- Network timeouts → trading loop unaffected
|
||||
- Rate limiting → messages queued, trading proceeds
|
||||
|
||||
**Notifications never crash the trading system.**
|
||||
|
||||
## Security Notes
|
||||
|
||||
- Never commit `.env` file with credentials
|
||||
- Bot token grants full bot control
|
||||
- Chat ID is not sensitive (just a number)
|
||||
- Messages are sent over HTTPS
|
||||
- No trading credentials in notifications
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Group Notifications
|
||||
|
||||
1. Add bot to Telegram group
|
||||
2. Get group chat ID (negative number like `-123456789`)
|
||||
3. Use group chat ID in `TELEGRAM_CHAT_ID`
|
||||
|
||||
### Multiple Recipients
|
||||
|
||||
Create multiple bots or use a broadcast group with multiple members.
|
||||
|
||||
### Custom Rate Limits
|
||||
|
||||
Not currently exposed in config, but can be modified in code:
|
||||
|
||||
```python
|
||||
telegram = TelegramClient(
|
||||
bot_token=settings.TELEGRAM_BOT_TOKEN,
|
||||
chat_id=settings.TELEGRAM_CHAT_ID,
|
||||
rate_limit=2.0, # 2 messages per second
|
||||
)
|
||||
```
|
||||
|
||||
## Bidirectional Commands
|
||||
|
||||
Control your trading bot remotely via Telegram commands. The bot not only sends notifications but also accepts commands for real-time control.
|
||||
|
||||
### Available Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/start` | Welcome message with quick start guide |
|
||||
| `/help` | List all available commands |
|
||||
| `/status` | Current trading status (mode, markets, P&L, circuit breaker) |
|
||||
| `/positions` | View current holdings grouped by market |
|
||||
| `/stop` | Pause all trading operations |
|
||||
| `/resume` | Resume trading operations |
|
||||
|
||||
### Command Examples
|
||||
|
||||
**Check Trading Status**
|
||||
```
|
||||
You: /status
|
||||
|
||||
Bot:
|
||||
📊 Trading Status
|
||||
|
||||
Mode: PAPER
|
||||
Markets: Korea, United States
|
||||
Trading: Active
|
||||
|
||||
Current P&L: +2.50%
|
||||
Circuit Breaker: -3.0%
|
||||
```
|
||||
|
||||
**View Holdings**
|
||||
```
|
||||
You: /positions
|
||||
|
||||
Bot:
|
||||
💼 Current Holdings
|
||||
|
||||
🇰🇷 Korea
|
||||
• 005930: 10 shares @ 70,000
|
||||
• 035420: 5 shares @ 200,000
|
||||
|
||||
🇺🇸 Overseas
|
||||
• AAPL: 15 shares @ 175
|
||||
• TSLA: 8 shares @ 245
|
||||
|
||||
Cash: ₩5,000,000
|
||||
```
|
||||
|
||||
**Pause Trading**
|
||||
```
|
||||
You: /stop
|
||||
|
||||
Bot:
|
||||
⏸️ Trading Paused
|
||||
|
||||
All trading operations have been suspended.
|
||||
Use /resume to restart trading.
|
||||
```
|
||||
|
||||
**Resume Trading**
|
||||
```
|
||||
You: /resume
|
||||
|
||||
Bot:
|
||||
▶️ Trading Resumed
|
||||
|
||||
Trading operations have been restarted.
|
||||
```
|
||||
|
||||
### Security
|
||||
|
||||
**Chat ID Verification**
|
||||
- Commands are only accepted from the configured `TELEGRAM_CHAT_ID`
|
||||
- Unauthorized users receive no response
|
||||
- Command attempts from wrong chat IDs are logged
|
||||
|
||||
**Authorization Required**
|
||||
- Only the bot owner (chat ID in `.env`) can control trading
|
||||
- No way for unauthorized users to discover or use commands
|
||||
- All command executions are logged for audit
|
||||
|
||||
### Configuration
|
||||
|
||||
Add to your `.env` file:
|
||||
|
||||
```bash
|
||||
# Commands are enabled by default
|
||||
TELEGRAM_COMMANDS_ENABLED=true
|
||||
|
||||
# Polling interval (seconds) - how often to check for commands
|
||||
TELEGRAM_POLLING_INTERVAL=1.0
|
||||
```
|
||||
|
||||
To disable commands but keep notifications:
|
||||
```bash
|
||||
TELEGRAM_COMMANDS_ENABLED=false
|
||||
```
|
||||
|
||||
### How It Works
|
||||
|
||||
1. **Long Polling**: Bot checks Telegram API every second for new messages
|
||||
2. **Command Parsing**: Messages starting with `/` are parsed as commands
|
||||
3. **Authentication**: Chat ID is verified before executing any command
|
||||
4. **Execution**: Command handler is called with current bot state
|
||||
5. **Response**: Result is sent back via Telegram
|
||||
|
||||
### Error Handling
|
||||
|
||||
- Command parsing errors → "Unknown command" response
|
||||
- API failures → Graceful degradation, error logged
|
||||
- Invalid state → Appropriate message (e.g., "Trading is already paused")
|
||||
- Trading loop isolation → Command errors never crash trading
|
||||
|
||||
### Troubleshooting Commands
|
||||
|
||||
**Commands not responding**
|
||||
1. Check `TELEGRAM_COMMANDS_ENABLED=true` in `.env`
|
||||
2. Verify you started conversation with `/start`
|
||||
3. Check logs for command handler errors
|
||||
4. Confirm chat ID matches `.env` configuration
|
||||
|
||||
**Wrong chat ID**
|
||||
- Commands from unauthorized chats are silently ignored
|
||||
- Check logs for "unauthorized chat_id" warnings
|
||||
|
||||
**Delayed responses**
|
||||
- Polling interval is 1 second by default
|
||||
- Network latency may add delay
|
||||
- Check `TELEGRAM_POLLING_INTERVAL` setting
|
||||
|
||||
## API Reference
|
||||
|
||||
See `telegram_client.py` for full API documentation.
|
||||
|
||||
### Notification Methods
|
||||
- `notify_trade_execution()` - Trade alerts
|
||||
- `notify_circuit_breaker()` - Emergency stops
|
||||
- `notify_fat_finger()` - Order rejections
|
||||
- `notify_market_open/close()` - Session tracking
|
||||
- `notify_system_start/shutdown()` - Lifecycle events
|
||||
- `notify_error()` - Error alerts
|
||||
|
||||
### Command Handler
|
||||
- `TelegramCommandHandler` - Bidirectional command processing
|
||||
- `register_command()` - Register custom command handlers
|
||||
- `start_polling()` / `stop_polling()` - Lifecycle management
|
||||
5
src/notifications/__init__.py
Normal file
5
src/notifications/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Real-time notifications for trading events."""
|
||||
|
||||
from src.notifications.telegram_client import TelegramClient
|
||||
|
||||
__all__ = ["TelegramClient"]
|
||||
734
src/notifications/telegram_client.py
Normal file
734
src/notifications/telegram_client.py
Normal file
@@ -0,0 +1,734 @@
|
||||
"""Telegram notification client for real-time trading alerts."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, fields
|
||||
from enum import Enum
|
||||
from typing import ClassVar
|
||||
|
||||
import aiohttp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NotificationPriority(Enum):
|
||||
"""Priority levels for notifications with emoji indicators."""
|
||||
|
||||
LOW = ("ℹ️", "info")
|
||||
MEDIUM = ("📊", "medium")
|
||||
HIGH = ("⚠️", "warning")
|
||||
CRITICAL = ("🚨", "critical")
|
||||
|
||||
def __init__(self, emoji: str, label: str) -> None:
|
||||
self.emoji = emoji
|
||||
self.label = label
|
||||
|
||||
|
||||
class LeakyBucket:
|
||||
"""Rate limiter using leaky bucket algorithm."""
|
||||
|
||||
def __init__(self, rate: float, capacity: int = 1) -> None:
|
||||
"""
|
||||
Initialize rate limiter.
|
||||
|
||||
Args:
|
||||
rate: Maximum requests per second
|
||||
capacity: Bucket capacity (burst size)
|
||||
"""
|
||||
self._rate = rate
|
||||
self._capacity = capacity
|
||||
self._tokens = float(capacity)
|
||||
self._last_update = time.monotonic()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def acquire(self) -> None:
|
||||
"""Wait until a token is available, then consume it."""
|
||||
async with self._lock:
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_update
|
||||
self._tokens = min(self._capacity, self._tokens + elapsed * self._rate)
|
||||
self._last_update = now
|
||||
|
||||
if self._tokens < 1.0:
|
||||
wait_time = (1.0 - self._tokens) / self._rate
|
||||
await asyncio.sleep(wait_time)
|
||||
self._tokens = 0.0
|
||||
else:
|
||||
self._tokens -= 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class NotificationFilter:
|
||||
"""Granular on/off flags for each notification type.
|
||||
|
||||
circuit_breaker is intentionally omitted — it is always sent regardless.
|
||||
"""
|
||||
|
||||
# Maps user-facing command keys to dataclass field names
|
||||
KEYS: ClassVar[dict[str, str]] = {
|
||||
"trades": "trades",
|
||||
"market": "market_open_close",
|
||||
"fatfinger": "fat_finger",
|
||||
"system": "system_events",
|
||||
"playbook": "playbook",
|
||||
"scenario": "scenario_match",
|
||||
"errors": "errors",
|
||||
}
|
||||
|
||||
trades: bool = True
|
||||
market_open_close: bool = True
|
||||
fat_finger: bool = True
|
||||
system_events: bool = True
|
||||
playbook: bool = True
|
||||
scenario_match: bool = True
|
||||
errors: bool = True
|
||||
|
||||
def set_flag(self, key: str, value: bool) -> bool:
|
||||
"""Set a filter flag by user-facing key. Returns False if key is unknown."""
|
||||
field = self.KEYS.get(key.lower())
|
||||
if field is None:
|
||||
return False
|
||||
setattr(self, field, value)
|
||||
return True
|
||||
|
||||
def as_dict(self) -> dict[str, bool]:
|
||||
"""Return {user_key: current_value} for display."""
|
||||
return {k: getattr(self, field) for k, field in self.KEYS.items()}
|
||||
|
||||
|
||||
@dataclass
|
||||
class NotificationMessage:
|
||||
"""Internal notification message structure."""
|
||||
|
||||
priority: NotificationPriority
|
||||
message: str
|
||||
|
||||
|
||||
class TelegramClient:
|
||||
"""Telegram Bot API client for sending trading notifications."""
|
||||
|
||||
API_BASE = "https://api.telegram.org/bot{token}"
|
||||
DEFAULT_TIMEOUT = 5.0 # seconds
|
||||
DEFAULT_RATE = 1.0 # messages per second
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bot_token: str | None = None,
|
||||
chat_id: str | None = None,
|
||||
enabled: bool = True,
|
||||
rate_limit: float = DEFAULT_RATE,
|
||||
notification_filter: NotificationFilter | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize Telegram client.
|
||||
|
||||
Args:
|
||||
bot_token: Telegram bot token from @BotFather
|
||||
chat_id: Target chat ID (user or group)
|
||||
enabled: Enable/disable notifications globally
|
||||
rate_limit: Maximum messages per second
|
||||
notification_filter: Granular per-type on/off flags
|
||||
"""
|
||||
self._bot_token = bot_token
|
||||
self._chat_id = chat_id
|
||||
self._enabled = enabled
|
||||
self._rate_limiter = LeakyBucket(rate=rate_limit)
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._filter = notification_filter if notification_filter is not None else NotificationFilter()
|
||||
|
||||
if not enabled:
|
||||
logger.info("Telegram notifications disabled via configuration")
|
||||
elif bot_token is None or chat_id is None:
|
||||
logger.warning(
|
||||
"Telegram notifications disabled (missing bot_token or chat_id)"
|
||||
)
|
||||
self._enabled = False
|
||||
else:
|
||||
logger.info("Telegram notifications enabled for chat_id=%s", chat_id)
|
||||
|
||||
def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create aiohttp session."""
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=self.DEFAULT_TIMEOUT)
|
||||
)
|
||||
return self._session
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close HTTP session."""
|
||||
if self._session is not None and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
def set_notification(self, key: str, value: bool) -> bool:
|
||||
"""Toggle a notification type by user-facing key at runtime.
|
||||
|
||||
Args:
|
||||
key: User-facing key (e.g. "scenario", "market", "all")
|
||||
value: True to enable, False to disable
|
||||
|
||||
Returns:
|
||||
True if key was valid, False if unknown.
|
||||
"""
|
||||
if key == "all":
|
||||
for k in NotificationFilter.KEYS:
|
||||
self._filter.set_flag(k, value)
|
||||
return True
|
||||
return self._filter.set_flag(key, value)
|
||||
|
||||
def filter_status(self) -> dict[str, bool]:
|
||||
"""Return current per-type filter state keyed by user-facing names."""
|
||||
return self._filter.as_dict()
|
||||
|
||||
async def send_message(self, text: str, parse_mode: str = "HTML") -> bool:
|
||||
"""
|
||||
Send a generic text message to Telegram.
|
||||
|
||||
Args:
|
||||
text: Message text to send
|
||||
parse_mode: Parse mode for formatting (HTML or Markdown)
|
||||
|
||||
Returns:
|
||||
True if message was sent successfully, False otherwise
|
||||
"""
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
try:
|
||||
await self._rate_limiter.acquire()
|
||||
|
||||
url = f"{self.API_BASE.format(token=self._bot_token)}/sendMessage"
|
||||
payload = {
|
||||
"chat_id": self._chat_id,
|
||||
"text": text,
|
||||
"parse_mode": parse_mode,
|
||||
}
|
||||
|
||||
session = self._get_session()
|
||||
async with session.post(url, json=payload) as resp:
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
logger.error(
|
||||
"Telegram API error (status=%d): %s", resp.status, error_text
|
||||
)
|
||||
return False
|
||||
logger.debug("Telegram message sent: %s", text[:50])
|
||||
return True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Telegram message timeout")
|
||||
return False
|
||||
except aiohttp.ClientError as exc:
|
||||
logger.error("Telegram message failed: %s", exc)
|
||||
return False
|
||||
except Exception as exc:
|
||||
logger.error("Unexpected error sending message: %s", exc)
|
||||
return False
|
||||
|
||||
async def _send_notification(self, msg: NotificationMessage) -> None:
|
||||
"""
|
||||
Send notification to Telegram with graceful degradation.
|
||||
|
||||
Args:
|
||||
msg: Notification message to send
|
||||
"""
|
||||
formatted_message = f"{msg.priority.emoji} {msg.message}"
|
||||
await self.send_message(formatted_message)
|
||||
|
||||
async def notify_trade_execution(
|
||||
self,
|
||||
stock_code: str,
|
||||
market: str,
|
||||
action: str,
|
||||
quantity: int,
|
||||
price: float,
|
||||
confidence: float,
|
||||
) -> None:
|
||||
"""
|
||||
Notify trade execution.
|
||||
|
||||
Args:
|
||||
stock_code: Stock ticker symbol
|
||||
market: Market name (e.g., "Korea", "United States")
|
||||
action: "BUY" or "SELL"
|
||||
quantity: Number of shares
|
||||
price: Execution price
|
||||
confidence: AI confidence level (0-100)
|
||||
"""
|
||||
if not self._filter.trades:
|
||||
return
|
||||
emoji = "🟢" if action == "BUY" else "🔴"
|
||||
message = (
|
||||
f"<b>{emoji} {action}</b>\n"
|
||||
f"Symbol: <code>{stock_code}</code> ({market})\n"
|
||||
f"Quantity: {quantity:,} shares\n"
|
||||
f"Price: {price:,.2f}\n"
|
||||
f"Confidence: {confidence:.0f}%"
|
||||
)
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=NotificationPriority.MEDIUM, message=message)
|
||||
)
|
||||
|
||||
async def notify_market_open(self, market_name: str) -> None:
|
||||
"""
|
||||
Notify market opening.
|
||||
|
||||
Args:
|
||||
market_name: Name of the market (e.g., "Korea", "United States")
|
||||
"""
|
||||
if not self._filter.market_open_close:
|
||||
return
|
||||
message = f"<b>Market Open</b>\n{market_name} trading session started"
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=NotificationPriority.LOW, message=message)
|
||||
)
|
||||
|
||||
async def notify_market_close(self, market_name: str, pnl_pct: float) -> None:
|
||||
"""
|
||||
Notify market closing.
|
||||
|
||||
Args:
|
||||
market_name: Name of the market
|
||||
pnl_pct: Final P&L percentage for the session
|
||||
"""
|
||||
if not self._filter.market_open_close:
|
||||
return
|
||||
pnl_sign = "+" if pnl_pct >= 0 else ""
|
||||
pnl_emoji = "📈" if pnl_pct >= 0 else "📉"
|
||||
message = (
|
||||
f"<b>Market Close</b>\n"
|
||||
f"{market_name} trading session ended\n"
|
||||
f"{pnl_emoji} P&L: {pnl_sign}{pnl_pct:.2f}%"
|
||||
)
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=NotificationPriority.LOW, message=message)
|
||||
)
|
||||
|
||||
async def notify_circuit_breaker(
|
||||
self, pnl_pct: float, threshold: float
|
||||
) -> None:
|
||||
"""
|
||||
Notify circuit breaker activation.
|
||||
|
||||
Args:
|
||||
pnl_pct: Current P&L percentage
|
||||
threshold: Circuit breaker threshold
|
||||
"""
|
||||
message = (
|
||||
f"<b>CIRCUIT BREAKER TRIPPED</b>\n"
|
||||
f"P&L: {pnl_pct:.2f}% (threshold: {threshold:.1f}%)\n"
|
||||
f"Trading halted for safety"
|
||||
)
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=NotificationPriority.CRITICAL, message=message)
|
||||
)
|
||||
|
||||
async def notify_fat_finger(
|
||||
self,
|
||||
stock_code: str,
|
||||
order_amount: float,
|
||||
total_cash: float,
|
||||
max_pct: float,
|
||||
) -> None:
|
||||
"""
|
||||
Notify fat-finger protection rejection.
|
||||
|
||||
Args:
|
||||
stock_code: Stock ticker symbol
|
||||
order_amount: Attempted order amount
|
||||
total_cash: Total available cash
|
||||
max_pct: Maximum allowed percentage
|
||||
"""
|
||||
if not self._filter.fat_finger:
|
||||
return
|
||||
attempted_pct = (order_amount / total_cash) * 100 if total_cash > 0 else 0
|
||||
message = (
|
||||
f"<b>Fat-Finger Protection</b>\n"
|
||||
f"Order rejected: <code>{stock_code}</code>\n"
|
||||
f"Attempted: {attempted_pct:.1f}% of cash\n"
|
||||
f"Max allowed: {max_pct:.0f}%\n"
|
||||
f"Amount: {order_amount:,.0f} / {total_cash:,.0f}"
|
||||
)
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=NotificationPriority.HIGH, message=message)
|
||||
)
|
||||
|
||||
async def notify_system_start(
|
||||
self, mode: str, enabled_markets: list[str]
|
||||
) -> None:
|
||||
"""
|
||||
Notify system startup.
|
||||
|
||||
Args:
|
||||
mode: Trading mode ("paper" or "live")
|
||||
enabled_markets: List of enabled market codes
|
||||
"""
|
||||
if not self._filter.system_events:
|
||||
return
|
||||
mode_emoji = "📝" if mode == "paper" else "💰"
|
||||
markets_str = ", ".join(enabled_markets)
|
||||
message = (
|
||||
f"<b>{mode_emoji} System Started</b>\n"
|
||||
f"Mode: {mode.upper()}\n"
|
||||
f"Markets: {markets_str}"
|
||||
)
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=NotificationPriority.MEDIUM, message=message)
|
||||
)
|
||||
|
||||
async def notify_playbook_generated(
|
||||
self,
|
||||
market: str,
|
||||
stock_count: int,
|
||||
scenario_count: int,
|
||||
token_count: int,
|
||||
) -> None:
|
||||
"""
|
||||
Notify that a daily playbook was generated.
|
||||
|
||||
Args:
|
||||
market: Market code (e.g., "KR", "US")
|
||||
stock_count: Number of stocks in the playbook
|
||||
scenario_count: Total number of scenarios
|
||||
token_count: Gemini token usage for the playbook
|
||||
"""
|
||||
if not self._filter.playbook:
|
||||
return
|
||||
message = (
|
||||
f"<b>Playbook Generated</b>\n"
|
||||
f"Market: {market}\n"
|
||||
f"Stocks: {stock_count}\n"
|
||||
f"Scenarios: {scenario_count}\n"
|
||||
f"Tokens: {token_count}"
|
||||
)
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=NotificationPriority.MEDIUM, message=message)
|
||||
)
|
||||
|
||||
async def notify_scenario_matched(
|
||||
self,
|
||||
stock_code: str,
|
||||
action: str,
|
||||
condition_summary: str,
|
||||
confidence: float,
|
||||
) -> None:
|
||||
"""
|
||||
Notify that a scenario matched for a stock.
|
||||
|
||||
Args:
|
||||
stock_code: Stock ticker symbol
|
||||
action: Scenario action (BUY/SELL/HOLD/REDUCE_ALL)
|
||||
condition_summary: Short summary of the matched condition
|
||||
confidence: Scenario confidence (0-100)
|
||||
"""
|
||||
if not self._filter.scenario_match:
|
||||
return
|
||||
message = (
|
||||
f"<b>Scenario Matched</b>\n"
|
||||
f"Symbol: <code>{stock_code}</code>\n"
|
||||
f"Action: {action}\n"
|
||||
f"Condition: {condition_summary}\n"
|
||||
f"Confidence: {confidence:.0f}%"
|
||||
)
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=NotificationPriority.HIGH, message=message)
|
||||
)
|
||||
|
||||
async def notify_playbook_failed(self, market: str, reason: str) -> None:
|
||||
"""
|
||||
Notify that playbook generation failed.
|
||||
|
||||
Args:
|
||||
market: Market code (e.g., "KR", "US")
|
||||
reason: Failure reason summary
|
||||
"""
|
||||
if not self._filter.playbook:
|
||||
return
|
||||
message = (
|
||||
f"<b>Playbook Failed</b>\n"
|
||||
f"Market: {market}\n"
|
||||
f"Reason: {reason[:200]}"
|
||||
)
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=NotificationPriority.HIGH, message=message)
|
||||
)
|
||||
|
||||
async def notify_system_shutdown(self, reason: str) -> None:
|
||||
"""
|
||||
Notify system shutdown.
|
||||
|
||||
Args:
|
||||
reason: Reason for shutdown (e.g., "Normal shutdown", "Circuit breaker")
|
||||
"""
|
||||
if not self._filter.system_events:
|
||||
return
|
||||
message = f"<b>System Shutdown</b>\n{reason}"
|
||||
priority = (
|
||||
NotificationPriority.CRITICAL
|
||||
if "circuit breaker" in reason.lower()
|
||||
else NotificationPriority.MEDIUM
|
||||
)
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=priority, message=message)
|
||||
)
|
||||
|
||||
async def notify_unfilled_order(
|
||||
self,
|
||||
stock_code: str,
|
||||
market: str,
|
||||
action: str,
|
||||
quantity: int,
|
||||
outcome: str,
|
||||
new_price: float | None = None,
|
||||
) -> None:
|
||||
"""Notify about an unfilled overseas order that was cancelled or resubmitted.
|
||||
|
||||
Args:
|
||||
stock_code: Stock ticker symbol.
|
||||
market: Exchange/market code (e.g., "NASD", "SEHK").
|
||||
action: "BUY" or "SELL".
|
||||
quantity: Unfilled quantity.
|
||||
outcome: "cancelled" or "resubmitted".
|
||||
new_price: New order price if resubmitted (None if only cancelled).
|
||||
"""
|
||||
if not self._filter.trades:
|
||||
return
|
||||
# SELL resubmit is high priority — position liquidation at risk.
|
||||
# BUY cancel is medium priority — only cash is freed.
|
||||
priority = (
|
||||
NotificationPriority.HIGH
|
||||
if action == "SELL"
|
||||
else NotificationPriority.MEDIUM
|
||||
)
|
||||
outcome_emoji = "🔄" if outcome == "resubmitted" else "❌"
|
||||
outcome_label = "재주문" if outcome == "resubmitted" else "취소됨"
|
||||
action_emoji = "🔴" if action == "SELL" else "🟢"
|
||||
lines = [
|
||||
f"<b>{outcome_emoji} 미체결 주문 {outcome_label}</b>",
|
||||
f"Symbol: <code>{stock_code}</code> ({market})",
|
||||
f"Action: {action_emoji} {action}",
|
||||
f"Quantity: {quantity:,} shares",
|
||||
]
|
||||
if new_price is not None:
|
||||
lines.append(f"New Price: {new_price:.4f}")
|
||||
message = "\n".join(lines)
|
||||
await self._send_notification(NotificationMessage(priority=priority, message=message))
|
||||
|
||||
async def notify_error(
|
||||
self, error_type: str, error_msg: str, context: str
|
||||
) -> None:
|
||||
"""
|
||||
Notify system error.
|
||||
|
||||
Args:
|
||||
error_type: Type of error (e.g., "Connection Error")
|
||||
error_msg: Error message
|
||||
context: Error context (e.g., stock code, market)
|
||||
"""
|
||||
if not self._filter.errors:
|
||||
return
|
||||
message = (
|
||||
f"<b>Error: {error_type}</b>\n"
|
||||
f"Context: {context}\n"
|
||||
f"Message: {error_msg[:200]}" # Truncate long errors
|
||||
)
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=NotificationPriority.HIGH, message=message)
|
||||
)
|
||||
|
||||
|
||||
class TelegramCommandHandler:
|
||||
"""Handles incoming Telegram commands via long polling."""
|
||||
|
||||
def __init__(
|
||||
self, client: TelegramClient, polling_interval: float = 1.0
|
||||
) -> None:
|
||||
"""
|
||||
Initialize command handler.
|
||||
|
||||
Args:
|
||||
client: TelegramClient instance for sending responses
|
||||
polling_interval: Polling interval in seconds
|
||||
"""
|
||||
self._client = client
|
||||
self._polling_interval = polling_interval
|
||||
self._commands: dict[str, Callable[[], Awaitable[None]]] = {}
|
||||
self._commands_with_args: dict[str, Callable[[list[str]], Awaitable[None]]] = {}
|
||||
self._last_update_id = 0
|
||||
self._polling_task: asyncio.Task[None] | None = None
|
||||
self._running = False
|
||||
|
||||
def register_command(
|
||||
self, command: str, handler: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
"""
|
||||
Register a command handler (no arguments).
|
||||
|
||||
Args:
|
||||
command: Command name (without leading slash, e.g., "start")
|
||||
handler: Async function to handle the command
|
||||
"""
|
||||
self._commands[command] = handler
|
||||
logger.debug("Registered command handler: /%s", command)
|
||||
|
||||
def register_command_with_args(
|
||||
self, command: str, handler: Callable[[list[str]], Awaitable[None]]
|
||||
) -> None:
|
||||
"""
|
||||
Register a command handler that receives trailing arguments.
|
||||
|
||||
Args:
|
||||
command: Command name (without leading slash, e.g., "notify")
|
||||
handler: Async function receiving list of argument tokens
|
||||
"""
|
||||
self._commands_with_args[command] = handler
|
||||
logger.debug("Registered command handler (with args): /%s", command)
|
||||
|
||||
async def start_polling(self) -> None:
|
||||
"""Start long polling for commands."""
|
||||
if self._running:
|
||||
logger.warning("Command handler already running")
|
||||
return
|
||||
|
||||
if not self._client._enabled:
|
||||
logger.info("Command handler disabled (TelegramClient disabled)")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._polling_task = asyncio.create_task(self._poll_loop())
|
||||
logger.info("Started Telegram command polling")
|
||||
|
||||
async def stop_polling(self) -> None:
|
||||
"""Stop polling and cancel pending tasks."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
if self._polling_task:
|
||||
self._polling_task.cancel()
|
||||
try:
|
||||
await self._polling_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("Stopped Telegram command polling")
|
||||
|
||||
async def _poll_loop(self) -> None:
|
||||
"""Main polling loop that fetches updates."""
|
||||
while self._running:
|
||||
try:
|
||||
updates = await self._get_updates()
|
||||
for update in updates:
|
||||
await self._handle_update(update)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as exc:
|
||||
logger.error("Error in polling loop: %s", exc)
|
||||
|
||||
await asyncio.sleep(self._polling_interval)
|
||||
|
||||
async def _get_updates(self) -> list[dict]:
|
||||
"""
|
||||
Fetch updates from Telegram API.
|
||||
|
||||
Returns:
|
||||
List of update objects
|
||||
"""
|
||||
try:
|
||||
url = f"{self._client.API_BASE.format(token=self._client._bot_token)}/getUpdates"
|
||||
payload = {
|
||||
"offset": self._last_update_id + 1,
|
||||
"timeout": int(self._polling_interval),
|
||||
"allowed_updates": ["message"],
|
||||
}
|
||||
|
||||
session = self._client._get_session()
|
||||
async with session.post(url, json=payload) as resp:
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
if resp.status == 409:
|
||||
# Another bot instance is already polling — stop this poller entirely.
|
||||
# Retrying would keep conflicting with the other instance.
|
||||
self._running = False
|
||||
logger.warning(
|
||||
"Telegram conflict (409): another instance is already polling. "
|
||||
"Disabling Telegram commands for this process. "
|
||||
"Ensure only one instance of The Ouroboros is running at a time.",
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
"getUpdates API error (status=%d): %s", resp.status, error_text
|
||||
)
|
||||
return []
|
||||
|
||||
data = await resp.json()
|
||||
if not data.get("ok"):
|
||||
logger.error("getUpdates returned ok=false: %s", data)
|
||||
return []
|
||||
|
||||
updates = data.get("result", [])
|
||||
if updates:
|
||||
self._last_update_id = updates[-1]["update_id"]
|
||||
|
||||
return updates
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug("getUpdates timeout (normal)")
|
||||
return []
|
||||
except aiohttp.ClientError as exc:
|
||||
logger.error("getUpdates failed: %s", exc)
|
||||
return []
|
||||
except Exception as exc:
|
||||
logger.error("Unexpected error in _get_updates: %s", exc)
|
||||
return []
|
||||
|
||||
async def _handle_update(self, update: dict) -> None:
|
||||
"""
|
||||
Parse and handle a single update.
|
||||
|
||||
Args:
|
||||
update: Update object from Telegram API
|
||||
"""
|
||||
try:
|
||||
message = update.get("message")
|
||||
if not message:
|
||||
return
|
||||
|
||||
# Verify chat_id matches configured chat
|
||||
chat_id = str(message.get("chat", {}).get("id", ""))
|
||||
if chat_id != self._client._chat_id:
|
||||
logger.warning(
|
||||
"Ignoring command from unauthorized chat_id: %s", chat_id
|
||||
)
|
||||
return
|
||||
|
||||
# Extract command text
|
||||
text = message.get("text", "").strip()
|
||||
if not text.startswith("/"):
|
||||
return
|
||||
|
||||
# Parse command (remove leading slash and extract command name)
|
||||
command_parts = text[1:].split()
|
||||
if not command_parts:
|
||||
return
|
||||
|
||||
# Remove @botname suffix if present (for group chats)
|
||||
command_name = command_parts[0].split("@")[0]
|
||||
|
||||
# Execute handler (args-aware handlers take priority)
|
||||
args_handler = self._commands_with_args.get(command_name)
|
||||
if args_handler:
|
||||
logger.info("Executing command: /%s %s", command_name, command_parts[1:])
|
||||
await args_handler(command_parts[1:])
|
||||
elif command_name in self._commands:
|
||||
logger.info("Executing command: /%s", command_name)
|
||||
await self._commands[command_name]()
|
||||
else:
|
||||
logger.debug("Unknown command: /%s", command_name)
|
||||
await self._client.send_message(
|
||||
f"Unknown command: /{command_name}\nUse /help to see available commands."
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error handling update: %s", exc)
|
||||
# Don't crash the polling loop on handler errors
|
||||
114
src/strategies/v20260220_210124_evolved.py
Normal file
114
src/strategies/v20260220_210124_evolved.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Auto-generated strategy: v20260220_210124
|
||||
|
||||
Generated at: 2026-02-20T21:01:24.706847+00:00
|
||||
Rationale: Auto-evolved from 6 failures. Primary failure markets: ['US_AMEX', 'US_NYSE', 'US_NASDAQ']. Average loss: -194.69
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Any
|
||||
from src.strategies.base import BaseStrategy
|
||||
|
||||
|
||||
class Strategy_v20260220_210124(BaseStrategy):
|
||||
"""Strategy: v20260220_210124"""
|
||||
|
||||
def evaluate(self, market_data: dict[str, Any]) -> dict[str, Any]:
|
||||
import datetime
|
||||
|
||||
# --- Strategy Constants ---
|
||||
# Minimum price for a stock to be considered for trading (avoids penny stocks)
|
||||
MIN_PRICE = 5.0
|
||||
|
||||
# Momentum signal thresholds (stricter than previous failures)
|
||||
MOMENTUM_PRICE_CHANGE_THRESHOLD = 7.0 # % price change
|
||||
MOMENTUM_VOLUME_RATIO_THRESHOLD = 4.0 # X times average volume
|
||||
|
||||
# Oversold signal thresholds (more conservative)
|
||||
OVERSOLD_RSI_THRESHOLD = 25.0 # RSI value (lower means more oversold)
|
||||
|
||||
# Confidence levels
|
||||
CONFIDENCE_HOLD = 30
|
||||
CONFIDENCE_BUY_OVERSOLD = 65
|
||||
CONFIDENCE_BUY_MOMENTUM = 85
|
||||
CONFIDENCE_BUY_STRONG_MOMENTUM = 90 # For higher-priced stocks with strong momentum
|
||||
|
||||
# Market hours in UTC (9:30 AM ET to 4:00 PM ET)
|
||||
MARKET_OPEN_UTC = datetime.time(14, 30)
|
||||
MARKET_CLOSE_UTC = datetime.time(21, 0)
|
||||
|
||||
# Volatile periods within market hours (UTC) to avoid
|
||||
# First hour after open (14:30 UTC - 15:30 UTC)
|
||||
VOLATILE_OPEN_END_UTC = datetime.time(15, 30)
|
||||
# Last 30 minutes before close (20:30 UTC - 21:00 UTC)
|
||||
VOLATILE_CLOSE_START_UTC = datetime.time(20, 30)
|
||||
|
||||
current_price = market_data.get('current_price')
|
||||
price_change_pct = market_data.get('price_change_pct')
|
||||
volume_ratio = market_data.get('volume_ratio') # Assumed pre-computed indicator
|
||||
rsi = market_data.get('rsi') # Assumed pre-computed indicator
|
||||
timestamp_str = market_data.get('timestamp')
|
||||
|
||||
action = "HOLD"
|
||||
confidence = CONFIDENCE_HOLD
|
||||
rationale = "Initial HOLD: No clear signal or conditions not met."
|
||||
|
||||
# --- 1. Basic Data Validation ---
|
||||
if current_price is None or price_change_pct is None:
|
||||
return {"action": "HOLD", "confidence": CONFIDENCE_HOLD,
|
||||
"rationale": "Insufficient core data (price or price change) to evaluate."}
|
||||
|
||||
# --- 2. Price Filter: Avoid low-priced/penny stocks ---
|
||||
if current_price < MIN_PRICE:
|
||||
return {"action": "HOLD", "confidence": CONFIDENCE_HOLD,
|
||||
"rationale": f"Avoiding low-priced stock (${current_price:.2f} < ${MIN_PRICE:.2f})."}
|
||||
|
||||
# --- 3. Time Filter: Only trade during core market hours ---
|
||||
if timestamp_str:
|
||||
try:
|
||||
dt_object = datetime.datetime.fromisoformat(timestamp_str)
|
||||
current_time_utc = dt_object.time()
|
||||
|
||||
if not (MARKET_OPEN_UTC <= current_time_utc < MARKET_CLOSE_UTC):
|
||||
return {"action": "HOLD", "confidence": CONFIDENCE_HOLD,
|
||||
"rationale": f"Avoiding trade outside core market hours ({current_time_utc} UTC)."}
|
||||
|
||||
if (MARKET_OPEN_UTC <= current_time_utc < VOLATILE_OPEN_END_UTC) or \
|
||||
(VOLATILE_CLOSE_START_UTC <= current_time_utc < MARKET_CLOSE_UTC):
|
||||
return {"action": "HOLD", "confidence": CONFIDENCE_HOLD,
|
||||
"rationale": f"Avoiding trade during volatile market open/close periods ({current_time_utc} UTC)."}
|
||||
|
||||
except ValueError:
|
||||
rationale += " (Warning: Malformed timestamp, time filters skipped)"
|
||||
|
||||
# --- Initialize signal states ---
|
||||
has_momentum_buy_signal = False
|
||||
has_oversold_buy_signal = False
|
||||
|
||||
# --- 4. Evaluate Enhanced Buy Signals ---
|
||||
|
||||
# Momentum Buy Signal
|
||||
if volume_ratio is not None and \
|
||||
price_change_pct > MOMENTUM_PRICE_CHANGE_THRESHOLD and \
|
||||
volume_ratio > MOMENTUM_VOLUME_RATIO_THRESHOLD:
|
||||
has_momentum_buy_signal = True
|
||||
rationale = f"Momentum BUY: Price change {price_change_pct:.2f}%, Volume {volume_ratio:.2f}x."
|
||||
confidence = CONFIDENCE_BUY_MOMENTUM
|
||||
if current_price >= 10.0:
|
||||
confidence = CONFIDENCE_BUY_STRONG_MOMENTUM
|
||||
|
||||
# Oversold Buy Signal
|
||||
if rsi is not None and rsi < OVERSOLD_RSI_THRESHOLD:
|
||||
has_oversold_buy_signal = True
|
||||
if not has_momentum_buy_signal:
|
||||
rationale = f"Oversold BUY: RSI {rsi:.2f}."
|
||||
confidence = CONFIDENCE_BUY_OVERSOLD
|
||||
if current_price >= 10.0:
|
||||
confidence = min(CONFIDENCE_BUY_OVERSOLD + 5, 80)
|
||||
|
||||
# --- 5. Decision Logic ---
|
||||
if has_momentum_buy_signal:
|
||||
action = "BUY"
|
||||
elif has_oversold_buy_signal:
|
||||
action = "BUY"
|
||||
|
||||
return {"action": action, "confidence": confidence, "rationale": rationale}
|
||||
97
src/strategies/v20260220_210159_evolved.py
Normal file
97
src/strategies/v20260220_210159_evolved.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Auto-generated strategy: v20260220_210159
|
||||
|
||||
Generated at: 2026-02-20T21:01:59.391523+00:00
|
||||
Rationale: Auto-evolved from 6 failures. Primary failure markets: ['US_AMEX', 'US_NYSE', 'US_NASDAQ']. Average loss: -194.69
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Any
|
||||
from src.strategies.base import BaseStrategy
|
||||
|
||||
|
||||
class Strategy_v20260220_210159(BaseStrategy):
|
||||
"""Strategy: v20260220_210159"""
|
||||
|
||||
def evaluate(self, market_data: dict[str, Any]) -> dict[str, Any]:
|
||||
import datetime
|
||||
|
||||
current_price = market_data.get('current_price')
|
||||
price_change_pct = market_data.get('price_change_pct')
|
||||
volume_ratio = market_data.get('volume_ratio')
|
||||
rsi = market_data.get('rsi')
|
||||
timestamp_str = market_data.get('timestamp')
|
||||
market_name = market_data.get('market')
|
||||
|
||||
# Default action
|
||||
action = "HOLD"
|
||||
confidence = 0
|
||||
rationale = "No strong signal or conditions not met."
|
||||
|
||||
# --- FAILURE PATTERN AVOIDANCE ---
|
||||
|
||||
# 1. Avoid low-priced/penny stocks
|
||||
MIN_PRICE_THRESHOLD = 5.0 # USD
|
||||
if current_price is not None and current_price < MIN_PRICE_THRESHOLD:
|
||||
rationale = (
|
||||
f"HOLD: Stock price (${current_price:.2f}) is below minimum threshold "
|
||||
f"(${MIN_PRICE_THRESHOLD:.2f}). Past failures consistently involved low-priced stocks."
|
||||
)
|
||||
return {"action": action, "confidence": confidence, "rationale": rationale}
|
||||
|
||||
# 2. Avoid early market hour volatility
|
||||
if timestamp_str:
|
||||
try:
|
||||
dt_obj = datetime.datetime.fromisoformat(timestamp_str)
|
||||
utc_hour = dt_obj.hour
|
||||
utc_minute = dt_obj.minute
|
||||
|
||||
if (utc_hour == 14 and utc_minute < 45) or (utc_hour == 13 and utc_minute >= 30):
|
||||
rationale = (
|
||||
f"HOLD: Trading during early market hours (UTC {utc_hour}:{utc_minute}), "
|
||||
f"a period identified with past failures due to high volatility."
|
||||
)
|
||||
return {"action": action, "confidence": confidence, "rationale": rationale}
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# --- IMPROVED BUY STRATEGY ---
|
||||
|
||||
# Momentum BUY signal
|
||||
if volume_ratio is not None and price_change_pct is not None:
|
||||
if price_change_pct > 7.0 and volume_ratio > 3.0:
|
||||
action = "BUY"
|
||||
confidence = 70
|
||||
rationale = "Improved BUY: Momentum signal with high volume and above price threshold."
|
||||
|
||||
if market_name == 'US_AMEX':
|
||||
confidence = max(55, confidence - 5)
|
||||
rationale += " (Adjusted lower for AMEX market's higher risk profile)."
|
||||
elif market_name == 'US_NASDAQ' and price_change_pct > 20:
|
||||
confidence = max(50, confidence - 10)
|
||||
rationale += " (Adjusted lower for aggressive NASDAQ momentum volatility)."
|
||||
|
||||
if price_change_pct > 15.0:
|
||||
confidence = max(50, confidence - 5)
|
||||
rationale += " (Caution: Very high daily price change, potential for reversal)."
|
||||
|
||||
return {"action": action, "confidence": confidence, "rationale": rationale}
|
||||
|
||||
# Oversold BUY signal
|
||||
if rsi is not None and price_change_pct is not None:
|
||||
if rsi < 30 and price_change_pct < -3.0:
|
||||
action = "BUY"
|
||||
confidence = 65
|
||||
rationale = "Improved BUY: Oversold signal with recent decline and above price threshold."
|
||||
|
||||
if market_name == 'US_AMEX':
|
||||
confidence = max(50, confidence - 5)
|
||||
rationale += " (Adjusted lower for AMEX market's higher risk on oversold assets)."
|
||||
|
||||
if price_change_pct < -10.0:
|
||||
confidence = max(45, confidence - 10)
|
||||
rationale += " (Caution: Very steep decline, potential falling knife)."
|
||||
|
||||
return {"action": action, "confidence": confidence, "rationale": rationale}
|
||||
|
||||
# If no specific BUY signal, default to HOLD
|
||||
return {"action": action, "confidence": confidence, "rationale": rationale}
|
||||
88
src/strategies/v20260220_210244_evolved.py
Normal file
88
src/strategies/v20260220_210244_evolved.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Auto-generated strategy: v20260220_210244
|
||||
|
||||
Generated at: 2026-02-20T21:02:44.387355+00:00
|
||||
Rationale: Auto-evolved from 6 failures. Primary failure markets: ['US_AMEX', 'US_NYSE', 'US_NASDAQ']. Average loss: -194.69
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Any
|
||||
from src.strategies.base import BaseStrategy
|
||||
|
||||
|
||||
class Strategy_v20260220_210244(BaseStrategy):
|
||||
"""Strategy: v20260220_210244"""
|
||||
|
||||
def evaluate(self, market_data: dict[str, Any]) -> dict[str, Any]:
|
||||
from datetime import datetime
|
||||
|
||||
# Extract required data points safely
|
||||
current_price = market_data.get("current_price")
|
||||
price_change_pct = market_data.get("price_change_pct")
|
||||
volume_ratio = market_data.get("volume_ratio")
|
||||
rsi = market_data.get("rsi")
|
||||
timestamp_str = market_data.get("timestamp")
|
||||
market_name = market_data.get("market")
|
||||
stock_code = market_data.get("stock_code", "UNKNOWN")
|
||||
|
||||
# Default action is HOLD with conservative confidence and rationale
|
||||
action = "HOLD"
|
||||
confidence = 50
|
||||
rationale = f"No strong BUY signal for {stock_code} or awaiting more favorable conditions after avoiding known failure patterns."
|
||||
|
||||
# --- 1. Failure Pattern Avoidance Filters ---
|
||||
|
||||
# A. Avoid low-priced (penny) stocks
|
||||
if current_price is not None and current_price < 5.0:
|
||||
return {
|
||||
"action": "HOLD",
|
||||
"confidence": 50,
|
||||
"rationale": f"AVOID {stock_code}: Stock price (${current_price:.2f}) is below minimum threshold ($5.00) for BUY action. Identified past failures on highly volatile, low-priced stocks."
|
||||
}
|
||||
|
||||
# B. Avoid initiating BUY trades during identified high-volatility hours
|
||||
if timestamp_str:
|
||||
try:
|
||||
trade_hour = datetime.fromisoformat(timestamp_str).hour
|
||||
if trade_hour in [14, 20]:
|
||||
return {
|
||||
"action": "HOLD",
|
||||
"confidence": 50,
|
||||
"rationale": f"AVOID {stock_code}: Trading during historically volatile hour ({trade_hour} UTC) where previous BUYs resulted in losses. Prefer to observe market stability."
|
||||
}
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# C. Be cautious with extreme momentum spikes
|
||||
if volume_ratio is not None and price_change_pct is not None:
|
||||
if volume_ratio >= 9.0 and price_change_pct >= 15.0:
|
||||
return {
|
||||
"action": "HOLD",
|
||||
"confidence": 50,
|
||||
"rationale": f"AVOID {stock_code}: Extreme short-term momentum detected (price change: +{price_change_pct:.2f}%, volume ratio: {volume_ratio:.1f}x). Historical failures indicate buying into such rapid spikes often leads to reversals."
|
||||
}
|
||||
|
||||
# D. Be cautious with "oversold" signals without further confirmation
|
||||
if rsi is not None and rsi < 30:
|
||||
return {
|
||||
"action": "HOLD",
|
||||
"confidence": 50,
|
||||
"rationale": f"AVOID {stock_code}: Oversold signal (RSI={rsi:.1f}) detected. While often a BUY signal, historical failures on similar 'oversold' trades suggest waiting for stronger confirmation."
|
||||
}
|
||||
|
||||
# --- 2. Improved BUY Signal Generation ---
|
||||
if volume_ratio is not None and 2.0 <= volume_ratio < 9.0 and \
|
||||
price_change_pct is not None and 2.0 <= price_change_pct < 15.0:
|
||||
|
||||
action = "BUY"
|
||||
confidence = 70
|
||||
rationale = f"BUY {stock_code}: Moderate momentum detected (price change: +{price_change_pct:.2f}%, volume ratio: {volume_ratio:.1f}x). Passed filters for price and extreme momentum, avoiding past failure patterns."
|
||||
|
||||
if market_name in ["US_AMEX", "US_NASDAQ"]:
|
||||
confidence = max(60, confidence - 5)
|
||||
rationale += f" Adjusted confidence for {market_name} market characteristics."
|
||||
elif market_name == "US_NYSE":
|
||||
confidence = max(65, confidence)
|
||||
|
||||
confidence = max(50, min(85, confidence))
|
||||
|
||||
return {"action": action, "confidence": confidence, "rationale": rationale}
|
||||
0
src/strategy/__init__.py
Normal file
0
src/strategy/__init__.py
Normal file
184
src/strategy/models.py
Normal file
184
src/strategy/models.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""Pydantic models for pre-market scenario planning.
|
||||
|
||||
Defines the data contracts for the proactive strategy system:
|
||||
- AI generates DayPlaybook before market open (structured JSON scenarios)
|
||||
- Local ScenarioEngine matches conditions during market hours (no API calls)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, date, datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class ScenarioAction(str, Enum):
|
||||
"""Actions that can be taken by scenarios."""
|
||||
|
||||
BUY = "BUY"
|
||||
SELL = "SELL"
|
||||
HOLD = "HOLD"
|
||||
REDUCE_ALL = "REDUCE_ALL"
|
||||
|
||||
|
||||
class MarketOutlook(str, Enum):
|
||||
"""AI's assessment of market direction."""
|
||||
|
||||
BULLISH = "bullish"
|
||||
NEUTRAL_TO_BULLISH = "neutral_to_bullish"
|
||||
NEUTRAL = "neutral"
|
||||
NEUTRAL_TO_BEARISH = "neutral_to_bearish"
|
||||
BEARISH = "bearish"
|
||||
|
||||
|
||||
class PlaybookStatus(str, Enum):
|
||||
"""Lifecycle status of a playbook."""
|
||||
|
||||
PENDING = "pending"
|
||||
READY = "ready"
|
||||
FAILED = "failed"
|
||||
EXPIRED = "expired"
|
||||
|
||||
|
||||
class StockCondition(BaseModel):
|
||||
"""Condition fields for scenario matching (all optional, AND-combined).
|
||||
|
||||
The ScenarioEngine evaluates all non-None fields as AND conditions.
|
||||
A condition matches only if ALL specified fields are satisfied.
|
||||
|
||||
Technical indicator fields:
|
||||
rsi_below / rsi_above — RSI threshold
|
||||
volume_ratio_above / volume_ratio_below — volume vs previous day
|
||||
price_above / price_below — absolute price level
|
||||
price_change_pct_above / price_change_pct_below — intraday % change
|
||||
|
||||
Position-aware fields (require market_data enrichment from open position):
|
||||
unrealized_pnl_pct_above — matches if unrealized P&L > threshold (e.g. 3.0 → +3%)
|
||||
unrealized_pnl_pct_below — matches if unrealized P&L < threshold (e.g. -2.0 → -2%)
|
||||
holding_days_above — matches if position held for more than N days
|
||||
holding_days_below — matches if position held for fewer than N days
|
||||
"""
|
||||
|
||||
rsi_below: float | None = None
|
||||
rsi_above: float | None = None
|
||||
volume_ratio_above: float | None = None
|
||||
volume_ratio_below: float | None = None
|
||||
price_above: float | None = None
|
||||
price_below: float | None = None
|
||||
price_change_pct_above: float | None = None
|
||||
price_change_pct_below: float | None = None
|
||||
unrealized_pnl_pct_above: float | None = None
|
||||
unrealized_pnl_pct_below: float | None = None
|
||||
holding_days_above: int | None = None
|
||||
holding_days_below: int | None = None
|
||||
|
||||
def has_any_condition(self) -> bool:
|
||||
"""Check if at least one condition field is set."""
|
||||
return any(
|
||||
v is not None
|
||||
for v in (
|
||||
self.rsi_below,
|
||||
self.rsi_above,
|
||||
self.volume_ratio_above,
|
||||
self.volume_ratio_below,
|
||||
self.price_above,
|
||||
self.price_below,
|
||||
self.price_change_pct_above,
|
||||
self.price_change_pct_below,
|
||||
self.unrealized_pnl_pct_above,
|
||||
self.unrealized_pnl_pct_below,
|
||||
self.holding_days_above,
|
||||
self.holding_days_below,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class StockScenario(BaseModel):
|
||||
"""A single condition-action rule for one stock."""
|
||||
|
||||
condition: StockCondition
|
||||
action: ScenarioAction
|
||||
confidence: int = Field(ge=0, le=100)
|
||||
allocation_pct: float = Field(ge=0, le=100, default=10.0)
|
||||
stop_loss_pct: float = Field(le=0, default=-2.0)
|
||||
take_profit_pct: float = Field(ge=0, default=3.0)
|
||||
rationale: str = ""
|
||||
|
||||
|
||||
class StockPlaybook(BaseModel):
|
||||
"""All scenarios for a single stock (ordered by priority)."""
|
||||
|
||||
stock_code: str
|
||||
stock_name: str = ""
|
||||
scenarios: list[StockScenario] = Field(min_length=1)
|
||||
|
||||
|
||||
class GlobalRule(BaseModel):
|
||||
"""Portfolio-level rule (checked before stock-level scenarios)."""
|
||||
|
||||
condition: str # e.g. "portfolio_pnl_pct < -2.0"
|
||||
action: ScenarioAction
|
||||
rationale: str = ""
|
||||
|
||||
|
||||
class CrossMarketContext(BaseModel):
|
||||
"""Summary of another market's state for cross-market awareness."""
|
||||
|
||||
market: str # e.g. "US" or "KR"
|
||||
date: str
|
||||
total_pnl: float = 0.0
|
||||
win_rate: float = 0.0
|
||||
index_change_pct: float = 0.0 # e.g. KOSPI or S&P500 change
|
||||
key_events: list[str] = Field(default_factory=list)
|
||||
lessons: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DayPlaybook(BaseModel):
|
||||
"""Complete playbook for a single trading day in a single market.
|
||||
|
||||
Generated by PreMarketPlanner (1 Gemini call per market per day).
|
||||
Consumed by ScenarioEngine during market hours (0 API calls).
|
||||
"""
|
||||
|
||||
date: date
|
||||
market: str # "KR" or "US"
|
||||
market_outlook: MarketOutlook = MarketOutlook.NEUTRAL
|
||||
generated_at: str = "" # ISO timestamp
|
||||
gemini_model: str = ""
|
||||
token_count: int = 0
|
||||
global_rules: list[GlobalRule] = Field(default_factory=list)
|
||||
stock_playbooks: list[StockPlaybook] = Field(default_factory=list)
|
||||
default_action: ScenarioAction = ScenarioAction.HOLD
|
||||
context_summary: dict = Field(default_factory=dict)
|
||||
cross_market: CrossMarketContext | None = None
|
||||
|
||||
@field_validator("stock_playbooks")
|
||||
@classmethod
|
||||
def validate_unique_stocks(cls, v: list[StockPlaybook]) -> list[StockPlaybook]:
|
||||
codes = [pb.stock_code for pb in v]
|
||||
if len(codes) != len(set(codes)):
|
||||
raise ValueError("Duplicate stock codes in playbook")
|
||||
return v
|
||||
|
||||
def get_stock_playbook(self, stock_code: str) -> StockPlaybook | None:
|
||||
"""Find the playbook for a specific stock."""
|
||||
for pb in self.stock_playbooks:
|
||||
if pb.stock_code == stock_code:
|
||||
return pb
|
||||
return None
|
||||
|
||||
@property
|
||||
def scenario_count(self) -> int:
|
||||
"""Total number of scenarios across all stocks."""
|
||||
return sum(len(pb.scenarios) for pb in self.stock_playbooks)
|
||||
|
||||
@property
|
||||
def stock_count(self) -> int:
|
||||
"""Number of stocks with scenarios."""
|
||||
return len(self.stock_playbooks)
|
||||
|
||||
def model_post_init(self, __context: object) -> None:
|
||||
"""Set generated_at if not provided."""
|
||||
if not self.generated_at:
|
||||
self.generated_at = datetime.now(UTC).isoformat()
|
||||
184
src/strategy/playbook_store.py
Normal file
184
src/strategy/playbook_store.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""Playbook persistence layer — CRUD for DayPlaybook in SQLite.
|
||||
|
||||
Stores and retrieves market-specific daily playbooks with JSON serialization.
|
||||
Designed for the pre-market strategy system (one playbook per market per day).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
from datetime import date
|
||||
|
||||
from src.strategy.models import DayPlaybook, PlaybookStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PlaybookStore:
|
||||
"""CRUD operations for DayPlaybook persistence."""
|
||||
|
||||
def __init__(self, conn: sqlite3.Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def save(self, playbook: DayPlaybook) -> int:
|
||||
"""Save or replace a playbook for a given date+market.
|
||||
|
||||
Uses INSERT OR REPLACE to enforce UNIQUE(date, market).
|
||||
|
||||
Returns:
|
||||
The row id of the inserted/replaced record.
|
||||
"""
|
||||
playbook_json = playbook.model_dump_json()
|
||||
cursor = self._conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO playbooks
|
||||
(date, market, status, playbook_json, generated_at,
|
||||
token_count, scenario_count, match_count)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
playbook.date.isoformat(),
|
||||
playbook.market,
|
||||
PlaybookStatus.READY.value,
|
||||
playbook_json,
|
||||
playbook.generated_at,
|
||||
playbook.token_count,
|
||||
playbook.scenario_count,
|
||||
0,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
row_id = cursor.lastrowid or 0
|
||||
logger.info(
|
||||
"Saved playbook for %s/%s (%d stocks, %d scenarios)",
|
||||
playbook.date, playbook.market,
|
||||
playbook.stock_count, playbook.scenario_count,
|
||||
)
|
||||
return row_id
|
||||
|
||||
def load(self, target_date: date, market: str) -> DayPlaybook | None:
|
||||
"""Load a playbook for a specific date and market.
|
||||
|
||||
Returns:
|
||||
DayPlaybook if found, None otherwise.
|
||||
"""
|
||||
row = self._conn.execute(
|
||||
"SELECT playbook_json FROM playbooks WHERE date = ? AND market = ?",
|
||||
(target_date.isoformat(), market),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return DayPlaybook.model_validate_json(row[0])
|
||||
|
||||
def get_status(self, target_date: date, market: str) -> PlaybookStatus | None:
|
||||
"""Get the status of a playbook without deserializing the full JSON."""
|
||||
row = self._conn.execute(
|
||||
"SELECT status FROM playbooks WHERE date = ? AND market = ?",
|
||||
(target_date.isoformat(), market),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return PlaybookStatus(row[0])
|
||||
|
||||
def update_status(self, target_date: date, market: str, status: PlaybookStatus) -> bool:
|
||||
"""Update the status of a playbook.
|
||||
|
||||
Returns:
|
||||
True if a row was updated, False if not found.
|
||||
"""
|
||||
cursor = self._conn.execute(
|
||||
"UPDATE playbooks SET status = ? WHERE date = ? AND market = ?",
|
||||
(status.value, target_date.isoformat(), market),
|
||||
)
|
||||
self._conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def increment_match_count(self, target_date: date, market: str) -> bool:
|
||||
"""Increment the match_count for tracking scenario hits during the day.
|
||||
|
||||
Returns:
|
||||
True if a row was updated, False if not found.
|
||||
"""
|
||||
cursor = self._conn.execute(
|
||||
"UPDATE playbooks SET match_count = match_count + 1 WHERE date = ? AND market = ?",
|
||||
(target_date.isoformat(), market),
|
||||
)
|
||||
self._conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def get_stats(self, target_date: date, market: str) -> dict | None:
|
||||
"""Get playbook stats without full deserialization.
|
||||
|
||||
Returns:
|
||||
Dict with status, token_count, scenario_count, match_count, or None.
|
||||
"""
|
||||
row = self._conn.execute(
|
||||
"""
|
||||
SELECT status, token_count, scenario_count, match_count, generated_at
|
||||
FROM playbooks WHERE date = ? AND market = ?
|
||||
""",
|
||||
(target_date.isoformat(), market),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return {
|
||||
"status": row[0],
|
||||
"token_count": row[1],
|
||||
"scenario_count": row[2],
|
||||
"match_count": row[3],
|
||||
"generated_at": row[4],
|
||||
}
|
||||
|
||||
def list_recent(self, market: str | None = None, limit: int = 7) -> list[dict]:
|
||||
"""List recent playbooks with summary info.
|
||||
|
||||
Args:
|
||||
market: Filter by market code. None for all markets.
|
||||
limit: Max number of results.
|
||||
|
||||
Returns:
|
||||
List of dicts with date, market, status, scenario_count, match_count.
|
||||
"""
|
||||
if market is not None:
|
||||
rows = self._conn.execute(
|
||||
"""
|
||||
SELECT date, market, status, scenario_count, match_count
|
||||
FROM playbooks WHERE market = ?
|
||||
ORDER BY date DESC LIMIT ?
|
||||
""",
|
||||
(market, limit),
|
||||
).fetchall()
|
||||
else:
|
||||
rows = self._conn.execute(
|
||||
"""
|
||||
SELECT date, market, status, scenario_count, match_count
|
||||
FROM playbooks
|
||||
ORDER BY date DESC LIMIT ?
|
||||
""",
|
||||
(limit,),
|
||||
).fetchall()
|
||||
return [
|
||||
{
|
||||
"date": row[0],
|
||||
"market": row[1],
|
||||
"status": row[2],
|
||||
"scenario_count": row[3],
|
||||
"match_count": row[4],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def delete(self, target_date: date, market: str) -> bool:
|
||||
"""Delete a playbook.
|
||||
|
||||
Returns:
|
||||
True if a row was deleted, False if not found.
|
||||
"""
|
||||
cursor = self._conn.execute(
|
||||
"DELETE FROM playbooks WHERE date = ? AND market = ?",
|
||||
(target_date.isoformat(), market),
|
||||
)
|
||||
self._conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
620
src/strategy/pre_market_planner.py
Normal file
620
src/strategy/pre_market_planner.py
Normal file
@@ -0,0 +1,620 @@
|
||||
"""Pre-market planner — generates DayPlaybook via Gemini before market open.
|
||||
|
||||
One Gemini API call per market per day. Candidates come from SmartVolatilityScanner.
|
||||
On failure, returns a smart rule-based fallback playbook that uses scanner signals
|
||||
(momentum/oversold) to generate BUY conditions, avoiding the all-HOLD problem.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import date, timedelta
|
||||
from typing import Any
|
||||
|
||||
from src.analysis.smart_scanner import ScanCandidate
|
||||
from src.brain.context_selector import ContextSelector, DecisionType
|
||||
from src.brain.gemini_client import GeminiClient
|
||||
from src.config import Settings
|
||||
from src.context.store import ContextLayer, ContextStore
|
||||
from src.strategy.models import (
|
||||
CrossMarketContext,
|
||||
DayPlaybook,
|
||||
GlobalRule,
|
||||
MarketOutlook,
|
||||
ScenarioAction,
|
||||
StockCondition,
|
||||
StockPlaybook,
|
||||
StockScenario,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Mapping from string to MarketOutlook enum
|
||||
_OUTLOOK_MAP: dict[str, MarketOutlook] = {
|
||||
"bullish": MarketOutlook.BULLISH,
|
||||
"neutral_to_bullish": MarketOutlook.NEUTRAL_TO_BULLISH,
|
||||
"neutral": MarketOutlook.NEUTRAL,
|
||||
"neutral_to_bearish": MarketOutlook.NEUTRAL_TO_BEARISH,
|
||||
"bearish": MarketOutlook.BEARISH,
|
||||
}
|
||||
|
||||
_ACTION_MAP: dict[str, ScenarioAction] = {
|
||||
"BUY": ScenarioAction.BUY,
|
||||
"SELL": ScenarioAction.SELL,
|
||||
"HOLD": ScenarioAction.HOLD,
|
||||
"REDUCE_ALL": ScenarioAction.REDUCE_ALL,
|
||||
}
|
||||
|
||||
|
||||
class PreMarketPlanner:
|
||||
"""Generates a DayPlaybook by calling Gemini once before market open.
|
||||
|
||||
Flow:
|
||||
1. Collect strategic context (L5-L7) + cross-market context
|
||||
2. Build a structured prompt with scan candidates
|
||||
3. Call Gemini for JSON scenario generation
|
||||
4. Parse and validate response into DayPlaybook
|
||||
5. On failure → defensive playbook (HOLD everything)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gemini_client: GeminiClient,
|
||||
context_store: ContextStore,
|
||||
context_selector: ContextSelector,
|
||||
settings: Settings,
|
||||
) -> None:
|
||||
self._gemini = gemini_client
|
||||
self._context_store = context_store
|
||||
self._context_selector = context_selector
|
||||
self._settings = settings
|
||||
|
||||
async def generate_playbook(
|
||||
self,
|
||||
market: str,
|
||||
candidates: list[ScanCandidate],
|
||||
today: date | None = None,
|
||||
current_holdings: list[dict] | None = None,
|
||||
) -> DayPlaybook:
|
||||
"""Generate a DayPlaybook for a market using Gemini.
|
||||
|
||||
Args:
|
||||
market: Market code ("KR" or "US")
|
||||
candidates: Stock candidates from SmartVolatilityScanner
|
||||
today: Override date (defaults to date.today()). Use market-local date.
|
||||
current_holdings: Currently held positions with entry_price and unrealized_pnl_pct.
|
||||
Each dict: {"stock_code": str, "name": str, "qty": int,
|
||||
"entry_price": float, "unrealized_pnl_pct": float,
|
||||
"holding_days": int}
|
||||
|
||||
Returns:
|
||||
DayPlaybook with scenarios. Empty/defensive if no candidates or failure.
|
||||
"""
|
||||
if today is None:
|
||||
today = date.today()
|
||||
|
||||
if not candidates:
|
||||
logger.info("No candidates for %s — returning empty playbook", market)
|
||||
return self._empty_playbook(today, market)
|
||||
|
||||
try:
|
||||
# 1. Gather context
|
||||
context_data = self._gather_context()
|
||||
self_market_scorecard = self.build_self_market_scorecard(market, today)
|
||||
cross_market = self.build_cross_market_context(market, today)
|
||||
|
||||
# 2. Build prompt
|
||||
prompt = self._build_prompt(
|
||||
market,
|
||||
candidates,
|
||||
context_data,
|
||||
self_market_scorecard,
|
||||
cross_market,
|
||||
current_holdings=current_holdings,
|
||||
)
|
||||
|
||||
# 3. Call Gemini
|
||||
market_data = {
|
||||
"stock_code": "PLANNER",
|
||||
"current_price": 0,
|
||||
"prompt_override": prompt,
|
||||
}
|
||||
decision = await self._gemini.decide(market_data)
|
||||
|
||||
# 4. Parse response
|
||||
playbook = self._parse_response(
|
||||
decision.rationale, today, market, candidates, cross_market,
|
||||
current_holdings=current_holdings,
|
||||
)
|
||||
playbook_with_tokens = playbook.model_copy(
|
||||
update={"token_count": decision.token_count}
|
||||
)
|
||||
logger.info(
|
||||
"Generated playbook for %s: %d stocks, %d scenarios, %d tokens",
|
||||
market,
|
||||
playbook_with_tokens.stock_count,
|
||||
playbook_with_tokens.scenario_count,
|
||||
playbook_with_tokens.token_count,
|
||||
)
|
||||
return playbook_with_tokens
|
||||
|
||||
except Exception:
|
||||
logger.exception("Playbook generation failed for %s", market)
|
||||
if self._settings.DEFENSIVE_PLAYBOOK_ON_FAILURE:
|
||||
return self._smart_fallback_playbook(today, market, candidates, self._settings)
|
||||
return self._empty_playbook(today, market)
|
||||
|
||||
def build_cross_market_context(
|
||||
self, target_market: str, today: date | None = None,
|
||||
) -> CrossMarketContext | None:
|
||||
"""Build cross-market context from the other market's L6 data.
|
||||
|
||||
KR planner → reads US scorecard from previous night.
|
||||
US planner → reads KR scorecard from today.
|
||||
|
||||
Args:
|
||||
target_market: The market being planned ("KR" or "US")
|
||||
today: Override date (defaults to date.today()). Use market-local date.
|
||||
"""
|
||||
other_market = "US" if target_market == "KR" else "KR"
|
||||
if today is None:
|
||||
today = date.today()
|
||||
timeframe_date = today - timedelta(days=1) if target_market == "KR" else today
|
||||
timeframe = timeframe_date.isoformat()
|
||||
|
||||
scorecard_key = f"scorecard_{other_market}"
|
||||
scorecard_data = self._context_store.get_context(
|
||||
ContextLayer.L6_DAILY, timeframe, scorecard_key
|
||||
)
|
||||
|
||||
if scorecard_data is None:
|
||||
logger.debug("No cross-market scorecard found for %s", other_market)
|
||||
return None
|
||||
|
||||
if isinstance(scorecard_data, str):
|
||||
try:
|
||||
scorecard_data = json.loads(scorecard_data)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None
|
||||
|
||||
if not isinstance(scorecard_data, dict):
|
||||
return None
|
||||
|
||||
return CrossMarketContext(
|
||||
market=other_market,
|
||||
date=timeframe,
|
||||
total_pnl=float(scorecard_data.get("total_pnl", 0.0)),
|
||||
win_rate=float(scorecard_data.get("win_rate", 0.0)),
|
||||
index_change_pct=float(scorecard_data.get("index_change_pct", 0.0)),
|
||||
key_events=scorecard_data.get("key_events", []),
|
||||
lessons=scorecard_data.get("lessons", []),
|
||||
)
|
||||
|
||||
def build_self_market_scorecard(
|
||||
self, market: str, today: date | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Build previous-day scorecard for the same market."""
|
||||
if today is None:
|
||||
today = date.today()
|
||||
timeframe = (today - timedelta(days=1)).isoformat()
|
||||
scorecard_key = f"scorecard_{market}"
|
||||
scorecard_data = self._context_store.get_context(
|
||||
ContextLayer.L6_DAILY, timeframe, scorecard_key
|
||||
)
|
||||
|
||||
if scorecard_data is None:
|
||||
return None
|
||||
|
||||
if isinstance(scorecard_data, str):
|
||||
try:
|
||||
scorecard_data = json.loads(scorecard_data)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None
|
||||
|
||||
if not isinstance(scorecard_data, dict):
|
||||
return None
|
||||
|
||||
return {
|
||||
"date": timeframe,
|
||||
"total_pnl": float(scorecard_data.get("total_pnl", 0.0)),
|
||||
"win_rate": float(scorecard_data.get("win_rate", 0.0)),
|
||||
"lessons": scorecard_data.get("lessons", []),
|
||||
}
|
||||
|
||||
def _gather_context(self) -> dict[str, Any]:
|
||||
"""Gather strategic context using ContextSelector."""
|
||||
layers = self._context_selector.select_layers(
|
||||
decision_type=DecisionType.STRATEGIC,
|
||||
include_realtime=True,
|
||||
)
|
||||
return self._context_selector.get_context_data(layers, max_items_per_layer=10)
|
||||
|
||||
def _build_prompt(
|
||||
self,
|
||||
market: str,
|
||||
candidates: list[ScanCandidate],
|
||||
context_data: dict[str, Any],
|
||||
self_market_scorecard: dict[str, Any] | None,
|
||||
cross_market: CrossMarketContext | None,
|
||||
current_holdings: list[dict] | None = None,
|
||||
) -> str:
|
||||
"""Build a structured prompt for Gemini to generate scenario JSON."""
|
||||
max_scenarios = self._settings.MAX_SCENARIOS_PER_STOCK
|
||||
|
||||
candidates_text = "\n".join(
|
||||
f" - {c.stock_code} ({c.name}): price={c.price}, "
|
||||
f"RSI={c.rsi:.1f}, volume_ratio={c.volume_ratio:.1f}, "
|
||||
f"signal={c.signal}, score={c.score:.1f}"
|
||||
for c in candidates
|
||||
)
|
||||
|
||||
holdings_text = ""
|
||||
if current_holdings:
|
||||
lines = []
|
||||
for h in current_holdings:
|
||||
code = h.get("stock_code", "")
|
||||
name = h.get("name", "")
|
||||
qty = h.get("qty", 0)
|
||||
entry_price = h.get("entry_price", 0.0)
|
||||
pnl_pct = h.get("unrealized_pnl_pct", 0.0)
|
||||
holding_days = h.get("holding_days", 0)
|
||||
lines.append(
|
||||
f" - {code} ({name}): {qty}주 @ {entry_price:,.0f}, "
|
||||
f"미실현손익 {pnl_pct:+.2f}%, 보유 {holding_days}일"
|
||||
)
|
||||
holdings_text = (
|
||||
"\n## Current Holdings (보유 중 — SELL/HOLD 전략 고려 필요)\n"
|
||||
+ "\n".join(lines)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
cross_market_text = ""
|
||||
if cross_market:
|
||||
cross_market_text = (
|
||||
f"\n## Other Market ({cross_market.market}) Summary\n"
|
||||
f"- P&L: {cross_market.total_pnl:+.2f}%\n"
|
||||
f"- Win Rate: {cross_market.win_rate:.0f}%\n"
|
||||
f"- Index Change: {cross_market.index_change_pct:+.2f}%\n"
|
||||
)
|
||||
if cross_market.lessons:
|
||||
cross_market_text += f"- Lessons: {'; '.join(cross_market.lessons[:3])}\n"
|
||||
|
||||
self_market_text = ""
|
||||
if self_market_scorecard:
|
||||
self_market_text = (
|
||||
f"\n## My Market Previous Day ({market})\n"
|
||||
f"- Date: {self_market_scorecard['date']}\n"
|
||||
f"- P&L: {self_market_scorecard['total_pnl']:+.2f}%\n"
|
||||
f"- Win Rate: {self_market_scorecard['win_rate']:.0f}%\n"
|
||||
)
|
||||
lessons = self_market_scorecard.get("lessons", [])
|
||||
if lessons:
|
||||
self_market_text += f"- Lessons: {'; '.join(lessons[:3])}\n"
|
||||
|
||||
context_text = ""
|
||||
if context_data:
|
||||
context_text = "\n## Strategic Context\n"
|
||||
for layer_name, layer_data in context_data.items():
|
||||
if layer_data:
|
||||
context_text += f"### {layer_name}\n"
|
||||
for key, value in list(layer_data.items())[:5]:
|
||||
context_text += f" - {key}: {value}\n"
|
||||
|
||||
holdings_instruction = ""
|
||||
if current_holdings:
|
||||
holding_codes = [h.get("stock_code", "") for h in current_holdings]
|
||||
holdings_instruction = (
|
||||
f"- Also include SELL/HOLD scenarios for held stocks: "
|
||||
f"{', '.join(holding_codes)} "
|
||||
f"(even if not in candidates list)\n"
|
||||
)
|
||||
|
||||
return (
|
||||
f"You are a pre-market trading strategist for the {market} market.\n"
|
||||
f"Generate structured trading scenarios for today.\n\n"
|
||||
f"## Candidates (from volatility scanner)\n{candidates_text}\n"
|
||||
f"{holdings_text}"
|
||||
f"{self_market_text}"
|
||||
f"{cross_market_text}"
|
||||
f"{context_text}\n"
|
||||
f"## Instructions\n"
|
||||
f"Return a JSON object with this exact structure:\n"
|
||||
f'{{\n'
|
||||
f' "market_outlook": "bullish|neutral_to_bullish|neutral'
|
||||
f'|neutral_to_bearish|bearish",\n'
|
||||
f' "global_rules": [\n'
|
||||
f' {{"condition": "portfolio_pnl_pct < -2.0",'
|
||||
f' "action": "REDUCE_ALL", "rationale": "..."}}\n'
|
||||
f' ],\n'
|
||||
f' "stocks": [\n'
|
||||
f' {{\n'
|
||||
f' "stock_code": "...",\n'
|
||||
f' "scenarios": [\n'
|
||||
f' {{\n'
|
||||
f' "condition": {{"rsi_below": 30, "volume_ratio_above": 2.0,'
|
||||
f' "unrealized_pnl_pct_above": 3.0, "holding_days_above": 5}},\n'
|
||||
f' "action": "BUY|SELL|HOLD",\n'
|
||||
f' "confidence": 85,\n'
|
||||
f' "allocation_pct": 10.0,\n'
|
||||
f' "stop_loss_pct": -2.0,\n'
|
||||
f' "take_profit_pct": 3.0,\n'
|
||||
f' "rationale": "..."\n'
|
||||
f' }}\n'
|
||||
f' ]\n'
|
||||
f' }}\n'
|
||||
f' ]\n'
|
||||
f'}}\n\n'
|
||||
f"Rules:\n"
|
||||
f"- Max {max_scenarios} scenarios per stock\n"
|
||||
f"- Candidates list is the primary source for BUY candidates\n"
|
||||
f"{holdings_instruction}"
|
||||
f"- Confidence 0-100 (80+ for actionable trades)\n"
|
||||
f"- stop_loss_pct must be <= 0, take_profit_pct must be >= 0\n"
|
||||
f"- Return ONLY the JSON, no markdown fences or explanation\n"
|
||||
)
|
||||
|
||||
def _parse_response(
|
||||
self,
|
||||
response_text: str,
|
||||
today: date,
|
||||
market: str,
|
||||
candidates: list[ScanCandidate],
|
||||
cross_market: CrossMarketContext | None,
|
||||
current_holdings: list[dict] | None = None,
|
||||
) -> DayPlaybook:
|
||||
"""Parse Gemini's JSON response into a validated DayPlaybook."""
|
||||
cleaned = self._extract_json(response_text)
|
||||
data = json.loads(cleaned)
|
||||
|
||||
valid_codes = {c.stock_code for c in candidates}
|
||||
# Holdings are also valid — AI may generate SELL/HOLD scenarios for them
|
||||
if current_holdings:
|
||||
for h in current_holdings:
|
||||
code = h.get("stock_code", "")
|
||||
if code:
|
||||
valid_codes.add(code)
|
||||
|
||||
# Parse market outlook
|
||||
outlook_str = data.get("market_outlook", "neutral")
|
||||
market_outlook = _OUTLOOK_MAP.get(outlook_str, MarketOutlook.NEUTRAL)
|
||||
|
||||
# Parse global rules
|
||||
global_rules = []
|
||||
for rule_data in data.get("global_rules", []):
|
||||
action_str = rule_data.get("action", "HOLD")
|
||||
action = _ACTION_MAP.get(action_str, ScenarioAction.HOLD)
|
||||
global_rules.append(
|
||||
GlobalRule(
|
||||
condition=rule_data.get("condition", ""),
|
||||
action=action,
|
||||
rationale=rule_data.get("rationale", ""),
|
||||
)
|
||||
)
|
||||
|
||||
# Parse stock playbooks
|
||||
stock_playbooks = []
|
||||
max_scenarios = self._settings.MAX_SCENARIOS_PER_STOCK
|
||||
for stock_data in data.get("stocks", []):
|
||||
code = stock_data.get("stock_code", "")
|
||||
if code not in valid_codes:
|
||||
logger.warning("Gemini returned unknown stock %s — skipping", code)
|
||||
continue
|
||||
|
||||
scenarios = []
|
||||
for sc_data in stock_data.get("scenarios", [])[:max_scenarios]:
|
||||
scenario = self._parse_scenario(sc_data)
|
||||
if scenario:
|
||||
scenarios.append(scenario)
|
||||
|
||||
if scenarios:
|
||||
stock_playbooks.append(
|
||||
StockPlaybook(
|
||||
stock_code=code,
|
||||
scenarios=scenarios,
|
||||
)
|
||||
)
|
||||
|
||||
return DayPlaybook(
|
||||
date=today,
|
||||
market=market,
|
||||
market_outlook=market_outlook,
|
||||
global_rules=global_rules,
|
||||
stock_playbooks=stock_playbooks,
|
||||
cross_market=cross_market,
|
||||
)
|
||||
|
||||
def _parse_scenario(self, sc_data: dict) -> StockScenario | None:
|
||||
"""Parse a single scenario from JSON data. Returns None if invalid."""
|
||||
try:
|
||||
cond_data = sc_data.get("condition", {})
|
||||
condition = StockCondition(
|
||||
rsi_below=cond_data.get("rsi_below"),
|
||||
rsi_above=cond_data.get("rsi_above"),
|
||||
volume_ratio_above=cond_data.get("volume_ratio_above"),
|
||||
volume_ratio_below=cond_data.get("volume_ratio_below"),
|
||||
price_above=cond_data.get("price_above"),
|
||||
price_below=cond_data.get("price_below"),
|
||||
price_change_pct_above=cond_data.get("price_change_pct_above"),
|
||||
price_change_pct_below=cond_data.get("price_change_pct_below"),
|
||||
unrealized_pnl_pct_above=cond_data.get("unrealized_pnl_pct_above"),
|
||||
unrealized_pnl_pct_below=cond_data.get("unrealized_pnl_pct_below"),
|
||||
holding_days_above=cond_data.get("holding_days_above"),
|
||||
holding_days_below=cond_data.get("holding_days_below"),
|
||||
)
|
||||
|
||||
if not condition.has_any_condition():
|
||||
logger.warning("Scenario has no conditions — skipping")
|
||||
return None
|
||||
|
||||
action_str = sc_data.get("action", "HOLD")
|
||||
action = _ACTION_MAP.get(action_str, ScenarioAction.HOLD)
|
||||
|
||||
return StockScenario(
|
||||
condition=condition,
|
||||
action=action,
|
||||
confidence=int(sc_data.get("confidence", 50)),
|
||||
allocation_pct=float(sc_data.get("allocation_pct", 10.0)),
|
||||
stop_loss_pct=float(sc_data.get("stop_loss_pct", -2.0)),
|
||||
take_profit_pct=float(sc_data.get("take_profit_pct", 3.0)),
|
||||
rationale=sc_data.get("rationale", ""),
|
||||
)
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning("Failed to parse scenario: %s", e)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_json(text: str) -> str:
|
||||
"""Extract JSON from response, stripping markdown fences if present."""
|
||||
stripped = text.strip()
|
||||
if stripped.startswith("```"):
|
||||
# Remove first line (```json or ```) and last line (```)
|
||||
lines = stripped.split("\n")
|
||||
lines = lines[1:] # Remove opening fence
|
||||
if lines and lines[-1].strip() == "```":
|
||||
lines = lines[:-1]
|
||||
stripped = "\n".join(lines)
|
||||
return stripped.strip()
|
||||
|
||||
@staticmethod
|
||||
def _empty_playbook(today: date, market: str) -> DayPlaybook:
|
||||
"""Return an empty playbook (no stocks, no scenarios)."""
|
||||
return DayPlaybook(
|
||||
date=today,
|
||||
market=market,
|
||||
market_outlook=MarketOutlook.NEUTRAL,
|
||||
stock_playbooks=[],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _defensive_playbook(
|
||||
today: date,
|
||||
market: str,
|
||||
candidates: list[ScanCandidate],
|
||||
) -> DayPlaybook:
|
||||
"""Return a defensive playbook — HOLD everything with stop-loss ready."""
|
||||
stock_playbooks = [
|
||||
StockPlaybook(
|
||||
stock_code=c.stock_code,
|
||||
scenarios=[
|
||||
StockScenario(
|
||||
condition=StockCondition(price_change_pct_below=-3.0),
|
||||
action=ScenarioAction.SELL,
|
||||
confidence=90,
|
||||
stop_loss_pct=-3.0,
|
||||
rationale="Defensive stop-loss (planner failure)",
|
||||
),
|
||||
],
|
||||
)
|
||||
for c in candidates
|
||||
]
|
||||
return DayPlaybook(
|
||||
date=today,
|
||||
market=market,
|
||||
market_outlook=MarketOutlook.NEUTRAL_TO_BEARISH,
|
||||
default_action=ScenarioAction.HOLD,
|
||||
stock_playbooks=stock_playbooks,
|
||||
global_rules=[
|
||||
GlobalRule(
|
||||
condition="portfolio_pnl_pct < -2.0",
|
||||
action=ScenarioAction.REDUCE_ALL,
|
||||
rationale="Defensive: reduce on loss threshold",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _smart_fallback_playbook(
|
||||
today: date,
|
||||
market: str,
|
||||
candidates: list[ScanCandidate],
|
||||
settings: Settings,
|
||||
) -> DayPlaybook:
|
||||
"""Rule-based fallback playbook when Gemini is unavailable.
|
||||
|
||||
Uses scanner signals (RSI, volume_ratio) to generate meaningful BUY
|
||||
conditions instead of the all-SELL defensive playbook. Candidates are
|
||||
already pre-qualified by SmartVolatilityScanner, so we trust their
|
||||
signals and build actionable scenarios from them.
|
||||
|
||||
Scenario logic per candidate:
|
||||
- momentum signal: BUY when volume_ratio exceeds scanner threshold
|
||||
- oversold signal: BUY when RSI is below oversold threshold
|
||||
- always: SELL stop-loss at -3.0% as guard
|
||||
"""
|
||||
stock_playbooks = []
|
||||
for c in candidates:
|
||||
scenarios: list[StockScenario] = []
|
||||
|
||||
if c.signal == "momentum":
|
||||
scenarios.append(
|
||||
StockScenario(
|
||||
condition=StockCondition(
|
||||
volume_ratio_above=settings.VOL_MULTIPLIER,
|
||||
),
|
||||
action=ScenarioAction.BUY,
|
||||
confidence=80,
|
||||
allocation_pct=10.0,
|
||||
stop_loss_pct=-3.0,
|
||||
take_profit_pct=5.0,
|
||||
rationale=(
|
||||
f"Rule-based BUY: momentum signal, "
|
||||
f"volume={c.volume_ratio:.1f}x (fallback planner)"
|
||||
),
|
||||
)
|
||||
)
|
||||
elif c.signal == "oversold":
|
||||
scenarios.append(
|
||||
StockScenario(
|
||||
condition=StockCondition(
|
||||
rsi_below=settings.RSI_OVERSOLD_THRESHOLD,
|
||||
),
|
||||
action=ScenarioAction.BUY,
|
||||
confidence=80,
|
||||
allocation_pct=10.0,
|
||||
stop_loss_pct=-3.0,
|
||||
take_profit_pct=5.0,
|
||||
rationale=(
|
||||
f"Rule-based BUY: oversold signal, "
|
||||
f"RSI={c.rsi:.0f} (fallback planner)"
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Always add stop-loss guard
|
||||
scenarios.append(
|
||||
StockScenario(
|
||||
condition=StockCondition(price_change_pct_below=-3.0),
|
||||
action=ScenarioAction.SELL,
|
||||
confidence=90,
|
||||
stop_loss_pct=-3.0,
|
||||
rationale="Rule-based stop-loss (fallback planner)",
|
||||
)
|
||||
)
|
||||
|
||||
stock_playbooks.append(
|
||||
StockPlaybook(
|
||||
stock_code=c.stock_code,
|
||||
scenarios=scenarios,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Smart fallback playbook for %s: %d stocks with rule-based BUY/SELL conditions",
|
||||
market,
|
||||
len(stock_playbooks),
|
||||
)
|
||||
return DayPlaybook(
|
||||
date=today,
|
||||
market=market,
|
||||
market_outlook=MarketOutlook.NEUTRAL,
|
||||
default_action=ScenarioAction.HOLD,
|
||||
stock_playbooks=stock_playbooks,
|
||||
global_rules=[
|
||||
GlobalRule(
|
||||
condition="portfolio_pnl_pct < -2.0",
|
||||
action=ScenarioAction.REDUCE_ALL,
|
||||
rationale="Defensive: reduce on loss threshold",
|
||||
),
|
||||
],
|
||||
)
|
||||
305
src/strategy/scenario_engine.py
Normal file
305
src/strategy/scenario_engine.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""Local scenario engine for playbook execution.
|
||||
|
||||
Matches real-time market conditions against pre-defined scenarios
|
||||
without any API calls. Designed for sub-100ms execution.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from src.strategy.models import (
|
||||
DayPlaybook,
|
||||
GlobalRule,
|
||||
ScenarioAction,
|
||||
StockCondition,
|
||||
StockScenario,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScenarioMatch:
|
||||
"""Result of matching market conditions against scenarios."""
|
||||
|
||||
stock_code: str
|
||||
matched_scenario: StockScenario | None
|
||||
action: ScenarioAction
|
||||
confidence: int
|
||||
rationale: str
|
||||
global_rule_triggered: GlobalRule | None = None
|
||||
match_details: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ScenarioEngine:
|
||||
"""Evaluates playbook scenarios against real-time market data.
|
||||
|
||||
No API calls — pure Python condition matching.
|
||||
|
||||
Expected market_data keys: "rsi", "volume_ratio", "current_price", "price_change_pct".
|
||||
Callers must normalize data source keys to match this contract.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._warned_keys: set[str] = set()
|
||||
|
||||
@staticmethod
|
||||
def _safe_float(value: Any) -> float | None:
|
||||
"""Safely cast a value to float. Returns None on failure."""
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
def _warn_missing_key(self, key: str) -> None:
|
||||
"""Log a missing-key warning once per key per engine instance."""
|
||||
if key not in self._warned_keys:
|
||||
self._warned_keys.add(key)
|
||||
logger.warning("Condition requires '%s' but key missing from market_data", key)
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
playbook: DayPlaybook,
|
||||
stock_code: str,
|
||||
market_data: dict[str, Any],
|
||||
portfolio_data: dict[str, Any],
|
||||
) -> ScenarioMatch:
|
||||
"""Match market conditions to scenarios and return a decision.
|
||||
|
||||
Algorithm:
|
||||
1. Check global rules first (portfolio-level circuit breakers)
|
||||
2. Find the StockPlaybook for the given stock_code
|
||||
3. Iterate scenarios in order (first match wins)
|
||||
4. If no match, return playbook.default_action (HOLD)
|
||||
|
||||
Args:
|
||||
playbook: Today's DayPlaybook for this market
|
||||
stock_code: Stock ticker to evaluate
|
||||
market_data: Real-time market data (price, rsi, volume_ratio, etc.)
|
||||
portfolio_data: Portfolio state (pnl_pct, total_cash, etc.)
|
||||
|
||||
Returns:
|
||||
ScenarioMatch with the decision
|
||||
"""
|
||||
# 1. Check global rules
|
||||
triggered_rule = self.check_global_rules(playbook, portfolio_data)
|
||||
if triggered_rule is not None:
|
||||
logger.info(
|
||||
"Global rule triggered for %s: %s -> %s",
|
||||
stock_code,
|
||||
triggered_rule.condition,
|
||||
triggered_rule.action.value,
|
||||
)
|
||||
return ScenarioMatch(
|
||||
stock_code=stock_code,
|
||||
matched_scenario=None,
|
||||
action=triggered_rule.action,
|
||||
confidence=100,
|
||||
rationale=f"Global rule: {triggered_rule.rationale or triggered_rule.condition}",
|
||||
global_rule_triggered=triggered_rule,
|
||||
)
|
||||
|
||||
# 2. Find stock playbook
|
||||
stock_pb = playbook.get_stock_playbook(stock_code)
|
||||
if stock_pb is None:
|
||||
logger.debug("No playbook for %s — defaulting to %s", stock_code, playbook.default_action)
|
||||
return ScenarioMatch(
|
||||
stock_code=stock_code,
|
||||
matched_scenario=None,
|
||||
action=playbook.default_action,
|
||||
confidence=0,
|
||||
rationale=f"No scenarios defined for {stock_code}",
|
||||
)
|
||||
|
||||
# 3. Iterate scenarios (first match wins)
|
||||
for scenario in stock_pb.scenarios:
|
||||
if self.evaluate_condition(scenario.condition, market_data):
|
||||
logger.info(
|
||||
"Scenario matched for %s: %s (confidence=%d)",
|
||||
stock_code,
|
||||
scenario.action.value,
|
||||
scenario.confidence,
|
||||
)
|
||||
return ScenarioMatch(
|
||||
stock_code=stock_code,
|
||||
matched_scenario=scenario,
|
||||
action=scenario.action,
|
||||
confidence=scenario.confidence,
|
||||
rationale=scenario.rationale,
|
||||
match_details=self._build_match_details(scenario.condition, market_data),
|
||||
)
|
||||
|
||||
# 4. No match — default action
|
||||
logger.debug("No scenario matched for %s — defaulting to %s", stock_code, playbook.default_action)
|
||||
return ScenarioMatch(
|
||||
stock_code=stock_code,
|
||||
matched_scenario=None,
|
||||
action=playbook.default_action,
|
||||
confidence=0,
|
||||
rationale="No scenario conditions met — holding position",
|
||||
)
|
||||
|
||||
def check_global_rules(
|
||||
self,
|
||||
playbook: DayPlaybook,
|
||||
portfolio_data: dict[str, Any],
|
||||
) -> GlobalRule | None:
|
||||
"""Check portfolio-level rules. Returns first triggered rule or None."""
|
||||
for rule in playbook.global_rules:
|
||||
if self._evaluate_global_condition(rule.condition, portfolio_data):
|
||||
return rule
|
||||
return None
|
||||
|
||||
def evaluate_condition(
|
||||
self,
|
||||
condition: StockCondition,
|
||||
market_data: dict[str, Any],
|
||||
) -> bool:
|
||||
"""Evaluate all non-None fields in condition as AND.
|
||||
|
||||
Returns True only if ALL specified conditions are met.
|
||||
Empty condition (no fields set) returns False for safety.
|
||||
"""
|
||||
if not condition.has_any_condition():
|
||||
return False
|
||||
|
||||
checks: list[bool] = []
|
||||
|
||||
rsi = self._safe_float(market_data.get("rsi"))
|
||||
if condition.rsi_below is not None or condition.rsi_above is not None:
|
||||
if "rsi" not in market_data:
|
||||
self._warn_missing_key("rsi")
|
||||
if condition.rsi_below is not None:
|
||||
checks.append(rsi is not None and rsi < condition.rsi_below)
|
||||
if condition.rsi_above is not None:
|
||||
checks.append(rsi is not None and rsi > condition.rsi_above)
|
||||
|
||||
volume_ratio = self._safe_float(market_data.get("volume_ratio"))
|
||||
if condition.volume_ratio_above is not None or condition.volume_ratio_below is not None:
|
||||
if "volume_ratio" not in market_data:
|
||||
self._warn_missing_key("volume_ratio")
|
||||
if condition.volume_ratio_above is not None:
|
||||
checks.append(volume_ratio is not None and volume_ratio > condition.volume_ratio_above)
|
||||
if condition.volume_ratio_below is not None:
|
||||
checks.append(volume_ratio is not None and volume_ratio < condition.volume_ratio_below)
|
||||
|
||||
price = self._safe_float(market_data.get("current_price"))
|
||||
if condition.price_above is not None or condition.price_below is not None:
|
||||
if "current_price" not in market_data:
|
||||
self._warn_missing_key("current_price")
|
||||
if condition.price_above is not None:
|
||||
checks.append(price is not None and price > condition.price_above)
|
||||
if condition.price_below is not None:
|
||||
checks.append(price is not None and price < condition.price_below)
|
||||
|
||||
price_change_pct = self._safe_float(market_data.get("price_change_pct"))
|
||||
if condition.price_change_pct_above is not None or condition.price_change_pct_below is not None:
|
||||
if "price_change_pct" not in market_data:
|
||||
self._warn_missing_key("price_change_pct")
|
||||
if condition.price_change_pct_above is not None:
|
||||
checks.append(price_change_pct is not None and price_change_pct > condition.price_change_pct_above)
|
||||
if condition.price_change_pct_below is not None:
|
||||
checks.append(price_change_pct is not None and price_change_pct < condition.price_change_pct_below)
|
||||
|
||||
# Position-aware conditions
|
||||
unrealized_pnl_pct = self._safe_float(market_data.get("unrealized_pnl_pct"))
|
||||
if condition.unrealized_pnl_pct_above is not None or condition.unrealized_pnl_pct_below is not None:
|
||||
if "unrealized_pnl_pct" not in market_data:
|
||||
self._warn_missing_key("unrealized_pnl_pct")
|
||||
if condition.unrealized_pnl_pct_above is not None:
|
||||
checks.append(
|
||||
unrealized_pnl_pct is not None
|
||||
and unrealized_pnl_pct > condition.unrealized_pnl_pct_above
|
||||
)
|
||||
if condition.unrealized_pnl_pct_below is not None:
|
||||
checks.append(
|
||||
unrealized_pnl_pct is not None
|
||||
and unrealized_pnl_pct < condition.unrealized_pnl_pct_below
|
||||
)
|
||||
|
||||
holding_days = self._safe_float(market_data.get("holding_days"))
|
||||
if condition.holding_days_above is not None or condition.holding_days_below is not None:
|
||||
if "holding_days" not in market_data:
|
||||
self._warn_missing_key("holding_days")
|
||||
if condition.holding_days_above is not None:
|
||||
checks.append(
|
||||
holding_days is not None
|
||||
and holding_days > condition.holding_days_above
|
||||
)
|
||||
if condition.holding_days_below is not None:
|
||||
checks.append(
|
||||
holding_days is not None
|
||||
and holding_days < condition.holding_days_below
|
||||
)
|
||||
|
||||
return len(checks) > 0 and all(checks)
|
||||
|
||||
def _evaluate_global_condition(
|
||||
self,
|
||||
condition_str: str,
|
||||
portfolio_data: dict[str, Any],
|
||||
) -> bool:
|
||||
"""Evaluate a simple global condition string against portfolio data.
|
||||
|
||||
Supports: "field < value", "field > value", "field <= value", "field >= value"
|
||||
"""
|
||||
parts = condition_str.strip().split()
|
||||
if len(parts) != 3:
|
||||
logger.warning("Invalid global condition format: %s", condition_str)
|
||||
return False
|
||||
|
||||
field_name, operator, value_str = parts
|
||||
try:
|
||||
threshold = float(value_str)
|
||||
except ValueError:
|
||||
logger.warning("Invalid threshold in condition: %s", condition_str)
|
||||
return False
|
||||
|
||||
actual = portfolio_data.get(field_name)
|
||||
if actual is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
actual_val = float(actual)
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
if operator == "<":
|
||||
return actual_val < threshold
|
||||
elif operator == ">":
|
||||
return actual_val > threshold
|
||||
elif operator == "<=":
|
||||
return actual_val <= threshold
|
||||
elif operator == ">=":
|
||||
return actual_val >= threshold
|
||||
else:
|
||||
logger.warning("Unknown operator in condition: %s", operator)
|
||||
return False
|
||||
|
||||
def _build_match_details(
|
||||
self,
|
||||
condition: StockCondition,
|
||||
market_data: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Build a summary of which conditions matched and their normalized values."""
|
||||
details: dict[str, Any] = {}
|
||||
|
||||
if condition.rsi_below is not None or condition.rsi_above is not None:
|
||||
details["rsi"] = self._safe_float(market_data.get("rsi"))
|
||||
if condition.volume_ratio_above is not None or condition.volume_ratio_below is not None:
|
||||
details["volume_ratio"] = self._safe_float(market_data.get("volume_ratio"))
|
||||
if condition.price_above is not None or condition.price_below is not None:
|
||||
details["current_price"] = self._safe_float(market_data.get("current_price"))
|
||||
if condition.price_change_pct_above is not None or condition.price_change_pct_below is not None:
|
||||
details["price_change_pct"] = self._safe_float(market_data.get("price_change_pct"))
|
||||
if condition.unrealized_pnl_pct_above is not None or condition.unrealized_pnl_pct_below is not None:
|
||||
details["unrealized_pnl_pct"] = self._safe_float(market_data.get("unrealized_pnl_pct"))
|
||||
if condition.holding_days_above is not None or condition.holding_days_below is not None:
|
||||
details["holding_days"] = self._safe_float(market_data.get("holding_days"))
|
||||
|
||||
return details
|
||||
799
tests/test_backup.py
Normal file
799
tests/test_backup.py
Normal file
@@ -0,0 +1,799 @@
|
||||
"""Tests for backup and disaster recovery system."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import sys
|
||||
import tempfile
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.backup.exporter import BackupExporter, ExportFormat
|
||||
from src.backup.health_monitor import HealthMonitor, HealthStatus
|
||||
from src.backup.scheduler import BackupPolicy, BackupScheduler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db(tmp_path: Path) -> Path:
|
||||
"""Create a temporary test database."""
|
||||
db_path = tmp_path / "test_trades.db"
|
||||
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create trades table
|
||||
cursor.execute("""
|
||||
CREATE TABLE trades (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
timestamp TEXT NOT NULL,
|
||||
stock_code TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
quantity INTEGER NOT NULL,
|
||||
price REAL NOT NULL,
|
||||
confidence INTEGER NOT NULL,
|
||||
rationale TEXT,
|
||||
pnl REAL DEFAULT 0.0
|
||||
)
|
||||
""")
|
||||
|
||||
# Insert test data
|
||||
test_trades = [
|
||||
("2024-01-01T10:00:00Z", "005930", "BUY", 10, 70000.0, 85, "Test buy", 0.0),
|
||||
("2024-01-01T11:00:00Z", "005930", "SELL", 10, 71000.0, 90, "Test sell", 10000.0),
|
||||
("2024-01-02T10:00:00Z", "AAPL", "BUY", 5, 180.0, 88, "Tech buy", 0.0),
|
||||
]
|
||||
|
||||
cursor.executemany(
|
||||
"""
|
||||
INSERT INTO trades (timestamp, stock_code, action, quantity, price, confidence, rationale, pnl)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
test_trades,
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
return db_path
|
||||
|
||||
|
||||
class TestBackupExporter:
|
||||
"""Test BackupExporter functionality."""
|
||||
|
||||
def test_exporter_init(self, temp_db: Path) -> None:
|
||||
"""Test exporter initialization."""
|
||||
exporter = BackupExporter(str(temp_db))
|
||||
assert exporter.db_path == str(temp_db)
|
||||
|
||||
def test_export_json(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test JSON export."""
|
||||
exporter = BackupExporter(str(temp_db))
|
||||
output_dir = tmp_path / "exports"
|
||||
|
||||
results = exporter.export_all(
|
||||
output_dir, formats=[ExportFormat.JSON], compress=False
|
||||
)
|
||||
|
||||
assert ExportFormat.JSON in results
|
||||
assert results[ExportFormat.JSON].exists()
|
||||
assert results[ExportFormat.JSON].suffix == ".json"
|
||||
|
||||
def test_export_json_compressed(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test compressed JSON export."""
|
||||
exporter = BackupExporter(str(temp_db))
|
||||
output_dir = tmp_path / "exports"
|
||||
|
||||
results = exporter.export_all(
|
||||
output_dir, formats=[ExportFormat.JSON], compress=True
|
||||
)
|
||||
|
||||
assert ExportFormat.JSON in results
|
||||
assert results[ExportFormat.JSON].suffix == ".gz"
|
||||
|
||||
def test_export_csv(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test CSV export."""
|
||||
exporter = BackupExporter(str(temp_db))
|
||||
output_dir = tmp_path / "exports"
|
||||
|
||||
results = exporter.export_all(
|
||||
output_dir, formats=[ExportFormat.CSV], compress=False
|
||||
)
|
||||
|
||||
assert ExportFormat.CSV in results
|
||||
assert results[ExportFormat.CSV].exists()
|
||||
|
||||
# Verify CSV content
|
||||
with open(results[ExportFormat.CSV], "r") as f:
|
||||
lines = f.readlines()
|
||||
assert len(lines) == 4 # Header + 3 rows
|
||||
|
||||
def test_export_all_formats(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test exporting all formats."""
|
||||
exporter = BackupExporter(str(temp_db))
|
||||
output_dir = tmp_path / "exports"
|
||||
|
||||
# Skip Parquet if pyarrow not available
|
||||
try:
|
||||
import pyarrow # noqa: F401
|
||||
|
||||
formats = [ExportFormat.JSON, ExportFormat.CSV, ExportFormat.PARQUET]
|
||||
except ImportError:
|
||||
formats = [ExportFormat.JSON, ExportFormat.CSV]
|
||||
|
||||
results = exporter.export_all(output_dir, formats=formats, compress=False)
|
||||
|
||||
for fmt in formats:
|
||||
assert fmt in results
|
||||
assert results[fmt].exists()
|
||||
|
||||
def test_incremental_export(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test incremental export."""
|
||||
exporter = BackupExporter(str(temp_db))
|
||||
output_dir = tmp_path / "exports"
|
||||
|
||||
# Export only trades after Jan 2
|
||||
cutoff = datetime(2024, 1, 2, tzinfo=UTC)
|
||||
results = exporter.export_all(
|
||||
output_dir,
|
||||
formats=[ExportFormat.JSON],
|
||||
compress=False,
|
||||
incremental_since=cutoff,
|
||||
)
|
||||
|
||||
# Should only have 1 trade (AAPL on Jan 2)
|
||||
import json
|
||||
|
||||
with open(results[ExportFormat.JSON], "r") as f:
|
||||
data = json.load(f)
|
||||
assert data["record_count"] == 1
|
||||
assert data["trades"][0]["stock_code"] == "AAPL"
|
||||
|
||||
def test_get_export_stats(self, temp_db: Path) -> None:
|
||||
"""Test export statistics."""
|
||||
exporter = BackupExporter(str(temp_db))
|
||||
stats = exporter.get_export_stats()
|
||||
|
||||
assert stats["total_trades"] == 3
|
||||
assert "date_range" in stats
|
||||
assert "db_size_bytes" in stats
|
||||
|
||||
|
||||
class TestBackupScheduler:
|
||||
"""Test BackupScheduler functionality."""
|
||||
|
||||
def test_scheduler_init(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test scheduler initialization."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
|
||||
assert scheduler.db_path == temp_db
|
||||
assert (backup_dir / "daily").exists()
|
||||
assert (backup_dir / "weekly").exists()
|
||||
assert (backup_dir / "monthly").exists()
|
||||
|
||||
def test_create_daily_backup(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test daily backup creation."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
|
||||
metadata = scheduler.create_backup(BackupPolicy.DAILY, verify=True)
|
||||
|
||||
assert metadata.policy == BackupPolicy.DAILY
|
||||
assert metadata.file_path.exists()
|
||||
assert metadata.size_bytes > 0
|
||||
assert metadata.checksum is not None
|
||||
|
||||
def test_create_weekly_backup(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test weekly backup creation."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
|
||||
metadata = scheduler.create_backup(BackupPolicy.WEEKLY, verify=False)
|
||||
|
||||
assert metadata.policy == BackupPolicy.WEEKLY
|
||||
assert metadata.file_path.exists()
|
||||
assert metadata.checksum is None # verify=False
|
||||
|
||||
def test_list_backups(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test listing backups."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
|
||||
scheduler.create_backup(BackupPolicy.DAILY)
|
||||
scheduler.create_backup(BackupPolicy.WEEKLY)
|
||||
|
||||
backups = scheduler.list_backups()
|
||||
assert len(backups) == 2
|
||||
|
||||
daily_backups = scheduler.list_backups(BackupPolicy.DAILY)
|
||||
assert len(daily_backups) == 1
|
||||
assert daily_backups[0].policy == BackupPolicy.DAILY
|
||||
|
||||
def test_cleanup_old_backups(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test cleanup of old backups."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir, daily_retention_days=0)
|
||||
|
||||
# Create a backup
|
||||
scheduler.create_backup(BackupPolicy.DAILY)
|
||||
|
||||
# Cleanup should remove it (0 day retention)
|
||||
removed = scheduler.cleanup_old_backups()
|
||||
assert removed[BackupPolicy.DAILY] >= 1
|
||||
|
||||
def test_backup_stats(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test backup statistics."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
|
||||
scheduler.create_backup(BackupPolicy.DAILY)
|
||||
scheduler.create_backup(BackupPolicy.MONTHLY)
|
||||
|
||||
stats = scheduler.get_backup_stats()
|
||||
|
||||
assert stats["daily"]["count"] == 1
|
||||
assert stats["monthly"]["count"] == 1
|
||||
assert stats["daily"]["total_size_bytes"] > 0
|
||||
|
||||
def test_restore_backup(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test backup restoration."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
|
||||
# Create backup
|
||||
metadata = scheduler.create_backup(BackupPolicy.DAILY)
|
||||
|
||||
# Modify database
|
||||
conn = sqlite3.connect(str(temp_db))
|
||||
conn.execute("DELETE FROM trades")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# Restore
|
||||
scheduler.restore_backup(metadata, verify=True)
|
||||
|
||||
# Verify restoration
|
||||
conn = sqlite3.connect(str(temp_db))
|
||||
cursor = conn.execute("SELECT COUNT(*) FROM trades")
|
||||
count = cursor.fetchone()[0]
|
||||
conn.close()
|
||||
|
||||
assert count == 3 # Original 3 trades restored
|
||||
|
||||
|
||||
class TestHealthMonitor:
|
||||
"""Test HealthMonitor functionality."""
|
||||
|
||||
def test_monitor_init(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test monitor initialization."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
monitor = HealthMonitor(str(temp_db), backup_dir)
|
||||
|
||||
assert monitor.db_path == temp_db
|
||||
|
||||
def test_check_database_health_ok(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test database health check (healthy)."""
|
||||
monitor = HealthMonitor(str(temp_db), tmp_path / "backups")
|
||||
result = monitor.check_database_health()
|
||||
|
||||
assert result.status == HealthStatus.HEALTHY
|
||||
assert "healthy" in result.message.lower()
|
||||
assert result.details is not None
|
||||
assert result.details["trade_count"] == 3
|
||||
|
||||
def test_check_database_health_missing(self, tmp_path: Path) -> None:
|
||||
"""Test database health check (missing file)."""
|
||||
non_existent = tmp_path / "missing.db"
|
||||
monitor = HealthMonitor(str(non_existent), tmp_path / "backups")
|
||||
result = monitor.check_database_health()
|
||||
|
||||
assert result.status == HealthStatus.UNHEALTHY
|
||||
assert "not found" in result.message.lower()
|
||||
|
||||
def test_check_disk_space(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test disk space check."""
|
||||
monitor = HealthMonitor(str(temp_db), tmp_path, min_disk_space_gb=0.001)
|
||||
result = monitor.check_disk_space()
|
||||
|
||||
# Should be healthy with minimal requirement
|
||||
assert result.status in [HealthStatus.HEALTHY, HealthStatus.DEGRADED]
|
||||
assert result.details is not None
|
||||
assert "free_gb" in result.details
|
||||
|
||||
def test_check_backup_recency_no_backups(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test backup recency check (no backups)."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
backup_dir.mkdir()
|
||||
(backup_dir / "daily").mkdir()
|
||||
|
||||
monitor = HealthMonitor(str(temp_db), backup_dir)
|
||||
result = monitor.check_backup_recency()
|
||||
|
||||
assert result.status == HealthStatus.UNHEALTHY
|
||||
assert "no" in result.message.lower()
|
||||
|
||||
def test_check_backup_recency_recent(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test backup recency check (recent backup)."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
scheduler.create_backup(BackupPolicy.DAILY)
|
||||
|
||||
monitor = HealthMonitor(str(temp_db), backup_dir)
|
||||
result = monitor.check_backup_recency()
|
||||
|
||||
assert result.status == HealthStatus.HEALTHY
|
||||
assert "recent" in result.message.lower()
|
||||
|
||||
def test_run_all_checks(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test running all health checks."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
scheduler.create_backup(BackupPolicy.DAILY)
|
||||
|
||||
monitor = HealthMonitor(str(temp_db), backup_dir, min_disk_space_gb=0.001)
|
||||
checks = monitor.run_all_checks()
|
||||
|
||||
assert "database" in checks
|
||||
assert "disk_space" in checks
|
||||
assert "backup_recency" in checks
|
||||
assert checks["database"].status == HealthStatus.HEALTHY
|
||||
|
||||
def test_get_overall_status(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test overall health status."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
scheduler.create_backup(BackupPolicy.DAILY)
|
||||
|
||||
monitor = HealthMonitor(str(temp_db), backup_dir, min_disk_space_gb=0.001)
|
||||
status = monitor.get_overall_status()
|
||||
|
||||
assert status in [HealthStatus.HEALTHY, HealthStatus.DEGRADED]
|
||||
|
||||
def test_get_health_report(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test health report generation."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
scheduler.create_backup(BackupPolicy.DAILY)
|
||||
|
||||
monitor = HealthMonitor(str(temp_db), backup_dir, min_disk_space_gb=0.001)
|
||||
report = monitor.get_health_report()
|
||||
|
||||
assert "overall_status" in report
|
||||
assert "timestamp" in report
|
||||
assert "checks" in report
|
||||
assert len(report["checks"]) == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BackupExporter — additional coverage for previously uncovered branches
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_db(tmp_path: Path) -> Path:
|
||||
"""Create a temporary database with NO trade records."""
|
||||
db_path = tmp_path / "empty_trades.db"
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.execute(
|
||||
"""CREATE TABLE trades (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
timestamp TEXT NOT NULL,
|
||||
stock_code TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
quantity INTEGER NOT NULL,
|
||||
price REAL NOT NULL,
|
||||
confidence INTEGER NOT NULL,
|
||||
rationale TEXT,
|
||||
pnl REAL DEFAULT 0.0
|
||||
)"""
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return db_path
|
||||
|
||||
|
||||
class TestBackupExporterAdditional:
|
||||
"""Cover branches missed in the original TestBackupExporter suite."""
|
||||
|
||||
def test_export_all_default_formats(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""export_all with formats=None must default to JSON+CSV+Parquet path."""
|
||||
exporter = BackupExporter(str(temp_db))
|
||||
# formats=None triggers the default list assignment (line 62)
|
||||
results = exporter.export_all(tmp_path / "out", formats=None, compress=False)
|
||||
# JSON and CSV must always succeed; Parquet needs pyarrow
|
||||
assert ExportFormat.JSON in results
|
||||
assert ExportFormat.CSV in results
|
||||
|
||||
def test_export_all_logs_error_on_failure(
|
||||
self, temp_db: Path, tmp_path: Path
|
||||
) -> None:
|
||||
"""export_all must log an error and continue when one format fails."""
|
||||
exporter = BackupExporter(str(temp_db))
|
||||
# Patch _export_format to raise on JSON, succeed on CSV
|
||||
original = exporter._export_format
|
||||
|
||||
def failing_export(fmt, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
if fmt == ExportFormat.JSON:
|
||||
raise RuntimeError("simulated failure")
|
||||
return original(fmt, *args, **kwargs)
|
||||
|
||||
exporter._export_format = failing_export # type: ignore[method-assign]
|
||||
results = exporter.export_all(
|
||||
tmp_path / "out",
|
||||
formats=[ExportFormat.JSON, ExportFormat.CSV],
|
||||
compress=False,
|
||||
)
|
||||
# JSON failed → not in results; CSV succeeded → in results
|
||||
assert ExportFormat.JSON not in results
|
||||
assert ExportFormat.CSV in results
|
||||
|
||||
def test_export_csv_empty_trades_no_compress(
|
||||
self, empty_db: Path, tmp_path: Path
|
||||
) -> None:
|
||||
"""CSV export with no trades and compress=False must write header row only."""
|
||||
exporter = BackupExporter(str(empty_db))
|
||||
results = exporter.export_all(
|
||||
tmp_path / "out",
|
||||
formats=[ExportFormat.CSV],
|
||||
compress=False,
|
||||
)
|
||||
assert ExportFormat.CSV in results
|
||||
out = results[ExportFormat.CSV]
|
||||
assert out.exists()
|
||||
content = out.read_text()
|
||||
assert "timestamp" in content
|
||||
|
||||
def test_export_csv_empty_trades_compressed(
|
||||
self, empty_db: Path, tmp_path: Path
|
||||
) -> None:
|
||||
"""CSV export with no trades and compress=True must write gzipped header."""
|
||||
import gzip
|
||||
|
||||
exporter = BackupExporter(str(empty_db))
|
||||
results = exporter.export_all(
|
||||
tmp_path / "out",
|
||||
formats=[ExportFormat.CSV],
|
||||
compress=True,
|
||||
)
|
||||
assert ExportFormat.CSV in results
|
||||
out = results[ExportFormat.CSV]
|
||||
assert out.suffix == ".gz"
|
||||
with gzip.open(out, "rt", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
assert "timestamp" in content
|
||||
|
||||
def test_export_csv_with_data_compressed(
|
||||
self, temp_db: Path, tmp_path: Path
|
||||
) -> None:
|
||||
"""CSV export with data and compress=True must write gzipped rows."""
|
||||
import gzip
|
||||
|
||||
exporter = BackupExporter(str(temp_db))
|
||||
results = exporter.export_all(
|
||||
tmp_path / "out",
|
||||
formats=[ExportFormat.CSV],
|
||||
compress=True,
|
||||
)
|
||||
assert ExportFormat.CSV in results
|
||||
out = results[ExportFormat.CSV]
|
||||
with gzip.open(out, "rt", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
# Header + 3 data rows
|
||||
assert len(lines) == 4
|
||||
|
||||
def test_export_parquet_raises_import_error_without_pyarrow(
|
||||
self, temp_db: Path, tmp_path: Path
|
||||
) -> None:
|
||||
"""Parquet export must raise ImportError when pyarrow is not installed."""
|
||||
exporter = BackupExporter(str(temp_db))
|
||||
with patch.dict(sys.modules, {"pyarrow": None, "pyarrow.parquet": None}):
|
||||
try:
|
||||
import pyarrow # noqa: F401
|
||||
pytest.skip("pyarrow is installed; cannot test ImportError path")
|
||||
except ImportError:
|
||||
pass
|
||||
results = exporter.export_all(
|
||||
tmp_path / "out",
|
||||
formats=[ExportFormat.PARQUET],
|
||||
compress=False,
|
||||
)
|
||||
# Parquet export fails gracefully; result dict should not contain it
|
||||
assert ExportFormat.PARQUET not in results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CloudStorage — mocked boto3 tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_boto3_module():
|
||||
"""Inject a fake boto3 into sys.modules for the duration of the test."""
|
||||
mock = MagicMock()
|
||||
with patch.dict(sys.modules, {"boto3": mock}):
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_config():
|
||||
"""Minimal S3Config for tests."""
|
||||
from src.backup.cloud_storage import S3Config
|
||||
|
||||
return S3Config(
|
||||
endpoint_url="http://localhost:9000",
|
||||
access_key="minioadmin",
|
||||
secret_key="minioadmin",
|
||||
bucket_name="test-bucket",
|
||||
region="us-east-1",
|
||||
)
|
||||
|
||||
|
||||
class TestCloudStorage:
|
||||
"""Test CloudStorage using mocked boto3."""
|
||||
|
||||
def test_init_creates_s3_client(self, mock_boto3_module, s3_config) -> None:
|
||||
"""CloudStorage.__init__ must call boto3.client with the correct args."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
mock_boto3_module.client.assert_called_once()
|
||||
call_kwargs = mock_boto3_module.client.call_args[1]
|
||||
assert call_kwargs["aws_access_key_id"] == "minioadmin"
|
||||
assert call_kwargs["aws_secret_access_key"] == "minioadmin"
|
||||
assert storage.config == s3_config
|
||||
|
||||
def test_init_raises_if_boto3_missing(self, s3_config) -> None:
|
||||
"""CloudStorage.__init__ must raise ImportError when boto3 is absent."""
|
||||
with patch.dict(sys.modules, {"boto3": None}): # type: ignore[dict-item]
|
||||
with pytest.raises((ImportError, TypeError)):
|
||||
# Re-import to trigger the try/except inside __init__
|
||||
import importlib
|
||||
|
||||
import src.backup.cloud_storage as m
|
||||
|
||||
importlib.reload(m)
|
||||
m.CloudStorage(s3_config)
|
||||
|
||||
def test_upload_file_success(
|
||||
self, mock_boto3_module, s3_config, tmp_path: Path
|
||||
) -> None:
|
||||
"""upload_file must call client.upload_file and return the object key."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
test_file = tmp_path / "backup.json.gz"
|
||||
test_file.write_bytes(b"data")
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
key = storage.upload_file(test_file, object_key="backups/backup.json.gz")
|
||||
|
||||
assert key == "backups/backup.json.gz"
|
||||
storage.client.upload_file.assert_called_once()
|
||||
|
||||
def test_upload_file_default_key(
|
||||
self, mock_boto3_module, s3_config, tmp_path: Path
|
||||
) -> None:
|
||||
"""upload_file without object_key must use the filename as key."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
test_file = tmp_path / "myfile.gz"
|
||||
test_file.write_bytes(b"data")
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
key = storage.upload_file(test_file)
|
||||
|
||||
assert key == "myfile.gz"
|
||||
|
||||
def test_upload_file_not_found(
|
||||
self, mock_boto3_module, s3_config, tmp_path: Path
|
||||
) -> None:
|
||||
"""upload_file must raise FileNotFoundError for missing files."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
with pytest.raises(FileNotFoundError):
|
||||
storage.upload_file(tmp_path / "nonexistent.gz")
|
||||
|
||||
def test_upload_file_propagates_client_error(
|
||||
self, mock_boto3_module, s3_config, tmp_path: Path
|
||||
) -> None:
|
||||
"""upload_file must re-raise exceptions from the boto3 client."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
test_file = tmp_path / "backup.gz"
|
||||
test_file.write_bytes(b"data")
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
storage.client.upload_file.side_effect = RuntimeError("network error")
|
||||
|
||||
with pytest.raises(RuntimeError, match="network error"):
|
||||
storage.upload_file(test_file)
|
||||
|
||||
def test_download_file_success(
|
||||
self, mock_boto3_module, s3_config, tmp_path: Path
|
||||
) -> None:
|
||||
"""download_file must call client.download_file and return local path."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
dest = tmp_path / "downloads" / "backup.gz"
|
||||
|
||||
result = storage.download_file("backups/backup.gz", dest)
|
||||
|
||||
assert result == dest
|
||||
storage.client.download_file.assert_called_once()
|
||||
|
||||
def test_download_file_propagates_error(
|
||||
self, mock_boto3_module, s3_config, tmp_path: Path
|
||||
) -> None:
|
||||
"""download_file must re-raise exceptions from the boto3 client."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
storage.client.download_file.side_effect = RuntimeError("timeout")
|
||||
|
||||
with pytest.raises(RuntimeError, match="timeout"):
|
||||
storage.download_file("key", tmp_path / "dest.gz")
|
||||
|
||||
def test_list_files_returns_objects(
|
||||
self, mock_boto3_module, s3_config
|
||||
) -> None:
|
||||
"""list_files must return parsed file metadata from S3 response."""
|
||||
from datetime import timezone
|
||||
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
storage.client.list_objects_v2.return_value = {
|
||||
"Contents": [
|
||||
{
|
||||
"Key": "backups/a.gz",
|
||||
"Size": 1024,
|
||||
"LastModified": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
"ETag": '"abc123"',
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
files = storage.list_files(prefix="backups/")
|
||||
assert len(files) == 1
|
||||
assert files[0]["key"] == "backups/a.gz"
|
||||
assert files[0]["size_bytes"] == 1024
|
||||
|
||||
def test_list_files_empty_bucket(
|
||||
self, mock_boto3_module, s3_config
|
||||
) -> None:
|
||||
"""list_files must return empty list when bucket has no objects."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
storage.client.list_objects_v2.return_value = {}
|
||||
|
||||
files = storage.list_files()
|
||||
assert files == []
|
||||
|
||||
def test_list_files_propagates_error(
|
||||
self, mock_boto3_module, s3_config
|
||||
) -> None:
|
||||
"""list_files must re-raise exceptions from the boto3 client."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
storage.client.list_objects_v2.side_effect = RuntimeError("auth error")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
storage.list_files()
|
||||
|
||||
def test_delete_file_success(
|
||||
self, mock_boto3_module, s3_config
|
||||
) -> None:
|
||||
"""delete_file must call client.delete_object with the correct key."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
storage.delete_file("backups/old.gz")
|
||||
storage.client.delete_object.assert_called_once_with(
|
||||
Bucket="test-bucket", Key="backups/old.gz"
|
||||
)
|
||||
|
||||
def test_delete_file_propagates_error(
|
||||
self, mock_boto3_module, s3_config
|
||||
) -> None:
|
||||
"""delete_file must re-raise exceptions from the boto3 client."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
storage.client.delete_object.side_effect = RuntimeError("permission denied")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
storage.delete_file("backups/old.gz")
|
||||
|
||||
def test_get_storage_stats_success(
|
||||
self, mock_boto3_module, s3_config
|
||||
) -> None:
|
||||
"""get_storage_stats must aggregate file sizes correctly."""
|
||||
from datetime import timezone
|
||||
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
storage.client.list_objects_v2.return_value = {
|
||||
"Contents": [
|
||||
{
|
||||
"Key": "a.gz",
|
||||
"Size": 1024 * 1024,
|
||||
"LastModified": datetime(2026, 1, 1, tzinfo=timezone.utc),
|
||||
"ETag": '"x"',
|
||||
},
|
||||
{
|
||||
"Key": "b.gz",
|
||||
"Size": 1024 * 1024,
|
||||
"LastModified": datetime(2026, 1, 2, tzinfo=timezone.utc),
|
||||
"ETag": '"y"',
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
stats = storage.get_storage_stats()
|
||||
assert stats["total_files"] == 2
|
||||
assert stats["total_size_bytes"] == 2 * 1024 * 1024
|
||||
assert stats["total_size_mb"] == pytest.approx(2.0)
|
||||
|
||||
def test_get_storage_stats_on_error(
|
||||
self, mock_boto3_module, s3_config
|
||||
) -> None:
|
||||
"""get_storage_stats must return error dict without raising on failure."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
storage.client.list_objects_v2.side_effect = RuntimeError("no connection")
|
||||
|
||||
stats = storage.get_storage_stats()
|
||||
assert "error" in stats
|
||||
assert stats["total_files"] == 0
|
||||
|
||||
def test_verify_connection_success(
|
||||
self, mock_boto3_module, s3_config
|
||||
) -> None:
|
||||
"""verify_connection must return True when head_bucket succeeds."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
result = storage.verify_connection()
|
||||
assert result is True
|
||||
|
||||
def test_verify_connection_failure(
|
||||
self, mock_boto3_module, s3_config
|
||||
) -> None:
|
||||
"""verify_connection must return False when head_bucket raises."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
storage.client.head_bucket.side_effect = RuntimeError("no such bucket")
|
||||
|
||||
result = storage.verify_connection()
|
||||
assert result is False
|
||||
|
||||
def test_enable_versioning(
|
||||
self, mock_boto3_module, s3_config
|
||||
) -> None:
|
||||
"""enable_versioning must call put_bucket_versioning."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
storage.enable_versioning()
|
||||
storage.client.put_bucket_versioning.assert_called_once()
|
||||
|
||||
def test_enable_versioning_propagates_error(
|
||||
self, mock_boto3_module, s3_config
|
||||
) -> None:
|
||||
"""enable_versioning must re-raise exceptions from the boto3 client."""
|
||||
from src.backup.cloud_storage import CloudStorage
|
||||
|
||||
storage = CloudStorage(s3_config)
|
||||
storage.client.put_bucket_versioning.side_effect = RuntimeError("denied")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
storage.enable_versioning()
|
||||
@@ -2,6 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.brain.gemini_client import GeminiClient
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -126,7 +130,7 @@ class TestPromptConstruction:
|
||||
"orderbook": {"asks": [], "bids": []},
|
||||
"foreigner_net": -50000,
|
||||
}
|
||||
prompt = client.build_prompt(market_data)
|
||||
prompt = client.build_prompt_sync(market_data)
|
||||
assert "005930" in prompt
|
||||
|
||||
def test_prompt_contains_price(self, settings):
|
||||
@@ -137,7 +141,7 @@ class TestPromptConstruction:
|
||||
"orderbook": {"asks": [], "bids": []},
|
||||
"foreigner_net": -50000,
|
||||
}
|
||||
prompt = client.build_prompt(market_data)
|
||||
prompt = client.build_prompt_sync(market_data)
|
||||
assert "72000" in prompt
|
||||
|
||||
def test_prompt_enforces_json_output_format(self, settings):
|
||||
@@ -148,7 +152,219 @@ class TestPromptConstruction:
|
||||
"orderbook": {"asks": [], "bids": []},
|
||||
"foreigner_net": 0,
|
||||
}
|
||||
prompt = client.build_prompt(market_data)
|
||||
prompt = client.build_prompt_sync(market_data)
|
||||
assert "JSON" in prompt
|
||||
assert "action" in prompt
|
||||
assert "confidence" in prompt
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Batch Decision Making
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBatchDecisionParsing:
|
||||
"""Batch response parser must handle JSON arrays correctly."""
|
||||
|
||||
def test_parse_valid_batch_response(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
stocks_data = [
|
||||
{"stock_code": "AAPL", "current_price": 185.5},
|
||||
{"stock_code": "MSFT", "current_price": 420.0},
|
||||
]
|
||||
raw = """[
|
||||
{"code": "AAPL", "action": "BUY", "confidence": 85, "rationale": "Strong momentum"},
|
||||
{"code": "MSFT", "action": "HOLD", "confidence": 50, "rationale": "Wait for earnings"}
|
||||
]"""
|
||||
|
||||
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||
|
||||
assert len(decisions) == 2
|
||||
assert decisions["AAPL"].action == "BUY"
|
||||
assert decisions["AAPL"].confidence == 85
|
||||
assert decisions["MSFT"].action == "HOLD"
|
||||
assert decisions["MSFT"].confidence == 50
|
||||
|
||||
def test_parse_batch_with_markdown_wrapper(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
||||
raw = """```json
|
||||
[{"code": "AAPL", "action": "BUY", "confidence": 90, "rationale": "Good"}]
|
||||
```"""
|
||||
|
||||
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||
|
||||
assert decisions["AAPL"].action == "BUY"
|
||||
assert decisions["AAPL"].confidence == 90
|
||||
|
||||
def test_parse_batch_empty_response_returns_hold_for_all(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
stocks_data = [
|
||||
{"stock_code": "AAPL", "current_price": 185.5},
|
||||
{"stock_code": "MSFT", "current_price": 420.0},
|
||||
]
|
||||
|
||||
decisions = client._parse_batch_response("", stocks_data, token_count=100)
|
||||
|
||||
assert len(decisions) == 2
|
||||
assert decisions["AAPL"].action == "HOLD"
|
||||
assert decisions["AAPL"].confidence == 0
|
||||
assert decisions["MSFT"].action == "HOLD"
|
||||
|
||||
def test_parse_batch_malformed_json_returns_hold_for_all(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
||||
raw = "This is not JSON"
|
||||
|
||||
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||
|
||||
assert decisions["AAPL"].action == "HOLD"
|
||||
assert decisions["AAPL"].confidence == 0
|
||||
|
||||
def test_parse_batch_not_array_returns_hold_for_all(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
||||
raw = '{"code": "AAPL", "action": "BUY", "confidence": 90, "rationale": "Good"}'
|
||||
|
||||
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||
|
||||
assert decisions["AAPL"].action == "HOLD"
|
||||
assert decisions["AAPL"].confidence == 0
|
||||
|
||||
def test_parse_batch_missing_stock_gets_hold(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
stocks_data = [
|
||||
{"stock_code": "AAPL", "current_price": 185.5},
|
||||
{"stock_code": "MSFT", "current_price": 420.0},
|
||||
]
|
||||
# Response only has AAPL, MSFT is missing
|
||||
raw = '[{"code": "AAPL", "action": "BUY", "confidence": 85, "rationale": "Good"}]'
|
||||
|
||||
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||
|
||||
assert decisions["AAPL"].action == "BUY"
|
||||
assert decisions["MSFT"].action == "HOLD"
|
||||
assert decisions["MSFT"].confidence == 0
|
||||
|
||||
def test_parse_batch_invalid_action_becomes_hold(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
||||
raw = '[{"code": "AAPL", "action": "YOLO", "confidence": 90, "rationale": "Moon"}]'
|
||||
|
||||
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||
|
||||
assert decisions["AAPL"].action == "HOLD"
|
||||
|
||||
def test_parse_batch_low_confidence_becomes_hold(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
||||
raw = '[{"code": "AAPL", "action": "BUY", "confidence": 65, "rationale": "Weak"}]'
|
||||
|
||||
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||
|
||||
assert decisions["AAPL"].action == "HOLD"
|
||||
assert decisions["AAPL"].confidence == 65
|
||||
|
||||
def test_parse_batch_missing_fields_gets_hold(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
||||
raw = '[{"code": "AAPL", "action": "BUY"}]' # Missing confidence and rationale
|
||||
|
||||
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||
|
||||
assert decisions["AAPL"].action == "HOLD"
|
||||
assert decisions["AAPL"].confidence == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompt Override (used by pre_market_planner)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPromptOverride:
|
||||
"""decide() must use prompt_override when present in market_data."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_override_is_sent_to_gemini(self, settings):
|
||||
"""When prompt_override is in market_data, it should be used as the prompt."""
|
||||
client = GeminiClient(settings)
|
||||
|
||||
custom_prompt = "You are a playbook generator. Return JSON with scenarios."
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = '{"action": "HOLD", "confidence": 50, "rationale": "test"}'
|
||||
|
||||
with patch.object(
|
||||
client._client.aio.models,
|
||||
"generate_content",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_generate:
|
||||
market_data = {
|
||||
"stock_code": "PLANNER",
|
||||
"current_price": 0,
|
||||
"prompt_override": custom_prompt,
|
||||
}
|
||||
await client.decide(market_data)
|
||||
|
||||
# Verify the custom prompt was sent, not a built prompt
|
||||
mock_generate.assert_called_once()
|
||||
actual_prompt = mock_generate.call_args[1].get(
|
||||
"contents", mock_generate.call_args[0][1] if len(mock_generate.call_args[0]) > 1 else None
|
||||
)
|
||||
assert actual_prompt == custom_prompt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_override_skips_optimization(self, settings):
|
||||
"""prompt_override should bypass prompt optimization."""
|
||||
client = GeminiClient(settings)
|
||||
client._enable_optimization = True
|
||||
|
||||
custom_prompt = "Custom playbook prompt"
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = '{"action": "HOLD", "confidence": 50, "rationale": "ok"}'
|
||||
|
||||
with patch.object(
|
||||
client._client.aio.models,
|
||||
"generate_content",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_generate:
|
||||
market_data = {
|
||||
"stock_code": "PLANNER",
|
||||
"current_price": 0,
|
||||
"prompt_override": custom_prompt,
|
||||
}
|
||||
await client.decide(market_data)
|
||||
|
||||
actual_prompt = mock_generate.call_args[1].get(
|
||||
"contents", mock_generate.call_args[0][1] if len(mock_generate.call_args[0]) > 1 else None
|
||||
)
|
||||
assert actual_prompt == custom_prompt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_without_prompt_override_uses_build_prompt(self, settings):
|
||||
"""Without prompt_override, decide() should use build_prompt as before."""
|
||||
client = GeminiClient(settings)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = '{"action": "HOLD", "confidence": 50, "rationale": "ok"}'
|
||||
|
||||
with patch.object(
|
||||
client._client.aio.models,
|
||||
"generate_content",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_generate:
|
||||
market_data = {
|
||||
"stock_code": "005930",
|
||||
"current_price": 72000,
|
||||
}
|
||||
await client.decide(market_data)
|
||||
|
||||
actual_prompt = mock_generate.call_args[1].get(
|
||||
"contents", mock_generate.call_args[0][1] if len(mock_generate.call_args[0]) > 1 else None
|
||||
)
|
||||
# Should contain stock code from build_prompt, not be a custom override
|
||||
assert "005930" in actual_prompt
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -49,6 +49,110 @@ class TestTokenManagement:
|
||||
|
||||
await broker.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_token_refresh_calls_api_once(self, settings):
|
||||
"""Multiple concurrent token requests should only call API once."""
|
||||
broker = KISBroker(settings)
|
||||
|
||||
# Track how many times the mock API is called
|
||||
call_count = [0]
|
||||
|
||||
def create_mock_resp():
|
||||
call_count[0] += 1
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(
|
||||
return_value={
|
||||
"access_token": "tok_concurrent",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 86400,
|
||||
}
|
||||
)
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
return mock_resp
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=create_mock_resp()):
|
||||
# Launch 5 concurrent token requests
|
||||
tokens = await asyncio.gather(
|
||||
broker._ensure_token(),
|
||||
broker._ensure_token(),
|
||||
broker._ensure_token(),
|
||||
broker._ensure_token(),
|
||||
broker._ensure_token(),
|
||||
)
|
||||
|
||||
# All should get the same token
|
||||
assert all(t == "tok_concurrent" for t in tokens)
|
||||
# API should be called only once (due to lock)
|
||||
assert call_count[0] == 1
|
||||
|
||||
await broker.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_refresh_cooldown_waits_then_retries(self, settings):
|
||||
"""Token refresh should wait out cooldown then retry (issue #54)."""
|
||||
broker = KISBroker(settings)
|
||||
broker._refresh_cooldown = 0.1 # Short cooldown for testing
|
||||
|
||||
# All attempts fail with 403 (EGW00133)
|
||||
mock_resp_403 = AsyncMock()
|
||||
mock_resp_403.status = 403
|
||||
mock_resp_403.text = AsyncMock(
|
||||
return_value='{"error_code":"EGW00133","error_description":"접근토큰 발급 잠시 후 다시 시도하세요(1분당 1회)"}'
|
||||
)
|
||||
mock_resp_403.__aenter__ = AsyncMock(return_value=mock_resp_403)
|
||||
mock_resp_403.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp_403):
|
||||
# First attempt should fail with 403
|
||||
with pytest.raises(ConnectionError, match="Token refresh failed"):
|
||||
await broker._ensure_token()
|
||||
|
||||
# Second attempt within cooldown should wait then retry (and still get 403)
|
||||
with pytest.raises(ConnectionError, match="Token refresh failed"):
|
||||
await broker._ensure_token()
|
||||
|
||||
await broker.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_refresh_allowed_after_cooldown(self, settings):
|
||||
"""Token refresh should be allowed after cooldown period expires."""
|
||||
broker = KISBroker(settings)
|
||||
broker._refresh_cooldown = 0.1 # Very short cooldown for testing
|
||||
|
||||
# First attempt fails
|
||||
mock_resp_403 = AsyncMock()
|
||||
mock_resp_403.status = 403
|
||||
mock_resp_403.text = AsyncMock(return_value='{"error_code":"EGW00133"}')
|
||||
mock_resp_403.__aenter__ = AsyncMock(return_value=mock_resp_403)
|
||||
mock_resp_403.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
# Second attempt succeeds
|
||||
mock_resp_200 = AsyncMock()
|
||||
mock_resp_200.status = 200
|
||||
mock_resp_200.json = AsyncMock(
|
||||
return_value={
|
||||
"access_token": "tok_after_cooldown",
|
||||
"expires_in": 86400,
|
||||
}
|
||||
)
|
||||
mock_resp_200.__aenter__ = AsyncMock(return_value=mock_resp_200)
|
||||
mock_resp_200.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp_403):
|
||||
with pytest.raises(ConnectionError, match="Token refresh failed"):
|
||||
await broker._ensure_token()
|
||||
|
||||
# Wait for cooldown to expire
|
||||
await asyncio.sleep(0.15)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp_200):
|
||||
token = await broker._ensure_token()
|
||||
assert token == "tok_after_cooldown"
|
||||
|
||||
await broker.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Network Error Handling
|
||||
@@ -107,6 +211,38 @@ class TestRateLimiter:
|
||||
await broker._rate_limiter.acquire()
|
||||
await broker.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_order_acquires_rate_limiter_twice(self, settings):
|
||||
"""send_order must acquire rate limiter for both hash key and order call."""
|
||||
broker = KISBroker(settings)
|
||||
broker._access_token = "tok"
|
||||
broker._token_expires_at = asyncio.get_event_loop().time() + 3600
|
||||
|
||||
# Mock hash key response
|
||||
mock_hash_resp = AsyncMock()
|
||||
mock_hash_resp.status = 200
|
||||
mock_hash_resp.json = AsyncMock(return_value={"HASH": "abc123"})
|
||||
mock_hash_resp.__aenter__ = AsyncMock(return_value=mock_hash_resp)
|
||||
mock_hash_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
# Mock order response
|
||||
mock_order_resp = AsyncMock()
|
||||
mock_order_resp.status = 200
|
||||
mock_order_resp.json = AsyncMock(return_value={"rt_cd": "0"})
|
||||
mock_order_resp.__aenter__ = AsyncMock(return_value=mock_order_resp)
|
||||
mock_order_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post", side_effect=[mock_hash_resp, mock_order_resp]
|
||||
):
|
||||
with patch.object(
|
||||
broker._rate_limiter, "acquire", new_callable=AsyncMock
|
||||
) as mock_acquire:
|
||||
await broker.send_order("005930", "BUY", 1, 50000)
|
||||
assert mock_acquire.call_count == 2
|
||||
|
||||
await broker.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hash Key Generation
|
||||
@@ -136,3 +272,648 @@ class TestHashKey:
|
||||
assert len(hash_key) > 0
|
||||
|
||||
await broker.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hash_key_acquires_rate_limiter(self, settings):
|
||||
"""_get_hash_key must go through the rate limiter to prevent burst."""
|
||||
broker = KISBroker(settings)
|
||||
broker._access_token = "tok"
|
||||
broker._token_expires_at = asyncio.get_event_loop().time() + 3600
|
||||
|
||||
body = {"CANO": "12345678", "ACNT_PRDT_CD": "01"}
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(return_value={"HASH": "abc123hash"})
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
|
||||
with patch.object(
|
||||
broker._rate_limiter, "acquire", new_callable=AsyncMock
|
||||
) as mock_acquire:
|
||||
await broker._get_hash_key(body)
|
||||
mock_acquire.assert_called_once()
|
||||
|
||||
await broker.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# fetch_market_rankings — TR_ID, path, params (issue #155)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_ranking_mock(items: list[dict]) -> AsyncMock:
|
||||
"""Build a mock HTTP response returning ranking items."""
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(return_value={"output": items})
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
return mock_resp
|
||||
|
||||
|
||||
class TestFetchMarketRankings:
|
||||
"""Verify correct TR_ID, API path, and params per ranking_type (issue #155)."""
|
||||
|
||||
@pytest.fixture
|
||||
def broker(self, settings) -> KISBroker:
|
||||
b = KISBroker(settings)
|
||||
b._access_token = "tok"
|
||||
b._token_expires_at = float("inf")
|
||||
b._rate_limiter.acquire = AsyncMock()
|
||||
return b
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_volume_uses_correct_tr_id_and_path(self, broker: KISBroker) -> None:
|
||||
mock_resp = _make_ranking_mock([])
|
||||
with patch("aiohttp.ClientSession.get", return_value=mock_resp) as mock_get:
|
||||
await broker.fetch_market_rankings(ranking_type="volume")
|
||||
|
||||
call_kwargs = mock_get.call_args
|
||||
url = call_kwargs[0][0] if call_kwargs[0] else call_kwargs[1].get("url", "")
|
||||
headers = call_kwargs[1].get("headers", {})
|
||||
params = call_kwargs[1].get("params", {})
|
||||
|
||||
assert "volume-rank" in url
|
||||
assert headers.get("tr_id") == "FHPST01710000"
|
||||
assert params.get("FID_COND_SCR_DIV_CODE") == "20171"
|
||||
assert params.get("FID_TRGT_EXLS_CLS_CODE") == "0000000000"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fluctuation_uses_correct_tr_id_and_path(self, broker: KISBroker) -> None:
|
||||
mock_resp = _make_ranking_mock([])
|
||||
with patch("aiohttp.ClientSession.get", return_value=mock_resp) as mock_get:
|
||||
await broker.fetch_market_rankings(ranking_type="fluctuation")
|
||||
|
||||
call_kwargs = mock_get.call_args
|
||||
url = call_kwargs[0][0] if call_kwargs[0] else call_kwargs[1].get("url", "")
|
||||
headers = call_kwargs[1].get("headers", {})
|
||||
params = call_kwargs[1].get("params", {})
|
||||
|
||||
assert "ranking/fluctuation" in url
|
||||
assert headers.get("tr_id") == "FHPST01700000"
|
||||
assert params.get("fid_cond_scr_div_code") == "20170"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_volume_returns_parsed_rows(self, broker: KISBroker) -> None:
|
||||
items = [
|
||||
{
|
||||
"mksc_shrn_iscd": "005930",
|
||||
"hts_kor_isnm": "삼성전자",
|
||||
"stck_prpr": "75000",
|
||||
"acml_vol": "10000000",
|
||||
"prdy_ctrt": "2.5",
|
||||
"vol_inrt": "150",
|
||||
}
|
||||
]
|
||||
mock_resp = _make_ranking_mock(items)
|
||||
with patch("aiohttp.ClientSession.get", return_value=mock_resp):
|
||||
result = await broker.fetch_market_rankings(ranking_type="volume")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["stock_code"] == "005930"
|
||||
assert result[0]["price"] == 75000.0
|
||||
assert result[0]["change_rate"] == 2.5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# KRX tick unit / round-down helpers (issue #157)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
from src.broker.kis_api import kr_tick_unit, kr_round_down # noqa: E402
|
||||
|
||||
|
||||
class TestKrTickUnit:
|
||||
"""kr_tick_unit and kr_round_down must implement KRX price tick rules."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"price, expected_tick",
|
||||
[
|
||||
(1999, 1),
|
||||
(2000, 5),
|
||||
(4999, 5),
|
||||
(5000, 10),
|
||||
(19999, 10),
|
||||
(20000, 50),
|
||||
(49999, 50),
|
||||
(50000, 100),
|
||||
(199999, 100),
|
||||
(200000, 500),
|
||||
(499999, 500),
|
||||
(500000, 1000),
|
||||
(1000000, 1000),
|
||||
],
|
||||
)
|
||||
def test_tick_unit_boundaries(self, price: int, expected_tick: int) -> None:
|
||||
assert kr_tick_unit(price) == expected_tick
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"price, expected_rounded",
|
||||
[
|
||||
(188150, 188100), # 100원 단위, 50원 잔여 → 내림
|
||||
(188100, 188100), # 이미 정렬됨
|
||||
(75050, 75000), # 100원 단위, 50원 잔여 → 내림
|
||||
(49950, 49950), # 50원 단위 정렬됨
|
||||
(49960, 49950), # 50원 단위, 10원 잔여 → 내림
|
||||
(1999, 1999), # 1원 단위 → 그대로
|
||||
(5003, 5000), # 10원 단위, 3원 잔여 → 내림
|
||||
],
|
||||
)
|
||||
def test_round_down_to_tick(self, price: int, expected_rounded: int) -> None:
|
||||
assert kr_round_down(price) == expected_rounded
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_current_price (issue #157)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetCurrentPrice:
|
||||
"""get_current_price must use inquire-price API and return (price, change, foreigner)."""
|
||||
|
||||
@pytest.fixture
|
||||
def broker(self, settings) -> KISBroker:
|
||||
b = KISBroker(settings)
|
||||
b._access_token = "tok"
|
||||
b._token_expires_at = float("inf")
|
||||
b._rate_limiter.acquire = AsyncMock()
|
||||
return b
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_correct_fields(self, broker: KISBroker) -> None:
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(
|
||||
return_value={
|
||||
"rt_cd": "0",
|
||||
"output": {
|
||||
"stck_prpr": "188600",
|
||||
"prdy_ctrt": "3.97",
|
||||
"frgn_ntby_qty": "12345",
|
||||
},
|
||||
}
|
||||
)
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.get", return_value=mock_resp) as mock_get:
|
||||
price, change_pct, foreigner = await broker.get_current_price("005930")
|
||||
|
||||
assert price == 188600.0
|
||||
assert change_pct == 3.97
|
||||
assert foreigner == 12345.0
|
||||
|
||||
call_kwargs = mock_get.call_args
|
||||
url = call_kwargs[0][0] if call_kwargs[0] else call_kwargs[1].get("url", "")
|
||||
headers = call_kwargs[1].get("headers", {})
|
||||
assert "inquire-price" in url
|
||||
assert headers.get("tr_id") == "FHKST01010100"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_error_raises_connection_error(self, broker: KISBroker) -> None:
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 500
|
||||
mock_resp.text = AsyncMock(return_value="Internal Server Error")
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.get", return_value=mock_resp):
|
||||
with pytest.raises(ConnectionError, match="get_current_price failed"):
|
||||
await broker.get_current_price("005930")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# send_order tick rounding and ORD_DVSN (issue #157)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSendOrderTickRounding:
|
||||
"""send_order must apply KRX tick rounding and correct ORD_DVSN codes."""
|
||||
|
||||
@pytest.fixture
|
||||
def broker(self, settings) -> KISBroker:
|
||||
b = KISBroker(settings)
|
||||
b._access_token = "tok"
|
||||
b._token_expires_at = float("inf")
|
||||
b._rate_limiter.acquire = AsyncMock()
|
||||
return b
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_limit_order_rounds_down_to_tick(self, broker: KISBroker) -> None:
|
||||
"""Price 188150 (not on 100-won tick) must be rounded to 188100."""
|
||||
mock_hash = AsyncMock()
|
||||
mock_hash.status = 200
|
||||
mock_hash.json = AsyncMock(return_value={"HASH": "h"})
|
||||
mock_hash.__aenter__ = AsyncMock(return_value=mock_hash)
|
||||
mock_hash.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_order = AsyncMock()
|
||||
mock_order.status = 200
|
||||
mock_order.json = AsyncMock(return_value={"rt_cd": "0"})
|
||||
mock_order.__aenter__ = AsyncMock(return_value=mock_order)
|
||||
mock_order.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
|
||||
) as mock_post:
|
||||
await broker.send_order("005930", "BUY", 1, price=188150)
|
||||
|
||||
order_call = mock_post.call_args_list[1]
|
||||
body = order_call[1].get("json", {})
|
||||
assert body["ORD_UNPR"] == "188100" # rounded down
|
||||
assert body["ORD_DVSN"] == "00" # 지정가
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_limit_order_ord_dvsn_is_00(self, broker: KISBroker) -> None:
|
||||
"""send_order with price>0 must use ORD_DVSN='00' (지정가)."""
|
||||
mock_hash = AsyncMock()
|
||||
mock_hash.status = 200
|
||||
mock_hash.json = AsyncMock(return_value={"HASH": "h"})
|
||||
mock_hash.__aenter__ = AsyncMock(return_value=mock_hash)
|
||||
mock_hash.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_order = AsyncMock()
|
||||
mock_order.status = 200
|
||||
mock_order.json = AsyncMock(return_value={"rt_cd": "0"})
|
||||
mock_order.__aenter__ = AsyncMock(return_value=mock_order)
|
||||
mock_order.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
|
||||
) as mock_post:
|
||||
await broker.send_order("005930", "BUY", 1, price=50000)
|
||||
|
||||
order_call = mock_post.call_args_list[1]
|
||||
body = order_call[1].get("json", {})
|
||||
assert body["ORD_DVSN"] == "00"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_market_order_ord_dvsn_is_01(self, broker: KISBroker) -> None:
|
||||
"""send_order with price=0 must use ORD_DVSN='01' (시장가)."""
|
||||
mock_hash = AsyncMock()
|
||||
mock_hash.status = 200
|
||||
mock_hash.json = AsyncMock(return_value={"HASH": "h"})
|
||||
mock_hash.__aenter__ = AsyncMock(return_value=mock_hash)
|
||||
mock_hash.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_order = AsyncMock()
|
||||
mock_order.status = 200
|
||||
mock_order.json = AsyncMock(return_value={"rt_cd": "0"})
|
||||
mock_order.__aenter__ = AsyncMock(return_value=mock_order)
|
||||
mock_order.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
|
||||
) as mock_post:
|
||||
await broker.send_order("005930", "SELL", 1, price=0)
|
||||
|
||||
order_call = mock_post.call_args_list[1]
|
||||
body = order_call[1].get("json", {})
|
||||
assert body["ORD_DVSN"] == "01"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TR_ID live/paper branching (issues #201, #202, #203)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTRIDBranchingDomestic:
|
||||
"""get_balance and send_order must use correct TR_ID for live vs paper mode."""
|
||||
|
||||
def _make_broker(self, settings, mode: str) -> KISBroker:
|
||||
from src.config import Settings
|
||||
|
||||
s = Settings(
|
||||
KIS_APP_KEY=settings.KIS_APP_KEY,
|
||||
KIS_APP_SECRET=settings.KIS_APP_SECRET,
|
||||
KIS_ACCOUNT_NO=settings.KIS_ACCOUNT_NO,
|
||||
GEMINI_API_KEY=settings.GEMINI_API_KEY,
|
||||
DB_PATH=":memory:",
|
||||
ENABLED_MARKETS="KR",
|
||||
MODE=mode,
|
||||
)
|
||||
b = KISBroker(s)
|
||||
b._access_token = "tok"
|
||||
b._token_expires_at = float("inf")
|
||||
b._rate_limiter.acquire = AsyncMock()
|
||||
return b
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_balance_paper_uses_vttc8434r(self, settings) -> None:
|
||||
broker = self._make_broker(settings, "paper")
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(
|
||||
return_value={"output1": [], "output2": {}}
|
||||
)
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.get", return_value=mock_resp) as mock_get:
|
||||
await broker.get_balance()
|
||||
|
||||
headers = mock_get.call_args[1].get("headers", {})
|
||||
assert headers["tr_id"] == "VTTC8434R"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_balance_live_uses_tttc8434r(self, settings) -> None:
|
||||
broker = self._make_broker(settings, "live")
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(
|
||||
return_value={"output1": [], "output2": {}}
|
||||
)
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.get", return_value=mock_resp) as mock_get:
|
||||
await broker.get_balance()
|
||||
|
||||
headers = mock_get.call_args[1].get("headers", {})
|
||||
assert headers["tr_id"] == "TTTC8434R"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_order_buy_paper_uses_vttc0012u(self, settings) -> None:
|
||||
broker = self._make_broker(settings, "paper")
|
||||
mock_hash = AsyncMock()
|
||||
mock_hash.status = 200
|
||||
mock_hash.json = AsyncMock(return_value={"HASH": "h"})
|
||||
mock_hash.__aenter__ = AsyncMock(return_value=mock_hash)
|
||||
mock_hash.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_order = AsyncMock()
|
||||
mock_order.status = 200
|
||||
mock_order.json = AsyncMock(return_value={"rt_cd": "0"})
|
||||
mock_order.__aenter__ = AsyncMock(return_value=mock_order)
|
||||
mock_order.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
|
||||
) as mock_post:
|
||||
await broker.send_order("005930", "BUY", 1)
|
||||
|
||||
order_headers = mock_post.call_args_list[1][1].get("headers", {})
|
||||
assert order_headers["tr_id"] == "VTTC0012U"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_order_buy_live_uses_tttc0012u(self, settings) -> None:
|
||||
broker = self._make_broker(settings, "live")
|
||||
mock_hash = AsyncMock()
|
||||
mock_hash.status = 200
|
||||
mock_hash.json = AsyncMock(return_value={"HASH": "h"})
|
||||
mock_hash.__aenter__ = AsyncMock(return_value=mock_hash)
|
||||
mock_hash.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_order = AsyncMock()
|
||||
mock_order.status = 200
|
||||
mock_order.json = AsyncMock(return_value={"rt_cd": "0"})
|
||||
mock_order.__aenter__ = AsyncMock(return_value=mock_order)
|
||||
mock_order.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
|
||||
) as mock_post:
|
||||
await broker.send_order("005930", "BUY", 1)
|
||||
|
||||
order_headers = mock_post.call_args_list[1][1].get("headers", {})
|
||||
assert order_headers["tr_id"] == "TTTC0012U"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_order_sell_paper_uses_vttc0011u(self, settings) -> None:
|
||||
broker = self._make_broker(settings, "paper")
|
||||
mock_hash = AsyncMock()
|
||||
mock_hash.status = 200
|
||||
mock_hash.json = AsyncMock(return_value={"HASH": "h"})
|
||||
mock_hash.__aenter__ = AsyncMock(return_value=mock_hash)
|
||||
mock_hash.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_order = AsyncMock()
|
||||
mock_order.status = 200
|
||||
mock_order.json = AsyncMock(return_value={"rt_cd": "0"})
|
||||
mock_order.__aenter__ = AsyncMock(return_value=mock_order)
|
||||
mock_order.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
|
||||
) as mock_post:
|
||||
await broker.send_order("005930", "SELL", 1)
|
||||
|
||||
order_headers = mock_post.call_args_list[1][1].get("headers", {})
|
||||
assert order_headers["tr_id"] == "VTTC0011U"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_order_sell_live_uses_tttc0011u(self, settings) -> None:
|
||||
broker = self._make_broker(settings, "live")
|
||||
mock_hash = AsyncMock()
|
||||
mock_hash.status = 200
|
||||
mock_hash.json = AsyncMock(return_value={"HASH": "h"})
|
||||
mock_hash.__aenter__ = AsyncMock(return_value=mock_hash)
|
||||
mock_hash.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_order = AsyncMock()
|
||||
mock_order.status = 200
|
||||
mock_order.json = AsyncMock(return_value={"rt_cd": "0"})
|
||||
mock_order.__aenter__ = AsyncMock(return_value=mock_order)
|
||||
mock_order.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
|
||||
) as mock_post:
|
||||
await broker.send_order("005930", "SELL", 1)
|
||||
|
||||
order_headers = mock_post.call_args_list[1][1].get("headers", {})
|
||||
assert order_headers["tr_id"] == "TTTC0011U"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Domestic Pending Orders (get_domestic_pending_orders)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetDomesticPendingOrders:
|
||||
"""get_domestic_pending_orders must return [] in paper mode and call TTTC0084R in live."""
|
||||
|
||||
def _make_broker(self, settings, mode: str) -> KISBroker:
|
||||
from src.config import Settings
|
||||
|
||||
s = Settings(
|
||||
KIS_APP_KEY=settings.KIS_APP_KEY,
|
||||
KIS_APP_SECRET=settings.KIS_APP_SECRET,
|
||||
KIS_ACCOUNT_NO=settings.KIS_ACCOUNT_NO,
|
||||
GEMINI_API_KEY=settings.GEMINI_API_KEY,
|
||||
DB_PATH=":memory:",
|
||||
ENABLED_MARKETS="KR",
|
||||
MODE=mode,
|
||||
)
|
||||
b = KISBroker(s)
|
||||
b._access_token = "tok"
|
||||
b._token_expires_at = float("inf")
|
||||
b._rate_limiter.acquire = AsyncMock()
|
||||
return b
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_paper_mode_returns_empty(self, settings) -> None:
|
||||
"""Paper mode must return [] immediately without any API call."""
|
||||
broker = self._make_broker(settings, "paper")
|
||||
|
||||
with patch("aiohttp.ClientSession.get") as mock_get:
|
||||
result = await broker.get_domestic_pending_orders()
|
||||
|
||||
assert result == []
|
||||
mock_get.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_mode_calls_tttc0084r_with_correct_params(
|
||||
self, settings
|
||||
) -> None:
|
||||
"""Live mode must call TTTC0084R with INQR_DVSN_1/2 and paging params."""
|
||||
broker = self._make_broker(settings, "live")
|
||||
pending = [{"odno": "001", "pdno": "005930", "psbl_qty": "10"}]
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(return_value={"output": pending})
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.get", return_value=mock_resp) as mock_get:
|
||||
result = await broker.get_domestic_pending_orders()
|
||||
|
||||
assert result == pending
|
||||
headers = mock_get.call_args[1].get("headers", {})
|
||||
assert headers["tr_id"] == "TTTC0084R"
|
||||
params = mock_get.call_args[1].get("params", {})
|
||||
assert params["INQR_DVSN_1"] == "0"
|
||||
assert params["INQR_DVSN_2"] == "0"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_mode_connection_error(self, settings) -> None:
|
||||
"""Network error must raise ConnectionError."""
|
||||
import aiohttp as _aiohttp
|
||||
|
||||
broker = self._make_broker(settings, "live")
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.get",
|
||||
side_effect=_aiohttp.ClientError("timeout"),
|
||||
):
|
||||
with pytest.raises(ConnectionError):
|
||||
await broker.get_domestic_pending_orders()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Domestic Order Cancellation (cancel_domestic_order)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCancelDomesticOrder:
|
||||
"""cancel_domestic_order must use correct TR_ID and build body correctly."""
|
||||
|
||||
def _make_broker(self, settings, mode: str) -> KISBroker:
|
||||
from src.config import Settings
|
||||
|
||||
s = Settings(
|
||||
KIS_APP_KEY=settings.KIS_APP_KEY,
|
||||
KIS_APP_SECRET=settings.KIS_APP_SECRET,
|
||||
KIS_ACCOUNT_NO=settings.KIS_ACCOUNT_NO,
|
||||
GEMINI_API_KEY=settings.GEMINI_API_KEY,
|
||||
DB_PATH=":memory:",
|
||||
ENABLED_MARKETS="KR",
|
||||
MODE=mode,
|
||||
)
|
||||
b = KISBroker(s)
|
||||
b._access_token = "tok"
|
||||
b._token_expires_at = float("inf")
|
||||
b._rate_limiter.acquire = AsyncMock()
|
||||
return b
|
||||
|
||||
def _make_post_mocks(self, order_payload: dict) -> tuple:
|
||||
mock_hash = AsyncMock()
|
||||
mock_hash.status = 200
|
||||
mock_hash.json = AsyncMock(return_value={"HASH": "h"})
|
||||
mock_hash.__aenter__ = AsyncMock(return_value=mock_hash)
|
||||
mock_hash.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_order = AsyncMock()
|
||||
mock_order.status = 200
|
||||
mock_order.json = AsyncMock(return_value=order_payload)
|
||||
mock_order.__aenter__ = AsyncMock(return_value=mock_order)
|
||||
mock_order.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
return mock_hash, mock_order
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_uses_tttc0013u(self, settings) -> None:
|
||||
"""Live mode must use TR_ID TTTC0013U."""
|
||||
broker = self._make_broker(settings, "live")
|
||||
mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"})
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
|
||||
) as mock_post:
|
||||
await broker.cancel_domestic_order("005930", "ORD001", "BRNO01", 5)
|
||||
|
||||
order_headers = mock_post.call_args_list[1][1].get("headers", {})
|
||||
assert order_headers["tr_id"] == "TTTC0013U"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_paper_uses_vttc0013u(self, settings) -> None:
|
||||
"""Paper mode must use TR_ID VTTC0013U."""
|
||||
broker = self._make_broker(settings, "paper")
|
||||
mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"})
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
|
||||
) as mock_post:
|
||||
await broker.cancel_domestic_order("005930", "ORD001", "BRNO01", 5)
|
||||
|
||||
order_headers = mock_post.call_args_list[1][1].get("headers", {})
|
||||
assert order_headers["tr_id"] == "VTTC0013U"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_sets_rvse_cncl_dvsn_cd_02(self, settings) -> None:
|
||||
"""Body must have RVSE_CNCL_DVSN_CD='02' (취소) and QTY_ALL_ORD_YN='Y'."""
|
||||
broker = self._make_broker(settings, "live")
|
||||
mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"})
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
|
||||
) as mock_post:
|
||||
await broker.cancel_domestic_order("005930", "ORD001", "BRNO01", 5)
|
||||
|
||||
body = mock_post.call_args_list[1][1].get("json", {})
|
||||
assert body["RVSE_CNCL_DVSN_CD"] == "02"
|
||||
assert body["QTY_ALL_ORD_YN"] == "Y"
|
||||
assert body["ORD_UNPR"] == "0"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_sets_krx_fwdg_ord_orgno_in_body(self, settings) -> None:
|
||||
"""Body must include KRX_FWDG_ORD_ORGNO and ORGN_ODNO from arguments."""
|
||||
broker = self._make_broker(settings, "live")
|
||||
mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"})
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
|
||||
) as mock_post:
|
||||
await broker.cancel_domestic_order("005930", "ORD123", "BRN456", 3)
|
||||
|
||||
body = mock_post.call_args_list[1][1].get("json", {})
|
||||
assert body["KRX_FWDG_ORD_ORGNO"] == "BRN456"
|
||||
assert body["ORGN_ODNO"] == "ORD123"
|
||||
assert body["ORD_QTY"] == "3"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_sets_hashkey_header(self, settings) -> None:
|
||||
"""Request must include hashkey header (same pattern as send_order)."""
|
||||
broker = self._make_broker(settings, "live")
|
||||
mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"})
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
|
||||
) as mock_post:
|
||||
await broker.cancel_domestic_order("005930", "ORD001", "BRNO01", 2)
|
||||
|
||||
order_headers = mock_post.call_args_list[1][1].get("headers", {})
|
||||
assert "hashkey" in order_headers
|
||||
assert order_headers["hashkey"] == "h"
|
||||
|
||||
629
tests/test_context.py
Normal file
629
tests/test_context.py
Normal file
@@ -0,0 +1,629 @@
|
||||
"""Tests for the multi-layered context management system."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from src.context.aggregator import ContextAggregator
|
||||
from src.context.layer import LAYER_CONFIG, ContextLayer
|
||||
from src.context.store import ContextStore
|
||||
from src.context.summarizer import ContextSummarizer
|
||||
from src.db import init_db, log_trade
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_conn() -> sqlite3.Connection:
|
||||
"""Provide an in-memory database connection."""
|
||||
return init_db(":memory:")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(db_conn: sqlite3.Connection) -> ContextStore:
|
||||
"""Provide a ContextStore instance."""
|
||||
return ContextStore(db_conn)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def aggregator(db_conn: sqlite3.Connection) -> ContextAggregator:
|
||||
"""Provide a ContextAggregator instance."""
|
||||
return ContextAggregator(db_conn)
|
||||
|
||||
|
||||
class TestContextStore:
|
||||
"""Test suite for ContextStore CRUD operations."""
|
||||
|
||||
def test_set_and_get_context(self, store: ContextStore) -> None:
|
||||
"""Test setting and retrieving a context value."""
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-04", "total_pnl", 1234.56)
|
||||
|
||||
value = store.get_context(ContextLayer.L6_DAILY, "2026-02-04", "total_pnl")
|
||||
assert value == 1234.56
|
||||
|
||||
def test_get_nonexistent_context(self, store: ContextStore) -> None:
|
||||
"""Test retrieving a non-existent context returns None."""
|
||||
value = store.get_context(ContextLayer.L6_DAILY, "2026-02-04", "nonexistent")
|
||||
assert value is None
|
||||
|
||||
def test_update_existing_context(self, store: ContextStore) -> None:
|
||||
"""Test updating an existing context value."""
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-04", "total_pnl", 100.0)
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-04", "total_pnl", 200.0)
|
||||
|
||||
value = store.get_context(ContextLayer.L6_DAILY, "2026-02-04", "total_pnl")
|
||||
assert value == 200.0
|
||||
|
||||
def test_get_all_contexts_for_layer(self, store: ContextStore) -> None:
|
||||
"""Test retrieving all contexts for a specific layer."""
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-04", "total_pnl", 100.0)
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-04", "trade_count", 10)
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-04", "win_rate", 60.5)
|
||||
|
||||
contexts = store.get_all_contexts(ContextLayer.L6_DAILY, "2026-02-04")
|
||||
assert len(contexts) == 3
|
||||
assert contexts["total_pnl"] == 100.0
|
||||
assert contexts["trade_count"] == 10
|
||||
assert contexts["win_rate"] == 60.5
|
||||
|
||||
def test_get_latest_timeframe(self, store: ContextStore) -> None:
|
||||
"""Test getting the most recent timeframe for a layer."""
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "total_pnl", 100.0)
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-03", "total_pnl", 200.0)
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-02", "total_pnl", 150.0)
|
||||
|
||||
latest = store.get_latest_timeframe(ContextLayer.L6_DAILY)
|
||||
# Latest by updated_at, which should be the last one set
|
||||
assert latest == "2026-02-02"
|
||||
|
||||
def test_delete_old_contexts(
|
||||
self, store: ContextStore, db_conn: sqlite3.Connection
|
||||
) -> None:
|
||||
"""Test deleting contexts older than a cutoff date."""
|
||||
# Insert contexts with specific old timestamps
|
||||
# (bypassing set_context which uses current time)
|
||||
old_date = "2026-01-01T00:00:00+00:00"
|
||||
new_date = "2026-02-01T00:00:00+00:00"
|
||||
|
||||
db_conn.execute(
|
||||
"""
|
||||
INSERT INTO contexts (layer, timeframe, key, value, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(ContextLayer.L6_DAILY.value, "2026-01-01", "total_pnl", "100.0", old_date, old_date),
|
||||
)
|
||||
db_conn.execute(
|
||||
"""
|
||||
INSERT INTO contexts (layer, timeframe, key, value, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(ContextLayer.L6_DAILY.value, "2026-02-01", "total_pnl", "200.0", new_date, new_date),
|
||||
)
|
||||
db_conn.commit()
|
||||
|
||||
# Delete contexts before 2026-01-15
|
||||
cutoff = "2026-01-15T00:00:00+00:00"
|
||||
deleted = store.delete_old_contexts(ContextLayer.L6_DAILY, cutoff)
|
||||
|
||||
# Should delete the 2026-01-01 context
|
||||
assert deleted == 1
|
||||
assert store.get_context(ContextLayer.L6_DAILY, "2026-02-01", "total_pnl") == 200.0
|
||||
assert store.get_context(ContextLayer.L6_DAILY, "2026-01-01", "total_pnl") is None
|
||||
|
||||
def test_cleanup_expired_contexts(
|
||||
self, store: ContextStore, db_conn: sqlite3.Connection
|
||||
) -> None:
|
||||
"""Test automatic cleanup based on retention policies."""
|
||||
# Set old contexts for L7 (7 day retention)
|
||||
old_date = (datetime.now(UTC) - timedelta(days=10)).isoformat()
|
||||
db_conn.execute(
|
||||
"""
|
||||
INSERT INTO contexts (layer, timeframe, key, value, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(ContextLayer.L7_REALTIME.value, "2026-01-01", "price", "100.0", old_date, old_date),
|
||||
)
|
||||
db_conn.commit()
|
||||
|
||||
deleted_counts = store.cleanup_expired_contexts()
|
||||
|
||||
# Should delete the old L7 context (10 days > 7 day retention)
|
||||
assert deleted_counts[ContextLayer.L7_REALTIME] == 1
|
||||
|
||||
# L1 has no retention limit, so nothing should be deleted
|
||||
assert deleted_counts[ContextLayer.L1_LEGACY] == 0
|
||||
|
||||
def test_context_metadata_initialized(
|
||||
self, store: ContextStore, db_conn: sqlite3.Connection
|
||||
) -> None:
|
||||
"""Test that context metadata is properly initialized."""
|
||||
cursor = db_conn.execute("SELECT COUNT(*) FROM context_metadata")
|
||||
count = cursor.fetchone()[0]
|
||||
|
||||
# Should have metadata for all 7 layers
|
||||
assert count == 7
|
||||
|
||||
# Verify L1 metadata
|
||||
cursor = db_conn.execute(
|
||||
"SELECT description, retention_days FROM context_metadata WHERE layer = ?",
|
||||
(ContextLayer.L1_LEGACY.value,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
assert row is not None
|
||||
assert "Cumulative trading history" in row[0]
|
||||
assert row[1] is None # No retention limit for L1
|
||||
|
||||
|
||||
class TestContextAggregator:
|
||||
"""Test suite for ContextAggregator."""
|
||||
|
||||
def test_aggregate_daily_from_trades(
|
||||
self, aggregator: ContextAggregator, db_conn: sqlite3.Connection
|
||||
) -> None:
|
||||
"""Test aggregating daily metrics from trades."""
|
||||
date = datetime.now(UTC).date().isoformat()
|
||||
|
||||
# Create sample trades
|
||||
log_trade(db_conn, "005930", "BUY", 85, "Good signal", quantity=10, price=70000, pnl=500)
|
||||
log_trade(db_conn, "000660", "SELL", 90, "Take profit", quantity=5, price=50000, pnl=1500)
|
||||
log_trade(db_conn, "035720", "HOLD", 75, "Wait", quantity=0, price=0, pnl=0)
|
||||
|
||||
# Manually set timestamps to the target date
|
||||
db_conn.execute(
|
||||
f"UPDATE trades SET timestamp = '{date}T10:00:00+00:00'"
|
||||
)
|
||||
db_conn.commit()
|
||||
|
||||
# Aggregate
|
||||
aggregator.aggregate_daily_from_trades(date, market="KR")
|
||||
|
||||
# Verify L6 contexts
|
||||
store = aggregator.store
|
||||
assert store.get_context(ContextLayer.L6_DAILY, date, "trade_count_KR") == 3
|
||||
assert store.get_context(ContextLayer.L6_DAILY, date, "buys_KR") == 1
|
||||
assert store.get_context(ContextLayer.L6_DAILY, date, "sells_KR") == 1
|
||||
assert store.get_context(ContextLayer.L6_DAILY, date, "holds_KR") == 1
|
||||
assert store.get_context(ContextLayer.L6_DAILY, date, "total_pnl_KR") == 2000.0
|
||||
assert store.get_context(ContextLayer.L6_DAILY, date, "unique_stocks_KR") == 3
|
||||
# 2 wins, 0 losses
|
||||
assert store.get_context(ContextLayer.L6_DAILY, date, "win_rate_KR") == 100.0
|
||||
|
||||
def test_aggregate_weekly_from_daily(self, aggregator: ContextAggregator) -> None:
|
||||
"""Test aggregating weekly metrics from daily."""
|
||||
week = "2026-W06"
|
||||
|
||||
# Set daily contexts
|
||||
aggregator.store.set_context(
|
||||
ContextLayer.L6_DAILY, "2026-02-02", "total_pnl_KR", 100.0
|
||||
)
|
||||
aggregator.store.set_context(
|
||||
ContextLayer.L6_DAILY, "2026-02-03", "total_pnl_KR", 200.0
|
||||
)
|
||||
aggregator.store.set_context(
|
||||
ContextLayer.L6_DAILY, "2026-02-02", "avg_confidence_KR", 80.0
|
||||
)
|
||||
aggregator.store.set_context(
|
||||
ContextLayer.L6_DAILY, "2026-02-03", "avg_confidence_KR", 85.0
|
||||
)
|
||||
|
||||
# Aggregate
|
||||
aggregator.aggregate_weekly_from_daily(week)
|
||||
|
||||
# Verify L5 contexts
|
||||
store = aggregator.store
|
||||
weekly_pnl = store.get_context(ContextLayer.L5_WEEKLY, week, "weekly_pnl_KR")
|
||||
avg_conf = store.get_context(ContextLayer.L5_WEEKLY, week, "avg_confidence_KR")
|
||||
|
||||
assert weekly_pnl == 300.0
|
||||
assert avg_conf == 82.5
|
||||
|
||||
def test_aggregate_monthly_from_weekly(self, aggregator: ContextAggregator) -> None:
|
||||
"""Test aggregating monthly metrics from weekly."""
|
||||
month = "2026-02"
|
||||
|
||||
# Set weekly contexts
|
||||
aggregator.store.set_context(
|
||||
ContextLayer.L5_WEEKLY, "2026-W05", "weekly_pnl_KR", 100.0
|
||||
)
|
||||
aggregator.store.set_context(
|
||||
ContextLayer.L5_WEEKLY, "2026-W06", "weekly_pnl_KR", 200.0
|
||||
)
|
||||
aggregator.store.set_context(
|
||||
ContextLayer.L5_WEEKLY, "2026-W07", "weekly_pnl_KR", 150.0
|
||||
)
|
||||
|
||||
# Aggregate
|
||||
aggregator.aggregate_monthly_from_weekly(month)
|
||||
|
||||
# Verify L4 contexts
|
||||
store = aggregator.store
|
||||
monthly_pnl = store.get_context(ContextLayer.L4_MONTHLY, month, "monthly_pnl")
|
||||
assert monthly_pnl == 450.0
|
||||
|
||||
def test_aggregate_quarterly_from_monthly(self, aggregator: ContextAggregator) -> None:
|
||||
"""Test aggregating quarterly metrics from monthly."""
|
||||
quarter = "2026-Q1"
|
||||
|
||||
# Set monthly contexts for Q1 (Jan, Feb, Mar)
|
||||
aggregator.store.set_context(ContextLayer.L4_MONTHLY, "2026-01", "monthly_pnl", 1000.0)
|
||||
aggregator.store.set_context(ContextLayer.L4_MONTHLY, "2026-02", "monthly_pnl", 2000.0)
|
||||
aggregator.store.set_context(ContextLayer.L4_MONTHLY, "2026-03", "monthly_pnl", 1500.0)
|
||||
|
||||
# Aggregate
|
||||
aggregator.aggregate_quarterly_from_monthly(quarter)
|
||||
|
||||
# Verify L3 contexts
|
||||
store = aggregator.store
|
||||
quarterly_pnl = store.get_context(ContextLayer.L3_QUARTERLY, quarter, "quarterly_pnl")
|
||||
assert quarterly_pnl == 4500.0
|
||||
|
||||
def test_aggregate_annual_from_quarterly(self, aggregator: ContextAggregator) -> None:
|
||||
"""Test aggregating annual metrics from quarterly."""
|
||||
year = "2026"
|
||||
|
||||
# Set quarterly contexts for all 4 quarters
|
||||
aggregator.store.set_context(ContextLayer.L3_QUARTERLY, "2026-Q1", "quarterly_pnl", 4500.0)
|
||||
aggregator.store.set_context(ContextLayer.L3_QUARTERLY, "2026-Q2", "quarterly_pnl", 5000.0)
|
||||
aggregator.store.set_context(ContextLayer.L3_QUARTERLY, "2026-Q3", "quarterly_pnl", 4800.0)
|
||||
aggregator.store.set_context(ContextLayer.L3_QUARTERLY, "2026-Q4", "quarterly_pnl", 5200.0)
|
||||
|
||||
# Aggregate
|
||||
aggregator.aggregate_annual_from_quarterly(year)
|
||||
|
||||
# Verify L2 contexts
|
||||
store = aggregator.store
|
||||
annual_pnl = store.get_context(ContextLayer.L2_ANNUAL, year, "annual_pnl")
|
||||
assert annual_pnl == 19500.0
|
||||
|
||||
def test_aggregate_legacy_from_annual(self, aggregator: ContextAggregator) -> None:
|
||||
"""Test aggregating legacy metrics from all annual data."""
|
||||
# Set annual contexts for multiple years
|
||||
aggregator.store.set_context(ContextLayer.L2_ANNUAL, "2024", "annual_pnl", 10000.0)
|
||||
aggregator.store.set_context(ContextLayer.L2_ANNUAL, "2025", "annual_pnl", 15000.0)
|
||||
aggregator.store.set_context(ContextLayer.L2_ANNUAL, "2026", "annual_pnl", 20000.0)
|
||||
|
||||
# Aggregate
|
||||
aggregator.aggregate_legacy_from_annual()
|
||||
|
||||
# Verify L1 contexts
|
||||
store = aggregator.store
|
||||
total_pnl = store.get_context(ContextLayer.L1_LEGACY, "LEGACY", "total_pnl")
|
||||
years_traded = store.get_context(ContextLayer.L1_LEGACY, "LEGACY", "years_traded")
|
||||
avg_annual_pnl = store.get_context(ContextLayer.L1_LEGACY, "LEGACY", "avg_annual_pnl")
|
||||
|
||||
assert total_pnl == 45000.0
|
||||
assert years_traded == 3
|
||||
assert avg_annual_pnl == 15000.0
|
||||
|
||||
def test_run_all_aggregations(
|
||||
self, aggregator: ContextAggregator, db_conn: sqlite3.Connection
|
||||
) -> None:
|
||||
"""Test running all aggregations from L7 to L1."""
|
||||
date = datetime.now(UTC).date().isoformat()
|
||||
|
||||
# Create sample trades
|
||||
log_trade(db_conn, "005930", "BUY", 85, "Good signal", quantity=10, price=70000, pnl=1000)
|
||||
|
||||
# Set timestamp
|
||||
db_conn.execute(f"UPDATE trades SET timestamp = '{date}T10:00:00+00:00'")
|
||||
db_conn.commit()
|
||||
|
||||
# Run all aggregations
|
||||
aggregator.run_all_aggregations()
|
||||
|
||||
# Verify data exists in each layer
|
||||
store = aggregator.store
|
||||
assert store.get_context(ContextLayer.L6_DAILY, date, "total_pnl_KR") == 1000.0
|
||||
from datetime import date as date_cls
|
||||
trade_date = date_cls.fromisoformat(date)
|
||||
iso_year, iso_week, _ = trade_date.isocalendar()
|
||||
trade_week = f"{iso_year}-W{iso_week:02d}"
|
||||
assert store.get_context(ContextLayer.L5_WEEKLY, trade_week, "weekly_pnl_KR") is not None
|
||||
trade_month = f"{trade_date.year}-{trade_date.month:02d}"
|
||||
trade_quarter = f"{trade_date.year}-Q{(trade_date.month - 1) // 3 + 1}"
|
||||
trade_year = str(trade_date.year)
|
||||
assert store.get_context(ContextLayer.L4_MONTHLY, trade_month, "monthly_pnl") == 1000.0
|
||||
assert store.get_context(ContextLayer.L3_QUARTERLY, trade_quarter, "quarterly_pnl") == 1000.0
|
||||
assert store.get_context(ContextLayer.L2_ANNUAL, trade_year, "annual_pnl") == 1000.0
|
||||
|
||||
|
||||
class TestLayerMetadata:
|
||||
"""Test suite for layer metadata configuration."""
|
||||
|
||||
def test_all_layers_have_metadata(self) -> None:
|
||||
"""Test that all 7 layers have metadata defined."""
|
||||
assert len(LAYER_CONFIG) == 7
|
||||
|
||||
for layer in ContextLayer:
|
||||
assert layer in LAYER_CONFIG
|
||||
|
||||
def test_layer_retention_policies(self) -> None:
|
||||
"""Test layer retention policies are correctly configured."""
|
||||
# L1 should have no retention limit
|
||||
assert LAYER_CONFIG[ContextLayer.L1_LEGACY].retention_days is None
|
||||
|
||||
# L7 should have the shortest retention (7 days)
|
||||
assert LAYER_CONFIG[ContextLayer.L7_REALTIME].retention_days == 7
|
||||
|
||||
# L2 should have a long retention (10 years)
|
||||
assert LAYER_CONFIG[ContextLayer.L2_ANNUAL].retention_days == 365 * 10
|
||||
|
||||
def test_layer_aggregation_chain(self) -> None:
|
||||
"""Test that the aggregation chain is properly configured."""
|
||||
# L7 has no source (leaf layer)
|
||||
assert LAYER_CONFIG[ContextLayer.L7_REALTIME].aggregation_source is None
|
||||
|
||||
# L6 aggregates from L7
|
||||
assert LAYER_CONFIG[ContextLayer.L6_DAILY].aggregation_source == ContextLayer.L7_REALTIME
|
||||
|
||||
# L5 aggregates from L6
|
||||
assert LAYER_CONFIG[ContextLayer.L5_WEEKLY].aggregation_source == ContextLayer.L6_DAILY
|
||||
|
||||
# L4 aggregates from L5
|
||||
assert LAYER_CONFIG[ContextLayer.L4_MONTHLY].aggregation_source == ContextLayer.L5_WEEKLY
|
||||
|
||||
# L3 aggregates from L4
|
||||
assert LAYER_CONFIG[ContextLayer.L3_QUARTERLY].aggregation_source == ContextLayer.L4_MONTHLY
|
||||
|
||||
# L2 aggregates from L3
|
||||
assert LAYER_CONFIG[ContextLayer.L2_ANNUAL].aggregation_source == ContextLayer.L3_QUARTERLY
|
||||
|
||||
# L1 aggregates from L2
|
||||
assert LAYER_CONFIG[ContextLayer.L1_LEGACY].aggregation_source == ContextLayer.L2_ANNUAL
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ContextSummarizer tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def summarizer(db_conn: sqlite3.Connection) -> ContextSummarizer:
|
||||
"""Provide a ContextSummarizer backed by an in-memory store."""
|
||||
return ContextSummarizer(ContextStore(db_conn))
|
||||
|
||||
|
||||
class TestContextSummarizer:
|
||||
"""Test suite for ContextSummarizer."""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# summarize_numeric_values
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_summarize_empty_values(self, summarizer: ContextSummarizer) -> None:
|
||||
"""Empty list must return SummaryStats with count=0 and no other fields."""
|
||||
stats = summarizer.summarize_numeric_values([])
|
||||
assert stats.count == 0
|
||||
assert stats.mean is None
|
||||
assert stats.min is None
|
||||
assert stats.max is None
|
||||
|
||||
def test_summarize_single_value(self, summarizer: ContextSummarizer) -> None:
|
||||
"""Single-element list must return correct stats with std=0 and trend=flat."""
|
||||
stats = summarizer.summarize_numeric_values([42.0])
|
||||
assert stats.count == 1
|
||||
assert stats.mean == 42.0
|
||||
assert stats.std == 0.0
|
||||
assert stats.trend == "flat"
|
||||
|
||||
def test_summarize_upward_trend(self, summarizer: ContextSummarizer) -> None:
|
||||
"""Increasing values must produce trend='up'."""
|
||||
values = [1.0, 2.0, 3.0, 10.0, 20.0, 30.0]
|
||||
stats = summarizer.summarize_numeric_values(values)
|
||||
assert stats.trend == "up"
|
||||
|
||||
def test_summarize_downward_trend(self, summarizer: ContextSummarizer) -> None:
|
||||
"""Decreasing values must produce trend='down'."""
|
||||
values = [30.0, 20.0, 10.0, 3.0, 2.0, 1.0]
|
||||
stats = summarizer.summarize_numeric_values(values)
|
||||
assert stats.trend == "down"
|
||||
|
||||
def test_summarize_flat_trend(self, summarizer: ContextSummarizer) -> None:
|
||||
"""Stable values must produce trend='flat'."""
|
||||
values = [100.0, 100.1, 99.9, 100.0, 100.2, 99.8]
|
||||
stats = summarizer.summarize_numeric_values(values)
|
||||
assert stats.trend == "flat"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# summarize_layer
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_summarize_layer_no_data(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""summarize_layer with no data must return the 'No data' sentinel."""
|
||||
result = summarizer.summarize_layer(ContextLayer.L6_DAILY)
|
||||
assert result["count"] == 0
|
||||
assert "No data" in result["summary"]
|
||||
|
||||
def test_summarize_layer_numeric(
|
||||
self, summarizer: ContextSummarizer, db_conn: sqlite3.Connection
|
||||
) -> None:
|
||||
"""summarize_layer must collect numeric values and produce stats."""
|
||||
store = summarizer.store
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "total_pnl", 100.0)
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-02", "total_pnl", 200.0)
|
||||
|
||||
result = summarizer.summarize_layer(ContextLayer.L6_DAILY)
|
||||
assert "total_entries" in result
|
||||
|
||||
def test_summarize_layer_with_dict_values(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""summarize_layer must handle dict values by extracting numeric subkeys."""
|
||||
store = summarizer.store
|
||||
# set_context serialises the value as JSON, so passing a dict works
|
||||
store.set_context(
|
||||
ContextLayer.L6_DAILY, "2026-02-01", "metrics",
|
||||
{"win_rate": 65.0, "label": "good"}
|
||||
)
|
||||
|
||||
result = summarizer.summarize_layer(ContextLayer.L6_DAILY)
|
||||
assert "total_entries" in result
|
||||
# numeric subkey "win_rate" should appear as "metrics.win_rate"
|
||||
assert "metrics.win_rate" in result
|
||||
|
||||
def test_summarize_layer_with_string_values(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""summarize_layer must count string values separately."""
|
||||
store = summarizer.store
|
||||
# set_context stores string values as JSON-encoded strings
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "outlook", "BULLISH")
|
||||
|
||||
result = summarizer.summarize_layer(ContextLayer.L6_DAILY)
|
||||
# String fields contribute a `<key>_count` entry
|
||||
assert "outlook_count" in result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# rolling_window_summary
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_rolling_window_summary_basic(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""rolling_window_summary must return the expected structure."""
|
||||
store = summarizer.store
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "pnl", 500.0)
|
||||
|
||||
result = summarizer.rolling_window_summary(ContextLayer.L6_DAILY)
|
||||
assert "window_days" in result
|
||||
assert "recent_data" in result
|
||||
assert "historical_summary" in result
|
||||
|
||||
def test_rolling_window_summary_no_older_data(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""rolling_window_summary with summarize_older=False skips history."""
|
||||
result = summarizer.rolling_window_summary(
|
||||
ContextLayer.L6_DAILY, summarize_older=False
|
||||
)
|
||||
assert result["historical_summary"] == {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# aggregate_to_higher_layer
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_aggregate_to_higher_layer_mean(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""aggregate_to_higher_layer with 'mean' via dict subkeys returns average."""
|
||||
store = summarizer.store
|
||||
# Use different outer keys but same inner metric key so get_all_contexts
|
||||
# returns multiple rows with the target subkey.
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day1", {"pnl": 100.0})
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day2", {"pnl": 200.0})
|
||||
|
||||
result = summarizer.aggregate_to_higher_layer(
|
||||
ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY, "pnl", "mean"
|
||||
)
|
||||
assert result == pytest.approx(150.0)
|
||||
|
||||
def test_aggregate_to_higher_layer_sum(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""aggregate_to_higher_layer with 'sum' must return the total."""
|
||||
store = summarizer.store
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day1", {"pnl": 100.0})
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day2", {"pnl": 200.0})
|
||||
|
||||
result = summarizer.aggregate_to_higher_layer(
|
||||
ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY, "pnl", "sum"
|
||||
)
|
||||
assert result == pytest.approx(300.0)
|
||||
|
||||
def test_aggregate_to_higher_layer_max(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""aggregate_to_higher_layer with 'max' must return the maximum."""
|
||||
store = summarizer.store
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day1", {"pnl": 100.0})
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day2", {"pnl": 200.0})
|
||||
|
||||
result = summarizer.aggregate_to_higher_layer(
|
||||
ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY, "pnl", "max"
|
||||
)
|
||||
assert result == pytest.approx(200.0)
|
||||
|
||||
def test_aggregate_to_higher_layer_min(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""aggregate_to_higher_layer with 'min' must return the minimum."""
|
||||
store = summarizer.store
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day1", {"pnl": 100.0})
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day2", {"pnl": 200.0})
|
||||
|
||||
result = summarizer.aggregate_to_higher_layer(
|
||||
ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY, "pnl", "min"
|
||||
)
|
||||
assert result == pytest.approx(100.0)
|
||||
|
||||
def test_aggregate_to_higher_layer_no_data(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""aggregate_to_higher_layer with no matching key must return None."""
|
||||
result = summarizer.aggregate_to_higher_layer(
|
||||
ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY, "nonexistent", "mean"
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_aggregate_to_higher_layer_unknown_func_defaults_to_mean(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""Unknown aggregation function must fall back to mean."""
|
||||
store = summarizer.store
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day1", {"pnl": 100.0})
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day2", {"pnl": 200.0})
|
||||
|
||||
result = summarizer.aggregate_to_higher_layer(
|
||||
ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY, "pnl", "unknown_func"
|
||||
)
|
||||
assert result == pytest.approx(150.0)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# create_compact_summary + format_summary_for_prompt
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_create_compact_summary(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""create_compact_summary must produce a dict keyed by layer value."""
|
||||
store = summarizer.store
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "pnl", 100.0)
|
||||
|
||||
result = summarizer.create_compact_summary([ContextLayer.L6_DAILY])
|
||||
assert ContextLayer.L6_DAILY.value in result
|
||||
|
||||
def test_format_summary_for_prompt_with_numeric_metrics(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""format_summary_for_prompt must render avg/trend fields."""
|
||||
store = summarizer.store
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "pnl", 100.0)
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-02", "pnl", 200.0)
|
||||
|
||||
compact = summarizer.create_compact_summary([ContextLayer.L6_DAILY])
|
||||
text = summarizer.format_summary_for_prompt(compact)
|
||||
assert isinstance(text, str)
|
||||
|
||||
def test_format_summary_for_prompt_skips_empty_layers(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""format_summary_for_prompt must skip layers with no metrics."""
|
||||
summary = {ContextLayer.L6_DAILY.value: {}}
|
||||
text = summarizer.format_summary_for_prompt(summary)
|
||||
assert text == ""
|
||||
|
||||
def test_format_summary_non_dict_value(
|
||||
self, summarizer: ContextSummarizer
|
||||
) -> None:
|
||||
"""format_summary_for_prompt must render non-dict values as plain text."""
|
||||
summary = {
|
||||
"daily": {
|
||||
"plain_count": 42,
|
||||
}
|
||||
}
|
||||
text = summarizer.format_summary_for_prompt(summary)
|
||||
assert "plain_count" in text
|
||||
assert "42" in text
|
||||
104
tests/test_context_scheduler.py
Normal file
104
tests/test_context_scheduler.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Tests for ContextScheduler."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from src.context.scheduler import ContextScheduler
|
||||
|
||||
|
||||
@dataclass
|
||||
class StubAggregator:
|
||||
"""Stub aggregator that records calls."""
|
||||
|
||||
weekly_calls: list[str]
|
||||
monthly_calls: list[str]
|
||||
quarterly_calls: list[str]
|
||||
annual_calls: list[str]
|
||||
legacy_calls: int
|
||||
|
||||
def aggregate_weekly_from_daily(self, week: str) -> None:
|
||||
self.weekly_calls.append(week)
|
||||
|
||||
def aggregate_monthly_from_weekly(self, month: str) -> None:
|
||||
self.monthly_calls.append(month)
|
||||
|
||||
def aggregate_quarterly_from_monthly(self, quarter: str) -> None:
|
||||
self.quarterly_calls.append(quarter)
|
||||
|
||||
def aggregate_annual_from_quarterly(self, year: str) -> None:
|
||||
self.annual_calls.append(year)
|
||||
|
||||
def aggregate_legacy_from_annual(self) -> None:
|
||||
self.legacy_calls += 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class StubStore:
|
||||
"""Stub store that records cleanup calls."""
|
||||
|
||||
cleanup_calls: int = 0
|
||||
|
||||
def cleanup_expired_contexts(self) -> None:
|
||||
self.cleanup_calls += 1
|
||||
|
||||
|
||||
def make_scheduler() -> tuple[ContextScheduler, StubAggregator, StubStore]:
|
||||
aggregator = StubAggregator([], [], [], [], 0)
|
||||
store = StubStore()
|
||||
scheduler = ContextScheduler(aggregator=aggregator, store=store)
|
||||
return scheduler, aggregator, store
|
||||
|
||||
|
||||
def test_run_if_due_weekly() -> None:
|
||||
scheduler, aggregator, store = make_scheduler()
|
||||
now = datetime(2026, 2, 8, 10, 0, tzinfo=UTC) # Sunday
|
||||
|
||||
result = scheduler.run_if_due(now)
|
||||
|
||||
assert result.weekly is True
|
||||
assert aggregator.weekly_calls == ["2026-W06"]
|
||||
assert store.cleanup_calls == 1
|
||||
|
||||
|
||||
def test_run_if_due_monthly() -> None:
|
||||
scheduler, aggregator, _store = make_scheduler()
|
||||
now = datetime(2026, 2, 28, 12, 0, tzinfo=UTC) # Last day of month
|
||||
|
||||
result = scheduler.run_if_due(now)
|
||||
|
||||
assert result.monthly is True
|
||||
assert aggregator.monthly_calls == ["2026-02"]
|
||||
|
||||
|
||||
def test_run_if_due_quarterly() -> None:
|
||||
scheduler, aggregator, _store = make_scheduler()
|
||||
now = datetime(2026, 3, 31, 12, 0, tzinfo=UTC) # Last day of Q1
|
||||
|
||||
result = scheduler.run_if_due(now)
|
||||
|
||||
assert result.quarterly is True
|
||||
assert aggregator.quarterly_calls == ["2026-Q1"]
|
||||
|
||||
|
||||
def test_run_if_due_annual_and_legacy() -> None:
|
||||
scheduler, aggregator, _store = make_scheduler()
|
||||
now = datetime(2026, 12, 31, 12, 0, tzinfo=UTC)
|
||||
|
||||
result = scheduler.run_if_due(now)
|
||||
|
||||
assert result.annual is True
|
||||
assert result.legacy is True
|
||||
assert aggregator.annual_calls == ["2026"]
|
||||
assert aggregator.legacy_calls == 1
|
||||
|
||||
|
||||
def test_cleanup_runs_once_per_day() -> None:
|
||||
scheduler, _aggregator, store = make_scheduler()
|
||||
now = datetime(2026, 2, 9, 9, 0, tzinfo=UTC)
|
||||
|
||||
scheduler.run_if_due(now)
|
||||
scheduler.run_if_due(now)
|
||||
|
||||
assert store.cleanup_calls == 1
|
||||
387
tests/test_daily_review.py
Normal file
387
tests/test_daily_review.py
Normal file
@@ -0,0 +1,387 @@
|
||||
"""Tests for DailyReviewer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from src.context.layer import ContextLayer
|
||||
from src.context.store import ContextStore
|
||||
from src.db import init_db, log_trade
|
||||
from src.evolution.daily_review import DailyReviewer
|
||||
from src.evolution.scorecard import DailyScorecard
|
||||
from src.logging.decision_logger import DecisionLogger
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
TODAY = datetime.now(UTC).strftime("%Y-%m-%d")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_conn() -> sqlite3.Connection:
|
||||
return init_db(":memory:")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def context_store(db_conn: sqlite3.Connection) -> ContextStore:
|
||||
return ContextStore(db_conn)
|
||||
|
||||
|
||||
def _log_decision(
|
||||
logger: DecisionLogger,
|
||||
*,
|
||||
stock_code: str,
|
||||
market: str,
|
||||
action: str,
|
||||
confidence: int,
|
||||
scenario_match: dict[str, float] | None = None,
|
||||
) -> str:
|
||||
return logger.log_decision(
|
||||
stock_code=stock_code,
|
||||
market=market,
|
||||
exchange_code="KRX" if market == "KR" else "NASDAQ",
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
rationale="test",
|
||||
context_snapshot={"scenario_match": scenario_match or {}},
|
||||
input_data={"stock_code": stock_code},
|
||||
)
|
||||
|
||||
|
||||
def test_generate_scorecard_market_scoped(
|
||||
db_conn: sqlite3.Connection, context_store: ContextStore,
|
||||
) -> None:
|
||||
reviewer = DailyReviewer(db_conn, context_store)
|
||||
logger = DecisionLogger(db_conn)
|
||||
|
||||
buy_id = _log_decision(
|
||||
logger,
|
||||
stock_code="005930",
|
||||
market="KR",
|
||||
action="BUY",
|
||||
confidence=90,
|
||||
scenario_match={"rsi": 29.0},
|
||||
)
|
||||
_log_decision(
|
||||
logger,
|
||||
stock_code="000660",
|
||||
market="KR",
|
||||
action="HOLD",
|
||||
confidence=60,
|
||||
)
|
||||
_log_decision(
|
||||
logger,
|
||||
stock_code="AAPL",
|
||||
market="US",
|
||||
action="SELL",
|
||||
confidence=80,
|
||||
scenario_match={"volume_ratio": 2.1},
|
||||
)
|
||||
|
||||
log_trade(
|
||||
db_conn,
|
||||
"005930",
|
||||
"BUY",
|
||||
90,
|
||||
"buy",
|
||||
quantity=1,
|
||||
price=100.0,
|
||||
pnl=10.0,
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
decision_id=buy_id,
|
||||
)
|
||||
log_trade(
|
||||
db_conn,
|
||||
"000660",
|
||||
"HOLD",
|
||||
60,
|
||||
"hold",
|
||||
quantity=0,
|
||||
price=0.0,
|
||||
pnl=0.0,
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
)
|
||||
log_trade(
|
||||
db_conn,
|
||||
"AAPL",
|
||||
"SELL",
|
||||
80,
|
||||
"sell",
|
||||
quantity=1,
|
||||
price=200.0,
|
||||
pnl=-5.0,
|
||||
market="US",
|
||||
exchange_code="NASDAQ",
|
||||
)
|
||||
|
||||
scorecard = reviewer.generate_scorecard(TODAY, "KR")
|
||||
|
||||
assert scorecard.market == "KR"
|
||||
assert scorecard.total_decisions == 2
|
||||
assert scorecard.buys == 1
|
||||
assert scorecard.sells == 0
|
||||
assert scorecard.holds == 1
|
||||
assert scorecard.total_pnl == 10.0
|
||||
assert scorecard.win_rate == 100.0
|
||||
assert scorecard.avg_confidence == 75.0
|
||||
assert scorecard.scenario_match_rate == 50.0
|
||||
|
||||
|
||||
def test_generate_scorecard_top_winners_and_losers(
|
||||
db_conn: sqlite3.Connection, context_store: ContextStore,
|
||||
) -> None:
|
||||
reviewer = DailyReviewer(db_conn, context_store)
|
||||
logger = DecisionLogger(db_conn)
|
||||
|
||||
for code, pnl in [("005930", 30.0), ("000660", 10.0), ("035420", -15.0), ("051910", -5.0)]:
|
||||
decision_id = _log_decision(
|
||||
logger,
|
||||
stock_code=code,
|
||||
market="KR",
|
||||
action="BUY" if pnl >= 0 else "SELL",
|
||||
confidence=80,
|
||||
scenario_match={"rsi": 30.0},
|
||||
)
|
||||
log_trade(
|
||||
db_conn,
|
||||
code,
|
||||
"BUY" if pnl >= 0 else "SELL",
|
||||
80,
|
||||
"test",
|
||||
quantity=1,
|
||||
price=100.0,
|
||||
pnl=pnl,
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
decision_id=decision_id,
|
||||
)
|
||||
|
||||
scorecard = reviewer.generate_scorecard(TODAY, "KR")
|
||||
assert scorecard.top_winners == ["005930", "000660"]
|
||||
assert scorecard.top_losers == ["035420", "051910"]
|
||||
|
||||
|
||||
def test_generate_scorecard_empty_day(
|
||||
db_conn: sqlite3.Connection, context_store: ContextStore,
|
||||
) -> None:
|
||||
reviewer = DailyReviewer(db_conn, context_store)
|
||||
scorecard = reviewer.generate_scorecard(TODAY, "KR")
|
||||
|
||||
assert scorecard.total_decisions == 0
|
||||
assert scorecard.total_pnl == 0.0
|
||||
assert scorecard.win_rate == 0.0
|
||||
assert scorecard.avg_confidence == 0.0
|
||||
assert scorecard.scenario_match_rate == 0.0
|
||||
assert scorecard.top_winners == []
|
||||
assert scorecard.top_losers == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_lessons_without_gemini_returns_empty(
|
||||
db_conn: sqlite3.Connection, context_store: ContextStore,
|
||||
) -> None:
|
||||
reviewer = DailyReviewer(db_conn, context_store, gemini_client=None)
|
||||
lessons = await reviewer.generate_lessons(
|
||||
DailyScorecard(
|
||||
date="2026-02-14",
|
||||
market="KR",
|
||||
total_decisions=1,
|
||||
buys=1,
|
||||
sells=0,
|
||||
holds=0,
|
||||
total_pnl=5.0,
|
||||
win_rate=100.0,
|
||||
avg_confidence=90.0,
|
||||
scenario_match_rate=100.0,
|
||||
)
|
||||
)
|
||||
assert lessons == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_lessons_parses_json_array(
|
||||
db_conn: sqlite3.Connection, context_store: ContextStore,
|
||||
) -> None:
|
||||
mock_gemini = MagicMock()
|
||||
mock_gemini.decide = AsyncMock(
|
||||
return_value=SimpleNamespace(rationale='["Cut losers earlier", "Reduce midday churn"]')
|
||||
)
|
||||
reviewer = DailyReviewer(db_conn, context_store, gemini_client=mock_gemini)
|
||||
|
||||
lessons = await reviewer.generate_lessons(
|
||||
DailyScorecard(
|
||||
date="2026-02-14",
|
||||
market="KR",
|
||||
total_decisions=3,
|
||||
buys=1,
|
||||
sells=1,
|
||||
holds=1,
|
||||
total_pnl=-2.5,
|
||||
win_rate=50.0,
|
||||
avg_confidence=70.0,
|
||||
scenario_match_rate=66.7,
|
||||
)
|
||||
)
|
||||
assert lessons == ["Cut losers earlier", "Reduce midday churn"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_lessons_fallback_to_lines(
|
||||
db_conn: sqlite3.Connection, context_store: ContextStore,
|
||||
) -> None:
|
||||
mock_gemini = MagicMock()
|
||||
mock_gemini.decide = AsyncMock(
|
||||
return_value=SimpleNamespace(rationale="- Keep risk tighter\n- Increase selectivity")
|
||||
)
|
||||
reviewer = DailyReviewer(db_conn, context_store, gemini_client=mock_gemini)
|
||||
|
||||
lessons = await reviewer.generate_lessons(
|
||||
DailyScorecard(
|
||||
date="2026-02-14",
|
||||
market="US",
|
||||
total_decisions=2,
|
||||
buys=1,
|
||||
sells=1,
|
||||
holds=0,
|
||||
total_pnl=1.0,
|
||||
win_rate=50.0,
|
||||
avg_confidence=75.0,
|
||||
scenario_match_rate=100.0,
|
||||
)
|
||||
)
|
||||
assert lessons == ["Keep risk tighter", "Increase selectivity"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_lessons_handles_gemini_error(
|
||||
db_conn: sqlite3.Connection, context_store: ContextStore,
|
||||
) -> None:
|
||||
mock_gemini = MagicMock()
|
||||
mock_gemini.decide = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
reviewer = DailyReviewer(db_conn, context_store, gemini_client=mock_gemini)
|
||||
|
||||
lessons = await reviewer.generate_lessons(
|
||||
DailyScorecard(
|
||||
date="2026-02-14",
|
||||
market="US",
|
||||
total_decisions=0,
|
||||
buys=0,
|
||||
sells=0,
|
||||
holds=0,
|
||||
total_pnl=0.0,
|
||||
win_rate=0.0,
|
||||
avg_confidence=0.0,
|
||||
scenario_match_rate=0.0,
|
||||
)
|
||||
)
|
||||
assert lessons == []
|
||||
|
||||
|
||||
def test_store_scorecard_in_context(
|
||||
db_conn: sqlite3.Connection, context_store: ContextStore,
|
||||
) -> None:
|
||||
reviewer = DailyReviewer(db_conn, context_store)
|
||||
scorecard = DailyScorecard(
|
||||
date="2026-02-14",
|
||||
market="KR",
|
||||
total_decisions=5,
|
||||
buys=2,
|
||||
sells=1,
|
||||
holds=2,
|
||||
total_pnl=15.0,
|
||||
win_rate=66.67,
|
||||
avg_confidence=82.0,
|
||||
scenario_match_rate=80.0,
|
||||
lessons=["Keep position sizing stable"],
|
||||
cross_market_note="US risk-off",
|
||||
)
|
||||
|
||||
reviewer.store_scorecard_in_context(scorecard)
|
||||
|
||||
stored = context_store.get_context(
|
||||
ContextLayer.L6_DAILY,
|
||||
"2026-02-14",
|
||||
"scorecard_KR",
|
||||
)
|
||||
assert stored is not None
|
||||
assert stored["market"] == "KR"
|
||||
assert stored["total_pnl"] == 15.0
|
||||
assert stored["lessons"] == ["Keep position sizing stable"]
|
||||
|
||||
|
||||
def test_store_scorecard_key_is_market_scoped(
|
||||
db_conn: sqlite3.Connection, context_store: ContextStore,
|
||||
) -> None:
|
||||
reviewer = DailyReviewer(db_conn, context_store)
|
||||
kr = DailyScorecard(
|
||||
date="2026-02-14",
|
||||
market="KR",
|
||||
total_decisions=1,
|
||||
buys=1,
|
||||
sells=0,
|
||||
holds=0,
|
||||
total_pnl=1.0,
|
||||
win_rate=100.0,
|
||||
avg_confidence=90.0,
|
||||
scenario_match_rate=100.0,
|
||||
)
|
||||
us = DailyScorecard(
|
||||
date="2026-02-14",
|
||||
market="US",
|
||||
total_decisions=1,
|
||||
buys=0,
|
||||
sells=1,
|
||||
holds=0,
|
||||
total_pnl=-1.0,
|
||||
win_rate=0.0,
|
||||
avg_confidence=70.0,
|
||||
scenario_match_rate=100.0,
|
||||
)
|
||||
|
||||
reviewer.store_scorecard_in_context(kr)
|
||||
reviewer.store_scorecard_in_context(us)
|
||||
|
||||
kr_ctx = context_store.get_context(ContextLayer.L6_DAILY, "2026-02-14", "scorecard_KR")
|
||||
us_ctx = context_store.get_context(ContextLayer.L6_DAILY, "2026-02-14", "scorecard_US")
|
||||
|
||||
assert kr_ctx["market"] == "KR"
|
||||
assert us_ctx["market"] == "US"
|
||||
assert kr_ctx["total_pnl"] == 1.0
|
||||
assert us_ctx["total_pnl"] == -1.0
|
||||
|
||||
|
||||
def test_generate_scorecard_handles_invalid_context_snapshot(
|
||||
db_conn: sqlite3.Connection, context_store: ContextStore,
|
||||
) -> None:
|
||||
reviewer = DailyReviewer(db_conn, context_store)
|
||||
db_conn.execute(
|
||||
"""
|
||||
INSERT INTO decision_logs (
|
||||
decision_id, timestamp, stock_code, market, exchange_code,
|
||||
action, confidence, rationale, context_snapshot, input_data
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
"d1",
|
||||
"2026-02-14T09:00:00+00:00",
|
||||
"005930",
|
||||
"KR",
|
||||
"KRX",
|
||||
"HOLD",
|
||||
50,
|
||||
"test",
|
||||
"{invalid_json",
|
||||
json.dumps({}),
|
||||
),
|
||||
)
|
||||
db_conn.commit()
|
||||
|
||||
scorecard = reviewer.generate_scorecard("2026-02-14", "KR")
|
||||
assert scorecard.total_decisions == 1
|
||||
assert scorecard.scenario_match_rate == 0.0
|
||||
442
tests/test_dashboard.py
Normal file
442
tests/test_dashboard.py
Normal file
@@ -0,0 +1,442 @@
|
||||
"""Tests for dashboard endpoint handlers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from collections.abc import Callable
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from src.dashboard.app import create_dashboard_app
|
||||
from src.db import init_db
|
||||
|
||||
|
||||
def _seed_db(conn: sqlite3.Connection) -> None:
|
||||
today = datetime.now(UTC).date().isoformat()
|
||||
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO playbooks (
|
||||
date, market, status, playbook_json, generated_at,
|
||||
token_count, scenario_count, match_count
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
"2026-02-14",
|
||||
"KR",
|
||||
"ready",
|
||||
json.dumps({"market": "KR", "stock_playbooks": []}),
|
||||
"2026-02-14T08:30:00+00:00",
|
||||
123,
|
||||
2,
|
||||
1,
|
||||
),
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO playbooks (
|
||||
date, market, status, playbook_json, generated_at,
|
||||
token_count, scenario_count, match_count
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
today,
|
||||
"US_NASDAQ",
|
||||
"ready",
|
||||
json.dumps({"market": "US_NASDAQ", "stock_playbooks": []}),
|
||||
f"{today}T08:30:00+00:00",
|
||||
100,
|
||||
1,
|
||||
0,
|
||||
),
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO contexts (layer, timeframe, key, value, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
"L6_DAILY",
|
||||
"2026-02-14",
|
||||
"scorecard_KR",
|
||||
json.dumps({"market": "KR", "total_pnl": 1.5, "win_rate": 60.0}),
|
||||
"2026-02-14T15:30:00+00:00",
|
||||
"2026-02-14T15:30:00+00:00",
|
||||
),
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO contexts (layer, timeframe, key, value, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
"L7_REALTIME",
|
||||
"2026-02-14T10:00:00+00:00",
|
||||
"volatility_KR_005930",
|
||||
json.dumps({"momentum_score": 70.0}),
|
||||
"2026-02-14T10:00:00+00:00",
|
||||
"2026-02-14T10:00:00+00:00",
|
||||
),
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO decision_logs (
|
||||
decision_id, timestamp, stock_code, market, exchange_code,
|
||||
action, confidence, rationale, context_snapshot, input_data
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
"d-kr-1",
|
||||
f"{today}T09:10:00+00:00",
|
||||
"005930",
|
||||
"KR",
|
||||
"KRX",
|
||||
"BUY",
|
||||
85,
|
||||
"signal matched",
|
||||
json.dumps({"scenario_match": {"rsi": 28.0}}),
|
||||
json.dumps({"current_price": 70000}),
|
||||
),
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO decision_logs (
|
||||
decision_id, timestamp, stock_code, market, exchange_code,
|
||||
action, confidence, rationale, context_snapshot, input_data
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
"d-us-1",
|
||||
f"{today}T21:10:00+00:00",
|
||||
"AAPL",
|
||||
"US_NASDAQ",
|
||||
"NASDAQ",
|
||||
"SELL",
|
||||
80,
|
||||
"no match",
|
||||
json.dumps({"scenario_match": {}}),
|
||||
json.dumps({"current_price": 200}),
|
||||
),
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO trades (
|
||||
timestamp, stock_code, action, confidence, rationale,
|
||||
quantity, price, pnl, market, exchange_code, selection_context, decision_id
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
f"{today}T09:11:00+00:00",
|
||||
"005930",
|
||||
"BUY",
|
||||
85,
|
||||
"buy",
|
||||
1,
|
||||
70000,
|
||||
2.0,
|
||||
"KR",
|
||||
"KRX",
|
||||
None,
|
||||
"d-kr-1",
|
||||
),
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO trades (
|
||||
timestamp, stock_code, action, confidence, rationale,
|
||||
quantity, price, pnl, market, exchange_code, selection_context, decision_id
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
f"{today}T21:11:00+00:00",
|
||||
"AAPL",
|
||||
"SELL",
|
||||
80,
|
||||
"sell",
|
||||
1,
|
||||
200,
|
||||
-1.0,
|
||||
"US_NASDAQ",
|
||||
"NASDAQ",
|
||||
None,
|
||||
"d-us-1",
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def _app(tmp_path: Path) -> Any:
|
||||
db_path = tmp_path / "dashboard_test.db"
|
||||
conn = init_db(str(db_path))
|
||||
_seed_db(conn)
|
||||
conn.close()
|
||||
return create_dashboard_app(str(db_path))
|
||||
|
||||
|
||||
def _endpoint(app: Any, path: str) -> Callable[..., Any]:
|
||||
for route in app.routes:
|
||||
if getattr(route, "path", None) == path:
|
||||
return route.endpoint
|
||||
raise AssertionError(f"route not found: {path}")
|
||||
|
||||
|
||||
def test_index_serves_html(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
index = _endpoint(app, "/")
|
||||
resp = index()
|
||||
assert isinstance(resp, FileResponse)
|
||||
assert "index.html" in str(resp.path)
|
||||
|
||||
|
||||
def test_status_endpoint(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_status = _endpoint(app, "/api/status")
|
||||
body = get_status()
|
||||
assert "KR" in body["markets"]
|
||||
assert "US_NASDAQ" in body["markets"]
|
||||
assert "totals" in body
|
||||
|
||||
|
||||
def test_playbook_found(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_playbook = _endpoint(app, "/api/playbook/{date_str}")
|
||||
body = get_playbook("2026-02-14", market="KR")
|
||||
assert body["market"] == "KR"
|
||||
|
||||
|
||||
def test_playbook_not_found(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_playbook = _endpoint(app, "/api/playbook/{date_str}")
|
||||
with pytest.raises(HTTPException, match="playbook not found"):
|
||||
get_playbook("2026-02-15", market="KR")
|
||||
|
||||
|
||||
def test_scorecard_found(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_scorecard = _endpoint(app, "/api/scorecard/{date_str}")
|
||||
body = get_scorecard("2026-02-14", market="KR")
|
||||
assert body["scorecard"]["total_pnl"] == 1.5
|
||||
|
||||
|
||||
def test_scorecard_not_found(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_scorecard = _endpoint(app, "/api/scorecard/{date_str}")
|
||||
with pytest.raises(HTTPException, match="scorecard not found"):
|
||||
get_scorecard("2026-02-15", market="KR")
|
||||
|
||||
|
||||
def test_performance_all(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_performance = _endpoint(app, "/api/performance")
|
||||
body = get_performance(market="all")
|
||||
assert body["market"] == "all"
|
||||
assert body["combined"]["total_trades"] == 2
|
||||
assert len(body["by_market"]) == 2
|
||||
|
||||
|
||||
def test_performance_market_filter(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_performance = _endpoint(app, "/api/performance")
|
||||
body = get_performance(market="KR")
|
||||
assert body["market"] == "KR"
|
||||
assert body["metrics"]["total_trades"] == 1
|
||||
|
||||
|
||||
def test_performance_empty_market(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_performance = _endpoint(app, "/api/performance")
|
||||
body = get_performance(market="JP")
|
||||
assert body["metrics"]["total_trades"] == 0
|
||||
|
||||
|
||||
def test_context_layer_all(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_context_layer = _endpoint(app, "/api/context/{layer}")
|
||||
body = get_context_layer("L7_REALTIME", timeframe=None, limit=100)
|
||||
assert body["layer"] == "L7_REALTIME"
|
||||
assert body["count"] == 1
|
||||
|
||||
|
||||
def test_context_layer_timeframe_filter(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_context_layer = _endpoint(app, "/api/context/{layer}")
|
||||
body = get_context_layer("L6_DAILY", timeframe="2026-02-14", limit=100)
|
||||
assert body["count"] == 1
|
||||
assert body["entries"][0]["key"] == "scorecard_KR"
|
||||
|
||||
|
||||
def test_decisions_endpoint(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_decisions = _endpoint(app, "/api/decisions")
|
||||
body = get_decisions(market="KR", limit=50)
|
||||
assert body["count"] == 1
|
||||
assert body["decisions"][0]["decision_id"] == "d-kr-1"
|
||||
|
||||
|
||||
def test_scenarios_active_filters_non_matched(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_active_scenarios = _endpoint(app, "/api/scenarios/active")
|
||||
body = get_active_scenarios(
|
||||
market="KR",
|
||||
date_str=datetime.now(UTC).date().isoformat(),
|
||||
limit=50,
|
||||
)
|
||||
assert body["count"] == 1
|
||||
assert body["matches"][0]["stock_code"] == "005930"
|
||||
|
||||
|
||||
def test_scenarios_active_empty_when_no_matches(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_active_scenarios = _endpoint(app, "/api/scenarios/active")
|
||||
body = get_active_scenarios(market="US", date_str="2026-02-14", limit=50)
|
||||
assert body["count"] == 0
|
||||
|
||||
|
||||
def test_pnl_history_all_markets(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_pnl_history = _endpoint(app, "/api/pnl/history")
|
||||
body = get_pnl_history(days=30, market="all")
|
||||
assert body["market"] == "all"
|
||||
assert isinstance(body["labels"], list)
|
||||
assert isinstance(body["pnl"], list)
|
||||
assert len(body["labels"]) == len(body["pnl"])
|
||||
|
||||
|
||||
def test_pnl_history_market_filter(tmp_path: Path) -> None:
|
||||
app = _app(tmp_path)
|
||||
get_pnl_history = _endpoint(app, "/api/pnl/history")
|
||||
body = get_pnl_history(days=30, market="KR")
|
||||
assert body["market"] == "KR"
|
||||
# KR has 1 trade with pnl=2.0
|
||||
assert len(body["labels"]) >= 1
|
||||
assert body["pnl"][0] == 2.0
|
||||
|
||||
|
||||
def test_positions_returns_open_buy(tmp_path: Path) -> None:
|
||||
"""BUY가 마지막 거래인 종목은 포지션으로 반환되어야 한다."""
|
||||
app = _app(tmp_path)
|
||||
get_positions = _endpoint(app, "/api/positions")
|
||||
body = get_positions()
|
||||
# seed_db: 005930은 BUY (오픈), AAPL은 SELL (마지막)
|
||||
assert body["count"] == 1
|
||||
pos = body["positions"][0]
|
||||
assert pos["stock_code"] == "005930"
|
||||
assert pos["market"] == "KR"
|
||||
assert pos["quantity"] == 1
|
||||
assert pos["entry_price"] == 70000
|
||||
|
||||
|
||||
def test_positions_excludes_closed_sell(tmp_path: Path) -> None:
|
||||
"""마지막 거래가 SELL인 종목은 포지션에 나타나지 않아야 한다."""
|
||||
app = _app(tmp_path)
|
||||
get_positions = _endpoint(app, "/api/positions")
|
||||
body = get_positions()
|
||||
codes = [p["stock_code"] for p in body["positions"]]
|
||||
assert "AAPL" not in codes
|
||||
|
||||
|
||||
def test_positions_empty_when_no_trades(tmp_path: Path) -> None:
|
||||
"""거래 내역이 없으면 빈 포지션 목록을 반환해야 한다."""
|
||||
db_path = tmp_path / "empty.db"
|
||||
conn = init_db(str(db_path))
|
||||
conn.close()
|
||||
app = create_dashboard_app(str(db_path))
|
||||
get_positions = _endpoint(app, "/api/positions")
|
||||
body = get_positions()
|
||||
assert body["count"] == 0
|
||||
assert body["positions"] == []
|
||||
|
||||
|
||||
def _seed_cb_context(conn: sqlite3.Connection, pnl_pct: float, market: str = "KR") -> None:
|
||||
import json as _json
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO system_metrics (key, value, updated_at) VALUES (?, ?, ?)",
|
||||
(
|
||||
f"portfolio_pnl_pct_{market}",
|
||||
_json.dumps({"pnl_pct": pnl_pct}),
|
||||
"2026-02-22T10:00:00+00:00",
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def test_status_circuit_breaker_ok(tmp_path: Path) -> None:
|
||||
"""pnl_pct가 -2.0%보다 높으면 status=ok를 반환해야 한다."""
|
||||
db_path = tmp_path / "cb_ok.db"
|
||||
conn = init_db(str(db_path))
|
||||
_seed_cb_context(conn, -1.0)
|
||||
conn.close()
|
||||
app = create_dashboard_app(str(db_path))
|
||||
get_status = _endpoint(app, "/api/status")
|
||||
body = get_status()
|
||||
cb = body["circuit_breaker"]
|
||||
assert cb["status"] == "ok"
|
||||
assert cb["current_pnl_pct"] == -1.0
|
||||
assert cb["threshold_pct"] == -3.0
|
||||
|
||||
|
||||
def test_status_circuit_breaker_warning(tmp_path: Path) -> None:
|
||||
"""pnl_pct가 -2.0% 이하이면 status=warning을 반환해야 한다."""
|
||||
db_path = tmp_path / "cb_warn.db"
|
||||
conn = init_db(str(db_path))
|
||||
_seed_cb_context(conn, -2.5)
|
||||
conn.close()
|
||||
app = create_dashboard_app(str(db_path))
|
||||
get_status = _endpoint(app, "/api/status")
|
||||
body = get_status()
|
||||
assert body["circuit_breaker"]["status"] == "warning"
|
||||
|
||||
|
||||
def test_status_circuit_breaker_tripped(tmp_path: Path) -> None:
|
||||
"""pnl_pct가 임계값(-3.0%) 이하이면 status=tripped를 반환해야 한다."""
|
||||
db_path = tmp_path / "cb_tripped.db"
|
||||
conn = init_db(str(db_path))
|
||||
_seed_cb_context(conn, -3.5)
|
||||
conn.close()
|
||||
app = create_dashboard_app(str(db_path))
|
||||
get_status = _endpoint(app, "/api/status")
|
||||
body = get_status()
|
||||
assert body["circuit_breaker"]["status"] == "tripped"
|
||||
|
||||
|
||||
def test_status_circuit_breaker_unknown_when_no_data(tmp_path: Path) -> None:
|
||||
"""L7 context에 pnl_pct 데이터가 없으면 status=unknown을 반환해야 한다."""
|
||||
app = _app(tmp_path) # seed_db에는 portfolio_pnl_pct 없음
|
||||
get_status = _endpoint(app, "/api/status")
|
||||
body = get_status()
|
||||
cb = body["circuit_breaker"]
|
||||
assert cb["status"] == "unknown"
|
||||
assert cb["current_pnl_pct"] is None
|
||||
|
||||
|
||||
def test_status_mode_paper(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""MODE=paper일 때 status 응답에 mode=paper가 포함돼야 한다."""
|
||||
monkeypatch.setenv("MODE", "paper")
|
||||
app = _app(tmp_path)
|
||||
get_status = _endpoint(app, "/api/status")
|
||||
body = get_status()
|
||||
assert body["mode"] == "paper"
|
||||
|
||||
|
||||
def test_status_mode_live(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""MODE=live일 때 status 응답에 mode=live가 포함돼야 한다."""
|
||||
monkeypatch.setenv("MODE", "live")
|
||||
app = _app(tmp_path)
|
||||
get_status = _endpoint(app, "/api/status")
|
||||
body = get_status()
|
||||
assert body["mode"] == "live"
|
||||
|
||||
|
||||
def test_status_mode_default_paper(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""MODE 환경변수가 없으면 mode 기본값은 paper여야 한다."""
|
||||
monkeypatch.delenv("MODE", raising=False)
|
||||
app = _app(tmp_path)
|
||||
get_status = _endpoint(app, "/api/status")
|
||||
body = get_status()
|
||||
assert body["mode"] == "paper"
|
||||
673
tests/test_data_integration.py
Normal file
673
tests/test_data_integration.py
Normal file
@@ -0,0 +1,673 @@
|
||||
"""Tests for external data integration (news, economic calendar, market data)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.brain.gemini_client import GeminiClient
|
||||
from src.data.economic_calendar import EconomicCalendar, EconomicEvent
|
||||
from src.data.market_data import MarketBreadth, MarketData, MarketSentiment
|
||||
from src.data.news_api import NewsAPI, NewsArticle, NewsSentiment
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# NewsAPI Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNewsAPI:
|
||||
"""Test news API integration with caching."""
|
||||
|
||||
def test_news_api_init_without_key(self):
|
||||
"""NewsAPI should initialize without API key for testing."""
|
||||
api = NewsAPI(api_key=None)
|
||||
assert api._api_key is None
|
||||
assert api._provider == "alphavantage"
|
||||
assert api._cache_ttl == 300
|
||||
|
||||
def test_news_api_init_with_custom_settings(self):
|
||||
"""NewsAPI should accept custom provider and cache TTL."""
|
||||
api = NewsAPI(api_key="test_key", provider="newsapi", cache_ttl=600)
|
||||
assert api._api_key == "test_key"
|
||||
assert api._provider == "newsapi"
|
||||
assert api._cache_ttl == 600
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_news_sentiment_without_api_key_returns_none(self):
|
||||
"""Without API key, get_news_sentiment should return None."""
|
||||
api = NewsAPI(api_key=None)
|
||||
result = await api.get_news_sentiment("AAPL")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_hit_returns_cached_sentiment(self):
|
||||
"""Cache hit should return cached sentiment without API call."""
|
||||
api = NewsAPI(api_key="test_key")
|
||||
|
||||
# Manually populate cache
|
||||
cached_sentiment = NewsSentiment(
|
||||
stock_code="AAPL",
|
||||
articles=[],
|
||||
avg_sentiment=0.5,
|
||||
article_count=0,
|
||||
fetched_at=time.time(),
|
||||
)
|
||||
api._cache["AAPL"] = cached_sentiment
|
||||
|
||||
result = await api.get_news_sentiment("AAPL")
|
||||
assert result is cached_sentiment
|
||||
assert result.stock_code == "AAPL"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_expiry_triggers_refetch(self):
|
||||
"""Expired cache entry should trigger refetch."""
|
||||
api = NewsAPI(api_key="test_key", cache_ttl=1)
|
||||
|
||||
# Add expired cache entry
|
||||
expired_sentiment = NewsSentiment(
|
||||
stock_code="AAPL",
|
||||
articles=[],
|
||||
avg_sentiment=0.5,
|
||||
article_count=0,
|
||||
fetched_at=time.time() - 10, # 10 seconds ago
|
||||
)
|
||||
api._cache["AAPL"] = expired_sentiment
|
||||
|
||||
# Mock the fetch to avoid real API call
|
||||
with patch.object(api, "_fetch_news", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = None
|
||||
result = await api.get_news_sentiment("AAPL")
|
||||
|
||||
# Should have attempted refetch since cache expired
|
||||
mock_fetch.assert_called_once_with("AAPL")
|
||||
|
||||
def test_clear_cache(self):
|
||||
"""clear_cache should empty the cache."""
|
||||
api = NewsAPI(api_key="test_key")
|
||||
api._cache["AAPL"] = NewsSentiment(
|
||||
stock_code="AAPL",
|
||||
articles=[],
|
||||
avg_sentiment=0.0,
|
||||
article_count=0,
|
||||
fetched_at=time.time(),
|
||||
)
|
||||
assert len(api._cache) == 1
|
||||
|
||||
api.clear_cache()
|
||||
assert len(api._cache) == 0
|
||||
|
||||
def test_parse_alphavantage_response_with_valid_data(self):
|
||||
"""Should parse Alpha Vantage response correctly."""
|
||||
api = NewsAPI(api_key="test_key", provider="alphavantage")
|
||||
|
||||
mock_response = {
|
||||
"feed": [
|
||||
{
|
||||
"title": "Apple hits new high",
|
||||
"summary": "Apple stock surges to record levels",
|
||||
"source": "Reuters",
|
||||
"time_published": "2026-02-04T10:00:00",
|
||||
"url": "https://example.com/1",
|
||||
"ticker_sentiment": [
|
||||
{"ticker": "AAPL", "ticker_sentiment_score": "0.85"}
|
||||
],
|
||||
"overall_sentiment_score": "0.75",
|
||||
},
|
||||
{
|
||||
"title": "Market volatility rises",
|
||||
"summary": "Tech stocks face headwinds",
|
||||
"source": "Bloomberg",
|
||||
"time_published": "2026-02-04T09:00:00",
|
||||
"url": "https://example.com/2",
|
||||
"ticker_sentiment": [
|
||||
{"ticker": "AAPL", "ticker_sentiment_score": "-0.3"}
|
||||
],
|
||||
"overall_sentiment_score": "-0.2",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
result = api._parse_alphavantage_response("AAPL", mock_response)
|
||||
|
||||
assert result is not None
|
||||
assert result.stock_code == "AAPL"
|
||||
assert result.article_count == 2
|
||||
assert len(result.articles) == 2
|
||||
assert result.articles[0].title == "Apple hits new high"
|
||||
assert result.articles[0].sentiment_score == 0.85
|
||||
assert result.articles[1].sentiment_score == -0.3
|
||||
# Average: (0.85 - 0.3) / 2 = 0.275
|
||||
assert abs(result.avg_sentiment - 0.275) < 0.01
|
||||
|
||||
def test_parse_alphavantage_response_without_feed_returns_none(self):
|
||||
"""Should return None if 'feed' key is missing."""
|
||||
api = NewsAPI(api_key="test_key", provider="alphavantage")
|
||||
result = api._parse_alphavantage_response("AAPL", {})
|
||||
assert result is None
|
||||
|
||||
def test_parse_newsapi_response_with_valid_data(self):
|
||||
"""Should parse NewsAPI.org response correctly."""
|
||||
api = NewsAPI(api_key="test_key", provider="newsapi")
|
||||
|
||||
mock_response = {
|
||||
"status": "ok",
|
||||
"articles": [
|
||||
{
|
||||
"title": "Apple stock surges",
|
||||
"description": "Strong earnings beat expectations",
|
||||
"source": {"name": "TechCrunch"},
|
||||
"publishedAt": "2026-02-04T10:00:00Z",
|
||||
"url": "https://example.com/1",
|
||||
},
|
||||
{
|
||||
"title": "Tech sector faces risks",
|
||||
"description": "Concerns over market downturn",
|
||||
"source": {"name": "CNBC"},
|
||||
"publishedAt": "2026-02-04T09:00:00Z",
|
||||
"url": "https://example.com/2",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
result = api._parse_newsapi_response("AAPL", mock_response)
|
||||
|
||||
assert result is not None
|
||||
assert result.stock_code == "AAPL"
|
||||
assert result.article_count == 2
|
||||
assert len(result.articles) == 2
|
||||
assert result.articles[0].title == "Apple stock surges"
|
||||
assert result.articles[0].source == "TechCrunch"
|
||||
|
||||
def test_estimate_sentiment_from_text_positive(self):
|
||||
"""Should detect positive sentiment from keywords."""
|
||||
api = NewsAPI()
|
||||
text = "Stock price surges with strong profit growth and upgrade"
|
||||
sentiment = api._estimate_sentiment_from_text(text)
|
||||
assert sentiment > 0.5
|
||||
|
||||
def test_estimate_sentiment_from_text_negative(self):
|
||||
"""Should detect negative sentiment from keywords."""
|
||||
api = NewsAPI()
|
||||
text = "Stock plunges on weak earnings, downgrade warning"
|
||||
sentiment = api._estimate_sentiment_from_text(text)
|
||||
assert sentiment < -0.5
|
||||
|
||||
def test_estimate_sentiment_from_text_neutral(self):
|
||||
"""Should return neutral sentiment without keywords."""
|
||||
api = NewsAPI()
|
||||
text = "Company announces quarterly report"
|
||||
sentiment = api._estimate_sentiment_from_text(text)
|
||||
assert abs(sentiment) < 0.1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EconomicCalendar Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEconomicCalendar:
|
||||
"""Test economic calendar functionality."""
|
||||
|
||||
def test_economic_calendar_init(self):
|
||||
"""EconomicCalendar should initialize correctly."""
|
||||
calendar = EconomicCalendar(api_key="test_key")
|
||||
assert calendar._api_key == "test_key"
|
||||
assert len(calendar._events) == 0
|
||||
|
||||
def test_add_event(self):
|
||||
"""Should be able to add events to calendar."""
|
||||
calendar = EconomicCalendar()
|
||||
event = EconomicEvent(
|
||||
name="FOMC Meeting",
|
||||
event_type="FOMC",
|
||||
datetime=datetime(2026, 3, 18),
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Interest rate decision",
|
||||
)
|
||||
calendar.add_event(event)
|
||||
assert len(calendar._events) == 1
|
||||
assert calendar._events[0].name == "FOMC Meeting"
|
||||
|
||||
def test_get_upcoming_events_filters_by_timeframe(self):
|
||||
"""Should only return events within specified timeframe."""
|
||||
calendar = EconomicCalendar()
|
||||
|
||||
# Add events at different times
|
||||
now = datetime.now()
|
||||
calendar.add_event(
|
||||
EconomicEvent(
|
||||
name="Event Tomorrow",
|
||||
event_type="GDP",
|
||||
datetime=now + timedelta(days=1),
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Test event",
|
||||
)
|
||||
)
|
||||
calendar.add_event(
|
||||
EconomicEvent(
|
||||
name="Event Next Month",
|
||||
event_type="CPI",
|
||||
datetime=now + timedelta(days=30),
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Test event",
|
||||
)
|
||||
)
|
||||
|
||||
# Get events for next 7 days
|
||||
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="HIGH")
|
||||
assert upcoming.high_impact_count == 1
|
||||
assert upcoming.events[0].name == "Event Tomorrow"
|
||||
|
||||
def test_get_upcoming_events_filters_by_impact(self):
|
||||
"""Should filter events by minimum impact level."""
|
||||
calendar = EconomicCalendar()
|
||||
|
||||
now = datetime.now()
|
||||
calendar.add_event(
|
||||
EconomicEvent(
|
||||
name="High Impact Event",
|
||||
event_type="FOMC",
|
||||
datetime=now + timedelta(days=1),
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Test",
|
||||
)
|
||||
)
|
||||
calendar.add_event(
|
||||
EconomicEvent(
|
||||
name="Low Impact Event",
|
||||
event_type="OTHER",
|
||||
datetime=now + timedelta(days=1),
|
||||
impact="LOW",
|
||||
country="US",
|
||||
description="Test",
|
||||
)
|
||||
)
|
||||
|
||||
# Filter for HIGH impact only
|
||||
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="HIGH")
|
||||
assert upcoming.high_impact_count == 1
|
||||
assert upcoming.events[0].name == "High Impact Event"
|
||||
|
||||
# Filter for MEDIUM and above (should still get HIGH)
|
||||
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="MEDIUM")
|
||||
assert len(upcoming.events) == 1
|
||||
|
||||
# Filter for LOW and above (should get both)
|
||||
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="LOW")
|
||||
assert len(upcoming.events) == 2
|
||||
|
||||
def test_get_earnings_date_returns_next_earnings(self):
|
||||
"""Should return next earnings date for a stock."""
|
||||
calendar = EconomicCalendar()
|
||||
|
||||
now = datetime.now()
|
||||
earnings_date = now + timedelta(days=5)
|
||||
|
||||
calendar.add_event(
|
||||
EconomicEvent(
|
||||
name="AAPL Earnings",
|
||||
event_type="EARNINGS",
|
||||
datetime=earnings_date,
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Apple quarterly earnings",
|
||||
)
|
||||
)
|
||||
|
||||
result = calendar.get_earnings_date("AAPL")
|
||||
assert result == earnings_date
|
||||
|
||||
def test_get_earnings_date_returns_none_if_not_found(self):
|
||||
"""Should return None if no earnings found for stock."""
|
||||
calendar = EconomicCalendar()
|
||||
result = calendar.get_earnings_date("UNKNOWN")
|
||||
assert result is None
|
||||
|
||||
def test_load_hardcoded_events(self):
|
||||
"""Should load hardcoded major economic events."""
|
||||
calendar = EconomicCalendar()
|
||||
calendar.load_hardcoded_events()
|
||||
|
||||
# Should have multiple events (FOMC, GDP, CPI)
|
||||
assert len(calendar._events) > 10
|
||||
|
||||
# Check for FOMC events
|
||||
fomc_events = [e for e in calendar._events if e.event_type == "FOMC"]
|
||||
assert len(fomc_events) > 0
|
||||
|
||||
# Check for GDP events
|
||||
gdp_events = [e for e in calendar._events if e.event_type == "GDP"]
|
||||
assert len(gdp_events) > 0
|
||||
|
||||
# Check for CPI events
|
||||
cpi_events = [e for e in calendar._events if e.event_type == "CPI"]
|
||||
assert len(cpi_events) == 12 # Monthly CPI releases
|
||||
|
||||
def test_is_high_volatility_period_returns_true_near_high_impact(self):
|
||||
"""Should return True if high-impact event is within threshold."""
|
||||
calendar = EconomicCalendar()
|
||||
|
||||
now = datetime.now()
|
||||
calendar.add_event(
|
||||
EconomicEvent(
|
||||
name="FOMC Meeting",
|
||||
event_type="FOMC",
|
||||
datetime=now + timedelta(hours=12),
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Test",
|
||||
)
|
||||
)
|
||||
|
||||
assert calendar.is_high_volatility_period(hours_ahead=24) is True
|
||||
|
||||
def test_is_high_volatility_period_returns_false_when_no_events(self):
|
||||
"""Should return False if no high-impact events nearby."""
|
||||
calendar = EconomicCalendar()
|
||||
assert calendar.is_high_volatility_period(hours_ahead=24) is False
|
||||
|
||||
def test_clear_events(self):
|
||||
"""Should clear all events."""
|
||||
calendar = EconomicCalendar()
|
||||
calendar.add_event(
|
||||
EconomicEvent(
|
||||
name="Test",
|
||||
event_type="TEST",
|
||||
datetime=datetime.now(),
|
||||
impact="LOW",
|
||||
country="US",
|
||||
description="Test",
|
||||
)
|
||||
)
|
||||
assert len(calendar._events) == 1
|
||||
|
||||
calendar.clear_events()
|
||||
assert len(calendar._events) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MarketData Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMarketData:
|
||||
"""Test market data indicators."""
|
||||
|
||||
def test_market_data_init(self):
|
||||
"""MarketData should initialize correctly."""
|
||||
data = MarketData(api_key="test_key")
|
||||
assert data._api_key == "test_key"
|
||||
|
||||
def test_get_market_sentiment_without_api_key_returns_neutral(self):
|
||||
"""Without API key, should return NEUTRAL sentiment."""
|
||||
data = MarketData(api_key=None)
|
||||
sentiment = data.get_market_sentiment()
|
||||
assert sentiment == MarketSentiment.NEUTRAL
|
||||
|
||||
def test_get_market_breadth_without_api_key_returns_none(self):
|
||||
"""Without API key, should return None for breadth."""
|
||||
data = MarketData(api_key=None)
|
||||
breadth = data.get_market_breadth()
|
||||
assert breadth is None
|
||||
|
||||
def test_get_sector_performance_without_api_key_returns_empty(self):
|
||||
"""Without API key, should return empty list."""
|
||||
data = MarketData(api_key=None)
|
||||
sectors = data.get_sector_performance()
|
||||
assert sectors == []
|
||||
|
||||
def test_get_market_indicators_returns_defaults_without_api(self):
|
||||
"""Should return default indicators without API key."""
|
||||
data = MarketData(api_key=None)
|
||||
indicators = data.get_market_indicators()
|
||||
|
||||
assert indicators.sentiment == MarketSentiment.NEUTRAL
|
||||
assert indicators.breadth.advance_decline_ratio == 1.0
|
||||
assert indicators.sector_performance == []
|
||||
assert indicators.vix_level is None
|
||||
|
||||
def test_calculate_fear_greed_score_neutral_baseline(self):
|
||||
"""Should return neutral score (50) for balanced market."""
|
||||
data = MarketData()
|
||||
breadth = MarketBreadth(
|
||||
advancing_stocks=500,
|
||||
declining_stocks=500,
|
||||
unchanged_stocks=100,
|
||||
new_highs=50,
|
||||
new_lows=50,
|
||||
advance_decline_ratio=1.0,
|
||||
)
|
||||
|
||||
score = data.calculate_fear_greed_score(breadth)
|
||||
assert score == 50
|
||||
|
||||
def test_calculate_fear_greed_score_greedy_market(self):
|
||||
"""Should return high score for greedy market conditions."""
|
||||
data = MarketData()
|
||||
breadth = MarketBreadth(
|
||||
advancing_stocks=800,
|
||||
declining_stocks=200,
|
||||
unchanged_stocks=100,
|
||||
new_highs=100,
|
||||
new_lows=10,
|
||||
advance_decline_ratio=4.0,
|
||||
)
|
||||
|
||||
score = data.calculate_fear_greed_score(breadth, vix=12.0)
|
||||
assert score > 70
|
||||
|
||||
def test_calculate_fear_greed_score_fearful_market(self):
|
||||
"""Should return low score for fearful market conditions."""
|
||||
data = MarketData()
|
||||
breadth = MarketBreadth(
|
||||
advancing_stocks=200,
|
||||
declining_stocks=800,
|
||||
unchanged_stocks=100,
|
||||
new_highs=10,
|
||||
new_lows=100,
|
||||
advance_decline_ratio=0.25,
|
||||
)
|
||||
|
||||
score = data.calculate_fear_greed_score(breadth, vix=35.0)
|
||||
assert score < 30
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GeminiClient Integration Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGeminiClientWithExternalData:
|
||||
"""Test GeminiClient integration with external data sources."""
|
||||
|
||||
def test_gemini_client_accepts_optional_data_sources(self, settings):
|
||||
"""GeminiClient should accept optional external data sources."""
|
||||
news_api = NewsAPI(api_key="test_key")
|
||||
calendar = EconomicCalendar()
|
||||
market_data = MarketData()
|
||||
|
||||
client = GeminiClient(
|
||||
settings,
|
||||
news_api=news_api,
|
||||
economic_calendar=calendar,
|
||||
market_data=market_data,
|
||||
)
|
||||
|
||||
assert client._news_api is news_api
|
||||
assert client._economic_calendar is calendar
|
||||
assert client._market_data is market_data
|
||||
|
||||
def test_gemini_client_works_without_external_data(self, settings):
|
||||
"""GeminiClient should work without external data sources."""
|
||||
client = GeminiClient(settings)
|
||||
assert client._news_api is None
|
||||
assert client._economic_calendar is None
|
||||
assert client._market_data is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_prompt_includes_news_sentiment(self, settings):
|
||||
"""build_prompt should include news sentiment when available."""
|
||||
client = GeminiClient(settings)
|
||||
|
||||
market_data = {
|
||||
"stock_code": "AAPL",
|
||||
"current_price": 180.0,
|
||||
"market_name": "US stock market",
|
||||
}
|
||||
|
||||
sentiment = NewsSentiment(
|
||||
stock_code="AAPL",
|
||||
articles=[
|
||||
NewsArticle(
|
||||
title="Apple hits record high",
|
||||
summary="Strong earnings",
|
||||
source="Reuters",
|
||||
published_at="2026-02-04",
|
||||
sentiment_score=0.85,
|
||||
url="https://example.com",
|
||||
)
|
||||
],
|
||||
avg_sentiment=0.85,
|
||||
article_count=1,
|
||||
fetched_at=time.time(),
|
||||
)
|
||||
|
||||
prompt = await client.build_prompt(market_data, news_sentiment=sentiment)
|
||||
|
||||
assert "AAPL" in prompt
|
||||
assert "180.0" in prompt
|
||||
assert "EXTERNAL DATA" in prompt
|
||||
assert "News Sentiment" in prompt
|
||||
assert "0.85" in prompt
|
||||
assert "Apple hits record high" in prompt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_prompt_with_economic_events(self, settings):
|
||||
"""build_prompt should include upcoming economic events."""
|
||||
calendar = EconomicCalendar()
|
||||
now = datetime.now()
|
||||
calendar.add_event(
|
||||
EconomicEvent(
|
||||
name="FOMC Meeting",
|
||||
event_type="FOMC",
|
||||
datetime=now + timedelta(days=2),
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Interest rate decision",
|
||||
)
|
||||
)
|
||||
|
||||
client = GeminiClient(settings, economic_calendar=calendar)
|
||||
|
||||
market_data = {
|
||||
"stock_code": "AAPL",
|
||||
"current_price": 180.0,
|
||||
"market_name": "US stock market",
|
||||
}
|
||||
|
||||
prompt = await client.build_prompt(market_data)
|
||||
|
||||
assert "EXTERNAL DATA" in prompt
|
||||
assert "High-Impact Events" in prompt
|
||||
assert "FOMC Meeting" in prompt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_prompt_with_market_indicators(self, settings):
|
||||
"""build_prompt should include market sentiment indicators."""
|
||||
market_data_provider = MarketData(api_key="test_key")
|
||||
|
||||
# Mock the get_market_indicators to return test data
|
||||
with patch.object(market_data_provider, "get_market_indicators") as mock:
|
||||
mock.return_value = MagicMock(
|
||||
sentiment=MarketSentiment.EXTREME_GREED,
|
||||
breadth=MagicMock(advance_decline_ratio=2.5),
|
||||
)
|
||||
|
||||
client = GeminiClient(settings, market_data=market_data_provider)
|
||||
|
||||
market_data = {
|
||||
"stock_code": "AAPL",
|
||||
"current_price": 180.0,
|
||||
"market_name": "US stock market",
|
||||
}
|
||||
|
||||
prompt = await client.build_prompt(market_data)
|
||||
|
||||
assert "EXTERNAL DATA" in prompt
|
||||
assert "Market Sentiment" in prompt
|
||||
assert "EXTREME_GREED" in prompt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_prompt_graceful_when_no_external_data(self, settings):
|
||||
"""build_prompt should work gracefully without external data."""
|
||||
client = GeminiClient(settings)
|
||||
|
||||
market_data = {
|
||||
"stock_code": "AAPL",
|
||||
"current_price": 180.0,
|
||||
"market_name": "US stock market",
|
||||
}
|
||||
|
||||
prompt = await client.build_prompt(market_data)
|
||||
|
||||
assert "AAPL" in prompt
|
||||
assert "180.0" in prompt
|
||||
# Should NOT have external data section
|
||||
assert "EXTERNAL DATA" not in prompt
|
||||
|
||||
def test_build_prompt_sync_backward_compatibility(self, settings):
|
||||
"""build_prompt_sync should maintain backward compatibility."""
|
||||
client = GeminiClient(settings)
|
||||
|
||||
market_data = {
|
||||
"stock_code": "005930",
|
||||
"current_price": 72000,
|
||||
"orderbook": {"asks": [], "bids": []},
|
||||
"foreigner_net": -50000,
|
||||
}
|
||||
|
||||
prompt = client.build_prompt_sync(market_data)
|
||||
|
||||
assert "005930" in prompt
|
||||
assert "72000" in prompt
|
||||
assert "JSON" in prompt
|
||||
# Sync version should NOT have external data
|
||||
assert "EXTERNAL DATA" not in prompt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decide_with_news_sentiment_parameter(self, settings):
|
||||
"""decide should accept optional news_sentiment parameter."""
|
||||
client = GeminiClient(settings)
|
||||
|
||||
market_data = {
|
||||
"stock_code": "AAPL",
|
||||
"current_price": 180.0,
|
||||
"market_name": "US stock market",
|
||||
}
|
||||
|
||||
sentiment = NewsSentiment(
|
||||
stock_code="AAPL",
|
||||
articles=[],
|
||||
avg_sentiment=0.5,
|
||||
article_count=1,
|
||||
fetched_at=time.time(),
|
||||
)
|
||||
|
||||
# Mock the Gemini API call
|
||||
with patch.object(client._client.aio.models, "generate_content", new_callable=AsyncMock) as mock_gen:
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = '{"action": "BUY", "confidence": 85, "rationale": "Good news"}'
|
||||
mock_gen.return_value = mock_response
|
||||
|
||||
decision = await client.decide(market_data, news_sentiment=sentiment)
|
||||
|
||||
assert decision.action == "BUY"
|
||||
assert decision.confidence == 85
|
||||
mock_gen.assert_called_once()
|
||||
195
tests/test_db.py
Normal file
195
tests/test_db.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""Tests for database helper functions."""
|
||||
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
from src.db import get_open_position, init_db, log_trade
|
||||
|
||||
|
||||
def test_get_open_position_returns_latest_buy() -> None:
|
||||
conn = init_db(":memory:")
|
||||
log_trade(
|
||||
conn=conn,
|
||||
stock_code="005930",
|
||||
action="BUY",
|
||||
confidence=90,
|
||||
rationale="entry",
|
||||
quantity=2,
|
||||
price=70000.0,
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
decision_id="d-buy-1",
|
||||
)
|
||||
|
||||
position = get_open_position(conn, "005930", "KR")
|
||||
assert position is not None
|
||||
assert position["decision_id"] == "d-buy-1"
|
||||
assert position["price"] == 70000.0
|
||||
assert position["quantity"] == 2
|
||||
|
||||
|
||||
def test_get_open_position_returns_none_when_latest_is_sell() -> None:
|
||||
conn = init_db(":memory:")
|
||||
log_trade(
|
||||
conn=conn,
|
||||
stock_code="005930",
|
||||
action="BUY",
|
||||
confidence=90,
|
||||
rationale="entry",
|
||||
quantity=1,
|
||||
price=70000.0,
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
decision_id="d-buy-1",
|
||||
)
|
||||
log_trade(
|
||||
conn=conn,
|
||||
stock_code="005930",
|
||||
action="SELL",
|
||||
confidence=95,
|
||||
rationale="exit",
|
||||
quantity=1,
|
||||
price=71000.0,
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
decision_id="d-sell-1",
|
||||
)
|
||||
|
||||
assert get_open_position(conn, "005930", "KR") is None
|
||||
|
||||
|
||||
def test_get_open_position_returns_none_when_no_trades() -> None:
|
||||
conn = init_db(":memory:")
|
||||
assert get_open_position(conn, "AAPL", "US_NASDAQ") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WAL mode tests (issue #210)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_wal_mode_applied_to_file_db() -> None:
|
||||
"""File-based DB must use WAL journal mode for dashboard concurrent reads."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
db_path = f.name
|
||||
try:
|
||||
conn = init_db(db_path)
|
||||
cursor = conn.execute("PRAGMA journal_mode")
|
||||
mode = cursor.fetchone()[0]
|
||||
assert mode == "wal", f"Expected WAL mode, got {mode}"
|
||||
conn.close()
|
||||
finally:
|
||||
os.unlink(db_path)
|
||||
# Clean up WAL auxiliary files if they exist
|
||||
for ext in ("-wal", "-shm"):
|
||||
path = db_path + ext
|
||||
if os.path.exists(path):
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
def test_wal_mode_not_applied_to_memory_db() -> None:
|
||||
""":memory: DB must not apply WAL (SQLite does not support WAL for in-memory)."""
|
||||
conn = init_db(":memory:")
|
||||
cursor = conn.execute("PRAGMA journal_mode")
|
||||
mode = cursor.fetchone()[0]
|
||||
# In-memory DBs default to 'memory' journal mode
|
||||
assert mode != "wal", "WAL should not be set on in-memory database"
|
||||
conn.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# mode column tests (issue #212)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_log_trade_stores_mode_paper() -> None:
|
||||
"""log_trade must persist mode='paper' in the trades table."""
|
||||
conn = init_db(":memory:")
|
||||
log_trade(
|
||||
conn=conn,
|
||||
stock_code="005930",
|
||||
action="BUY",
|
||||
confidence=85,
|
||||
rationale="test",
|
||||
mode="paper",
|
||||
)
|
||||
row = conn.execute("SELECT mode FROM trades ORDER BY id DESC LIMIT 1").fetchone()
|
||||
assert row is not None
|
||||
assert row[0] == "paper"
|
||||
|
||||
|
||||
def test_log_trade_stores_mode_live() -> None:
|
||||
"""log_trade must persist mode='live' in the trades table."""
|
||||
conn = init_db(":memory:")
|
||||
log_trade(
|
||||
conn=conn,
|
||||
stock_code="005930",
|
||||
action="BUY",
|
||||
confidence=85,
|
||||
rationale="test",
|
||||
mode="live",
|
||||
)
|
||||
row = conn.execute("SELECT mode FROM trades ORDER BY id DESC LIMIT 1").fetchone()
|
||||
assert row is not None
|
||||
assert row[0] == "live"
|
||||
|
||||
|
||||
def test_log_trade_default_mode_is_paper() -> None:
|
||||
"""log_trade without explicit mode must default to 'paper'."""
|
||||
conn = init_db(":memory:")
|
||||
log_trade(
|
||||
conn=conn,
|
||||
stock_code="005930",
|
||||
action="HOLD",
|
||||
confidence=50,
|
||||
rationale="test",
|
||||
)
|
||||
row = conn.execute("SELECT mode FROM trades ORDER BY id DESC LIMIT 1").fetchone()
|
||||
assert row is not None
|
||||
assert row[0] == "paper"
|
||||
|
||||
|
||||
def test_mode_column_exists_in_schema() -> None:
|
||||
"""trades table must have a mode column after init_db."""
|
||||
conn = init_db(":memory:")
|
||||
cursor = conn.execute("PRAGMA table_info(trades)")
|
||||
columns = {row[1] for row in cursor.fetchall()}
|
||||
assert "mode" in columns
|
||||
|
||||
|
||||
def test_mode_migration_adds_column_to_existing_db() -> None:
|
||||
"""init_db must add mode column to existing DBs that lack it (migration)."""
|
||||
import sqlite3
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
db_path = f.name
|
||||
try:
|
||||
# Create DB without mode column (simulate old schema)
|
||||
old_conn = sqlite3.connect(db_path)
|
||||
old_conn.execute(
|
||||
"""CREATE TABLE trades (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
timestamp TEXT NOT NULL,
|
||||
stock_code TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
confidence INTEGER NOT NULL,
|
||||
rationale TEXT,
|
||||
quantity INTEGER,
|
||||
price REAL,
|
||||
pnl REAL DEFAULT 0.0,
|
||||
market TEXT DEFAULT 'KR',
|
||||
exchange_code TEXT DEFAULT 'KRX',
|
||||
decision_id TEXT
|
||||
)"""
|
||||
)
|
||||
old_conn.commit()
|
||||
old_conn.close()
|
||||
|
||||
# Run init_db — should add mode column via migration
|
||||
conn = init_db(db_path)
|
||||
cursor = conn.execute("PRAGMA table_info(trades)")
|
||||
columns = {row[1] for row in cursor.fetchall()}
|
||||
assert "mode" in columns
|
||||
conn.close()
|
||||
finally:
|
||||
os.unlink(db_path)
|
||||
292
tests/test_decision_logger.py
Normal file
292
tests/test_decision_logger.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""Tests for decision logging and audit trail."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from src.db import init_db
|
||||
from src.logging.decision_logger import DecisionLog, DecisionLogger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_conn() -> sqlite3.Connection:
|
||||
"""Provide an in-memory database with initialized schema."""
|
||||
conn = init_db(":memory:")
|
||||
return conn
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def logger(db_conn: sqlite3.Connection) -> DecisionLogger:
|
||||
"""Provide a DecisionLogger instance."""
|
||||
return DecisionLogger(db_conn)
|
||||
|
||||
|
||||
def test_log_decision_creates_record(logger: DecisionLogger, db_conn: sqlite3.Connection) -> None:
|
||||
"""Test that log_decision creates a database record."""
|
||||
context_snapshot = {
|
||||
"L1": {"quote": {"price": 100.0, "volume": 1000}},
|
||||
"L2": {"orderbook": {"bid": [99.0], "ask": [101.0]}},
|
||||
}
|
||||
input_data = {"price": 100.0, "volume": 1000, "foreigner_net": 500}
|
||||
|
||||
decision_id = logger.log_decision(
|
||||
stock_code="005930",
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
action="BUY",
|
||||
confidence=85,
|
||||
rationale="Strong upward momentum",
|
||||
context_snapshot=context_snapshot,
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
# Verify decision_id is a valid UUID
|
||||
assert decision_id is not None
|
||||
assert len(decision_id) == 36 # UUID v4 format
|
||||
|
||||
# Verify record exists in database
|
||||
cursor = db_conn.execute(
|
||||
"SELECT decision_id, action, confidence FROM decision_logs WHERE decision_id = ?",
|
||||
(decision_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
assert row is not None
|
||||
assert row[0] == decision_id
|
||||
assert row[1] == "BUY"
|
||||
assert row[2] == 85
|
||||
|
||||
|
||||
def test_log_decision_stores_context_snapshot(logger: DecisionLogger) -> None:
|
||||
"""Test that context snapshot is stored as JSON."""
|
||||
context_snapshot = {
|
||||
"L1": {"real_time": "data"},
|
||||
"L3": {"daily": "aggregate"},
|
||||
"L7": {"legacy": "wisdom"},
|
||||
}
|
||||
input_data = {"price": 50000.0, "volume": 2000}
|
||||
|
||||
decision_id = logger.log_decision(
|
||||
stock_code="035420",
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
action="HOLD",
|
||||
confidence=75,
|
||||
rationale="Waiting for clearer signal",
|
||||
context_snapshot=context_snapshot,
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
# Retrieve and verify context snapshot
|
||||
decision = logger.get_decision_by_id(decision_id)
|
||||
assert decision is not None
|
||||
assert decision.context_snapshot == context_snapshot
|
||||
assert decision.input_data == input_data
|
||||
|
||||
|
||||
def test_get_unreviewed_decisions(logger: DecisionLogger) -> None:
|
||||
"""Test retrieving unreviewed decisions with confidence filter."""
|
||||
# Log multiple decisions with varying confidence
|
||||
logger.log_decision(
|
||||
stock_code="005930",
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
action="BUY",
|
||||
confidence=90,
|
||||
rationale="High confidence buy",
|
||||
context_snapshot={},
|
||||
input_data={},
|
||||
)
|
||||
logger.log_decision(
|
||||
stock_code="000660",
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
action="SELL",
|
||||
confidence=75,
|
||||
rationale="Low confidence sell",
|
||||
context_snapshot={},
|
||||
input_data={},
|
||||
)
|
||||
logger.log_decision(
|
||||
stock_code="035420",
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
action="HOLD",
|
||||
confidence=85,
|
||||
rationale="Medium confidence hold",
|
||||
context_snapshot={},
|
||||
input_data={},
|
||||
)
|
||||
|
||||
# Get unreviewed decisions with default threshold (80)
|
||||
unreviewed = logger.get_unreviewed_decisions()
|
||||
assert len(unreviewed) == 2 # Only confidence >= 80
|
||||
assert all(d.confidence >= 80 for d in unreviewed)
|
||||
assert all(not d.reviewed for d in unreviewed)
|
||||
|
||||
# Get with lower threshold
|
||||
unreviewed_all = logger.get_unreviewed_decisions(min_confidence=70)
|
||||
assert len(unreviewed_all) == 3
|
||||
|
||||
|
||||
def test_mark_reviewed(logger: DecisionLogger) -> None:
|
||||
"""Test marking a decision as reviewed."""
|
||||
decision_id = logger.log_decision(
|
||||
stock_code="005930",
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
action="BUY",
|
||||
confidence=85,
|
||||
rationale="Test decision",
|
||||
context_snapshot={},
|
||||
input_data={},
|
||||
)
|
||||
|
||||
# Initially unreviewed
|
||||
decision = logger.get_decision_by_id(decision_id)
|
||||
assert decision is not None
|
||||
assert not decision.reviewed
|
||||
assert decision.review_notes is None
|
||||
|
||||
# Mark as reviewed
|
||||
review_notes = "Good decision, captured bullish momentum correctly"
|
||||
logger.mark_reviewed(decision_id, review_notes)
|
||||
|
||||
# Verify updated
|
||||
decision = logger.get_decision_by_id(decision_id)
|
||||
assert decision is not None
|
||||
assert decision.reviewed
|
||||
assert decision.review_notes == review_notes
|
||||
|
||||
# Should not appear in unreviewed list
|
||||
unreviewed = logger.get_unreviewed_decisions()
|
||||
assert all(d.decision_id != decision_id for d in unreviewed)
|
||||
|
||||
|
||||
def test_update_outcome(logger: DecisionLogger) -> None:
|
||||
"""Test updating decision outcome with P&L and accuracy."""
|
||||
decision_id = logger.log_decision(
|
||||
stock_code="005930",
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
action="BUY",
|
||||
confidence=90,
|
||||
rationale="Expecting price increase",
|
||||
context_snapshot={},
|
||||
input_data={},
|
||||
)
|
||||
|
||||
# Initially no outcome
|
||||
decision = logger.get_decision_by_id(decision_id)
|
||||
assert decision is not None
|
||||
assert decision.outcome_pnl is None
|
||||
assert decision.outcome_accuracy is None
|
||||
|
||||
# Update outcome (profitable trade)
|
||||
logger.update_outcome(decision_id, pnl=5000.0, accuracy=1)
|
||||
|
||||
# Verify updated
|
||||
decision = logger.get_decision_by_id(decision_id)
|
||||
assert decision is not None
|
||||
assert decision.outcome_pnl == 5000.0
|
||||
assert decision.outcome_accuracy == 1
|
||||
|
||||
|
||||
def test_get_losing_decisions(logger: DecisionLogger) -> None:
|
||||
"""Test retrieving high-confidence losing decisions."""
|
||||
# Profitable decision
|
||||
id1 = logger.log_decision(
|
||||
stock_code="005930",
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
action="BUY",
|
||||
confidence=85,
|
||||
rationale="Correct prediction",
|
||||
context_snapshot={},
|
||||
input_data={},
|
||||
)
|
||||
logger.update_outcome(id1, pnl=3000.0, accuracy=1)
|
||||
|
||||
# High-confidence loss
|
||||
id2 = logger.log_decision(
|
||||
stock_code="000660",
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
action="SELL",
|
||||
confidence=90,
|
||||
rationale="Wrong prediction",
|
||||
context_snapshot={},
|
||||
input_data={},
|
||||
)
|
||||
logger.update_outcome(id2, pnl=-2000.0, accuracy=0)
|
||||
|
||||
# Low-confidence loss (should be ignored)
|
||||
id3 = logger.log_decision(
|
||||
stock_code="035420",
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
action="BUY",
|
||||
confidence=70,
|
||||
rationale="Low confidence, wrong",
|
||||
context_snapshot={},
|
||||
input_data={},
|
||||
)
|
||||
logger.update_outcome(id3, pnl=-1500.0, accuracy=0)
|
||||
|
||||
# Get high-confidence losing decisions
|
||||
losers = logger.get_losing_decisions(min_confidence=80, min_loss=-1000.0)
|
||||
assert len(losers) == 1
|
||||
assert losers[0].decision_id == id2
|
||||
assert losers[0].outcome_pnl == -2000.0
|
||||
assert losers[0].confidence == 90
|
||||
|
||||
|
||||
def test_get_decision_by_id_not_found(logger: DecisionLogger) -> None:
|
||||
"""Test that get_decision_by_id returns None for non-existent ID."""
|
||||
decision = logger.get_decision_by_id("non-existent-uuid")
|
||||
assert decision is None
|
||||
|
||||
|
||||
def test_unreviewed_limit(logger: DecisionLogger) -> None:
|
||||
"""Test that get_unreviewed_decisions respects limit parameter."""
|
||||
# Create 5 unreviewed decisions
|
||||
for i in range(5):
|
||||
logger.log_decision(
|
||||
stock_code=f"00{i}",
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
action="HOLD",
|
||||
confidence=85,
|
||||
rationale=f"Decision {i}",
|
||||
context_snapshot={},
|
||||
input_data={},
|
||||
)
|
||||
|
||||
# Get only 3
|
||||
unreviewed = logger.get_unreviewed_decisions(limit=3)
|
||||
assert len(unreviewed) == 3
|
||||
|
||||
|
||||
def test_decision_log_dataclass() -> None:
|
||||
"""Test DecisionLog dataclass creation."""
|
||||
now = datetime.now(UTC).isoformat()
|
||||
log = DecisionLog(
|
||||
decision_id="test-uuid",
|
||||
timestamp=now,
|
||||
stock_code="005930",
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
action="BUY",
|
||||
confidence=85,
|
||||
rationale="Test",
|
||||
context_snapshot={"L1": "data"},
|
||||
input_data={"price": 100.0},
|
||||
)
|
||||
|
||||
assert log.decision_id == "test-uuid"
|
||||
assert log.action == "BUY"
|
||||
assert log.confidence == 85
|
||||
assert log.reviewed is False
|
||||
assert log.outcome_pnl is None
|
||||
685
tests/test_evolution.py
Normal file
685
tests/test_evolution.py
Normal file
@@ -0,0 +1,685 @@
|
||||
"""Tests for the Evolution Engine components.
|
||||
|
||||
Tests cover:
|
||||
- EvolutionOptimizer: failure analysis and strategy generation
|
||||
- ABTester: A/B testing and statistical comparison
|
||||
- PerformanceTracker: metrics tracking and dashboard
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.config import Settings
|
||||
from src.db import init_db, log_trade
|
||||
from src.evolution.ab_test import ABTester
|
||||
from src.evolution.optimizer import EvolutionOptimizer
|
||||
from src.evolution.performance_tracker import (
|
||||
PerformanceDashboard,
|
||||
PerformanceTracker,
|
||||
StrategyMetrics,
|
||||
)
|
||||
from src.logging.decision_logger import DecisionLogger
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_conn() -> sqlite3.Connection:
|
||||
"""Provide an in-memory database with initialized schema."""
|
||||
return init_db(":memory:")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def settings() -> Settings:
|
||||
"""Provide test settings."""
|
||||
return Settings(
|
||||
KIS_APP_KEY="test_key",
|
||||
KIS_APP_SECRET="test_secret",
|
||||
KIS_ACCOUNT_NO="12345678-01",
|
||||
GEMINI_API_KEY="test_gemini_key",
|
||||
GEMINI_MODEL="gemini-pro",
|
||||
DB_PATH=":memory:",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def optimizer(settings: Settings) -> EvolutionOptimizer:
|
||||
"""Provide an EvolutionOptimizer instance."""
|
||||
return EvolutionOptimizer(settings)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def decision_logger(db_conn: sqlite3.Connection) -> DecisionLogger:
|
||||
"""Provide a DecisionLogger instance."""
|
||||
return DecisionLogger(db_conn)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ab_tester() -> ABTester:
|
||||
"""Provide an ABTester instance."""
|
||||
return ABTester(significance_level=0.05)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def performance_tracker(settings: Settings) -> PerformanceTracker:
|
||||
"""Provide a PerformanceTracker instance."""
|
||||
return PerformanceTracker(db_path=":memory:")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# EvolutionOptimizer Tests
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_analyze_failures_uses_decision_logger(optimizer: EvolutionOptimizer) -> None:
|
||||
"""Test that analyze_failures uses DecisionLogger.get_losing_decisions()."""
|
||||
# Add some losing decisions to the database
|
||||
logger = optimizer._decision_logger
|
||||
|
||||
# High-confidence loss
|
||||
id1 = logger.log_decision(
|
||||
stock_code="005930",
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
action="BUY",
|
||||
confidence=85,
|
||||
rationale="Expected growth",
|
||||
context_snapshot={"L1": {"price": 70000}},
|
||||
input_data={"price": 70000, "volume": 1000},
|
||||
)
|
||||
logger.update_outcome(id1, pnl=-2000.0, accuracy=0)
|
||||
|
||||
# Another high-confidence loss
|
||||
id2 = logger.log_decision(
|
||||
stock_code="000660",
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
action="SELL",
|
||||
confidence=90,
|
||||
rationale="Expected drop",
|
||||
context_snapshot={"L1": {"price": 100000}},
|
||||
input_data={"price": 100000, "volume": 500},
|
||||
)
|
||||
logger.update_outcome(id2, pnl=-1500.0, accuracy=0)
|
||||
|
||||
# Low-confidence loss (should be ignored)
|
||||
id3 = logger.log_decision(
|
||||
stock_code="035420",
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
action="HOLD",
|
||||
confidence=70,
|
||||
rationale="Uncertain",
|
||||
context_snapshot={},
|
||||
input_data={},
|
||||
)
|
||||
logger.update_outcome(id3, pnl=-500.0, accuracy=0)
|
||||
|
||||
# Analyze failures
|
||||
failures = optimizer.analyze_failures(limit=10)
|
||||
|
||||
# Should get 2 failures (confidence >= 80)
|
||||
assert len(failures) == 2
|
||||
assert all(f["confidence"] >= 80 for f in failures)
|
||||
assert all(f["outcome_pnl"] <= -100.0 for f in failures)
|
||||
|
||||
|
||||
def test_analyze_failures_empty_database(optimizer: EvolutionOptimizer) -> None:
|
||||
"""Test analyze_failures with no losing decisions."""
|
||||
failures = optimizer.analyze_failures()
|
||||
assert failures == []
|
||||
|
||||
|
||||
def test_identify_failure_patterns(optimizer: EvolutionOptimizer) -> None:
|
||||
"""Test identification of failure patterns."""
|
||||
failures = [
|
||||
{
|
||||
"decision_id": "1",
|
||||
"timestamp": "2024-01-15T09:30:00+00:00",
|
||||
"stock_code": "005930",
|
||||
"market": "KR",
|
||||
"exchange_code": "KRX",
|
||||
"action": "BUY",
|
||||
"confidence": 85,
|
||||
"rationale": "Test",
|
||||
"outcome_pnl": -1000.0,
|
||||
"outcome_accuracy": 0,
|
||||
"context_snapshot": {},
|
||||
"input_data": {},
|
||||
},
|
||||
{
|
||||
"decision_id": "2",
|
||||
"timestamp": "2024-01-15T14:30:00+00:00",
|
||||
"stock_code": "000660",
|
||||
"market": "KR",
|
||||
"exchange_code": "KRX",
|
||||
"action": "SELL",
|
||||
"confidence": 90,
|
||||
"rationale": "Test",
|
||||
"outcome_pnl": -2000.0,
|
||||
"outcome_accuracy": 0,
|
||||
"context_snapshot": {},
|
||||
"input_data": {},
|
||||
},
|
||||
{
|
||||
"decision_id": "3",
|
||||
"timestamp": "2024-01-15T09:45:00+00:00",
|
||||
"stock_code": "035420",
|
||||
"market": "US_NASDAQ",
|
||||
"exchange_code": "NASDAQ",
|
||||
"action": "BUY",
|
||||
"confidence": 80,
|
||||
"rationale": "Test",
|
||||
"outcome_pnl": -500.0,
|
||||
"outcome_accuracy": 0,
|
||||
"context_snapshot": {},
|
||||
"input_data": {},
|
||||
},
|
||||
]
|
||||
|
||||
patterns = optimizer.identify_failure_patterns(failures)
|
||||
|
||||
assert patterns["total_failures"] == 3
|
||||
assert patterns["markets"]["KR"] == 2
|
||||
assert patterns["markets"]["US_NASDAQ"] == 1
|
||||
assert patterns["actions"]["BUY"] == 2
|
||||
assert patterns["actions"]["SELL"] == 1
|
||||
assert 9 in patterns["hours"] # 09:30 and 09:45
|
||||
assert 14 in patterns["hours"] # 14:30
|
||||
assert patterns["avg_confidence"] == 85.0
|
||||
assert patterns["avg_loss"] == -1166.67
|
||||
|
||||
|
||||
def test_identify_failure_patterns_empty(optimizer: EvolutionOptimizer) -> None:
|
||||
"""Test pattern identification with no failures."""
|
||||
patterns = optimizer.identify_failure_patterns([])
|
||||
assert patterns["pattern_count"] == 0
|
||||
assert patterns["patterns"] == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_strategy_creates_file(optimizer: EvolutionOptimizer, tmp_path: Path) -> None:
|
||||
"""Test that generate_strategy creates a strategy file."""
|
||||
failures = [
|
||||
{
|
||||
"decision_id": "1",
|
||||
"timestamp": "2024-01-15T09:30:00+00:00",
|
||||
"stock_code": "005930",
|
||||
"market": "KR",
|
||||
"action": "BUY",
|
||||
"confidence": 85,
|
||||
"outcome_pnl": -1000.0,
|
||||
"context_snapshot": {},
|
||||
"input_data": {},
|
||||
}
|
||||
]
|
||||
|
||||
# Mock Gemini response
|
||||
mock_response = Mock()
|
||||
mock_response.text = """
|
||||
# Simple strategy
|
||||
price = market_data.get("current_price", 0)
|
||||
if price > 50000:
|
||||
return {"action": "BUY", "confidence": 70, "rationale": "Price above threshold"}
|
||||
return {"action": "HOLD", "confidence": 50, "rationale": "Waiting"}
|
||||
"""
|
||||
|
||||
with patch.object(optimizer._client.aio.models, "generate_content", new=AsyncMock(return_value=mock_response)):
|
||||
with patch("src.evolution.optimizer.STRATEGIES_DIR", tmp_path):
|
||||
strategy_path = await optimizer.generate_strategy(failures)
|
||||
|
||||
assert strategy_path is not None
|
||||
assert strategy_path.exists()
|
||||
assert strategy_path.suffix == ".py"
|
||||
assert "class Strategy_" in strategy_path.read_text()
|
||||
assert "def evaluate" in strategy_path.read_text()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_strategy_handles_api_error(optimizer: EvolutionOptimizer) -> None:
|
||||
"""Test that generate_strategy handles Gemini API errors gracefully."""
|
||||
failures = [{"decision_id": "1", "timestamp": "2024-01-15T09:30:00+00:00"}]
|
||||
|
||||
with patch.object(
|
||||
optimizer._client.aio.models,
|
||||
"generate_content",
|
||||
side_effect=Exception("API Error"),
|
||||
):
|
||||
strategy_path = await optimizer.generate_strategy(failures)
|
||||
|
||||
assert strategy_path is None
|
||||
|
||||
|
||||
def test_get_performance_summary() -> None:
|
||||
"""Test getting performance summary from trades table."""
|
||||
# Create a temporary database with trades
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
|
||||
tmp_path = tmp.name
|
||||
|
||||
conn = init_db(tmp_path)
|
||||
log_trade(conn, "005930", "BUY", 85, "Test win", quantity=10, price=70000, pnl=1000.0)
|
||||
log_trade(conn, "000660", "SELL", 90, "Test loss", quantity=5, price=100000, pnl=-500.0)
|
||||
log_trade(conn, "035420", "BUY", 80, "Test win", quantity=8, price=50000, pnl=800.0)
|
||||
conn.close()
|
||||
|
||||
# Create settings with temp database path
|
||||
settings = Settings(
|
||||
KIS_APP_KEY="test_key",
|
||||
KIS_APP_SECRET="test_secret",
|
||||
KIS_ACCOUNT_NO="12345678-01",
|
||||
GEMINI_API_KEY="test_gemini_key",
|
||||
GEMINI_MODEL="gemini-pro",
|
||||
DB_PATH=tmp_path,
|
||||
)
|
||||
|
||||
optimizer = EvolutionOptimizer(settings)
|
||||
summary = optimizer.get_performance_summary()
|
||||
|
||||
assert summary["total_trades"] == 3
|
||||
assert summary["wins"] == 2
|
||||
assert summary["losses"] == 1
|
||||
assert summary["total_pnl"] == 1300.0
|
||||
assert summary["avg_pnl"] == 433.33
|
||||
|
||||
# Clean up
|
||||
Path(tmp_path).unlink()
|
||||
|
||||
|
||||
def test_validate_strategy_success(optimizer: EvolutionOptimizer, tmp_path: Path) -> None:
|
||||
"""Test strategy validation when tests pass."""
|
||||
strategy_file = tmp_path / "test_strategy.py"
|
||||
strategy_file.write_text("# Valid strategy file")
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = Mock(returncode=0, stdout="", stderr="")
|
||||
result = optimizer.validate_strategy(strategy_file)
|
||||
|
||||
assert result is True
|
||||
assert strategy_file.exists()
|
||||
|
||||
|
||||
def test_validate_strategy_failure(optimizer: EvolutionOptimizer, tmp_path: Path) -> None:
|
||||
"""Test strategy validation when tests fail."""
|
||||
strategy_file = tmp_path / "test_strategy.py"
|
||||
strategy_file.write_text("# Invalid strategy file")
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = Mock(returncode=1, stdout="FAILED", stderr="")
|
||||
result = optimizer.validate_strategy(strategy_file)
|
||||
|
||||
assert result is False
|
||||
# File should be deleted on failure
|
||||
assert not strategy_file.exists()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# ABTester Tests
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_calculate_performance_basic(ab_tester: ABTester) -> None:
|
||||
"""Test basic performance calculation."""
|
||||
trades = [
|
||||
{"pnl": 1000.0},
|
||||
{"pnl": -500.0},
|
||||
{"pnl": 800.0},
|
||||
{"pnl": 200.0},
|
||||
]
|
||||
|
||||
perf = ab_tester.calculate_performance(trades, "TestStrategy")
|
||||
|
||||
assert perf.strategy_name == "TestStrategy"
|
||||
assert perf.total_trades == 4
|
||||
assert perf.wins == 3
|
||||
assert perf.losses == 1
|
||||
assert perf.total_pnl == 1500.0
|
||||
assert perf.avg_pnl == 375.0
|
||||
assert perf.win_rate == 75.0
|
||||
assert perf.sharpe_ratio is not None
|
||||
|
||||
|
||||
def test_calculate_performance_empty(ab_tester: ABTester) -> None:
|
||||
"""Test performance calculation with no trades."""
|
||||
perf = ab_tester.calculate_performance([], "EmptyStrategy")
|
||||
|
||||
assert perf.total_trades == 0
|
||||
assert perf.wins == 0
|
||||
assert perf.losses == 0
|
||||
assert perf.total_pnl == 0.0
|
||||
assert perf.avg_pnl == 0.0
|
||||
assert perf.win_rate == 0.0
|
||||
assert perf.sharpe_ratio is None
|
||||
|
||||
|
||||
def test_compare_strategies_significant_difference(ab_tester: ABTester) -> None:
|
||||
"""Test strategy comparison with significant performance difference."""
|
||||
# Strategy A: consistently profitable
|
||||
trades_a = [{"pnl": 1000.0} for _ in range(30)]
|
||||
|
||||
# Strategy B: consistently losing
|
||||
trades_b = [{"pnl": -500.0} for _ in range(30)]
|
||||
|
||||
result = ab_tester.compare_strategies(trades_a, trades_b, "Strategy A", "Strategy B")
|
||||
|
||||
# scipy returns np.True_ instead of Python bool
|
||||
assert bool(result.is_significant) is True
|
||||
assert result.winner == "Strategy A"
|
||||
assert result.p_value < 0.05
|
||||
assert result.performance_a.avg_pnl > result.performance_b.avg_pnl
|
||||
|
||||
|
||||
def test_compare_strategies_no_difference(ab_tester: ABTester) -> None:
|
||||
"""Test strategy comparison with no significant difference."""
|
||||
# Both strategies have similar performance
|
||||
trades_a = [{"pnl": 100.0}, {"pnl": -50.0}, {"pnl": 80.0}]
|
||||
trades_b = [{"pnl": 90.0}, {"pnl": -60.0}, {"pnl": 85.0}]
|
||||
|
||||
result = ab_tester.compare_strategies(trades_a, trades_b, "Strategy A", "Strategy B")
|
||||
|
||||
# With small samples and similar performance, likely not significant
|
||||
assert result.winner is None or not result.is_significant
|
||||
|
||||
|
||||
def test_should_deploy_meets_criteria(ab_tester: ABTester) -> None:
|
||||
"""Test deployment decision when criteria are met."""
|
||||
# Create a winning result that meets criteria
|
||||
trades_a = [{"pnl": 1000.0} for _ in range(25)] # 100% win rate
|
||||
trades_b = [{"pnl": -500.0} for _ in range(25)]
|
||||
|
||||
result = ab_tester.compare_strategies(trades_a, trades_b, "Winner", "Loser")
|
||||
|
||||
should_deploy = ab_tester.should_deploy(result, min_win_rate=60.0, min_trades=20)
|
||||
|
||||
assert should_deploy is True
|
||||
|
||||
|
||||
def test_should_deploy_insufficient_trades(ab_tester: ABTester) -> None:
|
||||
"""Test deployment decision with insufficient trades."""
|
||||
trades_a = [{"pnl": 1000.0} for _ in range(10)] # Only 10 trades
|
||||
trades_b = [{"pnl": -500.0} for _ in range(10)]
|
||||
|
||||
result = ab_tester.compare_strategies(trades_a, trades_b, "Winner", "Loser")
|
||||
|
||||
should_deploy = ab_tester.should_deploy(result, min_win_rate=60.0, min_trades=20)
|
||||
|
||||
assert should_deploy is False
|
||||
|
||||
|
||||
def test_should_deploy_low_win_rate(ab_tester: ABTester) -> None:
|
||||
"""Test deployment decision with low win rate."""
|
||||
# Mix of wins and losses, below 60% win rate
|
||||
trades_a = [{"pnl": 100.0}] * 10 + [{"pnl": -100.0}] * 15 # 40% win rate
|
||||
trades_b = [{"pnl": -500.0} for _ in range(25)]
|
||||
|
||||
result = ab_tester.compare_strategies(trades_a, trades_b, "LowWinner", "Loser")
|
||||
|
||||
should_deploy = ab_tester.should_deploy(result, min_win_rate=60.0, min_trades=20)
|
||||
|
||||
assert should_deploy is False
|
||||
|
||||
|
||||
def test_should_deploy_not_significant(ab_tester: ABTester) -> None:
|
||||
"""Test deployment decision when difference is not significant."""
|
||||
# Use more varied data to ensure statistical insignificance
|
||||
trades_a = [{"pnl": 100.0}, {"pnl": -50.0}] * 12 + [{"pnl": 100.0}]
|
||||
trades_b = [{"pnl": 95.0}, {"pnl": -45.0}] * 12 + [{"pnl": 95.0}]
|
||||
|
||||
result = ab_tester.compare_strategies(trades_a, trades_b, "A", "B")
|
||||
|
||||
should_deploy = ab_tester.should_deploy(result, min_win_rate=60.0, min_trades=20)
|
||||
|
||||
# Not significant or not profitable enough
|
||||
# Even if significant, win rate is 50% which is below 60% threshold
|
||||
assert should_deploy is False
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# PerformanceTracker Tests
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_strategy_metrics(db_conn: sqlite3.Connection) -> None:
|
||||
"""Test getting strategy metrics."""
|
||||
# Add some trades
|
||||
log_trade(db_conn, "005930", "BUY", 85, "Win 1", quantity=10, price=70000, pnl=1000.0)
|
||||
log_trade(db_conn, "000660", "SELL", 90, "Loss 1", quantity=5, price=100000, pnl=-500.0)
|
||||
log_trade(db_conn, "035420", "BUY", 80, "Win 2", quantity=8, price=50000, pnl=800.0)
|
||||
log_trade(db_conn, "005930", "HOLD", 75, "Hold", quantity=0, price=70000, pnl=0.0)
|
||||
|
||||
tracker = PerformanceTracker(db_path=":memory:")
|
||||
# Manually set connection for testing
|
||||
tracker._db_path = db_conn
|
||||
|
||||
# Need to use the same connection
|
||||
with patch("sqlite3.connect", return_value=db_conn):
|
||||
metrics = tracker.get_strategy_metrics()
|
||||
|
||||
assert metrics.total_trades == 4
|
||||
assert metrics.wins == 2
|
||||
assert metrics.losses == 1
|
||||
assert metrics.holds == 1
|
||||
assert metrics.win_rate == 50.0
|
||||
assert metrics.total_pnl == 1300.0
|
||||
|
||||
|
||||
def test_calculate_improvement_trend_improving(performance_tracker: PerformanceTracker) -> None:
|
||||
"""Test improvement trend calculation for improving strategy."""
|
||||
metrics = [
|
||||
StrategyMetrics(
|
||||
strategy_name="test",
|
||||
period_start="2024-01-01",
|
||||
period_end="2024-01-07",
|
||||
total_trades=10,
|
||||
wins=5,
|
||||
losses=5,
|
||||
holds=0,
|
||||
win_rate=50.0,
|
||||
avg_pnl=100.0,
|
||||
total_pnl=1000.0,
|
||||
best_trade=500.0,
|
||||
worst_trade=-300.0,
|
||||
avg_confidence=75.0,
|
||||
),
|
||||
StrategyMetrics(
|
||||
strategy_name="test",
|
||||
period_start="2024-01-08",
|
||||
period_end="2024-01-14",
|
||||
total_trades=10,
|
||||
wins=7,
|
||||
losses=3,
|
||||
holds=0,
|
||||
win_rate=70.0,
|
||||
avg_pnl=200.0,
|
||||
total_pnl=2000.0,
|
||||
best_trade=600.0,
|
||||
worst_trade=-200.0,
|
||||
avg_confidence=80.0,
|
||||
),
|
||||
]
|
||||
|
||||
trend = performance_tracker.calculate_improvement_trend(metrics)
|
||||
|
||||
assert trend["trend"] == "improving"
|
||||
assert trend["win_rate_change"] == 20.0
|
||||
assert trend["pnl_change"] == 100.0
|
||||
assert trend["confidence_change"] == 5.0
|
||||
|
||||
|
||||
def test_calculate_improvement_trend_declining(performance_tracker: PerformanceTracker) -> None:
|
||||
"""Test improvement trend calculation for declining strategy."""
|
||||
metrics = [
|
||||
StrategyMetrics(
|
||||
strategy_name="test",
|
||||
period_start="2024-01-01",
|
||||
period_end="2024-01-07",
|
||||
total_trades=10,
|
||||
wins=7,
|
||||
losses=3,
|
||||
holds=0,
|
||||
win_rate=70.0,
|
||||
avg_pnl=200.0,
|
||||
total_pnl=2000.0,
|
||||
best_trade=600.0,
|
||||
worst_trade=-200.0,
|
||||
avg_confidence=80.0,
|
||||
),
|
||||
StrategyMetrics(
|
||||
strategy_name="test",
|
||||
period_start="2024-01-08",
|
||||
period_end="2024-01-14",
|
||||
total_trades=10,
|
||||
wins=4,
|
||||
losses=6,
|
||||
holds=0,
|
||||
win_rate=40.0,
|
||||
avg_pnl=-50.0,
|
||||
total_pnl=-500.0,
|
||||
best_trade=300.0,
|
||||
worst_trade=-400.0,
|
||||
avg_confidence=70.0,
|
||||
),
|
||||
]
|
||||
|
||||
trend = performance_tracker.calculate_improvement_trend(metrics)
|
||||
|
||||
assert trend["trend"] == "declining"
|
||||
assert trend["win_rate_change"] == -30.0
|
||||
assert trend["pnl_change"] == -250.0
|
||||
|
||||
|
||||
def test_calculate_improvement_trend_insufficient_data(performance_tracker: PerformanceTracker) -> None:
|
||||
"""Test improvement trend with insufficient data."""
|
||||
metrics = [
|
||||
StrategyMetrics(
|
||||
strategy_name="test",
|
||||
period_start="2024-01-01",
|
||||
period_end="2024-01-07",
|
||||
total_trades=10,
|
||||
wins=5,
|
||||
losses=5,
|
||||
holds=0,
|
||||
win_rate=50.0,
|
||||
avg_pnl=100.0,
|
||||
total_pnl=1000.0,
|
||||
best_trade=500.0,
|
||||
worst_trade=-300.0,
|
||||
avg_confidence=75.0,
|
||||
)
|
||||
]
|
||||
|
||||
trend = performance_tracker.calculate_improvement_trend(metrics)
|
||||
|
||||
assert trend["trend"] == "insufficient_data"
|
||||
assert trend["win_rate_change"] == 0.0
|
||||
assert trend["pnl_change"] == 0.0
|
||||
|
||||
|
||||
def test_export_dashboard_json(performance_tracker: PerformanceTracker) -> None:
|
||||
"""Test exporting dashboard as JSON."""
|
||||
overall_metrics = StrategyMetrics(
|
||||
strategy_name="test",
|
||||
period_start="2024-01-01",
|
||||
period_end="2024-01-31",
|
||||
total_trades=100,
|
||||
wins=60,
|
||||
losses=40,
|
||||
holds=10,
|
||||
win_rate=60.0,
|
||||
avg_pnl=150.0,
|
||||
total_pnl=15000.0,
|
||||
best_trade=1000.0,
|
||||
worst_trade=-500.0,
|
||||
avg_confidence=80.0,
|
||||
)
|
||||
|
||||
dashboard = PerformanceDashboard(
|
||||
generated_at=datetime.now(UTC).isoformat(),
|
||||
overall_metrics=overall_metrics,
|
||||
daily_metrics=[],
|
||||
weekly_metrics=[],
|
||||
improvement_trend={"trend": "improving", "win_rate_change": 10.0},
|
||||
)
|
||||
|
||||
json_output = performance_tracker.export_dashboard_json(dashboard)
|
||||
|
||||
# Verify it's valid JSON
|
||||
data = json.loads(json_output)
|
||||
assert "generated_at" in data
|
||||
assert "overall_metrics" in data
|
||||
assert data["overall_metrics"]["total_trades"] == 100
|
||||
assert data["overall_metrics"]["win_rate"] == 60.0
|
||||
|
||||
|
||||
def test_generate_dashboard() -> None:
|
||||
"""Test generating a complete dashboard."""
|
||||
# Create tracker with temp database
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
|
||||
tmp_path = tmp.name
|
||||
|
||||
# Initialize with data
|
||||
conn = init_db(tmp_path)
|
||||
log_trade(conn, "005930", "BUY", 85, "Win", quantity=10, price=70000, pnl=1000.0)
|
||||
log_trade(conn, "000660", "SELL", 90, "Loss", quantity=5, price=100000, pnl=-500.0)
|
||||
conn.close()
|
||||
|
||||
tracker = PerformanceTracker(db_path=tmp_path)
|
||||
dashboard = tracker.generate_dashboard()
|
||||
|
||||
assert isinstance(dashboard, PerformanceDashboard)
|
||||
assert dashboard.overall_metrics.total_trades == 2
|
||||
assert len(dashboard.daily_metrics) == 7
|
||||
assert len(dashboard.weekly_metrics) == 4
|
||||
assert "trend" in dashboard.improvement_trend
|
||||
|
||||
# Clean up
|
||||
Path(tmp_path).unlink()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Integration Tests
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_evolution_pipeline(optimizer: EvolutionOptimizer, tmp_path: Path) -> None:
|
||||
"""Test the complete evolution pipeline."""
|
||||
# Add losing decisions
|
||||
logger = optimizer._decision_logger
|
||||
id1 = logger.log_decision(
|
||||
stock_code="005930",
|
||||
market="KR",
|
||||
exchange_code="KRX",
|
||||
action="BUY",
|
||||
confidence=85,
|
||||
rationale="Expected growth",
|
||||
context_snapshot={},
|
||||
input_data={},
|
||||
)
|
||||
logger.update_outcome(id1, pnl=-2000.0, accuracy=0)
|
||||
|
||||
# Mock Gemini and subprocess
|
||||
mock_response = Mock()
|
||||
mock_response.text = 'return {"action": "HOLD", "confidence": 50, "rationale": "Test"}'
|
||||
|
||||
with patch.object(optimizer._client.aio.models, "generate_content", new=AsyncMock(return_value=mock_response)):
|
||||
with patch("src.evolution.optimizer.STRATEGIES_DIR", tmp_path):
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = Mock(returncode=0, stdout="", stderr="")
|
||||
|
||||
result = await optimizer.evolve()
|
||||
|
||||
assert result is not None
|
||||
assert "title" in result
|
||||
assert "branch" in result
|
||||
assert "status" in result
|
||||
558
tests/test_latency_control.py
Normal file
558
tests/test_latency_control.py
Normal file
@@ -0,0 +1,558 @@
|
||||
"""Tests for latency control system (criticality assessment and priority queue)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from src.core.criticality import CriticalityAssessor, CriticalityLevel
|
||||
from src.core.priority_queue import PriorityTask, PriorityTaskQueue
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CriticalityAssessor Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCriticalityAssessor:
|
||||
"""Test suite for criticality assessment logic."""
|
||||
|
||||
def test_market_closed_returns_low(self) -> None:
|
||||
"""Market closed should return LOW priority."""
|
||||
assessor = CriticalityAssessor()
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=0.0,
|
||||
volatility_score=50.0,
|
||||
volume_surge=1.0,
|
||||
is_market_open=False,
|
||||
)
|
||||
assert level == CriticalityLevel.LOW
|
||||
|
||||
def test_very_low_volatility_returns_low(self) -> None:
|
||||
"""Very low volatility should return LOW priority."""
|
||||
assessor = CriticalityAssessor()
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=0.0,
|
||||
volatility_score=20.0, # Below 30.0 threshold
|
||||
volume_surge=1.0,
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.LOW
|
||||
|
||||
def test_critical_pnl_threshold_triggered(self) -> None:
|
||||
"""P&L below -2.5% should trigger CRITICAL."""
|
||||
assessor = CriticalityAssessor()
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=-2.6, # Below -2.5% threshold
|
||||
volatility_score=50.0,
|
||||
volume_surge=1.0,
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.CRITICAL
|
||||
|
||||
def test_critical_pnl_at_circuit_breaker_proximity(self) -> None:
|
||||
"""P&L at exactly -2.5% (near -3.0% breaker) should be CRITICAL."""
|
||||
assessor = CriticalityAssessor()
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=-2.5,
|
||||
volatility_score=50.0,
|
||||
volume_surge=1.0,
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.CRITICAL
|
||||
|
||||
def test_critical_price_change_positive(self) -> None:
|
||||
"""Large positive price change (>5%) should trigger CRITICAL."""
|
||||
assessor = CriticalityAssessor()
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=0.0,
|
||||
volatility_score=50.0,
|
||||
volume_surge=1.0,
|
||||
price_change_1m=5.5, # Above 5.0% threshold
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.CRITICAL
|
||||
|
||||
def test_critical_price_change_negative(self) -> None:
|
||||
"""Large negative price change (<-5%) should trigger CRITICAL."""
|
||||
assessor = CriticalityAssessor()
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=0.0,
|
||||
volatility_score=50.0,
|
||||
volume_surge=1.0,
|
||||
price_change_1m=-6.0, # Below -5.0% threshold
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.CRITICAL
|
||||
|
||||
def test_critical_volume_surge(self) -> None:
|
||||
"""Extreme volume surge (>10x) should trigger CRITICAL."""
|
||||
assessor = CriticalityAssessor()
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=0.0,
|
||||
volatility_score=50.0,
|
||||
volume_surge=12.0, # Above 10.0x threshold
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.CRITICAL
|
||||
|
||||
def test_high_volatility_returns_high(self) -> None:
|
||||
"""High volatility score should return HIGH priority."""
|
||||
assessor = CriticalityAssessor()
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=0.0,
|
||||
volatility_score=75.0, # Above 70.0 threshold
|
||||
volume_surge=1.0,
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.HIGH
|
||||
|
||||
def test_normal_conditions_return_normal(self) -> None:
|
||||
"""Normal market conditions should return NORMAL priority."""
|
||||
assessor = CriticalityAssessor()
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=0.5,
|
||||
volatility_score=50.0, # Between 30-70
|
||||
volume_surge=1.5,
|
||||
price_change_1m=1.0,
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.NORMAL
|
||||
|
||||
def test_custom_thresholds(self) -> None:
|
||||
"""Custom thresholds should be respected."""
|
||||
assessor = CriticalityAssessor(
|
||||
critical_pnl_threshold=-1.0,
|
||||
critical_price_change_threshold=3.0,
|
||||
critical_volume_surge_threshold=5.0,
|
||||
high_volatility_threshold=60.0,
|
||||
low_volatility_threshold=20.0,
|
||||
)
|
||||
|
||||
# Test custom P&L threshold
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=-1.1,
|
||||
volatility_score=50.0,
|
||||
volume_surge=1.0,
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.CRITICAL
|
||||
|
||||
# Test custom price change threshold
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=0.0,
|
||||
volatility_score=50.0,
|
||||
volume_surge=1.0,
|
||||
price_change_1m=3.5,
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.CRITICAL
|
||||
|
||||
def test_get_timeout_returns_correct_values(self) -> None:
|
||||
"""Timeout values should match specification."""
|
||||
assessor = CriticalityAssessor()
|
||||
|
||||
assert assessor.get_timeout(CriticalityLevel.CRITICAL) == 5.0
|
||||
assert assessor.get_timeout(CriticalityLevel.HIGH) == 30.0
|
||||
assert assessor.get_timeout(CriticalityLevel.NORMAL) == 60.0
|
||||
assert assessor.get_timeout(CriticalityLevel.LOW) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PriorityTaskQueue Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPriorityTaskQueue:
|
||||
"""Test suite for priority queue implementation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_task(self) -> None:
|
||||
"""Tasks should be enqueued successfully."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
success = await queue.enqueue(
|
||||
task_id="test-1",
|
||||
criticality=CriticalityLevel.NORMAL,
|
||||
task_data={"action": "test"},
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert await queue.size() == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_rejects_when_full(self) -> None:
|
||||
"""Queue should reject tasks when full."""
|
||||
queue = PriorityTaskQueue(max_size=2)
|
||||
|
||||
# Fill the queue
|
||||
await queue.enqueue("task-1", CriticalityLevel.NORMAL, {})
|
||||
await queue.enqueue("task-2", CriticalityLevel.NORMAL, {})
|
||||
|
||||
# Third task should be rejected
|
||||
success = await queue.enqueue("task-3", CriticalityLevel.NORMAL, {})
|
||||
assert success is False
|
||||
assert await queue.size() == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dequeue_returns_highest_priority(self) -> None:
|
||||
"""Dequeue should return highest priority task first."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Enqueue tasks in reverse priority order
|
||||
await queue.enqueue("low", CriticalityLevel.LOW, {"priority": 3})
|
||||
await queue.enqueue("normal", CriticalityLevel.NORMAL, {"priority": 2})
|
||||
await queue.enqueue("high", CriticalityLevel.HIGH, {"priority": 1})
|
||||
await queue.enqueue("critical", CriticalityLevel.CRITICAL, {"priority": 0})
|
||||
|
||||
# Dequeue should return CRITICAL first
|
||||
task = await queue.dequeue(timeout=1.0)
|
||||
assert task is not None
|
||||
assert task.task_id == "critical"
|
||||
assert task.priority == 0
|
||||
|
||||
# Then HIGH
|
||||
task = await queue.dequeue(timeout=1.0)
|
||||
assert task is not None
|
||||
assert task.task_id == "high"
|
||||
assert task.priority == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dequeue_fifo_within_same_priority(self) -> None:
|
||||
"""Tasks with same priority should be FIFO."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Enqueue multiple tasks with same priority
|
||||
await queue.enqueue("task-1", CriticalityLevel.NORMAL, {})
|
||||
await asyncio.sleep(0.01) # Small delay to ensure different timestamps
|
||||
await queue.enqueue("task-2", CriticalityLevel.NORMAL, {})
|
||||
await asyncio.sleep(0.01)
|
||||
await queue.enqueue("task-3", CriticalityLevel.NORMAL, {})
|
||||
|
||||
# Should dequeue in FIFO order
|
||||
task1 = await queue.dequeue(timeout=1.0)
|
||||
task2 = await queue.dequeue(timeout=1.0)
|
||||
task3 = await queue.dequeue(timeout=1.0)
|
||||
|
||||
assert task1 is not None and task1.task_id == "task-1"
|
||||
assert task2 is not None and task2.task_id == "task-2"
|
||||
assert task3 is not None and task3.task_id == "task-3"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dequeue_returns_none_when_empty(self) -> None:
|
||||
"""Dequeue should return None when queue is empty after timeout."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
task = await queue.dequeue(timeout=0.1)
|
||||
assert task is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_with_timeout_success(self) -> None:
|
||||
"""Task execution should succeed within timeout."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Create a simple async callback
|
||||
async def test_callback() -> str:
|
||||
await asyncio.sleep(0.01)
|
||||
return "success"
|
||||
|
||||
task = PriorityTask(
|
||||
priority=0,
|
||||
timestamp=0.0,
|
||||
task_id="test",
|
||||
task_data={},
|
||||
callback=test_callback,
|
||||
)
|
||||
|
||||
result = await queue.execute_with_timeout(task, timeout=1.0)
|
||||
assert result == "success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_with_timeout_raises_timeout_error(self) -> None:
|
||||
"""Task execution should raise TimeoutError if exceeds timeout."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Create a slow async callback
|
||||
async def slow_callback() -> str:
|
||||
await asyncio.sleep(1.0)
|
||||
return "too slow"
|
||||
|
||||
task = PriorityTask(
|
||||
priority=0,
|
||||
timestamp=0.0,
|
||||
task_id="test",
|
||||
task_data={},
|
||||
callback=slow_callback,
|
||||
)
|
||||
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await queue.execute_with_timeout(task, timeout=0.1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_with_timeout_propagates_exceptions(self) -> None:
|
||||
"""Task execution should propagate exceptions from callback."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Create a failing async callback
|
||||
async def failing_callback() -> None:
|
||||
raise ValueError("Test error")
|
||||
|
||||
task = PriorityTask(
|
||||
priority=0,
|
||||
timestamp=0.0,
|
||||
task_id="test",
|
||||
task_data={},
|
||||
callback=failing_callback,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Test error"):
|
||||
await queue.execute_with_timeout(task, timeout=1.0)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_without_timeout(self) -> None:
|
||||
"""Task execution should work without timeout (LOW priority)."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
async def test_callback() -> str:
|
||||
await asyncio.sleep(0.01)
|
||||
return "success"
|
||||
|
||||
task = PriorityTask(
|
||||
priority=3,
|
||||
timestamp=0.0,
|
||||
task_id="test",
|
||||
task_data={},
|
||||
callback=test_callback,
|
||||
)
|
||||
|
||||
result = await queue.execute_with_timeout(task, timeout=None)
|
||||
assert result == "success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_metrics(self) -> None:
|
||||
"""Queue should track metrics correctly."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Enqueue and dequeue some tasks
|
||||
await queue.enqueue("task-1", CriticalityLevel.CRITICAL, {})
|
||||
await queue.enqueue("task-2", CriticalityLevel.HIGH, {})
|
||||
await queue.enqueue("task-3", CriticalityLevel.NORMAL, {})
|
||||
|
||||
await queue.dequeue(timeout=1.0)
|
||||
await queue.dequeue(timeout=1.0)
|
||||
|
||||
metrics = await queue.get_metrics()
|
||||
|
||||
assert metrics.total_enqueued == 3
|
||||
assert metrics.total_dequeued == 2
|
||||
assert metrics.current_size == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_time_metrics(self) -> None:
|
||||
"""Queue should track wait times per criticality level."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Enqueue tasks with different criticality
|
||||
await queue.enqueue("critical-1", CriticalityLevel.CRITICAL, {})
|
||||
await asyncio.sleep(0.05) # Add some wait time
|
||||
|
||||
await queue.dequeue(timeout=1.0)
|
||||
|
||||
metrics = await queue.get_metrics()
|
||||
|
||||
# Should have wait time metrics for CRITICAL
|
||||
assert CriticalityLevel.CRITICAL in metrics.avg_wait_time
|
||||
assert metrics.avg_wait_time[CriticalityLevel.CRITICAL] > 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_queue(self) -> None:
|
||||
"""Clear should remove all tasks from queue."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
await queue.enqueue("task-1", CriticalityLevel.NORMAL, {})
|
||||
await queue.enqueue("task-2", CriticalityLevel.NORMAL, {})
|
||||
await queue.enqueue("task-3", CriticalityLevel.NORMAL, {})
|
||||
|
||||
cleared = await queue.clear()
|
||||
|
||||
assert cleared == 3
|
||||
assert await queue.size() == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_enqueue_dequeue(self) -> None:
|
||||
"""Queue should handle concurrent operations safely."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Concurrent enqueue operations
|
||||
async def enqueue_tasks() -> None:
|
||||
for i in range(10):
|
||||
await queue.enqueue(
|
||||
f"task-{i}",
|
||||
CriticalityLevel.NORMAL,
|
||||
{"index": i},
|
||||
)
|
||||
|
||||
# Concurrent dequeue operations
|
||||
async def dequeue_tasks() -> list[str]:
|
||||
tasks = []
|
||||
for _ in range(10):
|
||||
task = await queue.dequeue(timeout=1.0)
|
||||
if task:
|
||||
tasks.append(task.task_id)
|
||||
await asyncio.sleep(0.01)
|
||||
return tasks
|
||||
|
||||
# Run both concurrently
|
||||
enqueue_task = asyncio.create_task(enqueue_tasks())
|
||||
dequeue_task = asyncio.create_task(dequeue_tasks())
|
||||
|
||||
await enqueue_task
|
||||
dequeued_ids = await dequeue_task
|
||||
|
||||
# All tasks should be processed
|
||||
assert len(dequeued_ids) == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_metric_tracking(self) -> None:
|
||||
"""Queue should track timeout occurrences."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
async def slow_callback() -> str:
|
||||
await asyncio.sleep(1.0)
|
||||
return "too slow"
|
||||
|
||||
task = PriorityTask(
|
||||
priority=0,
|
||||
timestamp=0.0,
|
||||
task_id="test",
|
||||
task_data={},
|
||||
callback=slow_callback,
|
||||
)
|
||||
|
||||
try:
|
||||
await queue.execute_with_timeout(task, timeout=0.1)
|
||||
except TimeoutError:
|
||||
pass
|
||||
|
||||
metrics = await queue.get_metrics()
|
||||
assert metrics.total_timeouts == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_metric_tracking(self) -> None:
|
||||
"""Queue should track execution errors."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
async def failing_callback() -> None:
|
||||
raise ValueError("Test error")
|
||||
|
||||
task = PriorityTask(
|
||||
priority=0,
|
||||
timestamp=0.0,
|
||||
task_id="test",
|
||||
task_data={},
|
||||
callback=failing_callback,
|
||||
)
|
||||
|
||||
try:
|
||||
await queue.execute_with_timeout(task, timeout=1.0)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
metrics = await queue.get_metrics()
|
||||
assert metrics.total_errors == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLatencyControlIntegration:
|
||||
"""Integration tests for criticality assessment and priority queue."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_critical_task_bypass_queue(self) -> None:
|
||||
"""CRITICAL tasks should bypass lower priority tasks."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Add normal priority tasks
|
||||
await queue.enqueue("normal-1", CriticalityLevel.NORMAL, {})
|
||||
await queue.enqueue("normal-2", CriticalityLevel.NORMAL, {})
|
||||
|
||||
# Add critical task (should jump to front)
|
||||
await queue.enqueue("critical", CriticalityLevel.CRITICAL, {})
|
||||
|
||||
# Dequeue should return critical first
|
||||
task = await queue.dequeue(timeout=1.0)
|
||||
assert task is not None
|
||||
assert task.task_id == "critical"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_enforcement_by_criticality(self) -> None:
|
||||
"""Timeout enforcement should match criticality level."""
|
||||
assessor = CriticalityAssessor()
|
||||
|
||||
# CRITICAL should have 5s timeout
|
||||
critical_timeout = assessor.get_timeout(CriticalityLevel.CRITICAL)
|
||||
assert critical_timeout == 5.0
|
||||
|
||||
# HIGH should have 30s timeout
|
||||
high_timeout = assessor.get_timeout(CriticalityLevel.HIGH)
|
||||
assert high_timeout == 30.0
|
||||
|
||||
# NORMAL should have 60s timeout
|
||||
normal_timeout = assessor.get_timeout(CriticalityLevel.NORMAL)
|
||||
assert normal_timeout == 60.0
|
||||
|
||||
# LOW should have no timeout
|
||||
low_timeout = assessor.get_timeout(CriticalityLevel.LOW)
|
||||
assert low_timeout is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fast_path_execution_for_critical(self) -> None:
|
||||
"""CRITICAL tasks should complete quickly."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Create a fast callback simulating fast-path execution
|
||||
async def fast_path_callback() -> str:
|
||||
# Simulate simplified decision flow
|
||||
await asyncio.sleep(0.01) # Very fast execution
|
||||
return "fast_path_complete"
|
||||
|
||||
task = PriorityTask(
|
||||
priority=0, # CRITICAL
|
||||
timestamp=0.0,
|
||||
task_id="critical-fast",
|
||||
task_data={},
|
||||
callback=fast_path_callback,
|
||||
)
|
||||
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
result = await queue.execute_with_timeout(task, timeout=5.0)
|
||||
elapsed = time.time() - start
|
||||
|
||||
assert result == "fast_path_complete"
|
||||
assert elapsed < 5.0 # Should complete well under CRITICAL timeout
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graceful_degradation_when_queue_full(self) -> None:
|
||||
"""System should gracefully handle full queue."""
|
||||
queue = PriorityTaskQueue(max_size=2)
|
||||
|
||||
# Fill the queue
|
||||
await queue.enqueue("task-1", CriticalityLevel.NORMAL, {})
|
||||
await queue.enqueue("task-2", CriticalityLevel.NORMAL, {})
|
||||
|
||||
# Try to add more tasks
|
||||
success = await queue.enqueue("task-3", CriticalityLevel.NORMAL, {})
|
||||
assert success is False
|
||||
|
||||
# Queue should still function
|
||||
task = await queue.dequeue(timeout=1.0)
|
||||
assert task is not None
|
||||
|
||||
# Now we can add another task
|
||||
success = await queue.enqueue("task-4", CriticalityLevel.NORMAL, {})
|
||||
assert success is True
|
||||
117
tests/test_logging_config.py
Normal file
117
tests/test_logging_config.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""Tests for JSON structured logging configuration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from src.logging_config import JSONFormatter, setup_logging
|
||||
|
||||
|
||||
class TestJSONFormatter:
|
||||
"""Test JSONFormatter output."""
|
||||
|
||||
def test_basic_log_record(self) -> None:
|
||||
"""JSONFormatter must emit valid JSON with required fields."""
|
||||
formatter = JSONFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="test.logger",
|
||||
level=logging.INFO,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg="Hello %s",
|
||||
args=("world",),
|
||||
exc_info=None,
|
||||
)
|
||||
output = formatter.format(record)
|
||||
data = json.loads(output)
|
||||
assert data["level"] == "INFO"
|
||||
assert data["logger"] == "test.logger"
|
||||
assert data["message"] == "Hello world"
|
||||
assert "timestamp" in data
|
||||
|
||||
def test_includes_exception_info(self) -> None:
|
||||
"""JSONFormatter must include exception info when present."""
|
||||
formatter = JSONFormatter()
|
||||
try:
|
||||
raise ValueError("test error")
|
||||
except ValueError:
|
||||
exc_info = sys.exc_info()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.ERROR,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg="oops",
|
||||
args=(),
|
||||
exc_info=exc_info,
|
||||
)
|
||||
output = formatter.format(record)
|
||||
data = json.loads(output)
|
||||
assert "exception" in data
|
||||
assert "ValueError" in data["exception"]
|
||||
|
||||
def test_extra_trading_fields_included(self) -> None:
|
||||
"""Extra trading fields attached to the record must appear in JSON."""
|
||||
formatter = JSONFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg="trade",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
record.stock_code = "005930" # type: ignore[attr-defined]
|
||||
record.action = "BUY" # type: ignore[attr-defined]
|
||||
record.confidence = 85 # type: ignore[attr-defined]
|
||||
record.pnl_pct = -1.5 # type: ignore[attr-defined]
|
||||
record.order_amount = 1_000_000 # type: ignore[attr-defined]
|
||||
output = formatter.format(record)
|
||||
data = json.loads(output)
|
||||
assert data["stock_code"] == "005930"
|
||||
assert data["action"] == "BUY"
|
||||
assert data["confidence"] == 85
|
||||
assert data["pnl_pct"] == -1.5
|
||||
assert data["order_amount"] == 1_000_000
|
||||
|
||||
def test_none_extra_fields_excluded(self) -> None:
|
||||
"""Extra fields that are None must not appear in JSON output."""
|
||||
formatter = JSONFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="test",
|
||||
level=logging.INFO,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg="no extras",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
output = formatter.format(record)
|
||||
data = json.loads(output)
|
||||
assert "stock_code" not in data
|
||||
assert "action" not in data
|
||||
assert "confidence" not in data
|
||||
|
||||
|
||||
class TestSetupLogging:
|
||||
"""Test setup_logging function."""
|
||||
|
||||
def test_configures_root_logger(self) -> None:
|
||||
"""setup_logging must attach a JSON handler to the root logger."""
|
||||
setup_logging(level=logging.DEBUG)
|
||||
root = logging.getLogger()
|
||||
json_handlers = [
|
||||
h for h in root.handlers if isinstance(h.formatter, JSONFormatter)
|
||||
]
|
||||
assert len(json_handlers) == 1
|
||||
assert root.level == logging.DEBUG
|
||||
|
||||
def test_avoids_duplicate_handlers(self) -> None:
|
||||
"""Calling setup_logging twice must not add duplicate handlers."""
|
||||
setup_logging()
|
||||
setup_logging()
|
||||
root = logging.getLogger()
|
||||
assert len(root.handlers) == 1
|
||||
4600
tests/test_main.py
Normal file
4600
tests/test_main.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -7,6 +7,7 @@ import pytest
|
||||
|
||||
from src.markets.schedule import (
|
||||
MARKETS,
|
||||
expand_market_codes,
|
||||
get_next_market_open,
|
||||
get_open_markets,
|
||||
is_market_open,
|
||||
@@ -199,3 +200,28 @@ class TestGetNextMarketOpen:
|
||||
enabled_markets=["INVALID", "KR"], now=test_time
|
||||
)
|
||||
assert market.code == "KR"
|
||||
|
||||
|
||||
class TestExpandMarketCodes:
|
||||
"""Test shorthand market expansion."""
|
||||
|
||||
def test_expand_us_shorthand(self) -> None:
|
||||
assert expand_market_codes(["US"]) == ["US_NASDAQ", "US_NYSE", "US_AMEX"]
|
||||
|
||||
def test_expand_cn_shorthand(self) -> None:
|
||||
assert expand_market_codes(["CN"]) == ["CN_SHA", "CN_SZA"]
|
||||
|
||||
def test_expand_vn_shorthand(self) -> None:
|
||||
assert expand_market_codes(["VN"]) == ["VN_HAN", "VN_HCM"]
|
||||
|
||||
def test_expand_mixed_codes(self) -> None:
|
||||
assert expand_market_codes(["KR", "US", "JP"]) == [
|
||||
"KR",
|
||||
"US_NASDAQ",
|
||||
"US_NYSE",
|
||||
"US_AMEX",
|
||||
"JP",
|
||||
]
|
||||
|
||||
def test_expand_preserves_unknown_code(self) -> None:
|
||||
assert expand_market_codes(["KR", "UNKNOWN"]) == ["KR", "UNKNOWN"]
|
||||
|
||||
1033
tests/test_overseas_broker.py
Normal file
1033
tests/test_overseas_broker.py
Normal file
File diff suppressed because it is too large
Load Diff
289
tests/test_playbook_store.py
Normal file
289
tests/test_playbook_store.py
Normal file
@@ -0,0 +1,289 @@
|
||||
"""Tests for playbook persistence (PlaybookStore + DB schema)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
|
||||
import pytest
|
||||
|
||||
from src.db import init_db
|
||||
from src.strategy.models import (
|
||||
DayPlaybook,
|
||||
GlobalRule,
|
||||
MarketOutlook,
|
||||
PlaybookStatus,
|
||||
ScenarioAction,
|
||||
StockCondition,
|
||||
StockPlaybook,
|
||||
StockScenario,
|
||||
)
|
||||
from src.strategy.playbook_store import PlaybookStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conn():
|
||||
"""Create an in-memory DB with schema."""
|
||||
connection = init_db(":memory:")
|
||||
yield connection
|
||||
connection.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(conn) -> PlaybookStore:
|
||||
return PlaybookStore(conn)
|
||||
|
||||
|
||||
def _make_playbook(
|
||||
target_date: date = date(2026, 2, 8),
|
||||
market: str = "KR",
|
||||
outlook: MarketOutlook = MarketOutlook.NEUTRAL,
|
||||
stock_codes: list[str] | None = None,
|
||||
) -> DayPlaybook:
|
||||
"""Create a test playbook with sensible defaults."""
|
||||
if stock_codes is None:
|
||||
stock_codes = ["005930"]
|
||||
return DayPlaybook(
|
||||
date=target_date,
|
||||
market=market,
|
||||
market_outlook=outlook,
|
||||
token_count=150,
|
||||
stock_playbooks=[
|
||||
StockPlaybook(
|
||||
stock_code=code,
|
||||
scenarios=[
|
||||
StockScenario(
|
||||
condition=StockCondition(rsi_below=30.0),
|
||||
action=ScenarioAction.BUY,
|
||||
confidence=85,
|
||||
rationale=f"Oversold bounce for {code}",
|
||||
),
|
||||
],
|
||||
)
|
||||
for code in stock_codes
|
||||
],
|
||||
global_rules=[
|
||||
GlobalRule(
|
||||
condition="portfolio_pnl_pct < -2.0",
|
||||
action=ScenarioAction.REDUCE_ALL,
|
||||
rationale="Near circuit breaker",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSchema:
|
||||
def test_playbooks_table_exists(self, conn) -> None:
|
||||
row = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='playbooks'"
|
||||
).fetchone()
|
||||
assert row is not None
|
||||
|
||||
def test_unique_constraint(self, store: PlaybookStore) -> None:
|
||||
pb = _make_playbook()
|
||||
store.save(pb)
|
||||
# Saving again for same date+market should replace, not error
|
||||
pb2 = _make_playbook(stock_codes=["005930", "000660"])
|
||||
store.save(pb2)
|
||||
loaded = store.load(date(2026, 2, 8), "KR")
|
||||
assert loaded is not None
|
||||
assert loaded.stock_count == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Save / Load
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSaveLoad:
|
||||
def test_save_and_load(self, store: PlaybookStore) -> None:
|
||||
pb = _make_playbook()
|
||||
row_id = store.save(pb)
|
||||
assert row_id > 0
|
||||
|
||||
loaded = store.load(date(2026, 2, 8), "KR")
|
||||
assert loaded is not None
|
||||
assert loaded.date == date(2026, 2, 8)
|
||||
assert loaded.market == "KR"
|
||||
assert loaded.stock_count == 1
|
||||
assert loaded.scenario_count == 1
|
||||
|
||||
def test_load_not_found(self, store: PlaybookStore) -> None:
|
||||
result = store.load(date(2026, 1, 1), "KR")
|
||||
assert result is None
|
||||
|
||||
def test_save_preserves_all_fields(self, store: PlaybookStore) -> None:
|
||||
pb = _make_playbook(
|
||||
outlook=MarketOutlook.BULLISH,
|
||||
stock_codes=["005930", "AAPL"],
|
||||
)
|
||||
store.save(pb)
|
||||
loaded = store.load(date(2026, 2, 8), "KR")
|
||||
assert loaded is not None
|
||||
assert loaded.market_outlook == MarketOutlook.BULLISH
|
||||
assert loaded.stock_count == 2
|
||||
assert loaded.global_rules[0].action == ScenarioAction.REDUCE_ALL
|
||||
assert loaded.token_count == 150
|
||||
|
||||
def test_save_different_markets(self, store: PlaybookStore) -> None:
|
||||
kr = _make_playbook(market="KR")
|
||||
us = _make_playbook(market="US", stock_codes=["AAPL"])
|
||||
store.save(kr)
|
||||
store.save(us)
|
||||
|
||||
kr_loaded = store.load(date(2026, 2, 8), "KR")
|
||||
us_loaded = store.load(date(2026, 2, 8), "US")
|
||||
assert kr_loaded is not None
|
||||
assert us_loaded is not None
|
||||
assert kr_loaded.market == "KR"
|
||||
assert us_loaded.market == "US"
|
||||
assert kr_loaded.stock_playbooks[0].stock_code == "005930"
|
||||
assert us_loaded.stock_playbooks[0].stock_code == "AAPL"
|
||||
|
||||
def test_save_different_dates(self, store: PlaybookStore) -> None:
|
||||
d1 = _make_playbook(target_date=date(2026, 2, 7))
|
||||
d2 = _make_playbook(target_date=date(2026, 2, 8))
|
||||
store.save(d1)
|
||||
store.save(d2)
|
||||
|
||||
assert store.load(date(2026, 2, 7), "KR") is not None
|
||||
assert store.load(date(2026, 2, 8), "KR") is not None
|
||||
|
||||
def test_replace_updates_data(self, store: PlaybookStore) -> None:
|
||||
pb1 = _make_playbook(outlook=MarketOutlook.BEARISH)
|
||||
store.save(pb1)
|
||||
|
||||
pb2 = _make_playbook(outlook=MarketOutlook.BULLISH)
|
||||
store.save(pb2)
|
||||
|
||||
loaded = store.load(date(2026, 2, 8), "KR")
|
||||
assert loaded is not None
|
||||
assert loaded.market_outlook == MarketOutlook.BULLISH
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStatus:
|
||||
def test_get_status(self, store: PlaybookStore) -> None:
|
||||
store.save(_make_playbook())
|
||||
status = store.get_status(date(2026, 2, 8), "KR")
|
||||
assert status == PlaybookStatus.READY
|
||||
|
||||
def test_get_status_not_found(self, store: PlaybookStore) -> None:
|
||||
assert store.get_status(date(2026, 1, 1), "KR") is None
|
||||
|
||||
def test_update_status(self, store: PlaybookStore) -> None:
|
||||
store.save(_make_playbook())
|
||||
updated = store.update_status(date(2026, 2, 8), "KR", PlaybookStatus.EXPIRED)
|
||||
assert updated is True
|
||||
|
||||
status = store.get_status(date(2026, 2, 8), "KR")
|
||||
assert status == PlaybookStatus.EXPIRED
|
||||
|
||||
def test_update_status_not_found(self, store: PlaybookStore) -> None:
|
||||
updated = store.update_status(date(2026, 1, 1), "KR", PlaybookStatus.FAILED)
|
||||
assert updated is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Match count
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMatchCount:
|
||||
def test_increment_match_count(self, store: PlaybookStore) -> None:
|
||||
store.save(_make_playbook())
|
||||
store.increment_match_count(date(2026, 2, 8), "KR")
|
||||
store.increment_match_count(date(2026, 2, 8), "KR")
|
||||
|
||||
stats = store.get_stats(date(2026, 2, 8), "KR")
|
||||
assert stats is not None
|
||||
assert stats["match_count"] == 2
|
||||
|
||||
def test_increment_not_found(self, store: PlaybookStore) -> None:
|
||||
result = store.increment_match_count(date(2026, 1, 1), "KR")
|
||||
assert result is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stats
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStats:
|
||||
def test_get_stats(self, store: PlaybookStore) -> None:
|
||||
store.save(_make_playbook())
|
||||
stats = store.get_stats(date(2026, 2, 8), "KR")
|
||||
assert stats is not None
|
||||
assert stats["status"] == "ready"
|
||||
assert stats["token_count"] == 150
|
||||
assert stats["scenario_count"] == 1
|
||||
assert stats["match_count"] == 0
|
||||
assert stats["generated_at"] != ""
|
||||
|
||||
def test_get_stats_not_found(self, store: PlaybookStore) -> None:
|
||||
assert store.get_stats(date(2026, 1, 1), "KR") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# List recent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestListRecent:
|
||||
def test_list_recent(self, store: PlaybookStore) -> None:
|
||||
for day in range(5, 10):
|
||||
store.save(_make_playbook(target_date=date(2026, 2, day)))
|
||||
results = store.list_recent(market="KR", limit=3)
|
||||
assert len(results) == 3
|
||||
# Most recent first
|
||||
assert results[0]["date"] == "2026-02-09"
|
||||
assert results[2]["date"] == "2026-02-07"
|
||||
|
||||
def test_list_recent_all_markets(self, store: PlaybookStore) -> None:
|
||||
store.save(_make_playbook(market="KR"))
|
||||
store.save(_make_playbook(market="US", stock_codes=["AAPL"]))
|
||||
results = store.list_recent(market=None, limit=10)
|
||||
assert len(results) == 2
|
||||
|
||||
def test_list_recent_empty(self, store: PlaybookStore) -> None:
|
||||
results = store.list_recent(market="KR")
|
||||
assert results == []
|
||||
|
||||
def test_list_recent_filter_by_market(self, store: PlaybookStore) -> None:
|
||||
store.save(_make_playbook(market="KR"))
|
||||
store.save(_make_playbook(market="US", stock_codes=["AAPL"]))
|
||||
kr_only = store.list_recent(market="KR")
|
||||
assert len(kr_only) == 1
|
||||
assert kr_only[0]["market"] == "KR"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Delete
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDelete:
|
||||
def test_delete(self, store: PlaybookStore) -> None:
|
||||
store.save(_make_playbook())
|
||||
deleted = store.delete(date(2026, 2, 8), "KR")
|
||||
assert deleted is True
|
||||
assert store.load(date(2026, 2, 8), "KR") is None
|
||||
|
||||
def test_delete_not_found(self, store: PlaybookStore) -> None:
|
||||
deleted = store.delete(date(2026, 1, 1), "KR")
|
||||
assert deleted is False
|
||||
|
||||
def test_delete_one_market_keeps_other(self, store: PlaybookStore) -> None:
|
||||
store.save(_make_playbook(market="KR"))
|
||||
store.save(_make_playbook(market="US", stock_codes=["AAPL"]))
|
||||
store.delete(date(2026, 2, 8), "KR")
|
||||
assert store.load(date(2026, 2, 8), "KR") is None
|
||||
assert store.load(date(2026, 2, 8), "US") is not None
|
||||
1000
tests/test_pre_market_planner.py
Normal file
1000
tests/test_pre_market_planner.py
Normal file
File diff suppressed because it is too large
Load Diff
574
tests/test_scenario_engine.py
Normal file
574
tests/test_scenario_engine.py
Normal file
@@ -0,0 +1,574 @@
|
||||
"""Tests for the local scenario engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
|
||||
import pytest
|
||||
|
||||
from src.strategy.models import (
|
||||
DayPlaybook,
|
||||
GlobalRule,
|
||||
ScenarioAction,
|
||||
StockCondition,
|
||||
StockPlaybook,
|
||||
StockScenario,
|
||||
)
|
||||
from src.strategy.scenario_engine import ScenarioEngine, ScenarioMatch
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine() -> ScenarioEngine:
|
||||
return ScenarioEngine()
|
||||
|
||||
|
||||
def _scenario(
|
||||
rsi_below: float | None = None,
|
||||
rsi_above: float | None = None,
|
||||
volume_ratio_above: float | None = None,
|
||||
action: ScenarioAction = ScenarioAction.BUY,
|
||||
confidence: int = 85,
|
||||
**kwargs,
|
||||
) -> StockScenario:
|
||||
return StockScenario(
|
||||
condition=StockCondition(
|
||||
rsi_below=rsi_below,
|
||||
rsi_above=rsi_above,
|
||||
volume_ratio_above=volume_ratio_above,
|
||||
**kwargs,
|
||||
),
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
rationale=f"Test scenario: {action.value}",
|
||||
)
|
||||
|
||||
|
||||
def _playbook(
|
||||
stock_code: str = "005930",
|
||||
scenarios: list[StockScenario] | None = None,
|
||||
global_rules: list[GlobalRule] | None = None,
|
||||
default_action: ScenarioAction = ScenarioAction.HOLD,
|
||||
) -> DayPlaybook:
|
||||
if scenarios is None:
|
||||
scenarios = [_scenario(rsi_below=30.0)]
|
||||
return DayPlaybook(
|
||||
date=date(2026, 2, 7),
|
||||
market="KR",
|
||||
stock_playbooks=[StockPlaybook(stock_code=stock_code, scenarios=scenarios)],
|
||||
global_rules=global_rules or [],
|
||||
default_action=default_action,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# evaluate_condition
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEvaluateCondition:
|
||||
def test_rsi_below_match(self, engine: ScenarioEngine) -> None:
|
||||
cond = StockCondition(rsi_below=30.0)
|
||||
assert engine.evaluate_condition(cond, {"rsi": 25.0})
|
||||
|
||||
def test_rsi_below_no_match(self, engine: ScenarioEngine) -> None:
|
||||
cond = StockCondition(rsi_below=30.0)
|
||||
assert not engine.evaluate_condition(cond, {"rsi": 35.0})
|
||||
|
||||
def test_rsi_above_match(self, engine: ScenarioEngine) -> None:
|
||||
cond = StockCondition(rsi_above=70.0)
|
||||
assert engine.evaluate_condition(cond, {"rsi": 75.0})
|
||||
|
||||
def test_rsi_above_no_match(self, engine: ScenarioEngine) -> None:
|
||||
cond = StockCondition(rsi_above=70.0)
|
||||
assert not engine.evaluate_condition(cond, {"rsi": 65.0})
|
||||
|
||||
def test_volume_ratio_above_match(self, engine: ScenarioEngine) -> None:
|
||||
cond = StockCondition(volume_ratio_above=3.0)
|
||||
assert engine.evaluate_condition(cond, {"volume_ratio": 4.5})
|
||||
|
||||
def test_volume_ratio_below_match(self, engine: ScenarioEngine) -> None:
|
||||
cond = StockCondition(volume_ratio_below=1.0)
|
||||
assert engine.evaluate_condition(cond, {"volume_ratio": 0.5})
|
||||
|
||||
def test_price_above_match(self, engine: ScenarioEngine) -> None:
|
||||
cond = StockCondition(price_above=50000)
|
||||
assert engine.evaluate_condition(cond, {"current_price": 55000})
|
||||
|
||||
def test_price_below_match(self, engine: ScenarioEngine) -> None:
|
||||
cond = StockCondition(price_below=50000)
|
||||
assert engine.evaluate_condition(cond, {"current_price": 45000})
|
||||
|
||||
def test_price_change_pct_above_match(self, engine: ScenarioEngine) -> None:
|
||||
cond = StockCondition(price_change_pct_above=2.0)
|
||||
assert engine.evaluate_condition(cond, {"price_change_pct": 3.5})
|
||||
|
||||
def test_price_change_pct_below_match(self, engine: ScenarioEngine) -> None:
|
||||
cond = StockCondition(price_change_pct_below=-3.0)
|
||||
assert engine.evaluate_condition(cond, {"price_change_pct": -4.0})
|
||||
|
||||
def test_multiple_conditions_and_logic(self, engine: ScenarioEngine) -> None:
|
||||
cond = StockCondition(rsi_below=30.0, volume_ratio_above=3.0)
|
||||
# Both met
|
||||
assert engine.evaluate_condition(cond, {"rsi": 25.0, "volume_ratio": 4.0})
|
||||
# Only RSI met
|
||||
assert not engine.evaluate_condition(cond, {"rsi": 25.0, "volume_ratio": 2.0})
|
||||
# Only volume met
|
||||
assert not engine.evaluate_condition(cond, {"rsi": 35.0, "volume_ratio": 4.0})
|
||||
# Neither met
|
||||
assert not engine.evaluate_condition(cond, {"rsi": 35.0, "volume_ratio": 2.0})
|
||||
|
||||
def test_empty_condition_returns_false(self, engine: ScenarioEngine) -> None:
|
||||
cond = StockCondition()
|
||||
assert not engine.evaluate_condition(cond, {"rsi": 25.0})
|
||||
|
||||
def test_missing_data_returns_false(self, engine: ScenarioEngine) -> None:
|
||||
cond = StockCondition(rsi_below=30.0)
|
||||
assert not engine.evaluate_condition(cond, {})
|
||||
|
||||
def test_none_data_returns_false(self, engine: ScenarioEngine) -> None:
|
||||
cond = StockCondition(rsi_below=30.0)
|
||||
assert not engine.evaluate_condition(cond, {"rsi": None})
|
||||
|
||||
def test_boundary_value_not_matched(self, engine: ScenarioEngine) -> None:
|
||||
"""rsi_below=30 should NOT match rsi=30 (strict less than)."""
|
||||
cond = StockCondition(rsi_below=30.0)
|
||||
assert not engine.evaluate_condition(cond, {"rsi": 30.0})
|
||||
|
||||
def test_boundary_value_above_not_matched(self, engine: ScenarioEngine) -> None:
|
||||
"""rsi_above=70 should NOT match rsi=70 (strict greater than)."""
|
||||
cond = StockCondition(rsi_above=70.0)
|
||||
assert not engine.evaluate_condition(cond, {"rsi": 70.0})
|
||||
|
||||
def test_string_value_no_exception(self, engine: ScenarioEngine) -> None:
|
||||
"""String numeric value should not raise TypeError."""
|
||||
cond = StockCondition(rsi_below=30.0)
|
||||
# "25" can be cast to float → should match
|
||||
assert engine.evaluate_condition(cond, {"rsi": "25"})
|
||||
# "35" → should not match
|
||||
assert not engine.evaluate_condition(cond, {"rsi": "35"})
|
||||
|
||||
def test_percent_string_returns_false(self, engine: ScenarioEngine) -> None:
|
||||
"""Percent string like '30%' cannot be cast to float → False, no exception."""
|
||||
cond = StockCondition(rsi_below=30.0)
|
||||
assert not engine.evaluate_condition(cond, {"rsi": "30%"})
|
||||
|
||||
def test_decimal_value_no_exception(self, engine: ScenarioEngine) -> None:
|
||||
"""Decimal values should be safely handled."""
|
||||
from decimal import Decimal
|
||||
|
||||
cond = StockCondition(rsi_below=30.0)
|
||||
assert engine.evaluate_condition(cond, {"rsi": Decimal("25.0")})
|
||||
|
||||
def test_mixed_invalid_types_no_exception(self, engine: ScenarioEngine) -> None:
|
||||
"""Various invalid types should not raise exceptions."""
|
||||
cond = StockCondition(
|
||||
rsi_below=30.0, volume_ratio_above=2.0,
|
||||
price_above=100, price_change_pct_below=-1.0,
|
||||
)
|
||||
data = {
|
||||
"rsi": [25], # list
|
||||
"volume_ratio": "bad", # non-numeric string
|
||||
"current_price": {}, # dict
|
||||
"price_change_pct": object(), # arbitrary object
|
||||
}
|
||||
# Should return False (invalid types → None → False), never raise
|
||||
assert not engine.evaluate_condition(cond, data)
|
||||
|
||||
def test_missing_key_logs_warning_once(self, caplog) -> None:
|
||||
"""Missing key warning should fire only once per key per engine instance."""
|
||||
import logging
|
||||
|
||||
eng = ScenarioEngine()
|
||||
cond = StockCondition(rsi_below=30.0)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
eng.evaluate_condition(cond, {})
|
||||
eng.evaluate_condition(cond, {})
|
||||
eng.evaluate_condition(cond, {})
|
||||
# Warning should appear exactly once despite 3 calls
|
||||
assert caplog.text.count("'rsi' but key missing") == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_global_rules
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckGlobalRules:
|
||||
def test_no_rules(self, engine: ScenarioEngine) -> None:
|
||||
pb = _playbook(global_rules=[])
|
||||
result = engine.check_global_rules(pb, {"portfolio_pnl_pct": -1.0})
|
||||
assert result is None
|
||||
|
||||
def test_rule_triggered(self, engine: ScenarioEngine) -> None:
|
||||
pb = _playbook(
|
||||
global_rules=[
|
||||
GlobalRule(
|
||||
condition="portfolio_pnl_pct < -2.0",
|
||||
action=ScenarioAction.REDUCE_ALL,
|
||||
rationale="Near circuit breaker",
|
||||
),
|
||||
]
|
||||
)
|
||||
result = engine.check_global_rules(pb, {"portfolio_pnl_pct": -2.5})
|
||||
assert result is not None
|
||||
assert result.action == ScenarioAction.REDUCE_ALL
|
||||
|
||||
def test_rule_not_triggered(self, engine: ScenarioEngine) -> None:
|
||||
pb = _playbook(
|
||||
global_rules=[
|
||||
GlobalRule(
|
||||
condition="portfolio_pnl_pct < -2.0",
|
||||
action=ScenarioAction.REDUCE_ALL,
|
||||
),
|
||||
]
|
||||
)
|
||||
result = engine.check_global_rules(pb, {"portfolio_pnl_pct": -1.0})
|
||||
assert result is None
|
||||
|
||||
def test_first_rule_wins(self, engine: ScenarioEngine) -> None:
|
||||
pb = _playbook(
|
||||
global_rules=[
|
||||
GlobalRule(condition="portfolio_pnl_pct < -2.0", action=ScenarioAction.REDUCE_ALL),
|
||||
GlobalRule(condition="portfolio_pnl_pct < -1.0", action=ScenarioAction.HOLD),
|
||||
]
|
||||
)
|
||||
result = engine.check_global_rules(pb, {"portfolio_pnl_pct": -2.5})
|
||||
assert result is not None
|
||||
assert result.action == ScenarioAction.REDUCE_ALL
|
||||
|
||||
def test_greater_than_operator(self, engine: ScenarioEngine) -> None:
|
||||
pb = _playbook(
|
||||
global_rules=[
|
||||
GlobalRule(condition="volatility_index > 30", action=ScenarioAction.HOLD),
|
||||
]
|
||||
)
|
||||
result = engine.check_global_rules(pb, {"volatility_index": 35})
|
||||
assert result is not None
|
||||
|
||||
def test_missing_field_not_triggered(self, engine: ScenarioEngine) -> None:
|
||||
pb = _playbook(
|
||||
global_rules=[
|
||||
GlobalRule(condition="unknown_field < -2.0", action=ScenarioAction.REDUCE_ALL),
|
||||
]
|
||||
)
|
||||
result = engine.check_global_rules(pb, {"portfolio_pnl_pct": -5.0})
|
||||
assert result is None
|
||||
|
||||
def test_invalid_condition_format(self, engine: ScenarioEngine) -> None:
|
||||
pb = _playbook(
|
||||
global_rules=[
|
||||
GlobalRule(condition="bad format", action=ScenarioAction.HOLD),
|
||||
]
|
||||
)
|
||||
result = engine.check_global_rules(pb, {})
|
||||
assert result is None
|
||||
|
||||
def test_le_operator(self, engine: ScenarioEngine) -> None:
|
||||
pb = _playbook(
|
||||
global_rules=[
|
||||
GlobalRule(condition="portfolio_pnl_pct <= -2.0", action=ScenarioAction.REDUCE_ALL),
|
||||
]
|
||||
)
|
||||
assert engine.check_global_rules(pb, {"portfolio_pnl_pct": -2.0}) is not None
|
||||
assert engine.check_global_rules(pb, {"portfolio_pnl_pct": -1.9}) is None
|
||||
|
||||
def test_ge_operator(self, engine: ScenarioEngine) -> None:
|
||||
pb = _playbook(
|
||||
global_rules=[
|
||||
GlobalRule(condition="volatility >= 80.0", action=ScenarioAction.HOLD),
|
||||
]
|
||||
)
|
||||
assert engine.check_global_rules(pb, {"volatility": 80.0}) is not None
|
||||
assert engine.check_global_rules(pb, {"volatility": 79.9}) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# evaluate (full pipeline)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEvaluate:
|
||||
def test_scenario_match(self, engine: ScenarioEngine) -> None:
|
||||
pb = _playbook(scenarios=[_scenario(rsi_below=30.0)])
|
||||
result = engine.evaluate(pb, "005930", {"rsi": 25.0}, {})
|
||||
assert result.action == ScenarioAction.BUY
|
||||
assert result.confidence == 85
|
||||
assert result.matched_scenario is not None
|
||||
|
||||
def test_no_scenario_match_returns_default(self, engine: ScenarioEngine) -> None:
|
||||
pb = _playbook(scenarios=[_scenario(rsi_below=30.0)])
|
||||
result = engine.evaluate(pb, "005930", {"rsi": 50.0}, {})
|
||||
assert result.action == ScenarioAction.HOLD
|
||||
assert result.confidence == 0
|
||||
assert result.matched_scenario is None
|
||||
|
||||
def test_stock_not_in_playbook(self, engine: ScenarioEngine) -> None:
|
||||
pb = _playbook(stock_code="005930")
|
||||
result = engine.evaluate(pb, "AAPL", {"rsi": 25.0}, {})
|
||||
assert result.action == ScenarioAction.HOLD
|
||||
assert result.confidence == 0
|
||||
|
||||
def test_global_rule_takes_priority(self, engine: ScenarioEngine) -> None:
|
||||
pb = _playbook(
|
||||
scenarios=[_scenario(rsi_below=30.0)],
|
||||
global_rules=[
|
||||
GlobalRule(
|
||||
condition="portfolio_pnl_pct < -2.0",
|
||||
action=ScenarioAction.REDUCE_ALL,
|
||||
rationale="Loss limit",
|
||||
),
|
||||
],
|
||||
)
|
||||
result = engine.evaluate(
|
||||
pb,
|
||||
"005930",
|
||||
{"rsi": 25.0}, # Would match scenario
|
||||
{"portfolio_pnl_pct": -2.5}, # But global rule triggers first
|
||||
)
|
||||
assert result.action == ScenarioAction.REDUCE_ALL
|
||||
assert result.global_rule_triggered is not None
|
||||
assert result.matched_scenario is None
|
||||
|
||||
def test_first_scenario_wins(self, engine: ScenarioEngine) -> None:
|
||||
pb = _playbook(
|
||||
scenarios=[
|
||||
_scenario(rsi_below=30.0, action=ScenarioAction.BUY, confidence=90),
|
||||
_scenario(rsi_below=25.0, action=ScenarioAction.BUY, confidence=95),
|
||||
]
|
||||
)
|
||||
result = engine.evaluate(pb, "005930", {"rsi": 20.0}, {})
|
||||
# Both match, but first wins
|
||||
assert result.confidence == 90
|
||||
|
||||
def test_sell_scenario(self, engine: ScenarioEngine) -> None:
|
||||
pb = _playbook(
|
||||
scenarios=[
|
||||
_scenario(rsi_above=75.0, action=ScenarioAction.SELL, confidence=80),
|
||||
]
|
||||
)
|
||||
result = engine.evaluate(pb, "005930", {"rsi": 80.0}, {})
|
||||
assert result.action == ScenarioAction.SELL
|
||||
|
||||
def test_empty_playbook(self, engine: ScenarioEngine) -> None:
|
||||
pb = DayPlaybook(date=date(2026, 2, 7), market="KR", stock_playbooks=[])
|
||||
result = engine.evaluate(pb, "005930", {"rsi": 25.0}, {})
|
||||
assert result.action == ScenarioAction.HOLD
|
||||
|
||||
def test_match_details_populated(self, engine: ScenarioEngine) -> None:
|
||||
pb = _playbook(scenarios=[_scenario(rsi_below=30.0, volume_ratio_above=2.0)])
|
||||
result = engine.evaluate(
|
||||
pb, "005930", {"rsi": 25.0, "volume_ratio": 3.0}, {}
|
||||
)
|
||||
assert result.match_details.get("rsi") == 25.0
|
||||
assert result.match_details.get("volume_ratio") == 3.0
|
||||
|
||||
def test_custom_default_action(self, engine: ScenarioEngine) -> None:
|
||||
pb = _playbook(
|
||||
scenarios=[_scenario(rsi_below=10.0)], # Very unlikely to match
|
||||
default_action=ScenarioAction.SELL,
|
||||
)
|
||||
result = engine.evaluate(pb, "005930", {"rsi": 50.0}, {})
|
||||
assert result.action == ScenarioAction.SELL
|
||||
|
||||
def test_multiple_stocks_in_playbook(self, engine: ScenarioEngine) -> None:
|
||||
pb = DayPlaybook(
|
||||
date=date(2026, 2, 7),
|
||||
market="US",
|
||||
stock_playbooks=[
|
||||
StockPlaybook(
|
||||
stock_code="AAPL",
|
||||
scenarios=[_scenario(rsi_below=25.0, confidence=90)],
|
||||
),
|
||||
StockPlaybook(
|
||||
stock_code="MSFT",
|
||||
scenarios=[_scenario(rsi_above=75.0, action=ScenarioAction.SELL, confidence=80)],
|
||||
),
|
||||
],
|
||||
)
|
||||
aapl = engine.evaluate(pb, "AAPL", {"rsi": 20.0}, {})
|
||||
assert aapl.action == ScenarioAction.BUY
|
||||
assert aapl.confidence == 90
|
||||
|
||||
msft = engine.evaluate(pb, "MSFT", {"rsi": 80.0}, {})
|
||||
assert msft.action == ScenarioAction.SELL
|
||||
|
||||
def test_complex_multi_condition(self, engine: ScenarioEngine) -> None:
|
||||
pb = _playbook(
|
||||
scenarios=[
|
||||
_scenario(
|
||||
rsi_below=30.0,
|
||||
volume_ratio_above=3.0,
|
||||
price_change_pct_below=-2.0,
|
||||
confidence=95,
|
||||
),
|
||||
]
|
||||
)
|
||||
# All conditions met
|
||||
result = engine.evaluate(
|
||||
pb,
|
||||
"005930",
|
||||
{"rsi": 22.0, "volume_ratio": 4.0, "price_change_pct": -3.0},
|
||||
{},
|
||||
)
|
||||
assert result.action == ScenarioAction.BUY
|
||||
assert result.confidence == 95
|
||||
|
||||
# One condition not met
|
||||
result2 = engine.evaluate(
|
||||
pb,
|
||||
"005930",
|
||||
{"rsi": 22.0, "volume_ratio": 4.0, "price_change_pct": -1.0},
|
||||
{},
|
||||
)
|
||||
assert result2.action == ScenarioAction.HOLD
|
||||
|
||||
def test_scenario_match_returns_rationale(self, engine: ScenarioEngine) -> None:
|
||||
pb = _playbook(scenarios=[_scenario(rsi_below=30.0)])
|
||||
result = engine.evaluate(pb, "005930", {"rsi": 25.0}, {})
|
||||
assert result.rationale != ""
|
||||
|
||||
def test_result_stock_code(self, engine: ScenarioEngine) -> None:
|
||||
pb = _playbook()
|
||||
result = engine.evaluate(pb, "005930", {"rsi": 25.0}, {})
|
||||
assert result.stock_code == "005930"
|
||||
|
||||
def test_match_details_normalized(self, engine: ScenarioEngine) -> None:
|
||||
"""match_details should contain _safe_float normalized values, not raw."""
|
||||
pb = _playbook(scenarios=[_scenario(rsi_below=30.0)])
|
||||
# Pass string value — should be normalized to float in match_details
|
||||
result = engine.evaluate(pb, "005930", {"rsi": "25.0"}, {})
|
||||
assert result.action == ScenarioAction.BUY
|
||||
assert result.match_details["rsi"] == 25.0
|
||||
assert isinstance(result.match_details["rsi"], float)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Position-aware condition tests (#171)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPositionAwareConditions:
|
||||
"""Tests for unrealized_pnl_pct and holding_days condition fields."""
|
||||
|
||||
def test_evaluate_condition_unrealized_pnl_above_matches(
|
||||
self, engine: ScenarioEngine
|
||||
) -> None:
|
||||
"""unrealized_pnl_pct_above should match when P&L exceeds threshold."""
|
||||
condition = StockCondition(unrealized_pnl_pct_above=3.0)
|
||||
assert engine.evaluate_condition(condition, {"unrealized_pnl_pct": 5.0}) is True
|
||||
|
||||
def test_evaluate_condition_unrealized_pnl_above_no_match(
|
||||
self, engine: ScenarioEngine
|
||||
) -> None:
|
||||
"""unrealized_pnl_pct_above should NOT match when P&L is below threshold."""
|
||||
condition = StockCondition(unrealized_pnl_pct_above=3.0)
|
||||
assert engine.evaluate_condition(condition, {"unrealized_pnl_pct": 2.0}) is False
|
||||
|
||||
def test_evaluate_condition_unrealized_pnl_below_matches(
|
||||
self, engine: ScenarioEngine
|
||||
) -> None:
|
||||
"""unrealized_pnl_pct_below should match when P&L is under threshold."""
|
||||
condition = StockCondition(unrealized_pnl_pct_below=-2.0)
|
||||
assert engine.evaluate_condition(condition, {"unrealized_pnl_pct": -3.5}) is True
|
||||
|
||||
def test_evaluate_condition_unrealized_pnl_below_no_match(
|
||||
self, engine: ScenarioEngine
|
||||
) -> None:
|
||||
"""unrealized_pnl_pct_below should NOT match when P&L is above threshold."""
|
||||
condition = StockCondition(unrealized_pnl_pct_below=-2.0)
|
||||
assert engine.evaluate_condition(condition, {"unrealized_pnl_pct": -1.0}) is False
|
||||
|
||||
def test_evaluate_condition_holding_days_above_matches(
|
||||
self, engine: ScenarioEngine
|
||||
) -> None:
|
||||
"""holding_days_above should match when position held longer than threshold."""
|
||||
condition = StockCondition(holding_days_above=5)
|
||||
assert engine.evaluate_condition(condition, {"holding_days": 7}) is True
|
||||
|
||||
def test_evaluate_condition_holding_days_above_no_match(
|
||||
self, engine: ScenarioEngine
|
||||
) -> None:
|
||||
"""holding_days_above should NOT match when position held shorter."""
|
||||
condition = StockCondition(holding_days_above=5)
|
||||
assert engine.evaluate_condition(condition, {"holding_days": 3}) is False
|
||||
|
||||
def test_evaluate_condition_holding_days_below_matches(
|
||||
self, engine: ScenarioEngine
|
||||
) -> None:
|
||||
"""holding_days_below should match when position held fewer days."""
|
||||
condition = StockCondition(holding_days_below=3)
|
||||
assert engine.evaluate_condition(condition, {"holding_days": 1}) is True
|
||||
|
||||
def test_evaluate_condition_holding_days_below_no_match(
|
||||
self, engine: ScenarioEngine
|
||||
) -> None:
|
||||
"""holding_days_below should NOT match when held more days."""
|
||||
condition = StockCondition(holding_days_below=3)
|
||||
assert engine.evaluate_condition(condition, {"holding_days": 5}) is False
|
||||
|
||||
def test_combined_pnl_and_holding_days(self, engine: ScenarioEngine) -> None:
|
||||
"""Combined position-aware conditions should AND-evaluate correctly."""
|
||||
condition = StockCondition(
|
||||
unrealized_pnl_pct_above=3.0,
|
||||
holding_days_above=5,
|
||||
)
|
||||
# Both met → match
|
||||
assert engine.evaluate_condition(
|
||||
condition,
|
||||
{"unrealized_pnl_pct": 4.5, "holding_days": 7},
|
||||
) is True
|
||||
# Only pnl met → no match
|
||||
assert engine.evaluate_condition(
|
||||
condition,
|
||||
{"unrealized_pnl_pct": 4.5, "holding_days": 3},
|
||||
) is False
|
||||
|
||||
def test_missing_unrealized_pnl_does_not_match(
|
||||
self, engine: ScenarioEngine
|
||||
) -> None:
|
||||
"""Missing unrealized_pnl_pct key should not match the condition."""
|
||||
condition = StockCondition(unrealized_pnl_pct_above=3.0)
|
||||
assert engine.evaluate_condition(condition, {}) is False
|
||||
|
||||
def test_missing_holding_days_does_not_match(
|
||||
self, engine: ScenarioEngine
|
||||
) -> None:
|
||||
"""Missing holding_days key should not match the condition."""
|
||||
condition = StockCondition(holding_days_above=5)
|
||||
assert engine.evaluate_condition(condition, {}) is False
|
||||
|
||||
def test_match_details_includes_position_fields(
|
||||
self, engine: ScenarioEngine
|
||||
) -> None:
|
||||
"""match_details should include position fields when condition specifies them."""
|
||||
pb = _playbook(
|
||||
scenarios=[
|
||||
StockScenario(
|
||||
condition=StockCondition(unrealized_pnl_pct_above=3.0),
|
||||
action=ScenarioAction.SELL,
|
||||
confidence=90,
|
||||
rationale="Take profit",
|
||||
)
|
||||
]
|
||||
)
|
||||
result = engine.evaluate(
|
||||
pb,
|
||||
"005930",
|
||||
{"unrealized_pnl_pct": 5.0},
|
||||
{},
|
||||
)
|
||||
assert result.action == ScenarioAction.SELL
|
||||
assert "unrealized_pnl_pct" in result.match_details
|
||||
assert result.match_details["unrealized_pnl_pct"] == 5.0
|
||||
|
||||
def test_position_conditions_parse_from_planner(self) -> None:
|
||||
"""StockCondition should accept and store new fields from JSON parsing."""
|
||||
condition = StockCondition(
|
||||
unrealized_pnl_pct_above=3.0,
|
||||
unrealized_pnl_pct_below=None,
|
||||
holding_days_above=5,
|
||||
holding_days_below=None,
|
||||
)
|
||||
assert condition.unrealized_pnl_pct_above == 3.0
|
||||
assert condition.holding_days_above == 5
|
||||
assert condition.has_any_condition() is True
|
||||
81
tests/test_scorecard.py
Normal file
81
tests/test_scorecard.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Tests for DailyScorecard model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from src.evolution.scorecard import DailyScorecard
|
||||
|
||||
|
||||
def test_scorecard_initialization() -> None:
|
||||
scorecard = DailyScorecard(
|
||||
date="2026-02-08",
|
||||
market="KR",
|
||||
total_decisions=10,
|
||||
buys=3,
|
||||
sells=2,
|
||||
holds=5,
|
||||
total_pnl=1234.5,
|
||||
win_rate=60.0,
|
||||
avg_confidence=78.5,
|
||||
scenario_match_rate=70.0,
|
||||
top_winners=["005930", "000660"],
|
||||
top_losers=["035420"],
|
||||
lessons=["Avoid chasing breakouts"],
|
||||
cross_market_note="US volatility spillover",
|
||||
)
|
||||
|
||||
assert scorecard.market == "KR"
|
||||
assert scorecard.total_decisions == 10
|
||||
assert scorecard.total_pnl == 1234.5
|
||||
assert scorecard.top_winners == ["005930", "000660"]
|
||||
assert scorecard.lessons == ["Avoid chasing breakouts"]
|
||||
assert scorecard.cross_market_note == "US volatility spillover"
|
||||
|
||||
|
||||
def test_scorecard_defaults() -> None:
|
||||
scorecard = DailyScorecard(
|
||||
date="2026-02-08",
|
||||
market="US",
|
||||
total_decisions=0,
|
||||
buys=0,
|
||||
sells=0,
|
||||
holds=0,
|
||||
total_pnl=0.0,
|
||||
win_rate=0.0,
|
||||
avg_confidence=0.0,
|
||||
scenario_match_rate=0.0,
|
||||
)
|
||||
|
||||
assert scorecard.top_winners == []
|
||||
assert scorecard.top_losers == []
|
||||
assert scorecard.lessons == []
|
||||
assert scorecard.cross_market_note == ""
|
||||
|
||||
|
||||
def test_scorecard_list_isolation() -> None:
|
||||
a = DailyScorecard(
|
||||
date="2026-02-08",
|
||||
market="KR",
|
||||
total_decisions=1,
|
||||
buys=1,
|
||||
sells=0,
|
||||
holds=0,
|
||||
total_pnl=10.0,
|
||||
win_rate=100.0,
|
||||
avg_confidence=90.0,
|
||||
scenario_match_rate=100.0,
|
||||
)
|
||||
b = DailyScorecard(
|
||||
date="2026-02-08",
|
||||
market="US",
|
||||
total_decisions=1,
|
||||
buys=0,
|
||||
sells=1,
|
||||
holds=0,
|
||||
total_pnl=-5.0,
|
||||
win_rate=0.0,
|
||||
avg_confidence=60.0,
|
||||
scenario_match_rate=50.0,
|
||||
)
|
||||
|
||||
a.top_winners.append("005930")
|
||||
assert b.top_winners == []
|
||||
439
tests/test_smart_scanner.py
Normal file
439
tests/test_smart_scanner.py
Normal file
@@ -0,0 +1,439 @@
|
||||
"""Tests for SmartVolatilityScanner."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from src.analysis.smart_scanner import ScanCandidate, SmartVolatilityScanner
|
||||
from src.analysis.volatility import VolatilityAnalyzer
|
||||
from src.broker.kis_api import KISBroker
|
||||
from src.broker.overseas import OverseasBroker
|
||||
from src.config import Settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings() -> Settings:
|
||||
"""Create test settings."""
|
||||
return Settings(
|
||||
KIS_APP_KEY="test",
|
||||
KIS_APP_SECRET="test",
|
||||
KIS_ACCOUNT_NO="12345678-01",
|
||||
GEMINI_API_KEY="test",
|
||||
RSI_OVERSOLD_THRESHOLD=30,
|
||||
RSI_MOMENTUM_THRESHOLD=70,
|
||||
VOL_MULTIPLIER=2.0,
|
||||
SCANNER_TOP_N=3,
|
||||
DB_PATH=":memory:",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_broker(mock_settings: Settings) -> MagicMock:
|
||||
"""Create mock broker."""
|
||||
broker = MagicMock(spec=KISBroker)
|
||||
broker._settings = mock_settings
|
||||
broker.fetch_market_rankings = AsyncMock()
|
||||
broker.get_daily_prices = AsyncMock()
|
||||
return broker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scanner(mock_broker: MagicMock, mock_settings: Settings) -> SmartVolatilityScanner:
|
||||
"""Create smart scanner instance."""
|
||||
analyzer = VolatilityAnalyzer()
|
||||
return SmartVolatilityScanner(
|
||||
broker=mock_broker,
|
||||
overseas_broker=None,
|
||||
volatility_analyzer=analyzer,
|
||||
settings=mock_settings,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_overseas_broker() -> MagicMock:
|
||||
"""Create mock overseas broker."""
|
||||
broker = MagicMock(spec=OverseasBroker)
|
||||
broker.get_overseas_price = AsyncMock()
|
||||
broker.fetch_overseas_rankings = AsyncMock(return_value=[])
|
||||
return broker
|
||||
|
||||
|
||||
class TestSmartVolatilityScanner:
|
||||
"""Test suite for SmartVolatilityScanner."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_domestic_prefers_volatility_with_liquidity_bonus(
|
||||
self, scanner: SmartVolatilityScanner, mock_broker: MagicMock
|
||||
) -> None:
|
||||
"""Domestic scan should score by volatility first and volume rank second."""
|
||||
fluctuation_rows = [
|
||||
{
|
||||
"stock_code": "005930",
|
||||
"name": "Samsung",
|
||||
"price": 70000,
|
||||
"volume": 5000000,
|
||||
"change_rate": -5.0,
|
||||
"volume_increase_rate": 250,
|
||||
},
|
||||
{
|
||||
"stock_code": "035420",
|
||||
"name": "NAVER",
|
||||
"price": 250000,
|
||||
"volume": 3000000,
|
||||
"change_rate": 3.0,
|
||||
"volume_increase_rate": 200,
|
||||
},
|
||||
]
|
||||
volume_rows = [
|
||||
{"stock_code": "035420", "name": "NAVER", "price": 250000, "volume": 3000000},
|
||||
{"stock_code": "005930", "name": "Samsung", "price": 70000, "volume": 5000000},
|
||||
]
|
||||
mock_broker.fetch_market_rankings.side_effect = [fluctuation_rows, volume_rows]
|
||||
mock_broker.get_daily_prices.return_value = [
|
||||
{"open": 1, "high": 1, "low": 1, "close": 1, "volume": 1000000},
|
||||
{"open": 1, "high": 1, "low": 1, "close": 1, "volume": 1000000},
|
||||
]
|
||||
|
||||
candidates = await scanner.scan()
|
||||
|
||||
assert len(candidates) >= 1
|
||||
# Samsung has higher absolute move, so it should lead despite lower volume rank bonus.
|
||||
assert candidates[0].stock_code == "005930"
|
||||
assert candidates[0].signal == "oversold"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_domestic_finds_momentum_candidate(
|
||||
self, scanner: SmartVolatilityScanner, mock_broker: MagicMock
|
||||
) -> None:
|
||||
"""Positive change should be represented as momentum signal."""
|
||||
fluctuation_rows = [
|
||||
{
|
||||
"stock_code": "035420",
|
||||
"name": "NAVER",
|
||||
"price": 250000,
|
||||
"volume": 3000000,
|
||||
"change_rate": 5.0,
|
||||
"volume_increase_rate": 300,
|
||||
},
|
||||
]
|
||||
mock_broker.fetch_market_rankings.side_effect = [fluctuation_rows, fluctuation_rows]
|
||||
mock_broker.get_daily_prices.return_value = [
|
||||
{"open": 1, "high": 1, "low": 1, "close": 1, "volume": 1000000},
|
||||
{"open": 1, "high": 1, "low": 1, "close": 1, "volume": 1000000},
|
||||
]
|
||||
|
||||
candidates = await scanner.scan()
|
||||
|
||||
assert [c.stock_code for c in candidates] == ["035420"]
|
||||
assert candidates[0].signal == "momentum"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_domestic_filters_low_volatility(
|
||||
self, scanner: SmartVolatilityScanner, mock_broker: MagicMock
|
||||
) -> None:
|
||||
"""Domestic scan should drop symbols below volatility threshold."""
|
||||
fluctuation_rows = [
|
||||
{
|
||||
"stock_code": "000660",
|
||||
"name": "SK Hynix",
|
||||
"price": 150000,
|
||||
"volume": 500000,
|
||||
"change_rate": 0.2,
|
||||
"volume_increase_rate": 50,
|
||||
},
|
||||
]
|
||||
mock_broker.fetch_market_rankings.side_effect = [fluctuation_rows, fluctuation_rows]
|
||||
mock_broker.get_daily_prices.return_value = [
|
||||
{"open": 1, "high": 150100, "low": 149900, "close": 150000, "volume": 1000000},
|
||||
{"open": 1, "high": 150100, "low": 149900, "close": 150000, "volume": 1000000},
|
||||
]
|
||||
|
||||
candidates = await scanner.scan()
|
||||
|
||||
assert len(candidates) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_uses_fallback_on_api_error(
|
||||
self, scanner: SmartVolatilityScanner, mock_broker: MagicMock
|
||||
) -> None:
|
||||
"""Domestic scan should remain operational using fallback symbols."""
|
||||
mock_broker.fetch_market_rankings.side_effect = [
|
||||
ConnectionError("API unavailable"),
|
||||
ConnectionError("API unavailable"),
|
||||
]
|
||||
mock_broker.get_daily_prices.return_value = [
|
||||
{"open": 1, "high": 103, "low": 97, "close": 100, "volume": 1000000},
|
||||
{"open": 1, "high": 103, "low": 97, "close": 100, "volume": 800000},
|
||||
]
|
||||
|
||||
candidates = await scanner.scan(fallback_stocks=["005930", "000660"])
|
||||
|
||||
assert isinstance(candidates, list)
|
||||
assert len(candidates) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_returns_top_n_only(
|
||||
self, scanner: SmartVolatilityScanner, mock_broker: MagicMock
|
||||
) -> None:
|
||||
"""Test that scan returns at most top_n candidates."""
|
||||
fluctuation_rows = [
|
||||
{
|
||||
"stock_code": f"00{i}000",
|
||||
"name": f"Stock{i}",
|
||||
"price": 10000 * i,
|
||||
"volume": 5000000,
|
||||
"change_rate": -10,
|
||||
"volume_increase_rate": 500,
|
||||
}
|
||||
for i in range(1, 10)
|
||||
]
|
||||
mock_broker.fetch_market_rankings.side_effect = [fluctuation_rows, fluctuation_rows]
|
||||
mock_broker.get_daily_prices.return_value = [
|
||||
{"open": 1, "high": 105, "low": 95, "close": 100, "volume": 1000000},
|
||||
{"open": 1, "high": 105, "low": 95, "close": 100, "volume": 900000},
|
||||
]
|
||||
|
||||
candidates = await scanner.scan()
|
||||
|
||||
# Should respect top_n limit (3)
|
||||
assert len(candidates) <= scanner.top_n
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stock_codes(
|
||||
self, scanner: SmartVolatilityScanner
|
||||
) -> None:
|
||||
"""Test extraction of stock codes from candidates."""
|
||||
candidates = [
|
||||
ScanCandidate(
|
||||
stock_code="005930",
|
||||
name="Samsung",
|
||||
price=70000,
|
||||
volume=5000000,
|
||||
volume_ratio=2.5,
|
||||
rsi=28,
|
||||
signal="oversold",
|
||||
score=85.0,
|
||||
),
|
||||
ScanCandidate(
|
||||
stock_code="035420",
|
||||
name="NAVER",
|
||||
price=250000,
|
||||
volume=3000000,
|
||||
volume_ratio=3.0,
|
||||
rsi=75,
|
||||
signal="momentum",
|
||||
score=88.0,
|
||||
),
|
||||
]
|
||||
|
||||
codes = scanner.get_stock_codes(candidates)
|
||||
|
||||
assert codes == ["005930", "035420"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_overseas_uses_dynamic_symbols(
|
||||
self, mock_broker: MagicMock, mock_overseas_broker: MagicMock, mock_settings: Settings
|
||||
) -> None:
|
||||
"""Overseas scan should use provided dynamic universe symbols."""
|
||||
analyzer = VolatilityAnalyzer()
|
||||
scanner = SmartVolatilityScanner(
|
||||
broker=mock_broker,
|
||||
overseas_broker=mock_overseas_broker,
|
||||
volatility_analyzer=analyzer,
|
||||
settings=mock_settings,
|
||||
)
|
||||
|
||||
market = MagicMock()
|
||||
market.name = "NASDAQ"
|
||||
market.code = "US_NASDAQ"
|
||||
market.exchange_code = "NASD"
|
||||
market.is_domestic = False
|
||||
|
||||
mock_overseas_broker.get_overseas_price.side_effect = [
|
||||
{"output": {"last": "210.5", "rate": "1.6", "tvol": "1500000"}},
|
||||
{"output": {"last": "330.1", "rate": "0.2", "tvol": "900000"}},
|
||||
]
|
||||
|
||||
candidates = await scanner.scan(
|
||||
market=market,
|
||||
fallback_stocks=["AAPL", "MSFT"],
|
||||
)
|
||||
|
||||
assert [c.stock_code for c in candidates] == ["AAPL"]
|
||||
assert candidates[0].signal == "momentum"
|
||||
assert candidates[0].price == 210.5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_overseas_uses_ranking_api_first(
|
||||
self, mock_broker: MagicMock, mock_overseas_broker: MagicMock, mock_settings: Settings
|
||||
) -> None:
|
||||
"""Overseas scan should prioritize ranking API when available."""
|
||||
analyzer = VolatilityAnalyzer()
|
||||
scanner = SmartVolatilityScanner(
|
||||
broker=mock_broker,
|
||||
overseas_broker=mock_overseas_broker,
|
||||
volatility_analyzer=analyzer,
|
||||
settings=mock_settings,
|
||||
)
|
||||
market = MagicMock()
|
||||
market.name = "NASDAQ"
|
||||
market.code = "US_NASDAQ"
|
||||
market.exchange_code = "NASD"
|
||||
market.is_domestic = False
|
||||
|
||||
mock_overseas_broker.fetch_overseas_rankings.return_value = [
|
||||
{"symb": "NVDA", "last": "780.2", "rate": "2.4", "tvol": "1200000"},
|
||||
{"symb": "MSFT", "last": "420.0", "rate": "0.3", "tvol": "900000"},
|
||||
]
|
||||
|
||||
candidates = await scanner.scan(market=market, fallback_stocks=["AAPL", "TSLA"])
|
||||
|
||||
assert mock_overseas_broker.fetch_overseas_rankings.call_count >= 1
|
||||
mock_overseas_broker.get_overseas_price.assert_not_called()
|
||||
assert [c.stock_code for c in candidates] == ["NVDA"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_overseas_without_symbols_returns_empty(
|
||||
self, mock_broker: MagicMock, mock_overseas_broker: MagicMock, mock_settings: Settings
|
||||
) -> None:
|
||||
"""Overseas scan should return empty list when no symbol universe exists."""
|
||||
analyzer = VolatilityAnalyzer()
|
||||
scanner = SmartVolatilityScanner(
|
||||
broker=mock_broker,
|
||||
overseas_broker=mock_overseas_broker,
|
||||
volatility_analyzer=analyzer,
|
||||
settings=mock_settings,
|
||||
)
|
||||
market = MagicMock()
|
||||
market.name = "NASDAQ"
|
||||
market.code = "US_NASDAQ"
|
||||
market.exchange_code = "NASD"
|
||||
market.is_domestic = False
|
||||
|
||||
candidates = await scanner.scan(market=market, fallback_stocks=[])
|
||||
|
||||
assert candidates == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_overseas_picks_high_intraday_range_even_with_low_change(
|
||||
self, mock_broker: MagicMock, mock_overseas_broker: MagicMock, mock_settings: Settings
|
||||
) -> None:
|
||||
"""Volatility selection should consider intraday range, not only change rate."""
|
||||
analyzer = VolatilityAnalyzer()
|
||||
scanner = SmartVolatilityScanner(
|
||||
broker=mock_broker,
|
||||
overseas_broker=mock_overseas_broker,
|
||||
volatility_analyzer=analyzer,
|
||||
settings=mock_settings,
|
||||
)
|
||||
market = MagicMock()
|
||||
market.name = "NASDAQ"
|
||||
market.code = "US_NASDAQ"
|
||||
market.exchange_code = "NASD"
|
||||
market.is_domestic = False
|
||||
|
||||
# change rate is tiny, but high-low range is large (15%).
|
||||
mock_overseas_broker.fetch_overseas_rankings.return_value = [
|
||||
{
|
||||
"symb": "ABCD",
|
||||
"last": "100",
|
||||
"rate": "0.2",
|
||||
"high": "110",
|
||||
"low": "95",
|
||||
"tvol": "800000",
|
||||
}
|
||||
]
|
||||
|
||||
candidates = await scanner.scan(market=market, fallback_stocks=[])
|
||||
|
||||
assert [c.stock_code for c in candidates] == ["ABCD"]
|
||||
|
||||
|
||||
class TestImpliedRSIFormula:
|
||||
"""Test the implied_rsi formula in SmartVolatilityScanner (issue #181)."""
|
||||
|
||||
def test_neutral_change_gives_neutral_rsi(self) -> None:
|
||||
"""0% change → implied_rsi = 50 (neutral)."""
|
||||
# formula: 50 + (change_rate * 2.0)
|
||||
rsi = max(0.0, min(100.0, 50.0 + (0.0 * 2.0)))
|
||||
assert rsi == 50.0
|
||||
|
||||
def test_10pct_change_gives_rsi_70(self) -> None:
|
||||
"""10% upward change → implied_rsi = 70 (momentum signal)."""
|
||||
rsi = max(0.0, min(100.0, 50.0 + (10.0 * 2.0)))
|
||||
assert rsi == 70.0
|
||||
|
||||
def test_minus_10pct_gives_rsi_30(self) -> None:
|
||||
"""-10% change → implied_rsi = 30 (oversold signal)."""
|
||||
rsi = max(0.0, min(100.0, 50.0 + (-10.0 * 2.0)))
|
||||
assert rsi == 30.0
|
||||
|
||||
def test_saturation_at_25pct(self) -> None:
|
||||
"""Saturation occurs at >=25% change (not 12.5% as with old coefficient 4.0)."""
|
||||
rsi_12pct = max(0.0, min(100.0, 50.0 + (12.5 * 2.0)))
|
||||
rsi_25pct = max(0.0, min(100.0, 50.0 + (25.0 * 2.0)))
|
||||
rsi_30pct = max(0.0, min(100.0, 50.0 + (30.0 * 2.0)))
|
||||
# At 12.5% change: RSI = 75 (not 100, unlike old formula)
|
||||
assert rsi_12pct == 75.0
|
||||
# At 25%+ saturation
|
||||
assert rsi_25pct == 100.0
|
||||
assert rsi_30pct == 100.0 # Capped
|
||||
|
||||
def test_negative_saturation(self) -> None:
|
||||
"""Saturation at -25% gives RSI = 0."""
|
||||
rsi = max(0.0, min(100.0, 50.0 + (-25.0 * 2.0)))
|
||||
assert rsi == 0.0
|
||||
|
||||
|
||||
class TestRSICalculation:
|
||||
"""Test RSI calculation in VolatilityAnalyzer."""
|
||||
|
||||
def test_rsi_oversold(self) -> None:
|
||||
"""Test RSI calculation for downtrending prices."""
|
||||
analyzer = VolatilityAnalyzer()
|
||||
|
||||
# Steadily declining prices
|
||||
prices = [100 - i * 0.5 for i in range(20)]
|
||||
rsi = analyzer.calculate_rsi(prices, period=14)
|
||||
|
||||
assert rsi < 50 # Should be oversold territory
|
||||
|
||||
def test_rsi_overbought(self) -> None:
|
||||
"""Test RSI calculation for uptrending prices."""
|
||||
analyzer = VolatilityAnalyzer()
|
||||
|
||||
# Steadily rising prices
|
||||
prices = [100 + i * 0.5 for i in range(20)]
|
||||
rsi = analyzer.calculate_rsi(prices, period=14)
|
||||
|
||||
assert rsi > 50 # Should be overbought territory
|
||||
|
||||
def test_rsi_neutral(self) -> None:
|
||||
"""Test RSI calculation for flat prices."""
|
||||
analyzer = VolatilityAnalyzer()
|
||||
|
||||
# Flat prices with small oscillation
|
||||
prices = [100 + (i % 2) * 0.1 for i in range(20)]
|
||||
rsi = analyzer.calculate_rsi(prices, period=14)
|
||||
|
||||
assert 40 < rsi < 60 # Should be near neutral
|
||||
|
||||
def test_rsi_insufficient_data(self) -> None:
|
||||
"""Test RSI returns neutral when insufficient data."""
|
||||
analyzer = VolatilityAnalyzer()
|
||||
|
||||
prices = [100, 101, 102] # Only 3 prices, need 15+
|
||||
rsi = analyzer.calculate_rsi(prices, period=14)
|
||||
|
||||
assert rsi == 50.0 # Default neutral
|
||||
|
||||
def test_rsi_all_gains(self) -> None:
|
||||
"""Test RSI returns 100 when all gains (no losses)."""
|
||||
analyzer = VolatilityAnalyzer()
|
||||
|
||||
# Monotonic increase
|
||||
prices = [100 + i for i in range(20)]
|
||||
rsi = analyzer.calculate_rsi(prices, period=14)
|
||||
|
||||
assert rsi == 100.0 # Maximum RSI
|
||||
32
tests/test_strategies_base.py
Normal file
32
tests/test_strategies_base.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Tests for BaseStrategy abstract class."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from src.strategies.base import BaseStrategy
|
||||
|
||||
|
||||
class ConcreteStrategy(BaseStrategy):
|
||||
"""Minimal concrete strategy for testing."""
|
||||
|
||||
def evaluate(self, market_data: dict[str, Any]) -> dict[str, Any]:
|
||||
return {"action": "HOLD", "confidence": 50, "rationale": "test"}
|
||||
|
||||
|
||||
def test_base_strategy_cannot_be_instantiated() -> None:
|
||||
"""BaseStrategy cannot be instantiated directly (it's abstract)."""
|
||||
with pytest.raises(TypeError):
|
||||
BaseStrategy() # type: ignore[abstract]
|
||||
|
||||
|
||||
def test_concrete_strategy_evaluate_returns_decision() -> None:
|
||||
"""Concrete subclass must implement evaluate and return a dict."""
|
||||
strategy = ConcreteStrategy()
|
||||
result = strategy.evaluate({"close": [100.0, 101.0]})
|
||||
assert isinstance(result, dict)
|
||||
assert result["action"] == "HOLD"
|
||||
assert result["confidence"] == 50
|
||||
assert "rationale" in result
|
||||
366
tests/test_strategy_models.py
Normal file
366
tests/test_strategy_models.py
Normal file
@@ -0,0 +1,366 @@
|
||||
"""Tests for strategy/playbook Pydantic models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from src.strategy.models import (
|
||||
CrossMarketContext,
|
||||
DayPlaybook,
|
||||
GlobalRule,
|
||||
MarketOutlook,
|
||||
PlaybookStatus,
|
||||
ScenarioAction,
|
||||
StockCondition,
|
||||
StockPlaybook,
|
||||
StockScenario,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# StockCondition
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStockCondition:
|
||||
def test_empty_condition(self) -> None:
|
||||
cond = StockCondition()
|
||||
assert not cond.has_any_condition()
|
||||
|
||||
def test_single_field(self) -> None:
|
||||
cond = StockCondition(rsi_below=30.0)
|
||||
assert cond.has_any_condition()
|
||||
|
||||
def test_multiple_fields(self) -> None:
|
||||
cond = StockCondition(rsi_below=25.0, volume_ratio_above=3.0)
|
||||
assert cond.has_any_condition()
|
||||
|
||||
def test_all_fields(self) -> None:
|
||||
cond = StockCondition(
|
||||
rsi_below=30,
|
||||
rsi_above=10,
|
||||
volume_ratio_above=2.0,
|
||||
volume_ratio_below=10.0,
|
||||
price_above=1000,
|
||||
price_below=50000,
|
||||
price_change_pct_above=-5.0,
|
||||
price_change_pct_below=5.0,
|
||||
)
|
||||
assert cond.has_any_condition()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# StockScenario
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStockScenario:
|
||||
def test_valid_scenario(self) -> None:
|
||||
s = StockScenario(
|
||||
condition=StockCondition(rsi_below=25.0),
|
||||
action=ScenarioAction.BUY,
|
||||
confidence=85,
|
||||
allocation_pct=15.0,
|
||||
stop_loss_pct=-2.0,
|
||||
take_profit_pct=3.0,
|
||||
rationale="Oversold bounce expected",
|
||||
)
|
||||
assert s.action == ScenarioAction.BUY
|
||||
assert s.confidence == 85
|
||||
|
||||
def test_confidence_too_high(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
StockScenario(
|
||||
condition=StockCondition(),
|
||||
action=ScenarioAction.BUY,
|
||||
confidence=101,
|
||||
)
|
||||
|
||||
def test_confidence_too_low(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
StockScenario(
|
||||
condition=StockCondition(),
|
||||
action=ScenarioAction.BUY,
|
||||
confidence=-1,
|
||||
)
|
||||
|
||||
def test_allocation_too_high(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
StockScenario(
|
||||
condition=StockCondition(),
|
||||
action=ScenarioAction.BUY,
|
||||
confidence=80,
|
||||
allocation_pct=101.0,
|
||||
)
|
||||
|
||||
def test_stop_loss_must_be_negative(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
StockScenario(
|
||||
condition=StockCondition(),
|
||||
action=ScenarioAction.BUY,
|
||||
confidence=80,
|
||||
stop_loss_pct=1.0,
|
||||
)
|
||||
|
||||
def test_take_profit_must_be_positive(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
StockScenario(
|
||||
condition=StockCondition(),
|
||||
action=ScenarioAction.BUY,
|
||||
confidence=80,
|
||||
take_profit_pct=-1.0,
|
||||
)
|
||||
|
||||
def test_defaults(self) -> None:
|
||||
s = StockScenario(
|
||||
condition=StockCondition(),
|
||||
action=ScenarioAction.HOLD,
|
||||
confidence=50,
|
||||
)
|
||||
assert s.allocation_pct == 10.0
|
||||
assert s.stop_loss_pct == -2.0
|
||||
assert s.take_profit_pct == 3.0
|
||||
assert s.rationale == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# StockPlaybook
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStockPlaybook:
|
||||
def test_valid_playbook(self) -> None:
|
||||
pb = StockPlaybook(
|
||||
stock_code="005930",
|
||||
stock_name="Samsung Electronics",
|
||||
scenarios=[
|
||||
StockScenario(
|
||||
condition=StockCondition(rsi_below=25.0),
|
||||
action=ScenarioAction.BUY,
|
||||
confidence=85,
|
||||
),
|
||||
],
|
||||
)
|
||||
assert pb.stock_code == "005930"
|
||||
assert len(pb.scenarios) == 1
|
||||
|
||||
def test_empty_scenarios_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
StockPlaybook(
|
||||
stock_code="005930",
|
||||
scenarios=[],
|
||||
)
|
||||
|
||||
def test_multiple_scenarios(self) -> None:
|
||||
pb = StockPlaybook(
|
||||
stock_code="AAPL",
|
||||
scenarios=[
|
||||
StockScenario(
|
||||
condition=StockCondition(rsi_below=25.0),
|
||||
action=ScenarioAction.BUY,
|
||||
confidence=85,
|
||||
),
|
||||
StockScenario(
|
||||
condition=StockCondition(rsi_above=75.0),
|
||||
action=ScenarioAction.SELL,
|
||||
confidence=80,
|
||||
),
|
||||
],
|
||||
)
|
||||
assert len(pb.scenarios) == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GlobalRule
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGlobalRule:
|
||||
def test_valid_rule(self) -> None:
|
||||
rule = GlobalRule(
|
||||
condition="portfolio_pnl_pct < -2.0",
|
||||
action=ScenarioAction.REDUCE_ALL,
|
||||
rationale="Risk limit approaching",
|
||||
)
|
||||
assert rule.action == ScenarioAction.REDUCE_ALL
|
||||
|
||||
def test_hold_rule(self) -> None:
|
||||
rule = GlobalRule(
|
||||
condition="volatility_index > 30",
|
||||
action=ScenarioAction.HOLD,
|
||||
)
|
||||
assert rule.rationale == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CrossMarketContext
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCrossMarketContext:
|
||||
def test_valid_context(self) -> None:
|
||||
ctx = CrossMarketContext(
|
||||
market="US",
|
||||
date="2026-02-07",
|
||||
total_pnl=-1.5,
|
||||
win_rate=40.0,
|
||||
index_change_pct=-2.3,
|
||||
key_events=["Fed rate decision"],
|
||||
lessons=["Avoid tech sector on rate hike days"],
|
||||
)
|
||||
assert ctx.market == "US"
|
||||
assert len(ctx.key_events) == 1
|
||||
|
||||
def test_defaults(self) -> None:
|
||||
ctx = CrossMarketContext(market="KR", date="2026-02-07")
|
||||
assert ctx.total_pnl == 0.0
|
||||
assert ctx.key_events == []
|
||||
assert ctx.lessons == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DayPlaybook
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_scenario(rsi_below: float = 25.0) -> StockScenario:
|
||||
return StockScenario(
|
||||
condition=StockCondition(rsi_below=rsi_below),
|
||||
action=ScenarioAction.BUY,
|
||||
confidence=85,
|
||||
)
|
||||
|
||||
|
||||
def _make_playbook(**kwargs) -> DayPlaybook:
|
||||
defaults = {
|
||||
"date": date(2026, 2, 7),
|
||||
"market": "KR",
|
||||
"stock_playbooks": [
|
||||
StockPlaybook(stock_code="005930", scenarios=[_make_scenario()]),
|
||||
],
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return DayPlaybook(**defaults)
|
||||
|
||||
|
||||
class TestDayPlaybook:
|
||||
def test_valid_playbook(self) -> None:
|
||||
pb = _make_playbook()
|
||||
assert pb.market == "KR"
|
||||
assert pb.date == date(2026, 2, 7)
|
||||
assert pb.default_action == ScenarioAction.HOLD
|
||||
assert pb.scenario_count == 1
|
||||
assert pb.stock_count == 1
|
||||
|
||||
def test_generated_at_auto_set(self) -> None:
|
||||
pb = _make_playbook()
|
||||
assert pb.generated_at != ""
|
||||
|
||||
def test_explicit_generated_at(self) -> None:
|
||||
pb = _make_playbook(generated_at="2026-02-07T08:30:00")
|
||||
assert pb.generated_at == "2026-02-07T08:30:00"
|
||||
|
||||
def test_duplicate_stocks_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
DayPlaybook(
|
||||
date=date(2026, 2, 7),
|
||||
market="KR",
|
||||
stock_playbooks=[
|
||||
StockPlaybook(stock_code="005930", scenarios=[_make_scenario()]),
|
||||
StockPlaybook(stock_code="005930", scenarios=[_make_scenario(30)]),
|
||||
],
|
||||
)
|
||||
|
||||
def test_empty_stock_playbooks_allowed(self) -> None:
|
||||
pb = DayPlaybook(
|
||||
date=date(2026, 2, 7),
|
||||
market="KR",
|
||||
stock_playbooks=[],
|
||||
)
|
||||
assert pb.stock_count == 0
|
||||
assert pb.scenario_count == 0
|
||||
|
||||
def test_get_stock_playbook_found(self) -> None:
|
||||
pb = _make_playbook()
|
||||
result = pb.get_stock_playbook("005930")
|
||||
assert result is not None
|
||||
assert result.stock_code == "005930"
|
||||
|
||||
def test_get_stock_playbook_not_found(self) -> None:
|
||||
pb = _make_playbook()
|
||||
result = pb.get_stock_playbook("AAPL")
|
||||
assert result is None
|
||||
|
||||
def test_with_global_rules(self) -> None:
|
||||
pb = _make_playbook(
|
||||
global_rules=[
|
||||
GlobalRule(
|
||||
condition="portfolio_pnl_pct < -2.0",
|
||||
action=ScenarioAction.REDUCE_ALL,
|
||||
),
|
||||
],
|
||||
)
|
||||
assert len(pb.global_rules) == 1
|
||||
|
||||
def test_with_cross_market_context(self) -> None:
|
||||
ctx = CrossMarketContext(market="US", date="2026-02-07", total_pnl=-1.5)
|
||||
pb = _make_playbook(cross_market=ctx)
|
||||
assert pb.cross_market is not None
|
||||
assert pb.cross_market.market == "US"
|
||||
|
||||
def test_market_outlook(self) -> None:
|
||||
pb = _make_playbook(market_outlook=MarketOutlook.BEARISH)
|
||||
assert pb.market_outlook == MarketOutlook.BEARISH
|
||||
|
||||
def test_multiple_stocks_multiple_scenarios(self) -> None:
|
||||
pb = DayPlaybook(
|
||||
date=date(2026, 2, 7),
|
||||
market="US",
|
||||
stock_playbooks=[
|
||||
StockPlaybook(
|
||||
stock_code="AAPL",
|
||||
scenarios=[_make_scenario(), _make_scenario(30)],
|
||||
),
|
||||
StockPlaybook(
|
||||
stock_code="MSFT",
|
||||
scenarios=[_make_scenario()],
|
||||
),
|
||||
],
|
||||
)
|
||||
assert pb.stock_count == 2
|
||||
assert pb.scenario_count == 3
|
||||
|
||||
def test_serialization_roundtrip(self) -> None:
|
||||
pb = _make_playbook(
|
||||
market_outlook=MarketOutlook.BULLISH,
|
||||
cross_market=CrossMarketContext(market="US", date="2026-02-07"),
|
||||
)
|
||||
json_str = pb.model_dump_json()
|
||||
restored = DayPlaybook.model_validate_json(json_str)
|
||||
assert restored.market == pb.market
|
||||
assert restored.date == pb.date
|
||||
assert restored.scenario_count == pb.scenario_count
|
||||
assert restored.cross_market is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Enums
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnums:
|
||||
def test_scenario_action_values(self) -> None:
|
||||
assert ScenarioAction.BUY.value == "BUY"
|
||||
assert ScenarioAction.SELL.value == "SELL"
|
||||
assert ScenarioAction.HOLD.value == "HOLD"
|
||||
assert ScenarioAction.REDUCE_ALL.value == "REDUCE_ALL"
|
||||
|
||||
def test_market_outlook_values(self) -> None:
|
||||
assert len(MarketOutlook) == 5
|
||||
|
||||
def test_playbook_status_values(self) -> None:
|
||||
assert PlaybookStatus.READY.value == "ready"
|
||||
assert PlaybookStatus.EXPIRED.value == "expired"
|
||||
667
tests/test_telegram.py
Normal file
667
tests/test_telegram.py
Normal file
@@ -0,0 +1,667 @@
|
||||
"""Tests for Telegram notification client."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
from src.notifications.telegram_client import NotificationFilter, NotificationPriority, TelegramClient
|
||||
|
||||
|
||||
class TestTelegramClientInit:
|
||||
"""Test client initialization scenarios."""
|
||||
|
||||
def test_disabled_via_flag(self) -> None:
|
||||
"""Client disabled via enabled=False flag."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=False
|
||||
)
|
||||
assert client._enabled is False
|
||||
|
||||
def test_disabled_missing_token(self) -> None:
|
||||
"""Client disabled when bot_token is None."""
|
||||
client = TelegramClient(bot_token=None, chat_id="456", enabled=True)
|
||||
assert client._enabled is False
|
||||
|
||||
def test_disabled_missing_chat_id(self) -> None:
|
||||
"""Client disabled when chat_id is None."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id=None, enabled=True)
|
||||
assert client._enabled is False
|
||||
|
||||
def test_enabled_with_credentials(self) -> None:
|
||||
"""Client enabled when credentials provided."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
assert client._enabled is True
|
||||
|
||||
|
||||
class TestNotificationSending:
|
||||
"""Test notification sending behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_success(self) -> None:
|
||||
"""send_message returns True on successful send."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
result = await client.send_message("Test message")
|
||||
|
||||
assert result is True
|
||||
assert mock_post.call_count == 1
|
||||
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
assert payload["chat_id"] == "456"
|
||||
assert payload["text"] == "Test message"
|
||||
assert payload["parse_mode"] == "HTML"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_disabled_client(self) -> None:
|
||||
"""send_message returns False when client disabled."""
|
||||
client = TelegramClient(enabled=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post") as mock_post:
|
||||
result = await client.send_message("Test message")
|
||||
|
||||
assert result is False
|
||||
mock_post.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_api_error(self) -> None:
|
||||
"""send_message returns False on API error."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 400
|
||||
mock_resp.text = AsyncMock(return_value="Bad Request")
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
|
||||
result = await client.send_message("Test message")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_with_markdown(self) -> None:
|
||||
"""send_message supports different parse modes."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
result = await client.send_message("*bold*", parse_mode="Markdown")
|
||||
|
||||
assert result is True
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
assert payload["parse_mode"] == "Markdown"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_send_when_disabled(self) -> None:
|
||||
"""Notifications not sent when client disabled."""
|
||||
client = TelegramClient(enabled=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post") as mock_post:
|
||||
await client.notify_trade_execution(
|
||||
stock_code="AAPL",
|
||||
market="United States",
|
||||
action="BUY",
|
||||
quantity=10,
|
||||
price=150.0,
|
||||
confidence=85.0,
|
||||
)
|
||||
mock_post.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trade_execution_format(self) -> None:
|
||||
"""Trade notification has correct format."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
await client.notify_trade_execution(
|
||||
stock_code="TSLA",
|
||||
market="United States",
|
||||
action="SELL",
|
||||
quantity=5,
|
||||
price=250.50,
|
||||
confidence=92.0,
|
||||
)
|
||||
|
||||
# Verify API call was made
|
||||
assert mock_post.call_count == 1
|
||||
call_args = mock_post.call_args
|
||||
|
||||
# Check payload structure
|
||||
payload = call_args.kwargs["json"]
|
||||
assert payload["chat_id"] == "456"
|
||||
assert "TSLA" in payload["text"]
|
||||
assert "SELL" in payload["text"]
|
||||
assert "5" in payload["text"]
|
||||
assert "250.50" in payload["text"]
|
||||
assert "92%" in payload["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_playbook_generated_format(self) -> None:
|
||||
"""Playbook generated notification has expected fields."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
await client.notify_playbook_generated(
|
||||
market="KR",
|
||||
stock_count=4,
|
||||
scenario_count=12,
|
||||
token_count=980,
|
||||
)
|
||||
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
assert "Playbook Generated" in payload["text"]
|
||||
assert "Market: KR" in payload["text"]
|
||||
assert "Stocks: 4" in payload["text"]
|
||||
assert "Scenarios: 12" in payload["text"]
|
||||
assert "Tokens: 980" in payload["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_matched_format(self) -> None:
|
||||
"""Scenario matched notification has expected fields."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
await client.notify_scenario_matched(
|
||||
stock_code="AAPL",
|
||||
action="BUY",
|
||||
condition_summary="RSI < 30, volume_ratio > 2.0",
|
||||
confidence=88.2,
|
||||
)
|
||||
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
assert "Scenario Matched" in payload["text"]
|
||||
assert "AAPL" in payload["text"]
|
||||
assert "Action: BUY" in payload["text"]
|
||||
assert "RSI < 30" in payload["text"]
|
||||
assert "88%" in payload["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_playbook_failed_format(self) -> None:
|
||||
"""Playbook failed notification has expected fields."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
await client.notify_playbook_failed(
|
||||
market="US",
|
||||
reason="Gemini timeout",
|
||||
)
|
||||
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
assert "Playbook Failed" in payload["text"]
|
||||
assert "Market: US" in payload["text"]
|
||||
assert "Gemini timeout" in payload["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_circuit_breaker_priority(self) -> None:
|
||||
"""Circuit breaker uses CRITICAL priority."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
await client.notify_circuit_breaker(pnl_pct=-3.15, threshold=-3.0)
|
||||
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
# CRITICAL priority has 🚨 emoji
|
||||
assert NotificationPriority.CRITICAL.emoji in payload["text"]
|
||||
assert "-3.15%" in payload["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_error_handling(self) -> None:
|
||||
"""API errors logged but don't crash."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 400
|
||||
mock_resp.text = AsyncMock(return_value="Bad Request")
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
|
||||
# Should not raise exception
|
||||
await client.notify_system_start(mode="paper", enabled_markets=["KR"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_handling(self) -> None:
|
||||
"""Timeouts logged but don't crash."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post",
|
||||
side_effect=aiohttp.ClientError("Connection timeout"),
|
||||
):
|
||||
# Should not raise exception
|
||||
await client.notify_error(
|
||||
error_type="Test Error", error_msg="Test", context="test"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_management(self) -> None:
|
||||
"""Session created and reused correctly."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
# Session should be None initially
|
||||
assert client._session is None
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
|
||||
await client.notify_market_open("Korea")
|
||||
# Session should be created
|
||||
assert client._session is not None
|
||||
|
||||
session1 = client._session
|
||||
await client.notify_market_close("Korea", 1.5)
|
||||
# Same session should be reused
|
||||
assert client._session is session1
|
||||
|
||||
|
||||
class TestRateLimiting:
|
||||
"""Test rate limiter behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limiter_enforced(self) -> None:
|
||||
"""Rate limiter delays rapid requests."""
|
||||
import time
|
||||
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True, rate_limit=2.0
|
||||
)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
|
||||
start = time.monotonic()
|
||||
|
||||
# Send 3 messages (rate: 2/sec = 0.5s per message)
|
||||
await client.notify_market_open("Korea")
|
||||
await client.notify_market_open("United States")
|
||||
await client.notify_market_open("Japan")
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
# Should take at least 0.4 seconds (3 msgs at 2/sec with some tolerance)
|
||||
assert elapsed >= 0.4
|
||||
|
||||
|
||||
class TestMessagePriorities:
|
||||
"""Test priority-based messaging."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_low_priority_uses_info_emoji(self) -> None:
|
||||
"""LOW priority uses ℹ️ emoji."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
await client.notify_market_open("Korea")
|
||||
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
assert NotificationPriority.LOW.emoji in payload["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_critical_priority_uses_alarm_emoji(self) -> None:
|
||||
"""CRITICAL priority uses 🚨 emoji."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
await client.notify_system_shutdown("Circuit breaker tripped")
|
||||
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
assert NotificationPriority.CRITICAL.emoji in payload["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_playbook_generated_priority(self) -> None:
|
||||
"""Playbook generated uses MEDIUM priority emoji."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
await client.notify_playbook_generated(
|
||||
market="KR",
|
||||
stock_count=2,
|
||||
scenario_count=4,
|
||||
token_count=123,
|
||||
)
|
||||
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
assert NotificationPriority.MEDIUM.emoji in payload["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_playbook_failed_priority(self) -> None:
|
||||
"""Playbook failed uses HIGH priority emoji."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
await client.notify_playbook_failed(
|
||||
market="KR",
|
||||
reason="Invalid JSON",
|
||||
)
|
||||
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
assert NotificationPriority.HIGH.emoji in payload["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_matched_priority(self) -> None:
|
||||
"""Scenario matched uses HIGH priority emoji."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
await client.notify_scenario_matched(
|
||||
stock_code="AAPL",
|
||||
action="BUY",
|
||||
condition_summary="RSI < 30",
|
||||
confidence=80.0,
|
||||
)
|
||||
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
assert NotificationPriority.HIGH.emoji in payload["text"]
|
||||
|
||||
|
||||
class TestClientCleanup:
|
||||
"""Test client cleanup behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_closes_session(self) -> None:
|
||||
"""close() closes the HTTP session."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.closed = False
|
||||
mock_session.close = AsyncMock()
|
||||
client._session = mock_session
|
||||
|
||||
await client.close()
|
||||
mock_session.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_handles_no_session(self) -> None:
|
||||
"""close() handles None session gracefully."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
# Should not raise exception
|
||||
await client.close()
|
||||
|
||||
|
||||
class TestNotificationFilter:
|
||||
"""Test granular notification filter behavior."""
|
||||
|
||||
def test_default_filter_allows_all(self) -> None:
|
||||
"""Default NotificationFilter has all flags enabled."""
|
||||
f = NotificationFilter()
|
||||
assert f.trades is True
|
||||
assert f.market_open_close is True
|
||||
assert f.fat_finger is True
|
||||
assert f.system_events is True
|
||||
assert f.playbook is True
|
||||
assert f.scenario_match is True
|
||||
assert f.errors is True
|
||||
|
||||
def test_client_uses_default_filter_when_none_given(self) -> None:
|
||||
"""TelegramClient creates a default NotificationFilter when none provided."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
assert isinstance(client._filter, NotificationFilter)
|
||||
assert client._filter.scenario_match is True
|
||||
|
||||
def test_client_stores_provided_filter(self) -> None:
|
||||
"""TelegramClient stores a custom NotificationFilter."""
|
||||
nf = NotificationFilter(scenario_match=False, trades=False)
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True, notification_filter=nf
|
||||
)
|
||||
assert client._filter.scenario_match is False
|
||||
assert client._filter.trades is False
|
||||
assert client._filter.market_open_close is True # default still True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_match_filtered_does_not_send(self) -> None:
|
||||
"""notify_scenario_matched skips send when scenario_match=False."""
|
||||
nf = NotificationFilter(scenario_match=False)
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True, notification_filter=nf
|
||||
)
|
||||
with patch("aiohttp.ClientSession.post") as mock_post:
|
||||
await client.notify_scenario_matched(
|
||||
stock_code="005930", action="BUY", condition_summary="rsi<30", confidence=85.0
|
||||
)
|
||||
mock_post.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trades_filtered_does_not_send(self) -> None:
|
||||
"""notify_trade_execution skips send when trades=False."""
|
||||
nf = NotificationFilter(trades=False)
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True, notification_filter=nf
|
||||
)
|
||||
with patch("aiohttp.ClientSession.post") as mock_post:
|
||||
await client.notify_trade_execution(
|
||||
stock_code="005930", market="KR", action="BUY",
|
||||
quantity=10, price=70000.0, confidence=85.0
|
||||
)
|
||||
mock_post.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_market_open_close_filtered_does_not_send(self) -> None:
|
||||
"""notify_market_open/close skip send when market_open_close=False."""
|
||||
nf = NotificationFilter(market_open_close=False)
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True, notification_filter=nf
|
||||
)
|
||||
with patch("aiohttp.ClientSession.post") as mock_post:
|
||||
await client.notify_market_open("Korea")
|
||||
await client.notify_market_close("Korea", pnl_pct=1.5)
|
||||
mock_post.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_circuit_breaker_always_sends_regardless_of_filter(self) -> None:
|
||||
"""notify_circuit_breaker always sends (no filter flag)."""
|
||||
nf = NotificationFilter(
|
||||
trades=False, market_open_close=False, fat_finger=False,
|
||||
system_events=False, playbook=False, scenario_match=False, errors=False,
|
||||
)
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True, notification_filter=nf
|
||||
)
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
await client.notify_circuit_breaker(pnl_pct=-3.5, threshold=-3.0)
|
||||
assert mock_post.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_errors_filtered_does_not_send(self) -> None:
|
||||
"""notify_error skips send when errors=False."""
|
||||
nf = NotificationFilter(errors=False)
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True, notification_filter=nf
|
||||
)
|
||||
with patch("aiohttp.ClientSession.post") as mock_post:
|
||||
await client.notify_error("TestError", "something went wrong", "KR")
|
||||
mock_post.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_playbook_filtered_does_not_send(self) -> None:
|
||||
"""notify_playbook_generated/failed skip send when playbook=False."""
|
||||
nf = NotificationFilter(playbook=False)
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True, notification_filter=nf
|
||||
)
|
||||
with patch("aiohttp.ClientSession.post") as mock_post:
|
||||
await client.notify_playbook_generated("KR", 3, 10, 1200)
|
||||
await client.notify_playbook_failed("KR", "timeout")
|
||||
mock_post.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_system_events_filtered_does_not_send(self) -> None:
|
||||
"""notify_system_start/shutdown skip send when system_events=False."""
|
||||
nf = NotificationFilter(system_events=False)
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True, notification_filter=nf
|
||||
)
|
||||
with patch("aiohttp.ClientSession.post") as mock_post:
|
||||
await client.notify_system_start("paper", ["KR"])
|
||||
await client.notify_system_shutdown("Normal shutdown")
|
||||
mock_post.assert_not_called()
|
||||
|
||||
def test_set_flag_valid_key(self) -> None:
|
||||
"""set_flag returns True and updates field for a known key."""
|
||||
nf = NotificationFilter()
|
||||
assert nf.set_flag("scenario", False) is True
|
||||
assert nf.scenario_match is False
|
||||
|
||||
def test_set_flag_invalid_key(self) -> None:
|
||||
"""set_flag returns False for an unknown key."""
|
||||
nf = NotificationFilter()
|
||||
assert nf.set_flag("unknown_key", False) is False
|
||||
|
||||
def test_as_dict_keys_match_KEYS(self) -> None:
|
||||
"""as_dict() returns every key defined in KEYS."""
|
||||
nf = NotificationFilter()
|
||||
d = nf.as_dict()
|
||||
assert set(d.keys()) == set(NotificationFilter.KEYS.keys())
|
||||
|
||||
def test_set_notification_valid_key(self) -> None:
|
||||
"""TelegramClient.set_notification toggles filter at runtime."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
assert client._filter.scenario_match is True
|
||||
assert client.set_notification("scenario", False) is True
|
||||
assert client._filter.scenario_match is False
|
||||
|
||||
def test_set_notification_all_off(self) -> None:
|
||||
"""set_notification('all', False) disables every filter flag."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
assert client.set_notification("all", False) is True
|
||||
for v in client.filter_status().values():
|
||||
assert v is False
|
||||
|
||||
def test_set_notification_all_on(self) -> None:
|
||||
"""set_notification('all', True) enables every filter flag."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True,
|
||||
notification_filter=NotificationFilter(
|
||||
trades=False, market_open_close=False, scenario_match=False,
|
||||
fat_finger=False, system_events=False, playbook=False, errors=False,
|
||||
),
|
||||
)
|
||||
assert client.set_notification("all", True) is True
|
||||
for v in client.filter_status().values():
|
||||
assert v is True
|
||||
|
||||
def test_set_notification_unknown_key(self) -> None:
|
||||
"""set_notification returns False for an unknown key."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
assert client.set_notification("unknown", False) is False
|
||||
|
||||
def test_filter_status_reflects_current_state(self) -> None:
|
||||
"""filter_status() matches the current NotificationFilter state."""
|
||||
nf = NotificationFilter(trades=False, scenario_match=False)
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True, notification_filter=nf
|
||||
)
|
||||
status = client.filter_status()
|
||||
assert status["trades"] is False
|
||||
assert status["scenario"] is False
|
||||
assert status["market"] is True
|
||||
1013
tests/test_telegram_commands.py
Normal file
1013
tests/test_telegram_commands.py
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user