Compare commits
64 Commits
feature/is
...
feature/is
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1210c17989 | ||
| c43660a58c | |||
|
|
7fd48c7764 | ||
| a105bb7c1a | |||
|
|
1a34a74232 | ||
| a82a167915 | |||
|
|
7725e7a8de | ||
|
|
f0ae25c533 | ||
| 27f581f17d | |||
|
|
18a098d9a6 | ||
| d2b07326ed | |||
|
|
1c5eadc23b | ||
| 10ff718045 | |||
|
|
0ca3fe9f5d | ||
| 462f8763ab | |||
|
|
57a45a24cb | ||
| a7696568cc | |||
|
|
70701bf73a | ||
| 20dbd94892 | |||
|
|
48a99962e3 | ||
| ee66ecc305 | |||
|
|
065c9daaad | ||
| c76b9d5c15 | |||
|
|
259f9d2e24 | ||
| 8e715c55cd | |||
|
|
0057de4d12 | ||
|
|
71ac59794e | ||
| be04820b00 | |||
| 10b6e34d44 | |||
| 58f1106dbd | |||
| cf5072cced | |||
|
|
702653e52e | ||
|
|
db0d966a6a | ||
|
|
a56adcd342 | ||
|
|
eaf509a895 | ||
|
|
854931bed2 | ||
| 33b5ff5e54 | |||
| 3923d03650 | |||
|
|
c57ccc4bca | ||
|
|
cb2e3fae57 | ||
| 5e4c68c9d8 | |||
|
|
95f540e5df | ||
| 0087a6b20a | |||
|
|
3dfd7c0935 | ||
| 4b2bb25d03 | |||
|
|
881bbb4240 | ||
| 5f7d61748b | |||
|
|
972e71a2f1 | ||
| 614b9939b1 | |||
|
|
6dbc2afbf4 | ||
| 6c96f9ac64 | |||
|
|
ed26915562 | ||
| 628a572c70 | |||
|
|
73e1d0a54e | ||
| b111157dc8 | |||
|
|
8c05448843 | ||
|
|
87556b145e | ||
| 645c761238 | |||
|
|
033d5fcadd | ||
| 128324427f | |||
|
|
61f5aaf4a3 | ||
|
|
4f61d5af8e | ||
|
|
62fd4ff5e1 | ||
| f40f19e735 |
17
.env.example
17
.env.example
@@ -16,8 +16,21 @@ CONFIDENCE_THRESHOLD=80
|
||||
# Database
|
||||
DB_PATH=data/trade_logs.db
|
||||
|
||||
# Rate Limiting
|
||||
RATE_LIMIT_RPS=10.0
|
||||
# Rate Limiting (requests per second for KIS API)
|
||||
# Reduced to 5.0 to avoid "초당 거래건수 초과" errors (EGW00201)
|
||||
RATE_LIMIT_RPS=5.0
|
||||
|
||||
# Trading Mode (paper / live)
|
||||
MODE=paper
|
||||
|
||||
# External Data APIs (optional — for enhanced decision-making)
|
||||
# NEWS_API_KEY=your_news_api_key_here
|
||||
# NEWS_API_PROVIDER=alphavantage
|
||||
# MARKET_DATA_API_KEY=your_market_data_key_here
|
||||
|
||||
# Telegram Notifications (optional)
|
||||
# Get bot token from @BotFather on Telegram
|
||||
# Get chat ID from @userinfobot or your chat
|
||||
# TELEGRAM_BOT_TOKEN=1234567890:ABCdefGHIjklMNOpqrsTUVwxyz
|
||||
# TELEGRAM_CHAT_ID=123456789
|
||||
# TELEGRAM_ENABLED=true
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -174,4 +174,7 @@ cython_debug/
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
# Data files (trade logs, databases)
|
||||
# But NOT src/data/ which contains source code
|
||||
data/
|
||||
!src/data/
|
||||
|
||||
75
CLAUDE.md
75
CLAUDE.md
@@ -17,6 +17,67 @@ pytest -v --cov=src
|
||||
python -m src.main --mode=paper
|
||||
```
|
||||
|
||||
## Telegram Notifications (Optional)
|
||||
|
||||
Get real-time alerts for trades, circuit breakers, and system events via Telegram.
|
||||
|
||||
### Quick Setup
|
||||
|
||||
1. **Create bot**: Message [@BotFather](https://t.me/BotFather) on Telegram → `/newbot`
|
||||
2. **Get chat ID**: Message [@userinfobot](https://t.me/userinfobot) → `/start`
|
||||
3. **Configure**: Add to `.env`:
|
||||
```bash
|
||||
TELEGRAM_BOT_TOKEN=1234567890:ABCdefGHIjklMNOpqrsTUVwxyz
|
||||
TELEGRAM_CHAT_ID=123456789
|
||||
TELEGRAM_ENABLED=true
|
||||
```
|
||||
4. **Test**: Start bot conversation (`/start`), then run the agent
|
||||
|
||||
**Full documentation**: [src/notifications/README.md](src/notifications/README.md)
|
||||
|
||||
### What You'll Get
|
||||
|
||||
- 🟢 Trade execution alerts (BUY/SELL with confidence)
|
||||
- 🚨 Circuit breaker trips (automatic trading halt)
|
||||
- ⚠️ Fat-finger rejections (oversized orders blocked)
|
||||
- ℹ️ Market open/close notifications
|
||||
- 📝 System startup/shutdown status
|
||||
|
||||
**Fail-safe**: Notifications never crash the trading system. Missing credentials or API errors are logged but trading continues normally.
|
||||
|
||||
## Smart Volatility Scanner (Optional)
|
||||
|
||||
Python-first filtering pipeline that reduces Gemini API calls by pre-filtering stocks using technical indicators.
|
||||
|
||||
### How It Works
|
||||
|
||||
1. **Fetch Rankings** — KIS API volume surge rankings (top 30 stocks)
|
||||
2. **Python Filter** — RSI + volume ratio calculations (no AI)
|
||||
- Volume > 200% of previous day
|
||||
- RSI(14) < 30 (oversold) OR RSI(14) > 70 (momentum)
|
||||
3. **AI Judgment** — Only qualified candidates (1-3 stocks) sent to Gemini
|
||||
|
||||
### Configuration
|
||||
|
||||
Add to `.env` (optional, has sensible defaults):
|
||||
```bash
|
||||
RSI_OVERSOLD_THRESHOLD=30 # 0-50, default 30
|
||||
RSI_MOMENTUM_THRESHOLD=70 # 50-100, default 70
|
||||
VOL_MULTIPLIER=2.0 # Volume threshold (2.0 = 200%)
|
||||
SCANNER_TOP_N=3 # Max candidates per scan
|
||||
```
|
||||
|
||||
### Benefits
|
||||
|
||||
- **Reduces API costs** — Process 1-3 stocks instead of 20-30
|
||||
- **Python-based filtering** — Fast technical analysis before AI
|
||||
- **Evolution-ready** — Selection context logged for strategy optimization
|
||||
- **Fault-tolerant** — Falls back to static watchlist on API failure
|
||||
|
||||
### Realtime Mode Only
|
||||
|
||||
Smart Scanner runs in `TRADE_MODE=realtime` only. Daily mode uses static watchlists for batch efficiency.
|
||||
|
||||
## Documentation
|
||||
|
||||
- **[Workflow Guide](docs/workflow.md)** — Git workflow policy and agent-based development
|
||||
@@ -25,6 +86,7 @@ python -m src.main --mode=paper
|
||||
- **[Context Tree](docs/context-tree.md)** — L1-L7 hierarchical memory system
|
||||
- **[Testing](docs/testing.md)** — Test structure, coverage requirements, writing tests
|
||||
- **[Agent Policies](docs/agents.md)** — Prime directives, constraints, prohibited actions
|
||||
- **[Requirements Log](docs/requirements-log.md)** — User requirements and feedback tracking
|
||||
|
||||
## Core Principles
|
||||
|
||||
@@ -33,20 +95,31 @@ python -m src.main --mode=paper
|
||||
3. **Issue-Driven Development** — All work goes through Gitea issues → feature branches → PRs
|
||||
4. **Agent Specialization** — Use dedicated agents for design, coding, testing, docs, review
|
||||
|
||||
## Requirements Management
|
||||
|
||||
User requirements and feedback are tracked in [docs/requirements-log.md](docs/requirements-log.md):
|
||||
|
||||
- New requirements are added chronologically with dates
|
||||
- Code changes should reference related requirements
|
||||
- Helps maintain project evolution aligned with user needs
|
||||
- Preserves context across conversations and development cycles
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
src/
|
||||
├── analysis/ # Technical analysis (RSI, volatility, smart scanner)
|
||||
├── broker/ # KIS API client (domestic + overseas)
|
||||
├── brain/ # Gemini AI decision engine
|
||||
├── core/ # Risk manager (READ-ONLY)
|
||||
├── evolution/ # Self-improvement optimizer
|
||||
├── markets/ # Market schedules and timezone handling
|
||||
├── notifications/ # Telegram real-time alerts
|
||||
├── db.py # SQLite trade logging
|
||||
├── main.py # Trading loop orchestrator
|
||||
└── config.py # Settings (from .env)
|
||||
|
||||
tests/ # 54 tests across 4 files
|
||||
tests/ # 343 tests across 14 files
|
||||
docs/ # Extended documentation
|
||||
```
|
||||
|
||||
|
||||
48
README.md
48
README.md
@@ -29,6 +29,7 @@ KIS(한국투자증권) API로 매매하고, Google Gemini로 판단하며, 자
|
||||
| 브로커 | `src/broker/kis_api.py` | KIS API 비동기 래퍼 (토큰 갱신, 레이트 리미터, 해시키) |
|
||||
| 두뇌 | `src/brain/gemini_client.py` | Gemini 프롬프트 구성 및 JSON 응답 파싱 |
|
||||
| 방패 | `src/core/risk_manager.py` | 서킷 브레이커 + 팻 핑거 체크 |
|
||||
| 알림 | `src/notifications/telegram_client.py` | 텔레그램 실시간 거래 알림 (선택사항) |
|
||||
| 진화 | `src/evolution/optimizer.py` | 실패 패턴 분석 → 새 전략 생성 → 테스트 → PR |
|
||||
| DB | `src/db.py` | SQLite 거래 로그 기록 |
|
||||
|
||||
@@ -75,6 +76,34 @@ python -m src.main --mode=paper
|
||||
docker compose up -d ouroboros
|
||||
```
|
||||
|
||||
## 텔레그램 알림 (선택사항)
|
||||
|
||||
거래 실행, 서킷 브레이커 발동, 시스템 상태 등을 텔레그램으로 실시간 알림 받을 수 있습니다.
|
||||
|
||||
### 빠른 설정
|
||||
|
||||
1. **봇 생성**: 텔레그램에서 [@BotFather](https://t.me/BotFather) 메시지 → `/newbot` 명령
|
||||
2. **채팅 ID 확인**: [@userinfobot](https://t.me/userinfobot) 메시지 → `/start` 명령
|
||||
3. **환경변수 설정**: `.env` 파일에 추가
|
||||
```bash
|
||||
TELEGRAM_BOT_TOKEN=1234567890:ABCdefGHIjklMNOpqrsTUVwxyz
|
||||
TELEGRAM_CHAT_ID=123456789
|
||||
TELEGRAM_ENABLED=true
|
||||
```
|
||||
4. **테스트**: 봇과 대화 시작 (`/start` 전송) 후 에이전트 실행
|
||||
|
||||
**상세 문서**: [src/notifications/README.md](src/notifications/README.md)
|
||||
|
||||
### 알림 종류
|
||||
|
||||
- 🟢 거래 체결 알림 (BUY/SELL + 신뢰도)
|
||||
- 🚨 서킷 브레이커 발동 (자동 거래 중단)
|
||||
- ⚠️ 팻 핑거 차단 (과도한 주문 차단)
|
||||
- ℹ️ 장 시작/종료 알림
|
||||
- 📝 시스템 시작/종료 상태
|
||||
|
||||
**안전장치**: 알림 실패해도 거래는 계속 진행됩니다. 텔레그램 API 오류나 설정 누락이 있어도 거래 시스템은 정상 작동합니다.
|
||||
|
||||
## 테스트
|
||||
|
||||
35개 테스트가 TDD 방식으로 구현 전에 먼저 작성되었습니다.
|
||||
@@ -104,15 +133,16 @@ The-Ouroboros/
|
||||
│ ├── agents.md # AI 에이전트 페르소나 정의
|
||||
│ └── skills.md # 사용 가능한 도구 목록
|
||||
├── src/
|
||||
│ ├── config.py # Pydantic 설정
|
||||
│ ├── logging_config.py # JSON 구조화 로깅
|
||||
│ ├── db.py # SQLite 거래 기록
|
||||
│ ├── main.py # 비동기 거래 루프
|
||||
│ ├── broker/kis_api.py # KIS API 클라이언트
|
||||
│ ├── brain/gemini_client.py # Gemini 의사결정 엔진
|
||||
│ ├── core/risk_manager.py # 리스크 관리
|
||||
│ ├── evolution/optimizer.py # 전략 진화 엔진
|
||||
│ └── strategies/base.py # 전략 베이스 클래스
|
||||
│ ├── config.py # Pydantic 설정
|
||||
│ ├── logging_config.py # JSON 구조화 로깅
|
||||
│ ├── db.py # SQLite 거래 기록
|
||||
│ ├── main.py # 비동기 거래 루프
|
||||
│ ├── broker/kis_api.py # KIS API 클라이언트
|
||||
│ ├── brain/gemini_client.py # Gemini 의사결정 엔진
|
||||
│ ├── core/risk_manager.py # 리스크 관리
|
||||
│ ├── notifications/telegram_client.py # 텔레그램 알림
|
||||
│ ├── evolution/optimizer.py # 전략 진화 엔진
|
||||
│ └── strategies/base.py # 전략 베이스 클래스
|
||||
├── tests/ # TDD 테스트 스위트
|
||||
├── Dockerfile # 멀티스테이지 빌드
|
||||
├── docker-compose.yml # 서비스 오케스트레이션
|
||||
|
||||
45
docs/agent-constraints.md
Normal file
45
docs/agent-constraints.md
Normal file
@@ -0,0 +1,45 @@
|
||||
# Agent Constraints
|
||||
|
||||
This document records **persistent behavioral constraints** for agents working on this repository.
|
||||
It is distinct from `docs/requirements-log.md`, which records **project/product requirements**.
|
||||
|
||||
## Scope
|
||||
|
||||
- Applies to all AI agents and automation that modify this repo.
|
||||
- Supplements (does not replace) `docs/agents.md` and `docs/workflow.md`.
|
||||
|
||||
## Persistent Rules
|
||||
|
||||
1. **Workflow enforcement**
|
||||
- Follow `docs/workflow.md` for all changes.
|
||||
- Create a Gitea issue before any code or documentation change.
|
||||
- Work on a feature branch `feature/issue-{N}-{short-description}` and open a PR.
|
||||
- Never commit directly to `main`.
|
||||
|
||||
2. **Document-first routing**
|
||||
- When performing work, consult relevant `docs/` files *before* making changes.
|
||||
- Route decisions to the documented policy whenever applicable.
|
||||
- If guidance conflicts, prefer the stricter/safety-first rule and note it in the PR.
|
||||
|
||||
3. **Docs with code**
|
||||
- Any code change must be accompanied by relevant documentation updates.
|
||||
- If no doc update is needed, state the reason explicitly in the PR.
|
||||
|
||||
4. **Session-persistent user constraints**
|
||||
- If the user requests that a behavior should persist across sessions, record it here
|
||||
(or in a dedicated policy doc) and reference it when working.
|
||||
- Keep entries short and concrete, with dates.
|
||||
|
||||
## Change Control
|
||||
|
||||
- Changes to this file follow the same workflow as code changes.
|
||||
- Keep the history chronological and minimize rewording of existing entries.
|
||||
|
||||
## History
|
||||
|
||||
### 2026-02-08
|
||||
|
||||
- Always enforce Gitea workflow: issue -> feature branch -> PR before changes.
|
||||
- When work requires guidance, consult the relevant `docs/` policies first.
|
||||
- Any code change must be accompanied by relevant documentation updates.
|
||||
- Persist user constraints across sessions by recording them in this document.
|
||||
@@ -2,7 +2,42 @@
|
||||
|
||||
## Overview
|
||||
|
||||
Self-evolving AI trading agent for global stock markets via KIS (Korea Investment & Securities) API. The main loop in `src/main.py` orchestrates four components in a 60-second cycle per stock across multiple markets.
|
||||
Self-evolving AI trading agent for global stock markets via KIS (Korea Investment & Securities) API. The main loop in `src/main.py` orchestrates four components across multiple markets with two trading modes: daily (batch API calls) or realtime (per-stock decisions).
|
||||
|
||||
## Trading Modes
|
||||
|
||||
The system supports two trading frequency modes controlled by the `TRADE_MODE` environment variable:
|
||||
|
||||
### Daily Mode (default)
|
||||
|
||||
Optimized for Gemini Free tier API limits (20 calls/day):
|
||||
|
||||
- **Batch decisions**: 1 API call per market per session
|
||||
- **Fixed schedule**: 4 sessions per day at 6-hour intervals (configurable)
|
||||
- **API efficiency**: Processes all stocks in a market simultaneously
|
||||
- **Use case**: Free tier users, cost-conscious deployments
|
||||
- **Configuration**:
|
||||
```bash
|
||||
TRADE_MODE=daily
|
||||
DAILY_SESSIONS=4 # Sessions per day (1-10)
|
||||
SESSION_INTERVAL_HOURS=6 # Hours between sessions (1-24)
|
||||
```
|
||||
|
||||
**Example**: With 2 markets (US, KR) and 4 sessions/day = 8 API calls/day (within 20 call limit)
|
||||
|
||||
### Realtime Mode
|
||||
|
||||
High-frequency trading with individual stock analysis:
|
||||
|
||||
- **Per-stock decisions**: 1 API call per stock per cycle
|
||||
- **60-second interval**: Continuous monitoring
|
||||
- **Use case**: Production deployments with Gemini paid tier
|
||||
- **Configuration**:
|
||||
```bash
|
||||
TRADE_MODE=realtime
|
||||
```
|
||||
|
||||
**Note**: Realtime mode requires Gemini API subscription due to high call volume.
|
||||
|
||||
## Core Components
|
||||
|
||||
@@ -29,7 +64,39 @@ Self-evolving AI trading agent for global stock markets via KIS (Korea Investmen
|
||||
- `get_open_markets()` returns currently active markets
|
||||
- `get_next_market_open()` finds next market to open and when
|
||||
|
||||
### 2. Brain (`src/brain/gemini_client.py`)
|
||||
**New API Methods** (added in v0.9.0):
|
||||
- `fetch_market_rankings()` — Fetch volume surge rankings from KIS API
|
||||
- `get_daily_prices()` — Fetch OHLCV history for technical analysis
|
||||
|
||||
### 2. Analysis (`src/analysis/`)
|
||||
|
||||
**VolatilityAnalyzer** (`volatility.py`) — Technical indicator calculations
|
||||
|
||||
- ATR (Average True Range) for volatility measurement
|
||||
- RSI (Relative Strength Index) using Wilder's smoothing method
|
||||
- Price change percentages across multiple timeframes
|
||||
- Volume surge ratios and price-volume divergence
|
||||
- Momentum scoring (0-100 scale)
|
||||
- Breakout/breakdown pattern detection
|
||||
|
||||
**SmartVolatilityScanner** (`smart_scanner.py`) — Python-first filtering pipeline
|
||||
|
||||
- **Step 1**: Fetch volume rankings from KIS API (top 30 stocks)
|
||||
- **Step 2**: Calculate RSI and volume ratio for each stock
|
||||
- **Step 3**: Apply filters:
|
||||
- Volume ratio >= `VOL_MULTIPLIER` (default 2.0x previous day)
|
||||
- RSI < `RSI_OVERSOLD_THRESHOLD` (30) OR RSI > `RSI_MOMENTUM_THRESHOLD` (70)
|
||||
- **Step 4**: Score candidates by RSI extremity (60%) + volume surge (40%)
|
||||
- **Step 5**: Return top N candidates (default 3) for AI analysis
|
||||
- **Fallback**: Uses static watchlist if ranking API unavailable
|
||||
- **Realtime mode only**: Daily mode uses batch processing for API efficiency
|
||||
|
||||
**Benefits:**
|
||||
- Reduces Gemini API calls from 20-30 stocks to 1-3 qualified candidates
|
||||
- Fast Python-based filtering before expensive AI judgment
|
||||
- Logs selection context (RSI, volume_ratio, signal, score) for Evolution system
|
||||
|
||||
### 3. Brain (`src/brain/gemini_client.py`)
|
||||
|
||||
**GeminiClient** — AI decision engine powered by Google Gemini
|
||||
|
||||
@@ -39,7 +106,7 @@ Self-evolving AI trading agent for global stock markets via KIS (Korea Investmen
|
||||
- Falls back to safe HOLD on any parse/API error
|
||||
- Handles markdown-wrapped JSON, malformed responses, invalid actions
|
||||
|
||||
### 3. Risk Manager (`src/core/risk_manager.py`)
|
||||
### 4. Risk Manager (`src/core/risk_manager.py`)
|
||||
|
||||
**RiskManager** — Safety circuit breaker and order validation
|
||||
|
||||
@@ -51,7 +118,26 @@ Self-evolving AI trading agent for global stock markets via KIS (Korea Investmen
|
||||
- **Fat-Finger Protection**: Rejects orders exceeding 30% of available cash
|
||||
- Must always be enforced, cannot be disabled
|
||||
|
||||
### 4. Evolution (`src/evolution/optimizer.py`)
|
||||
### 5. Notifications (`src/notifications/telegram_client.py`)
|
||||
|
||||
**TelegramClient** — Real-time event notifications via Telegram Bot API
|
||||
|
||||
- Sends alerts for trades, circuit breakers, fat-finger rejections, system events
|
||||
- Non-blocking: failures are logged but never crash trading
|
||||
- Rate-limited: 1 message/second default to respect Telegram API limits
|
||||
- Auto-disabled when credentials missing
|
||||
- Gracefully handles API errors, network timeouts, invalid tokens
|
||||
|
||||
**Notification Types:**
|
||||
- Trade execution (BUY/SELL with confidence)
|
||||
- Circuit breaker trips (critical alert)
|
||||
- Fat-finger protection triggers (order rejection)
|
||||
- Market open/close events
|
||||
- System startup/shutdown status
|
||||
|
||||
**Setup:** See [src/notifications/README.md](../src/notifications/README.md) for bot creation and configuration.
|
||||
|
||||
### 6. Evolution (`src/evolution/optimizer.py`)
|
||||
|
||||
**StrategyOptimizer** — Self-improvement loop
|
||||
|
||||
@@ -63,9 +149,11 @@ Self-evolving AI trading agent for global stock markets via KIS (Korea Investmen
|
||||
|
||||
## Data Flow
|
||||
|
||||
### Realtime Mode (with Smart Scanner)
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Main Loop (60s cycle per stock, per market) │
|
||||
│ Main Loop (60s cycle per market) │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
@@ -78,6 +166,21 @@ Self-evolving AI trading agent for global stock markets via KIS (Korea Investmen
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Smart Scanner (Python-first) │
|
||||
│ - Fetch volume rankings (KIS) │
|
||||
│ - Get 20d price history per stock│
|
||||
│ - Calculate RSI(14) + vol ratio │
|
||||
│ - Filter: vol>2x AND RSI extreme │
|
||||
│ - Return top 3 qualified stocks │
|
||||
└──────────────────┬────────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ For Each Qualified Candidate │
|
||||
└──────────────────┬────────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Broker: Fetch Market Data │
|
||||
│ - Domestic: orderbook + balance │
|
||||
│ - Overseas: price + balance │
|
||||
@@ -91,7 +194,7 @@ Self-evolving AI trading agent for global stock markets via KIS (Korea Investmen
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Brain: Get Decision │
|
||||
│ Brain: Get Decision (AI) │
|
||||
│ - Build prompt with market data │
|
||||
│ - Call Gemini API │
|
||||
│ - Parse JSON response │
|
||||
@@ -115,10 +218,21 @@ Self-evolving AI trading agent for global stock markets via KIS (Korea Investmen
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Notifications: Send Alert │
|
||||
│ - Trade execution notification │
|
||||
│ - Non-blocking (errors logged) │
|
||||
│ - Rate-limited to 1/sec │
|
||||
└──────────────────┬────────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────┐
|
||||
│ Database: Log Trade │
|
||||
│ - SQLite (data/trades.db) │
|
||||
│ - Track: action, confidence, │
|
||||
│ rationale, market, exchange │
|
||||
│ - NEW: selection_context (JSON) │
|
||||
│ - RSI, volume_ratio, signal │
|
||||
│ - For Evolution optimization │
|
||||
└───────────────────────────────────┘
|
||||
```
|
||||
|
||||
@@ -138,11 +252,24 @@ CREATE TABLE trades (
|
||||
price REAL,
|
||||
pnl REAL DEFAULT 0.0,
|
||||
market TEXT DEFAULT 'KR', -- KR | US_NASDAQ | JP | etc.
|
||||
exchange_code TEXT DEFAULT 'KRX' -- KRX | NASD | NYSE | etc.
|
||||
exchange_code TEXT DEFAULT 'KRX', -- KRX | NASD | NYSE | etc.
|
||||
selection_context TEXT -- JSON: {rsi, volume_ratio, signal, score}
|
||||
);
|
||||
```
|
||||
|
||||
Auto-migration: Adds `market` and `exchange_code` columns if missing for backward compatibility.
|
||||
**Selection Context** (new in v0.9.0): Stores scanner selection criteria as JSON:
|
||||
```json
|
||||
{
|
||||
"rsi": 28.5,
|
||||
"volume_ratio": 2.7,
|
||||
"signal": "oversold",
|
||||
"score": 85.2
|
||||
}
|
||||
```
|
||||
|
||||
Enables Evolution system to analyze correlation between selection criteria and trade outcomes.
|
||||
|
||||
Auto-migration: Adds `market`, `exchange_code`, and `selection_context` columns if missing for backward compatibility.
|
||||
|
||||
## Configuration
|
||||
|
||||
@@ -164,6 +291,22 @@ CONFIDENCE_THRESHOLD=80
|
||||
MAX_LOSS_PCT=3.0
|
||||
MAX_ORDER_PCT=30.0
|
||||
ENABLED_MARKETS=KR,US_NASDAQ # Comma-separated market codes
|
||||
|
||||
# Trading Mode (API efficiency)
|
||||
TRADE_MODE=daily # daily | realtime
|
||||
DAILY_SESSIONS=4 # Sessions per day (daily mode only)
|
||||
SESSION_INTERVAL_HOURS=6 # Hours between sessions (daily mode only)
|
||||
|
||||
# Telegram Notifications (optional)
|
||||
TELEGRAM_BOT_TOKEN=1234567890:ABCdefGHIjklMNOpqrsTUVwxyz
|
||||
TELEGRAM_CHAT_ID=123456789
|
||||
TELEGRAM_ENABLED=true
|
||||
|
||||
# Smart Scanner (optional, realtime mode only)
|
||||
RSI_OVERSOLD_THRESHOLD=30 # 0-50, oversold threshold
|
||||
RSI_MOMENTUM_THRESHOLD=70 # 50-100, momentum threshold
|
||||
VOL_MULTIPLIER=2.0 # Minimum volume ratio (2.0 = 200%)
|
||||
SCANNER_TOP_N=3 # Max qualified candidates per scan
|
||||
```
|
||||
|
||||
Tests use in-memory SQLite (`DB_PATH=":memory:"`) and dummy credentials via `tests/conftest.py`.
|
||||
@@ -189,3 +332,12 @@ Tests use in-memory SQLite (`DB_PATH=":memory:"`) and dummy credentials via `tes
|
||||
- Wait until next market opens
|
||||
- Use `get_next_market_open()` to calculate wait time
|
||||
- Sleep until market open time
|
||||
|
||||
### Telegram API Errors
|
||||
- Log warning but continue trading
|
||||
- Missing credentials → auto-disable notifications
|
||||
- Network timeout → skip notification, no retry
|
||||
- Invalid token → log error, trading unaffected
|
||||
- Rate limit exceeded → queued via rate limiter
|
||||
|
||||
**Guarantee**: Notification failures never interrupt trading operations.
|
||||
|
||||
348
docs/disaster_recovery.md
Normal file
348
docs/disaster_recovery.md
Normal file
@@ -0,0 +1,348 @@
|
||||
# Disaster Recovery Guide
|
||||
|
||||
Complete guide for backing up and restoring The Ouroboros trading system.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Backup Strategy](#backup-strategy)
|
||||
- [Creating Backups](#creating-backups)
|
||||
- [Restoring from Backup](#restoring-from-backup)
|
||||
- [Health Monitoring](#health-monitoring)
|
||||
- [Export Formats](#export-formats)
|
||||
- [RTO/RPO](#rtorpo)
|
||||
- [Testing Recovery](#testing-recovery)
|
||||
|
||||
## Backup Strategy
|
||||
|
||||
The system implements a 3-tier backup retention policy:
|
||||
|
||||
| Policy | Frequency | Retention | Purpose |
|
||||
|--------|-----------|-----------|---------|
|
||||
| **Daily** | Every day | 30 days | Quick recovery from recent issues |
|
||||
| **Weekly** | Sunday | 1 year | Medium-term historical analysis |
|
||||
| **Monthly** | 1st of month | Forever | Long-term archival |
|
||||
|
||||
### Storage Structure
|
||||
|
||||
```
|
||||
data/backups/
|
||||
├── daily/ # Last 30 days
|
||||
├── weekly/ # Last 52 weeks
|
||||
└── monthly/ # Forever (cold storage)
|
||||
```
|
||||
|
||||
## Creating Backups
|
||||
|
||||
### Automated Backups (Recommended)
|
||||
|
||||
Set up a cron job to run daily:
|
||||
|
||||
```bash
|
||||
# Edit crontab
|
||||
crontab -e
|
||||
|
||||
# Run backup at 2 AM every day
|
||||
0 2 * * * cd /path/to/The-Ouroboros && ./scripts/backup.sh >> logs/backup.log 2>&1
|
||||
```
|
||||
|
||||
### Manual Backups
|
||||
|
||||
```bash
|
||||
# Run backup script
|
||||
./scripts/backup.sh
|
||||
|
||||
# Or use Python directly
|
||||
python3 -c "
|
||||
from pathlib import Path
|
||||
from src.backup.scheduler import BackupScheduler, BackupPolicy
|
||||
|
||||
scheduler = BackupScheduler('data/trade_logs.db', Path('data/backups'))
|
||||
metadata = scheduler.create_backup(BackupPolicy.DAILY, verify=True)
|
||||
print(f'Backup created: {metadata.file_path}')
|
||||
"
|
||||
```
|
||||
|
||||
### Export to Other Formats
|
||||
|
||||
```bash
|
||||
python3 -c "
|
||||
from pathlib import Path
|
||||
from src.backup.exporter import BackupExporter, ExportFormat
|
||||
|
||||
exporter = BackupExporter('data/trade_logs.db')
|
||||
results = exporter.export_all(
|
||||
Path('exports'),
|
||||
formats=[ExportFormat.JSON, ExportFormat.CSV],
|
||||
compress=True
|
||||
)
|
||||
"
|
||||
```
|
||||
|
||||
## Restoring from Backup
|
||||
|
||||
### Interactive Restoration
|
||||
|
||||
```bash
|
||||
./scripts/restore.sh
|
||||
```
|
||||
|
||||
The script will:
|
||||
1. List available backups
|
||||
2. Ask you to select one
|
||||
3. Create a safety backup of current database
|
||||
4. Restore the selected backup
|
||||
5. Verify database integrity
|
||||
|
||||
### Manual Restoration
|
||||
|
||||
```python
|
||||
from pathlib import Path
|
||||
from src.backup.scheduler import BackupScheduler
|
||||
|
||||
scheduler = BackupScheduler('data/trade_logs.db', Path('data/backups'))
|
||||
|
||||
# List backups
|
||||
backups = scheduler.list_backups()
|
||||
for backup in backups:
|
||||
print(f"{backup.timestamp}: {backup.file_path}")
|
||||
|
||||
# Restore specific backup
|
||||
scheduler.restore_backup(backups[0], verify=True)
|
||||
```
|
||||
|
||||
## Health Monitoring
|
||||
|
||||
### Check System Health
|
||||
|
||||
```python
|
||||
from pathlib import Path
|
||||
from src.backup.health_monitor import HealthMonitor
|
||||
|
||||
monitor = HealthMonitor('data/trade_logs.db', Path('data/backups'))
|
||||
|
||||
# Run all checks
|
||||
report = monitor.get_health_report()
|
||||
print(f"Overall status: {report['overall_status']}")
|
||||
|
||||
# Individual checks
|
||||
checks = monitor.run_all_checks()
|
||||
for name, result in checks.items():
|
||||
print(f"{name}: {result.status.value} - {result.message}")
|
||||
```
|
||||
|
||||
### Health Checks
|
||||
|
||||
The system monitors:
|
||||
|
||||
- **Database Health**: Accessibility, integrity, size
|
||||
- **Disk Space**: Available storage (alerts if < 10 GB)
|
||||
- **Backup Recency**: Ensures backups are < 25 hours old
|
||||
|
||||
### Health Status Levels
|
||||
|
||||
- **HEALTHY**: All systems operational
|
||||
- **DEGRADED**: Warning condition (e.g., low disk space)
|
||||
- **UNHEALTHY**: Critical issue (e.g., database corrupted, no backups)
|
||||
|
||||
## Export Formats
|
||||
|
||||
### JSON (Human-Readable)
|
||||
|
||||
```json
|
||||
{
|
||||
"export_timestamp": "2024-01-15T10:30:00Z",
|
||||
"record_count": 150,
|
||||
"trades": [
|
||||
{
|
||||
"timestamp": "2024-01-15T09:00:00Z",
|
||||
"stock_code": "005930",
|
||||
"action": "BUY",
|
||||
"quantity": 10,
|
||||
"price": 70000.0,
|
||||
"confidence": 85,
|
||||
"rationale": "Strong momentum",
|
||||
"pnl": 0.0
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### CSV (Analysis Tools)
|
||||
|
||||
Compatible with Excel, pandas, R:
|
||||
|
||||
```csv
|
||||
timestamp,stock_code,action,quantity,price,confidence,rationale,pnl
|
||||
2024-01-15T09:00:00Z,005930,BUY,10,70000.0,85,Strong momentum,0.0
|
||||
```
|
||||
|
||||
### Parquet (Big Data)
|
||||
|
||||
Columnar format for Spark, DuckDB:
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
df = pd.read_parquet('exports/trades_20240115.parquet')
|
||||
```
|
||||
|
||||
## RTO/RPO
|
||||
|
||||
### Recovery Time Objective (RTO)
|
||||
|
||||
**Target: < 5 minutes**
|
||||
|
||||
Time to restore trading operations:
|
||||
1. Identify backup to restore (1 min)
|
||||
2. Run restore script (2 min)
|
||||
3. Verify database integrity (1 min)
|
||||
4. Restart trading system (1 min)
|
||||
|
||||
### Recovery Point Objective (RPO)
|
||||
|
||||
**Target: < 24 hours**
|
||||
|
||||
Maximum acceptable data loss:
|
||||
- Daily backups ensure ≤ 24-hour data loss
|
||||
- For critical periods, run backups more frequently
|
||||
|
||||
## Testing Recovery
|
||||
|
||||
### Quarterly Recovery Test
|
||||
|
||||
Perform full disaster recovery test every quarter:
|
||||
|
||||
1. **Create test backup**
|
||||
```bash
|
||||
./scripts/backup.sh
|
||||
```
|
||||
|
||||
2. **Simulate disaster** (use test database)
|
||||
```bash
|
||||
cp data/trade_logs.db data/trade_logs_test.db
|
||||
rm data/trade_logs_test.db # Simulate data loss
|
||||
```
|
||||
|
||||
3. **Restore from backup**
|
||||
```bash
|
||||
DB_PATH=data/trade_logs_test.db ./scripts/restore.sh
|
||||
```
|
||||
|
||||
4. **Verify data integrity**
|
||||
```python
|
||||
import sqlite3
|
||||
conn = sqlite3.connect('data/trade_logs_test.db')
|
||||
cursor = conn.execute('SELECT COUNT(*) FROM trades')
|
||||
print(f"Restored {cursor.fetchone()[0]} trades")
|
||||
```
|
||||
|
||||
5. **Document results** in `logs/recovery_test_YYYYMMDD.md`
|
||||
|
||||
### Backup Verification
|
||||
|
||||
Always verify backups after creation:
|
||||
|
||||
```python
|
||||
from pathlib import Path
|
||||
from src.backup.scheduler import BackupScheduler
|
||||
|
||||
scheduler = BackupScheduler('data/trade_logs.db', Path('data/backups'))
|
||||
|
||||
# Create and verify
|
||||
metadata = scheduler.create_backup(BackupPolicy.DAILY, verify=True)
|
||||
print(f"Checksum: {metadata.checksum}") # Should not be None
|
||||
```
|
||||
|
||||
## Emergency Procedures
|
||||
|
||||
### Database Corrupted
|
||||
|
||||
1. Stop trading system immediately
|
||||
2. Check most recent backup age: `ls -lht data/backups/daily/`
|
||||
3. Restore: `./scripts/restore.sh`
|
||||
4. Verify: Run health check
|
||||
5. Resume trading
|
||||
|
||||
### Disk Full
|
||||
|
||||
1. Check disk space: `df -h`
|
||||
2. Clean old backups: Run cleanup manually
|
||||
```python
|
||||
from pathlib import Path
|
||||
from src.backup.scheduler import BackupScheduler
|
||||
scheduler = BackupScheduler('data/trade_logs.db', Path('data/backups'))
|
||||
scheduler.cleanup_old_backups()
|
||||
```
|
||||
3. Consider archiving old monthly backups to external storage
|
||||
4. Increase disk space if needed
|
||||
|
||||
### Lost All Backups
|
||||
|
||||
If local backups are lost:
|
||||
1. Check if exports exist in `exports/` directory
|
||||
2. Reconstruct database from CSV/JSON exports
|
||||
3. If no exports: Check broker API for trade history
|
||||
4. Manual reconstruction as last resort
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Test Restores Regularly**: Don't wait for disaster
|
||||
2. **Monitor Disk Space**: Set up alerts at 80% usage
|
||||
3. **Keep Multiple Generations**: Never delete all backups at once
|
||||
4. **Verify Checksums**: Always verify backup integrity
|
||||
5. **Document Changes**: Update this guide when backup strategy changes
|
||||
6. **Off-Site Storage**: Consider external backup for monthly archives
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Backup Script Fails
|
||||
|
||||
```bash
|
||||
# Check database file permissions
|
||||
ls -l data/trade_logs.db
|
||||
|
||||
# Check disk space
|
||||
df -h data/
|
||||
|
||||
# Run backup manually with debug
|
||||
python3 -c "
|
||||
import logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
from pathlib import Path
|
||||
from src.backup.scheduler import BackupScheduler, BackupPolicy
|
||||
scheduler = BackupScheduler('data/trade_logs.db', Path('data/backups'))
|
||||
scheduler.create_backup(BackupPolicy.DAILY, verify=True)
|
||||
"
|
||||
```
|
||||
|
||||
### Restore Fails Verification
|
||||
|
||||
```bash
|
||||
# Check backup file integrity
|
||||
python3 -c "
|
||||
import sqlite3
|
||||
conn = sqlite3.connect('data/backups/daily/trade_logs_daily_20240115.db')
|
||||
cursor = conn.execute('PRAGMA integrity_check')
|
||||
print(cursor.fetchone()[0])
|
||||
"
|
||||
```
|
||||
|
||||
### Health Check Fails
|
||||
|
||||
```python
|
||||
from pathlib import Path
|
||||
from src.backup.health_monitor import HealthMonitor
|
||||
|
||||
monitor = HealthMonitor('data/trade_logs.db', Path('data/backups'))
|
||||
|
||||
# Check each component individually
|
||||
print("Database:", monitor.check_database_health())
|
||||
print("Disk Space:", monitor.check_disk_space())
|
||||
print("Backup Recency:", monitor.check_backup_recency())
|
||||
```
|
||||
|
||||
## Contact
|
||||
|
||||
For backup/recovery issues:
|
||||
- Check logs: `logs/backup.log`
|
||||
- Review health status: Run health monitor
|
||||
- Raise issue on GitHub if automated recovery fails
|
||||
66
docs/requirements-log.md
Normal file
66
docs/requirements-log.md
Normal file
@@ -0,0 +1,66 @@
|
||||
# Requirements Log
|
||||
|
||||
프로젝트 진화를 위한 사용자 요구사항 기록.
|
||||
|
||||
이 문서는 시간순으로 사용자와의 대화에서 나온 요구사항과 피드백을 기록합니다.
|
||||
새로운 요구사항이 있으면 날짜와 함께 추가하세요.
|
||||
|
||||
---
|
||||
|
||||
## 2026-02-05
|
||||
|
||||
### API 효율화
|
||||
- Gemini API는 귀중한 자원. 종목별 개별 호출 대신 배치 호출 필요
|
||||
- Free tier 한도(20 calls/day) 고려하여 일일 몇 차례 거래 모드로 전환
|
||||
- 배치 API 호출로 여러 종목을 한 번에 분석
|
||||
|
||||
### 거래 모드
|
||||
- **Daily Mode**: 하루 4회 거래 세션 (6시간 간격) - Free tier 호환
|
||||
- **Realtime Mode**: 60초 간격 실시간 거래 - 유료 구독 필요
|
||||
- `TRADE_MODE` 환경변수로 모드 선택
|
||||
|
||||
### 진화 시스템
|
||||
- 사용자 대화 내용을 문서로 기록하여 향후에도 의도 반영
|
||||
- 프롬프트 품질 검증은 별도 이슈로 다룰 예정
|
||||
|
||||
### 문서화
|
||||
- 시스템 구조, 기능별 설명 등 코드 문서화 항상 신경쓸 것
|
||||
- 새로운 기능 추가 시 관련 문서 업데이트 필수
|
||||
|
||||
---
|
||||
|
||||
## 2026-02-06
|
||||
|
||||
### Smart Volatility Scanner (Python-First, AI-Last 파이프라인)
|
||||
|
||||
**배경:**
|
||||
- 정적 종목 리스트를 순회하는 방식은 비효율적
|
||||
- KIS API 거래량 순위를 통해 시장 주도주를 자동 탐지해야 함
|
||||
- Gemini API 호출 전에 Python 기반 기술적 분석으로 필터링 필요
|
||||
|
||||
**요구사항:**
|
||||
1. KIS API 거래량 순위 API 통합 (`fetch_market_rankings`)
|
||||
2. 일별 가격 히스토리 API 추가 (`get_daily_prices`)
|
||||
3. RSI(14) 계산 기능 구현 (Wilder's smoothing method)
|
||||
4. 필터 조건:
|
||||
- 거래량 > 전일 대비 200% (VOL_MULTIPLIER)
|
||||
- RSI < 30 (과매도) OR RSI > 70 (모멘텀)
|
||||
5. 상위 1-3개 적격 종목만 Gemini에 전달
|
||||
6. 종목 선정 배경(RSI, volume_ratio, signal, score) 데이터베이스 기록
|
||||
|
||||
**구현 결과:**
|
||||
- `src/analysis/smart_scanner.py`: SmartVolatilityScanner 클래스
|
||||
- `src/analysis/volatility.py`: calculate_rsi() 메서드 추가
|
||||
- `src/broker/kis_api.py`: 2개 신규 API 메서드
|
||||
- `src/db.py`: selection_context 컬럼 추가
|
||||
- 설정 가능한 임계값: RSI_OVERSOLD_THRESHOLD, RSI_MOMENTUM_THRESHOLD, VOL_MULTIPLIER, SCANNER_TOP_N
|
||||
|
||||
**효과:**
|
||||
- Gemini API 호출 20-30개 → 1-3개로 감소
|
||||
- Python 기반 빠른 필터링 → 비용 절감
|
||||
- 선정 기준 추적 → Evolution 시스템 최적화 가능
|
||||
- API 장애 시 정적 watchlist로 자동 전환
|
||||
|
||||
**참고:** Realtime 모드 전용. Daily 모드는 배치 효율성을 위해 정적 watchlist 사용.
|
||||
|
||||
**이슈/PR:** #76, #77
|
||||
96
scripts/backup.sh
Normal file
96
scripts/backup.sh
Normal file
@@ -0,0 +1,96 @@
|
||||
#!/usr/bin/env bash
|
||||
# Automated backup script for The Ouroboros trading system
|
||||
# Runs daily/weekly/monthly backups
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Configuration
|
||||
DB_PATH="${DB_PATH:-data/trade_logs.db}"
|
||||
BACKUP_DIR="${BACKUP_DIR:-data/backups}"
|
||||
PYTHON="${PYTHON:-python3}"
|
||||
|
||||
# Colors for output
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
log_info() {
|
||||
echo -e "${GREEN}[INFO]${NC} $1"
|
||||
}
|
||||
|
||||
log_warn() {
|
||||
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||
}
|
||||
|
||||
log_error() {
|
||||
echo -e "${RED}[ERROR]${NC} $1"
|
||||
}
|
||||
|
||||
# Check if database exists
|
||||
if [ ! -f "$DB_PATH" ]; then
|
||||
log_error "Database not found: $DB_PATH"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Create backup directory
|
||||
mkdir -p "$BACKUP_DIR"
|
||||
|
||||
log_info "Starting backup process..."
|
||||
log_info "Database: $DB_PATH"
|
||||
log_info "Backup directory: $BACKUP_DIR"
|
||||
|
||||
# Determine backup policy based on day of week and month
|
||||
DAY_OF_WEEK=$(date +%u) # 1-7 (Monday-Sunday)
|
||||
DAY_OF_MONTH=$(date +%d)
|
||||
|
||||
if [ "$DAY_OF_MONTH" == "01" ]; then
|
||||
POLICY="monthly"
|
||||
log_info "Running MONTHLY backup (first day of month)"
|
||||
elif [ "$DAY_OF_WEEK" == "7" ]; then
|
||||
POLICY="weekly"
|
||||
log_info "Running WEEKLY backup (Sunday)"
|
||||
else
|
||||
POLICY="daily"
|
||||
log_info "Running DAILY backup"
|
||||
fi
|
||||
|
||||
# Run Python backup script
|
||||
$PYTHON -c "
|
||||
from pathlib import Path
|
||||
from src.backup.scheduler import BackupScheduler, BackupPolicy
|
||||
from src.backup.health_monitor import HealthMonitor
|
||||
|
||||
# Create scheduler
|
||||
scheduler = BackupScheduler(
|
||||
db_path='$DB_PATH',
|
||||
backup_dir=Path('$BACKUP_DIR')
|
||||
)
|
||||
|
||||
# Create backup
|
||||
policy = BackupPolicy.$POLICY.upper()
|
||||
metadata = scheduler.create_backup(policy, verify=True)
|
||||
print(f'Backup created: {metadata.file_path}')
|
||||
print(f'Size: {metadata.size_bytes / 1024 / 1024:.2f} MB')
|
||||
print(f'Checksum: {metadata.checksum}')
|
||||
|
||||
# Cleanup old backups
|
||||
removed = scheduler.cleanup_old_backups()
|
||||
total_removed = sum(removed.values())
|
||||
if total_removed > 0:
|
||||
print(f'Removed {total_removed} old backup(s)')
|
||||
|
||||
# Health check
|
||||
monitor = HealthMonitor('$DB_PATH', Path('$BACKUP_DIR'))
|
||||
status = monitor.get_overall_status()
|
||||
print(f'System health: {status.value}')
|
||||
"
|
||||
|
||||
if [ $? -eq 0 ]; then
|
||||
log_info "Backup completed successfully"
|
||||
else
|
||||
log_error "Backup failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
log_info "Backup process finished"
|
||||
111
scripts/restore.sh
Normal file
111
scripts/restore.sh
Normal file
@@ -0,0 +1,111 @@
|
||||
#!/usr/bin/env bash
|
||||
# Restore script for The Ouroboros trading system
|
||||
# Restores database from a backup file
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Configuration
|
||||
DB_PATH="${DB_PATH:-data/trade_logs.db}"
|
||||
BACKUP_DIR="${BACKUP_DIR:-data/backups}"
|
||||
PYTHON="${PYTHON:-python3}"
|
||||
|
||||
# Colors for output
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
log_info() {
|
||||
echo -e "${GREEN}[INFO]${NC} $1"
|
||||
}
|
||||
|
||||
log_warn() {
|
||||
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||
}
|
||||
|
||||
log_error() {
|
||||
echo -e "${RED}[ERROR]${NC} $1"
|
||||
}
|
||||
|
||||
# Check if backup directory exists
|
||||
if [ ! -d "$BACKUP_DIR" ]; then
|
||||
log_error "Backup directory not found: $BACKUP_DIR"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
log_info "Available backups:"
|
||||
log_info "=================="
|
||||
|
||||
# List available backups
|
||||
$PYTHON -c "
|
||||
from pathlib import Path
|
||||
from src.backup.scheduler import BackupScheduler
|
||||
|
||||
scheduler = BackupScheduler(
|
||||
db_path='$DB_PATH',
|
||||
backup_dir=Path('$BACKUP_DIR')
|
||||
)
|
||||
|
||||
backups = scheduler.list_backups()
|
||||
|
||||
if not backups:
|
||||
print('No backups found.')
|
||||
exit(1)
|
||||
|
||||
for i, backup in enumerate(backups, 1):
|
||||
size_mb = backup.size_bytes / 1024 / 1024
|
||||
print(f'{i}. [{backup.policy.value.upper()}] {backup.file_path.name}')
|
||||
print(f' Date: {backup.timestamp.strftime(\"%Y-%m-%d %H:%M:%S UTC\")}')
|
||||
print(f' Size: {size_mb:.2f} MB')
|
||||
print()
|
||||
"
|
||||
|
||||
# Ask user to select backup
|
||||
echo ""
|
||||
read -p "Enter backup number to restore (or 'q' to quit): " BACKUP_NUM
|
||||
|
||||
if [ "$BACKUP_NUM" == "q" ]; then
|
||||
log_info "Restore cancelled"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Confirm restoration
|
||||
log_warn "WARNING: This will replace the current database!"
|
||||
log_warn "Current database will be backed up to: ${DB_PATH}.before_restore"
|
||||
read -p "Are you sure you want to continue? (yes/no): " CONFIRM
|
||||
|
||||
if [ "$CONFIRM" != "yes" ]; then
|
||||
log_info "Restore cancelled"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Perform restoration
|
||||
$PYTHON -c "
|
||||
from pathlib import Path
|
||||
from src.backup.scheduler import BackupScheduler
|
||||
|
||||
scheduler = BackupScheduler(
|
||||
db_path='$DB_PATH',
|
||||
backup_dir=Path('$BACKUP_DIR')
|
||||
)
|
||||
|
||||
backups = scheduler.list_backups()
|
||||
backup_index = int('$BACKUP_NUM') - 1
|
||||
|
||||
if backup_index < 0 or backup_index >= len(backups):
|
||||
print('Invalid backup number')
|
||||
exit(1)
|
||||
|
||||
selected = backups[backup_index]
|
||||
print(f'Restoring: {selected.file_path.name}')
|
||||
|
||||
scheduler.restore_backup(selected, verify=True)
|
||||
print('Restore completed successfully')
|
||||
"
|
||||
|
||||
if [ $? -eq 0 ]; then
|
||||
log_info "Database restored successfully"
|
||||
else
|
||||
log_error "Restore failed"
|
||||
exit 1
|
||||
fi
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from src.analysis.scanner import MarketScanner
|
||||
from src.analysis.smart_scanner import ScanCandidate, SmartVolatilityScanner
|
||||
from src.analysis.volatility import VolatilityAnalyzer
|
||||
|
||||
__all__ = ["VolatilityAnalyzer", "MarketScanner"]
|
||||
__all__ = ["VolatilityAnalyzer", "MarketScanner", "SmartVolatilityScanner", "ScanCandidate"]
|
||||
|
||||
@@ -42,6 +42,7 @@ class MarketScanner:
|
||||
volatility_analyzer: VolatilityAnalyzer,
|
||||
context_store: ContextStore,
|
||||
top_n: int = 5,
|
||||
max_concurrent_scans: int = 1,
|
||||
) -> None:
|
||||
"""Initialize the market scanner.
|
||||
|
||||
@@ -51,12 +52,14 @@ class MarketScanner:
|
||||
volatility_analyzer: Volatility analyzer instance
|
||||
context_store: Context store for L7 real-time data
|
||||
top_n: Number of top movers to return per market (default 5)
|
||||
max_concurrent_scans: Max concurrent stock scans (default 1, fully serialized)
|
||||
"""
|
||||
self.broker = broker
|
||||
self.overseas_broker = overseas_broker
|
||||
self.analyzer = volatility_analyzer
|
||||
self.context_store = context_store
|
||||
self.top_n = top_n
|
||||
self._scan_semaphore = asyncio.Semaphore(max_concurrent_scans)
|
||||
|
||||
async def scan_stock(
|
||||
self,
|
||||
@@ -83,8 +86,8 @@ class MarketScanner:
|
||||
# Convert to orderbook-like structure
|
||||
orderbook = {
|
||||
"output1": {
|
||||
"stck_prpr": price_data.get("output", {}).get("last", "0"),
|
||||
"acml_vol": price_data.get("output", {}).get("tvol", "0"),
|
||||
"stck_prpr": price_data.get("output", {}).get("last", "0") or "0",
|
||||
"acml_vol": price_data.get("output", {}).get("tvol", "0") or "0",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -139,8 +142,12 @@ class MarketScanner:
|
||||
|
||||
logger.info("Scanning %s market (%d stocks)", market.name, len(stock_codes))
|
||||
|
||||
# Scan all stocks concurrently (with rate limiting handled by broker)
|
||||
tasks = [self.scan_stock(code, market) for code in stock_codes]
|
||||
# Scan stocks with bounded concurrency to prevent API rate limit burst
|
||||
async def _bounded_scan(code: str) -> VolatilityMetrics | None:
|
||||
async with self._scan_semaphore:
|
||||
return await self.scan_stock(code, market)
|
||||
|
||||
tasks = [_bounded_scan(code) for code in stock_codes]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Filter out failures and sort by momentum score
|
||||
|
||||
192
src/analysis/smart_scanner.py
Normal file
192
src/analysis/smart_scanner.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""Smart Volatility Scanner with RSI and volume filters.
|
||||
|
||||
Fetches market rankings from KIS API and applies technical filters
|
||||
to identify high-probability trading candidates.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from src.analysis.volatility import VolatilityAnalyzer
|
||||
from src.broker.kis_api import KISBroker
|
||||
from src.config import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScanCandidate:
|
||||
"""A qualified candidate from the smart scanner."""
|
||||
|
||||
stock_code: str
|
||||
name: str
|
||||
price: float
|
||||
volume: float
|
||||
volume_ratio: float # Current volume / previous day volume
|
||||
rsi: float
|
||||
signal: str # "oversold" or "momentum"
|
||||
score: float # Composite score for ranking
|
||||
|
||||
|
||||
class SmartVolatilityScanner:
|
||||
"""Scans market rankings and applies RSI/volume filters.
|
||||
|
||||
Flow:
|
||||
1. Fetch volume rankings from KIS API
|
||||
2. For each ranked stock, fetch daily prices
|
||||
3. Calculate RSI and volume ratio
|
||||
4. Apply filters: volume > VOL_MULTIPLIER AND (RSI < 30 OR RSI > 70)
|
||||
5. Return top N qualified candidates
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
broker: KISBroker,
|
||||
volatility_analyzer: VolatilityAnalyzer,
|
||||
settings: Settings,
|
||||
) -> None:
|
||||
"""Initialize the smart scanner.
|
||||
|
||||
Args:
|
||||
broker: KIS broker for API calls
|
||||
volatility_analyzer: Analyzer for RSI calculation
|
||||
settings: Application settings
|
||||
"""
|
||||
self.broker = broker
|
||||
self.analyzer = volatility_analyzer
|
||||
self.settings = settings
|
||||
|
||||
# Extract scanner settings
|
||||
self.rsi_oversold = settings.RSI_OVERSOLD_THRESHOLD
|
||||
self.rsi_momentum = settings.RSI_MOMENTUM_THRESHOLD
|
||||
self.vol_multiplier = settings.VOL_MULTIPLIER
|
||||
self.top_n = settings.SCANNER_TOP_N
|
||||
|
||||
async def scan(
|
||||
self,
|
||||
fallback_stocks: list[str] | None = None,
|
||||
) -> list[ScanCandidate]:
|
||||
"""Execute smart scan and return qualified candidates.
|
||||
|
||||
Args:
|
||||
fallback_stocks: Stock codes to use if ranking API fails
|
||||
|
||||
Returns:
|
||||
List of ScanCandidate, sorted by score, up to top_n items
|
||||
"""
|
||||
# Step 1: Fetch rankings
|
||||
try:
|
||||
rankings = await self.broker.fetch_market_rankings(
|
||||
ranking_type="volume",
|
||||
limit=30, # Fetch more than needed for filtering
|
||||
)
|
||||
logger.info("Fetched %d stocks from volume rankings", len(rankings))
|
||||
except ConnectionError as exc:
|
||||
logger.warning("Ranking API failed, using fallback: %s", exc)
|
||||
if fallback_stocks:
|
||||
# Create minimal ranking data for fallback
|
||||
rankings = [
|
||||
{
|
||||
"stock_code": code,
|
||||
"name": code,
|
||||
"price": 0,
|
||||
"volume": 0,
|
||||
"change_rate": 0,
|
||||
"volume_increase_rate": 0,
|
||||
}
|
||||
for code in fallback_stocks
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
# Step 2: Analyze each stock
|
||||
candidates: list[ScanCandidate] = []
|
||||
|
||||
for stock in rankings:
|
||||
stock_code = stock["stock_code"]
|
||||
if not stock_code:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Fetch daily prices for RSI calculation
|
||||
daily_prices = await self.broker.get_daily_prices(stock_code, days=20)
|
||||
|
||||
if len(daily_prices) < 15: # Need at least 14+1 for RSI
|
||||
logger.debug("Insufficient price history for %s", stock_code)
|
||||
continue
|
||||
|
||||
# Calculate RSI
|
||||
close_prices = [p["close"] for p in daily_prices]
|
||||
rsi = self.analyzer.calculate_rsi(close_prices, period=14)
|
||||
|
||||
# Calculate volume ratio (today vs previous day avg)
|
||||
if len(daily_prices) >= 2:
|
||||
prev_day_volume = daily_prices[-2]["volume"]
|
||||
current_volume = stock.get("volume", 0) or daily_prices[-1]["volume"]
|
||||
volume_ratio = (
|
||||
current_volume / prev_day_volume if prev_day_volume > 0 else 1.0
|
||||
)
|
||||
else:
|
||||
volume_ratio = stock.get("volume_increase_rate", 0) / 100 + 1 # Fallback
|
||||
|
||||
# Apply filters
|
||||
volume_qualified = volume_ratio >= self.vol_multiplier
|
||||
rsi_oversold = rsi < self.rsi_oversold
|
||||
rsi_momentum = rsi > self.rsi_momentum
|
||||
|
||||
if volume_qualified and (rsi_oversold or rsi_momentum):
|
||||
signal = "oversold" if rsi_oversold else "momentum"
|
||||
|
||||
# Calculate composite score
|
||||
# Higher score for: extreme RSI + high volume
|
||||
rsi_extremity = abs(rsi - 50) / 50 # 0-1 scale
|
||||
volume_score = min(volume_ratio / 5, 1.0) # Cap at 5x
|
||||
score = (rsi_extremity * 0.6 + volume_score * 0.4) * 100
|
||||
|
||||
candidates.append(
|
||||
ScanCandidate(
|
||||
stock_code=stock_code,
|
||||
name=stock.get("name", stock_code),
|
||||
price=stock.get("price", daily_prices[-1]["close"]),
|
||||
volume=current_volume,
|
||||
volume_ratio=volume_ratio,
|
||||
rsi=rsi,
|
||||
signal=signal,
|
||||
score=score,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Qualified: %s (%s) RSI=%.1f vol=%.1fx signal=%s score=%.1f",
|
||||
stock_code,
|
||||
stock.get("name", ""),
|
||||
rsi,
|
||||
volume_ratio,
|
||||
signal,
|
||||
score,
|
||||
)
|
||||
|
||||
except ConnectionError as exc:
|
||||
logger.warning("Failed to analyze %s: %s", stock_code, exc)
|
||||
continue
|
||||
except Exception as exc:
|
||||
logger.error("Unexpected error analyzing %s: %s", stock_code, exc)
|
||||
continue
|
||||
|
||||
# Sort by score and return top N
|
||||
candidates.sort(key=lambda c: c.score, reverse=True)
|
||||
return candidates[: self.top_n]
|
||||
|
||||
def get_stock_codes(self, candidates: list[ScanCandidate]) -> list[str]:
|
||||
"""Extract stock codes from candidates for watchlist update.
|
||||
|
||||
Args:
|
||||
candidates: List of scan candidates
|
||||
|
||||
Returns:
|
||||
List of stock codes
|
||||
"""
|
||||
return [c.stock_code for c in candidates]
|
||||
@@ -124,6 +124,54 @@ class VolatilityAnalyzer:
|
||||
return 1.0
|
||||
return current_volume / avg_volume
|
||||
|
||||
def calculate_rsi(
|
||||
self,
|
||||
close_prices: list[float],
|
||||
period: int = 14,
|
||||
) -> float:
|
||||
"""Calculate Relative Strength Index (RSI) using Wilder's smoothing.
|
||||
|
||||
Args:
|
||||
close_prices: List of closing prices (oldest to newest, minimum period+1 values)
|
||||
period: RSI period (default 14)
|
||||
|
||||
Returns:
|
||||
RSI value between 0 and 100, or 50.0 (neutral) if insufficient data
|
||||
|
||||
Examples:
|
||||
>>> analyzer = VolatilityAnalyzer()
|
||||
>>> prices = [100 - i * 0.5 for i in range(20)] # Downtrend
|
||||
>>> rsi = analyzer.calculate_rsi(prices)
|
||||
>>> assert rsi < 50 # Oversold territory
|
||||
"""
|
||||
if len(close_prices) < period + 1:
|
||||
return 50.0 # Neutral RSI if insufficient data
|
||||
|
||||
# Calculate price changes
|
||||
changes = [close_prices[i] - close_prices[i - 1] for i in range(1, len(close_prices))]
|
||||
|
||||
# Separate gains and losses
|
||||
gains = [max(0.0, change) for change in changes]
|
||||
losses = [max(0.0, -change) for change in changes]
|
||||
|
||||
# Calculate initial average gain/loss (simple average for first period)
|
||||
avg_gain = sum(gains[:period]) / period
|
||||
avg_loss = sum(losses[:period]) / period
|
||||
|
||||
# Apply Wilder's smoothing for remaining periods
|
||||
for i in range(period, len(changes)):
|
||||
avg_gain = (avg_gain * (period - 1) + gains[i]) / period
|
||||
avg_loss = (avg_loss * (period - 1) + losses[i]) / period
|
||||
|
||||
# Calculate RS and RSI
|
||||
if avg_loss == 0:
|
||||
return 100.0 # All gains, maximum RSI
|
||||
|
||||
rs = avg_gain / avg_loss
|
||||
rsi = 100 - (100 / (1 + rs))
|
||||
|
||||
return rsi
|
||||
|
||||
def calculate_pv_divergence(
|
||||
self,
|
||||
price_change: float,
|
||||
|
||||
21
src/backup/__init__.py
Normal file
21
src/backup/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Backup and disaster recovery system for long-term sustainability.
|
||||
|
||||
This module provides:
|
||||
- Automated database backups (daily, weekly, monthly)
|
||||
- Multi-format exports (JSON, CSV, Parquet)
|
||||
- Cloud storage integration (S3-compatible)
|
||||
- Health monitoring and alerts
|
||||
"""
|
||||
|
||||
from src.backup.exporter import BackupExporter, ExportFormat
|
||||
from src.backup.scheduler import BackupScheduler, BackupPolicy
|
||||
from src.backup.cloud_storage import CloudStorage, S3Config
|
||||
|
||||
__all__ = [
|
||||
"BackupExporter",
|
||||
"ExportFormat",
|
||||
"BackupScheduler",
|
||||
"BackupPolicy",
|
||||
"CloudStorage",
|
||||
"S3Config",
|
||||
]
|
||||
274
src/backup/cloud_storage.py
Normal file
274
src/backup/cloud_storage.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""Cloud storage integration for off-site backups.
|
||||
|
||||
Supports S3-compatible storage providers:
|
||||
- AWS S3
|
||||
- MinIO
|
||||
- Backblaze B2
|
||||
- DigitalOcean Spaces
|
||||
- Cloudflare R2
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class S3Config:
|
||||
"""Configuration for S3-compatible storage."""
|
||||
|
||||
endpoint_url: str | None # None for AWS S3, custom URL for others
|
||||
access_key: str
|
||||
secret_key: str
|
||||
bucket_name: str
|
||||
region: str = "us-east-1"
|
||||
use_ssl: bool = True
|
||||
|
||||
|
||||
class CloudStorage:
|
||||
"""Upload backups to S3-compatible cloud storage."""
|
||||
|
||||
def __init__(self, config: S3Config) -> None:
|
||||
"""Initialize cloud storage client.
|
||||
|
||||
Args:
|
||||
config: S3 configuration
|
||||
|
||||
Raises:
|
||||
ImportError: If boto3 is not installed
|
||||
"""
|
||||
try:
|
||||
import boto3
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"boto3 is required for cloud storage. Install with: pip install boto3"
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.client = boto3.client(
|
||||
"s3",
|
||||
endpoint_url=config.endpoint_url,
|
||||
aws_access_key_id=config.access_key,
|
||||
aws_secret_access_key=config.secret_key,
|
||||
region_name=config.region,
|
||||
use_ssl=config.use_ssl,
|
||||
)
|
||||
|
||||
def upload_file(
|
||||
self,
|
||||
file_path: Path,
|
||||
object_key: str | None = None,
|
||||
metadata: dict[str, str] | None = None,
|
||||
) -> str:
|
||||
"""Upload a file to cloud storage.
|
||||
|
||||
Args:
|
||||
file_path: Local file to upload
|
||||
object_key: S3 object key (default: filename)
|
||||
metadata: Optional metadata to attach
|
||||
|
||||
Returns:
|
||||
S3 object key
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
Exception: If upload fails
|
||||
"""
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
if object_key is None:
|
||||
object_key = file_path.name
|
||||
|
||||
extra_args: dict[str, Any] = {}
|
||||
|
||||
# Add server-side encryption
|
||||
extra_args["ServerSideEncryption"] = "AES256"
|
||||
|
||||
# Add metadata if provided
|
||||
if metadata:
|
||||
extra_args["Metadata"] = metadata
|
||||
|
||||
logger.info("Uploading %s to s3://%s/%s", file_path.name, self.config.bucket_name, object_key)
|
||||
|
||||
try:
|
||||
self.client.upload_file(
|
||||
str(file_path),
|
||||
self.config.bucket_name,
|
||||
object_key,
|
||||
ExtraArgs=extra_args,
|
||||
)
|
||||
logger.info("Upload successful: %s", object_key)
|
||||
return object_key
|
||||
except Exception as exc:
|
||||
logger.error("Upload failed: %s", exc)
|
||||
raise
|
||||
|
||||
def download_file(self, object_key: str, local_path: Path) -> Path:
|
||||
"""Download a file from cloud storage.
|
||||
|
||||
Args:
|
||||
object_key: S3 object key
|
||||
local_path: Local destination path
|
||||
|
||||
Returns:
|
||||
Path to downloaded file
|
||||
|
||||
Raises:
|
||||
Exception: If download fails
|
||||
"""
|
||||
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info("Downloading s3://%s/%s to %s", self.config.bucket_name, object_key, local_path)
|
||||
|
||||
try:
|
||||
self.client.download_file(
|
||||
self.config.bucket_name,
|
||||
object_key,
|
||||
str(local_path),
|
||||
)
|
||||
logger.info("Download successful: %s", local_path)
|
||||
return local_path
|
||||
except Exception as exc:
|
||||
logger.error("Download failed: %s", exc)
|
||||
raise
|
||||
|
||||
def list_files(self, prefix: str = "") -> list[dict[str, Any]]:
|
||||
"""List files in cloud storage.
|
||||
|
||||
Args:
|
||||
prefix: Filter by object key prefix
|
||||
|
||||
Returns:
|
||||
List of file metadata dictionaries
|
||||
"""
|
||||
try:
|
||||
response = self.client.list_objects_v2(
|
||||
Bucket=self.config.bucket_name,
|
||||
Prefix=prefix,
|
||||
)
|
||||
|
||||
if "Contents" not in response:
|
||||
return []
|
||||
|
||||
files = []
|
||||
for obj in response["Contents"]:
|
||||
files.append(
|
||||
{
|
||||
"key": obj["Key"],
|
||||
"size_bytes": obj["Size"],
|
||||
"last_modified": obj["LastModified"],
|
||||
"etag": obj["ETag"],
|
||||
}
|
||||
)
|
||||
|
||||
return files
|
||||
except Exception as exc:
|
||||
logger.error("Failed to list files: %s", exc)
|
||||
raise
|
||||
|
||||
def delete_file(self, object_key: str) -> None:
|
||||
"""Delete a file from cloud storage.
|
||||
|
||||
Args:
|
||||
object_key: S3 object key
|
||||
|
||||
Raises:
|
||||
Exception: If deletion fails
|
||||
"""
|
||||
logger.info("Deleting s3://%s/%s", self.config.bucket_name, object_key)
|
||||
|
||||
try:
|
||||
self.client.delete_object(
|
||||
Bucket=self.config.bucket_name,
|
||||
Key=object_key,
|
||||
)
|
||||
logger.info("Deletion successful: %s", object_key)
|
||||
except Exception as exc:
|
||||
logger.error("Deletion failed: %s", exc)
|
||||
raise
|
||||
|
||||
def get_storage_stats(self) -> dict[str, Any]:
|
||||
"""Get cloud storage statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with storage stats
|
||||
"""
|
||||
try:
|
||||
files = self.list_files()
|
||||
|
||||
total_size = sum(f["size_bytes"] for f in files)
|
||||
total_count = len(files)
|
||||
|
||||
return {
|
||||
"total_files": total_count,
|
||||
"total_size_bytes": total_size,
|
||||
"total_size_mb": total_size / 1024 / 1024,
|
||||
"total_size_gb": total_size / 1024 / 1024 / 1024,
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.error("Failed to get storage stats: %s", exc)
|
||||
return {
|
||||
"error": str(exc),
|
||||
"total_files": 0,
|
||||
"total_size_bytes": 0,
|
||||
}
|
||||
|
||||
def verify_connection(self) -> bool:
|
||||
"""Verify connection to cloud storage.
|
||||
|
||||
Returns:
|
||||
True if connection is successful
|
||||
"""
|
||||
try:
|
||||
self.client.head_bucket(Bucket=self.config.bucket_name)
|
||||
logger.info("Cloud storage connection verified")
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.error("Cloud storage connection failed: %s", exc)
|
||||
return False
|
||||
|
||||
def create_bucket_if_not_exists(self) -> None:
|
||||
"""Create storage bucket if it doesn't exist.
|
||||
|
||||
Raises:
|
||||
Exception: If bucket creation fails
|
||||
"""
|
||||
try:
|
||||
self.client.head_bucket(Bucket=self.config.bucket_name)
|
||||
logger.info("Bucket already exists: %s", self.config.bucket_name)
|
||||
except self.client.exceptions.NoSuchBucket:
|
||||
logger.info("Creating bucket: %s", self.config.bucket_name)
|
||||
if self.config.region == "us-east-1":
|
||||
# us-east-1 requires special handling
|
||||
self.client.create_bucket(Bucket=self.config.bucket_name)
|
||||
else:
|
||||
self.client.create_bucket(
|
||||
Bucket=self.config.bucket_name,
|
||||
CreateBucketConfiguration={"LocationConstraint": self.config.region},
|
||||
)
|
||||
logger.info("Bucket created successfully")
|
||||
except Exception as exc:
|
||||
logger.error("Failed to verify/create bucket: %s", exc)
|
||||
raise
|
||||
|
||||
def enable_versioning(self) -> None:
|
||||
"""Enable versioning on the bucket.
|
||||
|
||||
Raises:
|
||||
Exception: If versioning enablement fails
|
||||
"""
|
||||
try:
|
||||
self.client.put_bucket_versioning(
|
||||
Bucket=self.config.bucket_name,
|
||||
VersioningConfiguration={"Status": "Enabled"},
|
||||
)
|
||||
logger.info("Versioning enabled for bucket: %s", self.config.bucket_name)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to enable versioning: %s", exc)
|
||||
raise
|
||||
326
src/backup/exporter.py
Normal file
326
src/backup/exporter.py
Normal file
@@ -0,0 +1,326 @@
|
||||
"""Multi-format database exporter for backups.
|
||||
|
||||
Supports JSON, CSV, and Parquet formats for different use cases:
|
||||
- JSON: Human-readable, easy to inspect
|
||||
- CSV: Analysis tools (Excel, pandas)
|
||||
- Parquet: Big data tools (Spark, DuckDB)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import gzip
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExportFormat(str, Enum):
|
||||
"""Supported export formats."""
|
||||
|
||||
JSON = "json"
|
||||
CSV = "csv"
|
||||
PARQUET = "parquet"
|
||||
|
||||
|
||||
class BackupExporter:
|
||||
"""Export database to multiple formats."""
|
||||
|
||||
def __init__(self, db_path: str) -> None:
|
||||
"""Initialize the exporter.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database
|
||||
"""
|
||||
self.db_path = db_path
|
||||
|
||||
def export_all(
|
||||
self,
|
||||
output_dir: Path,
|
||||
formats: list[ExportFormat] | None = None,
|
||||
compress: bool = True,
|
||||
incremental_since: datetime | None = None,
|
||||
) -> dict[ExportFormat, Path]:
|
||||
"""Export database to multiple formats.
|
||||
|
||||
Args:
|
||||
output_dir: Directory to write export files
|
||||
formats: List of formats to export (default: all)
|
||||
compress: Whether to gzip compress exports
|
||||
incremental_since: Only export records after this timestamp
|
||||
|
||||
Returns:
|
||||
Dictionary mapping format to output file path
|
||||
"""
|
||||
if formats is None:
|
||||
formats = [ExportFormat.JSON, ExportFormat.CSV, ExportFormat.PARQUET]
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
results: dict[ExportFormat, Path] = {}
|
||||
|
||||
for fmt in formats:
|
||||
try:
|
||||
output_file = self._export_format(
|
||||
fmt, output_dir, timestamp, compress, incremental_since
|
||||
)
|
||||
results[fmt] = output_file
|
||||
logger.info("Exported to %s: %s", fmt.value, output_file)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to export to %s: %s", fmt.value, exc)
|
||||
|
||||
return results
|
||||
|
||||
def _export_format(
|
||||
self,
|
||||
fmt: ExportFormat,
|
||||
output_dir: Path,
|
||||
timestamp: str,
|
||||
compress: bool,
|
||||
incremental_since: datetime | None,
|
||||
) -> Path:
|
||||
"""Export to a specific format.
|
||||
|
||||
Args:
|
||||
fmt: Export format
|
||||
output_dir: Output directory
|
||||
timestamp: Timestamp string for filename
|
||||
compress: Whether to compress
|
||||
incremental_since: Incremental export cutoff
|
||||
|
||||
Returns:
|
||||
Path to output file
|
||||
"""
|
||||
if fmt == ExportFormat.JSON:
|
||||
return self._export_json(output_dir, timestamp, compress, incremental_since)
|
||||
elif fmt == ExportFormat.CSV:
|
||||
return self._export_csv(output_dir, timestamp, compress, incremental_since)
|
||||
elif fmt == ExportFormat.PARQUET:
|
||||
return self._export_parquet(
|
||||
output_dir, timestamp, compress, incremental_since
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {fmt}")
|
||||
|
||||
def _get_trades(
|
||||
self, incremental_since: datetime | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch trades from database.
|
||||
|
||||
Args:
|
||||
incremental_since: Only fetch trades after this timestamp
|
||||
|
||||
Returns:
|
||||
List of trade records
|
||||
"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
if incremental_since:
|
||||
cursor = conn.execute(
|
||||
"SELECT * FROM trades WHERE timestamp > ?",
|
||||
(incremental_since.isoformat(),),
|
||||
)
|
||||
else:
|
||||
cursor = conn.execute("SELECT * FROM trades")
|
||||
|
||||
trades = [dict(row) for row in cursor.fetchall()]
|
||||
conn.close()
|
||||
|
||||
return trades
|
||||
|
||||
def _export_json(
|
||||
self,
|
||||
output_dir: Path,
|
||||
timestamp: str,
|
||||
compress: bool,
|
||||
incremental_since: datetime | None,
|
||||
) -> Path:
|
||||
"""Export to JSON format.
|
||||
|
||||
Args:
|
||||
output_dir: Output directory
|
||||
timestamp: Timestamp for filename
|
||||
compress: Whether to gzip
|
||||
incremental_since: Incremental cutoff
|
||||
|
||||
Returns:
|
||||
Path to output file
|
||||
"""
|
||||
trades = self._get_trades(incremental_since)
|
||||
|
||||
filename = f"trades_{timestamp}.json"
|
||||
if compress:
|
||||
filename += ".gz"
|
||||
|
||||
output_file = output_dir / filename
|
||||
|
||||
data = {
|
||||
"export_timestamp": datetime.now(UTC).isoformat(),
|
||||
"incremental_since": (
|
||||
incremental_since.isoformat() if incremental_since else None
|
||||
),
|
||||
"record_count": len(trades),
|
||||
"trades": trades,
|
||||
}
|
||||
|
||||
if compress:
|
||||
with gzip.open(output_file, "wt", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
else:
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
return output_file
|
||||
|
||||
def _export_csv(
|
||||
self,
|
||||
output_dir: Path,
|
||||
timestamp: str,
|
||||
compress: bool,
|
||||
incremental_since: datetime | None,
|
||||
) -> Path:
|
||||
"""Export to CSV format.
|
||||
|
||||
Args:
|
||||
output_dir: Output directory
|
||||
timestamp: Timestamp for filename
|
||||
compress: Whether to gzip
|
||||
incremental_since: Incremental cutoff
|
||||
|
||||
Returns:
|
||||
Path to output file
|
||||
"""
|
||||
trades = self._get_trades(incremental_since)
|
||||
|
||||
filename = f"trades_{timestamp}.csv"
|
||||
if compress:
|
||||
filename += ".gz"
|
||||
|
||||
output_file = output_dir / filename
|
||||
|
||||
if not trades:
|
||||
# Write empty CSV with headers
|
||||
if compress:
|
||||
with gzip.open(output_file, "wt", encoding="utf-8", newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(
|
||||
[
|
||||
"timestamp",
|
||||
"stock_code",
|
||||
"action",
|
||||
"quantity",
|
||||
"price",
|
||||
"confidence",
|
||||
"rationale",
|
||||
"pnl",
|
||||
]
|
||||
)
|
||||
else:
|
||||
with open(output_file, "w", encoding="utf-8", newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(
|
||||
[
|
||||
"timestamp",
|
||||
"stock_code",
|
||||
"action",
|
||||
"quantity",
|
||||
"price",
|
||||
"confidence",
|
||||
"rationale",
|
||||
"pnl",
|
||||
]
|
||||
)
|
||||
return output_file
|
||||
|
||||
# Get column names from first trade
|
||||
fieldnames = list(trades[0].keys())
|
||||
|
||||
if compress:
|
||||
with gzip.open(output_file, "wt", encoding="utf-8", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(trades)
|
||||
else:
|
||||
with open(output_file, "w", encoding="utf-8", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(trades)
|
||||
|
||||
return output_file
|
||||
|
||||
def _export_parquet(
|
||||
self,
|
||||
output_dir: Path,
|
||||
timestamp: str,
|
||||
compress: bool,
|
||||
incremental_since: datetime | None,
|
||||
) -> Path:
|
||||
"""Export to Parquet format.
|
||||
|
||||
Args:
|
||||
output_dir: Output directory
|
||||
timestamp: Timestamp for filename
|
||||
compress: Whether to compress (Parquet has built-in compression)
|
||||
incremental_since: Incremental cutoff
|
||||
|
||||
Returns:
|
||||
Path to output file
|
||||
"""
|
||||
trades = self._get_trades(incremental_since)
|
||||
|
||||
filename = f"trades_{timestamp}.parquet"
|
||||
output_file = output_dir / filename
|
||||
|
||||
try:
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"pyarrow is required for Parquet export. "
|
||||
"Install with: pip install pyarrow"
|
||||
)
|
||||
|
||||
# Convert to pyarrow table
|
||||
table = pa.Table.from_pylist(trades)
|
||||
|
||||
# Write with compression
|
||||
compression = "gzip" if compress else "none"
|
||||
pq.write_table(table, output_file, compression=compression)
|
||||
|
||||
return output_file
|
||||
|
||||
def get_export_stats(self) -> dict[str, Any]:
|
||||
"""Get statistics about exportable data.
|
||||
|
||||
Returns:
|
||||
Dictionary with data statistics
|
||||
"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
stats = {}
|
||||
|
||||
# Total trades
|
||||
cursor.execute("SELECT COUNT(*) FROM trades")
|
||||
stats["total_trades"] = cursor.fetchone()[0]
|
||||
|
||||
# Date range
|
||||
cursor.execute("SELECT MIN(timestamp), MAX(timestamp) FROM trades")
|
||||
min_date, max_date = cursor.fetchone()
|
||||
stats["date_range"] = {"earliest": min_date, "latest": max_date}
|
||||
|
||||
# Database size
|
||||
cursor.execute("SELECT page_count * page_size FROM pragma_page_count(), pragma_page_size()")
|
||||
stats["db_size_bytes"] = cursor.fetchone()[0]
|
||||
|
||||
conn.close()
|
||||
|
||||
return stats
|
||||
282
src/backup/health_monitor.py
Normal file
282
src/backup/health_monitor.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""Health monitoring for backup system.
|
||||
|
||||
Checks:
|
||||
- Database accessibility and integrity
|
||||
- Disk space availability
|
||||
- Backup success/failure tracking
|
||||
- Self-healing capabilities
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
import sqlite3
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HealthStatus(str, Enum):
|
||||
"""Health check status."""
|
||||
|
||||
HEALTHY = "healthy"
|
||||
DEGRADED = "degraded"
|
||||
UNHEALTHY = "unhealthy"
|
||||
|
||||
|
||||
@dataclass
|
||||
class HealthCheckResult:
|
||||
"""Result of a health check."""
|
||||
|
||||
status: HealthStatus
|
||||
message: str
|
||||
details: dict[str, Any] | None = None
|
||||
timestamp: datetime | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.timestamp is None:
|
||||
self.timestamp = datetime.now(UTC)
|
||||
|
||||
|
||||
class HealthMonitor:
|
||||
"""Monitor system health and backup status."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_path: str,
|
||||
backup_dir: Path,
|
||||
min_disk_space_gb: float = 10.0,
|
||||
max_backup_age_hours: int = 25, # Daily backups should be < 25 hours old
|
||||
) -> None:
|
||||
"""Initialize health monitor.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database
|
||||
backup_dir: Backup directory
|
||||
min_disk_space_gb: Minimum required disk space in GB
|
||||
max_backup_age_hours: Maximum acceptable backup age in hours
|
||||
"""
|
||||
self.db_path = Path(db_path)
|
||||
self.backup_dir = backup_dir
|
||||
self.min_disk_space_bytes = int(min_disk_space_gb * 1024 * 1024 * 1024)
|
||||
self.max_backup_age = timedelta(hours=max_backup_age_hours)
|
||||
|
||||
def check_database_health(self) -> HealthCheckResult:
|
||||
"""Check database accessibility and integrity.
|
||||
|
||||
Returns:
|
||||
HealthCheckResult
|
||||
"""
|
||||
# Check if database exists
|
||||
if not self.db_path.exists():
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.UNHEALTHY,
|
||||
message=f"Database not found: {self.db_path}",
|
||||
)
|
||||
|
||||
# Check if database is accessible
|
||||
try:
|
||||
conn = sqlite3.connect(str(self.db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Run integrity check
|
||||
cursor.execute("PRAGMA integrity_check")
|
||||
result = cursor.fetchone()[0]
|
||||
|
||||
if result != "ok":
|
||||
conn.close()
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.UNHEALTHY,
|
||||
message=f"Database integrity check failed: {result}",
|
||||
)
|
||||
|
||||
# Get database size
|
||||
cursor.execute(
|
||||
"SELECT page_count * page_size FROM pragma_page_count(), pragma_page_size()"
|
||||
)
|
||||
db_size = cursor.fetchone()[0]
|
||||
|
||||
# Get row counts
|
||||
cursor.execute("SELECT COUNT(*) FROM trades")
|
||||
trade_count = cursor.fetchone()[0]
|
||||
|
||||
conn.close()
|
||||
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.HEALTHY,
|
||||
message="Database is healthy",
|
||||
details={
|
||||
"size_bytes": db_size,
|
||||
"size_mb": db_size / 1024 / 1024,
|
||||
"trade_count": trade_count,
|
||||
},
|
||||
)
|
||||
|
||||
except sqlite3.Error as exc:
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.UNHEALTHY,
|
||||
message=f"Database access error: {exc}",
|
||||
)
|
||||
|
||||
def check_disk_space(self) -> HealthCheckResult:
|
||||
"""Check available disk space.
|
||||
|
||||
Returns:
|
||||
HealthCheckResult
|
||||
"""
|
||||
try:
|
||||
stat = shutil.disk_usage(self.backup_dir)
|
||||
|
||||
free_gb = stat.free / 1024 / 1024 / 1024
|
||||
total_gb = stat.total / 1024 / 1024 / 1024
|
||||
used_percent = (stat.used / stat.total) * 100
|
||||
|
||||
if stat.free < self.min_disk_space_bytes:
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.UNHEALTHY,
|
||||
message=f"Low disk space: {free_gb:.2f} GB free (minimum: {self.min_disk_space_bytes / 1024 / 1024 / 1024:.2f} GB)",
|
||||
details={
|
||||
"free_gb": free_gb,
|
||||
"total_gb": total_gb,
|
||||
"used_percent": used_percent,
|
||||
},
|
||||
)
|
||||
elif stat.free < self.min_disk_space_bytes * 2:
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.DEGRADED,
|
||||
message=f"Disk space low: {free_gb:.2f} GB free",
|
||||
details={
|
||||
"free_gb": free_gb,
|
||||
"total_gb": total_gb,
|
||||
"used_percent": used_percent,
|
||||
},
|
||||
)
|
||||
else:
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.HEALTHY,
|
||||
message=f"Disk space healthy: {free_gb:.2f} GB free",
|
||||
details={
|
||||
"free_gb": free_gb,
|
||||
"total_gb": total_gb,
|
||||
"used_percent": used_percent,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.UNHEALTHY,
|
||||
message=f"Failed to check disk space: {exc}",
|
||||
)
|
||||
|
||||
def check_backup_recency(self) -> HealthCheckResult:
|
||||
"""Check if backups are recent enough.
|
||||
|
||||
Returns:
|
||||
HealthCheckResult
|
||||
"""
|
||||
daily_dir = self.backup_dir / "daily"
|
||||
|
||||
if not daily_dir.exists():
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.DEGRADED,
|
||||
message="Daily backup directory not found",
|
||||
)
|
||||
|
||||
# Find most recent backup
|
||||
backups = sorted(daily_dir.glob("*.db"), key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
|
||||
if not backups:
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.UNHEALTHY,
|
||||
message="No daily backups found",
|
||||
)
|
||||
|
||||
most_recent = backups[0]
|
||||
mtime = datetime.fromtimestamp(most_recent.stat().st_mtime, tz=UTC)
|
||||
age = datetime.now(UTC) - mtime
|
||||
|
||||
if age > self.max_backup_age:
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.DEGRADED,
|
||||
message=f"Most recent backup is {age.total_seconds() / 3600:.1f} hours old",
|
||||
details={
|
||||
"backup_file": most_recent.name,
|
||||
"age_hours": age.total_seconds() / 3600,
|
||||
"threshold_hours": self.max_backup_age.total_seconds() / 3600,
|
||||
},
|
||||
)
|
||||
else:
|
||||
return HealthCheckResult(
|
||||
status=HealthStatus.HEALTHY,
|
||||
message=f"Recent backup found ({age.total_seconds() / 3600:.1f} hours old)",
|
||||
details={
|
||||
"backup_file": most_recent.name,
|
||||
"age_hours": age.total_seconds() / 3600,
|
||||
},
|
||||
)
|
||||
|
||||
def run_all_checks(self) -> dict[str, HealthCheckResult]:
|
||||
"""Run all health checks.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping check name to result
|
||||
"""
|
||||
checks = {
|
||||
"database": self.check_database_health(),
|
||||
"disk_space": self.check_disk_space(),
|
||||
"backup_recency": self.check_backup_recency(),
|
||||
}
|
||||
|
||||
# Log results
|
||||
for check_name, result in checks.items():
|
||||
if result.status == HealthStatus.UNHEALTHY:
|
||||
logger.error("[%s] %s: %s", check_name, result.status.value, result.message)
|
||||
elif result.status == HealthStatus.DEGRADED:
|
||||
logger.warning("[%s] %s: %s", check_name, result.status.value, result.message)
|
||||
else:
|
||||
logger.info("[%s] %s: %s", check_name, result.status.value, result.message)
|
||||
|
||||
return checks
|
||||
|
||||
def get_overall_status(self) -> HealthStatus:
|
||||
"""Get overall system health status.
|
||||
|
||||
Returns:
|
||||
HealthStatus (worst status from all checks)
|
||||
"""
|
||||
checks = self.run_all_checks()
|
||||
|
||||
# Return worst status
|
||||
if any(c.status == HealthStatus.UNHEALTHY for c in checks.values()):
|
||||
return HealthStatus.UNHEALTHY
|
||||
elif any(c.status == HealthStatus.DEGRADED for c in checks.values()):
|
||||
return HealthStatus.DEGRADED
|
||||
else:
|
||||
return HealthStatus.HEALTHY
|
||||
|
||||
def get_health_report(self) -> dict[str, Any]:
|
||||
"""Get comprehensive health report.
|
||||
|
||||
Returns:
|
||||
Dictionary with health report
|
||||
"""
|
||||
checks = self.run_all_checks()
|
||||
overall = self.get_overall_status()
|
||||
|
||||
return {
|
||||
"overall_status": overall.value,
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"checks": {
|
||||
name: {
|
||||
"status": result.status.value,
|
||||
"message": result.message,
|
||||
"details": result.details,
|
||||
}
|
||||
for name, result in checks.items()
|
||||
},
|
||||
}
|
||||
336
src/backup/scheduler.py
Normal file
336
src/backup/scheduler.py
Normal file
@@ -0,0 +1,336 @@
|
||||
"""Backup scheduler for automated database backups.
|
||||
|
||||
Implements backup policies:
|
||||
- Daily: Keep for 30 days (hot storage)
|
||||
- Weekly: Keep for 1 year (warm storage)
|
||||
- Monthly: Keep forever (cold storage)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BackupPolicy(str, Enum):
|
||||
"""Backup retention policies."""
|
||||
|
||||
DAILY = "daily"
|
||||
WEEKLY = "weekly"
|
||||
MONTHLY = "monthly"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackupMetadata:
|
||||
"""Metadata for a backup."""
|
||||
|
||||
timestamp: datetime
|
||||
policy: BackupPolicy
|
||||
file_path: Path
|
||||
size_bytes: int
|
||||
checksum: str | None = None
|
||||
|
||||
|
||||
class BackupScheduler:
|
||||
"""Manage automated database backups with retention policies."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_path: str,
|
||||
backup_dir: Path,
|
||||
daily_retention_days: int = 30,
|
||||
weekly_retention_days: int = 365,
|
||||
) -> None:
|
||||
"""Initialize the backup scheduler.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database
|
||||
backup_dir: Root directory for backups
|
||||
daily_retention_days: Days to keep daily backups
|
||||
weekly_retention_days: Days to keep weekly backups
|
||||
"""
|
||||
self.db_path = Path(db_path)
|
||||
self.backup_dir = backup_dir
|
||||
self.daily_retention = timedelta(days=daily_retention_days)
|
||||
self.weekly_retention = timedelta(days=weekly_retention_days)
|
||||
|
||||
# Create policy-specific directories
|
||||
self.daily_dir = backup_dir / "daily"
|
||||
self.weekly_dir = backup_dir / "weekly"
|
||||
self.monthly_dir = backup_dir / "monthly"
|
||||
|
||||
for d in [self.daily_dir, self.weekly_dir, self.monthly_dir]:
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def create_backup(
|
||||
self, policy: BackupPolicy, verify: bool = True
|
||||
) -> BackupMetadata:
|
||||
"""Create a database backup.
|
||||
|
||||
Args:
|
||||
policy: Backup policy (daily/weekly/monthly)
|
||||
verify: Whether to verify backup integrity
|
||||
|
||||
Returns:
|
||||
BackupMetadata object
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If database doesn't exist
|
||||
OSError: If backup fails
|
||||
"""
|
||||
if not self.db_path.exists():
|
||||
raise FileNotFoundError(f"Database not found: {self.db_path}")
|
||||
|
||||
timestamp = datetime.now(UTC)
|
||||
backup_filename = self._get_backup_filename(timestamp, policy)
|
||||
|
||||
# Determine output directory
|
||||
if policy == BackupPolicy.DAILY:
|
||||
output_dir = self.daily_dir
|
||||
elif policy == BackupPolicy.WEEKLY:
|
||||
output_dir = self.weekly_dir
|
||||
else: # MONTHLY
|
||||
output_dir = self.monthly_dir
|
||||
|
||||
backup_path = output_dir / backup_filename
|
||||
|
||||
# Create backup (copy database file)
|
||||
logger.info("Creating %s backup: %s", policy.value, backup_path)
|
||||
shutil.copy2(self.db_path, backup_path)
|
||||
|
||||
# Get file size
|
||||
size_bytes = backup_path.stat().st_size
|
||||
|
||||
# Verify backup if requested
|
||||
checksum = None
|
||||
if verify:
|
||||
checksum = self._verify_backup(backup_path)
|
||||
|
||||
metadata = BackupMetadata(
|
||||
timestamp=timestamp,
|
||||
policy=policy,
|
||||
file_path=backup_path,
|
||||
size_bytes=size_bytes,
|
||||
checksum=checksum,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Backup created: %s (%.2f MB)",
|
||||
backup_path.name,
|
||||
size_bytes / 1024 / 1024,
|
||||
)
|
||||
|
||||
return metadata
|
||||
|
||||
def _get_backup_filename(self, timestamp: datetime, policy: BackupPolicy) -> str:
|
||||
"""Generate backup filename.
|
||||
|
||||
Args:
|
||||
timestamp: Backup timestamp
|
||||
policy: Backup policy
|
||||
|
||||
Returns:
|
||||
Filename string
|
||||
"""
|
||||
ts_str = timestamp.strftime("%Y%m%d_%H%M%S")
|
||||
return f"trade_logs_{policy.value}_{ts_str}.db"
|
||||
|
||||
def _verify_backup(self, backup_path: Path) -> str:
|
||||
"""Verify backup integrity using SQLite integrity check.
|
||||
|
||||
Args:
|
||||
backup_path: Path to backup file
|
||||
|
||||
Returns:
|
||||
Checksum string (MD5 hash)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If integrity check fails
|
||||
"""
|
||||
import hashlib
|
||||
import sqlite3
|
||||
|
||||
# Integrity check
|
||||
try:
|
||||
conn = sqlite3.connect(str(backup_path))
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("PRAGMA integrity_check")
|
||||
result = cursor.fetchone()[0]
|
||||
conn.close()
|
||||
|
||||
if result != "ok":
|
||||
raise RuntimeError(f"Integrity check failed: {result}")
|
||||
except sqlite3.Error as exc:
|
||||
raise RuntimeError(f"Failed to verify backup: {exc}")
|
||||
|
||||
# Calculate MD5 checksum
|
||||
md5 = hashlib.md5()
|
||||
with open(backup_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(8192), b""):
|
||||
md5.update(chunk)
|
||||
|
||||
return md5.hexdigest()
|
||||
|
||||
def cleanup_old_backups(self) -> dict[BackupPolicy, int]:
|
||||
"""Remove backups older than retention policies.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping policy to number of backups removed
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
removed_counts: dict[BackupPolicy, int] = {}
|
||||
|
||||
# Daily backups: remove older than retention
|
||||
removed_counts[BackupPolicy.DAILY] = self._cleanup_directory(
|
||||
self.daily_dir, now - self.daily_retention
|
||||
)
|
||||
|
||||
# Weekly backups: remove older than retention
|
||||
removed_counts[BackupPolicy.WEEKLY] = self._cleanup_directory(
|
||||
self.weekly_dir, now - self.weekly_retention
|
||||
)
|
||||
|
||||
# Monthly backups: never remove (kept forever)
|
||||
removed_counts[BackupPolicy.MONTHLY] = 0
|
||||
|
||||
total = sum(removed_counts.values())
|
||||
if total > 0:
|
||||
logger.info("Cleaned up %d old backup(s)", total)
|
||||
|
||||
return removed_counts
|
||||
|
||||
def _cleanup_directory(self, directory: Path, cutoff: datetime) -> int:
|
||||
"""Remove backups older than cutoff date.
|
||||
|
||||
Args:
|
||||
directory: Directory to clean
|
||||
cutoff: Remove files older than this
|
||||
|
||||
Returns:
|
||||
Number of files removed
|
||||
"""
|
||||
removed = 0
|
||||
|
||||
for backup_file in directory.glob("*.db"):
|
||||
# Get file modification time
|
||||
mtime = datetime.fromtimestamp(backup_file.stat().st_mtime, tz=UTC)
|
||||
|
||||
if mtime < cutoff:
|
||||
logger.debug("Removing old backup: %s", backup_file.name)
|
||||
backup_file.unlink()
|
||||
removed += 1
|
||||
|
||||
return removed
|
||||
|
||||
def list_backups(
|
||||
self, policy: BackupPolicy | None = None
|
||||
) -> list[BackupMetadata]:
|
||||
"""List available backups.
|
||||
|
||||
Args:
|
||||
policy: Filter by policy (None for all)
|
||||
|
||||
Returns:
|
||||
List of BackupMetadata objects
|
||||
"""
|
||||
backups: list[BackupMetadata] = []
|
||||
|
||||
policies_to_check = (
|
||||
[policy] if policy else [BackupPolicy.DAILY, BackupPolicy.WEEKLY, BackupPolicy.MONTHLY]
|
||||
)
|
||||
|
||||
for pol in policies_to_check:
|
||||
if pol == BackupPolicy.DAILY:
|
||||
directory = self.daily_dir
|
||||
elif pol == BackupPolicy.WEEKLY:
|
||||
directory = self.weekly_dir
|
||||
else:
|
||||
directory = self.monthly_dir
|
||||
|
||||
for backup_file in sorted(directory.glob("*.db")):
|
||||
mtime = datetime.fromtimestamp(backup_file.stat().st_mtime, tz=UTC)
|
||||
size = backup_file.stat().st_size
|
||||
|
||||
backups.append(
|
||||
BackupMetadata(
|
||||
timestamp=mtime,
|
||||
policy=pol,
|
||||
file_path=backup_file,
|
||||
size_bytes=size,
|
||||
)
|
||||
)
|
||||
|
||||
# Sort by timestamp (newest first)
|
||||
backups.sort(key=lambda b: b.timestamp, reverse=True)
|
||||
|
||||
return backups
|
||||
|
||||
def get_backup_stats(self) -> dict[str, Any]:
|
||||
"""Get backup statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with backup stats
|
||||
"""
|
||||
stats: dict[str, Any] = {}
|
||||
|
||||
for policy in BackupPolicy:
|
||||
if policy == BackupPolicy.DAILY:
|
||||
directory = self.daily_dir
|
||||
elif policy == BackupPolicy.WEEKLY:
|
||||
directory = self.weekly_dir
|
||||
else:
|
||||
directory = self.monthly_dir
|
||||
|
||||
backups = list(directory.glob("*.db"))
|
||||
total_size = sum(b.stat().st_size for b in backups)
|
||||
|
||||
stats[policy.value] = {
|
||||
"count": len(backups),
|
||||
"total_size_bytes": total_size,
|
||||
"total_size_mb": total_size / 1024 / 1024,
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
def restore_backup(self, backup_metadata: BackupMetadata, verify: bool = True) -> None:
|
||||
"""Restore database from backup.
|
||||
|
||||
Args:
|
||||
backup_metadata: Backup to restore
|
||||
verify: Whether to verify restored database
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If backup file doesn't exist
|
||||
RuntimeError: If verification fails
|
||||
"""
|
||||
if not backup_metadata.file_path.exists():
|
||||
raise FileNotFoundError(f"Backup not found: {backup_metadata.file_path}")
|
||||
|
||||
# Create backup of current database
|
||||
if self.db_path.exists():
|
||||
backup_current = self.db_path.with_suffix(".db.before_restore")
|
||||
logger.info("Backing up current database to: %s", backup_current)
|
||||
shutil.copy2(self.db_path, backup_current)
|
||||
|
||||
# Restore backup
|
||||
logger.info("Restoring backup: %s", backup_metadata.file_path.name)
|
||||
shutil.copy2(backup_metadata.file_path, self.db_path)
|
||||
|
||||
# Verify restored database
|
||||
if verify:
|
||||
try:
|
||||
self._verify_backup(self.db_path)
|
||||
logger.info("Backup restored and verified successfully")
|
||||
except RuntimeError as exc:
|
||||
# Restore failed, revert to backup
|
||||
if backup_current.exists():
|
||||
logger.error("Restore verification failed, reverting: %s", exc)
|
||||
shutil.copy2(backup_current, self.db_path)
|
||||
raise
|
||||
293
src/brain/cache.py
Normal file
293
src/brain/cache.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""Response caching system for reducing redundant LLM calls.
|
||||
|
||||
This module provides caching for common trading scenarios:
|
||||
- TTL-based cache invalidation
|
||||
- Cache key based on market conditions
|
||||
- Cache hit rate monitoring
|
||||
- Special handling for HOLD decisions in quiet markets
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.brain.gemini_client import TradeDecision
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""Cached decision with metadata."""
|
||||
|
||||
decision: "TradeDecision"
|
||||
cached_at: float # Unix timestamp
|
||||
hit_count: int = 0
|
||||
market_data_hash: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheMetrics:
|
||||
"""Metrics for cache performance monitoring."""
|
||||
|
||||
total_requests: int = 0
|
||||
cache_hits: int = 0
|
||||
cache_misses: int = 0
|
||||
evictions: int = 0
|
||||
total_entries: int = 0
|
||||
|
||||
@property
|
||||
def hit_rate(self) -> float:
|
||||
"""Calculate cache hit rate."""
|
||||
if self.total_requests == 0:
|
||||
return 0.0
|
||||
return self.cache_hits / self.total_requests
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert metrics to dictionary."""
|
||||
return {
|
||||
"total_requests": self.total_requests,
|
||||
"cache_hits": self.cache_hits,
|
||||
"cache_misses": self.cache_misses,
|
||||
"hit_rate": self.hit_rate,
|
||||
"evictions": self.evictions,
|
||||
"total_entries": self.total_entries,
|
||||
}
|
||||
|
||||
|
||||
class DecisionCache:
|
||||
"""TTL-based cache for trade decisions."""
|
||||
|
||||
def __init__(self, ttl_seconds: int = 300, max_size: int = 1000) -> None:
|
||||
"""Initialize the decision cache.
|
||||
|
||||
Args:
|
||||
ttl_seconds: Time-to-live for cache entries in seconds (default: 5 minutes)
|
||||
max_size: Maximum number of cache entries
|
||||
"""
|
||||
self.ttl_seconds = ttl_seconds
|
||||
self.max_size = max_size
|
||||
self._cache: dict[str, CacheEntry] = {}
|
||||
self._metrics = CacheMetrics()
|
||||
|
||||
def _generate_cache_key(self, market_data: dict[str, Any]) -> str:
|
||||
"""Generate cache key from market data.
|
||||
|
||||
Key is based on:
|
||||
- Stock code
|
||||
- Current price (rounded to reduce sensitivity)
|
||||
- Market conditions (orderbook snapshot)
|
||||
|
||||
Args:
|
||||
market_data: Market data dictionary
|
||||
|
||||
Returns:
|
||||
Cache key string
|
||||
"""
|
||||
# Extract key components
|
||||
stock_code = market_data.get("stock_code", "UNKNOWN")
|
||||
current_price = market_data.get("current_price", 0)
|
||||
|
||||
# Round price to reduce sensitivity (cache hits for similar prices)
|
||||
# For prices > 1000, round to nearest 10
|
||||
# For prices < 1000, round to nearest 1
|
||||
if current_price > 1000:
|
||||
price_rounded = round(current_price / 10) * 10
|
||||
else:
|
||||
price_rounded = round(current_price)
|
||||
|
||||
# Include orderbook snapshot (if available)
|
||||
orderbook_key = ""
|
||||
if "orderbook" in market_data and market_data["orderbook"]:
|
||||
ob = market_data["orderbook"]
|
||||
# Just use bid/ask spread as indicator
|
||||
if "bid" in ob and "ask" in ob and ob["bid"] and ob["ask"]:
|
||||
bid_price = ob["bid"][0].get("price", 0) if ob["bid"] else 0
|
||||
ask_price = ob["ask"][0].get("price", 0) if ob["ask"] else 0
|
||||
spread = ask_price - bid_price
|
||||
orderbook_key = f"_spread{spread}"
|
||||
|
||||
# Generate cache key
|
||||
key_str = f"{stock_code}_{price_rounded}{orderbook_key}"
|
||||
|
||||
return key_str
|
||||
|
||||
def _generate_market_hash(self, market_data: dict[str, Any]) -> str:
|
||||
"""Generate hash of full market data for invalidation checks.
|
||||
|
||||
Args:
|
||||
market_data: Market data dictionary
|
||||
|
||||
Returns:
|
||||
Hash string
|
||||
"""
|
||||
# Create stable JSON representation
|
||||
stable_json = json.dumps(market_data, sort_keys=True, ensure_ascii=False)
|
||||
return hashlib.md5(stable_json.encode()).hexdigest()
|
||||
|
||||
def get(self, market_data: dict[str, Any]) -> TradeDecision | None:
|
||||
"""Retrieve cached decision if valid.
|
||||
|
||||
Args:
|
||||
market_data: Market data dictionary
|
||||
|
||||
Returns:
|
||||
Cached TradeDecision if valid, None otherwise
|
||||
"""
|
||||
self._metrics.total_requests += 1
|
||||
|
||||
cache_key = self._generate_cache_key(market_data)
|
||||
|
||||
if cache_key not in self._cache:
|
||||
self._metrics.cache_misses += 1
|
||||
return None
|
||||
|
||||
entry = self._cache[cache_key]
|
||||
current_time = time.time()
|
||||
|
||||
# Check TTL
|
||||
if current_time - entry.cached_at > self.ttl_seconds:
|
||||
# Expired
|
||||
del self._cache[cache_key]
|
||||
self._metrics.cache_misses += 1
|
||||
self._metrics.evictions += 1
|
||||
logger.debug("Cache expired for key: %s", cache_key)
|
||||
return None
|
||||
|
||||
# Cache hit
|
||||
entry.hit_count += 1
|
||||
self._metrics.cache_hits += 1
|
||||
logger.debug("Cache hit for key: %s (hits: %d)", cache_key, entry.hit_count)
|
||||
|
||||
return entry.decision
|
||||
|
||||
def set(
|
||||
self,
|
||||
market_data: dict[str, Any],
|
||||
decision: TradeDecision,
|
||||
) -> None:
|
||||
"""Store decision in cache.
|
||||
|
||||
Args:
|
||||
market_data: Market data dictionary
|
||||
decision: TradeDecision to cache
|
||||
"""
|
||||
cache_key = self._generate_cache_key(market_data)
|
||||
market_hash = self._generate_market_hash(market_data)
|
||||
|
||||
# Enforce max size (evict oldest if full)
|
||||
if len(self._cache) >= self.max_size:
|
||||
# Find oldest entry
|
||||
oldest_key = min(self._cache.keys(), key=lambda k: self._cache[k].cached_at)
|
||||
del self._cache[oldest_key]
|
||||
self._metrics.evictions += 1
|
||||
logger.debug("Cache full, evicted key: %s", oldest_key)
|
||||
|
||||
# Store entry
|
||||
entry = CacheEntry(
|
||||
decision=decision,
|
||||
cached_at=time.time(),
|
||||
market_data_hash=market_hash,
|
||||
)
|
||||
self._cache[cache_key] = entry
|
||||
self._metrics.total_entries = len(self._cache)
|
||||
|
||||
logger.debug("Cached decision for key: %s", cache_key)
|
||||
|
||||
def invalidate(self, stock_code: str | None = None) -> int:
|
||||
"""Invalidate cache entries.
|
||||
|
||||
Args:
|
||||
stock_code: Specific stock code to invalidate, or None for all
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
if stock_code is None:
|
||||
# Clear all
|
||||
count = len(self._cache)
|
||||
self._cache.clear()
|
||||
self._metrics.evictions += count
|
||||
self._metrics.total_entries = 0
|
||||
logger.info("Invalidated all cache entries (%d)", count)
|
||||
return count
|
||||
|
||||
# Invalidate specific stock
|
||||
keys_to_remove = [k for k in self._cache.keys() if k.startswith(f"{stock_code}_")]
|
||||
count = len(keys_to_remove)
|
||||
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
|
||||
self._metrics.evictions += count
|
||||
self._metrics.total_entries = len(self._cache)
|
||||
logger.info("Invalidated %d cache entries for stock: %s", count, stock_code)
|
||||
|
||||
return count
|
||||
|
||||
def cleanup_expired(self) -> int:
|
||||
"""Remove expired entries from cache.
|
||||
|
||||
Returns:
|
||||
Number of entries removed
|
||||
"""
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
k
|
||||
for k, v in self._cache.items()
|
||||
if current_time - v.cached_at > self.ttl_seconds
|
||||
]
|
||||
|
||||
count = len(expired_keys)
|
||||
for key in expired_keys:
|
||||
del self._cache[key]
|
||||
|
||||
self._metrics.evictions += count
|
||||
self._metrics.total_entries = len(self._cache)
|
||||
|
||||
if count > 0:
|
||||
logger.debug("Cleaned up %d expired cache entries", count)
|
||||
|
||||
return count
|
||||
|
||||
def get_metrics(self) -> CacheMetrics:
|
||||
"""Get current cache metrics.
|
||||
|
||||
Returns:
|
||||
CacheMetrics object with current statistics
|
||||
"""
|
||||
return self._metrics
|
||||
|
||||
def reset_metrics(self) -> None:
|
||||
"""Reset cache metrics."""
|
||||
self._metrics = CacheMetrics(total_entries=len(self._cache))
|
||||
logger.info("Cache metrics reset")
|
||||
|
||||
def should_cache_decision(self, decision: TradeDecision) -> bool:
|
||||
"""Determine if a decision should be cached.
|
||||
|
||||
HOLD decisions with low confidence are good candidates for caching,
|
||||
as they're likely to recur in quiet markets.
|
||||
|
||||
Args:
|
||||
decision: TradeDecision to evaluate
|
||||
|
||||
Returns:
|
||||
True if decision should be cached
|
||||
"""
|
||||
# Cache HOLD decisions (common in quiet markets)
|
||||
if decision.action == "HOLD":
|
||||
return True
|
||||
|
||||
# Cache high-confidence decisions (stable signals)
|
||||
if decision.confidence >= 90:
|
||||
return True
|
||||
|
||||
# Don't cache low-confidence BUY/SELL (volatile signals)
|
||||
return False
|
||||
296
src/brain/context_selector.py
Normal file
296
src/brain/context_selector.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""Smart context selection for optimizing token usage.
|
||||
|
||||
This module implements intelligent selection of context layers (L1-L7) based on
|
||||
decision type and market conditions:
|
||||
- L7 (real-time) for normal trading decisions
|
||||
- L6-L5 (daily/weekly) for strategic decisions
|
||||
- L4-L1 (monthly/legacy) only for major events or policy changes
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from src.context.layer import ContextLayer
|
||||
from src.context.store import ContextStore
|
||||
|
||||
|
||||
class DecisionType(str, Enum):
|
||||
"""Type of trading decision being made."""
|
||||
|
||||
NORMAL = "normal" # Regular trade decision
|
||||
STRATEGIC = "strategic" # Strategy adjustment
|
||||
MAJOR_EVENT = "major_event" # Portfolio rebalancing, policy change
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ContextSelection:
|
||||
"""Selected context layers and their relevance scores."""
|
||||
|
||||
layers: list[ContextLayer]
|
||||
relevance_scores: dict[ContextLayer, float]
|
||||
total_score: float
|
||||
|
||||
|
||||
class ContextSelector:
|
||||
"""Selects optimal context layers to minimize token usage."""
|
||||
|
||||
def __init__(self, store: ContextStore) -> None:
|
||||
"""Initialize the context selector.
|
||||
|
||||
Args:
|
||||
store: ContextStore instance for retrieving context data
|
||||
"""
|
||||
self.store = store
|
||||
|
||||
def select_layers(
|
||||
self,
|
||||
decision_type: DecisionType = DecisionType.NORMAL,
|
||||
include_realtime: bool = True,
|
||||
) -> list[ContextLayer]:
|
||||
"""Select context layers based on decision type.
|
||||
|
||||
Strategy:
|
||||
- NORMAL: L7 (real-time) only
|
||||
- STRATEGIC: L7 + L6 + L5 (real-time + daily + weekly)
|
||||
- MAJOR_EVENT: All layers L1-L7
|
||||
|
||||
Args:
|
||||
decision_type: Type of decision being made
|
||||
include_realtime: Whether to include L7 real-time data
|
||||
|
||||
Returns:
|
||||
List of context layers to use (ordered by priority)
|
||||
"""
|
||||
if decision_type == DecisionType.NORMAL:
|
||||
# Normal trading: only real-time data
|
||||
return [ContextLayer.L7_REALTIME] if include_realtime else []
|
||||
|
||||
elif decision_type == DecisionType.STRATEGIC:
|
||||
# Strategic decisions: real-time + recent history
|
||||
layers = []
|
||||
if include_realtime:
|
||||
layers.append(ContextLayer.L7_REALTIME)
|
||||
layers.extend([ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY])
|
||||
return layers
|
||||
|
||||
else: # MAJOR_EVENT
|
||||
# Major events: all layers for comprehensive context
|
||||
layers = []
|
||||
if include_realtime:
|
||||
layers.append(ContextLayer.L7_REALTIME)
|
||||
layers.extend(
|
||||
[
|
||||
ContextLayer.L6_DAILY,
|
||||
ContextLayer.L5_WEEKLY,
|
||||
ContextLayer.L4_MONTHLY,
|
||||
ContextLayer.L3_QUARTERLY,
|
||||
ContextLayer.L2_ANNUAL,
|
||||
ContextLayer.L1_LEGACY,
|
||||
]
|
||||
)
|
||||
return layers
|
||||
|
||||
def score_layer_relevance(
|
||||
self,
|
||||
layer: ContextLayer,
|
||||
decision_type: DecisionType,
|
||||
current_time: datetime | None = None,
|
||||
) -> float:
|
||||
"""Calculate relevance score for a context layer.
|
||||
|
||||
Relevance is based on:
|
||||
1. Decision type (normal, strategic, major event)
|
||||
2. Layer recency (L7 > L6 > ... > L1)
|
||||
3. Data availability
|
||||
|
||||
Args:
|
||||
layer: Context layer to score
|
||||
decision_type: Type of decision being made
|
||||
current_time: Current time (defaults to now)
|
||||
|
||||
Returns:
|
||||
Relevance score (0.0 to 1.0)
|
||||
"""
|
||||
if current_time is None:
|
||||
current_time = datetime.now(UTC)
|
||||
|
||||
# Base scores by decision type
|
||||
base_scores = {
|
||||
DecisionType.NORMAL: {
|
||||
ContextLayer.L7_REALTIME: 1.0,
|
||||
ContextLayer.L6_DAILY: 0.1,
|
||||
ContextLayer.L5_WEEKLY: 0.05,
|
||||
ContextLayer.L4_MONTHLY: 0.01,
|
||||
ContextLayer.L3_QUARTERLY: 0.0,
|
||||
ContextLayer.L2_ANNUAL: 0.0,
|
||||
ContextLayer.L1_LEGACY: 0.0,
|
||||
},
|
||||
DecisionType.STRATEGIC: {
|
||||
ContextLayer.L7_REALTIME: 0.9,
|
||||
ContextLayer.L6_DAILY: 0.8,
|
||||
ContextLayer.L5_WEEKLY: 0.7,
|
||||
ContextLayer.L4_MONTHLY: 0.3,
|
||||
ContextLayer.L3_QUARTERLY: 0.2,
|
||||
ContextLayer.L2_ANNUAL: 0.1,
|
||||
ContextLayer.L1_LEGACY: 0.05,
|
||||
},
|
||||
DecisionType.MAJOR_EVENT: {
|
||||
ContextLayer.L7_REALTIME: 0.7,
|
||||
ContextLayer.L6_DAILY: 0.7,
|
||||
ContextLayer.L5_WEEKLY: 0.7,
|
||||
ContextLayer.L4_MONTHLY: 0.8,
|
||||
ContextLayer.L3_QUARTERLY: 0.8,
|
||||
ContextLayer.L2_ANNUAL: 0.9,
|
||||
ContextLayer.L1_LEGACY: 1.0,
|
||||
},
|
||||
}
|
||||
|
||||
score = base_scores[decision_type].get(layer, 0.0)
|
||||
|
||||
# Check data availability
|
||||
latest_timeframe = self.store.get_latest_timeframe(layer)
|
||||
if latest_timeframe is None:
|
||||
# No data available - reduce score significantly
|
||||
score *= 0.1
|
||||
|
||||
return score
|
||||
|
||||
def select_with_scoring(
|
||||
self,
|
||||
decision_type: DecisionType = DecisionType.NORMAL,
|
||||
min_score: float = 0.5,
|
||||
) -> ContextSelection:
|
||||
"""Select context layers with relevance scoring.
|
||||
|
||||
Args:
|
||||
decision_type: Type of decision being made
|
||||
min_score: Minimum relevance score to include a layer
|
||||
|
||||
Returns:
|
||||
ContextSelection with selected layers and scores
|
||||
"""
|
||||
all_layers = [
|
||||
ContextLayer.L7_REALTIME,
|
||||
ContextLayer.L6_DAILY,
|
||||
ContextLayer.L5_WEEKLY,
|
||||
ContextLayer.L4_MONTHLY,
|
||||
ContextLayer.L3_QUARTERLY,
|
||||
ContextLayer.L2_ANNUAL,
|
||||
ContextLayer.L1_LEGACY,
|
||||
]
|
||||
|
||||
scores = {
|
||||
layer: self.score_layer_relevance(layer, decision_type) for layer in all_layers
|
||||
}
|
||||
|
||||
# Filter by minimum score
|
||||
selected_layers = [layer for layer, score in scores.items() if score >= min_score]
|
||||
|
||||
# Sort by score (descending)
|
||||
selected_layers.sort(key=lambda layer: scores[layer], reverse=True)
|
||||
|
||||
total_score = sum(scores[layer] for layer in selected_layers)
|
||||
|
||||
return ContextSelection(
|
||||
layers=selected_layers,
|
||||
relevance_scores=scores,
|
||||
total_score=total_score,
|
||||
)
|
||||
|
||||
def get_context_data(
|
||||
self,
|
||||
layers: list[ContextLayer],
|
||||
max_items_per_layer: int = 10,
|
||||
) -> dict[str, Any]:
|
||||
"""Retrieve context data for selected layers.
|
||||
|
||||
Args:
|
||||
layers: List of context layers to retrieve
|
||||
max_items_per_layer: Maximum number of items per layer
|
||||
|
||||
Returns:
|
||||
Dictionary with context data organized by layer
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
|
||||
for layer in layers:
|
||||
# Get latest timeframe for this layer
|
||||
latest_timeframe = self.store.get_latest_timeframe(layer)
|
||||
if latest_timeframe:
|
||||
# Get all contexts for latest timeframe
|
||||
contexts = self.store.get_all_contexts(layer, latest_timeframe)
|
||||
|
||||
# Limit number of items
|
||||
if len(contexts) > max_items_per_layer:
|
||||
# Keep only first N items
|
||||
contexts = dict(list(contexts.items())[:max_items_per_layer])
|
||||
|
||||
result[layer.value] = contexts
|
||||
|
||||
return result
|
||||
|
||||
def estimate_context_tokens(self, context_data: dict[str, Any]) -> int:
|
||||
"""Estimate total tokens for context data.
|
||||
|
||||
Args:
|
||||
context_data: Context data dictionary
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
import json
|
||||
|
||||
from src.brain.prompt_optimizer import PromptOptimizer
|
||||
|
||||
# Serialize to JSON and estimate tokens
|
||||
json_str = json.dumps(context_data, ensure_ascii=False)
|
||||
return PromptOptimizer.estimate_tokens(json_str)
|
||||
|
||||
def optimize_context_for_budget(
|
||||
self,
|
||||
decision_type: DecisionType,
|
||||
max_tokens: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Select and retrieve context data within a token budget.
|
||||
|
||||
Args:
|
||||
decision_type: Type of decision being made
|
||||
max_tokens: Maximum token budget for context
|
||||
|
||||
Returns:
|
||||
Optimized context data within budget
|
||||
"""
|
||||
# Start with minimal selection
|
||||
selection = self.select_with_scoring(decision_type, min_score=0.5)
|
||||
|
||||
# Retrieve data
|
||||
context_data = self.get_context_data(selection.layers)
|
||||
|
||||
# Check if within budget
|
||||
estimated_tokens = self.estimate_context_tokens(context_data)
|
||||
|
||||
if estimated_tokens <= max_tokens:
|
||||
return context_data
|
||||
|
||||
# If over budget, progressively reduce
|
||||
# 1. Reduce items per layer
|
||||
for max_items in [5, 3, 1]:
|
||||
context_data = self.get_context_data(selection.layers, max_items)
|
||||
estimated_tokens = self.estimate_context_tokens(context_data)
|
||||
if estimated_tokens <= max_tokens:
|
||||
return context_data
|
||||
|
||||
# 2. Remove lower-priority layers
|
||||
for min_score in [0.6, 0.7, 0.8, 0.9]:
|
||||
selection = self.select_with_scoring(decision_type, min_score=min_score)
|
||||
context_data = self.get_context_data(selection.layers, max_items_per_layer=1)
|
||||
estimated_tokens = self.estimate_context_tokens(context_data)
|
||||
if estimated_tokens <= max_tokens:
|
||||
return context_data
|
||||
|
||||
# Last resort: return only L7 with minimal data
|
||||
return self.get_context_data([ContextLayer.L7_REALTIME], max_items_per_layer=1)
|
||||
@@ -2,6 +2,17 @@
|
||||
|
||||
Constructs prompts from market data, calls Gemini, and parses structured
|
||||
JSON responses into validated TradeDecision objects.
|
||||
|
||||
Includes token efficiency optimizations:
|
||||
- Prompt compression and abbreviation
|
||||
- Response caching for common scenarios
|
||||
- Smart context selection
|
||||
- Token usage tracking and metrics
|
||||
|
||||
Includes external data integration:
|
||||
- News sentiment analysis
|
||||
- Economic calendar events
|
||||
- Market indicators
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -15,6 +26,11 @@ from typing import Any
|
||||
from google import genai
|
||||
|
||||
from src.config import Settings
|
||||
from src.data.news_api import NewsAPI, NewsSentiment
|
||||
from src.data.economic_calendar import EconomicCalendar
|
||||
from src.data.market_data import MarketData
|
||||
from src.brain.cache import DecisionCache
|
||||
from src.brain.prompt_optimizer import PromptOptimizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,23 +44,176 @@ class TradeDecision:
|
||||
action: str # "BUY" | "SELL" | "HOLD"
|
||||
confidence: int # 0-100
|
||||
rationale: str
|
||||
token_count: int = 0 # Estimated tokens used
|
||||
cached: bool = False # Whether decision came from cache
|
||||
|
||||
|
||||
class GeminiClient:
|
||||
"""Wraps the Gemini API for trade decision-making."""
|
||||
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
settings: Settings,
|
||||
news_api: NewsAPI | None = None,
|
||||
economic_calendar: EconomicCalendar | None = None,
|
||||
market_data: MarketData | None = None,
|
||||
enable_cache: bool = True,
|
||||
enable_optimization: bool = True,
|
||||
) -> None:
|
||||
self._settings = settings
|
||||
self._confidence_threshold = settings.CONFIDENCE_THRESHOLD
|
||||
self._client = genai.Client(api_key=settings.GEMINI_API_KEY)
|
||||
self._model_name = settings.GEMINI_MODEL
|
||||
|
||||
# External data sources (optional)
|
||||
self._news_api = news_api
|
||||
self._economic_calendar = economic_calendar
|
||||
self._market_data = market_data
|
||||
|
||||
# Token efficiency features
|
||||
self._enable_cache = enable_cache
|
||||
self._enable_optimization = enable_optimization
|
||||
self._cache = DecisionCache(ttl_seconds=300) if enable_cache else None
|
||||
self._optimizer = PromptOptimizer()
|
||||
|
||||
# Token usage metrics
|
||||
self._total_tokens_used = 0
|
||||
self._total_decisions = 0
|
||||
self._total_cached_decisions = 0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# External Data Integration
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _build_external_context(
|
||||
self, stock_code: str, news_sentiment: NewsSentiment | None = None
|
||||
) -> str:
|
||||
"""Build external data context for the prompt.
|
||||
|
||||
Args:
|
||||
stock_code: Stock ticker symbol
|
||||
news_sentiment: Optional pre-fetched news sentiment
|
||||
|
||||
Returns:
|
||||
Formatted string with external data context
|
||||
"""
|
||||
context_parts: list[str] = []
|
||||
|
||||
# News sentiment
|
||||
if news_sentiment is not None:
|
||||
sentiment_str = self._format_news_sentiment(news_sentiment)
|
||||
if sentiment_str:
|
||||
context_parts.append(sentiment_str)
|
||||
elif self._news_api is not None:
|
||||
# Fetch news sentiment if not provided
|
||||
try:
|
||||
sentiment = await self._news_api.get_news_sentiment(stock_code)
|
||||
if sentiment is not None:
|
||||
sentiment_str = self._format_news_sentiment(sentiment)
|
||||
if sentiment_str:
|
||||
context_parts.append(sentiment_str)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to fetch news sentiment: %s", exc)
|
||||
|
||||
# Economic events
|
||||
if self._economic_calendar is not None:
|
||||
events_str = self._format_economic_events(stock_code)
|
||||
if events_str:
|
||||
context_parts.append(events_str)
|
||||
|
||||
# Market indicators
|
||||
if self._market_data is not None:
|
||||
indicators_str = self._format_market_indicators()
|
||||
if indicators_str:
|
||||
context_parts.append(indicators_str)
|
||||
|
||||
if not context_parts:
|
||||
return ""
|
||||
|
||||
return "EXTERNAL DATA:\n" + "\n\n".join(context_parts)
|
||||
|
||||
def _format_news_sentiment(self, sentiment: NewsSentiment) -> str:
|
||||
"""Format news sentiment for prompt."""
|
||||
if sentiment.article_count == 0:
|
||||
return ""
|
||||
|
||||
# Select top 3 most relevant articles
|
||||
top_articles = sentiment.articles[:3]
|
||||
|
||||
lines = [
|
||||
f"News Sentiment: {sentiment.avg_sentiment:.2f} "
|
||||
f"(from {sentiment.article_count} articles)",
|
||||
]
|
||||
|
||||
for i, article in enumerate(top_articles, 1):
|
||||
lines.append(
|
||||
f" {i}. [{article.source}] {article.title} "
|
||||
f"(sentiment: {article.sentiment_score:.2f})"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _format_economic_events(self, stock_code: str) -> str:
|
||||
"""Format upcoming economic events for prompt."""
|
||||
if self._economic_calendar is None:
|
||||
return ""
|
||||
|
||||
# Check for upcoming high-impact events
|
||||
upcoming = self._economic_calendar.get_upcoming_events(
|
||||
days_ahead=7, min_impact="HIGH"
|
||||
)
|
||||
|
||||
if upcoming.high_impact_count == 0:
|
||||
return ""
|
||||
|
||||
lines = [
|
||||
f"Upcoming High-Impact Events: {upcoming.high_impact_count} in next 7 days"
|
||||
]
|
||||
|
||||
if upcoming.next_major_event is not None:
|
||||
event = upcoming.next_major_event
|
||||
lines.append(
|
||||
f" Next: {event.name} ({event.event_type}) "
|
||||
f"on {event.datetime.strftime('%Y-%m-%d')}"
|
||||
)
|
||||
|
||||
# Check for earnings
|
||||
earnings_date = self._economic_calendar.get_earnings_date(stock_code)
|
||||
if earnings_date is not None:
|
||||
lines.append(
|
||||
f" Earnings: {stock_code} on {earnings_date.strftime('%Y-%m-%d')}"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _format_market_indicators(self) -> str:
|
||||
"""Format market indicators for prompt."""
|
||||
if self._market_data is None:
|
||||
return ""
|
||||
|
||||
try:
|
||||
indicators = self._market_data.get_market_indicators()
|
||||
lines = [f"Market Sentiment: {indicators.sentiment.name}"]
|
||||
|
||||
# Add breadth if meaningful
|
||||
if indicators.breadth.advance_decline_ratio != 1.0:
|
||||
lines.append(
|
||||
f"Advance/Decline Ratio: {indicators.breadth.advance_decline_ratio:.2f}"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to get market indicators: %s", exc)
|
||||
return ""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Prompt Construction
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def build_prompt(self, market_data: dict[str, Any]) -> str:
|
||||
"""Build a structured prompt from market data.
|
||||
async def build_prompt(
|
||||
self, market_data: dict[str, Any], news_sentiment: NewsSentiment | None = None
|
||||
) -> str:
|
||||
"""Build a structured prompt from market data and external sources.
|
||||
|
||||
The prompt instructs Gemini to return valid JSON with action,
|
||||
confidence, and rationale fields.
|
||||
@@ -72,6 +241,60 @@ class GeminiClient:
|
||||
|
||||
market_info = "\n".join(market_info_lines)
|
||||
|
||||
# Add external data context if available
|
||||
external_context = await self._build_external_context(
|
||||
market_data["stock_code"], news_sentiment
|
||||
)
|
||||
if external_context:
|
||||
market_info += f"\n\n{external_context}"
|
||||
|
||||
json_format = (
|
||||
'{"action": "BUY"|"SELL"|"HOLD", '
|
||||
'"confidence": <int 0-100>, "rationale": "<string>"}'
|
||||
)
|
||||
return (
|
||||
f"You are a professional {market_name} trading analyst.\n"
|
||||
"Analyze the following market data and decide whether to "
|
||||
"BUY, SELL, or HOLD.\n\n"
|
||||
f"{market_info}\n\n"
|
||||
"You MUST respond with ONLY valid JSON in the following format:\n"
|
||||
f"{json_format}\n\n"
|
||||
"Rules:\n"
|
||||
"- action must be exactly one of: BUY, SELL, HOLD\n"
|
||||
"- confidence must be an integer from 0 to 100\n"
|
||||
"- rationale must explain your reasoning concisely\n"
|
||||
"- Do NOT wrap the JSON in markdown code blocks\n"
|
||||
)
|
||||
|
||||
def build_prompt_sync(self, market_data: dict[str, Any]) -> str:
|
||||
"""Synchronous version of build_prompt (for backward compatibility).
|
||||
|
||||
This version does NOT include external data integration.
|
||||
Use async build_prompt() for full functionality.
|
||||
"""
|
||||
market_name = market_data.get("market_name", "Korean stock market")
|
||||
|
||||
# Build market data section dynamically based on available fields
|
||||
market_info_lines = [
|
||||
f"Market: {market_name}",
|
||||
f"Stock Code: {market_data['stock_code']}",
|
||||
f"Current Price: {market_data['current_price']}",
|
||||
]
|
||||
|
||||
# Add orderbook if available (domestic markets)
|
||||
if "orderbook" in market_data:
|
||||
market_info_lines.append(
|
||||
f"Orderbook: {json.dumps(market_data['orderbook'], ensure_ascii=False)}"
|
||||
)
|
||||
|
||||
# Add foreigner net if non-zero
|
||||
if market_data.get("foreigner_net", 0) != 0:
|
||||
market_info_lines.append(
|
||||
f"Foreigner Net Buy/Sell: {market_data['foreigner_net']}"
|
||||
)
|
||||
|
||||
market_info = "\n".join(market_info_lines)
|
||||
|
||||
json_format = (
|
||||
'{"action": "BUY"|"SELL"|"HOLD", '
|
||||
'"confidence": <int 0-100>, "rationale": "<string>"}'
|
||||
@@ -152,28 +375,383 @@ class GeminiClient:
|
||||
# API Call
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def decide(self, market_data: dict[str, Any]) -> TradeDecision:
|
||||
"""Build prompt, call Gemini, and return a parsed decision."""
|
||||
prompt = self.build_prompt(market_data)
|
||||
logger.info("Requesting trade decision from Gemini")
|
||||
async def decide(
|
||||
self, market_data: dict[str, Any], news_sentiment: NewsSentiment | None = None
|
||||
) -> TradeDecision:
|
||||
"""Build prompt, call Gemini, and return a parsed decision.
|
||||
|
||||
Args:
|
||||
market_data: Market data dictionary with price, orderbook, etc.
|
||||
news_sentiment: Optional pre-fetched news sentiment
|
||||
|
||||
Returns:
|
||||
Parsed TradeDecision
|
||||
"""
|
||||
# Check cache first
|
||||
if self._cache:
|
||||
cached_decision = self._cache.get(market_data)
|
||||
if cached_decision:
|
||||
self._total_cached_decisions += 1
|
||||
self._total_decisions += 1
|
||||
logger.info(
|
||||
"Cache hit for decision",
|
||||
extra={
|
||||
"action": cached_decision.action,
|
||||
"confidence": cached_decision.confidence,
|
||||
"cache_hit_rate": self.get_cache_hit_rate(),
|
||||
},
|
||||
)
|
||||
# Return cached decision with cached flag
|
||||
return TradeDecision(
|
||||
action=cached_decision.action,
|
||||
confidence=cached_decision.confidence,
|
||||
rationale=cached_decision.rationale,
|
||||
token_count=0,
|
||||
cached=True,
|
||||
)
|
||||
|
||||
# Build optimized prompt
|
||||
if self._enable_optimization:
|
||||
prompt = self._optimizer.build_compressed_prompt(market_data)
|
||||
else:
|
||||
prompt = await self.build_prompt(market_data, news_sentiment)
|
||||
|
||||
# Estimate tokens
|
||||
token_count = self._optimizer.estimate_tokens(prompt)
|
||||
self._total_tokens_used += token_count
|
||||
|
||||
logger.info(
|
||||
"Requesting trade decision from Gemini",
|
||||
extra={"estimated_tokens": token_count, "optimized": self._enable_optimization},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self._client.aio.models.generate_content(
|
||||
model=self._model_name, contents=prompt,
|
||||
model=self._model_name,
|
||||
contents=prompt,
|
||||
)
|
||||
raw = response.text
|
||||
except Exception as exc:
|
||||
logger.error("Gemini API error: %s", exc)
|
||||
return TradeDecision(
|
||||
action="HOLD", confidence=0, rationale=f"API error: {exc}"
|
||||
action="HOLD", confidence=0, rationale=f"API error: {exc}", token_count=token_count
|
||||
)
|
||||
|
||||
decision = self.parse_response(raw)
|
||||
self._total_decisions += 1
|
||||
|
||||
# Add token count to decision
|
||||
decision_with_tokens = TradeDecision(
|
||||
action=decision.action,
|
||||
confidence=decision.confidence,
|
||||
rationale=decision.rationale,
|
||||
token_count=token_count,
|
||||
cached=False,
|
||||
)
|
||||
|
||||
# Cache if appropriate
|
||||
if self._cache and self._cache.should_cache_decision(decision):
|
||||
self._cache.set(market_data, decision)
|
||||
|
||||
logger.info(
|
||||
"Gemini decision",
|
||||
extra={
|
||||
"action": decision.action,
|
||||
"confidence": decision.confidence,
|
||||
"tokens": token_count,
|
||||
"avg_tokens": self.get_avg_tokens_per_decision(),
|
||||
},
|
||||
)
|
||||
return decision
|
||||
|
||||
return decision_with_tokens
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Token Efficiency Metrics
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_token_metrics(self) -> dict[str, Any]:
|
||||
"""Get token usage metrics.
|
||||
|
||||
Returns:
|
||||
Dictionary with token usage statistics
|
||||
"""
|
||||
metrics = {
|
||||
"total_tokens_used": self._total_tokens_used,
|
||||
"total_decisions": self._total_decisions,
|
||||
"total_cached_decisions": self._total_cached_decisions,
|
||||
"avg_tokens_per_decision": self.get_avg_tokens_per_decision(),
|
||||
"cache_hit_rate": self.get_cache_hit_rate(),
|
||||
}
|
||||
|
||||
if self._cache:
|
||||
cache_metrics = self._cache.get_metrics()
|
||||
metrics["cache_metrics"] = cache_metrics.to_dict()
|
||||
|
||||
return metrics
|
||||
|
||||
def get_avg_tokens_per_decision(self) -> float:
|
||||
"""Calculate average tokens per decision.
|
||||
|
||||
Returns:
|
||||
Average tokens per decision
|
||||
"""
|
||||
if self._total_decisions == 0:
|
||||
return 0.0
|
||||
return self._total_tokens_used / self._total_decisions
|
||||
|
||||
def get_cache_hit_rate(self) -> float:
|
||||
"""Calculate cache hit rate.
|
||||
|
||||
Returns:
|
||||
Cache hit rate (0.0 to 1.0)
|
||||
"""
|
||||
if self._total_decisions == 0:
|
||||
return 0.0
|
||||
return self._total_cached_decisions / self._total_decisions
|
||||
|
||||
def reset_metrics(self) -> None:
|
||||
"""Reset token usage metrics."""
|
||||
self._total_tokens_used = 0
|
||||
self._total_decisions = 0
|
||||
self._total_cached_decisions = 0
|
||||
if self._cache:
|
||||
self._cache.reset_metrics()
|
||||
logger.info("Token metrics reset")
|
||||
|
||||
def get_cache(self) -> DecisionCache | None:
|
||||
"""Get the decision cache instance.
|
||||
|
||||
Returns:
|
||||
DecisionCache instance or None if caching disabled
|
||||
"""
|
||||
return self._cache
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Batch Decision Making (for daily trading mode)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def decide_batch(
|
||||
self, stocks_data: list[dict[str, Any]]
|
||||
) -> dict[str, TradeDecision]:
|
||||
"""Make decisions for multiple stocks in a single API call.
|
||||
|
||||
This is designed for daily trading mode to minimize API usage
|
||||
when working with Gemini Free tier (20 calls/day limit).
|
||||
|
||||
Args:
|
||||
stocks_data: List of market data dictionaries, each with:
|
||||
- stock_code: Stock ticker
|
||||
- current_price: Current price
|
||||
- market_name: Market name (optional)
|
||||
- foreigner_net: Foreigner net buy/sell (optional)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping stock_code to TradeDecision
|
||||
|
||||
Example:
|
||||
>>> stocks_data = [
|
||||
... {"stock_code": "AAPL", "current_price": 185.5},
|
||||
... {"stock_code": "MSFT", "current_price": 420.0},
|
||||
... ]
|
||||
>>> decisions = await client.decide_batch(stocks_data)
|
||||
>>> decisions["AAPL"].action
|
||||
'BUY'
|
||||
"""
|
||||
if not stocks_data:
|
||||
return {}
|
||||
|
||||
# Build compressed batch prompt
|
||||
market_name = stocks_data[0].get("market_name", "stock market")
|
||||
|
||||
# Format stock data as compact JSON array
|
||||
compact_stocks = []
|
||||
for stock in stocks_data:
|
||||
compact = {
|
||||
"code": stock["stock_code"],
|
||||
"price": stock["current_price"],
|
||||
}
|
||||
if stock.get("foreigner_net", 0) != 0:
|
||||
compact["frgn"] = stock["foreigner_net"]
|
||||
compact_stocks.append(compact)
|
||||
|
||||
data_str = json.dumps(compact_stocks, ensure_ascii=False)
|
||||
|
||||
prompt = (
|
||||
f"You are a professional {market_name} trading analyst.\n"
|
||||
"Analyze the following stocks and decide whether to BUY, SELL, or HOLD each one.\n\n"
|
||||
f"Stock Data: {data_str}\n\n"
|
||||
"You MUST respond with ONLY a valid JSON array in this format:\n"
|
||||
'[{"code": "AAPL", "action": "BUY", "confidence": 85, "rationale": "..."},\n'
|
||||
' {"code": "MSFT", "action": "HOLD", "confidence": 50, "rationale": "..."}, ...]\n\n'
|
||||
"Rules:\n"
|
||||
"- Return one decision object per stock\n"
|
||||
"- action must be exactly: BUY, SELL, or HOLD\n"
|
||||
"- confidence must be 0-100\n"
|
||||
"- rationale should be concise (1-2 sentences)\n"
|
||||
"- Do NOT wrap JSON in markdown code blocks\n"
|
||||
)
|
||||
|
||||
# Estimate tokens
|
||||
token_count = self._optimizer.estimate_tokens(prompt)
|
||||
self._total_tokens_used += token_count
|
||||
|
||||
logger.info(
|
||||
"Requesting batch decision for %d stocks from Gemini",
|
||||
len(stocks_data),
|
||||
extra={"estimated_tokens": token_count},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self._client.aio.models.generate_content(
|
||||
model=self._model_name,
|
||||
contents=prompt,
|
||||
)
|
||||
raw = response.text
|
||||
except Exception as exc:
|
||||
logger.error("Gemini API error in batch decision: %s", exc)
|
||||
# Return HOLD for all stocks on API error
|
||||
return {
|
||||
stock["stock_code"]: TradeDecision(
|
||||
action="HOLD",
|
||||
confidence=0,
|
||||
rationale=f"API error: {exc}",
|
||||
token_count=token_count,
|
||||
cached=False,
|
||||
)
|
||||
for stock in stocks_data
|
||||
}
|
||||
|
||||
# Parse batch response
|
||||
return self._parse_batch_response(raw, stocks_data, token_count)
|
||||
|
||||
def _parse_batch_response(
|
||||
self, raw: str, stocks_data: list[dict[str, Any]], token_count: int
|
||||
) -> dict[str, TradeDecision]:
|
||||
"""Parse batch response into a dictionary of decisions.
|
||||
|
||||
Args:
|
||||
raw: Raw response from Gemini
|
||||
stocks_data: Original stock data list
|
||||
token_count: Token count for the request
|
||||
|
||||
Returns:
|
||||
Dictionary mapping stock_code to TradeDecision
|
||||
"""
|
||||
if not raw or not raw.strip():
|
||||
logger.warning("Empty batch response from Gemini — defaulting all to HOLD")
|
||||
return {
|
||||
stock["stock_code"]: TradeDecision(
|
||||
action="HOLD",
|
||||
confidence=0,
|
||||
rationale="Empty response",
|
||||
token_count=0,
|
||||
cached=False,
|
||||
)
|
||||
for stock in stocks_data
|
||||
}
|
||||
|
||||
# Strip markdown code fences if present
|
||||
cleaned = raw.strip()
|
||||
match = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", cleaned, re.DOTALL)
|
||||
if match:
|
||||
cleaned = match.group(1).strip()
|
||||
|
||||
try:
|
||||
data = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Malformed JSON in batch response — defaulting all to HOLD")
|
||||
return {
|
||||
stock["stock_code"]: TradeDecision(
|
||||
action="HOLD",
|
||||
confidence=0,
|
||||
rationale="Malformed JSON response",
|
||||
token_count=0,
|
||||
cached=False,
|
||||
)
|
||||
for stock in stocks_data
|
||||
}
|
||||
|
||||
if not isinstance(data, list):
|
||||
logger.warning("Batch response is not a JSON array — defaulting all to HOLD")
|
||||
return {
|
||||
stock["stock_code"]: TradeDecision(
|
||||
action="HOLD",
|
||||
confidence=0,
|
||||
rationale="Invalid response format",
|
||||
token_count=0,
|
||||
cached=False,
|
||||
)
|
||||
for stock in stocks_data
|
||||
}
|
||||
|
||||
# Build decision map
|
||||
decisions: dict[str, TradeDecision] = {}
|
||||
stock_codes = {stock["stock_code"] for stock in stocks_data}
|
||||
|
||||
for item in data:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
code = item.get("code")
|
||||
if not code or code not in stock_codes:
|
||||
continue
|
||||
|
||||
# Validate required fields
|
||||
if not all(k in item for k in ("action", "confidence", "rationale")):
|
||||
logger.warning("Missing fields for %s — using HOLD", code)
|
||||
decisions[code] = TradeDecision(
|
||||
action="HOLD",
|
||||
confidence=0,
|
||||
rationale="Missing required fields",
|
||||
token_count=0,
|
||||
cached=False,
|
||||
)
|
||||
continue
|
||||
|
||||
action = str(item["action"]).upper()
|
||||
if action not in VALID_ACTIONS:
|
||||
logger.warning("Invalid action '%s' for %s — forcing HOLD", action, code)
|
||||
action = "HOLD"
|
||||
|
||||
confidence = int(item["confidence"])
|
||||
rationale = str(item["rationale"])
|
||||
|
||||
# Enforce confidence threshold
|
||||
if confidence < self._confidence_threshold:
|
||||
logger.info(
|
||||
"Confidence %d < threshold %d for %s — forcing HOLD",
|
||||
confidence,
|
||||
self._confidence_threshold,
|
||||
code,
|
||||
)
|
||||
action = "HOLD"
|
||||
|
||||
decisions[code] = TradeDecision(
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
rationale=rationale,
|
||||
token_count=token_count // len(stocks_data), # Split token cost
|
||||
cached=False,
|
||||
)
|
||||
self._total_decisions += 1
|
||||
|
||||
# Fill in missing stocks with HOLD
|
||||
for stock in stocks_data:
|
||||
code = stock["stock_code"]
|
||||
if code not in decisions:
|
||||
logger.warning("No decision for %s in batch response — using HOLD", code)
|
||||
decisions[code] = TradeDecision(
|
||||
action="HOLD",
|
||||
confidence=0,
|
||||
rationale="Not found in batch response",
|
||||
token_count=0,
|
||||
cached=False,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Batch decision completed for %d stocks",
|
||||
len(decisions),
|
||||
extra={"tokens": token_count},
|
||||
)
|
||||
|
||||
return decisions
|
||||
|
||||
267
src/brain/prompt_optimizer.py
Normal file
267
src/brain/prompt_optimizer.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""Prompt optimization utilities for reducing token usage.
|
||||
|
||||
This module provides tools to compress prompts while maintaining decision quality:
|
||||
- Token counting
|
||||
- Text compression and abbreviation
|
||||
- Template-based prompts with variable slots
|
||||
- Priority-based context truncation
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
# Abbreviation mapping for common terms
|
||||
ABBREVIATIONS = {
|
||||
"price": "P",
|
||||
"volume": "V",
|
||||
"current": "cur",
|
||||
"previous": "prev",
|
||||
"change": "chg",
|
||||
"percentage": "pct",
|
||||
"market": "mkt",
|
||||
"orderbook": "ob",
|
||||
"foreigner": "fgn",
|
||||
"buy": "B",
|
||||
"sell": "S",
|
||||
"hold": "H",
|
||||
"confidence": "conf",
|
||||
"rationale": "reason",
|
||||
"action": "act",
|
||||
"net": "net",
|
||||
}
|
||||
|
||||
# Reverse mapping for decompression
|
||||
REVERSE_ABBREVIATIONS = {v: k for k, v in ABBREVIATIONS.items()}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TokenMetrics:
|
||||
"""Metrics about token usage in a prompt."""
|
||||
|
||||
char_count: int
|
||||
word_count: int
|
||||
estimated_tokens: int # Rough estimate: ~4 chars per token
|
||||
compression_ratio: float = 1.0 # Original / Compressed
|
||||
|
||||
|
||||
class PromptOptimizer:
|
||||
"""Optimizes prompts to reduce token usage while maintaining quality."""
|
||||
|
||||
@staticmethod
|
||||
def estimate_tokens(text: str) -> int:
|
||||
"""Estimate token count for text.
|
||||
|
||||
Uses a simple heuristic: ~4 characters per token for English.
|
||||
This is approximate but sufficient for optimization purposes.
|
||||
|
||||
Args:
|
||||
text: Input text to estimate tokens for
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
# Simple estimate: 1 token ≈ 4 characters
|
||||
return max(1, len(text) // 4)
|
||||
|
||||
@staticmethod
|
||||
def count_tokens(text: str) -> TokenMetrics:
|
||||
"""Count various metrics for a text.
|
||||
|
||||
Args:
|
||||
text: Input text to analyze
|
||||
|
||||
Returns:
|
||||
TokenMetrics with character, word, and estimated token counts
|
||||
"""
|
||||
char_count = len(text)
|
||||
word_count = len(text.split())
|
||||
estimated_tokens = PromptOptimizer.estimate_tokens(text)
|
||||
|
||||
return TokenMetrics(
|
||||
char_count=char_count,
|
||||
word_count=word_count,
|
||||
estimated_tokens=estimated_tokens,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def compress_json(data: dict[str, Any]) -> str:
|
||||
"""Compress JSON by removing whitespace.
|
||||
|
||||
Args:
|
||||
data: Dictionary to serialize
|
||||
|
||||
Returns:
|
||||
Compact JSON string without whitespace
|
||||
"""
|
||||
return json.dumps(data, separators=(",", ":"), ensure_ascii=False)
|
||||
|
||||
@staticmethod
|
||||
def abbreviate_text(text: str, aggressive: bool = False) -> str:
|
||||
"""Apply abbreviations to reduce text length.
|
||||
|
||||
Args:
|
||||
text: Input text to abbreviate
|
||||
aggressive: If True, apply more aggressive compression
|
||||
|
||||
Returns:
|
||||
Abbreviated text
|
||||
"""
|
||||
result = text
|
||||
|
||||
# Apply word-level abbreviations (case-insensitive)
|
||||
for full, abbr in ABBREVIATIONS.items():
|
||||
# Word boundaries to avoid partial replacements
|
||||
pattern = r"\b" + re.escape(full) + r"\b"
|
||||
result = re.sub(pattern, abbr, result, flags=re.IGNORECASE)
|
||||
|
||||
if aggressive:
|
||||
# Remove articles and filler words
|
||||
result = re.sub(r"\b(a|an|the)\b", "", result, flags=re.IGNORECASE)
|
||||
result = re.sub(r"\b(is|are|was|were)\b", "", result, flags=re.IGNORECASE)
|
||||
# Collapse multiple spaces
|
||||
result = re.sub(r"\s+", " ", result)
|
||||
|
||||
return result.strip()
|
||||
|
||||
@staticmethod
|
||||
def build_compressed_prompt(
|
||||
market_data: dict[str, Any],
|
||||
include_instructions: bool = True,
|
||||
max_length: int | None = None,
|
||||
) -> str:
|
||||
"""Build a compressed prompt from market data.
|
||||
|
||||
Args:
|
||||
market_data: Market data dictionary with stock info
|
||||
include_instructions: Whether to include full instructions
|
||||
max_length: Maximum character length (truncates if needed)
|
||||
|
||||
Returns:
|
||||
Compressed prompt string
|
||||
"""
|
||||
# Abbreviated market name
|
||||
market_name = market_data.get("market_name", "KR")
|
||||
if "Korea" in market_name:
|
||||
market_name = "KR"
|
||||
elif "United States" in market_name or "US" in market_name:
|
||||
market_name = "US"
|
||||
|
||||
# Core data - always included
|
||||
core_info = {
|
||||
"mkt": market_name,
|
||||
"code": market_data["stock_code"],
|
||||
"P": market_data["current_price"],
|
||||
}
|
||||
|
||||
# Optional fields
|
||||
if "orderbook" in market_data and market_data["orderbook"]:
|
||||
ob = market_data["orderbook"]
|
||||
# Compress orderbook: keep only top 3 levels
|
||||
compressed_ob = {
|
||||
"bid": ob.get("bid", [])[:3],
|
||||
"ask": ob.get("ask", [])[:3],
|
||||
}
|
||||
core_info["ob"] = compressed_ob
|
||||
|
||||
if market_data.get("foreigner_net", 0) != 0:
|
||||
core_info["fgn_net"] = market_data["foreigner_net"]
|
||||
|
||||
# Compress to JSON
|
||||
data_str = PromptOptimizer.compress_json(core_info)
|
||||
|
||||
if include_instructions:
|
||||
# Minimal instructions
|
||||
prompt = (
|
||||
f"{market_name} trader. Analyze:\n{data_str}\n\n"
|
||||
'Return JSON: {"act":"BUY"|"SELL"|"HOLD","conf":<0-100>,"reason":"<text>"}\n'
|
||||
"Rules: act=BUY/SELL/HOLD, conf=0-100, reason=concise. No markdown."
|
||||
)
|
||||
else:
|
||||
# Data only (for cached contexts where instructions are known)
|
||||
prompt = data_str
|
||||
|
||||
# Truncate if needed
|
||||
if max_length and len(prompt) > max_length:
|
||||
prompt = prompt[:max_length] + "..."
|
||||
|
||||
return prompt
|
||||
|
||||
@staticmethod
|
||||
def truncate_context(
|
||||
context: dict[str, Any],
|
||||
max_tokens: int,
|
||||
priority_keys: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Truncate context data to fit within token budget.
|
||||
|
||||
Keeps high-priority keys first, then truncates less important data.
|
||||
|
||||
Args:
|
||||
context: Context dictionary to truncate
|
||||
max_tokens: Maximum token budget
|
||||
priority_keys: List of keys to keep (in order of priority)
|
||||
|
||||
Returns:
|
||||
Truncated context dictionary
|
||||
"""
|
||||
if not context:
|
||||
return {}
|
||||
|
||||
if priority_keys is None:
|
||||
priority_keys = []
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
current_tokens = 0
|
||||
|
||||
# Add priority keys first
|
||||
for key in priority_keys:
|
||||
if key in context:
|
||||
value_str = json.dumps(context[key])
|
||||
tokens = PromptOptimizer.estimate_tokens(value_str)
|
||||
|
||||
if current_tokens + tokens <= max_tokens:
|
||||
result[key] = context[key]
|
||||
current_tokens += tokens
|
||||
else:
|
||||
break
|
||||
|
||||
# Add remaining keys if space available
|
||||
for key, value in context.items():
|
||||
if key in result:
|
||||
continue
|
||||
|
||||
value_str = json.dumps(value)
|
||||
tokens = PromptOptimizer.estimate_tokens(value_str)
|
||||
|
||||
if current_tokens + tokens <= max_tokens:
|
||||
result[key] = value
|
||||
current_tokens += tokens
|
||||
else:
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def calculate_compression_ratio(original: str, compressed: str) -> float:
|
||||
"""Calculate compression ratio between original and compressed text.
|
||||
|
||||
Args:
|
||||
original: Original text
|
||||
compressed: Compressed text
|
||||
|
||||
Returns:
|
||||
Compression ratio (original_tokens / compressed_tokens)
|
||||
"""
|
||||
original_tokens = PromptOptimizer.estimate_tokens(original)
|
||||
compressed_tokens = PromptOptimizer.estimate_tokens(compressed)
|
||||
|
||||
if compressed_tokens == 0:
|
||||
return 1.0
|
||||
|
||||
return original_tokens / compressed_tokens
|
||||
@@ -55,6 +55,9 @@ class KISBroker:
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._access_token: str | None = None
|
||||
self._token_expires_at: float = 0.0
|
||||
self._token_lock = asyncio.Lock()
|
||||
self._last_refresh_attempt: float = 0.0
|
||||
self._refresh_cooldown: float = 60.0 # Seconds (matches KIS 1/minute limit)
|
||||
self._rate_limiter = LeakyBucket(settings.RATE_LIMIT_RPS)
|
||||
|
||||
def _get_session(self) -> aiohttp.ClientSession:
|
||||
@@ -80,30 +83,54 @@ class KISBroker:
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _ensure_token(self) -> str:
|
||||
"""Return a valid access token, refreshing if expired."""
|
||||
"""Return a valid access token, refreshing if expired.
|
||||
|
||||
Uses a lock to prevent concurrent token refresh attempts that would
|
||||
hit the API's 1-per-minute rate limit (EGW00133).
|
||||
"""
|
||||
# Fast path: check without lock
|
||||
now = asyncio.get_event_loop().time()
|
||||
if self._access_token and now < self._token_expires_at:
|
||||
return self._access_token
|
||||
|
||||
logger.info("Refreshing KIS access token")
|
||||
session = self._get_session()
|
||||
url = f"{self._base_url}/oauth2/tokenP"
|
||||
body = {
|
||||
"grant_type": "client_credentials",
|
||||
"appkey": self._app_key,
|
||||
"appsecret": self._app_secret,
|
||||
}
|
||||
# Slow path: acquire lock and refresh
|
||||
async with self._token_lock:
|
||||
# Re-check after acquiring lock (another coroutine may have refreshed)
|
||||
now = asyncio.get_event_loop().time()
|
||||
if self._access_token and now < self._token_expires_at:
|
||||
return self._access_token
|
||||
|
||||
async with session.post(url, json=body) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ConnectionError(f"Token refresh failed ({resp.status}): {text}")
|
||||
data = await resp.json()
|
||||
# Check cooldown period (prevents hitting EGW00133: 1/minute limit)
|
||||
time_since_last_attempt = now - self._last_refresh_attempt
|
||||
if time_since_last_attempt < self._refresh_cooldown:
|
||||
remaining = self._refresh_cooldown - time_since_last_attempt
|
||||
error_msg = (
|
||||
f"Token refresh on cooldown. "
|
||||
f"Retry in {remaining:.1f}s (KIS allows 1/minute)"
|
||||
)
|
||||
logger.warning(error_msg)
|
||||
raise ConnectionError(error_msg)
|
||||
|
||||
self._access_token = data["access_token"]
|
||||
self._token_expires_at = now + data.get("expires_in", 86400) - 60 # 1-min buffer
|
||||
logger.info("Token refreshed successfully")
|
||||
return self._access_token
|
||||
logger.info("Refreshing KIS access token")
|
||||
self._last_refresh_attempt = now
|
||||
session = self._get_session()
|
||||
url = f"{self._base_url}/oauth2/tokenP"
|
||||
body = {
|
||||
"grant_type": "client_credentials",
|
||||
"appkey": self._app_key,
|
||||
"appsecret": self._app_secret,
|
||||
}
|
||||
|
||||
async with session.post(url, json=body) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ConnectionError(f"Token refresh failed ({resp.status}): {text}")
|
||||
data = await resp.json()
|
||||
|
||||
self._access_token = data["access_token"]
|
||||
self._token_expires_at = now + data.get("expires_in", 86400) - 60 # 1-min buffer
|
||||
logger.info("Token refreshed successfully")
|
||||
return self._access_token
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Hash Key (required for POST bodies)
|
||||
@@ -111,6 +138,7 @@ class KISBroker:
|
||||
|
||||
async def _get_hash_key(self, body: dict[str, Any]) -> str:
|
||||
"""Request a hash key from KIS for POST request body signing."""
|
||||
await self._rate_limiter.acquire()
|
||||
session = self._get_session()
|
||||
url = f"{self._base_url}/uapi/hashkey"
|
||||
headers = {
|
||||
@@ -252,3 +280,153 @@ class KISBroker:
|
||||
return data
|
||||
except (TimeoutError, aiohttp.ClientError) as exc:
|
||||
raise ConnectionError(f"Network error sending order: {exc}") from exc
|
||||
|
||||
async def fetch_market_rankings(
|
||||
self,
|
||||
ranking_type: str = "volume",
|
||||
limit: int = 30,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch market rankings from KIS API.
|
||||
|
||||
Args:
|
||||
ranking_type: Type of ranking ("volume" or "fluctuation")
|
||||
limit: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List of stock data dicts with keys: stock_code, name, price, volume,
|
||||
change_rate, volume_increase_rate
|
||||
|
||||
Raises:
|
||||
ConnectionError: If API request fails
|
||||
"""
|
||||
await self._rate_limiter.acquire()
|
||||
session = self._get_session()
|
||||
|
||||
# TR_ID for volume ranking
|
||||
tr_id = "FHPST01710000" if ranking_type == "volume" else "FHPST01710100"
|
||||
headers = await self._auth_headers(tr_id)
|
||||
|
||||
params = {
|
||||
"FID_COND_MRKT_DIV_CODE": "J", # Stock/ETF/ETN
|
||||
"FID_COND_SCR_DIV_CODE": "20001", # Volume surge
|
||||
"FID_INPUT_ISCD": "0000", # All stocks
|
||||
"FID_DIV_CLS_CODE": "0", # All types
|
||||
"FID_BLNG_CLS_CODE": "0",
|
||||
"FID_TRGT_CLS_CODE": "111111111",
|
||||
"FID_TRGT_EXLS_CLS_CODE": "000000",
|
||||
"FID_INPUT_PRICE_1": "0",
|
||||
"FID_INPUT_PRICE_2": "0",
|
||||
"FID_VOL_CNT": "0",
|
||||
"FID_INPUT_DATE_1": "",
|
||||
}
|
||||
|
||||
url = f"{self._base_url}/uapi/domestic-stock/v1/quotations/volume-rank"
|
||||
|
||||
try:
|
||||
async with session.get(url, headers=headers, params=params) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ConnectionError(
|
||||
f"fetch_market_rankings failed ({resp.status}): {text}"
|
||||
)
|
||||
data = await resp.json()
|
||||
|
||||
# Parse response - output is a list of ranked stocks
|
||||
def _safe_float(value: str | float | None, default: float = 0.0) -> float:
|
||||
if value is None or value == "":
|
||||
return default
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
return default
|
||||
|
||||
rankings = []
|
||||
for item in data.get("output", [])[:limit]:
|
||||
rankings.append({
|
||||
"stock_code": item.get("mksc_shrn_iscd", ""),
|
||||
"name": item.get("hts_kor_isnm", ""),
|
||||
"price": _safe_float(item.get("stck_prpr", "0")),
|
||||
"volume": _safe_float(item.get("acml_vol", "0")),
|
||||
"change_rate": _safe_float(item.get("prdy_ctrt", "0")),
|
||||
"volume_increase_rate": _safe_float(item.get("vol_inrt", "0")),
|
||||
})
|
||||
return rankings
|
||||
|
||||
except (TimeoutError, aiohttp.ClientError) as exc:
|
||||
raise ConnectionError(f"Network error fetching rankings: {exc}") from exc
|
||||
|
||||
async def get_daily_prices(
|
||||
self,
|
||||
stock_code: str,
|
||||
days: int = 20,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch daily OHLCV price history for a stock.
|
||||
|
||||
Args:
|
||||
stock_code: 6-digit stock code
|
||||
days: Number of trading days to fetch (default 20 for RSI calculation)
|
||||
|
||||
Returns:
|
||||
List of daily price dicts with keys: date, open, high, low, close, volume
|
||||
Sorted oldest to newest
|
||||
|
||||
Raises:
|
||||
ConnectionError: If API request fails
|
||||
"""
|
||||
await self._rate_limiter.acquire()
|
||||
session = self._get_session()
|
||||
|
||||
headers = await self._auth_headers("FHKST03010100")
|
||||
|
||||
# Calculate date range (today and N days ago)
|
||||
from datetime import datetime, timedelta
|
||||
end_date = datetime.now().strftime("%Y%m%d")
|
||||
start_date = (datetime.now() - timedelta(days=days + 10)).strftime("%Y%m%d")
|
||||
|
||||
params = {
|
||||
"FID_COND_MRKT_DIV_CODE": "J",
|
||||
"FID_INPUT_ISCD": stock_code,
|
||||
"FID_INPUT_DATE_1": start_date,
|
||||
"FID_INPUT_DATE_2": end_date,
|
||||
"FID_PERIOD_DIV_CODE": "D", # Daily
|
||||
"FID_ORG_ADJ_PRC": "0", # Adjusted price
|
||||
}
|
||||
|
||||
url = f"{self._base_url}/uapi/domestic-stock/v1/quotations/inquire-daily-itemchartprice"
|
||||
|
||||
try:
|
||||
async with session.get(url, headers=headers, params=params) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ConnectionError(
|
||||
f"get_daily_prices failed ({resp.status}): {text}"
|
||||
)
|
||||
data = await resp.json()
|
||||
|
||||
# Parse response
|
||||
def _safe_float(value: str | float | None, default: float = 0.0) -> float:
|
||||
if value is None or value == "":
|
||||
return default
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
return default
|
||||
|
||||
prices = []
|
||||
for item in data.get("output2", []):
|
||||
prices.append({
|
||||
"date": item.get("stck_bsop_date", ""),
|
||||
"open": _safe_float(item.get("stck_oprc", "0")),
|
||||
"high": _safe_float(item.get("stck_hgpr", "0")),
|
||||
"low": _safe_float(item.get("stck_lwpr", "0")),
|
||||
"close": _safe_float(item.get("stck_clpr", "0")),
|
||||
"volume": _safe_float(item.get("acml_vol", "0")),
|
||||
})
|
||||
|
||||
# Sort oldest to newest (KIS returns newest first)
|
||||
prices.reverse()
|
||||
|
||||
return prices[:days] # Return only requested number of days
|
||||
|
||||
except (TimeoutError, aiohttp.ClientError) as exc:
|
||||
raise ConnectionError(f"Network error fetching daily prices: {exc}") from exc
|
||||
|
||||
@@ -19,22 +19,69 @@ class Settings(BaseSettings):
|
||||
GEMINI_API_KEY: str
|
||||
GEMINI_MODEL: str = "gemini-pro"
|
||||
|
||||
# External Data APIs (optional — for data-driven decisions)
|
||||
NEWS_API_KEY: str | None = None
|
||||
NEWS_API_PROVIDER: str = "alphavantage" # "alphavantage" or "newsapi"
|
||||
MARKET_DATA_API_KEY: str | None = None
|
||||
|
||||
# Legacy field names (for backward compatibility)
|
||||
ALPHA_VANTAGE_API_KEY: str | None = None
|
||||
NEWSAPI_KEY: str | None = None
|
||||
|
||||
# Risk Management
|
||||
CIRCUIT_BREAKER_PCT: float = Field(default=-3.0, le=0.0)
|
||||
FAT_FINGER_PCT: float = Field(default=30.0, gt=0.0, le=100.0)
|
||||
CONFIDENCE_THRESHOLD: int = Field(default=80, ge=0, le=100)
|
||||
|
||||
# Smart Scanner Configuration
|
||||
RSI_OVERSOLD_THRESHOLD: int = Field(default=30, ge=0, le=50)
|
||||
RSI_MOMENTUM_THRESHOLD: int = Field(default=70, ge=50, le=100)
|
||||
VOL_MULTIPLIER: float = Field(default=2.0, gt=1.0, le=10.0)
|
||||
SCANNER_TOP_N: int = Field(default=3, ge=1, le=10)
|
||||
|
||||
# Database
|
||||
DB_PATH: str = "data/trade_logs.db"
|
||||
|
||||
# Rate Limiting (requests per second for KIS API)
|
||||
RATE_LIMIT_RPS: float = 10.0
|
||||
# Conservative limit to avoid EGW00201 "초당 거래건수 초과" errors.
|
||||
# KIS API real limit is ~2 RPS; 2.0 provides maximum safety.
|
||||
RATE_LIMIT_RPS: float = 2.0
|
||||
|
||||
# Trading mode
|
||||
MODE: str = Field(default="paper", pattern="^(paper|live)$")
|
||||
|
||||
# Trading frequency mode (daily = batch API calls, realtime = per-stock calls)
|
||||
TRADE_MODE: str = Field(default="daily", pattern="^(daily|realtime)$")
|
||||
DAILY_SESSIONS: int = Field(default=4, ge=1, le=10)
|
||||
SESSION_INTERVAL_HOURS: int = Field(default=6, ge=1, le=24)
|
||||
|
||||
# Pre-Market Planner
|
||||
PRE_MARKET_MINUTES: int = Field(default=30, ge=10, le=120)
|
||||
MAX_SCENARIOS_PER_STOCK: int = Field(default=5, ge=1, le=10)
|
||||
PLANNER_TIMEOUT_SECONDS: int = Field(default=60, ge=10, le=300)
|
||||
DEFENSIVE_PLAYBOOK_ON_FAILURE: bool = True
|
||||
RESCAN_INTERVAL_SECONDS: int = Field(default=300, ge=60, le=900)
|
||||
|
||||
# Market selection (comma-separated market codes)
|
||||
ENABLED_MARKETS: str = "KR"
|
||||
ENABLED_MARKETS: str = "KR,US"
|
||||
|
||||
# Backup and Disaster Recovery (optional)
|
||||
BACKUP_ENABLED: bool = True
|
||||
BACKUP_DIR: str = "data/backups"
|
||||
S3_ENDPOINT_URL: str | None = None # For MinIO, Backblaze B2, etc.
|
||||
S3_ACCESS_KEY: str | None = None
|
||||
S3_SECRET_KEY: str | None = None
|
||||
S3_BUCKET_NAME: str | None = None
|
||||
S3_REGION: str = "us-east-1"
|
||||
|
||||
# Telegram Notifications (optional)
|
||||
TELEGRAM_BOT_TOKEN: str | None = None
|
||||
TELEGRAM_CHAT_ID: str | None = None
|
||||
TELEGRAM_ENABLED: bool = True
|
||||
|
||||
# Telegram Commands (optional)
|
||||
TELEGRAM_COMMANDS_ENABLED: bool = True
|
||||
TELEGRAM_POLLING_INTERVAL: float = 1.0 # seconds
|
||||
|
||||
model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}
|
||||
|
||||
|
||||
328
src/context/summarizer.py
Normal file
328
src/context/summarizer.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""Context summarization for efficient historical data representation.
|
||||
|
||||
This module summarizes old context data instead of including raw details:
|
||||
- Key metrics only (averages, trends, not details)
|
||||
- Rolling window (keep last N days detailed, summarize older)
|
||||
- Aggregate historical data efficiently
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from src.context.layer import ContextLayer
|
||||
from src.context.store import ContextStore
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SummaryStats:
|
||||
"""Statistical summary of historical data."""
|
||||
|
||||
count: int
|
||||
mean: float | None = None
|
||||
min: float | None = None
|
||||
max: float | None = None
|
||||
std: float | None = None
|
||||
trend: str | None = None # "up", "down", "flat"
|
||||
|
||||
|
||||
class ContextSummarizer:
|
||||
"""Summarizes historical context data to reduce token usage."""
|
||||
|
||||
def __init__(self, store: ContextStore) -> None:
|
||||
"""Initialize the context summarizer.
|
||||
|
||||
Args:
|
||||
store: ContextStore instance for retrieving context data
|
||||
"""
|
||||
self.store = store
|
||||
|
||||
def summarize_numeric_values(self, values: list[float]) -> SummaryStats:
|
||||
"""Summarize a list of numeric values.
|
||||
|
||||
Args:
|
||||
values: List of numeric values to summarize
|
||||
|
||||
Returns:
|
||||
SummaryStats with mean, min, max, std, and trend
|
||||
"""
|
||||
if not values:
|
||||
return SummaryStats(count=0)
|
||||
|
||||
count = len(values)
|
||||
mean = sum(values) / count
|
||||
min_val = min(values)
|
||||
max_val = max(values)
|
||||
|
||||
# Calculate standard deviation
|
||||
if count > 1:
|
||||
variance = sum((x - mean) ** 2 for x in values) / (count - 1)
|
||||
std = variance**0.5
|
||||
else:
|
||||
std = 0.0
|
||||
|
||||
# Determine trend
|
||||
trend = "flat"
|
||||
if count >= 3:
|
||||
# Simple trend: compare first third vs last third
|
||||
first_third = values[: count // 3]
|
||||
last_third = values[-(count // 3) :]
|
||||
first_avg = sum(first_third) / len(first_third)
|
||||
last_avg = sum(last_third) / len(last_third)
|
||||
|
||||
# Trend threshold: 5% change
|
||||
threshold = 0.05 * abs(first_avg) if first_avg != 0 else 0.01
|
||||
|
||||
if last_avg > first_avg + threshold:
|
||||
trend = "up"
|
||||
elif last_avg < first_avg - threshold:
|
||||
trend = "down"
|
||||
|
||||
return SummaryStats(
|
||||
count=count,
|
||||
mean=round(mean, 4),
|
||||
min=round(min_val, 4),
|
||||
max=round(max_val, 4),
|
||||
std=round(std, 4),
|
||||
trend=trend,
|
||||
)
|
||||
|
||||
def summarize_layer(
|
||||
self,
|
||||
layer: ContextLayer,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Summarize all context data for a layer within a date range.
|
||||
|
||||
Args:
|
||||
layer: Context layer to summarize
|
||||
start_date: Start date (inclusive), None for all
|
||||
end_date: End date (inclusive), None for now
|
||||
|
||||
Returns:
|
||||
Dictionary with summarized metrics
|
||||
"""
|
||||
if end_date is None:
|
||||
end_date = datetime.now(UTC)
|
||||
|
||||
# Get all contexts for this layer
|
||||
all_contexts = self.store.get_all_contexts(layer)
|
||||
|
||||
if not all_contexts:
|
||||
return {"summary": "No data available", "count": 0}
|
||||
|
||||
# Group numeric values by key
|
||||
numeric_data: dict[str, list[float]] = {}
|
||||
text_data: dict[str, list[str]] = {}
|
||||
|
||||
for key, value in all_contexts.items():
|
||||
# Try to extract numeric values
|
||||
if isinstance(value, (int, float)):
|
||||
if key not in numeric_data:
|
||||
numeric_data[key] = []
|
||||
numeric_data[key].append(float(value))
|
||||
elif isinstance(value, dict):
|
||||
# Extract numeric fields from dict
|
||||
for subkey, subvalue in value.items():
|
||||
if isinstance(subvalue, (int, float)):
|
||||
full_key = f"{key}.{subkey}"
|
||||
if full_key not in numeric_data:
|
||||
numeric_data[full_key] = []
|
||||
numeric_data[full_key].append(float(subvalue))
|
||||
elif isinstance(value, str):
|
||||
if key not in text_data:
|
||||
text_data[key] = []
|
||||
text_data[key].append(value)
|
||||
|
||||
# Summarize numeric data
|
||||
summary: dict[str, Any] = {}
|
||||
|
||||
for key, values in numeric_data.items():
|
||||
stats = self.summarize_numeric_values(values)
|
||||
summary[key] = {
|
||||
"count": stats.count,
|
||||
"avg": stats.mean,
|
||||
"range": [stats.min, stats.max],
|
||||
"trend": stats.trend,
|
||||
}
|
||||
|
||||
# Summarize text data (just counts)
|
||||
for key, values in text_data.items():
|
||||
summary[f"{key}_count"] = len(values)
|
||||
|
||||
summary["total_entries"] = len(all_contexts)
|
||||
|
||||
return summary
|
||||
|
||||
def rolling_window_summary(
|
||||
self,
|
||||
layer: ContextLayer,
|
||||
window_days: int = 30,
|
||||
summarize_older: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a rolling window summary.
|
||||
|
||||
Recent data (within window) is kept detailed.
|
||||
Older data is summarized to key metrics.
|
||||
|
||||
Args:
|
||||
layer: Context layer to summarize
|
||||
window_days: Number of days to keep detailed
|
||||
summarize_older: Whether to summarize data older than window
|
||||
|
||||
Returns:
|
||||
Dictionary with recent (detailed) and historical (summary) data
|
||||
"""
|
||||
result: dict[str, Any] = {
|
||||
"window_days": window_days,
|
||||
"recent_data": {},
|
||||
"historical_summary": {},
|
||||
}
|
||||
|
||||
# Get all contexts
|
||||
all_contexts = self.store.get_all_contexts(layer)
|
||||
|
||||
recent_values: dict[str, list[float]] = {}
|
||||
historical_values: dict[str, list[float]] = {}
|
||||
|
||||
for key, value in all_contexts.items():
|
||||
# For simplicity, treat all numeric values
|
||||
if isinstance(value, (int, float)):
|
||||
# Note: We don't have timestamps in context keys
|
||||
# This is a simplified implementation
|
||||
# In practice, would need to check timeframe field
|
||||
|
||||
# For now, put recent data in window
|
||||
if key not in recent_values:
|
||||
recent_values[key] = []
|
||||
recent_values[key].append(float(value))
|
||||
|
||||
# Detailed recent data
|
||||
result["recent_data"] = {key: values[-10:] for key, values in recent_values.items()}
|
||||
|
||||
# Summarized historical data
|
||||
if summarize_older:
|
||||
for key, values in historical_values.items():
|
||||
stats = self.summarize_numeric_values(values)
|
||||
result["historical_summary"][key] = {
|
||||
"count": stats.count,
|
||||
"avg": stats.mean,
|
||||
"trend": stats.trend,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def aggregate_to_higher_layer(
|
||||
self,
|
||||
source_layer: ContextLayer,
|
||||
target_layer: ContextLayer,
|
||||
metric_key: str,
|
||||
aggregation_func: str = "mean",
|
||||
) -> float | None:
|
||||
"""Aggregate data from source layer to target layer.
|
||||
|
||||
Args:
|
||||
source_layer: Source context layer (more granular)
|
||||
target_layer: Target context layer (less granular)
|
||||
metric_key: Key of metric to aggregate
|
||||
aggregation_func: Aggregation function ("mean", "sum", "max", "min")
|
||||
|
||||
Returns:
|
||||
Aggregated value, or None if no data available
|
||||
"""
|
||||
# Get all contexts from source layer
|
||||
source_contexts = self.store.get_all_contexts(source_layer)
|
||||
|
||||
# Extract values for metric_key
|
||||
values = []
|
||||
for key, value in source_contexts.items():
|
||||
if key == metric_key and isinstance(value, (int, float)):
|
||||
values.append(float(value))
|
||||
elif isinstance(value, dict) and metric_key in value:
|
||||
subvalue = value[metric_key]
|
||||
if isinstance(subvalue, (int, float)):
|
||||
values.append(float(subvalue))
|
||||
|
||||
if not values:
|
||||
return None
|
||||
|
||||
# Apply aggregation function
|
||||
if aggregation_func == "mean":
|
||||
return sum(values) / len(values)
|
||||
elif aggregation_func == "sum":
|
||||
return sum(values)
|
||||
elif aggregation_func == "max":
|
||||
return max(values)
|
||||
elif aggregation_func == "min":
|
||||
return min(values)
|
||||
else:
|
||||
return sum(values) / len(values) # Default to mean
|
||||
|
||||
def create_compact_summary(
|
||||
self,
|
||||
layers: list[ContextLayer],
|
||||
top_n_metrics: int = 5,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a compact summary across multiple layers.
|
||||
|
||||
Args:
|
||||
layers: List of context layers to summarize
|
||||
top_n_metrics: Number of top metrics to include per layer
|
||||
|
||||
Returns:
|
||||
Compact summary dictionary
|
||||
"""
|
||||
summary: dict[str, Any] = {}
|
||||
|
||||
for layer in layers:
|
||||
layer_summary = self.summarize_layer(layer)
|
||||
|
||||
# Keep only top N metrics (by count/relevance)
|
||||
metrics = []
|
||||
for key, value in layer_summary.items():
|
||||
if isinstance(value, dict) and "count" in value:
|
||||
metrics.append((key, value, value["count"]))
|
||||
|
||||
# Sort by count (descending)
|
||||
metrics.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
# Keep top N
|
||||
top_metrics = {m[0]: m[1] for m in metrics[:top_n_metrics]}
|
||||
|
||||
summary[layer.value] = top_metrics
|
||||
|
||||
return summary
|
||||
|
||||
def format_summary_for_prompt(self, summary: dict[str, Any]) -> str:
|
||||
"""Format summary for inclusion in a prompt.
|
||||
|
||||
Args:
|
||||
summary: Summary dictionary
|
||||
|
||||
Returns:
|
||||
Formatted string for prompt
|
||||
"""
|
||||
lines = []
|
||||
|
||||
for layer, metrics in summary.items():
|
||||
if not metrics:
|
||||
continue
|
||||
|
||||
lines.append(f"{layer}:")
|
||||
for key, value in metrics.items():
|
||||
if isinstance(value, dict):
|
||||
# Format as: key: avg=X, trend=Y
|
||||
parts = []
|
||||
if "avg" in value and value["avg"] is not None:
|
||||
parts.append(f"avg={value['avg']:.2f}")
|
||||
if "trend" in value and value["trend"]:
|
||||
parts.append(f"trend={value['trend']}")
|
||||
if parts:
|
||||
lines.append(f" {key}: {', '.join(parts)}")
|
||||
else:
|
||||
lines.append(f" {key}: {value}")
|
||||
|
||||
return "\n".join(lines)
|
||||
205
src/data/README.md
Normal file
205
src/data/README.md
Normal file
@@ -0,0 +1,205 @@
|
||||
# External Data Integration
|
||||
|
||||
This module provides objective external data sources to enhance trading decisions beyond just market prices and user input.
|
||||
|
||||
## Modules
|
||||
|
||||
### `news_api.py` - News Sentiment Analysis
|
||||
|
||||
Fetches real-time news for stocks with sentiment scoring.
|
||||
|
||||
**Features:**
|
||||
- Alpha Vantage and NewsAPI.org support
|
||||
- Sentiment scoring (-1.0 to +1.0)
|
||||
- 5-minute caching to minimize API quota usage
|
||||
- Graceful fallback when API unavailable
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from src.data.news_api import NewsAPI
|
||||
|
||||
# Initialize with API key
|
||||
news_api = NewsAPI(api_key="your_key", provider="alphavantage")
|
||||
|
||||
# Fetch news sentiment
|
||||
sentiment = await news_api.get_news_sentiment("AAPL")
|
||||
if sentiment:
|
||||
print(f"Average sentiment: {sentiment.avg_sentiment}")
|
||||
for article in sentiment.articles[:3]:
|
||||
print(f"{article.title} ({article.sentiment_score})")
|
||||
```
|
||||
|
||||
### `economic_calendar.py` - Major Economic Events
|
||||
|
||||
Tracks FOMC meetings, GDP releases, CPI, earnings calendars, and other market-moving events.
|
||||
|
||||
**Features:**
|
||||
- High-impact event tracking (FOMC, GDP, CPI)
|
||||
- Earnings calendar per stock
|
||||
- Event proximity checking
|
||||
- Hardcoded major events for 2026 (no API required)
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from src.data.economic_calendar import EconomicCalendar
|
||||
|
||||
calendar = EconomicCalendar()
|
||||
calendar.load_hardcoded_events()
|
||||
|
||||
# Get upcoming high-impact events
|
||||
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="HIGH")
|
||||
print(f"High-impact events: {upcoming.high_impact_count}")
|
||||
|
||||
# Check if near earnings
|
||||
earnings_date = calendar.get_earnings_date("AAPL")
|
||||
if earnings_date:
|
||||
print(f"Next earnings: {earnings_date}")
|
||||
|
||||
# Check for high volatility period
|
||||
if calendar.is_high_volatility_period(hours_ahead=24):
|
||||
print("High-impact event imminent!")
|
||||
```
|
||||
|
||||
### `market_data.py` - Market Indicators
|
||||
|
||||
Provides market breadth, sector performance, and sentiment indicators.
|
||||
|
||||
**Features:**
|
||||
- Market sentiment levels (Fear & Greed equivalent)
|
||||
- Market breadth (advancing/declining stocks)
|
||||
- Sector performance tracking
|
||||
- Fear/Greed score calculation
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from src.data.market_data import MarketData
|
||||
|
||||
market_data = MarketData(api_key="your_key")
|
||||
|
||||
# Get market sentiment
|
||||
sentiment = market_data.get_market_sentiment()
|
||||
print(f"Market sentiment: {sentiment.name}")
|
||||
|
||||
# Get full indicators
|
||||
indicators = market_data.get_market_indicators("US")
|
||||
print(f"Sentiment: {indicators.sentiment.name}")
|
||||
print(f"A/D Ratio: {indicators.breadth.advance_decline_ratio}")
|
||||
```
|
||||
|
||||
## Integration with GeminiClient
|
||||
|
||||
The external data sources are seamlessly integrated into the AI decision engine:
|
||||
|
||||
```python
|
||||
from src.brain.gemini_client import GeminiClient
|
||||
from src.data.news_api import NewsAPI
|
||||
from src.data.economic_calendar import EconomicCalendar
|
||||
from src.data.market_data import MarketData
|
||||
from src.config import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# Initialize data sources
|
||||
news_api = NewsAPI(api_key=settings.NEWS_API_KEY, provider=settings.NEWS_API_PROVIDER)
|
||||
calendar = EconomicCalendar()
|
||||
calendar.load_hardcoded_events()
|
||||
market_data = MarketData(api_key=settings.MARKET_DATA_API_KEY)
|
||||
|
||||
# Create enhanced client
|
||||
client = GeminiClient(
|
||||
settings,
|
||||
news_api=news_api,
|
||||
economic_calendar=calendar,
|
||||
market_data=market_data
|
||||
)
|
||||
|
||||
# Make decision with external context
|
||||
market_data_dict = {
|
||||
"stock_code": "AAPL",
|
||||
"current_price": 180.0,
|
||||
"market_name": "US stock market"
|
||||
}
|
||||
|
||||
decision = await client.decide(market_data_dict)
|
||||
```
|
||||
|
||||
The external data is automatically included in the prompt sent to Gemini:
|
||||
|
||||
```
|
||||
Market: US stock market
|
||||
Stock Code: AAPL
|
||||
Current Price: 180.0
|
||||
|
||||
EXTERNAL DATA:
|
||||
News Sentiment: 0.85 (from 10 articles)
|
||||
1. [Reuters] Apple hits record high (sentiment: 0.92)
|
||||
2. [Bloomberg] Strong iPhone sales (sentiment: 0.78)
|
||||
3. [CNBC] Tech sector rallying (sentiment: 0.85)
|
||||
|
||||
Upcoming High-Impact Events: 2 in next 7 days
|
||||
Next: FOMC Meeting (FOMC) on 2026-03-18
|
||||
Earnings: AAPL on 2026-02-10
|
||||
|
||||
Market Sentiment: GREED
|
||||
Advance/Decline Ratio: 2.35
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
Add these to your `.env` file:
|
||||
|
||||
```bash
|
||||
# External Data APIs (optional)
|
||||
NEWS_API_KEY=your_alpha_vantage_key
|
||||
NEWS_API_PROVIDER=alphavantage # or "newsapi"
|
||||
MARKET_DATA_API_KEY=your_market_data_key
|
||||
```
|
||||
|
||||
## API Recommendations
|
||||
|
||||
### Alpha Vantage (News)
|
||||
- **Free tier:** 5 calls/min, 500 calls/day
|
||||
- **Pros:** Provides sentiment scores, no credit card required
|
||||
- **URL:** https://www.alphavantage.co/
|
||||
|
||||
### NewsAPI.org
|
||||
- **Free tier:** 100 requests/day
|
||||
- **Pros:** Large news coverage, easy to use
|
||||
- **Cons:** No sentiment scores (we use keyword heuristics)
|
||||
- **URL:** https://newsapi.org/
|
||||
|
||||
## Caching Strategy
|
||||
|
||||
To minimize API quota usage:
|
||||
|
||||
1. **News:** 5-minute TTL cache per stock
|
||||
2. **Economic Calendar:** Loaded once at startup (hardcoded events)
|
||||
3. **Market Data:** Fetched per decision (lightweight)
|
||||
|
||||
## Graceful Degradation
|
||||
|
||||
The system works gracefully without external data:
|
||||
|
||||
- If no API keys provided → decisions work with just market prices
|
||||
- If API fails → decision continues without external context
|
||||
- If cache expired → attempts refetch, falls back to no data
|
||||
- Errors are logged but never block trading decisions
|
||||
|
||||
## Testing
|
||||
|
||||
All modules have comprehensive test coverage (81%+):
|
||||
|
||||
```bash
|
||||
pytest tests/test_data_integration.py -v --cov=src/data
|
||||
```
|
||||
|
||||
Tests use mocks to avoid requiring real API keys.
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- Twitter/X sentiment analysis
|
||||
- Reddit WallStreetBets sentiment
|
||||
- Options flow data
|
||||
- Insider trading activity
|
||||
- Analyst upgrades/downgrades
|
||||
- Real-time economic data APIs
|
||||
5
src/data/__init__.py
Normal file
5
src/data/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""External data integration for objective decision-making."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ["NewsAPI", "EconomicCalendar", "MarketData"]
|
||||
219
src/data/economic_calendar.py
Normal file
219
src/data/economic_calendar.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""Economic calendar integration for major market events.
|
||||
|
||||
Tracks FOMC meetings, GDP releases, CPI, earnings calendars, and other
|
||||
market-moving events.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EconomicEvent:
|
||||
"""Single economic event."""
|
||||
|
||||
name: str
|
||||
event_type: str # "FOMC", "GDP", "CPI", "EARNINGS", etc.
|
||||
datetime: datetime
|
||||
impact: str # "HIGH", "MEDIUM", "LOW"
|
||||
country: str
|
||||
description: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpcomingEvents:
|
||||
"""Collection of upcoming economic events."""
|
||||
|
||||
events: list[EconomicEvent]
|
||||
high_impact_count: int
|
||||
next_major_event: EconomicEvent | None
|
||||
|
||||
|
||||
class EconomicCalendar:
|
||||
"""Economic calendar with event tracking and impact scoring."""
|
||||
|
||||
def __init__(self, api_key: str | None = None) -> None:
|
||||
"""Initialize economic calendar.
|
||||
|
||||
Args:
|
||||
api_key: API key for calendar provider (None for testing/hardcoded)
|
||||
"""
|
||||
self._api_key = api_key
|
||||
# For now, use hardcoded major events (can be extended with API)
|
||||
self._events: list[EconomicEvent] = []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_upcoming_events(
|
||||
self, days_ahead: int = 7, min_impact: str = "MEDIUM"
|
||||
) -> UpcomingEvents:
|
||||
"""Get upcoming economic events within specified timeframe.
|
||||
|
||||
Args:
|
||||
days_ahead: Number of days to look ahead
|
||||
min_impact: Minimum impact level ("LOW", "MEDIUM", "HIGH")
|
||||
|
||||
Returns:
|
||||
UpcomingEvents with filtered events
|
||||
"""
|
||||
now = datetime.now()
|
||||
end_date = now + timedelta(days=days_ahead)
|
||||
|
||||
# Filter events by timeframe and impact
|
||||
upcoming = [
|
||||
event
|
||||
for event in self._events
|
||||
if now <= event.datetime <= end_date
|
||||
and self._impact_level(event.impact) >= self._impact_level(min_impact)
|
||||
]
|
||||
|
||||
# Sort by datetime
|
||||
upcoming.sort(key=lambda e: e.datetime)
|
||||
|
||||
# Count high-impact events
|
||||
high_impact_count = sum(1 for e in upcoming if e.impact == "HIGH")
|
||||
|
||||
# Get next major event
|
||||
next_major = None
|
||||
for event in upcoming:
|
||||
if event.impact == "HIGH":
|
||||
next_major = event
|
||||
break
|
||||
|
||||
return UpcomingEvents(
|
||||
events=upcoming,
|
||||
high_impact_count=high_impact_count,
|
||||
next_major_event=next_major,
|
||||
)
|
||||
|
||||
def add_event(self, event: EconomicEvent) -> None:
|
||||
"""Add an economic event to the calendar."""
|
||||
self._events.append(event)
|
||||
|
||||
def clear_events(self) -> None:
|
||||
"""Clear all events (useful for testing)."""
|
||||
self._events.clear()
|
||||
|
||||
def get_earnings_date(self, stock_code: str) -> datetime | None:
|
||||
"""Get next earnings date for a stock.
|
||||
|
||||
Args:
|
||||
stock_code: Stock ticker symbol
|
||||
|
||||
Returns:
|
||||
Next earnings datetime or None if not found
|
||||
"""
|
||||
now = datetime.now()
|
||||
earnings_events = [
|
||||
event
|
||||
for event in self._events
|
||||
if event.event_type == "EARNINGS"
|
||||
and stock_code.upper() in event.name.upper()
|
||||
and event.datetime > now
|
||||
]
|
||||
|
||||
if not earnings_events:
|
||||
return None
|
||||
|
||||
# Return earliest upcoming earnings
|
||||
earnings_events.sort(key=lambda e: e.datetime)
|
||||
return earnings_events[0].datetime
|
||||
|
||||
def load_hardcoded_events(self) -> None:
|
||||
"""Load hardcoded major economic events for 2026.
|
||||
|
||||
This is a fallback when no API is available.
|
||||
"""
|
||||
# Major FOMC meetings in 2026 (estimated)
|
||||
fomc_dates = [
|
||||
datetime(2026, 3, 18),
|
||||
datetime(2026, 5, 6),
|
||||
datetime(2026, 6, 17),
|
||||
datetime(2026, 7, 29),
|
||||
datetime(2026, 9, 16),
|
||||
datetime(2026, 11, 4),
|
||||
datetime(2026, 12, 16),
|
||||
]
|
||||
|
||||
for date in fomc_dates:
|
||||
self.add_event(
|
||||
EconomicEvent(
|
||||
name="FOMC Meeting",
|
||||
event_type="FOMC",
|
||||
datetime=date,
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Federal Reserve interest rate decision",
|
||||
)
|
||||
)
|
||||
|
||||
# Quarterly GDP releases (estimated)
|
||||
gdp_dates = [
|
||||
datetime(2026, 4, 28),
|
||||
datetime(2026, 7, 30),
|
||||
datetime(2026, 10, 29),
|
||||
]
|
||||
|
||||
for date in gdp_dates:
|
||||
self.add_event(
|
||||
EconomicEvent(
|
||||
name="US GDP Release",
|
||||
event_type="GDP",
|
||||
datetime=date,
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Quarterly GDP growth rate",
|
||||
)
|
||||
)
|
||||
|
||||
# Monthly CPI releases (12th of each month, estimated)
|
||||
for month in range(1, 13):
|
||||
try:
|
||||
cpi_date = datetime(2026, month, 12)
|
||||
self.add_event(
|
||||
EconomicEvent(
|
||||
name="US CPI Release",
|
||||
event_type="CPI",
|
||||
datetime=cpi_date,
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Consumer Price Index inflation data",
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _impact_level(self, impact: str) -> int:
|
||||
"""Convert impact string to numeric level."""
|
||||
levels = {"LOW": 1, "MEDIUM": 2, "HIGH": 3}
|
||||
return levels.get(impact.upper(), 0)
|
||||
|
||||
def is_high_volatility_period(self, hours_ahead: int = 24) -> bool:
|
||||
"""Check if we're near a high-impact event.
|
||||
|
||||
Args:
|
||||
hours_ahead: Number of hours to look ahead
|
||||
|
||||
Returns:
|
||||
True if high-impact event is imminent
|
||||
"""
|
||||
now = datetime.now()
|
||||
threshold = now + timedelta(hours=hours_ahead)
|
||||
|
||||
for event in self._events:
|
||||
if event.impact == "HIGH" and now <= event.datetime <= threshold:
|
||||
return True
|
||||
|
||||
return False
|
||||
198
src/data/market_data.py
Normal file
198
src/data/market_data.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""Additional market data indicators beyond basic price data.
|
||||
|
||||
Provides market breadth, sector performance, and market sentiment indicators.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MarketSentiment(Enum):
|
||||
"""Overall market sentiment levels."""
|
||||
|
||||
EXTREME_FEAR = 1
|
||||
FEAR = 2
|
||||
NEUTRAL = 3
|
||||
GREED = 4
|
||||
EXTREME_GREED = 5
|
||||
|
||||
|
||||
@dataclass
|
||||
class SectorPerformance:
|
||||
"""Performance metrics for a market sector."""
|
||||
|
||||
sector_name: str
|
||||
daily_change_pct: float
|
||||
weekly_change_pct: float
|
||||
leader_stock: str # Best performing stock in sector
|
||||
laggard_stock: str # Worst performing stock in sector
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarketBreadth:
|
||||
"""Market breadth indicators."""
|
||||
|
||||
advancing_stocks: int
|
||||
declining_stocks: int
|
||||
unchanged_stocks: int
|
||||
new_highs: int
|
||||
new_lows: int
|
||||
advance_decline_ratio: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarketIndicators:
|
||||
"""Aggregated market indicators."""
|
||||
|
||||
sentiment: MarketSentiment
|
||||
breadth: MarketBreadth
|
||||
sector_performance: list[SectorPerformance]
|
||||
vix_level: float | None # Volatility index if available
|
||||
|
||||
|
||||
class MarketData:
|
||||
"""Market data provider for additional indicators."""
|
||||
|
||||
def __init__(self, api_key: str | None = None) -> None:
|
||||
"""Initialize market data provider.
|
||||
|
||||
Args:
|
||||
api_key: API key for data provider (None for testing)
|
||||
"""
|
||||
self._api_key = api_key
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_market_sentiment(self) -> MarketSentiment:
|
||||
"""Get current market sentiment level.
|
||||
|
||||
This is a simplified version. In production, this would integrate
|
||||
with Fear & Greed Index or similar sentiment indicators.
|
||||
|
||||
Returns:
|
||||
MarketSentiment enum value
|
||||
"""
|
||||
# Default to neutral when API not available
|
||||
if self._api_key is None:
|
||||
logger.debug("No market data API key — returning NEUTRAL sentiment")
|
||||
return MarketSentiment.NEUTRAL
|
||||
|
||||
# TODO: Integrate with actual sentiment API
|
||||
return MarketSentiment.NEUTRAL
|
||||
|
||||
def get_market_breadth(self, market: str = "US") -> MarketBreadth | None:
|
||||
"""Get market breadth indicators.
|
||||
|
||||
Args:
|
||||
market: Market code ("US", "KR", etc.)
|
||||
|
||||
Returns:
|
||||
MarketBreadth object or None if unavailable
|
||||
"""
|
||||
if self._api_key is None:
|
||||
logger.debug("No market data API key — returning None for breadth")
|
||||
return None
|
||||
|
||||
# TODO: Integrate with actual market breadth API
|
||||
return None
|
||||
|
||||
def get_sector_performance(
|
||||
self, market: str = "US"
|
||||
) -> list[SectorPerformance]:
|
||||
"""Get sector performance rankings.
|
||||
|
||||
Args:
|
||||
market: Market code ("US", "KR", etc.)
|
||||
|
||||
Returns:
|
||||
List of SectorPerformance objects, sorted by daily change
|
||||
"""
|
||||
if self._api_key is None:
|
||||
logger.debug("No market data API key — returning empty sector list")
|
||||
return []
|
||||
|
||||
# TODO: Integrate with actual sector performance API
|
||||
return []
|
||||
|
||||
def get_market_indicators(self, market: str = "US") -> MarketIndicators:
|
||||
"""Get aggregated market indicators.
|
||||
|
||||
Args:
|
||||
market: Market code ("US", "KR", etc.)
|
||||
|
||||
Returns:
|
||||
MarketIndicators with all available data
|
||||
"""
|
||||
sentiment = self.get_market_sentiment()
|
||||
breadth = self.get_market_breadth(market)
|
||||
sectors = self.get_sector_performance(market)
|
||||
|
||||
# Default breadth if unavailable
|
||||
if breadth is None:
|
||||
breadth = MarketBreadth(
|
||||
advancing_stocks=0,
|
||||
declining_stocks=0,
|
||||
unchanged_stocks=0,
|
||||
new_highs=0,
|
||||
new_lows=0,
|
||||
advance_decline_ratio=1.0,
|
||||
)
|
||||
|
||||
return MarketIndicators(
|
||||
sentiment=sentiment,
|
||||
breadth=breadth,
|
||||
sector_performance=sectors,
|
||||
vix_level=None, # TODO: Add VIX integration
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helper Methods
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def calculate_fear_greed_score(
|
||||
self, breadth: MarketBreadth, vix: float | None = None
|
||||
) -> int:
|
||||
"""Calculate a simple fear/greed score (0-100).
|
||||
|
||||
Args:
|
||||
breadth: Market breadth data
|
||||
vix: VIX level (optional)
|
||||
|
||||
Returns:
|
||||
Score from 0 (extreme fear) to 100 (extreme greed)
|
||||
"""
|
||||
# Start at neutral
|
||||
score = 50
|
||||
|
||||
# Adjust based on advance/decline ratio
|
||||
if breadth.advance_decline_ratio > 1.5:
|
||||
score += 20
|
||||
elif breadth.advance_decline_ratio > 1.0:
|
||||
score += 10
|
||||
elif breadth.advance_decline_ratio < 0.5:
|
||||
score -= 20
|
||||
elif breadth.advance_decline_ratio < 1.0:
|
||||
score -= 10
|
||||
|
||||
# Adjust based on new highs/lows
|
||||
if breadth.new_highs > breadth.new_lows * 2:
|
||||
score += 15
|
||||
elif breadth.new_lows > breadth.new_highs * 2:
|
||||
score -= 15
|
||||
|
||||
# Adjust based on VIX if available
|
||||
if vix is not None:
|
||||
if vix > 30: # High volatility = fear
|
||||
score -= 15
|
||||
elif vix < 15: # Low volatility = complacency/greed
|
||||
score += 10
|
||||
|
||||
# Clamp to 0-100
|
||||
return max(0, min(100, score))
|
||||
316
src/data/news_api.py
Normal file
316
src/data/news_api.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""News API integration with sentiment analysis and caching.
|
||||
|
||||
Fetches real-time news for stocks using free-tier APIs (Alpha Vantage or NewsAPI).
|
||||
Includes 5-minute caching to minimize API quota usage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache entries expire after 5 minutes
|
||||
CACHE_TTL_SECONDS = 300
|
||||
|
||||
|
||||
@dataclass
|
||||
class NewsArticle:
|
||||
"""Single news article with sentiment."""
|
||||
|
||||
title: str
|
||||
summary: str
|
||||
source: str
|
||||
published_at: str
|
||||
sentiment_score: float # -1.0 (negative) to +1.0 (positive)
|
||||
url: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class NewsSentiment:
|
||||
"""Aggregated news sentiment for a stock."""
|
||||
|
||||
stock_code: str
|
||||
articles: list[NewsArticle]
|
||||
avg_sentiment: float # Average sentiment across all articles
|
||||
article_count: int
|
||||
fetched_at: float # Unix timestamp
|
||||
|
||||
|
||||
class NewsAPI:
|
||||
"""News API client with sentiment analysis and caching."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
provider: str = "alphavantage",
|
||||
cache_ttl: int = CACHE_TTL_SECONDS,
|
||||
) -> None:
|
||||
"""Initialize NewsAPI client.
|
||||
|
||||
Args:
|
||||
api_key: API key for the news provider (None for testing)
|
||||
provider: News provider ("alphavantage" or "newsapi")
|
||||
cache_ttl: Cache time-to-live in seconds
|
||||
"""
|
||||
self._api_key = api_key
|
||||
self._provider = provider
|
||||
self._cache_ttl = cache_ttl
|
||||
self._cache: dict[str, NewsSentiment] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_news_sentiment(self, stock_code: str) -> NewsSentiment | None:
|
||||
"""Fetch news sentiment for a stock with caching.
|
||||
|
||||
Args:
|
||||
stock_code: Stock ticker symbol (e.g., "AAPL", "005930")
|
||||
|
||||
Returns:
|
||||
NewsSentiment object or None if fetch fails or API unavailable
|
||||
"""
|
||||
# Check cache first
|
||||
cached = self._get_from_cache(stock_code)
|
||||
if cached is not None:
|
||||
logger.debug("News cache hit for %s", stock_code)
|
||||
return cached
|
||||
|
||||
# API key required for real requests
|
||||
if self._api_key is None:
|
||||
logger.warning("No news API key provided — returning None")
|
||||
return None
|
||||
|
||||
# Fetch from API
|
||||
try:
|
||||
sentiment = await self._fetch_news(stock_code)
|
||||
if sentiment is not None:
|
||||
self._cache[stock_code] = sentiment
|
||||
return sentiment
|
||||
except Exception as exc:
|
||||
logger.error("Failed to fetch news for %s: %s", stock_code, exc)
|
||||
return None
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the news cache (useful for testing)."""
|
||||
self._cache.clear()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Cache Management
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_from_cache(self, stock_code: str) -> NewsSentiment | None:
|
||||
"""Retrieve cached sentiment if not expired."""
|
||||
if stock_code not in self._cache:
|
||||
return None
|
||||
|
||||
cached = self._cache[stock_code]
|
||||
age = time.time() - cached.fetched_at
|
||||
|
||||
if age > self._cache_ttl:
|
||||
logger.debug("News cache expired for %s (age: %.1fs)", stock_code, age)
|
||||
del self._cache[stock_code]
|
||||
return None
|
||||
|
||||
return cached
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# API Fetching
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _fetch_news(self, stock_code: str) -> NewsSentiment | None:
|
||||
"""Fetch news from the provider API."""
|
||||
if self._provider == "alphavantage":
|
||||
return await self._fetch_alphavantage(stock_code)
|
||||
elif self._provider == "newsapi":
|
||||
return await self._fetch_newsapi(stock_code)
|
||||
else:
|
||||
logger.error("Unknown news provider: %s", self._provider)
|
||||
return None
|
||||
|
||||
async def _fetch_alphavantage(self, stock_code: str) -> NewsSentiment | None:
|
||||
"""Fetch news from Alpha Vantage News Sentiment API."""
|
||||
url = "https://www.alphavantage.co/query"
|
||||
params = {
|
||||
"function": "NEWS_SENTIMENT",
|
||||
"tickers": stock_code,
|
||||
"apikey": self._api_key,
|
||||
"limit": 10, # Fetch top 10 articles
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, params=params, timeout=10) as resp:
|
||||
if resp.status != 200:
|
||||
logger.error(
|
||||
"Alpha Vantage API error: HTTP %d", resp.status
|
||||
)
|
||||
return None
|
||||
|
||||
data = await resp.json()
|
||||
return self._parse_alphavantage_response(stock_code, data)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Alpha Vantage request failed: %s", exc)
|
||||
return None
|
||||
|
||||
async def _fetch_newsapi(self, stock_code: str) -> NewsSentiment | None:
|
||||
"""Fetch news from NewsAPI.org."""
|
||||
url = "https://newsapi.org/v2/everything"
|
||||
params = {
|
||||
"q": stock_code,
|
||||
"apiKey": self._api_key,
|
||||
"pageSize": 10,
|
||||
"sortBy": "publishedAt",
|
||||
"language": "en",
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, params=params, timeout=10) as resp:
|
||||
if resp.status != 200:
|
||||
logger.error("NewsAPI error: HTTP %d", resp.status)
|
||||
return None
|
||||
|
||||
data = await resp.json()
|
||||
return self._parse_newsapi_response(stock_code, data)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("NewsAPI request failed: %s", exc)
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Response Parsing
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _parse_alphavantage_response(
|
||||
self, stock_code: str, data: dict[str, Any]
|
||||
) -> NewsSentiment | None:
|
||||
"""Parse Alpha Vantage API response."""
|
||||
if "feed" not in data:
|
||||
logger.warning("No 'feed' key in Alpha Vantage response")
|
||||
return None
|
||||
|
||||
articles: list[NewsArticle] = []
|
||||
for item in data["feed"]:
|
||||
# Extract sentiment for this specific ticker
|
||||
ticker_sentiment = self._extract_ticker_sentiment(item, stock_code)
|
||||
|
||||
article = NewsArticle(
|
||||
title=item.get("title", ""),
|
||||
summary=item.get("summary", "")[:200], # Truncate long summaries
|
||||
source=item.get("source", "Unknown"),
|
||||
published_at=item.get("time_published", ""),
|
||||
sentiment_score=ticker_sentiment,
|
||||
url=item.get("url", ""),
|
||||
)
|
||||
articles.append(article)
|
||||
|
||||
if not articles:
|
||||
return None
|
||||
|
||||
avg_sentiment = sum(a.sentiment_score for a in articles) / len(articles)
|
||||
|
||||
return NewsSentiment(
|
||||
stock_code=stock_code,
|
||||
articles=articles,
|
||||
avg_sentiment=avg_sentiment,
|
||||
article_count=len(articles),
|
||||
fetched_at=time.time(),
|
||||
)
|
||||
|
||||
def _extract_ticker_sentiment(
|
||||
self, item: dict[str, Any], stock_code: str
|
||||
) -> float:
|
||||
"""Extract sentiment score for specific ticker from article."""
|
||||
ticker_sentiments = item.get("ticker_sentiment", [])
|
||||
for ts in ticker_sentiments:
|
||||
if ts.get("ticker", "").upper() == stock_code.upper():
|
||||
# Alpha Vantage provides sentiment_score as string
|
||||
score_str = ts.get("ticker_sentiment_score", "0")
|
||||
try:
|
||||
return float(score_str)
|
||||
except ValueError:
|
||||
return 0.0
|
||||
|
||||
# Fallback to overall sentiment if ticker-specific not found
|
||||
overall_sentiment = item.get("overall_sentiment_score", "0")
|
||||
try:
|
||||
return float(overall_sentiment)
|
||||
except ValueError:
|
||||
return 0.0
|
||||
|
||||
def _parse_newsapi_response(
|
||||
self, stock_code: str, data: dict[str, Any]
|
||||
) -> NewsSentiment | None:
|
||||
"""Parse NewsAPI.org response.
|
||||
|
||||
Note: NewsAPI doesn't provide sentiment scores, so we use a
|
||||
simple heuristic based on title keywords.
|
||||
"""
|
||||
if data.get("status") != "ok" or "articles" not in data:
|
||||
logger.warning("Invalid NewsAPI response")
|
||||
return None
|
||||
|
||||
articles: list[NewsArticle] = []
|
||||
for item in data["articles"]:
|
||||
# Simple sentiment heuristic based on keywords
|
||||
sentiment = self._estimate_sentiment_from_text(
|
||||
item.get("title", "") + " " + item.get("description", "")
|
||||
)
|
||||
|
||||
article = NewsArticle(
|
||||
title=item.get("title", ""),
|
||||
summary=item.get("description", "")[:200],
|
||||
source=item.get("source", {}).get("name", "Unknown"),
|
||||
published_at=item.get("publishedAt", ""),
|
||||
sentiment_score=sentiment,
|
||||
url=item.get("url", ""),
|
||||
)
|
||||
articles.append(article)
|
||||
|
||||
if not articles:
|
||||
return None
|
||||
|
||||
avg_sentiment = sum(a.sentiment_score for a in articles) / len(articles)
|
||||
|
||||
return NewsSentiment(
|
||||
stock_code=stock_code,
|
||||
articles=articles,
|
||||
avg_sentiment=avg_sentiment,
|
||||
article_count=len(articles),
|
||||
fetched_at=time.time(),
|
||||
)
|
||||
|
||||
def _estimate_sentiment_from_text(self, text: str) -> float:
|
||||
"""Simple keyword-based sentiment estimation.
|
||||
|
||||
This is a fallback for APIs that don't provide sentiment scores.
|
||||
Returns a score between -1.0 and +1.0.
|
||||
"""
|
||||
text_lower = text.lower()
|
||||
|
||||
positive_keywords = [
|
||||
"surge", "jump", "gain", "rise", "soar", "rally", "profit",
|
||||
"growth", "upgrade", "beat", "strong", "bullish", "breakthrough",
|
||||
]
|
||||
negative_keywords = [
|
||||
"plunge", "fall", "drop", "decline", "crash", "loss", "weak",
|
||||
"downgrade", "miss", "bearish", "concern", "risk", "warning",
|
||||
]
|
||||
|
||||
positive_count = sum(1 for kw in positive_keywords if kw in text_lower)
|
||||
negative_count = sum(1 for kw in negative_keywords if kw in text_lower)
|
||||
|
||||
total = positive_count + negative_count
|
||||
if total == 0:
|
||||
return 0.0
|
||||
|
||||
# Normalize to -1.0 to +1.0 range
|
||||
return (positive_count - negative_count) / total
|
||||
28
src/db.py
28
src/db.py
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
@@ -38,6 +39,8 @@ def init_db(db_path: str) -> sqlite3.Connection:
|
||||
conn.execute("ALTER TABLE trades ADD COLUMN market TEXT DEFAULT 'KR'")
|
||||
if "exchange_code" not in columns:
|
||||
conn.execute("ALTER TABLE trades ADD COLUMN exchange_code TEXT DEFAULT 'KRX'")
|
||||
if "selection_context" not in columns:
|
||||
conn.execute("ALTER TABLE trades ADD COLUMN selection_context TEXT")
|
||||
|
||||
# Context tree tables for multi-layered memory management
|
||||
conn.execute(
|
||||
@@ -118,15 +121,33 @@ def log_trade(
|
||||
pnl: float = 0.0,
|
||||
market: str = "KR",
|
||||
exchange_code: str = "KRX",
|
||||
selection_context: dict[str, any] | None = None,
|
||||
) -> None:
|
||||
"""Insert a trade record into the database."""
|
||||
"""Insert a trade record into the database.
|
||||
|
||||
Args:
|
||||
conn: Database connection
|
||||
stock_code: Stock code
|
||||
action: Trade action (BUY/SELL/HOLD)
|
||||
confidence: Confidence level (0-100)
|
||||
rationale: AI decision rationale
|
||||
quantity: Number of shares
|
||||
price: Trade price
|
||||
pnl: Profit/loss
|
||||
market: Market code
|
||||
exchange_code: Exchange code
|
||||
selection_context: Scanner selection data (RSI, volume_ratio, signal, score)
|
||||
"""
|
||||
# Serialize selection context to JSON
|
||||
context_json = json.dumps(selection_context) if selection_context else None
|
||||
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO trades (
|
||||
timestamp, stock_code, action, confidence, rationale,
|
||||
quantity, price, pnl, market, exchange_code
|
||||
quantity, price, pnl, market, exchange_code, selection_context
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
datetime.now(UTC).isoformat(),
|
||||
@@ -139,6 +160,7 @@ def log_trade(
|
||||
pnl,
|
||||
market,
|
||||
exchange_code,
|
||||
context_json,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
@@ -23,7 +23,7 @@ from google import genai
|
||||
|
||||
from src.config import Settings
|
||||
from src.db import init_db
|
||||
from src.logging.decision_logger import DecisionLog, DecisionLogger
|
||||
from src.logging.decision_logger import DecisionLogger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
870
src/main.py
870
src/main.py
@@ -10,10 +10,12 @@ import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from src.analysis.scanner import MarketScanner
|
||||
from src.analysis.smart_scanner import ScanCandidate, SmartVolatilityScanner
|
||||
from src.analysis.volatility import VolatilityAnalyzer
|
||||
from src.brain.gemini_client import GeminiClient
|
||||
from src.broker.kis_api import KISBroker
|
||||
@@ -21,36 +23,53 @@ from src.broker.overseas import OverseasBroker
|
||||
from src.config import Settings
|
||||
from src.context.layer import ContextLayer
|
||||
from src.context.store import ContextStore
|
||||
from src.core.criticality import CriticalityAssessor, CriticalityLevel
|
||||
from src.core.criticality import CriticalityAssessor
|
||||
from src.core.priority_queue import PriorityTaskQueue
|
||||
from src.core.risk_manager import CircuitBreakerTripped, RiskManager
|
||||
from src.core.risk_manager import CircuitBreakerTripped, FatFingerRejected, RiskManager
|
||||
from src.db import init_db, log_trade
|
||||
from src.logging.decision_logger import DecisionLogger
|
||||
from src.logging_config import setup_logging
|
||||
from src.markets.schedule import MarketInfo, get_next_market_open, get_open_markets
|
||||
from src.notifications.telegram_client import TelegramClient, TelegramCommandHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Target stock codes to monitor per market
|
||||
WATCHLISTS = {
|
||||
"KR": ["005930", "000660", "035420"], # Samsung, SK Hynix, NAVER
|
||||
"US_NASDAQ": ["AAPL", "MSFT", "GOOGL"], # Example US stocks
|
||||
"US_NYSE": ["JPM", "BAC"], # Example NYSE stocks
|
||||
"JP": ["7203", "6758"], # Toyota, Sony
|
||||
}
|
||||
|
||||
def safe_float(value: str | float | None, default: float = 0.0) -> float:
|
||||
"""Convert to float, handling empty strings and None.
|
||||
|
||||
Args:
|
||||
value: Value to convert (string, float, or None)
|
||||
default: Default value if conversion fails
|
||||
|
||||
Returns:
|
||||
Converted float or default value
|
||||
|
||||
Examples:
|
||||
>>> safe_float("123.45")
|
||||
123.45
|
||||
>>> safe_float("")
|
||||
0.0
|
||||
>>> safe_float(None)
|
||||
0.0
|
||||
>>> safe_float("invalid", 99.0)
|
||||
99.0
|
||||
"""
|
||||
if value is None or value == "":
|
||||
return default
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
return default
|
||||
|
||||
|
||||
TRADE_INTERVAL_SECONDS = 60
|
||||
SCAN_INTERVAL_SECONDS = 60 # Scan markets every 60 seconds
|
||||
MAX_CONNECTION_RETRIES = 3
|
||||
|
||||
# Full stock universe per market (for scanning)
|
||||
# In production, this would be loaded from a database or API
|
||||
STOCK_UNIVERSE = {
|
||||
"KR": ["005930", "000660", "035420", "051910", "005380", "005490"],
|
||||
"US_NASDAQ": ["AAPL", "MSFT", "GOOGL", "AMZN", "NVDA", "TSLA"],
|
||||
"US_NYSE": ["JPM", "BAC", "XOM", "JNJ", "V"],
|
||||
"JP": ["7203", "6758", "9984", "6861"],
|
||||
}
|
||||
# Daily trading mode constants (for Free tier API efficiency)
|
||||
DAILY_TRADE_SESSIONS = 4 # Number of trading sessions per day
|
||||
TRADE_SESSION_INTERVAL_HOURS = 6 # Hours between sessions
|
||||
|
||||
|
||||
async def trading_cycle(
|
||||
@@ -62,8 +81,10 @@ async def trading_cycle(
|
||||
decision_logger: DecisionLogger,
|
||||
context_store: ContextStore,
|
||||
criticality_assessor: CriticalityAssessor,
|
||||
telegram: TelegramClient,
|
||||
market: MarketInfo,
|
||||
stock_code: str,
|
||||
scan_candidates: dict[str, ScanCandidate],
|
||||
) -> None:
|
||||
"""Execute one trading cycle for a single stock."""
|
||||
cycle_start_time = asyncio.get_event_loop().time()
|
||||
@@ -74,16 +95,16 @@ async def trading_cycle(
|
||||
balance_data = await broker.get_balance()
|
||||
|
||||
output2 = balance_data.get("output2", [{}])
|
||||
total_eval = float(output2[0].get("tot_evlu_amt", "0")) if output2 else 0
|
||||
total_cash = float(
|
||||
total_eval = safe_float(output2[0].get("tot_evlu_amt", "0")) if output2 else 0
|
||||
total_cash = safe_float(
|
||||
balance_data.get("output2", [{}])[0].get("dnca_tot_amt", "0")
|
||||
if output2
|
||||
else "0"
|
||||
)
|
||||
purchase_total = float(output2[0].get("pchs_amt_smtl_amt", "0")) if output2 else 0
|
||||
purchase_total = safe_float(output2[0].get("pchs_amt_smtl_amt", "0")) if output2 else 0
|
||||
|
||||
current_price = float(orderbook.get("output1", {}).get("stck_prpr", "0"))
|
||||
foreigner_net = float(orderbook.get("output1", {}).get("frgn_ntby_qty", "0"))
|
||||
current_price = safe_float(orderbook.get("output1", {}).get("stck_prpr", "0"))
|
||||
foreigner_net = safe_float(orderbook.get("output1", {}).get("frgn_ntby_qty", "0"))
|
||||
else:
|
||||
# Overseas market
|
||||
price_data = await overseas_broker.get_overseas_price(
|
||||
@@ -92,11 +113,19 @@ async def trading_cycle(
|
||||
balance_data = await overseas_broker.get_overseas_balance(market.exchange_code)
|
||||
|
||||
output2 = balance_data.get("output2", [{}])
|
||||
total_eval = float(output2[0].get("frcr_evlu_tota", "0")) if output2 else 0
|
||||
total_cash = float(output2[0].get("frcr_dncl_amt_2", "0")) if output2 else 0
|
||||
purchase_total = float(output2[0].get("frcr_buy_amt_smtl", "0")) if output2 else 0
|
||||
# Handle both list and dict response formats
|
||||
if isinstance(output2, list) and output2:
|
||||
balance_info = output2[0]
|
||||
elif isinstance(output2, dict):
|
||||
balance_info = output2
|
||||
else:
|
||||
balance_info = {}
|
||||
|
||||
current_price = float(price_data.get("output", {}).get("last", "0"))
|
||||
total_eval = safe_float(balance_info.get("frcr_evlu_tota", "0") or "0")
|
||||
total_cash = safe_float(balance_info.get("frcr_dncl_amt_2", "0") or "0")
|
||||
purchase_total = safe_float(balance_info.get("frcr_buy_amt_smtl", "0") or "0")
|
||||
|
||||
current_price = safe_float(price_data.get("output", {}).get("last", "0"))
|
||||
foreigner_net = 0.0 # Not available for overseas
|
||||
|
||||
# Calculate daily P&L %
|
||||
@@ -199,11 +228,23 @@ async def trading_cycle(
|
||||
order_amount = current_price * quantity
|
||||
|
||||
# 4. Risk check BEFORE order
|
||||
risk.validate_order(
|
||||
current_pnl_pct=pnl_pct,
|
||||
order_amount=order_amount,
|
||||
total_cash=total_cash,
|
||||
)
|
||||
try:
|
||||
risk.validate_order(
|
||||
current_pnl_pct=pnl_pct,
|
||||
order_amount=order_amount,
|
||||
total_cash=total_cash,
|
||||
)
|
||||
except FatFingerRejected as exc:
|
||||
try:
|
||||
await telegram.notify_fat_finger(
|
||||
stock_code=stock_code,
|
||||
order_amount=exc.order_amount,
|
||||
total_cash=exc.total_cash,
|
||||
max_pct=exc.max_pct,
|
||||
)
|
||||
except Exception as notify_exc:
|
||||
logger.warning("Fat finger notification failed: %s", notify_exc)
|
||||
raise # Re-raise to prevent trade
|
||||
|
||||
# 5. Send order
|
||||
if market.is_domestic:
|
||||
@@ -223,7 +264,30 @@ async def trading_cycle(
|
||||
)
|
||||
logger.info("Order result: %s", result.get("msg1", "OK"))
|
||||
|
||||
# 6. Log trade
|
||||
# 5.5. Notify trade execution
|
||||
try:
|
||||
await telegram.notify_trade_execution(
|
||||
stock_code=stock_code,
|
||||
market=market.name,
|
||||
action=decision.action,
|
||||
quantity=quantity,
|
||||
price=current_price,
|
||||
confidence=decision.confidence,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Telegram notification failed: %s", exc)
|
||||
|
||||
# 6. Log trade with selection context
|
||||
selection_context = None
|
||||
if stock_code in scan_candidates:
|
||||
candidate = scan_candidates[stock_code]
|
||||
selection_context = {
|
||||
"rsi": candidate.rsi,
|
||||
"volume_ratio": candidate.volume_ratio,
|
||||
"signal": candidate.signal,
|
||||
"score": candidate.score,
|
||||
}
|
||||
|
||||
log_trade(
|
||||
conn=db_conn,
|
||||
stock_code=stock_code,
|
||||
@@ -232,6 +296,7 @@ async def trading_cycle(
|
||||
rationale=decision.rationale,
|
||||
market=market.code,
|
||||
exchange_code=market.exchange_code,
|
||||
selection_context=selection_context,
|
||||
)
|
||||
|
||||
# 7. Latency monitoring
|
||||
@@ -256,6 +321,246 @@ async def trading_cycle(
|
||||
)
|
||||
|
||||
|
||||
async def run_daily_session(
|
||||
broker: KISBroker,
|
||||
overseas_broker: OverseasBroker,
|
||||
brain: GeminiClient,
|
||||
risk: RiskManager,
|
||||
db_conn: Any,
|
||||
decision_logger: DecisionLogger,
|
||||
context_store: ContextStore,
|
||||
criticality_assessor: CriticalityAssessor,
|
||||
telegram: TelegramClient,
|
||||
settings: Settings,
|
||||
smart_scanner: SmartVolatilityScanner | None = None,
|
||||
) -> None:
|
||||
"""Execute one daily trading session.
|
||||
|
||||
Designed for API efficiency with Gemini Free tier:
|
||||
- Batch decision making (1 API call per market)
|
||||
- Runs N times per day at fixed intervals
|
||||
- Minimizes API usage while maintaining trading capability
|
||||
"""
|
||||
# Get currently open markets
|
||||
open_markets = get_open_markets(settings.enabled_market_list)
|
||||
|
||||
if not open_markets:
|
||||
logger.info("No markets open for this session")
|
||||
return
|
||||
|
||||
logger.info("Starting daily trading session for %d markets", len(open_markets))
|
||||
|
||||
# Process each open market
|
||||
for market in open_markets:
|
||||
# Dynamic stock discovery via scanner (no static watchlists)
|
||||
try:
|
||||
candidates = await smart_scanner.scan()
|
||||
watchlist = [c.stock_code for c in candidates] if candidates else []
|
||||
except Exception as exc:
|
||||
logger.error("Smart Scanner failed for %s: %s", market.name, exc)
|
||||
watchlist = []
|
||||
|
||||
if not watchlist:
|
||||
logger.info("No scanner candidates for market %s — skipping", market.code)
|
||||
continue
|
||||
|
||||
logger.info("Processing market: %s (%d stocks)", market.name, len(watchlist))
|
||||
|
||||
# Collect market data for all stocks from scanner
|
||||
stocks_data = []
|
||||
for stock_code in watchlist:
|
||||
try:
|
||||
if market.is_domestic:
|
||||
orderbook = await broker.get_orderbook(stock_code)
|
||||
current_price = safe_float(orderbook.get("output1", {}).get("stck_prpr", "0"))
|
||||
foreigner_net = safe_float(
|
||||
orderbook.get("output1", {}).get("frgn_ntby_qty", "0")
|
||||
)
|
||||
else:
|
||||
price_data = await overseas_broker.get_overseas_price(
|
||||
market.exchange_code, stock_code
|
||||
)
|
||||
current_price = safe_float(price_data.get("output", {}).get("last", "0"))
|
||||
foreigner_net = 0.0
|
||||
|
||||
stocks_data.append(
|
||||
{
|
||||
"stock_code": stock_code,
|
||||
"market_name": market.name,
|
||||
"current_price": current_price,
|
||||
"foreigner_net": foreigner_net,
|
||||
}
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to fetch data for %s: %s", stock_code, exc)
|
||||
continue
|
||||
|
||||
if not stocks_data:
|
||||
logger.warning("No valid stock data for market %s", market.code)
|
||||
continue
|
||||
|
||||
# Get batch decisions (1 API call for all stocks in this market)
|
||||
logger.info("Requesting batch decision for %d stocks in %s", len(stocks_data), market.name)
|
||||
decisions = await brain.decide_batch(stocks_data)
|
||||
|
||||
# Get balance data once for the market
|
||||
if market.is_domestic:
|
||||
balance_data = await broker.get_balance()
|
||||
output2 = balance_data.get("output2", [{}])
|
||||
total_eval = safe_float(output2[0].get("tot_evlu_amt", "0")) if output2 else 0
|
||||
total_cash = safe_float(output2[0].get("dnca_tot_amt", "0")) if output2 else 0
|
||||
purchase_total = safe_float(output2[0].get("pchs_amt_smtl_amt", "0")) if output2 else 0
|
||||
else:
|
||||
balance_data = await overseas_broker.get_overseas_balance(market.exchange_code)
|
||||
output2 = balance_data.get("output2", [{}])
|
||||
if isinstance(output2, list) and output2:
|
||||
balance_info = output2[0]
|
||||
elif isinstance(output2, dict):
|
||||
balance_info = output2
|
||||
else:
|
||||
balance_info = {}
|
||||
|
||||
total_eval = safe_float(balance_info.get("frcr_evlu_tota", "0") or "0")
|
||||
total_cash = safe_float(balance_info.get("frcr_dncl_amt_2", "0") or "0")
|
||||
purchase_total = safe_float(balance_info.get("frcr_buy_amt_smtl", "0") or "0")
|
||||
|
||||
# Calculate daily P&L %
|
||||
pnl_pct = (
|
||||
((total_eval - purchase_total) / purchase_total * 100) if purchase_total > 0 else 0.0
|
||||
)
|
||||
|
||||
# Execute decisions for each stock
|
||||
for stock_data in stocks_data:
|
||||
stock_code = stock_data["stock_code"]
|
||||
decision = decisions.get(stock_code)
|
||||
|
||||
if not decision:
|
||||
logger.warning("No decision for %s — skipping", stock_code)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
"Decision for %s (%s): %s (confidence=%d)",
|
||||
stock_code,
|
||||
market.name,
|
||||
decision.action,
|
||||
decision.confidence,
|
||||
)
|
||||
|
||||
# Log decision
|
||||
context_snapshot = {
|
||||
"L1": {
|
||||
"current_price": stock_data["current_price"],
|
||||
"foreigner_net": stock_data["foreigner_net"],
|
||||
},
|
||||
"L2": {
|
||||
"total_eval": total_eval,
|
||||
"total_cash": total_cash,
|
||||
"purchase_total": purchase_total,
|
||||
"pnl_pct": pnl_pct,
|
||||
},
|
||||
}
|
||||
input_data = {
|
||||
"current_price": stock_data["current_price"],
|
||||
"foreigner_net": stock_data["foreigner_net"],
|
||||
"total_eval": total_eval,
|
||||
"total_cash": total_cash,
|
||||
"pnl_pct": pnl_pct,
|
||||
}
|
||||
|
||||
decision_logger.log_decision(
|
||||
stock_code=stock_code,
|
||||
market=market.code,
|
||||
exchange_code=market.exchange_code,
|
||||
action=decision.action,
|
||||
confidence=decision.confidence,
|
||||
rationale=decision.rationale,
|
||||
context_snapshot=context_snapshot,
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
# Execute if actionable
|
||||
if decision.action in ("BUY", "SELL"):
|
||||
quantity = 1
|
||||
order_amount = stock_data["current_price"] * quantity
|
||||
|
||||
# Risk check
|
||||
try:
|
||||
risk.validate_order(
|
||||
current_pnl_pct=pnl_pct,
|
||||
order_amount=order_amount,
|
||||
total_cash=total_cash,
|
||||
)
|
||||
except FatFingerRejected as exc:
|
||||
try:
|
||||
await telegram.notify_fat_finger(
|
||||
stock_code=stock_code,
|
||||
order_amount=exc.order_amount,
|
||||
total_cash=exc.total_cash,
|
||||
max_pct=exc.max_pct,
|
||||
)
|
||||
except Exception as notify_exc:
|
||||
logger.warning("Fat finger notification failed: %s", notify_exc)
|
||||
continue # Skip this order
|
||||
except CircuitBreakerTripped as exc:
|
||||
logger.critical("Circuit breaker tripped — stopping session")
|
||||
try:
|
||||
await telegram.notify_circuit_breaker(
|
||||
pnl_pct=exc.pnl_pct,
|
||||
threshold=exc.threshold,
|
||||
)
|
||||
except Exception as notify_exc:
|
||||
logger.warning("Circuit breaker notification failed: %s", notify_exc)
|
||||
raise
|
||||
|
||||
# Send order
|
||||
try:
|
||||
if market.is_domestic:
|
||||
result = await broker.send_order(
|
||||
stock_code=stock_code,
|
||||
order_type=decision.action,
|
||||
quantity=quantity,
|
||||
price=0, # market order
|
||||
)
|
||||
else:
|
||||
result = await overseas_broker.send_overseas_order(
|
||||
exchange_code=market.exchange_code,
|
||||
stock_code=stock_code,
|
||||
order_type=decision.action,
|
||||
quantity=quantity,
|
||||
price=0.0, # market order
|
||||
)
|
||||
logger.info("Order result: %s", result.get("msg1", "OK"))
|
||||
|
||||
# Notify trade execution
|
||||
try:
|
||||
await telegram.notify_trade_execution(
|
||||
stock_code=stock_code,
|
||||
market=market.name,
|
||||
action=decision.action,
|
||||
quantity=quantity,
|
||||
price=stock_data["current_price"],
|
||||
confidence=decision.confidence,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Telegram notification failed: %s", exc)
|
||||
except Exception as exc:
|
||||
logger.error("Order execution failed for %s: %s", stock_code, exc)
|
||||
continue
|
||||
|
||||
# Log trade
|
||||
log_trade(
|
||||
conn=db_conn,
|
||||
stock_code=stock_code,
|
||||
action=decision.action,
|
||||
confidence=decision.confidence,
|
||||
rationale=decision.rationale,
|
||||
market=market.code,
|
||||
exchange_code=market.exchange_code,
|
||||
)
|
||||
|
||||
logger.info("Daily trading session completed")
|
||||
|
||||
|
||||
async def run(settings: Settings) -> None:
|
||||
"""Main async loop — iterate over open markets on a timer."""
|
||||
broker = KISBroker(settings)
|
||||
@@ -266,6 +571,149 @@ async def run(settings: Settings) -> None:
|
||||
decision_logger = DecisionLogger(db_conn)
|
||||
context_store = ContextStore(db_conn)
|
||||
|
||||
# Initialize Telegram notifications
|
||||
telegram = TelegramClient(
|
||||
bot_token=settings.TELEGRAM_BOT_TOKEN,
|
||||
chat_id=settings.TELEGRAM_CHAT_ID,
|
||||
enabled=settings.TELEGRAM_ENABLED,
|
||||
)
|
||||
|
||||
# Initialize Telegram command handler
|
||||
command_handler = TelegramCommandHandler(telegram)
|
||||
|
||||
# Register basic commands
|
||||
async def handle_help() -> None:
|
||||
"""Handle /help command."""
|
||||
message = (
|
||||
"<b>📖 Available Commands</b>\n\n"
|
||||
"/help - Show available commands\n"
|
||||
"/status - Trading status (mode, markets, P&L)\n"
|
||||
"/positions - Current holdings\n"
|
||||
"/stop - Pause trading\n"
|
||||
"/resume - Resume trading"
|
||||
)
|
||||
await telegram.send_message(message)
|
||||
|
||||
async def handle_stop() -> None:
|
||||
"""Handle /stop command - pause trading."""
|
||||
if not pause_trading.is_set():
|
||||
await telegram.send_message("⏸️ Trading is already paused")
|
||||
return
|
||||
|
||||
pause_trading.clear()
|
||||
logger.info("Trading paused via Telegram command")
|
||||
await telegram.send_message(
|
||||
"<b>⏸️ Trading Paused</b>\n\n"
|
||||
"All trading operations have been suspended.\n"
|
||||
"Use /resume to restart trading."
|
||||
)
|
||||
|
||||
async def handle_resume() -> None:
|
||||
"""Handle /resume command - resume trading."""
|
||||
if pause_trading.is_set():
|
||||
await telegram.send_message("▶️ Trading is already active")
|
||||
return
|
||||
|
||||
pause_trading.set()
|
||||
logger.info("Trading resumed via Telegram command")
|
||||
await telegram.send_message(
|
||||
"<b>▶️ Trading Resumed</b>\n\n"
|
||||
"Trading operations have been restarted."
|
||||
)
|
||||
|
||||
async def handle_status() -> None:
|
||||
"""Handle /status command - show trading status."""
|
||||
try:
|
||||
# Get trading status
|
||||
trading_status = "Active" if pause_trading.is_set() else "Paused"
|
||||
|
||||
# Calculate P&L from balance data
|
||||
try:
|
||||
balance = await broker.get_balance()
|
||||
output2 = balance.get("output2", [{}])
|
||||
if output2:
|
||||
total_eval = safe_float(output2[0].get("tot_evlu_amt", "0"))
|
||||
purchase_total = safe_float(output2[0].get("pchs_amt_smtl_amt", "0"))
|
||||
current_pnl = (
|
||||
((total_eval - purchase_total) / purchase_total * 100)
|
||||
if purchase_total > 0
|
||||
else 0.0
|
||||
)
|
||||
pnl_str = f"{current_pnl:+.2f}%"
|
||||
else:
|
||||
pnl_str = "N/A"
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to get P&L: %s", exc)
|
||||
pnl_str = "N/A"
|
||||
|
||||
# Format market list
|
||||
markets_str = ", ".join(settings.enabled_market_list)
|
||||
|
||||
message = (
|
||||
"<b>📊 Trading Status</b>\n\n"
|
||||
f"<b>Mode:</b> {settings.MODE.upper()}\n"
|
||||
f"<b>Markets:</b> {markets_str}\n"
|
||||
f"<b>Trading:</b> {trading_status}\n\n"
|
||||
f"<b>Current P&L:</b> {pnl_str}\n"
|
||||
f"<b>Circuit Breaker:</b> {risk._cb_threshold:.1f}%"
|
||||
)
|
||||
await telegram.send_message(message)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error in /status handler: %s", exc)
|
||||
await telegram.send_message(
|
||||
"<b>⚠️ Error</b>\n\nFailed to retrieve trading status."
|
||||
)
|
||||
|
||||
async def handle_positions() -> None:
|
||||
"""Handle /positions command - show account summary."""
|
||||
try:
|
||||
# Get account balance
|
||||
balance = await broker.get_balance()
|
||||
output2 = balance.get("output2", [{}])
|
||||
|
||||
if not output2:
|
||||
await telegram.send_message(
|
||||
"<b>💼 Account Summary</b>\n\n"
|
||||
"No balance information available."
|
||||
)
|
||||
return
|
||||
|
||||
# Extract account-level data
|
||||
total_eval = safe_float(output2[0].get("tot_evlu_amt", "0"))
|
||||
total_cash = safe_float(output2[0].get("dnca_tot_amt", "0"))
|
||||
purchase_total = safe_float(output2[0].get("pchs_amt_smtl_amt", "0"))
|
||||
|
||||
# Calculate P&L
|
||||
pnl_pct = (
|
||||
((total_eval - purchase_total) / purchase_total * 100)
|
||||
if purchase_total > 0
|
||||
else 0.0
|
||||
)
|
||||
pnl_sign = "+" if pnl_pct >= 0 else ""
|
||||
|
||||
message = (
|
||||
"<b>💼 Account Summary</b>\n\n"
|
||||
f"<b>Total Evaluation:</b> ₩{total_eval:,.0f}\n"
|
||||
f"<b>Available Cash:</b> ₩{total_cash:,.0f}\n"
|
||||
f"<b>Purchase Total:</b> ₩{purchase_total:,.0f}\n"
|
||||
f"<b>P&L:</b> {pnl_sign}{pnl_pct:.2f}%\n\n"
|
||||
"<i>Note: Individual position details require API enhancement</i>"
|
||||
)
|
||||
await telegram.send_message(message)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error in /positions handler: %s", exc)
|
||||
await telegram.send_message(
|
||||
"<b>⚠️ Error</b>\n\nFailed to retrieve positions."
|
||||
)
|
||||
|
||||
command_handler.register_command("help", handle_help)
|
||||
command_handler.register_command("stop", handle_stop)
|
||||
command_handler.register_command("resume", handle_resume)
|
||||
command_handler.register_command("status", handle_status)
|
||||
command_handler.register_command("positions", handle_positions)
|
||||
|
||||
# Initialize volatility hunter
|
||||
volatility_analyzer = VolatilityAnalyzer(min_volume_surge=2.0, min_price_change=1.0)
|
||||
market_scanner = MarketScanner(
|
||||
@@ -274,8 +722,22 @@ async def run(settings: Settings) -> None:
|
||||
volatility_analyzer=volatility_analyzer,
|
||||
context_store=context_store,
|
||||
top_n=5,
|
||||
max_concurrent_scans=1, # Fully serialized to avoid EGW00201
|
||||
)
|
||||
|
||||
# Initialize smart scanner (Python-first, AI-last pipeline)
|
||||
smart_scanner = SmartVolatilityScanner(
|
||||
broker=broker,
|
||||
volatility_analyzer=volatility_analyzer,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
# Track scan candidates for selection context logging
|
||||
scan_candidates: dict[str, ScanCandidate] = {} # stock_code -> candidate
|
||||
|
||||
# Active stocks per market (dynamically discovered by scanner)
|
||||
active_stocks: dict[str, list[str]] = {} # market_code -> [stock_codes]
|
||||
|
||||
# Initialize latency control system
|
||||
criticality_assessor = CriticalityAssessor(
|
||||
critical_pnl_threshold=-2.5, # Near circuit breaker at -3.0%
|
||||
@@ -289,7 +751,13 @@ async def run(settings: Settings) -> None:
|
||||
# Track last scan time for each market
|
||||
last_scan_time: dict[str, float] = {}
|
||||
|
||||
# Track market open/close state for notifications
|
||||
_market_states: dict[str, bool] = {} # market_code -> is_open
|
||||
|
||||
# Trading control events
|
||||
shutdown = asyncio.Event()
|
||||
pause_trading = asyncio.Event()
|
||||
pause_trading.set() # Default: trading enabled
|
||||
|
||||
def _signal_handler() -> None:
|
||||
logger.info("Shutdown signal received")
|
||||
@@ -299,145 +767,245 @@ async def run(settings: Settings) -> None:
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
loop.add_signal_handler(sig, _signal_handler)
|
||||
|
||||
logger.info("The Ouroboros is alive. Mode: %s", settings.MODE)
|
||||
logger.info("The Ouroboros is alive. Mode: %s, Trading: %s", settings.MODE, settings.TRADE_MODE)
|
||||
logger.info("Enabled markets: %s", settings.enabled_market_list)
|
||||
|
||||
# Notify system startup
|
||||
try:
|
||||
while not shutdown.is_set():
|
||||
# Get currently open markets
|
||||
open_markets = get_open_markets(settings.enabled_market_list)
|
||||
await telegram.notify_system_start(settings.MODE, settings.enabled_market_list)
|
||||
except Exception as exc:
|
||||
logger.warning("System startup notification failed: %s", exc)
|
||||
|
||||
# Start command handler
|
||||
try:
|
||||
await command_handler.start_polling()
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to start command handler: %s", exc)
|
||||
|
||||
try:
|
||||
# Branch based on trading mode
|
||||
if settings.TRADE_MODE == "daily":
|
||||
# Daily trading mode: batch decisions at fixed intervals
|
||||
logger.info(
|
||||
"Daily trading mode: %d sessions every %d hours",
|
||||
settings.DAILY_SESSIONS,
|
||||
settings.SESSION_INTERVAL_HOURS,
|
||||
)
|
||||
|
||||
session_interval = settings.SESSION_INTERVAL_HOURS * 3600 # Convert to seconds
|
||||
|
||||
while not shutdown.is_set():
|
||||
# Wait for trading to be unpaused
|
||||
await pause_trading.wait()
|
||||
|
||||
if not open_markets:
|
||||
# No markets open — wait until next market opens
|
||||
try:
|
||||
next_market, next_open_time = get_next_market_open(
|
||||
settings.enabled_market_list
|
||||
await run_daily_session(
|
||||
broker,
|
||||
overseas_broker,
|
||||
brain,
|
||||
risk,
|
||||
db_conn,
|
||||
decision_logger,
|
||||
context_store,
|
||||
criticality_assessor,
|
||||
telegram,
|
||||
settings,
|
||||
smart_scanner=smart_scanner,
|
||||
)
|
||||
now = datetime.now(UTC)
|
||||
wait_seconds = (next_open_time - now).total_seconds()
|
||||
logger.info(
|
||||
"No markets open. Next market: %s, opens in %.1f hours",
|
||||
next_market.name,
|
||||
wait_seconds / 3600,
|
||||
)
|
||||
await asyncio.wait_for(shutdown.wait(), timeout=wait_seconds)
|
||||
except TimeoutError:
|
||||
continue # Market should be open now
|
||||
except ValueError as exc:
|
||||
logger.error("Failed to find next market open: %s", exc)
|
||||
await asyncio.sleep(TRADE_INTERVAL_SECONDS)
|
||||
continue
|
||||
|
||||
# Process each open market
|
||||
for market in open_markets:
|
||||
if shutdown.is_set():
|
||||
except CircuitBreakerTripped:
|
||||
logger.critical("Circuit breaker tripped — shutting down")
|
||||
shutdown.set()
|
||||
break
|
||||
except Exception as exc:
|
||||
logger.exception("Daily session error: %s", exc)
|
||||
|
||||
# Volatility Hunter: Scan market periodically to update watchlist
|
||||
now_timestamp = asyncio.get_event_loop().time()
|
||||
last_scan = last_scan_time.get(market.code, 0.0)
|
||||
if now_timestamp - last_scan >= SCAN_INTERVAL_SECONDS:
|
||||
# Wait for next session or shutdown
|
||||
logger.info("Next session in %.1f hours", session_interval / 3600)
|
||||
try:
|
||||
await asyncio.wait_for(shutdown.wait(), timeout=session_interval)
|
||||
except TimeoutError:
|
||||
pass # Normal — time for next session
|
||||
|
||||
else:
|
||||
# Realtime trading mode: original per-stock loop
|
||||
logger.info("Realtime trading mode: 60s interval per stock")
|
||||
|
||||
while not shutdown.is_set():
|
||||
# Wait for trading to be unpaused
|
||||
await pause_trading.wait()
|
||||
|
||||
# Get currently open markets
|
||||
open_markets = get_open_markets(settings.enabled_market_list)
|
||||
|
||||
if not open_markets:
|
||||
# Notify market close for any markets that were open
|
||||
for market_code, is_open in list(_market_states.items()):
|
||||
if is_open:
|
||||
try:
|
||||
from src.markets.schedule import MARKETS
|
||||
|
||||
market_info = MARKETS.get(market_code)
|
||||
if market_info:
|
||||
await telegram.notify_market_close(market_info.name, 0.0)
|
||||
except Exception as exc:
|
||||
logger.warning("Market close notification failed: %s", exc)
|
||||
_market_states[market_code] = False
|
||||
|
||||
# No markets open — wait until next market opens
|
||||
try:
|
||||
# Scan all stocks in the universe
|
||||
stock_universe = STOCK_UNIVERSE.get(market.code, [])
|
||||
if stock_universe:
|
||||
logger.info("Volatility Hunter: Scanning %s market", market.name)
|
||||
scan_result = await market_scanner.scan_market(
|
||||
market, stock_universe
|
||||
)
|
||||
|
||||
# Update watchlist with top movers
|
||||
current_watchlist = WATCHLISTS.get(market.code, [])
|
||||
updated_watchlist = market_scanner.get_updated_watchlist(
|
||||
current_watchlist,
|
||||
scan_result,
|
||||
max_replacements=2,
|
||||
)
|
||||
WATCHLISTS[market.code] = updated_watchlist
|
||||
|
||||
logger.info(
|
||||
"Volatility Hunter: Watchlist updated for %s (%d top movers, %d breakouts)",
|
||||
market.name,
|
||||
len(scan_result.top_movers),
|
||||
len(scan_result.breakouts),
|
||||
)
|
||||
|
||||
last_scan_time[market.code] = now_timestamp
|
||||
except Exception as exc:
|
||||
logger.error("Volatility Hunter scan failed for %s: %s", market.name, exc)
|
||||
|
||||
# Get watchlist for this market
|
||||
watchlist = WATCHLISTS.get(market.code, [])
|
||||
if not watchlist:
|
||||
logger.debug("No watchlist for market %s", market.code)
|
||||
next_market, next_open_time = get_next_market_open(
|
||||
settings.enabled_market_list
|
||||
)
|
||||
now = datetime.now(UTC)
|
||||
wait_seconds = (next_open_time - now).total_seconds()
|
||||
logger.info(
|
||||
"No markets open. Next market: %s, opens in %.1f hours",
|
||||
next_market.name,
|
||||
wait_seconds / 3600,
|
||||
)
|
||||
await asyncio.wait_for(shutdown.wait(), timeout=wait_seconds)
|
||||
except TimeoutError:
|
||||
continue # Market should be open now
|
||||
except ValueError as exc:
|
||||
logger.error("Failed to find next market open: %s", exc)
|
||||
await asyncio.sleep(TRADE_INTERVAL_SECONDS)
|
||||
continue
|
||||
|
||||
logger.info("Processing market: %s (%d stocks)", market.name, len(watchlist))
|
||||
|
||||
# Process each stock in the watchlist
|
||||
for stock_code in watchlist:
|
||||
# Process each open market
|
||||
for market in open_markets:
|
||||
if shutdown.is_set():
|
||||
break
|
||||
|
||||
# Retry logic for connection errors
|
||||
for attempt in range(1, MAX_CONNECTION_RETRIES + 1):
|
||||
# Notify market open if it just opened
|
||||
if not _market_states.get(market.code, False):
|
||||
try:
|
||||
await trading_cycle(
|
||||
broker,
|
||||
overseas_broker,
|
||||
brain,
|
||||
risk,
|
||||
db_conn,
|
||||
decision_logger,
|
||||
context_store,
|
||||
criticality_assessor,
|
||||
market,
|
||||
stock_code,
|
||||
)
|
||||
break # Success — exit retry loop
|
||||
except CircuitBreakerTripped:
|
||||
logger.critical("Circuit breaker tripped — shutting down")
|
||||
raise
|
||||
except ConnectionError as exc:
|
||||
if attempt < MAX_CONNECTION_RETRIES:
|
||||
logger.warning(
|
||||
"Connection error for %s (attempt %d/%d): %s",
|
||||
stock_code,
|
||||
attempt,
|
||||
MAX_CONNECTION_RETRIES,
|
||||
exc,
|
||||
)
|
||||
await asyncio.sleep(2**attempt) # Exponential backoff
|
||||
else:
|
||||
logger.error(
|
||||
"Connection error for %s (all retries exhausted): %s",
|
||||
stock_code,
|
||||
exc,
|
||||
)
|
||||
break # Give up on this stock
|
||||
await telegram.notify_market_open(market.name)
|
||||
except Exception as exc:
|
||||
logger.exception("Unexpected error for %s: %s", stock_code, exc)
|
||||
break # Don't retry on unexpected errors
|
||||
logger.warning("Market open notification failed: %s", exc)
|
||||
_market_states[market.code] = True
|
||||
|
||||
# Log priority queue metrics periodically
|
||||
metrics = await priority_queue.get_metrics()
|
||||
if metrics.total_enqueued > 0:
|
||||
logger.info(
|
||||
"Priority queue metrics: enqueued=%d, dequeued=%d, size=%d, timeouts=%d, errors=%d",
|
||||
metrics.total_enqueued,
|
||||
metrics.total_dequeued,
|
||||
metrics.current_size,
|
||||
metrics.total_timeouts,
|
||||
metrics.total_errors,
|
||||
)
|
||||
# Smart Scanner: dynamic stock discovery (no static watchlists)
|
||||
now_timestamp = asyncio.get_event_loop().time()
|
||||
last_scan = last_scan_time.get(market.code, 0.0)
|
||||
if now_timestamp - last_scan >= SCAN_INTERVAL_SECONDS:
|
||||
try:
|
||||
logger.info("Smart Scanner: Scanning %s market", market.name)
|
||||
|
||||
# Wait for next cycle or shutdown
|
||||
try:
|
||||
await asyncio.wait_for(shutdown.wait(), timeout=TRADE_INTERVAL_SECONDS)
|
||||
except TimeoutError:
|
||||
pass # Normal — timeout means it's time for next cycle
|
||||
candidates = await smart_scanner.scan()
|
||||
|
||||
if candidates:
|
||||
# Use scanner results directly as trading candidates
|
||||
active_stocks[market.code] = smart_scanner.get_stock_codes(
|
||||
candidates
|
||||
)
|
||||
|
||||
# Store candidates for selection context logging
|
||||
for candidate in candidates:
|
||||
scan_candidates[candidate.stock_code] = candidate
|
||||
|
||||
logger.info(
|
||||
"Smart Scanner: Found %d candidates for %s: %s",
|
||||
len(candidates),
|
||||
market.name,
|
||||
[f"{c.stock_code}(RSI={c.rsi:.0f})" for c in candidates],
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Smart Scanner: No candidates for %s — no trades", market.name
|
||||
)
|
||||
active_stocks[market.code] = []
|
||||
|
||||
last_scan_time[market.code] = now_timestamp
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Smart Scanner failed for %s: %s", market.name, exc)
|
||||
|
||||
# Get active stocks from scanner (dynamic, no static fallback)
|
||||
stock_codes = active_stocks.get(market.code, [])
|
||||
if not stock_codes:
|
||||
logger.debug("No active stocks for market %s", market.code)
|
||||
continue
|
||||
|
||||
logger.info("Processing market: %s (%d stocks)", market.name, len(stock_codes))
|
||||
|
||||
# Process each stock from scanner results
|
||||
for stock_code in stock_codes:
|
||||
if shutdown.is_set():
|
||||
break
|
||||
|
||||
# Retry logic for connection errors
|
||||
for attempt in range(1, MAX_CONNECTION_RETRIES + 1):
|
||||
try:
|
||||
await trading_cycle(
|
||||
broker,
|
||||
overseas_broker,
|
||||
brain,
|
||||
risk,
|
||||
db_conn,
|
||||
decision_logger,
|
||||
context_store,
|
||||
criticality_assessor,
|
||||
telegram,
|
||||
market,
|
||||
stock_code,
|
||||
scan_candidates,
|
||||
)
|
||||
break # Success — exit retry loop
|
||||
except CircuitBreakerTripped as exc:
|
||||
logger.critical("Circuit breaker tripped — shutting down")
|
||||
try:
|
||||
await telegram.notify_circuit_breaker(
|
||||
pnl_pct=exc.pnl_pct,
|
||||
threshold=exc.threshold,
|
||||
)
|
||||
except Exception as notify_exc:
|
||||
logger.warning(
|
||||
"Circuit breaker notification failed: %s", notify_exc
|
||||
)
|
||||
raise
|
||||
except ConnectionError as exc:
|
||||
if attempt < MAX_CONNECTION_RETRIES:
|
||||
logger.warning(
|
||||
"Connection error for %s (attempt %d/%d): %s",
|
||||
stock_code,
|
||||
attempt,
|
||||
MAX_CONNECTION_RETRIES,
|
||||
exc,
|
||||
)
|
||||
await asyncio.sleep(2**attempt) # Exponential backoff
|
||||
else:
|
||||
logger.error(
|
||||
"Connection error for %s (all retries exhausted): %s",
|
||||
stock_code,
|
||||
exc,
|
||||
)
|
||||
break # Give up on this stock
|
||||
except Exception as exc:
|
||||
logger.exception("Unexpected error for %s: %s", stock_code, exc)
|
||||
break # Don't retry on unexpected errors
|
||||
|
||||
# Log priority queue metrics periodically
|
||||
metrics = await priority_queue.get_metrics()
|
||||
if metrics.total_enqueued > 0:
|
||||
logger.info(
|
||||
"Priority queue metrics: enqueued=%d, dequeued=%d, size=%d, timeouts=%d, errors=%d",
|
||||
metrics.total_enqueued,
|
||||
metrics.total_dequeued,
|
||||
metrics.current_size,
|
||||
metrics.total_timeouts,
|
||||
metrics.total_errors,
|
||||
)
|
||||
|
||||
# Wait for next cycle or shutdown
|
||||
try:
|
||||
await asyncio.wait_for(shutdown.wait(), timeout=TRADE_INTERVAL_SECONDS)
|
||||
except TimeoutError:
|
||||
pass # Normal — timeout means it's time for next cycle
|
||||
finally:
|
||||
# Clean up resources
|
||||
await command_handler.stop_polling()
|
||||
await broker.close()
|
||||
await telegram.close()
|
||||
db_conn.close()
|
||||
logger.info("The Ouroboros rests.")
|
||||
|
||||
|
||||
350
src/notifications/README.md
Normal file
350
src/notifications/README.md
Normal file
@@ -0,0 +1,350 @@
|
||||
# Telegram Notifications
|
||||
|
||||
Real-time trading event notifications via Telegram Bot API.
|
||||
|
||||
## Setup
|
||||
|
||||
### 1. Create a Telegram Bot
|
||||
|
||||
1. Open Telegram and message [@BotFather](https://t.me/BotFather)
|
||||
2. Send `/newbot` command
|
||||
3. Follow prompts to name your bot
|
||||
4. Save the **bot token** (looks like `1234567890:ABCdefGHIjklMNOpqrsTUVwxyz`)
|
||||
|
||||
### 2. Get Your Chat ID
|
||||
|
||||
**Option A: Using @userinfobot**
|
||||
1. Message [@userinfobot](https://t.me/userinfobot) on Telegram
|
||||
2. Send `/start`
|
||||
3. Save your numeric **chat ID** (e.g., `123456789`)
|
||||
|
||||
**Option B: Using @RawDataBot**
|
||||
1. Message [@RawDataBot](https://t.me/rawdatabot) on Telegram
|
||||
2. Look for `"id":` in the JSON response
|
||||
3. Save your numeric **chat ID**
|
||||
|
||||
### 3. Configure Environment
|
||||
|
||||
Add to your `.env` file:
|
||||
|
||||
```bash
|
||||
TELEGRAM_BOT_TOKEN=1234567890:ABCdefGHIjklMNOpqrsTUVwxyz
|
||||
TELEGRAM_CHAT_ID=123456789
|
||||
TELEGRAM_ENABLED=true
|
||||
```
|
||||
|
||||
### 4. Test the Bot
|
||||
|
||||
Start a conversation with your bot on Telegram first (send `/start`), then run:
|
||||
|
||||
```bash
|
||||
python -m src.main --mode=paper
|
||||
```
|
||||
|
||||
You should receive a startup notification.
|
||||
|
||||
## Message Examples
|
||||
|
||||
### Trade Execution
|
||||
```
|
||||
🟢 BUY
|
||||
Symbol: AAPL (United States)
|
||||
Quantity: 10 shares
|
||||
Price: 150.25
|
||||
Confidence: 85%
|
||||
```
|
||||
|
||||
### Circuit Breaker
|
||||
```
|
||||
🚨 CIRCUIT BREAKER TRIPPED
|
||||
P&L: -3.15% (threshold: -3.0%)
|
||||
Trading halted for safety
|
||||
```
|
||||
|
||||
### Fat-Finger Protection
|
||||
```
|
||||
⚠️ Fat-Finger Protection
|
||||
Order rejected: TSLA
|
||||
Attempted: 45.0% of cash
|
||||
Max allowed: 30%
|
||||
Amount: 45,000 / 100,000
|
||||
```
|
||||
|
||||
### Market Open/Close
|
||||
```
|
||||
ℹ️ Market Open
|
||||
Korea trading session started
|
||||
|
||||
ℹ️ Market Close
|
||||
Korea trading session ended
|
||||
📈 P&L: +1.25%
|
||||
```
|
||||
|
||||
### System Status
|
||||
```
|
||||
📝 System Started
|
||||
Mode: PAPER
|
||||
Markets: KRX, NASDAQ
|
||||
|
||||
System Shutdown
|
||||
Normal shutdown
|
||||
```
|
||||
|
||||
## Notification Priorities
|
||||
|
||||
| Priority | Emoji | Use Case |
|
||||
|----------|-------|----------|
|
||||
| LOW | ℹ️ | Market open/close |
|
||||
| MEDIUM | 📊 | Trade execution, system start/stop |
|
||||
| HIGH | ⚠️ | Fat-finger protection, errors |
|
||||
| CRITICAL | 🚨 | Circuit breaker trips |
|
||||
|
||||
## Rate Limiting
|
||||
|
||||
- Default: 1 message per second
|
||||
- Prevents hitting Telegram's global rate limits
|
||||
- Configurable via `rate_limit` parameter
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### No notifications received
|
||||
|
||||
1. **Check bot configuration**
|
||||
```bash
|
||||
# Verify env variables are set
|
||||
grep TELEGRAM .env
|
||||
```
|
||||
|
||||
2. **Start conversation with bot**
|
||||
- Open bot in Telegram
|
||||
- Send `/start` command
|
||||
- Bot cannot message users who haven't started a conversation
|
||||
|
||||
3. **Check logs**
|
||||
```bash
|
||||
# Look for Telegram-related errors
|
||||
python -m src.main --mode=paper 2>&1 | grep -i telegram
|
||||
```
|
||||
|
||||
4. **Verify bot token**
|
||||
```bash
|
||||
curl https://api.telegram.org/bot<YOUR_TOKEN>/getMe
|
||||
# Should return bot info (not 401 error)
|
||||
```
|
||||
|
||||
5. **Verify chat ID**
|
||||
```bash
|
||||
curl -X POST https://api.telegram.org/bot<YOUR_TOKEN>/sendMessage \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"chat_id": "<YOUR_CHAT_ID>", "text": "Test"}'
|
||||
# Should send a test message
|
||||
```
|
||||
|
||||
### Notifications delayed
|
||||
|
||||
- Check rate limiter settings
|
||||
- Verify network connection
|
||||
- Look for timeout errors in logs
|
||||
|
||||
### "Chat not found" error
|
||||
|
||||
- Incorrect chat ID
|
||||
- Bot blocked by user
|
||||
- Need to send `/start` to bot first
|
||||
|
||||
### "Unauthorized" error
|
||||
|
||||
- Invalid bot token
|
||||
- Token revoked (regenerate with @BotFather)
|
||||
|
||||
## Graceful Degradation
|
||||
|
||||
The system works without Telegram notifications:
|
||||
|
||||
- Missing credentials → notifications disabled automatically
|
||||
- API errors → logged but trading continues
|
||||
- Network timeouts → trading loop unaffected
|
||||
- Rate limiting → messages queued, trading proceeds
|
||||
|
||||
**Notifications never crash the trading system.**
|
||||
|
||||
## Security Notes
|
||||
|
||||
- Never commit `.env` file with credentials
|
||||
- Bot token grants full bot control
|
||||
- Chat ID is not sensitive (just a number)
|
||||
- Messages are sent over HTTPS
|
||||
- No trading credentials in notifications
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Group Notifications
|
||||
|
||||
1. Add bot to Telegram group
|
||||
2. Get group chat ID (negative number like `-123456789`)
|
||||
3. Use group chat ID in `TELEGRAM_CHAT_ID`
|
||||
|
||||
### Multiple Recipients
|
||||
|
||||
Create multiple bots or use a broadcast group with multiple members.
|
||||
|
||||
### Custom Rate Limits
|
||||
|
||||
Not currently exposed in config, but can be modified in code:
|
||||
|
||||
```python
|
||||
telegram = TelegramClient(
|
||||
bot_token=settings.TELEGRAM_BOT_TOKEN,
|
||||
chat_id=settings.TELEGRAM_CHAT_ID,
|
||||
rate_limit=2.0, # 2 messages per second
|
||||
)
|
||||
```
|
||||
|
||||
## Bidirectional Commands
|
||||
|
||||
Control your trading bot remotely via Telegram commands. The bot not only sends notifications but also accepts commands for real-time control.
|
||||
|
||||
### Available Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/start` | Welcome message with quick start guide |
|
||||
| `/help` | List all available commands |
|
||||
| `/status` | Current trading status (mode, markets, P&L, circuit breaker) |
|
||||
| `/positions` | View current holdings grouped by market |
|
||||
| `/stop` | Pause all trading operations |
|
||||
| `/resume` | Resume trading operations |
|
||||
|
||||
### Command Examples
|
||||
|
||||
**Check Trading Status**
|
||||
```
|
||||
You: /status
|
||||
|
||||
Bot:
|
||||
📊 Trading Status
|
||||
|
||||
Mode: PAPER
|
||||
Markets: Korea, United States
|
||||
Trading: Active
|
||||
|
||||
Current P&L: +2.50%
|
||||
Circuit Breaker: -3.0%
|
||||
```
|
||||
|
||||
**View Holdings**
|
||||
```
|
||||
You: /positions
|
||||
|
||||
Bot:
|
||||
💼 Current Holdings
|
||||
|
||||
🇰🇷 Korea
|
||||
• 005930: 10 shares @ 70,000
|
||||
• 035420: 5 shares @ 200,000
|
||||
|
||||
🇺🇸 Overseas
|
||||
• AAPL: 15 shares @ 175
|
||||
• TSLA: 8 shares @ 245
|
||||
|
||||
Cash: ₩5,000,000
|
||||
```
|
||||
|
||||
**Pause Trading**
|
||||
```
|
||||
You: /stop
|
||||
|
||||
Bot:
|
||||
⏸️ Trading Paused
|
||||
|
||||
All trading operations have been suspended.
|
||||
Use /resume to restart trading.
|
||||
```
|
||||
|
||||
**Resume Trading**
|
||||
```
|
||||
You: /resume
|
||||
|
||||
Bot:
|
||||
▶️ Trading Resumed
|
||||
|
||||
Trading operations have been restarted.
|
||||
```
|
||||
|
||||
### Security
|
||||
|
||||
**Chat ID Verification**
|
||||
- Commands are only accepted from the configured `TELEGRAM_CHAT_ID`
|
||||
- Unauthorized users receive no response
|
||||
- Command attempts from wrong chat IDs are logged
|
||||
|
||||
**Authorization Required**
|
||||
- Only the bot owner (chat ID in `.env`) can control trading
|
||||
- No way for unauthorized users to discover or use commands
|
||||
- All command executions are logged for audit
|
||||
|
||||
### Configuration
|
||||
|
||||
Add to your `.env` file:
|
||||
|
||||
```bash
|
||||
# Commands are enabled by default
|
||||
TELEGRAM_COMMANDS_ENABLED=true
|
||||
|
||||
# Polling interval (seconds) - how often to check for commands
|
||||
TELEGRAM_POLLING_INTERVAL=1.0
|
||||
```
|
||||
|
||||
To disable commands but keep notifications:
|
||||
```bash
|
||||
TELEGRAM_COMMANDS_ENABLED=false
|
||||
```
|
||||
|
||||
### How It Works
|
||||
|
||||
1. **Long Polling**: Bot checks Telegram API every second for new messages
|
||||
2. **Command Parsing**: Messages starting with `/` are parsed as commands
|
||||
3. **Authentication**: Chat ID is verified before executing any command
|
||||
4. **Execution**: Command handler is called with current bot state
|
||||
5. **Response**: Result is sent back via Telegram
|
||||
|
||||
### Error Handling
|
||||
|
||||
- Command parsing errors → "Unknown command" response
|
||||
- API failures → Graceful degradation, error logged
|
||||
- Invalid state → Appropriate message (e.g., "Trading is already paused")
|
||||
- Trading loop isolation → Command errors never crash trading
|
||||
|
||||
### Troubleshooting Commands
|
||||
|
||||
**Commands not responding**
|
||||
1. Check `TELEGRAM_COMMANDS_ENABLED=true` in `.env`
|
||||
2. Verify you started conversation with `/start`
|
||||
3. Check logs for command handler errors
|
||||
4. Confirm chat ID matches `.env` configuration
|
||||
|
||||
**Wrong chat ID**
|
||||
- Commands from unauthorized chats are silently ignored
|
||||
- Check logs for "unauthorized chat_id" warnings
|
||||
|
||||
**Delayed responses**
|
||||
- Polling interval is 1 second by default
|
||||
- Network latency may add delay
|
||||
- Check `TELEGRAM_POLLING_INTERVAL` setting
|
||||
|
||||
## API Reference
|
||||
|
||||
See `telegram_client.py` for full API documentation.
|
||||
|
||||
### Notification Methods
|
||||
- `notify_trade_execution()` - Trade alerts
|
||||
- `notify_circuit_breaker()` - Emergency stops
|
||||
- `notify_fat_finger()` - Order rejections
|
||||
- `notify_market_open/close()` - Session tracking
|
||||
- `notify_system_start/shutdown()` - Lifecycle events
|
||||
- `notify_error()` - Error alerts
|
||||
|
||||
### Command Handler
|
||||
- `TelegramCommandHandler` - Bidirectional command processing
|
||||
- `register_command()` - Register custom command handlers
|
||||
- `start_polling()` / `stop_polling()` - Lifecycle management
|
||||
5
src/notifications/__init__.py
Normal file
5
src/notifications/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Real-time notifications for trading events."""
|
||||
|
||||
from src.notifications.telegram_client import TelegramClient
|
||||
|
||||
__all__ = ["TelegramClient"]
|
||||
511
src/notifications/telegram_client.py
Normal file
511
src/notifications/telegram_client.py
Normal file
@@ -0,0 +1,511 @@
|
||||
"""Telegram notification client for real-time trading alerts."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
import aiohttp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NotificationPriority(Enum):
|
||||
"""Priority levels for notifications with emoji indicators."""
|
||||
|
||||
LOW = ("ℹ️", "info")
|
||||
MEDIUM = ("📊", "medium")
|
||||
HIGH = ("⚠️", "warning")
|
||||
CRITICAL = ("🚨", "critical")
|
||||
|
||||
def __init__(self, emoji: str, label: str) -> None:
|
||||
self.emoji = emoji
|
||||
self.label = label
|
||||
|
||||
|
||||
class LeakyBucket:
|
||||
"""Rate limiter using leaky bucket algorithm."""
|
||||
|
||||
def __init__(self, rate: float, capacity: int = 1) -> None:
|
||||
"""
|
||||
Initialize rate limiter.
|
||||
|
||||
Args:
|
||||
rate: Maximum requests per second
|
||||
capacity: Bucket capacity (burst size)
|
||||
"""
|
||||
self._rate = rate
|
||||
self._capacity = capacity
|
||||
self._tokens = float(capacity)
|
||||
self._last_update = time.monotonic()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def acquire(self) -> None:
|
||||
"""Wait until a token is available, then consume it."""
|
||||
async with self._lock:
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_update
|
||||
self._tokens = min(self._capacity, self._tokens + elapsed * self._rate)
|
||||
self._last_update = now
|
||||
|
||||
if self._tokens < 1.0:
|
||||
wait_time = (1.0 - self._tokens) / self._rate
|
||||
await asyncio.sleep(wait_time)
|
||||
self._tokens = 0.0
|
||||
else:
|
||||
self._tokens -= 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class NotificationMessage:
|
||||
"""Internal notification message structure."""
|
||||
|
||||
priority: NotificationPriority
|
||||
message: str
|
||||
|
||||
|
||||
class TelegramClient:
|
||||
"""Telegram Bot API client for sending trading notifications."""
|
||||
|
||||
API_BASE = "https://api.telegram.org/bot{token}"
|
||||
DEFAULT_TIMEOUT = 5.0 # seconds
|
||||
DEFAULT_RATE = 1.0 # messages per second
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bot_token: str | None = None,
|
||||
chat_id: str | None = None,
|
||||
enabled: bool = True,
|
||||
rate_limit: float = DEFAULT_RATE,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize Telegram client.
|
||||
|
||||
Args:
|
||||
bot_token: Telegram bot token from @BotFather
|
||||
chat_id: Target chat ID (user or group)
|
||||
enabled: Enable/disable notifications globally
|
||||
rate_limit: Maximum messages per second
|
||||
"""
|
||||
self._bot_token = bot_token
|
||||
self._chat_id = chat_id
|
||||
self._enabled = enabled
|
||||
self._rate_limiter = LeakyBucket(rate=rate_limit)
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
|
||||
if not enabled:
|
||||
logger.info("Telegram notifications disabled via configuration")
|
||||
elif bot_token is None or chat_id is None:
|
||||
logger.warning(
|
||||
"Telegram notifications disabled (missing bot_token or chat_id)"
|
||||
)
|
||||
self._enabled = False
|
||||
else:
|
||||
logger.info("Telegram notifications enabled for chat_id=%s", chat_id)
|
||||
|
||||
def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create aiohttp session."""
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=self.DEFAULT_TIMEOUT)
|
||||
)
|
||||
return self._session
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close HTTP session."""
|
||||
if self._session is not None and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
async def send_message(self, text: str, parse_mode: str = "HTML") -> bool:
|
||||
"""
|
||||
Send a generic text message to Telegram.
|
||||
|
||||
Args:
|
||||
text: Message text to send
|
||||
parse_mode: Parse mode for formatting (HTML or Markdown)
|
||||
|
||||
Returns:
|
||||
True if message was sent successfully, False otherwise
|
||||
"""
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
try:
|
||||
await self._rate_limiter.acquire()
|
||||
|
||||
url = f"{self.API_BASE.format(token=self._bot_token)}/sendMessage"
|
||||
payload = {
|
||||
"chat_id": self._chat_id,
|
||||
"text": text,
|
||||
"parse_mode": parse_mode,
|
||||
}
|
||||
|
||||
session = self._get_session()
|
||||
async with session.post(url, json=payload) as resp:
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
logger.error(
|
||||
"Telegram API error (status=%d): %s", resp.status, error_text
|
||||
)
|
||||
return False
|
||||
logger.debug("Telegram message sent: %s", text[:50])
|
||||
return True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Telegram message timeout")
|
||||
return False
|
||||
except aiohttp.ClientError as exc:
|
||||
logger.error("Telegram message failed: %s", exc)
|
||||
return False
|
||||
except Exception as exc:
|
||||
logger.error("Unexpected error sending message: %s", exc)
|
||||
return False
|
||||
|
||||
async def _send_notification(self, msg: NotificationMessage) -> None:
|
||||
"""
|
||||
Send notification to Telegram with graceful degradation.
|
||||
|
||||
Args:
|
||||
msg: Notification message to send
|
||||
"""
|
||||
formatted_message = f"{msg.priority.emoji} {msg.message}"
|
||||
await self.send_message(formatted_message)
|
||||
|
||||
async def notify_trade_execution(
|
||||
self,
|
||||
stock_code: str,
|
||||
market: str,
|
||||
action: str,
|
||||
quantity: int,
|
||||
price: float,
|
||||
confidence: float,
|
||||
) -> None:
|
||||
"""
|
||||
Notify trade execution.
|
||||
|
||||
Args:
|
||||
stock_code: Stock ticker symbol
|
||||
market: Market name (e.g., "Korea", "United States")
|
||||
action: "BUY" or "SELL"
|
||||
quantity: Number of shares
|
||||
price: Execution price
|
||||
confidence: AI confidence level (0-100)
|
||||
"""
|
||||
emoji = "🟢" if action == "BUY" else "🔴"
|
||||
message = (
|
||||
f"<b>{emoji} {action}</b>\n"
|
||||
f"Symbol: <code>{stock_code}</code> ({market})\n"
|
||||
f"Quantity: {quantity:,} shares\n"
|
||||
f"Price: {price:,.2f}\n"
|
||||
f"Confidence: {confidence:.0f}%"
|
||||
)
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=NotificationPriority.MEDIUM, message=message)
|
||||
)
|
||||
|
||||
async def notify_market_open(self, market_name: str) -> None:
|
||||
"""
|
||||
Notify market opening.
|
||||
|
||||
Args:
|
||||
market_name: Name of the market (e.g., "Korea", "United States")
|
||||
"""
|
||||
message = f"<b>Market Open</b>\n{market_name} trading session started"
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=NotificationPriority.LOW, message=message)
|
||||
)
|
||||
|
||||
async def notify_market_close(self, market_name: str, pnl_pct: float) -> None:
|
||||
"""
|
||||
Notify market closing.
|
||||
|
||||
Args:
|
||||
market_name: Name of the market
|
||||
pnl_pct: Final P&L percentage for the session
|
||||
"""
|
||||
pnl_sign = "+" if pnl_pct >= 0 else ""
|
||||
pnl_emoji = "📈" if pnl_pct >= 0 else "📉"
|
||||
message = (
|
||||
f"<b>Market Close</b>\n"
|
||||
f"{market_name} trading session ended\n"
|
||||
f"{pnl_emoji} P&L: {pnl_sign}{pnl_pct:.2f}%"
|
||||
)
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=NotificationPriority.LOW, message=message)
|
||||
)
|
||||
|
||||
async def notify_circuit_breaker(
|
||||
self, pnl_pct: float, threshold: float
|
||||
) -> None:
|
||||
"""
|
||||
Notify circuit breaker activation.
|
||||
|
||||
Args:
|
||||
pnl_pct: Current P&L percentage
|
||||
threshold: Circuit breaker threshold
|
||||
"""
|
||||
message = (
|
||||
f"<b>CIRCUIT BREAKER TRIPPED</b>\n"
|
||||
f"P&L: {pnl_pct:.2f}% (threshold: {threshold:.1f}%)\n"
|
||||
f"Trading halted for safety"
|
||||
)
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=NotificationPriority.CRITICAL, message=message)
|
||||
)
|
||||
|
||||
async def notify_fat_finger(
|
||||
self,
|
||||
stock_code: str,
|
||||
order_amount: float,
|
||||
total_cash: float,
|
||||
max_pct: float,
|
||||
) -> None:
|
||||
"""
|
||||
Notify fat-finger protection rejection.
|
||||
|
||||
Args:
|
||||
stock_code: Stock ticker symbol
|
||||
order_amount: Attempted order amount
|
||||
total_cash: Total available cash
|
||||
max_pct: Maximum allowed percentage
|
||||
"""
|
||||
attempted_pct = (order_amount / total_cash) * 100 if total_cash > 0 else 0
|
||||
message = (
|
||||
f"<b>Fat-Finger Protection</b>\n"
|
||||
f"Order rejected: <code>{stock_code}</code>\n"
|
||||
f"Attempted: {attempted_pct:.1f}% of cash\n"
|
||||
f"Max allowed: {max_pct:.0f}%\n"
|
||||
f"Amount: {order_amount:,.0f} / {total_cash:,.0f}"
|
||||
)
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=NotificationPriority.HIGH, message=message)
|
||||
)
|
||||
|
||||
async def notify_system_start(
|
||||
self, mode: str, enabled_markets: list[str]
|
||||
) -> None:
|
||||
"""
|
||||
Notify system startup.
|
||||
|
||||
Args:
|
||||
mode: Trading mode ("paper" or "live")
|
||||
enabled_markets: List of enabled market codes
|
||||
"""
|
||||
mode_emoji = "📝" if mode == "paper" else "💰"
|
||||
markets_str = ", ".join(enabled_markets)
|
||||
message = (
|
||||
f"<b>{mode_emoji} System Started</b>\n"
|
||||
f"Mode: {mode.upper()}\n"
|
||||
f"Markets: {markets_str}"
|
||||
)
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=NotificationPriority.MEDIUM, message=message)
|
||||
)
|
||||
|
||||
async def notify_system_shutdown(self, reason: str) -> None:
|
||||
"""
|
||||
Notify system shutdown.
|
||||
|
||||
Args:
|
||||
reason: Reason for shutdown (e.g., "Normal shutdown", "Circuit breaker")
|
||||
"""
|
||||
message = f"<b>System Shutdown</b>\n{reason}"
|
||||
priority = (
|
||||
NotificationPriority.CRITICAL
|
||||
if "circuit breaker" in reason.lower()
|
||||
else NotificationPriority.MEDIUM
|
||||
)
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=priority, message=message)
|
||||
)
|
||||
|
||||
async def notify_error(
|
||||
self, error_type: str, error_msg: str, context: str
|
||||
) -> None:
|
||||
"""
|
||||
Notify system error.
|
||||
|
||||
Args:
|
||||
error_type: Type of error (e.g., "Connection Error")
|
||||
error_msg: Error message
|
||||
context: Error context (e.g., stock code, market)
|
||||
"""
|
||||
message = (
|
||||
f"<b>Error: {error_type}</b>\n"
|
||||
f"Context: {context}\n"
|
||||
f"Message: {error_msg[:200]}" # Truncate long errors
|
||||
)
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=NotificationPriority.HIGH, message=message)
|
||||
)
|
||||
|
||||
|
||||
class TelegramCommandHandler:
|
||||
"""Handles incoming Telegram commands via long polling."""
|
||||
|
||||
def __init__(
|
||||
self, client: TelegramClient, polling_interval: float = 1.0
|
||||
) -> None:
|
||||
"""
|
||||
Initialize command handler.
|
||||
|
||||
Args:
|
||||
client: TelegramClient instance for sending responses
|
||||
polling_interval: Polling interval in seconds
|
||||
"""
|
||||
self._client = client
|
||||
self._polling_interval = polling_interval
|
||||
self._commands: dict[str, Callable[[], Awaitable[None]]] = {}
|
||||
self._last_update_id = 0
|
||||
self._polling_task: asyncio.Task[None] | None = None
|
||||
self._running = False
|
||||
|
||||
def register_command(
|
||||
self, command: str, handler: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
"""
|
||||
Register a command handler.
|
||||
|
||||
Args:
|
||||
command: Command name (without leading slash, e.g., "start")
|
||||
handler: Async function to handle the command
|
||||
"""
|
||||
self._commands[command] = handler
|
||||
logger.debug("Registered command handler: /%s", command)
|
||||
|
||||
async def start_polling(self) -> None:
|
||||
"""Start long polling for commands."""
|
||||
if self._running:
|
||||
logger.warning("Command handler already running")
|
||||
return
|
||||
|
||||
if not self._client._enabled:
|
||||
logger.info("Command handler disabled (TelegramClient disabled)")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._polling_task = asyncio.create_task(self._poll_loop())
|
||||
logger.info("Started Telegram command polling")
|
||||
|
||||
async def stop_polling(self) -> None:
|
||||
"""Stop polling and cancel pending tasks."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
if self._polling_task:
|
||||
self._polling_task.cancel()
|
||||
try:
|
||||
await self._polling_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("Stopped Telegram command polling")
|
||||
|
||||
async def _poll_loop(self) -> None:
|
||||
"""Main polling loop that fetches updates."""
|
||||
while self._running:
|
||||
try:
|
||||
updates = await self._get_updates()
|
||||
for update in updates:
|
||||
await self._handle_update(update)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as exc:
|
||||
logger.error("Error in polling loop: %s", exc)
|
||||
|
||||
await asyncio.sleep(self._polling_interval)
|
||||
|
||||
async def _get_updates(self) -> list[dict]:
|
||||
"""
|
||||
Fetch updates from Telegram API.
|
||||
|
||||
Returns:
|
||||
List of update objects
|
||||
"""
|
||||
try:
|
||||
url = f"{self._client.API_BASE.format(token=self._client._bot_token)}/getUpdates"
|
||||
payload = {
|
||||
"offset": self._last_update_id + 1,
|
||||
"timeout": int(self._polling_interval),
|
||||
"allowed_updates": ["message"],
|
||||
}
|
||||
|
||||
session = self._client._get_session()
|
||||
async with session.post(url, json=payload) as resp:
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
logger.error(
|
||||
"getUpdates API error (status=%d): %s", resp.status, error_text
|
||||
)
|
||||
return []
|
||||
|
||||
data = await resp.json()
|
||||
if not data.get("ok"):
|
||||
logger.error("getUpdates returned ok=false: %s", data)
|
||||
return []
|
||||
|
||||
updates = data.get("result", [])
|
||||
if updates:
|
||||
self._last_update_id = updates[-1]["update_id"]
|
||||
|
||||
return updates
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug("getUpdates timeout (normal)")
|
||||
return []
|
||||
except aiohttp.ClientError as exc:
|
||||
logger.error("getUpdates failed: %s", exc)
|
||||
return []
|
||||
except Exception as exc:
|
||||
logger.error("Unexpected error in _get_updates: %s", exc)
|
||||
return []
|
||||
|
||||
async def _handle_update(self, update: dict) -> None:
|
||||
"""
|
||||
Parse and handle a single update.
|
||||
|
||||
Args:
|
||||
update: Update object from Telegram API
|
||||
"""
|
||||
try:
|
||||
message = update.get("message")
|
||||
if not message:
|
||||
return
|
||||
|
||||
# Verify chat_id matches configured chat
|
||||
chat_id = str(message.get("chat", {}).get("id", ""))
|
||||
if chat_id != self._client._chat_id:
|
||||
logger.warning(
|
||||
"Ignoring command from unauthorized chat_id: %s", chat_id
|
||||
)
|
||||
return
|
||||
|
||||
# Extract command text
|
||||
text = message.get("text", "").strip()
|
||||
if not text.startswith("/"):
|
||||
return
|
||||
|
||||
# Parse command (remove leading slash and extract command name)
|
||||
command_parts = text[1:].split()
|
||||
if not command_parts:
|
||||
return
|
||||
|
||||
# Remove @botname suffix if present (for group chats)
|
||||
command_name = command_parts[0].split("@")[0]
|
||||
|
||||
# Execute handler
|
||||
handler = self._commands.get(command_name)
|
||||
if handler:
|
||||
logger.info("Executing command: /%s", command_name)
|
||||
await handler()
|
||||
else:
|
||||
logger.debug("Unknown command: /%s", command_name)
|
||||
await self._client.send_message(
|
||||
f"Unknown command: /{command_name}\nUse /help to see available commands."
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error handling update: %s", exc)
|
||||
# Don't crash the polling loop on handler errors
|
||||
0
src/strategy/__init__.py
Normal file
0
src/strategy/__init__.py
Normal file
164
src/strategy/models.py
Normal file
164
src/strategy/models.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Pydantic models for pre-market scenario planning.
|
||||
|
||||
Defines the data contracts for the proactive strategy system:
|
||||
- AI generates DayPlaybook before market open (structured JSON scenarios)
|
||||
- Local ScenarioEngine matches conditions during market hours (no API calls)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, date, datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class ScenarioAction(str, Enum):
|
||||
"""Actions that can be taken by scenarios."""
|
||||
|
||||
BUY = "BUY"
|
||||
SELL = "SELL"
|
||||
HOLD = "HOLD"
|
||||
REDUCE_ALL = "REDUCE_ALL"
|
||||
|
||||
|
||||
class MarketOutlook(str, Enum):
|
||||
"""AI's assessment of market direction."""
|
||||
|
||||
BULLISH = "bullish"
|
||||
NEUTRAL_TO_BULLISH = "neutral_to_bullish"
|
||||
NEUTRAL = "neutral"
|
||||
NEUTRAL_TO_BEARISH = "neutral_to_bearish"
|
||||
BEARISH = "bearish"
|
||||
|
||||
|
||||
class PlaybookStatus(str, Enum):
|
||||
"""Lifecycle status of a playbook."""
|
||||
|
||||
PENDING = "pending"
|
||||
READY = "ready"
|
||||
FAILED = "failed"
|
||||
EXPIRED = "expired"
|
||||
|
||||
|
||||
class StockCondition(BaseModel):
|
||||
"""Condition fields for scenario matching (all optional, AND-combined).
|
||||
|
||||
The ScenarioEngine evaluates all non-None fields as AND conditions.
|
||||
A condition matches only if ALL specified fields are satisfied.
|
||||
"""
|
||||
|
||||
rsi_below: float | None = None
|
||||
rsi_above: float | None = None
|
||||
volume_ratio_above: float | None = None
|
||||
volume_ratio_below: float | None = None
|
||||
price_above: float | None = None
|
||||
price_below: float | None = None
|
||||
price_change_pct_above: float | None = None
|
||||
price_change_pct_below: float | None = None
|
||||
|
||||
def has_any_condition(self) -> bool:
|
||||
"""Check if at least one condition field is set."""
|
||||
return any(
|
||||
v is not None
|
||||
for v in (
|
||||
self.rsi_below,
|
||||
self.rsi_above,
|
||||
self.volume_ratio_above,
|
||||
self.volume_ratio_below,
|
||||
self.price_above,
|
||||
self.price_below,
|
||||
self.price_change_pct_above,
|
||||
self.price_change_pct_below,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class StockScenario(BaseModel):
|
||||
"""A single condition-action rule for one stock."""
|
||||
|
||||
condition: StockCondition
|
||||
action: ScenarioAction
|
||||
confidence: int = Field(ge=0, le=100)
|
||||
allocation_pct: float = Field(ge=0, le=100, default=10.0)
|
||||
stop_loss_pct: float = Field(le=0, default=-2.0)
|
||||
take_profit_pct: float = Field(ge=0, default=3.0)
|
||||
rationale: str = ""
|
||||
|
||||
|
||||
class StockPlaybook(BaseModel):
|
||||
"""All scenarios for a single stock (ordered by priority)."""
|
||||
|
||||
stock_code: str
|
||||
stock_name: str = ""
|
||||
scenarios: list[StockScenario] = Field(min_length=1)
|
||||
|
||||
|
||||
class GlobalRule(BaseModel):
|
||||
"""Portfolio-level rule (checked before stock-level scenarios)."""
|
||||
|
||||
condition: str # e.g. "portfolio_pnl_pct < -2.0"
|
||||
action: ScenarioAction
|
||||
rationale: str = ""
|
||||
|
||||
|
||||
class CrossMarketContext(BaseModel):
|
||||
"""Summary of another market's state for cross-market awareness."""
|
||||
|
||||
market: str # e.g. "US" or "KR"
|
||||
date: str
|
||||
total_pnl: float = 0.0
|
||||
win_rate: float = 0.0
|
||||
index_change_pct: float = 0.0 # e.g. KOSPI or S&P500 change
|
||||
key_events: list[str] = Field(default_factory=list)
|
||||
lessons: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DayPlaybook(BaseModel):
|
||||
"""Complete playbook for a single trading day in a single market.
|
||||
|
||||
Generated by PreMarketPlanner (1 Gemini call per market per day).
|
||||
Consumed by ScenarioEngine during market hours (0 API calls).
|
||||
"""
|
||||
|
||||
date: date
|
||||
market: str # "KR" or "US"
|
||||
market_outlook: MarketOutlook = MarketOutlook.NEUTRAL
|
||||
generated_at: str = "" # ISO timestamp
|
||||
gemini_model: str = ""
|
||||
token_count: int = 0
|
||||
global_rules: list[GlobalRule] = Field(default_factory=list)
|
||||
stock_playbooks: list[StockPlaybook] = Field(default_factory=list)
|
||||
default_action: ScenarioAction = ScenarioAction.HOLD
|
||||
context_summary: dict = Field(default_factory=dict)
|
||||
cross_market: CrossMarketContext | None = None
|
||||
|
||||
@field_validator("stock_playbooks")
|
||||
@classmethod
|
||||
def validate_unique_stocks(cls, v: list[StockPlaybook]) -> list[StockPlaybook]:
|
||||
codes = [pb.stock_code for pb in v]
|
||||
if len(codes) != len(set(codes)):
|
||||
raise ValueError("Duplicate stock codes in playbook")
|
||||
return v
|
||||
|
||||
def get_stock_playbook(self, stock_code: str) -> StockPlaybook | None:
|
||||
"""Find the playbook for a specific stock."""
|
||||
for pb in self.stock_playbooks:
|
||||
if pb.stock_code == stock_code:
|
||||
return pb
|
||||
return None
|
||||
|
||||
@property
|
||||
def scenario_count(self) -> int:
|
||||
"""Total number of scenarios across all stocks."""
|
||||
return sum(len(pb.scenarios) for pb in self.stock_playbooks)
|
||||
|
||||
@property
|
||||
def stock_count(self) -> int:
|
||||
"""Number of stocks with scenarios."""
|
||||
return len(self.stock_playbooks)
|
||||
|
||||
def model_post_init(self, __context: object) -> None:
|
||||
"""Set generated_at if not provided."""
|
||||
if not self.generated_at:
|
||||
self.generated_at = datetime.now(UTC).isoformat()
|
||||
365
tests/test_backup.py
Normal file
365
tests/test_backup.py
Normal file
@@ -0,0 +1,365 @@
|
||||
"""Tests for backup and disaster recovery system."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from src.backup.exporter import BackupExporter, ExportFormat
|
||||
from src.backup.health_monitor import HealthMonitor, HealthStatus
|
||||
from src.backup.scheduler import BackupPolicy, BackupScheduler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db(tmp_path: Path) -> Path:
|
||||
"""Create a temporary test database."""
|
||||
db_path = tmp_path / "test_trades.db"
|
||||
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create trades table
|
||||
cursor.execute("""
|
||||
CREATE TABLE trades (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
timestamp TEXT NOT NULL,
|
||||
stock_code TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
quantity INTEGER NOT NULL,
|
||||
price REAL NOT NULL,
|
||||
confidence INTEGER NOT NULL,
|
||||
rationale TEXT,
|
||||
pnl REAL DEFAULT 0.0
|
||||
)
|
||||
""")
|
||||
|
||||
# Insert test data
|
||||
test_trades = [
|
||||
("2024-01-01T10:00:00Z", "005930", "BUY", 10, 70000.0, 85, "Test buy", 0.0),
|
||||
("2024-01-01T11:00:00Z", "005930", "SELL", 10, 71000.0, 90, "Test sell", 10000.0),
|
||||
("2024-01-02T10:00:00Z", "AAPL", "BUY", 5, 180.0, 88, "Tech buy", 0.0),
|
||||
]
|
||||
|
||||
cursor.executemany(
|
||||
"""
|
||||
INSERT INTO trades (timestamp, stock_code, action, quantity, price, confidence, rationale, pnl)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
test_trades,
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
return db_path
|
||||
|
||||
|
||||
class TestBackupExporter:
|
||||
"""Test BackupExporter functionality."""
|
||||
|
||||
def test_exporter_init(self, temp_db: Path) -> None:
|
||||
"""Test exporter initialization."""
|
||||
exporter = BackupExporter(str(temp_db))
|
||||
assert exporter.db_path == str(temp_db)
|
||||
|
||||
def test_export_json(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test JSON export."""
|
||||
exporter = BackupExporter(str(temp_db))
|
||||
output_dir = tmp_path / "exports"
|
||||
|
||||
results = exporter.export_all(
|
||||
output_dir, formats=[ExportFormat.JSON], compress=False
|
||||
)
|
||||
|
||||
assert ExportFormat.JSON in results
|
||||
assert results[ExportFormat.JSON].exists()
|
||||
assert results[ExportFormat.JSON].suffix == ".json"
|
||||
|
||||
def test_export_json_compressed(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test compressed JSON export."""
|
||||
exporter = BackupExporter(str(temp_db))
|
||||
output_dir = tmp_path / "exports"
|
||||
|
||||
results = exporter.export_all(
|
||||
output_dir, formats=[ExportFormat.JSON], compress=True
|
||||
)
|
||||
|
||||
assert ExportFormat.JSON in results
|
||||
assert results[ExportFormat.JSON].suffix == ".gz"
|
||||
|
||||
def test_export_csv(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test CSV export."""
|
||||
exporter = BackupExporter(str(temp_db))
|
||||
output_dir = tmp_path / "exports"
|
||||
|
||||
results = exporter.export_all(
|
||||
output_dir, formats=[ExportFormat.CSV], compress=False
|
||||
)
|
||||
|
||||
assert ExportFormat.CSV in results
|
||||
assert results[ExportFormat.CSV].exists()
|
||||
|
||||
# Verify CSV content
|
||||
with open(results[ExportFormat.CSV], "r") as f:
|
||||
lines = f.readlines()
|
||||
assert len(lines) == 4 # Header + 3 rows
|
||||
|
||||
def test_export_all_formats(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test exporting all formats."""
|
||||
exporter = BackupExporter(str(temp_db))
|
||||
output_dir = tmp_path / "exports"
|
||||
|
||||
# Skip Parquet if pyarrow not available
|
||||
try:
|
||||
import pyarrow # noqa: F401
|
||||
|
||||
formats = [ExportFormat.JSON, ExportFormat.CSV, ExportFormat.PARQUET]
|
||||
except ImportError:
|
||||
formats = [ExportFormat.JSON, ExportFormat.CSV]
|
||||
|
||||
results = exporter.export_all(output_dir, formats=formats, compress=False)
|
||||
|
||||
for fmt in formats:
|
||||
assert fmt in results
|
||||
assert results[fmt].exists()
|
||||
|
||||
def test_incremental_export(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test incremental export."""
|
||||
exporter = BackupExporter(str(temp_db))
|
||||
output_dir = tmp_path / "exports"
|
||||
|
||||
# Export only trades after Jan 2
|
||||
cutoff = datetime(2024, 1, 2, tzinfo=UTC)
|
||||
results = exporter.export_all(
|
||||
output_dir,
|
||||
formats=[ExportFormat.JSON],
|
||||
compress=False,
|
||||
incremental_since=cutoff,
|
||||
)
|
||||
|
||||
# Should only have 1 trade (AAPL on Jan 2)
|
||||
import json
|
||||
|
||||
with open(results[ExportFormat.JSON], "r") as f:
|
||||
data = json.load(f)
|
||||
assert data["record_count"] == 1
|
||||
assert data["trades"][0]["stock_code"] == "AAPL"
|
||||
|
||||
def test_get_export_stats(self, temp_db: Path) -> None:
|
||||
"""Test export statistics."""
|
||||
exporter = BackupExporter(str(temp_db))
|
||||
stats = exporter.get_export_stats()
|
||||
|
||||
assert stats["total_trades"] == 3
|
||||
assert "date_range" in stats
|
||||
assert "db_size_bytes" in stats
|
||||
|
||||
|
||||
class TestBackupScheduler:
|
||||
"""Test BackupScheduler functionality."""
|
||||
|
||||
def test_scheduler_init(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test scheduler initialization."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
|
||||
assert scheduler.db_path == temp_db
|
||||
assert (backup_dir / "daily").exists()
|
||||
assert (backup_dir / "weekly").exists()
|
||||
assert (backup_dir / "monthly").exists()
|
||||
|
||||
def test_create_daily_backup(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test daily backup creation."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
|
||||
metadata = scheduler.create_backup(BackupPolicy.DAILY, verify=True)
|
||||
|
||||
assert metadata.policy == BackupPolicy.DAILY
|
||||
assert metadata.file_path.exists()
|
||||
assert metadata.size_bytes > 0
|
||||
assert metadata.checksum is not None
|
||||
|
||||
def test_create_weekly_backup(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test weekly backup creation."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
|
||||
metadata = scheduler.create_backup(BackupPolicy.WEEKLY, verify=False)
|
||||
|
||||
assert metadata.policy == BackupPolicy.WEEKLY
|
||||
assert metadata.file_path.exists()
|
||||
assert metadata.checksum is None # verify=False
|
||||
|
||||
def test_list_backups(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test listing backups."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
|
||||
scheduler.create_backup(BackupPolicy.DAILY)
|
||||
scheduler.create_backup(BackupPolicy.WEEKLY)
|
||||
|
||||
backups = scheduler.list_backups()
|
||||
assert len(backups) == 2
|
||||
|
||||
daily_backups = scheduler.list_backups(BackupPolicy.DAILY)
|
||||
assert len(daily_backups) == 1
|
||||
assert daily_backups[0].policy == BackupPolicy.DAILY
|
||||
|
||||
def test_cleanup_old_backups(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test cleanup of old backups."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir, daily_retention_days=0)
|
||||
|
||||
# Create a backup
|
||||
scheduler.create_backup(BackupPolicy.DAILY)
|
||||
|
||||
# Cleanup should remove it (0 day retention)
|
||||
removed = scheduler.cleanup_old_backups()
|
||||
assert removed[BackupPolicy.DAILY] >= 1
|
||||
|
||||
def test_backup_stats(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test backup statistics."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
|
||||
scheduler.create_backup(BackupPolicy.DAILY)
|
||||
scheduler.create_backup(BackupPolicy.MONTHLY)
|
||||
|
||||
stats = scheduler.get_backup_stats()
|
||||
|
||||
assert stats["daily"]["count"] == 1
|
||||
assert stats["monthly"]["count"] == 1
|
||||
assert stats["daily"]["total_size_bytes"] > 0
|
||||
|
||||
def test_restore_backup(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test backup restoration."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
|
||||
# Create backup
|
||||
metadata = scheduler.create_backup(BackupPolicy.DAILY)
|
||||
|
||||
# Modify database
|
||||
conn = sqlite3.connect(str(temp_db))
|
||||
conn.execute("DELETE FROM trades")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# Restore
|
||||
scheduler.restore_backup(metadata, verify=True)
|
||||
|
||||
# Verify restoration
|
||||
conn = sqlite3.connect(str(temp_db))
|
||||
cursor = conn.execute("SELECT COUNT(*) FROM trades")
|
||||
count = cursor.fetchone()[0]
|
||||
conn.close()
|
||||
|
||||
assert count == 3 # Original 3 trades restored
|
||||
|
||||
|
||||
class TestHealthMonitor:
|
||||
"""Test HealthMonitor functionality."""
|
||||
|
||||
def test_monitor_init(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test monitor initialization."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
monitor = HealthMonitor(str(temp_db), backup_dir)
|
||||
|
||||
assert monitor.db_path == temp_db
|
||||
|
||||
def test_check_database_health_ok(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test database health check (healthy)."""
|
||||
monitor = HealthMonitor(str(temp_db), tmp_path / "backups")
|
||||
result = monitor.check_database_health()
|
||||
|
||||
assert result.status == HealthStatus.HEALTHY
|
||||
assert "healthy" in result.message.lower()
|
||||
assert result.details is not None
|
||||
assert result.details["trade_count"] == 3
|
||||
|
||||
def test_check_database_health_missing(self, tmp_path: Path) -> None:
|
||||
"""Test database health check (missing file)."""
|
||||
non_existent = tmp_path / "missing.db"
|
||||
monitor = HealthMonitor(str(non_existent), tmp_path / "backups")
|
||||
result = monitor.check_database_health()
|
||||
|
||||
assert result.status == HealthStatus.UNHEALTHY
|
||||
assert "not found" in result.message.lower()
|
||||
|
||||
def test_check_disk_space(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test disk space check."""
|
||||
monitor = HealthMonitor(str(temp_db), tmp_path, min_disk_space_gb=0.001)
|
||||
result = monitor.check_disk_space()
|
||||
|
||||
# Should be healthy with minimal requirement
|
||||
assert result.status in [HealthStatus.HEALTHY, HealthStatus.DEGRADED]
|
||||
assert result.details is not None
|
||||
assert "free_gb" in result.details
|
||||
|
||||
def test_check_backup_recency_no_backups(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test backup recency check (no backups)."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
backup_dir.mkdir()
|
||||
(backup_dir / "daily").mkdir()
|
||||
|
||||
monitor = HealthMonitor(str(temp_db), backup_dir)
|
||||
result = monitor.check_backup_recency()
|
||||
|
||||
assert result.status == HealthStatus.UNHEALTHY
|
||||
assert "no" in result.message.lower()
|
||||
|
||||
def test_check_backup_recency_recent(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test backup recency check (recent backup)."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
scheduler.create_backup(BackupPolicy.DAILY)
|
||||
|
||||
monitor = HealthMonitor(str(temp_db), backup_dir)
|
||||
result = monitor.check_backup_recency()
|
||||
|
||||
assert result.status == HealthStatus.HEALTHY
|
||||
assert "recent" in result.message.lower()
|
||||
|
||||
def test_run_all_checks(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test running all health checks."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
scheduler.create_backup(BackupPolicy.DAILY)
|
||||
|
||||
monitor = HealthMonitor(str(temp_db), backup_dir, min_disk_space_gb=0.001)
|
||||
checks = monitor.run_all_checks()
|
||||
|
||||
assert "database" in checks
|
||||
assert "disk_space" in checks
|
||||
assert "backup_recency" in checks
|
||||
assert checks["database"].status == HealthStatus.HEALTHY
|
||||
|
||||
def test_get_overall_status(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test overall health status."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
scheduler.create_backup(BackupPolicy.DAILY)
|
||||
|
||||
monitor = HealthMonitor(str(temp_db), backup_dir, min_disk_space_gb=0.001)
|
||||
status = monitor.get_overall_status()
|
||||
|
||||
assert status in [HealthStatus.HEALTHY, HealthStatus.DEGRADED]
|
||||
|
||||
def test_get_health_report(self, temp_db: Path, tmp_path: Path) -> None:
|
||||
"""Test health report generation."""
|
||||
backup_dir = tmp_path / "backups"
|
||||
scheduler = BackupScheduler(str(temp_db), backup_dir)
|
||||
scheduler.create_backup(BackupPolicy.DAILY)
|
||||
|
||||
monitor = HealthMonitor(str(temp_db), backup_dir, min_disk_space_gb=0.001)
|
||||
report = monitor.get_health_report()
|
||||
|
||||
assert "overall_status" in report
|
||||
assert "timestamp" in report
|
||||
assert "checks" in report
|
||||
assert len(report["checks"]) == 3
|
||||
@@ -126,7 +126,7 @@ class TestPromptConstruction:
|
||||
"orderbook": {"asks": [], "bids": []},
|
||||
"foreigner_net": -50000,
|
||||
}
|
||||
prompt = client.build_prompt(market_data)
|
||||
prompt = client.build_prompt_sync(market_data)
|
||||
assert "005930" in prompt
|
||||
|
||||
def test_prompt_contains_price(self, settings):
|
||||
@@ -137,7 +137,7 @@ class TestPromptConstruction:
|
||||
"orderbook": {"asks": [], "bids": []},
|
||||
"foreigner_net": -50000,
|
||||
}
|
||||
prompt = client.build_prompt(market_data)
|
||||
prompt = client.build_prompt_sync(market_data)
|
||||
assert "72000" in prompt
|
||||
|
||||
def test_prompt_enforces_json_output_format(self, settings):
|
||||
@@ -148,7 +148,125 @@ class TestPromptConstruction:
|
||||
"orderbook": {"asks": [], "bids": []},
|
||||
"foreigner_net": 0,
|
||||
}
|
||||
prompt = client.build_prompt(market_data)
|
||||
prompt = client.build_prompt_sync(market_data)
|
||||
assert "JSON" in prompt
|
||||
assert "action" in prompt
|
||||
assert "confidence" in prompt
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Batch Decision Making
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBatchDecisionParsing:
|
||||
"""Batch response parser must handle JSON arrays correctly."""
|
||||
|
||||
def test_parse_valid_batch_response(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
stocks_data = [
|
||||
{"stock_code": "AAPL", "current_price": 185.5},
|
||||
{"stock_code": "MSFT", "current_price": 420.0},
|
||||
]
|
||||
raw = """[
|
||||
{"code": "AAPL", "action": "BUY", "confidence": 85, "rationale": "Strong momentum"},
|
||||
{"code": "MSFT", "action": "HOLD", "confidence": 50, "rationale": "Wait for earnings"}
|
||||
]"""
|
||||
|
||||
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||
|
||||
assert len(decisions) == 2
|
||||
assert decisions["AAPL"].action == "BUY"
|
||||
assert decisions["AAPL"].confidence == 85
|
||||
assert decisions["MSFT"].action == "HOLD"
|
||||
assert decisions["MSFT"].confidence == 50
|
||||
|
||||
def test_parse_batch_with_markdown_wrapper(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
||||
raw = """```json
|
||||
[{"code": "AAPL", "action": "BUY", "confidence": 90, "rationale": "Good"}]
|
||||
```"""
|
||||
|
||||
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||
|
||||
assert decisions["AAPL"].action == "BUY"
|
||||
assert decisions["AAPL"].confidence == 90
|
||||
|
||||
def test_parse_batch_empty_response_returns_hold_for_all(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
stocks_data = [
|
||||
{"stock_code": "AAPL", "current_price": 185.5},
|
||||
{"stock_code": "MSFT", "current_price": 420.0},
|
||||
]
|
||||
|
||||
decisions = client._parse_batch_response("", stocks_data, token_count=100)
|
||||
|
||||
assert len(decisions) == 2
|
||||
assert decisions["AAPL"].action == "HOLD"
|
||||
assert decisions["AAPL"].confidence == 0
|
||||
assert decisions["MSFT"].action == "HOLD"
|
||||
|
||||
def test_parse_batch_malformed_json_returns_hold_for_all(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
||||
raw = "This is not JSON"
|
||||
|
||||
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||
|
||||
assert decisions["AAPL"].action == "HOLD"
|
||||
assert decisions["AAPL"].confidence == 0
|
||||
|
||||
def test_parse_batch_not_array_returns_hold_for_all(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
||||
raw = '{"code": "AAPL", "action": "BUY", "confidence": 90, "rationale": "Good"}'
|
||||
|
||||
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||
|
||||
assert decisions["AAPL"].action == "HOLD"
|
||||
assert decisions["AAPL"].confidence == 0
|
||||
|
||||
def test_parse_batch_missing_stock_gets_hold(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
stocks_data = [
|
||||
{"stock_code": "AAPL", "current_price": 185.5},
|
||||
{"stock_code": "MSFT", "current_price": 420.0},
|
||||
]
|
||||
# Response only has AAPL, MSFT is missing
|
||||
raw = '[{"code": "AAPL", "action": "BUY", "confidence": 85, "rationale": "Good"}]'
|
||||
|
||||
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||
|
||||
assert decisions["AAPL"].action == "BUY"
|
||||
assert decisions["MSFT"].action == "HOLD"
|
||||
assert decisions["MSFT"].confidence == 0
|
||||
|
||||
def test_parse_batch_invalid_action_becomes_hold(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
||||
raw = '[{"code": "AAPL", "action": "YOLO", "confidence": 90, "rationale": "Moon"}]'
|
||||
|
||||
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||
|
||||
assert decisions["AAPL"].action == "HOLD"
|
||||
|
||||
def test_parse_batch_low_confidence_becomes_hold(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
||||
raw = '[{"code": "AAPL", "action": "BUY", "confidence": 65, "rationale": "Weak"}]'
|
||||
|
||||
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||
|
||||
assert decisions["AAPL"].action == "HOLD"
|
||||
assert decisions["AAPL"].confidence == 65
|
||||
|
||||
def test_parse_batch_missing_fields_gets_hold(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
||||
raw = '[{"code": "AAPL", "action": "BUY"}]' # Missing confidence and rationale
|
||||
|
||||
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
||||
|
||||
assert decisions["AAPL"].action == "HOLD"
|
||||
assert decisions["AAPL"].confidence == 0
|
||||
|
||||
@@ -49,6 +49,110 @@ class TestTokenManagement:
|
||||
|
||||
await broker.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_token_refresh_calls_api_once(self, settings):
|
||||
"""Multiple concurrent token requests should only call API once."""
|
||||
broker = KISBroker(settings)
|
||||
|
||||
# Track how many times the mock API is called
|
||||
call_count = [0]
|
||||
|
||||
def create_mock_resp():
|
||||
call_count[0] += 1
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(
|
||||
return_value={
|
||||
"access_token": "tok_concurrent",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 86400,
|
||||
}
|
||||
)
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
return mock_resp
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=create_mock_resp()):
|
||||
# Launch 5 concurrent token requests
|
||||
tokens = await asyncio.gather(
|
||||
broker._ensure_token(),
|
||||
broker._ensure_token(),
|
||||
broker._ensure_token(),
|
||||
broker._ensure_token(),
|
||||
broker._ensure_token(),
|
||||
)
|
||||
|
||||
# All should get the same token
|
||||
assert all(t == "tok_concurrent" for t in tokens)
|
||||
# API should be called only once (due to lock)
|
||||
assert call_count[0] == 1
|
||||
|
||||
await broker.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_refresh_cooldown_prevents_rapid_retries(self, settings):
|
||||
"""Token refresh should enforce cooldown after failure (issue #54)."""
|
||||
broker = KISBroker(settings)
|
||||
broker._refresh_cooldown = 2.0 # Short cooldown for testing
|
||||
|
||||
# First refresh attempt fails with 403 (EGW00133)
|
||||
mock_resp_403 = AsyncMock()
|
||||
mock_resp_403.status = 403
|
||||
mock_resp_403.text = AsyncMock(
|
||||
return_value='{"error_code":"EGW00133","error_description":"접근토큰 발급 잠시 후 다시 시도하세요(1분당 1회)"}'
|
||||
)
|
||||
mock_resp_403.__aenter__ = AsyncMock(return_value=mock_resp_403)
|
||||
mock_resp_403.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp_403):
|
||||
# First attempt should fail with 403
|
||||
with pytest.raises(ConnectionError, match="Token refresh failed"):
|
||||
await broker._ensure_token()
|
||||
|
||||
# Second attempt within cooldown should fail with cooldown error
|
||||
with pytest.raises(ConnectionError, match="Token refresh on cooldown"):
|
||||
await broker._ensure_token()
|
||||
|
||||
await broker.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_refresh_allowed_after_cooldown(self, settings):
|
||||
"""Token refresh should be allowed after cooldown period expires."""
|
||||
broker = KISBroker(settings)
|
||||
broker._refresh_cooldown = 0.1 # Very short cooldown for testing
|
||||
|
||||
# First attempt fails
|
||||
mock_resp_403 = AsyncMock()
|
||||
mock_resp_403.status = 403
|
||||
mock_resp_403.text = AsyncMock(return_value='{"error_code":"EGW00133"}')
|
||||
mock_resp_403.__aenter__ = AsyncMock(return_value=mock_resp_403)
|
||||
mock_resp_403.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
# Second attempt succeeds
|
||||
mock_resp_200 = AsyncMock()
|
||||
mock_resp_200.status = 200
|
||||
mock_resp_200.json = AsyncMock(
|
||||
return_value={
|
||||
"access_token": "tok_after_cooldown",
|
||||
"expires_in": 86400,
|
||||
}
|
||||
)
|
||||
mock_resp_200.__aenter__ = AsyncMock(return_value=mock_resp_200)
|
||||
mock_resp_200.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp_403):
|
||||
with pytest.raises(ConnectionError, match="Token refresh failed"):
|
||||
await broker._ensure_token()
|
||||
|
||||
# Wait for cooldown to expire
|
||||
await asyncio.sleep(0.15)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp_200):
|
||||
token = await broker._ensure_token()
|
||||
assert token == "tok_after_cooldown"
|
||||
|
||||
await broker.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Network Error Handling
|
||||
@@ -107,6 +211,38 @@ class TestRateLimiter:
|
||||
await broker._rate_limiter.acquire()
|
||||
await broker.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_order_acquires_rate_limiter_twice(self, settings):
|
||||
"""send_order must acquire rate limiter for both hash key and order call."""
|
||||
broker = KISBroker(settings)
|
||||
broker._access_token = "tok"
|
||||
broker._token_expires_at = asyncio.get_event_loop().time() + 3600
|
||||
|
||||
# Mock hash key response
|
||||
mock_hash_resp = AsyncMock()
|
||||
mock_hash_resp.status = 200
|
||||
mock_hash_resp.json = AsyncMock(return_value={"HASH": "abc123"})
|
||||
mock_hash_resp.__aenter__ = AsyncMock(return_value=mock_hash_resp)
|
||||
mock_hash_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
# Mock order response
|
||||
mock_order_resp = AsyncMock()
|
||||
mock_order_resp.status = 200
|
||||
mock_order_resp.json = AsyncMock(return_value={"rt_cd": "0"})
|
||||
mock_order_resp.__aenter__ = AsyncMock(return_value=mock_order_resp)
|
||||
mock_order_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post", side_effect=[mock_hash_resp, mock_order_resp]
|
||||
):
|
||||
with patch.object(
|
||||
broker._rate_limiter, "acquire", new_callable=AsyncMock
|
||||
) as mock_acquire:
|
||||
await broker.send_order("005930", "BUY", 1, 50000)
|
||||
assert mock_acquire.call_count == 2
|
||||
|
||||
await broker.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hash Key Generation
|
||||
@@ -136,3 +272,27 @@ class TestHashKey:
|
||||
assert len(hash_key) > 0
|
||||
|
||||
await broker.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hash_key_acquires_rate_limiter(self, settings):
|
||||
"""_get_hash_key must go through the rate limiter to prevent burst."""
|
||||
broker = KISBroker(settings)
|
||||
broker._access_token = "tok"
|
||||
broker._token_expires_at = asyncio.get_event_loop().time() + 3600
|
||||
|
||||
body = {"CANO": "12345678", "ACNT_PRDT_CD": "01"}
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(return_value={"HASH": "abc123hash"})
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
|
||||
with patch.object(
|
||||
broker._rate_limiter, "acquire", new_callable=AsyncMock
|
||||
) as mock_acquire:
|
||||
await broker._get_hash_key(body)
|
||||
mock_acquire.assert_called_once()
|
||||
|
||||
await broker.close()
|
||||
|
||||
673
tests/test_data_integration.py
Normal file
673
tests/test_data_integration.py
Normal file
@@ -0,0 +1,673 @@
|
||||
"""Tests for external data integration (news, economic calendar, market data)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.brain.gemini_client import GeminiClient
|
||||
from src.data.economic_calendar import EconomicCalendar, EconomicEvent
|
||||
from src.data.market_data import MarketBreadth, MarketData, MarketSentiment
|
||||
from src.data.news_api import NewsAPI, NewsArticle, NewsSentiment
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# NewsAPI Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNewsAPI:
|
||||
"""Test news API integration with caching."""
|
||||
|
||||
def test_news_api_init_without_key(self):
|
||||
"""NewsAPI should initialize without API key for testing."""
|
||||
api = NewsAPI(api_key=None)
|
||||
assert api._api_key is None
|
||||
assert api._provider == "alphavantage"
|
||||
assert api._cache_ttl == 300
|
||||
|
||||
def test_news_api_init_with_custom_settings(self):
|
||||
"""NewsAPI should accept custom provider and cache TTL."""
|
||||
api = NewsAPI(api_key="test_key", provider="newsapi", cache_ttl=600)
|
||||
assert api._api_key == "test_key"
|
||||
assert api._provider == "newsapi"
|
||||
assert api._cache_ttl == 600
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_news_sentiment_without_api_key_returns_none(self):
|
||||
"""Without API key, get_news_sentiment should return None."""
|
||||
api = NewsAPI(api_key=None)
|
||||
result = await api.get_news_sentiment("AAPL")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_hit_returns_cached_sentiment(self):
|
||||
"""Cache hit should return cached sentiment without API call."""
|
||||
api = NewsAPI(api_key="test_key")
|
||||
|
||||
# Manually populate cache
|
||||
cached_sentiment = NewsSentiment(
|
||||
stock_code="AAPL",
|
||||
articles=[],
|
||||
avg_sentiment=0.5,
|
||||
article_count=0,
|
||||
fetched_at=time.time(),
|
||||
)
|
||||
api._cache["AAPL"] = cached_sentiment
|
||||
|
||||
result = await api.get_news_sentiment("AAPL")
|
||||
assert result is cached_sentiment
|
||||
assert result.stock_code == "AAPL"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_expiry_triggers_refetch(self):
|
||||
"""Expired cache entry should trigger refetch."""
|
||||
api = NewsAPI(api_key="test_key", cache_ttl=1)
|
||||
|
||||
# Add expired cache entry
|
||||
expired_sentiment = NewsSentiment(
|
||||
stock_code="AAPL",
|
||||
articles=[],
|
||||
avg_sentiment=0.5,
|
||||
article_count=0,
|
||||
fetched_at=time.time() - 10, # 10 seconds ago
|
||||
)
|
||||
api._cache["AAPL"] = expired_sentiment
|
||||
|
||||
# Mock the fetch to avoid real API call
|
||||
with patch.object(api, "_fetch_news", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = None
|
||||
result = await api.get_news_sentiment("AAPL")
|
||||
|
||||
# Should have attempted refetch since cache expired
|
||||
mock_fetch.assert_called_once_with("AAPL")
|
||||
|
||||
def test_clear_cache(self):
|
||||
"""clear_cache should empty the cache."""
|
||||
api = NewsAPI(api_key="test_key")
|
||||
api._cache["AAPL"] = NewsSentiment(
|
||||
stock_code="AAPL",
|
||||
articles=[],
|
||||
avg_sentiment=0.0,
|
||||
article_count=0,
|
||||
fetched_at=time.time(),
|
||||
)
|
||||
assert len(api._cache) == 1
|
||||
|
||||
api.clear_cache()
|
||||
assert len(api._cache) == 0
|
||||
|
||||
def test_parse_alphavantage_response_with_valid_data(self):
|
||||
"""Should parse Alpha Vantage response correctly."""
|
||||
api = NewsAPI(api_key="test_key", provider="alphavantage")
|
||||
|
||||
mock_response = {
|
||||
"feed": [
|
||||
{
|
||||
"title": "Apple hits new high",
|
||||
"summary": "Apple stock surges to record levels",
|
||||
"source": "Reuters",
|
||||
"time_published": "2026-02-04T10:00:00",
|
||||
"url": "https://example.com/1",
|
||||
"ticker_sentiment": [
|
||||
{"ticker": "AAPL", "ticker_sentiment_score": "0.85"}
|
||||
],
|
||||
"overall_sentiment_score": "0.75",
|
||||
},
|
||||
{
|
||||
"title": "Market volatility rises",
|
||||
"summary": "Tech stocks face headwinds",
|
||||
"source": "Bloomberg",
|
||||
"time_published": "2026-02-04T09:00:00",
|
||||
"url": "https://example.com/2",
|
||||
"ticker_sentiment": [
|
||||
{"ticker": "AAPL", "ticker_sentiment_score": "-0.3"}
|
||||
],
|
||||
"overall_sentiment_score": "-0.2",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
result = api._parse_alphavantage_response("AAPL", mock_response)
|
||||
|
||||
assert result is not None
|
||||
assert result.stock_code == "AAPL"
|
||||
assert result.article_count == 2
|
||||
assert len(result.articles) == 2
|
||||
assert result.articles[0].title == "Apple hits new high"
|
||||
assert result.articles[0].sentiment_score == 0.85
|
||||
assert result.articles[1].sentiment_score == -0.3
|
||||
# Average: (0.85 - 0.3) / 2 = 0.275
|
||||
assert abs(result.avg_sentiment - 0.275) < 0.01
|
||||
|
||||
def test_parse_alphavantage_response_without_feed_returns_none(self):
|
||||
"""Should return None if 'feed' key is missing."""
|
||||
api = NewsAPI(api_key="test_key", provider="alphavantage")
|
||||
result = api._parse_alphavantage_response("AAPL", {})
|
||||
assert result is None
|
||||
|
||||
def test_parse_newsapi_response_with_valid_data(self):
|
||||
"""Should parse NewsAPI.org response correctly."""
|
||||
api = NewsAPI(api_key="test_key", provider="newsapi")
|
||||
|
||||
mock_response = {
|
||||
"status": "ok",
|
||||
"articles": [
|
||||
{
|
||||
"title": "Apple stock surges",
|
||||
"description": "Strong earnings beat expectations",
|
||||
"source": {"name": "TechCrunch"},
|
||||
"publishedAt": "2026-02-04T10:00:00Z",
|
||||
"url": "https://example.com/1",
|
||||
},
|
||||
{
|
||||
"title": "Tech sector faces risks",
|
||||
"description": "Concerns over market downturn",
|
||||
"source": {"name": "CNBC"},
|
||||
"publishedAt": "2026-02-04T09:00:00Z",
|
||||
"url": "https://example.com/2",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
result = api._parse_newsapi_response("AAPL", mock_response)
|
||||
|
||||
assert result is not None
|
||||
assert result.stock_code == "AAPL"
|
||||
assert result.article_count == 2
|
||||
assert len(result.articles) == 2
|
||||
assert result.articles[0].title == "Apple stock surges"
|
||||
assert result.articles[0].source == "TechCrunch"
|
||||
|
||||
def test_estimate_sentiment_from_text_positive(self):
|
||||
"""Should detect positive sentiment from keywords."""
|
||||
api = NewsAPI()
|
||||
text = "Stock price surges with strong profit growth and upgrade"
|
||||
sentiment = api._estimate_sentiment_from_text(text)
|
||||
assert sentiment > 0.5
|
||||
|
||||
def test_estimate_sentiment_from_text_negative(self):
|
||||
"""Should detect negative sentiment from keywords."""
|
||||
api = NewsAPI()
|
||||
text = "Stock plunges on weak earnings, downgrade warning"
|
||||
sentiment = api._estimate_sentiment_from_text(text)
|
||||
assert sentiment < -0.5
|
||||
|
||||
def test_estimate_sentiment_from_text_neutral(self):
|
||||
"""Should return neutral sentiment without keywords."""
|
||||
api = NewsAPI()
|
||||
text = "Company announces quarterly report"
|
||||
sentiment = api._estimate_sentiment_from_text(text)
|
||||
assert abs(sentiment) < 0.1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EconomicCalendar Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEconomicCalendar:
|
||||
"""Test economic calendar functionality."""
|
||||
|
||||
def test_economic_calendar_init(self):
|
||||
"""EconomicCalendar should initialize correctly."""
|
||||
calendar = EconomicCalendar(api_key="test_key")
|
||||
assert calendar._api_key == "test_key"
|
||||
assert len(calendar._events) == 0
|
||||
|
||||
def test_add_event(self):
|
||||
"""Should be able to add events to calendar."""
|
||||
calendar = EconomicCalendar()
|
||||
event = EconomicEvent(
|
||||
name="FOMC Meeting",
|
||||
event_type="FOMC",
|
||||
datetime=datetime(2026, 3, 18),
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Interest rate decision",
|
||||
)
|
||||
calendar.add_event(event)
|
||||
assert len(calendar._events) == 1
|
||||
assert calendar._events[0].name == "FOMC Meeting"
|
||||
|
||||
def test_get_upcoming_events_filters_by_timeframe(self):
|
||||
"""Should only return events within specified timeframe."""
|
||||
calendar = EconomicCalendar()
|
||||
|
||||
# Add events at different times
|
||||
now = datetime.now()
|
||||
calendar.add_event(
|
||||
EconomicEvent(
|
||||
name="Event Tomorrow",
|
||||
event_type="GDP",
|
||||
datetime=now + timedelta(days=1),
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Test event",
|
||||
)
|
||||
)
|
||||
calendar.add_event(
|
||||
EconomicEvent(
|
||||
name="Event Next Month",
|
||||
event_type="CPI",
|
||||
datetime=now + timedelta(days=30),
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Test event",
|
||||
)
|
||||
)
|
||||
|
||||
# Get events for next 7 days
|
||||
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="HIGH")
|
||||
assert upcoming.high_impact_count == 1
|
||||
assert upcoming.events[0].name == "Event Tomorrow"
|
||||
|
||||
def test_get_upcoming_events_filters_by_impact(self):
|
||||
"""Should filter events by minimum impact level."""
|
||||
calendar = EconomicCalendar()
|
||||
|
||||
now = datetime.now()
|
||||
calendar.add_event(
|
||||
EconomicEvent(
|
||||
name="High Impact Event",
|
||||
event_type="FOMC",
|
||||
datetime=now + timedelta(days=1),
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Test",
|
||||
)
|
||||
)
|
||||
calendar.add_event(
|
||||
EconomicEvent(
|
||||
name="Low Impact Event",
|
||||
event_type="OTHER",
|
||||
datetime=now + timedelta(days=1),
|
||||
impact="LOW",
|
||||
country="US",
|
||||
description="Test",
|
||||
)
|
||||
)
|
||||
|
||||
# Filter for HIGH impact only
|
||||
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="HIGH")
|
||||
assert upcoming.high_impact_count == 1
|
||||
assert upcoming.events[0].name == "High Impact Event"
|
||||
|
||||
# Filter for MEDIUM and above (should still get HIGH)
|
||||
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="MEDIUM")
|
||||
assert len(upcoming.events) == 1
|
||||
|
||||
# Filter for LOW and above (should get both)
|
||||
upcoming = calendar.get_upcoming_events(days_ahead=7, min_impact="LOW")
|
||||
assert len(upcoming.events) == 2
|
||||
|
||||
def test_get_earnings_date_returns_next_earnings(self):
|
||||
"""Should return next earnings date for a stock."""
|
||||
calendar = EconomicCalendar()
|
||||
|
||||
now = datetime.now()
|
||||
earnings_date = now + timedelta(days=5)
|
||||
|
||||
calendar.add_event(
|
||||
EconomicEvent(
|
||||
name="AAPL Earnings",
|
||||
event_type="EARNINGS",
|
||||
datetime=earnings_date,
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Apple quarterly earnings",
|
||||
)
|
||||
)
|
||||
|
||||
result = calendar.get_earnings_date("AAPL")
|
||||
assert result == earnings_date
|
||||
|
||||
def test_get_earnings_date_returns_none_if_not_found(self):
|
||||
"""Should return None if no earnings found for stock."""
|
||||
calendar = EconomicCalendar()
|
||||
result = calendar.get_earnings_date("UNKNOWN")
|
||||
assert result is None
|
||||
|
||||
def test_load_hardcoded_events(self):
|
||||
"""Should load hardcoded major economic events."""
|
||||
calendar = EconomicCalendar()
|
||||
calendar.load_hardcoded_events()
|
||||
|
||||
# Should have multiple events (FOMC, GDP, CPI)
|
||||
assert len(calendar._events) > 10
|
||||
|
||||
# Check for FOMC events
|
||||
fomc_events = [e for e in calendar._events if e.event_type == "FOMC"]
|
||||
assert len(fomc_events) > 0
|
||||
|
||||
# Check for GDP events
|
||||
gdp_events = [e for e in calendar._events if e.event_type == "GDP"]
|
||||
assert len(gdp_events) > 0
|
||||
|
||||
# Check for CPI events
|
||||
cpi_events = [e for e in calendar._events if e.event_type == "CPI"]
|
||||
assert len(cpi_events) == 12 # Monthly CPI releases
|
||||
|
||||
def test_is_high_volatility_period_returns_true_near_high_impact(self):
|
||||
"""Should return True if high-impact event is within threshold."""
|
||||
calendar = EconomicCalendar()
|
||||
|
||||
now = datetime.now()
|
||||
calendar.add_event(
|
||||
EconomicEvent(
|
||||
name="FOMC Meeting",
|
||||
event_type="FOMC",
|
||||
datetime=now + timedelta(hours=12),
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Test",
|
||||
)
|
||||
)
|
||||
|
||||
assert calendar.is_high_volatility_period(hours_ahead=24) is True
|
||||
|
||||
def test_is_high_volatility_period_returns_false_when_no_events(self):
|
||||
"""Should return False if no high-impact events nearby."""
|
||||
calendar = EconomicCalendar()
|
||||
assert calendar.is_high_volatility_period(hours_ahead=24) is False
|
||||
|
||||
def test_clear_events(self):
|
||||
"""Should clear all events."""
|
||||
calendar = EconomicCalendar()
|
||||
calendar.add_event(
|
||||
EconomicEvent(
|
||||
name="Test",
|
||||
event_type="TEST",
|
||||
datetime=datetime.now(),
|
||||
impact="LOW",
|
||||
country="US",
|
||||
description="Test",
|
||||
)
|
||||
)
|
||||
assert len(calendar._events) == 1
|
||||
|
||||
calendar.clear_events()
|
||||
assert len(calendar._events) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MarketData Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMarketData:
|
||||
"""Test market data indicators."""
|
||||
|
||||
def test_market_data_init(self):
|
||||
"""MarketData should initialize correctly."""
|
||||
data = MarketData(api_key="test_key")
|
||||
assert data._api_key == "test_key"
|
||||
|
||||
def test_get_market_sentiment_without_api_key_returns_neutral(self):
|
||||
"""Without API key, should return NEUTRAL sentiment."""
|
||||
data = MarketData(api_key=None)
|
||||
sentiment = data.get_market_sentiment()
|
||||
assert sentiment == MarketSentiment.NEUTRAL
|
||||
|
||||
def test_get_market_breadth_without_api_key_returns_none(self):
|
||||
"""Without API key, should return None for breadth."""
|
||||
data = MarketData(api_key=None)
|
||||
breadth = data.get_market_breadth()
|
||||
assert breadth is None
|
||||
|
||||
def test_get_sector_performance_without_api_key_returns_empty(self):
|
||||
"""Without API key, should return empty list."""
|
||||
data = MarketData(api_key=None)
|
||||
sectors = data.get_sector_performance()
|
||||
assert sectors == []
|
||||
|
||||
def test_get_market_indicators_returns_defaults_without_api(self):
|
||||
"""Should return default indicators without API key."""
|
||||
data = MarketData(api_key=None)
|
||||
indicators = data.get_market_indicators()
|
||||
|
||||
assert indicators.sentiment == MarketSentiment.NEUTRAL
|
||||
assert indicators.breadth.advance_decline_ratio == 1.0
|
||||
assert indicators.sector_performance == []
|
||||
assert indicators.vix_level is None
|
||||
|
||||
def test_calculate_fear_greed_score_neutral_baseline(self):
|
||||
"""Should return neutral score (50) for balanced market."""
|
||||
data = MarketData()
|
||||
breadth = MarketBreadth(
|
||||
advancing_stocks=500,
|
||||
declining_stocks=500,
|
||||
unchanged_stocks=100,
|
||||
new_highs=50,
|
||||
new_lows=50,
|
||||
advance_decline_ratio=1.0,
|
||||
)
|
||||
|
||||
score = data.calculate_fear_greed_score(breadth)
|
||||
assert score == 50
|
||||
|
||||
def test_calculate_fear_greed_score_greedy_market(self):
|
||||
"""Should return high score for greedy market conditions."""
|
||||
data = MarketData()
|
||||
breadth = MarketBreadth(
|
||||
advancing_stocks=800,
|
||||
declining_stocks=200,
|
||||
unchanged_stocks=100,
|
||||
new_highs=100,
|
||||
new_lows=10,
|
||||
advance_decline_ratio=4.0,
|
||||
)
|
||||
|
||||
score = data.calculate_fear_greed_score(breadth, vix=12.0)
|
||||
assert score > 70
|
||||
|
||||
def test_calculate_fear_greed_score_fearful_market(self):
|
||||
"""Should return low score for fearful market conditions."""
|
||||
data = MarketData()
|
||||
breadth = MarketBreadth(
|
||||
advancing_stocks=200,
|
||||
declining_stocks=800,
|
||||
unchanged_stocks=100,
|
||||
new_highs=10,
|
||||
new_lows=100,
|
||||
advance_decline_ratio=0.25,
|
||||
)
|
||||
|
||||
score = data.calculate_fear_greed_score(breadth, vix=35.0)
|
||||
assert score < 30
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GeminiClient Integration Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGeminiClientWithExternalData:
|
||||
"""Test GeminiClient integration with external data sources."""
|
||||
|
||||
def test_gemini_client_accepts_optional_data_sources(self, settings):
|
||||
"""GeminiClient should accept optional external data sources."""
|
||||
news_api = NewsAPI(api_key="test_key")
|
||||
calendar = EconomicCalendar()
|
||||
market_data = MarketData()
|
||||
|
||||
client = GeminiClient(
|
||||
settings,
|
||||
news_api=news_api,
|
||||
economic_calendar=calendar,
|
||||
market_data=market_data,
|
||||
)
|
||||
|
||||
assert client._news_api is news_api
|
||||
assert client._economic_calendar is calendar
|
||||
assert client._market_data is market_data
|
||||
|
||||
def test_gemini_client_works_without_external_data(self, settings):
|
||||
"""GeminiClient should work without external data sources."""
|
||||
client = GeminiClient(settings)
|
||||
assert client._news_api is None
|
||||
assert client._economic_calendar is None
|
||||
assert client._market_data is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_prompt_includes_news_sentiment(self, settings):
|
||||
"""build_prompt should include news sentiment when available."""
|
||||
client = GeminiClient(settings)
|
||||
|
||||
market_data = {
|
||||
"stock_code": "AAPL",
|
||||
"current_price": 180.0,
|
||||
"market_name": "US stock market",
|
||||
}
|
||||
|
||||
sentiment = NewsSentiment(
|
||||
stock_code="AAPL",
|
||||
articles=[
|
||||
NewsArticle(
|
||||
title="Apple hits record high",
|
||||
summary="Strong earnings",
|
||||
source="Reuters",
|
||||
published_at="2026-02-04",
|
||||
sentiment_score=0.85,
|
||||
url="https://example.com",
|
||||
)
|
||||
],
|
||||
avg_sentiment=0.85,
|
||||
article_count=1,
|
||||
fetched_at=time.time(),
|
||||
)
|
||||
|
||||
prompt = await client.build_prompt(market_data, news_sentiment=sentiment)
|
||||
|
||||
assert "AAPL" in prompt
|
||||
assert "180.0" in prompt
|
||||
assert "EXTERNAL DATA" in prompt
|
||||
assert "News Sentiment" in prompt
|
||||
assert "0.85" in prompt
|
||||
assert "Apple hits record high" in prompt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_prompt_with_economic_events(self, settings):
|
||||
"""build_prompt should include upcoming economic events."""
|
||||
calendar = EconomicCalendar()
|
||||
now = datetime.now()
|
||||
calendar.add_event(
|
||||
EconomicEvent(
|
||||
name="FOMC Meeting",
|
||||
event_type="FOMC",
|
||||
datetime=now + timedelta(days=2),
|
||||
impact="HIGH",
|
||||
country="US",
|
||||
description="Interest rate decision",
|
||||
)
|
||||
)
|
||||
|
||||
client = GeminiClient(settings, economic_calendar=calendar)
|
||||
|
||||
market_data = {
|
||||
"stock_code": "AAPL",
|
||||
"current_price": 180.0,
|
||||
"market_name": "US stock market",
|
||||
}
|
||||
|
||||
prompt = await client.build_prompt(market_data)
|
||||
|
||||
assert "EXTERNAL DATA" in prompt
|
||||
assert "High-Impact Events" in prompt
|
||||
assert "FOMC Meeting" in prompt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_prompt_with_market_indicators(self, settings):
|
||||
"""build_prompt should include market sentiment indicators."""
|
||||
market_data_provider = MarketData(api_key="test_key")
|
||||
|
||||
# Mock the get_market_indicators to return test data
|
||||
with patch.object(market_data_provider, "get_market_indicators") as mock:
|
||||
mock.return_value = MagicMock(
|
||||
sentiment=MarketSentiment.EXTREME_GREED,
|
||||
breadth=MagicMock(advance_decline_ratio=2.5),
|
||||
)
|
||||
|
||||
client = GeminiClient(settings, market_data=market_data_provider)
|
||||
|
||||
market_data = {
|
||||
"stock_code": "AAPL",
|
||||
"current_price": 180.0,
|
||||
"market_name": "US stock market",
|
||||
}
|
||||
|
||||
prompt = await client.build_prompt(market_data)
|
||||
|
||||
assert "EXTERNAL DATA" in prompt
|
||||
assert "Market Sentiment" in prompt
|
||||
assert "EXTREME_GREED" in prompt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_prompt_graceful_when_no_external_data(self, settings):
|
||||
"""build_prompt should work gracefully without external data."""
|
||||
client = GeminiClient(settings)
|
||||
|
||||
market_data = {
|
||||
"stock_code": "AAPL",
|
||||
"current_price": 180.0,
|
||||
"market_name": "US stock market",
|
||||
}
|
||||
|
||||
prompt = await client.build_prompt(market_data)
|
||||
|
||||
assert "AAPL" in prompt
|
||||
assert "180.0" in prompt
|
||||
# Should NOT have external data section
|
||||
assert "EXTERNAL DATA" not in prompt
|
||||
|
||||
def test_build_prompt_sync_backward_compatibility(self, settings):
|
||||
"""build_prompt_sync should maintain backward compatibility."""
|
||||
client = GeminiClient(settings)
|
||||
|
||||
market_data = {
|
||||
"stock_code": "005930",
|
||||
"current_price": 72000,
|
||||
"orderbook": {"asks": [], "bids": []},
|
||||
"foreigner_net": -50000,
|
||||
}
|
||||
|
||||
prompt = client.build_prompt_sync(market_data)
|
||||
|
||||
assert "005930" in prompt
|
||||
assert "72000" in prompt
|
||||
assert "JSON" in prompt
|
||||
# Sync version should NOT have external data
|
||||
assert "EXTERNAL DATA" not in prompt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decide_with_news_sentiment_parameter(self, settings):
|
||||
"""decide should accept optional news_sentiment parameter."""
|
||||
client = GeminiClient(settings)
|
||||
|
||||
market_data = {
|
||||
"stock_code": "AAPL",
|
||||
"current_price": 180.0,
|
||||
"market_name": "US stock market",
|
||||
}
|
||||
|
||||
sentiment = NewsSentiment(
|
||||
stock_code="AAPL",
|
||||
articles=[],
|
||||
avg_sentiment=0.5,
|
||||
article_count=1,
|
||||
fetched_at=time.time(),
|
||||
)
|
||||
|
||||
# Mock the Gemini API call
|
||||
with patch.object(client._client.aio.models, "generate_content", new_callable=AsyncMock) as mock_gen:
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = '{"action": "BUY", "confidence": 85, "rationale": "Good news"}'
|
||||
mock_gen.return_value = mock_response
|
||||
|
||||
decision = await client.decide(market_data, news_sentiment=sentiment)
|
||||
|
||||
assert decision.action == "BUY"
|
||||
assert decision.confidence == 85
|
||||
mock_gen.assert_called_once()
|
||||
@@ -11,15 +11,15 @@ from __future__ import annotations
|
||||
import json
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.config import Settings
|
||||
from src.db import init_db, log_trade
|
||||
from src.evolution.ab_test import ABTester, ABTestResult, StrategyPerformance
|
||||
from src.evolution.ab_test import ABTester
|
||||
from src.evolution.optimizer import EvolutionOptimizer
|
||||
from src.evolution.performance_tracker import (
|
||||
PerformanceDashboard,
|
||||
@@ -28,7 +28,6 @@ from src.evolution.performance_tracker import (
|
||||
)
|
||||
from src.logging.decision_logger import DecisionLogger
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
660
tests/test_main.py
Normal file
660
tests/test_main.py
Normal file
@@ -0,0 +1,660 @@
|
||||
"""Tests for main trading loop telegram integration."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.core.risk_manager import CircuitBreakerTripped, FatFingerRejected
|
||||
from src.main import safe_float, trading_cycle
|
||||
|
||||
|
||||
class TestSafeFloat:
|
||||
"""Test safe_float() helper function."""
|
||||
|
||||
def test_converts_valid_string(self):
|
||||
"""Test conversion of valid numeric string."""
|
||||
assert safe_float("123.45") == 123.45
|
||||
assert safe_float("0") == 0.0
|
||||
assert safe_float("-99.9") == -99.9
|
||||
|
||||
def test_handles_empty_string(self):
|
||||
"""Test empty string returns default."""
|
||||
assert safe_float("") == 0.0
|
||||
assert safe_float("", 99.0) == 99.0
|
||||
|
||||
def test_handles_none(self):
|
||||
"""Test None returns default."""
|
||||
assert safe_float(None) == 0.0
|
||||
assert safe_float(None, 42.0) == 42.0
|
||||
|
||||
def test_handles_invalid_string(self):
|
||||
"""Test invalid string returns default."""
|
||||
assert safe_float("invalid") == 0.0
|
||||
assert safe_float("not_a_number", 100.0) == 100.0
|
||||
assert safe_float("12.34.56") == 0.0
|
||||
|
||||
def test_handles_float_input(self):
|
||||
"""Test float input passes through."""
|
||||
assert safe_float(123.45) == 123.45
|
||||
assert safe_float(0.0) == 0.0
|
||||
|
||||
def test_custom_default(self):
|
||||
"""Test custom default value."""
|
||||
assert safe_float("", -1.0) == -1.0
|
||||
assert safe_float(None, 999.0) == 999.0
|
||||
|
||||
|
||||
class TestTradingCycleTelegramIntegration:
|
||||
"""Test telegram notifications in trading_cycle function."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_broker(self) -> MagicMock:
|
||||
"""Create mock broker."""
|
||||
broker = MagicMock()
|
||||
broker.get_orderbook = AsyncMock(
|
||||
return_value={
|
||||
"output1": {
|
||||
"stck_prpr": "50000",
|
||||
"frgn_ntby_qty": "100",
|
||||
}
|
||||
}
|
||||
)
|
||||
broker.get_balance = AsyncMock(
|
||||
return_value={
|
||||
"output2": [
|
||||
{
|
||||
"tot_evlu_amt": "10000000",
|
||||
"dnca_tot_amt": "5000000",
|
||||
"pchs_amt_smtl_amt": "5000000",
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
broker.send_order = AsyncMock(return_value={"msg1": "OK"})
|
||||
return broker
|
||||
|
||||
@pytest.fixture
|
||||
def mock_overseas_broker(self) -> MagicMock:
|
||||
"""Create mock overseas broker."""
|
||||
broker = MagicMock()
|
||||
return broker
|
||||
|
||||
@pytest.fixture
|
||||
def mock_brain(self) -> MagicMock:
|
||||
"""Create mock brain that decides to buy."""
|
||||
brain = MagicMock()
|
||||
decision = MagicMock()
|
||||
decision.action = "BUY"
|
||||
decision.confidence = 85
|
||||
decision.rationale = "Test buy"
|
||||
brain.decide = AsyncMock(return_value=decision)
|
||||
return brain
|
||||
|
||||
@pytest.fixture
|
||||
def mock_risk(self) -> MagicMock:
|
||||
"""Create mock risk manager."""
|
||||
risk = MagicMock()
|
||||
risk.validate_order = MagicMock()
|
||||
return risk
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db(self) -> MagicMock:
|
||||
"""Create mock database connection."""
|
||||
return MagicMock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_decision_logger(self) -> MagicMock:
|
||||
"""Create mock decision logger."""
|
||||
logger = MagicMock()
|
||||
logger.log_decision = MagicMock()
|
||||
return logger
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context_store(self) -> MagicMock:
|
||||
"""Create mock context store."""
|
||||
store = MagicMock()
|
||||
store.get_latest_timeframe = MagicMock(return_value=None)
|
||||
return store
|
||||
|
||||
@pytest.fixture
|
||||
def mock_criticality_assessor(self) -> MagicMock:
|
||||
"""Create mock criticality assessor."""
|
||||
assessor = MagicMock()
|
||||
assessor.assess_market_conditions = MagicMock(
|
||||
return_value=MagicMock(value="NORMAL")
|
||||
)
|
||||
assessor.get_timeout = MagicMock(return_value=5.0)
|
||||
return assessor
|
||||
|
||||
@pytest.fixture
|
||||
def mock_telegram(self) -> MagicMock:
|
||||
"""Create mock telegram client."""
|
||||
telegram = MagicMock()
|
||||
telegram.notify_trade_execution = AsyncMock()
|
||||
telegram.notify_fat_finger = AsyncMock()
|
||||
telegram.notify_circuit_breaker = AsyncMock()
|
||||
return telegram
|
||||
|
||||
@pytest.fixture
|
||||
def mock_market(self) -> MagicMock:
|
||||
"""Create mock market info."""
|
||||
market = MagicMock()
|
||||
market.name = "Korea"
|
||||
market.code = "KR"
|
||||
market.exchange_code = "KRX"
|
||||
market.is_domestic = True
|
||||
return market
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trade_execution_notification_sent(
|
||||
self,
|
||||
mock_broker: MagicMock,
|
||||
mock_overseas_broker: MagicMock,
|
||||
mock_brain: MagicMock,
|
||||
mock_risk: MagicMock,
|
||||
mock_db: MagicMock,
|
||||
mock_decision_logger: MagicMock,
|
||||
mock_context_store: MagicMock,
|
||||
mock_criticality_assessor: MagicMock,
|
||||
mock_telegram: MagicMock,
|
||||
mock_market: MagicMock,
|
||||
) -> None:
|
||||
"""Test telegram notification sent on trade execution."""
|
||||
with patch("src.main.log_trade"):
|
||||
await trading_cycle(
|
||||
broker=mock_broker,
|
||||
overseas_broker=mock_overseas_broker,
|
||||
brain=mock_brain,
|
||||
risk=mock_risk,
|
||||
db_conn=mock_db,
|
||||
decision_logger=mock_decision_logger,
|
||||
context_store=mock_context_store,
|
||||
criticality_assessor=mock_criticality_assessor,
|
||||
telegram=mock_telegram,
|
||||
market=mock_market,
|
||||
stock_code="005930",
|
||||
scan_candidates={},
|
||||
)
|
||||
|
||||
# Verify notification was sent
|
||||
mock_telegram.notify_trade_execution.assert_called_once()
|
||||
call_kwargs = mock_telegram.notify_trade_execution.call_args.kwargs
|
||||
assert call_kwargs["stock_code"] == "005930"
|
||||
assert call_kwargs["market"] == "Korea"
|
||||
assert call_kwargs["action"] == "BUY"
|
||||
assert call_kwargs["confidence"] == 85
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trade_execution_notification_failure_doesnt_crash(
|
||||
self,
|
||||
mock_broker: MagicMock,
|
||||
mock_overseas_broker: MagicMock,
|
||||
mock_brain: MagicMock,
|
||||
mock_risk: MagicMock,
|
||||
mock_db: MagicMock,
|
||||
mock_decision_logger: MagicMock,
|
||||
mock_context_store: MagicMock,
|
||||
mock_criticality_assessor: MagicMock,
|
||||
mock_telegram: MagicMock,
|
||||
mock_market: MagicMock,
|
||||
) -> None:
|
||||
"""Test trading continues even if notification fails."""
|
||||
# Make notification fail
|
||||
mock_telegram.notify_trade_execution.side_effect = Exception("API error")
|
||||
|
||||
with patch("src.main.log_trade"):
|
||||
# Should not raise exception
|
||||
await trading_cycle(
|
||||
broker=mock_broker,
|
||||
overseas_broker=mock_overseas_broker,
|
||||
brain=mock_brain,
|
||||
risk=mock_risk,
|
||||
db_conn=mock_db,
|
||||
decision_logger=mock_decision_logger,
|
||||
context_store=mock_context_store,
|
||||
criticality_assessor=mock_criticality_assessor,
|
||||
telegram=mock_telegram,
|
||||
market=mock_market,
|
||||
stock_code="005930",
|
||||
scan_candidates={},
|
||||
)
|
||||
|
||||
# Verify notification was attempted
|
||||
mock_telegram.notify_trade_execution.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fat_finger_notification_sent(
|
||||
self,
|
||||
mock_broker: MagicMock,
|
||||
mock_overseas_broker: MagicMock,
|
||||
mock_brain: MagicMock,
|
||||
mock_risk: MagicMock,
|
||||
mock_db: MagicMock,
|
||||
mock_decision_logger: MagicMock,
|
||||
mock_context_store: MagicMock,
|
||||
mock_criticality_assessor: MagicMock,
|
||||
mock_telegram: MagicMock,
|
||||
mock_market: MagicMock,
|
||||
) -> None:
|
||||
"""Test telegram notification sent on fat finger rejection."""
|
||||
# Make risk manager reject the order
|
||||
mock_risk.validate_order.side_effect = FatFingerRejected(
|
||||
order_amount=2000000,
|
||||
total_cash=5000000,
|
||||
max_pct=30.0,
|
||||
)
|
||||
|
||||
with patch("src.main.log_trade"):
|
||||
with pytest.raises(FatFingerRejected):
|
||||
await trading_cycle(
|
||||
broker=mock_broker,
|
||||
overseas_broker=mock_overseas_broker,
|
||||
brain=mock_brain,
|
||||
risk=mock_risk,
|
||||
db_conn=mock_db,
|
||||
decision_logger=mock_decision_logger,
|
||||
context_store=mock_context_store,
|
||||
criticality_assessor=mock_criticality_assessor,
|
||||
telegram=mock_telegram,
|
||||
market=mock_market,
|
||||
stock_code="005930",
|
||||
scan_candidates={},
|
||||
)
|
||||
|
||||
# Verify notification was sent
|
||||
mock_telegram.notify_fat_finger.assert_called_once()
|
||||
call_kwargs = mock_telegram.notify_fat_finger.call_args.kwargs
|
||||
assert call_kwargs["stock_code"] == "005930"
|
||||
assert call_kwargs["order_amount"] == 2000000
|
||||
assert call_kwargs["total_cash"] == 5000000
|
||||
assert call_kwargs["max_pct"] == 30.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fat_finger_notification_failure_still_raises(
|
||||
self,
|
||||
mock_broker: MagicMock,
|
||||
mock_overseas_broker: MagicMock,
|
||||
mock_brain: MagicMock,
|
||||
mock_risk: MagicMock,
|
||||
mock_db: MagicMock,
|
||||
mock_decision_logger: MagicMock,
|
||||
mock_context_store: MagicMock,
|
||||
mock_criticality_assessor: MagicMock,
|
||||
mock_telegram: MagicMock,
|
||||
mock_market: MagicMock,
|
||||
) -> None:
|
||||
"""Test fat finger exception still raised even if notification fails."""
|
||||
# Make risk manager reject the order
|
||||
mock_risk.validate_order.side_effect = FatFingerRejected(
|
||||
order_amount=2000000,
|
||||
total_cash=5000000,
|
||||
max_pct=30.0,
|
||||
)
|
||||
# Make notification fail
|
||||
mock_telegram.notify_fat_finger.side_effect = Exception("API error")
|
||||
|
||||
with patch("src.main.log_trade"):
|
||||
with pytest.raises(FatFingerRejected):
|
||||
await trading_cycle(
|
||||
broker=mock_broker,
|
||||
overseas_broker=mock_overseas_broker,
|
||||
brain=mock_brain,
|
||||
risk=mock_risk,
|
||||
db_conn=mock_db,
|
||||
decision_logger=mock_decision_logger,
|
||||
context_store=mock_context_store,
|
||||
criticality_assessor=mock_criticality_assessor,
|
||||
telegram=mock_telegram,
|
||||
market=mock_market,
|
||||
stock_code="005930",
|
||||
scan_candidates={},
|
||||
)
|
||||
|
||||
# Verify notification was attempted
|
||||
mock_telegram.notify_fat_finger.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_notification_on_hold_decision(
|
||||
self,
|
||||
mock_broker: MagicMock,
|
||||
mock_overseas_broker: MagicMock,
|
||||
mock_brain: MagicMock,
|
||||
mock_risk: MagicMock,
|
||||
mock_db: MagicMock,
|
||||
mock_decision_logger: MagicMock,
|
||||
mock_context_store: MagicMock,
|
||||
mock_criticality_assessor: MagicMock,
|
||||
mock_telegram: MagicMock,
|
||||
mock_market: MagicMock,
|
||||
) -> None:
|
||||
"""Test no trade notification sent when decision is HOLD."""
|
||||
# Change brain decision to HOLD
|
||||
decision = MagicMock()
|
||||
decision.action = "HOLD"
|
||||
decision.confidence = 50
|
||||
decision.rationale = "Insufficient signal"
|
||||
mock_brain.decide = AsyncMock(return_value=decision)
|
||||
|
||||
with patch("src.main.log_trade"):
|
||||
await trading_cycle(
|
||||
broker=mock_broker,
|
||||
overseas_broker=mock_overseas_broker,
|
||||
brain=mock_brain,
|
||||
risk=mock_risk,
|
||||
db_conn=mock_db,
|
||||
decision_logger=mock_decision_logger,
|
||||
context_store=mock_context_store,
|
||||
criticality_assessor=mock_criticality_assessor,
|
||||
telegram=mock_telegram,
|
||||
market=mock_market,
|
||||
stock_code="005930",
|
||||
scan_candidates={},
|
||||
)
|
||||
|
||||
# Verify no trade notification sent
|
||||
mock_telegram.notify_trade_execution.assert_not_called()
|
||||
|
||||
|
||||
class TestRunFunctionTelegramIntegration:
|
||||
"""Test telegram notifications in run function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_circuit_breaker_notification_sent(self) -> None:
|
||||
"""Test telegram notification sent when circuit breaker trips."""
|
||||
mock_telegram = MagicMock()
|
||||
mock_telegram.notify_circuit_breaker = AsyncMock()
|
||||
|
||||
# Simulate circuit breaker exception
|
||||
exc = CircuitBreakerTripped(pnl_pct=-3.5, threshold=-3.0)
|
||||
|
||||
# Test the notification logic
|
||||
try:
|
||||
await mock_telegram.notify_circuit_breaker(
|
||||
pnl_pct=exc.pnl_pct,
|
||||
threshold=exc.threshold,
|
||||
)
|
||||
except Exception:
|
||||
pass # Ignore errors in notification
|
||||
|
||||
# Verify notification was called
|
||||
mock_telegram.notify_circuit_breaker.assert_called_once_with(
|
||||
pnl_pct=-3.5,
|
||||
threshold=-3.0,
|
||||
)
|
||||
|
||||
|
||||
class TestOverseasBalanceParsing:
|
||||
"""Test overseas balance output2 parsing handles different formats."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_overseas_broker_with_list(self) -> MagicMock:
|
||||
"""Create mock overseas broker returning list format."""
|
||||
broker = MagicMock()
|
||||
broker.get_overseas_price = AsyncMock(
|
||||
return_value={"output": {"last": "150.50"}}
|
||||
)
|
||||
broker.get_overseas_balance = AsyncMock(
|
||||
return_value={
|
||||
"output2": [
|
||||
{
|
||||
"frcr_evlu_tota": "10000.00",
|
||||
"frcr_dncl_amt_2": "5000.00",
|
||||
"frcr_buy_amt_smtl": "4500.00",
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
return broker
|
||||
|
||||
@pytest.fixture
|
||||
def mock_overseas_broker_with_dict(self) -> MagicMock:
|
||||
"""Create mock overseas broker returning dict format."""
|
||||
broker = MagicMock()
|
||||
broker.get_overseas_price = AsyncMock(
|
||||
return_value={"output": {"last": "150.50"}}
|
||||
)
|
||||
broker.get_overseas_balance = AsyncMock(
|
||||
return_value={
|
||||
"output2": {
|
||||
"frcr_evlu_tota": "10000.00",
|
||||
"frcr_dncl_amt_2": "5000.00",
|
||||
"frcr_buy_amt_smtl": "4500.00",
|
||||
}
|
||||
}
|
||||
)
|
||||
return broker
|
||||
|
||||
@pytest.fixture
|
||||
def mock_overseas_broker_with_empty(self) -> MagicMock:
|
||||
"""Create mock overseas broker returning empty output2."""
|
||||
broker = MagicMock()
|
||||
broker.get_overseas_price = AsyncMock(
|
||||
return_value={"output": {"last": "150.50"}}
|
||||
)
|
||||
broker.get_overseas_balance = AsyncMock(return_value={"output2": []})
|
||||
return broker
|
||||
|
||||
@pytest.fixture
|
||||
def mock_overseas_broker_with_empty_price(self) -> MagicMock:
|
||||
"""Create mock overseas broker returning empty string for price."""
|
||||
broker = MagicMock()
|
||||
broker.get_overseas_price = AsyncMock(
|
||||
return_value={"output": {"last": ""}} # Empty string
|
||||
)
|
||||
broker.get_overseas_balance = AsyncMock(
|
||||
return_value={
|
||||
"output2": [
|
||||
{
|
||||
"frcr_evlu_tota": "10000.00",
|
||||
"frcr_dncl_amt_2": "5000.00",
|
||||
"frcr_buy_amt_smtl": "4500.00",
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
return broker
|
||||
|
||||
@pytest.fixture
|
||||
def mock_domestic_broker(self) -> MagicMock:
|
||||
"""Create minimal mock domestic broker."""
|
||||
broker = MagicMock()
|
||||
return broker
|
||||
|
||||
@pytest.fixture
|
||||
def mock_overseas_market(self) -> MagicMock:
|
||||
"""Create mock overseas market info."""
|
||||
market = MagicMock()
|
||||
market.name = "NASDAQ"
|
||||
market.code = "US_NASDAQ"
|
||||
market.exchange_code = "NASD"
|
||||
market.is_domestic = False
|
||||
return market
|
||||
|
||||
@pytest.fixture
|
||||
def mock_brain_hold(self) -> MagicMock:
|
||||
"""Create mock brain that always holds."""
|
||||
brain = MagicMock()
|
||||
decision = MagicMock()
|
||||
decision.action = "HOLD"
|
||||
decision.confidence = 50
|
||||
decision.rationale = "Testing balance parsing"
|
||||
brain.decide = AsyncMock(return_value=decision)
|
||||
return brain
|
||||
|
||||
@pytest.fixture
|
||||
def mock_risk(self) -> MagicMock:
|
||||
"""Create mock risk manager."""
|
||||
return MagicMock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db(self) -> MagicMock:
|
||||
"""Create mock database."""
|
||||
return MagicMock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_decision_logger(self) -> MagicMock:
|
||||
"""Create mock decision logger."""
|
||||
return MagicMock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context_store(self) -> MagicMock:
|
||||
"""Create mock context store."""
|
||||
store = MagicMock()
|
||||
store.get_latest_timeframe = MagicMock(return_value=None)
|
||||
return store
|
||||
|
||||
@pytest.fixture
|
||||
def mock_criticality_assessor(self) -> MagicMock:
|
||||
"""Create mock criticality assessor."""
|
||||
assessor = MagicMock()
|
||||
assessor.assess_market_conditions = MagicMock(
|
||||
return_value=MagicMock(value="NORMAL")
|
||||
)
|
||||
assessor.get_timeout = MagicMock(return_value=5.0)
|
||||
return assessor
|
||||
|
||||
@pytest.fixture
|
||||
def mock_telegram(self) -> MagicMock:
|
||||
"""Create mock telegram client."""
|
||||
return MagicMock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_overseas_balance_list_format(
|
||||
self,
|
||||
mock_domestic_broker: MagicMock,
|
||||
mock_overseas_broker_with_list: MagicMock,
|
||||
mock_brain_hold: MagicMock,
|
||||
mock_risk: MagicMock,
|
||||
mock_db: MagicMock,
|
||||
mock_decision_logger: MagicMock,
|
||||
mock_context_store: MagicMock,
|
||||
mock_criticality_assessor: MagicMock,
|
||||
mock_telegram: MagicMock,
|
||||
mock_overseas_market: MagicMock,
|
||||
) -> None:
|
||||
"""Test overseas balance parsing with list format (output2=[{...}])."""
|
||||
with patch("src.main.log_trade"):
|
||||
# Should not raise KeyError
|
||||
await trading_cycle(
|
||||
broker=mock_domestic_broker,
|
||||
overseas_broker=mock_overseas_broker_with_list,
|
||||
brain=mock_brain_hold,
|
||||
risk=mock_risk,
|
||||
db_conn=mock_db,
|
||||
decision_logger=mock_decision_logger,
|
||||
context_store=mock_context_store,
|
||||
criticality_assessor=mock_criticality_assessor,
|
||||
telegram=mock_telegram,
|
||||
market=mock_overseas_market,
|
||||
stock_code="AAPL",
|
||||
scan_candidates={},
|
||||
)
|
||||
|
||||
# Verify balance API was called
|
||||
mock_overseas_broker_with_list.get_overseas_balance.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_overseas_balance_dict_format(
|
||||
self,
|
||||
mock_domestic_broker: MagicMock,
|
||||
mock_overseas_broker_with_dict: MagicMock,
|
||||
mock_brain_hold: MagicMock,
|
||||
mock_risk: MagicMock,
|
||||
mock_db: MagicMock,
|
||||
mock_decision_logger: MagicMock,
|
||||
mock_context_store: MagicMock,
|
||||
mock_criticality_assessor: MagicMock,
|
||||
mock_telegram: MagicMock,
|
||||
mock_overseas_market: MagicMock,
|
||||
) -> None:
|
||||
"""Test overseas balance parsing with dict format (output2={...})."""
|
||||
with patch("src.main.log_trade"):
|
||||
# Should not raise KeyError
|
||||
await trading_cycle(
|
||||
broker=mock_domestic_broker,
|
||||
overseas_broker=mock_overseas_broker_with_dict,
|
||||
brain=mock_brain_hold,
|
||||
risk=mock_risk,
|
||||
db_conn=mock_db,
|
||||
decision_logger=mock_decision_logger,
|
||||
context_store=mock_context_store,
|
||||
criticality_assessor=mock_criticality_assessor,
|
||||
telegram=mock_telegram,
|
||||
market=mock_overseas_market,
|
||||
stock_code="AAPL",
|
||||
scan_candidates={},
|
||||
)
|
||||
|
||||
# Verify balance API was called
|
||||
mock_overseas_broker_with_dict.get_overseas_balance.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_overseas_balance_empty_format(
|
||||
self,
|
||||
mock_domestic_broker: MagicMock,
|
||||
mock_overseas_broker_with_empty: MagicMock,
|
||||
mock_brain_hold: MagicMock,
|
||||
mock_risk: MagicMock,
|
||||
mock_db: MagicMock,
|
||||
mock_decision_logger: MagicMock,
|
||||
mock_context_store: MagicMock,
|
||||
mock_criticality_assessor: MagicMock,
|
||||
mock_telegram: MagicMock,
|
||||
mock_overseas_market: MagicMock,
|
||||
) -> None:
|
||||
"""Test overseas balance parsing with empty output2."""
|
||||
with patch("src.main.log_trade"):
|
||||
# Should not raise KeyError, should default to 0
|
||||
await trading_cycle(
|
||||
broker=mock_domestic_broker,
|
||||
overseas_broker=mock_overseas_broker_with_empty,
|
||||
brain=mock_brain_hold,
|
||||
risk=mock_risk,
|
||||
db_conn=mock_db,
|
||||
decision_logger=mock_decision_logger,
|
||||
context_store=mock_context_store,
|
||||
criticality_assessor=mock_criticality_assessor,
|
||||
telegram=mock_telegram,
|
||||
market=mock_overseas_market,
|
||||
stock_code="AAPL",
|
||||
scan_candidates={},
|
||||
)
|
||||
|
||||
# Verify balance API was called
|
||||
mock_overseas_broker_with_empty.get_overseas_balance.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_overseas_price_empty_string(
|
||||
self,
|
||||
mock_domestic_broker: MagicMock,
|
||||
mock_overseas_broker_with_empty_price: MagicMock,
|
||||
mock_brain_hold: MagicMock,
|
||||
mock_risk: MagicMock,
|
||||
mock_db: MagicMock,
|
||||
mock_decision_logger: MagicMock,
|
||||
mock_context_store: MagicMock,
|
||||
mock_criticality_assessor: MagicMock,
|
||||
mock_telegram: MagicMock,
|
||||
mock_overseas_market: MagicMock,
|
||||
) -> None:
|
||||
"""Test overseas price parsing with empty string (issue #49)."""
|
||||
with patch("src.main.log_trade"):
|
||||
# Should not raise ValueError, should default to 0.0
|
||||
await trading_cycle(
|
||||
broker=mock_domestic_broker,
|
||||
overseas_broker=mock_overseas_broker_with_empty_price,
|
||||
brain=mock_brain_hold,
|
||||
risk=mock_risk,
|
||||
db_conn=mock_db,
|
||||
decision_logger=mock_decision_logger,
|
||||
context_store=mock_context_store,
|
||||
criticality_assessor=mock_criticality_assessor,
|
||||
telegram=mock_telegram,
|
||||
market=mock_overseas_market,
|
||||
stock_code="AAPL",
|
||||
scan_candidates={},
|
||||
)
|
||||
|
||||
# Verify price API was called
|
||||
mock_overseas_broker_with_empty_price.get_overseas_price.assert_called_once()
|
||||
377
tests/test_smart_scanner.py
Normal file
377
tests/test_smart_scanner.py
Normal file
@@ -0,0 +1,377 @@
|
||||
"""Tests for SmartVolatilityScanner."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from src.analysis.smart_scanner import ScanCandidate, SmartVolatilityScanner
|
||||
from src.analysis.volatility import VolatilityAnalyzer
|
||||
from src.broker.kis_api import KISBroker
|
||||
from src.config import Settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings() -> Settings:
|
||||
"""Create test settings."""
|
||||
return Settings(
|
||||
KIS_APP_KEY="test",
|
||||
KIS_APP_SECRET="test",
|
||||
KIS_ACCOUNT_NO="12345678-01",
|
||||
GEMINI_API_KEY="test",
|
||||
RSI_OVERSOLD_THRESHOLD=30,
|
||||
RSI_MOMENTUM_THRESHOLD=70,
|
||||
VOL_MULTIPLIER=2.0,
|
||||
SCANNER_TOP_N=3,
|
||||
DB_PATH=":memory:",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_broker(mock_settings: Settings) -> MagicMock:
|
||||
"""Create mock broker."""
|
||||
broker = MagicMock(spec=KISBroker)
|
||||
broker._settings = mock_settings
|
||||
broker.fetch_market_rankings = AsyncMock()
|
||||
broker.get_daily_prices = AsyncMock()
|
||||
return broker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scanner(mock_broker: MagicMock, mock_settings: Settings) -> SmartVolatilityScanner:
|
||||
"""Create smart scanner instance."""
|
||||
analyzer = VolatilityAnalyzer()
|
||||
return SmartVolatilityScanner(
|
||||
broker=mock_broker,
|
||||
volatility_analyzer=analyzer,
|
||||
settings=mock_settings,
|
||||
)
|
||||
|
||||
|
||||
class TestSmartVolatilityScanner:
|
||||
"""Test suite for SmartVolatilityScanner."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_finds_oversold_candidates(
|
||||
self, scanner: SmartVolatilityScanner, mock_broker: MagicMock
|
||||
) -> None:
|
||||
"""Test that scanner identifies oversold stocks with high volume."""
|
||||
# Mock rankings
|
||||
mock_broker.fetch_market_rankings.return_value = [
|
||||
{
|
||||
"stock_code": "005930",
|
||||
"name": "Samsung",
|
||||
"price": 70000,
|
||||
"volume": 5000000,
|
||||
"change_rate": -3.5,
|
||||
"volume_increase_rate": 250,
|
||||
},
|
||||
]
|
||||
|
||||
# Mock daily prices - trending down (oversold)
|
||||
prices = []
|
||||
for i in range(20):
|
||||
prices.append({
|
||||
"date": f"2026020{i:02d}",
|
||||
"open": 75000 - i * 200,
|
||||
"high": 75500 - i * 200,
|
||||
"low": 74500 - i * 200,
|
||||
"close": 75000 - i * 250, # Steady decline
|
||||
"volume": 2000000,
|
||||
})
|
||||
mock_broker.get_daily_prices.return_value = prices
|
||||
|
||||
candidates = await scanner.scan()
|
||||
|
||||
# Should find at least one candidate (depending on exact RSI calculation)
|
||||
mock_broker.fetch_market_rankings.assert_called_once()
|
||||
mock_broker.get_daily_prices.assert_called_once_with("005930", days=20)
|
||||
|
||||
# If qualified, should have oversold signal
|
||||
if candidates:
|
||||
assert candidates[0].signal in ["oversold", "momentum"]
|
||||
assert candidates[0].volume_ratio >= scanner.vol_multiplier
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_finds_momentum_candidates(
|
||||
self, scanner: SmartVolatilityScanner, mock_broker: MagicMock
|
||||
) -> None:
|
||||
"""Test that scanner identifies momentum stocks with high volume."""
|
||||
mock_broker.fetch_market_rankings.return_value = [
|
||||
{
|
||||
"stock_code": "035420",
|
||||
"name": "NAVER",
|
||||
"price": 250000,
|
||||
"volume": 3000000,
|
||||
"change_rate": 5.0,
|
||||
"volume_increase_rate": 300,
|
||||
},
|
||||
]
|
||||
|
||||
# Mock daily prices - trending up (momentum)
|
||||
prices = []
|
||||
for i in range(20):
|
||||
prices.append({
|
||||
"date": f"2026020{i:02d}",
|
||||
"open": 230000 + i * 500,
|
||||
"high": 231000 + i * 500,
|
||||
"low": 229000 + i * 500,
|
||||
"close": 230500 + i * 500, # Steady rise
|
||||
"volume": 1000000,
|
||||
})
|
||||
mock_broker.get_daily_prices.return_value = prices
|
||||
|
||||
candidates = await scanner.scan()
|
||||
|
||||
mock_broker.fetch_market_rankings.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_filters_low_volume(
|
||||
self, scanner: SmartVolatilityScanner, mock_broker: MagicMock
|
||||
) -> None:
|
||||
"""Test that stocks with low volume ratio are filtered out."""
|
||||
mock_broker.fetch_market_rankings.return_value = [
|
||||
{
|
||||
"stock_code": "000660",
|
||||
"name": "SK Hynix",
|
||||
"price": 150000,
|
||||
"volume": 500000,
|
||||
"change_rate": -5.0,
|
||||
"volume_increase_rate": 50, # Only 50% increase (< 200%)
|
||||
},
|
||||
]
|
||||
|
||||
# Low volume
|
||||
prices = []
|
||||
for i in range(20):
|
||||
prices.append({
|
||||
"date": f"2026020{i:02d}",
|
||||
"open": 150000 - i * 100,
|
||||
"high": 151000 - i * 100,
|
||||
"low": 149000 - i * 100,
|
||||
"close": 150000 - i * 150, # Declining (would be oversold)
|
||||
"volume": 1000000, # Current 500k < 2x prev day 1M
|
||||
})
|
||||
mock_broker.get_daily_prices.return_value = prices
|
||||
|
||||
candidates = await scanner.scan()
|
||||
|
||||
# Should be filtered out due to low volume ratio
|
||||
assert len(candidates) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_filters_neutral_rsi(
|
||||
self, scanner: SmartVolatilityScanner, mock_broker: MagicMock
|
||||
) -> None:
|
||||
"""Test that stocks with neutral RSI are filtered out."""
|
||||
mock_broker.fetch_market_rankings.return_value = [
|
||||
{
|
||||
"stock_code": "051910",
|
||||
"name": "LG Chem",
|
||||
"price": 500000,
|
||||
"volume": 3000000,
|
||||
"change_rate": 0.5,
|
||||
"volume_increase_rate": 300, # High volume
|
||||
},
|
||||
]
|
||||
|
||||
# Flat prices (neutral RSI ~50)
|
||||
prices = []
|
||||
for i in range(20):
|
||||
prices.append({
|
||||
"date": f"2026020{i:02d}",
|
||||
"open": 500000 + (i % 2) * 100, # Small oscillation
|
||||
"high": 500500,
|
||||
"low": 499500,
|
||||
"close": 500000 + (i % 2) * 50,
|
||||
"volume": 1000000,
|
||||
})
|
||||
mock_broker.get_daily_prices.return_value = prices
|
||||
|
||||
candidates = await scanner.scan()
|
||||
|
||||
# Should be filtered out (RSI ~50, not < 30 or > 70)
|
||||
assert len(candidates) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_uses_fallback_on_api_error(
|
||||
self, scanner: SmartVolatilityScanner, mock_broker: MagicMock
|
||||
) -> None:
|
||||
"""Test fallback to static list when ranking API fails."""
|
||||
mock_broker.fetch_market_rankings.side_effect = ConnectionError("API unavailable")
|
||||
|
||||
# Fallback stocks should still be analyzed
|
||||
prices = []
|
||||
for i in range(20):
|
||||
prices.append({
|
||||
"date": f"2026020{i:02d}",
|
||||
"open": 50000 - i * 50,
|
||||
"high": 51000 - i * 50,
|
||||
"low": 49000 - i * 50,
|
||||
"close": 50000 - i * 75, # Declining
|
||||
"volume": 1000000,
|
||||
})
|
||||
mock_broker.get_daily_prices.return_value = prices
|
||||
|
||||
candidates = await scanner.scan(fallback_stocks=["005930", "000660"])
|
||||
|
||||
# Should not crash
|
||||
assert isinstance(candidates, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_returns_top_n_only(
|
||||
self, scanner: SmartVolatilityScanner, mock_broker: MagicMock
|
||||
) -> None:
|
||||
"""Test that scan returns at most top_n candidates."""
|
||||
# Return many stocks
|
||||
mock_broker.fetch_market_rankings.return_value = [
|
||||
{
|
||||
"stock_code": f"00{i}000",
|
||||
"name": f"Stock{i}",
|
||||
"price": 10000 * i,
|
||||
"volume": 5000000,
|
||||
"change_rate": -10,
|
||||
"volume_increase_rate": 500,
|
||||
}
|
||||
for i in range(1, 10)
|
||||
]
|
||||
|
||||
# All oversold with high volume
|
||||
def make_prices(code: str) -> list[dict]:
|
||||
prices = []
|
||||
for i in range(20):
|
||||
prices.append({
|
||||
"date": f"2026020{i:02d}",
|
||||
"open": 10000 - i * 100,
|
||||
"high": 10500 - i * 100,
|
||||
"low": 9500 - i * 100,
|
||||
"close": 10000 - i * 150,
|
||||
"volume": 1000000,
|
||||
})
|
||||
return prices
|
||||
|
||||
mock_broker.get_daily_prices.side_effect = make_prices
|
||||
|
||||
candidates = await scanner.scan()
|
||||
|
||||
# Should respect top_n limit (3)
|
||||
assert len(candidates) <= scanner.top_n
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_skips_insufficient_price_history(
|
||||
self, scanner: SmartVolatilityScanner, mock_broker: MagicMock
|
||||
) -> None:
|
||||
"""Test that stocks with insufficient history are skipped."""
|
||||
mock_broker.fetch_market_rankings.return_value = [
|
||||
{
|
||||
"stock_code": "005930",
|
||||
"name": "Samsung",
|
||||
"price": 70000,
|
||||
"volume": 5000000,
|
||||
"change_rate": -5.0,
|
||||
"volume_increase_rate": 300,
|
||||
},
|
||||
]
|
||||
|
||||
# Only 5 days of data (need 15+ for RSI)
|
||||
mock_broker.get_daily_prices.return_value = [
|
||||
{
|
||||
"date": f"2026020{i:02d}",
|
||||
"open": 70000,
|
||||
"high": 71000,
|
||||
"low": 69000,
|
||||
"close": 70000,
|
||||
"volume": 2000000,
|
||||
}
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
candidates = await scanner.scan()
|
||||
|
||||
# Should skip due to insufficient data
|
||||
assert len(candidates) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stock_codes(
|
||||
self, scanner: SmartVolatilityScanner
|
||||
) -> None:
|
||||
"""Test extraction of stock codes from candidates."""
|
||||
candidates = [
|
||||
ScanCandidate(
|
||||
stock_code="005930",
|
||||
name="Samsung",
|
||||
price=70000,
|
||||
volume=5000000,
|
||||
volume_ratio=2.5,
|
||||
rsi=28,
|
||||
signal="oversold",
|
||||
score=85.0,
|
||||
),
|
||||
ScanCandidate(
|
||||
stock_code="035420",
|
||||
name="NAVER",
|
||||
price=250000,
|
||||
volume=3000000,
|
||||
volume_ratio=3.0,
|
||||
rsi=75,
|
||||
signal="momentum",
|
||||
score=88.0,
|
||||
),
|
||||
]
|
||||
|
||||
codes = scanner.get_stock_codes(candidates)
|
||||
|
||||
assert codes == ["005930", "035420"]
|
||||
|
||||
|
||||
class TestRSICalculation:
|
||||
"""Test RSI calculation in VolatilityAnalyzer."""
|
||||
|
||||
def test_rsi_oversold(self) -> None:
|
||||
"""Test RSI calculation for downtrending prices."""
|
||||
analyzer = VolatilityAnalyzer()
|
||||
|
||||
# Steadily declining prices
|
||||
prices = [100 - i * 0.5 for i in range(20)]
|
||||
rsi = analyzer.calculate_rsi(prices, period=14)
|
||||
|
||||
assert rsi < 50 # Should be oversold territory
|
||||
|
||||
def test_rsi_overbought(self) -> None:
|
||||
"""Test RSI calculation for uptrending prices."""
|
||||
analyzer = VolatilityAnalyzer()
|
||||
|
||||
# Steadily rising prices
|
||||
prices = [100 + i * 0.5 for i in range(20)]
|
||||
rsi = analyzer.calculate_rsi(prices, period=14)
|
||||
|
||||
assert rsi > 50 # Should be overbought territory
|
||||
|
||||
def test_rsi_neutral(self) -> None:
|
||||
"""Test RSI calculation for flat prices."""
|
||||
analyzer = VolatilityAnalyzer()
|
||||
|
||||
# Flat prices with small oscillation
|
||||
prices = [100 + (i % 2) * 0.1 for i in range(20)]
|
||||
rsi = analyzer.calculate_rsi(prices, period=14)
|
||||
|
||||
assert 40 < rsi < 60 # Should be near neutral
|
||||
|
||||
def test_rsi_insufficient_data(self) -> None:
|
||||
"""Test RSI returns neutral when insufficient data."""
|
||||
analyzer = VolatilityAnalyzer()
|
||||
|
||||
prices = [100, 101, 102] # Only 3 prices, need 15+
|
||||
rsi = analyzer.calculate_rsi(prices, period=14)
|
||||
|
||||
assert rsi == 50.0 # Default neutral
|
||||
|
||||
def test_rsi_all_gains(self) -> None:
|
||||
"""Test RSI returns 100 when all gains (no losses)."""
|
||||
analyzer = VolatilityAnalyzer()
|
||||
|
||||
# Monotonic increase
|
||||
prices = [100 + i for i in range(20)]
|
||||
rsi = analyzer.calculate_rsi(prices, period=14)
|
||||
|
||||
assert rsi == 100.0 # Maximum RSI
|
||||
366
tests/test_strategy_models.py
Normal file
366
tests/test_strategy_models.py
Normal file
@@ -0,0 +1,366 @@
|
||||
"""Tests for strategy/playbook Pydantic models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from src.strategy.models import (
|
||||
CrossMarketContext,
|
||||
DayPlaybook,
|
||||
GlobalRule,
|
||||
MarketOutlook,
|
||||
PlaybookStatus,
|
||||
ScenarioAction,
|
||||
StockCondition,
|
||||
StockPlaybook,
|
||||
StockScenario,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# StockCondition
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStockCondition:
|
||||
def test_empty_condition(self) -> None:
|
||||
cond = StockCondition()
|
||||
assert not cond.has_any_condition()
|
||||
|
||||
def test_single_field(self) -> None:
|
||||
cond = StockCondition(rsi_below=30.0)
|
||||
assert cond.has_any_condition()
|
||||
|
||||
def test_multiple_fields(self) -> None:
|
||||
cond = StockCondition(rsi_below=25.0, volume_ratio_above=3.0)
|
||||
assert cond.has_any_condition()
|
||||
|
||||
def test_all_fields(self) -> None:
|
||||
cond = StockCondition(
|
||||
rsi_below=30,
|
||||
rsi_above=10,
|
||||
volume_ratio_above=2.0,
|
||||
volume_ratio_below=10.0,
|
||||
price_above=1000,
|
||||
price_below=50000,
|
||||
price_change_pct_above=-5.0,
|
||||
price_change_pct_below=5.0,
|
||||
)
|
||||
assert cond.has_any_condition()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# StockScenario
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStockScenario:
|
||||
def test_valid_scenario(self) -> None:
|
||||
s = StockScenario(
|
||||
condition=StockCondition(rsi_below=25.0),
|
||||
action=ScenarioAction.BUY,
|
||||
confidence=85,
|
||||
allocation_pct=15.0,
|
||||
stop_loss_pct=-2.0,
|
||||
take_profit_pct=3.0,
|
||||
rationale="Oversold bounce expected",
|
||||
)
|
||||
assert s.action == ScenarioAction.BUY
|
||||
assert s.confidence == 85
|
||||
|
||||
def test_confidence_too_high(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
StockScenario(
|
||||
condition=StockCondition(),
|
||||
action=ScenarioAction.BUY,
|
||||
confidence=101,
|
||||
)
|
||||
|
||||
def test_confidence_too_low(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
StockScenario(
|
||||
condition=StockCondition(),
|
||||
action=ScenarioAction.BUY,
|
||||
confidence=-1,
|
||||
)
|
||||
|
||||
def test_allocation_too_high(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
StockScenario(
|
||||
condition=StockCondition(),
|
||||
action=ScenarioAction.BUY,
|
||||
confidence=80,
|
||||
allocation_pct=101.0,
|
||||
)
|
||||
|
||||
def test_stop_loss_must_be_negative(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
StockScenario(
|
||||
condition=StockCondition(),
|
||||
action=ScenarioAction.BUY,
|
||||
confidence=80,
|
||||
stop_loss_pct=1.0,
|
||||
)
|
||||
|
||||
def test_take_profit_must_be_positive(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
StockScenario(
|
||||
condition=StockCondition(),
|
||||
action=ScenarioAction.BUY,
|
||||
confidence=80,
|
||||
take_profit_pct=-1.0,
|
||||
)
|
||||
|
||||
def test_defaults(self) -> None:
|
||||
s = StockScenario(
|
||||
condition=StockCondition(),
|
||||
action=ScenarioAction.HOLD,
|
||||
confidence=50,
|
||||
)
|
||||
assert s.allocation_pct == 10.0
|
||||
assert s.stop_loss_pct == -2.0
|
||||
assert s.take_profit_pct == 3.0
|
||||
assert s.rationale == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# StockPlaybook
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStockPlaybook:
|
||||
def test_valid_playbook(self) -> None:
|
||||
pb = StockPlaybook(
|
||||
stock_code="005930",
|
||||
stock_name="Samsung Electronics",
|
||||
scenarios=[
|
||||
StockScenario(
|
||||
condition=StockCondition(rsi_below=25.0),
|
||||
action=ScenarioAction.BUY,
|
||||
confidence=85,
|
||||
),
|
||||
],
|
||||
)
|
||||
assert pb.stock_code == "005930"
|
||||
assert len(pb.scenarios) == 1
|
||||
|
||||
def test_empty_scenarios_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
StockPlaybook(
|
||||
stock_code="005930",
|
||||
scenarios=[],
|
||||
)
|
||||
|
||||
def test_multiple_scenarios(self) -> None:
|
||||
pb = StockPlaybook(
|
||||
stock_code="AAPL",
|
||||
scenarios=[
|
||||
StockScenario(
|
||||
condition=StockCondition(rsi_below=25.0),
|
||||
action=ScenarioAction.BUY,
|
||||
confidence=85,
|
||||
),
|
||||
StockScenario(
|
||||
condition=StockCondition(rsi_above=75.0),
|
||||
action=ScenarioAction.SELL,
|
||||
confidence=80,
|
||||
),
|
||||
],
|
||||
)
|
||||
assert len(pb.scenarios) == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GlobalRule
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGlobalRule:
|
||||
def test_valid_rule(self) -> None:
|
||||
rule = GlobalRule(
|
||||
condition="portfolio_pnl_pct < -2.0",
|
||||
action=ScenarioAction.REDUCE_ALL,
|
||||
rationale="Risk limit approaching",
|
||||
)
|
||||
assert rule.action == ScenarioAction.REDUCE_ALL
|
||||
|
||||
def test_hold_rule(self) -> None:
|
||||
rule = GlobalRule(
|
||||
condition="volatility_index > 30",
|
||||
action=ScenarioAction.HOLD,
|
||||
)
|
||||
assert rule.rationale == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CrossMarketContext
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCrossMarketContext:
|
||||
def test_valid_context(self) -> None:
|
||||
ctx = CrossMarketContext(
|
||||
market="US",
|
||||
date="2026-02-07",
|
||||
total_pnl=-1.5,
|
||||
win_rate=40.0,
|
||||
index_change_pct=-2.3,
|
||||
key_events=["Fed rate decision"],
|
||||
lessons=["Avoid tech sector on rate hike days"],
|
||||
)
|
||||
assert ctx.market == "US"
|
||||
assert len(ctx.key_events) == 1
|
||||
|
||||
def test_defaults(self) -> None:
|
||||
ctx = CrossMarketContext(market="KR", date="2026-02-07")
|
||||
assert ctx.total_pnl == 0.0
|
||||
assert ctx.key_events == []
|
||||
assert ctx.lessons == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DayPlaybook
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_scenario(rsi_below: float = 25.0) -> StockScenario:
|
||||
return StockScenario(
|
||||
condition=StockCondition(rsi_below=rsi_below),
|
||||
action=ScenarioAction.BUY,
|
||||
confidence=85,
|
||||
)
|
||||
|
||||
|
||||
def _make_playbook(**kwargs) -> DayPlaybook:
|
||||
defaults = {
|
||||
"date": date(2026, 2, 7),
|
||||
"market": "KR",
|
||||
"stock_playbooks": [
|
||||
StockPlaybook(stock_code="005930", scenarios=[_make_scenario()]),
|
||||
],
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return DayPlaybook(**defaults)
|
||||
|
||||
|
||||
class TestDayPlaybook:
|
||||
def test_valid_playbook(self) -> None:
|
||||
pb = _make_playbook()
|
||||
assert pb.market == "KR"
|
||||
assert pb.date == date(2026, 2, 7)
|
||||
assert pb.default_action == ScenarioAction.HOLD
|
||||
assert pb.scenario_count == 1
|
||||
assert pb.stock_count == 1
|
||||
|
||||
def test_generated_at_auto_set(self) -> None:
|
||||
pb = _make_playbook()
|
||||
assert pb.generated_at != ""
|
||||
|
||||
def test_explicit_generated_at(self) -> None:
|
||||
pb = _make_playbook(generated_at="2026-02-07T08:30:00")
|
||||
assert pb.generated_at == "2026-02-07T08:30:00"
|
||||
|
||||
def test_duplicate_stocks_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
DayPlaybook(
|
||||
date=date(2026, 2, 7),
|
||||
market="KR",
|
||||
stock_playbooks=[
|
||||
StockPlaybook(stock_code="005930", scenarios=[_make_scenario()]),
|
||||
StockPlaybook(stock_code="005930", scenarios=[_make_scenario(30)]),
|
||||
],
|
||||
)
|
||||
|
||||
def test_empty_stock_playbooks_allowed(self) -> None:
|
||||
pb = DayPlaybook(
|
||||
date=date(2026, 2, 7),
|
||||
market="KR",
|
||||
stock_playbooks=[],
|
||||
)
|
||||
assert pb.stock_count == 0
|
||||
assert pb.scenario_count == 0
|
||||
|
||||
def test_get_stock_playbook_found(self) -> None:
|
||||
pb = _make_playbook()
|
||||
result = pb.get_stock_playbook("005930")
|
||||
assert result is not None
|
||||
assert result.stock_code == "005930"
|
||||
|
||||
def test_get_stock_playbook_not_found(self) -> None:
|
||||
pb = _make_playbook()
|
||||
result = pb.get_stock_playbook("AAPL")
|
||||
assert result is None
|
||||
|
||||
def test_with_global_rules(self) -> None:
|
||||
pb = _make_playbook(
|
||||
global_rules=[
|
||||
GlobalRule(
|
||||
condition="portfolio_pnl_pct < -2.0",
|
||||
action=ScenarioAction.REDUCE_ALL,
|
||||
),
|
||||
],
|
||||
)
|
||||
assert len(pb.global_rules) == 1
|
||||
|
||||
def test_with_cross_market_context(self) -> None:
|
||||
ctx = CrossMarketContext(market="US", date="2026-02-07", total_pnl=-1.5)
|
||||
pb = _make_playbook(cross_market=ctx)
|
||||
assert pb.cross_market is not None
|
||||
assert pb.cross_market.market == "US"
|
||||
|
||||
def test_market_outlook(self) -> None:
|
||||
pb = _make_playbook(market_outlook=MarketOutlook.BEARISH)
|
||||
assert pb.market_outlook == MarketOutlook.BEARISH
|
||||
|
||||
def test_multiple_stocks_multiple_scenarios(self) -> None:
|
||||
pb = DayPlaybook(
|
||||
date=date(2026, 2, 7),
|
||||
market="US",
|
||||
stock_playbooks=[
|
||||
StockPlaybook(
|
||||
stock_code="AAPL",
|
||||
scenarios=[_make_scenario(), _make_scenario(30)],
|
||||
),
|
||||
StockPlaybook(
|
||||
stock_code="MSFT",
|
||||
scenarios=[_make_scenario()],
|
||||
),
|
||||
],
|
||||
)
|
||||
assert pb.stock_count == 2
|
||||
assert pb.scenario_count == 3
|
||||
|
||||
def test_serialization_roundtrip(self) -> None:
|
||||
pb = _make_playbook(
|
||||
market_outlook=MarketOutlook.BULLISH,
|
||||
cross_market=CrossMarketContext(market="US", date="2026-02-07"),
|
||||
)
|
||||
json_str = pb.model_dump_json()
|
||||
restored = DayPlaybook.model_validate_json(json_str)
|
||||
assert restored.market == pb.market
|
||||
assert restored.date == pb.date
|
||||
assert restored.scenario_count == pb.scenario_count
|
||||
assert restored.cross_market is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Enums
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnums:
|
||||
def test_scenario_action_values(self) -> None:
|
||||
assert ScenarioAction.BUY.value == "BUY"
|
||||
assert ScenarioAction.SELL.value == "SELL"
|
||||
assert ScenarioAction.HOLD.value == "HOLD"
|
||||
assert ScenarioAction.REDUCE_ALL.value == "REDUCE_ALL"
|
||||
|
||||
def test_market_outlook_values(self) -> None:
|
||||
assert len(MarketOutlook) == 5
|
||||
|
||||
def test_playbook_status_values(self) -> None:
|
||||
assert PlaybookStatus.READY.value == "ready"
|
||||
assert PlaybookStatus.EXPIRED.value == "expired"
|
||||
339
tests/test_telegram.py
Normal file
339
tests/test_telegram.py
Normal file
@@ -0,0 +1,339 @@
|
||||
"""Tests for Telegram notification client."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
from src.notifications.telegram_client import NotificationPriority, TelegramClient
|
||||
|
||||
|
||||
class TestTelegramClientInit:
|
||||
"""Test client initialization scenarios."""
|
||||
|
||||
def test_disabled_via_flag(self) -> None:
|
||||
"""Client disabled via enabled=False flag."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=False
|
||||
)
|
||||
assert client._enabled is False
|
||||
|
||||
def test_disabled_missing_token(self) -> None:
|
||||
"""Client disabled when bot_token is None."""
|
||||
client = TelegramClient(bot_token=None, chat_id="456", enabled=True)
|
||||
assert client._enabled is False
|
||||
|
||||
def test_disabled_missing_chat_id(self) -> None:
|
||||
"""Client disabled when chat_id is None."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id=None, enabled=True)
|
||||
assert client._enabled is False
|
||||
|
||||
def test_enabled_with_credentials(self) -> None:
|
||||
"""Client enabled when credentials provided."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
assert client._enabled is True
|
||||
|
||||
|
||||
class TestNotificationSending:
|
||||
"""Test notification sending behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_success(self) -> None:
|
||||
"""send_message returns True on successful send."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
result = await client.send_message("Test message")
|
||||
|
||||
assert result is True
|
||||
assert mock_post.call_count == 1
|
||||
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
assert payload["chat_id"] == "456"
|
||||
assert payload["text"] == "Test message"
|
||||
assert payload["parse_mode"] == "HTML"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_disabled_client(self) -> None:
|
||||
"""send_message returns False when client disabled."""
|
||||
client = TelegramClient(enabled=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post") as mock_post:
|
||||
result = await client.send_message("Test message")
|
||||
|
||||
assert result is False
|
||||
mock_post.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_api_error(self) -> None:
|
||||
"""send_message returns False on API error."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 400
|
||||
mock_resp.text = AsyncMock(return_value="Bad Request")
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
|
||||
result = await client.send_message("Test message")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_with_markdown(self) -> None:
|
||||
"""send_message supports different parse modes."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
result = await client.send_message("*bold*", parse_mode="Markdown")
|
||||
|
||||
assert result is True
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
assert payload["parse_mode"] == "Markdown"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_send_when_disabled(self) -> None:
|
||||
"""Notifications not sent when client disabled."""
|
||||
client = TelegramClient(enabled=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post") as mock_post:
|
||||
await client.notify_trade_execution(
|
||||
stock_code="AAPL",
|
||||
market="United States",
|
||||
action="BUY",
|
||||
quantity=10,
|
||||
price=150.0,
|
||||
confidence=85.0,
|
||||
)
|
||||
mock_post.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trade_execution_format(self) -> None:
|
||||
"""Trade notification has correct format."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
await client.notify_trade_execution(
|
||||
stock_code="TSLA",
|
||||
market="United States",
|
||||
action="SELL",
|
||||
quantity=5,
|
||||
price=250.50,
|
||||
confidence=92.0,
|
||||
)
|
||||
|
||||
# Verify API call was made
|
||||
assert mock_post.call_count == 1
|
||||
call_args = mock_post.call_args
|
||||
|
||||
# Check payload structure
|
||||
payload = call_args.kwargs["json"]
|
||||
assert payload["chat_id"] == "456"
|
||||
assert "TSLA" in payload["text"]
|
||||
assert "SELL" in payload["text"]
|
||||
assert "5" in payload["text"]
|
||||
assert "250.50" in payload["text"]
|
||||
assert "92%" in payload["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_circuit_breaker_priority(self) -> None:
|
||||
"""Circuit breaker uses CRITICAL priority."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
await client.notify_circuit_breaker(pnl_pct=-3.15, threshold=-3.0)
|
||||
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
# CRITICAL priority has 🚨 emoji
|
||||
assert NotificationPriority.CRITICAL.emoji in payload["text"]
|
||||
assert "-3.15%" in payload["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_error_handling(self) -> None:
|
||||
"""API errors logged but don't crash."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 400
|
||||
mock_resp.text = AsyncMock(return_value="Bad Request")
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
|
||||
# Should not raise exception
|
||||
await client.notify_system_start(mode="paper", enabled_markets=["KR"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_handling(self) -> None:
|
||||
"""Timeouts logged but don't crash."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.post",
|
||||
side_effect=aiohttp.ClientError("Connection timeout"),
|
||||
):
|
||||
# Should not raise exception
|
||||
await client.notify_error(
|
||||
error_type="Test Error", error_msg="Test", context="test"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_management(self) -> None:
|
||||
"""Session created and reused correctly."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
# Session should be None initially
|
||||
assert client._session is None
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
|
||||
await client.notify_market_open("Korea")
|
||||
# Session should be created
|
||||
assert client._session is not None
|
||||
|
||||
session1 = client._session
|
||||
await client.notify_market_close("Korea", 1.5)
|
||||
# Same session should be reused
|
||||
assert client._session is session1
|
||||
|
||||
|
||||
class TestRateLimiting:
|
||||
"""Test rate limiter behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limiter_enforced(self) -> None:
|
||||
"""Rate limiter delays rapid requests."""
|
||||
import time
|
||||
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True, rate_limit=2.0
|
||||
)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
|
||||
start = time.monotonic()
|
||||
|
||||
# Send 3 messages (rate: 2/sec = 0.5s per message)
|
||||
await client.notify_market_open("Korea")
|
||||
await client.notify_market_open("United States")
|
||||
await client.notify_market_open("Japan")
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
# Should take at least 0.4 seconds (3 msgs at 2/sec with some tolerance)
|
||||
assert elapsed >= 0.4
|
||||
|
||||
|
||||
class TestMessagePriorities:
|
||||
"""Test priority-based messaging."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_low_priority_uses_info_emoji(self) -> None:
|
||||
"""LOW priority uses ℹ️ emoji."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
await client.notify_market_open("Korea")
|
||||
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
assert NotificationPriority.LOW.emoji in payload["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_critical_priority_uses_alarm_emoji(self) -> None:
|
||||
"""CRITICAL priority uses 🚨 emoji."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
await client.notify_system_shutdown("Circuit breaker tripped")
|
||||
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
assert NotificationPriority.CRITICAL.emoji in payload["text"]
|
||||
|
||||
|
||||
class TestClientCleanup:
|
||||
"""Test client cleanup behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_closes_session(self) -> None:
|
||||
"""close() closes the HTTP session."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.closed = False
|
||||
mock_session.close = AsyncMock()
|
||||
client._session = mock_session
|
||||
|
||||
await client.close()
|
||||
mock_session.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_handles_no_session(self) -> None:
|
||||
"""close() handles None session gracefully."""
|
||||
client = TelegramClient(
|
||||
bot_token="123:abc", chat_id="456", enabled=True
|
||||
)
|
||||
|
||||
# Should not raise exception
|
||||
await client.close()
|
||||
777
tests/test_telegram_commands.py
Normal file
777
tests/test_telegram_commands.py
Normal file
@@ -0,0 +1,777 @@
|
||||
"""Tests for Telegram command handler."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.notifications.telegram_client import TelegramClient, TelegramCommandHandler
|
||||
|
||||
|
||||
class TestCommandHandlerInit:
|
||||
"""Test command handler initialization."""
|
||||
|
||||
def test_init_with_client(self) -> None:
|
||||
"""Handler initializes with TelegramClient."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
assert handler._client is client
|
||||
assert handler._polling_interval == 1.0
|
||||
assert handler._commands == {}
|
||||
assert handler._running is False
|
||||
|
||||
def test_custom_polling_interval(self) -> None:
|
||||
"""Handler accepts custom polling interval."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client, polling_interval=2.5)
|
||||
|
||||
assert handler._polling_interval == 2.5
|
||||
|
||||
|
||||
class TestCommandRegistration:
|
||||
"""Test command registration."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_command(self) -> None:
|
||||
"""Commands can be registered."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
async def test_handler() -> None:
|
||||
pass
|
||||
|
||||
handler.register_command("test", test_handler)
|
||||
|
||||
assert "test" in handler._commands
|
||||
assert handler._commands["test"] is test_handler
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_multiple_commands(self) -> None:
|
||||
"""Multiple commands can be registered."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
async def handler1() -> None:
|
||||
pass
|
||||
|
||||
async def handler2() -> None:
|
||||
pass
|
||||
|
||||
handler.register_command("start", handler1)
|
||||
handler.register_command("help", handler2)
|
||||
|
||||
assert len(handler._commands) == 2
|
||||
assert handler._commands["start"] is handler1
|
||||
assert handler._commands["help"] is handler2
|
||||
|
||||
|
||||
class TestPollingLifecycle:
|
||||
"""Test polling start/stop."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_polling(self) -> None:
|
||||
"""Polling can be started."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
with patch.object(handler, "_poll_loop", new_callable=AsyncMock):
|
||||
await handler.start_polling()
|
||||
|
||||
assert handler._running is True
|
||||
assert handler._polling_task is not None
|
||||
|
||||
await handler.stop_polling()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_polling_disabled_client(self) -> None:
|
||||
"""Polling not started when client disabled."""
|
||||
client = TelegramClient(enabled=False)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
await handler.start_polling()
|
||||
|
||||
assert handler._running is False
|
||||
assert handler._polling_task is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_polling(self) -> None:
|
||||
"""Polling can be stopped."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
with patch.object(handler, "_poll_loop", new_callable=AsyncMock):
|
||||
await handler.start_polling()
|
||||
await handler.stop_polling()
|
||||
|
||||
assert handler._running is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_double_start_ignored(self) -> None:
|
||||
"""Starting already running handler is ignored."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
with patch.object(handler, "_poll_loop", new_callable=AsyncMock):
|
||||
await handler.start_polling()
|
||||
task1 = handler._polling_task
|
||||
|
||||
await handler.start_polling() # Second start
|
||||
task2 = handler._polling_task
|
||||
|
||||
# Should be the same task
|
||||
assert task1 is task2
|
||||
|
||||
await handler.stop_polling()
|
||||
|
||||
|
||||
class TestUpdateHandling:
|
||||
"""Test update parsing and handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_valid_command(self) -> None:
|
||||
"""Valid commands are executed."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
executed = False
|
||||
|
||||
async def test_command() -> None:
|
||||
nonlocal executed
|
||||
executed = True
|
||||
|
||||
handler.register_command("test", test_command)
|
||||
|
||||
update = {
|
||||
"update_id": 1,
|
||||
"message": {
|
||||
"chat": {"id": 456},
|
||||
"text": "/test",
|
||||
},
|
||||
}
|
||||
|
||||
await handler._handle_update(update)
|
||||
assert executed is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_unknown_command(self) -> None:
|
||||
"""Unknown commands send help message."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
update = {
|
||||
"update_id": 1,
|
||||
"message": {
|
||||
"chat": {"id": 456},
|
||||
"text": "/unknown",
|
||||
},
|
||||
}
|
||||
|
||||
await handler._handle_update(update)
|
||||
|
||||
# Should send error message
|
||||
assert mock_post.call_count == 1
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
assert "Unknown command" in payload["text"]
|
||||
assert "/unknown" in payload["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignore_unauthorized_chat(self) -> None:
|
||||
"""Commands from unauthorized chats are ignored."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
executed = False
|
||||
|
||||
async def test_command() -> None:
|
||||
nonlocal executed
|
||||
executed = True
|
||||
|
||||
handler.register_command("test", test_command)
|
||||
|
||||
update = {
|
||||
"update_id": 1,
|
||||
"message": {
|
||||
"chat": {"id": 999}, # Wrong chat_id
|
||||
"text": "/test",
|
||||
},
|
||||
}
|
||||
|
||||
await handler._handle_update(update)
|
||||
assert executed is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignore_non_command_text(self) -> None:
|
||||
"""Non-command text is ignored."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
executed = False
|
||||
|
||||
async def test_command() -> None:
|
||||
nonlocal executed
|
||||
executed = True
|
||||
|
||||
handler.register_command("test", test_command)
|
||||
|
||||
update = {
|
||||
"update_id": 1,
|
||||
"message": {
|
||||
"chat": {"id": 456},
|
||||
"text": "Hello, not a command",
|
||||
},
|
||||
}
|
||||
|
||||
await handler._handle_update(update)
|
||||
assert executed is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_command_with_botname(self) -> None:
|
||||
"""Commands with @botname suffix are handled correctly."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
executed = False
|
||||
|
||||
async def test_command() -> None:
|
||||
nonlocal executed
|
||||
executed = True
|
||||
|
||||
handler.register_command("start", test_command)
|
||||
|
||||
update = {
|
||||
"update_id": 1,
|
||||
"message": {
|
||||
"chat": {"id": 456},
|
||||
"text": "/start@mybot",
|
||||
},
|
||||
}
|
||||
|
||||
await handler._handle_update(update)
|
||||
assert executed is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_update_error_isolation(self) -> None:
|
||||
"""Errors in handlers don't crash the system."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
async def failing_command() -> None:
|
||||
raise ValueError("Test error")
|
||||
|
||||
handler.register_command("fail", failing_command)
|
||||
|
||||
update = {
|
||||
"update_id": 1,
|
||||
"message": {
|
||||
"chat": {"id": 456},
|
||||
"text": "/fail",
|
||||
},
|
||||
}
|
||||
|
||||
# Should not raise exception
|
||||
await handler._handle_update(update)
|
||||
|
||||
|
||||
class TestTradingControlCommands:
|
||||
"""Test trading control commands."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_command_pauses_trading(self) -> None:
|
||||
"""Stop command clears pause event."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
# Create mock pause event
|
||||
import asyncio
|
||||
|
||||
pause_event = asyncio.Event()
|
||||
pause_event.set() # Initially active
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def mock_stop() -> None:
|
||||
"""Mock /stop handler."""
|
||||
if not pause_event.is_set():
|
||||
await client.send_message("⏸️ Trading is already paused")
|
||||
return
|
||||
|
||||
pause_event.clear()
|
||||
await client.send_message(
|
||||
"<b>⏸️ Trading Paused</b>\n\n"
|
||||
"All trading operations have been suspended.\n"
|
||||
"Use /resume to restart trading."
|
||||
)
|
||||
|
||||
handler.register_command("stop", mock_stop)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
update = {
|
||||
"update_id": 1,
|
||||
"message": {
|
||||
"chat": {"id": 456},
|
||||
"text": "/stop",
|
||||
},
|
||||
}
|
||||
|
||||
await handler._handle_update(update)
|
||||
|
||||
# Verify pause event was cleared
|
||||
assert not pause_event.is_set()
|
||||
|
||||
# Verify message was sent
|
||||
assert mock_post.call_count == 1
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
assert "Trading Paused" in payload["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_command_resumes_trading(self) -> None:
|
||||
"""Resume command sets pause event."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
# Create mock pause event (initially paused)
|
||||
import asyncio
|
||||
|
||||
pause_event = asyncio.Event()
|
||||
pause_event.clear() # Initially paused
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def mock_resume() -> None:
|
||||
"""Mock /resume handler."""
|
||||
if pause_event.is_set():
|
||||
await client.send_message("▶️ Trading is already active")
|
||||
return
|
||||
|
||||
pause_event.set()
|
||||
await client.send_message(
|
||||
"<b>▶️ Trading Resumed</b>\n\n"
|
||||
"Trading operations have been restarted."
|
||||
)
|
||||
|
||||
handler.register_command("resume", mock_resume)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
update = {
|
||||
"update_id": 1,
|
||||
"message": {
|
||||
"chat": {"id": 456},
|
||||
"text": "/resume",
|
||||
},
|
||||
}
|
||||
|
||||
await handler._handle_update(update)
|
||||
|
||||
# Verify pause event was set
|
||||
assert pause_event.is_set()
|
||||
|
||||
# Verify message was sent
|
||||
assert mock_post.call_count == 1
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
assert "Trading Resumed" in payload["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_when_already_paused(self) -> None:
|
||||
"""Stop command when already paused sends appropriate message."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
# Create mock pause event (already paused)
|
||||
import asyncio
|
||||
|
||||
pause_event = asyncio.Event()
|
||||
pause_event.clear()
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def mock_stop() -> None:
|
||||
"""Mock /stop handler."""
|
||||
if not pause_event.is_set():
|
||||
await client.send_message("⏸️ Trading is already paused")
|
||||
return
|
||||
|
||||
pause_event.clear()
|
||||
|
||||
handler.register_command("stop", mock_stop)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
update = {
|
||||
"update_id": 1,
|
||||
"message": {
|
||||
"chat": {"id": 456},
|
||||
"text": "/stop",
|
||||
},
|
||||
}
|
||||
|
||||
await handler._handle_update(update)
|
||||
|
||||
# Verify message was sent
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
assert "already paused" in payload["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_when_already_active(self) -> None:
|
||||
"""Resume command when already active sends appropriate message."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
# Create mock pause event (already active)
|
||||
import asyncio
|
||||
|
||||
pause_event = asyncio.Event()
|
||||
pause_event.set()
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def mock_resume() -> None:
|
||||
"""Mock /resume handler."""
|
||||
if pause_event.is_set():
|
||||
await client.send_message("▶️ Trading is already active")
|
||||
return
|
||||
|
||||
pause_event.set()
|
||||
|
||||
handler.register_command("resume", mock_resume)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
update = {
|
||||
"update_id": 1,
|
||||
"message": {
|
||||
"chat": {"id": 456},
|
||||
"text": "/resume",
|
||||
},
|
||||
}
|
||||
|
||||
await handler._handle_update(update)
|
||||
|
||||
# Verify message was sent
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
assert "already active" in payload["text"]
|
||||
|
||||
|
||||
class TestStatusCommands:
|
||||
"""Test status query commands."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_command_shows_trading_info(self) -> None:
|
||||
"""Status command displays mode, markets, and P&L."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def mock_status() -> None:
|
||||
"""Mock /status handler."""
|
||||
message = (
|
||||
"<b>📊 Trading Status</b>\n\n"
|
||||
"<b>Mode:</b> PAPER\n"
|
||||
"<b>Markets:</b> Korea, United States\n"
|
||||
"<b>Trading:</b> Active\n\n"
|
||||
"<b>Current P&L:</b> +2.50%\n"
|
||||
"<b>Circuit Breaker:</b> -3.0%"
|
||||
)
|
||||
await client.send_message(message)
|
||||
|
||||
handler.register_command("status", mock_status)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
update = {
|
||||
"update_id": 1,
|
||||
"message": {
|
||||
"chat": {"id": 456},
|
||||
"text": "/status",
|
||||
},
|
||||
}
|
||||
|
||||
await handler._handle_update(update)
|
||||
|
||||
# Verify message was sent
|
||||
assert mock_post.call_count == 1
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
assert "Trading Status" in payload["text"]
|
||||
assert "PAPER" in payload["text"]
|
||||
assert "P&L" in payload["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_command_error_handling(self) -> None:
|
||||
"""Status command handles errors gracefully."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def mock_status_error() -> None:
|
||||
"""Mock /status handler with error."""
|
||||
await client.send_message(
|
||||
"<b>⚠️ Error</b>\n\nFailed to retrieve trading status."
|
||||
)
|
||||
|
||||
handler.register_command("status", mock_status_error)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
update = {
|
||||
"update_id": 1,
|
||||
"message": {
|
||||
"chat": {"id": 456},
|
||||
"text": "/status",
|
||||
},
|
||||
}
|
||||
|
||||
await handler._handle_update(update)
|
||||
|
||||
# Should send error message
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
assert "Error" in payload["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_positions_command_shows_holdings(self) -> None:
|
||||
"""Positions command displays account summary."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def mock_positions() -> None:
|
||||
"""Mock /positions handler."""
|
||||
message = (
|
||||
"<b>💼 Account Summary</b>\n\n"
|
||||
"<b>Total Evaluation:</b> ₩10,500,000\n"
|
||||
"<b>Available Cash:</b> ₩5,000,000\n"
|
||||
"<b>Purchase Total:</b> ₩10,000,000\n"
|
||||
"<b>P&L:</b> +5.00%\n\n"
|
||||
"<i>Note: Individual position details require API enhancement</i>"
|
||||
)
|
||||
await client.send_message(message)
|
||||
|
||||
handler.register_command("positions", mock_positions)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
update = {
|
||||
"update_id": 1,
|
||||
"message": {
|
||||
"chat": {"id": 456},
|
||||
"text": "/positions",
|
||||
},
|
||||
}
|
||||
|
||||
await handler._handle_update(update)
|
||||
|
||||
# Verify message was sent
|
||||
assert mock_post.call_count == 1
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
assert "Account Summary" in payload["text"]
|
||||
assert "Total Evaluation" in payload["text"]
|
||||
assert "P&L" in payload["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_positions_command_empty_holdings(self) -> None:
|
||||
"""Positions command handles empty portfolio."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def mock_positions_empty() -> None:
|
||||
"""Mock /positions handler with no positions."""
|
||||
message = (
|
||||
"<b>💼 Account Summary</b>\n\n"
|
||||
"No balance information available."
|
||||
)
|
||||
await client.send_message(message)
|
||||
|
||||
handler.register_command("positions", mock_positions_empty)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
update = {
|
||||
"update_id": 1,
|
||||
"message": {
|
||||
"chat": {"id": 456},
|
||||
"text": "/positions",
|
||||
},
|
||||
}
|
||||
|
||||
await handler._handle_update(update)
|
||||
|
||||
# Verify message was sent
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
assert "No balance information available" in payload["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_positions_command_error_handling(self) -> None:
|
||||
"""Positions command handles errors gracefully."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def mock_positions_error() -> None:
|
||||
"""Mock /positions handler with error."""
|
||||
await client.send_message(
|
||||
"<b>⚠️ Error</b>\n\nFailed to retrieve positions."
|
||||
)
|
||||
|
||||
handler.register_command("positions", mock_positions_error)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
update = {
|
||||
"update_id": 1,
|
||||
"message": {
|
||||
"chat": {"id": 456},
|
||||
"text": "/positions",
|
||||
},
|
||||
}
|
||||
|
||||
await handler._handle_update(update)
|
||||
|
||||
# Should send error message
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
assert "Error" in payload["text"]
|
||||
|
||||
|
||||
class TestBasicCommands:
|
||||
"""Test basic command implementations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_help_command_content(self) -> None:
|
||||
"""Help command lists all available commands."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def mock_help() -> None:
|
||||
"""Mock /help handler."""
|
||||
message = (
|
||||
"<b>📖 Available Commands</b>\n\n"
|
||||
"/help - Show available commands\n"
|
||||
"/status - Trading status (mode, markets, P&L)\n"
|
||||
"/positions - Current holdings\n"
|
||||
"/stop - Pause trading\n"
|
||||
"/resume - Resume trading"
|
||||
)
|
||||
await client.send_message(message)
|
||||
|
||||
handler.register_command("help", mock_help)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
update = {
|
||||
"update_id": 1,
|
||||
"message": {
|
||||
"chat": {"id": 456},
|
||||
"text": "/help",
|
||||
},
|
||||
}
|
||||
|
||||
await handler._handle_update(update)
|
||||
|
||||
# Verify message was sent
|
||||
assert mock_post.call_count == 1
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
assert "Available Commands" in payload["text"]
|
||||
assert "/help" in payload["text"]
|
||||
assert "/status" in payload["text"]
|
||||
assert "/positions" in payload["text"]
|
||||
assert "/stop" in payload["text"]
|
||||
assert "/resume" in payload["text"]
|
||||
|
||||
|
||||
class TestGetUpdates:
|
||||
"""Test getUpdates API interaction."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_updates_success(self) -> None:
|
||||
"""getUpdates fetches and parses updates."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(
|
||||
return_value={
|
||||
"ok": True,
|
||||
"result": [
|
||||
{"update_id": 1, "message": {"text": "/test"}},
|
||||
{"update_id": 2, "message": {"text": "/help"}},
|
||||
],
|
||||
}
|
||||
)
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
|
||||
updates = await handler._get_updates()
|
||||
|
||||
assert len(updates) == 2
|
||||
assert updates[0]["update_id"] == 1
|
||||
assert updates[1]["update_id"] == 2
|
||||
assert handler._last_update_id == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_updates_api_error(self) -> None:
|
||||
"""getUpdates handles API errors gracefully."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 400
|
||||
mock_resp.text = AsyncMock(return_value="Bad Request")
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
|
||||
updates = await handler._get_updates()
|
||||
|
||||
assert updates == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_updates_empty_result(self) -> None:
|
||||
"""getUpdates handles empty results."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(return_value={"ok": True, "result": []})
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
|
||||
updates = await handler._get_updates()
|
||||
|
||||
assert updates == []
|
||||
663
tests/test_token_efficiency.py
Normal file
663
tests/test_token_efficiency.py
Normal file
@@ -0,0 +1,663 @@
|
||||
"""Tests for token efficiency optimization components.
|
||||
|
||||
Tests cover:
|
||||
- Prompt compression and optimization
|
||||
- Context selection logic
|
||||
- Summarization
|
||||
- Caching
|
||||
- Token reduction metrics
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from src.brain.cache import DecisionCache
|
||||
from src.brain.context_selector import ContextSelector, DecisionType
|
||||
from src.brain.gemini_client import TradeDecision
|
||||
from src.brain.prompt_optimizer import PromptOptimizer, TokenMetrics
|
||||
from src.context.layer import ContextLayer
|
||||
from src.context.store import ContextStore
|
||||
from src.context.summarizer import ContextSummarizer, SummaryStats
|
||||
|
||||
# ============================================================================
|
||||
# Prompt Optimizer Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestPromptOptimizer:
|
||||
"""Tests for PromptOptimizer."""
|
||||
|
||||
def test_estimate_tokens(self):
|
||||
"""Test token estimation."""
|
||||
optimizer = PromptOptimizer()
|
||||
|
||||
# Empty text
|
||||
assert optimizer.estimate_tokens("") == 0
|
||||
|
||||
# Short text (4 chars = 1 token estimate)
|
||||
assert optimizer.estimate_tokens("test") == 1
|
||||
|
||||
# Longer text
|
||||
text = "This is a longer piece of text for testing token estimation."
|
||||
tokens = optimizer.estimate_tokens(text)
|
||||
assert tokens > 0
|
||||
assert tokens == len(text) // 4
|
||||
|
||||
def test_count_tokens(self):
|
||||
"""Test token counting metrics."""
|
||||
optimizer = PromptOptimizer()
|
||||
|
||||
text = "Hello world, this is a test."
|
||||
metrics = optimizer.count_tokens(text)
|
||||
|
||||
assert isinstance(metrics, TokenMetrics)
|
||||
assert metrics.char_count == len(text)
|
||||
assert metrics.word_count == 6
|
||||
assert metrics.estimated_tokens > 0
|
||||
|
||||
def test_compress_json(self):
|
||||
"""Test JSON compression."""
|
||||
optimizer = PromptOptimizer()
|
||||
|
||||
data = {
|
||||
"action": "BUY",
|
||||
"confidence": 85,
|
||||
"rationale": "Strong uptrend",
|
||||
}
|
||||
|
||||
compressed = optimizer.compress_json(data)
|
||||
|
||||
# Should have no newlines and minimal whitespace
|
||||
assert "\n" not in compressed
|
||||
# Note: JSON values may contain spaces (e.g., "Strong uptrend")
|
||||
# but there should be no spaces around separators
|
||||
assert ": " not in compressed
|
||||
assert ", " not in compressed
|
||||
|
||||
# Should be valid JSON
|
||||
import json
|
||||
|
||||
parsed = json.loads(compressed)
|
||||
assert parsed == data
|
||||
|
||||
def test_abbreviate_text(self):
|
||||
"""Test text abbreviation."""
|
||||
optimizer = PromptOptimizer()
|
||||
|
||||
text = "The current price is high and volume is increasing."
|
||||
abbreviated = optimizer.abbreviate_text(text)
|
||||
|
||||
# Should contain abbreviations
|
||||
assert "cur" in abbreviated or "P" in abbreviated
|
||||
assert len(abbreviated) <= len(text)
|
||||
|
||||
def test_abbreviate_text_aggressive(self):
|
||||
"""Test aggressive text abbreviation."""
|
||||
optimizer = PromptOptimizer()
|
||||
|
||||
text = "The price is increasing and the volume is high."
|
||||
abbreviated = optimizer.abbreviate_text(text, aggressive=True)
|
||||
|
||||
# Should be shorter
|
||||
assert len(abbreviated) < len(text)
|
||||
|
||||
# Should have removed articles
|
||||
assert "the" not in abbreviated.lower()
|
||||
|
||||
def test_build_compressed_prompt(self):
|
||||
"""Test compressed prompt building."""
|
||||
optimizer = PromptOptimizer()
|
||||
|
||||
market_data = {
|
||||
"stock_code": "005930",
|
||||
"current_price": 75000,
|
||||
"market_name": "Korean stock market",
|
||||
}
|
||||
|
||||
prompt = optimizer.build_compressed_prompt(market_data)
|
||||
|
||||
# Should be much shorter than original
|
||||
assert len(prompt) < 300
|
||||
assert "005930" in prompt
|
||||
assert "75000" in prompt
|
||||
|
||||
def test_build_compressed_prompt_no_instructions(self):
|
||||
"""Test compressed prompt without instructions."""
|
||||
optimizer = PromptOptimizer()
|
||||
|
||||
market_data = {
|
||||
"stock_code": "AAPL",
|
||||
"current_price": 150.5,
|
||||
"market_name": "United States",
|
||||
}
|
||||
|
||||
prompt = optimizer.build_compressed_prompt(market_data, include_instructions=False)
|
||||
|
||||
# Should be very short (data only)
|
||||
assert len(prompt) < 100
|
||||
assert "AAPL" in prompt
|
||||
|
||||
def test_truncate_context(self):
|
||||
"""Test context truncation."""
|
||||
optimizer = PromptOptimizer()
|
||||
|
||||
context = {
|
||||
"price": 100.5,
|
||||
"volume": 1000000,
|
||||
"sentiment": 0.8,
|
||||
"extra_data": "Some long text that should be truncated",
|
||||
}
|
||||
|
||||
# Truncate to small budget
|
||||
truncated = optimizer.truncate_context(context, max_tokens=10)
|
||||
|
||||
# Should have fewer keys
|
||||
assert len(truncated) <= len(context)
|
||||
|
||||
def test_truncate_context_with_priority(self):
|
||||
"""Test context truncation with priority keys."""
|
||||
optimizer = PromptOptimizer()
|
||||
|
||||
context = {
|
||||
"price": 100.5,
|
||||
"volume": 1000000,
|
||||
"sentiment": 0.8,
|
||||
"extra_data": "Some data",
|
||||
}
|
||||
|
||||
priority_keys = ["price", "sentiment"]
|
||||
truncated = optimizer.truncate_context(context, max_tokens=20, priority_keys=priority_keys)
|
||||
|
||||
# Priority keys should be included
|
||||
assert "price" in truncated
|
||||
assert "sentiment" in truncated
|
||||
|
||||
def test_calculate_compression_ratio(self):
|
||||
"""Test compression ratio calculation."""
|
||||
optimizer = PromptOptimizer()
|
||||
|
||||
original = "This is a very long piece of text that should be compressed significantly."
|
||||
compressed = "Short text"
|
||||
|
||||
ratio = optimizer.calculate_compression_ratio(original, compressed)
|
||||
|
||||
# Ratio should be > 1 (original is longer)
|
||||
assert ratio > 1.0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Context Selector Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestContextSelector:
|
||||
"""Tests for ContextSelector."""
|
||||
|
||||
@pytest.fixture
|
||||
def store(self):
|
||||
"""Create in-memory ContextStore."""
|
||||
conn = sqlite3.connect(":memory:")
|
||||
# Create tables
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE context_metadata (
|
||||
layer TEXT PRIMARY KEY,
|
||||
description TEXT,
|
||||
retention_days INTEGER,
|
||||
aggregation_source TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE contexts (
|
||||
layer TEXT,
|
||||
timeframe TEXT,
|
||||
key TEXT,
|
||||
value TEXT,
|
||||
created_at TEXT,
|
||||
updated_at TEXT,
|
||||
PRIMARY KEY (layer, timeframe, key)
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
return ContextStore(conn)
|
||||
|
||||
def test_select_layers_normal(self, store):
|
||||
"""Test layer selection for normal decisions."""
|
||||
selector = ContextSelector(store)
|
||||
|
||||
layers = selector.select_layers(DecisionType.NORMAL)
|
||||
|
||||
# Should only select L7 (real-time)
|
||||
assert layers == [ContextLayer.L7_REALTIME]
|
||||
|
||||
def test_select_layers_strategic(self, store):
|
||||
"""Test layer selection for strategic decisions."""
|
||||
selector = ContextSelector(store)
|
||||
|
||||
layers = selector.select_layers(DecisionType.STRATEGIC)
|
||||
|
||||
# Should select L7 + L6 + L5
|
||||
assert ContextLayer.L7_REALTIME in layers
|
||||
assert ContextLayer.L6_DAILY in layers
|
||||
assert ContextLayer.L5_WEEKLY in layers
|
||||
assert len(layers) == 3
|
||||
|
||||
def test_select_layers_major_event(self, store):
|
||||
"""Test layer selection for major events."""
|
||||
selector = ContextSelector(store)
|
||||
|
||||
layers = selector.select_layers(DecisionType.MAJOR_EVENT)
|
||||
|
||||
# Should select all layers
|
||||
assert len(layers) == 7
|
||||
assert ContextLayer.L1_LEGACY in layers
|
||||
assert ContextLayer.L7_REALTIME in layers
|
||||
|
||||
def test_score_layer_relevance(self, store):
|
||||
"""Test layer relevance scoring."""
|
||||
selector = ContextSelector(store)
|
||||
|
||||
# Add some data first so scores aren't penalized
|
||||
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
|
||||
store.set_context(ContextLayer.L1_LEGACY, "legacy", "lesson", "test")
|
||||
|
||||
# L7 should have high score for normal decisions
|
||||
score = selector.score_layer_relevance(ContextLayer.L7_REALTIME, DecisionType.NORMAL)
|
||||
assert score == 1.0
|
||||
|
||||
# L1 should have low score for normal decisions
|
||||
score = selector.score_layer_relevance(ContextLayer.L1_LEGACY, DecisionType.NORMAL)
|
||||
assert score == 0.0
|
||||
|
||||
# L1 should have high score for major events
|
||||
score = selector.score_layer_relevance(ContextLayer.L1_LEGACY, DecisionType.MAJOR_EVENT)
|
||||
assert score == 1.0
|
||||
|
||||
def test_select_with_scoring(self, store):
|
||||
"""Test selection with relevance scoring."""
|
||||
selector = ContextSelector(store)
|
||||
|
||||
# Add data so layers aren't penalized
|
||||
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
|
||||
|
||||
selection = selector.select_with_scoring(DecisionType.NORMAL, min_score=0.5)
|
||||
|
||||
# Should only select high-relevance layers
|
||||
assert len(selection.layers) >= 1
|
||||
assert ContextLayer.L7_REALTIME in selection.layers
|
||||
assert all(selection.relevance_scores[layer] >= 0.5 for layer in selection.layers)
|
||||
|
||||
def test_get_context_data(self, store):
|
||||
"""Test context data retrieval."""
|
||||
selector = ContextSelector(store)
|
||||
|
||||
# Add some test data
|
||||
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
|
||||
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "volume", 1000000)
|
||||
|
||||
context_data = selector.get_context_data([ContextLayer.L7_REALTIME])
|
||||
|
||||
# Should retrieve data
|
||||
assert "L7_REALTIME" in context_data
|
||||
assert "price" in context_data["L7_REALTIME"]
|
||||
assert context_data["L7_REALTIME"]["price"] == 100.5
|
||||
|
||||
def test_estimate_context_tokens(self, store):
|
||||
"""Test context token estimation."""
|
||||
selector = ContextSelector(store)
|
||||
|
||||
context_data = {
|
||||
"L7_REALTIME": {"price": 100.5, "volume": 1000000},
|
||||
"L6_DAILY": {"avg_price": 99.8, "avg_volume": 950000},
|
||||
}
|
||||
|
||||
tokens = selector.estimate_context_tokens(context_data)
|
||||
|
||||
# Should estimate tokens
|
||||
assert tokens > 0
|
||||
|
||||
def test_optimize_context_for_budget(self, store):
|
||||
"""Test context optimization for token budget."""
|
||||
selector = ContextSelector(store)
|
||||
|
||||
# Add test data
|
||||
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
|
||||
|
||||
# Get optimized context within budget
|
||||
context = selector.optimize_context_for_budget(DecisionType.NORMAL, max_tokens=50)
|
||||
|
||||
# Should return data within budget
|
||||
tokens = selector.estimate_context_tokens(context)
|
||||
assert tokens <= 50
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Context Summarizer Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestContextSummarizer:
|
||||
"""Tests for ContextSummarizer."""
|
||||
|
||||
@pytest.fixture
|
||||
def store(self):
|
||||
"""Create in-memory ContextStore."""
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE context_metadata (
|
||||
layer TEXT PRIMARY KEY,
|
||||
description TEXT,
|
||||
retention_days INTEGER,
|
||||
aggregation_source TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE contexts (
|
||||
layer TEXT,
|
||||
timeframe TEXT,
|
||||
key TEXT,
|
||||
value TEXT,
|
||||
created_at TEXT,
|
||||
updated_at TEXT,
|
||||
PRIMARY KEY (layer, timeframe, key)
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
return ContextStore(conn)
|
||||
|
||||
def test_summarize_numeric_values(self, store):
|
||||
"""Test numeric value summarization."""
|
||||
summarizer = ContextSummarizer(store)
|
||||
|
||||
values = [10.0, 20.0, 30.0, 40.0, 50.0]
|
||||
stats = summarizer.summarize_numeric_values(values)
|
||||
|
||||
assert isinstance(stats, SummaryStats)
|
||||
assert stats.count == 5
|
||||
assert stats.mean == 30.0
|
||||
assert stats.min == 10.0
|
||||
assert stats.max == 50.0
|
||||
assert stats.std is not None
|
||||
|
||||
def test_summarize_numeric_values_trend(self, store):
|
||||
"""Test trend detection in numeric values."""
|
||||
summarizer = ContextSummarizer(store)
|
||||
|
||||
# Uptrend
|
||||
values_up = [10.0, 15.0, 20.0, 25.0, 30.0, 35.0]
|
||||
stats_up = summarizer.summarize_numeric_values(values_up)
|
||||
assert stats_up.trend == "up"
|
||||
|
||||
# Downtrend
|
||||
values_down = [35.0, 30.0, 25.0, 20.0, 15.0, 10.0]
|
||||
stats_down = summarizer.summarize_numeric_values(values_down)
|
||||
assert stats_down.trend == "down"
|
||||
|
||||
# Flat
|
||||
values_flat = [20.0, 20.1, 19.9, 20.0, 20.1, 19.9]
|
||||
stats_flat = summarizer.summarize_numeric_values(values_flat)
|
||||
assert stats_flat.trend == "flat"
|
||||
|
||||
def test_summarize_layer(self, store):
|
||||
"""Test layer summarization."""
|
||||
summarizer = ContextSummarizer(store)
|
||||
|
||||
# Add test data
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-04", "price", 100.5)
|
||||
store.set_context(ContextLayer.L6_DAILY, "2026-02-04", "volume", 1000000)
|
||||
|
||||
summary = summarizer.summarize_layer(ContextLayer.L6_DAILY)
|
||||
|
||||
# Should have summary
|
||||
assert "total_entries" in summary
|
||||
assert summary["total_entries"] > 0
|
||||
|
||||
def test_create_compact_summary(self, store):
|
||||
"""Test compact summary creation."""
|
||||
summarizer = ContextSummarizer(store)
|
||||
|
||||
# Add test data
|
||||
store.set_context(ContextLayer.L7_REALTIME, "2026-02-04", "price", 100.5)
|
||||
|
||||
layers = [ContextLayer.L7_REALTIME, ContextLayer.L6_DAILY]
|
||||
summary = summarizer.create_compact_summary(layers, top_n_metrics=3)
|
||||
|
||||
# Should have summaries for layers
|
||||
assert "L7_REALTIME" in summary
|
||||
|
||||
def test_format_summary_for_prompt(self, store):
|
||||
"""Test summary formatting for prompt."""
|
||||
summarizer = ContextSummarizer(store)
|
||||
|
||||
summary = {
|
||||
"L7_REALTIME": {
|
||||
"price": {"avg": 100.5, "trend": "up"},
|
||||
"volume": {"avg": 1000000, "trend": "flat"},
|
||||
}
|
||||
}
|
||||
|
||||
formatted = summarizer.format_summary_for_prompt(summary)
|
||||
|
||||
# Should be formatted string
|
||||
assert isinstance(formatted, str)
|
||||
assert "L7_REALTIME" in formatted
|
||||
assert "100.5" in formatted or "100.50" in formatted
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Decision Cache Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestDecisionCache:
|
||||
"""Tests for DecisionCache."""
|
||||
|
||||
def test_cache_init(self):
|
||||
"""Test cache initialization."""
|
||||
cache = DecisionCache(ttl_seconds=60, max_size=100)
|
||||
|
||||
assert cache.ttl_seconds == 60
|
||||
assert cache.max_size == 100
|
||||
|
||||
def test_cache_miss(self):
|
||||
"""Test cache miss."""
|
||||
cache = DecisionCache()
|
||||
|
||||
market_data = {"stock_code": "005930", "current_price": 75000}
|
||||
|
||||
decision = cache.get(market_data)
|
||||
|
||||
# Should be None (cache miss)
|
||||
assert decision is None
|
||||
|
||||
metrics = cache.get_metrics()
|
||||
assert metrics.cache_misses == 1
|
||||
assert metrics.cache_hits == 0
|
||||
|
||||
def test_cache_hit(self):
|
||||
"""Test cache hit."""
|
||||
cache = DecisionCache()
|
||||
|
||||
market_data = {"stock_code": "005930", "current_price": 75000}
|
||||
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||
|
||||
# Set cache
|
||||
cache.set(market_data, decision)
|
||||
|
||||
# Get from cache
|
||||
cached = cache.get(market_data)
|
||||
|
||||
assert cached is not None
|
||||
assert cached.action == "HOLD"
|
||||
assert cached.confidence == 50
|
||||
|
||||
metrics = cache.get_metrics()
|
||||
assert metrics.cache_hits == 1
|
||||
|
||||
def test_cache_ttl_expiration(self):
|
||||
"""Test cache TTL expiration."""
|
||||
cache = DecisionCache(ttl_seconds=1) # 1 second TTL
|
||||
|
||||
market_data = {"stock_code": "005930", "current_price": 75000}
|
||||
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||
|
||||
# Set cache
|
||||
cache.set(market_data, decision)
|
||||
|
||||
# Should hit immediately
|
||||
cached = cache.get(market_data)
|
||||
assert cached is not None
|
||||
|
||||
# Wait for expiration
|
||||
time.sleep(1.1)
|
||||
|
||||
# Should miss after expiration
|
||||
cached = cache.get(market_data)
|
||||
assert cached is None
|
||||
|
||||
def test_cache_max_size(self):
|
||||
"""Test cache max size eviction."""
|
||||
cache = DecisionCache(max_size=2)
|
||||
|
||||
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||
|
||||
# Add 3 entries (exceeds max_size)
|
||||
for i in range(3):
|
||||
market_data = {"stock_code": f"00{i}", "current_price": 1000 * i}
|
||||
cache.set(market_data, decision)
|
||||
|
||||
metrics = cache.get_metrics()
|
||||
|
||||
# Should have evicted 1 entry
|
||||
assert metrics.total_entries == 2
|
||||
assert metrics.evictions == 1
|
||||
|
||||
def test_invalidate_all(self):
|
||||
"""Test invalidate all cache entries."""
|
||||
cache = DecisionCache()
|
||||
|
||||
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||
|
||||
# Add entries
|
||||
for i in range(3):
|
||||
market_data = {"stock_code": f"00{i}", "current_price": 1000}
|
||||
cache.set(market_data, decision)
|
||||
|
||||
# Invalidate all
|
||||
count = cache.invalidate()
|
||||
|
||||
assert count == 3
|
||||
|
||||
metrics = cache.get_metrics()
|
||||
assert metrics.total_entries == 0
|
||||
|
||||
def test_invalidate_by_stock(self):
|
||||
"""Test invalidate cache by stock code."""
|
||||
cache = DecisionCache()
|
||||
|
||||
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||
|
||||
# Add entries for different stocks
|
||||
cache.set({"stock_code": "005930", "current_price": 75000}, decision)
|
||||
cache.set({"stock_code": "000660", "current_price": 50000}, decision)
|
||||
|
||||
# Invalidate specific stock
|
||||
count = cache.invalidate("005930")
|
||||
|
||||
assert count >= 1
|
||||
|
||||
# Other stock should still be cached
|
||||
cached = cache.get({"stock_code": "000660", "current_price": 50000})
|
||||
assert cached is not None
|
||||
|
||||
def test_cleanup_expired(self):
|
||||
"""Test cleanup of expired entries."""
|
||||
cache = DecisionCache(ttl_seconds=1)
|
||||
|
||||
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||
|
||||
# Add entry
|
||||
cache.set({"stock_code": "005930", "current_price": 75000}, decision)
|
||||
|
||||
# Wait for expiration
|
||||
time.sleep(1.1)
|
||||
|
||||
# Cleanup
|
||||
count = cache.cleanup_expired()
|
||||
|
||||
assert count == 1
|
||||
|
||||
metrics = cache.get_metrics()
|
||||
assert metrics.total_entries == 0
|
||||
|
||||
def test_should_cache_decision(self):
|
||||
"""Test decision caching criteria."""
|
||||
cache = DecisionCache()
|
||||
|
||||
# HOLD decisions should be cached
|
||||
hold_decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||
assert cache.should_cache_decision(hold_decision) is True
|
||||
|
||||
# High confidence BUY should be cached
|
||||
buy_decision = TradeDecision(action="BUY", confidence=95, rationale="Test")
|
||||
assert cache.should_cache_decision(buy_decision) is True
|
||||
|
||||
# Low confidence BUY should not be cached
|
||||
low_conf_buy = TradeDecision(action="BUY", confidence=60, rationale="Test")
|
||||
assert cache.should_cache_decision(low_conf_buy) is False
|
||||
|
||||
def test_cache_hit_rate(self):
|
||||
"""Test cache hit rate calculation."""
|
||||
cache = DecisionCache()
|
||||
|
||||
decision = TradeDecision(action="HOLD", confidence=50, rationale="Test")
|
||||
market_data = {"stock_code": "005930", "current_price": 75000}
|
||||
|
||||
# First request (miss)
|
||||
cache.get(market_data)
|
||||
|
||||
# Set cache
|
||||
cache.set(market_data, decision)
|
||||
|
||||
# Second request (hit)
|
||||
cache.get(market_data)
|
||||
|
||||
# Third request (hit)
|
||||
cache.get(market_data)
|
||||
|
||||
metrics = cache.get_metrics()
|
||||
|
||||
# 1 miss, 2 hits out of 3 requests
|
||||
assert metrics.total_requests == 3
|
||||
assert metrics.cache_hits == 2
|
||||
assert metrics.cache_misses == 1
|
||||
assert metrics.hit_rate == pytest.approx(2 / 3)
|
||||
|
||||
def test_reset_metrics(self):
|
||||
"""Test metrics reset."""
|
||||
cache = DecisionCache()
|
||||
|
||||
market_data = {"stock_code": "005930", "current_price": 75000}
|
||||
|
||||
# Generate some activity
|
||||
cache.get(market_data)
|
||||
cache.get(market_data)
|
||||
|
||||
# Reset
|
||||
cache.reset_metrics()
|
||||
|
||||
metrics = cache.get_metrics()
|
||||
assert metrics.total_requests == 0
|
||||
assert metrics.cache_hits == 0
|
||||
assert metrics.cache_misses == 0
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sqlite3
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock
|
||||
@@ -338,6 +339,28 @@ class TestMarketScanner:
|
||||
assert metrics.stock_code == "AAPL"
|
||||
assert metrics.current_price == 150.50
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_stock_overseas_empty_price(
|
||||
self,
|
||||
scanner: MarketScanner,
|
||||
mock_overseas_broker: OverseasBroker,
|
||||
context_store: ContextStore,
|
||||
) -> None:
|
||||
"""Test scanning overseas stock with empty price string (issue #49)."""
|
||||
mock_overseas_broker.get_overseas_price.return_value = {
|
||||
"output": {
|
||||
"last": "", # Empty string
|
||||
"tvol": "", # Empty string
|
||||
}
|
||||
}
|
||||
|
||||
market = MARKETS["US_NASDAQ"]
|
||||
metrics = await scanner.scan_stock("AAPL", market)
|
||||
|
||||
assert metrics is not None
|
||||
assert metrics.stock_code == "AAPL"
|
||||
assert metrics.current_price == 0.0 # Should default to 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_stock_error_handling(
|
||||
self,
|
||||
@@ -509,3 +532,45 @@ class TestMarketScanner:
|
||||
new_additions = [code for code in updated if code not in current_watchlist]
|
||||
assert len(new_additions) <= 1
|
||||
assert len(updated) == len(current_watchlist)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_market_respects_concurrency_limit(
|
||||
self,
|
||||
mock_broker: KISBroker,
|
||||
mock_overseas_broker: OverseasBroker,
|
||||
volatility_analyzer: VolatilityAnalyzer,
|
||||
context_store: ContextStore,
|
||||
) -> None:
|
||||
"""scan_market should limit concurrent scans to max_concurrent_scans."""
|
||||
max_concurrent = 2
|
||||
scanner = MarketScanner(
|
||||
broker=mock_broker,
|
||||
overseas_broker=mock_overseas_broker,
|
||||
volatility_analyzer=volatility_analyzer,
|
||||
context_store=context_store,
|
||||
top_n=5,
|
||||
max_concurrent_scans=max_concurrent,
|
||||
)
|
||||
|
||||
# Track peak concurrency
|
||||
active_count = 0
|
||||
peak_count = 0
|
||||
|
||||
original_scan = scanner.scan_stock
|
||||
|
||||
async def tracking_scan(code: str, market: Any) -> VolatilityMetrics:
|
||||
nonlocal active_count, peak_count
|
||||
active_count += 1
|
||||
peak_count = max(peak_count, active_count)
|
||||
await asyncio.sleep(0.05) # Simulate API call duration
|
||||
active_count -= 1
|
||||
return VolatilityMetrics(code, 50000, 500, 1.0, 1.0, 1.0, 1.0, 10.0, 50.0)
|
||||
|
||||
scanner.scan_stock = tracking_scan # type: ignore[method-assign]
|
||||
|
||||
market = MARKETS["KR"]
|
||||
stock_codes = ["001", "002", "003", "004", "005", "006"]
|
||||
|
||||
await scanner.scan_market(market, stock_codes)
|
||||
|
||||
assert peak_count <= max_concurrent
|
||||
|
||||
Reference in New Issue
Block a user