Add complete Ouroboros trading system with TDD test suite
Some checks failed
CI / test (push) Has been cancelled

Implement the full autonomous trading agent architecture:
- KIS broker with async API, token refresh, leaky bucket rate limiter, and hash key signing
- Gemini-powered decision engine with JSON parsing and confidence threshold enforcement
- Risk manager with circuit breaker (-3% P&L) and fat finger protection (30% cap)
- Evolution engine for self-improving strategy generation via failure analysis
- 35 passing tests written TDD-first covering risk, broker, and brain modules
- CI/CD pipeline, Docker multi-stage build, and AI agent context docs

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-02-04 02:08:48 +09:00
parent 9d9945822a
commit d1750af80f
27 changed files with 1842 additions and 0 deletions

23
.env.example Normal file
View File

@@ -0,0 +1,23 @@
# Korea Investment Securities API
KIS_APP_KEY=your_app_key_here
KIS_APP_SECRET=your_app_secret_here
KIS_ACCOUNT_NO=12345678-01
KIS_BASE_URL=https://openapivts.koreainvestment.com:9443
# Google Gemini
GEMINI_API_KEY=your_gemini_api_key_here
GEMINI_MODEL=gemini-pro
# Risk Management
CIRCUIT_BREAKER_PCT=-3.0
FAT_FINGER_PCT=30.0
CONFIDENCE_THRESHOLD=80
# Database
DB_PATH=data/trade_logs.db
# Rate Limiting
RATE_LIMIT_RPS=10.0
# Trading Mode (paper / live)
MODE=paper

39
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,39 @@
name: CI
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.11
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install dependencies
run: pip install ".[dev]"
- name: Lint
run: ruff check src/ tests/
- name: Type check
run: mypy src/ --strict
continue-on-error: true
- name: Run tests with coverage
run: pytest -v --cov=src --cov-report=term-missing --cov-fail-under=80
- name: Upload coverage
if: always()
uses: actions/upload-artifact@v4
with:
name: coverage-report
path: htmlcov/

44
Dockerfile Normal file
View File

