Compare commits
19 Commits
62fd4ff5e1
...
feature/is
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3dfd7c0935 | ||
| 4b2bb25d03 | |||
|
|
881bbb4240 | ||
| 5f7d61748b | |||
|
|
972e71a2f1 | ||
| 614b9939b1 | |||
|
|
6dbc2afbf4 | ||
| 6c96f9ac64 | |||
|
|
ed26915562 | ||
| 628a572c70 | |||
|
|
73e1d0a54e | ||
| b111157dc8 | |||
|
|
8c05448843 | ||
|
|
87556b145e | ||
| 645c761238 | |||
|
|
033d5fcadd | ||
| 128324427f | |||
|
|
61f5aaf4a3 | ||
|
|
4f61d5af8e |
@@ -26,3 +26,10 @@ MODE=paper
|
|||||||
# NEWS_API_KEY=your_news_api_key_here
|
# NEWS_API_KEY=your_news_api_key_here
|
||||||
# NEWS_API_PROVIDER=alphavantage
|
# NEWS_API_PROVIDER=alphavantage
|
||||||
# MARKET_DATA_API_KEY=your_market_data_key_here
|
# MARKET_DATA_API_KEY=your_market_data_key_here
|
||||||
|
|
||||||
|
# Telegram Notifications (optional)
|
||||||
|
# Get bot token from @BotFather on Telegram
|
||||||
|
# Get chat ID from @userinfobot or your chat
|
||||||
|
# TELEGRAM_BOT_TOKEN=1234567890:ABCdefGHIjklMNOpqrsTUVwxyz
|
||||||
|
# TELEGRAM_CHAT_ID=123456789
|
||||||
|
# TELEGRAM_ENABLED=true
|
||||||
|
|||||||
31
CLAUDE.md
31
CLAUDE.md
@@ -17,6 +17,34 @@ pytest -v --cov=src
|
|||||||
python -m src.main --mode=paper
|
python -m src.main --mode=paper
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Telegram Notifications (Optional)
|
||||||
|
|
||||||
|
Get real-time alerts for trades, circuit breakers, and system events via Telegram.
|
||||||
|
|
||||||
|
### Quick Setup
|
||||||
|
|
||||||
|
1. **Create bot**: Message [@BotFather](https://t.me/BotFather) on Telegram → `/newbot`
|
||||||
|
2. **Get chat ID**: Message [@userinfobot](https://t.me/userinfobot) → `/start`
|
||||||
|
3. **Configure**: Add to `.env`:
|
||||||
|
```bash
|
||||||
|
TELEGRAM_BOT_TOKEN=1234567890:ABCdefGHIjklMNOpqrsTUVwxyz
|
||||||
|
TELEGRAM_CHAT_ID=123456789
|
||||||
|
TELEGRAM_ENABLED=true
|
||||||
|
```
|
||||||
|
4. **Test**: Start bot conversation (`/start`), then run the agent
|
||||||
|
|
||||||
|
**Full documentation**: [src/notifications/README.md](src/notifications/README.md)
|
||||||
|
|
||||||
|
### What You'll Get
|
||||||
|
|
||||||
|
- 🟢 Trade execution alerts (BUY/SELL with confidence)
|
||||||
|
- 🚨 Circuit breaker trips (automatic trading halt)
|
||||||
|
- ⚠️ Fat-finger rejections (oversized orders blocked)
|
||||||
|
- ℹ️ Market open/close notifications
|
||||||
|
- 📝 System startup/shutdown status
|
||||||
|
|
||||||
|
**Fail-safe**: Notifications never crash the trading system. Missing credentials or API errors are logged but trading continues normally.
|
||||||
|
|
||||||
## Documentation
|
## Documentation
|
||||||
|
|
||||||
- **[Workflow Guide](docs/workflow.md)** — Git workflow policy and agent-based development
|
- **[Workflow Guide](docs/workflow.md)** — Git workflow policy and agent-based development
|
||||||
@@ -42,11 +70,12 @@ src/
|
|||||||
├── core/ # Risk manager (READ-ONLY)
|
├── core/ # Risk manager (READ-ONLY)
|
||||||
├── evolution/ # Self-improvement optimizer
|
├── evolution/ # Self-improvement optimizer
|
||||||
├── markets/ # Market schedules and timezone handling
|
├── markets/ # Market schedules and timezone handling
|
||||||
|
├── notifications/ # Telegram real-time alerts
|
||||||
├── db.py # SQLite trade logging
|
├── db.py # SQLite trade logging
|
||||||
├── main.py # Trading loop orchestrator
|
├── main.py # Trading loop orchestrator
|
||||||
└── config.py # Settings (from .env)
|
└── config.py # Settings (from .env)
|
||||||
|
|
||||||
tests/ # 54 tests across 4 files
|
tests/ # 273 tests across 13 files
|
||||||
docs/ # Extended documentation
|
docs/ # Extended documentation
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
48
README.md
48
README.md
@@ -29,6 +29,7 @@ KIS(한국투자증권) API로 매매하고, Google Gemini로 판단하며, 자
|
|||||||
| 브로커 | `src/broker/kis_api.py` | KIS API 비동기 래퍼 (토큰 갱신, 레이트 리미터, 해시키) |
|
| 브로커 | `src/broker/kis_api.py` | KIS API 비동기 래퍼 (토큰 갱신, 레이트 리미터, 해시키) |
|
||||||
| 두뇌 | `src/brain/gemini_client.py` | Gemini 프롬프트 구성 및 JSON 응답 파싱 |
|
| 두뇌 | `src/brain/gemini_client.py` | Gemini 프롬프트 구성 및 JSON 응답 파싱 |
|
||||||
| 방패 | `src/core/risk_manager.py` | 서킷 브레이커 + 팻 핑거 체크 |
|
| 방패 | `src/core/risk_manager.py` | 서킷 브레이커 + 팻 핑거 체크 |
|
||||||
|
| 알림 | `src/notifications/telegram_client.py` | 텔레그램 실시간 거래 알림 (선택사항) |
|
||||||
| 진화 | `src/evolution/optimizer.py` | 실패 패턴 분석 → 새 전략 생성 → 테스트 → PR |
|
| 진화 | `src/evolution/optimizer.py` | 실패 패턴 분석 → 새 전략 생성 → 테스트 → PR |
|
||||||
| DB | `src/db.py` | SQLite 거래 로그 기록 |
|
| DB | `src/db.py` | SQLite 거래 로그 기록 |
|
||||||
|
|
||||||
@@ -75,6 +76,34 @@ python -m src.main --mode=paper
|
|||||||
docker compose up -d ouroboros
|
docker compose up -d ouroboros
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 텔레그램 알림 (선택사항)
|
||||||
|
|
||||||
|
거래 실행, 서킷 브레이커 발동, 시스템 상태 등을 텔레그램으로 실시간 알림 받을 수 있습니다.
|
||||||
|
|
||||||
|
### 빠른 설정
|
||||||
|
|
||||||
|
1. **봇 생성**: 텔레그램에서 [@BotFather](https://t.me/BotFather) 메시지 → `/newbot` 명령
|
||||||
|
2. **채팅 ID 확인**: [@userinfobot](https://t.me/userinfobot) 메시지 → `/start` 명령
|
||||||
|
3. **환경변수 설정**: `.env` 파일에 추가
|
||||||
|
```bash
|
||||||
|
TELEGRAM_BOT_TOKEN=1234567890:ABCdefGHIjklMNOpqrsTUVwxyz
|
||||||
|
TELEGRAM_CHAT_ID=123456789
|
||||||
|
TELEGRAM_ENABLED=true
|
||||||
|
```
|
||||||
|
4. **테스트**: 봇과 대화 시작 (`/start` 전송) 후 에이전트 실행
|
||||||
|
|
||||||
|
**상세 문서**: [src/notifications/README.md](src/notifications/README.md)
|
||||||
|
|
||||||
|
### 알림 종류
|
||||||
|
|
||||||
|
- 🟢 거래 체결 알림 (BUY/SELL + 신뢰도)
|
||||||
|
- 🚨 서킷 브레이커 발동 (자동 거래 중단)
|
||||||
|
- ⚠️ 팻 핑거 차단 (과도한 주문 차단)
|
||||||
|
- ℹ️ 장 시작/종료 알림
|
||||||
|
- 📝 시스템 시작/종료 상태
|
||||||
|
|
||||||
|
**안전장치**: 알림 실패해도 거래는 계속 진행됩니다. 텔레그램 API 오류나 설정 누락이 있어도 거래 시스템은 정상 작동합니다.
|
||||||
|
|
||||||
## 테스트
|
## 테스트
|
||||||
|
|
||||||
35개 테스트가 TDD 방식으로 구현 전에 먼저 작성되었습니다.
|
35개 테스트가 TDD 방식으로 구현 전에 먼저 작성되었습니다.
|
||||||
@@ -104,15 +133,16 @@ The-Ouroboros/
|
|||||||
│ ├── agents.md # AI 에이전트 페르소나 정의
|
│ ├── agents.md # AI 에이전트 페르소나 정의
|
||||||
│ └── skills.md # 사용 가능한 도구 목록
|
│ └── skills.md # 사용 가능한 도구 목록
|
||||||
├── src/
|
├── src/
|
||||||
│ ├── config.py # Pydantic 설정
|
│ ├── config.py # Pydantic 설정
|
||||||
│ ├── logging_config.py # JSON 구조화 로깅
|
│ ├── logging_config.py # JSON 구조화 로깅
|
||||||
│ ├── db.py # SQLite 거래 기록
|
│ ├── db.py # SQLite 거래 기록
|
||||||
│ ├── main.py # 비동기 거래 루프
|
│ ├── main.py # 비동기 거래 루프
|
||||||
│ ├── broker/kis_api.py # KIS API 클라이언트
|
│ ├── broker/kis_api.py # KIS API 클라이언트
|
||||||
│ ├── brain/gemini_client.py # Gemini 의사결정 엔진
|
│ ├── brain/gemini_client.py # Gemini 의사결정 엔진
|
||||||
│ ├── core/risk_manager.py # 리스크 관리
|
│ ├── core/risk_manager.py # 리스크 관리
|
||||||
│ ├── evolution/optimizer.py # 전략 진화 엔진
|
│ ├── notifications/telegram_client.py # 텔레그램 알림
|
||||||
│ └── strategies/base.py # 전략 베이스 클래스
|
│ ├── evolution/optimizer.py # 전략 진화 엔진
|
||||||
|
│ └── strategies/base.py # 전략 베이스 클래스
|
||||||
├── tests/ # TDD 테스트 스위트
|
├── tests/ # TDD 테스트 스위트
|
||||||
├── Dockerfile # 멀티스테이지 빌드
|
├── Dockerfile # 멀티스테이지 빌드
|
||||||
├── docker-compose.yml # 서비스 오케스트레이션
|
├── docker-compose.yml # 서비스 오케스트레이션
|
||||||
|
|||||||
@@ -51,7 +51,26 @@ Self-evolving AI trading agent for global stock markets via KIS (Korea Investmen
|
|||||||
- **Fat-Finger Protection**: Rejects orders exceeding 30% of available cash
|
- **Fat-Finger Protection**: Rejects orders exceeding 30% of available cash
|
||||||
- Must always be enforced, cannot be disabled
|
- Must always be enforced, cannot be disabled
|
||||||
|
|
||||||
### 4. Evolution (`src/evolution/optimizer.py`)
|
### 4. Notifications (`src/notifications/telegram_client.py`)
|
||||||
|
|
||||||
|
**TelegramClient** — Real-time event notifications via Telegram Bot API
|
||||||
|
|
||||||
|
- Sends alerts for trades, circuit breakers, fat-finger rejections, system events
|
||||||
|
- Non-blocking: failures are logged but never crash trading
|
||||||
|
- Rate-limited: 1 message/second default to respect Telegram API limits
|
||||||
|
- Auto-disabled when credentials missing
|
||||||
|
- Gracefully handles API errors, network timeouts, invalid tokens
|
||||||
|
|
||||||
|
**Notification Types:**
|
||||||
|
- Trade execution (BUY/SELL with confidence)
|
||||||
|
- Circuit breaker trips (critical alert)
|
||||||
|
- Fat-finger protection triggers (order rejection)
|
||||||
|
- Market open/close events
|
||||||
|
- System startup/shutdown status
|
||||||
|
|
||||||
|
**Setup:** See [src/notifications/README.md](../src/notifications/README.md) for bot creation and configuration.
|
||||||
|
|
||||||
|
### 5. Evolution (`src/evolution/optimizer.py`)
|
||||||
|
|
||||||
**StrategyOptimizer** — Self-improvement loop
|
**StrategyOptimizer** — Self-improvement loop
|
||||||
|
|
||||||
@@ -115,6 +134,14 @@ Self-evolving AI trading agent for global stock markets via KIS (Korea Investmen
|
|||||||
│
|
│
|
||||||
▼
|
▼
|
||||||
┌──────────────────────────────────┐
|
┌──────────────────────────────────┐
|
||||||
|
│ Notifications: Send Alert │
|
||||||
|
│ - Trade execution notification │
|
||||||
|
│ - Non-blocking (errors logged) │
|
||||||
|
│ - Rate-limited to 1/sec │
|
||||||
|
└──────────────────┬────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌──────────────────────────────────┐
|
||||||
│ Database: Log Trade │
|
│ Database: Log Trade │
|
||||||
│ - SQLite (data/trades.db) │
|
│ - SQLite (data/trades.db) │
|
||||||
│ - Track: action, confidence, │
|
│ - Track: action, confidence, │
|
||||||
@@ -164,6 +191,11 @@ CONFIDENCE_THRESHOLD=80
|
|||||||
MAX_LOSS_PCT=3.0
|
MAX_LOSS_PCT=3.0
|
||||||
MAX_ORDER_PCT=30.0
|
MAX_ORDER_PCT=30.0
|
||||||
ENABLED_MARKETS=KR,US_NASDAQ # Comma-separated market codes
|
ENABLED_MARKETS=KR,US_NASDAQ # Comma-separated market codes
|
||||||
|
|
||||||
|
# Telegram Notifications (optional)
|
||||||
|
TELEGRAM_BOT_TOKEN=1234567890:ABCdefGHIjklMNOpqrsTUVwxyz
|
||||||
|
TELEGRAM_CHAT_ID=123456789
|
||||||
|
TELEGRAM_ENABLED=true
|
||||||
```
|
```
|
||||||
|
|
||||||
Tests use in-memory SQLite (`DB_PATH=":memory:"`) and dummy credentials via `tests/conftest.py`.
|
Tests use in-memory SQLite (`DB_PATH=":memory:"`) and dummy credentials via `tests/conftest.py`.
|
||||||
@@ -189,3 +221,12 @@ Tests use in-memory SQLite (`DB_PATH=":memory:"`) and dummy credentials via `tes
|
|||||||
- Wait until next market opens
|
- Wait until next market opens
|
||||||
- Use `get_next_market_open()` to calculate wait time
|
- Use `get_next_market_open()` to calculate wait time
|
||||||
- Sleep until market open time
|
- Sleep until market open time
|
||||||
|
|
||||||
|
### Telegram API Errors
|
||||||
|
- Log warning but continue trading
|
||||||
|
- Missing credentials → auto-disable notifications
|
||||||
|
- Network timeout → skip notification, no retry
|
||||||
|
- Invalid token → log error, trading unaffected
|
||||||
|
- Rate limit exceeded → queued via rate limiter
|
||||||
|
|
||||||
|
**Guarantee**: Notification failures never interrupt trading operations.
|
||||||
|
|||||||
348
docs/disaster_recovery.md
Normal file
348
docs/disaster_recovery.md
Normal file
@@ -0,0 +1,348 @@
|
|||||||
|
# Disaster Recovery Guide
|
||||||
|
|
||||||
|
Complete guide for backing up and restoring The Ouroboros trading system.
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
- [Backup Strategy](#backup-strategy)
|
||||||
|
- [Creating Backups](#creating-backups)
|
||||||
|
- [Restoring from Backup](#restoring-from-backup)
|
||||||
|
- [Health Monitoring](#health-monitoring)
|
||||||
|
- [Export Formats](#export-formats)
|
||||||
|
- [RTO/RPO](#rtorpo)
|
||||||
|
- [Testing Recovery](#testing-recovery)
|
||||||
|
|
||||||
|
## Backup Strategy
|
||||||
|
|
||||||
|
The system implements a 3-tier backup retention policy:
|
||||||
|
|
||||||
|
| Policy | Frequency | Retention | Purpose |
|
||||||
|
|--------|-----------|-----------|---------|
|
||||||
|
| **Daily** | Every day | 30 days | Quick recovery from recent issues |
|
||||||
|
| **Weekly** | Sunday | 1 year | Medium-term historical analysis |
|
||||||
|
| **Monthly** | 1st of month | Forever | Long-term archival |
|
||||||
|
|
||||||
|
### Storage Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
data/backups/
|
||||||
|
├── daily/ # Last 30 days
|
||||||
|
├── weekly/ # Last 52 weeks
|
||||||
|
└── monthly/ # Forever (cold storage)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Creating Backups
|
||||||
|
|
||||||
|
### Automated Backups (Recommended)
|
||||||
|
|
||||||
|
Set up a cron job to run daily:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Edit crontab
|
||||||
|
crontab -e
|
||||||
|
|
||||||
|
# Run backup at 2 AM every day
|
||||||
|
0 2 * * * cd /path/to/The-Ouroboros && ./scripts/backup.sh >> logs/backup.log 2>&1
|
||||||
|
```
|
||||||
|
|
||||||
|
### Manual Backups
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run backup script
|
||||||
|
./scripts/backup.sh
|
||||||
|
|
||||||
|
# Or use Python directly
|
||||||
|
python3 -c "
|
||||||
|
from pathlib import Path
|
||||||
|
from src.backup.scheduler import BackupScheduler, BackupPolicy
|
||||||
|
|
||||||
|
scheduler = BackupScheduler('data/trade_logs.db', Path('data/backups'))
|
||||||
|
metadata = scheduler.create_backup(BackupPolicy.DAILY, verify=True)
|
||||||
|
print(f'Backup created: {metadata.file_path}')
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Export to Other Formats
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -c "
|
||||||
|
from pathlib import Path
|
||||||
|
from src.backup.exporter import BackupExporter, ExportFormat
|
||||||
|
|
||||||
|
exporter = BackupExporter('data/trade_logs.db')
|
||||||
|
results = exporter.export_all(
|
||||||
|
Path('exports'),
|
||||||
|
formats=[ExportFormat.JSON, ExportFormat.CSV],
|
||||||
|
compress=True
|
||||||
|
)
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Restoring from Backup
|
||||||
|
|
||||||
|
### Interactive Restoration
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./scripts/restore.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
The script will:
|
||||||
|
1. List available backups
|
||||||
|
2. Ask you to select one
|
||||||
|
3. Create a safety backup of current database
|
||||||
|
4. Restore the selected backup
|
||||||
|
5. Verify database integrity
|
||||||
|
|
||||||
|
### Manual Restoration
|
||||||
|
|
||||||
|
```python
|
||||||
|
from pathlib import Path
|
||||||
|
from src.backup.scheduler import BackupScheduler
|
||||||
|
|
||||||
|
scheduler = BackupScheduler('data/trade_logs.db', Path('data/backups'))
|
||||||
|
|
||||||
|
# List backups
|
||||||
|
backups = scheduler.list_backups()
|
||||||
|
for backup in backups:
|
||||||
|
print(f"{backup.timestamp}: {backup.file_path}")
|
||||||
|
|
||||||
|
# Restore specific backup
|
||||||
|
scheduler.restore_backup(backups[0], verify=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Health Monitoring
|
||||||
|
|
||||||
|
### Check System Health
|
||||||
|
|
||||||
|
```python
|
||||||
|
from pathlib import Path
|
||||||
|
from src.backup.health_monitor import HealthMonitor
|
||||||
|
|
||||||
|
monitor = HealthMonitor('data/trade_logs.db', Path('data/backups'))
|
||||||
|
|
||||||
|
# Run all checks
|
||||||
|
report = monitor.get_health_report()
|
||||||
|
print(f"Overall status: {report['overall_status']}")
|
||||||
|
|
||||||
|
# Individual checks
|
||||||
|
checks = monitor.run_all_checks()
|
||||||
|
for name, result in checks.items():
|
||||||
|
print(f"{name}: {result.status.value} - {result.message}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Health Checks
|
||||||
|
|
||||||
|
The system monitors:
|
||||||
|
|
||||||
|
- **Database Health**: Accessibility, integrity, size
|
||||||
|
- **Disk Space**: Available storage (alerts if < 10 GB)
|
||||||
|
- **Backup Recency**: Ensures backups are < 25 hours old
|
||||||
|
|
||||||
|
### Health Status Levels
|
||||||
|
|
||||||
|
- **HEALTHY**: All systems operational
|
||||||
|
- **DEGRADED**: Warning condition (e.g., low disk space)
|
||||||
|
- **UNHEALTHY**: Critical issue (e.g., database corrupted, no backups)
|
||||||
|
|
||||||
|
## Export Formats
|
||||||
|
|
||||||
|
### JSON (Human-Readable)
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"export_timestamp": "2024-01-15T10:30:00Z",
|
||||||
|
"record_count": 150,
|
||||||
|
"trades": [
|
||||||
|
{
|
||||||
|
"timestamp": "2024-01-15T09:00:00Z",
|
||||||
|
"stock_code": "005930",
|
||||||
|
"action": "BUY",
|
||||||
|
"quantity": 10,
|
||||||
|
"price": 70000.0,
|
||||||
|
"confidence": 85,
|
||||||
|
"rationale": "Strong momentum",
|
||||||
|
"pnl": 0.0
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### CSV (Analysis Tools)
|
||||||
|
|
||||||
|
Compatible with Excel, pandas, R:
|
||||||
|
|
||||||
|
```csv
|
||||||
|
timestamp,stock_code,action,quantity,price,confidence,rationale,pnl
|
||||||
|
2024-01-15T09:00:00Z,005930,BUY,10,70000.0,85,Strong momentum,0.0
|
||||||
|
```
|
||||||
|
|
||||||
|
### Parquet (Big Data)
|
||||||
|
|
||||||
|
Columnar format for Spark, DuckDB:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import pandas as pd
|
||||||
|
df = pd.read_parquet('exports/trades_20240115.parquet')
|
||||||
|
```
|
||||||
|
|
||||||
|
## RTO/RPO
|
||||||
|
|
||||||
|
### Recovery Time Objective (RTO)
|
||||||
|
|
||||||
|
**Target: < 5 minutes**
|
||||||
|
|
||||||
|
Time to restore trading operations:
|
||||||
|
1. Identify backup to restore (1 min)
|
||||||
|
2. Run restore script (2 min)
|
||||||
|
3. Verify database integrity (1 min)
|
||||||
|
4. Restart trading system (1 min)
|
||||||
|
|
||||||
|
### Recovery Point Objective (RPO)
|
||||||
|
|
||||||
|
**Target: < 24 hours**
|
||||||
|
|
||||||
|
Maximum acceptable data loss:
|
||||||
|
- Daily backups ensure ≤ 24-hour data loss
|
||||||
|
- For critical periods, run backups more frequently
|
||||||
|
|
||||||
|
## Testing Recovery
|
||||||
|
|
||||||
|
### Quarterly Recovery Test
|
||||||
|
|
||||||
|
Perform full disaster recovery test every quarter:
|
||||||
|
|
||||||
|
1. **Create test backup**
|
||||||
|
```bash
|
||||||
|
./scripts/backup.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Simulate disaster** (use test database)
|
||||||
|
```bash
|
||||||
|
cp data/trade_logs.db data/trade_logs_test.db
|
||||||
|
rm data/trade_logs_test.db # Simulate data loss
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Restore from backup**
|
||||||
|
```bash
|
||||||
|
DB_PATH=data/trade_logs_test.db ./scripts/restore.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Verify data integrity**
|
||||||
|
```python
|
||||||
|
import sqlite3
|
||||||
|
conn = sqlite3.connect('data/trade_logs_test.db')
|
||||||
|
cursor = conn.execute('SELECT COUNT(*) FROM trades')
|
||||||
|
print(f"Restored {cursor.fetchone()[0]} trades")
|
||||||
|
```
|
||||||
|
|
||||||
|
5. **Document results** in `logs/recovery_test_YYYYMMDD.md`
|
||||||
|
|
||||||
|
### Backup Verification
|
||||||
|
|
||||||
|
Always verify backups after creation:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from pathlib import Path
|
||||||
|
from src.backup.scheduler import BackupScheduler
|
||||||
|
|
||||||
|
scheduler = BackupScheduler('data/trade_logs.db', Path('data/backups'))
|
||||||
|
|
||||||
|
# Create and verify
|
||||||
|
metadata = scheduler.create_backup(BackupPolicy.DAILY, verify=True)
|
||||||
|
print(f"Checksum: {metadata.checksum}") # Should not be None
|
||||||
|
```
|
||||||
|
|
||||||
|
## Emergency Procedures
|
||||||
|
|
||||||
|
### Database Corrupted
|
||||||
|
|
||||||
|
1. Stop trading system immediately
|
||||||
|
2. Check most recent backup age: `ls -lht data/backups/daily/`
|
||||||
|
3. Restore: `./scripts/restore.sh`
|
||||||
|
4. Verify: Run health check
|
||||||
|
5. Resume trading
|
||||||
|
|
||||||
|
### Disk Full
|
||||||
|
|
||||||
|
1. Check disk space: `df -h`
|
||||||
|
2. Clean old backups: Run cleanup manually
|
||||||
|
```python
|
||||||
|
from pathlib import Path
|
||||||
|
from src.backup.scheduler import BackupScheduler
|
||||||
|
scheduler = BackupScheduler('data/trade_logs.db', Path('data/backups'))
|
||||||
|
scheduler.cleanup_old_backups()
|
||||||
|
```
|
||||||
|
3. Consider archiving old monthly backups to external storage
|
||||||
|
4. Increase disk space if needed
|
||||||
|
|
||||||
|
### Lost All Backups
|
||||||
|
|
||||||
|
If local backups are lost:
|
||||||
|
1. Check if exports exist in `exports/` directory
|
||||||
|
2. Reconstruct database from CSV/JSON exports
|
||||||
|
3. If no exports: Check broker API for trade history
|
||||||
|
4. Manual reconstruction as last resort
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
1. **Test Restores Regularly**: Don't wait for disaster
|
||||||
|
2. **Monitor Disk Space**: Set up alerts at 80% usage
|
||||||
|
3. **Keep Multiple Generations**: Never delete all backups at once
|
||||||
|
4. **Verify Checksums**: Always verify backup integrity
|
||||||
|
5. **Document Changes**: Update this guide when backup strategy changes
|
||||||
|
6. **Off-Site Storage**: Consider external backup for monthly archives
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Backup Script Fails
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Check database file permissions
|
||||||
|
ls -l data/trade_logs.db
|
||||||
|
|
||||||
|
# Check disk space
|
||||||
|
df -h data/
|
||||||
|
|
||||||
|
# Run backup manually with debug
|
||||||
|
python3 -c "
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
from pathlib import Path
|
||||||
|
from src.backup.scheduler import BackupScheduler, BackupPolicy
|
||||||
|
scheduler = BackupScheduler('data/trade_logs.db', Path('data/backups'))
|
||||||
|
scheduler.create_backup(BackupPolicy.DAILY, verify=True)
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Restore Fails Verification
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Check backup file integrity
|
||||||
|
python3 -c "
|
||||||
|
import sqlite3
|
||||||
|
conn = sqlite3.connect('data/backups/daily/trade_logs_daily_20240115.db')
|
||||||
|
cursor = conn.execute('PRAGMA integrity_check')
|
||||||
|
print(cursor.fetchone()[0])
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Health Check Fails
|
||||||
|
|
||||||
|
```python
|
||||||
|
from pathlib import Path
|
||||||
|
from src.backup.health_monitor import HealthMonitor
|
||||||
|
|
||||||
|
monitor = HealthMonitor('data/trade_logs.db', Path('data/backups'))
|
||||||
|
|
||||||
|
# Check each component individually
|
||||||
|
print("Database:", monitor.check_database_health())
|
||||||
|
print("Disk Space:", monitor.check_disk_space())
|
||||||
|
print("Backup Recency:", monitor.check_backup_recency())
|
||||||
|
```
|
||||||
|
|
||||||
|
## Contact
|
||||||
|
|
||||||
|
For backup/recovery issues:
|
||||||
|
- Check logs: `logs/backup.log`
|
||||||
|
- Review health status: Run health monitor
|
||||||
|
- Raise issue on GitHub if automated recovery fails
|
||||||
96
scripts/backup.sh
Normal file
96
scripts/backup.sh
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# Automated backup script for The Ouroboros trading system
|
||||||
|
# Runs daily/weekly/monthly backups
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
DB_PATH="${DB_PATH:-data/trade_logs.db}"
|
||||||
|
BACKUP_DIR="${BACKUP_DIR:-data/backups}"
|
||||||
|
PYTHON="${PYTHON:-python3}"
|
||||||
|
|
||||||
|
# Colors for output
|
||||||
|
GREEN='\033[0;32m'
|
||||||
|
YELLOW='\033[1;33m'
|
||||||
|
RED='\033[0;31m'
|
||||||
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
|
log_info() {
|
||||||
|
echo -e "${GREEN}[INFO]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
log_warn() {
|
||||||
|
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
log_error() {
|
||||||
|
echo -e "${RED}[ERROR]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if database exists
|
||||||
|
if [ ! -f "$DB_PATH" ]; then
|
||||||
|
log_error "Database not found: $DB_PATH"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Create backup directory
|
||||||
|
mkdir -p "$BACKUP_DIR"
|
||||||
|
|
||||||
|
log_info "Starting backup process..."
|
||||||
|
log_info "Database: $DB_PATH"
|
||||||
|
log_info "Backup directory: $BACKUP_DIR"
|
||||||
|
|
||||||
|
# Determine backup policy based on day of week and month
|
||||||
|
DAY_OF_WEEK=$(date +%u) # 1-7 (Monday-Sunday)
|
||||||
|
DAY_OF_MONTH=$(date +%d)
|
||||||
|
|
||||||
|
if [ "$DAY_OF_MONTH" == "01" ]; then
|
||||||
|
POLICY="monthly"
|
||||||
|
log_info "Running MONTHLY backup (first day of month)"
|
||||||
|
elif [ "$DAY_OF_WEEK" == "7" ]; then
|
||||||
|
POLICY="weekly"
|
||||||
|
log_info "Running WEEKLY backup (Sunday)"
|
||||||
|
else
|
||||||
|
POLICY="daily"
|
||||||
|
log_info "Running DAILY backup"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Run Python backup script
|
||||||
|
$PYTHON -c "
|
||||||
|
from pathlib import Path
|
||||||
|
from src.backup.scheduler import BackupScheduler, BackupPolicy
|
||||||
|
from src.backup.health_monitor import HealthMonitor
|
||||||
|
|
||||||
|
# Create scheduler
|
||||||
|
scheduler = BackupScheduler(
|
||||||
|
db_path='$DB_PATH',
|
||||||
|
backup_dir=Path('$BACKUP_DIR')
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create backup
|
||||||
|
policy = BackupPolicy.$POLICY.upper()
|
||||||
|
metadata = scheduler.create_backup(policy, verify=True)
|
||||||
|
print(f'Backup created: {metadata.file_path}')
|
||||||
|
print(f'Size: {metadata.size_bytes / 1024 / 1024:.2f} MB')
|
||||||
|
print(f'Checksum: {metadata.checksum}')
|
||||||
|
|
||||||
|
# Cleanup old backups
|
||||||
|
removed = scheduler.cleanup_old_backups()
|
||||||
|
total_removed = sum(removed.values())
|
||||||
|
if total_removed > 0:
|
||||||
|
print(f'Removed {total_removed} old backup(s)')
|
||||||
|
|
||||||
|
# Health check
|
||||||
|
monitor = HealthMonitor('$DB_PATH', Path('$BACKUP_DIR'))
|
||||||
|
status = monitor.get_overall_status()
|
||||||
|
print(f'System health: {status.value}')
|
||||||
|
"
|
||||||
|
|
||||||
|
if [ $? -eq 0 ]; then
|
||||||
|
log_info "Backup completed successfully"
|
||||||
|
else
|
||||||
|
log_error "Backup failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
log_info "Backup process finished"
|
||||||
111
scripts/restore.sh
Normal file
111
scripts/restore.sh
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# Restore script for The Ouroboros trading system
|
||||||
|
# Restores database from a backup file
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
DB_PATH="${DB_PATH:-data/trade_logs.db}"
|
||||||
|
BACKUP_DIR="${BACKUP_DIR:-data/backups}"
|
||||||
|
PYTHON="${PYTHON:-python3}"
|
||||||
|
|
||||||
|
# Colors for output
|
||||||
|
GREEN='\033[0;32m'
|
||||||
|
YELLOW='\033[1;33m'
|
||||||
|
RED='\033[0;31m'
|
||||||
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
|
log_info() {
|
||||||
|
echo -e "${GREEN}[INFO]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
log_warn() {
|
||||||
|
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
log_error() {
|
||||||
|
echo -e "${RED}[ERROR]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if backup directory exists
|
||||||
|
if [ ! -d "$BACKUP_DIR" ]; then
|
||||||
|
log_error "Backup directory not found: $BACKUP_DIR"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
log_info "Available backups:"
|
||||||
|
log_info "=================="
|
||||||
|
|
||||||
|
# List available backups
|
||||||
|
$PYTHON -c "
|
||||||
|
from pathlib import Path
|
||||||
|
from src.backup.scheduler import BackupScheduler
|
||||||
|
|
||||||
|
scheduler = BackupScheduler(
|
||||||
|
db_path='$DB_PATH',
|
||||||
|
backup_dir=Path('$BACKUP_DIR')
|
||||||
|
)
|
||||||
|
|
||||||
|
backups = scheduler.list_backups()
|
||||||
|
|
||||||
|
if not backups:
|
||||||
|
print('No backups found.')
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
for i, backup in enumerate(backups, 1):
|
||||||
|
size_mb = backup.size_bytes / 1024 / 1024
|
||||||
|
print(f'{i}. [{backup.policy.value.upper()}] {backup.file_path.name}')
|
||||||
|
print(f' Date: {backup.timestamp.strftime(\"%Y-%m-%d %H:%M:%S UTC\")}')
|
||||||
|
print(f' Size: {size_mb:.2f} MB')
|
||||||
|
print()
|
||||||
|
"
|
||||||
|
|
||||||
|
# Ask user to select backup
|
||||||
|
echo ""
|
||||||
|
read -p "Enter backup number to restore (or 'q' to quit): " BACKUP_NUM
|
||||||
|
|
||||||
|
if [ "$BACKUP_NUM" == "q" ]; then
|
||||||
|
log_info "Restore cancelled"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Confirm restoration
|
||||||
|
log_warn "WARNING: This will replace the current database!"
|
||||||
|
log_warn "Current database will be backed up to: ${DB_PATH}.before_restore"
|
||||||
|
read -p "Are you sure you want to continue? (yes/no): " CONFIRM
|
||||||
|
|
||||||
|
if [ "$CONFIRM" != "yes" ]; then
|
||||||
|
log_info "Restore cancelled"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Perform restoration
|
||||||
|
$PYTHON -c "
|
||||||
|
from pathlib import Path
|
||||||
|
from src.backup.scheduler import BackupScheduler
|
||||||
|
|
||||||
|
scheduler = BackupScheduler(
|
||||||
|
db_path='$DB_PATH',
|
||||||
|
backup_dir=Path('$BACKUP_DIR')
|
||||||
|
)
|
||||||
|
|
||||||
|
backups = scheduler.list_backups()
|
||||||
|
backup_index = int('$BACKUP_NUM') - 1
|
||||||
|
|
||||||
|
if backup_index < 0 or backup_index >= len(backups):
|
||||||
|
print('Invalid backup number')
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
selected = backups[backup_index]
|
||||||
|
print(f'Restoring: {selected.file_path.name}')
|
||||||
|
|
||||||
|
scheduler.restore_backup(selected, verify=True)
|
||||||
|
print('Restore completed successfully')
|
||||||
|
"
|
||||||
|
|
||||||
|
if [ $? -eq 0 ]; then
|
||||||
|
log_info "Database restored successfully"
|
||||||
|
else
|
||||||
|
log_error "Restore failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
21
src/backup/__init__.py
Normal file
21
src/backup/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""Backup and disaster recovery system for long-term sustainability.
|
||||||
|
|
||||||
|
This module provides:
|
||||||
|
- Automated database backups (daily, weekly, monthly)
|
||||||
|
- Multi-format exports (JSON, CSV, Parquet)
|
||||||
|
- Cloud storage integration (S3-compatible)
|
||||||
|
- Health monitoring and alerts
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.backup.exporter import BackupExporter, ExportFormat
|
||||||
|
from src.backup.scheduler import BackupScheduler, BackupPolicy
|
||||||
|
from src.backup.cloud_storage import CloudStorage, S3Config
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BackupExporter",
|
||||||
|
"ExportFormat",
|
||||||
|
"BackupScheduler",
|
||||||
|
"BackupPolicy",
|
||||||
|
"CloudStorage",
|
||||||
|
"S3Config",
|
||||||
|
]
|
||||||
274
src/backup/cloud_storage.py
Normal file
274
src/backup/cloud_storage.py
Normal file
@@ -0,0 +1,274 @@
|
|||||||
|
"""Cloud storage integration for off-site backups.
|
||||||
|
|
||||||
|
Supports S3-compatible storage providers:
|
||||||
|
- AWS S3
|
||||||
|
- MinIO
|
||||||
|
- Backblaze B2
|
||||||
|
- DigitalOcean Spaces
|
||||||
|
- Cloudflare R2
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class S3Config:
|
||||||
|
"""Configuration for S3-compatible storage."""
|
||||||
|
|
||||||
|
endpoint_url: str | None # None for AWS S3, custom URL for others
|
||||||
|
access_key: str
|
||||||
|
secret_key: str
|
||||||
|
bucket_name: str
|
||||||
|
region: str = "us-east-1"
|
||||||
|
use_ssl: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class CloudStorage:
|
||||||
|
"""Upload backups to S3-compatible cloud storage."""
|
||||||
|
|
||||||
|
def __init__(self, config: S3Config) -> None:
|
||||||
|
"""Initialize cloud storage client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: S3 configuration
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If boto3 is not installed
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"boto3 is required for cloud storage. Install with: pip install boto3"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.client = boto3.client(
|
||||||
|
"s3",
|
||||||
|
endpoint_url=config.endpoint_url,
|
||||||
|
aws_access_key_id=config.access_key,
|
||||||
|
aws_secret_access_key=config.secret_key,
|
||||||
|
region_name=config.region,
|
||||||
|
use_ssl=config.use_ssl,
|
||||||
|
)
|
||||||
|
|
||||||
|
def upload_file(
|
||||||
|
self,
|
||||||
|
file_path: Path,
|
||||||
|
object_key: str | None = None,
|
||||||
|
metadata: dict[str, str] | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Upload a file to cloud storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Local file to upload
|
||||||
|
object_key: S3 object key (default: filename)
|
||||||
|
metadata: Optional metadata to attach
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
S3 object key
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If file doesn't exist
|
||||||
|
Exception: If upload fails
|
||||||
|
"""
|
||||||
|
if not file_path.exists():
|
||||||
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
|
|
||||||
|
if object_key is None:
|
||||||
|
object_key = file_path.name
|
||||||
|
|
||||||
|
extra_args: dict[str, Any] = {}
|
||||||
|
|
||||||
|
# Add server-side encryption
|
||||||
|
extra_args["ServerSideEncryption"] = "AES256"
|
||||||
|
|
||||||
|
# Add metadata if provided
|
||||||
|
if metadata:
|
||||||
|
extra_args["Metadata"] = metadata
|
||||||
|
|
||||||
|
logger.info("Uploading %s to s3://%s/%s", file_path.name, self.config.bucket_name, object_key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.client.upload_file(
|
||||||
|
str(file_path),
|
||||||
|
self.config.bucket_name,
|
||||||
|
object_key,
|
||||||
|
ExtraArgs=extra_args,
|
||||||
|
)
|
||||||
|
logger.info("Upload successful: %s", object_key)
|
||||||
|
return object_key
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Upload failed: %s", exc)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def download_file(self, object_key: str, local_path: Path) -> Path:
|
||||||
|
"""Download a file from cloud storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
object_key: S3 object key
|
||||||
|
local_path: Local destination path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to downloaded file
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If download fails
|
||||||
|
"""
|
||||||
|
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
logger.info("Downloading s3://%s/%s to %s", self.config.bucket_name, object_key, local_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.client.download_file(
|
||||||
|
self.config.bucket_name,
|
||||||
|
object_key,
|
||||||
|
str(local_path),
|
||||||
|
)
|
||||||
|
logger.info("Download successful: %s", local_path)
|
||||||
|
return local_path
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Download failed: %s", exc)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def list_files(self, prefix: str = "") -> list[dict[str, Any]]:
|
||||||
|
"""List files in cloud storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix: Filter by object key prefix
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of file metadata dictionaries
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = self.client.list_objects_v2(
|
||||||
|
Bucket=self.config.bucket_name,
|
||||||
|
Prefix=prefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
if "Contents" not in response:
|
||||||
|
return []
|
||||||
|
|
||||||
|
files = []
|
||||||
|
for obj in response["Contents"]:
|
||||||
|
files.append(
|
||||||
|
{
|
||||||
|
"key": obj["Key"],
|
||||||
|
"size_bytes": obj["Size"],
|
||||||
|
"last_modified": obj["LastModified"],
|
||||||
|
"etag": obj["ETag"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return files
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to list files: %s", exc)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def delete_file(self, object_key: str) -> None:
|
||||||
|
"""Delete a file from cloud storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
object_key: S3 object key
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If deletion fails
|
||||||
|
"""
|
||||||
|
logger.info("Deleting s3://%s/%s", self.config.bucket_name, object_key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.client.delete_object(
|
||||||
|
Bucket=self.config.bucket_name,
|
||||||
|
Key=object_key,
|
||||||
|
)
|
||||||
|
logger.info("Deletion successful: %s", object_key)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Deletion failed: %s", exc)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_storage_stats(self) -> dict[str, Any]:
|
||||||
|
"""Get cloud storage statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with storage stats
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
files = self.list_files()
|
||||||
|
|
||||||
|
total_size = sum(f["size_bytes"] for f in files)
|
||||||
|
total_count = len(files)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_files": total_count,
|
||||||
|
"total_size_bytes": total_size,
|
||||||
|
"total_size_mb": total_size / 1024 / 1024,
|
||||||
|
"total_size_gb": total_size / 1024 / 1024 / 1024,
|
||||||
|
}
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to get storage stats: %s", exc)
|
||||||
|
return {
|
||||||
|
"error": str(exc),
|
||||||
|
"total_files": 0,
|
||||||
|
"total_size_bytes": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def verify_connection(self) -> bool:
|
||||||
|
"""Verify connection to cloud storage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if connection is successful
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.client.head_bucket(Bucket=self.config.bucket_name)
|
||||||
|
logger.info("Cloud storage connection verified")
|
||||||
|
return True
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Cloud storage connection failed: %s", exc)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def create_bucket_if_not_exists(self) -> None:
|
||||||
|
"""Create storage bucket if it doesn't exist.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If bucket creation fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.client.head_bucket(Bucket=self.config.bucket_name)
|
||||||
|
logger.info("Bucket already exists: %s", self.config.bucket_name)
|
||||||
|
except self.client.exceptions.NoSuchBucket:
|
||||||
|
logger.info("Creating bucket: %s", self.config.bucket_name)
|
||||||
|
if self.config.region == "us-east-1":
|
||||||
|
# us-east-1 requires special handling
|
||||||
|
self.client.create_bucket(Bucket=self.config.bucket_name)
|
||||||
|
else:
|
||||||
|
self.client.create_bucket(
|
||||||
|
Bucket=self.config.bucket_name,
|
||||||
|
CreateBucketConfiguration={"LocationConstraint": self.config.region},
|
||||||
|
)
|
||||||
|
logger.info("Bucket created successfully")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to verify/create bucket: %s", exc)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def enable_versioning(self) -> None:
|
||||||
|
"""Enable versioning on the bucket.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If versioning enablement fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.client.put_bucket_versioning(
|
||||||
|
Bucket=self.config.bucket_name,
|
||||||
|
VersioningConfiguration={"Status": "Enabled"},
|
||||||
|
)
|
||||||
|
logger.info("Versioning enabled for bucket: %s", self.config.bucket_name)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to enable versioning: %s", exc)
|
||||||
|
raise
|
||||||
326
src/backup/exporter.py
Normal file
326
src/backup/exporter.py
Normal file
@@ -0,0 +1,326 @@
|
|||||||
|
"""Multi-format database exporter for backups.
|
||||||
|
|
||||||
|
Supports JSON, CSV, and Parquet formats for different use cases:
|
||||||
|
- JSON: Human-readable, easy to inspect
|
||||||
|
- CSV: Analysis tools (Excel, pandas)
|
||||||
|
- Parquet: Big data tools (Spark, DuckDB)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import gzip
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sqlite3
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ExportFormat(str, Enum):
|
||||||
|
"""Supported export formats."""
|
||||||
|
|
||||||
|
JSON = "json"
|
||||||
|
CSV = "csv"
|
||||||
|
PARQUET = "parquet"
|
||||||
|
|
||||||
|
|
||||||
|
class BackupExporter:
|
||||||
|
"""Export database to multiple formats."""
|
||||||
|
|
||||||
|
def __init__(self, db_path: str) -> None:
|
||||||
|
"""Initialize the exporter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path to SQLite database
|
||||||
|
"""
|
||||||
|
self.db_path = db_path
|
||||||
|
|
||||||
|
def export_all(
|
||||||
|
self,
|
||||||
|
output_dir: Path,
|
||||||
|
formats: list[ExportFormat] | None = None,
|
||||||
|
compress: bool = True,
|
||||||
|
incremental_since: datetime | None = None,
|
||||||
|
) -> dict[ExportFormat, Path]:
|
||||||
|
"""Export database to multiple formats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_dir: Directory to write export files
|
||||||
|
formats: List of formats to export (default: all)
|
||||||
|
compress: Whether to gzip compress exports
|
||||||
|
incremental_since: Only export records after this timestamp
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping format to output file path
|
||||||
|
"""
|
||||||
|
if formats is None:
|
||||||
|
formats = [ExportFormat.JSON, ExportFormat.CSV, ExportFormat.PARQUET]
|
||||||
|
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S")
|
||||||
|
|
||||||
|
results: dict[ExportFormat, Path] = {}
|
||||||
|
|
||||||
|
for fmt in formats:
|
||||||
|
try:
|
||||||
|
output_file = self._export_format(
|
||||||
|
fmt, output_dir, timestamp, compress, incremental_since
|
||||||
|
)
|
||||||
|
results[fmt] = output_file
|
||||||
|
logger.info("Exported to %s: %s", fmt.value, output_file)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to export to %s: %s", fmt.value, exc)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _export_format(
|
||||||
|
self,
|
||||||
|
fmt: ExportFormat,
|
||||||
|
output_dir: Path,
|
||||||
|
timestamp: str,
|
||||||
|
compress: bool,
|
||||||
|
incremental_since: datetime | None,
|
||||||
|
) -> Path:
|
||||||
|
"""Export to a specific format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fmt: Export format
|
||||||
|
output_dir: Output directory
|
||||||
|
timestamp: Timestamp string for filename
|
||||||
|
compress: Whether to compress
|
||||||
|
incremental_since: Incremental export cutoff
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to output file
|
||||||
|
"""
|
||||||
|
if fmt == ExportFormat.JSON:
|
||||||
|
return self._export_json(output_dir, timestamp, compress, incremental_since)
|
||||||
|
elif fmt == ExportFormat.CSV:
|
||||||
|
return self._export_csv(output_dir, timestamp, compress, incremental_since)
|
||||||
|
elif fmt == ExportFormat.PARQUET:
|
||||||
|
return self._export_parquet(
|
||||||
|
output_dir, timestamp, compress, incremental_since
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported format: {fmt}")
|
||||||
|
|
||||||
|
def _get_trades(
|
||||||
|
self, incremental_since: datetime | None = None
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Fetch trades from database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
incremental_since: Only fetch trades after this timestamp
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of trade records
|
||||||
|
"""
|
||||||
|
conn = sqlite3.connect(self.db_path)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
|
||||||
|
if incremental_since:
|
||||||
|
cursor = conn.execute(
|
||||||
|
"SELECT * FROM trades WHERE timestamp > ?",
|
||||||
|
(incremental_since.isoformat(),),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cursor = conn.execute("SELECT * FROM trades")
|
||||||
|
|
||||||
|
trades = [dict(row) for row in cursor.fetchall()]
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
return trades
|
||||||
|
|
||||||
|
def _export_json(
|
||||||
|
self,
|
||||||
|
output_dir: Path,
|
||||||
|
timestamp: str,
|
||||||
|
compress: bool,
|
||||||
|
incremental_since: datetime | None,
|
||||||
|
) -> Path:
|
||||||
|
"""Export to JSON format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_dir: Output directory
|
||||||
|
timestamp: Timestamp for filename
|
||||||
|
compress: Whether to gzip
|
||||||
|
incremental_since: Incremental cutoff
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to output file
|
||||||
|
"""
|
||||||
|
trades = self._get_trades(incremental_since)
|
||||||
|
|
||||||
|
filename = f"trades_{timestamp}.json"
|
||||||
|
if compress:
|
||||||
|
filename += ".gz"
|
||||||
|
|
||||||
|
output_file = output_dir / filename
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"export_timestamp": datetime.now(UTC).isoformat(),
|
||||||
|
"incremental_since": (
|
||||||
|
incremental_since.isoformat() if incremental_since else None
|
||||||
|
),
|
||||||
|
"record_count": len(trades),
|
||||||
|
"trades": trades,
|
||||||
|
}
|
||||||
|
|
||||||
|
if compress:
|
||||||
|
with gzip.open(output_file, "wt", encoding="utf-8") as f:
|
||||||
|
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||||
|
else:
|
||||||
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
return output_file
|
||||||
|
|
||||||
|
def _export_csv(
|
||||||
|
self,
|
||||||
|
output_dir: Path,
|
||||||
|
timestamp: str,
|
||||||
|
compress: bool,
|
||||||
|
incremental_since: datetime | None,
|
||||||
|
) -> Path:
|
||||||
|
"""Export to CSV format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_dir: Output directory
|
||||||
|
timestamp: Timestamp for filename
|
||||||
|
compress: Whether to gzip
|
||||||
|
incremental_since: Incremental cutoff
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to output file
|
||||||
|
"""
|
||||||
|
trades = self._get_trades(incremental_since)
|
||||||
|
|
||||||
|
filename = f"trades_{timestamp}.csv"
|
||||||
|
if compress:
|
||||||
|
filename += ".gz"
|
||||||
|
|
||||||
|
output_file = output_dir / filename
|
||||||
|
|
||||||
|
if not trades:
|
||||||
|
# Write empty CSV with headers
|
||||||
|
if compress:
|
||||||
|
with gzip.open(output_file, "wt", encoding="utf-8", newline="") as f:
|
||||||
|
writer = csv.writer(f)
|
||||||
|
writer.writerow(
|
||||||
|
[
|
||||||
|
"timestamp",
|
||||||
|
"stock_code",
|
||||||
|
"action",
|
||||||
|
"quantity",
|
||||||
|
"price",
|
||||||
|
"confidence",
|
||||||
|
"rationale",
|
||||||
|
"pnl",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
with open(output_file, "w", encoding="utf-8", newline="") as f:
|
||||||
|
writer = csv.writer(f)
|
||||||
|
writer.writerow(
|
||||||
|
[
|
||||||
|
"timestamp",
|
||||||
|
"stock_code",
|
||||||
|
"action",
|
||||||
|
"quantity",
|
||||||
|
"price",
|
||||||
|
"confidence",
|
||||||
|
"rationale",
|
||||||
|
"pnl",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return output_file
|
||||||
|
|
||||||
|
# Get column names from first trade
|
||||||
|
fieldnames = list(trades[0].keys())
|
||||||
|
|
||||||
|
if compress:
|
||||||
|
with gzip.open(output_file, "wt", encoding="utf-8", newline="") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||||
|
writer.writeheader()
|
||||||
|
writer.writerows(trades)
|
||||||
|
else:
|
||||||
|
with open(output_file, "w", encoding="utf-8", newline="") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||||
|
writer.writeheader()
|
||||||
|
writer.writerows(trades)
|
||||||
|
|
||||||
|
return output_file
|
||||||
|
|
||||||
|
def _export_parquet(
|
||||||
|
self,
|
||||||
|
output_dir: Path,
|
||||||
|
timestamp: str,
|
||||||
|
compress: bool,
|
||||||
|
incremental_since: datetime | None,
|
||||||
|
) -> Path:
|
||||||
|
"""Export to Parquet format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_dir: Output directory
|
||||||
|
timestamp: Timestamp for filename
|
||||||
|
compress: Whether to compress (Parquet has built-in compression)
|
||||||
|
incremental_since: Incremental cutoff
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to output file
|
||||||
|
"""
|
||||||
|
trades = self._get_trades(incremental_since)
|
||||||
|
|
||||||
|
filename = f"trades_{timestamp}.parquet"
|
||||||
|
output_file = output_dir / filename
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pyarrow as pa
|
||||||
|
import pyarrow.parquet as pq
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"pyarrow is required for Parquet export. "
|
||||||
|
"Install with: pip install pyarrow"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to pyarrow table
|
||||||
|
table = pa.Table.from_pylist(trades)
|
||||||
|
|
||||||
|
# Write with compression
|
||||||
|
compression = "gzip" if compress else "none"
|
||||||
|
pq.write_table(table, output_file, compression=compression)
|
||||||
|
|
||||||
|
return output_file
|
||||||
|
|
||||||
|
def get_export_stats(self) -> dict[str, Any]:
|
||||||
|
"""Get statistics about exportable data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with data statistics
|
||||||
|
"""
|
||||||
|
conn = sqlite3.connect(self.db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
stats = {}
|
||||||
|
|
||||||
|
# Total trades
|
||||||
|
cursor.execute("SELECT COUNT(*) FROM trades")
|
||||||
|
stats["total_trades"] = cursor.fetchone()[0]
|
||||||
|
|
||||||
|
# Date range
|
||||||
|
cursor.execute("SELECT MIN(timestamp), MAX(timestamp) FROM trades")
|
||||||
|
min_date, max_date = cursor.fetchone()
|
||||||
|
stats["date_range"] = {"earliest": min_date, "latest": max_date}
|
||||||
|
|
||||||
|
# Database size
|
||||||
|
cursor.execute("SELECT page_count * page_size FROM pragma_page_count(), pragma_page_size()")
|
||||||
|
stats["db_size_bytes"] = cursor.fetchone()[0]
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
return stats
|
||||||
282
src/backup/health_monitor.py
Normal file
282
src/backup/health_monitor.py
Normal file
@@ -0,0 +1,282 @@
|
|||||||
|
"""Health monitoring for backup system.
|
||||||
|
|
||||||
|
Checks:
|
||||||
|
- Database accessibility and integrity
|
||||||
|
- Disk space availability
|
||||||
|
- Backup success/failure tracking
|
||||||
|
- Self-healing capabilities
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
import sqlite3
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HealthStatus(str, Enum):
|
||||||
|
"""Health check status."""
|
||||||
|
|
||||||
|
HEALTHY = "healthy"
|
||||||
|
DEGRADED = "degraded"
|
||||||
|
UNHEALTHY = "unhealthy"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HealthCheckResult:
|
||||||
|
"""Result of a health check."""
|
||||||
|
|
||||||
|
status: HealthStatus
|
||||||
|
message: str
|
||||||
|
details: dict[str, Any] | None = None
|
||||||
|
timestamp: datetime | None = None
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
if self.timestamp is None:
|
||||||
|
self.timestamp = datetime.now(UTC)
|
||||||
|
|
||||||
|
|
||||||
|
class HealthMonitor:
|
||||||
|
"""Monitor system health and backup status."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
db_path: str,
|
||||||
|
backup_dir: Path,
|
||||||
|
min_disk_space_gb: float = 10.0,
|
||||||
|
max_backup_age_hours: int = 25, # Daily backups should be < 25 hours old
|
||||||
|
) -> None:
|
||||||
|
"""Initialize health monitor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path to SQLite database
|
||||||
|
backup_dir: Backup directory
|
||||||
|
min_disk_space_gb: Minimum required disk space in GB
|
||||||
|
max_backup_age_hours: Maximum acceptable backup age in hours
|
||||||
|
"""
|
||||||
|
self.db_path = Path(db_path)
|
||||||
|
self.backup_dir = backup_dir
|
||||||
|
self.min_disk_space_bytes = int(min_disk_space_gb * 1024 * 1024 * 1024)
|
||||||
|
self.max_backup_age = timedelta(hours=max_backup_age_hours)
|
||||||
|
|
||||||
|
def check_database_health(self) -> HealthCheckResult:
|
||||||
|
"""Check database accessibility and integrity.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HealthCheckResult
|
||||||
|
"""
|
||||||
|
# Check if database exists
|
||||||
|
if not self.db_path.exists():
|
||||||
|
return HealthCheckResult(
|
||||||
|
status=HealthStatus.UNHEALTHY,
|
||||||
|
message=f"Database not found: {self.db_path}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if database is accessible
|
||||||
|
try:
|
||||||
|
conn = sqlite3.connect(str(self.db_path))
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# Run integrity check
|
||||||
|
cursor.execute("PRAGMA integrity_check")
|
||||||
|
result = cursor.fetchone()[0]
|
||||||
|
|
||||||
|
if result != "ok":
|
||||||
|
conn.close()
|
||||||
|
return HealthCheckResult(
|
||||||
|
status=HealthStatus.UNHEALTHY,
|
||||||
|
message=f"Database integrity check failed: {result}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get database size
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT page_count * page_size FROM pragma_page_count(), pragma_page_size()"
|
||||||
|
)
|
||||||
|
db_size = cursor.fetchone()[0]
|
||||||
|
|
||||||
|
# Get row counts
|
||||||
|
cursor.execute("SELECT COUNT(*) FROM trades")
|
||||||
|
trade_count = cursor.fetchone()[0]
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
return HealthCheckResult(
|
||||||
|
status=HealthStatus.HEALTHY,
|
||||||
|
message="Database is healthy",
|
||||||
|
details={
|
||||||
|
"size_bytes": db_size,
|
||||||
|
"size_mb": db_size / 1024 / 1024,
|
||||||
|
"trade_count": trade_count,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except sqlite3.Error as exc:
|
||||||
|
return HealthCheckResult(
|
||||||
|
status=HealthStatus.UNHEALTHY,
|
||||||
|
message=f"Database access error: {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_disk_space(self) -> HealthCheckResult:
|
||||||
|
"""Check available disk space.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HealthCheckResult
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
stat = shutil.disk_usage(self.backup_dir)
|
||||||
|
|
||||||
|
free_gb = stat.free / 1024 / 1024 / 1024
|
||||||
|
total_gb = stat.total / 1024 / 1024 / 1024
|
||||||
|
used_percent = (stat.used / stat.total) * 100
|
||||||
|
|
||||||
|
if stat.free < self.min_disk_space_bytes:
|
||||||
|
return HealthCheckResult(
|
||||||
|
status=HealthStatus.UNHEALTHY,
|
||||||
|
message=f"Low disk space: {free_gb:.2f} GB free (minimum: {self.min_disk_space_bytes / 1024 / 1024 / 1024:.2f} GB)",
|
||||||
|
details={
|
||||||
|
"free_gb": free_gb,
|
||||||
|
"total_gb": total_gb,
|
||||||
|
"used_percent": used_percent,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
elif stat.free < self.min_disk_space_bytes * 2:
|
||||||
|
return HealthCheckResult(
|
||||||
|
status=HealthStatus.DEGRADED,
|
||||||
|
message=f"Disk space low: {free_gb:.2f} GB free",
|
||||||
|
details={
|
||||||
|
"free_gb": free_gb,
|
||||||
|
"total_gb": total_gb,
|
||||||
|
"used_percent": used_percent,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return HealthCheckResult(
|
||||||
|
status=HealthStatus.HEALTHY,
|
||||||
|
message=f"Disk space healthy: {free_gb:.2f} GB free",
|
||||||
|
details={
|
||||||
|
"free_gb": free_gb,
|
||||||
|
"total_gb": total_gb,
|
||||||
|
"used_percent": used_percent,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
return HealthCheckResult(
|
||||||
|
status=HealthStatus.UNHEALTHY,
|
||||||
|
message=f"Failed to check disk space: {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_backup_recency(self) -> HealthCheckResult:
|
||||||
|
"""Check if backups are recent enough.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HealthCheckResult
|
||||||
|
"""
|
||||||
|
daily_dir = self.backup_dir / "daily"
|
||||||
|
|
||||||
|
if not daily_dir.exists():
|
||||||
|
return HealthCheckResult(
|
||||||
|
status=HealthStatus.DEGRADED,
|
||||||
|
message="Daily backup directory not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find most recent backup
|
||||||
|
backups = sorted(daily_dir.glob("*.db"), key=lambda p: p.stat().st_mtime, reverse=True)
|
||||||
|
|
||||||
|
if not backups:
|
||||||
|
return HealthCheckResult(
|
||||||
|
status=HealthStatus.UNHEALTHY,
|
||||||
|
message="No daily backups found",
|
||||||
|
)
|
||||||
|
|
||||||
|
most_recent = backups[0]
|
||||||
|
mtime = datetime.fromtimestamp(most_recent.stat().st_mtime, tz=UTC)
|
||||||
|
age = datetime.now(UTC) - mtime
|
||||||
|
|
||||||
|
if age > self.max_backup_age:
|
||||||
|
return HealthCheckResult(
|
||||||
|
status=HealthStatus.DEGRADED,
|
||||||
|
message=f"Most recent backup is {age.total_seconds() / 3600:.1f} hours old",
|
||||||
|
details={
|
||||||
|
"backup_file": most_recent.name,
|
||||||
|
"age_hours": age.total_seconds() / 3600,
|
||||||
|
"threshold_hours": self.max_backup_age.total_seconds() / 3600,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return HealthCheckResult(
|
||||||
|
status=HealthStatus.HEALTHY,
|
||||||
|
message=f"Recent backup found ({age.total_seconds() / 3600:.1f} hours old)",
|
||||||
|
details={
|
||||||
|
"backup_file": most_recent.name,
|
||||||
|
"age_hours": age.total_seconds() / 3600,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_all_checks(self) -> dict[str, HealthCheckResult]:
|
||||||
|
"""Run all health checks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping check name to result
|
||||||
|
"""
|
||||||
|
checks = {
|
||||||
|
"database": self.check_database_health(),
|
||||||
|
"disk_space": self.check_disk_space(),
|
||||||
|
"backup_recency": self.check_backup_recency(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Log results
|
||||||
|
for check_name, result in checks.items():
|
||||||
|
if result.status == HealthStatus.UNHEALTHY:
|
||||||
|
logger.error("[%s] %s: %s", check_name, result.status.value, result.message)
|
||||||
|
elif result.status == HealthStatus.DEGRADED:
|
||||||
|
logger.warning("[%s] %s: %s", check_name, result.status.value, result.message)
|
||||||
|
else:
|
||||||
|
logger.info("[%s] %s: %s", check_name, result.status.value, result.message)
|
||||||
|
|
||||||
|
return checks
|
||||||
|
|
||||||
|
def get_overall_status(self) -> HealthStatus:
|
||||||
|
"""Get overall system health status.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HealthStatus (worst status from all checks)
|
||||||
|
"""
|
||||||
|
checks = self.run_all_checks()
|
||||||
|
|
||||||
|
# Return worst status
|
||||||
|
if any(c.status == HealthStatus.UNHEALTHY for c in checks.values()):
|
||||||
|
return HealthStatus.UNHEALTHY
|
||||||
|
elif any(c.status == HealthStatus.DEGRADED for c in checks.values()):
|
||||||
|
return HealthStatus.DEGRADED
|
||||||
|
else:
|
||||||
|
return HealthStatus.HEALTHY
|
||||||
|
|
||||||
|
def get_health_report(self) -> dict[str, Any]:
|
||||||
|
"""Get comprehensive health report.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with health report
|
||||||
|
"""
|
||||||
|
checks = self.run_all_checks()
|
||||||
|
overall = self.get_overall_status()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"overall_status": overall.value,
|
||||||
|
"timestamp": datetime.now(UTC).isoformat(),
|
||||||
|
"checks": {
|
||||||
|
name: {
|
||||||
|
"status": result.status.value,
|
||||||
|
"message": result.message,
|
||||||
|
"details": result.details,
|
||||||
|
}
|
||||||
|
for name, result in checks.items()
|
||||||
|
},
|
||||||
|
}
|
||||||
336
src/backup/scheduler.py
Normal file
336
src/backup/scheduler.py
Normal file
@@ -0,0 +1,336 @@
|
|||||||
|
"""Backup scheduler for automated database backups.
|
||||||
|
|
||||||
|
Implements backup policies:
|
||||||
|
- Daily: Keep for 30 days (hot storage)
|
||||||
|
- Weekly: Keep for 1 year (warm storage)
|
||||||
|
- Monthly: Keep forever (cold storage)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BackupPolicy(str, Enum):
|
||||||
|
"""Backup retention policies."""
|
||||||
|
|
||||||
|
DAILY = "daily"
|
||||||
|
WEEKLY = "weekly"
|
||||||
|
MONTHLY = "monthly"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BackupMetadata:
|
||||||
|
"""Metadata for a backup."""
|
||||||
|
|
||||||
|
timestamp: datetime
|
||||||
|
policy: BackupPolicy
|
||||||
|
file_path: Path
|
||||||
|
size_bytes: int
|
||||||
|
checksum: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class BackupScheduler:
|
||||||
|
"""Manage automated database backups with retention policies."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
db_path: str,
|
||||||
|
backup_dir: Path,
|
||||||
|
daily_retention_days: int = 30,
|
||||||
|
weekly_retention_days: int = 365,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the backup scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path to SQLite database
|
||||||
|
backup_dir: Root directory for backups
|
||||||
|
daily_retention_days: Days to keep daily backups
|
||||||
|
weekly_retention_days: Days to keep weekly backups
|
||||||
|
"""
|
||||||
|
self.db_path = Path(db_path)
|
||||||
|
self.backup_dir = backup_dir
|
||||||
|
self.daily_retention = timedelta(days=daily_retention_days)
|
||||||
|
self.weekly_retention = timedelta(days=weekly_retention_days)
|
||||||
|
|
||||||
|
# Create policy-specific directories
|
||||||
|
self.daily_dir = backup_dir / "daily"
|
||||||
|
self.weekly_dir = backup_dir / "weekly"
|
||||||
|
self.monthly_dir = backup_dir / "monthly"
|
||||||
|
|
||||||
|
for d in [self.daily_dir, self.weekly_dir, self.monthly_dir]:
|
||||||
|
d.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def create_backup(
|
||||||
|
self, policy: BackupPolicy, verify: bool = True
|
||||||
|
) -> BackupMetadata:
|
||||||
|
"""Create a database backup.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
policy: Backup policy (daily/weekly/monthly)
|
||||||
|
verify: Whether to verify backup integrity
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BackupMetadata object
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If database doesn't exist
|
||||||
|
OSError: If backup fails
|
||||||
|
"""
|
||||||
|
if not self.db_path.exists():
|
||||||
|
raise FileNotFoundError(f"Database not found: {self.db_path}")
|
||||||
|
|
||||||
|
timestamp = datetime.now(UTC)
|
||||||
|
backup_filename = self._get_backup_filename(timestamp, policy)
|
||||||
|
|
||||||
|
# Determine output directory
|
||||||
|
if policy == BackupPolicy.DAILY:
|
||||||
|
output_dir = self.daily_dir
|
||||||
|
elif policy == BackupPolicy.WEEKLY:
|
||||||
|
output_dir = self.weekly_dir
|
||||||
|
else: # MONTHLY
|
||||||
|
output_dir = self.monthly_dir
|
||||||
|
|
||||||
|
backup_path = output_dir / backup_filename
|
||||||
|
|
||||||
|
# Create backup (copy database file)
|
||||||
|
logger.info("Creating %s backup: %s", policy.value, backup_path)
|
||||||
|
shutil.copy2(self.db_path, backup_path)
|
||||||
|
|
||||||
|
# Get file size
|
||||||
|
size_bytes = backup_path.stat().st_size
|
||||||
|
|
||||||
|
# Verify backup if requested
|
||||||
|
checksum = None
|
||||||
|
if verify:
|
||||||
|
checksum = self._verify_backup(backup_path)
|
||||||
|
|
||||||
|
metadata = BackupMetadata(
|
||||||
|
timestamp=timestamp,
|
||||||
|
policy=policy,
|
||||||
|
file_path=backup_path,
|
||||||
|
size_bytes=size_bytes,
|
||||||
|
checksum=checksum,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Backup created: %s (%.2f MB)",
|
||||||
|
backup_path.name,
|
||||||
|
size_bytes / 1024 / 1024,
|
||||||
|
)
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def _get_backup_filename(self, timestamp: datetime, policy: BackupPolicy) -> str:
|
||||||
|
"""Generate backup filename.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timestamp: Backup timestamp
|
||||||
|
policy: Backup policy
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filename string
|
||||||
|
"""
|
||||||
|
ts_str = timestamp.strftime("%Y%m%d_%H%M%S")
|
||||||
|
return f"trade_logs_{policy.value}_{ts_str}.db"
|
||||||
|
|
||||||
|
def _verify_backup(self, backup_path: Path) -> str:
|
||||||
|
"""Verify backup integrity using SQLite integrity check.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backup_path: Path to backup file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Checksum string (MD5 hash)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If integrity check fails
|
||||||
|
"""
|
||||||
|
import hashlib
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
|
# Integrity check
|
||||||
|
try:
|
||||||
|
conn = sqlite3.connect(str(backup_path))
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("PRAGMA integrity_check")
|
||||||
|
result = cursor.fetchone()[0]
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
if result != "ok":
|
||||||
|
raise RuntimeError(f"Integrity check failed: {result}")
|
||||||
|
except sqlite3.Error as exc:
|
||||||
|
raise RuntimeError(f"Failed to verify backup: {exc}")
|
||||||
|
|
||||||
|
# Calculate MD5 checksum
|
||||||
|
md5 = hashlib.md5()
|
||||||
|
with open(backup_path, "rb") as f:
|
||||||
|
for chunk in iter(lambda: f.read(8192), b""):
|
||||||
|
md5.update(chunk)
|
||||||
|
|
||||||
|
return md5.hexdigest()
|
||||||
|
|
||||||
|
def cleanup_old_backups(self) -> dict[BackupPolicy, int]:
|
||||||
|
"""Remove backups older than retention policies.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping policy to number of backups removed
|
||||||
|
"""
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
removed_counts: dict[BackupPolicy, int] = {}
|
||||||
|
|
||||||
|
# Daily backups: remove older than retention
|
||||||
|
removed_counts[BackupPolicy.DAILY] = self._cleanup_directory(
|
||||||
|
self.daily_dir, now - self.daily_retention
|
||||||
|
)
|
||||||
|
|
||||||
|
# Weekly backups: remove older than retention
|
||||||
|
removed_counts[BackupPolicy.WEEKLY] = self._cleanup_directory(
|
||||||
|
self.weekly_dir, now - self.weekly_retention
|
||||||
|
)
|
||||||
|
|
||||||
|
# Monthly backups: never remove (kept forever)
|
||||||
|
removed_counts[BackupPolicy.MONTHLY] = 0
|
||||||
|
|
||||||
|
total = sum(removed_counts.values())
|
||||||
|
if total > 0:
|
||||||
|
logger.info("Cleaned up %d old backup(s)", total)
|
||||||
|
|
||||||
|
return removed_counts
|
||||||
|
|
||||||
|
def _cleanup_directory(self, directory: Path, cutoff: datetime) -> int:
|
||||||
|
"""Remove backups older than cutoff date.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
directory: Directory to clean
|
||||||
|
cutoff: Remove files older than this
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of files removed
|
||||||
|
"""
|
||||||
|
removed = 0
|
||||||
|
|
||||||
|
for backup_file in directory.glob("*.db"):
|
||||||
|
# Get file modification time
|
||||||
|
mtime = datetime.fromtimestamp(backup_file.stat().st_mtime, tz=UTC)
|
||||||
|
|
||||||
|
if mtime < cutoff:
|
||||||
|
logger.debug("Removing old backup: %s", backup_file.name)
|
||||||
|
backup_file.unlink()
|
||||||
|
removed += 1
|
||||||
|
|
||||||
|
return removed
|
||||||
|
|
||||||
|
def list_backups(
|
||||||
|
self, policy: BackupPolicy | None = None
|
||||||
|
) -> list[BackupMetadata]:
|
||||||
|
"""List available backups.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
policy: Filter by policy (None for all)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of BackupMetadata objects
|
||||||
|
"""
|
||||||
|
backups: list[BackupMetadata] = []
|
||||||
|
|
||||||
|
policies_to_check = (
|
||||||
|
[policy] if policy else [BackupPolicy.DAILY, BackupPolicy.WEEKLY, BackupPolicy.MONTHLY]
|
||||||
|
)
|
||||||
|
|
||||||
|
for pol in policies_to_check:
|
||||||
|
if pol == BackupPolicy.DAILY:
|
||||||
|
directory = self.daily_dir
|
||||||
|
elif pol == BackupPolicy.WEEKLY:
|
||||||
|
directory = self.weekly_dir
|
||||||
|
else:
|
||||||
|
directory = self.monthly_dir
|
||||||
|
|
||||||
|
for backup_file in sorted(directory.glob("*.db")):
|
||||||
|
mtime = datetime.fromtimestamp(backup_file.stat().st_mtime, tz=UTC)
|
||||||
|
size = backup_file.stat().st_size
|
||||||
|
|
||||||
|
backups.append(
|
||||||
|
BackupMetadata(
|
||||||
|
timestamp=mtime,
|
||||||
|
policy=pol,
|
||||||
|
file_path=backup_file,
|
||||||
|
size_bytes=size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sort by timestamp (newest first)
|
||||||
|
backups.sort(key=lambda b: b.timestamp, reverse=True)
|
||||||
|
|
||||||
|
return backups
|
||||||
|
|
||||||
|
def get_backup_stats(self) -> dict[str, Any]:
|
||||||
|
"""Get backup statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with backup stats
|
||||||
|
"""
|
||||||
|
stats: dict[str, Any] = {}
|
||||||
|
|
||||||
|
for policy in BackupPolicy:
|
||||||
|
if policy == BackupPolicy.DAILY:
|
||||||
|
directory = self.daily_dir
|
||||||
|
elif policy == BackupPolicy.WEEKLY:
|
||||||
|
directory = self.weekly_dir
|
||||||
|
else:
|
||||||
|
directory = self.monthly_dir
|
||||||
|
|
||||||
|
backups = list(directory.glob("*.db"))
|
||||||
|
total_size = sum(b.stat().st_size for b in backups)
|
||||||
|
|
||||||
|
stats[policy.value] = {
|
||||||
|
"count": len(backups),
|
||||||
|
"total_size_bytes": total_size,
|
||||||
|
"total_size_mb": total_size / 1024 / 1024,
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def restore_backup(self, backup_metadata: BackupMetadata, verify: bool = True) -> None:
|
||||||
|
"""Restore database from backup.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backup_metadata: Backup to restore
|
||||||
|
verify: Whether to verify restored database
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If backup file doesn't exist
|
||||||
|
RuntimeError: If verification fails
|
||||||
|
"""
|
||||||
|
if not backup_metadata.file_path.exists():
|
||||||
|
raise FileNotFoundError(f"Backup not found: {backup_metadata.file_path}")
|
||||||
|
|
||||||
|
# Create backup of current database
|
||||||
|
if self.db_path.exists():
|
||||||
|
backup_current = self.db_path.with_suffix(".db.before_restore")
|
||||||
|
logger.info("Backing up current database to: %s", backup_current)
|
||||||
|
shutil.copy2(self.db_path, backup_current)
|
||||||
|
|
||||||
|
# Restore backup
|
||||||
|
logger.info("Restoring backup: %s", backup_metadata.file_path.name)
|
||||||
|
shutil.copy2(backup_metadata.file_path, self.db_path)
|
||||||
|
|
||||||
|
# Verify restored database
|
||||||
|
if verify:
|
||||||
|
try:
|
||||||
|
self._verify_backup(self.db_path)
|
||||||
|
logger.info("Backup restored and verified successfully")
|
||||||
|
except RuntimeError as exc:
|
||||||
|
# Restore failed, revert to backup
|
||||||
|
if backup_current.exists():
|
||||||
|
logger.error("Restore verification failed, reverting: %s", exc)
|
||||||
|
shutil.copy2(backup_current, self.db_path)
|
||||||
|
raise
|
||||||
296
src/brain/context_selector.py
Normal file
296
src/brain/context_selector.py
Normal file
@@ -0,0 +1,296 @@
|
|||||||
|
"""Smart context selection for optimizing token usage.
|
||||||
|
|
||||||
|
This module implements intelligent selection of context layers (L1-L7) based on
|
||||||
|
decision type and market conditions:
|
||||||
|
- L7 (real-time) for normal trading decisions
|
||||||
|
- L6-L5 (daily/weekly) for strategic decisions
|
||||||
|
- L4-L1 (monthly/legacy) only for major events or policy changes
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from src.context.layer import ContextLayer
|
||||||
|
from src.context.store import ContextStore
|
||||||
|
|
||||||
|
|
||||||
|
class DecisionType(str, Enum):
|
||||||
|
"""Type of trading decision being made."""
|
||||||
|
|
||||||
|
NORMAL = "normal" # Regular trade decision
|
||||||
|
STRATEGIC = "strategic" # Strategy adjustment
|
||||||
|
MAJOR_EVENT = "major_event" # Portfolio rebalancing, policy change
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ContextSelection:
|
||||||
|
"""Selected context layers and their relevance scores."""
|
||||||
|
|
||||||
|
layers: list[ContextLayer]
|
||||||
|
relevance_scores: dict[ContextLayer, float]
|
||||||
|
total_score: float
|
||||||
|
|
||||||
|
|
||||||
|
class ContextSelector:
|
||||||
|
"""Selects optimal context layers to minimize token usage."""
|
||||||
|
|
||||||
|
def __init__(self, store: ContextStore) -> None:
|
||||||
|
"""Initialize the context selector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
store: ContextStore instance for retrieving context data
|
||||||
|
"""
|
||||||
|
self.store = store
|
||||||
|
|
||||||
|
def select_layers(
|
||||||
|
self,
|
||||||
|
decision_type: DecisionType = DecisionType.NORMAL,
|
||||||
|
include_realtime: bool = True,
|
||||||
|
) -> list[ContextLayer]:
|
||||||
|
"""Select context layers based on decision type.
|
||||||
|
|
||||||
|
Strategy:
|
||||||
|
- NORMAL: L7 (real-time) only
|
||||||
|
- STRATEGIC: L7 + L6 + L5 (real-time + daily + weekly)
|
||||||
|
- MAJOR_EVENT: All layers L1-L7
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decision_type: Type of decision being made
|
||||||
|
include_realtime: Whether to include L7 real-time data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of context layers to use (ordered by priority)
|
||||||
|
"""
|
||||||
|
if decision_type == DecisionType.NORMAL:
|
||||||
|
# Normal trading: only real-time data
|
||||||
|
return [ContextLayer.L7_REALTIME] if include_realtime else []
|
||||||
|
|
||||||
|
elif decision_type == DecisionType.STRATEGIC:
|
||||||
|
# Strategic decisions: real-time + recent history
|
||||||
|
layers = []
|
||||||
|
if include_realtime:
|
||||||
|
layers.append(ContextLayer.L7_REALTIME)
|
||||||
|
layers.extend([ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY])
|
||||||
|
return layers
|
||||||
|
|
||||||
|
else: # MAJOR_EVENT
|
||||||
|
# Major events: all layers for comprehensive context
|
||||||
|
layers = []
|
||||||
|
if include_realtime:
|
||||||
|
layers.append(ContextLayer.L7_REALTIME)
|
||||||
|
layers.extend(
|
||||||
|
[
|
||||||
|
ContextLayer.L6_DAILY,
|
||||||
|
ContextLayer.L5_WEEKLY,
|
||||||
|
ContextLayer.L4_MONTHLY,
|
||||||
|
ContextLayer.L3_QUARTERLY,
|
||||||
|
ContextLayer.L2_ANNUAL,
|
||||||
|
ContextLayer.L1_LEGACY,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return layers
|
||||||
|
|
||||||
|
def score_layer_relevance(
|
||||||
|
self,
|
||||||
|
layer: ContextLayer,
|
||||||
|
decision_type: DecisionType,
|
||||||
|
current_time: datetime | None = None,
|
||||||
|
) -> float:
|
||||||
|
"""Calculate relevance score for a context layer.
|
||||||
|
|
||||||
|
Relevance is based on:
|
||||||
|
1. Decision type (normal, strategic, major event)
|
||||||
|
2. Layer recency (L7 > L6 > ... > L1)
|
||||||
|
3. Data availability
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer: Context layer to score
|
||||||
|
decision_type: Type of decision being made
|
||||||
|
current_time: Current time (defaults to now)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Relevance score (0.0 to 1.0)
|
||||||
|
"""
|
||||||
|
if current_time is None:
|
||||||
|
current_time = datetime.now(UTC)
|
||||||
|
|
||||||
|
# Base scores by decision type
|
||||||
|
base_scores = {
|
||||||
|
DecisionType.NORMAL: {
|
||||||
|
ContextLayer.L7_REALTIME: 1.0,
|
||||||
|
ContextLayer.L6_DAILY: 0.1,
|
||||||
|
ContextLayer.L5_WEEKLY: 0.05,
|
||||||
|
ContextLayer.L4_MONTHLY: 0.01,
|
||||||
|
ContextLayer.L3_QUARTERLY: 0.0,
|
||||||
|
ContextLayer.L2_ANNUAL: 0.0,
|
||||||
|
ContextLayer.L1_LEGACY: 0.0,
|
||||||
|
},
|
||||||
|
DecisionType.STRATEGIC: {
|
||||||
|
ContextLayer.L7_REALTIME: 0.9,
|
||||||
|
ContextLayer.L6_DAILY: 0.8,
|
||||||
|
ContextLayer.L5_WEEKLY: 0.7,
|
||||||
|
ContextLayer.L4_MONTHLY: 0.3,
|
||||||
|
ContextLayer.L3_QUARTERLY: 0.2,
|
||||||
|
ContextLayer.L2_ANNUAL: 0.1,
|
||||||
|
ContextLayer.L1_LEGACY: 0.05,
|
||||||
|
},
|
||||||
|
DecisionType.MAJOR_EVENT: {
|
||||||
|
ContextLayer.L7_REALTIME: 0.7,
|
||||||
|
ContextLayer.L6_DAILY: 0.7,
|
||||||
|
ContextLayer.L5_WEEKLY: 0.7,
|
||||||
|
ContextLayer.L4_MONTHLY: 0.8,
|
||||||
|
ContextLayer.L3_QUARTERLY: 0.8,
|
||||||
|
ContextLayer.L2_ANNUAL: 0.9,
|
||||||
|
ContextLayer.L1_LEGACY: 1.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
score = base_scores[decision_type].get(layer, 0.0)
|
||||||
|
|
||||||
|
# Check data availability
|
||||||
|
latest_timeframe = self.store.get_latest_timeframe(layer)
|
||||||
|
if latest_timeframe is None:
|
||||||
|
# No data available - reduce score significantly
|
||||||
|
score *= 0.1
|
||||||
|
|
||||||
|
return score
|
||||||
|
|
||||||
|
def select_with_scoring(
|
||||||
|
self,
|
||||||
|
decision_type: DecisionType = DecisionType.NORMAL,
|
||||||
|
min_score: float = 0.5,
|
||||||
|
) -> ContextSelection:
|
||||||
|
"""Select context layers with relevance scoring.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decision_type: Type of decision being made
|
||||||
|
min_score: Minimum relevance score to include a layer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ContextSelection with selected layers and scores
|
||||||
|
"""
|
||||||
|
all_layers = [
|
||||||
|
ContextLayer.L7_REALTIME,
|
||||||
|
ContextLayer.L6_DAILY,
|
||||||
|
ContextLayer.L5_WEEKLY,
|
||||||
|
ContextLayer.L4_MONTHLY,
|
||||||
|
ContextLayer.L3_QUARTERLY,
|
||||||
|
ContextLayer.L2_ANNUAL,
|
||||||
|
ContextLayer.L1_LEGACY,
|
||||||
|
]
|
||||||
|
|
||||||
|
scores = {
|
||||||
|
layer: self.score_layer_relevance(layer, decision_type) for layer in all_layers
|
||||||
|
}
|
||||||
|
|
||||||
|
# Filter by minimum score
|
||||||
|
selected_layers = [layer for layer, score in scores.items() if score >= min_score]
|
||||||
|
|
||||||
|
# Sort by score (descending)
|
||||||
|
selected_layers.sort(key=lambda layer: scores[layer], reverse=True)
|
||||||
|
|
||||||
|
total_score = sum(scores[layer] for layer in selected_layers)
|
||||||
|
|
||||||
|
return ContextSelection(
|
||||||
|
layers=selected_layers,
|
||||||
|
relevance_scores=scores,
|
||||||
|
total_score=total_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_context_data(
|
||||||
|
self,
|
||||||
|
layers: list[ContextLayer],
|
||||||
|
max_items_per_layer: int = 10,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Retrieve context data for selected layers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layers: List of context layers to retrieve
|
||||||
|
max_items_per_layer: Maximum number of items per layer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with context data organized by layer
|
||||||
|
"""
|
||||||
|
result: dict[str, Any] = {}
|
||||||
|
|
||||||
|
for layer in layers:
|
||||||
|
# Get latest timeframe for this layer
|
||||||
|
latest_timeframe = self.store.get_latest_timeframe(layer)
|
||||||
|
if latest_timeframe:
|
||||||
|
# Get all contexts for latest timeframe
|
||||||
|
contexts = self.store.get_all_contexts(layer, latest_timeframe)
|
||||||
|
|
||||||
|
# Limit number of items
|
||||||
|
if len(contexts) > max_items_per_layer:
|
||||||
|
# Keep only first N items
|
||||||
|
contexts = dict(list(contexts.items())[:max_items_per_layer])
|
||||||
|
|
||||||
|
result[layer.value] = contexts
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def estimate_context_tokens(self, context_data: dict[str, Any]) -> int:
|
||||||
|
"""Estimate total tokens for context data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_data: Context data dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
from src.brain.prompt_optimizer import PromptOptimizer
|
||||||
|
|
||||||
|
# Serialize to JSON and estimate tokens
|
||||||
|
json_str = json.dumps(context_data, ensure_ascii=False)
|
||||||
|
return PromptOptimizer.estimate_tokens(json_str)
|
||||||
|
|
||||||
|
def optimize_context_for_budget(
|
||||||
|
self,
|
||||||
|
decision_type: DecisionType,
|
||||||
|
max_tokens: int,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Select and retrieve context data within a token budget.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decision_type: Type of decision being made
|
||||||
|
max_tokens: Maximum token budget for context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optimized context data within budget
|
||||||
|
"""
|
||||||
|
# Start with minimal selection
|
||||||
|
selection = self.select_with_scoring(decision_type, min_score=0.5)
|
||||||
|
|
||||||
|
# Retrieve data
|
||||||
|
context_data = self.get_context_data(selection.layers)
|
||||||
|
|
||||||
|
# Check if within budget
|
||||||
|
estimated_tokens = self.estimate_context_tokens(context_data)
|
||||||
|
|
||||||
|
if estimated_tokens <= max_tokens:
|
||||||
|
return context_data
|
||||||
|
|
||||||
|
# If over budget, progressively reduce
|
||||||
|
# 1. Reduce items per layer
|
||||||
|
for max_items in [5, 3, 1]:
|
||||||
|
context_data = self.get_context_data(selection.layers, max_items)
|
||||||
|
estimated_tokens = self.estimate_context_tokens(context_data)
|
||||||
|
if estimated_tokens <= max_tokens:
|
||||||
|
return context_data
|
||||||
|
|
||||||
|
# 2. Remove lower-priority layers
|
||||||
|
for min_score in [0.6, 0.7, 0.8, 0.9]:
|
||||||
|
selection = self.select_with_scoring(decision_type, min_score=min_score)
|
||||||
|
context_data = self.get_context_data(selection.layers, max_items_per_layer=1)
|
||||||
|
estimated_tokens = self.estimate_context_tokens(context_data)
|
||||||
|
if estimated_tokens <= max_tokens:
|
||||||
|
return context_data
|
||||||
|
|
||||||
|
# Last resort: return only L7 with minimal data
|
||||||
|
return self.get_context_data([ContextLayer.L7_REALTIME], max_items_per_layer=1)
|
||||||
@@ -7,7 +7,12 @@ Includes token efficiency optimizations:
|
|||||||
- Prompt compression and abbreviation
|
- Prompt compression and abbreviation
|
||||||
- Response caching for common scenarios
|
- Response caching for common scenarios
|
||||||
- Smart context selection
|
- Smart context selection
|
||||||
- Token usage tracking
|
- Token usage tracking and metrics
|
||||||
|
|
||||||
|
Includes external data integration:
|
||||||
|
- News sentiment analysis
|
||||||
|
- Economic calendar events
|
||||||
|
- Market indicators
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -15,7 +20,7 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from google import genai
|
from google import genai
|
||||||
|
|||||||
267
src/brain/prompt_optimizer.py
Normal file
267
src/brain/prompt_optimizer.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
"""Prompt optimization utilities for reducing token usage.
|
||||||
|
|
||||||
|
This module provides tools to compress prompts while maintaining decision quality:
|
||||||
|
- Token counting
|
||||||
|
- Text compression and abbreviation
|
||||||
|
- Template-based prompts with variable slots
|
||||||
|
- Priority-based context truncation
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
# Abbreviation mapping for common terms
|
||||||
|
ABBREVIATIONS = {
|
||||||
|
"price": "P",
|
||||||
|
"volume": "V",
|
||||||
|
"current": "cur",
|
||||||
|
"previous": "prev",
|
||||||
|
"change": "chg",
|
||||||
|
"percentage": "pct",
|
||||||
|
"market": "mkt",
|
||||||
|
"orderbook": "ob",
|
||||||
|
"foreigner": "fgn",
|
||||||
|
"buy": "B",
|
||||||
|
"sell": "S",
|
||||||
|
"hold": "H",
|
||||||
|
"confidence": "conf",
|
||||||
|
"rationale": "reason",
|
||||||
|
"action": "act",
|
||||||
|
"net": "net",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Reverse mapping for decompression
|
||||||
|
REVERSE_ABBREVIATIONS = {v: k for k, v in ABBREVIATIONS.items()}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TokenMetrics:
|
||||||
|
"""Metrics about token usage in a prompt."""
|
||||||
|
|
||||||
|
char_count: int
|
||||||
|
word_count: int
|
||||||
|
estimated_tokens: int # Rough estimate: ~4 chars per token
|
||||||
|
compression_ratio: float = 1.0 # Original / Compressed
|
||||||
|
|
||||||
|
|
||||||
|
class PromptOptimizer:
|
||||||
|
"""Optimizes prompts to reduce token usage while maintaining quality."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def estimate_tokens(text: str) -> int:
|
||||||
|
"""Estimate token count for text.
|
||||||
|
|
||||||
|
Uses a simple heuristic: ~4 characters per token for English.
|
||||||
|
This is approximate but sufficient for optimization purposes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Input text to estimate tokens for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return 0
|
||||||
|
# Simple estimate: 1 token ≈ 4 characters
|
||||||
|
return max(1, len(text) // 4)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def count_tokens(text: str) -> TokenMetrics:
|
||||||
|
"""Count various metrics for a text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Input text to analyze
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TokenMetrics with character, word, and estimated token counts
|
||||||
|
"""
|
||||||
|
char_count = len(text)
|
||||||
|
word_count = len(text.split())
|
||||||
|
estimated_tokens = PromptOptimizer.estimate_tokens(text)
|
||||||
|
|
||||||
|
return TokenMetrics(
|
||||||
|
char_count=char_count,
|
||||||
|
word_count=word_count,
|
||||||
|
estimated_tokens=estimated_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compress_json(data: dict[str, Any]) -> str:
|
||||||
|
"""Compress JSON by removing whitespace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary to serialize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compact JSON string without whitespace
|
||||||
|
"""
|
||||||
|
return json.dumps(data, separators=(",", ":"), ensure_ascii=False)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def abbreviate_text(text: str, aggressive: bool = False) -> str:
|
||||||
|
"""Apply abbreviations to reduce text length.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Input text to abbreviate
|
||||||
|
aggressive: If True, apply more aggressive compression
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Abbreviated text
|
||||||
|
"""
|
||||||
|
result = text
|
||||||
|
|
||||||
|
# Apply word-level abbreviations (case-insensitive)
|
||||||
|
for full, abbr in ABBREVIATIONS.items():
|
||||||
|
# Word boundaries to avoid partial replacements
|
||||||
|
pattern = r"\b" + re.escape(full) + r"\b"
|
||||||
|
result = re.sub(pattern, abbr, result, flags=re.IGNORECASE)
|
||||||
|
|
||||||
|
if aggressive:
|
||||||
|
# Remove articles and filler words
|
||||||
|
result = re.sub(r"\b(a|an|the)\b", "", result, flags=re.IGNORECASE)
|
||||||
|
result = re.sub(r"\b(is|are|was|were)\b", "", result, flags=re.IGNORECASE)
|
||||||
|
# Collapse multiple spaces
|
||||||
|
result = re.sub(r"\s+", " ", result)
|
||||||
|
|
||||||
|
return result.strip()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_compressed_prompt(
|
||||||
|
market_data: dict[str, Any],
|
||||||
|
include_instructions: bool = True,
|
||||||
|
max_length: int | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Build a compressed prompt from market data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
market_data: Market data dictionary with stock info
|
||||||
|
include_instructions: Whether to include full instructions
|
||||||
|
max_length: Maximum character length (truncates if needed)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compressed prompt string
|
||||||
|
"""
|
||||||
|
# Abbreviated market name
|
||||||
|
market_name = market_data.get("market_name", "KR")
|
||||||
|
if "Korea" in market_name:
|
||||||
|
market_name = "KR"
|
||||||
|
elif "United States" in market_name or "US" in market_name:
|
||||||
|
market_name = "US"
|
||||||
|
|
||||||
|
# Core data - always included
|
||||||
|
core_info = {
|
||||||
|
"mkt": market_name,
|
||||||
|
"code": market_data["stock_code"],
|
||||||
|
"P": market_data["current_price"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Optional fields
|
||||||
|
if "orderbook" in market_data and market_data["orderbook"]:
|
||||||
|
ob = market_data["orderbook"]
|
||||||
|
# Compress orderbook: keep only top 3 levels
|
||||||
|
compressed_ob = {
|
||||||
|
"bid": ob.get("bid", [])[:3],
|
||||||
|
"ask": ob.get("ask", [])[:3],
|
||||||
|
}
|
||||||
|
core_info["ob"] = compressed_ob
|
||||||
|
|
||||||
|
if market_data.get("foreigner_net", 0) != 0:
|
||||||
|
core_info["fgn_net"] = market_data["foreigner_net"]
|
||||||
|
|
||||||
|
# Compress to JSON
|
||||||
|
data_str = PromptOptimizer.compress_json(core_info)
|
||||||
|
|
||||||
|
if include_instructions:
|
||||||
|
# Minimal instructions
|
||||||
|
prompt = (
|
||||||
|
f"{market_name} trader. Analyze:\n{data_str}\n\n"
|
||||||
|
'Return JSON: {"act":"BUY"|"SELL"|"HOLD","conf":<0-100>,"reason":"<text>"}\n'
|
||||||
|
"Rules: act=BUY/SELL/HOLD, conf=0-100, reason=concise. No markdown."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Data only (for cached contexts where instructions are known)
|
||||||
|
prompt = data_str
|
||||||
|
|
||||||
|
# Truncate if needed
|
||||||
|
if max_length and len(prompt) > max_length:
|
||||||
|
prompt = prompt[:max_length] + "..."
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def truncate_context(
|
||||||
|
context: dict[str, Any],
|
||||||
|
max_tokens: int,
|
||||||
|
priority_keys: list[str] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Truncate context data to fit within token budget.
|
||||||
|
|
||||||
|
Keeps high-priority keys first, then truncates less important data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context dictionary to truncate
|
||||||
|
max_tokens: Maximum token budget
|
||||||
|
priority_keys: List of keys to keep (in order of priority)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Truncated context dictionary
|
||||||
|
"""
|
||||||
|
if not context:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if priority_keys is None:
|
||||||
|
priority_keys = []
|
||||||
|
|
||||||
|
result: dict[str, Any] = {}
|
||||||
|
current_tokens = 0
|
||||||
|
|
||||||
|
# Add priority keys first
|
||||||
|
for key in priority_keys:
|
||||||
|
if key in context:
|
||||||
|
value_str = json.dumps(context[key])
|
||||||
|
tokens = PromptOptimizer.estimate_tokens(value_str)
|
||||||
|
|
||||||
|
if current_tokens + tokens <= max_tokens:
|
||||||
|
result[key] = context[key]
|
||||||
|
current_tokens += tokens
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Add remaining keys if space available
|
||||||
|
for key, value in context.items():
|
||||||
|
if key in result:
|
||||||
|
continue
|
||||||
|
|
||||||
|
value_str = json.dumps(value)
|
||||||
|
tokens = PromptOptimizer.estimate_tokens(value_str)
|
||||||
|
|
||||||
|
if current_tokens + tokens <= max_tokens:
|
||||||
|
result[key] = value
|
||||||
|
current_tokens += tokens
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calculate_compression_ratio(original: str, compressed: str) -> float:
|
||||||
|
"""Calculate compression ratio between original and compressed text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
original: Original text
|
||||||
|
compressed: Compressed text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compression ratio (original_tokens / compressed_tokens)
|
||||||
|
"""
|
||||||
|
original_tokens = PromptOptimizer.estimate_tokens(original)
|
||||||
|
compressed_tokens = PromptOptimizer.estimate_tokens(compressed)
|
||||||
|
|
||||||
|
if compressed_tokens == 0:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
return original_tokens / compressed_tokens
|
||||||
@@ -24,6 +24,10 @@ class Settings(BaseSettings):
|
|||||||
NEWS_API_PROVIDER: str = "alphavantage" # "alphavantage" or "newsapi"
|
NEWS_API_PROVIDER: str = "alphavantage" # "alphavantage" or "newsapi"
|
||||||
MARKET_DATA_API_KEY: str | None = None
|
MARKET_DATA_API_KEY: str | None = None
|
||||||
|
|
||||||
|
# Legacy field names (for backward compatibility)
|
||||||
|
ALPHA_VANTAGE_API_KEY: str | None = None
|
||||||
|
NEWSAPI_KEY: str | None = None
|
||||||
|
|
||||||
# Risk Management
|
# Risk Management
|
||||||
CIRCUIT_BREAKER_PCT: float = Field(default=-3.0, le=0.0)
|
CIRCUIT_BREAKER_PCT: float = Field(default=-3.0, le=0.0)
|
||||||
FAT_FINGER_PCT: float = Field(default=30.0, gt=0.0, le=100.0)
|
FAT_FINGER_PCT: float = Field(default=30.0, gt=0.0, le=100.0)
|
||||||
@@ -41,6 +45,20 @@ class Settings(BaseSettings):
|
|||||||
# Market selection (comma-separated market codes)
|
# Market selection (comma-separated market codes)
|
||||||
ENABLED_MARKETS: str = "KR"
|
ENABLED_MARKETS: str = "KR"
|
||||||
|
|
||||||
|
# Backup and Disaster Recovery (optional)
|
||||||
|
BACKUP_ENABLED: bool = True
|
||||||
|
BACKUP_DIR: str = "data/backups"
|
||||||
|
S3_ENDPOINT_URL: str | None = None # For MinIO, Backblaze B2, etc.
|
||||||
|
S3_ACCESS_KEY: str | None = None
|
||||||
|
S3_SECRET_KEY: str | None = None
|
||||||
|
S3_BUCKET_NAME: str | None = None
|
||||||
|
S3_REGION: str = "us-east-1"
|
||||||
|
|
||||||
|
# Telegram Notifications (optional)
|
||||||
|
TELEGRAM_BOT_TOKEN: str | None = None
|
||||||
|
TELEGRAM_CHAT_ID: str | None = None
|
||||||
|
TELEGRAM_ENABLED: bool = True
|
||||||
|
|
||||||
model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}
|
model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
328
src/context/summarizer.py
Normal file
328
src/context/summarizer.py
Normal file
@@ -0,0 +1,328 @@
|
|||||||
|
"""Context summarization for efficient historical data representation.
|
||||||
|
|
||||||
|
This module summarizes old context data instead of including raw details:
|
||||||
|
- Key metrics only (averages, trends, not details)
|
||||||
|
- Rolling window (keep last N days detailed, summarize older)
|
||||||
|
- Aggregate historical data efficiently
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from src.context.layer import ContextLayer
|
||||||
|
from src.context.store import ContextStore
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SummaryStats:
|
||||||
|
"""Statistical summary of historical data."""
|
||||||
|
|
||||||
|
count: int
|
||||||
|
mean: float | None = None
|
||||||
|
min: float | None = None
|
||||||
|
max: float | None = None
|
||||||
|
std: float | None = None
|
||||||
|
trend: str | None = None # "up", "down", "flat"
|
||||||
|
|
||||||
|
|
||||||
|
class ContextSummarizer:
|
||||||
|
"""Summarizes historical context data to reduce token usage."""
|
||||||
|
|
||||||
|
def __init__(self, store: ContextStore) -> None:
|
||||||
|
"""Initialize the context summarizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
store: ContextStore instance for retrieving context data
|
||||||
|
"""
|
||||||
|
self.store = store
|
||||||
|
|
||||||
|
def summarize_numeric_values(self, values: list[float]) -> SummaryStats:
|
||||||
|
"""Summarize a list of numeric values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
values: List of numeric values to summarize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SummaryStats with mean, min, max, std, and trend
|
||||||
|
"""
|
||||||
|
if not values:
|
||||||
|
return SummaryStats(count=0)
|
||||||
|
|
||||||
|
count = len(values)
|
||||||
|
mean = sum(values) / count
|
||||||
|
min_val = min(values)
|
||||||
|
max_val = max(values)
|
||||||
|
|
||||||
|
# Calculate standard deviation
|
||||||
|
if count > 1:
|
||||||
|
variance = sum((x - mean) ** 2 for x in values) / (count - 1)
|
||||||
|
std = variance**0.5
|
||||||
|
else:
|
||||||
|
std = 0.0
|
||||||
|
|
||||||
|
# Determine trend
|
||||||
|
trend = "flat"
|
||||||
|
if count >= 3:
|
||||||
|
# Simple trend: compare first third vs last third
|
||||||
|
first_third = values[: count // 3]
|
||||||
|
last_third = values[-(count // 3) :]
|
||||||
|
first_avg = sum(first_third) / len(first_third)
|
||||||
|
last_avg = sum(last_third) / len(last_third)
|
||||||
|
|
||||||
|
# Trend threshold: 5% change
|
||||||
|
threshold = 0.05 * abs(first_avg) if first_avg != 0 else 0.01
|
||||||
|
|
||||||
|
if last_avg > first_avg + threshold:
|
||||||
|
trend = "up"
|
||||||
|
elif last_avg < first_avg - threshold:
|
||||||
|
trend = "down"
|
||||||
|
|
||||||
|
return SummaryStats(
|
||||||
|
count=count,
|
||||||
|
mean=round(mean, 4),
|
||||||
|
min=round(min_val, 4),
|
||||||
|
max=round(max_val, 4),
|
||||||
|
std=round(std, 4),
|
||||||
|
trend=trend,
|
||||||
|
)
|
||||||
|
|
||||||
|
def summarize_layer(
|
||||||
|
self,
|
||||||
|
layer: ContextLayer,
|
||||||
|
start_date: datetime | None = None,
|
||||||
|
end_date: datetime | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Summarize all context data for a layer within a date range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer: Context layer to summarize
|
||||||
|
start_date: Start date (inclusive), None for all
|
||||||
|
end_date: End date (inclusive), None for now
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with summarized metrics
|
||||||
|
"""
|
||||||
|
if end_date is None:
|
||||||
|
end_date = datetime.now(UTC)
|
||||||
|
|
||||||
|
# Get all contexts for this layer
|
||||||
|
all_contexts = self.store.get_all_contexts(layer)
|
||||||
|
|
||||||
|
if not all_contexts:
|
||||||
|
return {"summary": "No data available", "count": 0}
|
||||||
|
|
||||||
|
# Group numeric values by key
|
||||||
|
numeric_data: dict[str, list[float]] = {}
|
||||||
|
text_data: dict[str, list[str]] = {}
|
||||||
|
|
||||||
|
for key, value in all_contexts.items():
|
||||||
|
# Try to extract numeric values
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
if key not in numeric_data:
|
||||||
|
numeric_data[key] = []
|
||||||
|
numeric_data[key].append(float(value))
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
# Extract numeric fields from dict
|
||||||
|
for subkey, subvalue in value.items():
|
||||||
|
if isinstance(subvalue, (int, float)):
|
||||||
|
full_key = f"{key}.{subkey}"
|
||||||
|
if full_key not in numeric_data:
|
||||||
|
numeric_data[full_key] = []
|
||||||
|
numeric_data[full_key].append(float(subvalue))
|
||||||
|
elif isinstance(value, str):
|
||||||
|
if key not in text_data:
|
||||||
|
text_data[key] = []
|
||||||
|
text_data[key].append(value)
|
||||||
|
|
||||||
|
# Summarize numeric data
|
||||||
|
summary: dict[str, Any] = {}
|
||||||
|
|
||||||
|
for key, values in numeric_data.items():
|
||||||
|
stats = self.summarize_numeric_values(values)
|
||||||
|
summary[key] = {
|
||||||
|
"count": stats.count,
|
||||||
|
"avg": stats.mean,
|
||||||
|
"range": [stats.min, stats.max],
|
||||||
|
"trend": stats.trend,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Summarize text data (just counts)
|
||||||
|
for key, values in text_data.items():
|
||||||
|
summary[f"{key}_count"] = len(values)
|
||||||
|
|
||||||
|
summary["total_entries"] = len(all_contexts)
|
||||||
|
|
||||||
|
return summary
|
||||||
|
|
||||||
|
def rolling_window_summary(
|
||||||
|
self,
|
||||||
|
layer: ContextLayer,
|
||||||
|
window_days: int = 30,
|
||||||
|
summarize_older: bool = True,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create a rolling window summary.
|
||||||
|
|
||||||
|
Recent data (within window) is kept detailed.
|
||||||
|
Older data is summarized to key metrics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer: Context layer to summarize
|
||||||
|
window_days: Number of days to keep detailed
|
||||||
|
summarize_older: Whether to summarize data older than window
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with recent (detailed) and historical (summary) data
|
||||||
|
"""
|
||||||
|
result: dict[str, Any] = {
|
||||||
|
"window_days": window_days,
|
||||||
|
"recent_data": {},
|
||||||
|
"historical_summary": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get all contexts
|
||||||
|
all_contexts = self.store.get_all_contexts(layer)
|
||||||
|
|
||||||
|
recent_values: dict[str, list[float]] = {}
|
||||||
|
historical_values: dict[str, list[float]] = {}
|
||||||
|
|
||||||
|
for key, value in all_contexts.items():
|
||||||
|
# For simplicity, treat all numeric values
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
# Note: We don't have timestamps in context keys
|
||||||
|
# This is a simplified implementation
|
||||||
|
# In practice, would need to check timeframe field
|
||||||
|
|
||||||
|
# For now, put recent data in window
|
||||||
|
if key not in recent_values:
|
||||||
|
recent_values[key] = []
|
||||||
|
recent_values[key].append(float(value))
|
||||||
|
|
||||||
|
# Detailed recent data
|
||||||
|
result["recent_data"] = {key: values[-10:] for key, values in recent_values.items()}
|
||||||
|
|
||||||
|
# Summarized historical data
|
||||||
|
if summarize_older:
|
||||||
|
for key, values in historical_values.items():
|
||||||
|
stats = self.summarize_numeric_values(values)
|
||||||
|
result["historical_summary"][key] = {
|
||||||
|
"count": stats.count,
|
||||||
|
"avg": stats.mean,
|
||||||
|
"trend": stats.trend,
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def aggregate_to_higher_layer(
|
||||||
|
self,
|
||||||
|
source_layer: ContextLayer,
|
||||||
|
target_layer: ContextLayer,
|
||||||
|
metric_key: str,
|
||||||
|
aggregation_func: str = "mean",
|
||||||
|
) -> float | None:
|
||||||
|
"""Aggregate data from source layer to target layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_layer: Source context layer (more granular)
|
||||||
|
target_layer: Target context layer (less granular)
|
||||||
|
metric_key: Key of metric to aggregate
|
||||||
|
aggregation_func: Aggregation function ("mean", "sum", "max", "min")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Aggregated value, or None if no data available
|
||||||
|
"""
|
||||||
|
# Get all contexts from source layer
|
||||||
|
source_contexts = self.store.get_all_contexts(source_layer)
|
||||||
|
|
||||||
|
# Extract values for metric_key
|
||||||
|
values = []
|
||||||
|
for key, value in source_contexts.items():
|
||||||
|
if key == metric_key and isinstance(value, (int, float)):
|
||||||
|
values.append(float(value))
|
||||||
|
elif isinstance(value, dict) and metric_key in value:
|
||||||
|
subvalue = value[metric_key]
|
||||||
|
if isinstance(subvalue, (int, float)):
|
||||||
|
values.append(float(subvalue))
|
||||||
|
|
||||||
|
if not values:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Apply aggregation function
|
||||||
|
if aggregation_func == "mean":
|
||||||
|
return sum(values) / len(values)
|
||||||
|
elif aggregation_func == "sum":
|
||||||
|
return sum(values)
|
||||||
|
elif aggregation_func == "max":
|
||||||
|
return max(values)
|
||||||
|
elif aggregation_func == "min":
|
||||||
|
return min(values)
|
||||||
|
else:
|
||||||
|
return sum(values) / len(values) # Default to mean
|
||||||
|
|
||||||
|
def create_compact_summary(
|
||||||
|
self,
|
||||||
|
layers: list[ContextLayer],
|
||||||
|
top_n_metrics: int = 5,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create a compact summary across multiple layers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layers: List of context layers to summarize
|
||||||
|
top_n_metrics: Number of top metrics to include per layer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compact summary dictionary
|
||||||
|
"""
|
||||||
|
summary: dict[str, Any] = {}
|
||||||
|
|
||||||
|
for layer in layers:
|
||||||
|
layer_summary = self.summarize_layer(layer)
|
||||||
|
|
||||||
|
# Keep only top N metrics (by count/relevance)
|
||||||
|
metrics = []
|
||||||
|
for key, value in layer_summary.items():
|
||||||
|
if isinstance(value, dict) and "count" in value:
|
||||||
|
metrics.append((key, value, value["count"]))
|
||||||
|
|
||||||
|
# Sort by count (descending)
|
||||||
|
metrics.sort(key=lambda x: x[2], reverse=True)
|
||||||
|
|
||||||
|
# Keep top N
|
||||||
|
top_metrics = {m[0]: m[1] for m in metrics[:top_n_metrics]}
|
||||||
|
|
||||||
|
summary[layer.value] = top_metrics
|
||||||
|
|
||||||
|
return summary
|
||||||
|
|
||||||
|
def format_summary_for_prompt(self, summary: dict[str, Any]) -> str:
|
||||||
|
"""Format summary for inclusion in a prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
summary: Summary dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted string for prompt
|
||||||
|
"""
|
||||||
|
lines = []
|
||||||
|
|
||||||
|
for layer, metrics in summary.items():
|
||||||
|
if not metrics:
|
||||||
|
continue
|
||||||
|
|
||||||
|
lines.append(f"{layer}:")
|
||||||
|
for key, value in metrics.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
# Format as: key: avg=X, trend=Y
|
||||||
|
parts = []
|
||||||
|
if "avg" in value and value["avg"] is not None:
|
||||||
|
parts.append(f"avg={value['avg']:.2f}")
|
||||||
|
if "trend" in value and value["trend"]:
|
||||||
|
parts.append(f"trend={value['trend']}")
|
||||||
|
if parts:
|
||||||
|
lines.append(f" {key}: {', '.join(parts)}")
|
||||||
|
else:
|
||||||
|
lines.append(f" {key}: {value}")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
@@ -23,7 +23,7 @@ from google import genai
|
|||||||
|
|
||||||
from src.config import Settings
|
from src.config import Settings
|
||||||
from src.db import init_db
|
from src.db import init_db
|
||||||
from src.logging.decision_logger import DecisionLog, DecisionLogger
|
from src.logging.decision_logger import DecisionLogger
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
105
src/main.py
105
src/main.py
@@ -10,6 +10,7 @@ import argparse
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import signal
|
import signal
|
||||||
|
import sys
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -21,13 +22,14 @@ from src.broker.overseas import OverseasBroker
|
|||||||
from src.config import Settings
|
from src.config import Settings
|
||||||
from src.context.layer import ContextLayer
|
from src.context.layer import ContextLayer
|
||||||
from src.context.store import ContextStore
|
from src.context.store import ContextStore
|
||||||
from src.core.criticality import CriticalityAssessor, CriticalityLevel
|
from src.core.criticality import CriticalityAssessor
|
||||||
from src.core.priority_queue import PriorityTaskQueue
|
from src.core.priority_queue import PriorityTaskQueue
|
||||||
from src.core.risk_manager import CircuitBreakerTripped, RiskManager
|
from src.core.risk_manager import CircuitBreakerTripped, FatFingerRejected, RiskManager
|
||||||
from src.db import init_db, log_trade
|
from src.db import init_db, log_trade
|
||||||
from src.logging.decision_logger import DecisionLogger
|
from src.logging.decision_logger import DecisionLogger
|
||||||
from src.logging_config import setup_logging
|
from src.logging_config import setup_logging
|
||||||
from src.markets.schedule import MarketInfo, get_next_market_open, get_open_markets
|
from src.markets.schedule import MarketInfo, get_next_market_open, get_open_markets
|
||||||
|
from src.notifications.telegram_client import TelegramClient
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -62,6 +64,7 @@ async def trading_cycle(
|
|||||||
decision_logger: DecisionLogger,
|
decision_logger: DecisionLogger,
|
||||||
context_store: ContextStore,
|
context_store: ContextStore,
|
||||||
criticality_assessor: CriticalityAssessor,
|
criticality_assessor: CriticalityAssessor,
|
||||||
|
telegram: TelegramClient,
|
||||||
market: MarketInfo,
|
market: MarketInfo,
|
||||||
stock_code: str,
|
stock_code: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -92,9 +95,17 @@ async def trading_cycle(
|
|||||||
balance_data = await overseas_broker.get_overseas_balance(market.exchange_code)
|
balance_data = await overseas_broker.get_overseas_balance(market.exchange_code)
|
||||||
|
|
||||||
output2 = balance_data.get("output2", [{}])
|
output2 = balance_data.get("output2", [{}])
|
||||||
total_eval = float(output2[0].get("frcr_evlu_tota", "0")) if output2 else 0
|
# Handle both list and dict response formats
|
||||||
total_cash = float(output2[0].get("frcr_dncl_amt_2", "0")) if output2 else 0
|
if isinstance(output2, list) and output2:
|
||||||
purchase_total = float(output2[0].get("frcr_buy_amt_smtl", "0")) if output2 else 0
|
balance_info = output2[0]
|
||||||
|
elif isinstance(output2, dict):
|
||||||
|
balance_info = output2
|
||||||
|
else:
|
||||||
|
balance_info = {}
|
||||||
|
|
||||||
|
total_eval = float(balance_info.get("frcr_evlu_tota", "0") or "0")
|
||||||
|
total_cash = float(balance_info.get("frcr_dncl_amt_2", "0") or "0")
|
||||||
|
purchase_total = float(balance_info.get("frcr_buy_amt_smtl", "0") or "0")
|
||||||
|
|
||||||
current_price = float(price_data.get("output", {}).get("last", "0"))
|
current_price = float(price_data.get("output", {}).get("last", "0"))
|
||||||
foreigner_net = 0.0 # Not available for overseas
|
foreigner_net = 0.0 # Not available for overseas
|
||||||
@@ -199,11 +210,23 @@ async def trading_cycle(
|
|||||||
order_amount = current_price * quantity
|
order_amount = current_price * quantity
|
||||||
|
|
||||||
# 4. Risk check BEFORE order
|
# 4. Risk check BEFORE order
|
||||||
risk.validate_order(
|
try:
|
||||||
current_pnl_pct=pnl_pct,
|
risk.validate_order(
|
||||||
order_amount=order_amount,
|
current_pnl_pct=pnl_pct,
|
||||||
total_cash=total_cash,
|
order_amount=order_amount,
|
||||||
)
|
total_cash=total_cash,
|
||||||
|
)
|
||||||
|
except FatFingerRejected as exc:
|
||||||
|
try:
|
||||||
|
await telegram.notify_fat_finger(
|
||||||
|
stock_code=stock_code,
|
||||||
|
order_amount=exc.order_amount,
|
||||||
|
total_cash=exc.total_cash,
|
||||||
|
max_pct=exc.max_pct,
|
||||||
|
)
|
||||||
|
except Exception as notify_exc:
|
||||||
|
logger.warning("Fat finger notification failed: %s", notify_exc)
|
||||||
|
raise # Re-raise to prevent trade
|
||||||
|
|
||||||
# 5. Send order
|
# 5. Send order
|
||||||
if market.is_domestic:
|
if market.is_domestic:
|
||||||
@@ -223,6 +246,19 @@ async def trading_cycle(
|
|||||||
)
|
)
|
||||||
logger.info("Order result: %s", result.get("msg1", "OK"))
|
logger.info("Order result: %s", result.get("msg1", "OK"))
|
||||||
|
|
||||||
|
# 5.5. Notify trade execution
|
||||||
|
try:
|
||||||
|
await telegram.notify_trade_execution(
|
||||||
|
stock_code=stock_code,
|
||||||
|
market=market.name,
|
||||||
|
action=decision.action,
|
||||||
|
quantity=quantity,
|
||||||
|
price=current_price,
|
||||||
|
confidence=decision.confidence,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Telegram notification failed: %s", exc)
|
||||||
|
|
||||||
# 6. Log trade
|
# 6. Log trade
|
||||||
log_trade(
|
log_trade(
|
||||||
conn=db_conn,
|
conn=db_conn,
|
||||||
@@ -266,6 +302,13 @@ async def run(settings: Settings) -> None:
|
|||||||
decision_logger = DecisionLogger(db_conn)
|
decision_logger = DecisionLogger(db_conn)
|
||||||
context_store = ContextStore(db_conn)
|
context_store = ContextStore(db_conn)
|
||||||
|
|
||||||
|
# Initialize Telegram notifications
|
||||||
|
telegram = TelegramClient(
|
||||||
|
bot_token=settings.TELEGRAM_BOT_TOKEN,
|
||||||
|
chat_id=settings.TELEGRAM_CHAT_ID,
|
||||||
|
enabled=settings.TELEGRAM_ENABLED,
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize volatility hunter
|
# Initialize volatility hunter
|
||||||
volatility_analyzer = VolatilityAnalyzer(min_volume_surge=2.0, min_price_change=1.0)
|
volatility_analyzer = VolatilityAnalyzer(min_volume_surge=2.0, min_price_change=1.0)
|
||||||
market_scanner = MarketScanner(
|
market_scanner = MarketScanner(
|
||||||
@@ -289,6 +332,9 @@ async def run(settings: Settings) -> None:
|
|||||||
# Track last scan time for each market
|
# Track last scan time for each market
|
||||||
last_scan_time: dict[str, float] = {}
|
last_scan_time: dict[str, float] = {}
|
||||||
|
|
||||||
|
# Track market open/close state for notifications
|
||||||
|
_market_states: dict[str, bool] = {} # market_code -> is_open
|
||||||
|
|
||||||
shutdown = asyncio.Event()
|
shutdown = asyncio.Event()
|
||||||
|
|
||||||
def _signal_handler() -> None:
|
def _signal_handler() -> None:
|
||||||
@@ -302,12 +348,31 @@ async def run(settings: Settings) -> None:
|
|||||||
logger.info("The Ouroboros is alive. Mode: %s", settings.MODE)
|
logger.info("The Ouroboros is alive. Mode: %s", settings.MODE)
|
||||||
logger.info("Enabled markets: %s", settings.enabled_market_list)
|
logger.info("Enabled markets: %s", settings.enabled_market_list)
|
||||||
|
|
||||||
|
# Notify system startup
|
||||||
|
try:
|
||||||
|
await telegram.notify_system_start(settings.MODE, settings.enabled_market_list)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("System startup notification failed: %s", exc)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while not shutdown.is_set():
|
while not shutdown.is_set():
|
||||||
# Get currently open markets
|
# Get currently open markets
|
||||||
open_markets = get_open_markets(settings.enabled_market_list)
|
open_markets = get_open_markets(settings.enabled_market_list)
|
||||||
|
|
||||||
if not open_markets:
|
if not open_markets:
|
||||||
|
# Notify market close for any markets that were open
|
||||||
|
for market_code, is_open in list(_market_states.items()):
|
||||||
|
if is_open:
|
||||||
|
try:
|
||||||
|
from src.markets.schedule import MARKETS
|
||||||
|
|
||||||
|
market_info = MARKETS.get(market_code)
|
||||||
|
if market_info:
|
||||||
|
await telegram.notify_market_close(market_info.name, 0.0)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Market close notification failed: %s", exc)
|
||||||
|
_market_states[market_code] = False
|
||||||
|
|
||||||
# No markets open — wait until next market opens
|
# No markets open — wait until next market opens
|
||||||
try:
|
try:
|
||||||
next_market, next_open_time = get_next_market_open(
|
next_market, next_open_time = get_next_market_open(
|
||||||
@@ -333,6 +398,14 @@ async def run(settings: Settings) -> None:
|
|||||||
if shutdown.is_set():
|
if shutdown.is_set():
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Notify market open if it just opened
|
||||||
|
if not _market_states.get(market.code, False):
|
||||||
|
try:
|
||||||
|
await telegram.notify_market_open(market.name)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Market open notification failed: %s", exc)
|
||||||
|
_market_states[market.code] = True
|
||||||
|
|
||||||
# Volatility Hunter: Scan market periodically to update watchlist
|
# Volatility Hunter: Scan market periodically to update watchlist
|
||||||
now_timestamp = asyncio.get_event_loop().time()
|
now_timestamp = asyncio.get_event_loop().time()
|
||||||
last_scan = last_scan_time.get(market.code, 0.0)
|
last_scan = last_scan_time.get(market.code, 0.0)
|
||||||
@@ -391,12 +464,22 @@ async def run(settings: Settings) -> None:
|
|||||||
decision_logger,
|
decision_logger,
|
||||||
context_store,
|
context_store,
|
||||||
criticality_assessor,
|
criticality_assessor,
|
||||||
|
telegram,
|
||||||
market,
|
market,
|
||||||
stock_code,
|
stock_code,
|
||||||
)
|
)
|
||||||
break # Success — exit retry loop
|
break # Success — exit retry loop
|
||||||
except CircuitBreakerTripped:
|
except CircuitBreakerTripped as exc:
|
||||||
logger.critical("Circuit breaker tripped — shutting down")
|
logger.critical("Circuit breaker tripped — shutting down")
|
||||||
|
try:
|
||||||
|
await telegram.notify_circuit_breaker(
|
||||||
|
pnl_pct=exc.pnl_pct,
|
||||||
|
threshold=exc.threshold,
|
||||||
|
)
|
||||||
|
except Exception as notify_exc:
|
||||||
|
logger.warning(
|
||||||
|
"Circuit breaker notification failed: %s", notify_exc
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
except ConnectionError as exc:
|
except ConnectionError as exc:
|
||||||
if attempt < MAX_CONNECTION_RETRIES:
|
if attempt < MAX_CONNECTION_RETRIES:
|
||||||
|
|||||||
213
src/notifications/README.md
Normal file
213
src/notifications/README.md
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
# Telegram Notifications
|
||||||
|
|
||||||
|
Real-time trading event notifications via Telegram Bot API.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
### 1. Create a Telegram Bot
|
||||||
|
|
||||||
|
1. Open Telegram and message [@BotFather](https://t.me/BotFather)
|
||||||
|
2. Send `/newbot` command
|
||||||
|
3. Follow prompts to name your bot
|
||||||
|
4. Save the **bot token** (looks like `1234567890:ABCdefGHIjklMNOpqrsTUVwxyz`)
|
||||||
|
|
||||||
|
### 2. Get Your Chat ID
|
||||||
|
|
||||||
|
**Option A: Using @userinfobot**
|
||||||
|
1. Message [@userinfobot](https://t.me/userinfobot) on Telegram
|
||||||
|
2. Send `/start`
|
||||||
|
3. Save your numeric **chat ID** (e.g., `123456789`)
|
||||||
|
|
||||||
|
**Option B: Using @RawDataBot**
|
||||||
|
1. Message [@RawDataBot](https://t.me/rawdatabot) on Telegram
|
||||||
|
2. Look for `"id":` in the JSON response
|
||||||
|
3. Save your numeric **chat ID**
|
||||||
|
|
||||||
|
### 3. Configure Environment
|
||||||
|
|
||||||
|
Add to your `.env` file:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
TELEGRAM_BOT_TOKEN=1234567890:ABCdefGHIjklMNOpqrsTUVwxyz
|
||||||
|
TELEGRAM_CHAT_ID=123456789
|
||||||
|
TELEGRAM_ENABLED=true
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Test the Bot
|
||||||
|
|
||||||
|
Start a conversation with your bot on Telegram first (send `/start`), then run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m src.main --mode=paper
|
||||||
|
```
|
||||||
|
|
||||||
|
You should receive a startup notification.
|
||||||
|
|
||||||
|
## Message Examples
|
||||||
|
|
||||||
|
### Trade Execution
|
||||||
|
```
|
||||||
|
🟢 BUY
|
||||||
|
Symbol: AAPL (United States)
|
||||||
|
Quantity: 10 shares
|
||||||
|
Price: 150.25
|
||||||
|
Confidence: 85%
|
||||||
|
```
|
||||||
|
|
||||||
|
### Circuit Breaker
|
||||||
|
```
|
||||||
|
🚨 CIRCUIT BREAKER TRIPPED
|
||||||
|
P&L: -3.15% (threshold: -3.0%)
|
||||||
|
Trading halted for safety
|
||||||
|
```
|
||||||
|
|
||||||
|
### Fat-Finger Protection
|
||||||
|
```
|
||||||
|
⚠️ Fat-Finger Protection
|
||||||
|
Order rejected: TSLA
|
||||||
|
Attempted: 45.0% of cash
|
||||||
|
Max allowed: 30%
|
||||||
|
Amount: 45,000 / 100,000
|
||||||
|
```
|
||||||
|
|
||||||
|
### Market Open/Close
|
||||||
|
```
|
||||||
|
ℹ️ Market Open
|
||||||
|
Korea trading session started
|
||||||
|
|
||||||
|
ℹ️ Market Close
|
||||||
|
Korea trading session ended
|
||||||
|
📈 P&L: +1.25%
|
||||||
|
```
|
||||||
|
|
||||||
|
### System Status
|
||||||
|
```
|
||||||
|
📝 System Started
|
||||||
|
Mode: PAPER
|
||||||
|
Markets: KRX, NASDAQ
|
||||||
|
|
||||||
|
System Shutdown
|
||||||
|
Normal shutdown
|
||||||
|
```
|
||||||
|
|
||||||
|
## Notification Priorities
|
||||||
|
|
||||||
|
| Priority | Emoji | Use Case |
|
||||||
|
|----------|-------|----------|
|
||||||
|
| LOW | ℹ️ | Market open/close |
|
||||||
|
| MEDIUM | 📊 | Trade execution, system start/stop |
|
||||||
|
| HIGH | ⚠️ | Fat-finger protection, errors |
|
||||||
|
| CRITICAL | 🚨 | Circuit breaker trips |
|
||||||
|
|
||||||
|
## Rate Limiting
|
||||||
|
|
||||||
|
- Default: 1 message per second
|
||||||
|
- Prevents hitting Telegram's global rate limits
|
||||||
|
- Configurable via `rate_limit` parameter
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### No notifications received
|
||||||
|
|
||||||
|
1. **Check bot configuration**
|
||||||
|
```bash
|
||||||
|
# Verify env variables are set
|
||||||
|
grep TELEGRAM .env
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Start conversation with bot**
|
||||||
|
- Open bot in Telegram
|
||||||
|
- Send `/start` command
|
||||||
|
- Bot cannot message users who haven't started a conversation
|
||||||
|
|
||||||
|
3. **Check logs**
|
||||||
|
```bash
|
||||||
|
# Look for Telegram-related errors
|
||||||
|
python -m src.main --mode=paper 2>&1 | grep -i telegram
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Verify bot token**
|
||||||
|
```bash
|
||||||
|
curl https://api.telegram.org/bot<YOUR_TOKEN>/getMe
|
||||||
|
# Should return bot info (not 401 error)
|
||||||
|
```
|
||||||
|
|
||||||
|
5. **Verify chat ID**
|
||||||
|
```bash
|
||||||
|
curl -X POST https://api.telegram.org/bot<YOUR_TOKEN>/sendMessage \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{"chat_id": "<YOUR_CHAT_ID>", "text": "Test"}'
|
||||||
|
# Should send a test message
|
||||||
|
```
|
||||||
|
|
||||||
|
### Notifications delayed
|
||||||
|
|
||||||
|
- Check rate limiter settings
|
||||||
|
- Verify network connection
|
||||||
|
- Look for timeout errors in logs
|
||||||
|
|
||||||
|
### "Chat not found" error
|
||||||
|
|
||||||
|
- Incorrect chat ID
|
||||||
|
- Bot blocked by user
|
||||||
|
- Need to send `/start` to bot first
|
||||||
|
|
||||||
|
### "Unauthorized" error
|
||||||
|
|
||||||
|
- Invalid bot token
|
||||||
|
- Token revoked (regenerate with @BotFather)
|
||||||
|
|
||||||
|
## Graceful Degradation
|
||||||
|
|
||||||
|
The system works without Telegram notifications:
|
||||||
|
|
||||||
|
- Missing credentials → notifications disabled automatically
|
||||||
|
- API errors → logged but trading continues
|
||||||
|
- Network timeouts → trading loop unaffected
|
||||||
|
- Rate limiting → messages queued, trading proceeds
|
||||||
|
|
||||||
|
**Notifications never crash the trading system.**
|
||||||
|
|
||||||
|
## Security Notes
|
||||||
|
|
||||||
|
- Never commit `.env` file with credentials
|
||||||
|
- Bot token grants full bot control
|
||||||
|
- Chat ID is not sensitive (just a number)
|
||||||
|
- Messages are sent over HTTPS
|
||||||
|
- No trading credentials in notifications
|
||||||
|
|
||||||
|
## Advanced Usage
|
||||||
|
|
||||||
|
### Group Notifications
|
||||||
|
|
||||||
|
1. Add bot to Telegram group
|
||||||
|
2. Get group chat ID (negative number like `-123456789`)
|
||||||
|
3. Use group chat ID in `TELEGRAM_CHAT_ID`
|
||||||
|
|
||||||
|
### Multiple Recipients
|
||||||
|
|
||||||
|
Create multiple bots or use a broadcast group with multiple members.
|
||||||
|
|
||||||
|
### Custom Rate Limits
|
||||||
|
|
||||||
|
Not currently exposed in config, but can be modified in code:
|
||||||
|
|
||||||
|
```python
|
||||||
|
telegram = TelegramClient(
|
||||||
|
bot_token=settings.TELEGRAM_BOT_TOKEN,
|
||||||
|
chat_id=settings.TELEGRAM_CHAT_ID,
|
||||||
|
rate_limit=2.0, # 2 messages per second
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Reference
|
||||||
|
|
||||||
|
See `telegram_client.py` for full API documentation.
|
||||||
|
|
||||||
|
Key methods:
|
||||||
|
- `notify_trade_execution()` - Trade alerts
|
||||||
|
- `notify_circuit_breaker()` - Emergency stops
|
||||||
|
- `notify_fat_finger()` - Order rejections
|
||||||
|
- `notify_market_open/close()` - Session tracking
|
||||||
|
- `notify_system_start/shutdown()` - Lifecycle events
|
||||||
|
- `notify_error()` - Error alerts
|
||||||
5
src/notifications/__init__.py
Normal file
5
src/notifications/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Real-time notifications for trading events."""
|
||||||
|
|
||||||
|
from src.notifications.telegram_client import TelegramClient
|
||||||
|
|
||||||
|
__all__ = ["TelegramClient"]
|
||||||
325
src/notifications/telegram_client.py
Normal file
325
src/notifications/telegram_client.py
Normal file
@@ -0,0 +1,325 @@
|
|||||||
|
"""Telegram notification client for real-time trading alerts."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class NotificationPriority(Enum):
|
||||||
|
"""Priority levels for notifications with emoji indicators."""
|
||||||
|
|
||||||
|
LOW = ("ℹ️", "info")
|
||||||
|
MEDIUM = ("📊", "medium")
|
||||||
|
HIGH = ("⚠️", "warning")
|
||||||
|
CRITICAL = ("🚨", "critical")
|
||||||
|
|
||||||
|
def __init__(self, emoji: str, label: str) -> None:
|
||||||
|
self.emoji = emoji
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
|
||||||
|
class LeakyBucket:
|
||||||
|
"""Rate limiter using leaky bucket algorithm."""
|
||||||
|
|
||||||
|
def __init__(self, rate: float, capacity: int = 1) -> None:
|
||||||
|
"""
|
||||||
|
Initialize rate limiter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rate: Maximum requests per second
|
||||||
|
capacity: Bucket capacity (burst size)
|
||||||
|
"""
|
||||||
|
self._rate = rate
|
||||||
|
self._capacity = capacity
|
||||||
|
self._tokens = float(capacity)
|
||||||
|
self._last_update = time.monotonic()
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def acquire(self) -> None:
|
||||||
|
"""Wait until a token is available, then consume it."""
|
||||||
|
async with self._lock:
|
||||||
|
now = time.monotonic()
|
||||||
|
elapsed = now - self._last_update
|
||||||
|
self._tokens = min(self._capacity, self._tokens + elapsed * self._rate)
|
||||||
|
self._last_update = now
|
||||||
|
|
||||||
|
if self._tokens < 1.0:
|
||||||
|
wait_time = (1.0 - self._tokens) / self._rate
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
self._tokens = 0.0
|
||||||
|
else:
|
||||||
|
self._tokens -= 1.0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NotificationMessage:
|
||||||
|
"""Internal notification message structure."""
|
||||||
|
|
||||||
|
priority: NotificationPriority
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class TelegramClient:
|
||||||
|
"""Telegram Bot API client for sending trading notifications."""
|
||||||
|
|
||||||
|
API_BASE = "https://api.telegram.org/bot{token}"
|
||||||
|
DEFAULT_TIMEOUT = 5.0 # seconds
|
||||||
|
DEFAULT_RATE = 1.0 # messages per second
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
bot_token: str | None = None,
|
||||||
|
chat_id: str | None = None,
|
||||||
|
enabled: bool = True,
|
||||||
|
rate_limit: float = DEFAULT_RATE,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize Telegram client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bot_token: Telegram bot token from @BotFather
|
||||||
|
chat_id: Target chat ID (user or group)
|
||||||
|
enabled: Enable/disable notifications globally
|
||||||
|
rate_limit: Maximum messages per second
|
||||||
|
"""
|
||||||
|
self._bot_token = bot_token
|
||||||
|
self._chat_id = chat_id
|
||||||
|
self._enabled = enabled
|
||||||
|
self._rate_limiter = LeakyBucket(rate=rate_limit)
|
||||||
|
self._session: aiohttp.ClientSession | None = None
|
||||||
|
|
||||||
|
if not enabled:
|
||||||
|
logger.info("Telegram notifications disabled via configuration")
|
||||||
|
elif bot_token is None or chat_id is None:
|
||||||
|
logger.warning(
|
||||||
|
"Telegram notifications disabled (missing bot_token or chat_id)"
|
||||||
|
)
|
||||||
|
self._enabled = False
|
||||||
|
else:
|
||||||
|
logger.info("Telegram notifications enabled for chat_id=%s", chat_id)
|
||||||
|
|
||||||
|
def _get_session(self) -> aiohttp.ClientSession:
|
||||||
|
"""Get or create aiohttp session."""
|
||||||
|
if self._session is None or self._session.closed:
|
||||||
|
self._session = aiohttp.ClientSession(
|
||||||
|
timeout=aiohttp.ClientTimeout(total=self.DEFAULT_TIMEOUT)
|
||||||
|
)
|
||||||
|
return self._session
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close HTTP session."""
|
||||||
|
if self._session is not None and not self._session.closed:
|
||||||
|
await self._session.close()
|
||||||
|
|
||||||
|
async def _send_notification(self, msg: NotificationMessage) -> None:
|
||||||
|
"""
|
||||||
|
Send notification to Telegram with graceful degradation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msg: Notification message to send
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._rate_limiter.acquire()
|
||||||
|
|
||||||
|
formatted_message = f"{msg.priority.emoji} {msg.message}"
|
||||||
|
url = f"{self.API_BASE.format(token=self._bot_token)}/sendMessage"
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"chat_id": self._chat_id,
|
||||||
|
"text": formatted_message,
|
||||||
|
"parse_mode": "HTML",
|
||||||
|
}
|
||||||
|
|
||||||
|
session = self._get_session()
|
||||||
|
async with session.post(url, json=payload) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
error_text = await resp.text()
|
||||||
|
logger.error(
|
||||||
|
"Telegram API error (status=%d): %s", resp.status, error_text
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug("Telegram notification sent: %s", msg.message[:50])
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error("Telegram notification timeout")
|
||||||
|
except aiohttp.ClientError as exc:
|
||||||
|
logger.error("Telegram notification failed: %s", exc)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Unexpected error sending notification: %s", exc)
|
||||||
|
|
||||||
|
async def notify_trade_execution(
|
||||||
|
self,
|
||||||
|
stock_code: str,
|
||||||
|
market: str,
|
||||||
|
action: str,
|
||||||
|
quantity: int,
|
||||||
|
price: float,
|
||||||
|
confidence: float,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Notify trade execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stock_code: Stock ticker symbol
|
||||||
|
market: Market name (e.g., "Korea", "United States")
|
||||||
|
action: "BUY" or "SELL"
|
||||||
|
quantity: Number of shares
|
||||||
|
price: Execution price
|
||||||
|
confidence: AI confidence level (0-100)
|
||||||
|
"""
|
||||||
|
emoji = "🟢" if action == "BUY" else "🔴"
|
||||||
|
message = (
|
||||||
|
f"<b>{emoji} {action}</b>\n"
|
||||||
|
f"Symbol: <code>{stock_code}</code> ({market})\n"
|
||||||
|
f"Quantity: {quantity:,} shares\n"
|
||||||
|
f"Price: {price:,.2f}\n"
|
||||||
|
f"Confidence: {confidence:.0f}%"
|
||||||
|
)
|
||||||
|
await self._send_notification(
|
||||||
|
NotificationMessage(priority=NotificationPriority.MEDIUM, message=message)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def notify_market_open(self, market_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Notify market opening.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
market_name: Name of the market (e.g., "Korea", "United States")
|
||||||
|
"""
|
||||||
|
message = f"<b>Market Open</b>\n{market_name} trading session started"
|
||||||
|
await self._send_notification(
|
||||||
|
NotificationMessage(priority=NotificationPriority.LOW, message=message)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def notify_market_close(self, market_name: str, pnl_pct: float) -> None:
|
||||||
|
"""
|
||||||
|
Notify market closing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
market_name: Name of the market
|
||||||
|
pnl_pct: Final P&L percentage for the session
|
||||||
|
"""
|
||||||
|
pnl_sign = "+" if pnl_pct >= 0 else ""
|
||||||
|
pnl_emoji = "📈" if pnl_pct >= 0 else "📉"
|
||||||
|
message = (
|
||||||
|
f"<b>Market Close</b>\n"
|
||||||
|
f"{market_name} trading session ended\n"
|
||||||
|
f"{pnl_emoji} P&L: {pnl_sign}{pnl_pct:.2f}%"
|
||||||
|
)
|
||||||
|
await self._send_notification(
|
||||||
|
NotificationMessage(priority=NotificationPriority.LOW, message=message)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def notify_circuit_breaker(
|
||||||
|
self, pnl_pct: float, threshold: float
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Notify circuit breaker activation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pnl_pct: Current P&L percentage
|
||||||
|
threshold: Circuit breaker threshold
|
||||||
|
"""
|
||||||
|
message = (
|
||||||
|
f"<b>CIRCUIT BREAKER TRIPPED</b>\n"
|
||||||
|
f"P&L: {pnl_pct:.2f}% (threshold: {threshold:.1f}%)\n"
|
||||||
|
f"Trading halted for safety"
|
||||||
|
)
|
||||||
|
await self._send_notification(
|
||||||
|
NotificationMessage(priority=NotificationPriority.CRITICAL, message=message)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def notify_fat_finger(
|
||||||
|
self,
|
||||||
|
stock_code: str,
|
||||||
|
order_amount: float,
|
||||||
|
total_cash: float,
|
||||||
|
max_pct: float,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Notify fat-finger protection rejection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stock_code: Stock ticker symbol
|
||||||
|
order_amount: Attempted order amount
|
||||||
|
total_cash: Total available cash
|
||||||
|
max_pct: Maximum allowed percentage
|
||||||
|
"""
|
||||||
|
attempted_pct = (order_amount / total_cash) * 100 if total_cash > 0 else 0
|
||||||
|
message = (
|
||||||
|
f"<b>Fat-Finger Protection</b>\n"
|
||||||
|
f"Order rejected: <code>{stock_code}</code>\n"
|
||||||
|
f"Attempted: {attempted_pct:.1f}% of cash\n"
|
||||||
|
f"Max allowed: {max_pct:.0f}%\n"
|
||||||
|
f"Amount: {order_amount:,.0f} / {total_cash:,.0f}"
|
||||||
|
)
|
||||||
|
await self._send_notification(
|
||||||
|
NotificationMessage(priority=NotificationPriority.HIGH, message=message)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def notify_system_start(
|
||||||
|
self, mode: str, enabled_markets: list[str]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Notify system startup.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mode: Trading mode ("paper" or "live")
|
||||||
|
enabled_markets: List of enabled market codes
|
||||||
|
"""
|
||||||
|
mode_emoji = "📝" if mode == "paper" else "💰"
|
||||||
|
markets_str = ", ".join(enabled_markets)
|
||||||
|
message = (
|
||||||
|
f"<b>{mode_emoji} System Started</b>\n"
|
||||||
|
f"Mode: {mode.upper()}\n"
|
||||||
|
f"Markets: {markets_str}"
|
||||||
|
)
|
||||||
|
await self._send_notification(
|
||||||
|
NotificationMessage(priority=NotificationPriority.MEDIUM, message=message)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def notify_system_shutdown(self, reason: str) -> None:
|
||||||
|
"""
|
||||||
|
Notify system shutdown.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reason: Reason for shutdown (e.g., "Normal shutdown", "Circuit breaker")
|
||||||
|
"""
|
||||||
|
message = f"<b>System Shutdown</b>\n{reason}"
|
||||||
|
priority = (
|
||||||
|
NotificationPriority.CRITICAL
|
||||||
|
if "circuit breaker" in reason.lower()
|
||||||
|
else NotificationPriority.MEDIUM
|
||||||
|
)
|
||||||
|
await self._send_notification(
|
||||||
|
NotificationMessage(priority=priority, message=message)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def notify_error(
|
||||||
|
self, error_type: str, error_msg: str, context: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Notify system error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_type: Type of error (e.g., "Connection Error")
|
||||||
|
error_msg: Error message
|
||||||
|
context: Error context (e.g., stock code, market)
|
||||||
|
"""
|
||||||
|
message = (
|
||||||
|
f"<b>Error: {error_type}</b>\n"
|
||||||
|
f"Context: {context}\n"
|
||||||
|
f"Message: {error_msg[:200]}" # Truncate long errors
|
||||||
|
)
|
||||||
|
await self._send_notification(
|
||||||
|
NotificationMessage(priority=NotificationPriority.HIGH, message=message)
|
||||||
|
)
|
||||||
365
tests/test_backup.py
Normal file
365
tests/test_backup.py
Normal file
@@ -0,0 +1,365 @@
|
|||||||
|
"""Tests for backup and disaster recovery system."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
import tempfile
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.backup.exporter import BackupExporter, ExportFormat
|
||||||
|
from src.backup.health_monitor import HealthMonitor, HealthStatus
|
||||||
|
from src.backup.scheduler import BackupPolicy, BackupScheduler
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_db(tmp_path: Path) -> Path:
|
||||||
|
"""Create a temporary test database."""
|
||||||
|
db_path = tmp_path / "test_trades.db"
|
||||||
|
|
||||||
|
conn = sqlite3.connect(str(db_path))
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# Create trades table
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE trades (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
timestamp TEXT NOT NULL,
|
||||||
|
stock_code TEXT NOT NULL,
|
||||||
|
action TEXT NOT NULL,
|
||||||
|
quantity INTEGER NOT NULL,
|
||||||
|
price REAL NOT NULL,
|
||||||
|
confidence INTEGER NOT NULL,
|
||||||
|
rationale TEXT,
|
||||||
|
pnl REAL DEFAULT 0.0
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Insert test data
|
||||||
|
test_trades = [
|
||||||
|
("2024-01-01T10:00:00Z", "005930", "BUY", 10, 70000.0, 85, "Test buy", 0.0),
|
||||||
|
("2024-01-01T11:00:00Z", "005930", "SELL", 10, 71000.0, 90, "Test sell", 10000.0),
|
||||||
|
("2024-01-02T10:00:00Z", "AAPL", "BUY", 5, 180.0, 88, "Tech buy", 0.0),
|
||||||
|
]
|
||||||
|
|
||||||
|
cursor.executemany(
|
||||||
|
"""
|
||||||
|
INSERT INTO trades (timestamp, stock_code, action, quantity, price, confidence, rationale, pnl)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
test_trades,
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
return db_path
|
||||||
|
|
||||||
|
|
||||||
|
class TestBackupExporter:
|
||||||
|
"""Test BackupExporter functionality."""
|
||||||
|
|
||||||
|
def test_exporter_init(self, temp_db: Path) -> None:
|
||||||
|
"""Test exporter initialization."""
|
||||||
|
exporter = BackupExporter(str(temp_db))
|
||||||
|
assert exporter.db_path == str(temp_db)
|
||||||
|
|
||||||
|
def test_export_json(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test JSON export."""
|
||||||
|
exporter = BackupExporter(str(temp_db))
|
||||||
|
output_dir = tmp_path / "exports"
|
||||||
|
|
||||||
|
results = exporter.export_all(
|
||||||
|
output_dir, formats=[ExportFormat.JSON], compress=False
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ExportFormat.JSON in results
|
||||||
|
assert results[ExportFormat.JSON].exists()
|
||||||
|
assert results[ExportFormat.JSON].suffix == ".json"
|
||||||
|
|
||||||
|
def test_export_json_compressed(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test compressed JSON export."""
|
||||||
|
exporter = BackupExporter(str(temp_db))
|
||||||
|
output_dir = tmp_path / "exports"
|
||||||
|
|
||||||
|
results = exporter.export_all(
|
||||||
|
output_dir, formats=[ExportFormat.JSON], compress=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ExportFormat.JSON in results
|
||||||
|
assert results[ExportFormat.JSON].suffix == ".gz"
|
||||||
|
|
||||||
|
def test_export_csv(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test CSV export."""
|
||||||
|
exporter = BackupExporter(str(temp_db))
|
||||||
|
output_dir = tmp_path / "exports"
|
||||||
|
|
||||||
|
results = exporter.export_all(
|
||||||
|
output_dir, formats=[ExportFormat.CSV], compress=False
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ExportFormat.CSV in results
|
||||||
|
assert results[ExportFormat.CSV].exists()
|
||||||
|
|
||||||
|
# Verify CSV content
|
||||||
|
with open(results[ExportFormat.CSV], "r") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
assert len(lines) == 4 # Header + 3 rows
|
||||||
|
|
||||||
|
def test_export_all_formats(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test exporting all formats."""
|
||||||
|
exporter = BackupExporter(str(temp_db))
|
||||||
|
output_dir = tmp_path / "exports"
|
||||||
|
|
||||||
|
# Skip Parquet if pyarrow not available
|
||||||
|
try:
|
||||||
|
import pyarrow # noqa: F401
|
||||||
|
|
||||||
|
formats = [ExportFormat.JSON, ExportFormat.CSV, ExportFormat.PARQUET]
|
||||||
|
except ImportError:
|
||||||
|
formats = [ExportFormat.JSON, ExportFormat.CSV]
|
||||||
|
|
||||||
|
results = exporter.export_all(output_dir, formats=formats, compress=False)
|
||||||
|
|
||||||
|
for fmt in formats:
|
||||||
|
assert fmt in results
|
||||||
|
assert results[fmt].exists()
|
||||||
|
|
||||||
|
def test_incremental_export(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test incremental export."""
|
||||||
|
exporter = BackupExporter(str(temp_db))
|
||||||
|
output_dir = tmp_path / "exports"
|
||||||
|
|
||||||
|
# Export only trades after Jan 2
|
||||||
|
cutoff = datetime(2024, 1, 2, tzinfo=UTC)
|
||||||
|
results = exporter.export_all(
|
||||||
|
output_dir,
|
||||||
|
formats=[ExportFormat.JSON],
|
||||||
|
compress=False,
|
||||||
|
incremental_since=cutoff,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should only have 1 trade (AAPL on Jan 2)
|
||||||
|
import json
|
||||||
|
|
||||||
|
with open(results[ExportFormat.JSON], "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
assert data["record_count"] == 1
|
||||||
|
assert data["trades"][0]["stock_code"] == "AAPL"
|
||||||
|
|
||||||
|
def test_get_export_stats(self, temp_db: Path) -> None:
|
||||||
|
"""Test export statistics."""
|
||||||
|
exporter = BackupExporter(str(temp_db))
|
||||||
|
stats = exporter.get_export_stats()
|
||||||
|
|
||||||
|
assert stats["total_trades"] == 3
|
||||||
|
assert "date_range" in stats
|
||||||
|
assert "db_size_bytes" in stats
|
||||||
|
|
||||||
|
|
||||||
|
class TestBackupScheduler:
|
||||||
|
"""Test BackupScheduler functionality."""
|
||||||
|
|
||||||
|
def test_scheduler_init(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test scheduler initialization."""
|
||||||
|
backup_dir = tmp_path / "backups"
|
||||||
|
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||||
|
|
||||||
|
assert scheduler.db_path == temp_db
|
||||||
|
assert (backup_dir / "daily").exists()
|
||||||
|
assert (backup_dir / "weekly").exists()
|
||||||
|
assert (backup_dir / "monthly").exists()
|
||||||
|
|
||||||
|
def test_create_daily_backup(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test daily backup creation."""
|
||||||
|
backup_dir = tmp_path / "backups"
|
||||||
|
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||||
|
|
||||||
|
metadata = scheduler.create_backup(BackupPolicy.DAILY, verify=True)
|
||||||
|
|
||||||
|
assert metadata.policy == BackupPolicy.DAILY
|
||||||
|
assert metadata.file_path.exists()
|
||||||
|
assert metadata.size_bytes > 0
|
||||||
|
assert metadata.checksum is not None
|
||||||
|
|
||||||
|
def test_create_weekly_backup(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test weekly backup creation."""
|
||||||
|
backup_dir = tmp_path / "backups"
|
||||||
|
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||||
|
|
||||||
|
metadata = scheduler.create_backup(BackupPolicy.WEEKLY, verify=False)
|
||||||
|
|
||||||
|
assert metadata.policy == BackupPolicy.WEEKLY
|
||||||
|
assert metadata.file_path.exists()
|
||||||
|
assert metadata.checksum is None # verify=False
|
||||||
|
|
||||||
|
def test_list_backups(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test listing backups."""
|
||||||
|
backup_dir = tmp_path / "backups"
|
||||||
|
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||||
|
|
||||||
|
scheduler.create_backup(BackupPolicy.DAILY)
|
||||||
|
scheduler.create_backup(BackupPolicy.WEEKLY)
|
||||||
|
|
||||||
|
backups = scheduler.list_backups()
|
||||||
|
assert len(backups) == 2
|
||||||
|
|
||||||
|
daily_backups = scheduler.list_backups(BackupPolicy.DAILY)
|
||||||
|
assert len(daily_backups) == 1
|
||||||
|
assert daily_backups[0].policy == BackupPolicy.DAILY
|
||||||
|
|
||||||
|
def test_cleanup_old_backups(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test cleanup of old backups."""
|
||||||
|
backup_dir = tmp_path / "backups"
|
||||||
|
scheduler = BackupScheduler(str(temp_db), backup_dir, daily_retention_days=0)
|
||||||
|
|
||||||
|
# Create a backup
|
||||||
|
scheduler.create_backup(BackupPolicy.DAILY)
|
||||||
|
|
||||||
|
# Cleanup should remove it (0 day retention)
|
||||||
|
removed = scheduler.cleanup_old_backups()
|
||||||
|
assert removed[BackupPolicy.DAILY] >= 1
|
||||||
|
|
||||||
|
def test_backup_stats(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test backup statistics."""
|
||||||
|
backup_dir = tmp_path / "backups"
|
||||||
|
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||||
|
|
||||||
|
scheduler.create_backup(BackupPolicy.DAILY)
|
||||||
|
scheduler.create_backup(BackupPolicy.MONTHLY)
|
||||||
|
|
||||||
|
stats = scheduler.get_backup_stats()
|
||||||
|
|
||||||
|
assert stats["daily"]["count"] == 1
|
||||||
|
assert stats["monthly"]["count"] == 1
|
||||||
|
assert stats["daily"]["total_size_bytes"] > 0
|
||||||
|
|
||||||
|
def test_restore_backup(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test backup restoration."""
|
||||||
|
backup_dir = tmp_path / "backups"
|
||||||
|
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||||
|
|
||||||
|
# Create backup
|
||||||
|
metadata = scheduler.create_backup(BackupPolicy.DAILY)
|
||||||
|
|
||||||
|
# Modify database
|
||||||
|
conn = sqlite3.connect(str(temp_db))
|
||||||
|
conn.execute("DELETE FROM trades")
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
# Restore
|
||||||
|
scheduler.restore_backup(metadata, verify=True)
|
||||||
|
|
||||||
|
# Verify restoration
|
||||||
|
conn = sqlite3.connect(str(temp_db))
|
||||||
|
cursor = conn.execute("SELECT COUNT(*) FROM trades")
|
||||||
|
count = cursor.fetchone()[0]
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
assert count == 3 # Original 3 trades restored
|
||||||
|
|
||||||
|
|
||||||
|
class TestHealthMonitor:
|
||||||
|
"""Test HealthMonitor functionality."""
|
||||||
|
|
||||||
|
def test_monitor_init(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test monitor initialization."""
|
||||||
|
backup_dir = tmp_path / "backups"
|
||||||
|
monitor = HealthMonitor(str(temp_db), backup_dir)
|
||||||
|
|
||||||
|
assert monitor.db_path == temp_db
|
||||||
|
|
||||||
|
def test_check_database_health_ok(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test database health check (healthy)."""
|
||||||
|
monitor = HealthMonitor(str(temp_db), tmp_path / "backups")
|
||||||
|
result = monitor.check_database_health()
|
||||||
|
|
||||||
|
assert result.status == HealthStatus.HEALTHY
|
||||||
|
assert "healthy" in result.message.lower()
|
||||||
|
assert result.details is not None
|
||||||
|
assert result.details["trade_count"] == 3
|
||||||
|
|
||||||
|
def test_check_database_health_missing(self, tmp_path: Path) -> None:
|
||||||
|
"""Test database health check (missing file)."""
|
||||||
|
non_existent = tmp_path / "missing.db"
|
||||||
|
monitor = HealthMonitor(str(non_existent), tmp_path / "backups")
|
||||||
|
result = monitor.check_database_health()
|
||||||
|
|
||||||
|
assert result.status == HealthStatus.UNHEALTHY
|
||||||
|
assert "not found" in result.message.lower()
|
||||||
|
|
||||||
|
def test_check_disk_space(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test disk space check."""
|
||||||
|
monitor = HealthMonitor(str(temp_db), tmp_path, min_disk_space_gb=0.001)
|
||||||
|
result = monitor.check_disk_space()
|
||||||
|
|
||||||
|
# Should be healthy with minimal requirement
|
||||||
|
assert result.status in [HealthStatus.HEALTHY, HealthStatus.DEGRADED]
|
||||||
|
assert result.details is not None
|
||||||
|
assert "free_gb" in result.details
|
||||||
|
|
||||||
|
def test_check_backup_recency_no_backups(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test backup recency check (no backups)."""
|
||||||
|
backup_dir = tmp_path / "backups"
|
||||||
|
backup_dir.mkdir()
|
||||||
|
(backup_dir / "daily").mkdir()
|
||||||
|
|
||||||
|
monitor = HealthMonitor(str(temp_db), backup_dir)
|
||||||
|
result = monitor.check_backup_recency()
|
||||||
|
|
||||||
|
assert result.status == HealthStatus.UNHEALTHY
|
||||||
|
assert "no" in result.message.lower()
|
||||||
|
|
||||||
|
def test_check_backup_recency_recent(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test backup recency check (recent backup)."""
|
||||||
|
backup_dir = tmp_path / "backups"
|
||||||
|
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||||
|
scheduler.create_backup(BackupPolicy.DAILY)
|
||||||
|
|
||||||
|
monitor = HealthMonitor(str(temp_db), backup_dir)
|
||||||
|
result = monitor.check_backup_recency()
|
||||||
|
|
||||||
|
assert result.status == HealthStatus.HEALTHY
|
||||||
|
assert "recent" in result.message.lower()
|
||||||
|
|
||||||
|
def test_run_all_checks(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test running all health checks."""
|
||||||
|
backup_dir = tmp_path / "backups"
|
||||||
|
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||||
|
scheduler.create_backup(BackupPolicy.DAILY)
|
||||||
|
|
||||||
|
monitor = HealthMonitor(str(temp_db), backup_dir, min_disk_space_gb=0.001)
|
||||||
|
checks = monitor.run_all_checks()
|
||||||
|
|
||||||
|
assert "database" in checks
|
||||||
|
assert "disk_space" in checks
|
||||||
|
assert "backup_recency" in checks
|
||||||
|
assert checks["database"].status == HealthStatus.HEALTHY
|
||||||
|
|
||||||
|
def test_get_overall_status(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test overall health status."""
|
||||||
|
backup_dir = tmp_path / "backups"
|
||||||
|
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||||
|
scheduler.create_backup(BackupPolicy.DAILY)
|
||||||
|
|
||||||
|
monitor = HealthMonitor(str(temp_db), backup_dir, min_disk_space_gb=0.001)
|
||||||
|
status = monitor.get_overall_status()
|
||||||
|
|
||||||
|
assert status in [HealthStatus.HEALTHY, HealthStatus.DEGRADED]
|
||||||
|
|
||||||
|
def test_get_health_report(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test health report generation."""
|
||||||
|
backup_dir = tmp_path / "backups"
|
||||||
|
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||||
|
scheduler.create_backup(BackupPolicy.DAILY)
|
||||||
|
|
||||||
|
monitor = HealthMonitor(str(temp_db), backup_dir, min_disk_space_gb=0.001)
|
||||||
|
report = monitor.get_health_report()
|
||||||
|
|
||||||
|
assert "overall_status" in report
|
||||||
|
assert "timestamp" in report
|
||||||
|
assert "checks" in report
|
||||||
|
assert len(report["checks"]) == 3
|
||||||
@@ -11,15 +11,15 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import tempfile
|
import tempfile
|
||||||
from datetime import UTC, datetime, timedelta
|
from datetime import UTC, datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from src.config import Settings
|
from src.config import Settings
|
||||||
from src.db import init_db, log_trade
|
from src.db import init_db, log_trade
|
||||||
from src.evolution.ab_test import ABTester, ABTestResult, StrategyPerformance
|
from src.evolution.ab_test import ABTester
|
||||||
from src.evolution.optimizer import EvolutionOptimizer
|
from src.evolution.optimizer import EvolutionOptimizer
|
||||||
from src.evolution.performance_tracker import (
|
from src.evolution.performance_tracker import (
|
||||||
PerformanceDashboard,
|
PerformanceDashboard,
|
||||||
@@ -28,7 +28,6 @@ from src.evolution.performance_tracker import (
|
|||||||
)
|
)
|
||||||
from src.logging.decision_logger import DecisionLogger
|
from src.logging.decision_logger import DecisionLogger
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Fixtures
|
# Fixtures
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|||||||
561
tests/test_main.py
Normal file
561
tests/test_main.py
Normal file
@@ -0,0 +1,561 @@
|
|||||||
|
"""Tests for main trading loop telegram integration."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.core.risk_manager import CircuitBreakerTripped, FatFingerRejected
|
||||||
|
from src.main import trading_cycle
|
||||||
|
|
||||||
|
|
||||||
|
class TestTradingCycleTelegramIntegration:
|
||||||
|
"""Test telegram notifications in trading_cycle function."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_broker(self) -> MagicMock:
|
||||||
|
"""Create mock broker."""
|
||||||
|
broker = MagicMock()
|
||||||
|
broker.get_orderbook = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"output1": {
|
||||||
|
"stck_prpr": "50000",
|
||||||
|
"frgn_ntby_qty": "100",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
broker.get_balance = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"output2": [
|
||||||
|
{
|
||||||
|
"tot_evlu_amt": "10000000",
|
||||||
|
"dnca_tot_amt": "5000000",
|
||||||
|
"pchs_amt_smtl_amt": "5000000",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
broker.send_order = AsyncMock(return_value={"msg1": "OK"})
|
||||||
|
return broker
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_overseas_broker(self) -> MagicMock:
|
||||||
|
"""Create mock overseas broker."""
|
||||||
|
broker = MagicMock()
|
||||||
|
return broker
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_brain(self) -> MagicMock:
|
||||||
|
"""Create mock brain that decides to buy."""
|
||||||
|
brain = MagicMock()
|
||||||
|
decision = MagicMock()
|
||||||
|
decision.action = "BUY"
|
||||||
|
decision.confidence = 85
|
||||||
|
decision.rationale = "Test buy"
|
||||||
|
brain.decide = AsyncMock(return_value=decision)
|
||||||
|
return brain
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_risk(self) -> MagicMock:
|
||||||
|
"""Create mock risk manager."""
|
||||||
|
risk = MagicMock()
|
||||||
|
risk.validate_order = MagicMock()
|
||||||
|
return risk
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_db(self) -> MagicMock:
|
||||||
|
"""Create mock database connection."""
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_decision_logger(self) -> MagicMock:
|
||||||
|
"""Create mock decision logger."""
|
||||||
|
logger = MagicMock()
|
||||||
|
logger.log_decision = MagicMock()
|
||||||
|
return logger
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_context_store(self) -> MagicMock:
|
||||||
|
"""Create mock context store."""
|
||||||
|
store = MagicMock()
|
||||||
|
store.get_latest_timeframe = MagicMock(return_value=None)
|
||||||
|
return store
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_criticality_assessor(self) -> MagicMock:
|
||||||
|
"""Create mock criticality assessor."""
|
||||||
|
assessor = MagicMock()
|
||||||
|
assessor.assess_market_conditions = MagicMock(
|
||||||
|
return_value=MagicMock(value="NORMAL")
|
||||||
|
)
|
||||||
|
assessor.get_timeout = MagicMock(return_value=5.0)
|
||||||
|
return assessor
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_telegram(self) -> MagicMock:
|
||||||
|
"""Create mock telegram client."""
|
||||||
|
telegram = MagicMock()
|
||||||
|
telegram.notify_trade_execution = AsyncMock()
|
||||||
|
telegram.notify_fat_finger = AsyncMock()
|
||||||
|
telegram.notify_circuit_breaker = AsyncMock()
|
||||||
|
return telegram
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_market(self) -> MagicMock:
|
||||||
|
"""Create mock market info."""
|
||||||
|
market = MagicMock()
|
||||||
|
market.name = "Korea"
|
||||||
|
market.code = "KR"
|
||||||
|
market.exchange_code = "KRX"
|
||||||
|
market.is_domestic = True
|
||||||
|
return market
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trade_execution_notification_sent(
|
||||||
|
self,
|
||||||
|
mock_broker: MagicMock,
|
||||||
|
mock_overseas_broker: MagicMock,
|
||||||
|
mock_brain: MagicMock,
|
||||||
|
mock_risk: MagicMock,
|
||||||
|
mock_db: MagicMock,
|
||||||
|
mock_decision_logger: MagicMock,
|
||||||
|
mock_context_store: MagicMock,
|
||||||
|
mock_criticality_assessor: MagicMock,
|
||||||
|
mock_telegram: MagicMock,
|
||||||
|
mock_market: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test telegram notification sent on trade execution."""
|
||||||
|
with patch("src.main.log_trade"):
|
||||||
|
await trading_cycle(
|
||||||
|
broker=mock_broker,
|
||||||
|
overseas_broker=mock_overseas_broker,
|
||||||
|
brain=mock_brain,
|
||||||
|
risk=mock_risk,
|
||||||
|
db_conn=mock_db,
|
||||||
|
decision_logger=mock_decision_logger,
|
||||||
|
context_store=mock_context_store,
|
||||||
|
criticality_assessor=mock_criticality_assessor,
|
||||||
|
telegram=mock_telegram,
|
||||||
|
market=mock_market,
|
||||||
|
stock_code="005930",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify notification was sent
|
||||||
|
mock_telegram.notify_trade_execution.assert_called_once()
|
||||||
|
call_kwargs = mock_telegram.notify_trade_execution.call_args.kwargs
|
||||||
|
assert call_kwargs["stock_code"] == "005930"
|
||||||
|
assert call_kwargs["market"] == "Korea"
|
||||||
|
assert call_kwargs["action"] == "BUY"
|
||||||
|
assert call_kwargs["confidence"] == 85
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trade_execution_notification_failure_doesnt_crash(
|
||||||
|
self,
|
||||||
|
mock_broker: MagicMock,
|
||||||
|
mock_overseas_broker: MagicMock,
|
||||||
|
mock_brain: MagicMock,
|
||||||
|
mock_risk: MagicMock,
|
||||||
|
mock_db: MagicMock,
|
||||||
|
mock_decision_logger: MagicMock,
|
||||||
|
mock_context_store: MagicMock,
|
||||||
|
mock_criticality_assessor: MagicMock,
|
||||||
|
mock_telegram: MagicMock,
|
||||||
|
mock_market: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test trading continues even if notification fails."""
|
||||||
|
# Make notification fail
|
||||||
|
mock_telegram.notify_trade_execution.side_effect = Exception("API error")
|
||||||
|
|
||||||
|
with patch("src.main.log_trade"):
|
||||||
|
# Should not raise exception
|
||||||
|
await trading_cycle(
|
||||||
|
broker=mock_broker,
|
||||||
|
overseas_broker=mock_overseas_broker,
|
||||||
|
brain=mock_brain,
|
||||||
|
risk=mock_risk,
|
||||||
|
db_conn=mock_db,
|
||||||
|
decision_logger=mock_decision_logger,
|
||||||
|
context_store=mock_context_store,
|
||||||
|
criticality_assessor=mock_criticality_assessor,
|
||||||
|
telegram=mock_telegram,
|
||||||
|
market=mock_market,
|
||||||
|
stock_code="005930",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify notification was attempted
|
||||||
|
mock_telegram.notify_trade_execution.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fat_finger_notification_sent(
|
||||||
|
self,
|
||||||
|
mock_broker: MagicMock,
|
||||||
|
mock_overseas_broker: MagicMock,
|
||||||
|
mock_brain: MagicMock,
|
||||||
|
mock_risk: MagicMock,
|
||||||
|
mock_db: MagicMock,
|
||||||
|
mock_decision_logger: MagicMock,
|
||||||
|
mock_context_store: MagicMock,
|
||||||
|
mock_criticality_assessor: MagicMock,
|
||||||
|
mock_telegram: MagicMock,
|
||||||
|
mock_market: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test telegram notification sent on fat finger rejection."""
|
||||||
|
# Make risk manager reject the order
|
||||||
|
mock_risk.validate_order.side_effect = FatFingerRejected(
|
||||||
|
order_amount=2000000,
|
||||||
|
total_cash=5000000,
|
||||||
|
max_pct=30.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("src.main.log_trade"):
|
||||||
|
with pytest.raises(FatFingerRejected):
|
||||||
|
await trading_cycle(
|
||||||
|
broker=mock_broker,
|
||||||
|
overseas_broker=mock_overseas_broker,
|
||||||
|
brain=mock_brain,
|
||||||
|
risk=mock_risk,
|
||||||
|
db_conn=mock_db,
|
||||||
|
decision_logger=mock_decision_logger,
|
||||||
|
context_store=mock_context_store,
|
||||||
|
criticality_assessor=mock_criticality_assessor,
|
||||||
|
telegram=mock_telegram,
|
||||||
|
market=mock_market,
|
||||||
|
stock_code="005930",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify notification was sent
|
||||||
|
mock_telegram.notify_fat_finger.assert_called_once()
|
||||||
|
call_kwargs = mock_telegram.notify_fat_finger.call_args.kwargs
|
||||||
|
assert call_kwargs["stock_code"] == "005930"
|
||||||
|
assert call_kwargs["order_amount"] == 2000000
|
||||||
|
assert call_kwargs["total_cash"] == 5000000
|
||||||
|
assert call_kwargs["max_pct"] == 30.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fat_finger_notification_failure_still_raises(
|
||||||
|
self,
|
||||||
|
mock_broker: MagicMock,
|
||||||
|
mock_overseas_broker: MagicMock,
|
||||||
|
mock_brain: MagicMock,
|
||||||
|
mock_risk: MagicMock,
|
||||||
|
mock_db: MagicMock,
|
||||||
|
mock_decision_logger: MagicMock,
|
||||||
|
mock_context_store: MagicMock,
|
||||||
|
mock_criticality_assessor: MagicMock,
|
||||||
|
mock_telegram: MagicMock,
|
||||||
|
mock_market: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test fat finger exception still raised even if notification fails."""
|
||||||
|
# Make risk manager reject the order
|
||||||
|
mock_risk.validate_order.side_effect = FatFingerRejected(
|
||||||
|
order_amount=2000000,
|
||||||
|
total_cash=5000000,
|
||||||
|
max_pct=30.0,
|
||||||
|
)
|
||||||
|
# Make notification fail
|
||||||
|
mock_telegram.notify_fat_finger.side_effect = Exception("API error")
|
||||||
|
|
||||||
|
with patch("src.main.log_trade"):
|
||||||
|
with pytest.raises(FatFingerRejected):
|
||||||
|
await trading_cycle(
|
||||||
|
broker=mock_broker,
|
||||||
|
overseas_broker=mock_overseas_broker,
|
||||||
|
brain=mock_brain,
|
||||||
|
risk=mock_risk,
|
||||||
|
db_conn=mock_db,
|
||||||
|
decision_logger=mock_decision_logger,
|
||||||
|
context_store=mock_context_store,
|
||||||
|
criticality_assessor=mock_criticality_assessor,
|
||||||
|
telegram=mock_telegram,
|
||||||
|
market=mock_market,
|
||||||
|
stock_code="005930",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify notification was attempted
|
||||||
|
mock_telegram.notify_fat_finger.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_notification_on_hold_decision(
|
||||||
|
self,
|
||||||
|
mock_broker: MagicMock,
|
||||||
|
mock_overseas_broker: MagicMock,
|
||||||
|
mock_brain: MagicMock,
|
||||||
|
mock_risk: MagicMock,
|
||||||
|
mock_db: MagicMock,
|
||||||
|
mock_decision_logger: MagicMock,
|
||||||
|
mock_context_store: MagicMock,
|
||||||
|
mock_criticality_assessor: MagicMock,
|
||||||
|
mock_telegram: MagicMock,
|
||||||
|
mock_market: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test no trade notification sent when decision is HOLD."""
|
||||||
|
# Change brain decision to HOLD
|
||||||
|
decision = MagicMock()
|
||||||
|
decision.action = "HOLD"
|
||||||
|
decision.confidence = 50
|
||||||
|
decision.rationale = "Insufficient signal"
|
||||||
|
mock_brain.decide = AsyncMock(return_value=decision)
|
||||||
|
|
||||||
|
with patch("src.main.log_trade"):
|
||||||
|
await trading_cycle(
|
||||||
|
broker=mock_broker,
|
||||||
|
overseas_broker=mock_overseas_broker,
|
||||||
|
brain=mock_brain,
|
||||||
|
risk=mock_risk,
|
||||||
|
db_conn=mock_db,
|
||||||
|
decision_logger=mock_decision_logger,
|
||||||
|
context_store=mock_context_store,
|
||||||
|
criticality_assessor=mock_criticality_assessor,
|
||||||
|
telegram=mock_telegram,
|
||||||
|
market=mock_market,
|
||||||
|
stock_code="005930",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify no trade notification sent
|
||||||
|
mock_telegram.notify_trade_execution.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunFunctionTelegramIntegration:
|
||||||
|
"""Test telegram notifications in run function."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_circuit_breaker_notification_sent(self) -> None:
|
||||||
|
"""Test telegram notification sent when circuit breaker trips."""
|
||||||
|
mock_telegram = MagicMock()
|
||||||
|
mock_telegram.notify_circuit_breaker = AsyncMock()
|
||||||
|
|
||||||
|
# Simulate circuit breaker exception
|
||||||
|
exc = CircuitBreakerTripped(pnl_pct=-3.5, threshold=-3.0)
|
||||||
|
|
||||||
|
# Test the notification logic
|
||||||
|
try:
|
||||||
|
await mock_telegram.notify_circuit_breaker(
|
||||||
|
pnl_pct=exc.pnl_pct,
|
||||||
|
threshold=exc.threshold,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Ignore errors in notification
|
||||||
|
|
||||||
|
# Verify notification was called
|
||||||
|
mock_telegram.notify_circuit_breaker.assert_called_once_with(
|
||||||
|
pnl_pct=-3.5,
|
||||||
|
threshold=-3.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOverseasBalanceParsing:
|
||||||
|
"""Test overseas balance output2 parsing handles different formats."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_overseas_broker_with_list(self) -> MagicMock:
|
||||||
|
"""Create mock overseas broker returning list format."""
|
||||||
|
broker = MagicMock()
|
||||||
|
broker.get_overseas_price = AsyncMock(
|
||||||
|
return_value={"output": {"last": "150.50"}}
|
||||||
|
)
|
||||||
|
broker.get_overseas_balance = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"output2": [
|
||||||
|
{
|
||||||
|
"frcr_evlu_tota": "10000.00",
|
||||||
|
"frcr_dncl_amt_2": "5000.00",
|
||||||
|
"frcr_buy_amt_smtl": "4500.00",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return broker
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_overseas_broker_with_dict(self) -> MagicMock:
|
||||||
|
"""Create mock overseas broker returning dict format."""
|
||||||
|
broker = MagicMock()
|
||||||
|
broker.get_overseas_price = AsyncMock(
|
||||||
|
return_value={"output": {"last": "150.50"}}
|
||||||
|
)
|
||||||
|
broker.get_overseas_balance = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"output2": {
|
||||||
|
"frcr_evlu_tota": "10000.00",
|
||||||
|
"frcr_dncl_amt_2": "5000.00",
|
||||||
|
"frcr_buy_amt_smtl": "4500.00",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return broker
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_overseas_broker_with_empty(self) -> MagicMock:
|
||||||
|
"""Create mock overseas broker returning empty output2."""
|
||||||
|
broker = MagicMock()
|
||||||
|
broker.get_overseas_price = AsyncMock(
|
||||||
|
return_value={"output": {"last": "150.50"}}
|
||||||
|
)
|
||||||
|
broker.get_overseas_balance = AsyncMock(return_value={"output2": []})
|
||||||
|
return broker
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_domestic_broker(self) -> MagicMock:
|
||||||
|
"""Create minimal mock domestic broker."""
|
||||||
|
broker = MagicMock()
|
||||||
|
return broker
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_overseas_market(self) -> MagicMock:
|
||||||
|
"""Create mock overseas market info."""
|
||||||
|
market = MagicMock()
|
||||||
|
market.name = "NASDAQ"
|
||||||
|
market.code = "US_NASDAQ"
|
||||||
|
market.exchange_code = "NASD"
|
||||||
|
market.is_domestic = False
|
||||||
|
return market
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_brain_hold(self) -> MagicMock:
|
||||||
|
"""Create mock brain that always holds."""
|
||||||
|
brain = MagicMock()
|
||||||
|
decision = MagicMock()
|
||||||
|
decision.action = "HOLD"
|
||||||
|
decision.confidence = 50
|
||||||
|
decision.rationale = "Testing balance parsing"
|
||||||
|
brain.decide = AsyncMock(return_value=decision)
|
||||||
|
return brain
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_risk(self) -> MagicMock:
|
||||||
|
"""Create mock risk manager."""
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_db(self) -> MagicMock:
|
||||||
|
"""Create mock database."""
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_decision_logger(self) -> MagicMock:
|
||||||
|
"""Create mock decision logger."""
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_context_store(self) -> MagicMock:
|
||||||
|
"""Create mock context store."""
|
||||||
|
store = MagicMock()
|
||||||
|
store.get_latest_timeframe = MagicMock(return_value=None)
|
||||||
|
return store
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_criticality_assessor(self) -> MagicMock:
|
||||||
|
"""Create mock criticality assessor."""
|
||||||
|
assessor = MagicMock()
|
||||||
|
assessor.assess_market_conditions = MagicMock(
|
||||||
|
return_value=MagicMock(value="NORMAL")
|
||||||
|
)
|
||||||
|
assessor.get_timeout = MagicMock(return_value=5.0)
|
||||||
|
return assessor
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_telegram(self) -> MagicMock:
|
||||||
|
"""Create mock telegram client."""
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_overseas_balance_list_format(
|
||||||
|
self,
|
||||||
|
mock_domestic_broker: MagicMock,
|
||||||
|
mock_overseas_broker_with_list: MagicMock,
|
||||||
|
mock_brain_hold: MagicMock,
|
||||||
|
mock_risk: MagicMock,
|
||||||
|
mock_db: MagicMock,
|
||||||
|
mock_decision_logger: MagicMock,
|
||||||
|
mock_context_store: MagicMock,
|
||||||
|
mock_criticality_assessor: MagicMock,
|
||||||
|
mock_telegram: MagicMock,
|
||||||
|
mock_overseas_market: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test overseas balance parsing with list format (output2=[{...}])."""
|
||||||
|
with patch("src.main.log_trade"):
|
||||||
|
# Should not raise KeyError
|
||||||
|
await trading_cycle(
|
||||||
|
broker=mock_domestic_broker,
|
||||||
|
overseas_broker=mock_overseas_broker_with_list,
|
||||||
|
brain=mock_brain_hold,
|
||||||
|
risk=mock_risk,
|
||||||
|
db_conn=mock_db,
|
||||||
|
decision_logger=mock_decision_logger,
|
||||||
|
context_store=mock_context_store,
|
||||||
|
criticality_assessor=mock_criticality_assessor,
|
||||||
|
telegram=mock_telegram,
|
||||||
|
market=mock_overseas_market,
|
||||||
|
stock_code="AAPL",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify balance API was called
|
||||||
|
mock_overseas_broker_with_list.get_overseas_balance.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_overseas_balance_dict_format(
|
||||||
|
self,
|
||||||
|
mock_domestic_broker: MagicMock,
|
||||||
|
mock_overseas_broker_with_dict: MagicMock,
|
||||||
|
mock_brain_hold: MagicMock,
|
||||||
|
mock_risk: MagicMock,
|
||||||
|
mock_db: MagicMock,
|
||||||
|
mock_decision_logger: MagicMock,
|
||||||
|
mock_context_store: MagicMock,
|
||||||
|
mock_criticality_assessor: MagicMock,
|
||||||
|
mock_telegram: MagicMock,
|
||||||
|
mock_overseas_market: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test overseas balance parsing with dict format (output2={...})."""
|
||||||
|
with patch("src.main.log_trade"):
|
||||||
|
# Should not raise KeyError
|
||||||
|
await trading_cycle(
|
||||||
|
broker=mock_domestic_broker,
|
||||||
|
overseas_broker=mock_overseas_broker_with_dict,
|
||||||
|
brain=mock_brain_hold,
|
||||||
|
risk=mock_risk,
|
||||||
|
db_conn=mock_db,
|
||||||
|
decision_logger=mock_decision_logger,
|
||||||
|
context_store=mock_context_store,
|
||||||
|
criticality_assessor=mock_criticality_assessor,
|
||||||
|
telegram=mock_telegram,
|
||||||
|
market=mock_overseas_market,
|
||||||
|
stock_code="AAPL",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify balance API was called
|
||||||
|
mock_overseas_broker_with_dict.get_overseas_balance.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_overseas_balance_empty_format(
|
||||||
|
self,
|
||||||
|
mock_domestic_broker: MagicMock,
|
||||||
|
mock_overseas_broker_with_empty: MagicMock,
|
||||||
|
mock_brain_hold: MagicMock,
|
||||||
|
mock_risk: MagicMock,
|
||||||
|
mock_db: MagicMock,
|
||||||
|
mock_decision_logger: MagicMock,
|
||||||
|
mock_context_store: MagicMock,
|
||||||
|
mock_criticality_assessor: MagicMock,
|
||||||
|
mock_telegram: MagicMock,
|
||||||
|
mock_overseas_market: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test overseas balance parsing with empty output2."""
|
||||||
|
with patch("src.main.log_trade"):
|
||||||
|
# Should not raise KeyError, should default to 0
|
||||||
|
await trading_cycle(
|
||||||
|
broker=mock_domestic_broker,
|
||||||
|
overseas_broker=mock_overseas_broker_with_empty,
|
||||||
|
brain=mock_brain_hold,
|
||||||
|
risk=mock_risk,
|
||||||
|
db_conn=mock_db,
|
||||||
|
decision_logger=mock_decision_logger,
|
||||||
|
context_store=mock_context_store,
|
||||||
|
criticality_assessor=mock_criticality_assessor,
|
||||||
|
telegram=mock_telegram,
|
||||||
|
market=mock_overseas_market,
|
||||||
|
stock_code="AAPL",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify balance API was called
|
||||||
|
mock_overseas_broker_with_empty.get_overseas_balance.assert_called_once()
|
||||||
269
tests/test_telegram.py
Normal file
269
tests/test_telegram.py
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
"""Tests for Telegram notification client."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.notifications.telegram_client import NotificationPriority, TelegramClient
|
||||||
|
|
||||||
|
|
||||||
|
class TestTelegramClientInit:
|
||||||
|
"""Test client initialization scenarios."""
|
||||||
|
|
||||||
|
def test_disabled_via_flag(self) -> None:
|
||||||
|
"""Client disabled via enabled=False flag."""
|
||||||
|
client = TelegramClient(
|
||||||
|
bot_token="123:abc", chat_id="456", enabled=False
|
||||||
|
)
|
||||||
|
assert client._enabled is False
|
||||||
|
|
||||||
|
def test_disabled_missing_token(self) -> None:
|
||||||
|
"""Client disabled when bot_token is None."""
|
||||||
|
client = TelegramClient(bot_token=None, chat_id="456", enabled=True)
|
||||||
|
assert client._enabled is False
|
||||||
|
|
||||||
|
def test_disabled_missing_chat_id(self) -> None:
|
||||||
|
"""Client disabled when chat_id is None."""
|
||||||
|
client = TelegramClient(bot_token="123:abc", chat_id=None, enabled=True)
|
||||||
|
assert client._enabled is False
|
||||||
|
|
||||||
|
def test_enabled_with_credentials(self) -> None:
|
||||||
|
"""Client enabled when credentials provided."""
|
||||||
|
client = TelegramClient(
|
||||||
|
bot_token="123:abc", chat_id="456", enabled=True
|
||||||
|
)
|
||||||
|
assert client._enabled is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestNotificationSending:
|
||||||
|
"""Test notification sending behavior."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_send_when_disabled(self) -> None:
|
||||||
|
"""Notifications not sent when client disabled."""
|
||||||
|
client = TelegramClient(enabled=False)
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession.post") as mock_post:
|
||||||
|
await client.notify_trade_execution(
|
||||||
|
stock_code="AAPL",
|
||||||
|
market="United States",
|
||||||
|
action="BUY",
|
||||||
|
quantity=10,
|
||||||
|
price=150.0,
|
||||||
|
confidence=85.0,
|
||||||
|
)
|
||||||
|
mock_post.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trade_execution_format(self) -> None:
|
||||||
|
"""Trade notification has correct format."""
|
||||||
|
client = TelegramClient(
|
||||||
|
bot_token="123:abc", chat_id="456", enabled=True
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_resp = AsyncMock()
|
||||||
|
mock_resp.status = 200
|
||||||
|
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||||
|
await client.notify_trade_execution(
|
||||||
|
stock_code="TSLA",
|
||||||
|
market="United States",
|
||||||
|
action="SELL",
|
||||||
|
quantity=5,
|
||||||
|
price=250.50,
|
||||||
|
confidence=92.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify API call was made
|
||||||
|
assert mock_post.call_count == 1
|
||||||
|
call_args = mock_post.call_args
|
||||||
|
|
||||||
|
# Check payload structure
|
||||||
|
payload = call_args.kwargs["json"]
|
||||||
|
assert payload["chat_id"] == "456"
|
||||||
|
assert "TSLA" in payload["text"]
|
||||||
|
assert "SELL" in payload["text"]
|
||||||
|
assert "5" in payload["text"]
|
||||||
|
assert "250.50" in payload["text"]
|
||||||
|
assert "92%" in payload["text"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_circuit_breaker_priority(self) -> None:
|
||||||
|
"""Circuit breaker uses CRITICAL priority."""
|
||||||
|
client = TelegramClient(
|
||||||
|
bot_token="123:abc", chat_id="456", enabled=True
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_resp = AsyncMock()
|
||||||
|
mock_resp.status = 200
|
||||||
|
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||||
|
await client.notify_circuit_breaker(pnl_pct=-3.15, threshold=-3.0)
|
||||||
|
|
||||||
|
payload = mock_post.call_args.kwargs["json"]
|
||||||
|
# CRITICAL priority has 🚨 emoji
|
||||||
|
assert NotificationPriority.CRITICAL.emoji in payload["text"]
|
||||||
|
assert "-3.15%" in payload["text"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_error_handling(self) -> None:
|
||||||
|
"""API errors logged but don't crash."""
|
||||||
|
client = TelegramClient(
|
||||||
|
bot_token="123:abc", chat_id="456", enabled=True
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_resp = AsyncMock()
|
||||||
|
mock_resp.status = 400
|
||||||
|
mock_resp.text = AsyncMock(return_value="Bad Request")
|
||||||
|
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
|
||||||
|
# Should not raise exception
|
||||||
|
await client.notify_system_start(mode="paper", enabled_markets=["KR"])
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_timeout_handling(self) -> None:
|
||||||
|
"""Timeouts logged but don't crash."""
|
||||||
|
client = TelegramClient(
|
||||||
|
bot_token="123:abc", chat_id="456", enabled=True
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"aiohttp.ClientSession.post",
|
||||||
|
side_effect=aiohttp.ClientError("Connection timeout"),
|
||||||
|
):
|
||||||
|
# Should not raise exception
|
||||||
|
await client.notify_error(
|
||||||
|
error_type="Test Error", error_msg="Test", context="test"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_session_management(self) -> None:
|
||||||
|
"""Session created and reused correctly."""
|
||||||
|
client = TelegramClient(
|
||||||
|
bot_token="123:abc", chat_id="456", enabled=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Session should be None initially
|
||||||
|
assert client._session is None
|
||||||
|
|
||||||
|
mock_resp = AsyncMock()
|
||||||
|
mock_resp.status = 200
|
||||||
|
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
|
||||||
|
await client.notify_market_open("Korea")
|
||||||
|
# Session should be created
|
||||||
|
assert client._session is not None
|
||||||
|
|
||||||
|
session1 = client._session
|
||||||
|
await client.notify_market_close("Korea", 1.5)
|
||||||
|
# Same session should be reused
|
||||||
|
assert client._session is session1
|
||||||
|
|
||||||
|
|
||||||
|
class TestRateLimiting:
|
||||||
|
"""Test rate limiter behavior."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rate_limiter_enforced(self) -> None:
|
||||||
|
"""Rate limiter delays rapid requests."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
client = TelegramClient(
|
||||||
|
bot_token="123:abc", chat_id="456", enabled=True, rate_limit=2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_resp = AsyncMock()
|
||||||
|
mock_resp.status = 200
|
||||||
|
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
|
||||||
|
start = time.monotonic()
|
||||||
|
|
||||||
|
# Send 3 messages (rate: 2/sec = 0.5s per message)
|
||||||
|
await client.notify_market_open("Korea")
|
||||||
|
await client.notify_market_open("United States")
|
||||||
|
await client.notify_market_open("Japan")
|
||||||
|
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
|
||||||
|
# Should take at least 0.4 seconds (3 msgs at 2/sec with some tolerance)
|
||||||
|
assert elapsed >= 0.4
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessagePriorities:
|
||||||
|
"""Test priority-based messaging."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_low_priority_uses_info_emoji(self) -> None:
|
||||||
|
"""LOW priority uses ℹ️ emoji."""
|
||||||
|
client = TelegramClient(
|
||||||
|
bot_token="123:abc", chat_id="456", enabled=True
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_resp = AsyncMock()
|
||||||
|
mock_resp.status = 200
|
||||||
|
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||||
|
await client.notify_market_open("Korea")
|
||||||
|
|
||||||
|
payload = mock_post.call_args.kwargs["json"]
|
||||||
|
assert NotificationPriority.LOW.emoji in payload["text"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_critical_priority_uses_alarm_emoji(self) -> None:
|
||||||
|
"""CRITICAL priority uses 🚨 emoji."""
|
||||||
|
client = TelegramClient(
|
||||||
|
bot_token="123:abc", chat_id="456", enabled=True
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_resp = AsyncMock()
|
||||||
|
mock_resp.status = 200
|
||||||
|
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||||
|
await client.notify_system_shutdown("Circuit breaker tripped")
|
||||||
|
|
||||||
|
payload = mock_post.call_args.kwargs["json"]
|
||||||
|
assert NotificationPriority.CRITICAL.emoji in payload["text"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestClientCleanup:
|
||||||
|
"""Test client cleanup behavior."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close_closes_session(self) -> None:
|
||||||
|
"""close() closes the HTTP session."""
|
||||||
|
client = TelegramClient(
|
||||||
|
bot_token="123:abc", chat_id="456", enabled=True
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_session.closed = False
|
||||||
|
mock_session.close = AsyncMock()
|
||||||
|
client._session = mock_session
|
||||||
|
|
||||||
|
await client.close()
|
||||||
|
mock_session.close.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close_handles_no_session(self) -> None:
|
||||||
|
"""close() handles None session gracefully."""
|
||||||
|
client = TelegramClient(
|
||||||
|
bot_token="123:abc", chat_id="456", enabled=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not raise exception
|
||||||
|
await client.close()
|
||||||
663
tests/test_token_efficiency.py
Normal file
663
tests/test_token_efficiency.py
Normal file
@@ -0,0 +1,663 @@
|
|||||||
|
"""Tests for token efficiency optimization components.
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- Prompt compression and optimization
|
||||||
|
- Context selection logic
|
||||||
|
- Summarization
|
||||||
|
- Caching
|
||||||
|
- Token reduction metrics
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.brain.cache import DecisionCache
|
||||||
|
from src.brain.context_selector import ContextSelector, DecisionType
|
||||||
|
from src.brain.gemini_client import TradeDecision
|
||||||
|
from src.brain.prompt_optimizer import PromptOptimizer, TokenMetrics
|
||||||
|
from src.context.layer import ContextLayer
|
||||||
|
from src.context.store import ContextStore
|
||||||
|
from src.context.summarizer import ContextSummarizer, SummaryStats
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Prompt Optimizer Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptOptimizer:
|
||||||
|
"""Tests for PromptOptimizer."""
|
||||||
|
|
||||||
|
def test_estimate_tokens(self):
|
||||||
|
"""Test token estimation."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
# Empty text
|
||||||
|
assert optimizer.estimate_tokens("") == 0
|
||||||
|
|
||||||
|
# Short text (4 chars = 1 token estimate)
|
||||||
|
assert optimizer.estimate_tokens("test") == 1
|
||||||
|
|
||||||
|
# Longer text
|
||||||
|
text = "This is a longer piece of text for testing token estimation."
|
||||||
|
tokens = optimizer.estimate_tokens(text)
|
||||||
|
assert tokens > 0
|
||||||
|
assert tokens == len(text) // 4
|
||||||
|
|
||||||
|
def test_count_tokens(self):
|
||||||
|
"""Test token counting metrics."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
text = "Hello world, this is a test."
|
||||||
|
metrics = optimizer.count_tokens(text)
|
||||||
|
|
||||||
|
assert isinstance(metrics, TokenMetrics)
|
||||||
|
assert metrics.char_count == len(text)
|
||||||
|
assert metrics.word_count == 6
|
||||||
|
assert metrics.estimated_tokens > 0
|
||||||
|
|
||||||
|
def test_compress_json(self):
|
||||||
|
"""Test JSON compression."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"action": "BUY",
|
||||||
|
"confidence": 85,
|
||||||
|
"rationale": "Strong uptrend",
|
||||||
|
}
|
||||||
|
|
||||||
|
compressed = optimizer.compress_json(data)
|
||||||
|
|
||||||
|
# Should have no newlines and minimal whitespace
|
||||||
|
assert "\n" not in compressed
|
||||||
|
# Note: JSON values may contain spaces (e.g., "Strong uptrend")
|
||||||
|
# but there should be no spaces around separators
|
||||||
|
assert ": " not in compressed
|
||||||
|
assert ", " not in compressed
|
||||||
|
|
||||||
|
# Should be valid JSON
|
||||||
|
import json
|
||||||
|
|
||||||
|
parsed = json.loads(compressed)
|
||||||
|
assert parsed == data
|
||||||
|
|
||||||
|
def test_abbreviate_text(self):
|
||||||
|
"""Test text abbreviation."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
text = "The current price is high and volume is increasing."
|
||||||
|
abbreviated = optimizer.abbreviate_text(text)
|
||||||
|
|
||||||
|
# Should contain abbreviations
|
||||||
|
assert "cur" in abbreviated or "P" in abbreviated
|
||||||
|
assert len(abbreviated) <= len(text)
|
||||||
|
|
||||||
|
def test_abbreviate_text_aggressive(self):
|
||||||
|
"""Test aggressive text abbreviation."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
text = "The price is increasing and the volume is high."
|
||||||
|
abbreviated = optimizer.abbreviate_text(text, aggressive=True)
|
||||||
|
|
||||||
|
# Should be shorter
|
||||||
|
assert len(abbreviated) < len(text)
|
||||||
|
|
||||||
|
# Should have removed articles
|
||||||
|
assert "the" not in abbreviated.lower()
|
||||||
|
|
||||||
|
def test_build_compressed_prompt(self):
|
||||||
|
"""Test compressed prompt building."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
market_data = {
|
||||||
|
"stock_code": "005930",
|
||||||
|
"current_price": 75000,
|
||||||
|
"market_name": "Korean stock market",
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt = optimizer.build_compressed_prompt(market_data)
|
||||||
|
|
||||||
|
# Should be much shorter than original
|
||||||
|
assert len(prompt) < 300
|
||||||
|
assert "005930" in prompt
|
||||||
|
assert "75000" in prompt
|
||||||
|
|
||||||
|
def test_build_compressed_prompt_no_instructions(self):
|
||||||
|
"""Test compressed prompt without instructions."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
market_data = {
|
||||||
|
"stock_code": "AAPL",
|
||||||
|
"current_price": 150.5,
|
||||||
|
"market_name": "United States",
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt = optimizer.build_compressed_prompt(market_data, include_instructions=False)
|
||||||
|
|
||||||
|
# Should be very short (data only)
|
||||||
|
assert len(prompt) < 100
|
||||||
|
assert "AAPL" in prompt
|
||||||
|
|
||||||
|
def test_truncate_context(self):
|
||||||
|
"""Test context truncation."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
context = {
|
||||||
|
"price": 100.5,
|
||||||
|
"volume": 1000000,
|
||||||
|
"sentiment": 0.8,
|
||||||
|
"extra_data": "Some long text that should be truncated",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Truncate to small budget
|
||||||
|
truncated = optimizer.truncate_context(context, max_tokens=10)
|
||||||
|
|
||||||
|
# Should have fewer keys
|
||||||
|
assert len(truncated) <= len(context)
|
||||||
|
|
||||||
|
def test_truncate_context_with_priority(self):
|
||||||
|
"""Test context truncation with priority keys."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
context = {
|
||||||
|
"price": 100.5,
|
||||||
|
"volume": 1000000,
|
||||||
|
"sentiment": 0.8,
|
||||||
|
"extra_data": "Some data",
|
||||||
|
}
|
||||||
|
|
||||||
|
priority_keys = ["price", "sentiment"]
|
||||||
|
truncated = optimizer.truncate_context(context, max_tokens=20, priority_keys=priority_keys)
|
||||||
|
|
||||||
|
# Priority keys should be included
|
||||||
|
assert "price" in truncated
|
||||||
|
assert "sentiment" in truncated
|
||||||
|
|
||||||
|
def test_calculate_compression_ratio(self):
|
||||||
|
"""Test compression ratio calculation."""
|
||||||
|
optimizer = PromptOptimizer()
|
||||||
|
|
||||||
|
original = "This is a very long piece of text that should be compressed significantly."
|
||||||
|
compressed = "Short text"
|
||||||
|
|
||||||
|
ratio = optimizer.calculate_compression_ratio(original, compressed)
|
||||||
|
|
||||||
|
# Ratio should be > 1 (original is longer)
|
||||||
|
assert ratio > 1.0
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Context Selector Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextSelector:
|
||||||
|
"""Tests for ContextSelector."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def store(self):
|
||||||
|
"""Create in-memory ContextStore."""
|
||||||
|
conn = sqlite3.connect(":memory:")
|
||||||
|
# Create tables
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE context_metadata (
|
||||||
|
layer TEXT PRIMARY KEY,
|
||||||
|
description TEXT,
|
||||||
|
retention_days INTEGER,
|
||||||
|
aggregation_source TEXT
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE contexts (
|
||||||
|
layer TEXT,
|
||||||
|
timeframe TEXT,
|
||||||
|
key TEXT,
|
||||||
|
value TEXT,
|
||||||
|
created_at TEXT,
|
||||||
|
updated_at TEXT,
|
||||||
|
PRIMARY KEY (layer, timeframe, key)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
return ContextStore(conn)
|
||||||
|
|
||||||
|
def test_select_layers_normal(self, store):
|
||||||
|
"""Test layer selection for normal decisions."""
|
||||||
|
selector = ContextSelector(store)
|
||||||
|
|
||||||
|
layers = selector.select_layers(DecisionType.NORMAL)
|
||||||
|
|
||||||
|
# Should only select L7 (real-time)
|
||||||
|
assert layers == [ContextLayer.L7_REALTIME]
|
||||||
|
|
||||||
|
def test_select_layers_strategic(self, store):
|
||||||
|
"""Test layer selection for strategic decisions."""
|
||||||
|
selector = ContextSelector(store)
|
||||||
|
|
||||||
|
layers = selector.select_layers(DecisionType.STRATEGIC)
|
||||||
|
|
||||||
|
# Should select L7 + L6 + L5
|
||||||
|
assert ContextLayer.L7_REALTIME in layers
|
||||||
|
assert ContextLayer.L6_DAILY in layers
|
||||||
|
assert ContextLayer.L5_WEEKLY in layers
|
||||||
|
assert len(layers) == 3
|
||||||
|
|
||||||
|
def test_select_layers_major_event(self, store):
|
||||||
|
"""Test layer selection for major events."""
|
||||||
|
selector = ContextSelector(store)
|
||||||
|
|
||||||
|
layers = selector.select_layers(DecisionType.MAJOR_EVENT)
|
||||||
|
|
||||||
|
# Should select all layers
|
||||||
|
assert len(layers) == 7
|
||||||
|
assert ContextLayer.L1_LEGACY in layers
|
||||||
|
assert ContextLayer.L7_REALTIME in layers
|
||||||
|
|
||||||
|
def test_score_layer_relevance(self, store):
|
||||||
|
"""Test layer relevance scoring."""
|
||||||
|
selector = ContextSelector(store)
|
||||||
|
|
||||||
|
# Add some data first so scores aren't penalized
|
||||||
|
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
|
||||||
|
store.set_context(ContextLayer.L1_LEGACY, "legacy", "lesson", "test")
|
||||||
|
|
||||||
|
# L7 should have high score for normal decisions
|
||||||
|
score = selector.score_layer_relevance(ContextLayer.L7_REALTIME, DecisionType.NORMAL)
|
||||||
|
assert score == 1.0
|
||||||
|
|
||||||
|
# L1 should have low score for normal decisions
|
||||||
|
score = selector.score_layer_relevance(ContextLayer.L1_LEGACY, DecisionType.NORMAL)
|
||||||
|
assert score == 0.0
|
||||||
|
|
||||||
|
# L1 should have high score for major events
|
||||||
|
score = selector.score_layer_relevance(ContextLayer.L1_LEGACY, DecisionType.MAJOR_EVENT)
|
||||||
|
assert score == 1.0
|
||||||
|
|
||||||
|
def test_select_with_scoring(self, store):
|
||||||
|
"""Test selection with relevance scoring."""
|
||||||
|
selector = ContextSelector(store)
|
||||||
|
|
||||||
|
# Add data so layers aren't penalized
|
||||||
|
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
|
||||||
|
|
||||||
|
selection = selector.select_with_scoring(DecisionType.NORMAL, min_score=0.5)
|
||||||
|
|
||||||
|
# Should only select high-relevance layers
|
||||||
|
assert len(selection.layers) >= 1
|
||||||
|
assert ContextLayer.L7_REALTIME in selection.layers
|
||||||
|
assert all(selection.relevance_scores[layer] >= 0.5 for layer in selection.layers)
|
||||||
|
|
||||||
|
def test_get_context_data(self, store):
|
||||||
|
"""Test context data retrieval."""
|
||||||
|
selector = ContextSelector(store)
|
||||||
|
|
||||||
|
# Add some test data
|
||||||
|
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
|
||||||
|
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "volume", 1000000)
|
||||||
|
|
||||||
|
context_data = selector.get_context_data([ContextLayer.L7_REALTIME])
|
||||||
|
|
||||||
|
# Should retrieve data
|
||||||
|
assert "L7_REALTIME" in context_data
|
||||||
|
assert "price" in context_data["L7_REALTIME"]
|
||||||
|
assert context_data["L7_REALTIME"]["price"] == 100.5
|
||||||
|
|
||||||
|
def test_estimate_context_tokens(self, store):
|
||||||
|
"""Test context token estimation."""
|
||||||
|
selector = ContextSelector(store)
|
||||||
|
|
||||||
|
context_data = {
|
||||||
|
"L7_REALTIME": {"price": 100.5, "volume": 1000000},
|
||||||
|
"L6_DAILY": {"avg_price": 99.8, "avg_volume": 950000},
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens = selector.estimate_context_tokens(context_data)
|
||||||
|
|
||||||
|
# Should estimate tokens
|
||||||
|
assert tokens > 0
|
||||||
|
|
||||||
|
def test_optimize_context_for_budget(self, store):
|
||||||
|
"""Test context optimization for token budget."""
|
||||||
|
selector = ContextSelector(store)
|
||||||
|
|
||||||
|
# Add test data
|
||||||
|
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
|
||||||
|
|
||||||
|
# Get optimized context within budget
|
||||||
|
context = selector.optimize_context_for_budget(DecisionType.NORMAL, max_tokens=50)
|
||||||
|
|
||||||
|
# Should return data within budget
|
||||||
|
tokens = selector.estimate_context_tokens(context)
|
||||||
|
assert tokens <= 50
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Context Summarizer Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextSummarizer:
|
||||||
|
"""Tests for ContextSummarizer."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def store(self):
|
||||||
|
"""Create in-memory ContextStore."""
|
||||||
|
conn = sqlite3.connect(":memory:")
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE context_metadata (
|
||||||
|
layer TEXT PRIMARY KEY,
|
||||||
|
description TEXT,
|
||||||
|
retention_days INTEGER,
|
||||||
|
aggregation_source TEXT
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE contexts (
|
||||||
|
layer TEXT,
|
||||||
|
timeframe TEXT,
|
||||||
|
key TEXT,
|
||||||
|
value TEXT,
|
||||||
|
created_at TEXT,
|
||||||
|
updated_at TEXT,
|
||||||
|
PRIMARY KEY (layer, timeframe, key)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
return ContextStore(conn)
|
||||||
|
|
||||||
|
def test_summarize_numeric_values(self, store):
|
||||||
|
"""Test numeric value summarization."""
|
||||||
|
summarizer = ContextSummarizer(store)
|
||||||
|
|
||||||
|
values = [10.0, 20.0, 30.0, 40.0, 50.0]
|
||||||
|
stats = summarizer.summarize_numeric_values(values)
|
||||||
|
|
||||||
|
assert isinstance(stats, SummaryStats)
|
||||||
|
assert stats.count == 5
|
||||||
|
assert stats.mean == 30.0
|
||||||
|
assert stats.min == 10.0
|
||||||
|
assert stats.max == 50.0
|
||||||
|
assert stats.std is not None
|
||||||
|
|
||||||
|
def test_summarize_numeric_values_trend(self, store):
|
||||||
|
"""Test trend detection in numeric values."""
|
||||||
|
summarizer = ContextSummarizer(store)
|
||||||
|
|
||||||
|
# Uptrend
|
||||||
|
values_up = [10.0, 15.0, 20.0, 25.0, 30.0, 35.0]
|
||||||
|
stats_up = summarizer.summarize_numeric_values(values_up)
|
||||||
|
assert stats_up.trend == "up"
|
||||||
|
|
||||||
|
# Downtrend
|
||||||
|
values_down = [35.0, 30.0, 25.0, 20.0, 15.0, 10.0]
|
||||||
|
stats_down = summarizer.summarize_numeric_values(values_down)
|
||||||
|
assert stats_down.trend == "down"
|
||||||
|
|
||||||
|
# Flat
|
||||||
|
values_flat = [20.0, 20.1, 19.9, 20.0, 20.1, 19.9]
|
||||||
|
stats_flat = summarizer.summarize_numeric_values(values_flat)
|
||||||
|
assert stats_flat.trend == "flat"
|
||||||
|
|
||||||
|
def test_summarize_layer(self, store):
|
||||||
|
"""Test layer summarization."""
|
||||||
|
summarizer = ContextSummarizer(store)
|
||||||
|
|
||||||
|
# Add test data
|
||||||
|
store.set_context(ContextLayer.L6_DAILY, "2026-02-04", "price", 100.5)
|
||||||
|
store.set_context(ContextLayer.L6_DAILY, "2026-02-04", "volume", 1000000)
|
||||||
|
|
||||||
|
summary = summarizer.summarize_layer(ContextLayer.L6_DAILY)
|
||||||
|
|
||||||
|
# Should have summary
|
||||||
|
assert "total_entries" in summary
|
||||||
|
assert summary["total_entries"] > 0
|
||||||
|
|
||||||
|
def test_create_compact_summary(self, store):
|
||||||
|
"""Test compact summary creation."""
|
||||||
|
summarizer = ContextSummarizer(store)
|
||||||
|
|
||||||
|
# Add test data
|
||||||
|
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
|
||||||
|
|
||||||
|
layers = [ContextLayer.L7_REALTIME, ContextLayer.L6_DAILY]
|
||||||
|
summary = summarizer.create_compact_summary(layers, top_n_metrics=3)
|
||||||
|
|
||||||
|
# Should have summaries for layers
|
||||||
|
assert "L7_REALTIME" in summary
|
||||||
|
|
||||||
|
def test_format_summary_for_prompt(self, store):
|
||||||
|
"""Test summary formatting for prompt."""
|
||||||
|
summarizer = ContextSummarizer(store)
|
||||||
|
|
||||||
|
summary = {
|
||||||
|
"L7_REALTIME": {
|
||||||
|
"price": {"avg": 100.5, "trend": "up"},
|
||||||
|
"volume": {"avg": 1000000, "trend": "flat"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
formatted = summarizer.format_summary_for_prompt(summary)
|
||||||
|
|
||||||
|
# Should be formatted string
|
||||||
|
assert isinstance(formatted, str)
|
||||||
|
assert "L7_REALTIME" in formatted
|
||||||
|
assert "100.5" in formatted or "100.50" in formatted
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Decision Cache Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestDecisionCache:
|
||||||
|
"""Tests for DecisionCache."""
|
||||||
|
|
||||||
|
def test_cache_init(self):
|
||||||
|
"""Test cache initialization."""
|
||||||
|
cache = DecisionCache(ttl_seconds=60, max_size=100)
|
||||||
|
|
||||||
|
assert cache.ttl_seconds == 60
|
||||||
|
assert cache.max_size == 100
|
||||||
|
|
||||||
|
def test_cache_miss(self):
|
||||||
|
"""Test cache miss."""
|
||||||
|
cache = DecisionCache()
|
||||||
|
|
||||||
|
market_data = {"stock_code": "005930", "current_price": 75000}
|
||||||
|
|
||||||
|
decision = cache.get(market_data)
|
||||||
|
|
||||||
|
# Should be None (cache miss)
|
||||||
|
assert decision is None
|
||||||
|
|
||||||
|
metrics = cache.get_metrics()
|
||||||
|
assert metrics.cache_misses == 1
|
||||||
|
assert metrics.cache_hits == 0
|
||||||
|
|
||||||
|
def test_cache_hit(self):
|
||||||
|
"""Test cache hit."""
|
||||||
|
cache = DecisionCache()
|
||||||
|
|
||||||
|
market_data = {"stock_code": "005930", "current_price": 75000}
|
||||||
|
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||||
|
|
||||||
|
# Set cache
|
||||||
|
cache.set(market_data, decision)
|
||||||
|
|
||||||
|
# Get from cache
|
||||||
|
cached = cache.get(market_data)
|
||||||
|
|
||||||
|
assert cached is not None
|
||||||
|
assert cached.action == "HOLD"
|
||||||
|
assert cached.confidence == 50
|
||||||
|
|
||||||
|
metrics = cache.get_metrics()
|
||||||
|
assert metrics.cache_hits == 1
|
||||||
|
|
||||||
|
def test_cache_ttl_expiration(self):
|
||||||
|
"""Test cache TTL expiration."""
|
||||||
|
cache = DecisionCache(ttl_seconds=1) # 1 second TTL
|
||||||
|
|
||||||
|
market_data = {"stock_code": "005930", "current_price": 75000}
|
||||||
|
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||||
|
|
||||||
|
# Set cache
|
||||||
|
cache.set(market_data, decision)
|
||||||
|
|
||||||
|
# Should hit immediately
|
||||||
|
cached = cache.get(market_data)
|
||||||
|
assert cached is not None
|
||||||
|
|
||||||
|
# Wait for expiration
|
||||||
|
time.sleep(1.1)
|
||||||
|
|
||||||
|
# Should miss after expiration
|
||||||
|
cached = cache.get(market_data)
|
||||||
|
assert cached is None
|
||||||
|
|
||||||
|
def test_cache_max_size(self):
|
||||||
|
"""Test cache max size eviction."""
|
||||||
|
cache = DecisionCache(max_size=2)
|
||||||
|
|
||||||
|
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||||
|
|
||||||
|
# Add 3 entries (exceeds max_size)
|
||||||
|
for i in range(3):
|
||||||
|
market_data = {"stock_code": f"00{i}", "current_price": 1000 * i}
|
||||||
|
cache.set(market_data, decision)
|
||||||
|
|
||||||
|
metrics = cache.get_metrics()
|
||||||
|
|
||||||
|
# Should have evicted 1 entry
|
||||||
|
assert metrics.total_entries == 2
|
||||||
|
assert metrics.evictions == 1
|
||||||
|
|
||||||
|
def test_invalidate_all(self):
|
||||||
|
"""Test invalidate all cache entries."""
|
||||||
|
cache = DecisionCache()
|
||||||
|
|
||||||
|
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||||
|
|
||||||
|
# Add entries
|
||||||
|
for i in range(3):
|
||||||
|
market_data = {"stock_code": f"00{i}", "current_price": 1000}
|
||||||
|
cache.set(market_data, decision)
|
||||||
|
|
||||||
|
# Invalidate all
|
||||||
|
count = cache.invalidate()
|
||||||
|
|
||||||
|
assert count == 3
|
||||||
|
|
||||||
|
metrics = cache.get_metrics()
|
||||||
|
assert metrics.total_entries == 0
|
||||||
|
|
||||||
|
def test_invalidate_by_stock(self):
|
||||||
|
"""Test invalidate cache by stock code."""
|
||||||
|
cache = DecisionCache()
|
||||||
|
|
||||||
|
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||||
|
|
||||||
|
# Add entries for different stocks
|
||||||
|
cache.set({"stock_code": "005930", "current_price": 75000}, decision)
|
||||||
|
cache.set({"stock_code": "000660", "current_price": 50000}, decision)
|
||||||
|
|
||||||
|
# Invalidate specific stock
|
||||||
|
count = cache.invalidate("005930")
|
||||||
|
|
||||||
|
assert count >= 1
|
||||||
|
|
||||||
|
# Other stock should still be cached
|
||||||
|
cached = cache.get({"stock_code": "000660", "current_price": 50000})
|
||||||
|
assert cached is not None
|
||||||
|
|
||||||
|
def test_cleanup_expired(self):
|
||||||
|
"""Test cleanup of expired entries."""
|
||||||
|
cache = DecisionCache(ttl_seconds=1)
|
||||||
|
|
||||||
|
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||||
|
|
||||||
|
# Add entry
|
||||||
|
cache.set({"stock_code": "005930", "current_price": 75000}, decision)
|
||||||
|
|
||||||
|
# Wait for expiration
|
||||||
|
time.sleep(1.1)
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
count = cache.cleanup_expired()
|
||||||
|
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
metrics = cache.get_metrics()
|
||||||
|
assert metrics.total_entries == 0
|
||||||
|
|
||||||
|
def test_should_cache_decision(self):
|
||||||
|
"""Test decision caching criteria."""
|
||||||
|
cache = DecisionCache()
|
||||||
|
|
||||||
|
# HOLD decisions should be cached
|
||||||
|
hold_decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||||
|
assert cache.should_cache_decision(hold_decision) is True
|
||||||
|
|
||||||
|
# High confidence BUY should be cached
|
||||||
|
buy_decision = TradeDecision(action="BUY", confidence=95, rationale="Test")
|
||||||
|
assert cache.should_cache_decision(buy_decision) is True
|
||||||
|
|
||||||
|
# Low confidence BUY should not be cached
|
||||||
|
low_conf_buy = TradeDecision(action="BUY", confidence=60, rationale="Test")
|
||||||
|
assert cache.should_cache_decision(low_conf_buy) is False
|
||||||
|
|
||||||
|
def test_cache_hit_rate(self):
|
||||||
|
"""Test cache hit rate calculation."""
|
||||||
|
cache = DecisionCache()
|
||||||
|
|
||||||
|
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||||
|
market_data = {"stock_code": "005930", "current_price": 75000}
|
||||||
|
|
||||||
|
# First request (miss)
|
||||||
|
cache.get(market_data)
|
||||||
|
|
||||||
|
# Set cache
|
||||||
|
cache.set(market_data, decision)
|
||||||
|
|
||||||
|
# Second request (hit)
|
||||||
|
cache.get(market_data)
|
||||||
|
|
||||||
|
# Third request (hit)
|
||||||
|
cache.get(market_data)
|
||||||
|
|
||||||
|
metrics = cache.get_metrics()
|
||||||
|
|
||||||
|
# 1 miss, 2 hits out of 3 requests
|
||||||
|
assert metrics.total_requests == 3
|
||||||
|
assert metrics.cache_hits == 2
|
||||||
|
assert metrics.cache_misses == 1
|
||||||
|
assert metrics.hit_rate == pytest.approx(2 / 3)
|
||||||
|
|
||||||
|
def test_reset_metrics(self):
|
||||||
|
"""Test metrics reset."""
|
||||||
|
cache = DecisionCache()
|
||||||
|
|
||||||
|
market_data = {"stock_code": "005930", "current_price": 75000}
|
||||||
|
|
||||||
|
# Generate some activity
|
||||||
|
cache.get(market_data)
|
||||||
|
cache.get(market_data)
|
||||||
|
|
||||||
|
# Reset
|
||||||
|
cache.reset_metrics()
|
||||||
|
|
||||||
|
metrics = cache.get_metrics()
|
||||||
|
assert metrics.total_requests == 0
|
||||||
|
assert metrics.cache_hits == 0
|
||||||
|
assert metrics.cache_misses == 0
|
||||||
Reference in New Issue
Block a user