Add complete Ouroboros trading system with TDD test suite
Some checks failed
CI / test (push) Has been cancelled
Some checks failed
CI / test (push) Has been cancelled
Implement the full autonomous trading agent architecture: - KIS broker with async API, token refresh, leaky bucket rate limiter, and hash key signing - Gemini-powered decision engine with JSON parsing and confidence threshold enforcement - Risk manager with circuit breaker (-3% P&L) and fat finger protection (30% cap) - Evolution engine for self-improving strategy generation via failure analysis - 35 passing tests written TDD-first covering risk, broker, and brain modules - CI/CD pipeline, Docker multi-stage build, and AI agent context docs Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
23
.env.example
Normal file
23
.env.example
Normal file
@@ -0,0 +1,23 @@
|
||||
# Korea Investment Securities API
|
||||
KIS_APP_KEY=your_app_key_here
|
||||
KIS_APP_SECRET=your_app_secret_here
|
||||
KIS_ACCOUNT_NO=12345678-01
|
||||
KIS_BASE_URL=https://openapivts.koreainvestment.com:9443
|
||||
|
||||
# Google Gemini
|
||||
GEMINI_API_KEY=your_gemini_api_key_here
|
||||
GEMINI_MODEL=gemini-pro
|
||||
|
||||
# Risk Management
|
||||
CIRCUIT_BREAKER_PCT=-3.0
|
||||
FAT_FINGER_PCT=30.0
|
||||
CONFIDENCE_THRESHOLD=80
|
||||
|
||||
# Database
|
||||
DB_PATH=data/trade_logs.db
|
||||
|
||||
# Rate Limiting
|
||||
RATE_LIMIT_RPS=10.0
|
||||
|
||||
# Trading Mode (paper / live)
|
||||
MODE=paper
|
||||
39
.github/workflows/ci.yml
vendored
Normal file
39
.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python 3.11
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Install dependencies
|
||||
run: pip install ".[dev]"
|
||||
|
||||
- name: Lint
|
||||
run: ruff check src/ tests/
|
||||
|
||||
- name: Type check
|
||||
run: mypy src/ --strict
|
||||
continue-on-error: true
|
||||
|
||||
- name: Run tests with coverage
|
||||
run: pytest -v --cov=src --cov-report=term-missing --cov-fail-under=80
|
||||
|
||||
- name: Upload coverage
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: coverage-report
|
||||
path: htmlcov/
|
||||
44
Dockerfile
Normal file
44
Dockerfile
Normal file
@@ -0,0 +1,44 @@
|
||||
FROM python:3.11-slim AS base
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# System deps
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
gcc \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Python dependencies
|
||||
COPY pyproject.toml .
|
||||
RUN pip install --no-cache-dir ".[dev]"
|
||||
|
||||
# Copy source
|
||||
COPY src/ src/
|
||||
COPY tests/ tests/
|
||||
COPY docs/ docs/
|
||||
|
||||
# Create data directory
|
||||
RUN mkdir -p data
|
||||
|
||||
# Run tests as build validation
|
||||
RUN pytest -v --tb=short
|
||||
|
||||
# Production stage
|
||||
FROM python:3.11-slim AS production
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY pyproject.toml .
|
||||
RUN pip install --no-cache-dir .
|
||||
|
||||
COPY src/ src/
|
||||
RUN mkdir -p data
|
||||
|
||||
# Non-root user
|
||||
RUN useradd --create-home appuser
|
||||
USER appuser
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=5s --retries=3 \
|
||||
CMD python -c "import sys; sys.exit(0)"
|
||||
|
||||
ENTRYPOINT ["python", "-m", "src.main"]
|
||||
CMD ["--mode=paper"]
|
||||
30
docker-compose.yml
Normal file
30
docker-compose.yml
Normal file
@@ -0,0 +1,30 @@
|
||||
services:
|
||||
ouroboros:
|
||||
build:
|
||||
context: .
|
||||
target: production
|
||||
container_name: ouroboros
|
||||
restart: unless-stopped
|
||||
env_file:
|
||||
- .env
|
||||
volumes:
|
||||
- trade_data:/app/data
|
||||
command: ["--mode=paper"]
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "10m"
|
||||
max-file: "3"
|
||||
|
||||
# Run tests (one-shot)
|
||||
test:
|
||||
build:
|
||||
context: .
|
||||
target: base
|
||||
container_name: ouroboros-test
|
||||
command: ["pytest", "-v", "--cov=src", "--cov-report=term-missing"]
|
||||
profiles:
|
||||
- test
|
||||
|
||||
volumes:
|
||||
trade_data:
|
||||
58
docs/agents.md
Normal file
58
docs/agents.md
Normal file
@@ -0,0 +1,58 @@
|
||||
# The Ouroboros - Agent Persona Definition
|
||||
|
||||
## Role: The Guardian
|
||||
|
||||
You are **The Guardian**, the primary AI agent responsible for maintaining and evolving
|
||||
The Ouroboros trading system. Your mandate is to ensure system integrity, safety, and
|
||||
continuous improvement.
|
||||
|
||||
## Prime Directives
|
||||
|
||||
1. **NEVER disable, bypass, or weaken `core/risk_manager.py`.**
|
||||
The risk manager is the last line of defense against catastrophic loss.
|
||||
Any code change that reduces risk controls MUST be rejected.
|
||||
|
||||
2. **All code changes require a passing test.**
|
||||
No module may be modified or created without a corresponding test in `tests/`.
|
||||
Run `pytest -v --cov=src` before proposing any merge.
|
||||
|
||||
3. **Preserve the Circuit Breaker.**
|
||||
The daily P&L circuit breaker (-3.0% threshold) is non-negotiable.
|
||||
It may only be made *stricter*, never relaxed.
|
||||
|
||||
4. **Fat Finger Protection is sacred.**
|
||||
The 30% max-order-size rule must remain enforced at all times.
|
||||
|
||||
## Decision Framework
|
||||
|
||||
When modifying code, follow this priority order:
|
||||
|
||||
1. **Safety** - Will this change increase risk exposure? If yes, reject.
|
||||
2. **Correctness** - Is the logic provably correct? Verify with tests.
|
||||
3. **Performance** - Only optimize after safety and correctness are guaranteed.
|
||||
4. **Readability** - Code must be understandable by future agents and humans.
|
||||
|
||||
## File Ownership
|
||||
|
||||
| Module | Guardian Rule |
|
||||
|---|---|
|
||||
| `core/risk_manager.py` | READ-ONLY. Changes require human approval + 2 passing test suites. |
|
||||
| `broker/kis_api.py` | Rate limiter must never be removed. Token refresh must remain automatic. |
|
||||
| `brain/gemini_client.py` | Confidence < 80 MUST force HOLD. This rule cannot be weakened. |
|
||||
| `evolution/optimizer.py` | Generated strategies must pass ALL tests before activation. |
|
||||
| `strategies/*` | New strategies are welcome but must inherit `BaseStrategy`. |
|
||||
|
||||
## Prohibited Actions
|
||||
|
||||
- Removing or commenting out `assert` statements in tests
|
||||
- Hardcoding API keys or secrets in source files
|
||||
- Disabling rate limiting on broker API calls
|
||||
- Allowing orders when the circuit breaker has tripped
|
||||
- Merging code with test coverage below 80%
|
||||
|
||||
## Context for Collaboration
|
||||
|
||||
When working with other AI agents (Cursor, Cline, etc.):
|
||||
- Share this document as the system constitution
|
||||
- All agents must acknowledge these rules before making changes
|
||||
- Conflicts are resolved by defaulting to the *safer* option
|
||||
109
docs/skills.md
Normal file
109
docs/skills.md
Normal file
@@ -0,0 +1,109 @@
|
||||
# The Ouroboros - Available Skills & Tools
|
||||
|
||||
## Development Tools
|
||||
|
||||
### Run Tests
|
||||
```bash
|
||||
pytest -v --cov=src --cov-report=term-missing
|
||||
```
|
||||
Run the full test suite with coverage reporting. **Must pass before any merge.**
|
||||
|
||||
### Run Specific Test Module
|
||||
```bash
|
||||
pytest tests/test_risk.py -v
|
||||
pytest tests/test_broker.py -v
|
||||
pytest tests/test_brain.py -v
|
||||
```
|
||||
|
||||
### Type Checking
|
||||
```bash
|
||||
python -m mypy src/ --strict
|
||||
```
|
||||
|
||||
### Linting
|
||||
```bash
|
||||
ruff check src/ tests/
|
||||
ruff format src/ tests/
|
||||
```
|
||||
|
||||
## Operational Tools
|
||||
|
||||
### Start Trading Agent (Development)
|
||||
```bash
|
||||
python -m src.main --mode=paper
|
||||
```
|
||||
Runs the agent in paper-trading mode (no real orders).
|
||||
|
||||
### Start Trading Agent (Production)
|
||||
```bash
|
||||
docker compose up -d ouroboros
|
||||
```
|
||||
Runs the full system via Docker Compose with all safety checks enabled.
|
||||
|
||||
### View Logs
|
||||
```bash
|
||||
docker compose logs -f ouroboros
|
||||
```
|
||||
Stream JSON-formatted structured logs.
|
||||
|
||||
### Run Backtester
|
||||
```bash
|
||||
python -m src.evolution.optimizer --backtest --days=30
|
||||
```
|
||||
Analyze the last 30 days of trade logs and generate performance metrics.
|
||||
|
||||
## Evolution Tools
|
||||
|
||||
### Generate New Strategy
|
||||
```bash
|
||||
python -m src.evolution.optimizer --evolve
|
||||
```
|
||||
Triggers the evolution engine to:
|
||||
1. Analyze `trade_logs.db` for failing patterns
|
||||
2. Ask Gemini to generate a new strategy
|
||||
3. Run tests on the new strategy
|
||||
4. Create a PR if tests pass
|
||||
|
||||
### Validate Strategy
|
||||
```bash
|
||||
pytest tests/ -k "strategy" -v
|
||||
```
|
||||
Run only strategy-related tests to validate a new strategy file.
|
||||
|
||||
## Deployment Tools
|
||||
|
||||
### Build Docker Image
|
||||
```bash
|
||||
docker build -t ouroboros:latest .
|
||||
```
|
||||
|
||||
### Deploy with Docker Compose
|
||||
```bash
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
### Health Check
|
||||
```bash
|
||||
curl http://localhost:8080/health
|
||||
```
|
||||
|
||||
## Database Tools
|
||||
|
||||
### View Trade Logs
|
||||
```bash
|
||||
sqlite3 data/trade_logs.db "SELECT * FROM trades ORDER BY timestamp DESC LIMIT 20;"
|
||||
```
|
||||
|
||||
### Export Trade History
|
||||
```bash
|
||||
sqlite3 -header -csv data/trade_logs.db "SELECT * FROM trades;" > trades_export.csv
|
||||
```
|
||||
|
||||
## Safety Checklist (Pre-Deploy)
|
||||
|
||||
- [ ] `pytest -v --cov=src` passes with >= 80% coverage
|
||||
- [ ] `ruff check src/ tests/` reports no errors
|
||||
- [ ] `.env` file contains valid KIS and Gemini API keys
|
||||
- [ ] Circuit breaker threshold is set to -3.0% or stricter
|
||||
- [ ] Rate limiter is configured for KIS API limits
|
||||
- [ ] Docker health check endpoint responds 200
|
||||
40
pyproject.toml
Normal file
40
pyproject.toml
Normal file
@@ -0,0 +1,40 @@
|
||||
[project]
|
||||
name = "the-ouroboros"
|
||||
version = "0.1.0"
|
||||
description = "Evolutionary AI Trading Agent for KIS"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"aiohttp>=3.9,<4",
|
||||
"pydantic>=2.5,<3",
|
||||
"pydantic-settings>=2.1,<3",
|
||||
"google-generativeai>=0.8,<1",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0,<9",
|
||||
"pytest-asyncio>=0.23,<1",
|
||||
"pytest-cov>=5.0,<6",
|
||||
"ruff>=0.5,<1",
|
||||
"mypy>=1.10,<2",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
ouroboros = "src.main:main"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
asyncio_mode = "auto"
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py311"
|
||||
line-length = 100
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "N", "W", "UP"]
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.11"
|
||||
strict = true
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
0
src/brain/__init__.py
Normal file
0
src/brain/__init__.py
Normal file
152
src/brain/gemini_client.py
Normal file
152
src/brain/gemini_client.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""Decision engine powered by Google Gemini.
|
||||
|
||||
Constructs prompts from market data, calls Gemini, and parses structured
|
||||
JSON responses into validated TradeDecision objects.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import google.generativeai as genai
|
||||
|
||||
from src.config import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VALID_ACTIONS = {"BUY", "SELL", "HOLD"}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TradeDecision:
|
||||
"""Validated decision from the AI brain."""
|
||||
|
||||
action: str # "BUY" | "SELL" | "HOLD"
|
||||
confidence: int # 0-100
|
||||
rationale: str
|
||||
|
||||
|
||||
class GeminiClient:
|
||||
"""Wraps the Gemini API for trade decision-making."""
|
||||
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
self._settings = settings
|
||||
self._confidence_threshold = settings.CONFIDENCE_THRESHOLD
|
||||
genai.configure(api_key=settings.GEMINI_API_KEY)
|
||||
self._model = genai.GenerativeModel(settings.GEMINI_MODEL)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Prompt Construction
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def build_prompt(self, market_data: dict[str, Any]) -> str:
|
||||
"""Build a structured prompt from market data.
|
||||
|
||||
The prompt instructs Gemini to return valid JSON with action,
|
||||
confidence, and rationale fields.
|
||||
"""
|
||||
return (
|
||||
"You are a professional Korean stock market trading analyst.\n"
|
||||
"Analyze the following market data and decide whether to BUY, SELL, or HOLD.\n\n"
|
||||
f"Stock Code: {market_data['stock_code']}\n"
|
||||
f"Current Price: {market_data['current_price']}\n"
|
||||
f"Orderbook: {json.dumps(market_data['orderbook'], ensure_ascii=False)}\n"
|
||||
f"Foreigner Net Buy/Sell: {market_data['foreigner_net']}\n\n"
|
||||
"You MUST respond with ONLY valid JSON in the following format:\n"
|
||||
'{"action": "BUY"|"SELL"|"HOLD", "confidence": <int 0-100>, "rationale": "<string>"}\n\n'
|
||||
"Rules:\n"
|
||||
"- action must be exactly one of: BUY, SELL, HOLD\n"
|
||||
"- confidence must be an integer from 0 to 100\n"
|
||||
"- rationale must explain your reasoning concisely\n"
|
||||
"- Do NOT wrap the JSON in markdown code blocks\n"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Response Parsing
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def parse_response(self, raw: str) -> TradeDecision:
|
||||
"""Parse a raw Gemini response into a TradeDecision.
|
||||
|
||||
Handles: valid JSON, JSON wrapped in markdown code blocks,
|
||||
malformed JSON, missing fields, and invalid action values.
|
||||
|
||||
On any failure, returns a safe HOLD with confidence 0.
|
||||
"""
|
||||
if not raw or not raw.strip():
|
||||
logger.warning("Empty response from Gemini — defaulting to HOLD")
|
||||
return TradeDecision(action="HOLD", confidence=0, rationale="Empty response")
|
||||
|
||||
# Strip markdown code fences if present
|
||||
cleaned = raw.strip()
|
||||
match = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", cleaned, re.DOTALL)
|
||||
if match:
|
||||
cleaned = match.group(1).strip()
|
||||
|
||||
try:
|
||||
data = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Malformed JSON from Gemini — defaulting to HOLD")
|
||||
return TradeDecision(
|
||||
action="HOLD", confidence=0, rationale="Malformed JSON response"
|
||||
)
|
||||
|
||||
# Validate required fields
|
||||
if not all(k in data for k in ("action", "confidence", "rationale")):
|
||||
logger.warning("Missing fields in Gemini response — defaulting to HOLD")
|
||||
return TradeDecision(
|
||||
action="HOLD", confidence=0, rationale="Missing required fields"
|
||||
)
|
||||
|
||||
action = str(data["action"]).upper()
|
||||
if action not in VALID_ACTIONS:
|
||||
logger.warning("Invalid action '%s' from Gemini — defaulting to HOLD", action)
|
||||
return TradeDecision(
|
||||
action="HOLD", confidence=0, rationale=f"Invalid action: {action}"
|
||||
)
|
||||
|
||||
confidence = int(data["confidence"])
|
||||
rationale = str(data["rationale"])
|
||||
|
||||
# Enforce confidence threshold
|
||||
if confidence < self._confidence_threshold:
|
||||
logger.info(
|
||||
"Confidence %d < threshold %d — forcing HOLD",
|
||||
confidence,
|
||||
self._confidence_threshold,
|
||||
)
|
||||
action = "HOLD"
|
||||
|
||||
return TradeDecision(action=action, confidence=confidence, rationale=rationale)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# API Call
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def decide(self, market_data: dict[str, Any]) -> TradeDecision:
|
||||
"""Build prompt, call Gemini, and return a parsed decision."""
|
||||
prompt = self.build_prompt(market_data)
|
||||
logger.info("Requesting trade decision from Gemini")
|
||||
|
||||
try:
|
||||
response = await self._model.generate_content_async(prompt)
|
||||
raw = response.text
|
||||
except Exception as exc:
|
||||
logger.error("Gemini API error: %s", exc)
|
||||
return TradeDecision(
|
||||
action="HOLD", confidence=0, rationale=f"API error: {exc}"
|
||||
)
|
||||
|
||||
decision = self.parse_response(raw)
|
||||
logger.info(
|
||||
"Gemini decision",
|
||||
extra={
|
||||
"action": decision.action,
|
||||
"confidence": decision.confidence,
|
||||
},
|
||||
)
|
||||
return decision
|
||||
0
src/broker/__init__.py
Normal file
0
src/broker/__init__.py
Normal file
245
src/broker/kis_api.py
Normal file
245
src/broker/kis_api.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""Async wrapper for the Korea Investment Securities (KIS) Open API.
|
||||
|
||||
Handles token refresh, rate limiting (leaky bucket), and hash key generation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from src.config import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LeakyBucket:
|
||||
"""Simple leaky-bucket rate limiter for async code."""
|
||||
|
||||
def __init__(self, rate: float) -> None:
|
||||
"""Args:
|
||||
rate: Maximum requests per second.
|
||||
"""
|
||||
self._rate = rate
|
||||
self._interval = 1.0 / rate
|
||||
self._last = 0.0
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def acquire(self) -> None:
|
||||
async with self._lock:
|
||||
now = asyncio.get_event_loop().time()
|
||||
wait = self._last + self._interval - now
|
||||
if wait > 0:
|
||||
await asyncio.sleep(wait)
|
||||
self._last = asyncio.get_event_loop().time()
|
||||
|
||||
|
||||
class KISBroker:
|
||||
"""Async client for KIS Open API with automatic token management."""
|
||||
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
self._settings = settings
|
||||
self._base_url = settings.KIS_BASE_URL
|
||||
self._app_key = settings.KIS_APP_KEY
|
||||
self._app_secret = settings.KIS_APP_SECRET
|
||||
self._account_no = settings.account_number
|
||||
self._product_cd = settings.account_product_code
|
||||
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._access_token: str | None = None
|
||||
self._token_expires_at: float = 0.0
|
||||
self._rate_limiter = LeakyBucket(settings.RATE_LIMIT_RPS)
|
||||
|
||||
def _get_session(self) -> aiohttp.ClientSession:
|
||||
if self._session is None or self._session.closed:
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
self._session = aiohttp.ClientSession(timeout=timeout)
|
||||
return self._session
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Token Management
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _ensure_token(self) -> str:
|
||||
"""Return a valid access token, refreshing if expired."""
|
||||
now = asyncio.get_event_loop().time()
|
||||
if self._access_token and now < self._token_expires_at:
|
||||
return self._access_token
|
||||
|
||||
logger.info("Refreshing KIS access token")
|
||||
session = self._get_session()
|
||||
url = f"{self._base_url}/oauth2/tokenP"
|
||||
body = {
|
||||
"grant_type": "client_credentials",
|
||||
"appkey": self._app_key,
|
||||
"appsecret": self._app_secret,
|
||||
}
|
||||
|
||||
async with session.post(url, json=body) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ConnectionError(f"Token refresh failed ({resp.status}): {text}")
|
||||
data = await resp.json()
|
||||
|
||||
self._access_token = data["access_token"]
|
||||
self._token_expires_at = now + data.get("expires_in", 86400) - 60 # 1-min buffer
|
||||
logger.info("Token refreshed successfully")
|
||||
return self._access_token
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Hash Key (required for POST bodies)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _get_hash_key(self, body: dict[str, Any]) -> str:
|
||||
"""Request a hash key from KIS for POST request body signing."""
|
||||
session = self._get_session()
|
||||
url = f"{self._base_url}/uapi/hashkey"
|
||||
headers = {
|
||||
"content-Type": "application/json",
|
||||
"appKey": self._app_key,
|
||||
"appSecret": self._app_secret,
|
||||
}
|
||||
|
||||
async with session.post(url, json=body, headers=headers) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ConnectionError(f"Hash key request failed ({resp.status}): {text}")
|
||||
data = await resp.json()
|
||||
|
||||
return data["HASH"]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Common Headers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _auth_headers(self, tr_id: str) -> dict[str, str]:
|
||||
token = await self._ensure_token()
|
||||
return {
|
||||
"content-type": "application/json; charset=utf-8",
|
||||
"authorization": f"Bearer {token}",
|
||||
"appkey": self._app_key,
|
||||
"appsecret": self._app_secret,
|
||||
"tr_id": tr_id,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# API Methods
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_orderbook(self, stock_code: str) -> dict[str, Any]:
|
||||
"""Fetch the current orderbook for a given stock code."""
|
||||
await self._rate_limiter.acquire()
|
||||
session = self._get_session()
|
||||
|
||||
headers = await self._auth_headers("FHKST01010200")
|
||||
params = {
|
||||
"FID_COND_MRKT_DIV_CODE": "J",
|
||||
"FID_INPUT_ISCD": stock_code,
|
||||
}
|
||||
url = f"{self._base_url}/uapi/domestic-stock/v1/quotations/inquire-asking-price-exp-ccn"
|
||||
|
||||
try:
|
||||
async with session.get(url, headers=headers, params=params) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ConnectionError(
|
||||
f"get_orderbook failed ({resp.status}): {text}"
|
||||
)
|
||||
return await resp.json()
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError) as exc:
|
||||
raise ConnectionError(f"Network error fetching orderbook: {exc}") from exc
|
||||
|
||||
async def get_balance(self) -> dict[str, Any]:
|
||||
"""Fetch current account balance and holdings."""
|
||||
await self._rate_limiter.acquire()
|
||||
session = self._get_session()
|
||||
|
||||
headers = await self._auth_headers("VTTC8434R") # 모의투자 잔고조회
|
||||
params = {
|
||||
"CANO": self._account_no,
|
||||
"ACNT_PRDT_CD": self._product_cd,
|
||||
"AFHR_FLPR_YN": "N",
|
||||
"OFL_YN": "",
|
||||
"INQR_DVSN": "02",
|
||||
"UNPR_DVSN": "01",
|
||||
"FUND_STTL_ICLD_YN": "N",
|
||||
"FNCG_AMT_AUTO_RDPT_YN": "N",
|
||||
"PRCS_DVSN": "01",
|
||||
"CTX_AREA_FK100": "",
|
||||
"CTX_AREA_NK100": "",
|
||||
}
|
||||
url = f"{self._base_url}/uapi/domestic-stock/v1/trading/inquire-balance"
|
||||
|
||||
try:
|
||||
async with session.get(url, headers=headers, params=params) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ConnectionError(
|
||||
f"get_balance failed ({resp.status}): {text}"
|
||||
)
|
||||
return await resp.json()
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError) as exc:
|
||||
raise ConnectionError(f"Network error fetching balance: {exc}") from exc
|
||||
|
||||
async def send_order(
|
||||
self,
|
||||
stock_code: str,
|
||||
order_type: str, # "BUY" or "SELL"
|
||||
quantity: int,
|
||||
price: int = 0,
|
||||
) -> dict[str, Any]:
|
||||
"""Submit a buy or sell order.
|
||||
|
||||
Args:
|
||||
stock_code: 6-digit stock code.
|
||||
order_type: "BUY" or "SELL".
|
||||
quantity: Number of shares.
|
||||
price: Order price (0 for market order).
|
||||
"""
|
||||
await self._rate_limiter.acquire()
|
||||
session = self._get_session()
|
||||
|
||||
tr_id = "VTTC0802U" if order_type == "BUY" else "VTTC0801U"
|
||||
body = {
|
||||
"CANO": self._account_no,
|
||||
"ACNT_PRDT_CD": self._product_cd,
|
||||
"PDNO": stock_code,
|
||||
"ORD_DVSN": "01" if price > 0 else "06", # 01=지정가, 06=시장가
|
||||
"ORD_QTY": str(quantity),
|
||||
"ORD_UNPR": str(price),
|
||||
}
|
||||
|
||||
hash_key = await self._get_hash_key(body)
|
||||
headers = await self._auth_headers(tr_id)
|
||||
headers["hashkey"] = hash_key
|
||||
|
||||
url = f"{self._base_url}/uapi/domestic-stock/v1/trading/order-cash"
|
||||
|
||||
try:
|
||||
async with session.post(url, headers=headers, json=body) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise ConnectionError(
|
||||
f"send_order failed ({resp.status}): {text}"
|
||||
)
|
||||
data = await resp.json()
|
||||
logger.info(
|
||||
"Order submitted",
|
||||
extra={
|
||||
"stock_code": stock_code,
|
||||
"action": order_type,
|
||||
},
|
||||
)
|
||||
return data
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError) as exc:
|
||||
raise ConnectionError(f"Network error sending order: {exc}") from exc
|
||||
44
src/config.py
Normal file
44
src/config.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Strictly typed configuration loaded from environment variables."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings — loaded from .env or environment variables."""
|
||||
|
||||
# KIS Open API
|
||||
KIS_APP_KEY: str
|
||||
KIS_APP_SECRET: str
|
||||
KIS_ACCOUNT_NO: str # format: "XXXXXXXX-XX"
|
||||
KIS_BASE_URL: str = "https://openapivts.koreainvestment.com:9443"
|
||||
|
||||
# Google Gemini
|
||||
GEMINI_API_KEY: str
|
||||
GEMINI_MODEL: str = "gemini-pro"
|
||||
|
||||
# Risk Management
|
||||
CIRCUIT_BREAKER_PCT: float = Field(default=-3.0, le=0.0)
|
||||
FAT_FINGER_PCT: float = Field(default=30.0, gt=0.0, le=100.0)
|
||||
CONFIDENCE_THRESHOLD: int = Field(default=80, ge=0, le=100)
|
||||
|
||||
# Database
|
||||
DB_PATH: str = "data/trade_logs.db"
|
||||
|
||||
# Rate Limiting (requests per second for KIS API)
|
||||
RATE_LIMIT_RPS: float = 10.0
|
||||
|
||||
# Trading mode
|
||||
MODE: str = Field(default="paper", pattern="^(paper|live)$")
|
||||
|
||||
model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}
|
||||
|
||||
@property
|
||||
def account_number(self) -> str:
|
||||
return self.KIS_ACCOUNT_NO.split("-")[0]
|
||||
|
||||
@property
|
||||
def account_product_code(self) -> str:
|
||||
return self.KIS_ACCOUNT_NO.split("-")[1]
|
||||
0
src/core/__init__.py
Normal file
0
src/core/__init__.py
Normal file
84
src/core/risk_manager.py
Normal file
84
src/core/risk_manager.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Risk management — the Shield that protects the portfolio.
|
||||
|
||||
This module is READ-ONLY by policy (see docs/agents.md).
|
||||
Changes require human approval and two passing test suites.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from src.config import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CircuitBreakerTripped(SystemExit):
|
||||
"""Raised when daily P&L loss exceeds the allowed threshold."""
|
||||
|
||||
def __init__(self, pnl_pct: float, threshold: float) -> None:
|
||||
self.pnl_pct = pnl_pct
|
||||
self.threshold = threshold
|
||||
super().__init__(
|
||||
f"CIRCUIT BREAKER: Daily P&L {pnl_pct:.2f}% exceeded "
|
||||
f"threshold {threshold:.2f}%. All trading halted."
|
||||
)
|
||||
|
||||
|
||||
class FatFingerRejected(Exception):
|
||||
"""Raised when an order exceeds the maximum allowed proportion of cash."""
|
||||
|
||||
def __init__(self, order_amount: float, total_cash: float, max_pct: float) -> None:
|
||||
self.order_amount = order_amount
|
||||
self.total_cash = total_cash
|
||||
self.max_pct = max_pct
|
||||
ratio = (order_amount / total_cash * 100) if total_cash > 0 else float("inf")
|
||||
super().__init__(
|
||||
f"FAT FINGER: Order {order_amount:,.0f} is {ratio:.1f}% of "
|
||||
f"cash {total_cash:,.0f} (max allowed: {max_pct:.1f}%)."
|
||||
)
|
||||
|
||||
|
||||
class RiskManager:
|
||||
"""Pre-order risk gate that enforces circuit breaker and fat-finger checks."""
|
||||
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
self._cb_threshold = settings.CIRCUIT_BREAKER_PCT
|
||||
self._ff_max_pct = settings.FAT_FINGER_PCT
|
||||
|
||||
def check_circuit_breaker(self, current_pnl_pct: float) -> None:
|
||||
"""Halt trading if daily loss exceeds the threshold.
|
||||
|
||||
The threshold is inclusive: exactly -3.0% is allowed, but -3.01% is not.
|
||||
"""
|
||||
if current_pnl_pct < self._cb_threshold:
|
||||
logger.critical(
|
||||
"Circuit breaker tripped",
|
||||
extra={"pnl_pct": current_pnl_pct},
|
||||
)
|
||||
raise CircuitBreakerTripped(current_pnl_pct, self._cb_threshold)
|
||||
|
||||
def check_fat_finger(self, order_amount: float, total_cash: float) -> None:
|
||||
"""Reject orders that exceed the maximum proportion of available cash."""
|
||||
if total_cash <= 0:
|
||||
raise FatFingerRejected(order_amount, total_cash, self._ff_max_pct)
|
||||
|
||||
ratio_pct = (order_amount / total_cash) * 100
|
||||
if ratio_pct > self._ff_max_pct:
|
||||
logger.warning(
|
||||
"Fat finger check failed",
|
||||
extra={"order_amount": order_amount},
|
||||
)
|
||||
raise FatFingerRejected(order_amount, total_cash, self._ff_max_pct)
|
||||
|
||||
def validate_order(
|
||||
self,
|
||||
current_pnl_pct: float,
|
||||
order_amount: float,
|
||||
total_cash: float,
|
||||
) -> None:
|
||||
"""Run all pre-order risk checks. Raises on failure."""
|
||||
self.check_circuit_breaker(current_pnl_pct)
|
||||
self.check_fat_finger(order_amount, total_cash)
|
||||
logger.info("Order passed risk validation")
|
||||
59
src/db.py
Normal file
59
src/db.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Database layer for trade logging."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
|
||||
def init_db(db_path: str) -> sqlite3.Connection:
|
||||
"""Initialize the trade logs database and return a connection."""
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS trades (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
timestamp TEXT NOT NULL,
|
||||
stock_code TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
confidence INTEGER NOT NULL,
|
||||
rationale TEXT,
|
||||
quantity INTEGER,
|
||||
price REAL,
|
||||
pnl REAL DEFAULT 0.0
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
return conn
|
||||
|
||||
|
||||
def log_trade(
|
||||
conn: sqlite3.Connection,
|
||||
stock_code: str,
|
||||
action: str,
|
||||
confidence: int,
|
||||
rationale: str,
|
||||
quantity: int = 0,
|
||||
price: float = 0.0,
|
||||
pnl: float = 0.0,
|
||||
) -> None:
|
||||
"""Insert a trade record into the database."""
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO trades (timestamp, stock_code, action, confidence, rationale, quantity, price, pnl)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
datetime.now(timezone.utc).isoformat(),
|
||||
stock_code,
|
||||
action,
|
||||
confidence,
|
||||
rationale,
|
||||
quantity,
|
||||
price,
|
||||
pnl,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
0
src/evolution/__init__.py
Normal file
0
src/evolution/__init__.py
Normal file
229
src/evolution/optimizer.py
Normal file
229
src/evolution/optimizer.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""Evolution Engine — analyzes trade logs and generates new strategies.
|
||||
|
||||
This module:
|
||||
1. Reads trade_logs.db to identify failing patterns
|
||||
2. Asks Gemini to generate a new strategy class
|
||||
3. Runs pytest on the generated file
|
||||
4. Creates a simulated PR if tests pass
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import subprocess
|
||||
import textwrap
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import google.generativeai as genai
|
||||
|
||||
from src.config import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
STRATEGIES_DIR = Path("src/strategies")
|
||||
STRATEGY_TEMPLATE = textwrap.dedent("""\
|
||||
\"\"\"Auto-generated strategy: {name}
|
||||
|
||||
Generated at: {timestamp}
|
||||
Rationale: {rationale}
|
||||
\"\"\"
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Any
|
||||
from src.strategies.base import BaseStrategy
|
||||
|
||||
|
||||
class {class_name}(BaseStrategy):
|
||||
\"\"\"Strategy: {name}\"\"\"
|
||||
|
||||
def evaluate(self, market_data: dict[str, Any]) -> dict[str, Any]:
|
||||
{body}
|
||||
""")
|
||||
|
||||
|
||||
class EvolutionOptimizer:
|
||||
"""Analyzes trade history and evolves trading strategies."""
|
||||
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
self._settings = settings
|
||||
self._db_path = settings.DB_PATH
|
||||
genai.configure(api_key=settings.GEMINI_API_KEY)
|
||||
self._model = genai.GenerativeModel(settings.GEMINI_MODEL)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Analysis
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def analyze_failures(self, limit: int = 50) -> list[dict[str, Any]]:
|
||||
"""Find trades where high confidence led to losses."""
|
||||
conn = sqlite3.connect(self._db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT stock_code, action, confidence, pnl, rationale, timestamp
|
||||
FROM trades
|
||||
WHERE confidence >= 80 AND pnl < 0
|
||||
ORDER BY pnl ASC
|
||||
LIMIT ?
|
||||
""",
|
||||
(limit,),
|
||||
).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_performance_summary(self) -> dict[str, Any]:
|
||||
"""Return aggregate performance metrics from trade logs."""
|
||||
conn = sqlite3.connect(self._db_path)
|
||||
try:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
COUNT(*) as total_trades,
|
||||
SUM(CASE WHEN pnl > 0 THEN 1 ELSE 0 END) as wins,
|
||||
SUM(CASE WHEN pnl < 0 THEN 1 ELSE 0 END) as losses,
|
||||
COALESCE(AVG(pnl), 0) as avg_pnl,
|
||||
COALESCE(SUM(pnl), 0) as total_pnl
|
||||
FROM trades
|
||||
"""
|
||||
).fetchone()
|
||||
return {
|
||||
"total_trades": row[0],
|
||||
"wins": row[1] or 0,
|
||||
"losses": row[2] or 0,
|
||||
"avg_pnl": round(row[3], 2),
|
||||
"total_pnl": round(row[4], 2),
|
||||
}
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Strategy Generation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def generate_strategy(self, failures: list[dict[str, Any]]) -> Path | None:
|
||||
"""Ask Gemini to generate a new strategy based on failure analysis.
|
||||
|
||||
Returns the path to the generated strategy file, or None on failure.
|
||||
"""
|
||||
prompt = (
|
||||
"You are a quantitative trading strategy developer.\n"
|
||||
"Analyze these failed trades and generate an improved strategy.\n\n"
|
||||
f"Failed trades:\n{json.dumps(failures, indent=2, default=str)}\n\n"
|
||||
"Generate a Python class that inherits from BaseStrategy.\n"
|
||||
"The class must have an `evaluate(self, market_data: dict) -> dict` method.\n"
|
||||
"The method must return a dict with keys: action, confidence, rationale.\n"
|
||||
"Respond with ONLY the method body (Python code), no class definition.\n"
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self._model.generate_content_async(prompt)
|
||||
body = response.text.strip()
|
||||
except Exception as exc:
|
||||
logger.error("Failed to generate strategy: %s", exc)
|
||||
return None
|
||||
|
||||
# Clean up code fences
|
||||
if body.startswith("```"):
|
||||
lines = body.split("\n")
|
||||
body = "\n".join(lines[1:-1])
|
||||
|
||||
# Create strategy file
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
||||
version = f"v{timestamp}"
|
||||
class_name = f"Strategy_{version}"
|
||||
file_name = f"{version}_evolved.py"
|
||||
|
||||
STRATEGIES_DIR.mkdir(parents=True, exist_ok=True)
|
||||
file_path = STRATEGIES_DIR / file_name
|
||||
|
||||
# Indent the body for the class method
|
||||
indented_body = textwrap.indent(body, " ")
|
||||
|
||||
content = STRATEGY_TEMPLATE.format(
|
||||
name=version,
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
rationale="Auto-evolved from failure analysis",
|
||||
class_name=class_name,
|
||||
body=indented_body.strip(),
|
||||
)
|
||||
|
||||
file_path.write_text(content)
|
||||
logger.info("Generated strategy file: %s", file_path)
|
||||
return file_path
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Validation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def validate_strategy(self, strategy_path: Path) -> bool:
|
||||
"""Run pytest on the generated strategy. Returns True if all tests pass."""
|
||||
logger.info("Validating strategy: %s", strategy_path)
|
||||
result = subprocess.run(
|
||||
["python", "-m", "pytest", "tests/", "-v", "--tb=short"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
logger.info("Strategy validation PASSED")
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
"Strategy validation FAILED:\n%s", result.stdout + result.stderr
|
||||
)
|
||||
# Clean up failing strategy
|
||||
strategy_path.unlink(missing_ok=True)
|
||||
return False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# PR Simulation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def create_pr_simulation(self, strategy_path: Path) -> dict[str, str]:
|
||||
"""Simulate creating a pull request for the new strategy."""
|
||||
pr = {
|
||||
"title": f"[Evolution] New strategy: {strategy_path.stem}",
|
||||
"branch": f"evolution/{strategy_path.stem}",
|
||||
"body": (
|
||||
f"Auto-generated strategy from evolution engine.\n"
|
||||
f"File: {strategy_path}\n"
|
||||
f"All tests passed."
|
||||
),
|
||||
"status": "ready_for_review",
|
||||
}
|
||||
logger.info("PR simulation created: %s", pr["title"])
|
||||
return pr
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Full Pipeline
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def evolve(self) -> dict[str, Any] | None:
|
||||
"""Run the full evolution pipeline.
|
||||
|
||||
1. Analyze failures
|
||||
2. Generate new strategy
|
||||
3. Validate with tests
|
||||
4. Create PR simulation
|
||||
|
||||
Returns PR info on success, None on failure.
|
||||
"""
|
||||
failures = self.analyze_failures()
|
||||
if not failures:
|
||||
logger.info("No failure patterns found — skipping evolution")
|
||||
return None
|
||||
|
||||
strategy_path = await self.generate_strategy(failures)
|
||||
if strategy_path is None:
|
||||
return None
|
||||
|
||||
if not self.validate_strategy(strategy_path):
|
||||
return None
|
||||
|
||||
return self.create_pr_simulation(strategy_path)
|
||||
42
src/logging_config.py
Normal file
42
src/logging_config.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""JSON-formatted structured logging for machine readability."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import json
|
||||
|
||||
|
||||
class JSONFormatter(logging.Formatter):
|
||||
"""Emit log records as single-line JSON objects."""
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
log_entry: dict[str, Any] = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"level": record.levelname,
|
||||
"logger": record.name,
|
||||
"message": record.getMessage(),
|
||||
}
|
||||
if record.exc_info and record.exc_info[1]:
|
||||
log_entry["exception"] = self.formatException(record.exc_info)
|
||||
# Merge any extra fields attached to the record
|
||||
for key in ("stock_code", "action", "confidence", "pnl_pct", "order_amount"):
|
||||
value = getattr(record, key, None)
|
||||
if value is not None:
|
||||
log_entry[key] = value
|
||||
return json.dumps(log_entry, ensure_ascii=False)
|
||||
|
||||
|
||||
def setup_logging(level: int = logging.INFO) -> None:
|
||||
"""Configure the root logger with JSON output to stdout."""
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setFormatter(JSONFormatter())
|
||||
|
||||
root = logging.getLogger()
|
||||
root.setLevel(level)
|
||||
# Avoid duplicate handlers on repeated calls
|
||||
root.handlers.clear()
|
||||
root.addHandler(handler)
|
||||
171
src/main.py
Normal file
171
src/main.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""The Ouroboros — main trading loop.
|
||||
|
||||
Orchestrates the broker, brain, and risk manager into a continuous
|
||||
trading cycle with configurable intervals.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from src.brain.gemini_client import GeminiClient
|
||||
from src.broker.kis_api import KISBroker
|
||||
from src.config import Settings
|
||||
from src.core.risk_manager import CircuitBreakerTripped, RiskManager
|
||||
from src.db import init_db, log_trade
|
||||
from src.logging_config import setup_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Target stock codes to monitor
|
||||
WATCHLIST = ["005930", "000660", "035420"] # Samsung, SK Hynix, NAVER
|
||||
|
||||
TRADE_INTERVAL_SECONDS = 60
|
||||
|
||||
|
||||
async def trading_cycle(
|
||||
broker: KISBroker,
|
||||
brain: GeminiClient,
|
||||
risk: RiskManager,
|
||||
db_conn: Any,
|
||||
stock_code: str,
|
||||
) -> None:
|
||||
"""Execute one trading cycle for a single stock."""
|
||||
# 1. Fetch market data
|
||||
orderbook = await broker.get_orderbook(stock_code)
|
||||
balance_data = await broker.get_balance()
|
||||
|
||||
output2 = balance_data.get("output2", [{}])
|
||||
total_eval = float(output2[0].get("tot_evlu_amt", "0")) if output2 else 0
|
||||
total_cash = float(
|
||||
balance_data.get("output2", [{}])[0].get("dnca_tot_amt", "0")
|
||||
if output2
|
||||
else "0"
|
||||
)
|
||||
purchase_total = float(output2[0].get("pchs_amt_smtl_amt", "0")) if output2 else 0
|
||||
|
||||
# Calculate daily P&L %
|
||||
pnl_pct = ((total_eval - purchase_total) / purchase_total * 100) if purchase_total > 0 else 0.0
|
||||
|
||||
current_price = float(
|
||||
orderbook.get("output1", {}).get("stck_prpr", "0")
|
||||
)
|
||||
|
||||
market_data = {
|
||||
"stock_code": stock_code,
|
||||
"current_price": current_price,
|
||||
"orderbook": orderbook.get("output1", {}),
|
||||
"foreigner_net": float(
|
||||
orderbook.get("output1", {}).get("frgn_ntby_qty", "0")
|
||||
),
|
||||
}
|
||||
|
||||
# 2. Ask the brain for a decision
|
||||
decision = await brain.decide(market_data)
|
||||
logger.info(
|
||||
"Decision for %s: %s (confidence=%d)",
|
||||
stock_code,
|
||||
decision.action,
|
||||
decision.confidence,
|
||||
)
|
||||
|
||||
# 3. Execute if actionable
|
||||
if decision.action in ("BUY", "SELL"):
|
||||
# Determine order size (simplified: 1 lot)
|
||||
quantity = 1
|
||||
order_amount = current_price * quantity
|
||||
|
||||
# 4. Risk check BEFORE order
|
||||
risk.validate_order(
|
||||
current_pnl_pct=pnl_pct,
|
||||
order_amount=order_amount,
|
||||
total_cash=total_cash,
|
||||
)
|
||||
|
||||
# 5. Send order
|
||||
result = await broker.send_order(
|
||||
stock_code=stock_code,
|
||||
order_type=decision.action,
|
||||
quantity=quantity,
|
||||
price=0, # market order
|
||||
)
|
||||
logger.info("Order result: %s", result.get("msg1", "OK"))
|
||||
|
||||
# 6. Log trade
|
||||
log_trade(
|
||||
conn=db_conn,
|
||||
stock_code=stock_code,
|
||||
action=decision.action,
|
||||
confidence=decision.confidence,
|
||||
rationale=decision.rationale,
|
||||
)
|
||||
|
||||
|
||||
async def run(settings: Settings) -> None:
|
||||
"""Main async loop — iterate over watchlist on a timer."""
|
||||
broker = KISBroker(settings)
|
||||
brain = GeminiClient(settings)
|
||||
risk = RiskManager(settings)
|
||||
db_conn = init_db(settings.DB_PATH)
|
||||
|
||||
shutdown = asyncio.Event()
|
||||
|
||||
def _signal_handler() -> None:
|
||||
logger.info("Shutdown signal received")
|
||||
shutdown.set()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
loop.add_signal_handler(sig, _signal_handler)
|
||||
|
||||
logger.info("The Ouroboros is alive. Mode: %s", settings.MODE)
|
||||
logger.info("Watchlist: %s", WATCHLIST)
|
||||
|
||||
try:
|
||||
while not shutdown.is_set():
|
||||
for code in WATCHLIST:
|
||||
if shutdown.is_set():
|
||||
break
|
||||
try:
|
||||
await trading_cycle(broker, brain, risk, db_conn, code)
|
||||
except CircuitBreakerTripped:
|
||||
logger.critical("Circuit breaker tripped — shutting down")
|
||||
raise
|
||||
except ConnectionError as exc:
|
||||
logger.error("Connection error for %s: %s", code, exc)
|
||||
except Exception as exc:
|
||||
logger.exception("Unexpected error for %s: %s", code, exc)
|
||||
|
||||
# Wait for next cycle or shutdown
|
||||
try:
|
||||
await asyncio.wait_for(shutdown.wait(), timeout=TRADE_INTERVAL_SECONDS)
|
||||
except asyncio.TimeoutError:
|
||||
pass # Normal — timeout means it's time for next cycle
|
||||
finally:
|
||||
await broker.close()
|
||||
db_conn.close()
|
||||
logger.info("The Ouroboros rests.")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="The Ouroboros Trading Agent")
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=["paper", "live"],
|
||||
default="paper",
|
||||
help="Trading mode (default: paper)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
setup_logging()
|
||||
settings = Settings(MODE=args.mode) # type: ignore[call-arg]
|
||||
asyncio.run(run(settings))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
src/strategies/__init__.py
Normal file
0
src/strategies/__init__.py
Normal file
19
src/strategies/base.py
Normal file
19
src/strategies/base.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Base class for all trading strategies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseStrategy(ABC):
|
||||
"""All strategies must inherit from this class."""
|
||||
|
||||
@abstractmethod
|
||||
def evaluate(self, market_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Evaluate market data and return a trade decision.
|
||||
|
||||
Returns:
|
||||
dict with keys: action ("BUY"|"SELL"|"HOLD"), confidence (int), rationale (str)
|
||||
"""
|
||||
...
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
23
tests/conftest.py
Normal file
23
tests/conftest.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Shared test fixtures for The Ouroboros test suite."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from src.config import Settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def settings() -> Settings:
|
||||
"""Return a Settings instance with safe test defaults."""
|
||||
return Settings(
|
||||
KIS_APP_KEY="test_app_key",
|
||||
KIS_APP_SECRET="test_app_secret",
|
||||
KIS_ACCOUNT_NO="12345678-01",
|
||||
KIS_BASE_URL="https://openapivts.koreainvestment.com:9443",
|
||||
GEMINI_API_KEY="test_gemini_key",
|
||||
CIRCUIT_BREAKER_PCT=-3.0,
|
||||
FAT_FINGER_PCT=30.0,
|
||||
CONFIDENCE_THRESHOLD=80,
|
||||
DB_PATH=":memory:",
|
||||
)
|
||||
159
tests/test_brain.py
Normal file
159
tests/test_brain.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""TDD tests for brain/gemini_client.py — written BEFORE implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.brain.gemini_client import GeminiClient, TradeDecision
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response Parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResponseParsing:
|
||||
"""Gemini responses must be parsed into validated TradeDecision objects."""
|
||||
|
||||
def test_valid_buy_response(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
raw = '{"action": "BUY", "confidence": 90, "rationale": "Strong momentum"}'
|
||||
decision = client.parse_response(raw)
|
||||
assert decision.action == "BUY"
|
||||
assert decision.confidence == 90
|
||||
assert decision.rationale == "Strong momentum"
|
||||
|
||||
def test_valid_sell_response(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
raw = '{"action": "SELL", "confidence": 85, "rationale": "Overbought RSI"}'
|
||||
decision = client.parse_response(raw)
|
||||
assert decision.action == "SELL"
|
||||
|
||||
def test_valid_hold_response(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
raw = '{"action": "HOLD", "confidence": 95, "rationale": "Sideways market"}'
|
||||
decision = client.parse_response(raw)
|
||||
assert decision.action == "HOLD"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Confidence Threshold Enforcement
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConfidenceThreshold:
|
||||
"""If confidence < 80, the action MUST be forced to HOLD."""
|
||||
|
||||
def test_low_confidence_buy_becomes_hold(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
raw = '{"action": "BUY", "confidence": 65, "rationale": "Weak signal"}'
|
||||
decision = client.parse_response(raw)
|
||||
assert decision.action == "HOLD"
|
||||
assert decision.confidence == 65
|
||||
|
||||
def test_low_confidence_sell_becomes_hold(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
raw = '{"action": "SELL", "confidence": 79, "rationale": "Uncertain"}'
|
||||
decision = client.parse_response(raw)
|
||||
assert decision.action == "HOLD"
|
||||
|
||||
def test_exactly_threshold_is_allowed(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
raw = '{"action": "BUY", "confidence": 80, "rationale": "Just enough"}'
|
||||
decision = client.parse_response(raw)
|
||||
assert decision.action == "BUY"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Malformed JSON Handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMalformedJsonHandling:
|
||||
"""Gemini may return garbage — the parser must not crash."""
|
||||
|
||||
def test_empty_string_returns_hold(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
decision = client.parse_response("")
|
||||
assert decision.action == "HOLD"
|
||||
assert decision.confidence == 0
|
||||
|
||||
def test_plain_text_returns_hold(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
decision = client.parse_response("I think you should buy Samsung stock")
|
||||
assert decision.action == "HOLD"
|
||||
assert decision.confidence == 0
|
||||
|
||||
def test_partial_json_returns_hold(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
decision = client.parse_response('{"action": "BUY", "confidence":')
|
||||
assert decision.action == "HOLD"
|
||||
assert decision.confidence == 0
|
||||
|
||||
def test_json_with_missing_fields_returns_hold(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
decision = client.parse_response('{"action": "BUY"}')
|
||||
assert decision.action == "HOLD"
|
||||
assert decision.confidence == 0
|
||||
|
||||
def test_json_with_invalid_action_returns_hold(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
decision = client.parse_response(
|
||||
'{"action": "YOLO", "confidence": 99, "rationale": "moon"}'
|
||||
)
|
||||
assert decision.action == "HOLD"
|
||||
assert decision.confidence == 0
|
||||
|
||||
def test_json_wrapped_in_markdown_code_block(self, settings):
|
||||
"""Gemini often wraps JSON in ```json ... ``` blocks."""
|
||||
client = GeminiClient(settings)
|
||||
raw = '```json\n{"action": "BUY", "confidence": 92, "rationale": "Good"}\n```'
|
||||
decision = client.parse_response(raw)
|
||||
assert decision.action == "BUY"
|
||||
assert decision.confidence == 92
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompt Construction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPromptConstruction:
|
||||
"""The prompt sent to Gemini must include all required market data."""
|
||||
|
||||
def test_prompt_contains_stock_code(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
market_data = {
|
||||
"stock_code": "005930",
|
||||
"current_price": 72000,
|
||||
"orderbook": {"asks": [], "bids": []},
|
||||
"foreigner_net": -50000,
|
||||
}
|
||||
prompt = client.build_prompt(market_data)
|
||||
assert "005930" in prompt
|
||||
|
||||
def test_prompt_contains_price(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
market_data = {
|
||||
"stock_code": "005930",
|
||||
"current_price": 72000,
|
||||
"orderbook": {"asks": [], "bids": []},
|
||||
"foreigner_net": -50000,
|
||||
}
|
||||
prompt = client.build_prompt(market_data)
|
||||
assert "72000" in prompt
|
||||
|
||||
def test_prompt_enforces_json_output_format(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
market_data = {
|
||||
"stock_code": "005930",
|
||||
"current_price": 72000,
|
||||
"orderbook": {"asks": [], "bids": []},
|
||||
"foreigner_net": 0,
|
||||
}
|
||||
prompt = client.build_prompt(market_data)
|
||||
assert "JSON" in prompt
|
||||
assert "action" in prompt
|
||||
assert "confidence" in prompt
|
||||
140
tests/test_broker.py
Normal file
140
tests/test_broker.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""TDD tests for broker/kis_api.py — written BEFORE implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
from src.broker.kis_api import KISBroker
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Token Management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTokenManagement:
|
||||
"""Access token must be auto-refreshed and cached."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetches_token_on_first_call(self, settings):
|
||||
broker = KISBroker(settings)
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(
|
||||
return_value={
|
||||
"access_token": "tok_abc123",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 86400,
|
||||
}
|
||||
)
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
|
||||
token = await broker._ensure_token()
|
||||
assert token == "tok_abc123"
|
||||
|
||||
await broker.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reuses_cached_token(self, settings):
|
||||
broker = KISBroker(settings)
|
||||
broker._access_token = "cached_token"
|
||||
broker._token_expires_at = asyncio.get_event_loop().time() + 3600
|
||||
|
||||
token = await broker._ensure_token()
|
||||
assert token == "cached_token"
|
||||
|
||||
await broker.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Network Error Handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNetworkErrorHandling:
|
||||
"""Broker must handle network timeouts and HTTP errors gracefully."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_raises_connection_error(self, settings):
|
||||
broker = KISBroker(settings)
|
||||
broker._access_token = "tok"
|
||||
broker._token_expires_at = asyncio.get_event_loop().time() + 3600
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.get",
|
||||
side_effect=asyncio.TimeoutError(),
|
||||
):
|
||||
with pytest.raises(ConnectionError):
|
||||
await broker.get_orderbook("005930")
|
||||
|
||||
await broker.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_500_raises_connection_error(self, settings):
|
||||
broker = KISBroker(settings)
|
||||
broker._access_token = "tok"
|
||||
broker._token_expires_at = asyncio.get_event_loop().time() + 3600
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 500
|
||||
mock_resp.text = AsyncMock(return_value="Internal Server Error")
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.get", return_value=mock_resp):
|
||||
with pytest.raises(ConnectionError):
|
||||
await broker.get_orderbook("005930")
|
||||
|
||||
await broker.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rate Limiter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRateLimiter:
|
||||
"""The leaky bucket rate limiter must throttle requests."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limiter_does_not_block_under_limit(self, settings):
|
||||
broker = KISBroker(settings)
|
||||
# Should complete without blocking when under limit
|
||||
await broker._rate_limiter.acquire()
|
||||
await broker.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hash Key Generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHashKey:
|
||||
"""POST requests to KIS require a hash key."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generates_hash_key_for_post_body(self, settings):
|
||||
broker = KISBroker(settings)
|
||||
broker._access_token = "tok"
|
||||
broker._token_expires_at = asyncio.get_event_loop().time() + 3600
|
||||
|
||||
body = {"CANO": "12345678", "ACNT_PRDT_CD": "01"}
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(return_value={"HASH": "abc123hash"})
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
|
||||
hash_key = await broker._get_hash_key(body)
|
||||
assert isinstance(hash_key, str)
|
||||
assert len(hash_key) > 0
|
||||
|
||||
await broker.close()
|
||||
132
tests/test_risk.py
Normal file
132
tests/test_risk.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""TDD tests for core/risk_manager.py — written BEFORE implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from src.core.risk_manager import (
|
||||
CircuitBreakerTripped,
|
||||
FatFingerRejected,
|
||||
RiskManager,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Circuit Breaker Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCircuitBreaker:
|
||||
"""The circuit breaker must halt all trading when daily loss exceeds the threshold."""
|
||||
|
||||
def test_allows_trading_when_pnl_is_positive(self, settings):
|
||||
rm = RiskManager(settings)
|
||||
# 2% gain — should be fine
|
||||
rm.check_circuit_breaker(current_pnl_pct=2.0)
|
||||
|
||||
def test_allows_trading_at_zero_pnl(self, settings):
|
||||
rm = RiskManager(settings)
|
||||
rm.check_circuit_breaker(current_pnl_pct=0.0)
|
||||
|
||||
def test_allows_trading_at_exactly_threshold(self, settings):
|
||||
rm = RiskManager(settings)
|
||||
# Exactly -3.0% is ON the boundary — still allowed
|
||||
rm.check_circuit_breaker(current_pnl_pct=-3.0)
|
||||
|
||||
def test_trips_when_loss_exceeds_threshold(self, settings):
|
||||
rm = RiskManager(settings)
|
||||
with pytest.raises(CircuitBreakerTripped):
|
||||
rm.check_circuit_breaker(current_pnl_pct=-3.01)
|
||||
|
||||
def test_trips_at_large_loss(self, settings):
|
||||
rm = RiskManager(settings)
|
||||
with pytest.raises(CircuitBreakerTripped):
|
||||
rm.check_circuit_breaker(current_pnl_pct=-10.0)
|
||||
|
||||
def test_custom_threshold(self):
|
||||
"""A stricter threshold (-1.5%) should trip earlier."""
|
||||
from src.config import Settings
|
||||
|
||||
strict = Settings(
|
||||
KIS_APP_KEY="k",
|
||||
KIS_APP_SECRET="s",
|
||||
KIS_ACCOUNT_NO="00000000-00",
|
||||
KIS_BASE_URL="https://example.com",
|
||||
GEMINI_API_KEY="g",
|
||||
CIRCUIT_BREAKER_PCT=-1.5,
|
||||
FAT_FINGER_PCT=30.0,
|
||||
CONFIDENCE_THRESHOLD=80,
|
||||
DB_PATH=":memory:",
|
||||
)
|
||||
rm = RiskManager(strict)
|
||||
with pytest.raises(CircuitBreakerTripped):
|
||||
rm.check_circuit_breaker(current_pnl_pct=-1.51)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fat Finger Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFatFingerCheck:
|
||||
"""Orders exceeding 30% of total cash must be rejected."""
|
||||
|
||||
def test_allows_small_order(self, settings):
|
||||
rm = RiskManager(settings)
|
||||
# 10% of 10_000_000 = 1_000_000
|
||||
rm.check_fat_finger(order_amount=1_000_000, total_cash=10_000_000)
|
||||
|
||||
def test_allows_order_at_exactly_threshold(self, settings):
|
||||
rm = RiskManager(settings)
|
||||
# Exactly 30% — allowed
|
||||
rm.check_fat_finger(order_amount=3_000_000, total_cash=10_000_000)
|
||||
|
||||
def test_rejects_order_exceeding_threshold(self, settings):
|
||||
rm = RiskManager(settings)
|
||||
with pytest.raises(FatFingerRejected):
|
||||
rm.check_fat_finger(order_amount=3_000_001, total_cash=10_000_000)
|
||||
|
||||
def test_rejects_massive_order(self, settings):
|
||||
rm = RiskManager(settings)
|
||||
with pytest.raises(FatFingerRejected):
|
||||
rm.check_fat_finger(order_amount=9_000_000, total_cash=10_000_000)
|
||||
|
||||
def test_zero_cash_rejects_any_order(self, settings):
|
||||
rm = RiskManager(settings)
|
||||
with pytest.raises(FatFingerRejected):
|
||||
rm.check_fat_finger(order_amount=1, total_cash=0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pre-Order Validation (Integration of both checks)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPreOrderValidation:
|
||||
"""validate_order must run BOTH checks before approving."""
|
||||
|
||||
def test_passes_when_both_checks_ok(self, settings):
|
||||
rm = RiskManager(settings)
|
||||
rm.validate_order(
|
||||
current_pnl_pct=0.5,
|
||||
order_amount=1_000_000,
|
||||
total_cash=10_000_000,
|
||||
)
|
||||
|
||||
def test_fails_on_circuit_breaker(self, settings):
|
||||
rm = RiskManager(settings)
|
||||
with pytest.raises(CircuitBreakerTripped):
|
||||
rm.validate_order(
|
||||
current_pnl_pct=-5.0,
|
||||
order_amount=100,
|
||||
total_cash=10_000_000,
|
||||
)
|
||||
|
||||
def test_fails_on_fat_finger(self, settings):
|
||||
rm = RiskManager(settings)
|
||||
with pytest.raises(FatFingerRejected):
|
||||
rm.validate_order(
|
||||
current_pnl_pct=1.0,
|
||||
order_amount=5_000_000,
|
||||
total_cash=10_000_000,
|
||||
)
|
||||
Reference in New Issue
Block a user