Compare commits

..

1 Commits

Author SHA1 Message Date
agentson
e7de9aa894 Add scanner diagnostics for zero-candidate trade stalls 2026-02-18 00:11:53 +09:00
38 changed files with 433 additions and 10937 deletions

View File

@@ -1,82 +1,36 @@
# ============================================================
# The Ouroboros — Environment Configuration
# ============================================================
# Copy this file to .env and fill in your values.
# Lines starting with # are comments.
# ============================================================
# Korea Investment Securities API # Korea Investment Securities API
# ============================================================
KIS_APP_KEY=your_app_key_here KIS_APP_KEY=your_app_key_here
KIS_APP_SECRET=your_app_secret_here KIS_APP_SECRET=your_app_secret_here
KIS_ACCOUNT_NO=12345678-01 KIS_ACCOUNT_NO=12345678-01
KIS_BASE_URL=https://openapivts.koreainvestment.com:9443
# Paper trading (VTS): https://openapivts.koreainvestment.com:29443
# Live trading: https://openapi.koreainvestment.com:9443
KIS_BASE_URL=https://openapivts.koreainvestment.com:29443
# ============================================================
# Trading Mode
# ============================================================
# paper = 모의투자 (safe for testing), live = 실전투자 (real money)
MODE=paper
# daily = batch per session, realtime = per-stock continuous scan
TRADE_MODE=daily
# Comma-separated market codes: KR, US, JP, HK, CN, VN
ENABLED_MARKETS=KR,US
# Simulated USD cash for paper (VTS) overseas trading.
# VTS overseas balance API often returns 0; this value is used as fallback.
# Set to 0 to disable fallback (not used in live mode).
PAPER_OVERSEAS_CASH=50000.0
# ============================================================
# Google Gemini # Google Gemini
# ============================================================
GEMINI_API_KEY=your_gemini_api_key_here GEMINI_API_KEY=your_gemini_api_key_here
# Recommended: gemini-2.0-flash-exp or gemini-1.5-pro GEMINI_MODEL=gemini-pro
GEMINI_MODEL=gemini-2.0-flash-exp
# ============================================================
# Risk Management # Risk Management
# ============================================================
CIRCUIT_BREAKER_PCT=-3.0 CIRCUIT_BREAKER_PCT=-3.0
FAT_FINGER_PCT=30.0 FAT_FINGER_PCT=30.0
CONFIDENCE_THRESHOLD=80 CONFIDENCE_THRESHOLD=80
# ============================================================
# Database # Database
# ============================================================
DB_PATH=data/trade_logs.db DB_PATH=data/trade_logs.db
# ============================================================ # Rate Limiting (requests per second for KIS API)
# Rate Limiting # Reduced to 5.0 to avoid "초당 거래건수 초과" errors (EGW00201)
# ============================================================ RATE_LIMIT_RPS=5.0
# KIS API real limit is ~2 RPS. Keep at 2.0 for maximum safety.
# Increasing this risks EGW00201 "초당 거래건수 초과" errors.
RATE_LIMIT_RPS=2.0
# ============================================================ # Trading Mode (paper / live)
# External Data APIs (optional) MODE=paper
# ============================================================
# External Data APIs (optional — for enhanced decision-making)
# NEWS_API_KEY=your_news_api_key_here # NEWS_API_KEY=your_news_api_key_here
# NEWS_API_PROVIDER=alphavantage # NEWS_API_PROVIDER=alphavantage
# MARKET_DATA_API_KEY=your_market_data_key_here # MARKET_DATA_API_KEY=your_market_data_key_here
# ============================================================
# Telegram Notifications (optional) # Telegram Notifications (optional)
# ============================================================
# Get bot token from @BotFather on Telegram # Get bot token from @BotFather on Telegram
# Get chat ID from @userinfobot or your chat # Get chat ID from @userinfobot or your chat
# TELEGRAM_BOT_TOKEN=1234567890:ABCdefGHIjklMNOpqrsTUVwxyz # TELEGRAM_BOT_TOKEN=1234567890:ABCdefGHIjklMNOpqrsTUVwxyz
# TELEGRAM_CHAT_ID=123456789 # TELEGRAM_CHAT_ID=123456789
# TELEGRAM_ENABLED=true # TELEGRAM_ENABLED=true
# ============================================================
# Dashboard (optional)
# ============================================================
# DASHBOARD_ENABLED=false
# DASHBOARD_HOST=127.0.0.1
# DASHBOARD_PORT=8080

View File

@@ -94,7 +94,6 @@ Smart Scanner runs in `TRADE_MODE=realtime` only. Daily mode uses static watchli
- **[Testing](docs/testing.md)** — Test structure, coverage requirements, writing tests - **[Testing](docs/testing.md)** — Test structure, coverage requirements, writing tests
- **[Agent Policies](docs/agents.md)** — Prime directives, constraints, prohibited actions - **[Agent Policies](docs/agents.md)** — Prime directives, constraints, prohibited actions
- **[Requirements Log](docs/requirements-log.md)** — User requirements and feedback tracking - **[Requirements Log](docs/requirements-log.md)** — User requirements and feedback tracking
- **[Live Trading Checklist](docs/live-trading-checklist.md)** — 모의→실전 전환 체크리스트
## Core Principles ## Core Principles
@@ -171,7 +170,7 @@ Markets auto-detected based on timezone and enabled in `ENABLED_MARKETS` env var
- `src/core/risk_manager.py` is **READ-ONLY** — changes require human approval - `src/core/risk_manager.py` is **READ-ONLY** — changes require human approval
- Circuit breaker at -3.0% P&L — may only be made **stricter** - Circuit breaker at -3.0% P&L — may only be made **stricter**
- Fat-finger protection: max 30% of cash per order — always enforced - Fat-finger protection: max 30% of cash per order — always enforced
- Confidence 임계값 (market_outlook별, 낮출 수 없음): BEARISH ≥ 90, NEUTRAL/기본 ≥ 80, BULLISH ≥ 75 - Confidence < 80 → force HOLD — cannot be weakened
- All code changes → corresponding tests → coverage ≥ 80% - All code changes → corresponding tests → coverage ≥ 80%
## Contributing ## Contributing

View File

