Add complete Ouroboros trading system with TDD test suite
Some checks failed
CI / test (push) Has been cancelled
Some checks failed
CI / test (push) Has been cancelled
Implement the full autonomous trading agent architecture: - KIS broker with async API, token refresh, leaky bucket rate limiter, and hash key signing - Gemini-powered decision engine with JSON parsing and confidence threshold enforcement - Risk manager with circuit breaker (-3% P&L) and fat finger protection (30% cap) - Evolution engine for self-improving strategy generation via failure analysis - 35 passing tests written TDD-first covering risk, broker, and brain modules - CI/CD pipeline, Docker multi-stage build, and AI agent context docs Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
23
tests/conftest.py
Normal file
23
tests/conftest.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Shared test fixtures for The Ouroboros test suite."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from src.config import Settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def settings() -> Settings:
|
||||
"""Return a Settings instance with safe test defaults."""
|
||||
return Settings(
|
||||
KIS_APP_KEY="test_app_key",
|
||||
KIS_APP_SECRET="test_app_secret",
|
||||
KIS_ACCOUNT_NO="12345678-01",
|
||||
KIS_BASE_URL="https://openapivts.koreainvestment.com:9443",
|
||||
GEMINI_API_KEY="test_gemini_key",
|
||||
CIRCUIT_BREAKER_PCT=-3.0,
|
||||
FAT_FINGER_PCT=30.0,
|
||||
CONFIDENCE_THRESHOLD=80,
|
||||
DB_PATH=":memory:",
|
||||
)
|
||||
159
tests/test_brain.py
Normal file
159
tests/test_brain.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""TDD tests for brain/gemini_client.py — written BEFORE implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.brain.gemini_client import GeminiClient, TradeDecision
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response Parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResponseParsing:
|
||||
"""Gemini responses must be parsed into validated TradeDecision objects."""
|
||||
|
||||
def test_valid_buy_response(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
raw = '{"action": "BUY", "confidence": 90, "rationale": "Strong momentum"}'
|
||||
decision = client.parse_response(raw)
|
||||
assert decision.action == "BUY"
|
||||
assert decision.confidence == 90
|
||||
assert decision.rationale == "Strong momentum"
|
||||
|
||||
def test_valid_sell_response(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
raw = '{"action": "SELL", "confidence": 85, "rationale": "Overbought RSI"}'
|
||||
decision = client.parse_response(raw)
|
||||
assert decision.action == "SELL"
|
||||
|
||||
def test_valid_hold_response(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
raw = '{"action": "HOLD", "confidence": 95, "rationale": "Sideways market"}'
|
||||
decision = client.parse_response(raw)
|
||||
assert decision.action == "HOLD"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Confidence Threshold Enforcement
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConfidenceThreshold:
|
||||
"""If confidence < 80, the action MUST be forced to HOLD."""
|
||||
|
||||
def test_low_confidence_buy_becomes_hold(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
raw = '{"action": "BUY", "confidence": 65, "rationale": "Weak signal"}'
|
||||
decision = client.parse_response(raw)
|
||||
assert decision.action == "HOLD"
|
||||
assert decision.confidence == 65
|
||||
|
||||
def test_low_confidence_sell_becomes_hold(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
raw = '{"action": "SELL", "confidence": 79, "rationale": "Uncertain"}'
|
||||
decision = client.parse_response(raw)
|
||||
assert decision.action == "HOLD"
|
||||
|
||||
def test_exactly_threshold_is_allowed(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
raw = '{"action": "BUY", "confidence": 80, "rationale": "Just enough"}'
|
||||
decision = client.parse_response(raw)
|
||||
assert decision.action == "BUY"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Malformed JSON Handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMalformedJsonHandling:
|
||||
"""Gemini may return garbage — the parser must not crash."""
|
||||
|
||||
def test_empty_string_returns_hold(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
decision = client.parse_response("")
|
||||
assert decision.action == "HOLD"
|
||||
assert decision.confidence == 0
|
||||
|
||||
def test_plain_text_returns_hold(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
decision = client.parse_response("I think you should buy Samsung stock")
|
||||
assert decision.action == "HOLD"
|
||||
assert decision.confidence == 0
|
||||
|
||||
def test_partial_json_returns_hold(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
decision = client.parse_response('{"action": "BUY", "confidence":')
|
||||
assert decision.action == "HOLD"
|
||||
assert decision.confidence == 0
|
||||
|
||||
def test_json_with_missing_fields_returns_hold(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
decision = client.parse_response('{"action": "BUY"}')
|
||||
assert decision.action == "HOLD"
|
||||
assert decision.confidence == 0
|
||||
|
||||
def test_json_with_invalid_action_returns_hold(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
decision = client.parse_response(
|
||||
'{"action": "YOLO", "confidence": 99, "rationale": "moon"}'
|
||||
)
|
||||
assert decision.action == "HOLD"
|
||||
assert decision.confidence == 0
|
||||
|
||||
def test_json_wrapped_in_markdown_code_block(self, settings):
|
||||
"""Gemini often wraps JSON in ```json ... ``` blocks."""
|
||||
client = GeminiClient(settings)
|
||||
raw = '```json\n{"action": "BUY", "confidence": 92, "rationale": "Good"}\n```'
|
||||
decision = client.parse_response(raw)
|
||||
assert decision.action == "BUY"
|
||||
assert decision.confidence == 92
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompt Construction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPromptConstruction:
|
||||
"""The prompt sent to Gemini must include all required market data."""
|
||||
|
||||
def test_prompt_contains_stock_code(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
market_data = {
|
||||
"stock_code": "005930",
|
||||
"current_price": 72000,
|
||||
"orderbook": {"asks": [], "bids": []},
|
||||
"foreigner_net": -50000,
|
||||
}
|
||||
prompt = client.build_prompt(market_data)
|
||||
assert "005930" in prompt
|
||||
|
||||
def test_prompt_contains_price(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
market_data = {
|
||||
"stock_code": "005930",
|
||||
"current_price": 72000,
|
||||
"orderbook": {"asks": [], "bids": []},
|
||||
"foreigner_net": -50000,
|
||||
}
|
||||
prompt = client.build_prompt(market_data)
|
||||
assert "72000" in prompt
|
||||
|
||||
def test_prompt_enforces_json_output_format(self, settings):
|
||||
client = GeminiClient(settings)
|
||||
market_data = {
|
||||
"stock_code": "005930",
|
||||
"current_price": 72000,
|
||||
"orderbook": {"asks": [], "bids": []},
|
||||
"foreigner_net": 0,
|
||||
}
|
||||
prompt = client.build_prompt(market_data)
|
||||
assert "JSON" in prompt
|
||||
assert "action" in prompt
|
||||
assert "confidence" in prompt
|
||||
140
tests/test_broker.py
Normal file
140
tests/test_broker.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""TDD tests for broker/kis_api.py — written BEFORE implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
from src.broker.kis_api import KISBroker
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Token Management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTokenManagement:
|
||||
"""Access token must be auto-refreshed and cached."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetches_token_on_first_call(self, settings):
|
||||
broker = KISBroker(settings)
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(
|
||||
return_value={
|
||||
"access_token": "tok_abc123",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 86400,
|
||||
}
|
||||
)
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
|
||||
token = await broker._ensure_token()
|
||||
assert token == "tok_abc123"
|
||||
|
||||
await broker.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reuses_cached_token(self, settings):
|
||||
broker = KISBroker(settings)
|
||||
broker._access_token = "cached_token"
|
||||
broker._token_expires_at = asyncio.get_event_loop().time() + 3600
|
||||
|
||||
token = await broker._ensure_token()
|
||||
assert token == "cached_token"
|
||||
|
||||
await broker.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Network Error Handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNetworkErrorHandling:
|
||||
"""Broker must handle network timeouts and HTTP errors gracefully."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_raises_connection_error(self, settings):
|
||||
broker = KISBroker(settings)
|
||||
broker._access_token = "tok"
|
||||
broker._token_expires_at = asyncio.get_event_loop().time() + 3600
|
||||
|
||||
with patch(
|
||||
"aiohttp.ClientSession.get",
|
||||
side_effect=asyncio.TimeoutError(),
|
||||
):
|
||||
with pytest.raises(ConnectionError):
|
||||
await broker.get_orderbook("005930")
|
||||
|
||||
await broker.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_500_raises_connection_error(self, settings):
|
||||
broker = KISBroker(settings)
|
||||
broker._access_token = "tok"
|
||||
broker._token_expires_at = asyncio.get_event_loop().time() + 3600
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 500
|
||||
mock_resp.text = AsyncMock(return_value="Internal Server Error")
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.get", return_value=mock_resp):
|
||||
with pytest.raises(ConnectionError):
|
||||
await broker.get_orderbook("005930")
|
||||
|
||||
await broker.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rate Limiter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRateLimiter:
|
||||
"""The leaky bucket rate limiter must throttle requests."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limiter_does_not_block_under_limit(self, settings):
|
||||
broker = KISBroker(settings)
|
||||
# Should complete without blocking when under limit
|
||||
await broker._rate_limiter.acquire()
|
||||
await broker.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hash Key Generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHashKey:
|
||||
"""POST requests to KIS require a hash key."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generates_hash_key_for_post_body(self, settings):
|
||||
broker = KISBroker(settings)
|
||||
broker._access_token = "tok"
|
||||
broker._token_expires_at = asyncio.get_event_loop().time() + 3600
|
||||
|
||||
body = {"CANO": "12345678", "ACNT_PRDT_CD": "01"}
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(return_value={"HASH": "abc123hash"})
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
|
||||
hash_key = await broker._get_hash_key(body)
|
||||
assert isinstance(hash_key, str)
|
||||
assert len(hash_key) > 0
|
||||
|
||||
await broker.close()
|
||||
132
tests/test_risk.py
Normal file
132
tests/test_risk.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""TDD tests for core/risk_manager.py — written BEFORE implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from src.core.risk_manager import (
|
||||
CircuitBreakerTripped,
|
||||
FatFingerRejected,
|
||||
RiskManager,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Circuit Breaker Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCircuitBreaker:
|
||||
"""The circuit breaker must halt all trading when daily loss exceeds the threshold."""
|
||||
|
||||
def test_allows_trading_when_pnl_is_positive(self, settings):
|
||||
rm = RiskManager(settings)
|
||||
# 2% gain — should be fine
|
||||
rm.check_circuit_breaker(current_pnl_pct=2.0)
|
||||
|
||||
def test_allows_trading_at_zero_pnl(self, settings):
|
||||
rm = RiskManager(settings)
|
||||
rm.check_circuit_breaker(current_pnl_pct=0.0)
|
||||
|
||||
def test_allows_trading_at_exactly_threshold(self, settings):
|
||||
rm = RiskManager(settings)
|
||||
# Exactly -3.0% is ON the boundary — still allowed
|
||||
rm.check_circuit_breaker(current_pnl_pct=-3.0)
|
||||
|
||||
def test_trips_when_loss_exceeds_threshold(self, settings):
|
||||
rm = RiskManager(settings)
|
||||
with pytest.raises(CircuitBreakerTripped):
|
||||
rm.check_circuit_breaker(current_pnl_pct=-3.01)
|
||||
|
||||
def test_trips_at_large_loss(self, settings):
|
||||
rm = RiskManager(settings)
|
||||
with pytest.raises(CircuitBreakerTripped):
|
||||
rm.check_circuit_breaker(current_pnl_pct=-10.0)
|
||||
|
||||
def test_custom_threshold(self):
|
||||
"""A stricter threshold (-1.5%) should trip earlier."""
|
||||
from src.config import Settings
|
||||
|
||||
strict = Settings(
|
||||
KIS_APP_KEY="k",
|
||||
KIS_APP_SECRET="s",
|
||||
KIS_ACCOUNT_NO="00000000-00",
|
||||
KIS_BASE_URL="https://example.com",
|
||||
GEMINI_API_KEY="g",
|
||||
CIRCUIT_BREAKER_PCT=-1.5,
|
||||
FAT_FINGER_PCT=30.0,
|
||||
CONFIDENCE_THRESHOLD=80,
|
||||
DB_PATH=":memory:",
|
||||
)
|
||||
rm = RiskManager(strict)
|
||||
with pytest.raises(CircuitBreakerTripped):
|
||||
rm.check_circuit_breaker(current_pnl_pct=-1.51)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fat Finger Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFatFingerCheck:
|
||||
"""Orders exceeding 30% of total cash must be rejected."""
|
||||
|
||||
def test_allows_small_order(self, settings):
|
||||
rm = RiskManager(settings)
|
||||
# 10% of 10_000_000 = 1_000_000
|
||||
rm.check_fat_finger(order_amount=1_000_000, total_cash=10_000_000)
|
||||
|
||||
def test_allows_order_at_exactly_threshold(self, settings):
|
||||
rm = RiskManager(settings)
|
||||
# Exactly 30% — allowed
|
||||
rm.check_fat_finger(order_amount=3_000_000, total_cash=10_000_000)
|
||||
|
||||
def test_rejects_order_exceeding_threshold(self, settings):
|
||||
rm = RiskManager(settings)
|
||||
with pytest.raises(FatFingerRejected):
|
||||
rm.check_fat_finger(order_amount=3_000_001, total_cash=10_000_000)
|
||||
|
||||
def test_rejects_massive_order(self, settings):
|
||||
rm = RiskManager(settings)
|
||||
with pytest.raises(FatFingerRejected):
|
||||
rm.check_fat_finger(order_amount=9_000_000, total_cash=10_000_000)
|
||||
|
||||
def test_zero_cash_rejects_any_order(self, settings):
|
||||
rm = RiskManager(settings)
|
||||
with pytest.raises(FatFingerRejected):
|
||||
rm.check_fat_finger(order_amount=1, total_cash=0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pre-Order Validation (Integration of both checks)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPreOrderValidation:
|
||||
"""validate_order must run BOTH checks before approving."""
|
||||
|
||||
def test_passes_when_both_checks_ok(self, settings):
|
||||
rm = RiskManager(settings)
|
||||
rm.validate_order(
|
||||
current_pnl_pct=0.5,
|
||||
order_amount=1_000_000,
|
||||
total_cash=10_000_000,
|
||||
)
|
||||
|
||||
def test_fails_on_circuit_breaker(self, settings):
|
||||
rm = RiskManager(settings)
|
||||
with pytest.raises(CircuitBreakerTripped):
|
||||
rm.validate_order(
|
||||
current_pnl_pct=-5.0,
|
||||
order_amount=100,
|
||||
total_cash=10_000_000,
|
||||
)
|
||||
|
||||
def test_fails_on_fat_finger(self, settings):
|
||||
rm = RiskManager(settings)
|
||||
with pytest.raises(FatFingerRejected):
|
||||
rm.validate_order(
|
||||
current_pnl_pct=1.0,
|
||||
order_amount=5_000_000,
|
||||
total_cash=10_000_000,
|
||||
)
|
||||
Reference in New Issue
Block a user