Add complete Ouroboros trading system with TDD test suite
Some checks failed
CI / test (push) Has been cancelled
Some checks failed
CI / test (push) Has been cancelled
Implement the full autonomous trading agent architecture: - KIS broker with async API, token refresh, leaky bucket rate limiter, and hash key signing - Gemini-powered decision engine with JSON parsing and confidence threshold enforcement - Risk manager with circuit breaker (-3% P&L) and fat finger protection (30% cap) - Evolution engine for self-improving strategy generation via failure analysis - 35 passing tests written TDD-first covering risk, broker, and brain modules - CI/CD pipeline, Docker multi-stage build, and AI agent context docs Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
23
.env.example
Normal file
23
.env.example
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# Korea Investment Securities API
|
||||||
|
KIS_APP_KEY=your_app_key_here
|
||||||
|
KIS_APP_SECRET=your_app_secret_here
|
||||||
|
KIS_ACCOUNT_NO=12345678-01
|
||||||
|
KIS_BASE_URL=https://openapivts.koreainvestment.com:9443
|
||||||
|
|
||||||
|
# Google Gemini
|
||||||
|
GEMINI_API_KEY=your_gemini_api_key_here
|
||||||
|
GEMINI_MODEL=gemini-pro
|
||||||
|
|
||||||
|
# Risk Management
|
||||||
|
CIRCUIT_BREAKER_PCT=-3.0
|
||||||
|
FAT_FINGER_PCT=30.0
|
||||||
|
CONFIDENCE_THRESHOLD=80
|
||||||
|
|
||||||
|
# Database
|
||||||
|
DB_PATH=data/trade_logs.db
|
||||||
|
|
||||||
|
# Rate Limiting
|
||||||
|
RATE_LIMIT_RPS=10.0
|
||||||
|
|
||||||
|
# Trading Mode (paper / live)
|
||||||
|
MODE=paper
|
||||||
39
.github/workflows/ci.yml
vendored
Normal file
39
.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
pull_request:
|
||||||
|
branches: [main]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python 3.11
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.11"
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: pip install ".[dev]"
|
||||||
|
|
||||||
|
- name: Lint
|
||||||
|
run: ruff check src/ tests/
|
||||||
|
|
||||||
|
- name: Type check
|
||||||
|
run: mypy src/ --strict
|
||||||
|
continue-on-error: true
|
||||||
|
|
||||||
|
- name: Run tests with coverage
|
||||||
|
run: pytest -v --cov=src --cov-report=term-missing --cov-fail-under=80
|
||||||
|
|
||||||
|
- name: Upload coverage
|
||||||
|
if: always()
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: coverage-report
|
||||||
|
path: htmlcov/
|
||||||
44
Dockerfile
Normal file
44
Dockerfile
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
FROM python:3.11-slim AS base
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# System deps
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
gcc \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install Python dependencies
|
||||||
|
COPY pyproject.toml .
|
||||||
|
RUN pip install --no-cache-dir ".[dev]"
|
||||||
|
|
||||||
|
# Copy source
|
||||||
|
COPY src/ src/
|
||||||
|
COPY tests/ tests/
|
||||||
|
COPY docs/ docs/
|
||||||
|
|
||||||
|
# Create data directory
|
||||||
|
RUN mkdir -p data
|
||||||
|
|
||||||
|
# Run tests as build validation
|
||||||
|
RUN pytest -v --tb=short
|
||||||
|
|
||||||
|
# Production stage
|
||||||
|
FROM python:3.11-slim AS production
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY pyproject.toml .
|
||||||
|
RUN pip install --no-cache-dir .
|
||||||
|
|
||||||
|
COPY src/ src/
|
||||||
|
RUN mkdir -p data
|
||||||
|
|
||||||
|
# Non-root user
|
||||||
|
RUN useradd --create-home appuser
|
||||||
|
USER appuser
|
||||||
|
|
||||||
|
HEALTHCHECK --interval=30s --timeout=5s --retries=3 \
|
||||||
|
CMD python -c "import sys; sys.exit(0)"
|
||||||
|
|
||||||
|
ENTRYPOINT ["python", "-m", "src.main"]
|
||||||
|
CMD ["--mode=paper"]
|
||||||
30
docker-compose.yml
Normal file
30
docker-compose.yml
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
services:
|
||||||
|
ouroboros:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
target: production
|
||||||
|
container_name: ouroboros
|
||||||
|
restart: unless-stopped
|
||||||
|
env_file:
|
||||||
|
- .env
|
||||||
|
volumes:
|
||||||
|
- trade_data:/app/data
|
||||||
|
command: ["--mode=paper"]
|
||||||
|
logging:
|
||||||
|
driver: json-file
|
||||||
|
options:
|
||||||
|
max-size: "10m"
|
||||||
|
max-file: "3"
|
||||||
|
|
||||||
|
# Run tests (one-shot)
|
||||||
|
test:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
target: base
|
||||||
|
container_name: ouroboros-test
|
||||||
|
command: ["pytest", "-v", "--cov=src", "--cov-report=term-missing"]
|
||||||
|
profiles:
|
||||||
|
- test
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
trade_data:
|
||||||
58
docs/agents.md
Normal file
58
docs/agents.md
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
# The Ouroboros - Agent Persona Definition
|
||||||
|
|
||||||
|
## Role: The Guardian
|
||||||
|
|
||||||
|
You are **The Guardian**, the primary AI agent responsible for maintaining and evolving
|
||||||
|
The Ouroboros trading system. Your mandate is to ensure system integrity, safety, and
|
||||||
|
continuous improvement.
|
||||||
|
|
||||||
|
## Prime Directives
|
||||||
|
|
||||||
|
1. **NEVER disable, bypass, or weaken `core/risk_manager.py`.**
|
||||||
|
The risk manager is the last line of defense against catastrophic loss.
|
||||||
|
Any code change that reduces risk controls MUST be rejected.
|
||||||
|
|
||||||
|
2. **All code changes require a passing test.**
|
||||||
|
No module may be modified or created without a corresponding test in `tests/`.
|
||||||
|
Run `pytest -v --cov=src` before proposing any merge.
|
||||||
|
|
||||||
|
3. **Preserve the Circuit Breaker.**
|
||||||
|
The daily P&L circuit breaker (-3.0% threshold) is non-negotiable.
|
||||||
|
It may only be made *stricter*, never relaxed.
|
||||||
|
|
||||||
|
4. **Fat Finger Protection is sacred.**
|
||||||
|
The 30% max-order-size rule must remain enforced at all times.
|
||||||
|
|
||||||
|
## Decision Framework
|
||||||
|
|
||||||
|
When modifying code, follow this priority order:
|
||||||
|
|
||||||
|
1. **Safety** - Will this change increase risk exposure? If yes, reject.
|
||||||
|
2. **Correctness** - Is the logic provably correct? Verify with tests.
|
||||||
|
3. **Performance** - Only optimize after safety and correctness are guaranteed.
|
||||||
|
4. **Readability** - Code must be understandable by future agents and humans.
|
||||||
|
|
||||||
|
## File Ownership
|
||||||
|
|
||||||
|
| Module | Guardian Rule |
|
||||||
|
|---|---|
|
||||||
|
| `core/risk_manager.py` | READ-ONLY. Changes require human approval + 2 passing test suites. |
|
||||||
|
| `broker/kis_api.py` | Rate limiter must never be removed. Token refresh must remain automatic. |
|
||||||
|
| `brain/gemini_client.py` | Confidence < 80 MUST force HOLD. This rule cannot be weakened. |
|
||||||
|
| `evolution/optimizer.py` | Generated strategies must pass ALL tests before activation. |
|
||||||
|
| `strategies/*` | New strategies are welcome but must inherit `BaseStrategy`. |
|
||||||
|
|
||||||
|
## Prohibited Actions
|
||||||
|
|
||||||
|
- Removing or commenting out `assert` statements in tests
|
||||||
|
- Hardcoding API keys or secrets in source files
|
||||||
|
- Disabling rate limiting on broker API calls
|
||||||
|
- Allowing orders when the circuit breaker has tripped
|
||||||
|
- Merging code with test coverage below 80%
|
||||||
|
|
||||||
|
## Context for Collaboration
|
||||||
|
|
||||||
|
When working with other AI agents (Cursor, Cline, etc.):
|
||||||
|
- Share this document as the system constitution
|
||||||
|
- All agents must acknowledge these rules before making changes
|
||||||
|
- Conflicts are resolved by defaulting to the *safer* option
|
||||||
109
docs/skills.md
Normal file
109
docs/skills.md
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
# The Ouroboros - Available Skills & Tools
|
||||||
|
|
||||||
|
## Development Tools
|
||||||
|
|
||||||
|
### Run Tests
|
||||||
|
```bash
|
||||||
|
pytest -v --cov=src --cov-report=term-missing
|
||||||
|
```
|
||||||
|
Run the full test suite with coverage reporting. **Must pass before any merge.**
|
||||||
|
|
||||||
|
### Run Specific Test Module
|
||||||
|
```bash
|
||||||
|
pytest tests/test_risk.py -v
|
||||||
|
pytest tests/test_broker.py -v
|
||||||
|
pytest tests/test_brain.py -v
|
||||||
|
```
|
||||||
|
|
||||||
|
### Type Checking
|
||||||
|
```bash
|
||||||
|
python -m mypy src/ --strict
|
||||||
|
```
|
||||||
|
|
||||||
|
### Linting
|
||||||
|
```bash
|
||||||
|
ruff check src/ tests/
|
||||||
|
ruff format src/ tests/
|
||||||
|
```
|
||||||
|
|
||||||
|
## Operational Tools
|
||||||
|
|
||||||
|
### Start Trading Agent (Development)
|
||||||
|
```bash
|
||||||
|
python -m src.main --mode=paper
|
||||||
|
```
|
||||||
|
Runs the agent in paper-trading mode (no real orders).
|
||||||
|
|
||||||
|
### Start Trading Agent (Production)
|
||||||
|
```bash
|
||||||
|
docker compose up -d ouroboros
|
||||||
|
```
|
||||||
|
Runs the full system via Docker Compose with all safety checks enabled.
|
||||||
|
|
||||||
|
### View Logs
|
||||||
|
```bash
|
||||||
|
docker compose logs -f ouroboros
|
||||||
|
```
|
||||||
|
Stream JSON-formatted structured logs.
|
||||||
|
|
||||||
|
### Run Backtester
|
||||||
|
```bash
|
||||||
|
python -m src.evolution.optimizer --backtest --days=30
|
||||||
|
```
|
||||||
|
Analyze the last 30 days of trade logs and generate performance metrics.
|
||||||
|
|
||||||
|
## Evolution Tools
|
||||||
|
|
||||||
|
### Generate New Strategy
|
||||||
|
```bash
|
||||||
|
python -m src.evolution.optimizer --evolve
|
||||||
|
```
|
||||||
|
Triggers the evolution engine to:
|
||||||
|
1. Analyze `trade_logs.db` for failing patterns
|
||||||
|
2. Ask Gemini to generate a new strategy
|
||||||
|
3. Run tests on the new strategy
|
||||||
|
4. Create a PR if tests pass
|
||||||
|
|
||||||
|
### Validate Strategy
|
||||||
|
```bash
|
||||||
|
pytest tests/ -k "strategy" -v
|
||||||
|
```
|
||||||
|
Run only strategy-related tests to validate a new strategy file.
|
||||||
|
|
||||||
|
## Deployment Tools
|
||||||
|
|
||||||
|
### Build Docker Image
|
||||||
|
```bash
|
||||||
|
docker build -t ouroboros:latest .
|
||||||
|
```
|
||||||
|
|
||||||
|
### Deploy with Docker Compose
|
||||||
|
```bash
|
||||||
|
docker compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
### Health Check
|
||||||
|
```bash
|
||||||
|
curl http://localhost:8080/health
|
||||||
|
```
|
||||||
|
|
||||||
|
## Database Tools
|
||||||
|
|
||||||
|
### View Trade Logs
|
||||||
|
```bash
|
||||||
|
sqlite3 data/trade_logs.db "SELECT * FROM trades ORDER BY timestamp DESC LIMIT 20;"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Export Trade History
|
||||||
|
```bash
|
||||||
|
sqlite3 -header -csv data/trade_logs.db "SELECT * FROM trades;" > trades_export.csv
|
||||||
|
```
|
||||||
|
|
||||||
|
## Safety Checklist (Pre-Deploy)
|
||||||
|
|
||||||
|
- [ ] `pytest -v --cov=src` passes with >= 80% coverage
|
||||||
|
- [ ] `ruff check src/ tests/` reports no errors
|
||||||
|
- [ ] `.env` file contains valid KIS and Gemini API keys
|
||||||
|
- [ ] Circuit breaker threshold is set to -3.0% or stricter
|
||||||
|
- [ ] Rate limiter is configured for KIS API limits
|
||||||
|
- [ ] Docker health check endpoint responds 200
|
||||||
40
pyproject.toml
Normal file
40
pyproject.toml
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
[project]
|
||||||
|
name = "the-ouroboros"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Evolutionary AI Trading Agent for KIS"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
dependencies = [
|
||||||
|
"aiohttp>=3.9,<4",
|
||||||
|
"pydantic>=2.5,<3",
|
||||||
|
"pydantic-settings>=2.1,<3",
|
||||||
|
"google-generativeai>=0.8,<1",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest>=8.0,<9",
|
||||||
|
"pytest-asyncio>=0.23,<1",
|
||||||
|
"pytest-cov>=5.0,<6",
|
||||||
|
"ruff>=0.5,<1",
|
||||||
|
"mypy>=1.10,<2",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
ouroboros = "src.main:main"
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["tests"]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
target-version = "py311"
|
||||||
|
line-length = 100
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
select = ["E", "F", "I", "N", "W", "UP"]
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
python_version = "3.11"
|
||||||
|
strict = true
|
||||||
|
warn_return_any = true
|
||||||
|
warn_unused_configs = true
|
||||||
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
0
src/brain/__init__.py
Normal file
0
src/brain/__init__.py
Normal file
152
src/brain/gemini_client.py
Normal file
152
src/brain/gemini_client.py
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
"""Decision engine powered by Google Gemini.
|
||||||
|
|
||||||
|
Constructs prompts from market data, calls Gemini, and parses structured
|
||||||
|
JSON responses into validated TradeDecision objects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import google.generativeai as genai
|
||||||
|
|
||||||
|
from src.config import Settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
VALID_ACTIONS = {"BUY", "SELL", "HOLD"}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TradeDecision:
|
||||||
|
"""Validated decision from the AI brain."""
|
||||||
|
|
||||||
|
action: str # "BUY" | "SELL" | "HOLD"
|
||||||
|
confidence: int # 0-100
|
||||||
|
rationale: str
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiClient:
|
||||||
|
"""Wraps the Gemini API for trade decision-making."""
|
||||||
|
|
||||||
|
def __init__(self, settings: Settings) -> None:
|
||||||
|
self._settings = settings
|
||||||
|
self._confidence_threshold = settings.CONFIDENCE_THRESHOLD
|
||||||
|
genai.configure(api_key=settings.GEMINI_API_KEY)
|
||||||
|
self._model = genai.GenerativeModel(settings.GEMINI_MODEL)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Prompt Construction
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def build_prompt(self, market_data: dict[str, Any]) -> str:
|
||||||
|
"""Build a structured prompt from market data.
|
||||||
|
|
||||||
|
The prompt instructs Gemini to return valid JSON with action,
|
||||||
|
confidence, and rationale fields.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
"You are a professional Korean stock market trading analyst.\n"
|
||||||
|
"Analyze the following market data and decide whether to BUY, SELL, or HOLD.\n\n"
|
||||||
|
f"Stock Code: {market_data['stock_code']}\n"
|
||||||
|
f"Current Price: {market_data['current_price']}\n"
|
||||||
|
f"Orderbook: {json.dumps(market_data['orderbook'], ensure_ascii=False)}\n"
|
||||||
|
f"Foreigner Net Buy/Sell: {market_data['foreigner_net']}\n\n"
|
||||||
|
"You MUST respond with ONLY valid JSON in the following format:\n"
|
||||||
|
'{"action": "BUY"|"SELL"|"HOLD", "confidence": <int 0-100>, "rationale": "<string>"}\n\n'
|
||||||
|
"Rules:\n"
|
||||||
|
"- action must be exactly one of: BUY, SELL, HOLD\n"
|
||||||
|
"- confidence must be an integer from 0 to 100\n"
|
||||||
|
"- rationale must explain your reasoning concisely\n"
|
||||||
|
"- Do NOT wrap the JSON in markdown code blocks\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Response Parsing
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def parse_response(self, raw: str) -> TradeDecision:
|
||||||
|
"""Parse a raw Gemini response into a TradeDecision.
|
||||||
|
|
||||||
|
Handles: valid JSON, JSON wrapped in markdown code blocks,
|
||||||
|
malformed JSON, missing fields, and invalid action values.
|
||||||
|
|
||||||
|
On any failure, returns a safe HOLD with confidence 0.
|
||||||
|
"""
|
||||||
|
if not raw or not raw.strip():
|
||||||
|
logger.warning("Empty response from Gemini — defaulting to HOLD")
|
||||||
|
return TradeDecision(action="HOLD", confidence=0, rationale="Empty response")
|
||||||
|
|
||||||
|
# Strip markdown code fences if present
|
||||||
|
cleaned = raw.strip()
|
||||||
|
match = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", cleaned, re.DOTALL)
|
||||||
|
if match:
|
||||||
|
cleaned = match.group(1).strip()
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(cleaned)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("Malformed JSON from Gemini — defaulting to HOLD")
|
||||||
|
return TradeDecision(
|
||||||
|
action="HOLD", confidence=0, rationale="Malformed JSON response"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate required fields
|
||||||
|
if not all(k in data for k in ("action", "confidence", "rationale")):
|
||||||
|
logger.warning("Missing fields in Gemini response — defaulting to HOLD")
|
||||||
|
return TradeDecision(
|
||||||
|
action="HOLD", confidence=0, rationale="Missing required fields"
|
||||||
|
)
|
||||||
|
|
||||||
|
action = str(data["action"]).upper()
|
||||||
|
if action not in VALID_ACTIONS:
|
||||||
|
logger.warning("Invalid action '%s' from Gemini — defaulting to HOLD", action)
|
||||||
|
return TradeDecision(
|
||||||
|
action="HOLD", confidence=0, rationale=f"Invalid action: {action}"
|
||||||
|
)
|
||||||
|
|
||||||
|
confidence = int(data["confidence"])
|
||||||
|
rationale = str(data["rationale"])
|
||||||
|
|
||||||
|
# Enforce confidence threshold
|
||||||
|
if confidence < self._confidence_threshold:
|
||||||
|
logger.info(
|
||||||
|
"Confidence %d < threshold %d — forcing HOLD",
|
||||||
|
confidence,
|
||||||
|
self._confidence_threshold,
|
||||||
|
)
|
||||||
|
action = "HOLD"
|
||||||
|
|
||||||
|
return TradeDecision(action=action, confidence=confidence, rationale=rationale)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# API Call
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def decide(self, market_data: dict[str, Any]) -> TradeDecision:
|
||||||
|
"""Build prompt, call Gemini, and return a parsed decision."""
|
||||||
|
prompt = self.build_prompt(market_data)
|
||||||
|
logger.info("Requesting trade decision from Gemini")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self._model.generate_content_async(prompt)
|
||||||
|
raw = response.text
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Gemini API error: %s", exc)
|
||||||
|
return TradeDecision(
|
||||||
|
action="HOLD", confidence=0, rationale=f"API error: {exc}"
|
||||||
|
)
|
||||||
|
|
||||||
|
decision = self.parse_response(raw)
|
||||||
|
logger.info(
|
||||||
|
"Gemini decision",
|
||||||
|
extra={
|
||||||
|
"action": decision.action,
|
||||||
|
"confidence": decision.confidence,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return decision
|
||||||
0
src/broker/__init__.py
Normal file
0
src/broker/__init__.py
Normal file
245
src/broker/kis_api.py
Normal file
245
src/broker/kis_api.py
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
"""Async wrapper for the Korea Investment Securities (KIS) Open API.
|
||||||
|
|
||||||
|
Handles token refresh, rate limiting (leaky bucket), and hash key generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from src.config import Settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LeakyBucket:
|
||||||
|
"""Simple leaky-bucket rate limiter for async code."""
|
||||||
|
|
||||||
|
def __init__(self, rate: float) -> None:
|
||||||
|
"""Args:
|
||||||
|
rate: Maximum requests per second.
|
||||||
|
"""
|
||||||
|
self._rate = rate
|
||||||
|
self._interval = 1.0 / rate
|
||||||
|
self._last = 0.0
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def acquire(self) -> None:
|
||||||
|
async with self._lock:
|
||||||
|
now = asyncio.get_event_loop().time()
|
||||||
|
wait = self._last + self._interval - now
|
||||||
|
if wait > 0:
|
||||||
|
await asyncio.sleep(wait)
|
||||||
|
self._last = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
|
|
||||||
|
class KISBroker:
|
||||||
|
"""Async client for KIS Open API with automatic token management."""
|
||||||
|
|
||||||
|
def __init__(self, settings: Settings) -> None:
|
||||||
|
self._settings = settings
|
||||||
|
self._base_url = settings.KIS_BASE_URL
|
||||||
|
self._app_key = settings.KIS_APP_KEY
|
||||||
|
self._app_secret = settings.KIS_APP_SECRET
|
||||||
|
self._account_no = settings.account_number
|
||||||
|
self._product_cd = settings.account_product_code
|
||||||
|
|
||||||
|
self._session: aiohttp.ClientSession | None = None
|
||||||
|
self._access_token: str | None = None
|
||||||
|
self._token_expires_at: float = 0.0
|
||||||
|
self._rate_limiter = LeakyBucket(settings.RATE_LIMIT_RPS)
|
||||||
|
|
||||||
|
def _get_session(self) -> aiohttp.ClientSession:
|
||||||
|
if self._session is None or self._session.closed:
|
||||||
|
timeout = aiohttp.ClientTimeout(total=10)
|
||||||
|
self._session = aiohttp.ClientSession(timeout=timeout)
|
||||||
|
return self._session
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
if self._session and not self._session.closed:
|
||||||
|
await self._session.close()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Token Management
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _ensure_token(self) -> str:
|
||||||
|
"""Return a valid access token, refreshing if expired."""
|
||||||
|
now = asyncio.get_event_loop().time()
|
||||||
|
if self._access_token and now < self._token_expires_at:
|
||||||
|
return self._access_token
|
||||||
|
|
||||||
|
logger.info("Refreshing KIS access token")
|
||||||
|
session = self._get_session()
|
||||||
|
url = f"{self._base_url}/oauth2/tokenP"
|
||||||
|
body = {
|
||||||
|
"grant_type": "client_credentials",
|
||||||
|
"appkey": self._app_key,
|
||||||
|
"appsecret": self._app_secret,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with session.post(url, json=body) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
text = await resp.text()
|
||||||
|
raise ConnectionError(f"Token refresh failed ({resp.status}): {text}")
|
||||||
|
data = await resp.json()
|
||||||
|
|
||||||
|
self._access_token = data["access_token"]
|
||||||
|
self._token_expires_at = now + data.get("expires_in", 86400) - 60 # 1-min buffer
|
||||||
|
logger.info("Token refreshed successfully")
|
||||||
|
return self._access_token
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Hash Key (required for POST bodies)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _get_hash_key(self, body: dict[str, Any]) -> str:
|
||||||
|
"""Request a hash key from KIS for POST request body signing."""
|
||||||
|
session = self._get_session()
|
||||||
|
url = f"{self._base_url}/uapi/hashkey"
|
||||||
|
headers = {
|
||||||
|
"content-Type": "application/json",
|
||||||
|
"appKey": self._app_key,
|
||||||
|
"appSecret": self._app_secret,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with session.post(url, json=body, headers=headers) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
text = await resp.text()
|
||||||
|
raise ConnectionError(f"Hash key request failed ({resp.status}): {text}")
|
||||||
|
data = await resp.json()
|
||||||
|
|
||||||
|
return data["HASH"]
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Common Headers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _auth_headers(self, tr_id: str) -> dict[str, str]:
|
||||||
|
token = await self._ensure_token()
|
||||||
|
return {
|
||||||
|
"content-type": "application/json; charset=utf-8",
|
||||||
|
"authorization": f"Bearer {token}",
|
||||||
|
"appkey": self._app_key,
|
||||||
|
"appsecret": self._app_secret,
|
||||||
|
"tr_id": tr_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# API Methods
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def get_orderbook(self, stock_code: str) -> dict[str, Any]:
|
||||||
|
"""Fetch the current orderbook for a given stock code."""
|
||||||
|
await self._rate_limiter.acquire()
|
||||||
|
session = self._get_session()
|
||||||
|
|
||||||
|
headers = await self._auth_headers("FHKST01010200")
|
||||||
|
params = {
|
||||||
|
"FID_COND_MRKT_DIV_CODE": "J",
|
||||||
|
"FID_INPUT_ISCD": stock_code,
|
||||||
|
}
|
||||||
|
url = f"{self._base_url}/uapi/domestic-stock/v1/quotations/inquire-asking-price-exp-ccn"
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.get(url, headers=headers, params=params) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
text = await resp.text()
|
||||||
|
raise ConnectionError(
|
||||||
|
f"get_orderbook failed ({resp.status}): {text}"
|
||||||
|
)
|
||||||
|
return await resp.json()
|
||||||
|
except (aiohttp.ClientError, asyncio.TimeoutError) as exc:
|
||||||
|
raise ConnectionError(f"Network error fetching orderbook: {exc}") from exc
|
||||||
|
|
||||||
|
async def get_balance(self) -> dict[str, Any]:
|
||||||
|
"""Fetch current account balance and holdings."""
|
||||||
|
await self._rate_limiter.acquire()
|
||||||
|
session = self._get_session()
|
||||||
|
|
||||||
|
headers = await self._auth_headers("VTTC8434R") # 모의투자 잔고조회
|
||||||
|
params = {
|
||||||
|
"CANO": self._account_no,
|
||||||
|
"ACNT_PRDT_CD": self._product_cd,
|
||||||
|
"AFHR_FLPR_YN": "N",
|
||||||
|
"OFL_YN": "",
|
||||||
|
"INQR_DVSN": "02",
|
||||||
|
"UNPR_DVSN": "01",
|
||||||
|
"FUND_STTL_ICLD_YN": "N",
|
||||||
|
"FNCG_AMT_AUTO_RDPT_YN": "N",
|
||||||
|
"PRCS_DVSN": "01",
|
||||||
|
"CTX_AREA_FK100": "",
|
||||||
|
"CTX_AREA_NK100": "",
|
||||||
|
}
|
||||||
|
url = f"{self._base_url}/uapi/domestic-stock/v1/trading/inquire-balance"
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.get(url, headers=headers, params=params) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
text = await resp.text()
|
||||||
|
raise ConnectionError(
|
||||||
|
f"get_balance failed ({resp.status}): {text}"
|
||||||
|
)
|
||||||
|
return await resp.json()
|
||||||
|
except (aiohttp.ClientError, asyncio.TimeoutError) as exc:
|
||||||
|
raise ConnectionError(f"Network error fetching balance: {exc}") from exc
|
||||||
|
|
||||||
|
async def send_order(
|
||||||
|
self,
|
||||||
|
stock_code: str,
|
||||||
|
order_type: str, # "BUY" or "SELL"
|
||||||
|
quantity: int,
|
||||||
|
price: int = 0,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Submit a buy or sell order.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stock_code: 6-digit stock code.
|
||||||
|
order_type: "BUY" or "SELL".
|
||||||
|
quantity: Number of shares.
|
||||||
|
price: Order price (0 for market order).
|
||||||
|
"""
|
||||||
|
await self._rate_limiter.acquire()
|
||||||
|
session = self._get_session()
|
||||||
|
|
||||||
|
tr_id = "VTTC0802U" if order_type == "BUY" else "VTTC0801U"
|
||||||
|
body = {
|
||||||
|
"CANO": self._account_no,
|
||||||
|
"ACNT_PRDT_CD": self._product_cd,
|
||||||
|
"PDNO": stock_code,
|
||||||
|
"ORD_DVSN": "01" if price > 0 else "06", # 01=지정가, 06=시장가
|
||||||
|
"ORD_QTY": str(quantity),
|
||||||
|
"ORD_UNPR": str(price),
|
||||||
|
}
|
||||||
|
|
||||||
|
hash_key = await self._get_hash_key(body)
|
||||||
|
headers = await self._auth_headers(tr_id)
|
||||||
|
headers["hashkey"] = hash_key
|
||||||
|
|
||||||
|
url = f"{self._base_url}/uapi/domestic-stock/v1/trading/order-cash"
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.post(url, headers=headers, json=body) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
text = await resp.text()
|
||||||
|
raise ConnectionError(
|
||||||
|
f"send_order failed ({resp.status}): {text}"
|
||||||
|
)
|
||||||
|
data = await resp.json()
|
||||||
|
logger.info(
|
||||||
|
"Order submitted",
|
||||||
|
extra={
|
||||||
|
"stock_code": stock_code,
|
||||||
|
"action": order_type,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
except (aiohttp.ClientError, asyncio.TimeoutError) as exc:
|
||||||
|
raise ConnectionError(f"Network error sending order: {exc}") from exc
|
||||||
44
src/config.py
Normal file
44
src/config.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
"""Strictly typed configuration loaded from environment variables."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""Application settings — loaded from .env or environment variables."""
|
||||||
|
|
||||||
|
# KIS Open API
|
||||||
|
KIS_APP_KEY: str
|
||||||
|
KIS_APP_SECRET: str
|
||||||
|
KIS_ACCOUNT_NO: str # format: "XXXXXXXX-XX"
|
||||||
|
KIS_BASE_URL: str = "https://openapivts.koreainvestment.com:9443"
|
||||||
|
|
||||||
|
# Google Gemini
|
||||||
|
GEMINI_API_KEY: str
|
||||||
|
GEMINI_MODEL: str = "gemini-pro"
|
||||||
|
|
||||||
|
# Risk Management
|
||||||
|
CIRCUIT_BREAKER_PCT: float = Field(default=-3.0, le=0.0)
|
||||||
|
FAT_FINGER_PCT: float = Field(default=30.0, gt=0.0, le=100.0)
|
||||||
|
CONFIDENCE_THRESHOLD: int = Field(default=80, ge=0, le=100)
|
||||||
|
|
||||||
|
# Database
|
||||||
|
DB_PATH: str = "data/trade_logs.db"
|
||||||
|
|
||||||
|
# Rate Limiting (requests per second for KIS API)
|
||||||
|
RATE_LIMIT_RPS: float = 10.0
|
||||||
|
|
||||||
|
# Trading mode
|
||||||
|
MODE: str = Field(default="paper", pattern="^(paper|live)$")
|
||||||
|
|
||||||
|
model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def account_number(self) -> str:
|
||||||
|
return self.KIS_ACCOUNT_NO.split("-")[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def account_product_code(self) -> str:
|
||||||
|
return self.KIS_ACCOUNT_NO.split("-")[1]
|
||||||
0
src/core/__init__.py
Normal file
0
src/core/__init__.py
Normal file
84
src/core/risk_manager.py
Normal file
84
src/core/risk_manager.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
"""Risk management — the Shield that protects the portfolio.
|
||||||
|
|
||||||
|
This module is READ-ONLY by policy (see docs/agents.md).
|
||||||
|
Changes require human approval and two passing test suites.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from src.config import Settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CircuitBreakerTripped(SystemExit):
|
||||||
|
"""Raised when daily P&L loss exceeds the allowed threshold."""
|
||||||
|
|
||||||
|
def __init__(self, pnl_pct: float, threshold: float) -> None:
|
||||||
|
self.pnl_pct = pnl_pct
|
||||||
|
self.threshold = threshold
|
||||||
|
super().__init__(
|
||||||
|
f"CIRCUIT BREAKER: Daily P&L {pnl_pct:.2f}% exceeded "
|
||||||
|
f"threshold {threshold:.2f}%. All trading halted."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FatFingerRejected(Exception):
|
||||||
|
"""Raised when an order exceeds the maximum allowed proportion of cash."""
|
||||||
|
|
||||||
|
def __init__(self, order_amount: float, total_cash: float, max_pct: float) -> None:
|
||||||
|
self.order_amount = order_amount
|
||||||
|
self.total_cash = total_cash
|
||||||
|
self.max_pct = max_pct
|
||||||
|
ratio = (order_amount / total_cash * 100) if total_cash > 0 else float("inf")
|
||||||
|
super().__init__(
|
||||||
|
f"FAT FINGER: Order {order_amount:,.0f} is {ratio:.1f}% of "
|
||||||
|
f"cash {total_cash:,.0f} (max allowed: {max_pct:.1f}%)."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RiskManager:
|
||||||
|
"""Pre-order risk gate that enforces circuit breaker and fat-finger checks."""
|
||||||
|
|
||||||
|
def __init__(self, settings: Settings) -> None:
|
||||||
|
self._cb_threshold = settings.CIRCUIT_BREAKER_PCT
|
||||||
|
self._ff_max_pct = settings.FAT_FINGER_PCT
|
||||||
|
|
||||||
|
def check_circuit_breaker(self, current_pnl_pct: float) -> None:
|
||||||
|
"""Halt trading if daily loss exceeds the threshold.
|
||||||
|
|
||||||
|
The threshold is inclusive: exactly -3.0% is allowed, but -3.01% is not.
|
||||||
|
"""
|
||||||
|
if current_pnl_pct < self._cb_threshold:
|
||||||
|
logger.critical(
|
||||||
|
"Circuit breaker tripped",
|
||||||
|
extra={"pnl_pct": current_pnl_pct},
|
||||||
|
)
|
||||||
|
raise CircuitBreakerTripped(current_pnl_pct, self._cb_threshold)
|
||||||
|
|
||||||
|
def check_fat_finger(self, order_amount: float, total_cash: float) -> None:
|
||||||
|
"""Reject orders that exceed the maximum proportion of available cash."""
|
||||||
|
if total_cash <= 0:
|
||||||
|
raise FatFingerRejected(order_amount, total_cash, self._ff_max_pct)
|
||||||
|
|
||||||
|
ratio_pct = (order_amount / total_cash) * 100
|
||||||
|
if ratio_pct > self._ff_max_pct:
|
||||||
|
logger.warning(
|
||||||
|
"Fat finger check failed",
|
||||||
|
extra={"order_amount": order_amount},
|
||||||
|
)
|
||||||
|
raise FatFingerRejected(order_amount, total_cash, self._ff_max_pct)
|
||||||
|
|
||||||
|
def validate_order(
|
||||||
|
self,
|
||||||
|
current_pnl_pct: float,
|
||||||
|
order_amount: float,
|
||||||
|
total_cash: float,
|
||||||
|
) -> None:
|
||||||
|
"""Run all pre-order risk checks. Raises on failure."""
|
||||||
|
self.check_circuit_breaker(current_pnl_pct)
|
||||||
|
self.check_fat_finger(order_amount, total_cash)
|
||||||
|
logger.info("Order passed risk validation")
|
||||||
59
src/db.py
Normal file
59
src/db.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
"""Database layer for trade logging."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def init_db(db_path: str) -> sqlite3.Connection:
|
||||||
|
"""Initialize the trade logs database and return a connection."""
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS trades (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
timestamp TEXT NOT NULL,
|
||||||
|
stock_code TEXT NOT NULL,
|
||||||
|
action TEXT NOT NULL,
|
||||||
|
confidence INTEGER NOT NULL,
|
||||||
|
rationale TEXT,
|
||||||
|
quantity INTEGER,
|
||||||
|
price REAL,
|
||||||
|
pnl REAL DEFAULT 0.0
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
return conn
|
||||||
|
|
||||||
|
|
||||||
|
def log_trade(
|
||||||
|
conn: sqlite3.Connection,
|
||||||
|
stock_code: str,
|
||||||
|
action: str,
|
||||||
|
confidence: int,
|
||||||
|
rationale: str,
|
||||||
|
quantity: int = 0,
|
||||||
|
price: float = 0.0,
|
||||||
|
pnl: float = 0.0,
|
||||||
|
) -> None:
|
||||||
|
"""Insert a trade record into the database."""
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO trades (timestamp, stock_code, action, confidence, rationale, quantity, price, pnl)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
datetime.now(timezone.utc).isoformat(),
|
||||||
|
stock_code,
|
||||||
|
action,
|
||||||
|
confidence,
|
||||||
|
rationale,
|
||||||
|
quantity,
|
||||||
|
price,
|
||||||
|
pnl,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
0
src/evolution/__init__.py
Normal file
0
src/evolution/__init__.py
Normal file
229
src/evolution/optimizer.py
Normal file
229
src/evolution/optimizer.py
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
"""Evolution Engine — analyzes trade logs and generates new strategies.
|
||||||
|
|
||||||
|
This module:
|
||||||
|
1. Reads trade_logs.db to identify failing patterns
|
||||||
|
2. Asks Gemini to generate a new strategy class
|
||||||
|
3. Runs pytest on the generated file
|
||||||
|
4. Creates a simulated PR if tests pass
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sqlite3
|
||||||
|
import subprocess
|
||||||
|
import textwrap
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import google.generativeai as genai
|
||||||
|
|
||||||
|
from src.config import Settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
STRATEGIES_DIR = Path("src/strategies")
|
||||||
|
STRATEGY_TEMPLATE = textwrap.dedent("""\
|
||||||
|
\"\"\"Auto-generated strategy: {name}
|
||||||
|
|
||||||
|
Generated at: {timestamp}
|
||||||
|
Rationale: {rationale}
|
||||||
|
\"\"\"
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
from typing import Any
|
||||||
|
from src.strategies.base import BaseStrategy
|
||||||
|
|
||||||
|
|
||||||
|
class {class_name}(BaseStrategy):
|
||||||
|
\"\"\"Strategy: {name}\"\"\"
|
||||||
|
|
||||||
|
def evaluate(self, market_data: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
{body}
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
class EvolutionOptimizer:
|
||||||
|
"""Analyzes trade history and evolves trading strategies."""
|
||||||
|
|
||||||
|
def __init__(self, settings: Settings) -> None:
|
||||||
|
self._settings = settings
|
||||||
|
self._db_path = settings.DB_PATH
|
||||||
|
genai.configure(api_key=settings.GEMINI_API_KEY)
|
||||||
|
self._model = genai.GenerativeModel(settings.GEMINI_MODEL)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Analysis
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def analyze_failures(self, limit: int = 50) -> list[dict[str, Any]]:
|
||||||
|
"""Find trades where high confidence led to losses."""
|
||||||
|
conn = sqlite3.connect(self._db_path)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
try:
|
||||||
|
rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT stock_code, action, confidence, pnl, rationale, timestamp
|
||||||
|
FROM trades
|
||||||
|
WHERE confidence >= 80 AND pnl < 0
|
||||||
|
ORDER BY pnl ASC
|
||||||
|
LIMIT ?
|
||||||
|
""",
|
||||||
|
(limit,),
|
||||||
|
).fetchall()
|
||||||
|
return [dict(r) for r in rows]
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def get_performance_summary(self) -> dict[str, Any]:
|
||||||
|
"""Return aggregate performance metrics from trade logs."""
|
||||||
|
conn = sqlite3.connect(self._db_path)
|
||||||
|
try:
|
||||||
|
row = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT
|
||||||
|
COUNT(*) as total_trades,
|
||||||
|
SUM(CASE WHEN pnl > 0 THEN 1 ELSE 0 END) as wins,
|
||||||
|
SUM(CASE WHEN pnl < 0 THEN 1 ELSE 0 END) as losses,
|
||||||
|
COALESCE(AVG(pnl), 0) as avg_pnl,
|
||||||
|
COALESCE(SUM(pnl), 0) as total_pnl
|
||||||
|
FROM trades
|
||||||
|
"""
|
||||||
|
).fetchone()
|
||||||
|
return {
|
||||||
|
"total_trades": row[0],
|
||||||
|
"wins": row[1] or 0,
|
||||||
|
"losses": row[2] or 0,
|
||||||
|
"avg_pnl": round(row[3], 2),
|
||||||
|
"total_pnl": round(row[4], 2),
|
||||||
|
}
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Strategy Generation
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def generate_strategy(self, failures: list[dict[str, Any]]) -> Path | None:
|
||||||
|
"""Ask Gemini to generate a new strategy based on failure analysis.
|
||||||
|
|
||||||
|
Returns the path to the generated strategy file, or None on failure.
|
||||||
|
"""
|
||||||
|
prompt = (
|
||||||
|
"You are a quantitative trading strategy developer.\n"
|
||||||
|
"Analyze these failed trades and generate an improved strategy.\n\n"
|
||||||
|
f"Failed trades:\n{json.dumps(failures, indent=2, default=str)}\n\n"
|
||||||
|
"Generate a Python class that inherits from BaseStrategy.\n"
|
||||||
|
"The class must have an `evaluate(self, market_data: dict) -> dict` method.\n"
|
||||||
|
"The method must return a dict with keys: action, confidence, rationale.\n"
|
||||||
|
"Respond with ONLY the method body (Python code), no class definition.\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self._model.generate_content_async(prompt)
|
||||||
|
body = response.text.strip()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to generate strategy: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Clean up code fences
|
||||||
|
if body.startswith("```"):
|
||||||
|
lines = body.split("\n")
|
||||||
|
body = "\n".join(lines[1:-1])
|
||||||
|
|
||||||
|
# Create strategy file
|
||||||
|
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
||||||
|
version = f"v{timestamp}"
|
||||||
|
class_name = f"Strategy_{version}"
|
||||||
|
file_name = f"{version}_evolved.py"
|
||||||
|
|
||||||
|
STRATEGIES_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
file_path = STRATEGIES_DIR / file_name
|
||||||
|
|
||||||
|
# Indent the body for the class method
|
||||||
|
indented_body = textwrap.indent(body, " ")
|
||||||
|
|
||||||
|
content = STRATEGY_TEMPLATE.format(
|
||||||
|
name=version,
|
||||||
|
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||||
|
rationale="Auto-evolved from failure analysis",
|
||||||
|
class_name=class_name,
|
||||||
|
body=indented_body.strip(),
|
||||||
|
)
|
||||||
|
|
||||||
|
file_path.write_text(content)
|
||||||
|
logger.info("Generated strategy file: %s", file_path)
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Validation
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def validate_strategy(self, strategy_path: Path) -> bool:
|
||||||
|
"""Run pytest on the generated strategy. Returns True if all tests pass."""
|
||||||
|
logger.info("Validating strategy: %s", strategy_path)
|
||||||
|
result = subprocess.run(
|
||||||
|
["python", "-m", "pytest", "tests/", "-v", "--tb=short"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
if result.returncode == 0:
|
||||||
|
logger.info("Strategy validation PASSED")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Strategy validation FAILED:\n%s", result.stdout + result.stderr
|
||||||
|
)
|
||||||
|
# Clean up failing strategy
|
||||||
|
strategy_path.unlink(missing_ok=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# PR Simulation
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def create_pr_simulation(self, strategy_path: Path) -> dict[str, str]:
|
||||||
|
"""Simulate creating a pull request for the new strategy."""
|
||||||
|
pr = {
|
||||||
|
"title": f"[Evolution] New strategy: {strategy_path.stem}",
|
||||||
|
"branch": f"evolution/{strategy_path.stem}",
|
||||||
|
"body": (
|
||||||
|
f"Auto-generated strategy from evolution engine.\n"
|
||||||
|
f"File: {strategy_path}\n"
|
||||||
|
f"All tests passed."
|
||||||
|
),
|
||||||
|
"status": "ready_for_review",
|
||||||
|
}
|
||||||
|
logger.info("PR simulation created: %s", pr["title"])
|
||||||
|
return pr
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Full Pipeline
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def evolve(self) -> dict[str, Any] | None:
|
||||||
|
"""Run the full evolution pipeline.
|
||||||
|
|
||||||
|
1. Analyze failures
|
||||||
|
2. Generate new strategy
|
||||||
|
3. Validate with tests
|
||||||
|
4. Create PR simulation
|
||||||
|
|
||||||
|
Returns PR info on success, None on failure.
|
||||||
|
"""
|
||||||
|
failures = self.analyze_failures()
|
||||||
|
if not failures:
|
||||||
|
logger.info("No failure patterns found — skipping evolution")
|
||||||
|
return None
|
||||||
|
|
||||||
|
strategy_path = await self.generate_strategy(failures)
|
||||||
|
if strategy_path is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not self.validate_strategy(strategy_path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self.create_pr_simulation(strategy_path)
|
||||||
42
src/logging_config.py
Normal file
42
src/logging_config.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""JSON-formatted structured logging for machine readability."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
class JSONFormatter(logging.Formatter):
|
||||||
|
"""Emit log records as single-line JSON objects."""
|
||||||
|
|
||||||
|
def format(self, record: logging.LogRecord) -> str:
|
||||||
|
log_entry: dict[str, Any] = {
|
||||||
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"level": record.levelname,
|
||||||
|
"logger": record.name,
|
||||||
|
"message": record.getMessage(),
|
||||||
|
}
|
||||||
|
if record.exc_info and record.exc_info[1]:
|
||||||
|
log_entry["exception"] = self.formatException(record.exc_info)
|
||||||
|
# Merge any extra fields attached to the record
|
||||||
|
for key in ("stock_code", "action", "confidence", "pnl_pct", "order_amount"):
|
||||||
|
value = getattr(record, key, None)
|
||||||
|
if value is not None:
|
||||||
|
log_entry[key] = value
|
||||||
|
return json.dumps(log_entry, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(level: int = logging.INFO) -> None:
|
||||||
|
"""Configure the root logger with JSON output to stdout."""
|
||||||
|
handler = logging.StreamHandler(sys.stdout)
|
||||||
|
handler.setFormatter(JSONFormatter())
|
||||||
|
|
||||||
|
root = logging.getLogger()
|
||||||
|
root.setLevel(level)
|
||||||
|
# Avoid duplicate handlers on repeated calls
|
||||||
|
root.handlers.clear()
|
||||||
|
root.addHandler(handler)
|
||||||
171
src/main.py
Normal file
171
src/main.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
"""The Ouroboros — main trading loop.
|
||||||
|
|
||||||
|
Orchestrates the broker, brain, and risk manager into a continuous
|
||||||
|
trading cycle with configurable intervals.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from src.brain.gemini_client import GeminiClient
|
||||||
|
from src.broker.kis_api import KISBroker
|
||||||
|
from src.config import Settings
|
||||||
|
from src.core.risk_manager import CircuitBreakerTripped, RiskManager
|
||||||
|
from src.db import init_db, log_trade
|
||||||
|
from src.logging_config import setup_logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Target stock codes to monitor
|
||||||
|
WATCHLIST = ["005930", "000660", "035420"] # Samsung, SK Hynix, NAVER
|
||||||
|
|
||||||
|
TRADE_INTERVAL_SECONDS = 60
|
||||||
|
|
||||||
|
|
||||||
|
async def trading_cycle(
|
||||||
|
broker: KISBroker,
|
||||||
|
brain: GeminiClient,
|
||||||
|
risk: RiskManager,
|
||||||
|
db_conn: Any,
|
||||||
|
stock_code: str,
|
||||||
|
) -> None:
|
||||||
|
"""Execute one trading cycle for a single stock."""
|
||||||
|
# 1. Fetch market data
|
||||||
|
orderbook = await broker.get_orderbook(stock_code)
|
||||||
|
balance_data = await broker.get_balance()
|
||||||
|
|
||||||
|
output2 = balance_data.get("output2", [{}])
|
||||||
|
total_eval = float(output2[0].get("tot_evlu_amt", "0")) if output2 else 0
|
||||||
|
total_cash = float(
|
||||||
|
balance_data.get("output2", [{}])[0].get("dnca_tot_amt", "0")
|
||||||
|
if output2
|
||||||
|
else "0"
|
||||||
|
)
|
||||||
|
purchase_total = float(output2[0].get("pchs_amt_smtl_amt", "0")) if output2 else 0
|
||||||
|
|
||||||
|
# Calculate daily P&L %
|
||||||
|
pnl_pct = ((total_eval - purchase_total) / purchase_total * 100) if purchase_total > 0 else 0.0
|
||||||
|
|
||||||
|
current_price = float(
|
||||||
|
orderbook.get("output1", {}).get("stck_prpr", "0")
|
||||||
|
)
|
||||||
|
|
||||||
|
market_data = {
|
||||||
|
"stock_code": stock_code,
|
||||||
|
"current_price": current_price,
|
||||||
|
"orderbook": orderbook.get("output1", {}),
|
||||||
|
"foreigner_net": float(
|
||||||
|
orderbook.get("output1", {}).get("frgn_ntby_qty", "0")
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 2. Ask the brain for a decision
|
||||||
|
decision = await brain.decide(market_data)
|
||||||
|
logger.info(
|
||||||
|
"Decision for %s: %s (confidence=%d)",
|
||||||
|
stock_code,
|
||||||
|
decision.action,
|
||||||
|
decision.confidence,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Execute if actionable
|
||||||
|
if decision.action in ("BUY", "SELL"):
|
||||||
|
# Determine order size (simplified: 1 lot)
|
||||||
|
quantity = 1
|
||||||
|
order_amount = current_price * quantity
|
||||||
|
|
||||||
|
# 4. Risk check BEFORE order
|
||||||
|
risk.validate_order(
|
||||||
|
current_pnl_pct=pnl_pct,
|
||||||
|
order_amount=order_amount,
|
||||||
|
total_cash=total_cash,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. Send order
|
||||||
|
result = await broker.send_order(
|
||||||
|
stock_code=stock_code,
|
||||||
|
order_type=decision.action,
|
||||||
|
quantity=quantity,
|
||||||
|
price=0, # market order
|
||||||
|
)
|
||||||
|
logger.info("Order result: %s", result.get("msg1", "OK"))
|
||||||
|
|
||||||
|
# 6. Log trade
|
||||||
|
log_trade(
|
||||||
|
conn=db_conn,
|
||||||
|
stock_code=stock_code,
|
||||||
|
action=decision.action,
|
||||||
|
confidence=decision.confidence,
|
||||||
|
rationale=decision.rationale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def run(settings: Settings) -> None:
|
||||||
|
"""Main async loop — iterate over watchlist on a timer."""
|
||||||
|
broker = KISBroker(settings)
|
||||||
|
brain = GeminiClient(settings)
|
||||||
|
risk = RiskManager(settings)
|
||||||
|
db_conn = init_db(settings.DB_PATH)
|
||||||
|
|
||||||
|
shutdown = asyncio.Event()
|
||||||
|
|
||||||
|
def _signal_handler() -> None:
|
||||||
|
logger.info("Shutdown signal received")
|
||||||
|
shutdown.set()
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||||
|
loop.add_signal_handler(sig, _signal_handler)
|
||||||
|
|
||||||
|
logger.info("The Ouroboros is alive. Mode: %s", settings.MODE)
|
||||||
|
logger.info("Watchlist: %s", WATCHLIST)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while not shutdown.is_set():
|
||||||
|
for code in WATCHLIST:
|
||||||
|
if shutdown.is_set():
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
await trading_cycle(broker, brain, risk, db_conn, code)
|
||||||
|
except CircuitBreakerTripped:
|
||||||
|
logger.critical("Circuit breaker tripped — shutting down")
|
||||||
|
raise
|
||||||
|
except ConnectionError as exc:
|
||||||
|
logger.error("Connection error for %s: %s", code, exc)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Unexpected error for %s: %s", code, exc)
|
||||||
|
|
||||||
|
# Wait for next cycle or shutdown
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(shutdown.wait(), timeout=TRADE_INTERVAL_SECONDS)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass # Normal — timeout means it's time for next cycle
|
||||||
|
finally:
|
||||||
|
await broker.close()
|
||||||
|
db_conn.close()
|
||||||
|
logger.info("The Ouroboros rests.")
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description="The Ouroboros Trading Agent")
|
||||||
|
parser.add_argument(
|
||||||
|
"--mode",
|
||||||
|
choices=["paper", "live"],
|
||||||
|
default="paper",
|
||||||
|
help="Trading mode (default: paper)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
setup_logging()
|
||||||
|
settings = Settings(MODE=args.mode) # type: ignore[call-arg]
|
||||||
|
asyncio.run(run(settings))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
0
src/strategies/__init__.py
Normal file
0
src/strategies/__init__.py
Normal file
19
src/strategies/base.py
Normal file
19
src/strategies/base.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
"""Base class for all trading strategies."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class BaseStrategy(ABC):
|
||||||
|
"""All strategies must inherit from this class."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def evaluate(self, market_data: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Evaluate market data and return a trade decision.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with keys: action ("BUY"|"SELL"|"HOLD"), confidence (int), rationale (str)
|
||||||
|
"""
|
||||||
|
...
|
||||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
23
tests/conftest.py
Normal file
23
tests/conftest.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
"""Shared test fixtures for The Ouroboros test suite."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.config import Settings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def settings() -> Settings:
|
||||||
|
"""Return a Settings instance with safe test defaults."""
|
||||||
|
return Settings(
|
||||||
|
KIS_APP_KEY="test_app_key",
|
||||||
|
KIS_APP_SECRET="test_app_secret",
|
||||||
|
KIS_ACCOUNT_NO="12345678-01",
|
||||||
|
KIS_BASE_URL="https://openapivts.koreainvestment.com:9443",
|
||||||
|
GEMINI_API_KEY="test_gemini_key",
|
||||||
|
CIRCUIT_BREAKER_PCT=-3.0,
|
||||||
|
FAT_FINGER_PCT=30.0,
|
||||||
|
CONFIDENCE_THRESHOLD=80,
|
||||||
|
DB_PATH=":memory:",
|
||||||
|
)
|
||||||
159
tests/test_brain.py
Normal file
159
tests/test_brain.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
"""TDD tests for brain/gemini_client.py — written BEFORE implementation."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.brain.gemini_client import GeminiClient, TradeDecision
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Response Parsing
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestResponseParsing:
|
||||||
|
"""Gemini responses must be parsed into validated TradeDecision objects."""
|
||||||
|
|
||||||
|
def test_valid_buy_response(self, settings):
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
raw = '{"action": "BUY", "confidence": 90, "rationale": "Strong momentum"}'
|
||||||
|
decision = client.parse_response(raw)
|
||||||
|
assert decision.action == "BUY"
|
||||||
|
assert decision.confidence == 90
|
||||||
|
assert decision.rationale == "Strong momentum"
|
||||||
|
|
||||||
|
def test_valid_sell_response(self, settings):
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
raw = '{"action": "SELL", "confidence": 85, "rationale": "Overbought RSI"}'
|
||||||
|
decision = client.parse_response(raw)
|
||||||
|
assert decision.action == "SELL"
|
||||||
|
|
||||||
|
def test_valid_hold_response(self, settings):
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
raw = '{"action": "HOLD", "confidence": 95, "rationale": "Sideways market"}'
|
||||||
|
decision = client.parse_response(raw)
|
||||||
|
assert decision.action == "HOLD"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Confidence Threshold Enforcement
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfidenceThreshold:
|
||||||
|
"""If confidence < 80, the action MUST be forced to HOLD."""
|
||||||
|
|
||||||
|
def test_low_confidence_buy_becomes_hold(self, settings):
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
raw = '{"action": "BUY", "confidence": 65, "rationale": "Weak signal"}'
|
||||||
|
decision = client.parse_response(raw)
|
||||||
|
assert decision.action == "HOLD"
|
||||||
|
assert decision.confidence == 65
|
||||||
|
|
||||||
|
def test_low_confidence_sell_becomes_hold(self, settings):
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
raw = '{"action": "SELL", "confidence": 79, "rationale": "Uncertain"}'
|
||||||
|
decision = client.parse_response(raw)
|
||||||
|
assert decision.action == "HOLD"
|
||||||
|
|
||||||
|
def test_exactly_threshold_is_allowed(self, settings):
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
raw = '{"action": "BUY", "confidence": 80, "rationale": "Just enough"}'
|
||||||
|
decision = client.parse_response(raw)
|
||||||
|
assert decision.action == "BUY"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Malformed JSON Handling
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMalformedJsonHandling:
|
||||||
|
"""Gemini may return garbage — the parser must not crash."""
|
||||||
|
|
||||||
|
def test_empty_string_returns_hold(self, settings):
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
decision = client.parse_response("")
|
||||||
|
assert decision.action == "HOLD"
|
||||||
|
assert decision.confidence == 0
|
||||||
|
|
||||||
|
def test_plain_text_returns_hold(self, settings):
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
decision = client.parse_response("I think you should buy Samsung stock")
|
||||||
|
assert decision.action == "HOLD"
|
||||||
|
assert decision.confidence == 0
|
||||||
|
|
||||||
|
def test_partial_json_returns_hold(self, settings):
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
decision = client.parse_response('{"action": "BUY", "confidence":')
|
||||||
|
assert decision.action == "HOLD"
|
||||||
|
assert decision.confidence == 0
|
||||||
|
|
||||||
|
def test_json_with_missing_fields_returns_hold(self, settings):
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
decision = client.parse_response('{"action": "BUY"}')
|
||||||
|
assert decision.action == "HOLD"
|
||||||
|
assert decision.confidence == 0
|
||||||
|
|
||||||
|
def test_json_with_invalid_action_returns_hold(self, settings):
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
decision = client.parse_response(
|
||||||
|
'{"action": "YOLO", "confidence": 99, "rationale": "moon"}'
|
||||||
|
)
|
||||||
|
assert decision.action == "HOLD"
|
||||||
|
assert decision.confidence == 0
|
||||||
|
|
||||||
|
def test_json_wrapped_in_markdown_code_block(self, settings):
|
||||||
|
"""Gemini often wraps JSON in ```json ... ``` blocks."""
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
raw = '```json\n{"action": "BUY", "confidence": 92, "rationale": "Good"}\n```'
|
||||||
|
decision = client.parse_response(raw)
|
||||||
|
assert decision.action == "BUY"
|
||||||
|
assert decision.confidence == 92
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Prompt Construction
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptConstruction:
|
||||||
|
"""The prompt sent to Gemini must include all required market data."""
|
||||||
|
|
||||||
|
def test_prompt_contains_stock_code(self, settings):
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
market_data = {
|
||||||
|
"stock_code": "005930",
|
||||||
|
"current_price": 72000,
|
||||||
|
"orderbook": {"asks": [], "bids": []},
|
||||||
|
"foreigner_net": -50000,
|
||||||
|
}
|
||||||
|
prompt = client.build_prompt(market_data)
|
||||||
|
assert "005930" in prompt
|
||||||
|
|
||||||
|
def test_prompt_contains_price(self, settings):
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
market_data = {
|
||||||
|
"stock_code": "005930",
|
||||||
|
"current_price": 72000,
|
||||||
|
"orderbook": {"asks": [], "bids": []},
|
||||||
|
"foreigner_net": -50000,
|
||||||
|
}
|
||||||
|
prompt = client.build_prompt(market_data)
|
||||||
|
assert "72000" in prompt
|
||||||
|
|
||||||
|
def test_prompt_enforces_json_output_format(self, settings):
|
||||||
|
client = GeminiClient(settings)
|
||||||
|
market_data = {
|
||||||
|
"stock_code": "005930",
|
||||||
|
"current_price": 72000,
|
||||||
|
"orderbook": {"asks": [], "bids": []},
|
||||||
|
"foreigner_net": 0,
|
||||||
|
}
|
||||||
|
prompt = client.build_prompt(market_data)
|
||||||
|
assert "JSON" in prompt
|
||||||
|
assert "action" in prompt
|
||||||
|
assert "confidence" in prompt
|
||||||
140
tests/test_broker.py
Normal file
140
tests/test_broker.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
"""TDD tests for broker/kis_api.py — written BEFORE implementation."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.broker.kis_api import KISBroker
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Token Management
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenManagement:
|
||||||
|
"""Access token must be auto-refreshed and cached."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetches_token_on_first_call(self, settings):
|
||||||
|
broker = KISBroker(settings)
|
||||||
|
mock_resp = AsyncMock()
|
||||||
|
mock_resp.status = 200
|
||||||
|
mock_resp.json = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"access_token": "tok_abc123",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": 86400,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
|
||||||
|
token = await broker._ensure_token()
|
||||||
|
assert token == "tok_abc123"
|
||||||
|
|
||||||
|
await broker.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reuses_cached_token(self, settings):
|
||||||
|
broker = KISBroker(settings)
|
||||||
|
broker._access_token = "cached_token"
|
||||||
|
broker._token_expires_at = asyncio.get_event_loop().time() + 3600
|
||||||
|
|
||||||
|
token = await broker._ensure_token()
|
||||||
|
assert token == "cached_token"
|
||||||
|
|
||||||
|
await broker.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Network Error Handling
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestNetworkErrorHandling:
|
||||||
|
"""Broker must handle network timeouts and HTTP errors gracefully."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_timeout_raises_connection_error(self, settings):
|
||||||
|
broker = KISBroker(settings)
|
||||||
|
broker._access_token = "tok"
|
||||||
|
broker._token_expires_at = asyncio.get_event_loop().time() + 3600
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"aiohttp.ClientSession.get",
|
||||||
|
side_effect=asyncio.TimeoutError(),
|
||||||
|
):
|
||||||
|
with pytest.raises(ConnectionError):
|
||||||
|
await broker.get_orderbook("005930")
|
||||||
|
|
||||||
|
await broker.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_http_500_raises_connection_error(self, settings):
|
||||||
|
broker = KISBroker(settings)
|
||||||
|
broker._access_token = "tok"
|
||||||
|
broker._token_expires_at = asyncio.get_event_loop().time() + 3600
|
||||||
|
|
||||||
|
mock_resp = AsyncMock()
|
||||||
|
mock_resp.status = 500
|
||||||
|
mock_resp.text = AsyncMock(return_value="Internal Server Error")
|
||||||
|
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession.get", return_value=mock_resp):
|
||||||
|
with pytest.raises(ConnectionError):
|
||||||
|
await broker.get_orderbook("005930")
|
||||||
|
|
||||||
|
await broker.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Rate Limiter
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRateLimiter:
|
||||||
|
"""The leaky bucket rate limiter must throttle requests."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rate_limiter_does_not_block_under_limit(self, settings):
|
||||||
|
broker = KISBroker(settings)
|
||||||
|
# Should complete without blocking when under limit
|
||||||
|
await broker._rate_limiter.acquire()
|
||||||
|
await broker.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Hash Key Generation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestHashKey:
|
||||||
|
"""POST requests to KIS require a hash key."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generates_hash_key_for_post_body(self, settings):
|
||||||
|
broker = KISBroker(settings)
|
||||||
|
broker._access_token = "tok"
|
||||||
|
broker._token_expires_at = asyncio.get_event_loop().time() + 3600
|
||||||
|
|
||||||
|
body = {"CANO": "12345678", "ACNT_PRDT_CD": "01"}
|
||||||
|
|
||||||
|
mock_resp = AsyncMock()
|
||||||
|
mock_resp.status = 200
|
||||||
|
mock_resp.json = AsyncMock(return_value={"HASH": "abc123hash"})
|
||||||
|
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
|
||||||
|
hash_key = await broker._get_hash_key(body)
|
||||||
|
assert isinstance(hash_key, str)
|
||||||
|
assert len(hash_key) > 0
|
||||||
|
|
||||||
|
await broker.close()
|
||||||
132
tests/test_risk.py
Normal file
132
tests/test_risk.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
"""TDD tests for core/risk_manager.py — written BEFORE implementation."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.core.risk_manager import (
|
||||||
|
CircuitBreakerTripped,
|
||||||
|
FatFingerRejected,
|
||||||
|
RiskManager,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Circuit Breaker Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCircuitBreaker:
|
||||||
|
"""The circuit breaker must halt all trading when daily loss exceeds the threshold."""
|
||||||
|
|
||||||
|
def test_allows_trading_when_pnl_is_positive(self, settings):
|
||||||
|
rm = RiskManager(settings)
|
||||||
|
# 2% gain — should be fine
|
||||||
|
rm.check_circuit_breaker(current_pnl_pct=2.0)
|
||||||
|
|
||||||
|
def test_allows_trading_at_zero_pnl(self, settings):
|
||||||
|
rm = RiskManager(settings)
|
||||||
|
rm.check_circuit_breaker(current_pnl_pct=0.0)
|
||||||
|
|
||||||
|
def test_allows_trading_at_exactly_threshold(self, settings):
|
||||||
|
rm = RiskManager(settings)
|
||||||
|
# Exactly -3.0% is ON the boundary — still allowed
|
||||||
|
rm.check_circuit_breaker(current_pnl_pct=-3.0)
|
||||||
|
|
||||||
|
def test_trips_when_loss_exceeds_threshold(self, settings):
|
||||||
|
rm = RiskManager(settings)
|
||||||
|
with pytest.raises(CircuitBreakerTripped):
|
||||||
|
rm.check_circuit_breaker(current_pnl_pct=-3.01)
|
||||||
|
|
||||||
|
def test_trips_at_large_loss(self, settings):
|
||||||
|
rm = RiskManager(settings)
|
||||||
|
with pytest.raises(CircuitBreakerTripped):
|
||||||
|
rm.check_circuit_breaker(current_pnl_pct=-10.0)
|
||||||
|
|
||||||
|
def test_custom_threshold(self):
|
||||||
|
"""A stricter threshold (-1.5%) should trip earlier."""
|
||||||
|
from src.config import Settings
|
||||||
|
|
||||||
|
strict = Settings(
|
||||||
|
KIS_APP_KEY="k",
|
||||||
|
KIS_APP_SECRET="s",
|
||||||
|
KIS_ACCOUNT_NO="00000000-00",
|
||||||
|
KIS_BASE_URL="https://example.com",
|
||||||
|
GEMINI_API_KEY="g",
|
||||||
|
CIRCUIT_BREAKER_PCT=-1.5,
|
||||||
|
FAT_FINGER_PCT=30.0,
|
||||||
|
CONFIDENCE_THRESHOLD=80,
|
||||||
|
DB_PATH=":memory:",
|
||||||
|
)
|
||||||
|
rm = RiskManager(strict)
|
||||||
|
with pytest.raises(CircuitBreakerTripped):
|
||||||
|
rm.check_circuit_breaker(current_pnl_pct=-1.51)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fat Finger Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestFatFingerCheck:
|
||||||
|
"""Orders exceeding 30% of total cash must be rejected."""
|
||||||
|
|
||||||
|
def test_allows_small_order(self, settings):
|
||||||
|
rm = RiskManager(settings)
|
||||||
|
# 10% of 10_000_000 = 1_000_000
|
||||||
|
rm.check_fat_finger(order_amount=1_000_000, total_cash=10_000_000)
|
||||||
|
|
||||||
|
def test_allows_order_at_exactly_threshold(self, settings):
|
||||||
|
rm = RiskManager(settings)
|
||||||
|
# Exactly 30% — allowed
|
||||||
|
rm.check_fat_finger(order_amount=3_000_000, total_cash=10_000_000)
|
||||||
|
|
||||||
|
def test_rejects_order_exceeding_threshold(self, settings):
|
||||||
|
rm = RiskManager(settings)
|
||||||
|
with pytest.raises(FatFingerRejected):
|
||||||
|
rm.check_fat_finger(order_amount=3_000_001, total_cash=10_000_000)
|
||||||
|
|
||||||
|
def test_rejects_massive_order(self, settings):
|
||||||
|
rm = RiskManager(settings)
|
||||||
|
with pytest.raises(FatFingerRejected):
|
||||||
|
rm.check_fat_finger(order_amount=9_000_000, total_cash=10_000_000)
|
||||||
|
|
||||||
|
def test_zero_cash_rejects_any_order(self, settings):
|
||||||
|
rm = RiskManager(settings)
|
||||||
|
with pytest.raises(FatFingerRejected):
|
||||||
|
rm.check_fat_finger(order_amount=1, total_cash=0)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pre-Order Validation (Integration of both checks)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestPreOrderValidation:
|
||||||
|
"""validate_order must run BOTH checks before approving."""
|
||||||
|
|
||||||
|
def test_passes_when_both_checks_ok(self, settings):
|
||||||
|
rm = RiskManager(settings)
|
||||||
|
rm.validate_order(
|
||||||
|
current_pnl_pct=0.5,
|
||||||
|
order_amount=1_000_000,
|
||||||
|
total_cash=10_000_000,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_fails_on_circuit_breaker(self, settings):
|
||||||
|
rm = RiskManager(settings)
|
||||||
|
with pytest.raises(CircuitBreakerTripped):
|
||||||
|
rm.validate_order(
|
||||||
|
current_pnl_pct=-5.0,
|
||||||
|
order_amount=100,
|
||||||
|
total_cash=10_000_000,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_fails_on_fat_finger(self, settings):
|
||||||
|
rm = RiskManager(settings)
|
||||||
|
with pytest.raises(FatFingerRejected):
|
||||||
|
rm.validate_order(
|
||||||
|
current_pnl_pct=1.0,
|
||||||
|
order_amount=5_000_000,
|
||||||
|
total_cash=10_000_000,
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user