@@ -192,27 +192,6 @@ When `TELEGRAM_COMMANDS_ENABLED=true` (default), the bot accepts these interacti
Commands are only processed from the authorized `TELEGRAM_CHAT_ID`. Commands are only processed from the authorized `TELEGRAM_CHAT_ID`.
## KIS API TR_ID 참조 문서
**TR_ID를 추가하거나 수정할 때 반드시 공식 문서를 먼저 확인할 것.**
공식 문서: `docs/한국투자증권_오픈API_전체문서_20260221_030000.xlsx`
> ⚠️ 커뮤니티 블로그, GitHub 예제 등 비공식 자료의 TR_ID는 오래되거나 틀릴 수 있음.
> 실제로 `VTTT1006U`(미국 매도 — 잘못됨)가 오랫동안 코드에 남아있던 사례가 있음 (Issue #189).
### 주요 TR_ID 목록
| 구분 | 모의투자 TR_ID | 실전투자 TR_ID | 시트명 |
|------|---------------|---------------|--------|
| 해외주식 매수 (미국) | `VTTT1002U` | `TTTT1002U` | 해외주식 주문 |
| 해외주식 매도 (미국) | `VTTT1001U` | `TTTT1006U` | 해외주식 주문 |
새로운 TR_ID가 필요할 때:
1. 위 xlsx 파일에서 해당 거래 유형의 시트를 찾는다.
2. 모의투자(`VTTT`) / 실전투자(`TTTT`) 컬럼을 구분하여 정확한 값을 사용한다.
3. 코드에 출처 주석을 남긴다: `# Source: 한국투자증권_오픈API_전체문서 — '<시트명>' 시트`
## Environment Setup ## Environment Setup
```bash ```bash

View File

@@ -0,0 +1,29 @@
# Issue: Realtime 모드에서 거래가 지속적으로 0건
## Summary
`realtime` 실행 중 주문 단계까지 진입하지 못하고, 스캐너 단계에서 후보가 0건으로 반복 종료된다.
## Observed
- 로그에서 반복적으로 `Smart Scanner: No candidates ... — no trades` 출력
- 해외 시장에서 `Overseas ranking endpoint unavailable (404)` 다수 발생
- fallback 심볼 스캔도 `0 candidates`로 종료
- `data/trade_logs.db` 기준 최근 구간에 `BUY/SELL` 없음
## Impact
- 매매 전략 품질과 무관하게 주문 경로가 실행되지 않아 실질 거래 불가
- 장애 원인을 로그만으로 즉시 분해하기 어려움
## Root-Cause Hypothesis
- 스캐너 필터(가격/변동성) 단계에서 대부분 탈락
- 해외 랭킹 API 불가 시 입력 유니버스가 빈 상태가 되어 후보 생성 실패
- 기존 로그는 최종 결과(0 candidates)만 보여 원인별 분해가 어려움
## Acceptance Criteria
- 스캔 1회마다 탈락 사유가 구조화되어 로그에 남아야 함
- 국내/해외(랭킹/폴백) 경로 모두 동일한 진단 지표를 제공해야 함
- 운영자가 로그만 보고 `왜 0 candidates인지`를 즉시 판단 가능해야 함
## Scope
- 이번 이슈는 **진단 가능성 개선(Observability)** 에 한정
- 후보 생성 전략 변경(기본 유니버스 강제 추가 등)은 별도 이슈로 분리

View File

@@ -1,131 +0,0 @@
# 실전 전환 체크리스트
모의 거래(paper)에서 실전(live)으로 전환하기 전에 아래 항목을 **순서대로** 모두 확인하세요.
---
## 1. 사전 조건
### 1-1. KIS OpenAPI 실전 계좌 준비
- [ ] 한국투자증권 계좌 개설 완료 (일반 위탁 계좌)
- [ ] OpenAPI 실전 사용 신청 (KIS 홈페이지 → Open API → 서비스 신청)
- [ ] 실전용 APP_KEY / APP_SECRET 발급 완료
- [ ] KIS_ACCOUNT_NO 형식 확인: `XXXXXXXX-XX` (8자리-2자리)
### 1-2. 리스크 파라미터 검토
- [ ] `CIRCUIT_BREAKER_PCT` 확인: 기본값 -3.0% (더 엄격하게 조정 권장)
- [ ] `FAT_FINGER_PCT` 확인: 기본값 30.0% (1회 주문 최대 잔고 대비 %)
- [ ] `CONFIDENCE_THRESHOLD` 확인: BEARISH ≥ 90, NEUTRAL ≥ 80, BULLISH ≥ 75
- [ ] 초기 투자금 결정 및 해외 주식 운용 한도 설정
### 1-3. 시스템 요건
- [ ] 커버리지 80% 이상 유지 확인: `pytest --cov=src`
- [ ] 타입 체크 통과: `mypy src/ --strict`
- [ ] Lint 통과: `ruff check src/ tests/`
---
## 2. 환경 설정
### 2-1. `.env` 파일 수정
```bash
# 1. KIS 실전 URL로 변경 (모의: openapivts 포트 29443)
KIS_BASE_URL=https://openapi.koreainvestment.com:9443
# 2. 실전 APP_KEY / APP_SECRET으로 교체
KIS_APP_KEY=<실전_APP_KEY>
KIS_APP_SECRET=<실전_APP_SECRET>
KIS_ACCOUNT_NO=<실전_계좌번호>
# 3. 모드를 live로 변경
MODE=live
# 4. PAPER_OVERSEAS_CASH 비활성화 (live 모드에선 무시되지만 명시적으로 0 설정)
PAPER_OVERSEAS_CASH=0
```
> ⚠️ `KIS_BASE_URL` 포트 주의:
> - **모의(VTS)**: `https://openapivts.koreainvestment.com:29443`
> - **실전**: `https://openapi.koreainvestment.com:9443`
### 2-2. TR_ID 자동 분기 확인
아래 TR_ID는 `MODE` 값에 따라 코드에서 **자동으로 선택**됩니다.
별도 설정 불필요하나, 문제 발생 시 아래 표를 참조하세요.
| 구분 | 모의 TR_ID | 실전 TR_ID |
|------|-----------|-----------|
| 국내 잔고 조회 | `VTTC8434R` | `TTTC8434R` |
| 국내 현금 매수 | `VTTC0012U` | `TTTC0012U` |
| 국내 현금 매도 | `VTTC0011U` | `TTTC0011U` |
| 해외 잔고 조회 | `VTTS3012R` | `TTTS3012R` |
| 해외 매수 | `VTTT1002U` | `TTTT1002U` |
| 해외 매도 | `VTTT1001U` | `TTTT1006U` |
> **출처**: `docs/한국투자증권_오픈API_전체문서_20260221_030000.xlsx` (공식 문서 기준)
---
## 3. 최종 확인
### 3-1. 실전 시작 전 점검
- [ ] DB 백업 완료: `data/trade_logs.db``data/backups/`
- [ ] Telegram 알림 설정 확인 (실전에서는 알림이 더욱 중요)
- [ ] 소액으로 첫 거래 진행 후 TR_ID/계좌 정상 동작 확인
### 3-2. 실행 명령
```bash
# 실전 모드로 실행
python -m src.main --mode=live
# 대시보드 함께 실행 (별도 터미널에서 모니터링)
python -m src.main --mode=live --dashboard
```
### 3-3. 실전 시작 직후 확인 사항
- [ ] 로그에 `MODE=live` 출력 확인
- [ ] 첫 잔고 조회 성공 (ConnectionError 없음)
- [ ] Telegram 알림 수신 확인 ("System started")
- [ ] 첫 주문 후 KIS 앱에서 체결 내역 확인
---
## 4. 비상 정지 방법
### 즉각 정지
```bash
# 터미널에서 Ctrl+C (정상 종료 트리거)
# 또는 Telegram 봇 명령:
/stop
```
### Circuit Breaker 발동 시
- CB가 발동되면 자동으로 거래 중단 및 Telegram 알림 전송
- CB 임계값: `CIRCUIT_BREAKER_PCT` (기본 -3.0%)
- **임계값은 엄격하게만 조정 가능** (더 낮은 음수 값으로만 변경)
---
## 5. 롤백 절차
실전 전환 후 문제 발생 시:
```bash
# 1. 즉시 .env에서 MODE=paper로 복원
# 2. 재시작
python -m src.main --mode=paper
# 3. DB에서 최근 거래 확인
sqlite3 data/trade_logs.db "SELECT * FROM trades ORDER BY id DESC LIMIT 20;"
```
---
## 관련 문서
- [시스템 아키텍처](architecture.md)
- [워크플로우 가이드](workflow.md)
- [재해 복구](disaster_recovery.md)
- [Agent 제약 조건](agents.md)

View File

@@ -0,0 +1,32 @@
# PR: Smart Scanner 진단 로그 추가 (0 candidates 원인 분해)
## Linked Issue
- `docs/issues/ISSUE-2026-02-17-no-trades-zero-candidates.md`
## What Changed
- `src/analysis/smart_scanner.py`에 스캔 진단 카운터 추가
- 국내 스캔 진단 로그 추가
- 해외 랭킹 스캔 진단 로그 추가
- 해외 fallback 심볼 스캔 진단 로그 추가
## Diagnostics Keys
- `total_rows`
- `missing_code`
- `invalid_price`
- `low_volatility`
- `connection_error` (해당 경로에서만)
- `unexpected_error` (해당 경로에서만)
- `qualified`
## Expected Log Examples
- `Domestic scan diagnostics: {...}`
- `Overseas ranking scan diagnostics for US_NASDAQ: {...}`
- `Overseas fallback scan diagnostics for US_NYSE: {...}`
## Out of Scope
- 해외 랭킹 404 시 기본 심볼 유니버스 강제 주입
- 국내 경로 fallback 정책 변경
## Validation
- `.venv/bin/python -m py_compile src/analysis/smart_scanner.py`

View File

@@ -7,32 +7,6 @@
--- ---
## 2026-02-21
### 거래 상태 확인 중 발견된 버그 (#187)
- 거래 상태 점검 요청 → SELL 주문(손절/익절)이 Fat Finger에 막혀 전혀 실행 안 됨 발견
- **#187 (Critical)**: SELL 주문에서 Fat Finger 오탐 — `order_amount/total_cash > 30%`가 SELL에도 적용되어 대형 포지션 매도 불가
- JELD stop-loss -6.20% → 차단, RXT take-profit +46.13% → 차단
- 수정: SELL은 `check_circuit_breaker`만 호출, `validate_order`(Fat Finger 포함) 미호출
---
## 2026-02-20
### 지속적 모니터링 및 개선점 도출 (이슈 #178~#182)
- Dashboard 포함해서 실행하며 간헐적 문제 모니터링 및 개선점 자동 도출 요청
- 모니터링 결과 발견된 이슈 목록:
- **#178**: uvicorn 미설치 → dashboard 미작동 + 오해의 소지 있는 시작 로그 → uvicorn 설치 완료
- **#179 (Critical)**: 잔액 부족 주문 실패 후 매 사이클마다 무한 재시도 (MLECW 20분 이상 반복)
- **#180**: 다중 인스턴스 실행 시 Telegram 409 충돌
- **#181**: implied_rsi 공식 포화 문제 (change_rate≥12.5% → RSI=100)
- **#182 (Critical)**: 보유 종목이 SmartScanner 변동성 필터에 걸려 SELL 신호 미생성 → SELL 체결 0건, 잔고 소진
- 요구사항: 모니터링 자동화 및 주기적 개선점 리포트 도출
---
## 2026-02-05 ## 2026-02-05
### API 효율화 ### API 효율화
@@ -191,167 +165,3 @@
**효과:** **효과:**
- 국내/해외 스캐너 기준이 변동성 중심으로 일관화 - 국내/해외 스캐너 기준이 변동성 중심으로 일관화
- 고변동 구간에서 자동 익스포저 축소, 저변동 구간에서 과소진입 완화 - 고변동 구간에서 자동 익스포저 축소, 저변동 구간에서 과소진입 완화
## 2026-02-18
### KIS 해외 랭킹 API 404 에러 수정
**배경:**
- KIS 해외주식 랭킹 API(`fetch_overseas_rankings`)가 모든 거래소에서 HTTP 404를 반환
- Smart Scanner가 해외 시장 후보 종목을 찾지 못해 거래가 전혀 실행되지 않음
**근본 원인:**
- TR_ID, API 경로, 거래소 코드가 모두 KIS 공식 문서와 불일치
**구현 결과:**
- `src/config.py`: TR_ID/Path 기본값을 KIS 공식 스펙으로 수정
- `src/broker/overseas.py`: 랭킹 API 전용 거래소 코드 매핑 추가 (NASD→NAS, NYSE→NYS, AMEX→AMS), 올바른 API 파라미터 사용
- `tests/test_overseas_broker.py`: 19개 단위 테스트 추가
**효과:**
- 해외 시장 랭킹 스캔이 정상 동작하여 Smart Scanner가 후보 종목 탐지 가능
### Gemini prompt_override 미적용 버그 수정
**배경:**
- `run_overnight` 실행 시 모든 시장에서 Playbook 생성 실패 (`JSONDecodeError`)
- defensive playbook으로 폴백되어 모든 종목이 HOLD 처리
**근본 원인:**
- `pre_market_planner.py``market_data["prompt_override"]`에 Playbook 전용 프롬프트를 넣어 `gemini.decide()` 호출
- `gemini_client.py``decide()` 메서드가 `prompt_override` 키를 전혀 확인하지 않고 항상 일반 트레이드 결정 프롬프트 생성
- Gemini가 Playbook JSON 대신 일반 트레이드 결정을 반환하여 파싱 실패
**구현 결과:**
- `src/brain/gemini_client.py`: `decide()` 메서드에서 `prompt_override` 우선 사용 로직 추가
- `tests/test_brain.py`: 3개 테스트 추가 (override 전달, optimization 우회, 미지정 시 기존 동작 유지)
**이슈/PR:** #143
### 미국장 거래 미실행 근본 원인 분석 및 수정 (자율 실행 세션)
**배경:**
- 사용자 요청: "미국장 열면 프로그램 돌려서 거래 한 번도 못 한 거 꼭 원인 찾아서 해결해줘"
- 프로그램을 미국장 개장(9:30 AM EST) 전부터 실행하여 실시간 로그를 분석
**발견된 근본 원인 #1: Defensive Playbook — BUY 조건 없음**
- Gemini free tier (20 RPD) 소진 → `generate_playbook()` 실패 → `_defensive_playbook()` 폴백
- Defensive playbook은 `price_change_pct_below: -3.0 → SELL` 조건만 존재, BUY 조건 없음
- ScenarioEngine이 항상 HOLD 반환 → 거래 0건
**수정 #1 (PR #146, Issue #145):**
- `src/strategy/pre_market_planner.py`: `_smart_fallback_playbook()` 메서드 추가
- 스캐너 signal 기반 BUY 조건 생성: `momentum → volume_ratio_above`, `oversold → rsi_below`
- 기존 defensive stop-loss SELL 조건 유지
- Gemini 실패 시 defensive → smart fallback으로 전환
- 테스트 10개 추가
**발견된 근본 원인 #2: 가격 API 거래소 코드 불일치 + VTS 잔고 API 오류**
실제 로그:
```
Scenario matched for MRNX: BUY (confidence=80) ✓
Decision for EWUS (NYSE American): BUY (confidence=80) ✓
Skip BUY APLZ (NYSE American): no affordable quantity (cash=0.00, price=0.00) ✗
```
- `get_overseas_price()`: `NASD`/`NYSE`/`AMEX` 전송 → API가 `NAS`/`NYS`/`AMS` 기대 → 빈 응답 → `price=0`
- `VTTS3012R` 잔고 API: "ERROR : INPUT INVALID_CHECK_ACNO" → `total_cash=0`
- 결과: `_determine_order_quantity()` 가 0 반환 → 주문 건너뜀
**수정 #2 (PR #148, Issue #147):**
- `src/broker/overseas.py`: `_PRICE_EXCHANGE_MAP = _RANKING_EXCHANGE_MAP` 추가, 가격 API에 매핑 적용
- `src/config.py`: `PAPER_OVERSEAS_CASH: float = Field(default=50000.0)` — paper 모드 시뮬레이션 잔고
- `src/main.py`: 잔고 0일 때 PAPER_OVERSEAS_CASH 폴백, 가격 0일 때 candidate.price 폴백
- 테스트 8개 추가
**효과:**
- BUY 결정 → 실제 주문 전송까지의 파이프라인이 완전히 동작
- Paper 모드에서 KIS VTS 해외 잔고 API 오류에 관계없이 시뮬레이션 거래 가능
**이슈/PR:** #145, #146, #147, #148
### 해외주식 시장가 주문 거부 수정 (Fix #3, 연속 발견)
**배경:**
- Fix #147 적용 후 주문 전송 시작 → KIS VTS가 거부: "지정가만 가능한 상품입니다"
**근본 원인:**
- `trading_cycle()`, `run_daily_session()` 양쪽에서 `send_overseas_order(price=0.0)` 하드코딩
- `price=0``ORD_DVSN="01"` (시장가) 전송 → KIS VTS 거부
- Fix #147에서 이미 `current_price`를 올바르게 계산했으나 주문 시 미사용
**구현 결과:**
- `src/main.py`: 두 곳에서 `price=0.0``price=current_price`/`price=stock_data["current_price"]`
- `tests/test_main.py`: 회귀 테스트 `test_overseas_buy_order_uses_limit_price` 추가
**최종 확인 로그:**
```
Order result: 모의투자 매수주문이 완료 되었습니다. ✓
```
**이슈/PR:** #149, #150
---
## 2026-02-23
### 국내주식 지정가 전환 및 미체결 처리 (#232)
**배경:**
- 해외주식은 #211에서 지정가로 전환했으나 국내주식은 여전히 `price=0` (시장가)
- KRX도 지정가 주문 사용 시 동일한 미체결 위험이 존재
- 지정가 전환 + 미체결 처리를 함께 구현
**구현 내용:**
1. `src/broker/kis_api.py`
- `get_domestic_pending_orders()`: 모의 즉시 `[]`, 실전 `TTTC0084R` GET
- `cancel_domestic_order()`: 실전 `TTTC0013U` / 모의 `VTTC0013U`, hashkey 필수
2. `src/main.py`
- import `kr_round_down` 추가
- `trading_cycle`, `run_daily_session` 국내 주문 `price=0` → 지정가:
BUY +0.2% / SELL -0.2%, `kr_round_down` KRX 틱 반올림 적용
- `handle_domestic_pending_orders` 함수: BUY→취소+쿨다운, SELL→취소+재주문(-0.4%, 최대1회)
- daily/realtime 두 모드에서 domestic pending 체크 호출 추가
3. 테스트 14개 추가:
- `TestGetDomesticPendingOrders` (3), `TestCancelDomesticOrder` (5)
- `TestHandleDomesticPendingOrders` (4), `TestDomesticLimitOrderPrice` (2)
**이슈/PR:** #232, PR #233
---
## 2026-02-24
### 해외잔고 ghost position 수정 — '모의투자 잔고내역이 없습니다' 반복 방지 (#235)
**배경:**
- 모의투자 실행 시 MLECW, KNRX, NBY, SNSE 등 만료/정지된 종목에 대해
`모의투자 잔고내역이 없습니다` 오류가 매 사이클 반복됨
**근본 원인:**
1. `ovrs_cblc_qty` (해외잔고수량, 총 보유) vs `ord_psbl_qty` (주문가능수량, 실제 매도 가능)
- 기존 코드: `ovrs_cblc_qty` 우선 사용 → 만료 Warrant가 `ovrs_cblc_qty=289456`이지만 실제 `ord_psbl_qty=0`
- startup sync / build_overseas_symbol_universe가 이 종목들을 포지션으로 기록
2. SELL 실패 시 DB 포지션이 닫히지 않아 다음 사이클에서도 재시도 (무한 반복)
**구현 내용:**
1. `src/main.py``_extract_held_codes_from_balance`, `_extract_held_qty_from_balance`
- 해외 잔고 필드 우선순위 변경: `ord_psbl_qty``ovrs_cblc_qty``hldg_qty` (fallback 유지)
- KIS 공식 문서(VTTS3012R) 기준: `ord_psbl_qty`가 실제 매도 가능 수량
2. `src/main.py``trading_cycle` ghost-close 처리
- 해외 SELL이 `잔고내역이 없습니다`로 실패 시 DB 포지션을 `[ghost-close]` SELL로 종료
- exchange code 불일치 등 예외 상황에서 무한 반복 방지
3. 테스트 7개 추가:
- `TestExtractHeldQtyFromBalance` 3개: ord_psbl_qty 우선, 0이면 0 반환, fallback
- `TestExtractHeldCodesFromBalance` 2개: ord_psbl_qty=0인 종목 제외, fallback
- `TestOverseasGhostPositionClose` 2개: ghost-close 로그 확인, 일반 오류 무시
**이슈/PR:** #235, PR #236

View File

@@ -23,7 +23,7 @@ if [ -z "${APP_CMD:-}" ]; then
dashboard_port="${DASHBOARD_PORT:-8080}" dashboard_port="${DASHBOARD_PORT:-8080}"
APP_CMD="DASHBOARD_PORT=$dashboard_port $PYTHON_BIN -m src.main --mode=live --dashboard" APP_CMD="DASHBOARD_PORT=$dashboard_port $PYTHON_BIN -m src.main --mode=paper --dashboard"
fi fi
mkdir -p "$LOG_DIR" mkdir -p "$LOG_DIR"

View File

@@ -128,6 +128,16 @@ class SmartVolatilityScanner:
if not fluct_rows: if not fluct_rows:
return [] return []
diagnostics: dict[str, int | float] = {
"total_rows": len(fluct_rows),
"missing_code": 0,
"invalid_price": 0,
"low_volatility": 0,
"connection_error": 0,
"unexpected_error": 0,
"qualified": 0,
}
volume_rank_bonus: dict[str, float] = {} volume_rank_bonus: dict[str, float] = {}
for idx, row in enumerate(volume_rows): for idx, row in enumerate(volume_rows):
code = _extract_stock_code(row) code = _extract_stock_code(row)
@@ -139,6 +149,7 @@ class SmartVolatilityScanner:
for stock in fluct_rows: for stock in fluct_rows:
stock_code = _extract_stock_code(stock) stock_code = _extract_stock_code(stock)
if not stock_code: if not stock_code:
diagnostics["missing_code"] += 1
continue continue
try: try:
@@ -168,14 +179,18 @@ class SmartVolatilityScanner:
volume_ratio = max(volume_ratio, volume / prev_day_volume) volume_ratio = max(volume_ratio, volume / prev_day_volume)
volatility_pct = max(abs(change_rate), intraday_range_pct) volatility_pct = max(abs(change_rate), intraday_range_pct)
if price <= 0 or volatility_pct < 0.8: if price <= 0:
diagnostics["invalid_price"] += 1
continue
if volatility_pct < 0.8:
diagnostics["low_volatility"] += 1
continue continue
volatility_score = min(volatility_pct / 10.0, 1.0) * 85.0 volatility_score = min(volatility_pct / 10.0, 1.0) * 85.0
liquidity_score = volume_rank_bonus.get(stock_code, 0.0) liquidity_score = volume_rank_bonus.get(stock_code, 0.0)
score = min(100.0, volatility_score + liquidity_score) score = min(100.0, volatility_score + liquidity_score)
signal = "momentum" if change_rate >= 0 else "oversold" signal = "momentum" if change_rate >= 0 else "oversold"
implied_rsi = max(0.0, min(100.0, 50.0 + (change_rate * 2.0))) implied_rsi = max(0.0, min(100.0, 50.0 + (change_rate * 4.0)))
candidates.append( candidates.append(
ScanCandidate( ScanCandidate(
@@ -189,14 +204,22 @@ class SmartVolatilityScanner:
score=score, score=score,
) )
) )
diagnostics["qualified"] += 1
except ConnectionError as exc: except ConnectionError as exc:
diagnostics["connection_error"] += 1
logger.warning("Failed to analyze %s: %s", stock_code, exc) logger.warning("Failed to analyze %s: %s", stock_code, exc)
continue continue
except Exception as exc: except Exception as exc:
diagnostics["unexpected_error"] += 1
logger.error("Unexpected error analyzing %s: %s", stock_code, exc) logger.error("Unexpected error analyzing %s: %s", stock_code, exc)
continue continue
logger.info(
"Domestic scan diagnostics: %s (volatility_threshold=0.8, top_n=%d)",
diagnostics,
self.top_n,
)
logger.info("Domestic ranking scan found %d candidates", len(candidates)) logger.info("Domestic ranking scan found %d candidates", len(candidates))
candidates.sort(key=lambda c: c.score, reverse=True) candidates.sort(key=lambda c: c.score, reverse=True)
return candidates[: self.top_n] return candidates[: self.top_n]
@@ -242,6 +265,14 @@ class SmartVolatilityScanner:
if not fluct_rows: if not fluct_rows:
return [] return []
diagnostics: dict[str, int | float] = {
"total_rows": len(fluct_rows),
"missing_code": 0,
"invalid_price": 0,
"low_volatility": 0,
"qualified": 0,
}
volume_rank_bonus: dict[str, float] = {} volume_rank_bonus: dict[str, float] = {}
try: try:
volume_rows = await self.overseas_broker.fetch_overseas_rankings( volume_rows = await self.overseas_broker.fetch_overseas_rankings(
@@ -266,6 +297,7 @@ class SmartVolatilityScanner:
for row in fluct_rows: for row in fluct_rows:
stock_code = _extract_stock_code(row) stock_code = _extract_stock_code(row)
if not stock_code: if not stock_code:
diagnostics["missing_code"] += 1
continue continue
price = _extract_last_price(row) price = _extract_last_price(row)
@@ -275,14 +307,18 @@ class SmartVolatilityScanner:
volatility_pct = max(abs(change_rate), intraday_range_pct) volatility_pct = max(abs(change_rate), intraday_range_pct)
# Volatility-first filter (not simple gainers/value ranking). # Volatility-first filter (not simple gainers/value ranking).
if price <= 0 or volatility_pct < 0.8: if price <= 0:
diagnostics["invalid_price"] += 1
continue
if volatility_pct < 0.8:
diagnostics["low_volatility"] += 1
continue continue
volatility_score = min(volatility_pct / 10.0, 1.0) * 85.0 volatility_score = min(volatility_pct / 10.0, 1.0) * 85.0
liquidity_score = volume_rank_bonus.get(stock_code, 0.0) liquidity_score = volume_rank_bonus.get(stock_code, 0.0)
score = min(100.0, volatility_score + liquidity_score) score = min(100.0, volatility_score + liquidity_score)
signal = "momentum" if change_rate >= 0 else "oversold" signal = "momentum" if change_rate >= 0 else "oversold"
implied_rsi = max(0.0, min(100.0, 50.0 + (change_rate * 2.0))) implied_rsi = max(0.0, min(100.0, 50.0 + (change_rate * 4.0)))
candidates.append( candidates.append(
ScanCandidate( ScanCandidate(
stock_code=stock_code, stock_code=stock_code,
@@ -295,7 +331,14 @@ class SmartVolatilityScanner:
score=score, score=score,
) )
) )
diagnostics["qualified"] += 1
logger.info(
"Overseas ranking scan diagnostics for %s: %s (volatility_threshold=0.8, top_n=%d)",
market.code,
diagnostics,
self.top_n,
)
if candidates: if candidates:
logger.info( logger.info(
"Overseas ranking scan found %d candidates for %s", "Overseas ranking scan found %d candidates for %s",
@@ -320,6 +363,14 @@ class SmartVolatilityScanner:
len(symbols), len(symbols),
market.name, market.name,
) )
diagnostics: dict[str, int | float] = {
"total_rows": len(symbols),
"invalid_price": 0,
"low_volatility": 0,
"connection_error": 0,
"unexpected_error": 0,
"qualified": 0,
}
candidates: list[ScanCandidate] = [] candidates: list[ScanCandidate] = []
for stock_code in symbols: for stock_code in symbols:
try: try:
@@ -333,12 +384,16 @@ class SmartVolatilityScanner:
intraday_range_pct = _extract_intraday_range_pct(output, price) intraday_range_pct = _extract_intraday_range_pct(output, price)
volatility_pct = max(abs(change_rate), intraday_range_pct) volatility_pct = max(abs(change_rate), intraday_range_pct)
if price <= 0 or volatility_pct < 0.8: if price <= 0:
diagnostics["invalid_price"] += 1
continue
if volatility_pct < 0.8:
diagnostics["low_volatility"] += 1
continue continue
score = min(volatility_pct / 10.0, 1.0) * 100.0 score = min(volatility_pct / 10.0, 1.0) * 100.0
signal = "momentum" if change_rate >= 0 else "oversold" signal = "momentum" if change_rate >= 0 else "oversold"
implied_rsi = max(0.0, min(100.0, 50.0 + (change_rate * 2.0))) implied_rsi = max(0.0, min(100.0, 50.0 + (change_rate * 4.0)))
candidates.append( candidates.append(
ScanCandidate( ScanCandidate(
stock_code=stock_code, stock_code=stock_code,
@@ -351,10 +406,19 @@ class SmartVolatilityScanner:
score=score, score=score,
) )
) )
diagnostics["qualified"] += 1
except ConnectionError as exc: except ConnectionError as exc:
diagnostics["connection_error"] += 1
logger.warning("Failed to analyze overseas %s: %s", stock_code, exc) logger.warning("Failed to analyze overseas %s: %s", stock_code, exc)
except Exception as exc: except Exception as exc:
diagnostics["unexpected_error"] += 1
logger.error("Unexpected error analyzing overseas %s: %s", stock_code, exc) logger.error("Unexpected error analyzing overseas %s: %s", stock_code, exc)
logger.info(
"Overseas fallback scan diagnostics for %s: %s (volatility_threshold=0.8, top_n=%d)",
market.code,
diagnostics,
self.top_n,
)
logger.info( logger.info(
"Overseas symbol fallback scan found %d candidates for %s", "Overseas symbol fallback scan found %d candidates for %s",
len(candidates), len(candidates),

View File

@@ -346,10 +346,8 @@ class GeminiClient:
# Validate required fields # Validate required fields
if not all(k in data for k in ("action", "confidence", "rationale")): if not all(k in data for k in ("action", "confidence", "rationale")):
logger.warning("Missing fields in Gemini response — defaulting to HOLD") logger.warning("Missing fields in Gemini response — defaulting to HOLD")
# Preserve raw text in rationale so prompt_override callers (e.g. pre_market_planner)
# can extract their own JSON format from decision.rationale (#245)
return TradeDecision( return TradeDecision(
action="HOLD", confidence=0, rationale=raw action="HOLD", confidence=0, rationale="Missing required fields"
) )
action = str(data["action"]).upper() action = str(data["action"]).upper()
@@ -412,10 +410,8 @@ class GeminiClient:
cached=True, cached=True,
) )
# Build prompt (prompt_override takes priority for callers like pre_market_planner) # Build optimized prompt
if "prompt_override" in market_data: if self._enable_optimization:
prompt = market_data["prompt_override"]
elif self._enable_optimization:
prompt = self._optimizer.build_compressed_prompt(market_data) prompt = self._optimizer.build_compressed_prompt(market_data)
else: else:
prompt = await self.build_prompt(market_data, news_sentiment) prompt = await self.build_prompt(market_data, news_sentiment)
@@ -441,18 +437,6 @@ class GeminiClient:
action="HOLD", confidence=0, rationale=f"API error: {exc}", token_count=token_count action="HOLD", confidence=0, rationale=f"API error: {exc}", token_count=token_count
) )
# prompt_override callers (e.g. pre_market_planner) expect raw text back,
# not a parsed TradeDecision. Skip parse_response to avoid spurious
# "Missing fields" warnings and return the raw response directly. (#247)
if "prompt_override" in market_data:
logger.info(
"Gemini raw response received (prompt_override, tokens=%d)", token_count
)
# Not a trade decision — don't inflate _total_decisions metrics
return TradeDecision(
action="HOLD", confidence=0, rationale=raw, token_count=token_count
)
decision = self.parse_response(raw) decision = self.parse_response(raw)
self._total_decisions += 1 self._total_decisions += 1

View File

@@ -179,8 +179,8 @@ class PromptOptimizer:
# Minimal instructions # Minimal instructions
prompt = ( prompt = (
f"{market_name} trader. Analyze:\n{data_str}\n\n" f"{market_name} trader. Analyze:\n{data_str}\n\n"
'Return JSON: {"action":"BUY"|"SELL"|"HOLD","confidence":<0-100>,"rationale":"<text>"}\n' 'Return JSON: {"act":"BUY"|"SELL"|"HOLD","conf":<0-100>,"reason":"<text>"}\n'
"Rules: action=BUY/SELL/HOLD, confidence=0-100, rationale=concise. No markdown." "Rules: act=BUY/SELL/HOLD, conf=0-100, reason=concise. No markdown."
) )
else: else:
# Data only (for cached contexts where instructions are known) # Data only (for cached contexts where instructions are known)

View File

@@ -8,7 +8,7 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import ssl import ssl
from typing import Any, cast from typing import Any
import aiohttp import aiohttp
@@ -20,39 +20,6 @@ _KIS_VTS_HOST = "openapivts.koreainvestment.com"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def kr_tick_unit(price: float) -> int:
"""Return KRX tick size for the given price level.
KRX price tick rules (domestic stocks):
price < 2,000 → 1원
2,000 ≤ price < 5,000 → 5원
5,000 ≤ price < 20,000 → 10원
20,000 ≤ price < 50,000 → 50원
50,000 ≤ price < 200,000 → 100원
200,000 ≤ price < 500,000 → 500원
500,000 ≤ price → 1,000원
"""
if price < 2_000:
return 1
if price < 5_000:
return 5
if price < 20_000:
return 10
if price < 50_000:
return 50
if price < 200_000:
return 100
if price < 500_000:
return 500
return 1_000
def kr_round_down(price: float) -> int:
"""Round *down* price to the nearest KRX tick unit."""
tick = kr_tick_unit(price)
return int(price // tick * tick)
class LeakyBucket: class LeakyBucket:
"""Simple leaky-bucket rate limiter for async code.""" """Simple leaky-bucket rate limiter for async code."""
@@ -231,64 +198,12 @@ class KISBroker:
except (TimeoutError, aiohttp.ClientError) as exc: except (TimeoutError, aiohttp.ClientError) as exc:
raise ConnectionError(f"Network error fetching orderbook: {exc}") from exc raise ConnectionError(f"Network error fetching orderbook: {exc}") from exc
async def get_current_price(
self, stock_code: str
) -> tuple[float, float, float]:
"""Fetch current price data for a domestic stock.
Uses the ``inquire-price`` API (FHKST01010100), which works in both
real and VTS environments and returns the actual last-traded price.
Returns:
(current_price, prdy_ctrt, frgn_ntby_qty)
- current_price: Last traded price in KRW.
- prdy_ctrt: Day change rate (%).
- frgn_ntby_qty: Foreigner net buy quantity.
"""
await self._rate_limiter.acquire()
session = self._get_session()
headers = await self._auth_headers("FHKST01010100")
params = {
"FID_COND_MRKT_DIV_CODE": "J",
"FID_INPUT_ISCD": stock_code,
}
url = f"{self._base_url}/uapi/domestic-stock/v1/quotations/inquire-price"
def _f(val: str | None) -> float:
try:
return float(val or "0")
except ValueError:
return 0.0
try:
async with session.get(url, headers=headers, params=params) as resp:
if resp.status != 200:
text = await resp.text()
raise ConnectionError(
f"get_current_price failed ({resp.status}): {text}"
)
data = await resp.json()
out = data.get("output", {})
return (
_f(out.get("stck_prpr")),
_f(out.get("prdy_ctrt")),
_f(out.get("frgn_ntby_qty")),
)
except (TimeoutError, aiohttp.ClientError) as exc:
raise ConnectionError(
f"Network error fetching current price: {exc}"
) from exc
async def get_balance(self) -> dict[str, Any]: async def get_balance(self) -> dict[str, Any]:
"""Fetch current account balance and holdings.""" """Fetch current account balance and holdings."""
await self._rate_limiter.acquire() await self._rate_limiter.acquire()
session = self._get_session() session = self._get_session()
# TR_ID: 실전 TTTC8434R, 모의 VTTC8434R headers = await self._auth_headers("VTTC8434R") # 모의투자 잔고조회
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '국내주식 잔고조회' 시트
tr_id = "TTTC8434R" if self._settings.MODE == "live" else "VTTC8434R"
headers = await self._auth_headers(tr_id)
params = { params = {
"CANO": self._account_no, "CANO": self._account_no,
"ACNT_PRDT_CD": self._product_cd, "ACNT_PRDT_CD": self._product_cd,
@@ -333,30 +248,14 @@ class KISBroker:
await self._rate_limiter.acquire() await self._rate_limiter.acquire()
session = self._get_session() session = self._get_session()
# TR_ID: 실전 BUY=TTTC0012U SELL=TTTC0011U, 모의 BUY=VTTC0012U SELL=VTTC0011U tr_id = "VTTC0802U" if order_type == "BUY" else "VTTC0801U"
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '주식주문(현금)' 시트
# ※ TTTC0802U/VTTC0802U는 미수매수(증거금40% 계좌 전용) — 현금주문에 사용 금지
if self._settings.MODE == "live":
tr_id = "TTTC0012U" if order_type == "BUY" else "TTTC0011U"
else:
tr_id = "VTTC0012U" if order_type == "BUY" else "VTTC0011U"
# KRX requires limit orders to be rounded down to the tick unit.
# ORD_DVSN: "00"=지정가, "01"=시장가
if price > 0:
ord_dvsn = "00" # 지정가
ord_price = kr_round_down(price)
else:
ord_dvsn = "01" # 시장가
ord_price = 0
body = { body = {
"CANO": self._account_no, "CANO": self._account_no,
"ACNT_PRDT_CD": self._product_cd, "ACNT_PRDT_CD": self._product_cd,
"PDNO": stock_code, "PDNO": stock_code,
"ORD_DVSN": ord_dvsn, "ORD_DVSN": "01" if price > 0 else "06", # 01=지정가, 06=시장가
"ORD_QTY": str(quantity), "ORD_QTY": str(quantity),
"ORD_UNPR": str(ord_price), "ORD_UNPR": str(price),
} }
hash_key = await self._get_hash_key(body) hash_key = await self._get_hash_key(body)
@@ -405,46 +304,26 @@ class KISBroker:
await self._rate_limiter.acquire() await self._rate_limiter.acquire()
session = self._get_session() session = self._get_session()
if ranking_type == "volume": # TR_ID for volume ranking
# 거래량순위: FHPST01710000 / /quotations/volume-rank tr_id = "FHPST01710000" if ranking_type == "volume" else "FHPST01710100"
tr_id = "FHPST01710000"
url = f"{self._base_url}/uapi/domestic-stock/v1/quotations/volume-rank"
params: dict[str, str] = {
"FID_COND_MRKT_DIV_CODE": "J",
"FID_COND_SCR_DIV_CODE": "20171",
"FID_INPUT_ISCD": "0000",
"FID_DIV_CLS_CODE": "0",
"FID_BLNG_CLS_CODE": "0",
"FID_TRGT_CLS_CODE": "111111111",
"FID_TRGT_EXLS_CLS_CODE": "0000000000",
"FID_INPUT_PRICE_1": "0",
"FID_INPUT_PRICE_2": "0",
"FID_VOL_CNT": "0",
"FID_INPUT_DATE_1": "",
}
else:
# 등락률순위: FHPST01700000 / /ranking/fluctuation (소문자 파라미터)
tr_id = "FHPST01700000"
url = f"{self._base_url}/uapi/domestic-stock/v1/ranking/fluctuation"
params = {
"fid_cond_mrkt_div_code": "J",
"fid_cond_scr_div_code": "20170",
"fid_input_iscd": "0000",
"fid_rank_sort_cls_code": "0",
"fid_input_cnt_1": str(limit),
"fid_prc_cls_code": "0",
"fid_input_price_1": "0",
"fid_input_price_2": "0",
"fid_vol_cnt": "0",
"fid_trgt_cls_code": "0",
"fid_trgt_exls_cls_code": "0",
"fid_div_cls_code": "0",
"fid_rsfl_rate1": "0",
"fid_rsfl_rate2": "0",
}
headers = await self._auth_headers(tr_id) headers = await self._auth_headers(tr_id)
params = {
"FID_COND_MRKT_DIV_CODE": "J", # Stock/ETF/ETN
"FID_COND_SCR_DIV_CODE": "20001", # Volume surge
"FID_INPUT_ISCD": "0000", # All stocks
"FID_DIV_CLS_CODE": "0", # All types
"FID_BLNG_CLS_CODE": "0",
"FID_TRGT_CLS_CODE": "111111111",
"FID_TRGT_EXLS_CLS_CODE": "000000",
"FID_INPUT_PRICE_1": "0",
"FID_INPUT_PRICE_2": "0",
"FID_VOL_CNT": "0",
"FID_INPUT_DATE_1": "",
}
url = f"{self._base_url}/uapi/domestic-stock/v1/quotations/volume-rank"
try: try:
async with session.get(url, headers=headers, params=params) as resp: async with session.get(url, headers=headers, params=params) as resp:
if resp.status != 200: if resp.status != 200:
@@ -466,7 +345,7 @@ class KISBroker:
rankings = [] rankings = []
for item in data.get("output", [])[:limit]: for item in data.get("output", [])[:limit]:
rankings.append({ rankings.append({
"stock_code": item.get("stck_shrn_iscd") or item.get("mksc_shrn_iscd", ""), "stock_code": item.get("mksc_shrn_iscd", ""),
"name": item.get("hts_kor_isnm", ""), "name": item.get("hts_kor_isnm", ""),
"price": _safe_float(item.get("stck_prpr", "0")), "price": _safe_float(item.get("stck_prpr", "0")),
"volume": _safe_float(item.get("acml_vol", "0")), "volume": _safe_float(item.get("acml_vol", "0")),
@@ -478,112 +357,6 @@ class KISBroker:
except (TimeoutError, aiohttp.ClientError) as exc: except (TimeoutError, aiohttp.ClientError) as exc:
raise ConnectionError(f"Network error fetching rankings: {exc}") from exc raise ConnectionError(f"Network error fetching rankings: {exc}") from exc
async def get_domestic_pending_orders(self) -> list[dict[str, Any]]:
"""Fetch unfilled (pending) domestic limit orders.
The KIS pending-orders API (TTTC0084R) is unsupported in paper (VTS)
mode, so this method returns an empty list immediately when MODE is
not "live".
Returns:
List of pending order dicts from the KIS ``output`` field.
Each dict includes keys such as ``odno``, ``orgn_odno``,
``ord_gno_brno``, ``psbl_qty``, ``sll_buy_dvsn_cd``, ``pdno``.
"""
if self._settings.MODE != "live":
logger.debug(
"get_domestic_pending_orders: paper mode — TTTC0084R unsupported, returning []"
)
return []
await self._rate_limiter.acquire()
session = self._get_session()
# TR_ID: 실전 TTTC0084R (모의 미지원)
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '주식 미체결조회' 시트
headers = await self._auth_headers("TTTC0084R")
params = {
"CANO": self._account_no,
"ACNT_PRDT_CD": self._product_cd,
"INQR_DVSN_1": "0",
"INQR_DVSN_2": "0",
"CTX_AREA_FK100": "",
"CTX_AREA_NK100": "",
}
url = f"{self._base_url}/uapi/domestic-stock/v1/trading/inquire-psbl-rvsecncl"
try:
async with session.get(url, headers=headers, params=params) as resp:
if resp.status != 200:
text = await resp.text()
raise ConnectionError(
f"get_domestic_pending_orders failed ({resp.status}): {text}"
)
data = await resp.json()
return data.get("output", []) or []
except (TimeoutError, aiohttp.ClientError) as exc:
raise ConnectionError(
f"Network error fetching domestic pending orders: {exc}"
) from exc
async def cancel_domestic_order(
self,
stock_code: str,
orgn_odno: str,
krx_fwdg_ord_orgno: str,
qty: int,
) -> dict[str, Any]:
"""Cancel an unfilled domestic limit order.
Args:
stock_code: 6-digit domestic stock code (``pdno``).
orgn_odno: Original order number from pending-orders response
(``orgn_odno`` field).
krx_fwdg_ord_orgno: KRX forwarding order branch number from
pending-orders response (``ord_gno_brno`` field).
qty: Quantity to cancel (use ``psbl_qty`` from pending order).
Returns:
Raw KIS API response dict (check ``rt_cd == "0"`` for success).
"""
await self._rate_limiter.acquire()
session = self._get_session()
# TR_ID: 실전 TTTC0013U, 모의 VTTC0013U
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '주식주문(정정취소)' 시트
tr_id = "TTTC0013U" if self._settings.MODE == "live" else "VTTC0013U"
body = {
"CANO": self._account_no,
"ACNT_PRDT_CD": self._product_cd,
"KRX_FWDG_ORD_ORGNO": krx_fwdg_ord_orgno,
"ORGN_ODNO": orgn_odno,
"ORD_DVSN": "00",
"ORD_QTY": str(qty),
"ORD_UNPR": "0",
"RVSE_CNCL_DVSN_CD": "02",
"QTY_ALL_ORD_YN": "Y",
}
hash_key = await self._get_hash_key(body)
headers = await self._auth_headers(tr_id)
headers["hashkey"] = hash_key
url = f"{self._base_url}/uapi/domestic-stock/v1/trading/order-rvsecncl"
try:
async with session.post(url, headers=headers, json=body) as resp:
if resp.status != 200:
text = await resp.text()
raise ConnectionError(
f"cancel_domestic_order failed ({resp.status}): {text}"
)
return cast(dict[str, Any], await resp.json())
except (TimeoutError, aiohttp.ClientError) as exc:
raise ConnectionError(
f"Network error cancelling domestic order: {exc}"
) from exc
async def get_daily_prices( async def get_daily_prices(
self, self,
stock_code: str, stock_code: str,

View File

@@ -12,38 +12,6 @@ from src.broker.kis_api import KISBroker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Ranking API uses different exchange codes than order/quote APIs.
_RANKING_EXCHANGE_MAP: dict[str, str] = {
"NASD": "NAS",
"NYSE": "NYS",
"AMEX": "AMS",
"SEHK": "HKS",
"SHAA": "SHS",
"SZAA": "SZS",
"HSX": "HSX",
"HNX": "HNX",
"TSE": "TSE",
}
# Price inquiry API (HHDFS00000300) uses the same short exchange codes as rankings.
# NASD → NAS, NYSE → NYS, AMEX → AMS (confirmed: AMEX returns empty, AMS returns price).
_PRICE_EXCHANGE_MAP: dict[str, str] = _RANKING_EXCHANGE_MAP
# Cancel order TR_IDs per exchange code — (live_tr_id, paper_tr_id).
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '해외주식 주문취소' 시트
_CANCEL_TR_ID_MAP: dict[str, tuple[str, str]] = {
"NASD": ("TTTT1004U", "VTTT1004U"),
"NYSE": ("TTTT1004U", "VTTT1004U"),
"AMEX": ("TTTT1004U", "VTTT1004U"),
"SEHK": ("TTTS1003U", "VTTS1003U"),
"TSE": ("TTTS0309U", "VTTS0309U"),
"SHAA": ("TTTS0302U", "VTTS0302U"),
"SZAA": ("TTTS0306U", "VTTS0306U"),
"HNX": ("TTTS0312U", "VTTS0312U"),
"HSX": ("TTTS0312U", "VTTS0312U"),
}
class OverseasBroker: class OverseasBroker:
"""KIS Overseas Stock API wrapper that reuses KISBroker infrastructure.""" """KIS Overseas Stock API wrapper that reuses KISBroker infrastructure."""
@@ -76,11 +44,9 @@ class OverseasBroker:
session = self._broker._get_session() session = self._broker._get_session()
headers = await self._broker._auth_headers("HHDFS00000300") headers = await self._broker._auth_headers("HHDFS00000300")
# Map internal exchange codes to the short form expected by the price API.
price_excd = _PRICE_EXCHANGE_MAP.get(exchange_code, exchange_code)
params = { params = {
"AUTH": "", "AUTH": "",
"EXCD": price_excd, "EXCD": exchange_code,
"SYMB": stock_code, "SYMB": stock_code,
} }
url = f"{self._broker._base_url}/uapi/overseas-price/v1/quotations/price" url = f"{self._broker._base_url}/uapi/overseas-price/v1/quotations/price"
@@ -104,7 +70,7 @@ class OverseasBroker:
ranking_type: str = "fluctuation", ranking_type: str = "fluctuation",
limit: int = 30, limit: int = 30,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Fetch overseas rankings (price change or volume surge). """Fetch overseas rankings (price change or volume amount).
Ranking API specs may differ by account/product. Endpoint paths and Ranking API specs may differ by account/product. Endpoint paths and
TR_IDs are configurable via settings and can be overridden in .env. TR_IDs are configurable via settings and can be overridden in .env.
@@ -115,65 +81,66 @@ class OverseasBroker:
await self._broker._rate_limiter.acquire() await self._broker._rate_limiter.acquire()
session = self._broker._get_session() session = self._broker._get_session()
ranking_excd = _RANKING_EXCHANGE_MAP.get(exchange_code, exchange_code)
if ranking_type == "volume": if ranking_type == "volume":
tr_id = self._broker._settings.OVERSEAS_RANKING_VOLUME_TR_ID configured_tr_id = self._broker._settings.OVERSEAS_RANKING_VOLUME_TR_ID
path = self._broker._settings.OVERSEAS_RANKING_VOLUME_PATH configured_path = self._broker._settings.OVERSEAS_RANKING_VOLUME_PATH
params: dict[str, str] = { default_tr_id = "HHDFS76200200"
"KEYB": "", # NEXT KEY BUFF — Required, 공백 default_path = "/uapi/overseas-price/v1/quotations/inquire-volume-rank"
"AUTH": "",
"EXCD": ranking_excd,
"MIXN": "0",
"VOL_RANG": "0",
}
else: else:
tr_id = self._broker._settings.OVERSEAS_RANKING_FLUCT_TR_ID configured_tr_id = self._broker._settings.OVERSEAS_RANKING_FLUCT_TR_ID
path = self._broker._settings.OVERSEAS_RANKING_FLUCT_PATH configured_path = self._broker._settings.OVERSEAS_RANKING_FLUCT_PATH
params = { default_tr_id = "HHDFS76200100"
"KEYB": "", # NEXT KEY BUFF — Required, 공백 default_path = "/uapi/overseas-price/v1/quotations/inquire-updown-rank"
"AUTH": "",
"EXCD": ranking_excd,
"NDAY": "0",
"GUBN": "1", # 0=하락율, 1=상승율 — 변동성 스캐너는 급등 종목 우선
"VOL_RANG": "0",
}
headers = await self._broker._auth_headers(tr_id) endpoint_specs: list[tuple[str, str]] = [(configured_tr_id, configured_path)]
url = f"{self._broker._base_url}{path}" if (configured_tr_id, configured_path) != (default_tr_id, default_path):
endpoint_specs.append((default_tr_id, default_path))
try: # Try common param variants used by KIS overseas quotation APIs.
async with session.get(url, headers=headers, params=params) as resp: param_variants = [
if resp.status != 200: {"AUTH": "", "EXCD": exchange_code, "NREC": str(max(limit, 30))},
text = await resp.text() {"AUTH": "", "OVRS_EXCG_CD": exchange_code, "NREC": str(max(limit, 30))},
if resp.status == 404: {"AUTH": "", "EXCD": exchange_code},
logger.warning( {"AUTH": "", "OVRS_EXCG_CD": exchange_code},
"Overseas ranking endpoint unavailable (404) for %s/%s; " ]
"using symbol fallback scan",
exchange_code,
ranking_type,
)
return []
raise ConnectionError(
f"fetch_overseas_rankings failed ({resp.status}): {text}"
)
data = await resp.json() last_error: str | None = None
rows = self._extract_ranking_rows(data) saw_http_404 = False
if rows: for tr_id, path in endpoint_specs:
return rows[:limit] headers = await self._broker._auth_headers(tr_id)
url = f"{self._broker._base_url}{path}"
for params in param_variants:
try:
async with session.get(url, headers=headers, params=params) as resp:
text = await resp.text()
if resp.status != 200:
last_error = f"HTTP {resp.status}: {text}"
if resp.status == 404:
saw_http_404 = True
continue
logger.debug( data = await resp.json()
"Overseas ranking returned empty for %s/%s (keys=%s)", rows = self._extract_ranking_rows(data)
exchange_code, if rows:
ranking_type, return rows[:limit]
list(data.keys()),
) # keep trying another param variant if response has no usable rows
return [] last_error = f"empty output (keys={list(data.keys())})"
except (TimeoutError, aiohttp.ClientError) as exc: except (TimeoutError, aiohttp.ClientError) as exc:
raise ConnectionError( last_error = str(exc)
f"Network error fetching overseas rankings: {exc}" continue
) from exc
if saw_http_404:
logger.warning(
"Overseas ranking endpoint unavailable (404) for %s/%s; using symbol fallback scan",
exchange_code,
ranking_type,
)
return []
raise ConnectionError(
f"fetch_overseas_rankings failed for {exchange_code}/{ranking_type}: {last_error}"
)
async def get_overseas_balance(self, exchange_code: str) -> dict[str, Any]: async def get_overseas_balance(self, exchange_code: str) -> dict[str, Any]:
""" """
@@ -191,12 +158,8 @@ class OverseasBroker:
await self._broker._rate_limiter.acquire() await self._broker._rate_limiter.acquire()
session = self._broker._get_session() session = self._broker._get_session()
# TR_ID: 실전 TTTS3012R, 모의 VTTS3012R # Virtual trading TR_ID for overseas balance inquiry
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '해외주식 잔고조회' 시트 headers = await self._broker._auth_headers("VTTS3012R")
balance_tr_id = (
"TTTS3012R" if self._broker._settings.MODE == "live" else "VTTS3012R"
)
headers = await self._broker._auth_headers(balance_tr_id)
params = { params = {
"CANO": self._broker._account_no, "CANO": self._broker._account_no,
"ACNT_PRDT_CD": self._broker._product_cd, "ACNT_PRDT_CD": self._broker._product_cd,
@@ -222,59 +185,6 @@ class OverseasBroker:
f"Network error fetching overseas balance: {exc}" f"Network error fetching overseas balance: {exc}"
) from exc ) from exc
async def get_overseas_buying_power(
self,
exchange_code: str,
stock_code: str,
price: float,
) -> dict[str, Any]:
"""
Fetch overseas buying power for a specific stock and price.
Args:
exchange_code: Exchange code (e.g., "NASD", "NYSE")
stock_code: Stock ticker symbol
price: Current stock price (used for quantity calculation)
Returns:
API response; key field: output.ord_psbl_frcr_amt (주문가능외화금액)
Raises:
ConnectionError: On network or API errors
"""
await self._broker._rate_limiter.acquire()
session = self._broker._get_session()
# TR_ID: 실전 TTTS3007R, 모의 VTTS3007R
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '해외주식 매수가능금액조회' 시트
ps_tr_id = (
"TTTS3007R" if self._broker._settings.MODE == "live" else "VTTS3007R"
)
headers = await self._broker._auth_headers(ps_tr_id)
params = {
"CANO": self._broker._account_no,
"ACNT_PRDT_CD": self._broker._product_cd,
"OVRS_EXCG_CD": exchange_code,
"OVRS_ORD_UNPR": f"{price:.2f}",
"ITEM_CD": stock_code,
}
url = (
f"{self._broker._base_url}/uapi/overseas-stock/v1/trading/inquire-psamount"
)
try:
async with session.get(url, headers=headers, params=params) as resp:
if resp.status != 200:
text = await resp.text()
raise ConnectionError(
f"get_overseas_buying_power failed ({resp.status}): {text}"
)
return await resp.json()
except (TimeoutError, aiohttp.ClientError) as exc:
raise ConnectionError(
f"Network error fetching overseas buying power: {exc}"
) from exc
async def send_overseas_order( async def send_overseas_order(
self, self,
exchange_code: str, exchange_code: str,
@@ -302,12 +212,8 @@ class OverseasBroker:
await self._broker._rate_limiter.acquire() await self._broker._rate_limiter.acquire()
session = self._broker._get_session() session = self._broker._get_session()
# TR_ID: 실전 BUY=TTTT1002U SELL=TTTT1006U, 모의 BUY=VTTT1002U SELL=VTTT1001U # Virtual trading TR_IDs for overseas orders
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '해외주식 주문' 시트 tr_id = "VTTT1002U" if order_type == "BUY" else "VTTT1006U"
if self._broker._settings.MODE == "live":
tr_id = "TTTT1002U" if order_type == "BUY" else "TTTT1006U"
else:
tr_id = "VTTT1002U" if order_type == "BUY" else "VTTT1001U"
body = { body = {
"CANO": self._broker._account_no, "CANO": self._broker._account_no,
@@ -334,158 +240,20 @@ class OverseasBroker:
f"send_overseas_order failed ({resp.status}): {text}" f"send_overseas_order failed ({resp.status}): {text}"
) )
data = await resp.json() data = await resp.json()
rt_cd = data.get("rt_cd", "") logger.info(
msg1 = data.get("msg1", "") "Overseas order submitted",
if rt_cd == "0": extra={
logger.info( "exchange": exchange_code,
"Overseas order submitted", "stock_code": stock_code,
extra={ "action": order_type,
"exchange": exchange_code, },
"stock_code": stock_code, )
"action": order_type,
},
)
else:
logger.warning(
"Overseas order rejected (rt_cd=%s): %s [%s %s %s qty=%d]",
rt_cd,
msg1,
order_type,
stock_code,
exchange_code,
quantity,
)
return data return data
except (TimeoutError, aiohttp.ClientError) as exc: except (TimeoutError, aiohttp.ClientError) as exc:
raise ConnectionError( raise ConnectionError(
f"Network error sending overseas order: {exc}" f"Network error sending overseas order: {exc}"
) from exc ) from exc
async def get_overseas_pending_orders(
self, exchange_code: str
) -> list[dict[str, Any]]:
"""Fetch unfilled (pending) overseas orders for a given exchange.
Args:
exchange_code: Exchange code (e.g., "NASD", "SEHK").
For US markets, NASD returns all US pending orders (NASD/NYSE/AMEX).
Returns:
List of pending order dicts with fields: odno, pdno, sll_buy_dvsn_cd,
ft_ord_qty, nccs_qty, ft_ord_unpr3, ovrs_excg_cd.
Always returns [] in paper mode (TTTS3018R is live-only).
Raises:
ConnectionError: On network or API errors (live mode only).
"""
if self._broker._settings.MODE != "live":
logger.debug(
"Pending orders API (TTTS3018R) not supported in paper mode; returning []"
)
return []
await self._broker._rate_limiter.acquire()
session = self._broker._get_session()
# TTTS3018R: 해외주식 미체결내역조회 (실전 전용)
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '해외주식 미체결조회' 시트
headers = await self._broker._auth_headers("TTTS3018R")
params = {
"CANO": self._broker._account_no,
"ACNT_PRDT_CD": self._broker._product_cd,
"OVRS_EXCG_CD": exchange_code,
"SORT_SQN": "DS",
"CTX_AREA_FK200": "",
"CTX_AREA_NK200": "",
}
url = (
f"{self._broker._base_url}/uapi/overseas-stock/v1/trading/inquire-nccs"
)
try:
async with session.get(url, headers=headers, params=params) as resp:
if resp.status != 200:
text = await resp.text()
raise ConnectionError(
f"get_overseas_pending_orders failed ({resp.status}): {text}"
)
data = await resp.json()
output = data.get("output", [])
if isinstance(output, list):
return output
return []
except (TimeoutError, aiohttp.ClientError) as exc:
raise ConnectionError(
f"Network error fetching pending orders: {exc}"
) from exc
async def cancel_overseas_order(
self,
exchange_code: str,
stock_code: str,
odno: str,
qty: int,
) -> dict[str, Any]:
"""Cancel an overseas limit order.
Args:
exchange_code: Exchange code (e.g., "NASD", "SEHK").
stock_code: Stock ticker symbol.
odno: Original order number to cancel.
qty: Unfilled quantity to cancel.
Returns:
API response dict containing rt_cd and msg1.
Raises:
ValueError: If exchange_code has no cancel TR_ID mapping.
ConnectionError: On network or API errors.
"""
tr_ids = _CANCEL_TR_ID_MAP.get(exchange_code)
if tr_ids is None:
raise ValueError(f"No cancel TR_ID mapping for exchange: {exchange_code}")
live_tr_id, paper_tr_id = tr_ids
tr_id = live_tr_id if self._broker._settings.MODE == "live" else paper_tr_id
await self._broker._rate_limiter.acquire()
session = self._broker._get_session()
# RVSE_CNCL_DVSN_CD="02" means cancel (not revision).
# OVRS_ORD_UNPR must be "0" for cancellations.
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '해외주식 정정취소주문' 시트
body = {
"CANO": self._broker._account_no,
"ACNT_PRDT_CD": self._broker._product_cd,
"OVRS_EXCG_CD": exchange_code,
"PDNO": stock_code,
"ORGN_ODNO": odno,
"RVSE_CNCL_DVSN_CD": "02",
"ORD_QTY": str(qty),
"OVRS_ORD_UNPR": "0",
"ORD_SVR_DVSN_CD": "0",
}
hash_key = await self._broker._get_hash_key(body)
headers = await self._broker._auth_headers(tr_id)
headers["hashkey"] = hash_key
url = (
f"{self._broker._base_url}/uapi/overseas-stock/v1/trading/order-rvsecncl"
)
try:
async with session.post(url, headers=headers, json=body) as resp:
if resp.status != 200:
text = await resp.text()
raise ConnectionError(
f"cancel_overseas_order failed ({resp.status}): {text}"
)
return await resp.json()
except (TimeoutError, aiohttp.ClientError) as exc:
raise ConnectionError(
f"Network error cancelling overseas order: {exc}"
) from exc
def _get_currency_code(self, exchange_code: str) -> str: def _get_currency_code(self, exchange_code: str) -> str:
""" """
Map exchange code to currency code. Map exchange code to currency code.

View File

@@ -13,11 +13,11 @@ class Settings(BaseSettings):
KIS_APP_KEY: str KIS_APP_KEY: str
KIS_APP_SECRET: str KIS_APP_SECRET: str
KIS_ACCOUNT_NO: str # format: "XXXXXXXX-XX" KIS_ACCOUNT_NO: str # format: "XXXXXXXX-XX"
KIS_BASE_URL: str = "https://openapivts.koreainvestment.com:29443" KIS_BASE_URL: str = "https://openapivts.koreainvestment.com:9443"
# Google Gemini # Google Gemini
GEMINI_API_KEY: str GEMINI_API_KEY: str
GEMINI_MODEL: str = "gemini-2.0-flash" GEMINI_MODEL: str = "gemini-pro"
# External Data APIs (optional — for data-driven decisions) # External Data APIs (optional — for data-driven decisions)
NEWS_API_KEY: str | None = None NEWS_API_KEY: str | None = None
@@ -55,11 +55,6 @@ class Settings(BaseSettings):
# Trading mode # Trading mode
MODE: str = Field(default="paper", pattern="^(paper|live)$") MODE: str = Field(default="paper", pattern="^(paper|live)$")
# Simulated USD cash for VTS (paper) overseas trading.
# KIS VTS overseas balance API returns errors for most accounts.
# This value is used as a fallback when the balance API returns 0 in paper mode.
PAPER_OVERSEAS_CASH: float = Field(default=50000.0, ge=0.0)
# Trading frequency mode (daily = batch API calls, realtime = per-stock calls) # Trading frequency mode (daily = batch API calls, realtime = per-stock calls)
TRADE_MODE: str = Field(default="daily", pattern="^(daily|realtime)$") TRADE_MODE: str = Field(default="daily", pattern="^(daily|realtime)$")
DAILY_SESSIONS: int = Field(default=4, ge=1, le=10) DAILY_SESSIONS: int = Field(default=4, ge=1, le=10)
@@ -93,26 +88,16 @@ class Settings(BaseSettings):
TELEGRAM_COMMANDS_ENABLED: bool = True TELEGRAM_COMMANDS_ENABLED: bool = True
TELEGRAM_POLLING_INTERVAL: float = 1.0 # seconds TELEGRAM_POLLING_INTERVAL: float = 1.0 # seconds
# Telegram notification type filters (granular control)
# circuit_breaker is always sent regardless — safety-critical
TELEGRAM_NOTIFY_TRADES: bool = True # BUY/SELL execution alerts
TELEGRAM_NOTIFY_MARKET_OPEN_CLOSE: bool = True # Market open/close alerts
TELEGRAM_NOTIFY_FAT_FINGER: bool = True # Fat-finger rejection alerts
TELEGRAM_NOTIFY_SYSTEM_EVENTS: bool = True # System start/shutdown alerts
TELEGRAM_NOTIFY_PLAYBOOK: bool = True # Playbook generated/failed alerts
TELEGRAM_NOTIFY_SCENARIO_MATCH: bool = True # Scenario matched alerts (most frequent)
TELEGRAM_NOTIFY_ERRORS: bool = True # Error alerts
# Overseas ranking API (KIS endpoint/TR_ID may vary by account/product) # Overseas ranking API (KIS endpoint/TR_ID may vary by account/product)
# Override these from .env if your account uses different specs. # Override these from .env if your account uses different specs.
OVERSEAS_RANKING_ENABLED: bool = True OVERSEAS_RANKING_ENABLED: bool = True
OVERSEAS_RANKING_FLUCT_TR_ID: str = "HHDFS76290000" OVERSEAS_RANKING_FLUCT_TR_ID: str = "HHDFS76200100"
OVERSEAS_RANKING_VOLUME_TR_ID: str = "HHDFS76270000" OVERSEAS_RANKING_VOLUME_TR_ID: str = "HHDFS76200200"
OVERSEAS_RANKING_FLUCT_PATH: str = ( OVERSEAS_RANKING_FLUCT_PATH: str = (
"/uapi/overseas-stock/v1/ranking/updown-rate" "/uapi/overseas-price/v1/quotations/inquire-updown-rank"
) )
OVERSEAS_RANKING_VOLUME_PATH: str = ( OVERSEAS_RANKING_VOLUME_PATH: str = (
"/uapi/overseas-stock/v1/ranking/volume-surge" "/uapi/overseas-price/v1/quotations/inquire-volume-rank"
) )
# Dashboard (optional) # Dashboard (optional)

View File

@@ -3,9 +3,8 @@
from __future__ import annotations from __future__ import annotations
import json import json
import os
import sqlite3 import sqlite3
from datetime import UTC, datetime, timezone from datetime import UTC, datetime
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@@ -13,11 +12,10 @@ from fastapi import FastAPI, HTTPException, Query
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
def create_dashboard_app(db_path: str, mode: str = "paper") -> FastAPI: def create_dashboard_app(db_path: str) -> FastAPI:
"""Create dashboard FastAPI app bound to a SQLite database path.""" """Create dashboard FastAPI app bound to a SQLite database path."""
app = FastAPI(title="The Ouroboros Dashboard", version="1.0.0") app = FastAPI(title="The Ouroboros Dashboard", version="1.0.0")
app.state.db_path = db_path app.state.db_path = db_path
app.state.mode = mode
@app.get("/") @app.get("/")
def index() -> FileResponse: def index() -> FileResponse:
@@ -81,49 +79,14 @@ def create_dashboard_app(db_path: str, mode: str = "paper") -> FastAPI:
total_pnl += market_status[market]["total_pnl"] total_pnl += market_status[market]["total_pnl"]
total_decisions += market_status[market]["decision_count"] total_decisions += market_status[market]["decision_count"]
cb_threshold = float(os.getenv("CIRCUIT_BREAKER_PCT", "-3.0"))
pnl_pct_rows = conn.execute(
"""
SELECT key, value
FROM system_metrics
WHERE key LIKE 'portfolio_pnl_pct_%'
ORDER BY updated_at DESC
LIMIT 20
"""
).fetchall()
current_pnl_pct: float | None = None
if pnl_pct_rows:
values = [
json.loads(row["value"]).get("pnl_pct")
for row in pnl_pct_rows
if json.loads(row["value"]).get("pnl_pct") is not None
]
if values:
current_pnl_pct = round(min(values), 4)
if current_pnl_pct is None:
cb_status = "unknown"
elif current_pnl_pct <= cb_threshold:
cb_status = "tripped"
elif current_pnl_pct <= cb_threshold + 1.0:
cb_status = "warning"
else:
cb_status = "ok"
return { return {
"date": today, "date": today,
"mode": mode,
"markets": market_status, "markets": market_status,
"totals": { "totals": {
"trade_count": total_trades, "trade_count": total_trades,
"total_pnl": round(total_pnl, 2), "total_pnl": round(total_pnl, 2),
"decision_count": total_decisions, "decision_count": total_decisions,
}, },
"circuit_breaker": {
"threshold_pct": cb_threshold,
"current_pnl_pct": current_pnl_pct,
"status": cb_status,
},
} }
@app.get("/api/playbook/{date_str}") @app.get("/api/playbook/{date_str}")
@@ -296,50 +259,6 @@ def create_dashboard_app(db_path: str, mode: str = "paper") -> FastAPI:
) )
return {"market": market, "count": len(decisions), "decisions": decisions} return {"market": market, "count": len(decisions), "decisions": decisions}
@app.get("/api/pnl/history")
def get_pnl_history(
days: int = Query(default=30, ge=1, le=365),
market: str = Query("all"),
) -> dict[str, Any]:
"""Return daily P&L history for charting."""
with _connect(db_path) as conn:
if market == "all":
rows = conn.execute(
"""
SELECT DATE(timestamp) AS date,
SUM(pnl) AS daily_pnl,
COUNT(*) AS trade_count
FROM trades
WHERE pnl IS NOT NULL
AND DATE(timestamp) >= DATE('now', ?)
GROUP BY DATE(timestamp)
ORDER BY DATE(timestamp)
""",
(f"-{days} days",),
).fetchall()
else:
rows = conn.execute(
"""
SELECT DATE(timestamp) AS date,
SUM(pnl) AS daily_pnl,
COUNT(*) AS trade_count
FROM trades
WHERE pnl IS NOT NULL
AND market = ?
AND DATE(timestamp) >= DATE('now', ?)
GROUP BY DATE(timestamp)
ORDER BY DATE(timestamp)
""",
(market, f"-{days} days"),
).fetchall()
return {
"days": days,
"market": market,
"labels": [row["date"] for row in rows],
"pnl": [round(float(row["daily_pnl"]), 2) for row in rows],
"trades": [int(row["trade_count"]) for row in rows],
}
@app.get("/api/scenarios/active") @app.get("/api/scenarios/active")
def get_active_scenarios( def get_active_scenarios(
market: str = Query("US"), market: str = Query("US"),
@@ -378,68 +297,12 @@ def create_dashboard_app(db_path: str, mode: str = "paper") -> FastAPI:
) )
return {"market": market, "date": date_str, "count": len(matches), "matches": matches} return {"market": market, "date": date_str, "count": len(matches), "matches": matches}
@app.get("/api/positions")
def get_positions() -> dict[str, Any]:
"""Return all currently open positions (last trade per symbol is BUY)."""
with _connect(db_path) as conn:
rows = conn.execute(
"""
SELECT stock_code, market, exchange_code,
price AS entry_price, quantity, timestamp AS entry_time,
decision_id
FROM (
SELECT stock_code, market, exchange_code, price, quantity,
timestamp, decision_id, action,
ROW_NUMBER() OVER (
PARTITION BY stock_code, market
ORDER BY timestamp DESC
) AS rn
FROM trades
)
WHERE rn = 1 AND action = 'BUY'
ORDER BY entry_time DESC
"""
).fetchall()
now = datetime.now(timezone.utc)
positions = []
for row in rows:
entry_time_str = row["entry_time"]
try:
entry_dt = datetime.fromisoformat(entry_time_str.replace("Z", "+00:00"))
held_seconds = int((now - entry_dt).total_seconds())
held_hours = held_seconds // 3600
held_minutes = (held_seconds % 3600) // 60
if held_hours >= 1:
held_display = f"{held_hours}h {held_minutes}m"
else:
held_display = f"{held_minutes}m"
except (ValueError, TypeError):
held_display = "--"
positions.append(
{
"stock_code": row["stock_code"],
"market": row["market"],
"exchange_code": row["exchange_code"],
"entry_price": row["entry_price"],
"quantity": row["quantity"],
"entry_time": entry_time_str,
"held": held_display,
"decision_id": row["decision_id"],
}
)
return {"count": len(positions), "positions": positions}
return app return app
def _connect(db_path: str) -> sqlite3.Connection: def _connect(db_path: str) -> sqlite3.Connection:
conn = sqlite3.connect(db_path) conn = sqlite3.connect(db_path)
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA busy_timeout=8000")
return conn return conn

View File

@@ -1,10 +1,9 @@
<!doctype html> <!doctype html>
<html lang="ko"> <html lang="en">
<head> <head>
<meta charset="UTF-8" /> <meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>The Ouroboros Dashboard</title> <title>The Ouroboros Dashboard</title>
<script src="https://cdn.jsdelivr.net/npm/chart.js@4.4.0/dist/chart.umd.min.js"></script>
<style> <style>
:root { :root {
--bg: #0b1724; --bg: #0b1724;
@@ -12,787 +11,51 @@
--fg: #e6eef7; --fg: #e6eef7;
--muted: #9fb3c8; --muted: #9fb3c8;
--accent: #3cb371; --accent: #3cb371;
--red: #e05555;
--warn: #e8a040;
--border: #28455f;
} }
* { box-sizing: border-box; margin: 0; padding: 0; }
body { body {
margin: 0;
font-family: ui-monospace, SFMono-Regular, Menlo, monospace; font-family: ui-monospace, SFMono-Regular, Menlo, monospace;
background: radial-gradient(circle at top left, #173b58, var(--bg)); background: radial-gradient(circle at top left, #173b58, var(--bg));
color: var(--fg); color: var(--fg);
min-height: 100vh;
font-size: 13px;
} }
.wrap { max-width: 1100px; margin: 0 auto; padding: 20px 16px; } .wrap {
max-width: 900px;
/* Header */ margin: 48px auto;
header { padding: 0 16px;
display: flex;
align-items: center;
justify-content: space-between;
margin-bottom: 20px;
padding-bottom: 12px;
border-bottom: 1px solid var(--border);
} }
header h1 { font-size: 18px; color: var(--accent); letter-spacing: 0.5px; }
.header-right { display: flex; align-items: center; gap: 12px; color: var(--muted); font-size: 12px; }
.refresh-btn {
background: none; border: 1px solid var(--border); color: var(--muted);
padding: 4px 10px; border-radius: 6px; cursor: pointer; font-family: inherit;
font-size: 12px; transition: border-color 0.2s;
}
.refresh-btn:hover { border-color: var(--accent); color: var(--accent); }
.mode-badge {
padding: 3px 10px; border-radius: 5px; font-size: 12px; font-weight: 700;
letter-spacing: 0.5px;
}
.mode-badge.live {
background: rgba(224, 85, 85, 0.15); color: var(--red);
border: 1px solid rgba(224, 85, 85, 0.4);
animation: pulse-warn 2s ease-in-out infinite;
}
.mode-badge.paper {
background: rgba(232, 160, 64, 0.15); color: var(--warn);
border: 1px solid rgba(232, 160, 64, 0.4);
}
/* CB Gauge */
.cb-gauge-wrap {
display: flex; align-items: center; gap: 8px;
font-size: 11px; color: var(--muted);
}
.cb-dot {
width: 8px; height: 8px; border-radius: 50%; flex-shrink: 0;
}
.cb-dot.ok { background: var(--accent); }
.cb-dot.warning { background: var(--warn); animation: pulse-warn 1.2s ease-in-out infinite; }
.cb-dot.tripped { background: var(--red); animation: pulse-warn 0.6s ease-in-out infinite; }
.cb-dot.unknown { background: var(--border); }
@keyframes pulse-warn {
0%, 100% { opacity: 1; }
50% { opacity: 0.35; }
}
.cb-bar-wrap { width: 64px; height: 5px; background: rgba(255,255,255,0.08); border-radius: 3px; overflow: hidden; }
.cb-bar-fill { height: 100%; border-radius: 3px; transition: width 0.4s, background 0.4s; }
/* Summary cards */
.cards { display: grid; grid-template-columns: repeat(4, 1fr); gap: 12px; margin-bottom: 20px; }
@media (max-width: 700px) { .cards { grid-template-columns: repeat(2, 1fr); } }
.card { .card {
background: var(--panel); background: color-mix(in oklab, var(--panel), black 12%);
border: 1px solid var(--border); border: 1px solid #28455f;
border-radius: 10px; border-radius: 12px;
padding: 16px; padding: 20px;
} }
.card-label { color: var(--muted); font-size: 11px; margin-bottom: 6px; text-transform: uppercase; letter-spacing: 0.5px; } h1 {
.card-value { font-size: 22px; font-weight: 700; } margin-top: 0;
.card-sub { color: var(--muted); font-size: 11px; margin-top: 4px; }
.positive { color: var(--accent); }
.negative { color: var(--red); }
.neutral { color: var(--fg); }
/* Chart panel */
.chart-panel {
background: var(--panel);
border: 1px solid var(--border);
border-radius: 10px;
padding: 16px;
margin-bottom: 20px;
} }
.panel-header { code {
display: flex; color: var(--accent);
align-items: center;
justify-content: space-between;
margin-bottom: 16px;
} }
.panel-title { font-size: 13px; color: var(--muted); font-weight: 600; } li {
.chart-container { position: relative; height: 180px; } margin: 6px 0;
.chart-error { color: var(--muted); text-align: center; padding: 40px 0; font-size: 12px; } color: var(--muted);
/* Days selector */
.days-selector { display: flex; gap: 4px; }
.day-btn {
background: none; border: 1px solid var(--border); color: var(--muted);
padding: 3px 8px; border-radius: 4px; cursor: pointer; font-family: inherit; font-size: 11px;
} }
.day-btn.active { border-color: var(--accent); color: var(--accent); background: rgba(60, 179, 113, 0.08); }
/* Decisions panel */
.decisions-panel {
background: var(--panel);
border: 1px solid var(--border);
border-radius: 10px;
padding: 16px;
}
.market-tabs { display: flex; gap: 6px; flex-wrap: wrap; }
.tab-btn {
background: none; border: 1px solid var(--border); color: var(--muted);
padding: 4px 10px; border-radius: 6px; cursor: pointer; font-family: inherit; font-size: 11px;
}
.tab-btn.active { border-color: var(--accent); color: var(--accent); background: rgba(60, 179, 113, 0.08); }
.decisions-table { width: 100%; border-collapse: collapse; margin-top: 14px; }
.decisions-table th {
text-align: left; color: var(--muted); font-size: 11px; font-weight: 600;
padding: 6px 8px; border-bottom: 1px solid var(--border); white-space: nowrap;
}
.decisions-table td {
padding: 8px 8px; border-bottom: 1px solid rgba(40, 69, 95, 0.5);
vertical-align: middle; white-space: nowrap;
}
.decisions-table tr:last-child td { border-bottom: none; }
.decisions-table tr:hover td { background: rgba(255,255,255,0.02); }
.badge {
display: inline-block; padding: 2px 7px; border-radius: 4px;
font-size: 11px; font-weight: 700; letter-spacing: 0.5px;
}
.badge-buy { background: rgba(60, 179, 113, 0.15); color: var(--accent); }
.badge-sell { background: rgba(224, 85, 85, 0.15); color: var(--red); }
.badge-hold { background: rgba(159, 179, 200, 0.12); color: var(--muted); }
.conf-bar-wrap { display: flex; align-items: center; gap: 6px; min-width: 90px; }
.conf-bar { flex: 1; height: 6px; background: rgba(255,255,255,0.08); border-radius: 3px; overflow: hidden; }
.conf-fill { height: 100%; border-radius: 3px; background: var(--accent); transition: width 0.3s; }
.conf-val { color: var(--muted); font-size: 11px; min-width: 26px; text-align: right; }
.rationale-cell { max-width: 200px; overflow: hidden; text-overflow: ellipsis; color: var(--muted); }
.empty-row td { text-align: center; color: var(--muted); padding: 24px; }
/* Positions panel */
.positions-panel {
background: var(--panel);
border: 1px solid var(--border);
border-radius: 10px;
padding: 16px;
margin-bottom: 20px;
}
.positions-table { width: 100%; border-collapse: collapse; margin-top: 14px; }
.positions-table th {
text-align: left; color: var(--muted); font-size: 11px; font-weight: 600;
padding: 6px 8px; border-bottom: 1px solid var(--border); white-space: nowrap;
}
.positions-table td {
padding: 8px 8px; border-bottom: 1px solid rgba(40, 69, 95, 0.5);
vertical-align: middle; white-space: nowrap;
}
.positions-table tr:last-child td { border-bottom: none; }
.positions-table tr:hover td { background: rgba(255,255,255,0.02); }
.pos-empty { color: var(--muted); text-align: center; padding: 20px 0; font-size: 12px; }
.pos-count {
display: inline-block; background: rgba(60, 179, 113, 0.12);
color: var(--accent); font-size: 11px; font-weight: 700;
padding: 2px 8px; border-radius: 10px; margin-left: 8px;
}
/* Spinner */
.spinner { display: inline-block; width: 12px; height: 12px; border: 2px solid var(--border); border-top-color: var(--accent); border-radius: 50%; animation: spin 0.8s linear infinite; }
@keyframes spin { to { transform: rotate(360deg); } }
/* Generic panel */
.panel {
background: var(--panel);
border: 1px solid var(--border);
border-radius: 10px;
padding: 16px;
margin-top: 20px;
}
/* Playbook panel - details/summary accordion */
.playbook-panel details { border: 1px solid var(--border); border-radius: 4px; margin-bottom: 6px; }
.playbook-panel summary { padding: 8px 12px; cursor: pointer; font-weight: 600; background: var(--bg); color: var(--fg); }
.playbook-panel summary:hover { color: var(--accent); }
.playbook-panel pre { margin: 0; padding: 12px; background: var(--bg); overflow-x: auto;
font-size: 11px; color: #a0c4ff; white-space: pre-wrap; }
/* Scorecard KPI card grid */
.scorecard-grid { display: grid; grid-template-columns: repeat(auto-fill, minmax(140px, 1fr)); gap: 10px; }
.kpi-card { background: var(--bg); border: 1px solid var(--border); border-radius: 6px; padding: 12px; text-align: center; }
.kpi-card .kpi-label { font-size: 11px; color: var(--muted); margin-bottom: 4px; }
.kpi-card .kpi-value { font-size: 20px; font-weight: 700; color: var(--fg); }
/* Scenarios table */
.scenarios-table { width: 100%; border-collapse: collapse; font-size: 13px; }
.scenarios-table th { background: var(--bg); padding: 8px; text-align: left; border-bottom: 1px solid var(--border);
color: var(--muted); font-size: 11px; font-weight: 600; white-space: nowrap; }
.scenarios-table td { padding: 7px 8px; border-bottom: 1px solid rgba(40,69,95,0.5); }
.scenarios-table tr:hover td { background: rgba(255,255,255,0.02); }
/* Context table */
.context-table { width: 100%; border-collapse: collapse; font-size: 12px; }
.context-table th { background: var(--bg); padding: 8px; text-align: left; border-bottom: 1px solid var(--border);
color: var(--muted); font-size: 11px; font-weight: 600; white-space: nowrap; }
.context-table td { padding: 6px 8px; border-bottom: 1px solid rgba(40,69,95,0.5); vertical-align: top; }
.context-value { max-height: 60px; overflow-y: auto; color: #a0c4ff; word-break: break-all; }
/* Common panel select controls */
.panel-controls { display: flex; gap: 8px; align-items: center; flex-wrap: wrap; }
.panel-controls select, .panel-controls input[type="number"] {
background: var(--bg); color: var(--fg); border: 1px solid var(--border);
border-radius: 4px; padding: 4px 8px; font-size: 13px; font-family: inherit;
}
.panel-date { color: var(--muted); font-size: 12px; }
.empty-msg { color: var(--muted); text-align: center; padding: 20px 0; font-size: 12px; }
</style> </style>
</head> </head>
<body> <body>
<div class="wrap"> <div class="wrap">
<!-- Header --> <div class="card">
<header> <h1>The Ouroboros Dashboard API</h1>
<h1>&#x1F40D; The Ouroboros</h1> <p>Use the following endpoints:</p>
<div class="header-right"> <ul>
<span class="mode-badge" id="mode-badge">--</span> <li><code>/api/status</code></li>
<div class="cb-gauge-wrap" id="cb-gauge" title="Circuit Breaker"> <li><code>/api/playbook/{date}?market=KR</code></li>
<span class="cb-dot unknown" id="cb-dot"></span> <li><code>/api/scorecard/{date}?market=KR</code></li>
<span id="cb-label">CB --</span> <li><code>/api/performance?market=all</code></li>
<div class="cb-bar-wrap"> <li><code>/api/context/{layer}</code></li>
<div class="cb-bar-fill" id="cb-bar" style="width:0%;background:var(--accent)"></div> <li><code>/api/decisions?market=KR</code></li>
</div> <li><code>/api/scenarios/active?market=US</code></li>
</div> </ul>
<span id="last-updated">--</span>
<button class="refresh-btn" onclick="refreshAll()">&#x21BA; 새로고침</button>
</div>
</header>
<!-- Summary cards -->
<div class="cards">
<div class="card">
<div class="card-label">오늘 거래</div>
<div class="card-value neutral" id="card-trades">--</div>
<div class="card-sub" id="card-trades-sub">거래 건수</div>
</div>
<div class="card">
<div class="card-label">오늘 P&amp;L</div>
<div class="card-value" id="card-pnl">--</div>
<div class="card-sub" id="card-pnl-sub">실현 손익</div>
</div>
<div class="card">
<div class="card-label">승률</div>
<div class="card-value neutral" id="card-winrate">--</div>
<div class="card-sub">전체 누적</div>
</div>
<div class="card">
<div class="card-label">누적 거래</div>
<div class="card-value neutral" id="card-total">--</div>
<div class="card-sub">전체 기간</div>
</div>
</div>
<!-- Open Positions -->
<div class="positions-panel">
<div class="panel-header">
<span class="panel-title">
현재 보유 포지션
<span class="pos-count" id="positions-count">0</span>
</span>
</div>
<table class="positions-table">
<thead>
<tr>
<th>종목</th>
<th>시장</th>
<th>수량</th>
<th>진입가</th>
<th>보유 시간</th>
</tr>
</thead>
<tbody id="positions-body">
<tr><td colspan="5" class="pos-empty"><span class="spinner"></span></td></tr>
</tbody>
</table>
</div>
<!-- P&L Chart -->
<div class="chart-panel">
<div class="panel-header">
<span class="panel-title">P&amp;L 추이</span>
<div class="days-selector">
<button class="day-btn active" data-days="7" onclick="selectDays(this)">7일</button>
<button class="day-btn" data-days="30" onclick="selectDays(this)">30일</button>
<button class="day-btn" data-days="90" onclick="selectDays(this)">90일</button>
</div>
</div>
<div class="chart-container">
<canvas id="pnl-chart"></canvas>
<div class="chart-error" id="chart-error" style="display:none">데이터 없음</div>
</div>
</div>
<!-- Decisions log -->
<div class="decisions-panel">
<div class="panel-header">
<span class="panel-title">최근 결정 로그</span>
<div class="market-tabs" id="market-tabs">
<button class="tab-btn active" data-market="KR" onclick="selectMarket(this)">KR</button>
<button class="tab-btn" data-market="US_NASDAQ" onclick="selectMarket(this)">US_NASDAQ</button>
<button class="tab-btn" data-market="US_NYSE" onclick="selectMarket(this)">US_NYSE</button>
<button class="tab-btn" data-market="JP" onclick="selectMarket(this)">JP</button>
<button class="tab-btn" data-market="HK" onclick="selectMarket(this)">HK</button>
</div>
</div>
<table class="decisions-table">
<thead>
<tr>
<th>시각</th>
<th>종목</th>
<th>액션</th>
<th>신뢰도</th>
<th>사유</th>
</tr>
</thead>
<tbody id="decisions-body">
<tr class="empty-row"><td colspan="5"><span class="spinner"></span></td></tr>
</tbody>
</table>
</div>
<!-- playbook panel -->
<div class="panel playbook-panel">
<div class="panel-header">
<span class="panel-title">&#x1F4CB; 프리마켓 플레이북</span>
<div class="panel-controls">
<select id="pb-market-select" onchange="fetchPlaybook()">
<option value="KR">KR</option>
<option value="US_NASDAQ">US_NASDAQ</option>
<option value="US_NYSE">US_NYSE</option>
</select>
<span id="pb-date" class="panel-date"></span>
</div>
</div>
<div id="playbook-content"><p class="empty-msg">데이터 없음</p></div>
</div>
<!-- scorecard panel -->
<div class="panel">
<div class="panel-header">
<span class="panel-title">&#x1F4CA; 일간 스코어카드</span>
<div class="panel-controls">
<select id="sc-market-select" onchange="fetchScorecard()">
<option value="KR">KR</option>
<option value="US_NASDAQ">US_NASDAQ</option>
</select>
<span id="sc-date" class="panel-date"></span>
</div>
</div>
<div id="scorecard-grid" class="scorecard-grid"><p class="empty-msg">데이터 없음</p></div>
</div>
<!-- scenarios panel -->
<div class="panel">
<div class="panel-header">
<span class="panel-title">&#x1F3AF; 활성 시나리오 매칭</span>
<div class="panel-controls">
<select id="scen-market-select" onchange="fetchScenarios()">
<option value="KR">KR</option>
<option value="US_NASDAQ">US_NASDAQ</option>
</select>
</div>
</div>
<div id="scenarios-content"><p class="empty-msg">데이터 없음</p></div>
</div>
<!-- context layer panel -->
<div class="panel">
<div class="panel-header">
<span class="panel-title">&#x1F9E0; 컨텍스트 트리</span>
<div class="panel-controls">
<select id="ctx-layer-select" onchange="fetchContext()">
<option value="L7_REALTIME">L7_REALTIME</option>
<option value="L6_DAILY">L6_DAILY</option>
<option value="L5_WEEKLY">L5_WEEKLY</option>
<option value="L4_MONTHLY">L4_MONTHLY</option>
<option value="L3_QUARTERLY">L3_QUARTERLY</option>
<option value="L2_YEARLY">L2_YEARLY</option>
<option value="L1_LIFETIME">L1_LIFETIME</option>
</select>
<input id="ctx-limit" type="number" value="20" min="1" max="200"
style="width:60px;" onchange="fetchContext()">
</div>
</div>
<div id="context-content"><p class="empty-msg">데이터 없음</p></div>
</div> </div>
</div> </div>
<script>
let pnlChart = null;
let currentDays = 7;
let currentMarket = 'KR';
function fmt(dt) {
try {
const d = new Date(dt);
return d.toLocaleTimeString('ko-KR', { hour: '2-digit', minute: '2-digit', hour12: false });
} catch { return dt || '--'; }
}
function fmtPnl(v) {
if (v === null || v === undefined) return '--';
const n = parseFloat(v);
const cls = n > 0 ? 'positive' : n < 0 ? 'negative' : 'neutral';
const sign = n > 0 ? '+' : '';
return `<span class="${cls}">${sign}${n.toFixed(2)}</span>`;
}
function badge(action) {
const a = (action || '').toUpperCase();
const cls = a === 'BUY' ? 'badge-buy' : a === 'SELL' ? 'badge-sell' : 'badge-hold';
return `<span class="badge ${cls}">${a}</span>`;
}
function confBar(conf) {
const pct = Math.min(Math.max(conf || 0, 0), 100);
return `<div class="conf-bar-wrap">
<div class="conf-bar"><div class="conf-fill" style="width:${pct}%"></div></div>
<span class="conf-val">${pct}</span>
</div>`;
}
function fmtPrice(v, market) {
if (v === null || v === undefined) return '--';
const n = parseFloat(v);
const sym = market === 'KR' ? '₩' : market === 'JP' ? '¥' : market === 'HK' ? 'HK$' : '$';
return sym + n.toLocaleString('en-US', { minimumFractionDigits: 0, maximumFractionDigits: 4 });
}
async function fetchPositions() {
const tbody = document.getElementById('positions-body');
const countEl = document.getElementById('positions-count');
try {
const r = await fetch('/api/positions');
if (!r.ok) throw new Error('fetch failed');
const d = await r.json();
countEl.textContent = d.count ?? 0;
if (!d.positions || d.positions.length === 0) {
tbody.innerHTML = '<tr><td colspan="5" class="pos-empty">현재 보유 중인 포지션 없음</td></tr>';
return;
}
tbody.innerHTML = d.positions.map(p => `
<tr>
<td><strong>${p.stock_code || '--'}</strong></td>
<td><span style="color:var(--muted);font-size:11px">${p.market || '--'}</span></td>
<td>${p.quantity ?? '--'}</td>
<td>${fmtPrice(p.entry_price, p.market)}</td>
<td style="color:var(--muted);font-size:11px">${p.held || '--'}</td>
</tr>
`).join('');
} catch {
tbody.innerHTML = '<tr><td colspan="5" class="pos-empty">데이터 로드 실패</td></tr>';
}
}
function renderCbGauge(cb) {
if (!cb) return;
const dot = document.getElementById('cb-dot');
const label = document.getElementById('cb-label');
const bar = document.getElementById('cb-bar');
const status = cb.status || 'unknown';
const threshold = cb.threshold_pct ?? -3.0;
const current = cb.current_pnl_pct;
// dot color
dot.className = `cb-dot ${status}`;
// label
if (current !== null && current !== undefined) {
const sign = current > 0 ? '+' : '';
label.textContent = `CB ${sign}${current.toFixed(2)}%`;
} else {
label.textContent = 'CB --';
}
// bar: fill = how much of the threshold has been consumed (0%=safe, 100%=tripped)
const colorMap = { ok: 'var(--accent)', warning: 'var(--warn)', tripped: 'var(--red)', unknown: 'var(--border)' };
bar.style.background = colorMap[status] || 'var(--border)';
if (current !== null && current !== undefined && threshold < 0) {
const fillPct = Math.min(Math.max((current / threshold) * 100, 0), 100);
bar.style.width = `${fillPct}%`;
} else {
bar.style.width = '0%';
}
}
async function fetchStatus() {
try {
const r = await fetch('/api/status');
if (!r.ok) return;
const d = await r.json();
const t = d.totals || {};
document.getElementById('card-trades').textContent = t.trade_count ?? '--';
const pnlEl = document.getElementById('card-pnl');
const pnlV = t.total_pnl;
if (pnlV !== undefined) {
const n = parseFloat(pnlV);
const sign = n > 0 ? '+' : '';
pnlEl.textContent = `${sign}${n.toFixed(2)}`;
pnlEl.className = `card-value ${n > 0 ? 'positive' : n < 0 ? 'negative' : 'neutral'}`;
}
document.getElementById('card-pnl-sub').textContent = `결정 ${t.decision_count ?? 0}`;
renderCbGauge(d.circuit_breaker);
renderModeBadge(d.mode);
} catch {}
}
function renderModeBadge(mode) {
const el = document.getElementById('mode-badge');
if (!el) return;
if (mode === 'live') {
el.textContent = '🔴 실전투자';
el.className = 'mode-badge live';
} else {
el.textContent = '🟡 모의투자';
el.className = 'mode-badge paper';
}
}
async function fetchPerformance() {
try {
const r = await fetch('/api/performance?market=all');
if (!r.ok) return;
const d = await r.json();
const c = d.combined || {};
document.getElementById('card-winrate').textContent = c.win_rate !== undefined ? `${c.win_rate}%` : '--';
document.getElementById('card-total').textContent = c.total_trades ?? '--';
} catch {}
}
async function fetchPnlHistory(days) {
try {
const r = await fetch(`/api/pnl/history?days=${days}`);
if (!r.ok) throw new Error('fetch failed');
const d = await r.json();
renderChart(d);
} catch {
document.getElementById('chart-error').style.display = 'block';
}
}
function renderChart(data) {
const errEl = document.getElementById('chart-error');
if (!data.labels || data.labels.length === 0) {
errEl.style.display = 'block';
return;
}
errEl.style.display = 'none';
const colors = data.pnl.map(v => v >= 0 ? 'rgba(60,179,113,0.75)' : 'rgba(224,85,85,0.75)');
const borderColors = data.pnl.map(v => v >= 0 ? '#3cb371' : '#e05555');
if (pnlChart) { pnlChart.destroy(); pnlChart = null; }
const ctx = document.getElementById('pnl-chart').getContext('2d');
pnlChart = new Chart(ctx, {
type: 'bar',
data: {
labels: data.labels,
datasets: [{
label: 'Daily P&L',
data: data.pnl,
backgroundColor: colors,
borderColor: borderColors,
borderWidth: 1,
borderRadius: 3,
}]
},
options: {
responsive: true,
maintainAspectRatio: false,
plugins: {
legend: { display: false },
tooltip: {
callbacks: {
label: ctx => {
const v = ctx.parsed.y;
const sign = v >= 0 ? '+' : '';
const trades = data.trades[ctx.dataIndex];
return [`P&L: ${sign}${v.toFixed(2)}`, `거래: ${trades}`];
}
}
}
},
scales: {
x: {
ticks: { color: '#9fb3c8', font: { size: 10 }, maxRotation: 0 },
grid: { color: 'rgba(40,69,95,0.4)' }
},
y: {
ticks: { color: '#9fb3c8', font: { size: 10 } },
grid: { color: 'rgba(40,69,95,0.4)' }
}
}
}
});
}
async function fetchDecisions(market) {
const tbody = document.getElementById('decisions-body');
tbody.innerHTML = '<tr class="empty-row"><td colspan="5"><span class="spinner"></span></td></tr>';
try {
const r = await fetch(`/api/decisions?market=${market}&limit=50`);
if (!r.ok) throw new Error('fetch failed');
const d = await r.json();
if (!d.decisions || d.decisions.length === 0) {
tbody.innerHTML = '<tr class="empty-row"><td colspan="5">결정 로그 없음</td></tr>';
return;
}
tbody.innerHTML = d.decisions.map(dec => `
<tr>
<td>${fmt(dec.timestamp)}</td>
<td>${dec.stock_code || '--'}</td>
<td>${badge(dec.action)}</td>
<td>${confBar(dec.confidence)}</td>
<td class="rationale-cell" title="${(dec.rationale || '').replace(/"/g, '&quot;')}">${dec.rationale || '--'}</td>
</tr>
`).join('');
} catch {
tbody.innerHTML = '<tr class="empty-row"><td colspan="5">데이터 로드 실패</td></tr>';
}
}
function selectDays(btn) {
document.querySelectorAll('.day-btn').forEach(b => b.classList.remove('active'));
btn.classList.add('active');
currentDays = parseInt(btn.dataset.days, 10);
fetchPnlHistory(currentDays);
}
function selectMarket(btn) {
document.querySelectorAll('.tab-btn').forEach(b => b.classList.remove('active'));
btn.classList.add('active');
currentMarket = btn.dataset.market;
fetchDecisions(currentMarket);
}
function todayStr() {
return new Date().toISOString().slice(0, 10);
}
function esc(s) {
return String(s ?? '').replace(/&/g, '&amp;').replace(/</g, '&lt;').replace(/>/g, '&gt;').replace(/"/g, '&quot;');
}
async function fetchJSON(url) {
const r = await fetch(url);
if (!r.ok) throw new Error(`HTTP ${r.status}`);
return r.json();
}
async function fetchPlaybook() {
const market = document.getElementById('pb-market-select').value;
const date = todayStr();
document.getElementById('pb-date').textContent = date;
const el = document.getElementById('playbook-content');
try {
const data = await fetchJSON(`/api/playbook/${date}?market=${market}`);
const stocks = data.stock_playbooks ?? [];
if (stocks.length === 0) {
el.innerHTML = '<p class="empty-msg">오늘 플레이북 없음</p>';
return;
}
el.innerHTML = stocks.map(sp =>
`<details><summary>${esc(sp.stock_code ?? '?')}${esc(sp.signal ?? '')}</summary>` +
`<pre>${esc(JSON.stringify(sp, null, 2))}</pre></details>`
).join('');
} catch {
el.innerHTML = '<p class="empty-msg">플레이북 없음 (오늘 미생성 또는 API 오류)</p>';
}
}
async function fetchScorecard() {
const market = document.getElementById('sc-market-select').value;
const date = todayStr();
document.getElementById('sc-date').textContent = date;
const el = document.getElementById('scorecard-grid');
try {
const data = await fetchJSON(`/api/scorecard/${date}?market=${market}`);
const sc = data.scorecard ?? {};
const entries = Object.entries(sc);
if (entries.length === 0) {
el.innerHTML = '<p class="empty-msg">스코어카드 없음</p>';
return;
}
el.className = 'scorecard-grid';
el.innerHTML = entries.map(([k, v]) => `
<div class="kpi-card">
<div class="kpi-label">${esc(k)}</div>
<div class="kpi-value">${typeof v === 'number' ? v.toFixed(2) : esc(String(v))}</div>
</div>`).join('');
} catch {
el.innerHTML = '<p class="empty-msg">스코어카드 없음 (오늘 미생성 또는 API 오류)</p>';
}
}
async function fetchScenarios() {
const market = document.getElementById('scen-market-select').value;
const date = todayStr();
const el = document.getElementById('scenarios-content');
try {
const data = await fetchJSON(`/api/scenarios/active?market=${market}&date_str=${date}&limit=50`);
const matches = data.matches ?? [];
if (matches.length === 0) {
el.innerHTML = '<p class="empty-msg">활성 시나리오 없음</p>';
return;
}
el.innerHTML = `<table class="scenarios-table">
<thead><tr><th>종목</th><th>신호</th><th>신뢰도</th><th>매칭 조건</th></tr></thead>
<tbody>${matches.map(m => `
<tr>
<td>${esc(m.stock_code)}</td>
<td>${esc(m.signal ?? '-')}</td>
<td>${esc(m.confidence ?? '-')}</td>
<td><code style="font-size:11px">${esc(JSON.stringify(m.scenario_match ?? {}))}</code></td>
</tr>`).join('')}
</tbody></table>`;
} catch {
el.innerHTML = '<p class="empty-msg">데이터 없음</p>';
}
}
async function fetchContext() {
const layer = document.getElementById('ctx-layer-select').value;
const limit = Math.min(Math.max(parseInt(document.getElementById('ctx-limit').value, 10) || 20, 1), 200);
const el = document.getElementById('context-content');
try {
const data = await fetchJSON(`/api/context/${layer}?limit=${limit}`);
const entries = data.entries ?? [];
if (entries.length === 0) {
el.innerHTML = '<p class="empty-msg">컨텍스트 없음</p>';
return;
}
el.innerHTML = `<table class="context-table">
<thead><tr><th>timeframe</th><th>key</th><th>value</th><th>updated</th></tr></thead>
<tbody>${entries.map(e => `
<tr>
<td>${esc(e.timeframe)}</td>
<td>${esc(e.key)}</td>
<td><div class="context-value">${esc(JSON.stringify(e.value ?? e.raw_value))}</div></td>
<td style="font-size:11px;color:var(--muted)">${esc((e.updated_at ?? '').slice(0, 16))}</td>
</tr>`).join('')}
</tbody></table>`;
} catch {
el.innerHTML = '<p class="empty-msg">데이터 없음</p>';
}
}
async function refreshAll() {
document.getElementById('last-updated').textContent = '업데이트 중...';
await Promise.all([
fetchStatus(),
fetchPerformance(),
fetchPositions(),
fetchPnlHistory(currentDays),
fetchDecisions(currentMarket),
fetchPlaybook(),
fetchScorecard(),
fetchScenarios(),
fetchContext(),
]);
const now = new Date();
const timeStr = now.toLocaleTimeString('ko-KR', { hour: '2-digit', minute: '2-digit', second: '2-digit', hour12: false });
document.getElementById('last-updated').textContent = `마지막 업데이트: ${timeStr}`;
}
// Initial load
refreshAll();
// Auto-refresh every 30 seconds
setInterval(refreshAll, 30000);
</script>
</body> </body>
</html> </html>

View File

@@ -14,11 +14,6 @@ def init_db(db_path: str) -> sqlite3.Connection:
if db_path != ":memory:": if db_path != ":memory:":
Path(db_path).parent.mkdir(parents=True, exist_ok=True) Path(db_path).parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(db_path) conn = sqlite3.connect(db_path)
# Enable WAL mode for concurrent read/write (dashboard + trading loop).
# WAL does not apply to in-memory databases.
if db_path != ":memory:":
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA busy_timeout=5000")
conn.execute( conn.execute(
""" """
CREATE TABLE IF NOT EXISTS trades ( CREATE TABLE IF NOT EXISTS trades (
@@ -33,13 +28,12 @@ def init_db(db_path: str) -> sqlite3.Connection:
pnl REAL DEFAULT 0.0, pnl REAL DEFAULT 0.0,
market TEXT DEFAULT 'KR', market TEXT DEFAULT 'KR',
exchange_code TEXT DEFAULT 'KRX', exchange_code TEXT DEFAULT 'KRX',
decision_id TEXT, decision_id TEXT
mode TEXT DEFAULT 'paper'
) )
""" """
) )
# Migration: Add columns if they don't exist (backward-compatible schema upgrades) # Migration: Add market and exchange_code columns if they don't exist
cursor = conn.execute("PRAGMA table_info(trades)") cursor = conn.execute("PRAGMA table_info(trades)")
columns = {row[1] for row in cursor.fetchall()} columns = {row[1] for row in cursor.fetchall()}
@@ -51,8 +45,6 @@ def init_db(db_path: str) -> sqlite3.Connection:
conn.execute("ALTER TABLE trades ADD COLUMN selection_context TEXT") conn.execute("ALTER TABLE trades ADD COLUMN selection_context TEXT")
if "decision_id" not in columns: if "decision_id" not in columns:
conn.execute("ALTER TABLE trades ADD COLUMN decision_id TEXT") conn.execute("ALTER TABLE trades ADD COLUMN decision_id TEXT")
if "mode" not in columns:
conn.execute("ALTER TABLE trades ADD COLUMN mode TEXT DEFAULT 'paper'")
# Context tree tables for multi-layered memory management # Context tree tables for multi-layered memory management
conn.execute( conn.execute(
@@ -139,25 +131,6 @@ def init_db(db_path: str) -> sqlite3.Connection:
conn.execute( conn.execute(
"CREATE INDEX IF NOT EXISTS idx_decision_logs_confidence ON decision_logs(confidence)" "CREATE INDEX IF NOT EXISTS idx_decision_logs_confidence ON decision_logs(confidence)"
) )
# Index for open-position queries (partition by stock_code, market, ordered by timestamp)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_trades_stock_market_ts"
" ON trades (stock_code, market, timestamp DESC)"
)
# Lightweight key-value store for trading system runtime metrics (dashboard use only)
# Intentionally separate from the AI context tree to preserve separation of concerns.
conn.execute(
"""
CREATE TABLE IF NOT EXISTS system_metrics (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
updated_at TEXT NOT NULL
)
"""
)
conn.commit() conn.commit()
return conn return conn
@@ -175,7 +148,6 @@ def log_trade(
exchange_code: str = "KRX", exchange_code: str = "KRX",
selection_context: dict[str, any] | None = None, selection_context: dict[str, any] | None = None,
decision_id: str | None = None, decision_id: str | None = None,
mode: str = "paper",
) -> None: ) -> None:
"""Insert a trade record into the database. """Insert a trade record into the database.
@@ -191,8 +163,6 @@ def log_trade(
market: Market code market: Market code
exchange_code: Exchange code exchange_code: Exchange code
selection_context: Scanner selection data (RSI, volume_ratio, signal, score) selection_context: Scanner selection data (RSI, volume_ratio, signal, score)
decision_id: Unique decision identifier for audit linking
mode: Trading mode ('paper' or 'live') for data separation
""" """
# Serialize selection context to JSON # Serialize selection context to JSON
context_json = json.dumps(selection_context) if selection_context else None context_json = json.dumps(selection_context) if selection_context else None
@@ -201,10 +171,9 @@ def log_trade(
""" """
INSERT INTO trades ( INSERT INTO trades (
timestamp, stock_code, action, confidence, rationale, timestamp, stock_code, action, confidence, rationale,
quantity, price, pnl, market, exchange_code, selection_context, decision_id, quantity, price, pnl, market, exchange_code, selection_context, decision_id
mode
) )
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", """,
( (
datetime.now(UTC).isoformat(), datetime.now(UTC).isoformat(),
@@ -219,7 +188,6 @@ def log_trade(
exchange_code, exchange_code,
context_json, context_json,
decision_id, decision_id,
mode,
), ),
) )
conn.commit() conn.commit()
@@ -254,11 +222,10 @@ def get_open_position(
"""Return open position if latest trade is BUY, else None.""" """Return open position if latest trade is BUY, else None."""
cursor = conn.execute( cursor = conn.execute(
""" """
SELECT action, decision_id, price, quantity, timestamp SELECT action, decision_id, price, quantity
FROM trades FROM trades
WHERE stock_code = ? WHERE stock_code = ?
AND market = ? AND market = ?
AND action IN ('BUY', 'SELL')
ORDER BY timestamp DESC ORDER BY timestamp DESC
LIMIT 1 LIMIT 1
""", """,
@@ -267,7 +234,7 @@ def get_open_position(
row = cursor.fetchone() row = cursor.fetchone()
if not row or row[0] != "BUY": if not row or row[0] != "BUY":
return None return None
return {"decision_id": row[1], "price": row[2], "quantity": row[3], "timestamp": row[4]} return {"decision_id": row[1], "price": row[2], "quantity": row[3]}
def get_recent_symbols( def get_recent_symbols(

File diff suppressed because it is too large Load Diff

View File

@@ -4,9 +4,8 @@ import asyncio
import logging import logging
import time import time
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from dataclasses import dataclass, fields from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import ClassVar
import aiohttp import aiohttp
@@ -59,45 +58,6 @@ class LeakyBucket:
self._tokens -= 1.0 self._tokens -= 1.0
@dataclass
class NotificationFilter:
"""Granular on/off flags for each notification type.
circuit_breaker is intentionally omitted — it is always sent regardless.
"""
# Maps user-facing command keys to dataclass field names
KEYS: ClassVar[dict[str, str]] = {
"trades": "trades",
"market": "market_open_close",
"fatfinger": "fat_finger",
"system": "system_events",
"playbook": "playbook",
"scenario": "scenario_match",
"errors": "errors",
}
trades: bool = True
market_open_close: bool = True
fat_finger: bool = True
system_events: bool = True
playbook: bool = True
scenario_match: bool = True
errors: bool = True
def set_flag(self, key: str, value: bool) -> bool:
"""Set a filter flag by user-facing key. Returns False if key is unknown."""
field = self.KEYS.get(key.lower())
if field is None:
return False
setattr(self, field, value)
return True
def as_dict(self) -> dict[str, bool]:
"""Return {user_key: current_value} for display."""
return {k: getattr(self, field) for k, field in self.KEYS.items()}
@dataclass @dataclass
class NotificationMessage: class NotificationMessage:
"""Internal notification message structure.""" """Internal notification message structure."""
@@ -119,7 +79,6 @@ class TelegramClient:
chat_id: str | None = None, chat_id: str | None = None,
enabled: bool = True, enabled: bool = True,
rate_limit: float = DEFAULT_RATE, rate_limit: float = DEFAULT_RATE,
notification_filter: NotificationFilter | None = None,
) -> None: ) -> None:
""" """
Initialize Telegram client. Initialize Telegram client.
@@ -129,14 +88,12 @@ class TelegramClient:
chat_id: Target chat ID (user or group) chat_id: Target chat ID (user or group)
enabled: Enable/disable notifications globally enabled: Enable/disable notifications globally
rate_limit: Maximum messages per second rate_limit: Maximum messages per second
notification_filter: Granular per-type on/off flags
""" """
self._bot_token = bot_token self._bot_token = bot_token
self._chat_id = chat_id self._chat_id = chat_id
self._enabled = enabled self._enabled = enabled
self._rate_limiter = LeakyBucket(rate=rate_limit) self._rate_limiter = LeakyBucket(rate=rate_limit)
self._session: aiohttp.ClientSession | None = None self._session: aiohttp.ClientSession | None = None
self._filter = notification_filter if notification_filter is not None else NotificationFilter()
if not enabled: if not enabled:
logger.info("Telegram notifications disabled via configuration") logger.info("Telegram notifications disabled via configuration")
@@ -161,26 +118,6 @@ class TelegramClient:
if self._session is not None and not self._session.closed: if self._session is not None and not self._session.closed:
await self._session.close() await self._session.close()
def set_notification(self, key: str, value: bool) -> bool:
"""Toggle a notification type by user-facing key at runtime.
Args:
key: User-facing key (e.g. "scenario", "market", "all")
value: True to enable, False to disable
Returns:
True if key was valid, False if unknown.
"""
if key == "all":
for k in NotificationFilter.KEYS:
self._filter.set_flag(k, value)
return True
return self._filter.set_flag(key, value)
def filter_status(self) -> dict[str, bool]:
"""Return current per-type filter state keyed by user-facing names."""
return self._filter.as_dict()
async def send_message(self, text: str, parse_mode: str = "HTML") -> bool: async def send_message(self, text: str, parse_mode: str = "HTML") -> bool:
""" """
Send a generic text message to Telegram. Send a generic text message to Telegram.
@@ -256,8 +193,6 @@ class TelegramClient:
price: Execution price price: Execution price
confidence: AI confidence level (0-100) confidence: AI confidence level (0-100)
""" """
if not self._filter.trades:
return
emoji = "🟢" if action == "BUY" else "🔴" emoji = "🟢" if action == "BUY" else "🔴"
message = ( message = (
f"<b>{emoji} {action}</b>\n" f"<b>{emoji} {action}</b>\n"
@@ -277,8 +212,6 @@ class TelegramClient:
Args: Args:
market_name: Name of the market (e.g., "Korea", "United States") market_name: Name of the market (e.g., "Korea", "United States")
""" """
if not self._filter.market_open_close:
return
message = f"<b>Market Open</b>\n{market_name} trading session started" message = f"<b>Market Open</b>\n{market_name} trading session started"
await self._send_notification( await self._send_notification(
NotificationMessage(priority=NotificationPriority.LOW, message=message) NotificationMessage(priority=NotificationPriority.LOW, message=message)
@@ -292,8 +225,6 @@ class TelegramClient:
market_name: Name of the market market_name: Name of the market
pnl_pct: Final P&L percentage for the session pnl_pct: Final P&L percentage for the session
""" """
if not self._filter.market_open_close:
return
pnl_sign = "+" if pnl_pct >= 0 else "" pnl_sign = "+" if pnl_pct >= 0 else ""
pnl_emoji = "📈" if pnl_pct >= 0 else "📉" pnl_emoji = "📈" if pnl_pct >= 0 else "📉"
message = ( message = (
@@ -340,8 +271,6 @@ class TelegramClient:
total_cash: Total available cash total_cash: Total available cash
max_pct: Maximum allowed percentage max_pct: Maximum allowed percentage
""" """
if not self._filter.fat_finger:
return
attempted_pct = (order_amount / total_cash) * 100 if total_cash > 0 else 0 attempted_pct = (order_amount / total_cash) * 100 if total_cash > 0 else 0
message = ( message = (
f"<b>Fat-Finger Protection</b>\n" f"<b>Fat-Finger Protection</b>\n"
@@ -364,8 +293,6 @@ class TelegramClient:
mode: Trading mode ("paper" or "live") mode: Trading mode ("paper" or "live")
enabled_markets: List of enabled market codes enabled_markets: List of enabled market codes
""" """
if not self._filter.system_events:
return
mode_emoji = "📝" if mode == "paper" else "💰" mode_emoji = "📝" if mode == "paper" else "💰"
markets_str = ", ".join(enabled_markets) markets_str = ", ".join(enabled_markets)
message = ( message = (
@@ -393,8 +320,6 @@ class TelegramClient:
scenario_count: Total number of scenarios scenario_count: Total number of scenarios
token_count: Gemini token usage for the playbook token_count: Gemini token usage for the playbook
""" """
if not self._filter.playbook:
return
message = ( message = (
f"<b>Playbook Generated</b>\n" f"<b>Playbook Generated</b>\n"
f"Market: {market}\n" f"Market: {market}\n"
@@ -422,8 +347,6 @@ class TelegramClient:
condition_summary: Short summary of the matched condition condition_summary: Short summary of the matched condition
confidence: Scenario confidence (0-100) confidence: Scenario confidence (0-100)
""" """
if not self._filter.scenario_match:
return
message = ( message = (
f"<b>Scenario Matched</b>\n" f"<b>Scenario Matched</b>\n"
f"Symbol: <code>{stock_code}</code>\n" f"Symbol: <code>{stock_code}</code>\n"
@@ -443,8 +366,6 @@ class TelegramClient:
market: Market code (e.g., "KR", "US") market: Market code (e.g., "KR", "US")
reason: Failure reason summary reason: Failure reason summary
""" """
if not self._filter.playbook:
return
message = ( message = (
f"<b>Playbook Failed</b>\n" f"<b>Playbook Failed</b>\n"
f"Market: {market}\n" f"Market: {market}\n"
@@ -461,8 +382,6 @@ class TelegramClient:
Args: Args:
reason: Reason for shutdown (e.g., "Normal shutdown", "Circuit breaker") reason: Reason for shutdown (e.g., "Normal shutdown", "Circuit breaker")
""" """
if not self._filter.system_events:
return
message = f"<b>System Shutdown</b>\n{reason}" message = f"<b>System Shutdown</b>\n{reason}"
priority = ( priority = (
NotificationPriority.CRITICAL NotificationPriority.CRITICAL
@@ -473,48 +392,6 @@ class TelegramClient:
NotificationMessage(priority=priority, message=message) NotificationMessage(priority=priority, message=message)
) )
async def notify_unfilled_order(
self,
stock_code: str,
market: str,
action: str,
quantity: int,
outcome: str,
new_price: float | None = None,
) -> None:
"""Notify about an unfilled overseas order that was cancelled or resubmitted.
Args:
stock_code: Stock ticker symbol.
market: Exchange/market code (e.g., "NASD", "SEHK").
action: "BUY" or "SELL".
quantity: Unfilled quantity.
outcome: "cancelled" or "resubmitted".
new_price: New order price if resubmitted (None if only cancelled).
"""
if not self._filter.trades:
return
# SELL resubmit is high priority — position liquidation at risk.
# BUY cancel is medium priority — only cash is freed.
priority = (
NotificationPriority.HIGH
if action == "SELL"
else NotificationPriority.MEDIUM
)
outcome_emoji = "🔄" if outcome == "resubmitted" else ""
outcome_label = "재주문" if outcome == "resubmitted" else "취소됨"
action_emoji = "🔴" if action == "SELL" else "🟢"
lines = [
f"<b>{outcome_emoji} 미체결 주문 {outcome_label}</b>",
f"Symbol: <code>{stock_code}</code> ({market})",
f"Action: {action_emoji} {action}",
f"Quantity: {quantity:,} shares",
]
if new_price is not None:
lines.append(f"New Price: {new_price:.4f}")
message = "\n".join(lines)
await self._send_notification(NotificationMessage(priority=priority, message=message))
async def notify_error( async def notify_error(
self, error_type: str, error_msg: str, context: str self, error_type: str, error_msg: str, context: str
) -> None: ) -> None:
@@ -526,8 +403,6 @@ class TelegramClient:
error_msg: Error message error_msg: Error message
context: Error context (e.g., stock code, market) context: Error context (e.g., stock code, market)
""" """
if not self._filter.errors:
return
message = ( message = (
f"<b>Error: {error_type}</b>\n" f"<b>Error: {error_type}</b>\n"
f"Context: {context}\n" f"Context: {context}\n"
@@ -554,7 +429,6 @@ class TelegramCommandHandler:
self._client = client self._client = client
self._polling_interval = polling_interval self._polling_interval = polling_interval
self._commands: dict[str, Callable[[], Awaitable[None]]] = {} self._commands: dict[str, Callable[[], Awaitable[None]]] = {}
self._commands_with_args: dict[str, Callable[[list[str]], Awaitable[None]]] = {}
self._last_update_id = 0 self._last_update_id = 0
self._polling_task: asyncio.Task[None] | None = None self._polling_task: asyncio.Task[None] | None = None
self._running = False self._running = False
@@ -563,7 +437,7 @@ class TelegramCommandHandler:
self, command: str, handler: Callable[[], Awaitable[None]] self, command: str, handler: Callable[[], Awaitable[None]]
) -> None: ) -> None:
""" """
Register a command handler (no arguments). Register a command handler.
Args: Args:
command: Command name (without leading slash, e.g., "start") command: Command name (without leading slash, e.g., "start")
@@ -572,19 +446,6 @@ class TelegramCommandHandler:
self._commands[command] = handler self._commands[command] = handler
logger.debug("Registered command handler: /%s", command) logger.debug("Registered command handler: /%s", command)
def register_command_with_args(
self, command: str, handler: Callable[[list[str]], Awaitable[None]]
) -> None:
"""
Register a command handler that receives trailing arguments.
Args:
command: Command name (without leading slash, e.g., "notify")
handler: Async function receiving list of argument tokens
"""
self._commands_with_args[command] = handler
logger.debug("Registered command handler (with args): /%s", command)
async def start_polling(self) -> None: async def start_polling(self) -> None:
"""Start long polling for commands.""" """Start long polling for commands."""
if self._running: if self._running:
@@ -646,19 +507,9 @@ class TelegramCommandHandler:
async with session.post(url, json=payload) as resp: async with session.post(url, json=payload) as resp:
if resp.status != 200: if resp.status != 200:
error_text = await resp.text() error_text = await resp.text()
if resp.status == 409: logger.error(
# Another bot instance is already polling — stop this poller entirely. "getUpdates API error (status=%d): %s", resp.status, error_text
# Retrying would keep conflicting with the other instance. )
self._running = False
logger.warning(
"Telegram conflict (409): another instance is already polling. "
"Disabling Telegram commands for this process. "
"Ensure only one instance of The Ouroboros is running at a time.",
)
else:
logger.error(
"getUpdates API error (status=%d): %s", resp.status, error_text
)
return [] return []
data = await resp.json() data = await resp.json()
@@ -715,14 +566,11 @@ class TelegramCommandHandler:
# Remove @botname suffix if present (for group chats) # Remove @botname suffix if present (for group chats)
command_name = command_parts[0].split("@")[0] command_name = command_parts[0].split("@")[0]
# Execute handler (args-aware handlers take priority) # Execute handler
args_handler = self._commands_with_args.get(command_name) handler = self._commands.get(command_name)
if args_handler: if handler:
logger.info("Executing command: /%s %s", command_name, command_parts[1:])
await args_handler(command_parts[1:])
elif command_name in self._commands:
logger.info("Executing command: /%s", command_name) logger.info("Executing command: /%s", command_name)
await self._commands[command_name]() await handler()
else: else:
logger.debug("Unknown command: /%s", command_name) logger.debug("Unknown command: /%s", command_name)
await self._client.send_message( await self._client.send_message(

View File

@@ -46,18 +46,6 @@ class StockCondition(BaseModel):
The ScenarioEngine evaluates all non-None fields as AND conditions. The ScenarioEngine evaluates all non-None fields as AND conditions.
A condition matches only if ALL specified fields are satisfied. A condition matches only if ALL specified fields are satisfied.
Technical indicator fields:
rsi_below / rsi_above — RSI threshold
volume_ratio_above / volume_ratio_below — volume vs previous day
price_above / price_below — absolute price level
price_change_pct_above / price_change_pct_below — intraday % change
Position-aware fields (require market_data enrichment from open position):
unrealized_pnl_pct_above — matches if unrealized P&L > threshold (e.g. 3.0 → +3%)
unrealized_pnl_pct_below — matches if unrealized P&L < threshold (e.g. -2.0 → -2%)
holding_days_above — matches if position held for more than N days
holding_days_below — matches if position held for fewer than N days
""" """
rsi_below: float | None = None rsi_below: float | None = None
@@ -68,10 +56,6 @@ class StockCondition(BaseModel):
price_below: float | None = None price_below: float | None = None
price_change_pct_above: float | None = None price_change_pct_above: float | None = None
price_change_pct_below: float | None = None price_change_pct_below: float | None = None
unrealized_pnl_pct_above: float | None = None
unrealized_pnl_pct_below: float | None = None
holding_days_above: int | None = None
holding_days_below: int | None = None
def has_any_condition(self) -> bool: def has_any_condition(self) -> bool:
"""Check if at least one condition field is set.""" """Check if at least one condition field is set."""
@@ -86,10 +70,6 @@ class StockCondition(BaseModel):
self.price_below, self.price_below,
self.price_change_pct_above, self.price_change_pct_above,
self.price_change_pct_below, self.price_change_pct_below,
self.unrealized_pnl_pct_above,
self.unrealized_pnl_pct_below,
self.holding_days_above,
self.holding_days_below,
) )
) )

View File

@@ -1,8 +1,7 @@
"""Pre-market planner — generates DayPlaybook via Gemini before market open. """Pre-market planner — generates DayPlaybook via Gemini before market open.
One Gemini API call per market per day. Candidates come from SmartVolatilityScanner. One Gemini API call per market per day. Candidates come from SmartVolatilityScanner.
On failure, returns a smart rule-based fallback playbook that uses scanner signals On failure, returns a defensive playbook (all HOLD, no trades).
(momentum/oversold) to generate BUY conditions, avoiding the all-HOLD problem.
""" """
from __future__ import annotations from __future__ import annotations
@@ -75,7 +74,6 @@ class PreMarketPlanner:
market: str, market: str,
candidates: list[ScanCandidate], candidates: list[ScanCandidate],
today: date | None = None, today: date | None = None,
current_holdings: list[dict] | None = None,
) -> DayPlaybook: ) -> DayPlaybook:
"""Generate a DayPlaybook for a market using Gemini. """Generate a DayPlaybook for a market using Gemini.
@@ -83,10 +81,6 @@ class PreMarketPlanner:
market: Market code ("KR" or "US") market: Market code ("KR" or "US")
candidates: Stock candidates from SmartVolatilityScanner candidates: Stock candidates from SmartVolatilityScanner
today: Override date (defaults to date.today()). Use market-local date. today: Override date (defaults to date.today()). Use market-local date.
current_holdings: Currently held positions with entry_price and unrealized_pnl_pct.
Each dict: {"stock_code": str, "name": str, "qty": int,
"entry_price": float, "unrealized_pnl_pct": float,
"holding_days": int}
Returns: Returns:
DayPlaybook with scenarios. Empty/defensive if no candidates or failure. DayPlaybook with scenarios. Empty/defensive if no candidates or failure.
@@ -111,7 +105,6 @@ class PreMarketPlanner:
context_data, context_data,
self_market_scorecard, self_market_scorecard,
cross_market, cross_market,
current_holdings=current_holdings,
) )
# 3. Call Gemini # 3. Call Gemini
@@ -124,8 +117,7 @@ class PreMarketPlanner:
# 4. Parse response # 4. Parse response
playbook = self._parse_response( playbook = self._parse_response(
decision.rationale, today, market, candidates, cross_market, decision.rationale, today, market, candidates, cross_market
current_holdings=current_holdings,
) )
playbook_with_tokens = playbook.model_copy( playbook_with_tokens = playbook.model_copy(
update={"token_count": decision.token_count} update={"token_count": decision.token_count}
@@ -142,7 +134,7 @@ class PreMarketPlanner:
except Exception: except Exception:
logger.exception("Playbook generation failed for %s", market) logger.exception("Playbook generation failed for %s", market)
if self._settings.DEFENSIVE_PLAYBOOK_ON_FAILURE: if self._settings.DEFENSIVE_PLAYBOOK_ON_FAILURE:
return self._smart_fallback_playbook(today, market, candidates, self._settings) return self._defensive_playbook(today, market, candidates)
return self._empty_playbook(today, market) return self._empty_playbook(today, market)
def build_cross_market_context( def build_cross_market_context(
@@ -237,7 +229,6 @@ class PreMarketPlanner:
context_data: dict[str, Any], context_data: dict[str, Any],
self_market_scorecard: dict[str, Any] | None, self_market_scorecard: dict[str, Any] | None,
cross_market: CrossMarketContext | None, cross_market: CrossMarketContext | None,
current_holdings: list[dict] | None = None,
) -> str: ) -> str:
"""Build a structured prompt for Gemini to generate scenario JSON.""" """Build a structured prompt for Gemini to generate scenario JSON."""
max_scenarios = self._settings.MAX_SCENARIOS_PER_STOCK max_scenarios = self._settings.MAX_SCENARIOS_PER_STOCK
@@ -249,26 +240,6 @@ class PreMarketPlanner:
for c in candidates for c in candidates
) )
holdings_text = ""
if current_holdings:
lines = []
for h in current_holdings:
code = h.get("stock_code", "")
name = h.get("name", "")
qty = h.get("qty", 0)
entry_price = h.get("entry_price", 0.0)
pnl_pct = h.get("unrealized_pnl_pct", 0.0)
holding_days = h.get("holding_days", 0)
lines.append(
f" - {code} ({name}): {qty}주 @ {entry_price:,.0f}, "
f"미실현손익 {pnl_pct:+.2f}%, 보유 {holding_days}"
)
holdings_text = (
"\n## Current Holdings (보유 중 — SELL/HOLD 전략 고려 필요)\n"
+ "\n".join(lines)
+ "\n"
)
cross_market_text = "" cross_market_text = ""
if cross_market: if cross_market:
cross_market_text = ( cross_market_text = (
@@ -301,20 +272,10 @@ class PreMarketPlanner:
for key, value in list(layer_data.items())[:5]: for key, value in list(layer_data.items())[:5]:
context_text += f" - {key}: {value}\n" context_text += f" - {key}: {value}\n"
holdings_instruction = ""
if current_holdings:
holding_codes = [h.get("stock_code", "") for h in current_holdings]
holdings_instruction = (
f"- Also include SELL/HOLD scenarios for held stocks: "
f"{', '.join(holding_codes)} "
f"(even if not in candidates list)\n"
)
return ( return (
f"You are a pre-market trading strategist for the {market} market.\n" f"You are a pre-market trading strategist for the {market} market.\n"
f"Generate structured trading scenarios for today.\n\n" f"Generate structured trading scenarios for today.\n\n"
f"## Candidates (from volatility scanner)\n{candidates_text}\n" f"## Candidates (from volatility scanner)\n{candidates_text}\n"
f"{holdings_text}"
f"{self_market_text}" f"{self_market_text}"
f"{cross_market_text}" f"{cross_market_text}"
f"{context_text}\n" f"{context_text}\n"
@@ -332,8 +293,7 @@ class PreMarketPlanner:
f' "stock_code": "...",\n' f' "stock_code": "...",\n'
f' "scenarios": [\n' f' "scenarios": [\n'
f' {{\n' f' {{\n'
f' "condition": {{"rsi_below": 30, "volume_ratio_above": 2.0,' f' "condition": {{"rsi_below": 30, "volume_ratio_above": 2.0}},\n'
f' "unrealized_pnl_pct_above": 3.0, "holding_days_above": 5}},\n'
f' "action": "BUY|SELL|HOLD",\n' f' "action": "BUY|SELL|HOLD",\n'
f' "confidence": 85,\n' f' "confidence": 85,\n'
f' "allocation_pct": 10.0,\n' f' "allocation_pct": 10.0,\n'
@@ -347,8 +307,7 @@ class PreMarketPlanner:
f'}}\n\n' f'}}\n\n'
f"Rules:\n" f"Rules:\n"
f"- Max {max_scenarios} scenarios per stock\n" f"- Max {max_scenarios} scenarios per stock\n"
f"- Candidates list is the primary source for BUY candidates\n" f"- Only use stocks from the candidates list\n"
f"{holdings_instruction}"
f"- Confidence 0-100 (80+ for actionable trades)\n" f"- Confidence 0-100 (80+ for actionable trades)\n"
f"- stop_loss_pct must be <= 0, take_profit_pct must be >= 0\n" f"- stop_loss_pct must be <= 0, take_profit_pct must be >= 0\n"
f"- Return ONLY the JSON, no markdown fences or explanation\n" f"- Return ONLY the JSON, no markdown fences or explanation\n"
@@ -361,19 +320,12 @@ class PreMarketPlanner:
market: str, market: str,
candidates: list[ScanCandidate], candidates: list[ScanCandidate],
cross_market: CrossMarketContext | None, cross_market: CrossMarketContext | None,
current_holdings: list[dict] | None = None,
) -> DayPlaybook: ) -> DayPlaybook:
"""Parse Gemini's JSON response into a validated DayPlaybook.""" """Parse Gemini's JSON response into a validated DayPlaybook."""
cleaned = self._extract_json(response_text) cleaned = self._extract_json(response_text)
data = json.loads(cleaned) data = json.loads(cleaned)
valid_codes = {c.stock_code for c in candidates} valid_codes = {c.stock_code for c in candidates}
# Holdings are also valid — AI may generate SELL/HOLD scenarios for them
if current_holdings:
for h in current_holdings:
code = h.get("stock_code", "")
if code:
valid_codes.add(code)
# Parse market outlook # Parse market outlook
outlook_str = data.get("market_outlook", "neutral") outlook_str = data.get("market_outlook", "neutral")
@@ -437,10 +389,6 @@ class PreMarketPlanner:
price_below=cond_data.get("price_below"), price_below=cond_data.get("price_below"),
price_change_pct_above=cond_data.get("price_change_pct_above"), price_change_pct_above=cond_data.get("price_change_pct_above"),
price_change_pct_below=cond_data.get("price_change_pct_below"), price_change_pct_below=cond_data.get("price_change_pct_below"),
unrealized_pnl_pct_above=cond_data.get("unrealized_pnl_pct_above"),
unrealized_pnl_pct_below=cond_data.get("unrealized_pnl_pct_below"),
holding_days_above=cond_data.get("holding_days_above"),
holding_days_below=cond_data.get("holding_days_below"),
) )
if not condition.has_any_condition(): if not condition.has_any_condition():
@@ -522,99 +470,3 @@ class PreMarketPlanner:
), ),
], ],
) )
@staticmethod
def _smart_fallback_playbook(
today: date,
market: str,
candidates: list[ScanCandidate],
settings: Settings,
) -> DayPlaybook:
"""Rule-based fallback playbook when Gemini is unavailable.
Uses scanner signals (RSI, volume_ratio) to generate meaningful BUY
conditions instead of the all-SELL defensive playbook. Candidates are
already pre-qualified by SmartVolatilityScanner, so we trust their
signals and build actionable scenarios from them.
Scenario logic per candidate:
- momentum signal: BUY when volume_ratio exceeds scanner threshold
- oversold signal: BUY when RSI is below oversold threshold
- always: SELL stop-loss at -3.0% as guard
"""
stock_playbooks = []
for c in candidates:
scenarios: list[StockScenario] = []
if c.signal == "momentum":
scenarios.append(
StockScenario(
condition=StockCondition(
volume_ratio_above=settings.VOL_MULTIPLIER,
),
action=ScenarioAction.BUY,
confidence=80,
allocation_pct=10.0,
stop_loss_pct=-3.0,
take_profit_pct=5.0,
rationale=(
f"Rule-based BUY: momentum signal, "
f"volume={c.volume_ratio:.1f}x (fallback planner)"
),
)
)
elif c.signal == "oversold":
scenarios.append(
StockScenario(
condition=StockCondition(
rsi_below=settings.RSI_OVERSOLD_THRESHOLD,
),
action=ScenarioAction.BUY,
confidence=80,
allocation_pct=10.0,
stop_loss_pct=-3.0,
take_profit_pct=5.0,
rationale=(
f"Rule-based BUY: oversold signal, "
f"RSI={c.rsi:.0f} (fallback planner)"
),
)
)
# Always add stop-loss guard
scenarios.append(
StockScenario(
condition=StockCondition(price_change_pct_below=-3.0),
action=ScenarioAction.SELL,
confidence=90,
stop_loss_pct=-3.0,
rationale="Rule-based stop-loss (fallback planner)",
)
)
stock_playbooks.append(
StockPlaybook(
stock_code=c.stock_code,
scenarios=scenarios,
)
)
logger.info(
"Smart fallback playbook for %s: %d stocks with rule-based BUY/SELL conditions",
market,
len(stock_playbooks),
)
return DayPlaybook(
date=today,
market=market,
market_outlook=MarketOutlook.NEUTRAL,
default_action=ScenarioAction.HOLD,
stock_playbooks=stock_playbooks,
global_rules=[
GlobalRule(
condition="portfolio_pnl_pct < -2.0",
action=ScenarioAction.REDUCE_ALL,
rationale="Defensive: reduce on loss threshold",
),
],
)

View File

@@ -206,37 +206,6 @@ class ScenarioEngine:
if condition.price_change_pct_below is not None: if condition.price_change_pct_below is not None:
checks.append(price_change_pct is not None and price_change_pct < condition.price_change_pct_below) checks.append(price_change_pct is not None and price_change_pct < condition.price_change_pct_below)
# Position-aware conditions
unrealized_pnl_pct = self._safe_float(market_data.get("unrealized_pnl_pct"))
if condition.unrealized_pnl_pct_above is not None or condition.unrealized_pnl_pct_below is not None:
if "unrealized_pnl_pct" not in market_data:
self._warn_missing_key("unrealized_pnl_pct")
if condition.unrealized_pnl_pct_above is not None:
checks.append(
unrealized_pnl_pct is not None
and unrealized_pnl_pct > condition.unrealized_pnl_pct_above
)
if condition.unrealized_pnl_pct_below is not None:
checks.append(
unrealized_pnl_pct is not None
and unrealized_pnl_pct < condition.unrealized_pnl_pct_below
)
holding_days = self._safe_float(market_data.get("holding_days"))
if condition.holding_days_above is not None or condition.holding_days_below is not None:
if "holding_days" not in market_data:
self._warn_missing_key("holding_days")
if condition.holding_days_above is not None:
checks.append(
holding_days is not None
and holding_days > condition.holding_days_above
)
if condition.holding_days_below is not None:
checks.append(
holding_days is not None
and holding_days < condition.holding_days_below
)
return len(checks) > 0 and all(checks) return len(checks) > 0 and all(checks)
def _evaluate_global_condition( def _evaluate_global_condition(
@@ -297,9 +266,5 @@ class ScenarioEngine:
details["current_price"] = self._safe_float(market_data.get("current_price")) details["current_price"] = self._safe_float(market_data.get("current_price"))
if condition.price_change_pct_above is not None or condition.price_change_pct_below is not None: if condition.price_change_pct_above is not None or condition.price_change_pct_below is not None:
details["price_change_pct"] = self._safe_float(market_data.get("price_change_pct")) details["price_change_pct"] = self._safe_float(market_data.get("price_change_pct"))
if condition.unrealized_pnl_pct_above is not None or condition.unrealized_pnl_pct_below is not None:
details["unrealized_pnl_pct"] = self._safe_float(market_data.get("unrealized_pnl_pct"))
if condition.holding_days_above is not None or condition.holding_days_below is not None:
details["holding_days"] = self._safe_float(market_data.get("holding_days"))
return details return details

View File

@@ -3,11 +3,9 @@
from __future__ import annotations from __future__ import annotations
import sqlite3 import sqlite3
import sys
import tempfile import tempfile
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from pathlib import Path from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest import pytest
@@ -365,435 +363,3 @@ class TestHealthMonitor:
assert "timestamp" in report assert "timestamp" in report
assert "checks" in report assert "checks" in report
assert len(report["checks"]) == 3 assert len(report["checks"]) == 3
# ---------------------------------------------------------------------------
# BackupExporter — additional coverage for previously uncovered branches
# ---------------------------------------------------------------------------
@pytest.fixture
def empty_db(tmp_path: Path) -> Path:
"""Create a temporary database with NO trade records."""
db_path = tmp_path / "empty_trades.db"
conn = sqlite3.connect(str(db_path))
conn.execute(
"""CREATE TABLE trades (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TEXT NOT NULL,
stock_code TEXT NOT NULL,
action TEXT NOT NULL,
quantity INTEGER NOT NULL,
price REAL NOT NULL,
confidence INTEGER NOT NULL,
rationale TEXT,
pnl REAL DEFAULT 0.0
)"""
)
conn.commit()
conn.close()
return db_path
class TestBackupExporterAdditional:
"""Cover branches missed in the original TestBackupExporter suite."""
def test_export_all_default_formats(self, temp_db: Path, tmp_path: Path) -> None:
"""export_all with formats=None must default to JSON+CSV+Parquet path."""
exporter = BackupExporter(str(temp_db))
# formats=None triggers the default list assignment (line 62)
results = exporter.export_all(tmp_path / "out", formats=None, compress=False)
# JSON and CSV must always succeed; Parquet needs pyarrow
assert ExportFormat.JSON in results
assert ExportFormat.CSV in results
def test_export_all_logs_error_on_failure(
self, temp_db: Path, tmp_path: Path
) -> None:
"""export_all must log an error and continue when one format fails."""
exporter = BackupExporter(str(temp_db))
# Patch _export_format to raise on JSON, succeed on CSV
original = exporter._export_format
def failing_export(fmt, *args, **kwargs): # type: ignore[no-untyped-def]
if fmt == ExportFormat.JSON:
raise RuntimeError("simulated failure")
return original(fmt, *args, **kwargs)
exporter._export_format = failing_export # type: ignore[method-assign]
results = exporter.export_all(
tmp_path / "out",
formats=[ExportFormat.JSON, ExportFormat.CSV],
compress=False,
)
# JSON failed → not in results; CSV succeeded → in results
assert ExportFormat.JSON not in results
assert ExportFormat.CSV in results
def test_export_csv_empty_trades_no_compress(
self, empty_db: Path, tmp_path: Path
) -> None:
"""CSV export with no trades and compress=False must write header row only."""
exporter = BackupExporter(str(empty_db))
results = exporter.export_all(
tmp_path / "out",
formats=[ExportFormat.CSV],
compress=False,
)
assert ExportFormat.CSV in results
out = results[ExportFormat.CSV]
assert out.exists()
content = out.read_text()
assert "timestamp" in content
def test_export_csv_empty_trades_compressed(
self, empty_db: Path, tmp_path: Path
) -> None:
"""CSV export with no trades and compress=True must write gzipped header."""
import gzip
exporter = BackupExporter(str(empty_db))
results = exporter.export_all(
tmp_path / "out",
formats=[ExportFormat.CSV],
compress=True,
)
assert ExportFormat.CSV in results
out = results[ExportFormat.CSV]
assert out.suffix == ".gz"
with gzip.open(out, "rt", encoding="utf-8") as f:
content = f.read()
assert "timestamp" in content
def test_export_csv_with_data_compressed(
self, temp_db: Path, tmp_path: Path
) -> None:
"""CSV export with data and compress=True must write gzipped rows."""
import gzip
exporter = BackupExporter(str(temp_db))
results = exporter.export_all(
tmp_path / "out",
formats=[ExportFormat.CSV],
compress=True,
)
assert ExportFormat.CSV in results
out = results[ExportFormat.CSV]
with gzip.open(out, "rt", encoding="utf-8") as f:
lines = f.readlines()
# Header + 3 data rows
assert len(lines) == 4
def test_export_parquet_raises_import_error_without_pyarrow(
self, temp_db: Path, tmp_path: Path
) -> None:
"""Parquet export must raise ImportError when pyarrow is not installed."""
exporter = BackupExporter(str(temp_db))
with patch.dict(sys.modules, {"pyarrow": None, "pyarrow.parquet": None}):
try:
import pyarrow # noqa: F401
pytest.skip("pyarrow is installed; cannot test ImportError path")
except ImportError:
pass
results = exporter.export_all(
tmp_path / "out",
formats=[ExportFormat.PARQUET],
compress=False,
)
# Parquet export fails gracefully; result dict should not contain it
assert ExportFormat.PARQUET not in results
# ---------------------------------------------------------------------------
# CloudStorage — mocked boto3 tests
# ---------------------------------------------------------------------------
@pytest.fixture
def mock_boto3_module():
"""Inject a fake boto3 into sys.modules for the duration of the test."""
mock = MagicMock()
with patch.dict(sys.modules, {"boto3": mock}):
yield mock
@pytest.fixture
def s3_config():
"""Minimal S3Config for tests."""
from src.backup.cloud_storage import S3Config
return S3Config(
endpoint_url="http://localhost:9000",
access_key="minioadmin",
secret_key="minioadmin",
bucket_name="test-bucket",
region="us-east-1",
)
class TestCloudStorage:
"""Test CloudStorage using mocked boto3."""
def test_init_creates_s3_client(self, mock_boto3_module, s3_config) -> None:
"""CloudStorage.__init__ must call boto3.client with the correct args."""
from src.backup.cloud_storage import CloudStorage
storage = CloudStorage(s3_config)
mock_boto3_module.client.assert_called_once()
call_kwargs = mock_boto3_module.client.call_args[1]
assert call_kwargs["aws_access_key_id"] == "minioadmin"
assert call_kwargs["aws_secret_access_key"] == "minioadmin"
assert storage.config == s3_config
def test_init_raises_if_boto3_missing(self, s3_config) -> None:
"""CloudStorage.__init__ must raise ImportError when boto3 is absent."""
with patch.dict(sys.modules, {"boto3": None}): # type: ignore[dict-item]
with pytest.raises((ImportError, TypeError)):
# Re-import to trigger the try/except inside __init__
import importlib
import src.backup.cloud_storage as m
importlib.reload(m)
m.CloudStorage(s3_config)
def test_upload_file_success(
self, mock_boto3_module, s3_config, tmp_path: Path
) -> None:
"""upload_file must call client.upload_file and return the object key."""
from src.backup.cloud_storage import CloudStorage
test_file = tmp_path / "backup.json.gz"
test_file.write_bytes(b"data")
storage = CloudStorage(s3_config)
key = storage.upload_file(test_file, object_key="backups/backup.json.gz")
assert key == "backups/backup.json.gz"
storage.client.upload_file.assert_called_once()
def test_upload_file_default_key(
self, mock_boto3_module, s3_config, tmp_path: Path
) -> None:
"""upload_file without object_key must use the filename as key."""
from src.backup.cloud_storage import CloudStorage
test_file = tmp_path / "myfile.gz"
test_file.write_bytes(b"data")
storage = CloudStorage(s3_config)
key = storage.upload_file(test_file)
assert key == "myfile.gz"
def test_upload_file_not_found(
self, mock_boto3_module, s3_config, tmp_path: Path
) -> None:
"""upload_file must raise FileNotFoundError for missing files."""
from src.backup.cloud_storage import CloudStorage
storage = CloudStorage(s3_config)
with pytest.raises(FileNotFoundError):
storage.upload_file(tmp_path / "nonexistent.gz")
def test_upload_file_propagates_client_error(
self, mock_boto3_module, s3_config, tmp_path: Path
) -> None:
"""upload_file must re-raise exceptions from the boto3 client."""
from src.backup.cloud_storage import CloudStorage
test_file = tmp_path / "backup.gz"
test_file.write_bytes(b"data")
storage = CloudStorage(s3_config)
storage.client.upload_file.side_effect = RuntimeError("network error")
with pytest.raises(RuntimeError, match="network error"):
storage.upload_file(test_file)
def test_download_file_success(
self, mock_boto3_module, s3_config, tmp_path: Path
) -> None:
"""download_file must call client.download_file and return local path."""
from src.backup.cloud_storage import CloudStorage
storage = CloudStorage(s3_config)
dest = tmp_path / "downloads" / "backup.gz"
result = storage.download_file("backups/backup.gz", dest)
assert result == dest
storage.client.download_file.assert_called_once()
def test_download_file_propagates_error(
self, mock_boto3_module, s3_config, tmp_path: Path
) -> None:
"""download_file must re-raise exceptions from the boto3 client."""
from src.backup.cloud_storage import CloudStorage
storage = CloudStorage(s3_config)
storage.client.download_file.side_effect = RuntimeError("timeout")
with pytest.raises(RuntimeError, match="timeout"):
storage.download_file("key", tmp_path / "dest.gz")
def test_list_files_returns_objects(
self, mock_boto3_module, s3_config
) -> None:
"""list_files must return parsed file metadata from S3 response."""
from datetime import timezone
from src.backup.cloud_storage import CloudStorage
storage = CloudStorage(s3_config)
storage.client.list_objects_v2.return_value = {
"Contents": [
{
"Key": "backups/a.gz",
"Size": 1024,
"LastModified": datetime(2026, 1, 1, tzinfo=timezone.utc),
"ETag": '"abc123"',
}
]
}
files = storage.list_files(prefix="backups/")
assert len(files) == 1
assert files[0]["key"] == "backups/a.gz"
assert files[0]["size_bytes"] == 1024
def test_list_files_empty_bucket(
self, mock_boto3_module, s3_config
) -> None:
"""list_files must return empty list when bucket has no objects."""
from src.backup.cloud_storage import CloudStorage
storage = CloudStorage(s3_config)
storage.client.list_objects_v2.return_value = {}
files = storage.list_files()
assert files == []
def test_list_files_propagates_error(
self, mock_boto3_module, s3_config
) -> None:
"""list_files must re-raise exceptions from the boto3 client."""
from src.backup.cloud_storage import CloudStorage
storage = CloudStorage(s3_config)
storage.client.list_objects_v2.side_effect = RuntimeError("auth error")
with pytest.raises(RuntimeError):
storage.list_files()
def test_delete_file_success(
self, mock_boto3_module, s3_config
) -> None:
"""delete_file must call client.delete_object with the correct key."""
from src.backup.cloud_storage import CloudStorage
storage = CloudStorage(s3_config)
storage.delete_file("backups/old.gz")
storage.client.delete_object.assert_called_once_with(
Bucket="test-bucket", Key="backups/old.gz"
)
def test_delete_file_propagates_error(
self, mock_boto3_module, s3_config
) -> None:
"""delete_file must re-raise exceptions from the boto3 client."""
from src.backup.cloud_storage import CloudStorage
storage = CloudStorage(s3_config)
storage.client.delete_object.side_effect = RuntimeError("permission denied")
with pytest.raises(RuntimeError):
storage.delete_file("backups/old.gz")
def test_get_storage_stats_success(
self, mock_boto3_module, s3_config
) -> None:
"""get_storage_stats must aggregate file sizes correctly."""
from datetime import timezone
from src.backup.cloud_storage import CloudStorage
storage = CloudStorage(s3_config)
storage.client.list_objects_v2.return_value = {
"Contents": [
{
"Key": "a.gz",
"Size": 1024 * 1024,
"LastModified": datetime(2026, 1, 1, tzinfo=timezone.utc),
"ETag": '"x"',
},
{
"Key": "b.gz",
"Size": 1024 * 1024,
"LastModified": datetime(2026, 1, 2, tzinfo=timezone.utc),
"ETag": '"y"',
},
]
}
stats = storage.get_storage_stats()
assert stats["total_files"] == 2
assert stats["total_size_bytes"] == 2 * 1024 * 1024
assert stats["total_size_mb"] == pytest.approx(2.0)
def test_get_storage_stats_on_error(
self, mock_boto3_module, s3_config
) -> None:
"""get_storage_stats must return error dict without raising on failure."""
from src.backup.cloud_storage import CloudStorage
storage = CloudStorage(s3_config)
storage.client.list_objects_v2.side_effect = RuntimeError("no connection")
stats = storage.get_storage_stats()
assert "error" in stats
assert stats["total_files"] == 0
def test_verify_connection_success(
self, mock_boto3_module, s3_config
) -> None:
"""verify_connection must return True when head_bucket succeeds."""
from src.backup.cloud_storage import CloudStorage
storage = CloudStorage(s3_config)
result = storage.verify_connection()
assert result is True
def test_verify_connection_failure(
self, mock_boto3_module, s3_config
) -> None:
"""verify_connection must return False when head_bucket raises."""
from src.backup.cloud_storage import CloudStorage
storage = CloudStorage(s3_config)
storage.client.head_bucket.side_effect = RuntimeError("no such bucket")
result = storage.verify_connection()
assert result is False
def test_enable_versioning(
self, mock_boto3_module, s3_config
) -> None:
"""enable_versioning must call put_bucket_versioning."""
from src.backup.cloud_storage import CloudStorage
storage = CloudStorage(s3_config)
storage.enable_versioning()
storage.client.put_bucket_versioning.assert_called_once()
def test_enable_versioning_propagates_error(
self, mock_boto3_module, s3_config
) -> None:
"""enable_versioning must re-raise exceptions from the boto3 client."""
from src.backup.cloud_storage import CloudStorage
storage = CloudStorage(s3_config)
storage.client.put_bucket_versioning.side_effect = RuntimeError("denied")
with pytest.raises(RuntimeError):
storage.enable_versioning()

View File

@@ -2,10 +2,6 @@
from __future__ import annotations from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from src.brain.gemini_client import GeminiClient from src.brain.gemini_client import GeminiClient
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -93,21 +89,9 @@ class TestMalformedJsonHandling:
def test_json_with_missing_fields_returns_hold(self, settings): def test_json_with_missing_fields_returns_hold(self, settings):
client = GeminiClient(settings) client = GeminiClient(settings)
raw = '{"action": "BUY"}' decision = client.parse_response('{"action": "BUY"}')
decision = client.parse_response(raw)
assert decision.action == "HOLD" assert decision.action == "HOLD"
assert decision.confidence == 0 assert decision.confidence == 0
# rationale preserves raw so prompt_override callers (e.g. pre_market_planner)
# can extract non-TradeDecision JSON from decision.rationale (#245)
assert decision.rationale == raw
def test_non_trade_decision_json_preserves_raw_in_rationale(self, settings):
"""Playbook JSON (no action/confidence/rationale) must be preserved for planner."""
client = GeminiClient(settings)
playbook_json = '{"market_outlook": "neutral", "stocks": []}'
decision = client.parse_response(playbook_json)
assert decision.action == "HOLD"
assert decision.rationale == playbook_json
def test_json_with_invalid_action_returns_hold(self, settings): def test_json_with_invalid_action_returns_hold(self, settings):
client = GeminiClient(settings) client = GeminiClient(settings)
@@ -286,132 +270,3 @@ class TestBatchDecisionParsing:
assert decisions["AAPL"].action == "HOLD" assert decisions["AAPL"].action == "HOLD"
assert decisions["AAPL"].confidence == 0 assert decisions["AAPL"].confidence == 0
# ---------------------------------------------------------------------------
# Prompt Override (used by pre_market_planner)
# ---------------------------------------------------------------------------
class TestPromptOverride:
"""decide() must use prompt_override when present in market_data."""
@pytest.mark.asyncio
async def test_prompt_override_is_sent_to_gemini(self, settings):
"""When prompt_override is in market_data, it should be used as the prompt."""
client = GeminiClient(settings)
custom_prompt = "You are a playbook generator. Return JSON with scenarios."
playbook_json = '{"market_outlook": "neutral", "stocks": []}'
mock_response = MagicMock()
mock_response.text = playbook_json
with patch.object(
client._client.aio.models,
"generate_content",
new_callable=AsyncMock,
return_value=mock_response,
) as mock_generate:
market_data = {
"stock_code": "PLANNER",
"current_price": 0,
"prompt_override": custom_prompt,
}
decision = await client.decide(market_data)
# Verify the custom prompt was sent, not a built prompt
mock_generate.assert_called_once()
actual_prompt = mock_generate.call_args[1].get(
"contents", mock_generate.call_args[0][1] if len(mock_generate.call_args[0]) > 1 else None
)
assert actual_prompt == custom_prompt
# Raw response preserved in rationale without parse_response (#247)
assert decision.rationale == playbook_json
@pytest.mark.asyncio
async def test_prompt_override_skips_parse_response(self, settings):
"""prompt_override bypasses parse_response — no Missing fields warning, raw preserved."""
client = GeminiClient(settings)
client._enable_optimization = True
custom_prompt = "Custom playbook prompt"
playbook_json = '{"market_outlook": "bullish", "stocks": [{"stock_code": "AAPL"}]}'
mock_response = MagicMock()
mock_response.text = playbook_json
with patch.object(
client._client.aio.models,
"generate_content",
new_callable=AsyncMock,
return_value=mock_response,
):
with patch.object(client, "parse_response") as mock_parse:
market_data = {
"stock_code": "PLANNER",
"current_price": 0,
"prompt_override": custom_prompt,
}
decision = await client.decide(market_data)
# parse_response must NOT be called for prompt_override
mock_parse.assert_not_called()
# Raw playbook JSON preserved in rationale
assert decision.rationale == playbook_json
@pytest.mark.asyncio
async def test_prompt_override_takes_priority_over_optimization(self, settings):
"""prompt_override must win over enable_optimization=True."""
client = GeminiClient(settings)
client._enable_optimization = True
custom_prompt = "Explicit playbook prompt"
mock_response = MagicMock()
mock_response.text = '{"market_outlook": "neutral", "stocks": []}'
with patch.object(
client._client.aio.models,
"generate_content",
new_callable=AsyncMock,
return_value=mock_response,
) as mock_generate:
market_data = {
"stock_code": "PLANNER",
"current_price": 0,
"prompt_override": custom_prompt,
}
await client.decide(market_data)
actual_prompt = mock_generate.call_args[1].get(
"contents", mock_generate.call_args[0][1] if len(mock_generate.call_args[0]) > 1 else None
)
# The custom prompt must be used, not the compressed prompt
assert actual_prompt == custom_prompt
@pytest.mark.asyncio
async def test_without_prompt_override_uses_build_prompt(self, settings):
"""Without prompt_override, decide() should use build_prompt as before."""
client = GeminiClient(settings)
mock_response = MagicMock()
mock_response.text = '{"action": "HOLD", "confidence": 50, "rationale": "ok"}'
with patch.object(
client._client.aio.models,
"generate_content",
new_callable=AsyncMock,
return_value=mock_response,
) as mock_generate:
market_data = {
"stock_code": "005930",
"current_price": 72000,
}
await client.decide(market_data)
actual_prompt = mock_generate.call_args[1].get(
"contents", mock_generate.call_args[0][1] if len(mock_generate.call_args[0]) > 1 else None
)
# Should contain stock code from build_prompt, not be a custom override
assert "005930" in actual_prompt

View File

@@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, patch
import pytest import pytest
@@ -90,12 +90,12 @@ class TestTokenManagement:
await broker.close() await broker.close()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_token_refresh_cooldown_waits_then_retries(self, settings): async def test_token_refresh_cooldown_prevents_rapid_retries(self, settings):
"""Token refresh should wait out cooldown then retry (issue #54).""" """Token refresh should enforce cooldown after failure (issue #54)."""
broker = KISBroker(settings) broker = KISBroker(settings)
broker._refresh_cooldown = 0.1 # Short cooldown for testing broker._refresh_cooldown = 2.0 # Short cooldown for testing
# All attempts fail with 403 (EGW00133) # First refresh attempt fails with 403 (EGW00133)
mock_resp_403 = AsyncMock() mock_resp_403 = AsyncMock()
mock_resp_403.status = 403 mock_resp_403.status = 403
mock_resp_403.text = AsyncMock( mock_resp_403.text = AsyncMock(
@@ -109,8 +109,8 @@ class TestTokenManagement:
with pytest.raises(ConnectionError, match="Token refresh failed"): with pytest.raises(ConnectionError, match="Token refresh failed"):
await broker._ensure_token() await broker._ensure_token()
# Second attempt within cooldown should wait then retry (and still get 403) # Second attempt within cooldown should fail with cooldown error
with pytest.raises(ConnectionError, match="Token refresh failed"): with pytest.raises(ConnectionError, match="Token refresh on cooldown"):
await broker._ensure_token() await broker._ensure_token()
await broker.close() await broker.close()
@@ -296,647 +296,3 @@ class TestHashKey:
mock_acquire.assert_called_once() mock_acquire.assert_called_once()
await broker.close() await broker.close()
# ---------------------------------------------------------------------------
# fetch_market_rankings — TR_ID, path, params (issue #155)
# ---------------------------------------------------------------------------
def _make_ranking_mock(items: list[dict]) -> AsyncMock:
"""Build a mock HTTP response returning ranking items."""
mock_resp = AsyncMock()
mock_resp.status = 200
mock_resp.json = AsyncMock(return_value={"output": items})
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
return mock_resp
class TestFetchMarketRankings:
"""Verify correct TR_ID, API path, and params per ranking_type (issue #155)."""
@pytest.fixture
def broker(self, settings) -> KISBroker:
b = KISBroker(settings)
b._access_token = "tok"
b._token_expires_at = float("inf")
b._rate_limiter.acquire = AsyncMock()
return b
@pytest.mark.asyncio
async def test_volume_uses_correct_tr_id_and_path(self, broker: KISBroker) -> None:
mock_resp = _make_ranking_mock([])
with patch("aiohttp.ClientSession.get", return_value=mock_resp) as mock_get:
await broker.fetch_market_rankings(ranking_type="volume")
call_kwargs = mock_get.call_args
url = call_kwargs[0][0] if call_kwargs[0] else call_kwargs[1].get("url", "")
headers = call_kwargs[1].get("headers", {})
params = call_kwargs[1].get("params", {})
assert "volume-rank" in url
assert headers.get("tr_id") == "FHPST01710000"
assert params.get("FID_COND_SCR_DIV_CODE") == "20171"
assert params.get("FID_TRGT_EXLS_CLS_CODE") == "0000000000"
@pytest.mark.asyncio
async def test_fluctuation_uses_correct_tr_id_and_path(self, broker: KISBroker) -> None:
mock_resp = _make_ranking_mock([])
with patch("aiohttp.ClientSession.get", return_value=mock_resp) as mock_get:
await broker.fetch_market_rankings(ranking_type="fluctuation")
call_kwargs = mock_get.call_args
url = call_kwargs[0][0] if call_kwargs[0] else call_kwargs[1].get("url", "")
headers = call_kwargs[1].get("headers", {})
params = call_kwargs[1].get("params", {})
assert "ranking/fluctuation" in url
assert headers.get("tr_id") == "FHPST01700000"
assert params.get("fid_cond_scr_div_code") == "20170"
# 실전 API는 4자리("0000") 거부 — 1자리("0")여야 한다 (#240)
assert params.get("fid_rank_sort_cls_code") == "0"
@pytest.mark.asyncio
async def test_volume_returns_parsed_rows(self, broker: KISBroker) -> None:
items = [
{
"mksc_shrn_iscd": "005930",
"hts_kor_isnm": "삼성전자",
"stck_prpr": "75000",
"acml_vol": "10000000",
"prdy_ctrt": "2.5",
"vol_inrt": "150",
}
]
mock_resp = _make_ranking_mock(items)
with patch("aiohttp.ClientSession.get", return_value=mock_resp):
result = await broker.fetch_market_rankings(ranking_type="volume")
assert len(result) == 1
assert result[0]["stock_code"] == "005930"
assert result[0]["price"] == 75000.0
assert result[0]["change_rate"] == 2.5
@pytest.mark.asyncio
async def test_fluctuation_parses_stck_shrn_iscd(self, broker: KISBroker) -> None:
"""실전 API는 mksc_shrn_iscd 대신 stck_shrn_iscd를 반환한다 (#240)."""
items = [
{
"stck_shrn_iscd": "015260",
"hts_kor_isnm": "에이엔피",
"stck_prpr": "794",
"acml_vol": "4896196",
"prdy_ctrt": "29.74",
"vol_inrt": "0",
}
]
mock_resp = _make_ranking_mock(items)
with patch("aiohttp.ClientSession.get", return_value=mock_resp):
result = await broker.fetch_market_rankings(ranking_type="fluctuation")
assert len(result) == 1
assert result[0]["stock_code"] == "015260"
assert result[0]["change_rate"] == 29.74
# ---------------------------------------------------------------------------
# KRX tick unit / round-down helpers (issue #157)
# ---------------------------------------------------------------------------
from src.broker.kis_api import kr_tick_unit, kr_round_down # noqa: E402
class TestKrTickUnit:
"""kr_tick_unit and kr_round_down must implement KRX price tick rules."""
@pytest.mark.parametrize(
"price, expected_tick",
[
(1999, 1),
(2000, 5),
(4999, 5),
(5000, 10),
(19999, 10),
(20000, 50),
(49999, 50),
(50000, 100),
(199999, 100),
(200000, 500),
(499999, 500),
(500000, 1000),
(1000000, 1000),
],
)
def test_tick_unit_boundaries(self, price: int, expected_tick: int) -> None:
assert kr_tick_unit(price) == expected_tick
@pytest.mark.parametrize(
"price, expected_rounded",
[
(188150, 188100), # 100원 단위, 50원 잔여 → 내림
(188100, 188100), # 이미 정렬됨
(75050, 75000), # 100원 단위, 50원 잔여 → 내림
(49950, 49950), # 50원 단위 정렬됨
(49960, 49950), # 50원 단위, 10원 잔여 → 내림
(1999, 1999), # 1원 단위 → 그대로
(5003, 5000), # 10원 단위, 3원 잔여 → 내림
],
)
def test_round_down_to_tick(self, price: int, expected_rounded: int) -> None:
assert kr_round_down(price) == expected_rounded
# ---------------------------------------------------------------------------
# get_current_price (issue #157)
# ---------------------------------------------------------------------------
class TestGetCurrentPrice:
"""get_current_price must use inquire-price API and return (price, change, foreigner)."""
@pytest.fixture
def broker(self, settings) -> KISBroker:
b = KISBroker(settings)
b._access_token = "tok"
b._token_expires_at = float("inf")
b._rate_limiter.acquire = AsyncMock()
return b
@pytest.mark.asyncio
async def test_returns_correct_fields(self, broker: KISBroker) -> None:
mock_resp = AsyncMock()
mock_resp.status = 200
mock_resp.json = AsyncMock(
return_value={
"rt_cd": "0",
"output": {
"stck_prpr": "188600",
"prdy_ctrt": "3.97",
"frgn_ntby_qty": "12345",
},
}
)
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
with patch("aiohttp.ClientSession.get", return_value=mock_resp) as mock_get:
price, change_pct, foreigner = await broker.get_current_price("005930")
assert price == 188600.0
assert change_pct == 3.97
assert foreigner == 12345.0
call_kwargs = mock_get.call_args
url = call_kwargs[0][0] if call_kwargs[0] else call_kwargs[1].get("url", "")
headers = call_kwargs[1].get("headers", {})
assert "inquire-price" in url
assert headers.get("tr_id") == "FHKST01010100"
@pytest.mark.asyncio
async def test_http_error_raises_connection_error(self, broker: KISBroker) -> None:
mock_resp = AsyncMock()
mock_resp.status = 500
mock_resp.text = AsyncMock(return_value="Internal Server Error")
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
with patch("aiohttp.ClientSession.get", return_value=mock_resp):
with pytest.raises(ConnectionError, match="get_current_price failed"):
await broker.get_current_price("005930")
# ---------------------------------------------------------------------------
# send_order tick rounding and ORD_DVSN (issue #157)
# ---------------------------------------------------------------------------
class TestSendOrderTickRounding:
"""send_order must apply KRX tick rounding and correct ORD_DVSN codes."""
@pytest.fixture
def broker(self, settings) -> KISBroker:
b = KISBroker(settings)
b._access_token = "tok"
b._token_expires_at = float("inf")
b._rate_limiter.acquire = AsyncMock()
return b
@pytest.mark.asyncio
async def test_limit_order_rounds_down_to_tick(self, broker: KISBroker) -> None:
"""Price 188150 (not on 100-won tick) must be rounded to 188100."""
mock_hash = AsyncMock()
mock_hash.status = 200
mock_hash.json = AsyncMock(return_value={"HASH": "h"})
mock_hash.__aenter__ = AsyncMock(return_value=mock_hash)
mock_hash.__aexit__ = AsyncMock(return_value=False)
mock_order = AsyncMock()
mock_order.status = 200
mock_order.json = AsyncMock(return_value={"rt_cd": "0"})
mock_order.__aenter__ = AsyncMock(return_value=mock_order)
mock_order.__aexit__ = AsyncMock(return_value=False)
with patch(
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
) as mock_post:
await broker.send_order("005930", "BUY", 1, price=188150)
order_call = mock_post.call_args_list[1]
body = order_call[1].get("json", {})
assert body["ORD_UNPR"] == "188100" # rounded down
assert body["ORD_DVSN"] == "00" # 지정가
@pytest.mark.asyncio
async def test_limit_order_ord_dvsn_is_00(self, broker: KISBroker) -> None:
"""send_order with price>0 must use ORD_DVSN='00' (지정가)."""
mock_hash = AsyncMock()
mock_hash.status = 200
mock_hash.json = AsyncMock(return_value={"HASH": "h"})
mock_hash.__aenter__ = AsyncMock(return_value=mock_hash)
mock_hash.__aexit__ = AsyncMock(return_value=False)
mock_order = AsyncMock()
mock_order.status = 200
mock_order.json = AsyncMock(return_value={"rt_cd": "0"})
mock_order.__aenter__ = AsyncMock(return_value=mock_order)
mock_order.__aexit__ = AsyncMock(return_value=False)
with patch(
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
) as mock_post:
await broker.send_order("005930", "BUY", 1, price=50000)
order_call = mock_post.call_args_list[1]
body = order_call[1].get("json", {})
assert body["ORD_DVSN"] == "00"
@pytest.mark.asyncio
async def test_market_order_ord_dvsn_is_01(self, broker: KISBroker) -> None:
"""send_order with price=0 must use ORD_DVSN='01' (시장가)."""
mock_hash = AsyncMock()
mock_hash.status = 200
mock_hash.json = AsyncMock(return_value={"HASH": "h"})
mock_hash.__aenter__ = AsyncMock(return_value=mock_hash)
mock_hash.__aexit__ = AsyncMock(return_value=False)
mock_order = AsyncMock()
mock_order.status = 200
mock_order.json = AsyncMock(return_value={"rt_cd": "0"})
mock_order.__aenter__ = AsyncMock(return_value=mock_order)
mock_order.__aexit__ = AsyncMock(return_value=False)
with patch(
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
) as mock_post:
await broker.send_order("005930", "SELL", 1, price=0)
order_call = mock_post.call_args_list[1]
body = order_call[1].get("json", {})
assert body["ORD_DVSN"] == "01"
# ---------------------------------------------------------------------------
# TR_ID live/paper branching (issues #201, #202, #203)
# ---------------------------------------------------------------------------
class TestTRIDBranchingDomestic:
"""get_balance and send_order must use correct TR_ID for live vs paper mode."""
def _make_broker(self, settings, mode: str) -> KISBroker:
from src.config import Settings
s = Settings(
KIS_APP_KEY=settings.KIS_APP_KEY,
KIS_APP_SECRET=settings.KIS_APP_SECRET,
KIS_ACCOUNT_NO=settings.KIS_ACCOUNT_NO,
GEMINI_API_KEY=settings.GEMINI_API_KEY,
DB_PATH=":memory:",
ENABLED_MARKETS="KR",
MODE=mode,
)
b = KISBroker(s)
b._access_token = "tok"
b._token_expires_at = float("inf")
b._rate_limiter.acquire = AsyncMock()
return b
@pytest.mark.asyncio
async def test_get_balance_paper_uses_vttc8434r(self, settings) -> None:
broker = self._make_broker(settings, "paper")
mock_resp = AsyncMock()
mock_resp.status = 200
mock_resp.json = AsyncMock(
return_value={"output1": [], "output2": {}}
)
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
with patch("aiohttp.ClientSession.get", return_value=mock_resp) as mock_get:
await broker.get_balance()
headers = mock_get.call_args[1].get("headers", {})
assert headers["tr_id"] == "VTTC8434R"
@pytest.mark.asyncio
async def test_get_balance_live_uses_tttc8434r(self, settings) -> None:
broker = self._make_broker(settings, "live")
mock_resp = AsyncMock()
mock_resp.status = 200
mock_resp.json = AsyncMock(
return_value={"output1": [], "output2": {}}
)
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
with patch("aiohttp.ClientSession.get", return_value=mock_resp) as mock_get:
await broker.get_balance()
headers = mock_get.call_args[1].get("headers", {})
assert headers["tr_id"] == "TTTC8434R"
@pytest.mark.asyncio
async def test_send_order_buy_paper_uses_vttc0012u(self, settings) -> None:
broker = self._make_broker(settings, "paper")
mock_hash = AsyncMock()
mock_hash.status = 200
mock_hash.json = AsyncMock(return_value={"HASH": "h"})
mock_hash.__aenter__ = AsyncMock(return_value=mock_hash)
mock_hash.__aexit__ = AsyncMock(return_value=False)
mock_order = AsyncMock()
mock_order.status = 200
mock_order.json = AsyncMock(return_value={"rt_cd": "0"})
mock_order.__aenter__ = AsyncMock(return_value=mock_order)
mock_order.__aexit__ = AsyncMock(return_value=False)
with patch(
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
) as mock_post:
await broker.send_order("005930", "BUY", 1)
order_headers = mock_post.call_args_list[1][1].get("headers", {})
assert order_headers["tr_id"] == "VTTC0012U"
@pytest.mark.asyncio
async def test_send_order_buy_live_uses_tttc0012u(self, settings) -> None:
broker = self._make_broker(settings, "live")
mock_hash = AsyncMock()
mock_hash.status = 200
mock_hash.json = AsyncMock(return_value={"HASH": "h"})
mock_hash.__aenter__ = AsyncMock(return_value=mock_hash)
mock_hash.__aexit__ = AsyncMock(return_value=False)
mock_order = AsyncMock()
mock_order.status = 200
mock_order.json = AsyncMock(return_value={"rt_cd": "0"})
mock_order.__aenter__ = AsyncMock(return_value=mock_order)
mock_order.__aexit__ = AsyncMock(return_value=False)
with patch(
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
) as mock_post:
await broker.send_order("005930", "BUY", 1)
order_headers = mock_post.call_args_list[1][1].get("headers", {})
assert order_headers["tr_id"] == "TTTC0012U"
@pytest.mark.asyncio
async def test_send_order_sell_paper_uses_vttc0011u(self, settings) -> None:
broker = self._make_broker(settings, "paper")
mock_hash = AsyncMock()
mock_hash.status = 200
mock_hash.json = AsyncMock(return_value={"HASH": "h"})
mock_hash.__aenter__ = AsyncMock(return_value=mock_hash)
mock_hash.__aexit__ = AsyncMock(return_value=False)
mock_order = AsyncMock()
mock_order.status = 200
mock_order.json = AsyncMock(return_value={"rt_cd": "0"})
mock_order.__aenter__ = AsyncMock(return_value=mock_order)
mock_order.__aexit__ = AsyncMock(return_value=False)
with patch(
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
) as mock_post:
await broker.send_order("005930", "SELL", 1)
order_headers = mock_post.call_args_list[1][1].get("headers", {})
assert order_headers["tr_id"] == "VTTC0011U"
@pytest.mark.asyncio
async def test_send_order_sell_live_uses_tttc0011u(self, settings) -> None:
broker = self._make_broker(settings, "live")
mock_hash = AsyncMock()
mock_hash.status = 200
mock_hash.json = AsyncMock(return_value={"HASH": "h"})
mock_hash.__aenter__ = AsyncMock(return_value=mock_hash)
mock_hash.__aexit__ = AsyncMock(return_value=False)
mock_order = AsyncMock()
mock_order.status = 200
mock_order.json = AsyncMock(return_value={"rt_cd": "0"})
mock_order.__aenter__ = AsyncMock(return_value=mock_order)
mock_order.__aexit__ = AsyncMock(return_value=False)
with patch(
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
) as mock_post:
await broker.send_order("005930", "SELL", 1)
order_headers = mock_post.call_args_list[1][1].get("headers", {})
assert order_headers["tr_id"] == "TTTC0011U"
# ---------------------------------------------------------------------------
# Domestic Pending Orders (get_domestic_pending_orders)
# ---------------------------------------------------------------------------
class TestGetDomesticPendingOrders:
"""get_domestic_pending_orders must return [] in paper mode and call TTTC0084R in live."""
def _make_broker(self, settings, mode: str) -> KISBroker:
from src.config import Settings
s = Settings(
KIS_APP_KEY=settings.KIS_APP_KEY,
KIS_APP_SECRET=settings.KIS_APP_SECRET,
KIS_ACCOUNT_NO=settings.KIS_ACCOUNT_NO,
GEMINI_API_KEY=settings.GEMINI_API_KEY,
DB_PATH=":memory:",
ENABLED_MARKETS="KR",
MODE=mode,
)
b = KISBroker(s)
b._access_token = "tok"
b._token_expires_at = float("inf")
b._rate_limiter.acquire = AsyncMock()
return b
@pytest.mark.asyncio
async def test_paper_mode_returns_empty(self, settings) -> None:
"""Paper mode must return [] immediately without any API call."""
broker = self._make_broker(settings, "paper")
with patch("aiohttp.ClientSession.get") as mock_get:
result = await broker.get_domestic_pending_orders()
assert result == []
mock_get.assert_not_called()
@pytest.mark.asyncio
async def test_live_mode_calls_tttc0084r_with_correct_params(
self, settings
) -> None:
"""Live mode must call TTTC0084R with INQR_DVSN_1/2 and paging params."""
broker = self._make_broker(settings, "live")
pending = [{"odno": "001", "pdno": "005930", "psbl_qty": "10"}]
mock_resp = AsyncMock()
mock_resp.status = 200
mock_resp.json = AsyncMock(return_value={"output": pending})
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
with patch("aiohttp.ClientSession.get", return_value=mock_resp) as mock_get:
result = await broker.get_domestic_pending_orders()
assert result == pending
headers = mock_get.call_args[1].get("headers", {})
assert headers["tr_id"] == "TTTC0084R"
params = mock_get.call_args[1].get("params", {})
assert params["INQR_DVSN_1"] == "0"
assert params["INQR_DVSN_2"] == "0"
@pytest.mark.asyncio
async def test_live_mode_connection_error(self, settings) -> None:
"""Network error must raise ConnectionError."""
import aiohttp as _aiohttp
broker = self._make_broker(settings, "live")
with patch(
"aiohttp.ClientSession.get",
side_effect=_aiohttp.ClientError("timeout"),
):
with pytest.raises(ConnectionError):
await broker.get_domestic_pending_orders()
# ---------------------------------------------------------------------------
# Domestic Order Cancellation (cancel_domestic_order)
# ---------------------------------------------------------------------------
class TestCancelDomesticOrder:
"""cancel_domestic_order must use correct TR_ID and build body correctly."""
def _make_broker(self, settings, mode: str) -> KISBroker:
from src.config import Settings
s = Settings(
KIS_APP_KEY=settings.KIS_APP_KEY,
KIS_APP_SECRET=settings.KIS_APP_SECRET,
KIS_ACCOUNT_NO=settings.KIS_ACCOUNT_NO,
GEMINI_API_KEY=settings.GEMINI_API_KEY,
DB_PATH=":memory:",
ENABLED_MARKETS="KR",
MODE=mode,
)
b = KISBroker(s)
b._access_token = "tok"
b._token_expires_at = float("inf")
b._rate_limiter.acquire = AsyncMock()
return b
def _make_post_mocks(self, order_payload: dict) -> tuple:
mock_hash = AsyncMock()
mock_hash.status = 200
mock_hash.json = AsyncMock(return_value={"HASH": "h"})
mock_hash.__aenter__ = AsyncMock(return_value=mock_hash)
mock_hash.__aexit__ = AsyncMock(return_value=False)
mock_order = AsyncMock()
mock_order.status = 200
mock_order.json = AsyncMock(return_value=order_payload)
mock_order.__aenter__ = AsyncMock(return_value=mock_order)
mock_order.__aexit__ = AsyncMock(return_value=False)
return mock_hash, mock_order
@pytest.mark.asyncio
async def test_live_uses_tttc0013u(self, settings) -> None:
"""Live mode must use TR_ID TTTC0013U."""
broker = self._make_broker(settings, "live")
mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"})
with patch(
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
) as mock_post:
await broker.cancel_domestic_order("005930", "ORD001", "BRNO01", 5)
order_headers = mock_post.call_args_list[1][1].get("headers", {})
assert order_headers["tr_id"] == "TTTC0013U"
@pytest.mark.asyncio
async def test_paper_uses_vttc0013u(self, settings) -> None:
"""Paper mode must use TR_ID VTTC0013U."""
broker = self._make_broker(settings, "paper")
mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"})
with patch(
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
) as mock_post:
await broker.cancel_domestic_order("005930", "ORD001", "BRNO01", 5)
order_headers = mock_post.call_args_list[1][1].get("headers", {})
assert order_headers["tr_id"] == "VTTC0013U"
@pytest.mark.asyncio
async def test_cancel_sets_rvse_cncl_dvsn_cd_02(self, settings) -> None:
"""Body must have RVSE_CNCL_DVSN_CD='02' (취소) and QTY_ALL_ORD_YN='Y'."""
broker = self._make_broker(settings, "live")
mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"})
with patch(
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
) as mock_post:
await broker.cancel_domestic_order("005930", "ORD001", "BRNO01", 5)
body = mock_post.call_args_list[1][1].get("json", {})
assert body["RVSE_CNCL_DVSN_CD"] == "02"
assert body["QTY_ALL_ORD_YN"] == "Y"
assert body["ORD_UNPR"] == "0"
@pytest.mark.asyncio
async def test_cancel_sets_krx_fwdg_ord_orgno_in_body(self, settings) -> None:
"""Body must include KRX_FWDG_ORD_ORGNO and ORGN_ODNO from arguments."""
broker = self._make_broker(settings, "live")
mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"})
with patch(
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
) as mock_post:
await broker.cancel_domestic_order("005930", "ORD123", "BRN456", 3)
body = mock_post.call_args_list[1][1].get("json", {})
assert body["KRX_FWDG_ORD_ORGNO"] == "BRN456"
assert body["ORGN_ODNO"] == "ORD123"
assert body["ORD_QTY"] == "3"
@pytest.mark.asyncio
async def test_cancel_sets_hashkey_header(self, settings) -> None:
"""Request must include hashkey header (same pattern as send_order)."""
broker = self._make_broker(settings, "live")
mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"})
with patch(
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
) as mock_post:
await broker.cancel_domestic_order("005930", "ORD001", "BRNO01", 2)
order_headers = mock_post.call_args_list[1][1].get("headers", {})
assert "hashkey" in order_headers
assert order_headers["hashkey"] == "h"

View File

@@ -10,7 +10,6 @@ import pytest
from src.context.aggregator import ContextAggregator from src.context.aggregator import ContextAggregator
from src.context.layer import LAYER_CONFIG, ContextLayer from src.context.layer import LAYER_CONFIG, ContextLayer
from src.context.store import ContextStore from src.context.store import ContextStore
from src.context.summarizer import ContextSummarizer
from src.db import init_db, log_trade from src.db import init_db, log_trade
@@ -371,259 +370,3 @@ class TestLayerMetadata:
# L1 aggregates from L2 # L1 aggregates from L2
assert LAYER_CONFIG[ContextLayer.L1_LEGACY].aggregation_source == ContextLayer.L2_ANNUAL assert LAYER_CONFIG[ContextLayer.L1_LEGACY].aggregation_source == ContextLayer.L2_ANNUAL
# ---------------------------------------------------------------------------
# ContextSummarizer tests
# ---------------------------------------------------------------------------
@pytest.fixture
def summarizer(db_conn: sqlite3.Connection) -> ContextSummarizer:
"""Provide a ContextSummarizer backed by an in-memory store."""
return ContextSummarizer(ContextStore(db_conn))
class TestContextSummarizer:
"""Test suite for ContextSummarizer."""
# ------------------------------------------------------------------
# summarize_numeric_values
# ------------------------------------------------------------------
def test_summarize_empty_values(self, summarizer: ContextSummarizer) -> None:
"""Empty list must return SummaryStats with count=0 and no other fields."""
stats = summarizer.summarize_numeric_values([])
assert stats.count == 0
assert stats.mean is None
assert stats.min is None
assert stats.max is None
def test_summarize_single_value(self, summarizer: ContextSummarizer) -> None:
"""Single-element list must return correct stats with std=0 and trend=flat."""
stats = summarizer.summarize_numeric_values([42.0])
assert stats.count == 1
assert stats.mean == 42.0
assert stats.std == 0.0
assert stats.trend == "flat"
def test_summarize_upward_trend(self, summarizer: ContextSummarizer) -> None:
"""Increasing values must produce trend='up'."""
values = [1.0, 2.0, 3.0, 10.0, 20.0, 30.0]
stats = summarizer.summarize_numeric_values(values)
assert stats.trend == "up"
def test_summarize_downward_trend(self, summarizer: ContextSummarizer) -> None:
"""Decreasing values must produce trend='down'."""
values = [30.0, 20.0, 10.0, 3.0, 2.0, 1.0]
stats = summarizer.summarize_numeric_values(values)
assert stats.trend == "down"
def test_summarize_flat_trend(self, summarizer: ContextSummarizer) -> None:
"""Stable values must produce trend='flat'."""
values = [100.0, 100.1, 99.9, 100.0, 100.2, 99.8]
stats = summarizer.summarize_numeric_values(values)
assert stats.trend == "flat"
# ------------------------------------------------------------------
# summarize_layer
# ------------------------------------------------------------------
def test_summarize_layer_no_data(
self, summarizer: ContextSummarizer
) -> None:
"""summarize_layer with no data must return the 'No data' sentinel."""
result = summarizer.summarize_layer(ContextLayer.L6_DAILY)
assert result["count"] == 0
assert "No data" in result["summary"]
def test_summarize_layer_numeric(
self, summarizer: ContextSummarizer, db_conn: sqlite3.Connection
) -> None:
"""summarize_layer must collect numeric values and produce stats."""
store = summarizer.store
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "total_pnl", 100.0)
store.set_context(ContextLayer.L6_DAILY, "2026-02-02", "total_pnl", 200.0)
result = summarizer.summarize_layer(ContextLayer.L6_DAILY)
assert "total_entries" in result
def test_summarize_layer_with_dict_values(
self, summarizer: ContextSummarizer
) -> None:
"""summarize_layer must handle dict values by extracting numeric subkeys."""
store = summarizer.store
# set_context serialises the value as JSON, so passing a dict works
store.set_context(
ContextLayer.L6_DAILY, "2026-02-01", "metrics",
{"win_rate": 65.0, "label": "good"}
)
result = summarizer.summarize_layer(ContextLayer.L6_DAILY)
assert "total_entries" in result
# numeric subkey "win_rate" should appear as "metrics.win_rate"
assert "metrics.win_rate" in result
def test_summarize_layer_with_string_values(
self, summarizer: ContextSummarizer
) -> None:
"""summarize_layer must count string values separately."""
store = summarizer.store
# set_context stores string values as JSON-encoded strings
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "outlook", "BULLISH")
result = summarizer.summarize_layer(ContextLayer.L6_DAILY)
# String fields contribute a `<key>_count` entry
assert "outlook_count" in result
# ------------------------------------------------------------------
# rolling_window_summary
# ------------------------------------------------------------------
def test_rolling_window_summary_basic(
self, summarizer: ContextSummarizer
) -> None:
"""rolling_window_summary must return the expected structure."""
store = summarizer.store
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "pnl", 500.0)
result = summarizer.rolling_window_summary(ContextLayer.L6_DAILY)
assert "window_days" in result
assert "recent_data" in result
assert "historical_summary" in result
def test_rolling_window_summary_no_older_data(
self, summarizer: ContextSummarizer
) -> None:
"""rolling_window_summary with summarize_older=False skips history."""
result = summarizer.rolling_window_summary(
ContextLayer.L6_DAILY, summarize_older=False
)
assert result["historical_summary"] == {}
# ------------------------------------------------------------------
# aggregate_to_higher_layer
# ------------------------------------------------------------------
def test_aggregate_to_higher_layer_mean(
self, summarizer: ContextSummarizer
) -> None:
"""aggregate_to_higher_layer with 'mean' via dict subkeys returns average."""
store = summarizer.store
# Use different outer keys but same inner metric key so get_all_contexts
# returns multiple rows with the target subkey.
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day1", {"pnl": 100.0})
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day2", {"pnl": 200.0})
result = summarizer.aggregate_to_higher_layer(
ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY, "pnl", "mean"
)
assert result == pytest.approx(150.0)
def test_aggregate_to_higher_layer_sum(
self, summarizer: ContextSummarizer
) -> None:
"""aggregate_to_higher_layer with 'sum' must return the total."""
store = summarizer.store
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day1", {"pnl": 100.0})
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day2", {"pnl": 200.0})
result = summarizer.aggregate_to_higher_layer(
ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY, "pnl", "sum"
)
assert result == pytest.approx(300.0)
def test_aggregate_to_higher_layer_max(
self, summarizer: ContextSummarizer
) -> None:
"""aggregate_to_higher_layer with 'max' must return the maximum."""
store = summarizer.store
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day1", {"pnl": 100.0})
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day2", {"pnl": 200.0})
result = summarizer.aggregate_to_higher_layer(
ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY, "pnl", "max"
)
assert result == pytest.approx(200.0)
def test_aggregate_to_higher_layer_min(
self, summarizer: ContextSummarizer
) -> None:
"""aggregate_to_higher_layer with 'min' must return the minimum."""
store = summarizer.store
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day1", {"pnl": 100.0})
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day2", {"pnl": 200.0})
result = summarizer.aggregate_to_higher_layer(
ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY, "pnl", "min"
)
assert result == pytest.approx(100.0)
def test_aggregate_to_higher_layer_no_data(
self, summarizer: ContextSummarizer
) -> None:
"""aggregate_to_higher_layer with no matching key must return None."""
result = summarizer.aggregate_to_higher_layer(
ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY, "nonexistent", "mean"
)
assert result is None
def test_aggregate_to_higher_layer_unknown_func_defaults_to_mean(
self, summarizer: ContextSummarizer
) -> None:
"""Unknown aggregation function must fall back to mean."""
store = summarizer.store
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day1", {"pnl": 100.0})
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day2", {"pnl": 200.0})
result = summarizer.aggregate_to_higher_layer(
ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY, "pnl", "unknown_func"
)
assert result == pytest.approx(150.0)
# ------------------------------------------------------------------
# create_compact_summary + format_summary_for_prompt
# ------------------------------------------------------------------
def test_create_compact_summary(
self, summarizer: ContextSummarizer
) -> None:
"""create_compact_summary must produce a dict keyed by layer value."""
store = summarizer.store
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "pnl", 100.0)
result = summarizer.create_compact_summary([ContextLayer.L6_DAILY])
assert ContextLayer.L6_DAILY.value in result
def test_format_summary_for_prompt_with_numeric_metrics(
self, summarizer: ContextSummarizer
) -> None:
"""format_summary_for_prompt must render avg/trend fields."""
store = summarizer.store
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "pnl", 100.0)
store.set_context(ContextLayer.L6_DAILY, "2026-02-02", "pnl", 200.0)
compact = summarizer.create_compact_summary([ContextLayer.L6_DAILY])
text = summarizer.format_summary_for_prompt(compact)
assert isinstance(text, str)
def test_format_summary_for_prompt_skips_empty_layers(
self, summarizer: ContextSummarizer
) -> None:
"""format_summary_for_prompt must skip layers with no metrics."""
summary = {ContextLayer.L6_DAILY.value: {}}
text = summarizer.format_summary_for_prompt(summary)
assert text == ""
def test_format_summary_non_dict_value(
self, summarizer: ContextSummarizer
) -> None:
"""format_summary_for_prompt must render non-dict values as plain text."""
summary = {
"daily": {
"plain_count": 42,
}
}
text = summarizer.format_summary_for_prompt(summary)
assert "plain_count" in text
assert "42" in text

View File

@@ -296,156 +296,3 @@ def test_scenarios_active_empty_when_no_matches(tmp_path: Path) -> None:
get_active_scenarios = _endpoint(app, "/api/scenarios/active") get_active_scenarios = _endpoint(app, "/api/scenarios/active")
body = get_active_scenarios(market="US", date_str="2026-02-14", limit=50) body = get_active_scenarios(market="US", date_str="2026-02-14", limit=50)
assert body["count"] == 0 assert body["count"] == 0
def test_pnl_history_all_markets(tmp_path: Path) -> None:
app = _app(tmp_path)
get_pnl_history = _endpoint(app, "/api/pnl/history")
body = get_pnl_history(days=30, market="all")
assert body["market"] == "all"
assert isinstance(body["labels"], list)
assert isinstance(body["pnl"], list)
assert len(body["labels"]) == len(body["pnl"])
def test_pnl_history_market_filter(tmp_path: Path) -> None:
app = _app(tmp_path)
get_pnl_history = _endpoint(app, "/api/pnl/history")
body = get_pnl_history(days=30, market="KR")
assert body["market"] == "KR"
# KR has 1 trade with pnl=2.0
assert len(body["labels"]) >= 1
assert body["pnl"][0] == 2.0
def test_positions_returns_open_buy(tmp_path: Path) -> None:
"""BUY가 마지막 거래인 종목은 포지션으로 반환되어야 한다."""
app = _app(tmp_path)
get_positions = _endpoint(app, "/api/positions")
body = get_positions()
# seed_db: 005930은 BUY (오픈), AAPL은 SELL (마지막)
assert body["count"] == 1
pos = body["positions"][0]
assert pos["stock_code"] == "005930"
assert pos["market"] == "KR"
assert pos["quantity"] == 1
assert pos["entry_price"] == 70000
def test_positions_excludes_closed_sell(tmp_path: Path) -> None:
"""마지막 거래가 SELL인 종목은 포지션에 나타나지 않아야 한다."""
app = _app(tmp_path)
get_positions = _endpoint(app, "/api/positions")
body = get_positions()
codes = [p["stock_code"] for p in body["positions"]]
assert "AAPL" not in codes
def test_positions_empty_when_no_trades(tmp_path: Path) -> None:
"""거래 내역이 없으면 빈 포지션 목록을 반환해야 한다."""
db_path = tmp_path / "empty.db"
conn = init_db(str(db_path))
conn.close()
app = create_dashboard_app(str(db_path))
get_positions = _endpoint(app, "/api/positions")
body = get_positions()
assert body["count"] == 0
assert body["positions"] == []
def _seed_cb_context(conn: sqlite3.Connection, pnl_pct: float, market: str = "KR") -> None:
import json as _json
conn.execute(
"INSERT OR REPLACE INTO system_metrics (key, value, updated_at) VALUES (?, ?, ?)",
(
f"portfolio_pnl_pct_{market}",
_json.dumps({"pnl_pct": pnl_pct}),
"2026-02-22T10:00:00+00:00",
),
)
conn.commit()
def test_status_circuit_breaker_ok(tmp_path: Path) -> None:
"""pnl_pct가 -2.0%보다 높으면 status=ok를 반환해야 한다."""
db_path = tmp_path / "cb_ok.db"
conn = init_db(str(db_path))
_seed_cb_context(conn, -1.0)
conn.close()
app = create_dashboard_app(str(db_path))
get_status = _endpoint(app, "/api/status")
body = get_status()
cb = body["circuit_breaker"]
assert cb["status"] == "ok"
assert cb["current_pnl_pct"] == -1.0
assert cb["threshold_pct"] == -3.0
def test_status_circuit_breaker_warning(tmp_path: Path) -> None:
"""pnl_pct가 -2.0% 이하이면 status=warning을 반환해야 한다."""
db_path = tmp_path / "cb_warn.db"
conn = init_db(str(db_path))
_seed_cb_context(conn, -2.5)
conn.close()
app = create_dashboard_app(str(db_path))
get_status = _endpoint(app, "/api/status")
body = get_status()
assert body["circuit_breaker"]["status"] == "warning"
def test_status_circuit_breaker_tripped(tmp_path: Path) -> None:
"""pnl_pct가 임계값(-3.0%) 이하이면 status=tripped를 반환해야 한다."""
db_path = tmp_path / "cb_tripped.db"
conn = init_db(str(db_path))
_seed_cb_context(conn, -3.5)
conn.close()
app = create_dashboard_app(str(db_path))
get_status = _endpoint(app, "/api/status")
body = get_status()
assert body["circuit_breaker"]["status"] == "tripped"
def test_status_circuit_breaker_unknown_when_no_data(tmp_path: Path) -> None:
"""L7 context에 pnl_pct 데이터가 없으면 status=unknown을 반환해야 한다."""
app = _app(tmp_path) # seed_db에는 portfolio_pnl_pct 없음
get_status = _endpoint(app, "/api/status")
body = get_status()
cb = body["circuit_breaker"]
assert cb["status"] == "unknown"
assert cb["current_pnl_pct"] is None
def test_status_mode_paper(tmp_path: Path) -> None:
"""mode=paper로 생성하면 status 응답에 mode=paper가 포함돼야 한다."""
db_path = tmp_path / "dashboard_test.db"
conn = init_db(str(db_path))
_seed_db(conn)
conn.close()
app = create_dashboard_app(str(db_path), mode="paper")
get_status = _endpoint(app, "/api/status")
body = get_status()
assert body["mode"] == "paper"
def test_status_mode_live(tmp_path: Path) -> None:
"""mode=live로 생성하면 status 응답에 mode=live가 포함돼야 한다."""
db_path = tmp_path / "dashboard_test.db"
conn = init_db(str(db_path))
_seed_db(conn)
conn.close()
app = create_dashboard_app(str(db_path), mode="live")
get_status = _endpoint(app, "/api/status")
body = get_status()
assert body["mode"] == "live"
def test_status_mode_default_paper(tmp_path: Path) -> None:
"""mode 파라미터 미전달 시 기본값은 paper여야 한다."""
db_path = tmp_path / "dashboard_test.db"
conn = init_db(str(db_path))
_seed_db(conn)
conn.close()
app = create_dashboard_app(str(db_path))
get_status = _endpoint(app, "/api/status")
body = get_status()
assert body["mode"] == "paper"

View File

@@ -1,8 +1,5 @@
"""Tests for database helper functions.""" """Tests for database helper functions."""
import tempfile
import os
from src.db import get_open_position, init_db, log_trade from src.db import get_open_position, init_db, log_trade
@@ -61,135 +58,3 @@ def test_get_open_position_returns_none_when_latest_is_sell() -> None:
def test_get_open_position_returns_none_when_no_trades() -> None: def test_get_open_position_returns_none_when_no_trades() -> None:
conn = init_db(":memory:") conn = init_db(":memory:")
assert get_open_position(conn, "AAPL", "US_NASDAQ") is None assert get_open_position(conn, "AAPL", "US_NASDAQ") is None
# ---------------------------------------------------------------------------
# WAL mode tests (issue #210)
# ---------------------------------------------------------------------------
def test_wal_mode_applied_to_file_db() -> None:
"""File-based DB must use WAL journal mode for dashboard concurrent reads."""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
conn = init_db(db_path)
cursor = conn.execute("PRAGMA journal_mode")
mode = cursor.fetchone()[0]
assert mode == "wal", f"Expected WAL mode, got {mode}"
conn.close()
finally:
os.unlink(db_path)
# Clean up WAL auxiliary files if they exist
for ext in ("-wal", "-shm"):
path = db_path + ext
if os.path.exists(path):
os.unlink(path)
def test_wal_mode_not_applied_to_memory_db() -> None:
""":memory: DB must not apply WAL (SQLite does not support WAL for in-memory)."""
conn = init_db(":memory:")
cursor = conn.execute("PRAGMA journal_mode")
mode = cursor.fetchone()[0]
# In-memory DBs default to 'memory' journal mode
assert mode != "wal", "WAL should not be set on in-memory database"
conn.close()
# ---------------------------------------------------------------------------
# mode column tests (issue #212)
# ---------------------------------------------------------------------------
def test_log_trade_stores_mode_paper() -> None:
"""log_trade must persist mode='paper' in the trades table."""
conn = init_db(":memory:")
log_trade(
conn=conn,
stock_code="005930",
action="BUY",
confidence=85,
rationale="test",
mode="paper",
)
row = conn.execute("SELECT mode FROM trades ORDER BY id DESC LIMIT 1").fetchone()
assert row is not None
assert row[0] == "paper"
def test_log_trade_stores_mode_live() -> None:
"""log_trade must persist mode='live' in the trades table."""
conn = init_db(":memory:")
log_trade(
conn=conn,
stock_code="005930",
action="BUY",
confidence=85,
rationale="test",
mode="live",
)
row = conn.execute("SELECT mode FROM trades ORDER BY id DESC LIMIT 1").fetchone()
assert row is not None
assert row[0] == "live"
def test_log_trade_default_mode_is_paper() -> None:
"""log_trade without explicit mode must default to 'paper'."""
conn = init_db(":memory:")
log_trade(
conn=conn,
stock_code="005930",
action="HOLD",
confidence=50,
rationale="test",
)
row = conn.execute("SELECT mode FROM trades ORDER BY id DESC LIMIT 1").fetchone()
assert row is not None
assert row[0] == "paper"
def test_mode_column_exists_in_schema() -> None:
"""trades table must have a mode column after init_db."""
conn = init_db(":memory:")
cursor = conn.execute("PRAGMA table_info(trades)")
columns = {row[1] for row in cursor.fetchall()}
assert "mode" in columns
def test_mode_migration_adds_column_to_existing_db() -> None:
"""init_db must add mode column to existing DBs that lack it (migration)."""
import sqlite3
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
# Create DB without mode column (simulate old schema)
old_conn = sqlite3.connect(db_path)
old_conn.execute(
"""CREATE TABLE trades (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TEXT NOT NULL,
stock_code TEXT NOT NULL,
action TEXT NOT NULL,
confidence INTEGER NOT NULL,
rationale TEXT,
quantity INTEGER,
price REAL,
pnl REAL DEFAULT 0.0,
market TEXT DEFAULT 'KR',
exchange_code TEXT DEFAULT 'KRX',
decision_id TEXT
)"""
)
old_conn.commit()
old_conn.close()
# Run init_db — should add mode column via migration
conn = init_db(db_path)
cursor = conn.execute("PRAGMA table_info(trades)")
columns = {row[1] for row in cursor.fetchall()}
assert "mode" in columns
conn.close()
finally:
os.unlink(db_path)

View File

@@ -1,117 +0,0 @@
"""Tests for JSON structured logging configuration."""
from __future__ import annotations
import json
import logging
import sys
from src.logging_config import JSONFormatter, setup_logging
class TestJSONFormatter:
"""Test JSONFormatter output."""
def test_basic_log_record(self) -> None:
"""JSONFormatter must emit valid JSON with required fields."""
formatter = JSONFormatter()
record = logging.LogRecord(
name="test.logger",
level=logging.INFO,
pathname="",
lineno=0,
msg="Hello %s",
args=("world",),
exc_info=None,
)
output = formatter.format(record)
data = json.loads(output)
assert data["level"] == "INFO"
assert data["logger"] == "test.logger"
assert data["message"] == "Hello world"
assert "timestamp" in data
def test_includes_exception_info(self) -> None:
"""JSONFormatter must include exception info when present."""
formatter = JSONFormatter()
try:
raise ValueError("test error")
except ValueError:
exc_info = sys.exc_info()
record = logging.LogRecord(
name="test",
level=logging.ERROR,
pathname="",
lineno=0,
msg="oops",
args=(),
exc_info=exc_info,
)
output = formatter.format(record)
data = json.loads(output)
assert "exception" in data
assert "ValueError" in data["exception"]
def test_extra_trading_fields_included(self) -> None:
"""Extra trading fields attached to the record must appear in JSON."""
formatter = JSONFormatter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="",
lineno=0,
msg="trade",
args=(),
exc_info=None,
)
record.stock_code = "005930" # type: ignore[attr-defined]
record.action = "BUY" # type: ignore[attr-defined]
record.confidence = 85 # type: ignore[attr-defined]
record.pnl_pct = -1.5 # type: ignore[attr-defined]
record.order_amount = 1_000_000 # type: ignore[attr-defined]
output = formatter.format(record)
data = json.loads(output)
assert data["stock_code"] == "005930"
assert data["action"] == "BUY"
assert data["confidence"] == 85
assert data["pnl_pct"] == -1.5
assert data["order_amount"] == 1_000_000
def test_none_extra_fields_excluded(self) -> None:
"""Extra fields that are None must not appear in JSON output."""
formatter = JSONFormatter()
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="",
lineno=0,
msg="no extras",
args=(),
exc_info=None,
)
output = formatter.format(record)
data = json.loads(output)
assert "stock_code" not in data
assert "action" not in data
assert "confidence" not in data
class TestSetupLogging:
"""Test setup_logging function."""
def test_configures_root_logger(self) -> None:
"""setup_logging must attach a JSON handler to the root logger."""
setup_logging(level=logging.DEBUG)
root = logging.getLogger()
json_handlers = [
h for h in root.handlers if isinstance(h.formatter, JSONFormatter)
]
assert len(json_handlers) == 1
assert root.level == logging.DEBUG
def test_avoids_duplicate_handlers(self) -> None:
"""Calling setup_logging twice must not add duplicate handlers."""
setup_logging()
setup_logging()
root = logging.getLogger()
assert len(root.handlers) == 1

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -164,23 +164,18 @@ class TestGeneratePlaybook:
assert pb.market_outlook == MarketOutlook.NEUTRAL assert pb.market_outlook == MarketOutlook.NEUTRAL
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_gemini_failure_returns_smart_fallback(self) -> None: async def test_gemini_failure_returns_defensive(self) -> None:
planner = _make_planner() planner = _make_planner()
planner._gemini.decide = AsyncMock(side_effect=RuntimeError("API timeout")) planner._gemini.decide = AsyncMock(side_effect=RuntimeError("API timeout"))
# oversold candidate (signal="oversold", rsi=28.5)
candidates = [_candidate()] candidates = [_candidate()]
pb = await planner.generate_playbook("KR", candidates, today=date(2026, 2, 8)) pb = await planner.generate_playbook("KR", candidates, today=date(2026, 2, 8))
assert pb.default_action == ScenarioAction.HOLD assert pb.default_action == ScenarioAction.HOLD
# Smart fallback uses NEUTRAL outlook (not NEUTRAL_TO_BEARISH) assert pb.market_outlook == MarketOutlook.NEUTRAL_TO_BEARISH
assert pb.market_outlook == MarketOutlook.NEUTRAL
assert pb.stock_count == 1 assert pb.stock_count == 1
# Oversold candidate → first scenario is BUY, second is SELL stop-loss # Defensive playbook has stop-loss scenarios
scenarios = pb.stock_playbooks[0].scenarios assert pb.stock_playbooks[0].scenarios[0].action == ScenarioAction.SELL
assert scenarios[0].action == ScenarioAction.BUY
assert scenarios[0].condition.rsi_below == 30
assert scenarios[1].action == ScenarioAction.SELL
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_gemini_failure_empty_when_defensive_disabled(self) -> None: async def test_gemini_failure_empty_when_defensive_disabled(self) -> None:
@@ -662,339 +657,3 @@ class TestDefensivePlaybook:
assert pb.stock_count == 0 assert pb.stock_count == 0
assert pb.market == "US" assert pb.market == "US"
assert pb.market_outlook == MarketOutlook.NEUTRAL assert pb.market_outlook == MarketOutlook.NEUTRAL
# ---------------------------------------------------------------------------
# Smart fallback playbook
# ---------------------------------------------------------------------------
class TestSmartFallbackPlaybook:
"""Tests for _smart_fallback_playbook — rule-based BUY/SELL on Gemini failure."""
def _make_settings(self) -> Settings:
return Settings(
KIS_APP_KEY="test",
KIS_APP_SECRET="test",
KIS_ACCOUNT_NO="12345678-01",
GEMINI_API_KEY="test",
RSI_OVERSOLD_THRESHOLD=30,
VOL_MULTIPLIER=2.0,
)
def test_momentum_candidate_gets_buy_on_volume(self) -> None:
candidates = [
_candidate(code="CHOW", signal="momentum", volume_ratio=13.64, rsi=100.0)
]
settings = self._make_settings()
pb = PreMarketPlanner._smart_fallback_playbook(
date(2026, 2, 17), "US_AMEX", candidates, settings
)
assert pb.stock_count == 1
sp = pb.stock_playbooks[0]
assert sp.stock_code == "CHOW"
# First scenario: BUY with volume_ratio_above
buy_sc = sp.scenarios[0]
assert buy_sc.action == ScenarioAction.BUY
assert buy_sc.condition.volume_ratio_above == 2.0
assert buy_sc.condition.rsi_below is None
assert buy_sc.confidence == 80
# Second scenario: stop-loss SELL
sell_sc = sp.scenarios[1]
assert sell_sc.action == ScenarioAction.SELL
assert sell_sc.condition.price_change_pct_below == -3.0
def test_oversold_candidate_gets_buy_on_rsi(self) -> None:
candidates = [
_candidate(code="005930", signal="oversold", rsi=22.0, volume_ratio=3.5)
]
settings = self._make_settings()
pb = PreMarketPlanner._smart_fallback_playbook(
date(2026, 2, 17), "KR", candidates, settings
)
sp = pb.stock_playbooks[0]
buy_sc = sp.scenarios[0]
assert buy_sc.action == ScenarioAction.BUY
assert buy_sc.condition.rsi_below == 30
assert buy_sc.condition.volume_ratio_above is None
def test_all_candidates_have_stop_loss_sell(self) -> None:
candidates = [
_candidate(code="AAA", signal="momentum", volume_ratio=5.0),
_candidate(code="BBB", signal="oversold", rsi=25.0),
]
settings = self._make_settings()
pb = PreMarketPlanner._smart_fallback_playbook(
date(2026, 2, 17), "US_NASDAQ", candidates, settings
)
assert pb.stock_count == 2
for sp in pb.stock_playbooks:
sell_scenarios = [s for s in sp.scenarios if s.action == ScenarioAction.SELL]
assert len(sell_scenarios) == 1
assert sell_scenarios[0].condition.price_change_pct_below == -3.0
assert sell_scenarios[0].condition.price_change_pct_below == -3.0
def test_market_outlook_is_neutral(self) -> None:
candidates = [_candidate(signal="momentum", volume_ratio=5.0)]
settings = self._make_settings()
pb = PreMarketPlanner._smart_fallback_playbook(
date(2026, 2, 17), "US_AMEX", candidates, settings
)
assert pb.market_outlook == MarketOutlook.NEUTRAL
def test_default_action_is_hold(self) -> None:
candidates = [_candidate(signal="momentum", volume_ratio=5.0)]
settings = self._make_settings()
pb = PreMarketPlanner._smart_fallback_playbook(
date(2026, 2, 17), "US_AMEX", candidates, settings
)
assert pb.default_action == ScenarioAction.HOLD
def test_has_global_reduce_all_rule(self) -> None:
candidates = [_candidate(signal="momentum", volume_ratio=5.0)]
settings = self._make_settings()
pb = PreMarketPlanner._smart_fallback_playbook(
date(2026, 2, 17), "US_AMEX", candidates, settings
)
assert len(pb.global_rules) == 1
rule = pb.global_rules[0]
assert rule.action == ScenarioAction.REDUCE_ALL
assert "portfolio_pnl_pct" in rule.condition
def test_empty_candidates_returns_empty_playbook(self) -> None:
settings = self._make_settings()
pb = PreMarketPlanner._smart_fallback_playbook(
date(2026, 2, 17), "US_AMEX", [], settings
)
assert pb.stock_count == 0
def test_vol_multiplier_applied_from_settings(self) -> None:
"""VOL_MULTIPLIER=3.0 should set volume_ratio_above=3.0 for momentum."""
candidates = [_candidate(signal="momentum", volume_ratio=5.0)]
settings = self._make_settings()
settings = settings.model_copy(update={"VOL_MULTIPLIER": 3.0})
pb = PreMarketPlanner._smart_fallback_playbook(
date(2026, 2, 17), "US_AMEX", candidates, settings
)
buy_sc = pb.stock_playbooks[0].scenarios[0]
assert buy_sc.condition.volume_ratio_above == 3.0
def test_rsi_oversold_threshold_applied_from_settings(self) -> None:
"""RSI_OVERSOLD_THRESHOLD=25 should set rsi_below=25 for oversold."""
candidates = [_candidate(signal="oversold", rsi=22.0)]
settings = self._make_settings()
settings = settings.model_copy(update={"RSI_OVERSOLD_THRESHOLD": 25})
pb = PreMarketPlanner._smart_fallback_playbook(
date(2026, 2, 17), "KR", candidates, settings
)
buy_sc = pb.stock_playbooks[0].scenarios[0]
assert buy_sc.condition.rsi_below == 25
@pytest.mark.asyncio
async def test_generate_playbook_uses_smart_fallback_on_gemini_error(self) -> None:
"""generate_playbook() should use smart fallback (not defensive) on API failure."""
planner = _make_planner()
planner._gemini.decide = AsyncMock(side_effect=ConnectionError("429 quota exceeded"))
# momentum candidate
candidates = [
_candidate(code="CHOW", signal="momentum", volume_ratio=13.64, rsi=100.0)
]
pb = await planner.generate_playbook(
"US_AMEX", candidates, today=date(2026, 2, 18)
)
# Should NOT be all-SELL defensive; should have BUY for momentum
assert pb.stock_count == 1
buy_scenarios = [
s for s in pb.stock_playbooks[0].scenarios
if s.action == ScenarioAction.BUY
]
assert len(buy_scenarios) == 1
assert buy_scenarios[0].condition.volume_ratio_above == 2.0 # VOL_MULTIPLIER default
# ---------------------------------------------------------------------------
# Holdings in prompt (#170)
# ---------------------------------------------------------------------------
class TestHoldingsInPrompt:
"""Tests for current_holdings parameter in generate_playbook / _build_prompt."""
def _make_holdings(self) -> list[dict]:
return [
{
"stock_code": "005930",
"name": "Samsung",
"qty": 10,
"entry_price": 71000.0,
"unrealized_pnl_pct": 2.3,
"holding_days": 3,
}
]
def test_build_prompt_includes_holdings_section(self) -> None:
"""Prompt should contain a Current Holdings section when holdings are given."""
planner = _make_planner()
candidates = [_candidate()]
holdings = self._make_holdings()
prompt = planner._build_prompt(
"KR",
candidates,
context_data={},
self_market_scorecard=None,
cross_market=None,
current_holdings=holdings,
)
assert "## Current Holdings" in prompt
assert "005930" in prompt
assert "+2.30%" in prompt
assert "보유 3일" in prompt
def test_build_prompt_no_holdings_omits_section(self) -> None:
"""Prompt should NOT contain a Current Holdings section when holdings=None."""
planner = _make_planner()
candidates = [_candidate()]
prompt = planner._build_prompt(
"KR",
candidates,
context_data={},
self_market_scorecard=None,
cross_market=None,
current_holdings=None,
)
assert "## Current Holdings" not in prompt
def test_build_prompt_empty_holdings_omits_section(self) -> None:
"""Empty list should also omit the holdings section."""
planner = _make_planner()
candidates = [_candidate()]
prompt = planner._build_prompt(
"KR",
candidates,
context_data={},
self_market_scorecard=None,
cross_market=None,
current_holdings=[],
)
assert "## Current Holdings" not in prompt
def test_build_prompt_holdings_instruction_included(self) -> None:
"""Prompt should include instruction to generate scenarios for held stocks."""
planner = _make_planner()
candidates = [_candidate()]
holdings = self._make_holdings()
prompt = planner._build_prompt(
"KR",
candidates,
context_data={},
self_market_scorecard=None,
cross_market=None,
current_holdings=holdings,
)
assert "005930" in prompt
assert "SELL/HOLD" in prompt
@pytest.mark.asyncio
async def test_generate_playbook_passes_holdings_to_prompt(self) -> None:
"""generate_playbook should pass current_holdings through to the prompt."""
planner = _make_planner()
candidates = [_candidate()]
holdings = self._make_holdings()
# Capture the actual prompt sent to Gemini
captured_prompts: list[str] = []
original_decide = planner._gemini.decide
async def capture_and_call(data: dict) -> TradeDecision:
captured_prompts.append(data.get("prompt_override", ""))
return await original_decide(data)
planner._gemini.decide = capture_and_call # type: ignore[method-assign]
await planner.generate_playbook(
"KR", candidates, today=date(2026, 2, 8), current_holdings=holdings
)
assert len(captured_prompts) == 1
assert "## Current Holdings" in captured_prompts[0]
assert "005930" in captured_prompts[0]
@pytest.mark.asyncio
async def test_holdings_stock_allowed_in_parse_response(self) -> None:
"""Holdings stocks not in candidates list should be accepted in the response."""
holding_code = "000660" # Not in candidates
stocks = [
{
"stock_code": "005930", # candidate
"scenarios": [
{
"condition": {"rsi_below": 30},
"action": "BUY",
"confidence": 85,
"rationale": "oversold",
}
],
},
{
"stock_code": holding_code, # holding only
"scenarios": [
{
"condition": {"price_change_pct_below": -2.0},
"action": "SELL",
"confidence": 90,
"rationale": "stop-loss",
}
],
},
]
planner = _make_planner(gemini_response=_gemini_response_json(stocks=stocks))
candidates = [_candidate()] # only 005930
holdings = [
{
"stock_code": holding_code,
"name": "SK Hynix",
"qty": 5,
"entry_price": 180000.0,
"unrealized_pnl_pct": -1.5,
"holding_days": 7,
}
]
pb = await planner.generate_playbook(
"KR",
candidates,
today=date(2026, 2, 8),
current_holdings=holdings,
)
codes = [sp.stock_code for sp in pb.stock_playbooks]
assert "005930" in codes
assert holding_code in codes

View File

@@ -440,135 +440,3 @@ class TestEvaluate:
assert result.action == ScenarioAction.BUY assert result.action == ScenarioAction.BUY
assert result.match_details["rsi"] == 25.0 assert result.match_details["rsi"] == 25.0
assert isinstance(result.match_details["rsi"], float) assert isinstance(result.match_details["rsi"], float)
# ---------------------------------------------------------------------------
# Position-aware condition tests (#171)
# ---------------------------------------------------------------------------
class TestPositionAwareConditions:
"""Tests for unrealized_pnl_pct and holding_days condition fields."""
def test_evaluate_condition_unrealized_pnl_above_matches(
self, engine: ScenarioEngine
) -> None:
"""unrealized_pnl_pct_above should match when P&L exceeds threshold."""
condition = StockCondition(unrealized_pnl_pct_above=3.0)
assert engine.evaluate_condition(condition, {"unrealized_pnl_pct": 5.0}) is True
def test_evaluate_condition_unrealized_pnl_above_no_match(
self, engine: ScenarioEngine
) -> None:
"""unrealized_pnl_pct_above should NOT match when P&L is below threshold."""
condition = StockCondition(unrealized_pnl_pct_above=3.0)
assert engine.evaluate_condition(condition, {"unrealized_pnl_pct": 2.0}) is False
def test_evaluate_condition_unrealized_pnl_below_matches(
self, engine: ScenarioEngine
) -> None:
"""unrealized_pnl_pct_below should match when P&L is under threshold."""
condition = StockCondition(unrealized_pnl_pct_below=-2.0)
assert engine.evaluate_condition(condition, {"unrealized_pnl_pct": -3.5}) is True
def test_evaluate_condition_unrealized_pnl_below_no_match(
self, engine: ScenarioEngine
) -> None:
"""unrealized_pnl_pct_below should NOT match when P&L is above threshold."""
condition = StockCondition(unrealized_pnl_pct_below=-2.0)
assert engine.evaluate_condition(condition, {"unrealized_pnl_pct": -1.0}) is False
def test_evaluate_condition_holding_days_above_matches(
self, engine: ScenarioEngine
) -> None:
"""holding_days_above should match when position held longer than threshold."""
condition = StockCondition(holding_days_above=5)
assert engine.evaluate_condition(condition, {"holding_days": 7}) is True
def test_evaluate_condition_holding_days_above_no_match(
self, engine: ScenarioEngine
) -> None:
"""holding_days_above should NOT match when position held shorter."""
condition = StockCondition(holding_days_above=5)
assert engine.evaluate_condition(condition, {"holding_days": 3}) is False
def test_evaluate_condition_holding_days_below_matches(
self, engine: ScenarioEngine
) -> None:
"""holding_days_below should match when position held fewer days."""
condition = StockCondition(holding_days_below=3)
assert engine.evaluate_condition(condition, {"holding_days": 1}) is True
def test_evaluate_condition_holding_days_below_no_match(
self, engine: ScenarioEngine
) -> None:
"""holding_days_below should NOT match when held more days."""
condition = StockCondition(holding_days_below=3)
assert engine.evaluate_condition(condition, {"holding_days": 5}) is False
def test_combined_pnl_and_holding_days(self, engine: ScenarioEngine) -> None:
"""Combined position-aware conditions should AND-evaluate correctly."""
condition = StockCondition(
unrealized_pnl_pct_above=3.0,
holding_days_above=5,
)
# Both met → match
assert engine.evaluate_condition(
condition,
{"unrealized_pnl_pct": 4.5, "holding_days": 7},
) is True
# Only pnl met → no match
assert engine.evaluate_condition(
condition,
{"unrealized_pnl_pct": 4.5, "holding_days": 3},
) is False
def test_missing_unrealized_pnl_does_not_match(
self, engine: ScenarioEngine
) -> None:
"""Missing unrealized_pnl_pct key should not match the condition."""
condition = StockCondition(unrealized_pnl_pct_above=3.0)
assert engine.evaluate_condition(condition, {}) is False
def test_missing_holding_days_does_not_match(
self, engine: ScenarioEngine
) -> None:
"""Missing holding_days key should not match the condition."""
condition = StockCondition(holding_days_above=5)
assert engine.evaluate_condition(condition, {}) is False
def test_match_details_includes_position_fields(
self, engine: ScenarioEngine
) -> None:
"""match_details should include position fields when condition specifies them."""
pb = _playbook(
scenarios=[
StockScenario(
condition=StockCondition(unrealized_pnl_pct_above=3.0),
action=ScenarioAction.SELL,
confidence=90,
rationale="Take profit",
)
]
)
result = engine.evaluate(
pb,
"005930",
{"unrealized_pnl_pct": 5.0},
{},
)
assert result.action == ScenarioAction.SELL
assert "unrealized_pnl_pct" in result.match_details
assert result.match_details["unrealized_pnl_pct"] == 5.0
def test_position_conditions_parse_from_planner(self) -> None:
"""StockCondition should accept and store new fields from JSON parsing."""
condition = StockCondition(
unrealized_pnl_pct_above=3.0,
unrealized_pnl_pct_below=None,
holding_days_above=5,
holding_days_below=None,
)
assert condition.unrealized_pnl_pct_above == 3.0
assert condition.holding_days_above == 5
assert condition.has_any_condition() is True

View File

@@ -350,42 +350,6 @@ class TestSmartVolatilityScanner:
assert [c.stock_code for c in candidates] == ["ABCD"] assert [c.stock_code for c in candidates] == ["ABCD"]
class TestImpliedRSIFormula:
"""Test the implied_rsi formula in SmartVolatilityScanner (issue #181)."""
def test_neutral_change_gives_neutral_rsi(self) -> None:
"""0% change → implied_rsi = 50 (neutral)."""
# formula: 50 + (change_rate * 2.0)
rsi = max(0.0, min(100.0, 50.0 + (0.0 * 2.0)))
assert rsi == 50.0
def test_10pct_change_gives_rsi_70(self) -> None:
"""10% upward change → implied_rsi = 70 (momentum signal)."""
rsi = max(0.0, min(100.0, 50.0 + (10.0 * 2.0)))
assert rsi == 70.0
def test_minus_10pct_gives_rsi_30(self) -> None:
"""-10% change → implied_rsi = 30 (oversold signal)."""
rsi = max(0.0, min(100.0, 50.0 + (-10.0 * 2.0)))
assert rsi == 30.0
def test_saturation_at_25pct(self) -> None:
"""Saturation occurs at >=25% change (not 12.5% as with old coefficient 4.0)."""
rsi_12pct = max(0.0, min(100.0, 50.0 + (12.5 * 2.0)))
rsi_25pct = max(0.0, min(100.0, 50.0 + (25.0 * 2.0)))
rsi_30pct = max(0.0, min(100.0, 50.0 + (30.0 * 2.0)))
# At 12.5% change: RSI = 75 (not 100, unlike old formula)
assert rsi_12pct == 75.0
# At 25%+ saturation
assert rsi_25pct == 100.0
assert rsi_30pct == 100.0 # Capped
def test_negative_saturation(self) -> None:
"""Saturation at -25% gives RSI = 0."""
rsi = max(0.0, min(100.0, 50.0 + (-25.0 * 2.0)))
assert rsi == 0.0
class TestRSICalculation: class TestRSICalculation:
"""Test RSI calculation in VolatilityAnalyzer.""" """Test RSI calculation in VolatilityAnalyzer."""

View File

@@ -1,32 +0,0 @@
"""Tests for BaseStrategy abstract class."""
from __future__ import annotations
from typing import Any
import pytest
from src.strategies.base import BaseStrategy
class ConcreteStrategy(BaseStrategy):
"""Minimal concrete strategy for testing."""
def evaluate(self, market_data: dict[str, Any]) -> dict[str, Any]:
return {"action": "HOLD", "confidence": 50, "rationale": "test"}
def test_base_strategy_cannot_be_instantiated() -> None:
"""BaseStrategy cannot be instantiated directly (it's abstract)."""
with pytest.raises(TypeError):
BaseStrategy() # type: ignore[abstract]
def test_concrete_strategy_evaluate_returns_decision() -> None:
"""Concrete subclass must implement evaluate and return a dict."""
strategy = ConcreteStrategy()
result = strategy.evaluate({"close": [100.0, 101.0]})
assert isinstance(result, dict)
assert result["action"] == "HOLD"
assert result["confidence"] == 50
assert "rationale" in result

View File

@@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, patch
import aiohttp import aiohttp
import pytest import pytest
from src.notifications.telegram_client import NotificationFilter, NotificationPriority, TelegramClient from src.notifications.telegram_client import NotificationPriority, TelegramClient
class TestTelegramClientInit: class TestTelegramClientInit:
@@ -481,187 +481,3 @@ class TestClientCleanup:
# Should not raise exception # Should not raise exception
await client.close() await client.close()
class TestNotificationFilter:
"""Test granular notification filter behavior."""
def test_default_filter_allows_all(self) -> None:
"""Default NotificationFilter has all flags enabled."""
f = NotificationFilter()
assert f.trades is True
assert f.market_open_close is True
assert f.fat_finger is True
assert f.system_events is True
assert f.playbook is True
assert f.scenario_match is True
assert f.errors is True
def test_client_uses_default_filter_when_none_given(self) -> None:
"""TelegramClient creates a default NotificationFilter when none provided."""
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
assert isinstance(client._filter, NotificationFilter)
assert client._filter.scenario_match is True
def test_client_stores_provided_filter(self) -> None:
"""TelegramClient stores a custom NotificationFilter."""
nf = NotificationFilter(scenario_match=False, trades=False)
client = TelegramClient(
bot_token="123:abc", chat_id="456", enabled=True, notification_filter=nf
)
assert client._filter.scenario_match is False
assert client._filter.trades is False
assert client._filter.market_open_close is True # default still True
@pytest.mark.asyncio
async def test_scenario_match_filtered_does_not_send(self) -> None:
"""notify_scenario_matched skips send when scenario_match=False."""
nf = NotificationFilter(scenario_match=False)
client = TelegramClient(
bot_token="123:abc", chat_id="456", enabled=True, notification_filter=nf
)
with patch("aiohttp.ClientSession.post") as mock_post:
await client.notify_scenario_matched(
stock_code="005930", action="BUY", condition_summary="rsi<30", confidence=85.0
)
mock_post.assert_not_called()
@pytest.mark.asyncio
async def test_trades_filtered_does_not_send(self) -> None:
"""notify_trade_execution skips send when trades=False."""
nf = NotificationFilter(trades=False)
client = TelegramClient(
bot_token="123:abc", chat_id="456", enabled=True, notification_filter=nf
)
with patch("aiohttp.ClientSession.post") as mock_post:
await client.notify_trade_execution(
stock_code="005930", market="KR", action="BUY",
quantity=10, price=70000.0, confidence=85.0
)
mock_post.assert_not_called()
@pytest.mark.asyncio
async def test_market_open_close_filtered_does_not_send(self) -> None:
"""notify_market_open/close skip send when market_open_close=False."""
nf = NotificationFilter(market_open_close=False)
client = TelegramClient(
bot_token="123:abc", chat_id="456", enabled=True, notification_filter=nf
)
with patch("aiohttp.ClientSession.post") as mock_post:
await client.notify_market_open("Korea")
await client.notify_market_close("Korea", pnl_pct=1.5)
mock_post.assert_not_called()
@pytest.mark.asyncio
async def test_circuit_breaker_always_sends_regardless_of_filter(self) -> None:
"""notify_circuit_breaker always sends (no filter flag)."""
nf = NotificationFilter(
trades=False, market_open_close=False, fat_finger=False,
system_events=False, playbook=False, scenario_match=False, errors=False,
)
client = TelegramClient(
bot_token="123:abc", chat_id="456", enabled=True, notification_filter=nf
)
mock_resp = AsyncMock()
mock_resp.status = 200
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
await client.notify_circuit_breaker(pnl_pct=-3.5, threshold=-3.0)
assert mock_post.call_count == 1
@pytest.mark.asyncio
async def test_errors_filtered_does_not_send(self) -> None:
"""notify_error skips send when errors=False."""
nf = NotificationFilter(errors=False)
client = TelegramClient(
bot_token="123:abc", chat_id="456", enabled=True, notification_filter=nf
)
with patch("aiohttp.ClientSession.post") as mock_post:
await client.notify_error("TestError", "something went wrong", "KR")
mock_post.assert_not_called()
@pytest.mark.asyncio
async def test_playbook_filtered_does_not_send(self) -> None:
"""notify_playbook_generated/failed skip send when playbook=False."""
nf = NotificationFilter(playbook=False)
client = TelegramClient(
bot_token="123:abc", chat_id="456", enabled=True, notification_filter=nf
)
with patch("aiohttp.ClientSession.post") as mock_post:
await client.notify_playbook_generated("KR", 3, 10, 1200)
await client.notify_playbook_failed("KR", "timeout")
mock_post.assert_not_called()
@pytest.mark.asyncio
async def test_system_events_filtered_does_not_send(self) -> None:
"""notify_system_start/shutdown skip send when system_events=False."""
nf = NotificationFilter(system_events=False)
client = TelegramClient(
bot_token="123:abc", chat_id="456", enabled=True, notification_filter=nf
)
with patch("aiohttp.ClientSession.post") as mock_post:
await client.notify_system_start("paper", ["KR"])
await client.notify_system_shutdown("Normal shutdown")
mock_post.assert_not_called()
def test_set_flag_valid_key(self) -> None:
"""set_flag returns True and updates field for a known key."""
nf = NotificationFilter()
assert nf.set_flag("scenario", False) is True
assert nf.scenario_match is False
def test_set_flag_invalid_key(self) -> None:
"""set_flag returns False for an unknown key."""
nf = NotificationFilter()
assert nf.set_flag("unknown_key", False) is False
def test_as_dict_keys_match_KEYS(self) -> None:
"""as_dict() returns every key defined in KEYS."""
nf = NotificationFilter()
d = nf.as_dict()
assert set(d.keys()) == set(NotificationFilter.KEYS.keys())
def test_set_notification_valid_key(self) -> None:
"""TelegramClient.set_notification toggles filter at runtime."""
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
assert client._filter.scenario_match is True
assert client.set_notification("scenario", False) is True
assert client._filter.scenario_match is False
def test_set_notification_all_off(self) -> None:
"""set_notification('all', False) disables every filter flag."""
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
assert client.set_notification("all", False) is True
for v in client.filter_status().values():
assert v is False
def test_set_notification_all_on(self) -> None:
"""set_notification('all', True) enables every filter flag."""
client = TelegramClient(
bot_token="123:abc", chat_id="456", enabled=True,
notification_filter=NotificationFilter(
trades=False, market_open_close=False, scenario_match=False,
fat_finger=False, system_events=False, playbook=False, errors=False,
),
)
assert client.set_notification("all", True) is True
for v in client.filter_status().values():
assert v is True
def test_set_notification_unknown_key(self) -> None:
"""set_notification returns False for an unknown key."""
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
assert client.set_notification("unknown", False) is False
def test_filter_status_reflects_current_state(self) -> None:
"""filter_status() matches the current NotificationFilter state."""
nf = NotificationFilter(trades=False, scenario_match=False)
client = TelegramClient(
bot_token="123:abc", chat_id="456", enabled=True, notification_filter=nf
)
status = client.filter_status()
assert status["trades"] is False
assert status["scenario"] is False
assert status["market"] is True

View File

@@ -875,139 +875,3 @@ class TestGetUpdates:
updates = await handler._get_updates() updates = await handler._get_updates()
assert updates == [] assert updates == []
@pytest.mark.asyncio
async def test_get_updates_409_stops_polling(self) -> None:
"""409 Conflict response stops the poller (_running = False) and returns empty list."""
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
handler = TelegramCommandHandler(client)
handler._running = True # simulate active poller
mock_resp = AsyncMock()
mock_resp.status = 409
mock_resp.text = AsyncMock(
return_value='{"ok":false,"error_code":409,"description":"Conflict"}'
)
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
updates = await handler._get_updates()
assert updates == []
assert handler._running is False # poller stopped
@pytest.mark.asyncio
async def test_poll_loop_exits_after_409(self) -> None:
"""_poll_loop exits naturally after _running is set to False by a 409 response."""
import asyncio as _asyncio
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
handler = TelegramCommandHandler(client)
call_count = 0
async def mock_get_updates_409() -> list[dict]:
nonlocal call_count
call_count += 1
# Simulate 409 stopping the poller
handler._running = False
return []
handler._get_updates = mock_get_updates_409 # type: ignore[method-assign]
handler._running = True
task = _asyncio.create_task(handler._poll_loop())
await _asyncio.wait_for(task, timeout=2.0)
# _get_updates called exactly once, then loop exited
assert call_count == 1
assert handler._running is False
class TestCommandWithArgs:
"""Test register_command_with_args and argument dispatch."""
def test_register_command_with_args_stored(self) -> None:
"""register_command_with_args stores handler in _commands_with_args."""
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
handler = TelegramCommandHandler(client)
async def my_handler(args: list[str]) -> None:
pass
handler.register_command_with_args("notify", my_handler)
assert "notify" in handler._commands_with_args
assert handler._commands_with_args["notify"] is my_handler
@pytest.mark.asyncio
async def test_args_handler_receives_arguments(self) -> None:
"""Args handler is called with the trailing tokens."""
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
handler = TelegramCommandHandler(client)
received: list[list[str]] = []
async def capture(args: list[str]) -> None:
received.append(args)
handler.register_command_with_args("notify", capture)
update = {
"message": {
"chat": {"id": "456"},
"text": "/notify scenario off",
}
}
await handler._handle_update(update)
assert received == [["scenario", "off"]]
@pytest.mark.asyncio
async def test_args_handler_takes_priority_over_no_args_handler(self) -> None:
"""When both handlers exist for same command, args handler wins."""
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
handler = TelegramCommandHandler(client)
no_args_called = []
args_called = []
async def no_args_handler() -> None:
no_args_called.append(True)
async def args_handler(args: list[str]) -> None:
args_called.append(args)
handler.register_command("notify", no_args_handler)
handler.register_command_with_args("notify", args_handler)
update = {
"message": {
"chat": {"id": "456"},
"text": "/notify all off",
}
}
await handler._handle_update(update)
assert args_called == [["all", "off"]]
assert no_args_called == []
@pytest.mark.asyncio
async def test_args_handler_with_no_trailing_args(self) -> None:
"""/notify with no args still dispatches to args handler with empty list."""
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
handler = TelegramCommandHandler(client)
received: list[list[str]] = []
async def capture(args: list[str]) -> None:
received.append(args)
handler.register_command_with_args("notify", capture)
update = {
"message": {
"chat": {"id": "456"},
"text": "/notify",
}
}
await handler._handle_update(update)
assert received == [[]]

View File

@@ -124,10 +124,6 @@ class TestPromptOptimizer:
assert len(prompt) < 300 assert len(prompt) < 300
assert "005930" in prompt assert "005930" in prompt
assert "75000" in prompt assert "75000" in prompt
# Keys must match parse_response expectations (#242)
assert '"action"' in prompt
assert '"confidence"' in prompt
assert '"rationale"' in prompt
def test_build_compressed_prompt_no_instructions(self): def test_build_compressed_prompt_no_instructions(self):
"""Test compressed prompt without instructions.""" """Test compressed prompt without instructions."""