diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..1317e3c --- /dev/null +++ b/.env.example @@ -0,0 +1,23 @@ +# Korea Investment Securities API +KIS_APP_KEY=your_app_key_here +KIS_APP_SECRET=your_app_secret_here +KIS_ACCOUNT_NO=12345678-01 +KIS_BASE_URL=https://openapivts.koreainvestment.com:9443 + +# Google Gemini +GEMINI_API_KEY=your_gemini_api_key_here +GEMINI_MODEL=gemini-pro + +# Risk Management +CIRCUIT_BREAKER_PCT=-3.0 +FAT_FINGER_PCT=30.0 +CONFIDENCE_THRESHOLD=80 + +# Database +DB_PATH=data/trade_logs.db + +# Rate Limiting +RATE_LIMIT_RPS=10.0 + +# Trading Mode (paper / live) +MODE=paper diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..6fcd55a --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,39 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: pip install ".[dev]" + + - name: Lint + run: ruff check src/ tests/ + + - name: Type check + run: mypy src/ --strict + continue-on-error: true + + - name: Run tests with coverage + run: pytest -v --cov=src --cov-report=term-missing --cov-fail-under=80 + + - name: Upload coverage + if: always() + uses: actions/upload-artifact@v4 + with: + name: coverage-report + path: htmlcov/ diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..f2b0423 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,44 @@ +FROM python:3.11-slim AS base + +WORKDIR /app + +# System deps +RUN apt-get update && apt-get install -y --no-install-recommends \ + gcc \ + && rm -rf /var/lib/apt/lists/* + +# Install Python dependencies +COPY pyproject.toml . +RUN pip install --no-cache-dir ".[dev]" + +# Copy source +COPY src/ src/ +COPY tests/ tests/ +COPY docs/ docs/ + +# Create data directory +RUN mkdir -p data + +# Run tests as build validation +RUN pytest -v --tb=short + +# Production stage +FROM python:3.11-slim AS production + +WORKDIR /app + +COPY pyproject.toml . +RUN pip install --no-cache-dir . + +COPY src/ src/ +RUN mkdir -p data + +# Non-root user +RUN useradd --create-home appuser +USER appuser + +HEALTHCHECK --interval=30s --timeout=5s --retries=3 \ + CMD python -c "import sys; sys.exit(0)" + +ENTRYPOINT ["python", "-m", "src.main"] +CMD ["--mode=paper"] diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..fca5c9b --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,30 @@ +services: + ouroboros: + build: + context: . + target: production + container_name: ouroboros + restart: unless-stopped + env_file: + - .env + volumes: + - trade_data:/app/data + command: ["--mode=paper"] + logging: + driver: json-file + options: + max-size: "10m" + max-file: "3" + + # Run tests (one-shot) + test: + build: + context: . + target: base + container_name: ouroboros-test + command: ["pytest", "-v", "--cov=src", "--cov-report=term-missing"] + profiles: + - test + +volumes: + trade_data: diff --git a/docs/agents.md b/docs/agents.md new file mode 100644 index 0000000..8184711 --- /dev/null +++ b/docs/agents.md @@ -0,0 +1,58 @@ +# The Ouroboros - Agent Persona Definition + +## Role: The Guardian + +You are **The Guardian**, the primary AI agent responsible for maintaining and evolving +The Ouroboros trading system. Your mandate is to ensure system integrity, safety, and +continuous improvement. + +## Prime Directives + +1. **NEVER disable, bypass, or weaken `core/risk_manager.py`.** + The risk manager is the last line of defense against catastrophic loss. + Any code change that reduces risk controls MUST be rejected. + +2. **All code changes require a passing test.** + No module may be modified or created without a corresponding test in `tests/`. + Run `pytest -v --cov=src` before proposing any merge. + +3. **Preserve the Circuit Breaker.** + The daily P&L circuit breaker (-3.0% threshold) is non-negotiable. + It may only be made *stricter*, never relaxed. + +4. **Fat Finger Protection is sacred.** + The 30% max-order-size rule must remain enforced at all times. + +## Decision Framework + +When modifying code, follow this priority order: + +1. **Safety** - Will this change increase risk exposure? If yes, reject. +2. **Correctness** - Is the logic provably correct? Verify with tests. +3. **Performance** - Only optimize after safety and correctness are guaranteed. +4. **Readability** - Code must be understandable by future agents and humans. + +## File Ownership + +| Module | Guardian Rule | +|---|---| +| `core/risk_manager.py` | READ-ONLY. Changes require human approval + 2 passing test suites. | +| `broker/kis_api.py` | Rate limiter must never be removed. Token refresh must remain automatic. | +| `brain/gemini_client.py` | Confidence < 80 MUST force HOLD. This rule cannot be weakened. | +| `evolution/optimizer.py` | Generated strategies must pass ALL tests before activation. | +| `strategies/*` | New strategies are welcome but must inherit `BaseStrategy`. | + +## Prohibited Actions + +- Removing or commenting out `assert` statements in tests +- Hardcoding API keys or secrets in source files +- Disabling rate limiting on broker API calls +- Allowing orders when the circuit breaker has tripped +- Merging code with test coverage below 80% + +## Context for Collaboration + +When working with other AI agents (Cursor, Cline, etc.): +- Share this document as the system constitution +- All agents must acknowledge these rules before making changes +- Conflicts are resolved by defaulting to the *safer* option diff --git a/docs/skills.md b/docs/skills.md new file mode 100644 index 0000000..dc75ba4 --- /dev/null +++ b/docs/skills.md @@ -0,0 +1,109 @@ +# The Ouroboros - Available Skills & Tools + +## Development Tools + +### Run Tests +```bash +pytest -v --cov=src --cov-report=term-missing +``` +Run the full test suite with coverage reporting. **Must pass before any merge.** + +### Run Specific Test Module +```bash +pytest tests/test_risk.py -v +pytest tests/test_broker.py -v +pytest tests/test_brain.py -v +``` + +### Type Checking +```bash +python -m mypy src/ --strict +``` + +### Linting +```bash +ruff check src/ tests/ +ruff format src/ tests/ +``` + +## Operational Tools + +### Start Trading Agent (Development) +```bash +python -m src.main --mode=paper +``` +Runs the agent in paper-trading mode (no real orders). + +### Start Trading Agent (Production) +```bash +docker compose up -d ouroboros +``` +Runs the full system via Docker Compose with all safety checks enabled. + +### View Logs +```bash +docker compose logs -f ouroboros +``` +Stream JSON-formatted structured logs. + +### Run Backtester +```bash +python -m src.evolution.optimizer --backtest --days=30 +``` +Analyze the last 30 days of trade logs and generate performance metrics. + +## Evolution Tools + +### Generate New Strategy +```bash +python -m src.evolution.optimizer --evolve +``` +Triggers the evolution engine to: +1. Analyze `trade_logs.db` for failing patterns +2. Ask Gemini to generate a new strategy +3. Run tests on the new strategy +4. Create a PR if tests pass + +### Validate Strategy +```bash +pytest tests/ -k "strategy" -v +``` +Run only strategy-related tests to validate a new strategy file. + +## Deployment Tools + +### Build Docker Image +```bash +docker build -t ouroboros:latest . +``` + +### Deploy with Docker Compose +```bash +docker compose up -d +``` + +### Health Check +```bash +curl http://localhost:8080/health +``` + +## Database Tools + +### View Trade Logs +```bash +sqlite3 data/trade_logs.db "SELECT * FROM trades ORDER BY timestamp DESC LIMIT 20;" +``` + +### Export Trade History +```bash +sqlite3 -header -csv data/trade_logs.db "SELECT * FROM trades;" > trades_export.csv +``` + +## Safety Checklist (Pre-Deploy) + +- [ ] `pytest -v --cov=src` passes with >= 80% coverage +- [ ] `ruff check src/ tests/` reports no errors +- [ ] `.env` file contains valid KIS and Gemini API keys +- [ ] Circuit breaker threshold is set to -3.0% or stricter +- [ ] Rate limiter is configured for KIS API limits +- [ ] Docker health check endpoint responds 200 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..a1d141a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,40 @@ +[project] +name = "the-ouroboros" +version = "0.1.0" +description = "Evolutionary AI Trading Agent for KIS" +requires-python = ">=3.11" +dependencies = [ + "aiohttp>=3.9,<4", + "pydantic>=2.5,<3", + "pydantic-settings>=2.1,<3", + "google-generativeai>=0.8,<1", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0,<9", + "pytest-asyncio>=0.23,<1", + "pytest-cov>=5.0,<6", + "ruff>=0.5,<1", + "mypy>=1.10,<2", +] + +[project.scripts] +ouroboros = "src.main:main" + +[tool.pytest.ini_options] +testpaths = ["tests"] +asyncio_mode = "auto" + +[tool.ruff] +target-version = "py311" +line-length = 100 + +[tool.ruff.lint] +select = ["E", "F", "I", "N", "W", "UP"] + +[tool.mypy] +python_version = "3.11" +strict = true +warn_return_any = true +warn_unused_configs = true diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/brain/__init__.py b/src/brain/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/brain/gemini_client.py b/src/brain/gemini_client.py new file mode 100644 index 0000000..9004545 --- /dev/null +++ b/src/brain/gemini_client.py @@ -0,0 +1,152 @@ +"""Decision engine powered by Google Gemini. + +Constructs prompts from market data, calls Gemini, and parses structured +JSON responses into validated TradeDecision objects. +""" + +from __future__ import annotations + +import json +import logging +import re +from dataclasses import dataclass +from typing import Any + +import google.generativeai as genai + +from src.config import Settings + +logger = logging.getLogger(__name__) + +VALID_ACTIONS = {"BUY", "SELL", "HOLD"} + + +@dataclass(frozen=True) +class TradeDecision: + """Validated decision from the AI brain.""" + + action: str # "BUY" | "SELL" | "HOLD" + confidence: int # 0-100 + rationale: str + + +class GeminiClient: + """Wraps the Gemini API for trade decision-making.""" + + def __init__(self, settings: Settings) -> None: + self._settings = settings + self._confidence_threshold = settings.CONFIDENCE_THRESHOLD + genai.configure(api_key=settings.GEMINI_API_KEY) + self._model = genai.GenerativeModel(settings.GEMINI_MODEL) + + # ------------------------------------------------------------------ + # Prompt Construction + # ------------------------------------------------------------------ + + def build_prompt(self, market_data: dict[str, Any]) -> str: + """Build a structured prompt from market data. + + The prompt instructs Gemini to return valid JSON with action, + confidence, and rationale fields. + """ + return ( + "You are a professional Korean stock market trading analyst.\n" + "Analyze the following market data and decide whether to BUY, SELL, or HOLD.\n\n" + f"Stock Code: {market_data['stock_code']}\n" + f"Current Price: {market_data['current_price']}\n" + f"Orderbook: {json.dumps(market_data['orderbook'], ensure_ascii=False)}\n" + f"Foreigner Net Buy/Sell: {market_data['foreigner_net']}\n\n" + "You MUST respond with ONLY valid JSON in the following format:\n" + '{"action": "BUY"|"SELL"|"HOLD", "confidence": , "rationale": ""}\n\n' + "Rules:\n" + "- action must be exactly one of: BUY, SELL, HOLD\n" + "- confidence must be an integer from 0 to 100\n" + "- rationale must explain your reasoning concisely\n" + "- Do NOT wrap the JSON in markdown code blocks\n" + ) + + # ------------------------------------------------------------------ + # Response Parsing + # ------------------------------------------------------------------ + + def parse_response(self, raw: str) -> TradeDecision: + """Parse a raw Gemini response into a TradeDecision. + + Handles: valid JSON, JSON wrapped in markdown code blocks, + malformed JSON, missing fields, and invalid action values. + + On any failure, returns a safe HOLD with confidence 0. + """ + if not raw or not raw.strip(): + logger.warning("Empty response from Gemini — defaulting to HOLD") + return TradeDecision(action="HOLD", confidence=0, rationale="Empty response") + + # Strip markdown code fences if present + cleaned = raw.strip() + match = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", cleaned, re.DOTALL) + if match: + cleaned = match.group(1).strip() + + try: + data = json.loads(cleaned) + except json.JSONDecodeError: + logger.warning("Malformed JSON from Gemini — defaulting to HOLD") + return TradeDecision( + action="HOLD", confidence=0, rationale="Malformed JSON response" + ) + + # Validate required fields + if not all(k in data for k in ("action", "confidence", "rationale")): + logger.warning("Missing fields in Gemini response — defaulting to HOLD") + return TradeDecision( + action="HOLD", confidence=0, rationale="Missing required fields" + ) + + action = str(data["action"]).upper() + if action not in VALID_ACTIONS: + logger.warning("Invalid action '%s' from Gemini — defaulting to HOLD", action) + return TradeDecision( + action="HOLD", confidence=0, rationale=f"Invalid action: {action}" + ) + + confidence = int(data["confidence"]) + rationale = str(data["rationale"]) + + # Enforce confidence threshold + if confidence < self._confidence_threshold: + logger.info( + "Confidence %d < threshold %d — forcing HOLD", + confidence, + self._confidence_threshold, + ) + action = "HOLD" + + return TradeDecision(action=action, confidence=confidence, rationale=rationale) + + # ------------------------------------------------------------------ + # API Call + # ------------------------------------------------------------------ + + async def decide(self, market_data: dict[str, Any]) -> TradeDecision: + """Build prompt, call Gemini, and return a parsed decision.""" + prompt = self.build_prompt(market_data) + logger.info("Requesting trade decision from Gemini") + + try: + response = await self._model.generate_content_async(prompt) + raw = response.text + except Exception as exc: + logger.error("Gemini API error: %s", exc) + return TradeDecision( + action="HOLD", confidence=0, rationale=f"API error: {exc}" + ) + + decision = self.parse_response(raw) + logger.info( + "Gemini decision", + extra={ + "action": decision.action, + "confidence": decision.confidence, + }, + ) + return decision diff --git a/src/broker/__init__.py b/src/broker/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/broker/kis_api.py b/src/broker/kis_api.py new file mode 100644 index 0000000..104df5f --- /dev/null +++ b/src/broker/kis_api.py @@ -0,0 +1,245 @@ +"""Async wrapper for the Korea Investment Securities (KIS) Open API. + +Handles token refresh, rate limiting (leaky bucket), and hash key generation. +""" + +from __future__ import annotations + +import asyncio +import hashlib +import json +import logging +import time +from typing import Any + +import aiohttp + +from src.config import Settings + +logger = logging.getLogger(__name__) + + +class LeakyBucket: + """Simple leaky-bucket rate limiter for async code.""" + + def __init__(self, rate: float) -> None: + """Args: + rate: Maximum requests per second. + """ + self._rate = rate + self._interval = 1.0 / rate + self._last = 0.0 + self._lock = asyncio.Lock() + + async def acquire(self) -> None: + async with self._lock: + now = asyncio.get_event_loop().time() + wait = self._last + self._interval - now + if wait > 0: + await asyncio.sleep(wait) + self._last = asyncio.get_event_loop().time() + + +class KISBroker: + """Async client for KIS Open API with automatic token management.""" + + def __init__(self, settings: Settings) -> None: + self._settings = settings + self._base_url = settings.KIS_BASE_URL + self._app_key = settings.KIS_APP_KEY + self._app_secret = settings.KIS_APP_SECRET + self._account_no = settings.account_number + self._product_cd = settings.account_product_code + + self._session: aiohttp.ClientSession | None = None + self._access_token: str | None = None + self._token_expires_at: float = 0.0 + self._rate_limiter = LeakyBucket(settings.RATE_LIMIT_RPS) + + def _get_session(self) -> aiohttp.ClientSession: + if self._session is None or self._session.closed: + timeout = aiohttp.ClientTimeout(total=10) + self._session = aiohttp.ClientSession(timeout=timeout) + return self._session + + async def close(self) -> None: + if self._session and not self._session.closed: + await self._session.close() + + # ------------------------------------------------------------------ + # Token Management + # ------------------------------------------------------------------ + + async def _ensure_token(self) -> str: + """Return a valid access token, refreshing if expired.""" + now = asyncio.get_event_loop().time() + if self._access_token and now < self._token_expires_at: + return self._access_token + + logger.info("Refreshing KIS access token") + session = self._get_session() + url = f"{self._base_url}/oauth2/tokenP" + body = { + "grant_type": "client_credentials", + "appkey": self._app_key, + "appsecret": self._app_secret, + } + + async with session.post(url, json=body) as resp: + if resp.status != 200: + text = await resp.text() + raise ConnectionError(f"Token refresh failed ({resp.status}): {text}") + data = await resp.json() + + self._access_token = data["access_token"] + self._token_expires_at = now + data.get("expires_in", 86400) - 60 # 1-min buffer + logger.info("Token refreshed successfully") + return self._access_token + + # ------------------------------------------------------------------ + # Hash Key (required for POST bodies) + # ------------------------------------------------------------------ + + async def _get_hash_key(self, body: dict[str, Any]) -> str: + """Request a hash key from KIS for POST request body signing.""" + session = self._get_session() + url = f"{self._base_url}/uapi/hashkey" + headers = { + "content-Type": "application/json", + "appKey": self._app_key, + "appSecret": self._app_secret, + } + + async with session.post(url, json=body, headers=headers) as resp: + if resp.status != 200: + text = await resp.text() + raise ConnectionError(f"Hash key request failed ({resp.status}): {text}") + data = await resp.json() + + return data["HASH"] + + # ------------------------------------------------------------------ + # Common Headers + # ------------------------------------------------------------------ + + async def _auth_headers(self, tr_id: str) -> dict[str, str]: + token = await self._ensure_token() + return { + "content-type": "application/json; charset=utf-8", + "authorization": f"Bearer {token}", + "appkey": self._app_key, + "appsecret": self._app_secret, + "tr_id": tr_id, + } + + # ------------------------------------------------------------------ + # API Methods + # ------------------------------------------------------------------ + + async def get_orderbook(self, stock_code: str) -> dict[str, Any]: + """Fetch the current orderbook for a given stock code.""" + await self._rate_limiter.acquire() + session = self._get_session() + + headers = await self._auth_headers("FHKST01010200") + params = { + "FID_COND_MRKT_DIV_CODE": "J", + "FID_INPUT_ISCD": stock_code, + } + url = f"{self._base_url}/uapi/domestic-stock/v1/quotations/inquire-asking-price-exp-ccn" + + try: + async with session.get(url, headers=headers, params=params) as resp: + if resp.status != 200: + text = await resp.text() + raise ConnectionError( + f"get_orderbook failed ({resp.status}): {text}" + ) + return await resp.json() + except (aiohttp.ClientError, asyncio.TimeoutError) as exc: + raise ConnectionError(f"Network error fetching orderbook: {exc}") from exc + + async def get_balance(self) -> dict[str, Any]: + """Fetch current account balance and holdings.""" + await self._rate_limiter.acquire() + session = self._get_session() + + headers = await self._auth_headers("VTTC8434R") # 모의투자 잔고조회 + params = { + "CANO": self._account_no, + "ACNT_PRDT_CD": self._product_cd, + "AFHR_FLPR_YN": "N", + "OFL_YN": "", + "INQR_DVSN": "02", + "UNPR_DVSN": "01", + "FUND_STTL_ICLD_YN": "N", + "FNCG_AMT_AUTO_RDPT_YN": "N", + "PRCS_DVSN": "01", + "CTX_AREA_FK100": "", + "CTX_AREA_NK100": "", + } + url = f"{self._base_url}/uapi/domestic-stock/v1/trading/inquire-balance" + + try: + async with session.get(url, headers=headers, params=params) as resp: + if resp.status != 200: + text = await resp.text() + raise ConnectionError( + f"get_balance failed ({resp.status}): {text}" + ) + return await resp.json() + except (aiohttp.ClientError, asyncio.TimeoutError) as exc: + raise ConnectionError(f"Network error fetching balance: {exc}") from exc + + async def send_order( + self, + stock_code: str, + order_type: str, # "BUY" or "SELL" + quantity: int, + price: int = 0, + ) -> dict[str, Any]: + """Submit a buy or sell order. + + Args: + stock_code: 6-digit stock code. + order_type: "BUY" or "SELL". + quantity: Number of shares. + price: Order price (0 for market order). + """ + await self._rate_limiter.acquire() + session = self._get_session() + + tr_id = "VTTC0802U" if order_type == "BUY" else "VTTC0801U" + body = { + "CANO": self._account_no, + "ACNT_PRDT_CD": self._product_cd, + "PDNO": stock_code, + "ORD_DVSN": "01" if price > 0 else "06", # 01=지정가, 06=시장가 + "ORD_QTY": str(quantity), + "ORD_UNPR": str(price), + } + + hash_key = await self._get_hash_key(body) + headers = await self._auth_headers(tr_id) + headers["hashkey"] = hash_key + + url = f"{self._base_url}/uapi/domestic-stock/v1/trading/order-cash" + + try: + async with session.post(url, headers=headers, json=body) as resp: + if resp.status != 200: + text = await resp.text() + raise ConnectionError( + f"send_order failed ({resp.status}): {text}" + ) + data = await resp.json() + logger.info( + "Order submitted", + extra={ + "stock_code": stock_code, + "action": order_type, + }, + ) + return data + except (aiohttp.ClientError, asyncio.TimeoutError) as exc: + raise ConnectionError(f"Network error sending order: {exc}") from exc diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000..d9bd569 --- /dev/null +++ b/src/config.py @@ -0,0 +1,44 @@ +"""Strictly typed configuration loaded from environment variables.""" + +from __future__ import annotations + +from pydantic import Field +from pydantic_settings import BaseSettings + + +class Settings(BaseSettings): + """Application settings — loaded from .env or environment variables.""" + + # KIS Open API + KIS_APP_KEY: str + KIS_APP_SECRET: str + KIS_ACCOUNT_NO: str # format: "XXXXXXXX-XX" + KIS_BASE_URL: str = "https://openapivts.koreainvestment.com:9443" + + # Google Gemini + GEMINI_API_KEY: str + GEMINI_MODEL: str = "gemini-pro" + + # Risk Management + CIRCUIT_BREAKER_PCT: float = Field(default=-3.0, le=0.0) + FAT_FINGER_PCT: float = Field(default=30.0, gt=0.0, le=100.0) + CONFIDENCE_THRESHOLD: int = Field(default=80, ge=0, le=100) + + # Database + DB_PATH: str = "data/trade_logs.db" + + # Rate Limiting (requests per second for KIS API) + RATE_LIMIT_RPS: float = 10.0 + + # Trading mode + MODE: str = Field(default="paper", pattern="^(paper|live)$") + + model_config = {"env_file": ".env", "env_file_encoding": "utf-8"} + + @property + def account_number(self) -> str: + return self.KIS_ACCOUNT_NO.split("-")[0] + + @property + def account_product_code(self) -> str: + return self.KIS_ACCOUNT_NO.split("-")[1] diff --git a/src/core/__init__.py b/src/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/core/risk_manager.py b/src/core/risk_manager.py new file mode 100644 index 0000000..22d755b --- /dev/null +++ b/src/core/risk_manager.py @@ -0,0 +1,84 @@ +"""Risk management — the Shield that protects the portfolio. + +This module is READ-ONLY by policy (see docs/agents.md). +Changes require human approval and two passing test suites. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass + +from src.config import Settings + +logger = logging.getLogger(__name__) + + +class CircuitBreakerTripped(SystemExit): + """Raised when daily P&L loss exceeds the allowed threshold.""" + + def __init__(self, pnl_pct: float, threshold: float) -> None: + self.pnl_pct = pnl_pct + self.threshold = threshold + super().__init__( + f"CIRCUIT BREAKER: Daily P&L {pnl_pct:.2f}% exceeded " + f"threshold {threshold:.2f}%. All trading halted." + ) + + +class FatFingerRejected(Exception): + """Raised when an order exceeds the maximum allowed proportion of cash.""" + + def __init__(self, order_amount: float, total_cash: float, max_pct: float) -> None: + self.order_amount = order_amount + self.total_cash = total_cash + self.max_pct = max_pct + ratio = (order_amount / total_cash * 100) if total_cash > 0 else float("inf") + super().__init__( + f"FAT FINGER: Order {order_amount:,.0f} is {ratio:.1f}% of " + f"cash {total_cash:,.0f} (max allowed: {max_pct:.1f}%)." + ) + + +class RiskManager: + """Pre-order risk gate that enforces circuit breaker and fat-finger checks.""" + + def __init__(self, settings: Settings) -> None: + self._cb_threshold = settings.CIRCUIT_BREAKER_PCT + self._ff_max_pct = settings.FAT_FINGER_PCT + + def check_circuit_breaker(self, current_pnl_pct: float) -> None: + """Halt trading if daily loss exceeds the threshold. + + The threshold is inclusive: exactly -3.0% is allowed, but -3.01% is not. + """ + if current_pnl_pct < self._cb_threshold: + logger.critical( + "Circuit breaker tripped", + extra={"pnl_pct": current_pnl_pct}, + ) + raise CircuitBreakerTripped(current_pnl_pct, self._cb_threshold) + + def check_fat_finger(self, order_amount: float, total_cash: float) -> None: + """Reject orders that exceed the maximum proportion of available cash.""" + if total_cash <= 0: + raise FatFingerRejected(order_amount, total_cash, self._ff_max_pct) + + ratio_pct = (order_amount / total_cash) * 100 + if ratio_pct > self._ff_max_pct: + logger.warning( + "Fat finger check failed", + extra={"order_amount": order_amount}, + ) + raise FatFingerRejected(order_amount, total_cash, self._ff_max_pct) + + def validate_order( + self, + current_pnl_pct: float, + order_amount: float, + total_cash: float, + ) -> None: + """Run all pre-order risk checks. Raises on failure.""" + self.check_circuit_breaker(current_pnl_pct) + self.check_fat_finger(order_amount, total_cash) + logger.info("Order passed risk validation") diff --git a/src/db.py b/src/db.py new file mode 100644 index 0000000..16f1e0e --- /dev/null +++ b/src/db.py @@ -0,0 +1,59 @@ +"""Database layer for trade logging.""" + +from __future__ import annotations + +import sqlite3 +from datetime import datetime, timezone +from typing import Any + + +def init_db(db_path: str) -> sqlite3.Connection: + """Initialize the trade logs database and return a connection.""" + conn = sqlite3.connect(db_path) + conn.execute( + """ + CREATE TABLE IF NOT EXISTS trades ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL, + stock_code TEXT NOT NULL, + action TEXT NOT NULL, + confidence INTEGER NOT NULL, + rationale TEXT, + quantity INTEGER, + price REAL, + pnl REAL DEFAULT 0.0 + ) + """ + ) + conn.commit() + return conn + + +def log_trade( + conn: sqlite3.Connection, + stock_code: str, + action: str, + confidence: int, + rationale: str, + quantity: int = 0, + price: float = 0.0, + pnl: float = 0.0, +) -> None: + """Insert a trade record into the database.""" + conn.execute( + """ + INSERT INTO trades (timestamp, stock_code, action, confidence, rationale, quantity, price, pnl) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + datetime.now(timezone.utc).isoformat(), + stock_code, + action, + confidence, + rationale, + quantity, + price, + pnl, + ), + ) + conn.commit() diff --git a/src/evolution/__init__.py b/src/evolution/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/evolution/optimizer.py b/src/evolution/optimizer.py new file mode 100644 index 0000000..4947b76 --- /dev/null +++ b/src/evolution/optimizer.py @@ -0,0 +1,229 @@ +"""Evolution Engine — analyzes trade logs and generates new strategies. + +This module: +1. Reads trade_logs.db to identify failing patterns +2. Asks Gemini to generate a new strategy class +3. Runs pytest on the generated file +4. Creates a simulated PR if tests pass +""" + +from __future__ import annotations + +import json +import logging +import sqlite3 +import subprocess +import textwrap +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +import google.generativeai as genai + +from src.config import Settings + +logger = logging.getLogger(__name__) + +STRATEGIES_DIR = Path("src/strategies") +STRATEGY_TEMPLATE = textwrap.dedent("""\ + \"\"\"Auto-generated strategy: {name} + + Generated at: {timestamp} + Rationale: {rationale} + \"\"\" + + from __future__ import annotations + from typing import Any + from src.strategies.base import BaseStrategy + + + class {class_name}(BaseStrategy): + \"\"\"Strategy: {name}\"\"\" + + def evaluate(self, market_data: dict[str, Any]) -> dict[str, Any]: + {body} +""") + + +class EvolutionOptimizer: + """Analyzes trade history and evolves trading strategies.""" + + def __init__(self, settings: Settings) -> None: + self._settings = settings + self._db_path = settings.DB_PATH + genai.configure(api_key=settings.GEMINI_API_KEY) + self._model = genai.GenerativeModel(settings.GEMINI_MODEL) + + # ------------------------------------------------------------------ + # Analysis + # ------------------------------------------------------------------ + + def analyze_failures(self, limit: int = 50) -> list[dict[str, Any]]: + """Find trades where high confidence led to losses.""" + conn = sqlite3.connect(self._db_path) + conn.row_factory = sqlite3.Row + try: + rows = conn.execute( + """ + SELECT stock_code, action, confidence, pnl, rationale, timestamp + FROM trades + WHERE confidence >= 80 AND pnl < 0 + ORDER BY pnl ASC + LIMIT ? + """, + (limit,), + ).fetchall() + return [dict(r) for r in rows] + finally: + conn.close() + + def get_performance_summary(self) -> dict[str, Any]: + """Return aggregate performance metrics from trade logs.""" + conn = sqlite3.connect(self._db_path) + try: + row = conn.execute( + """ + SELECT + COUNT(*) as total_trades, + SUM(CASE WHEN pnl > 0 THEN 1 ELSE 0 END) as wins, + SUM(CASE WHEN pnl < 0 THEN 1 ELSE 0 END) as losses, + COALESCE(AVG(pnl), 0) as avg_pnl, + COALESCE(SUM(pnl), 0) as total_pnl + FROM trades + """ + ).fetchone() + return { + "total_trades": row[0], + "wins": row[1] or 0, + "losses": row[2] or 0, + "avg_pnl": round(row[3], 2), + "total_pnl": round(row[4], 2), + } + finally: + conn.close() + + # ------------------------------------------------------------------ + # Strategy Generation + # ------------------------------------------------------------------ + + async def generate_strategy(self, failures: list[dict[str, Any]]) -> Path | None: + """Ask Gemini to generate a new strategy based on failure analysis. + + Returns the path to the generated strategy file, or None on failure. + """ + prompt = ( + "You are a quantitative trading strategy developer.\n" + "Analyze these failed trades and generate an improved strategy.\n\n" + f"Failed trades:\n{json.dumps(failures, indent=2, default=str)}\n\n" + "Generate a Python class that inherits from BaseStrategy.\n" + "The class must have an `evaluate(self, market_data: dict) -> dict` method.\n" + "The method must return a dict with keys: action, confidence, rationale.\n" + "Respond with ONLY the method body (Python code), no class definition.\n" + ) + + try: + response = await self._model.generate_content_async(prompt) + body = response.text.strip() + except Exception as exc: + logger.error("Failed to generate strategy: %s", exc) + return None + + # Clean up code fences + if body.startswith("```"): + lines = body.split("\n") + body = "\n".join(lines[1:-1]) + + # Create strategy file + timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") + version = f"v{timestamp}" + class_name = f"Strategy_{version}" + file_name = f"{version}_evolved.py" + + STRATEGIES_DIR.mkdir(parents=True, exist_ok=True) + file_path = STRATEGIES_DIR / file_name + + # Indent the body for the class method + indented_body = textwrap.indent(body, " ") + + content = STRATEGY_TEMPLATE.format( + name=version, + timestamp=datetime.now(timezone.utc).isoformat(), + rationale="Auto-evolved from failure analysis", + class_name=class_name, + body=indented_body.strip(), + ) + + file_path.write_text(content) + logger.info("Generated strategy file: %s", file_path) + return file_path + + # ------------------------------------------------------------------ + # Validation + # ------------------------------------------------------------------ + + def validate_strategy(self, strategy_path: Path) -> bool: + """Run pytest on the generated strategy. Returns True if all tests pass.""" + logger.info("Validating strategy: %s", strategy_path) + result = subprocess.run( + ["python", "-m", "pytest", "tests/", "-v", "--tb=short"], + capture_output=True, + text=True, + timeout=120, + ) + if result.returncode == 0: + logger.info("Strategy validation PASSED") + return True + else: + logger.warning( + "Strategy validation FAILED:\n%s", result.stdout + result.stderr + ) + # Clean up failing strategy + strategy_path.unlink(missing_ok=True) + return False + + # ------------------------------------------------------------------ + # PR Simulation + # ------------------------------------------------------------------ + + def create_pr_simulation(self, strategy_path: Path) -> dict[str, str]: + """Simulate creating a pull request for the new strategy.""" + pr = { + "title": f"[Evolution] New strategy: {strategy_path.stem}", + "branch": f"evolution/{strategy_path.stem}", + "body": ( + f"Auto-generated strategy from evolution engine.\n" + f"File: {strategy_path}\n" + f"All tests passed." + ), + "status": "ready_for_review", + } + logger.info("PR simulation created: %s", pr["title"]) + return pr + + # ------------------------------------------------------------------ + # Full Pipeline + # ------------------------------------------------------------------ + + async def evolve(self) -> dict[str, Any] | None: + """Run the full evolution pipeline. + + 1. Analyze failures + 2. Generate new strategy + 3. Validate with tests + 4. Create PR simulation + + Returns PR info on success, None on failure. + """ + failures = self.analyze_failures() + if not failures: + logger.info("No failure patterns found — skipping evolution") + return None + + strategy_path = await self.generate_strategy(failures) + if strategy_path is None: + return None + + if not self.validate_strategy(strategy_path): + return None + + return self.create_pr_simulation(strategy_path) diff --git a/src/logging_config.py b/src/logging_config.py new file mode 100644 index 0000000..54fd50a --- /dev/null +++ b/src/logging_config.py @@ -0,0 +1,42 @@ +"""JSON-formatted structured logging for machine readability.""" + +from __future__ import annotations + +import logging +import sys +from datetime import datetime, timezone +from typing import Any + +import json + + +class JSONFormatter(logging.Formatter): + """Emit log records as single-line JSON objects.""" + + def format(self, record: logging.LogRecord) -> str: + log_entry: dict[str, Any] = { + "timestamp": datetime.now(timezone.utc).isoformat(), + "level": record.levelname, + "logger": record.name, + "message": record.getMessage(), + } + if record.exc_info and record.exc_info[1]: + log_entry["exception"] = self.formatException(record.exc_info) + # Merge any extra fields attached to the record + for key in ("stock_code", "action", "confidence", "pnl_pct", "order_amount"): + value = getattr(record, key, None) + if value is not None: + log_entry[key] = value + return json.dumps(log_entry, ensure_ascii=False) + + +def setup_logging(level: int = logging.INFO) -> None: + """Configure the root logger with JSON output to stdout.""" + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(JSONFormatter()) + + root = logging.getLogger() + root.setLevel(level) + # Avoid duplicate handlers on repeated calls + root.handlers.clear() + root.addHandler(handler) diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..cacb329 --- /dev/null +++ b/src/main.py @@ -0,0 +1,171 @@ +"""The Ouroboros — main trading loop. + +Orchestrates the broker, brain, and risk manager into a continuous +trading cycle with configurable intervals. +""" + +from __future__ import annotations + +import argparse +import asyncio +import logging +import signal +import sys +from typing import Any + +from src.brain.gemini_client import GeminiClient +from src.broker.kis_api import KISBroker +from src.config import Settings +from src.core.risk_manager import CircuitBreakerTripped, RiskManager +from src.db import init_db, log_trade +from src.logging_config import setup_logging + +logger = logging.getLogger(__name__) + +# Target stock codes to monitor +WATCHLIST = ["005930", "000660", "035420"] # Samsung, SK Hynix, NAVER + +TRADE_INTERVAL_SECONDS = 60 + + +async def trading_cycle( + broker: KISBroker, + brain: GeminiClient, + risk: RiskManager, + db_conn: Any, + stock_code: str, +) -> None: + """Execute one trading cycle for a single stock.""" + # 1. Fetch market data + orderbook = await broker.get_orderbook(stock_code) + balance_data = await broker.get_balance() + + output2 = balance_data.get("output2", [{}]) + total_eval = float(output2[0].get("tot_evlu_amt", "0")) if output2 else 0 + total_cash = float( + balance_data.get("output2", [{}])[0].get("dnca_tot_amt", "0") + if output2 + else "0" + ) + purchase_total = float(output2[0].get("pchs_amt_smtl_amt", "0")) if output2 else 0 + + # Calculate daily P&L % + pnl_pct = ((total_eval - purchase_total) / purchase_total * 100) if purchase_total > 0 else 0.0 + + current_price = float( + orderbook.get("output1", {}).get("stck_prpr", "0") + ) + + market_data = { + "stock_code": stock_code, + "current_price": current_price, + "orderbook": orderbook.get("output1", {}), + "foreigner_net": float( + orderbook.get("output1", {}).get("frgn_ntby_qty", "0") + ), + } + + # 2. Ask the brain for a decision + decision = await brain.decide(market_data) + logger.info( + "Decision for %s: %s (confidence=%d)", + stock_code, + decision.action, + decision.confidence, + ) + + # 3. Execute if actionable + if decision.action in ("BUY", "SELL"): + # Determine order size (simplified: 1 lot) + quantity = 1 + order_amount = current_price * quantity + + # 4. Risk check BEFORE order + risk.validate_order( + current_pnl_pct=pnl_pct, + order_amount=order_amount, + total_cash=total_cash, + ) + + # 5. Send order + result = await broker.send_order( + stock_code=stock_code, + order_type=decision.action, + quantity=quantity, + price=0, # market order + ) + logger.info("Order result: %s", result.get("msg1", "OK")) + + # 6. Log trade + log_trade( + conn=db_conn, + stock_code=stock_code, + action=decision.action, + confidence=decision.confidence, + rationale=decision.rationale, + ) + + +async def run(settings: Settings) -> None: + """Main async loop — iterate over watchlist on a timer.""" + broker = KISBroker(settings) + brain = GeminiClient(settings) + risk = RiskManager(settings) + db_conn = init_db(settings.DB_PATH) + + shutdown = asyncio.Event() + + def _signal_handler() -> None: + logger.info("Shutdown signal received") + shutdown.set() + + loop = asyncio.get_running_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + loop.add_signal_handler(sig, _signal_handler) + + logger.info("The Ouroboros is alive. Mode: %s", settings.MODE) + logger.info("Watchlist: %s", WATCHLIST) + + try: + while not shutdown.is_set(): + for code in WATCHLIST: + if shutdown.is_set(): + break + try: + await trading_cycle(broker, brain, risk, db_conn, code) + except CircuitBreakerTripped: + logger.critical("Circuit breaker tripped — shutting down") + raise + except ConnectionError as exc: + logger.error("Connection error for %s: %s", code, exc) + except Exception as exc: + logger.exception("Unexpected error for %s: %s", code, exc) + + # Wait for next cycle or shutdown + try: + await asyncio.wait_for(shutdown.wait(), timeout=TRADE_INTERVAL_SECONDS) + except asyncio.TimeoutError: + pass # Normal — timeout means it's time for next cycle + finally: + await broker.close() + db_conn.close() + logger.info("The Ouroboros rests.") + + +def main() -> None: + parser = argparse.ArgumentParser(description="The Ouroboros Trading Agent") + parser.add_argument( + "--mode", + choices=["paper", "live"], + default="paper", + help="Trading mode (default: paper)", + ) + args = parser.parse_args() + + setup_logging() + settings = Settings(MODE=args.mode) # type: ignore[call-arg] + asyncio.run(run(settings)) + + +if __name__ == "__main__": + main() diff --git a/src/strategies/__init__.py b/src/strategies/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/strategies/base.py b/src/strategies/base.py new file mode 100644 index 0000000..d9878e5 --- /dev/null +++ b/src/strategies/base.py @@ -0,0 +1,19 @@ +"""Base class for all trading strategies.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + + +class BaseStrategy(ABC): + """All strategies must inherit from this class.""" + + @abstractmethod + def evaluate(self, market_data: dict[str, Any]) -> dict[str, Any]: + """Evaluate market data and return a trade decision. + + Returns: + dict with keys: action ("BUY"|"SELL"|"HOLD"), confidence (int), rationale (str) + """ + ... diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..fdb0b08 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,23 @@ +"""Shared test fixtures for The Ouroboros test suite.""" + +from __future__ import annotations + +import pytest + +from src.config import Settings + + +@pytest.fixture +def settings() -> Settings: + """Return a Settings instance with safe test defaults.""" + return Settings( + KIS_APP_KEY="test_app_key", + KIS_APP_SECRET="test_app_secret", + KIS_ACCOUNT_NO="12345678-01", + KIS_BASE_URL="https://openapivts.koreainvestment.com:9443", + GEMINI_API_KEY="test_gemini_key", + CIRCUIT_BREAKER_PCT=-3.0, + FAT_FINGER_PCT=30.0, + CONFIDENCE_THRESHOLD=80, + DB_PATH=":memory:", + ) diff --git a/tests/test_brain.py b/tests/test_brain.py new file mode 100644 index 0000000..204fcd1 --- /dev/null +++ b/tests/test_brain.py @@ -0,0 +1,159 @@ +"""TDD tests for brain/gemini_client.py — written BEFORE implementation.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from src.brain.gemini_client import GeminiClient, TradeDecision + + +# --------------------------------------------------------------------------- +# Response Parsing +# --------------------------------------------------------------------------- + + +class TestResponseParsing: + """Gemini responses must be parsed into validated TradeDecision objects.""" + + def test_valid_buy_response(self, settings): + client = GeminiClient(settings) + raw = '{"action": "BUY", "confidence": 90, "rationale": "Strong momentum"}' + decision = client.parse_response(raw) + assert decision.action == "BUY" + assert decision.confidence == 90 + assert decision.rationale == "Strong momentum" + + def test_valid_sell_response(self, settings): + client = GeminiClient(settings) + raw = '{"action": "SELL", "confidence": 85, "rationale": "Overbought RSI"}' + decision = client.parse_response(raw) + assert decision.action == "SELL" + + def test_valid_hold_response(self, settings): + client = GeminiClient(settings) + raw = '{"action": "HOLD", "confidence": 95, "rationale": "Sideways market"}' + decision = client.parse_response(raw) + assert decision.action == "HOLD" + + +# --------------------------------------------------------------------------- +# Confidence Threshold Enforcement +# --------------------------------------------------------------------------- + + +class TestConfidenceThreshold: + """If confidence < 80, the action MUST be forced to HOLD.""" + + def test_low_confidence_buy_becomes_hold(self, settings): + client = GeminiClient(settings) + raw = '{"action": "BUY", "confidence": 65, "rationale": "Weak signal"}' + decision = client.parse_response(raw) + assert decision.action == "HOLD" + assert decision.confidence == 65 + + def test_low_confidence_sell_becomes_hold(self, settings): + client = GeminiClient(settings) + raw = '{"action": "SELL", "confidence": 79, "rationale": "Uncertain"}' + decision = client.parse_response(raw) + assert decision.action == "HOLD" + + def test_exactly_threshold_is_allowed(self, settings): + client = GeminiClient(settings) + raw = '{"action": "BUY", "confidence": 80, "rationale": "Just enough"}' + decision = client.parse_response(raw) + assert decision.action == "BUY" + + +# --------------------------------------------------------------------------- +# Malformed JSON Handling +# --------------------------------------------------------------------------- + + +class TestMalformedJsonHandling: + """Gemini may return garbage — the parser must not crash.""" + + def test_empty_string_returns_hold(self, settings): + client = GeminiClient(settings) + decision = client.parse_response("") + assert decision.action == "HOLD" + assert decision.confidence == 0 + + def test_plain_text_returns_hold(self, settings): + client = GeminiClient(settings) + decision = client.parse_response("I think you should buy Samsung stock") + assert decision.action == "HOLD" + assert decision.confidence == 0 + + def test_partial_json_returns_hold(self, settings): + client = GeminiClient(settings) + decision = client.parse_response('{"action": "BUY", "confidence":') + assert decision.action == "HOLD" + assert decision.confidence == 0 + + def test_json_with_missing_fields_returns_hold(self, settings): + client = GeminiClient(settings) + decision = client.parse_response('{"action": "BUY"}') + assert decision.action == "HOLD" + assert decision.confidence == 0 + + def test_json_with_invalid_action_returns_hold(self, settings): + client = GeminiClient(settings) + decision = client.parse_response( + '{"action": "YOLO", "confidence": 99, "rationale": "moon"}' + ) + assert decision.action == "HOLD" + assert decision.confidence == 0 + + def test_json_wrapped_in_markdown_code_block(self, settings): + """Gemini often wraps JSON in ```json ... ``` blocks.""" + client = GeminiClient(settings) + raw = '```json\n{"action": "BUY", "confidence": 92, "rationale": "Good"}\n```' + decision = client.parse_response(raw) + assert decision.action == "BUY" + assert decision.confidence == 92 + + +# --------------------------------------------------------------------------- +# Prompt Construction +# --------------------------------------------------------------------------- + + +class TestPromptConstruction: + """The prompt sent to Gemini must include all required market data.""" + + def test_prompt_contains_stock_code(self, settings): + client = GeminiClient(settings) + market_data = { + "stock_code": "005930", + "current_price": 72000, + "orderbook": {"asks": [], "bids": []}, + "foreigner_net": -50000, + } + prompt = client.build_prompt(market_data) + assert "005930" in prompt + + def test_prompt_contains_price(self, settings): + client = GeminiClient(settings) + market_data = { + "stock_code": "005930", + "current_price": 72000, + "orderbook": {"asks": [], "bids": []}, + "foreigner_net": -50000, + } + prompt = client.build_prompt(market_data) + assert "72000" in prompt + + def test_prompt_enforces_json_output_format(self, settings): + client = GeminiClient(settings) + market_data = { + "stock_code": "005930", + "current_price": 72000, + "orderbook": {"asks": [], "bids": []}, + "foreigner_net": 0, + } + prompt = client.build_prompt(market_data) + assert "JSON" in prompt + assert "action" in prompt + assert "confidence" in prompt diff --git a/tests/test_broker.py b/tests/test_broker.py new file mode 100644 index 0000000..5d8d2ac --- /dev/null +++ b/tests/test_broker.py @@ -0,0 +1,140 @@ +"""TDD tests for broker/kis_api.py — written BEFORE implementation.""" + +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import aiohttp +import pytest + +from src.broker.kis_api import KISBroker + + +# --------------------------------------------------------------------------- +# Token Management +# --------------------------------------------------------------------------- + + +class TestTokenManagement: + """Access token must be auto-refreshed and cached.""" + + @pytest.mark.asyncio + async def test_fetches_token_on_first_call(self, settings): + broker = KISBroker(settings) + mock_resp = AsyncMock() + mock_resp.status = 200 + mock_resp.json = AsyncMock( + return_value={ + "access_token": "tok_abc123", + "token_type": "Bearer", + "expires_in": 86400, + } + ) + mock_resp.__aenter__ = AsyncMock(return_value=mock_resp) + mock_resp.__aexit__ = AsyncMock(return_value=False) + + with patch("aiohttp.ClientSession.post", return_value=mock_resp): + token = await broker._ensure_token() + assert token == "tok_abc123" + + await broker.close() + + @pytest.mark.asyncio + async def test_reuses_cached_token(self, settings): + broker = KISBroker(settings) + broker._access_token = "cached_token" + broker._token_expires_at = asyncio.get_event_loop().time() + 3600 + + token = await broker._ensure_token() + assert token == "cached_token" + + await broker.close() + + +# --------------------------------------------------------------------------- +# Network Error Handling +# --------------------------------------------------------------------------- + + +class TestNetworkErrorHandling: + """Broker must handle network timeouts and HTTP errors gracefully.""" + + @pytest.mark.asyncio + async def test_timeout_raises_connection_error(self, settings): + broker = KISBroker(settings) + broker._access_token = "tok" + broker._token_expires_at = asyncio.get_event_loop().time() + 3600 + + with patch( + "aiohttp.ClientSession.get", + side_effect=asyncio.TimeoutError(), + ): + with pytest.raises(ConnectionError): + await broker.get_orderbook("005930") + + await broker.close() + + @pytest.mark.asyncio + async def test_http_500_raises_connection_error(self, settings): + broker = KISBroker(settings) + broker._access_token = "tok" + broker._token_expires_at = asyncio.get_event_loop().time() + 3600 + + mock_resp = AsyncMock() + mock_resp.status = 500 + mock_resp.text = AsyncMock(return_value="Internal Server Error") + mock_resp.__aenter__ = AsyncMock(return_value=mock_resp) + mock_resp.__aexit__ = AsyncMock(return_value=False) + + with patch("aiohttp.ClientSession.get", return_value=mock_resp): + with pytest.raises(ConnectionError): + await broker.get_orderbook("005930") + + await broker.close() + + +# --------------------------------------------------------------------------- +# Rate Limiter +# --------------------------------------------------------------------------- + + +class TestRateLimiter: + """The leaky bucket rate limiter must throttle requests.""" + + @pytest.mark.asyncio + async def test_rate_limiter_does_not_block_under_limit(self, settings): + broker = KISBroker(settings) + # Should complete without blocking when under limit + await broker._rate_limiter.acquire() + await broker.close() + + +# --------------------------------------------------------------------------- +# Hash Key Generation +# --------------------------------------------------------------------------- + + +class TestHashKey: + """POST requests to KIS require a hash key.""" + + @pytest.mark.asyncio + async def test_generates_hash_key_for_post_body(self, settings): + broker = KISBroker(settings) + broker._access_token = "tok" + broker._token_expires_at = asyncio.get_event_loop().time() + 3600 + + body = {"CANO": "12345678", "ACNT_PRDT_CD": "01"} + + mock_resp = AsyncMock() + mock_resp.status = 200 + mock_resp.json = AsyncMock(return_value={"HASH": "abc123hash"}) + mock_resp.__aenter__ = AsyncMock(return_value=mock_resp) + mock_resp.__aexit__ = AsyncMock(return_value=False) + + with patch("aiohttp.ClientSession.post", return_value=mock_resp): + hash_key = await broker._get_hash_key(body) + assert isinstance(hash_key, str) + assert len(hash_key) > 0 + + await broker.close() diff --git a/tests/test_risk.py b/tests/test_risk.py new file mode 100644 index 0000000..bebdb5a --- /dev/null +++ b/tests/test_risk.py @@ -0,0 +1,132 @@ +"""TDD tests for core/risk_manager.py — written BEFORE implementation.""" + +from __future__ import annotations + +import pytest + +from src.core.risk_manager import ( + CircuitBreakerTripped, + FatFingerRejected, + RiskManager, +) + + +# --------------------------------------------------------------------------- +# Circuit Breaker Tests +# --------------------------------------------------------------------------- + + +class TestCircuitBreaker: + """The circuit breaker must halt all trading when daily loss exceeds the threshold.""" + + def test_allows_trading_when_pnl_is_positive(self, settings): + rm = RiskManager(settings) + # 2% gain — should be fine + rm.check_circuit_breaker(current_pnl_pct=2.0) + + def test_allows_trading_at_zero_pnl(self, settings): + rm = RiskManager(settings) + rm.check_circuit_breaker(current_pnl_pct=0.0) + + def test_allows_trading_at_exactly_threshold(self, settings): + rm = RiskManager(settings) + # Exactly -3.0% is ON the boundary — still allowed + rm.check_circuit_breaker(current_pnl_pct=-3.0) + + def test_trips_when_loss_exceeds_threshold(self, settings): + rm = RiskManager(settings) + with pytest.raises(CircuitBreakerTripped): + rm.check_circuit_breaker(current_pnl_pct=-3.01) + + def test_trips_at_large_loss(self, settings): + rm = RiskManager(settings) + with pytest.raises(CircuitBreakerTripped): + rm.check_circuit_breaker(current_pnl_pct=-10.0) + + def test_custom_threshold(self): + """A stricter threshold (-1.5%) should trip earlier.""" + from src.config import Settings + + strict = Settings( + KIS_APP_KEY="k", + KIS_APP_SECRET="s", + KIS_ACCOUNT_NO="00000000-00", + KIS_BASE_URL="https://example.com", + GEMINI_API_KEY="g", + CIRCUIT_BREAKER_PCT=-1.5, + FAT_FINGER_PCT=30.0, + CONFIDENCE_THRESHOLD=80, + DB_PATH=":memory:", + ) + rm = RiskManager(strict) + with pytest.raises(CircuitBreakerTripped): + rm.check_circuit_breaker(current_pnl_pct=-1.51) + + +# --------------------------------------------------------------------------- +# Fat Finger Tests +# --------------------------------------------------------------------------- + + +class TestFatFingerCheck: + """Orders exceeding 30% of total cash must be rejected.""" + + def test_allows_small_order(self, settings): + rm = RiskManager(settings) + # 10% of 10_000_000 = 1_000_000 + rm.check_fat_finger(order_amount=1_000_000, total_cash=10_000_000) + + def test_allows_order_at_exactly_threshold(self, settings): + rm = RiskManager(settings) + # Exactly 30% — allowed + rm.check_fat_finger(order_amount=3_000_000, total_cash=10_000_000) + + def test_rejects_order_exceeding_threshold(self, settings): + rm = RiskManager(settings) + with pytest.raises(FatFingerRejected): + rm.check_fat_finger(order_amount=3_000_001, total_cash=10_000_000) + + def test_rejects_massive_order(self, settings): + rm = RiskManager(settings) + with pytest.raises(FatFingerRejected): + rm.check_fat_finger(order_amount=9_000_000, total_cash=10_000_000) + + def test_zero_cash_rejects_any_order(self, settings): + rm = RiskManager(settings) + with pytest.raises(FatFingerRejected): + rm.check_fat_finger(order_amount=1, total_cash=0) + + +# --------------------------------------------------------------------------- +# Pre-Order Validation (Integration of both checks) +# --------------------------------------------------------------------------- + + +class TestPreOrderValidation: + """validate_order must run BOTH checks before approving.""" + + def test_passes_when_both_checks_ok(self, settings): + rm = RiskManager(settings) + rm.validate_order( + current_pnl_pct=0.5, + order_amount=1_000_000, + total_cash=10_000_000, + ) + + def test_fails_on_circuit_breaker(self, settings): + rm = RiskManager(settings) + with pytest.raises(CircuitBreakerTripped): + rm.validate_order( + current_pnl_pct=-5.0, + order_amount=100, + total_cash=10_000_000, + ) + + def test_fails_on_fat_finger(self, settings): + rm = RiskManager(settings) + with pytest.raises(FatFingerRejected): + rm.validate_order( + current_pnl_pct=1.0, + order_amount=5_000_000, + total_cash=10_000_000, + )