Compare commits
6 Commits
feature/is
...
feature/is
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8c05448843 | ||
|
|
87556b145e | ||
| 645c761238 | |||
|
|
033d5fcadd | ||
| 128324427f | |||
|
|
62fd4ff5e1 |
@@ -21,3 +21,8 @@ RATE_LIMIT_RPS=10.0
|
|||||||
|
|
||||||
# Trading Mode (paper / live)
|
# Trading Mode (paper / live)
|
||||||
MODE=paper
|
MODE=paper
|
||||||
|
|
||||||
|
# External Data APIs (optional — for enhanced decision-making)
|
||||||
|
# NEWS_API_KEY=your_news_api_key_here
|
||||||
|
# NEWS_API_PROVIDER=alphavantage
|
||||||
|
# MARKET_DATA_API_KEY=your_market_data_key_here
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -174,4 +174,7 @@ cython_debug/
|
|||||||
# PyPI configuration file
|
# PyPI configuration file
|
||||||
.pypirc
|
.pypirc
|
||||||
|
|
||||||
|
# Data files (trade logs, databases)
|
||||||
|
# But NOT src/data/ which contains source code
|
||||||
data/
|
data/
|
||||||
|
!src/data/
|
||||||
|
|||||||
348
docs/disaster_recovery.md
Normal file
348
docs/disaster_recovery.md
Normal file
@@ -0,0 +1,348 @@
|
|||||||
|
# Disaster Recovery Guide
|
||||||
|
|
||||||
|
Complete guide for backing up and restoring The Ouroboros trading system.
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
- [Backup Strategy](#backup-strategy)
|
||||||
|
- [Creating Backups](#creating-backups)
|
||||||
|
- [Restoring from Backup](#restoring-from-backup)
|
||||||
|
- [Health Monitoring](#health-monitoring)
|
||||||
|
- [Export Formats](#export-formats)
|
||||||
|
- [RTO/RPO](#rtorpo)
|
||||||
|
- [Testing Recovery](#testing-recovery)
|
||||||
|
|
||||||
|
## Backup Strategy
|
||||||
|
|
||||||
|
The system implements a 3-tier backup retention policy:
|
||||||
|
|
||||||
|
| Policy | Frequency | Retention | Purpose |
|
||||||
|
|--------|-----------|-----------|---------|
|
||||||
|
| **Daily** | Every day | 30 days | Quick recovery from recent issues |
|
||||||
|
| **Weekly** | Sunday | 1 year | Medium-term historical analysis |
|
||||||
|
| **Monthly** | 1st of month | Forever | Long-term archival |
|
||||||
|
|
||||||
|
### Storage Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
data/backups/
|
||||||
|
├── daily/ # Last 30 days
|
||||||
|
├── weekly/ # Last 52 weeks
|
||||||
|
└── monthly/ # Forever (cold storage)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Creating Backups
|
||||||
|
|
||||||
|
### Automated Backups (Recommended)
|
||||||
|
|
||||||
|
Set up a cron job to run daily:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Edit crontab
|
||||||
|
crontab -e
|
||||||
|
|
||||||
|
# Run backup at 2 AM every day
|
||||||
|
0 2 * * * cd /path/to/The-Ouroboros && ./scripts/backup.sh >> logs/backup.log 2>&1
|
||||||
|
```
|
||||||
|
|
||||||
|
### Manual Backups
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run backup script
|
||||||
|
./scripts/backup.sh
|
||||||
|
|
||||||
|
# Or use Python directly
|
||||||
|
python3 -c "
|
||||||
|
from pathlib import Path
|
||||||
|
from src.backup.scheduler import BackupScheduler, BackupPolicy
|
||||||
|
|
||||||
|
scheduler = BackupScheduler('data/trade_logs.db', Path('data/backups'))
|
||||||
|
metadata = scheduler.create_backup(BackupPolicy.DAILY, verify=True)
|
||||||
|
print(f'Backup created: {metadata.file_path}')
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Export to Other Formats
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -c "
|
||||||
|
from pathlib import Path
|
||||||
|
from src.backup.exporter import BackupExporter, ExportFormat
|
||||||
|
|
||||||
|
exporter = BackupExporter('data/trade_logs.db')
|
||||||
|
results = exporter.export_all(
|
||||||
|
Path('exports'),
|
||||||
|
formats=[ExportFormat.JSON, ExportFormat.CSV],
|
||||||
|
compress=True
|
||||||
|
)
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Restoring from Backup
|
||||||
|
|
||||||
|
### Interactive Restoration
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./scripts/restore.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
The script will:
|
||||||
|
1. List available backups
|
||||||
|
2. Ask you to select one
|
||||||
|
3. Create a safety backup of current database
|
||||||
|
4. Restore the selected backup
|
||||||
|
5. Verify database integrity
|
||||||
|
|
||||||
|
### Manual Restoration
|
||||||
|
|
||||||
|
```python
|
||||||
|
from pathlib import Path
|
||||||
|
from src.backup.scheduler import BackupScheduler
|
||||||
|
|
||||||
|
scheduler = BackupScheduler('data/trade_logs.db', Path('data/backups'))
|
||||||
|
|
||||||
|
# List backups
|
||||||
|
backups = scheduler.list_backups()
|
||||||
|
for backup in backups:
|
||||||
|
print(f"{backup.timestamp}: {backup.file_path}")
|
||||||
|
|
||||||
|
# Restore specific backup
|
||||||
|
scheduler.restore_backup(backups[0], verify=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Health Monitoring
|
||||||
|
|
||||||
|
### Check System Health
|
||||||
|
|
||||||
|
```python
|
||||||
|
from pathlib import Path
|
||||||
|
from src.backup.health_monitor import HealthMonitor
|
||||||
|
|
||||||
|
monitor = HealthMonitor('data/trade_logs.db', Path('data/backups'))
|
||||||
|
|
||||||
|
# Run all checks
|
||||||
|
report = monitor.get_health_report()
|
||||||
|
print(f"Overall status: {report['overall_status']}")
|
||||||
|
|
||||||
|
# Individual checks
|
||||||
|
checks = monitor.run_all_checks()
|
||||||
|
for name, result in checks.items():
|
||||||
|
print(f"{name}: {result.status.value} - {result.message}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Health Checks
|
||||||
|
|
||||||
|
The system monitors:
|
||||||
|
|
||||||
|
- **Database Health**: Accessibility, integrity, size
|
||||||
|
- **Disk Space**: Available storage (alerts if < 10 GB)
|
||||||
|
- **Backup Recency**: Ensures backups are < 25 hours old
|
||||||
|
|
||||||
|
### Health Status Levels
|
||||||
|
|
||||||
|
- **HEALTHY**: All systems operational
|
||||||
|
- **DEGRADED**: Warning condition (e.g., low disk space)
|
||||||
|
- **UNHEALTHY**: Critical issue (e.g., database corrupted, no backups)
|
||||||
|
|
||||||
|
## Export Formats
|
||||||
|
|
||||||
|
### JSON (Human-Readable)
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"export_timestamp": "2024-01-15T10:30:00Z",
|
||||||
|
"record_count": 150,
|
||||||
|
"trades": [
|
||||||
|
{
|
||||||
|
"timestamp": "2024-01-15T09:00:00Z",
|
||||||
|
"stock_code": "005930",
|
||||||
|
"action": "BUY",
|
||||||
|
"quantity": 10,
|
||||||
|
"price": 70000.0,
|
||||||
|
"confidence": 85,
|
||||||
|
"rationale": "Strong momentum",
|
||||||
|
"pnl": 0.0
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### CSV (Analysis Tools)
|
||||||
|
|
||||||
|
Compatible with Excel, pandas, R:
|
||||||
|
|
||||||
|
```csv
|
||||||
|
timestamp,stock_code,action,quantity,price,confidence,rationale,pnl
|
||||||
|
2024-01-15T09:00:00Z,005930,BUY,10,70000.0,85,Strong momentum,0.0
|
||||||
|
```
|
||||||
|
|
||||||
|
### Parquet (Big Data)
|
||||||
|
|
||||||
|
Columnar format for Spark, DuckDB:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import pandas as pd
|
||||||
|
df = pd.read_parquet('exports/trades_20240115.parquet')
|
||||||
|
```
|
||||||
|
|
||||||
|
## RTO/RPO
|
||||||
|
|
||||||
|
### Recovery Time Objective (RTO)
|
||||||
|
|
||||||
|
**Target: < 5 minutes**
|
||||||
|
|
||||||
|
Time to restore trading operations:
|
||||||
|
1. Identify backup to restore (1 min)
|
||||||
|
2. Run restore script (2 min)
|
||||||
|
3. Verify database integrity (1 min)
|
||||||
|
4. Restart trading system (1 min)
|
||||||
|
|
||||||
|
### Recovery Point Objective (RPO)
|
||||||
|
|
||||||
|
**Target: < 24 hours**
|
||||||
|
|
||||||
|
Maximum acceptable data loss:
|
||||||
|
- Daily backups ensure ≤ 24-hour data loss
|
||||||
|
- For critical periods, run backups more frequently
|
||||||
|
|
||||||
|
## Testing Recovery
|
||||||
|
|
||||||
|
### Quarterly Recovery Test
|
||||||
|
|
||||||
|
Perform full disaster recovery test every quarter:
|
||||||
|
|
||||||
|
1. **Create test backup**
|
||||||
|
```bash
|
||||||
|
./scripts/backup.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Simulate disaster** (use test database)
|
||||||
|
```bash
|
||||||
|
cp data/trade_logs.db data/trade_logs_test.db
|
||||||
|
rm data/trade_logs_test.db # Simulate data loss
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Restore from backup**
|
||||||
|
```bash
|
||||||
|
DB_PATH=data/trade_logs_test.db ./scripts/restore.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Verify data integrity**
|
||||||
|
```python
|
||||||
|
import sqlite3
|
||||||
|
conn = sqlite3.connect('data/trade_logs_test.db')
|
||||||
|
cursor = conn.execute('SELECT COUNT(*) FROM trades')
|
||||||
|
print(f"Restored {cursor.fetchone()[0]} trades")
|
||||||
|
```
|
||||||
|
|
||||||
|
5. **Document results** in `logs/recovery_test_YYYYMMDD.md`
|
||||||
|
|
||||||
|
### Backup Verification
|
||||||
|
|
||||||
|
Always verify backups after creation:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from pathlib import Path
|
||||||
|
from src.backup.scheduler import BackupScheduler
|
||||||
|
|
||||||
|
scheduler = BackupScheduler('data/trade_logs.db', Path('data/backups'))
|
||||||
|
|
||||||
|
# Create and verify
|
||||||
|
metadata = scheduler.create_backup(BackupPolicy.DAILY, verify=True)
|
||||||
|
print(f"Checksum: {metadata.checksum}") # Should not be None
|
||||||
|
```
|
||||||
|
|
||||||
|
## Emergency Procedures
|
||||||
|
|
||||||
|
### Database Corrupted
|
||||||
|
|
||||||
|
1. Stop trading system immediately
|
||||||
|
2. Check most recent backup age: `ls -lht data/backups/daily/`
|
||||||
|
3. Restore: `./scripts/restore.sh`
|
||||||
|
4. Verify: Run health check
|
||||||
|
5. Resume trading
|
||||||
|
|
||||||
|
### Disk Full
|
||||||
|
|
||||||
|
1. Check disk space: `df -h`
|
||||||
|
2. Clean old backups: Run cleanup manually
|
||||||
|
```python
|
||||||
|
from pathlib import Path
|
||||||
|
from src.backup.scheduler import BackupScheduler
|
||||||
|
scheduler = BackupScheduler('data/trade_logs.db', Path('data/backups'))
|
||||||
|
scheduler.cleanup_old_backups()
|
||||||
|
```
|
||||||
|
3. Consider archiving old monthly backups to external storage
|
||||||
|
4. Increase disk space if needed
|
||||||
|
|
||||||
|
### Lost All Backups
|
||||||
|
|
||||||
|
If local backups are lost:
|
||||||
|
1. Check if exports exist in `exports/` directory
|
||||||
|
2. Reconstruct database from CSV/JSON exports
|
||||||
|
3. If no exports: Check broker API for trade history
|
||||||
|
4. Manual reconstruction as last resort
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
1. **Test Restores Regularly**: Don't wait for disaster
|
||||||
|
2. **Monitor Disk Space**: Set up alerts at 80% usage
|
||||||
|
3. **Keep Multiple Generations**: Never delete all backups at once
|
||||||
|
4. **Verify Checksums**: Always verify backup integrity
|
||||||
|
5. **Document Changes**: Update this guide when backup strategy changes
|
||||||
|
6. **Off-Site Storage**: Consider external backup for monthly archives
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Backup Script Fails
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Check database file permissions
|
||||||
|
ls -l data/trade_logs.db
|
||||||
|
|
||||||
|
# Check disk space
|
||||||
|
df -h data/
|
||||||
|
|
||||||
|
# Run backup manually with debug
|
||||||
|
python3 -c "
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
from pathlib import Path
|
||||||
|
from src.backup.scheduler import BackupScheduler, BackupPolicy
|
||||||
|
scheduler = BackupScheduler('data/trade_logs.db', Path('data/backups'))
|
||||||
|
scheduler.create_backup(BackupPolicy.DAILY, verify=True)
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Restore Fails Verification
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Check backup file integrity
|
||||||
|
python3 -c "
|
||||||
|
import sqlite3
|
||||||
|
conn = sqlite3.connect('data/backups/daily/trade_logs_daily_20240115.db')
|
||||||
|
cursor = conn.execute('PRAGMA integrity_check')
|
||||||
|
print(cursor.fetchone()[0])
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Health Check Fails
|
||||||
|
|
||||||
|
```python
|
||||||
|
from pathlib import Path
|
||||||
|
from src.backup.health_monitor import HealthMonitor
|
||||||
|
|
||||||
|
monitor = HealthMonitor('data/trade_logs.db', Path('data/backups'))
|
||||||
|
|
||||||
|
# Check each component individually
|
||||||
|
print("Database:", monitor.check_database_health())
|
||||||
|
print("Disk Space:", monitor.check_disk_space())
|
||||||
|
print("Backup Recency:", monitor.check_backup_recency())
|
||||||
|
```
|
||||||
|
|
||||||
|
## Contact
|
||||||
|
|
||||||
|
For backup/recovery issues:
|
||||||
|
- Check logs: `logs/backup.log`
|
||||||
|
- Review health status: Run health monitor
|
||||||
|
- Raise issue on GitHub if automated recovery fails
|
||||||
96
scripts/backup.sh
Normal file
96
scripts/backup.sh
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# Automated backup script for The Ouroboros trading system
|
||||||
|
# Runs daily/weekly/monthly backups
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
DB_PATH="${DB_PATH:-data/trade_logs.db}"
|
||||||
|
BACKUP_DIR="${BACKUP_DIR:-data/backups}"
|
||||||
|
PYTHON="${PYTHON:-python3}"
|
||||||
|
|
||||||
|
# Colors for output
|
||||||
|
GREEN='\033[0;32m'
|
||||||
|
YELLOW='\033[1;33m'
|
||||||
|
RED='\033[0;31m'
|
||||||
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
|
log_info() {
|
||||||
|
echo -e "${GREEN}[INFO]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
log_warn() {
|
||||||
|
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
log_error() {
|
||||||
|
echo -e "${RED}[ERROR]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if database exists
|
||||||
|
if [ ! -f "$DB_PATH" ]; then
|
||||||
|
log_error "Database not found: $DB_PATH"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Create backup directory
|
||||||
|
mkdir -p "$BACKUP_DIR"
|
||||||
|
|
||||||
|
log_info "Starting backup process..."
|
||||||
|
log_info "Database: $DB_PATH"
|
||||||
|
log_info "Backup directory: $BACKUP_DIR"
|
||||||
|
|
||||||
|
# Determine backup policy based on day of week and month
|
||||||
|
DAY_OF_WEEK=$(date +%u) # 1-7 (Monday-Sunday)
|
||||||
|
DAY_OF_MONTH=$(date +%d)
|
||||||
|
|
||||||
|
if [ "$DAY_OF_MONTH" == "01" ]; then
|
||||||
|
POLICY="monthly"
|
||||||
|
log_info "Running MONTHLY backup (first day of month)"
|
||||||
|
elif [ "$DAY_OF_WEEK" == "7" ]; then
|
||||||
|
POLICY="weekly"
|
||||||
|
log_info "Running WEEKLY backup (Sunday)"
|
||||||
|
else
|
||||||
|
POLICY="daily"
|
||||||
|
log_info "Running DAILY backup"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Run Python backup script
|
||||||
|
$PYTHON -c "
|
||||||
|
from pathlib import Path
|
||||||
|
from src.backup.scheduler import BackupScheduler, BackupPolicy
|
||||||
|
from src.backup.health_monitor import HealthMonitor
|
||||||
|
|
||||||
|
# Create scheduler
|
||||||
|
scheduler = BackupScheduler(
|
||||||
|
db_path='$DB_PATH',
|
||||||
|
backup_dir=Path('$BACKUP_DIR')
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create backup
|
||||||
|
policy = BackupPolicy.$POLICY.upper()
|
||||||
|
metadata = scheduler.create_backup(policy, verify=True)
|
||||||
|
print(f'Backup created: {metadata.file_path}')
|
||||||
|
print(f'Size: {metadata.size_bytes / 1024 / 1024:.2f} MB')
|
||||||
|
print(f'Checksum: {metadata.checksum}')
|
||||||
|
|
||||||
|
# Cleanup old backups
|
||||||
|
removed = scheduler.cleanup_old_backups()
|
||||||
|
total_removed = sum(removed.values())
|
||||||
|
if total_removed > 0:
|
||||||
|
print(f'Removed {total_removed} old backup(s)')
|
||||||
|
|
||||||
|
# Health check
|
||||||
|
monitor = HealthMonitor('$DB_PATH', Path('$BACKUP_DIR'))
|
||||||
|
status = monitor.get_overall_status()
|
||||||
|
print(f'System health: {status.value}')
|
||||||
|
"
|
||||||
|
|
||||||
|
if [ $? -eq 0 ]; then
|
||||||
|
log_info "Backup completed successfully"
|
||||||
|
else
|
||||||
|
log_error "Backup failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
log_info "Backup process finished"
|
||||||
111
scripts/restore.sh
Normal file
111
scripts/restore.sh
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# Restore script for The Ouroboros trading system
|
||||||
|
# Restores database from a backup file
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
DB_PATH="${DB_PATH:-data/trade_logs.db}"
|
||||||
|
BACKUP_DIR="${BACKUP_DIR:-data/backups}"
|
||||||
|
PYTHON="${PYTHON:-python3}"
|
||||||
|
|
||||||
|
# Colors for output
|
||||||
|
GREEN='\033[0;32m'
|
||||||
|
YELLOW='\033[1;33m'
|
||||||
|
RED='\033[0;31m'
|
||||||
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
|
log_info() {
|
||||||
|
echo -e "${GREEN}[INFO]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
log_warn() {
|
||||||
|
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
log_error() {
|
||||||
|
echo -e "${RED}[ERROR]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if backup directory exists
|
||||||
|
if [ ! -d "$BACKUP_DIR" ]; then
|
||||||
|
log_error "Backup directory not found: $BACKUP_DIR"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
log_info "Available backups:"
|
||||||
|
log_info "=================="
|
||||||
|
|
||||||
|
# List available backups
|
||||||
|
$PYTHON -c "
|
||||||
|
from pathlib import Path
|
||||||
|
from src.backup.scheduler import BackupScheduler
|
||||||
|
|
||||||
|
scheduler = BackupScheduler(
|
||||||
|
db_path='$DB_PATH',
|
||||||
|
backup_dir=Path('$BACKUP_DIR')
|
||||||
|
)
|
||||||
|
|
||||||
|
backups = scheduler.list_backups()
|
||||||
|
|
||||||
|
if not backups:
|
||||||
|
print('No backups found.')
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
for i, backup in enumerate(backups, 1):
|
||||||
|
size_mb = backup.size_bytes / 1024 / 1024
|
||||||
|
print(f'{i}. [{backup.policy.value.upper()}] {backup.file_path.name}')
|
||||||
|
print(f' Date: {backup.timestamp.strftime(\"%Y-%m-%d %H:%M:%S UTC\")}')
|
||||||
|
print(f' Size: {size_mb:.2f} MB')
|
||||||
|
print()
|
||||||
|
"
|
||||||
|
|
||||||
|
# Ask user to select backup
|
||||||
|
echo ""
|
||||||
|
read -p "Enter backup number to restore (or 'q' to quit): " BACKUP_NUM
|
||||||
|
|
||||||
|
if [ "$BACKUP_NUM" == "q" ]; then
|
||||||
|
log_info "Restore cancelled"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Confirm restoration
|
||||||
|
log_warn "WARNING: This will replace the current database!"
|
||||||
|
log_warn "Current database will be backed up to: ${DB_PATH}.before_restore"
|
||||||
|
read -p "Are you sure you want to continue? (yes/no): " CONFIRM
|
||||||
|
|
||||||
|
if [ "$CONFIRM" != "yes" ]; then
|
||||||
|
log_info "Restore cancelled"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Perform restoration
|
||||||
|
$PYTHON -c "
|
||||||
|
from pathlib import Path
|
||||||
|
from src.backup.scheduler import BackupScheduler
|
||||||
|
|
||||||
|
scheduler = BackupScheduler(
|
||||||
|
db_path='$DB_PATH',
|
||||||
|
backup_dir=Path('$BACKUP_DIR')
|
||||||
|
)
|
||||||
|
|
||||||
|
backups = scheduler.list_backups()
|
||||||
|
backup_index = int('$BACKUP_NUM') - 1
|
||||||
|
|
||||||
|
if backup_index < 0 or backup_index >= len(backups):
|
||||||
|
print('Invalid backup number')
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
selected = backups[backup_index]
|
||||||
|
print(f'Restoring: {selected.file_path.name}')
|
||||||
|
|
||||||
|
scheduler.restore_backup(selected, verify=True)
|
||||||
|
print('Restore completed successfully')
|
||||||
|
"
|
||||||
|
|
||||||
|
if [ $? -eq 0 ]; then
|
||||||
|
log_info "Database restored successfully"
|
||||||
|
else
|
||||||
|
log_error "Restore failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
21
src/backup/__init__.py
Normal file
21
src/backup/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""Backup and disaster recovery system for long-term sustainability.
|
||||||
|
|
||||||
|
This module provides:
|
||||||
|
- Automated database backups (daily, weekly, monthly)
|
||||||
|
- Multi-format exports (JSON, CSV, Parquet)
|
||||||
|
- Cloud storage integration (S3-compatible)
|
||||||
|
- Health monitoring and alerts
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.backup.exporter import BackupExporter, ExportFormat
|
||||||
|
from src.backup.scheduler import BackupScheduler, BackupPolicy
|
||||||
|
from src.backup.cloud_storage import CloudStorage, S3Config
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BackupExporter",
|
||||||
|
"ExportFormat",
|
||||||
|
"BackupScheduler",
|
||||||
|
"BackupPolicy",
|
||||||
|
"CloudStorage",
|
||||||
|
"S3Config",
|
||||||
|
]
|
||||||
274
src/backup/cloud_storage.py
Normal file
274
src/backup/cloud_storage.py
Normal file
@@ -0,0 +1,274 @@
|
|||||||
|
"""Cloud storage integration for off-site backups.
|
||||||
|
|
||||||
|
Supports S3-compatible storage providers:
|
||||||
|
- AWS S3
|
||||||
|
- MinIO
|
||||||
|
- Backblaze B2
|
||||||
|
- DigitalOcean Spaces
|
||||||
|
- Cloudflare R2
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class S3Config:
|
||||||
|
"""Configuration for S3-compatible storage."""
|
||||||
|
|
||||||
|
endpoint_url: str | None # None for AWS S3, custom URL for others
|
||||||
|
access_key: str
|
||||||
|
secret_key: str
|
||||||
|
bucket_name: str
|
||||||
|
region: str = "us-east-1"
|
||||||
|
use_ssl: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class CloudStorage:
|
||||||
|
"""Upload backups to S3-compatible cloud storage."""
|
||||||
|
|
||||||
|
def __init__(self, config: S3Config) -> None:
|
||||||
|
"""Initialize cloud storage client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: S3 configuration
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If boto3 is not installed
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"boto3 is required for cloud storage. Install with: pip install boto3"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.client = boto3.client(
|
||||||
|
"s3",
|
||||||
|
endpoint_url=config.endpoint_url,
|
||||||
|
aws_access_key_id=config.access_key,
|
||||||
|
aws_secret_access_key=config.secret_key,
|
||||||
|
region_name=config.region,
|
||||||
|
use_ssl=config.use_ssl,
|
||||||
|
)
|
||||||
|
|
||||||
|
def upload_file(
|
||||||
|
self,
|
||||||
|
file_path: Path,
|
||||||
|
object_key: str | None = None,
|
||||||
|
metadata: dict[str, str] | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Upload a file to cloud storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Local file to upload
|
||||||
|
object_key: S3 object key (default: filename)
|
||||||
|
metadata: Optional metadata to attach
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
S3 object key
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If file doesn't exist
|
||||||
|
Exception: If upload fails
|
||||||
|
"""
|
||||||
|
if not file_path.exists():
|
||||||
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
|
|
||||||
|
if object_key is None:
|
||||||
|
object_key = file_path.name
|
||||||
|
|
||||||
|
extra_args: dict[str, Any] = {}
|
||||||
|
|
||||||
|
# Add server-side encryption
|
||||||
|
extra_args["ServerSideEncryption"] = "AES256"
|
||||||
|
|
||||||
|
# Add metadata if provided
|
||||||
|
if metadata:
|
||||||
|
extra_args["Metadata"] = metadata
|
||||||
|
|
||||||
|
logger.info("Uploading %s to s3://%s/%s", file_path.name, self.config.bucket_name, object_key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.client.upload_file(
|
||||||
|
str(file_path),
|
||||||
|
self.config.bucket_name,
|
||||||
|
object_key,
|
||||||
|
ExtraArgs=extra_args,
|
||||||
|
)
|
||||||
|
logger.info("Upload successful: %s", object_key)
|
||||||
|
return object_key
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Upload failed: %s", exc)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def download_file(self, object_key: str, local_path: Path) -> Path:
|
||||||
|
"""Download a file from cloud storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
object_key: S3 object key
|
||||||
|
local_path: Local destination path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to downloaded file
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If download fails
|
||||||
|
"""
|
||||||
|
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
logger.info("Downloading s3://%s/%s to %s", self.config.bucket_name, object_key, local_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.client.download_file(
|
||||||
|
self.config.bucket_name,
|
||||||
|
object_key,
|
||||||
|
str(local_path),
|
||||||
|
)
|
||||||
|
logger.info("Download successful: %s", local_path)
|
||||||
|
return local_path
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Download failed: %s", exc)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def list_files(self, prefix: str = "") -> list[dict[str, Any]]:
|
||||||
|
"""List files in cloud storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix: Filter by object key prefix
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of file metadata dictionaries
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = self.client.list_objects_v2(
|
||||||
|
Bucket=self.config.bucket_name,
|
||||||
|
Prefix=prefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
if "Contents" not in response:
|
||||||
|
return []
|
||||||
|
|
||||||
|
files = []
|
||||||
|
for obj in response["Contents"]:
|
||||||
|
files.append(
|
||||||
|
{
|
||||||
|
"key": obj["Key"],
|
||||||
|
"size_bytes": obj["Size"],
|
||||||
|
"last_modified": obj["LastModified"],
|
||||||
|
"etag": obj["ETag"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return files
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to list files: %s", exc)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def delete_file(self, object_key: str) -> None:
|
||||||
|
"""Delete a file from cloud storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
object_key: S3 object key
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If deletion fails
|
||||||
|
"""
|
||||||
|
logger.info("Deleting s3://%s/%s", self.config.bucket_name, object_key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.client.delete_object(
|
||||||
|
Bucket=self.config.bucket_name,
|
||||||
|
Key=object_key,
|
||||||
|
)
|
||||||
|
logger.info("Deletion successful: %s", object_key)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Deletion failed: %s", exc)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_storage_stats(self) -> dict[str, Any]:
|
||||||
|
"""Get cloud storage statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with storage stats
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
files = self.list_files()
|
||||||
|
|
||||||
|
total_size = sum(f["size_bytes"] for f in files)
|
||||||
|
total_count = len(files)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_files": total_count,
|
||||||
|
"total_size_bytes": total_size,
|
||||||
|
"total_size_mb": total_size / 1024 / 1024,
|
||||||
|
"total_size_gb": total_size / 1024 / 1024 / 1024,
|
||||||
|
}
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to get storage stats: %s", exc)
|
||||||
|
return {
|
||||||
|
"error": str(exc),
|
||||||
|
"total_files": 0,
|
||||||
|
"total_size_bytes": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def verify_connection(self) -> bool:
|
||||||
|
"""Verify connection to cloud storage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if connection is successful
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.client.head_bucket(Bucket=self.config.bucket_name)
|
||||||
|
logger.info("Cloud storage connection verified")
|
||||||
|
return True
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Cloud storage connection failed: %s", exc)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def create_bucket_if_not_exists(self) -> None:
|
||||||
|
"""Create storage bucket if it doesn't exist.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If bucket creation fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.client.head_bucket(Bucket=self.config.bucket_name)
|
||||||
|
logger.info("Bucket already exists: %s", self.config.bucket_name)
|
||||||
|
except self.client.exceptions.NoSuchBucket:
|
||||||
|
logger.info("Creating bucket: %s", self.config.bucket_name)
|
||||||
|
if self.config.region == "us-east-1":
|
||||||
|
# us-east-1 requires special handling
|
||||||
|
self.client.create_bucket(Bucket=self.config.bucket_name)
|
||||||
|
else:
|
||||||
|
self.client.create_bucket(
|
||||||
|
Bucket=self.config.bucket_name,
|
||||||
|
CreateBucketConfiguration={"LocationConstraint": self.config.region},
|
||||||
|
)
|
||||||
|
logger.info("Bucket created successfully")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to verify/create bucket: %s", exc)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def enable_versioning(self) -> None:
|
||||||
|
"""Enable versioning on the bucket.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If versioning enablement fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.client.put_bucket_versioning(
|
||||||
|
Bucket=self.config.bucket_name,
|
||||||
|
VersioningConfiguration={"Status": "Enabled"},
|
||||||
|
)
|
||||||
|
logger.info("Versioning enabled for bucket: %s", self.config.bucket_name)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to enable versioning: %s", exc)
|
||||||
|
raise
|
||||||
326
src/backup/exporter.py
Normal file
326
src/backup/exporter.py
Normal file
@@ -0,0 +1,326 @@
|
|||||||
|
"""Multi-format database exporter for backups.
|
||||||
|
|
||||||
|
Supports JSON, CSV, and Parquet formats for different use cases:
|
||||||
|
- JSON: Human-readable, easy to inspect
|
||||||
|
- CSV: Analysis tools (Excel, pandas)
|
||||||
|
- Parquet: Big data tools (Spark, DuckDB)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import gzip
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sqlite3
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ExportFormat(str, Enum):
|
||||||
|
"""Supported export formats."""
|
||||||
|
|
||||||
|
JSON = "json"
|
||||||
|
CSV = "csv"
|
||||||
|
PARQUET = "parquet"
|
||||||
|
|
||||||
|
|
||||||
|
class BackupExporter:
|
||||||
|
"""Export database to multiple formats."""
|
||||||
|
|
||||||
|
def __init__(self, db_path: str) -> None:
|
||||||
|
"""Initialize the exporter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path to SQLite database
|
||||||
|
"""
|
||||||
|
self.db_path = db_path
|
||||||
|
|
||||||
|
def export_all(
|
||||||
|
self,
|
||||||
|
output_dir: Path,
|
||||||
|
formats: list[ExportFormat] | None = None,
|
||||||
|
compress: bool = True,
|
||||||
|
incremental_since: datetime | None = None,
|
||||||
|
) -> dict[ExportFormat, Path]:
|
||||||
|
"""Export database to multiple formats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_dir: Directory to write export files
|
||||||
|
formats: List of formats to export (default: all)
|
||||||
|
compress: Whether to gzip compress exports
|
||||||
|
incremental_since: Only export records after this timestamp
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping format to output file path
|
||||||
|
"""
|
||||||
|
if formats is None:
|
||||||
|
formats = [ExportFormat.JSON, ExportFormat.CSV, ExportFormat.PARQUET]
|
||||||
|
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S")
|
||||||
|
|
||||||
|
results: dict[ExportFormat, Path] = {}
|
||||||
|
|
||||||
|
for fmt in formats:
|
||||||
|
try:
|
||||||
|
output_file = self._export_format(
|
||||||
|
fmt, output_dir, timestamp, compress, incremental_since
|
||||||
|
)
|
||||||
|
results[fmt] = output_file
|
||||||
|
logger.info("Exported to %s: %s", fmt.value, output_file)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to export to %s: %s", fmt.value, exc)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _export_format(
|
||||||
|
self,
|
||||||
|
fmt: ExportFormat,
|
||||||
|
output_dir: Path,
|
||||||
|
timestamp: str,
|
||||||
|
compress: bool,
|
||||||
|
incremental_since: datetime | None,
|
||||||
|
) -> Path:
|
||||||
|
"""Export to a specific format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fmt: Export format
|
||||||
|
output_dir: Output directory
|
||||||
|
timestamp: Timestamp string for filename
|
||||||
|
compress: Whether to compress
|
||||||
|
incremental_since: Incremental export cutoff
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to output file
|
||||||
|
"""
|
||||||
|
if fmt == ExportFormat.JSON:
|
||||||
|
return self._export_json(output_dir, timestamp, compress, incremental_since)
|
||||||
|
elif fmt == ExportFormat.CSV:
|
||||||
|
return self._export_csv(output_dir, timestamp, compress, incremental_since)
|
||||||
|
elif fmt == ExportFormat.PARQUET:
|
||||||
|
return self._export_parquet(
|
||||||
|
output_dir, timestamp, compress, incremental_since
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported format: {fmt}")
|
||||||
|
|
||||||
|
def _get_trades(
|
||||||
|
self, incremental_since: datetime | None = None
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Fetch trades from database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
incremental_since: Only fetch trades after this timestamp
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of trade records
|
||||||
|
"""
|
||||||
|
conn = sqlite3.connect(self.db_path)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
|
||||||
|
if incremental_since:
|
||||||
|
cursor = conn.execute(
|
||||||
|
"SELECT * FROM trades WHERE timestamp > ?",
|
||||||
|
(incremental_since.isoformat(),),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cursor = conn.execute("SELECT * FROM trades")
|
||||||
|
|
||||||
|
trades = [dict(row) for row in cursor.fetchall()]
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
return trades
|
||||||
|
|
||||||
|
def _export_json(
|
||||||
|
self,
|
||||||
|
output_dir: Path,
|
||||||
|
timestamp: str,
|
||||||
|
compress: bool,
|
||||||
|
incremental_since: datetime | None,
|
||||||
|
) -> Path:
|
||||||
|
"""Export to JSON format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_dir: Output directory
|
||||||
|
timestamp: Timestamp for filename
|
||||||
|
compress: Whether to gzip
|
||||||
|
incremental_since: Incremental cutoff
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to output file
|
||||||
|
"""
|
||||||
|
trades = self._get_trades(incremental_since)
|
||||||
|
|
||||||
|
filename = f"trades_{timestamp}.json"
|
||||||
|
if compress:
|
||||||
|
filename += ".gz"
|
||||||
|
|
||||||
|
output_file = output_dir / filename
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"export_timestamp": datetime.now(UTC).isoformat(),
|
||||||
|
"incremental_since": (
|
||||||
|
incremental_since.isoformat() if incremental_since else None
|
||||||
|
),
|
||||||
|
"record_count": len(trades),
|
||||||
|
"trades": trades,
|
||||||
|
}
|
||||||
|
|
||||||
|
if compress:
|
||||||
|
with gzip.open(output_file, "wt", encoding="utf-8") as f:
|
||||||
|
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||||
|
else:
|
||||||
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
return output_file
|
||||||
|
|
||||||
|
def _export_csv(
|
||||||
|
self,
|
||||||
|
output_dir: Path,
|
||||||
|
timestamp: str,
|
||||||
|
compress: bool,
|
||||||
|
incremental_since: datetime | None,
|
||||||
|
) -> Path:
|
||||||
|
"""Export to CSV format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_dir: Output directory
|
||||||
|
timestamp: Timestamp for filename
|
||||||
|
compress: Whether to gzip
|
||||||
|
incremental_since: Incremental cutoff
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to output file
|
||||||
|
"""
|
||||||
|
trades = self._get_trades(incremental_since)
|
||||||
|
|
||||||
|
filename = f"trades_{timestamp}.csv"
|
||||||
|
if compress:
|
||||||
|
filename += ".gz"
|
||||||
|
|
||||||
|
output_file = output_dir / filename
|
||||||
|
|
||||||
|
if not trades:
|
||||||
|
# Write empty CSV with headers
|
||||||
|
if compress:
|
||||||
|
with gzip.open(output_file, "wt", encoding="utf-8", newline="") as f:
|
||||||
|
writer = csv.writer(f)
|
||||||
|
writer.writerow(
|
||||||
|
[
|
||||||
|
"timestamp",
|
||||||
|
"stock_code",
|
||||||
|
"action",
|
||||||
|
"quantity",
|
||||||
|
"price",
|
||||||
|
"confidence",
|
||||||
|
"rationale",
|
||||||
|
"pnl",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
with open(output_file, "w", encoding="utf-8", newline="") as f:
|
||||||
|
writer = csv.writer(f)
|
||||||
|
writer.writerow(
|
||||||
|
[
|
||||||
|
"timestamp",
|
||||||
|
"stock_code",
|
||||||
|
"action",
|
||||||
|
"quantity",
|
||||||
|
"price",
|
||||||
|
"confidence",
|
||||||
|
"rationale",
|
||||||
|
"pnl",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return output_file
|
||||||
|
|
||||||
|
# Get column names from first trade
|
||||||
|
fieldnames = list(trades[0].keys())
|
||||||
|
|
||||||
|
if compress:
|
||||||
|
with gzip.open(output_file, "wt", encoding="utf-8", newline="") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||||
|
writer.writeheader()
|
||||||
|
writer.writerows(trades)
|
||||||
|
else:
|
||||||
|
with open(output_file, "w", encoding="utf-8", newline="") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||||
|
writer.writeheader()
|
||||||
|
writer.writerows(trades)
|
||||||
|
|
||||||
|
return output_file
|
||||||
|
|
||||||
|
def _export_parquet(
|
||||||
|
self,
|
||||||
|
output_dir: Path,
|
||||||
|
timestamp: str,
|
||||||
|
compress: bool,
|
||||||
|
incremental_since: datetime | None,
|
||||||
|
) -> Path:
|
||||||
|
"""Export to Parquet format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_dir: Output directory
|
||||||
|
timestamp: Timestamp for filename
|
||||||
|
compress: Whether to compress (Parquet has built-in compression)
|
||||||
|
incremental_since: Incremental cutoff
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to output file
|
||||||
|
"""
|
||||||
|
trades = self._get_trades(incremental_since)
|
||||||
|
|
||||||
|
filename = f"trades_{timestamp}.parquet"
|
||||||
|
output_file = output_dir / filename
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pyarrow as pa
|
||||||
|
import pyarrow.parquet as pq
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"pyarrow is required for Parquet export. "
|
||||||
|
"Install with: pip install pyarrow"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to pyarrow table
|
||||||
|
table = pa.Table.from_pylist(trades)
|
||||||
|
|
||||||
|
# Write with compression
|
||||||
|
compression = "gzip" if compress else "none"
|
||||||
|
pq.write_table(table, output_file, compression=compression)
|
||||||
|
|
||||||
|
return output_file
|
||||||
|
|
||||||
|
def get_export_stats(self) -> dict[str, Any]:
|
||||||
|
"""Get statistics about exportable data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with data statistics
|
||||||
|
"""
|
||||||
|
conn = sqlite3.connect(self.db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
stats = {}
|
||||||
|
|
||||||
|
# Total trades
|
||||||
|
cursor.execute("SELECT COUNT(*) FROM trades")
|
||||||
|
stats["total_trades"] = cursor.fetchone()[0]
|
||||||
|
|
||||||
|
# Date range
|
||||||
|
cursor.execute("SELECT MIN(timestamp), MAX(timestamp) FROM trades")
|
||||||
|
min_date, max_date = cursor.fetchone()
|
||||||
|
stats["date_range"] = {"earliest": min_date, "latest": max_date}
|
||||||
|
|
||||||
|
# Database size
|
||||||
|
cursor.execute("SELECT page_count * page_size FROM pragma_page_count(), pragma_page_size()")
|
||||||
|
stats["db_size_bytes"] = cursor.fetchone()[0]
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
return stats
|
||||||
282
src/backup/health_monitor.py
Normal file
282
src/backup/health_monitor.py
Normal file
@@ -0,0 +1,282 @@
|
|||||||
|
"""Health monitoring for backup system.
|
||||||
|
|
||||||
|
Checks:
|
||||||
|
- Database accessibility and integrity
|
||||||
|
- Disk space availability
|
||||||
|
- Backup success/failure tracking
|
||||||
|
- Self-healing capabilities
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
import sqlite3
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HealthStatus(str, Enum):
|
||||||
|
"""Health check status."""
|
||||||
|
|
||||||
|
HEALTHY = "healthy"
|
||||||
|
DEGRADED = "degraded"
|
||||||
|
UNHEALTHY = "unhealthy"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HealthCheckResult:
|
||||||
|
"""Result of a health check."""
|
||||||
|
|
||||||
|
status: HealthStatus
|
||||||
|
message: str
|
||||||
|
details: dict[str, Any] | None = None
|
||||||
|
timestamp: datetime | None = None
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
if self.timestamp is None:
|
||||||
|
self.timestamp = datetime.now(UTC)
|
||||||
|
|
||||||
|
|
||||||
|
class HealthMonitor:
|
||||||
|
"""Monitor system health and backup status."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
db_path: str,
|
||||||
|
backup_dir: Path,
|
||||||
|
min_disk_space_gb: float = 10.0,
|
||||||
|
max_backup_age_hours: int = 25, # Daily backups should be < 25 hours old
|
||||||
|
) -> None:
|
||||||
|
"""Initialize health monitor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path to SQLite database
|
||||||
|
backup_dir: Backup directory
|
||||||
|
min_disk_space_gb: Minimum required disk space in GB
|
||||||
|
max_backup_age_hours: Maximum acceptable backup age in hours
|
||||||
|
"""
|
||||||
|
self.db_path = Path(db_path)
|
||||||
|
self.backup_dir = backup_dir
|
||||||
|
self.min_disk_space_bytes = int(min_disk_space_gb * 1024 * 1024 * 1024)
|
||||||
|
self.max_backup_age = timedelta(hours=max_backup_age_hours)
|
||||||
|
|
||||||
|
def check_database_health(self) -> HealthCheckResult:
|
||||||
|
"""Check database accessibility and integrity.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HealthCheckResult
|
||||||
|
"""
|
||||||
|
# Check if database exists
|
||||||
|
if not self.db_path.exists():
|
||||||
|
return HealthCheckResult(
|
||||||
|
status=HealthStatus.UNHEALTHY,
|
||||||
|
message=f"Database not found: {self.db_path}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if database is accessible
|
||||||
|
try:
|
||||||
|
conn = sqlite3.connect(str(self.db_path))
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# Run integrity check
|
||||||
|
cursor.execute("PRAGMA integrity_check")
|
||||||
|
result = cursor.fetchone()[0]
|
||||||
|
|
||||||
|
if result != "ok":
|
||||||
|
conn.close()
|
||||||
|
return HealthCheckResult(
|
||||||
|
status=HealthStatus.UNHEALTHY,
|
||||||
|
message=f"Database integrity check failed: {result}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get database size
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT page_count * page_size FROM pragma_page_count(), pragma_page_size()"
|
||||||
|
)
|
||||||
|
db_size = cursor.fetchone()[0]
|
||||||
|
|
||||||
|
# Get row counts
|
||||||
|
cursor.execute("SELECT COUNT(*) FROM trades")
|
||||||
|
trade_count = cursor.fetchone()[0]
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
return HealthCheckResult(
|
||||||
|
status=HealthStatus.HEALTHY,
|
||||||
|
message="Database is healthy",
|
||||||
|
details={
|
||||||
|
"size_bytes": db_size,
|
||||||
|
"size_mb": db_size / 1024 / 1024,
|
||||||
|
"trade_count": trade_count,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except sqlite3.Error as exc:
|
||||||
|
return HealthCheckResult(
|
||||||
|
status=HealthStatus.UNHEALTHY,
|
||||||
|
message=f"Database access error: {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_disk_space(self) -> HealthCheckResult:
|
||||||
|
"""Check available disk space.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HealthCheckResult
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
stat = shutil.disk_usage(self.backup_dir)
|
||||||
|
|
||||||
|
free_gb = stat.free / 1024 / 1024 / 1024
|
||||||
|
total_gb = stat.total / 1024 / 1024 / 1024
|
||||||
|
used_percent = (stat.used / stat.total) * 100
|
||||||
|
|
||||||
|
if stat.free < self.min_disk_space_bytes:
|
||||||
|
return HealthCheckResult(
|
||||||
|
status=HealthStatus.UNHEALTHY,
|
||||||
|
message=f"Low disk space: {free_gb:.2f} GB free (minimum: {self.min_disk_space_bytes / 1024 / 1024 / 1024:.2f} GB)",
|
||||||
|
details={
|
||||||
|
"free_gb": free_gb,
|
||||||
|
"total_gb": total_gb,
|
||||||
|
"used_percent": used_percent,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
elif stat.free < self.min_disk_space_bytes * 2:
|
||||||
|
return HealthCheckResult(
|
||||||
|
status=HealthStatus.DEGRADED,
|
||||||
|
message=f"Disk space low: {free_gb:.2f} GB free",
|
||||||
|
details={
|
||||||
|
"free_gb": free_gb,
|
||||||
|
"total_gb": total_gb,
|
||||||
|
"used_percent": used_percent,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return HealthCheckResult(
|
||||||
|
status=HealthStatus.HEALTHY,
|
||||||
|
message=f"Disk space healthy: {free_gb:.2f} GB free",
|
||||||
|
details={
|
||||||
|
"free_gb": free_gb,
|
||||||
|
"total_gb": total_gb,
|
||||||
|
"used_percent": used_percent,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
return HealthCheckResult(
|
||||||
|
status=HealthStatus.UNHEALTHY,
|
||||||
|
message=f"Failed to check disk space: {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_backup_recency(self) -> HealthCheckResult:
|
||||||
|
"""Check if backups are recent enough.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HealthCheckResult
|
||||||
|
"""
|
||||||
|
daily_dir = self.backup_dir / "daily"
|
||||||
|
|
||||||
|
if not daily_dir.exists():
|
||||||
|
return HealthCheckResult(
|
||||||
|
status=HealthStatus.DEGRADED,
|
||||||
|
message="Daily backup directory not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find most recent backup
|
||||||
|
backups = sorted(daily_dir.glob("*.db"), key=lambda p: p.stat().st_mtime, reverse=True)
|
||||||
|
|
||||||
|
if not backups:
|
||||||
|
return HealthCheckResult(
|
||||||
|
status=HealthStatus.UNHEALTHY,
|
||||||
|
message="No daily backups found",
|
||||||
|
)
|
||||||
|
|
||||||
|
most_recent = backups[0]
|
||||||
|
mtime = datetime.fromtimestamp(most_recent.stat().st_mtime, tz=UTC)
|
||||||
|
age = datetime.now(UTC) - mtime
|
||||||
|
|
||||||
|
if age > self.max_backup_age:
|
||||||
|
return HealthCheckResult(
|
||||||
|
status=HealthStatus.DEGRADED,
|
||||||
|
message=f"Most recent backup is {age.total_seconds() / 3600:.1f} hours old",
|
||||||
|
details={
|
||||||
|
"backup_file": most_recent.name,
|
||||||
|
"age_hours": age.total_seconds() / 3600,
|
||||||
|
"threshold_hours": self.max_backup_age.total_seconds() / 3600,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return HealthCheckResult(
|
||||||
|
status=HealthStatus.HEALTHY,
|
||||||
|
message=f"Recent backup found ({age.total_seconds() / 3600:.1f} hours old)",
|
||||||
|
details={
|
||||||
|
"backup_file": most_recent.name,
|
||||||
|
"age_hours": age.total_seconds() / 3600,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_all_checks(self) -> dict[str, HealthCheckResult]:
|
||||||
|
"""Run all health checks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping check name to result
|
||||||
|
"""
|
||||||
|
checks = {
|
||||||
|
"database": self.check_database_health(),
|
||||||
|
"disk_space": self.check_disk_space(),
|
||||||
|
"backup_recency": self.check_backup_recency(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Log results
|
||||||
|
for check_name, result in checks.items():
|
||||||
|
if result.status == HealthStatus.UNHEALTHY:
|
||||||
|
logger.error("[%s] %s: %s", check_name, result.status.value, result.message)
|
||||||
|
elif result.status == HealthStatus.DEGRADED:
|
||||||
|
logger.warning("[%s] %s: %s", check_name, result.status.value, result.message)
|
||||||
|
else:
|
||||||
|
logger.info("[%s] %s: %s", check_name, result.status.value, result.message)
|
||||||
|
|
||||||
|
return checks
|
||||||
|
|
||||||
|
def get_overall_status(self) -> HealthStatus:
|
||||||
|
"""Get overall system health status.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HealthStatus (worst status from all checks)
|
||||||
|
"""
|
||||||
|
checks = self.run_all_checks()
|
||||||
|
|
||||||
|
# Return worst status
|
||||||
|
if any(c.status == HealthStatus.UNHEALTHY for c in checks.values()):
|
||||||
|
return HealthStatus.UNHEALTHY
|
||||||
|
elif any(c.status == HealthStatus.DEGRADED for c in checks.values()):
|
||||||
|
return HealthStatus.DEGRADED
|
||||||
|
else:
|
||||||
|
return HealthStatus.HEALTHY
|
||||||
|
|
||||||
|
def get_health_report(self) -> dict[str, Any]:
|
||||||
|
"""Get comprehensive health report.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with health report
|
||||||
|
"""
|
||||||
|
checks = self.run_all_checks()
|
||||||
|
overall = self.get_overall_status()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"overall_status": overall.value,
|
||||||
|
"timestamp": datetime.now(UTC).isoformat(),
|
||||||
|
"checks": {
|
||||||
|
name: {
|
||||||
|
"status": result.status.value,
|
||||||
|
"message": result.message,
|
||||||
|
"details": result.details,
|
||||||
|
}
|
||||||
|
for name, result in checks.items()
|
||||||
|
},
|
||||||
|
}
|
||||||
336
src/backup/scheduler.py
Normal file
336
src/backup/scheduler.py
Normal file
@@ -0,0 +1,336 @@
|
|||||||
|
"""Backup scheduler for automated database backups.
|
||||||
|
|
||||||
|
Implements backup policies:
|
||||||
|
- Daily: Keep for 30 days (hot storage)
|
||||||
|
- Weekly: Keep for 1 year (warm storage)
|
||||||
|
- Monthly: Keep forever (cold storage)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BackupPolicy(str, Enum):
|
||||||
|
"""Backup retention policies."""
|
||||||
|
|
||||||
|
DAILY = "daily"
|
||||||
|
WEEKLY = "weekly"
|
||||||
|
MONTHLY = "monthly"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BackupMetadata:
|
||||||
|
"""Metadata for a backup."""
|
||||||
|
|
||||||
|
timestamp: datetime
|
||||||
|
policy: BackupPolicy
|
||||||
|
file_path: Path
|
||||||
|
size_bytes: int
|
||||||
|
checksum: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class BackupScheduler:
|
||||||
|
"""Manage automated database backups with retention policies."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
db_path: str,
|
||||||
|
backup_dir: Path,
|
||||||
|
daily_retention_days: int = 30,
|
||||||
|
weekly_retention_days: int = 365,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the backup scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path to SQLite database
|
||||||
|
backup_dir: Root directory for backups
|
||||||
|
daily_retention_days: Days to keep daily backups
|
||||||
|
weekly_retention_days: Days to keep weekly backups
|
||||||
|
"""
|
||||||
|
self.db_path = Path(db_path)
|
||||||
|
self.backup_dir = backup_dir
|
||||||
|
self.daily_retention = timedelta(days=daily_retention_days)
|
||||||
|
self.weekly_retention = timedelta(days=weekly_retention_days)
|
||||||
|
|
||||||
|
# Create policy-specific directories
|
||||||
|
self.daily_dir = backup_dir / "daily"
|
||||||
|
self.weekly_dir = backup_dir / "weekly"
|
||||||
|
self.monthly_dir = backup_dir / "monthly"
|
||||||
|
|
||||||
|
for d in [self.daily_dir, self.weekly_dir, self.monthly_dir]:
|
||||||
|
d.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def create_backup(
|
||||||
|
self, policy: BackupPolicy, verify: bool = True
|
||||||
|
) -> BackupMetadata:
|
||||||
|
"""Create a database backup.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
policy: Backup policy (daily/weekly/monthly)
|
||||||
|
verify: Whether to verify backup integrity
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BackupMetadata object
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If database doesn't exist
|
||||||
|
OSError: If backup fails
|
||||||
|
"""
|
||||||
|
if not self.db_path.exists():
|
||||||
|
raise FileNotFoundError(f"Database not found: {self.db_path}")
|
||||||
|
|
||||||
|
timestamp = datetime.now(UTC)
|
||||||
|
backup_filename = self._get_backup_filename(timestamp, policy)
|
||||||
|
|
||||||
|
# Determine output directory
|
||||||
|
if policy == BackupPolicy.DAILY:
|
||||||
|
output_dir = self.daily_dir
|
||||||
|
elif policy == BackupPolicy.WEEKLY:
|
||||||
|
output_dir = self.weekly_dir
|
||||||
|
else: # MONTHLY
|
||||||
|
output_dir = self.monthly_dir
|
||||||
|
|
||||||
|
backup_path = output_dir / backup_filename
|
||||||
|
|
||||||
|
# Create backup (copy database file)
|
||||||
|
logger.info("Creating %s backup: %s", policy.value, backup_path)
|
||||||
|
shutil.copy2(self.db_path, backup_path)
|
||||||
|
|
||||||
|
# Get file size
|
||||||
|
size_bytes = backup_path.stat().st_size
|
||||||
|
|
||||||
|
# Verify backup if requested
|
||||||
|
checksum = None
|
||||||
|
if verify:
|
||||||
|
checksum = self._verify_backup(backup_path)
|
||||||
|
|
||||||
|
metadata = BackupMetadata(
|
||||||
|
timestamp=timestamp,
|
||||||
|
policy=policy,
|
||||||
|
file_path=backup_path,
|
||||||
|
size_bytes=size_bytes,
|
||||||
|
checksum=checksum,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Backup created: %s (%.2f MB)",
|
||||||
|
backup_path.name,
|
||||||
|
size_bytes / 1024 / 1024,
|
||||||
|
)
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def _get_backup_filename(self, timestamp: datetime, policy: BackupPolicy) -> str:
|
||||||
|
"""Generate backup filename.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timestamp: Backup timestamp
|
||||||
|
policy: Backup policy
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filename string
|
||||||
|
"""
|
||||||
|
ts_str = timestamp.strftime("%Y%m%d_%H%M%S")
|
||||||
|
return f"trade_logs_{policy.value}_{ts_str}.db"
|
||||||
|
|
||||||
|
def _verify_backup(self, backup_path: Path) -> str:
|
||||||
|
"""Verify backup integrity using SQLite integrity check.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backup_path: Path to backup file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Checksum string (MD5 hash)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If integrity check fails
|
||||||
|
"""
|
||||||
|
import hashlib
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
|
# Integrity check
|
||||||
|
try:
|
||||||
|
conn = sqlite3.connect(str(backup_path))
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("PRAGMA integrity_check")
|
||||||
|
result = cursor.fetchone()[0]
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
if result != "ok":
|
||||||
|
raise RuntimeError(f"Integrity check failed: {result}")
|
||||||
|
except sqlite3.Error as exc:
|
||||||
|
raise RuntimeError(f"Failed to verify backup: {exc}")
|
||||||
|
|
||||||
|
# Calculate MD5 checksum
|
||||||
|
md5 = hashlib.md5()
|
||||||
|
with open(backup_path, "rb") as f:
|
||||||
|
for chunk in iter(lambda: f.read(8192), b""):
|
||||||
|
md5.update(chunk)
|
||||||
|
|
||||||
|
return md5.hexdigest()
|
||||||
|
|
||||||
|
def cleanup_old_backups(self) -> dict[BackupPolicy, int]:
|
||||||
|
"""Remove backups older than retention policies.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping policy to number of backups removed
|
||||||
|
"""
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
removed_counts: dict[BackupPolicy, int] = {}
|
||||||
|
|
||||||
|
# Daily backups: remove older than retention
|
||||||
|
removed_counts[BackupPolicy.DAILY] = self._cleanup_directory(
|
||||||
|
self.daily_dir, now - self.daily_retention
|
||||||
|
)
|
||||||
|
|
||||||
|
# Weekly backups: remove older than retention
|
||||||
|
removed_counts[BackupPolicy.WEEKLY] = self._cleanup_directory(
|
||||||
|
self.weekly_dir, now - self.weekly_retention
|
||||||
|
)
|
||||||
|
|
||||||
|
# Monthly backups: never remove (kept forever)
|
||||||
|
removed_counts[BackupPolicy.MONTHLY] = 0
|
||||||
|
|
||||||
|
total = sum(removed_counts.values())
|
||||||
|
if total > 0:
|
||||||
|
logger.info("Cleaned up %d old backup(s)", total)
|
||||||
|
|
||||||
|
return removed_counts
|
||||||
|
|
||||||
|
def _cleanup_directory(self, directory: Path, cutoff: datetime) -> int:
|
||||||
|
"""Remove backups older than cutoff date.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
directory: Directory to clean
|
||||||
|
cutoff: Remove files older than this
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of files removed
|
||||||
|
"""
|
||||||
|
removed = 0
|
||||||
|
|
||||||
|
for backup_file in directory.glob("*.db"):
|
||||||
|
# Get file modification time
|
||||||
|
mtime = datetime.fromtimestamp(backup_file.stat().st_mtime, tz=UTC)
|
||||||
|
|
||||||
|
if mtime < cutoff:
|
||||||
|
logger.debug("Removing old backup: %s", backup_file.name)
|
||||||
|
backup_file.unlink()
|
||||||
|
removed += 1
|
||||||
|
|
||||||
|
return removed
|
||||||
|
|
||||||
|
def list_backups(
|
||||||
|
self, policy: BackupPolicy | None = None
|
||||||
|
) -> list[BackupMetadata]:
|
||||||
|
"""List available backups.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
policy: Filter by policy (None for all)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of BackupMetadata objects
|
||||||
|
"""
|
||||||
|
backups: list[BackupMetadata] = []
|
||||||
|
|
||||||
|
policies_to_check = (
|
||||||
|
[policy] if policy else [BackupPolicy.DAILY, BackupPolicy.WEEKLY, BackupPolicy.MONTHLY]
|
||||||
|
)
|
||||||
|
|
||||||
|
for pol in policies_to_check:
|
||||||
|
if pol == BackupPolicy.DAILY:
|
||||||
|
directory = self.daily_dir
|
||||||
|
elif pol == BackupPolicy.WEEKLY:
|
||||||
|
directory = self.weekly_dir
|
||||||
|
else:
|
||||||
|
directory = self.monthly_dir
|
||||||
|
|
||||||
|
for backup_file in sorted(directory.glob("*.db")):
|
||||||
|
mtime = datetime.fromtimestamp(backup_file.stat().st_mtime, tz=UTC)
|
||||||
|
size = backup_file.stat().st_size
|
||||||
|
|
||||||
|
backups.append(
|
||||||
|
BackupMetadata(
|
||||||
|
timestamp=mtime,
|
||||||
|
policy=pol,
|
||||||
|
file_path=backup_file,
|
||||||
|
size_bytes=size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sort by timestamp (newest first)
|
||||||
|
backups.sort(key=lambda b: b.timestamp, reverse=True)
|
||||||
|
|
||||||
|
return backups
|
||||||
|
|
||||||
|
def get_backup_stats(self) -> dict[str, Any]:
|
||||||
|
"""Get backup statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with backup stats
|
||||||
|
"""
|
||||||
|
stats: dict[str, Any] = {}
|
||||||
|
|
||||||
|
for policy in BackupPolicy:
|
||||||
|
if policy == BackupPolicy.DAILY:
|
||||||
|
directory = self.daily_dir
|
||||||
|
elif policy == BackupPolicy.WEEKLY:
|
||||||
|
directory = self.weekly_dir
|
||||||
|
else:
|
||||||
|
directory = self.monthly_dir
|
||||||
|
|
||||||
|
backups = list(directory.glob("*.db"))
|
||||||
|
total_size = sum(b.stat().st_size for b in backups)
|
||||||
|
|
||||||
|
stats[policy.value] = {
|
||||||
|
"count": len(backups),
|
||||||
|
"total_size_bytes": total_size,
|
||||||
|
"total_size_mb": total_size / 1024 / 1024,
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def restore_backup(self, backup_metadata: BackupMetadata, verify: bool = True) -> None:
|
||||||
|
"""Restore database from backup.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backup_metadata: Backup to restore
|
||||||
|
verify: Whether to verify restored database
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If backup file doesn't exist
|
||||||
|
RuntimeError: If verification fails
|
||||||
|
"""
|
||||||
|
if not backup_metadata.file_path.exists():
|
||||||
|
raise FileNotFoundError(f"Backup not found: {backup_metadata.file_path}")
|
||||||
|
|
||||||
|
# Create backup of current database
|
||||||
|
if self.db_path.exists():
|
||||||
|
backup_current = self.db_path.with_suffix(".db.before_restore")
|
||||||
|
logger.info("Backing up current database to: %s", backup_current)
|
||||||
|
shutil.copy2(self.db_path, backup_current)
|
||||||
|
|
||||||
|
# Restore backup
|
||||||
|
logger.info("Restoring backup: %s", backup_metadata.file_path.name)
|
||||||
|
shutil.copy2(backup_metadata.file_path, self.db_path)
|
||||||
|
|
||||||
|
# Verify restored database
|
||||||
|
if verify:
|
||||||
|
try:
|
||||||
|
self._verify_backup(self.db_path)
|
||||||
|
logger.info("Backup restored and verified successfully")
|
||||||
|
except RuntimeError as exc:
|
||||||
|
# Restore failed, revert to backup
|
||||||
|
if backup_current.exists():
|
||||||
|
logger.error("Restore verification failed, reverting: %s", exc)
|
||||||
|
shutil.copy2(backup_current, self.db_path)
|
||||||
|
raise
|
||||||
@@ -13,8 +13,8 @@ import hashlib
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from src.brain.gemini_client import TradeDecision
|
from src.brain.gemini_client import TradeDecision
|
||||||
@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class CacheEntry:
|
class CacheEntry:
|
||||||
"""Cached decision with metadata."""
|
"""Cached decision with metadata."""
|
||||||
|
|
||||||
decision: TradeDecision
|
decision: "TradeDecision"
|
||||||
cached_at: float # Unix timestamp
|
cached_at: float # Unix timestamp
|
||||||
hit_count: int = 0
|
hit_count: int = 0
|
||||||
market_data_hash: str = ""
|
market_data_hash: str = ""
|
||||||
|
|||||||
@@ -6,7 +6,13 @@ JSON responses into validated TradeDecision objects.
|
|||||||
Includes token efficiency optimizations:
|
Includes token efficiency optimizations:
|
||||||
- Prompt compression and abbreviation
|
- Prompt compression and abbreviation
|
||||||
- Response caching for common scenarios
|
- Response caching for common scenarios
|
||||||
|
- Smart context selection
|
||||||
- Token usage tracking and metrics
|
- Token usage tracking and metrics
|
||||||
|
|
||||||
|
Includes external data integration:
|
||||||
|
- News sentiment analysis
|
||||||
|
- Economic calendar events
|
||||||
|
- Market indicators
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -19,9 +25,12 @@ from typing import Any
|
|||||||
|
|
||||||
from google import genai
|
from google import genai
|
||||||
|
|
||||||
|
from src.config import Settings
|
||||||
|
from src.data.news_api import NewsAPI, NewsSentiment
|
||||||
|
from src.data.economic_calendar import EconomicCalendar
|
||||||
|
from src.data.market_data import MarketData
|
||||||
from src.brain.cache import DecisionCache
|
from src.brain.cache import DecisionCache
|
||||||
from src.brain.prompt_optimizer import PromptOptimizer
|
from src.brain.prompt_optimizer import PromptOptimizer
|
||||||
from src.config import Settings
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -45,6 +54,9 @@ class GeminiClient:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
settings: Settings,
|
settings: Settings,
|
||||||
|
news_api: NewsAPI | None = None,
|
||||||
|
economic_calendar: EconomicCalendar | None = None,
|
||||||
|
market_data: MarketData | None = None,
|
||||||
enable_cache: bool = True,
|
enable_cache: bool = True,
|
||||||
enable_optimization: bool = True,
|
enable_optimization: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -53,6 +65,11 @@ class GeminiClient:
|
|||||||
self._client = genai.Client(api_key=settings.GEMINI_API_KEY)
|
self._client = genai.Client(api_key=settings.GEMINI_API_KEY)
|
||||||
self._model_name = settings.GEMINI_MODEL
|
self._model_name = settings.GEMINI_MODEL
|
||||||
|
|
||||||
|
# External data sources (optional)
|
||||||
|
self._news_api = news_api
|
||||||
|
self._economic_calendar = economic_calendar
|
||||||
|
self._market_data = market_data
|
||||||
|
|
||||||
# Token efficiency features
|
# Token efficiency features
|
||||||
self._enable_cache = enable_cache
|
self._enable_cache = enable_cache
|
||||||
self._enable_optimization = enable_optimization
|
self._enable_optimization = enable_optimization
|
||||||
@@ -64,12 +81,139 @@ class GeminiClient:
|
|||||||
self._total_decisions = 0
|
self._total_decisions = 0
|
||||||
self._total_cached_decisions = 0
|
self._total_cached_decisions = 0
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# External Data Integration
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _build_external_context(
|
||||||
|
self, stock_code: str, news_sentiment: NewsSentiment | None = None
|
||||||
|
) -> str:
|
||||||
|
"""Build external data context for the prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stock_code: Stock ticker symbol
|
||||||
|
news_sentiment: Optional pre-fetched news sentiment
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted string with external data context
|
||||||
|
"""
|
||||||
|
context_parts: list[str] = []
|
||||||
|
|
||||||
|
# News sentiment
|
||||||
|
if news_sentiment is not None:
|
||||||
|
sentiment_str = self._format_news_sentiment(news_sentiment)
|
||||||
|
if sentiment_str:
|
||||||
|
context_parts.append(sentiment_str)
|
||||||
|
elif self._news_api is not None:
|
||||||
|
# Fetch news sentiment if not provided
|
||||||
|
try:
|
||||||
|
sentiment = await self._news_api.get_news_sentiment(stock_code)
|
||||||
|
if sentiment is not None:
|
||||||
|
sentiment_str = self._format_news_sentiment(sentiment)
|
||||||
|
if sentiment_str:
|
||||||
|
context_parts.append(sentiment_str)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to fetch news sentiment: %s", exc)
|
||||||
|
|
||||||
|
# Economic events
|
||||||
|
if self._economic_calendar is not None:
|
||||||
|
events_str = self._format_economic_events(stock_code)
|
||||||
|
if events_str:
|
||||||
|
context_parts.append(events_str)
|
||||||
|
|
||||||
|
# Market indicators
|
||||||
|
if self._market_data is not None:
|
||||||
|
indicators_str = self._format_market_indicators()
|
||||||
|
if indicators_str:
|
||||||
|
context_parts.append(indicators_str)
|
||||||
|
|
||||||
|
if not context_parts:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
return "EXTERNAL DATA:\n" + "\n\n".join(context_parts)
|
||||||
|
|
||||||
|
def _format_news_sentiment(self, sentiment: NewsSentiment) -> str:
|
||||||
|
"""Format news sentiment for prompt."""
|
||||||
|
if sentiment.article_count == 0:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Select top 3 most relevant articles
|
||||||
|
top_articles = sentiment.articles[:3]
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
f"News Sentiment: {sentiment.avg_sentiment:.2f} "
|
||||||
|
f"(from {sentiment.article_count} articles)",
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, article in enumerate(top_articles, 1):
|
||||||
|
lines.append(
|
||||||
|
f" {i}. [{article.source}] {article.title} "
|
||||||
|
f"(sentiment: {article.sentiment_score:.2f})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def _format_economic_events(self, stock_code: str) -> str:
|
||||||
|
"""Format upcoming economic events for prompt."""
|
||||||
|
if self._economic_calendar is None:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Check for upcoming high-impact events
|
||||||
|
upcoming = self._economic_calendar.get_upcoming_events(
|
||||||
|
days_ahead=7, min_impact="HIGH"
|
||||||
|
)
|
||||||
|
|
||||||
|
if upcoming.high_impact_count == 0:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
f"Upcoming High-Impact Events: {upcoming.high_impact_count} in next 7 days"
|
||||||
|
]
|
||||||
|
|
||||||
|
if upcoming.next_major_event is not None:
|
||||||
|
event = upcoming.next_major_event
|
||||||
|
lines.append(
|
||||||
|
f" Next: {event.name} ({event.event_type}) "
|
||||||
|
f"on {event.datetime.strftime('%Y-%m-%d')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for earnings
|
||||||
|
earnings_date = self._economic_calendar.get_earnings_date(stock_code)
|
||||||
|
if earnings_date is not None:
|
||||||
|
lines.append(
|
||||||
|
f" Earnings: {stock_code} on {earnings_date.strftime('%Y-%m-%d')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def _format_market_indicators(self) -> str:
|
||||||
|
"""Format market indicators for prompt."""
|
||||||
|
if self._market_data is None:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
indicators = self._market_data.get_market_indicators()
|
||||||
|
lines = [f"Market Sentiment: {indicators.sentiment.name}"]
|
||||||
|
|
||||||
|
# Add breadth if meaningful
|
||||||
|
if indicators.breadth.advance_decline_ratio != 1.0:
|
||||||
|
lines.append(
|
||||||
|
f"Advance/Decline Ratio: {indicators.breadth.advance_decline_ratio:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to get market indicators: %s", exc)
|
||||||
|
return ""
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Prompt Construction
|
# Prompt Construction
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def build_prompt(self, market_data: dict[str, Any]) -> str:
|
async def build_prompt(
|
||||||
"""Build a structured prompt from market data.
|
self, market_data: dict[str, Any], news_sentiment: NewsSentiment | None = None
|
||||||
|
) -> str:
|
||||||
|
"""Build a structured prompt from market data and external sources.
|
||||||
|
|
||||||
The prompt instructs Gemini to return valid JSON with action,
|
The prompt instructs Gemini to return valid JSON with action,
|
||||||
confidence, and rationale fields.
|
confidence, and rationale fields.
|
||||||
@@ -97,6 +241,60 @@ class GeminiClient:
|
|||||||
|
|
||||||
market_info = "\n".join(market_info_lines)
|
market_info = "\n".join(market_info_lines)
|
||||||
|
|
||||||
|
# Add external data context if available
|
||||||
|
external_context = await self._build_external_context(
|
||||||
|
market_data["stock_code"], news_sentiment
|
||||||
|
)
|
||||||
|
if external_context:
|
||||||
|
market_info += f"\n\n{external_context}"
|
||||||
|
|
||||||
|
json_format = (
|
||||||
|
'{"action": "BUY"|"SELL"|"HOLD", '
|
||||||
|
'"confidence": <int 0-100>, "rationale": "<string>"}'
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
f"You are a professional {market_name} trading analyst.\n"
|
||||||
|
"Analyze the following market data and decide whether to "
|
||||||
|
"BUY, SELL, or HOLD.\n\n"
|
||||||
|
f"{market_info}\n\n"
|
||||||
|
"You MUST respond with ONLY valid JSON in the following format:\n"
|
||||||
|
f"{json_format}\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
"- action must be exactly one of: BUY, SELL, HOLD\n"
|
||||||
|
"- confidence must be an integer from 0 to 100\n"
|
||||||
|
"- rationale must explain your reasoning concisely\n"
|
||||||
|
"- Do NOT wrap the JSON in markdown code blocks\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_prompt_sync(self, market_data: dict[str, Any]) -> str:
|
||||||
|
"""Synchronous version of build_prompt (for backward compatibility).
|
||||||
|
|
||||||
|
This version does NOT include external data integration.
|
||||||
|
Use async build_prompt() for full functionality.
|
||||||
|
"""
|
||||||
|
market_name = market_data.get("market_name", "Korean stock market")
|
||||||
|
|
||||||
|
# Build market data section dynamically based on available fields
|
||||||
|
market_info_lines = [
|
||||||
|
f"Market: {market_name}",
|
||||||
|
f"Stock Code: {market_data['stock_code']}",
|
||||||
|
f"Current Price: {market_data['current_price']}",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add orderbook if available (domestic markets)
|
||||||
|
if "orderbook" in market_data:
|
||||||
|
market_info_lines.append(
|
||||||
|
f"Orderbook: {json.dumps(market_data['orderbook'], ensure_ascii=False)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add foreigner net if non-zero
|
||||||
|
if market_data.get("foreigner_net", 0) != 0:
|
||||||
|
market_info_lines.append(
|
||||||
|
f"Foreigner Net Buy/Sell: {market_data['foreigner_net']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
market_info = "\n".join(market_info_lines)
|
||||||
|
|
||||||
json_format = (
|
json_format = (
|
||||||
'{"action": "BUY"|"SELL"|"HOLD", '
|
'{"action": "BUY"|"SELL"|"HOLD", '
|
||||||
'"confidence": <int 0-100>, "rationale": "<string>"}'
|
'"confidence": <int 0-100>, "rationale": "<string>"}'
|
||||||
@@ -177,8 +375,18 @@ class GeminiClient:
|
|||||||
# API Call
|
# API Call
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
async def decide(self, market_data: dict[str, Any]) -> TradeDecision:
|
async def decide(
|
||||||
"""Build prompt, call Gemini, and return a parsed decision."""
|
self, market_data: dict[str, Any], news_sentiment: NewsSentiment | None = None
|
||||||
|
) -> TradeDecision:
|
||||||
|
"""Build prompt, call Gemini, and return a parsed decision.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
market_data: Market data dictionary with price, orderbook, etc.
|
||||||
|
news_sentiment: Optional pre-fetched news sentiment
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed TradeDecision
|
||||||
|
"""
|
||||||
# Check cache first
|
# Check cache first
|
||||||
if self._cache:
|
if self._cache:
|
||||||
cached_decision = self._cache.get(market_data)
|
cached_decision = self._cache.get(market_data)
|
||||||
@@ -206,7 +414,7 @@ class GeminiClient:
|
|||||||
if self._enable_optimization:
|
if self._enable_optimization:
|
||||||
prompt = self._optimizer.build_compressed_prompt(market_data)
|
prompt = self._optimizer.build_compressed_prompt(market_data)
|
||||||
else:
|
else:
|
||||||
prompt = self.build_prompt(market_data)
|
prompt = await self.build_prompt(market_data, news_sentiment)
|
||||||
|
|
||||||
# Estimate tokens
|
# Estimate tokens
|
||||||
token_count = self._optimizer.estimate_tokens(prompt)
|
token_count = self._optimizer.estimate_tokens(prompt)
|
||||||
|
|||||||
@@ -19,6 +19,15 @@ class Settings(BaseSettings):
|
|||||||
GEMINI_API_KEY: str
|
GEMINI_API_KEY: str
|
||||||
GEMINI_MODEL: str = "gemini-pro"
|
GEMINI_MODEL: str = "gemini-pro"
|
||||||
|
|
||||||
|
# External Data APIs (optional — for data-driven decisions)
|
||||||
|
NEWS_API_KEY: str | None = None
|
||||||
|
NEWS_API_PROVIDER: str = "alphavantage" # "alphavantage" or "newsapi"
|
||||||
|
MARKET_DATA_API_KEY: str | None = None
|
||||||
|
|
||||||
|
# Legacy field names (for backward compatibility)
|
||||||
|
ALPHA_VANTAGE_API_KEY: str | None = None
|
||||||
|
NEWSAPI_KEY: str | None = None
|
||||||
|
|
||||||
# Risk Management
|
# Risk Management
|
||||||
CIRCUIT_BREAKER_PCT: float = Field(default=-3.0, le=0.0)
|
CIRCUIT_BREAKER_PCT: float = Field(default=-3.0, le=0.0)
|
||||||
FAT_FINGER_PCT: float = Field(default=30.0, gt=0.0, le=100.0)
|
FAT_FINGER_PCT: float = Field(default=30.0, gt=0.0, le=100.0)
|
||||||
@@ -36,6 +45,15 @@ class Settings(BaseSettings):
|
|||||||
# Market selection (comma-separated market codes)
|
# Market selection (comma-separated market codes)
|
||||||
ENABLED_MARKETS: str = "KR"
|
ENABLED_MARKETS: str = "KR"
|
||||||
|
|
||||||
|
# Backup and Disaster Recovery (optional)
|
||||||
|
BACKUP_ENABLED: bool = True
|
||||||
|
BACKUP_DIR: str = "data/backups"
|
||||||
|
S3_ENDPOINT_URL: str | None = None # For MinIO, Backblaze B2, etc.
|
||||||
|
S3_ACCESS_KEY: str | None = None
|
||||||
|
S3_SECRET_KEY: str | None = None
|
||||||
|
S3_BUCKET_NAME: str | None = None
|
||||||
|
S3_REGION: str = "us-east-1"
|
||||||
|
|
||||||
model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}
|
model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
205
src/data/README.md
Normal file
205
src/data/README.md
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
# External Data Integration
|
||||||
|
|
||||||
|
This module provides objective external data sources to enhance trading decisions beyond just market prices and user input.
|
||||||
|
|
||||||
|
## Modules
|
||||||
|
|
||||||
|
### `news_api.py` - News Sentiment Analysis
|
||||||
|
|
||||||
|
Fetches real-time news for stocks with sentiment scoring.
|
||||||
|
|
||||||
|
**Features:**
|
||||||
|
- Alpha Vantage and NewsAPI.org support
|
||||||
|
- Sentiment scoring (-1.0 to +1.0)
|
||||||
|
- 5-minute caching to minimize API quota usage
|
||||||
|
- Graceful fallback when API unavailable
|
||||||
|
|
||||||
|
**Usage:**
|
||||||
|
```python
|
||||||
|
from src.data.news_api import NewsAPI
|
||||||
|
|
||||||
|
# Initialize with API key
|
||||||
|
news_api = NewsAPI(api_key="your_key", provider="alphavantage")
|
||||||
|
|
||||||
|
# Fetch news sentiment
|
||||||
|
sentiment = await news_api.get_news_sentiment("AAPL")
|
||||||
|
if sentiment:
|
||||||
|
print(f"Average sentiment: {sentiment.avg_sentiment}")
|
||||||
|
for article in sentiment.articles[:3]:
|
||||||
|
print(f"{article.title} ({article.sentiment_score})")
|
||||||
|
```
|
||||||
|
|
||||||
|
### `economic_calendar.py` - Major Economic Events
|
||||||
|
|
||||||
|
Tracks FOMC meetings, GDP releases, CPI, earnings calendars, and other market-moving events.
|
||||||
|
|
||||||
|
**Features:**
|
||||||
|
- High-impact event tracking (FOMC, GDP, CPI)
|
||||||
|
- Earnings calendar per stock
|
||||||
|
- Event proximity checking
|
||||||
|
- Hardcoded major events for 2026 (no API required)
|
||||||
|
|
||||||
|
**Usage:**
|
||||||
|
```python
|
||||||
|
from src.data.economic_calendar import EconomicCalendar
|
||||||
|
|
||||||
|
calendar = EconomicCalendar()
|
||||||
|
calendar.load_hardcoded_events()
|
||||||
|
|
||||||
|
# Get upcoming high-impact events
|
||||||
|
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="HIGH")
|
||||||
|
print(f"High-impact events: {upcoming.high_impact_count}")
|
||||||
|
|
||||||
|
# Check if near earnings
|
||||||
|
earnings_date = calendar.get_earnings_date("AAPL")
|
||||||
|
if earnings_date:
|
||||||
|
print(f"Next earnings: {earnings_date}")
|
||||||
|
|
||||||
|
# Check for high volatility period
|
||||||
|
if calendar.is_high_volatility_period(hours_ahead=24):
|
||||||
|
print("High-impact event imminent!")
|
||||||
|
```
|
||||||
|
|
||||||
|
### `market_data.py` - Market Indicators
|
||||||
|
|
||||||
|
Provides market breadth, sector performance, and sentiment indicators.
|
||||||
|
|
||||||
|
**Features:**
|
||||||
|
- Market sentiment levels (Fear & Greed equivalent)
|
||||||
|
- Market breadth (advancing/declining stocks)
|
||||||
|
- Sector performance tracking
|
||||||
|
- Fear/Greed score calculation
|
||||||
|
|
||||||
|
**Usage:**
|
||||||
|
```python
|
||||||
|
from src.data.market_data import MarketData
|
||||||
|
|
||||||
|
market_data = MarketData(api_key="your_key")
|
||||||
|
|
||||||
|
# Get market sentiment
|
||||||
|
sentiment = market_data.get_market_sentiment()
|
||||||
|
print(f"Market sentiment: {sentiment.name}")
|
||||||
|
|
||||||
|
# Get full indicators
|
||||||
|
indicators = market_data.get_market_indicators("US")
|
||||||
|
print(f"Sentiment: {indicators.sentiment.name}")
|
||||||
|
print(f"A/D Ratio: {indicators.breadth.advance_decline_ratio}")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Integration with GeminiClient
|
||||||
|
|
||||||
|
The external data sources are seamlessly integrated into the AI decision engine:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.brain.gemini_client import GeminiClient
|
||||||
|
from src.data.news_api import NewsAPI
|
||||||
|
from src.data.economic_calendar import EconomicCalendar
|
||||||
|
from src.data.market_data import MarketData
|
||||||
|
from src.config import Settings
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
|
||||||
|
# Initialize data sources
|
||||||
|
news_api = NewsAPI(api_key=settings.NEWS_API_KEY, provider=settings.NEWS_API_PROVIDER)
|
||||||
|
calendar = EconomicCalendar()
|
||||||
|
calendar.load_hardcoded_events()
|
||||||
|
market_data = MarketData(api_key=settings.MARKET_DATA_API_KEY)
|
||||||
|
|
||||||
|
# Create enhanced client
|
||||||
|
client = GeminiClient(
|
||||||
|
settings,
|
||||||
|
news_api=news_api,
|
||||||
|
economic_calendar=calendar,
|
||||||
|
market_data=market_data
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make decision with external context
|
||||||
|
market_data_dict = {
|
||||||
|
"stock_code": "AAPL",
|
||||||
|
"current_price": 180.0,
|
||||||
|
"market_name": "US stock market"
|
||||||
|
}
|
||||||
|
|
||||||
|
decision = await client.decide(market_data_dict)
|
||||||
|
```
|
||||||
|
|
||||||
|
The external data is automatically included in the prompt sent to Gemini:
|
||||||
|
|
||||||
|
```
|
||||||
|
Market: US stock market
|
||||||
|
Stock Code: AAPL
|
||||||
|
Current Price: 180.0
|
||||||
|
|
||||||
|
EXTERNAL DATA:
|
||||||
|
News Sentiment: 0.85 (from 10 articles)
|
||||||
|
1. [Reuters] Apple hits record high (sentiment: 0.92)
|
||||||
|
2. [Bloomberg] Strong iPhone sales (sentiment: 0.78)
|
||||||
|
3. [CNBC] Tech sector rallying (sentiment: 0.85)
|
||||||
|
|
||||||
|
Upcoming High-Impact Events: 2 in next 7 days
|
||||||
|
Next: FOMC Meeting (FOMC) on 2026-03-18
|
||||||
|
Earnings: AAPL on 2026-02-10
|
||||||
|
|
||||||
|
Market Sentiment: GREED
|
||||||
|
Advance/Decline Ratio: 2.35
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Add these to your `.env` file:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# External Data APIs (optional)
|
||||||
|
NEWS_API_KEY=your_alpha_vantage_key
|
||||||
|
NEWS_API_PROVIDER=alphavantage # or "newsapi"
|
||||||
|
MARKET_DATA_API_KEY=your_market_data_key
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Recommendations
|
||||||
|
|
||||||
|
### Alpha Vantage (News)
|
||||||
|
- **Free tier:** 5 calls/min, 500 calls/day
|
||||||
|
- **Pros:** Provides sentiment scores, no credit card required
|
||||||
|
- **URL:** https://www.alphavantage.co/
|
||||||
|
|
||||||
|
### NewsAPI.org
|
||||||
|
- **Free tier:** 100 requests/day
|
||||||
|
- **Pros:** Large news coverage, easy to use
|
||||||
|
- **Cons:** No sentiment scores (we use keyword heuristics)
|
||||||
|
- **URL:** https://newsapi.org/
|
||||||
|
|
||||||
|
## Caching Strategy
|
||||||
|
|
||||||
|
To minimize API quota usage:
|
||||||
|
|
||||||
|
1. **News:** 5-minute TTL cache per stock
|
||||||
|
2. **Economic Calendar:** Loaded once at startup (hardcoded events)
|
||||||
|
3. **Market Data:** Fetched per decision (lightweight)
|
||||||
|
|
||||||
|
## Graceful Degradation
|
||||||
|
|
||||||
|
The system works gracefully without external data:
|
||||||
|
|
||||||
|
- If no API keys provided → decisions work with just market prices
|
||||||
|
- If API fails → decision continues without external context
|
||||||
|
- If cache expired → attempts refetch, falls back to no data
|
||||||
|
- Errors are logged but never block trading decisions
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
All modules have comprehensive test coverage (81%+):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pytest tests/test_data_integration.py -v --cov=src/data
|
||||||
|
```
|
||||||
|
|
||||||
|
Tests use mocks to avoid requiring real API keys.
|
||||||
|
|
||||||
|
## Future Enhancements
|
||||||
|
|
||||||
|
- Twitter/X sentiment analysis
|
||||||
|
- Reddit WallStreetBets sentiment
|
||||||
|
- Options flow data
|
||||||
|
- Insider trading activity
|
||||||
|
- Analyst upgrades/downgrades
|
||||||
|
- Real-time economic data APIs
|
||||||
5
src/data/__init__.py
Normal file
5
src/data/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""External data integration for objective decision-making."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
__all__ = ["NewsAPI", "EconomicCalendar", "MarketData"]
|
||||||
219
src/data/economic_calendar.py
Normal file
219
src/data/economic_calendar.py
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
"""Economic calendar integration for major market events.
|
||||||
|
|
||||||
|
Tracks FOMC meetings, GDP releases, CPI, earnings calendars, and other
|
||||||
|
market-moving events.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EconomicEvent:
|
||||||
|
"""Single economic event."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
event_type: str # "FOMC", "GDP", "CPI", "EARNINGS", etc.
|
||||||
|
datetime: datetime
|
||||||
|
impact: str # "HIGH", "MEDIUM", "LOW"
|
||||||
|
country: str
|
||||||
|
description: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UpcomingEvents:
|
||||||
|
"""Collection of upcoming economic events."""
|
||||||
|
|
||||||
|
events: list[EconomicEvent]
|
||||||
|
high_impact_count: int
|
||||||
|
next_major_event: EconomicEvent | None
|
||||||
|
|
||||||
|
|
||||||
|
class EconomicCalendar:
|
||||||
|
"""Economic calendar with event tracking and impact scoring."""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str | None = None) -> None:
|
||||||
|
"""Initialize economic calendar.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API key for calendar provider (None for testing/hardcoded)
|
||||||
|
"""
|
||||||
|
self._api_key = api_key
|
||||||
|
# For now, use hardcoded major events (can be extended with API)
|
||||||
|
self._events: list[EconomicEvent] = []
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_upcoming_events(
|
||||||
|
self, days_ahead: int = 7, min_impact: str = "MEDIUM"
|
||||||
|
) -> UpcomingEvents:
|
||||||
|
"""Get upcoming economic events within specified timeframe.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
days_ahead: Number of days to look ahead
|
||||||
|
min_impact: Minimum impact level ("LOW", "MEDIUM", "HIGH")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UpcomingEvents with filtered events
|
||||||
|
"""
|
||||||
|
now = datetime.now()
|
||||||
|
end_date = now + timedelta(days=days_ahead)
|
||||||
|
|
||||||
|
# Filter events by timeframe and impact
|
||||||
|
upcoming = [
|
||||||
|
event
|
||||||
|
for event in self._events
|
||||||
|
if now <= event.datetime <= end_date
|
||||||
|
and self._impact_level(event.impact) >= self._impact_level(min_impact)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Sort by datetime
|
||||||
|
upcoming.sort(key=lambda e: e.datetime)
|
||||||
|
|
||||||
|
# Count high-impact events
|
||||||
|
high_impact_count = sum(1 for e in upcoming if e.impact == "HIGH")
|
||||||
|
|
||||||
|
# Get next major event
|
||||||
|
next_major = None
|
||||||
|
for event in upcoming:
|
||||||
|
if event.impact == "HIGH":
|
||||||
|
next_major = event
|
||||||
|
break
|
||||||
|
|
||||||
|
return UpcomingEvents(
|
||||||
|
events=upcoming,
|
||||||
|
high_impact_count=high_impact_count,
|
||||||
|
next_major_event=next_major,
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_event(self, event: EconomicEvent) -> None:
|
||||||
|
"""Add an economic event to the calendar."""
|
||||||
|
self._events.append(event)
|
||||||
|
|
||||||
|
def clear_events(self) -> None:
|
||||||
|
"""Clear all events (useful for testing)."""
|
||||||
|
self._events.clear()
|
||||||
|
|
||||||
|
def get_earnings_date(self, stock_code: str) -> datetime | None:
|
||||||
|
"""Get next earnings date for a stock.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stock_code: Stock ticker symbol
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Next earnings datetime or None if not found
|
||||||
|
"""
|
||||||
|
now = datetime.now()
|
||||||
|
earnings_events = [
|
||||||
|
event
|
||||||
|
for event in self._events
|
||||||
|
if event.event_type == "EARNINGS"
|
||||||
|
and stock_code.upper() in event.name.upper()
|
||||||
|
and event.datetime > now
|
||||||
|
]
|
||||||
|
|
||||||
|
if not earnings_events:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Return earliest upcoming earnings
|
||||||
|
earnings_events.sort(key=lambda e: e.datetime)
|
||||||
|
return earnings_events[0].datetime
|
||||||
|
|
||||||
|
def load_hardcoded_events(self) -> None:
|
||||||
|
"""Load hardcoded major economic events for 2026.
|
||||||
|
|
||||||
|
This is a fallback when no API is available.
|
||||||
|
"""
|
||||||
|
# Major FOMC meetings in 2026 (estimated)
|
||||||
|
fomc_dates = [
|
||||||
|
datetime(2026, 3, 18),
|
||||||
|
datetime(2026, 5, 6),
|
||||||
|
datetime(2026, 6, 17),
|
||||||
|
datetime(2026, 7, 29),
|
||||||
|
datetime(2026, 9, 16),
|
||||||
|
datetime(2026, 11, 4),
|
||||||
|
datetime(2026, 12, 16),
|
||||||
|
]
|
||||||
|
|
||||||
|
for date in fomc_dates:
|
||||||
|
self.add_event(
|
||||||
|
EconomicEvent(
|
||||||
|
name="FOMC Meeting",
|
||||||
|
event_type="FOMC",
|
||||||
|
datetime=date,
|
||||||
|
impact="HIGH",
|
||||||
|
country="US",
|
||||||
|
description="Federal Reserve interest rate decision",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Quarterly GDP releases (estimated)
|
||||||
|
gdp_dates = [
|
||||||
|
datetime(2026, 4, 28),
|
||||||
|
datetime(2026, 7, 30),
|
||||||
|
datetime(2026, 10, 29),
|
||||||
|
]
|
||||||
|
|
||||||
|
for date in gdp_dates:
|
||||||
|
self.add_event(
|
||||||
|
EconomicEvent(
|
||||||
|
name="US GDP Release",
|
||||||
|
event_type="GDP",
|
||||||
|
datetime=date,
|
||||||
|
impact="HIGH",
|
||||||
|
country="US",
|
||||||
|
description="Quarterly GDP growth rate",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Monthly CPI releases (12th of each month, estimated)
|
||||||
|
for month in range(1, 13):
|
||||||
|
try:
|
||||||
|
cpi_date = datetime(2026, month, 12)
|
||||||
|
self.add_event(
|
||||||
|
EconomicEvent(
|
||||||
|
name="US CPI Release",
|
||||||
|
event_type="CPI",
|
||||||
|
datetime=cpi_date,
|
||||||
|
impact="HIGH",
|
||||||
|
country="US",
|
||||||
|
description="Consumer Price Index inflation data",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _impact_level(self, impact: str) -> int:
|
||||||
|
"""Convert impact string to numeric level."""
|
||||||
|
levels = {"LOW": 1, "MEDIUM": 2, "HIGH": 3}
|
||||||
|
return levels.get(impact.upper(), 0)
|
||||||
|
|
||||||
|
def is_high_volatility_period(self, hours_ahead: int = 24) -> bool:
|
||||||
|
"""Check if we're near a high-impact event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hours_ahead: Number of hours to look ahead
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if high-impact event is imminent
|
||||||
|
"""
|
||||||
|
now = datetime.now()
|
||||||
|
threshold = now + timedelta(hours=hours_ahead)
|
||||||
|
|
||||||
|
for event in self._events:
|
||||||
|
if event.impact == "HIGH" and now <= event.datetime <= threshold:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
198
src/data/market_data.py
Normal file
198
src/data/market_data.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
"""Additional market data indicators beyond basic price data.
|
||||||
|
|
||||||
|
Provides market breadth, sector performance, and market sentiment indicators.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MarketSentiment(Enum):
|
||||||
|
"""Overall market sentiment levels."""
|
||||||
|
|
||||||
|
EXTREME_FEAR = 1
|
||||||
|
FEAR = 2
|
||||||
|
NEUTRAL = 3
|
||||||
|
GREED = 4
|
||||||
|
EXTREME_GREED = 5
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SectorPerformance:
|
||||||
|
"""Performance metrics for a market sector."""
|
||||||
|
|
||||||
|
sector_name: str
|
||||||
|
daily_change_pct: float
|
||||||
|
weekly_change_pct: float
|
||||||
|
leader_stock: str # Best performing stock in sector
|
||||||
|
laggard_stock: str # Worst performing stock in sector
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MarketBreadth:
|
||||||
|
"""Market breadth indicators."""
|
||||||
|
|
||||||
|
advancing_stocks: int
|
||||||
|
declining_stocks: int
|
||||||
|
unchanged_stocks: int
|
||||||
|
new_highs: int
|
||||||
|
new_lows: int
|
||||||
|
advance_decline_ratio: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MarketIndicators:
|
||||||
|
"""Aggregated market indicators."""
|
||||||
|
|
||||||
|
sentiment: MarketSentiment
|
||||||
|
breadth: MarketBreadth
|
||||||
|
sector_performance: list[SectorPerformance]
|
||||||
|
vix_level: float | None # Volatility index if available
|
||||||
|
|
||||||
|
|
||||||
|
class MarketData:
|
||||||
|
"""Market data provider for additional indicators."""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str | None = None) -> None:
|
||||||
|
"""Initialize market data provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API key for data provider (None for testing)
|
||||||
|
"""
|
||||||
|
self._api_key = api_key
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_market_sentiment(self) -> MarketSentiment:
|
||||||
|
"""Get current market sentiment level.
|
||||||
|
|
||||||
|
This is a simplified version. In production, this would integrate
|
||||||
|
with Fear & Greed Index or similar sentiment indicators.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MarketSentiment enum value
|
||||||
|
"""
|
||||||
|
# Default to neutral when API not available
|
||||||
|
if self._api_key is None:
|
||||||
|
logger.debug("No market data API key — returning NEUTRAL sentiment")
|
||||||
|
return MarketSentiment.NEUTRAL
|
||||||
|
|
||||||
|
# TODO: Integrate with actual sentiment API
|
||||||
|
return MarketSentiment.NEUTRAL
|
||||||
|
|
||||||
|
def get_market_breadth(self, market: str = "US") -> MarketBreadth | None:
|
||||||
|
"""Get market breadth indicators.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
market: Market code ("US", "KR", etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MarketBreadth object or None if unavailable
|
||||||
|
"""
|
||||||
|
if self._api_key is None:
|
||||||
|
logger.debug("No market data API key — returning None for breadth")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# TODO: Integrate with actual market breadth API
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_sector_performance(
|
||||||
|
self, market: str = "US"
|
||||||
|
) -> list[SectorPerformance]:
|
||||||
|
"""Get sector performance rankings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
market: Market code ("US", "KR", etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SectorPerformance objects, sorted by daily change
|
||||||
|
"""
|
||||||
|
if self._api_key is None:
|
||||||
|
logger.debug("No market data API key — returning empty sector list")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# TODO: Integrate with actual sector performance API
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_market_indicators(self, market: str = "US") -> MarketIndicators:
|
||||||
|
"""Get aggregated market indicators.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
market: Market code ("US", "KR", etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MarketIndicators with all available data
|
||||||
|
"""
|
||||||
|
sentiment = self.get_market_sentiment()
|
||||||
|
breadth = self.get_market_breadth(market)
|
||||||
|
sectors = self.get_sector_performance(market)
|
||||||
|
|
||||||
|
# Default breadth if unavailable
|
||||||
|
if breadth is None:
|
||||||
|
breadth = MarketBreadth(
|
||||||
|
advancing_stocks=0,
|
||||||
|
declining_stocks=0,
|
||||||
|
unchanged_stocks=0,
|
||||||
|
new_highs=0,
|
||||||
|
new_lows=0,
|
||||||
|
advance_decline_ratio=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
return MarketIndicators(
|
||||||
|
sentiment=sentiment,
|
||||||
|
breadth=breadth,
|
||||||
|
sector_performance=sectors,
|
||||||
|
vix_level=None, # TODO: Add VIX integration
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Helper Methods
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def calculate_fear_greed_score(
|
||||||
|
self, breadth: MarketBreadth, vix: float | None = None
|
||||||
|
) -> int:
|
||||||
|
"""Calculate a simple fear/greed score (0-100).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
breadth: Market breadth data
|
||||||
|
vix: VIX level (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Score from 0 (extreme fear) to 100 (extreme greed)
|
||||||
|
"""
|
||||||
|
# Start at neutral
|
||||||
|
score = 50
|
||||||
|
|
||||||
|
# Adjust based on advance/decline ratio
|
||||||
|
if breadth.advance_decline_ratio > 1.5:
|
||||||
|
score += 20
|
||||||
|
elif breadth.advance_decline_ratio > 1.0:
|
||||||
|
score += 10
|
||||||
|
elif breadth.advance_decline_ratio < 0.5:
|
||||||
|
score -= 20
|
||||||
|
elif breadth.advance_decline_ratio < 1.0:
|
||||||
|
score -= 10
|
||||||
|
|
||||||
|
# Adjust based on new highs/lows
|
||||||
|
if breadth.new_highs > breadth.new_lows * 2:
|
||||||
|
score += 15
|
||||||
|
elif breadth.new_lows > breadth.new_highs * 2:
|
||||||
|
score -= 15
|
||||||
|
|
||||||
|
# Adjust based on VIX if available
|
||||||
|
if vix is not None:
|
||||||
|
if vix > 30: # High volatility = fear
|
||||||
|
score -= 15
|
||||||
|
elif vix < 15: # Low volatility = complacency/greed
|
||||||
|
score += 10
|
||||||
|
|
||||||
|
# Clamp to 0-100
|
||||||
|
return max(0, min(100, score))
|
||||||
316
src/data/news_api.py
Normal file
316
src/data/news_api.py
Normal file
@@ -0,0 +1,316 @@
|
|||||||
|
"""News API integration with sentiment analysis and caching.
|
||||||
|
|
||||||
|
Fetches real-time news for stocks using free-tier APIs (Alpha Vantage or NewsAPI).
|
||||||
|
Includes 5-minute caching to minimize API quota usage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Cache entries expire after 5 minutes
|
||||||
|
CACHE_TTL_SECONDS = 300
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NewsArticle:
|
||||||
|
"""Single news article with sentiment."""
|
||||||
|
|
||||||
|
title: str
|
||||||
|
summary: str
|
||||||
|
source: str
|
||||||
|
published_at: str
|
||||||
|
sentiment_score: float # -1.0 (negative) to +1.0 (positive)
|
||||||
|
url: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NewsSentiment:
|
||||||
|
"""Aggregated news sentiment for a stock."""
|
||||||
|
|
||||||
|
stock_code: str
|
||||||
|
articles: list[NewsArticle]
|
||||||
|
avg_sentiment: float # Average sentiment across all articles
|
||||||
|
article_count: int
|
||||||
|
fetched_at: float # Unix timestamp
|
||||||
|
|
||||||
|
|
||||||
|
class NewsAPI:
|
||||||
|
"""News API client with sentiment analysis and caching."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str | None = None,
|
||||||
|
provider: str = "alphavantage",
|
||||||
|
cache_ttl: int = CACHE_TTL_SECONDS,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize NewsAPI client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API key for the news provider (None for testing)
|
||||||
|
provider: News provider ("alphavantage" or "newsapi")
|
||||||
|
cache_ttl: Cache time-to-live in seconds
|
||||||
|
"""
|
||||||
|
self._api_key = api_key
|
||||||
|
self._provider = provider
|
||||||
|
self._cache_ttl = cache_ttl
|
||||||
|
self._cache: dict[str, NewsSentiment] = {}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def get_news_sentiment(self, stock_code: str) -> NewsSentiment | None:
|
||||||
|
"""Fetch news sentiment for a stock with caching.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stock_code: Stock ticker symbol (e.g., "AAPL", "005930")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
NewsSentiment object or None if fetch fails or API unavailable
|
||||||
|
"""
|
||||||
|
# Check cache first
|
||||||
|
cached = self._get_from_cache(stock_code)
|
||||||
|
if cached is not None:
|
||||||
|
logger.debug("News cache hit for %s", stock_code)
|
||||||
|
return cached
|
||||||
|
|
||||||
|
# API key required for real requests
|
||||||
|
if self._api_key is None:
|
||||||
|
logger.warning("No news API key provided — returning None")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Fetch from API
|
||||||
|
try:
|
||||||
|
sentiment = await self._fetch_news(stock_code)
|
||||||
|
if sentiment is not None:
|
||||||
|
self._cache[stock_code] = sentiment
|
||||||
|
return sentiment
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to fetch news for %s: %s", stock_code, exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def clear_cache(self) -> None:
|
||||||
|
"""Clear the news cache (useful for testing)."""
|
||||||
|
self._cache.clear()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Cache Management
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _get_from_cache(self, stock_code: str) -> NewsSentiment | None:
|
||||||
|
"""Retrieve cached sentiment if not expired."""
|
||||||
|
if stock_code not in self._cache:
|
||||||
|
return None
|
||||||
|
|
||||||
|
cached = self._cache[stock_code]
|
||||||
|
age = time.time() - cached.fetched_at
|
||||||
|
|
||||||
|
if age > self._cache_ttl:
|
||||||
|
logger.debug("News cache expired for %s (age: %.1fs)", stock_code, age)
|
||||||
|
del self._cache[stock_code]
|
||||||
|
return None
|
||||||
|
|
||||||
|
return cached
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# API Fetching
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _fetch_news(self, stock_code: str) -> NewsSentiment | None:
|
||||||
|
"""Fetch news from the provider API."""
|
||||||
|
if self._provider == "alphavantage":
|
||||||
|
return await self._fetch_alphavantage(stock_code)
|
||||||
|
elif self._provider == "newsapi":
|
||||||
|
return await self._fetch_newsapi(stock_code)
|
||||||
|
else:
|
||||||
|
logger.error("Unknown news provider: %s", self._provider)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _fetch_alphavantage(self, stock_code: str) -> NewsSentiment | None:
|
||||||
|
"""Fetch news from Alpha Vantage News Sentiment API."""
|
||||||
|
url = "https://www.alphavantage.co/query"
|
||||||
|
params = {
|
||||||
|
"function": "NEWS_SENTIMENT",
|
||||||
|
"tickers": stock_code,
|
||||||
|
"apikey": self._api_key,
|
||||||
|
"limit": 10, # Fetch top 10 articles
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(url, params=params, timeout=10) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
logger.error(
|
||||||
|
"Alpha Vantage API error: HTTP %d", resp.status
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
data = await resp.json()
|
||||||
|
return self._parse_alphavantage_response(stock_code, data)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Alpha Vantage request failed: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _fetch_newsapi(self, stock_code: str) -> NewsSentiment | None:
|
||||||
|
"""Fetch news from NewsAPI.org."""
|
||||||
|
url = "https://newsapi.org/v2/everything"
|
||||||
|
params = {
|
||||||
|
"q": stock_code,
|
||||||
|
"apiKey": self._api_key,
|
||||||
|
"pageSize": 10,
|
||||||
|
"sortBy": "publishedAt",
|
||||||
|
"language": "en",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(url, params=params, timeout=10) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
logger.error("NewsAPI error: HTTP %d", resp.status)
|
||||||
|
return None
|
||||||
|
|
||||||
|
data = await resp.json()
|
||||||
|
return self._parse_newsapi_response(stock_code, data)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("NewsAPI request failed: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Response Parsing
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _parse_alphavantage_response(
|
||||||
|
self, stock_code: str, data: dict[str, Any]
|
||||||
|
) -> NewsSentiment | None:
|
||||||
|
"""Parse Alpha Vantage API response."""
|
||||||
|
if "feed" not in data:
|
||||||
|
logger.warning("No 'feed' key in Alpha Vantage response")
|
||||||
|
return None
|
||||||
|
|
||||||
|
articles: list[NewsArticle] = []
|
||||||
|
for item in data["feed"]:
|
||||||
|
# Extract sentiment for this specific ticker
|
||||||
|
ticker_sentiment = self._extract_ticker_sentiment(item, stock_code)
|
||||||
|
|
||||||
|
article = NewsArticle(
|
||||||
|
title=item.get("title", ""),
|
||||||
|
summary=item.get("summary", "")[:200], # Truncate long summaries
|
||||||
|
source=item.get("source", "Unknown"),
|
||||||
|
published_at=item.get("time_published", ""),
|
||||||
|
sentiment_score=ticker_sentiment,
|
||||||
|
url=item.get("url", ""),
|
||||||
|
)
|
||||||
|
articles.append(article)
|
||||||
|
|
||||||
|
if not articles:
|
||||||
|
return None
|
||||||
|
|
||||||
|
avg_sentiment = sum(a.sentiment_score for a in articles) / len(articles)
|
||||||
|
|
||||||
|
return NewsSentiment(
|
||||||
|
stock_code=stock_code,
|
||||||
|
articles=articles,
|
||||||
|
avg_sentiment=avg_sentiment,
|
||||||
|
article_count=len(articles),
|
||||||
|
fetched_at=time.time(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _extract_ticker_sentiment(
|
||||||
|
self, item: dict[str, Any], stock_code: str
|
||||||
|
) -> float:
|
||||||
|
"""Extract sentiment score for specific ticker from article."""
|
||||||
|
ticker_sentiments = item.get("ticker_sentiment", [])
|
||||||
|
for ts in ticker_sentiments:
|
||||||
|
if ts.get("ticker", "").upper() == stock_code.upper():
|
||||||
|
# Alpha Vantage provides sentiment_score as string
|
||||||
|
score_str = ts.get("ticker_sentiment_score", "0")
|
||||||
|
try:
|
||||||
|
return float(score_str)
|
||||||
|
except ValueError:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Fallback to overall sentiment if ticker-specific not found
|
||||||
|
overall_sentiment = item.get("overall_sentiment_score", "0")
|
||||||
|
try:
|
||||||
|
return float(overall_sentiment)
|
||||||
|
except ValueError:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def _parse_newsapi_response(
|
||||||
|
self, stock_code: str, data: dict[str, Any]
|
||||||
|
) -> NewsSentiment | None:
|
||||||
|
"""Parse NewsAPI.org response.
|
||||||
|
|
||||||
|
Note: NewsAPI doesn't provide sentiment scores, so we use a
|
||||||
|
simple heuristic based on title keywords.
|
||||||
|
"""
|
||||||
|
if data.get("status") != "ok" or "articles" not in data:
|
||||||
|
logger.warning("Invalid NewsAPI response")
|
||||||
|
return None
|
||||||
|
|
||||||
|
articles: list[NewsArticle] = []
|
||||||
|
for item in data["articles"]:
|
||||||
|
# Simple sentiment heuristic based on keywords
|
||||||
|
sentiment = self._estimate_sentiment_from_text(
|
||||||
|
item.get("title", "") + " " + item.get("description", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
article = NewsArticle(
|
||||||
|
title=item.get("title", ""),
|
||||||
|
summary=item.get("description", "")[:200],
|
||||||
|
source=item.get("source", {}).get("name", "Unknown"),
|
||||||
|
published_at=item.get("publishedAt", ""),
|
||||||
|
sentiment_score=sentiment,
|
||||||
|
url=item.get("url", ""),
|
||||||
|
)
|
||||||
|
articles.append(article)
|
||||||
|
|
||||||
|
if not articles:
|
||||||
|
return None
|
||||||
|
|
||||||
|
avg_sentiment = sum(a.sentiment_score for a in articles) / len(articles)
|
||||||
|
|
||||||
|
return NewsSentiment(
|
||||||
|
stock_code=stock_code,
|
||||||
|
articles=articles,
|
||||||
|
avg_sentiment=avg_sentiment,
|
||||||
|
article_count=len(articles),
|
||||||
|
fetched_at=time.time(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _estimate_sentiment_from_text(self, text: str) -> float:
|
||||||
|
"""Simple keyword-based sentiment estimation.
|
||||||
|
|
||||||
|
This is a fallback for APIs that don't provide sentiment scores.
|
||||||
|
Returns a score between -1.0 and +1.0.
|
||||||
|
"""
|
||||||
|
text_lower = text.lower()
|
||||||
|
|
||||||
|
positive_keywords = [
|
||||||
|
"surge", "jump", "gain", "rise", "soar", "rally", "profit",
|
||||||
|
"growth", "upgrade", "beat", "strong", "bullish", "breakthrough",
|
||||||
|
]
|
||||||
|
negative_keywords = [
|
||||||
|
"plunge", "fall", "drop", "decline", "crash", "loss", "weak",
|
||||||
|
"downgrade", "miss", "bearish", "concern", "risk", "warning",
|
||||||
|
]
|
||||||
|
|
||||||
|
positive_count = sum(1 for kw in positive_keywords if kw in text_lower)
|
||||||
|
negative_count = sum(1 for kw in negative_keywords if kw in text_lower)
|
||||||
|
|
||||||
|
total = positive_count + negative_count
|
||||||
|
if total == 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Normalize to -1.0 to +1.0 range
|
||||||
|
return (positive_count - negative_count) / total
|
||||||
365
tests/test_backup.py
Normal file
365
tests/test_backup.py
Normal file
@@ -0,0 +1,365 @@
|
|||||||
|
"""Tests for backup and disaster recovery system."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
import tempfile
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.backup.exporter import BackupExporter, ExportFormat
|
||||||
|
from src.backup.health_monitor import HealthMonitor, HealthStatus
|
||||||
|
from src.backup.scheduler import BackupPolicy, BackupScheduler
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_db(tmp_path: Path) -> Path:
|
||||||
|
"""Create a temporary test database."""
|
||||||
|
db_path = tmp_path / "test_trades.db"
|
||||||
|
|
||||||
|
conn = sqlite3.connect(str(db_path))
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# Create trades table
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE trades (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
timestamp TEXT NOT NULL,
|
||||||
|
stock_code TEXT NOT NULL,
|
||||||
|
action TEXT NOT NULL,
|
||||||
|
quantity INTEGER NOT NULL,
|
||||||
|
price REAL NOT NULL,
|
||||||
|
confidence INTEGER NOT NULL,
|
||||||
|
rationale TEXT,
|
||||||
|
pnl REAL DEFAULT 0.0
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Insert test data
|
||||||
|
test_trades = [
|
||||||
|
("2024-01-01T10:00:00Z", "005930", "BUY", 10, 70000.0, 85, "Test buy", 0.0),
|
||||||
|
("2024-01-01T11:00:00Z", "005930", "SELL", 10, 71000.0, 90, "Test sell", 10000.0),
|
||||||
|
("2024-01-02T10:00:00Z", "AAPL", "BUY", 5, 180.0, 88, "Tech buy", 0.0),
|
||||||
|
]
|
||||||
|
|
||||||
|
cursor.executemany(
|
||||||
|
"""
|
||||||
|
INSERT INTO trades (timestamp, stock_code, action, quantity, price, confidence, rationale, pnl)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
test_trades,
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
return db_path
|
||||||
|
|
||||||
|
|
||||||
|
class TestBackupExporter:
|
||||||
|
"""Test BackupExporter functionality."""
|
||||||
|
|
||||||
|
def test_exporter_init(self, temp_db: Path) -> None:
|
||||||
|
"""Test exporter initialization."""
|
||||||
|
exporter = BackupExporter(str(temp_db))
|
||||||
|
assert exporter.db_path == str(temp_db)
|
||||||
|
|
||||||
|
def test_export_json(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test JSON export."""
|
||||||
|
exporter = BackupExporter(str(temp_db))
|
||||||
|
output_dir = tmp_path / "exports"
|
||||||
|
|
||||||
|
results = exporter.export_all(
|
||||||
|
output_dir, formats=[ExportFormat.JSON], compress=False
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ExportFormat.JSON in results
|
||||||
|
assert results[ExportFormat.JSON].exists()
|
||||||
|
assert results[ExportFormat.JSON].suffix == ".json"
|
||||||
|
|
||||||
|
def test_export_json_compressed(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test compressed JSON export."""
|
||||||
|
exporter = BackupExporter(str(temp_db))
|
||||||
|
output_dir = tmp_path / "exports"
|
||||||
|
|
||||||
|
results = exporter.export_all(
|
||||||
|
output_dir, formats=[ExportFormat.JSON], compress=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ExportFormat.JSON in results
|
||||||
|
assert results[ExportFormat.JSON].suffix == ".gz"
|
||||||
|
|
||||||
|
def test_export_csv(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test CSV export."""
|
||||||
|
exporter = BackupExporter(str(temp_db))
|
||||||
|
output_dir = tmp_path / "exports"
|
||||||
|
|
||||||
|
results = exporter.export_all(
|
||||||
|
output_dir, formats=[ExportFormat.CSV], compress=False
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ExportFormat.CSV in results
|
||||||
|
assert results[ExportFormat.CSV].exists()
|
||||||
|
|
||||||
|
# Verify CSV content
|
||||||
|
with open(results[ExportFormat.CSV], "r") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
assert len(lines) == 4 # Header + 3 rows
|
||||||
|
|
||||||
|
def test_export_all_formats(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test exporting all formats."""
|
||||||
|
exporter = BackupExporter(str(temp_db))
|
||||||
|
output_dir = tmp_path / "exports"
|
||||||
|
|
||||||
|
# Skip Parquet if pyarrow not available
|
||||||
|
try:
|
||||||
|
import pyarrow # noqa: F401
|
||||||
|
|
||||||
|
formats = [ExportFormat.JSON, ExportFormat.CSV, ExportFormat.PARQUET]
|
||||||
|
except ImportError:
|
||||||
|
formats = [ExportFormat.JSON, ExportFormat.CSV]
|
||||||
|
|
||||||
|
results = exporter.export_all(output_dir, formats=formats, compress=False)
|
||||||
|
|
||||||
|
for fmt in formats:
|
||||||
|
assert fmt in results
|
||||||
|
assert results[fmt].exists()
|
||||||
|
|
||||||
|
def test_incremental_export(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test incremental export."""
|
||||||
|
exporter = BackupExporter(str(temp_db))
|
||||||
|
output_dir = tmp_path / "exports"
|
||||||
|
|
||||||
|
# Export only trades after Jan 2
|
||||||
|
cutoff = datetime(2024, 1, 2, tzinfo=UTC)
|
||||||
|
results = exporter.export_all(
|
||||||
|
output_dir,
|
||||||
|
formats=[ExportFormat.JSON],
|
||||||
|
compress=False,
|
||||||
|
incremental_since=cutoff,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should only have 1 trade (AAPL on Jan 2)
|
||||||
|
import json
|
||||||
|
|
||||||
|
with open(results[ExportFormat.JSON], "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
assert data["record_count"] == 1
|
||||||
|
assert data["trades"][0]["stock_code"] == "AAPL"
|
||||||
|
|
||||||
|
def test_get_export_stats(self, temp_db: Path) -> None:
|
||||||
|
"""Test export statistics."""
|
||||||
|
exporter = BackupExporter(str(temp_db))
|
||||||
|
stats = exporter.get_export_stats()
|
||||||
|
|
||||||
|
assert stats["total_trades"] == 3
|
||||||
|
assert "date_range" in stats
|
||||||
|
assert "db_size_bytes" in stats
|
||||||
|
|
||||||
|
|
||||||
|
class TestBackupScheduler:
|
||||||
|
"""Test BackupScheduler functionality."""
|
||||||
|
|
||||||
|
def test_scheduler_init(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test scheduler initialization."""
|
||||||
|
backup_dir = tmp_path / "backups"
|
||||||
|
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||||
|
|
||||||
|
assert scheduler.db_path == temp_db
|
||||||
|
assert (backup_dir / "daily").exists()
|
||||||
|
assert (backup_dir / "weekly").exists()
|
||||||
|
assert (backup_dir / "monthly").exists()
|
||||||
|
|
||||||
|
def test_create_daily_backup(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test daily backup creation."""
|
||||||
|
backup_dir = tmp_path / "backups"
|
||||||
|
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||||
|
|
||||||
|
metadata = scheduler.create_backup(BackupPolicy.DAILY, verify=True)
|
||||||
|
|
||||||
|
assert metadata.policy == BackupPolicy.DAILY
|
||||||
|
assert metadata.file_path.exists()
|
||||||
|
assert metadata.size_bytes > 0
|
||||||
|
assert metadata.checksum is not None
|
||||||
|
|
||||||
|
def test_create_weekly_backup(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test weekly backup creation."""
|
||||||
|
backup_dir = tmp_path / "backups"
|
||||||
|
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||||
|
|
||||||
|
metadata = scheduler.create_backup(BackupPolicy.WEEKLY, verify=False)
|
||||||
|
|
||||||
|
assert metadata.policy == BackupPolicy.WEEKLY
|
||||||
|
assert metadata.file_path.exists()
|
||||||
|
assert metadata.checksum is None # verify=False
|
||||||
|
|
||||||
|
def test_list_backups(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test listing backups."""
|
||||||
|
backup_dir = tmp_path / "backups"
|
||||||
|
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||||
|
|
||||||
|
scheduler.create_backup(BackupPolicy.DAILY)
|
||||||
|
scheduler.create_backup(BackupPolicy.WEEKLY)
|
||||||
|
|
||||||
|
backups = scheduler.list_backups()
|
||||||
|
assert len(backups) == 2
|
||||||
|
|
||||||
|
daily_backups = scheduler.list_backups(BackupPolicy.DAILY)
|
||||||
|
assert len(daily_backups) == 1
|
||||||
|
assert daily_backups[0].policy == BackupPolicy.DAILY
|
||||||
|
|
||||||
|
def test_cleanup_old_backups(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test cleanup of old backups."""
|
||||||
|
backup_dir = tmp_path / "backups"
|
||||||
|
scheduler = BackupScheduler(str(temp_db), backup_dir, daily_retention_days=0)
|
||||||
|
|
||||||
|
# Create a backup
|
||||||
|
scheduler.create_backup(BackupPolicy.DAILY)
|
||||||
|
|
||||||
|
# Cleanup should remove it (0 day retention)
|
||||||
|
removed = scheduler.cleanup_old_backups()
|
||||||
|
assert removed[BackupPolicy.DAILY] >= 1
|
||||||
|
|
||||||
|
def test_backup_stats(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test backup statistics."""
|
||||||
|
backup_dir = tmp_path / "backups"
|
||||||
|
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||||
|
|
||||||
|
scheduler.create_backup(BackupPolicy.DAILY)
|
||||||
|
scheduler.create_backup(BackupPolicy.MONTHLY)
|
||||||
|
|
||||||
|
stats = scheduler.get_backup_stats()
|
||||||
|
|
||||||
|
assert stats["daily"]["count"] == 1
|
||||||
|
assert stats["monthly"]["count"] == 1
|
||||||
|
assert stats["daily"]["total_size_bytes"] > 0
|
||||||
|
|
||||||
|
def test_restore_backup(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test backup restoration."""
|
||||||
|
backup_dir = tmp_path / "backups"
|
||||||
|
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||||
|
|
||||||
|
# Create backup
|
||||||
|
metadata = scheduler.create_backup(BackupPolicy.DAILY)
|
||||||
|
|
||||||
|
# Modify database
|
||||||
|
conn = sqlite3.connect(str(temp_db))
|
||||||
|
conn.execute("DELETE FROM trades")
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
# Restore
|
||||||
|
scheduler.restore_backup(metadata, verify=True)
|
||||||
|
|
||||||
|
# Verify restoration
|
||||||
|
conn = sqlite3.connect(str(temp_db))
|
||||||
|
cursor = conn.execute("SELECT COUNT(*) FROM trades")
|
||||||
|
count = cursor.fetchone()[0]
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
assert count == 3 # Original 3 trades restored
|
||||||
|
|
||||||
|
|
||||||
|
class TestHealthMonitor:
|
||||||
|
"""Test HealthMonitor functionality."""
|
||||||
|
|
||||||
|
def test_monitor_init(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test monitor initialization."""
|
||||||
|
backup_dir = tmp_path / "backups"
|
||||||
|
monitor = HealthMonitor(str(temp_db), backup_dir)
|
||||||
|
|
||||||
|
assert monitor.db_path == temp_db
|
||||||
|
|
||||||
|
def test_check_database_health_ok(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test database health check (healthy)."""
|
||||||
|
monitor = HealthMonitor(str(temp_db), tmp_path / "backups")
|
||||||
|
result = monitor.check_database_health()
|
||||||
|
|
||||||
|
assert result.status == HealthStatus.HEALTHY
|
||||||
|
assert "healthy" in result.message.lower()
|
||||||
|
assert result.details is not None
|
||||||
|
assert result.details["trade_count"] == 3
|
||||||
|
|
||||||
|
def test_check_database_health_missing(self, tmp_path: Path) -> None:
|
||||||
|
"""Test database health check (missing file)."""
|
||||||
|
non_existent = tmp_path / "missing.db"
|
||||||
|
monitor = HealthMonitor(str(non_existent), tmp_path / "backups")
|
||||||
|
result = monitor.check_database_health()
|
||||||
|
|
||||||
|
assert result.status == HealthStatus.UNHEALTHY
|
||||||
|
assert "not found" in result.message.lower()
|
||||||
|
|
||||||
|
def test_check_disk_space(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test disk space check."""
|
||||||
|
monitor = HealthMonitor(str(temp_db), tmp_path, min_disk_space_gb=0.001)
|
||||||
|
result = monitor.check_disk_space()
|
||||||
|
|
||||||
|
# Should be healthy with minimal requirement
|
||||||
|
assert result.status in [HealthStatus.HEALTHY, HealthStatus.DEGRADED]
|
||||||
|
assert result.details is not None
|
||||||
|
assert "free_gb" in result.details
|
||||||
|
|
||||||
|
def test_check_backup_recency_no_backups(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test backup recency check (no backups)."""
|
||||||
|
backup_dir = tmp_path / "backups"
|
||||||
|
backup_dir.mkdir()
|
||||||
|
(backup_dir / "daily").mkdir()
|
||||||
|
|
||||||
|
monitor = HealthMonitor(str(temp_db), backup_dir)
|
||||||
|
result = monitor.check_backup_recency()
|
||||||
|
|
||||||
|
assert result.status == HealthStatus.UNHEALTHY
|
||||||
|
assert "no" in result.message.lower()
|
||||||
|
|
||||||
|
def test_check_backup_recency_recent(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test backup recency check (recent backup)."""
|
||||||
|
backup_dir = tmp_path / "backups"
|
||||||
|
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||||
|
scheduler.create_backup(BackupPolicy.DAILY)
|
||||||
|
|
||||||
|
monitor = HealthMonitor(str(temp_db), backup_dir)
|
||||||
|
result = monitor.check_backup_recency()
|
||||||
|
|
||||||
|
assert result.status == HealthStatus.HEALTHY
|
||||||
|
assert "recent" in result.message.lower()
|
||||||
|
|
||||||
|
def test_run_all_checks(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test running all health checks."""
|
||||||
|
backup_dir = tmp_path / "backups"
|
||||||
|
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||||
|
scheduler.create_backup(BackupPolicy.DAILY)
|
||||||
|
|
||||||
|
monitor = HealthMonitor(str(temp_db), backup_dir, min_disk_space_gb=0.001)
|
||||||
|
checks = monitor.run_all_checks()
|
||||||
|
|
||||||
|
assert "database" in checks
|
||||||
|
assert "disk_space" in checks
|
||||||
|
assert "backup_recency" in checks
|
||||||
|
assert checks["database"].status == HealthStatus.HEALTHY
|
||||||
|
|
||||||
|
def test_get_overall_status(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test overall health status."""
|
||||||
|
backup_dir = tmp_path / "backups"
|
||||||
|
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||||
|
scheduler.create_backup(BackupPolicy.DAILY)
|
||||||
|
|
||||||
|
monitor = HealthMonitor(str(temp_db), backup_dir, min_disk_space_gb=0.001)
|
||||||
|
status = monitor.get_overall_status()
|
||||||
|
|
||||||
|
assert status in [HealthStatus.HEALTHY, HealthStatus.DEGRADED]
|
||||||
|
|
||||||
|
def test_get_health_report(self, temp_db: Path, tmp_path: Path) -> None:
|
||||||
|
"""Test health report generation."""
|
||||||
|
backup_dir = tmp_path / "backups"
|
||||||
|
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||||
|
scheduler.create_backup(BackupPolicy.DAILY)
|
||||||
|
|
||||||
|
monitor = HealthMonitor(str(temp_db), backup_dir, min_disk_space_gb=0.001)
|
||||||
|
report = monitor.get_health_report()
|
||||||
|
|
||||||
|
assert "overall_status" in report
|
||||||
|
assert "timestamp" in report
|
||||||
|
assert "checks" in report
|
||||||
|
assert len(report["checks"]) == 3
|
||||||
@@ -126,7 +126,7 @@ class TestPromptConstruction:
|
|||||||
"orderbook": {"asks": [], "bids": []},
|
"orderbook": {"asks": [], "bids": []},
|
||||||
"foreigner_net": -50000,
|
"foreigner_net": -50000,
|
||||||
}
|
}
|
||||||
prompt = client.build_prompt(market_data)
|
prompt = client.build_prompt_sync(market_data)
|
||||||
assert "005930" in prompt
|
assert "005930" in prompt
|
||||||
|
|
||||||
def test_prompt_contains_price(self, settings):
|
def test_prompt_contains_price(self, settings):
|
||||||
@@ -137,7 +137,7 @@ class TestPromptConstruction:
|
|||||||
"orderbook": {"asks": [], "bids": []},
|
"orderbook": {"asks": [], "bids": []},
|
||||||
"foreigner_net": -50000,
|
"foreigner_net": -50000,
|
||||||
}
|
}
|
||||||
prompt = client.build_prompt(market_data)
|
prompt = client.build_prompt_sync(market_data)
|
||||||
assert "72000" in prompt
|
assert "72000" in prompt
|
||||||
|
|
||||||
def test_prompt_enforces_json_output_format(self, settings):
|
def test_prompt_enforces_json_output_format(self, settings):
|
||||||
@@ -148,7 +148,7 @@ class TestPromptConstruction:
|
|||||||
"orderbook": {"asks": [], "bids": []},
|
"orderbook": {"asks": [], "bids": []},
|
||||||
"foreigner_net": 0,
|
"foreigner_net": 0,
|
||||||
}
|
}
|
||||||
prompt = client.build_prompt(market_data)
|
prompt = client.build_prompt_sync(market_data)
|
||||||
assert "JSON" in prompt
|
assert "JSON" in prompt
|
||||||
assert "action" in prompt
|
assert "action" in prompt
|
||||||
assert "confidence" in prompt
|
assert "confidence" in prompt
|
||||||
|
|||||||
673
tests/test_data_integration.py
Normal file
673
tests/test_data_integration.py
Normal file
@@ -0,0 +1,673 @@
|
|||||||
|
"""Tests for external data integration (news, economic calendar, market data)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.brain.gemini_client import GeminiClient
|
||||||
|
from src.data.economic_calendar import EconomicCalendar, EconomicEvent
|
||||||
|
from src.data.market_data import MarketBreadth, MarketData, MarketSentiment
|
||||||
|
from src.data.news_api import NewsAPI, NewsArticle, NewsSentiment
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# NewsAPI Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestNewsAPI:
|
||||||
|
"""Test news API integration with caching."""
|
||||||
|
|
||||||
|
def test_news_api_init_without_key(self):
|
||||||
|
"""NewsAPI should initialize without API key for testing."""
|
||||||
|
api = NewsAPI(api_key=None)
|
||||||
|
assert api._api_key is None
|
||||||
|
assert api._provider == "alphavantage"
|
||||||
|
assert api._cache_ttl == 300
|
||||||
|
|
||||||
|
def test_news_api_init_with_custom_settings(self):
|
||||||
|
"""NewsAPI should accept custom provider and cache TTL."""
|
||||||
|
api = NewsAPI(api_key="test_key", provider="newsapi", cache_ttl=600)
|
||||||
|
assert api._api_key == "test_key"
|
||||||
|
assert api._provider == "newsapi"
|
||||||
|
assert api._cache_ttl == 600
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_news_sentiment_without_api_key_returns_none(self):
|
||||||
|
"""Without API key, get_news_sentiment should return None."""
|
||||||
|
api = NewsAPI(api_key=None)
|
||||||
|
result = await api.get_news_sentiment("AAPL")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_hit_returns_cached_sentiment(self):
|
||||||
|
"""Cache hit should return cached sentiment without API call."""
|
||||||
|
api = NewsAPI(api_key="test_key")
|
||||||
|
|
||||||
|
# Manually populate cache
|
||||||
|
cached_sentiment = NewsSentiment(
|
||||||
|
stock_code="AAPL",
|
||||||
|
articles=[],
|
||||||
|
avg_sentiment=0.5,
|
||||||
|
article_count=0,
|
||||||
|
fetched_at=time.time(),
|
||||||
|
)
|
||||||
|
api._cache["AAPL"] = cached_sentiment
|
||||||
|
|
||||||
|
result = await api.get_news_sentiment("AAPL")
|
||||||
|
assert result is cached_sentiment
|
||||||
|
assert result.stock_code == "AAPL"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_expiry_triggers_refetch(self):
|
||||||
|
"""Expired cache entry should trigger refetch."""
|
||||||
|
api = NewsAPI(api_key="test_key", cache_ttl=1)
|
||||||
|
|
||||||
|
# Add expired cache entry
|
||||||
|
expired_sentiment = NewsSentiment(
|
||||||
|
stock_code="AAPL",
|
||||||
|
articles=[],
|
||||||
|
avg_sentiment=0.5,
|
||||||
|
article_count=0,
|
||||||
|
fetched_at=time.time() - 10, # 10 seconds ago
|
||||||
|
)
|
||||||
|
api._cache["AAPL"] = expired_sentiment
|
||||||
|
|
||||||
|
# Mock the fetch to avoid real API call
|
||||||
|
with patch.object(api, "_fetch_news", new_callable=AsyncMock) as mock_fetch:
|
||||||
|
mock_fetch.return_value = None
|
||||||
|
result = await api.get_news_sentiment("AAPL")
|
||||||
|
|
||||||
|
# Should have attempted refetch since cache expired
|
||||||
|
mock_fetch.assert_called_once_with("AAPL")
|
||||||
|
|
||||||
|
def test_clear_cache(self):
|
||||||
|
"""clear_cache should empty the cache."""
|
||||||
|
api = NewsAPI(api_key="test_key")
|
||||||
|
api._cache["AAPL"] = NewsSentiment(
|
||||||
|
stock_code="AAPL",
|
||||||
|
articles=[],
|
||||||
|
avg_sentiment=0.0,
|
||||||
|
article_count=0,
|
||||||
|
fetched_at=time.time(),
|
||||||
|
)
|
||||||
|
assert len(api._cache) == 1
|
||||||
|
|
||||||
|
api.clear_cache()
|
||||||
|
assert len(api._cache) == 0
|
||||||
|
|
||||||
|
def test_parse_alphavantage_response_with_valid_data(self):
|
||||||
|
"""Should parse Alpha Vantage response correctly."""
|
||||||
|
api = NewsAPI(api_key="test_key", provider="alphavantage")
|
||||||
|
|
||||||
|
mock_response = {
|
||||||
|
"feed": [
|
||||||
|
{
|
||||||
|
"title": "Apple hits new high",
|
||||||
|
"summary": "Apple stock surges to record levels",
|
||||||
|
"source": "Reuters",
|
||||||
|
"time_published": "2026-02-04T10:00:00",
|
||||||
|
"url": "https://example.com/1",
|
||||||
|
"ticker_sentiment": [
|
||||||
|
{"ticker": "AAPL", "ticker_sentiment_score": "0.85"}
|
||||||
|
],
|
||||||
|
"overall_sentiment_score": "0.75",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "Market volatility rises",
|
||||||
|
"summary": "Tech stocks face headwinds",
|
||||||
|
"source": "Bloomberg",
|
||||||
|
"time_published": "2026-02-04T09:00:00",
|
||||||
|
"url": "https://example.com/2",
|
||||||
|
"ticker_sentiment": [
|
||||||
|
{"ticker": "AAPL", "ticker_sentiment_score": "-0.3"}
|
||||||
|
],
|
||||||
|
"overall_sentiment_score": "-0.2",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
result = api._parse_alphavantage_response("AAPL", mock_response)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.stock_code == "AAPL"
|
||||||
|
assert result.article_count == 2
|
||||||
|
assert len(result.articles) == 2
|
||||||
|
assert result.articles[0].title == "Apple hits new high"
|
||||||
|
assert result.articles[0].sentiment_score == 0.85
|
||||||
|
assert result.articles[1].sentiment_score == -0.3
|
||||||
|
# Average: (0.85 - 0.3) / 2 = 0.275
|
||||||
|
assert abs(result.avg_sentiment - 0.275) < 0.01
|
||||||
|
|
||||||
|
def test_parse_alphavantage_response_without_feed_returns_none(self):
|
||||||
|
"""Should return None if 'feed' key is missing."""
|
||||||
|
api = NewsAPI(api_key="test_key", provider="alphavantage")
|
||||||
|
result = api._parse_alphavantage_response("AAPL", {})
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_parse_newsapi_response_with_valid_data(self):
|
||||||
|
"""Should parse NewsAPI.org response correctly."""
|
||||||
|
api = NewsAPI(api_key="test_key", provider="newsapi")
|
||||||
|
|
||||||
|
mock_response = {
|
||||||
|
"status": "ok",
|
||||||
|
"articles": [
|
||||||
|
{
|
||||||
|
"title": "Apple stock surges",
|
||||||
|
"description": "Strong earnings beat expectations",
|
||||||
|
"source": {"name": "TechCrunch"},
|
||||||
|
"publishedAt": "2026-02-04T10:00:00Z",
|
||||||
|
"url": "https://example.com/1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "Tech sector faces risks",
|
||||||
|
"description": "Concerns over market downturn",
|
||||||
|
"source": {"name": "CNBC"},
|
||||||
|
"publishedAt": "2026-02-04T09:00:00Z",
|
||||||
|
"url": "https://example.com/2",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = api._parse_newsapi_response("AAPL", mock_response)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.stock_code == "AAPL"
|
||||||
|
assert result.article_count == 2
|
||||||
|
assert len(result.articles) == 2
|
||||||
|
assert result.articles[0].title == "Apple stock surges"
|
||||||
|
assert result.articles[0].source == "TechCrunch"
|
||||||
|
|
||||||
|
def test_estimate_sentiment_from_text_positive(self):
|
||||||
|
"""Should detect positive sentiment from keywords."""
|
||||||
|
api = NewsAPI()
|
||||||
|
text = "Stock price surges with strong profit growth and upgrade"
|
||||||
|
sentiment = api._estimate_sentiment_from_text(text)
|
||||||
|
assert sentiment > 0.5
|
||||||
|
|
||||||
|
def test_estimate_sentiment_from_text_negative(self):
|
||||||
|
"""Should detect negative sentiment from keywords."""
|
||||||
|
api = NewsAPI()
|
||||||
|
text = "Stock plunges on weak earnings, downgrade warning"
|
||||||
|
sentiment = api._estimate_sentiment_from_text(text)
|
||||||
|
assert sentiment < -0.5
|
||||||
|
|
||||||
|
def test_estimate_sentiment_from_text_neutral(self):
|
||||||
|
"""Should return neutral sentiment without keywords."""
|
||||||
|
api = NewsAPI()
|
||||||
|
text = "Company announces quarterly report"
|
||||||
|
sentiment = api._estimate_sentiment_from_text(text)
|
||||||
|
assert abs(sentiment) < 0.1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# EconomicCalendar Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEconomicCalendar:
|
||||||
|
"""Test economic calendar functionality."""
|
||||||
|
|
||||||
|
def test_economic_calendar_init(self):
|
||||||
|
"""EconomicCalendar should initialize correctly."""
|
||||||
|
calendar = EconomicCalendar(api_key="test_key")
|
||||||
|
assert calendar._api_key == "test_key"
|
||||||
|
assert len(calendar._events) == 0
|
||||||
|
|
||||||
|
def test_add_event(self):
|
||||||
|
"""Should be able to add events to calendar."""
|
||||||
|
calendar = EconomicCalendar()
|
||||||
|
event = EconomicEvent(
|
||||||
|
name="FOMC Meeting",
|
||||||
|
event_type="FOMC",
|
||||||
|
datetime=datetime(2026, 3, 18),
|
||||||
|
impact="HIGH",
|
||||||
|
country="US",
|
||||||
|
description="Interest rate decision",
|
||||||
|
)
|
||||||
|
calendar.add_event(event)
|
||||||
|
assert len(calendar._events) == 1
|
||||||
|
assert calendar._events[0].name == "FOMC Meeting"
|
||||||
|
|
||||||
|
def test_get_upcoming_events_filters_by_timeframe(self):
|
||||||
|
"""Should only return events within specified timeframe."""
|
||||||
|
calendar = EconomicCalendar()
|
||||||
|
|
||||||
|
# Add events at different times
|
||||||
|
now = datetime.now()
|
||||||
|
calendar.add_event(
|
||||||
|
EconomicEvent(
|
||||||
|
name="Event Tomorrow",
|
||||||
|
event_type="GDP",
|
||||||
|
datetime=now + timedelta(days=1),
|
||||||
|
impact="HIGH",
|
||||||
|
country="US",
|
||||||
|
description="Test event",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
calendar.add_event(
|
||||||
|
EconomicEvent(
|
||||||
|
name="Event Next Month",
|
||||||
|
event_type="CPI",
|
||||||
|
datetime=now + timedelta(days=30),
|
||||||
|
impact="HIGH",
|
||||||
|
country="US",
|
||||||
|
description="Test event",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get events for next 7 days
|
||||||
|
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="HIGH")
|
||||||
|
assert upcoming.high_impact_count == 1
|
||||||
|
assert upcoming.events[0].name == "Event Tomorrow"
|
||||||
|
|
||||||
|
def test_get_upcoming_events_filters_by_impact(self):
|
||||||
|
"""Should filter events by minimum impact level."""
|
||||||
|
calendar = EconomicCalendar()
|
||||||
|
|
||||||
|
now = datetime.now()
|
||||||
|
calendar.add_event(
|
||||||
|
EconomicEvent(
|
||||||
|
name="High Impact Event",
|
||||||
|
event_type="FOMC",
|
||||||
|
datetime=now + timedelta(days=1),
|
||||||
|
impact="HIGH",
|
||||||
|
country="US",
|
||||||
|
description="Test",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
calendar.add_event(
|
||||||
|
EconomicEvent(
|
||||||
|
name="Low Impact Event",
|
||||||
|
event_type="OTHER",
|
||||||
|
datetime=now + timedelta(days=1),
|
||||||
|
impact="LOW",
|
||||||
|
country="US",
|
||||||
|
description="Test",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter for HIGH impact only
|
||||||
|
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="HIGH")
|
||||||
|
assert upcoming.high_impact_count == 1
|
||||||
|
assert upcoming.events[0].name == "High Impact Event"
|
||||||
|
|
||||||
|
# Filter for MEDIUM and above (should still get HIGH)
|
||||||
|
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="MEDIUM")
|
||||||
|
assert len(upcoming.events) == 1
|
||||||
|
|
||||||
|
# Filter for LOW and above (should get both)
|
||||||
|
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="LOW")
|
||||||
|
assert len(upcoming.events) == 2
|
||||||
|
|
||||||
|
def test_get_earnings_date_returns_next_earnings(self):
|
||||||
|
"""Should return next earnings date for a stock."""
|
||||||
|
calendar = EconomicCalendar()
|
||||||
|
|
||||||
|
now = datetime.now()
|
||||||
|
earnings_date = now + timedelta(days=5)
|
||||||
|
|
||||||
|
calendar.add_event(
|
||||||
|
EconomicEvent(
|
||||||
|
name="AAPL Earnings",
|
||||||
|
event_type="EARNINGS",
|
||||||
|
datetime=earnings_date,
|
||||||
|
impact="HIGH",
|
||||||
|
country="US",
|
||||||
|
description="Apple quarterly earnings",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = calendar.get_earnings_date("AAPL")
|
||||||
|
assert result == earnings_date
|
||||||
|
|
||||||
|
def test_get_earnings_date_returns_none_if_not_found(self):
|
||||||
|
"""Should return None if no earnings found for stock."""
|
||||||
|
calendar = EconomicCalendar()
|
||||||
|
result = calendar.get_earnings_date("UNKNOWN")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_load_hardcoded_events(self):
|
||||||
|
"""Should load hardcoded major economic events."""
|
||||||
|
calendar = EconomicCalendar()
|
||||||
|
calendar.load_hardcoded_events()
|
||||||
|
|
||||||
|
# Should have multiple events (FOMC, GDP, CPI)
|
||||||
|
assert len(calendar._events) > 10
|
||||||
|
|
||||||
|
# Check for FOMC events
|
||||||
|
fomc_events = [e for e in calendar._events if e.event_type == "FOMC"]
|
||||||
|
assert len(fomc_events) > 0
|
||||||
|
|
||||||
|
# Check for GDP events
|
||||||
|
gdp_events = [e for e in calendar._events if e.event_type == "GDP"]
|
||||||
|
assert len(gdp_events) > 0
|
||||||
|
|
||||||
|
# Check for CPI events
|
||||||
|
cpi_events = [e for e in calendar._events if e.event_type == "CPI"]
|
||||||
|
assert len(cpi_events) == 12 # Monthly CPI releases
|
||||||
|
|
||||||
|
def test_is_high_volatility_period_returns_true_near_high_impact(self):
|
||||||
|
"""Should return True if high-impact event is within threshold."""
|
||||||
|
calendar = EconomicCalendar()
|
||||||
|
|
||||||
|
now = datetime.now()
|
||||||
|
calendar.add_event(
|
||||||
|
EconomicEvent(
|
||||||
|
name="FOMC Meeting",
|
||||||
|
event_type="FOMC",
|
||||||
|
datetime=now + timedelta(hours=12),
|
||||||
|
impact="HIGH",
|
||||||
|
country="US",
|
||||||
|
description="Test",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert calendar.is_high_volatility_period(hours_ahead=24) is True
|
||||||
|
|
||||||
|
def test_is_high_volatility_period_returns_false_when_no_events(self):
|
||||||
|
"""Should return False if no high-impact events nearby."""
|
||||||
|
calendar = EconomicCalendar()
|
||||||
|
assert calendar.is_high_volatility_period(hours_ahead=24) is False
|
||||||
|
|
||||||
|
def test_clear_events(self):
|
||||||
|
"""Should clear all events."""
|
||||||
|
calendar = EconomicCalendar()
|
||||||
|
calendar.add_event(
|
||||||
|
EconomicEvent(
|
||||||
|
name="Test",
|
||||||
|
event_type="TEST",
|
||||||
|
datetime=datetime.now(),
|
||||||
|
impact="LOW",
|
||||||
|
country="US",
|
||||||
|
description="Test",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert len(calendar._events) == 1
|
||||||
|
|
||||||
|
calendar.clear_events()
|
||||||
|
assert len(calendar._events) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# MarketData Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMarketData:
|
||||||
|
"""Test market data indicators."""
|
||||||
|
|
||||||
|
def test_market_data_init(self):
|
||||||
|
"""MarketData should initialize correctly."""
|
||||||
|
data = MarketData(api_key="test_key")
|
||||||
|
assert data._api_key == "test_key"
|
||||||
|
|
||||||
|
def test_get_market_sentiment_without_api_key_returns_neutral(self):
|
||||||
|
"""Without API key, should return NEUTRAL sentiment."""
|
||||||
|
data = MarketData(api_key=None)
|
||||||
|
sentiment = data.get_market_sentiment()
|
||||||
|
assert sentiment == MarketSentiment.NEUTRAL
|
||||||
|
|
||||||
|
def test_get_market_breadth_without_api_key_returns_none(self):
|
||||||
|
"""Without API key, should return None for breadth."""
|
||||||
|
data = MarketData(api_key=None)
|
||||||
|
breadth = data.get_market_breadth()
|
||||||
|
assert breadth is None
|
||||||
|
|
||||||
|
def test_get_sector_performance_without_api_key_returns_empty(self):
|
||||||
|
"""Without API key, should return empty list."""
|
||||||
|
data = MarketData(api_key=None)
|
||||||
|
sectors = data.get_sector_performance()
|
||||||
|
assert sectors == []
|
||||||
|
|
||||||
|
def test_get_market_indicators_returns_defaults_without_api(self):
|
||||||
|
"""Should return default indicators without API key."""
|
||||||
|
data = MarketData(api_key=None)
|
||||||
|
indicators = data.get_market_indicators()
|
||||||
|
|
||||||
|
assert indicators.sentiment == MarketSentiment.NEUTRAL
|
||||||
|
assert indicators.breadth.advance_decline_ratio == 1.0
|
||||||
|
assert indicators.sector_performance == []
|
||||||
|
assert indicators.vix_level is None
|
||||||
|
|
||||||
|
def test_calculate_fear_greed_score_neutral_baseline(self):
|
||||||
|
"""Should return neutral score (50) for balanced market."""
|
||||||
|
data = MarketData()
|
||||||
|
breadth = MarketBreadth(
|
||||||
|
advancing_stocks=500,
|
||||||
|
declining_stocks=500,
|
||||||
|
unchanged_stocks=100,
|
||||||
|
new_highs=50,
|
||||||
|
new_lows=50,
|
||||||
|
advance_decline_ratio=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
score = data.calculate_fear_greed_score(breadth)
|
||||||
|
assert score == 50
|
||||||
|
|
||||||
|
def test_calculate_fear_greed_score_greedy_market(self):
|
||||||
|
"""Should return high score for greedy market conditions."""
|
||||||
|
data = MarketData()
|
||||||
|
breadth = MarketBreadth(
|
||||||
|
advancing_stocks=800,
|
||||||
|
declining_stocks=200,
|
||||||
|
unchanged_stocks=100,
|
||||||
|
new_highs=100,
|
||||||
|
new_lows=10,
|
||||||
|
advance_decline_ratio=4.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
score = data.calculate_fear_greed_score(breadth, vix=12.0)
|
||||||
|
assert score > 70
|
||||||
|
|
||||||
|
def test_calculate_fear_greed_score_fearful_market(self):
|
||||||
|
"""Should return low score for fearful market conditions."""
|
||||||
|
data = MarketData()
|
||||||
|
breadth = MarketBreadth(
|
||||||
|
advancing_stocks=200,
|
||||||
|
declining_stocks=800,
|
||||||
|
unchanged_stocks=100,
|
||||||
|
new_highs=10,
|
||||||
|
new_lows=100,
|
||||||
|
advance_decline_ratio=0.25,
|
||||||
|
)
|
||||||
|
|
||||||
|
score = data.calculate_fear_greed_score(breadth, vix=35.0)
|
||||||
|
assert score < 30
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GeminiClient Integration Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestGeminiClientWithExternalData:
|
||||||
|
"""Test GeminiClient integration with external data sources."""
|
||||||
|
|
||||||
|
def test_gemini_client_accepts_optional_data_sources(self, settings):
|
||||||
|
"""GeminiClient should accept optional external data sources."""
|
||||||
|
news_api = NewsAPI(api_key="test_key")
|
||||||
|
calendar = EconomicCalendar()
|
||||||
|
market_data = MarketData()
|
||||||
|
|
||||||
|
client = GeminiClient(
|
||||||
|
settings,
|
||||||
|
news_api=news_api,
|
||||||
|
economic_calendar=calendar,
|
||||||
|
market_data=market_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert client._news_api is news_api
|
||||||
|
assert client._economic_calendar is calendar
|
||||||
|
assert client._market_data is market_data
|
||||||
|
|
||||||
|
def test_gemini_client_works_without_external_data(self, settings):
|
||||||
|
"""GeminiClient should work without external data sources."""
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
assert client._news_api is None
|
||||||
|
assert client._economic_calendar is None
|
||||||
|
assert client._market_data is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_build_prompt_includes_news_sentiment(self, settings):
|
||||||
|
"""build_prompt should include news sentiment when available."""
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
|
||||||
|
market_data = {
|
||||||
|
"stock_code": "AAPL",
|
||||||
|
"current_price": 180.0,
|
||||||
|
"market_name": "US stock market",
|
||||||
|
}
|
||||||
|
|
||||||
|
sentiment = NewsSentiment(
|
||||||
|
stock_code="AAPL",
|
||||||
|
articles=[
|
||||||
|
NewsArticle(
|
||||||
|
title="Apple hits record high",
|
||||||
|
summary="Strong earnings",
|
||||||
|
source="Reuters",
|
||||||
|
published_at="2026-02-04",
|
||||||
|
sentiment_score=0.85,
|
||||||
|
url="https://example.com",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
avg_sentiment=0.85,
|
||||||
|
article_count=1,
|
||||||
|
fetched_at=time.time(),
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = await client.build_prompt(market_data, news_sentiment=sentiment)
|
||||||
|
|
||||||
|
assert "AAPL" in prompt
|
||||||
|
assert "180.0" in prompt
|
||||||
|
assert "EXTERNAL DATA" in prompt
|
||||||
|
assert "News Sentiment" in prompt
|
||||||
|
assert "0.85" in prompt
|
||||||
|
assert "Apple hits record high" in prompt
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_build_prompt_with_economic_events(self, settings):
|
||||||
|
"""build_prompt should include upcoming economic events."""
|
||||||
|
calendar = EconomicCalendar()
|
||||||
|
now = datetime.now()
|
||||||
|
calendar.add_event(
|
||||||
|
EconomicEvent(
|
||||||
|
name="FOMC Meeting",
|
||||||
|
event_type="FOMC",
|
||||||
|
datetime=now + timedelta(days=2),
|
||||||
|
impact="HIGH",
|
||||||
|
country="US",
|
||||||
|
description="Interest rate decision",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
client = GeminiClient(settings, economic_calendar=calendar)
|
||||||
|
|
||||||
|
market_data = {
|
||||||
|
"stock_code": "AAPL",
|
||||||
|
"current_price": 180.0,
|
||||||
|
"market_name": "US stock market",
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt = await client.build_prompt(market_data)
|
||||||
|
|
||||||
|
assert "EXTERNAL DATA" in prompt
|
||||||
|
assert "High-Impact Events" in prompt
|
||||||
|
assert "FOMC Meeting" in prompt
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_build_prompt_with_market_indicators(self, settings):
|
||||||
|
"""build_prompt should include market sentiment indicators."""
|
||||||
|
market_data_provider = MarketData(api_key="test_key")
|
||||||
|
|
||||||
|
# Mock the get_market_indicators to return test data
|
||||||
|
with patch.object(market_data_provider, "get_market_indicators") as mock:
|
||||||
|
mock.return_value = MagicMock(
|
||||||
|
sentiment=MarketSentiment.EXTREME_GREED,
|
||||||
|
breadth=MagicMock(advance_decline_ratio=2.5),
|
||||||
|
)
|
||||||
|
|
||||||
|
client = GeminiClient(settings, market_data=market_data_provider)
|
||||||
|
|
||||||
|
market_data = {
|
||||||
|
"stock_code": "AAPL",
|
||||||
|
"current_price": 180.0,
|
||||||
|
"market_name": "US stock market",
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt = await client.build_prompt(market_data)
|
||||||
|
|
||||||
|
assert "EXTERNAL DATA" in prompt
|
||||||
|
assert "Market Sentiment" in prompt
|
||||||
|
assert "EXTREME_GREED" in prompt
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_build_prompt_graceful_when_no_external_data(self, settings):
|
||||||
|
"""build_prompt should work gracefully without external data."""
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
|
||||||
|
market_data = {
|
||||||
|
"stock_code": "AAPL",
|
||||||
|
"current_price": 180.0,
|
||||||
|
"market_name": "US stock market",
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt = await client.build_prompt(market_data)
|
||||||
|
|
||||||
|
assert "AAPL" in prompt
|
||||||
|
assert "180.0" in prompt
|
||||||
|
# Should NOT have external data section
|
||||||
|
assert "EXTERNAL DATA" not in prompt
|
||||||
|
|
||||||
|
def test_build_prompt_sync_backward_compatibility(self, settings):
|
||||||
|
"""build_prompt_sync should maintain backward compatibility."""
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
|
||||||
|
market_data = {
|
||||||
|
"stock_code": "005930",
|
||||||
|
"current_price": 72000,
|
||||||
|
"orderbook": {"asks": [], "bids": []},
|
||||||
|
"foreigner_net": -50000,
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt = client.build_prompt_sync(market_data)
|
||||||
|
|
||||||
|
assert "005930" in prompt
|
||||||
|
assert "72000" in prompt
|
||||||
|
assert "JSON" in prompt
|
||||||
|
# Sync version should NOT have external data
|
||||||
|
assert "EXTERNAL DATA" not in prompt
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decide_with_news_sentiment_parameter(self, settings):
|
||||||
|
"""decide should accept optional news_sentiment parameter."""
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
|
||||||
|
market_data = {
|
||||||
|
"stock_code": "AAPL",
|
||||||
|
"current_price": 180.0,
|
||||||
|
"market_name": "US stock market",
|
||||||
|
}
|
||||||
|
|
||||||
|
sentiment = NewsSentiment(
|
||||||
|
stock_code="AAPL",
|
||||||
|
articles=[],
|
||||||
|
avg_sentiment=0.5,
|
||||||
|
article_count=1,
|
||||||
|
fetched_at=time.time(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the Gemini API call
|
||||||
|
with patch.object(client._client.aio.models, "generate_content", new_callable=AsyncMock) as mock_gen:
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.text = '{"action": "BUY", "confidence": 85, "rationale": "Good news"}'
|
||||||
|
mock_gen.return_value = mock_response
|
||||||
|
|
||||||
|
decision = await client.decide(market_data, news_sentiment=sentiment)
|
||||||
|
|
||||||
|
assert decision.action == "BUY"
|
||||||
|
assert decision.confidence == 85
|
||||||
|
mock_gen.assert_called_once()
|
||||||
Reference in New Issue
Block a user