@@ -0,0 +1,44 @@
FROM python:3.11-slim AS base
WORKDIR /app
# System deps
RUN apt-get update && apt-get install -y --no-install-recommends \
gcc \
&& rm -rf /var/lib/apt/lists/*
# Install Python dependencies
COPY pyproject.toml .
RUN pip install --no-cache-dir ".[dev]"
# Copy source
COPY src/ src/
COPY tests/ tests/
COPY docs/ docs/
# Create data directory
RUN mkdir -p data
# Run tests as build validation
RUN pytest -v --tb=short
# Production stage
FROM python:3.11-slim AS production
WORKDIR /app
COPY pyproject.toml .
RUN pip install --no-cache-dir .
COPY src/ src/
RUN mkdir -p data
# Non-root user
RUN useradd --create-home appuser
USER appuser
HEALTHCHECK --interval=30s --timeout=5s --retries=3 \
CMD python -c "import sys; sys.exit(0)"
ENTRYPOINT ["python", "-m", "src.main"]
CMD ["--mode=paper"]

30
docker-compose.yml Normal file
View File

@@ -0,0 +1,30 @@
services:
ouroboros:
build:
context: .
target: production
container_name: ouroboros
restart: unless-stopped
env_file:
- .env
volumes:
- trade_data:/app/data
command: ["--mode=paper"]
logging:
driver: json-file
options:
max-size: "10m"
max-file: "3"
# Run tests (one-shot)
test:
build:
context: .
target: base
container_name: ouroboros-test
command: ["pytest", "-v", "--cov=src", "--cov-report=term-missing"]
profiles:
- test
volumes:
trade_data:

58
docs/agents.md Normal file
View File

@@ -0,0 +1,58 @@
# The Ouroboros - Agent Persona Definition
## Role: The Guardian
You are **The Guardian**, the primary AI agent responsible for maintaining and evolving
The Ouroboros trading system. Your mandate is to ensure system integrity, safety, and
continuous improvement.
## Prime Directives
1. **NEVER disable, bypass, or weaken `core/risk_manager.py`.**
The risk manager is the last line of defense against catastrophic loss.
Any code change that reduces risk controls MUST be rejected.
2. **All code changes require a passing test.**
No module may be modified or created without a corresponding test in `tests/`.
Run `pytest -v --cov=src` before proposing any merge.
3. **Preserve the Circuit Breaker.**
The daily P&L circuit breaker (-3.0% threshold) is non-negotiable.
It may only be made *stricter*, never relaxed.
4. **Fat Finger Protection is sacred.**
The 30% max-order-size rule must remain enforced at all times.
## Decision Framework
When modifying code, follow this priority order:
1. **Safety** - Will this change increase risk exposure? If yes, reject.
2. **Correctness** - Is the logic provably correct? Verify with tests.
3. **Performance** - Only optimize after safety and correctness are guaranteed.
4. **Readability** - Code must be understandable by future agents and humans.
## File Ownership
| Module | Guardian Rule |
|---|---|
| `core/risk_manager.py` | READ-ONLY. Changes require human approval + 2 passing test suites. |
| `broker/kis_api.py` | Rate limiter must never be removed. Token refresh must remain automatic. |
| `brain/gemini_client.py` | Confidence < 80 MUST force HOLD. This rule cannot be weakened. |
| `evolution/optimizer.py` | Generated strategies must pass ALL tests before activation. |
| `strategies/*` | New strategies are welcome but must inherit `BaseStrategy`. |
## Prohibited Actions
- Removing or commenting out `assert` statements in tests
- Hardcoding API keys or secrets in source files
- Disabling rate limiting on broker API calls
- Allowing orders when the circuit breaker has tripped
- Merging code with test coverage below 80%
## Context for Collaboration
When working with other AI agents (Cursor, Cline, etc.):
- Share this document as the system constitution
- All agents must acknowledge these rules before making changes
- Conflicts are resolved by defaulting to the *safer* option

109
docs/skills.md Normal file
View File

@@ -0,0 +1,109 @@
# The Ouroboros - Available Skills & Tools
## Development Tools
### Run Tests
```bash
pytest -v --cov=src --cov-report=term-missing
```
Run the full test suite with coverage reporting. **Must pass before any merge.**
### Run Specific Test Module
```bash
pytest tests/test_risk.py -v
pytest tests/test_broker.py -v
pytest tests/test_brain.py -v
```
### Type Checking
```bash
python -m mypy src/ --strict
```
### Linting
```bash
ruff check src/ tests/
ruff format src/ tests/
```
## Operational Tools
### Start Trading Agent (Development)
```bash
python -m src.main --mode=paper
```
Runs the agent in paper-trading mode (no real orders).
### Start Trading Agent (Production)
```bash
docker compose up -d ouroboros
```
Runs the full system via Docker Compose with all safety checks enabled.
### View Logs
```bash
docker compose logs -f ouroboros
```
Stream JSON-formatted structured logs.
### Run Backtester
```bash
python -m src.evolution.optimizer --backtest --days=30
```
Analyze the last 30 days of trade logs and generate performance metrics.
## Evolution Tools
### Generate New Strategy
```bash
python -m src.evolution.optimizer --evolve
```
Triggers the evolution engine to:
1. Analyze `trade_logs.db` for failing patterns
2. Ask Gemini to generate a new strategy
3. Run tests on the new strategy
4. Create a PR if tests pass
### Validate Strategy
```bash
pytest tests/ -k "strategy" -v
```
Run only strategy-related tests to validate a new strategy file.
## Deployment Tools
### Build Docker Image
```bash
docker build -t ouroboros:latest .
```
### Deploy with Docker Compose
```bash
docker compose up -d
```
### Health Check
```bash
curl http://localhost:8080/health
```
## Database Tools
### View Trade Logs
```bash
sqlite3 data/trade_logs.db "SELECT * FROM trades ORDER BY timestamp DESC LIMIT 20;"
```
### Export Trade History
```bash
sqlite3 -header -csv data/trade_logs.db "SELECT * FROM trades;" > trades_export.csv
```
## Safety Checklist (Pre-Deploy)
- [ ] `pytest -v --cov=src` passes with >= 80% coverage
- [ ] `ruff check src/ tests/` reports no errors
- [ ] `.env` file contains valid KIS and Gemini API keys
- [ ] Circuit breaker threshold is set to -3.0% or stricter
- [ ] Rate limiter is configured for KIS API limits
- [ ] Docker health check endpoint responds 200

40
pyproject.toml Normal file
View File

@@ -0,0 +1,40 @@
[project]
name = "the-ouroboros"
version = "0.1.0"
description = "Evolutionary AI Trading Agent for KIS"
requires-python = ">=3.11"
dependencies = [
"aiohttp>=3.9,<4",
"pydantic>=2.5,<3",
"pydantic-settings>=2.1,<3",
"google-generativeai>=0.8,<1",
]
[project.optional-dependencies]
dev = [
"pytest>=8.0,<9",
"pytest-asyncio>=0.23,<1",
"pytest-cov>=5.0,<6",
"ruff>=0.5,<1",
"mypy>=1.10,<2",
]
[project.scripts]
ouroboros = "src.main:main"
[tool.pytest.ini_options]
testpaths = ["tests"]
asyncio_mode = "auto"
[tool.ruff]
target-version = "py311"
line-length = 100
[tool.ruff.lint]
select = ["E", "F", "I", "N", "W", "UP"]
[tool.mypy]
python_version = "3.11"
strict = true
warn_return_any = true
warn_unused_configs = true

0
src/__init__.py Normal file
View File

0
src/brain/__init__.py Normal file
View File

152
src/brain/gemini_client.py Normal file
View File

@@ -0,0 +1,152 @@
"""Decision engine powered by Google Gemini.
Constructs prompts from market data, calls Gemini, and parses structured
JSON responses into validated TradeDecision objects.
"""
from __future__ import annotations
import json
import logging
import re
from dataclasses import dataclass
from typing import Any
import google.generativeai as genai
from src.config import Settings
logger = logging.getLogger(__name__)
VALID_ACTIONS = {"BUY", "SELL", "HOLD"}
@dataclass(frozen=True)
class TradeDecision:
"""Validated decision from the AI brain."""
action: str # "BUY" | "SELL" | "HOLD"
confidence: int # 0-100
rationale: str
class GeminiClient:
"""Wraps the Gemini API for trade decision-making."""
def __init__(self, settings: Settings) -> None:
self._settings = settings
self._confidence_threshold = settings.CONFIDENCE_THRESHOLD
genai.configure(api_key=settings.GEMINI_API_KEY)
self._model = genai.GenerativeModel(settings.GEMINI_MODEL)
# ------------------------------------------------------------------
# Prompt Construction
# ------------------------------------------------------------------
def build_prompt(self, market_data: dict[str, Any]) -> str:
"""Build a structured prompt from market data.
The prompt instructs Gemini to return valid JSON with action,
confidence, and rationale fields.
"""
return (
"You are a professional Korean stock market trading analyst.\n"
"Analyze the following market data and decide whether to BUY, SELL, or HOLD.\n\n"
f"Stock Code: {market_data['stock_code']}\n"
f"Current Price: {market_data['current_price']}\n"
f"Orderbook: {json.dumps(market_data['orderbook'], ensure_ascii=False)}\n"
f"Foreigner Net Buy/Sell: {market_data['foreigner_net']}\n\n"
"You MUST respond with ONLY valid JSON in the following format:\n"
'{"action": "BUY"|"SELL"|"HOLD", "confidence": <int 0-100>, "rationale": "<string>"}\n\n'
"Rules:\n"
"- action must be exactly one of: BUY, SELL, HOLD\n"
"- confidence must be an integer from 0 to 100\n"
"- rationale must explain your reasoning concisely\n"
"- Do NOT wrap the JSON in markdown code blocks\n"
)
# ------------------------------------------------------------------
# Response Parsing
# ------------------------------------------------------------------
def parse_response(self, raw: str) -> TradeDecision:
"""Parse a raw Gemini response into a TradeDecision.
Handles: valid JSON, JSON wrapped in markdown code blocks,
malformed JSON, missing fields, and invalid action values.
On any failure, returns a safe HOLD with confidence 0.
"""
if not raw or not raw.strip():
logger.warning("Empty response from Gemini — defaulting to HOLD")
return TradeDecision(action="HOLD", confidence=0, rationale="Empty response")
# Strip markdown code fences if present
cleaned = raw.strip()
match = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", cleaned, re.DOTALL)
if match:
cleaned = match.group(1).strip()
try:
data = json.loads(cleaned)
except json.JSONDecodeError:
logger.warning("Malformed JSON from Gemini — defaulting to HOLD")
return TradeDecision(
action="HOLD", confidence=0, rationale="Malformed JSON response"
)
# Validate required fields
if not all(k in data for k in ("action", "confidence", "rationale")):
logger.warning("Missing fields in Gemini response — defaulting to HOLD")
return TradeDecision(
action="HOLD", confidence=0, rationale="Missing required fields"
)
action = str(data["action"]).upper()
if action not in VALID_ACTIONS:
logger.warning("Invalid action '%s' from Gemini — defaulting to HOLD", action)
return TradeDecision(
action="HOLD", confidence=0, rationale=f"Invalid action: {action}"
)
confidence = int(data["confidence"])
rationale = str(data["rationale"])
# Enforce confidence threshold
if confidence < self._confidence_threshold:
logger.info(
"Confidence %d < threshold %d — forcing HOLD",
confidence,
self._confidence_threshold,
)
action = "HOLD"
return TradeDecision(action=action, confidence=confidence, rationale=rationale)
# ------------------------------------------------------------------
# API Call
# ------------------------------------------------------------------
async def decide(self, market_data: dict[str, Any]) -> TradeDecision:
"""Build prompt, call Gemini, and return a parsed decision."""
prompt = self.build_prompt(market_data)
logger.info("Requesting trade decision from Gemini")
try:
response = await self._model.generate_content_async(prompt)
raw = response.text
except Exception as exc:
logger.error("Gemini API error: %s", exc)
return TradeDecision(
action="HOLD", confidence=0, rationale=f"API error: {exc}"
)
decision = self.parse_response(raw)
logger.info(
"Gemini decision",
extra={
"action": decision.action,
"confidence": decision.confidence,
},
)
return decision

0
src/broker/__init__.py Normal file
View File

245
src/broker/kis_api.py Normal file
View File

@@ -0,0 +1,245 @@
"""Async wrapper for the Korea Investment Securities (KIS) Open API.
Handles token refresh, rate limiting (leaky bucket), and hash key generation.
"""
from __future__ import annotations
import asyncio
import hashlib
import json
import logging
import time
from typing import Any
import aiohttp
from src.config import Settings
logger = logging.getLogger(__name__)
class LeakyBucket:
"""Simple leaky-bucket rate limiter for async code."""
def __init__(self, rate: float) -> None:
"""Args:
rate: Maximum requests per second.
"""
self._rate = rate
self._interval = 1.0 / rate
self._last = 0.0
self._lock = asyncio.Lock()
async def acquire(self) -> None:
async with self._lock:
now = asyncio.get_event_loop().time()
wait = self._last + self._interval - now
if wait > 0:
await asyncio.sleep(wait)
self._last = asyncio.get_event_loop().time()
class KISBroker:
"""Async client for KIS Open API with automatic token management."""
def __init__(self, settings: Settings) -> None:
self._settings = settings
self._base_url = settings.KIS_BASE_URL
self._app_key = settings.KIS_APP_KEY
self._app_secret = settings.KIS_APP_SECRET
self._account_no = settings.account_number
self._product_cd = settings.account_product_code
self._session: aiohttp.ClientSession | None = None
self._access_token: str | None = None
self._token_expires_at: float = 0.0
self._rate_limiter = LeakyBucket(settings.RATE_LIMIT_RPS)
def _get_session(self) -> aiohttp.ClientSession:
if self._session is None or self._session.closed:
timeout = aiohttp.ClientTimeout(total=10)
self._session = aiohttp.ClientSession(timeout=timeout)
return self._session
async def close(self) -> None:
if self._session and not self._session.closed:
await self._session.close()
# ------------------------------------------------------------------
# Token Management
# ------------------------------------------------------------------
async def _ensure_token(self) -> str:
"""Return a valid access token, refreshing if expired."""
now = asyncio.get_event_loop().time()
if self._access_token and now < self._token_expires_at:
return self._access_token
logger.info("Refreshing KIS access token")
session = self._get_session()
url = f"{self._base_url}/oauth2/tokenP"
body = {
"grant_type": "client_credentials",
"appkey": self._app_key,
"appsecret": self._app_secret,
}
async with session.post(url, json=body) as resp:
if resp.status != 200:
text = await resp.text()
raise ConnectionError(f"Token refresh failed ({resp.status}): {text}")
data = await resp.json()
self._access_token = data["access_token"]
self._token_expires_at = now + data.get("expires_in", 86400) - 60 # 1-min buffer
logger.info("Token refreshed successfully")
return self._access_token
# ------------------------------------------------------------------
# Hash Key (required for POST bodies)
# ------------------------------------------------------------------
async def _get_hash_key(self, body: dict[str, Any]) -> str:
"""Request a hash key from KIS for POST request body signing."""
session = self._get_session()
url = f"{self._base_url}/uapi/hashkey"
headers = {
"content-Type": "application/json",
"appKey": self._app_key,
"appSecret": self._app_secret,
}
async with session.post(url, json=body, headers=headers) as resp:
if resp.status != 200:
text = await resp.text()
raise ConnectionError(f"Hash key request failed ({resp.status}): {text}")
data = await resp.json()
return data["HASH"]
# ------------------------------------------------------------------
# Common Headers
# ------------------------------------------------------------------
async def _auth_headers(self, tr_id: str) -> dict[str, str]:
token = await self._ensure_token()
return {
"content-type": "application/json; charset=utf-8",
"authorization": f"Bearer {token}",
"appkey": self._app_key,
"appsecret": self._app_secret,
"tr_id": tr_id,
}
# ------------------------------------------------------------------
# API Methods
# ------------------------------------------------------------------
async def get_orderbook(self, stock_code: str) -> dict[str, Any]:
"""Fetch the current orderbook for a given stock code."""
await self._rate_limiter.acquire()
session = self._get_session()
headers = await self._auth_headers("FHKST01010200")
params = {
"FID_COND_MRKT_DIV_CODE": "J",
"FID_INPUT_ISCD": stock_code,
}
url = f"{self._base_url}/uapi/domestic-stock/v1/quotations/inquire-asking-price-exp-ccn"
try:
async with session.get(url, headers=headers, params=params) as resp:
if resp.status != 200:
text = await resp.text()
raise ConnectionError(
f"get_orderbook failed ({resp.status}): {text}"
)
return await resp.json()
except (aiohttp.ClientError, asyncio.TimeoutError) as exc:
raise ConnectionError(f"Network error fetching orderbook: {exc}") from exc
async def get_balance(self) -> dict[str, Any]:
"""Fetch current account balance and holdings."""
await self._rate_limiter.acquire()
session = self._get_session()
headers = await self._auth_headers("VTTC8434R") # 모의투자 잔고조회
params = {
"CANO": self._account_no,
"ACNT_PRDT_CD": self._product_cd,
"AFHR_FLPR_YN": "N",
"OFL_YN": "",
"INQR_DVSN": "02",
"UNPR_DVSN": "01",
"FUND_STTL_ICLD_YN": "N",
"FNCG_AMT_AUTO_RDPT_YN": "N",
"PRCS_DVSN": "01",
"CTX_AREA_FK100": "",
"CTX_AREA_NK100": "",
}
url = f"{self._base_url}/uapi/domestic-stock/v1/trading/inquire-balance"
try:
async with session.get(url, headers=headers, params=params) as resp:
if resp.status != 200:
text = await resp.text()
raise ConnectionError(
f"get_balance failed ({resp.status}): {text}"
)
return await resp.json()
except (aiohttp.ClientError, asyncio.TimeoutError) as exc:
raise ConnectionError(f"Network error fetching balance: {exc}") from exc
async def send_order(
self,
stock_code: str,
order_type: str, # "BUY" or "SELL"
quantity: int,
price: int = 0,
) -> dict[str, Any]:
"""Submit a buy or sell order.
Args:
stock_code: 6-digit stock code.
order_type: "BUY" or "SELL".
quantity: Number of shares.
price: Order price (0 for market order).
"""
await self._rate_limiter.acquire()
session = self._get_session()
tr_id = "VTTC0802U" if order_type == "BUY" else "VTTC0801U"
body = {
"CANO": self._account_no,
"ACNT_PRDT_CD": self._product_cd,
"PDNO": stock_code,
"ORD_DVSN": "01" if price > 0 else "06", # 01=지정가, 06=시장가
"ORD_QTY": str(quantity),
"ORD_UNPR": str(price),
}
hash_key = await self._get_hash_key(body)
headers = await self._auth_headers(tr_id)
headers["hashkey"] = hash_key
url = f"{self._base_url}/uapi/domestic-stock/v1/trading/order-cash"
try:
async with session.post(url, headers=headers, json=body) as resp:
if resp.status != 200:
text = await resp.text()
raise ConnectionError(
f"send_order failed ({resp.status}): {text}"
)
data = await resp.json()
logger.info(
"Order submitted",
extra={
"stock_code": stock_code,
"action": order_type,
},
)
return data
except (aiohttp.ClientError, asyncio.TimeoutError) as exc:
raise ConnectionError(f"Network error sending order: {exc}") from exc

44
src/config.py Normal file
View File

@@ -0,0 +1,44 @@
"""Strictly typed configuration loaded from environment variables."""
from __future__ import annotations
from pydantic import Field
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
"""Application settings — loaded from .env or environment variables."""
# KIS Open API
KIS_APP_KEY: str
KIS_APP_SECRET: str
KIS_ACCOUNT_NO: str # format: "XXXXXXXX-XX"
KIS_BASE_URL: str = "https://openapivts.koreainvestment.com:9443"
# Google Gemini
GEMINI_API_KEY: str
GEMINI_MODEL: str = "gemini-pro"
# Risk Management
CIRCUIT_BREAKER_PCT: float = Field(default=-3.0, le=0.0)
FAT_FINGER_PCT: float = Field(default=30.0, gt=0.0, le=100.0)
CONFIDENCE_THRESHOLD: int = Field(default=80, ge=0, le=100)
# Database
DB_PATH: str = "data/trade_logs.db"
# Rate Limiting (requests per second for KIS API)
RATE_LIMIT_RPS: float = 10.0
# Trading mode
MODE: str = Field(default="paper", pattern="^(paper|live)$")
model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}
@property
def account_number(self) -> str:
return self.KIS_ACCOUNT_NO.split("-")[0]
@property
def account_product_code(self) -> str:
return self.KIS_ACCOUNT_NO.split("-")[1]

0
src/core/__init__.py Normal file
View File

84
src/core/risk_manager.py Normal file
View File

@@ -0,0 +1,84 @@
"""Risk management — the Shield that protects the portfolio.
This module is READ-ONLY by policy (see docs/agents.md).
Changes require human approval and two passing test suites.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from src.config import Settings
logger = logging.getLogger(__name__)
class CircuitBreakerTripped(SystemExit):
"""Raised when daily P&L loss exceeds the allowed threshold."""
def __init__(self, pnl_pct: float, threshold: float) -> None:
self.pnl_pct = pnl_pct
self.threshold = threshold
super().__init__(
f"CIRCUIT BREAKER: Daily P&L {pnl_pct:.2f}% exceeded "
f"threshold {threshold:.2f}%. All trading halted."
)
class FatFingerRejected(Exception):
"""Raised when an order exceeds the maximum allowed proportion of cash."""
def __init__(self, order_amount: float, total_cash: float, max_pct: float) -> None:
self.order_amount = order_amount
self.total_cash = total_cash
self.max_pct = max_pct
ratio = (order_amount / total_cash * 100) if total_cash > 0 else float("inf")
super().__init__(
f"FAT FINGER: Order {order_amount:,.0f} is {ratio:.1f}% of "
f"cash {total_cash:,.0f} (max allowed: {max_pct:.1f}%)."
)
class RiskManager:
"""Pre-order risk gate that enforces circuit breaker and fat-finger checks."""
def __init__(self, settings: Settings) -> None:
self._cb_threshold = settings.CIRCUIT_BREAKER_PCT
self._ff_max_pct = settings.FAT_FINGER_PCT
def check_circuit_breaker(self, current_pnl_pct: float) -> None:
"""Halt trading if daily loss exceeds the threshold.
The threshold is inclusive: exactly -3.0% is allowed, but -3.01% is not.
"""
if current_pnl_pct < self._cb_threshold:
logger.critical(
"Circuit breaker tripped",
extra={"pnl_pct": current_pnl_pct},
)
raise CircuitBreakerTripped(current_pnl_pct, self._cb_threshold)
def check_fat_finger(self, order_amount: float, total_cash: float) -> None:
"""Reject orders that exceed the maximum proportion of available cash."""
if total_cash <= 0:
raise FatFingerRejected(order_amount, total_cash, self._ff_max_pct)
ratio_pct = (order_amount / total_cash) * 100
if ratio_pct > self._ff_max_pct:
logger.warning(
"Fat finger check failed",
extra={"order_amount": order_amount},
)
raise FatFingerRejected(order_amount, total_cash, self._ff_max_pct)
def validate_order(
self,
current_pnl_pct: float,
order_amount: float,
total_cash: float,
) -> None:
"""Run all pre-order risk checks. Raises on failure."""
self.check_circuit_breaker(current_pnl_pct)
self.check_fat_finger(order_amount, total_cash)
logger.info("Order passed risk validation")

59
src/db.py Normal file
View File

@@ -0,0 +1,59 @@
"""Database layer for trade logging."""
from __future__ import annotations
import sqlite3
from datetime import datetime, timezone
from typing import Any
def init_db(db_path: str) -> sqlite3.Connection:
"""Initialize the trade logs database and return a connection."""
conn = sqlite3.connect(db_path)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS trades (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TEXT NOT NULL,
stock_code TEXT NOT NULL,
action TEXT NOT NULL,
confidence INTEGER NOT NULL,
rationale TEXT,
quantity INTEGER,
price REAL,
pnl REAL DEFAULT 0.0
)
"""
)
conn.commit()
return conn
def log_trade(
conn: sqlite3.Connection,
stock_code: str,
action: str,
confidence: int,
rationale: str,
quantity: int = 0,
price: float = 0.0,
pnl: float = 0.0,
) -> None:
"""Insert a trade record into the database."""
conn.execute(
"""
INSERT INTO trades (timestamp, stock_code, action, confidence, rationale, quantity, price, pnl)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
datetime.now(timezone.utc).isoformat(),
stock_code,
action,
confidence,
rationale,
quantity,
price,
pnl,
),
)
conn.commit()

View File

229
src/evolution/optimizer.py Normal file
View File

@@ -0,0 +1,229 @@
"""Evolution Engine — analyzes trade logs and generates new strategies.
This module:
1. Reads trade_logs.db to identify failing patterns
2. Asks Gemini to generate a new strategy class
3. Runs pytest on the generated file
4. Creates a simulated PR if tests pass
"""
from __future__ import annotations
import json
import logging
import sqlite3
import subprocess
import textwrap
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import google.generativeai as genai
from src.config import Settings
logger = logging.getLogger(__name__)
STRATEGIES_DIR = Path("src/strategies")
STRATEGY_TEMPLATE = textwrap.dedent("""\
\"\"\"Auto-generated strategy: {name}
Generated at: {timestamp}
Rationale: {rationale}
\"\"\"
from __future__ import annotations
from typing import Any
from src.strategies.base import BaseStrategy
class {class_name}(BaseStrategy):
\"\"\"Strategy: {name}\"\"\"
def evaluate(self, market_data: dict[str, Any]) -> dict[str, Any]:
{body}
""")
class EvolutionOptimizer:
"""Analyzes trade history and evolves trading strategies."""
def __init__(self, settings: Settings) -> None:
self._settings = settings
self._db_path = settings.DB_PATH
genai.configure(api_key=settings.GEMINI_API_KEY)
self._model = genai.GenerativeModel(settings.GEMINI_MODEL)
# ------------------------------------------------------------------
# Analysis
# ------------------------------------------------------------------
def analyze_failures(self, limit: int = 50) -> list[dict[str, Any]]:
"""Find trades where high confidence led to losses."""
conn = sqlite3.connect(self._db_path)
conn.row_factory = sqlite3.Row
try:
rows = conn.execute(
"""
SELECT stock_code, action, confidence, pnl, rationale, timestamp
FROM trades
WHERE confidence >= 80 AND pnl < 0
ORDER BY pnl ASC
LIMIT ?
""",
(limit,),
).fetchall()
return [dict(r) for r in rows]
finally:
conn.close()
def get_performance_summary(self) -> dict[str, Any]:
"""Return aggregate performance metrics from trade logs."""
conn = sqlite3.connect(self._db_path)
try:
row = conn.execute(
"""
SELECT
COUNT(*) as total_trades,
SUM(CASE WHEN pnl > 0 THEN 1 ELSE 0 END) as wins,
SUM(CASE WHEN pnl < 0 THEN 1 ELSE 0 END) as losses,
COALESCE(AVG(pnl), 0) as avg_pnl,
COALESCE(SUM(pnl), 0) as total_pnl
FROM trades
"""
).fetchone()
return {
"total_trades": row[0],
"wins": row[1] or 0,
"losses": row[2] or 0,
"avg_pnl": round(row[3], 2),
"total_pnl": round(row[4], 2),
}
finally:
conn.close()
# ------------------------------------------------------------------
# Strategy Generation
# ------------------------------------------------------------------
async def generate_strategy(self, failures: list[dict[str, Any]]) -> Path | None:
"""Ask Gemini to generate a new strategy based on failure analysis.
Returns the path to the generated strategy file, or None on failure.
"""
prompt = (
"You are a quantitative trading strategy developer.\n"
"Analyze these failed trades and generate an improved strategy.\n\n"
f"Failed trades:\n{json.dumps(failures, indent=2, default=str)}\n\n"
"Generate a Python class that inherits from BaseStrategy.\n"
"The class must have an `evaluate(self, market_data: dict) -> dict` method.\n"
"The method must return a dict with keys: action, confidence, rationale.\n"
"Respond with ONLY the method body (Python code), no class definition.\n"
)
try:
response = await self._model.generate_content_async(prompt)
body = response.text.strip()
except Exception as exc:
logger.error("Failed to generate strategy: %s", exc)
return None
# Clean up code fences
if body.startswith("```"):
lines = body.split("\n")
body = "\n".join(lines[1:-1])
# Create strategy file
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
version = f"v{timestamp}"
class_name = f"Strategy_{version}"
file_name = f"{version}_evolved.py"
STRATEGIES_DIR.mkdir(parents=True, exist_ok=True)
file_path = STRATEGIES_DIR / file_name
# Indent the body for the class method
indented_body = textwrap.indent(body, " ")
content = STRATEGY_TEMPLATE.format(
name=version,
timestamp=datetime.now(timezone.utc).isoformat(),
rationale="Auto-evolved from failure analysis",
class_name=class_name,
body=indented_body.strip(),
)
file_path.write_text(content)
logger.info("Generated strategy file: %s", file_path)
return file_path
# ------------------------------------------------------------------
# Validation
# ------------------------------------------------------------------
def validate_strategy(self, strategy_path: Path) -> bool:
"""Run pytest on the generated strategy. Returns True if all tests pass."""
logger.info("Validating strategy: %s", strategy_path)
result = subprocess.run(
["python", "-m", "pytest", "tests/", "-v", "--tb=short"],
capture_output=True,
text=True,
timeout=120,
)
if result.returncode == 0:
logger.info("Strategy validation PASSED")
return True
else:
logger.warning(
"Strategy validation FAILED:\n%s", result.stdout + result.stderr
)
# Clean up failing strategy
strategy_path.unlink(missing_ok=True)
return False
# ------------------------------------------------------------------
# PR Simulation
# ------------------------------------------------------------------
def create_pr_simulation(self, strategy_path: Path) -> dict[str, str]:
"""Simulate creating a pull request for the new strategy."""
pr = {
"title": f"[Evolution] New strategy: {strategy_path.stem}",
"branch": f"evolution/{strategy_path.stem}",
"body": (
f"Auto-generated strategy from evolution engine.\n"
f"File: {strategy_path}\n"
f"All tests passed."
),
"status": "ready_for_review",
}
logger.info("PR simulation created: %s", pr["title"])
return pr
# ------------------------------------------------------------------
# Full Pipeline
# ------------------------------------------------------------------
async def evolve(self) -> dict[str, Any] | None:
"""Run the full evolution pipeline.
1. Analyze failures
2. Generate new strategy
3. Validate with tests
4. Create PR simulation
Returns PR info on success, None on failure.
"""
failures = self.analyze_failures()
if not failures:
logger.info("No failure patterns found — skipping evolution")
return None
strategy_path = await self.generate_strategy(failures)
if strategy_path is None:
return None
if not self.validate_strategy(strategy_path):
return None
return self.create_pr_simulation(strategy_path)

42
src/logging_config.py Normal file
View File

@@ -0,0 +1,42 @@
"""JSON-formatted structured logging for machine readability."""
from __future__ import annotations
import logging
import sys
from datetime import datetime, timezone
from typing import Any
import json
class JSONFormatter(logging.Formatter):
"""Emit log records as single-line JSON objects."""
def format(self, record: logging.LogRecord) -> str:
log_entry: dict[str, Any] = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
}
if record.exc_info and record.exc_info[1]:
log_entry["exception"] = self.formatException(record.exc_info)
# Merge any extra fields attached to the record
for key in ("stock_code", "action", "confidence", "pnl_pct", "order_amount"):
value = getattr(record, key, None)
if value is not None:
log_entry[key] = value
return json.dumps(log_entry, ensure_ascii=False)
def setup_logging(level: int = logging.INFO) -> None:
"""Configure the root logger with JSON output to stdout."""
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(JSONFormatter())
root = logging.getLogger()
root.setLevel(level)
# Avoid duplicate handlers on repeated calls
root.handlers.clear()
root.addHandler(handler)

171
src/main.py Normal file
View File

@@ -0,0 +1,171 @@
"""The Ouroboros — main trading loop.
Orchestrates the broker, brain, and risk manager into a continuous
trading cycle with configurable intervals.
"""
from __future__ import annotations
import argparse
import asyncio
import logging
import signal
import sys
from typing import Any
from src.brain.gemini_client import GeminiClient
from src.broker.kis_api import KISBroker
from src.config import Settings
from src.core.risk_manager import CircuitBreakerTripped, RiskManager
from src.db import init_db, log_trade
from src.logging_config import setup_logging
logger = logging.getLogger(__name__)
# Target stock codes to monitor
WATCHLIST = ["005930", "000660", "035420"] # Samsung, SK Hynix, NAVER
TRADE_INTERVAL_SECONDS = 60
async def trading_cycle(
broker: KISBroker,
brain: GeminiClient,
risk: RiskManager,
db_conn: Any,
stock_code: str,
) -> None:
"""Execute one trading cycle for a single stock."""
# 1. Fetch market data
orderbook = await broker.get_orderbook(stock_code)
balance_data = await broker.get_balance()
output2 = balance_data.get("output2", [{}])
total_eval = float(output2[0].get("tot_evlu_amt", "0")) if output2 else 0
total_cash = float(
balance_data.get("output2", [{}])[0].get("dnca_tot_amt", "0")
if output2
else "0"
)
purchase_total = float(output2[0].get("pchs_amt_smtl_amt", "0")) if output2 else 0
# Calculate daily P&L %
pnl_pct = ((total_eval - purchase_total) / purchase_total * 100) if purchase_total > 0 else 0.0
current_price = float(
orderbook.get("output1", {}).get("stck_prpr", "0")
)
market_data = {
"stock_code": stock_code,
"current_price": current_price,
"orderbook": orderbook.get("output1", {}),
"foreigner_net": float(
orderbook.get("output1", {}).get("frgn_ntby_qty", "0")
),
}
# 2. Ask the brain for a decision
decision = await brain.decide(market_data)
logger.info(
"Decision for %s: %s (confidence=%d)",
stock_code,
decision.action,
decision.confidence,
)
# 3. Execute if actionable
if decision.action in ("BUY", "SELL"):
# Determine order size (simplified: 1 lot)
quantity = 1
order_amount = current_price * quantity
# 4. Risk check BEFORE order
risk.validate_order(
current_pnl_pct=pnl_pct,
order_amount=order_amount,
total_cash=total_cash,
)
# 5. Send order
result = await broker.send_order(
stock_code=stock_code,
order_type=decision.action,
quantity=quantity,
price=0, # market order
)
logger.info("Order result: %s", result.get("msg1", "OK"))
# 6. Log trade
log_trade(
conn=db_conn,
stock_code=stock_code,
action=decision.action,
confidence=decision.confidence,
rationale=decision.rationale,
)
async def run(settings: Settings) -> None:
"""Main async loop — iterate over watchlist on a timer."""
broker = KISBroker(settings)
brain = GeminiClient(settings)
risk = RiskManager(settings)
db_conn = init_db(settings.DB_PATH)
shutdown = asyncio.Event()
def _signal_handler() -> None:
logger.info("Shutdown signal received")
shutdown.set()
loop = asyncio.get_running_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, _signal_handler)
logger.info("The Ouroboros is alive. Mode: %s", settings.MODE)
logger.info("Watchlist: %s", WATCHLIST)
try:
while not shutdown.is_set():
for code in WATCHLIST:
if shutdown.is_set():
break
try:
await trading_cycle(broker, brain, risk, db_conn, code)
except CircuitBreakerTripped:
logger.critical("Circuit breaker tripped — shutting down")
raise
except ConnectionError as exc:
logger.error("Connection error for %s: %s", code, exc)
except Exception as exc:
logger.exception("Unexpected error for %s: %s", code, exc)
# Wait for next cycle or shutdown
try:
await asyncio.wait_for(shutdown.wait(), timeout=TRADE_INTERVAL_SECONDS)
except asyncio.TimeoutError:
pass # Normal — timeout means it's time for next cycle
finally:
await broker.close()
db_conn.close()
logger.info("The Ouroboros rests.")
def main() -> None:
parser = argparse.ArgumentParser(description="The Ouroboros Trading Agent")
parser.add_argument(
"--mode",
choices=["paper", "live"],
default="paper",
help="Trading mode (default: paper)",
)
args = parser.parse_args()
setup_logging()
settings = Settings(MODE=args.mode) # type: ignore[call-arg]
asyncio.run(run(settings))
if __name__ == "__main__":
main()

View File

19
src/strategies/base.py Normal file
View File

@@ -0,0 +1,19 @@
"""Base class for all trading strategies."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any
class BaseStrategy(ABC):
"""All strategies must inherit from this class."""
@abstractmethod
def evaluate(self, market_data: dict[str, Any]) -> dict[str, Any]:
"""Evaluate market data and return a trade decision.
Returns:
dict with keys: action ("BUY"|"SELL"|"HOLD"), confidence (int), rationale (str)
"""
...

0
tests/__init__.py Normal file
View File

23
tests/conftest.py Normal file
View File

@@ -0,0 +1,23 @@
"""Shared test fixtures for The Ouroboros test suite."""
from __future__ import annotations
import pytest
from src.config import Settings
@pytest.fixture
def settings() -> Settings:
"""Return a Settings instance with safe test defaults."""
return Settings(
KIS_APP_KEY="test_app_key",
KIS_APP_SECRET="test_app_secret",
KIS_ACCOUNT_NO="12345678-01",
KIS_BASE_URL="https://openapivts.koreainvestment.com:9443",
GEMINI_API_KEY="test_gemini_key",
CIRCUIT_BREAKER_PCT=-3.0,
FAT_FINGER_PCT=30.0,
CONFIDENCE_THRESHOLD=80,
DB_PATH=":memory:",
)

159
tests/test_brain.py Normal file
View File

@@ -0,0 +1,159 @@
"""TDD tests for brain/gemini_client.py — written BEFORE implementation."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from src.brain.gemini_client import GeminiClient, TradeDecision
# ---------------------------------------------------------------------------
# Response Parsing
# ---------------------------------------------------------------------------
class TestResponseParsing:
"""Gemini responses must be parsed into validated TradeDecision objects."""
def test_valid_buy_response(self, settings):
client = GeminiClient(settings)
raw = '{"action": "BUY", "confidence": 90, "rationale": "Strong momentum"}'
decision = client.parse_response(raw)
assert decision.action == "BUY"
assert decision.confidence == 90
assert decision.rationale == "Strong momentum"
def test_valid_sell_response(self, settings):
client = GeminiClient(settings)
raw = '{"action": "SELL", "confidence": 85, "rationale": "Overbought RSI"}'
decision = client.parse_response(raw)
assert decision.action == "SELL"
def test_valid_hold_response(self, settings):
client = GeminiClient(settings)
raw = '{"action": "HOLD", "confidence": 95, "rationale": "Sideways market"}'
decision = client.parse_response(raw)
assert decision.action == "HOLD"
# ---------------------------------------------------------------------------
# Confidence Threshold Enforcement
# ---------------------------------------------------------------------------
class TestConfidenceThreshold:
"""If confidence < 80, the action MUST be forced to HOLD."""
def test_low_confidence_buy_becomes_hold(self, settings):
client = GeminiClient(settings)
raw = '{"action": "BUY", "confidence": 65, "rationale": "Weak signal"}'
decision = client.parse_response(raw)
assert decision.action == "HOLD"
assert decision.confidence == 65
def test_low_confidence_sell_becomes_hold(self, settings):
client = GeminiClient(settings)
raw = '{"action": "SELL", "confidence": 79, "rationale": "Uncertain"}'
decision = client.parse_response(raw)
assert decision.action == "HOLD"
def test_exactly_threshold_is_allowed(self, settings):
client = GeminiClient(settings)
raw = '{"action": "BUY", "confidence": 80, "rationale": "Just enough"}'
decision = client.parse_response(raw)
assert decision.action == "BUY"
# ---------------------------------------------------------------------------
# Malformed JSON Handling
# ---------------------------------------------------------------------------
class TestMalformedJsonHandling:
"""Gemini may return garbage — the parser must not crash."""
def test_empty_string_returns_hold(self, settings):
client = GeminiClient(settings)
decision = client.parse_response("")
assert decision.action == "HOLD"
assert decision.confidence == 0
def test_plain_text_returns_hold(self, settings):
client = GeminiClient(settings)
decision = client.parse_response("I think you should buy Samsung stock")
assert decision.action == "HOLD"
assert decision.confidence == 0
def test_partial_json_returns_hold(self, settings):
client = GeminiClient(settings)
decision = client.parse_response('{"action": "BUY", "confidence":')
assert decision.action == "HOLD"
assert decision.confidence == 0
def test_json_with_missing_fields_returns_hold(self, settings):
client = GeminiClient(settings)
decision = client.parse_response('{"action": "BUY"}')
assert decision.action == "HOLD"
assert decision.confidence == 0
def test_json_with_invalid_action_returns_hold(self, settings):
client = GeminiClient(settings)
decision = client.parse_response(
'{"action": "YOLO", "confidence": 99, "rationale": "moon"}'
)
assert decision.action == "HOLD"
assert decision.confidence == 0
def test_json_wrapped_in_markdown_code_block(self, settings):
"""Gemini often wraps JSON in ```json ... ``` blocks."""
client = GeminiClient(settings)
raw = '```json\n{"action": "BUY", "confidence": 92, "rationale": "Good"}\n```'
decision = client.parse_response(raw)
assert decision.action == "BUY"
assert decision.confidence == 92
# ---------------------------------------------------------------------------
# Prompt Construction
# ---------------------------------------------------------------------------
class TestPromptConstruction:
"""The prompt sent to Gemini must include all required market data."""
def test_prompt_contains_stock_code(self, settings):
client = GeminiClient(settings)
market_data = {
"stock_code": "005930",
"current_price": 72000,
"orderbook": {"asks": [], "bids": []},
"foreigner_net": -50000,
}
prompt = client.build_prompt(market_data)
assert "005930" in prompt
def test_prompt_contains_price(self, settings):
client = GeminiClient(settings)
market_data = {
"stock_code": "005930",
"current_price": 72000,
"orderbook": {"asks": [], "bids": []},
"foreigner_net": -50000,
}
prompt = client.build_prompt(market_data)
assert "72000" in prompt
def test_prompt_enforces_json_output_format(self, settings):
client = GeminiClient(settings)
market_data = {
"stock_code": "005930",
"current_price": 72000,
"orderbook": {"asks": [], "bids": []},
"foreigner_net": 0,
}
prompt = client.build_prompt(market_data)
assert "JSON" in prompt
assert "action" in prompt
assert "confidence" in prompt

140
tests/test_broker.py Normal file
View File

@@ -0,0 +1,140 @@
"""TDD tests for broker/kis_api.py — written BEFORE implementation."""
from __future__ import annotations
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import aiohttp
import pytest
from src.broker.kis_api import KISBroker
# ---------------------------------------------------------------------------
# Token Management
# ---------------------------------------------------------------------------
class TestTokenManagement:
"""Access token must be auto-refreshed and cached."""
@pytest.mark.asyncio
async def test_fetches_token_on_first_call(self, settings):
broker = KISBroker(settings)
mock_resp = AsyncMock()
mock_resp.status = 200
mock_resp.json = AsyncMock(
return_value={
"access_token": "tok_abc123",
"token_type": "Bearer",
"expires_in": 86400,
}
)
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
token = await broker._ensure_token()
assert token == "tok_abc123"
await broker.close()
@pytest.mark.asyncio
async def test_reuses_cached_token(self, settings):
broker = KISBroker(settings)
broker._access_token = "cached_token"
broker._token_expires_at = asyncio.get_event_loop().time() + 3600
token = await broker._ensure_token()
assert token == "cached_token"
await broker.close()
# ---------------------------------------------------------------------------
# Network Error Handling
# ---------------------------------------------------------------------------
class TestNetworkErrorHandling:
"""Broker must handle network timeouts and HTTP errors gracefully."""
@pytest.mark.asyncio
async def test_timeout_raises_connection_error(self, settings):
broker = KISBroker(settings)
broker._access_token = "tok"
broker._token_expires_at = asyncio.get_event_loop().time() + 3600
with patch(
"aiohttp.ClientSession.get",
side_effect=asyncio.TimeoutError(),
):
with pytest.raises(ConnectionError):
await broker.get_orderbook("005930")
await broker.close()
@pytest.mark.asyncio
async def test_http_500_raises_connection_error(self, settings):
broker = KISBroker(settings)
broker._access_token = "tok"
broker._token_expires_at = asyncio.get_event_loop().time() + 3600
mock_resp = AsyncMock()
mock_resp.status = 500
mock_resp.text = AsyncMock(return_value="Internal Server Error")
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
with patch("aiohttp.ClientSession.get", return_value=mock_resp):
with pytest.raises(ConnectionError):
await broker.get_orderbook("005930")
await broker.close()
# ---------------------------------------------------------------------------
# Rate Limiter
# ---------------------------------------------------------------------------
class TestRateLimiter:
"""The leaky bucket rate limiter must throttle requests."""
@pytest.mark.asyncio
async def test_rate_limiter_does_not_block_under_limit(self, settings):
broker = KISBroker(settings)
# Should complete without blocking when under limit
await broker._rate_limiter.acquire()
await broker.close()
# ---------------------------------------------------------------------------
# Hash Key Generation
# ---------------------------------------------------------------------------
class TestHashKey:
"""POST requests to KIS require a hash key."""
@pytest.mark.asyncio
async def test_generates_hash_key_for_post_body(self, settings):
broker = KISBroker(settings)
broker._access_token = "tok"
broker._token_expires_at = asyncio.get_event_loop().time() + 3600
body = {"CANO": "12345678", "ACNT_PRDT_CD": "01"}
mock_resp = AsyncMock()
mock_resp.status = 200
mock_resp.json = AsyncMock(return_value={"HASH": "abc123hash"})
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
hash_key = await broker._get_hash_key(body)
assert isinstance(hash_key, str)
assert len(hash_key) > 0
await broker.close()

132
tests/test_risk.py Normal file
View File

@@ -0,0 +1,132 @@
"""TDD tests for core/risk_manager.py — written BEFORE implementation."""
from __future__ import annotations
import pytest
from src.core.risk_manager import (
CircuitBreakerTripped,
FatFingerRejected,
RiskManager,
)
# ---------------------------------------------------------------------------
# Circuit Breaker Tests
# ---------------------------------------------------------------------------
class TestCircuitBreaker:
"""The circuit breaker must halt all trading when daily loss exceeds the threshold."""
def test_allows_trading_when_pnl_is_positive(self, settings):
rm = RiskManager(settings)
# 2% gain — should be fine
rm.check_circuit_breaker(current_pnl_pct=2.0)
def test_allows_trading_at_zero_pnl(self, settings):
rm = RiskManager(settings)
rm.check_circuit_breaker(current_pnl_pct=0.0)
def test_allows_trading_at_exactly_threshold(self, settings):
rm = RiskManager(settings)
# Exactly -3.0% is ON the boundary — still allowed
rm.check_circuit_breaker(current_pnl_pct=-3.0)
def test_trips_when_loss_exceeds_threshold(self, settings):
rm = RiskManager(settings)
with pytest.raises(CircuitBreakerTripped):
rm.check_circuit_breaker(current_pnl_pct=-3.01)
def test_trips_at_large_loss(self, settings):
rm = RiskManager(settings)
with pytest.raises(CircuitBreakerTripped):
rm.check_circuit_breaker(current_pnl_pct=-10.0)
def test_custom_threshold(self):
"""A stricter threshold (-1.5%) should trip earlier."""
from src.config import Settings
strict = Settings(
KIS_APP_KEY="k",
KIS_APP_SECRET="s",
KIS_ACCOUNT_NO="00000000-00",
KIS_BASE_URL="https://example.com",
GEMINI_API_KEY="g",
CIRCUIT_BREAKER_PCT=-1.5,
FAT_FINGER_PCT=30.0,
CONFIDENCE_THRESHOLD=80,
DB_PATH=":memory:",
)
rm = RiskManager(strict)
with pytest.raises(CircuitBreakerTripped):
rm.check_circuit_breaker(current_pnl_pct=-1.51)
# ---------------------------------------------------------------------------
# Fat Finger Tests
# ---------------------------------------------------------------------------
class TestFatFingerCheck:
"""Orders exceeding 30% of total cash must be rejected."""
def test_allows_small_order(self, settings):
rm = RiskManager(settings)
# 10% of 10_000_000 = 1_000_000
rm.check_fat_finger(order_amount=1_000_000, total_cash=10_000_000)
def test_allows_order_at_exactly_threshold(self, settings):
rm = RiskManager(settings)
# Exactly 30% — allowed
rm.check_fat_finger(order_amount=3_000_000, total_cash=10_000_000)
def test_rejects_order_exceeding_threshold(self, settings):
rm = RiskManager(settings)
with pytest.raises(FatFingerRejected):
rm.check_fat_finger(order_amount=3_000_001, total_cash=10_000_000)
def test_rejects_massive_order(self, settings):
rm = RiskManager(settings)
with pytest.raises(FatFingerRejected):
rm.check_fat_finger(order_amount=9_000_000, total_cash=10_000_000)
def test_zero_cash_rejects_any_order(self, settings):
rm = RiskManager(settings)
with pytest.raises(FatFingerRejected):
rm.check_fat_finger(order_amount=1, total_cash=0)
# ---------------------------------------------------------------------------
# Pre-Order Validation (Integration of both checks)
# ---------------------------------------------------------------------------
class TestPreOrderValidation:
"""validate_order must run BOTH checks before approving."""
def test_passes_when_both_checks_ok(self, settings):
rm = RiskManager(settings)
rm.validate_order(
current_pnl_pct=0.5,
order_amount=1_000_000,
total_cash=10_000_000,
)
def test_fails_on_circuit_breaker(self, settings):
rm = RiskManager(settings)
with pytest.raises(CircuitBreakerTripped):
rm.validate_order(
current_pnl_pct=-5.0,
order_amount=100,
total_cash=10_000_000,
)
def test_fails_on_fat_finger(self, settings):
rm = RiskManager(settings)
with pytest.raises(FatFingerRejected):
rm.validate_order(
current_pnl_pct=1.0,
order_amount=5_000_000,
total_cash=10_000_000,
)