Some checks failed
CI / test (pull_request) Has been cancelled
pre_market_planner처럼 prompt_override를 사용하는 호출자는 플레이북 JSON 등 TradeDecision이 아닌 raw 텍스트를 기대한다. 기존에는 parse_response를 통과시켜 항상 "Missing fields" 경고가 발생했다. decide()에서 prompt_override 감지 시 parse_response를 건너뛰고 raw 응답을 rationale에 담아 직접 반환하도록 수정한다. 정상 응답인데 경고가 뜨는 문제가 해결된다. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
388 lines
15 KiB
Python
388 lines
15 KiB
Python
"""TDD tests for brain/gemini_client.py — written BEFORE implementation."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from src.brain.gemini_client import GeminiClient
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Response Parsing
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestResponseParsing:
|
|
"""Gemini responses must be parsed into validated TradeDecision objects."""
|
|
|
|
def test_valid_buy_response(self, settings):
|
|
client = GeminiClient(settings)
|
|
raw = '{"action": "BUY", "confidence": 90, "rationale": "Strong momentum"}'
|
|
decision = client.parse_response(raw)
|
|
assert decision.action == "BUY"
|
|
assert decision.confidence == 90
|
|
assert decision.rationale == "Strong momentum"
|
|
|
|
def test_valid_sell_response(self, settings):
|
|
client = GeminiClient(settings)
|
|
raw = '{"action": "SELL", "confidence": 85, "rationale": "Overbought RSI"}'
|
|
decision = client.parse_response(raw)
|
|
assert decision.action == "SELL"
|
|
|
|
def test_valid_hold_response(self, settings):
|
|
client = GeminiClient(settings)
|
|
raw = '{"action": "HOLD", "confidence": 95, "rationale": "Sideways market"}'
|
|
decision = client.parse_response(raw)
|
|
assert decision.action == "HOLD"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Confidence Threshold Enforcement
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestConfidenceThreshold:
|
|
"""If confidence < 80, the action MUST be forced to HOLD."""
|
|
|
|
def test_low_confidence_buy_becomes_hold(self, settings):
|
|
client = GeminiClient(settings)
|
|
raw = '{"action": "BUY", "confidence": 65, "rationale": "Weak signal"}'
|
|
decision = client.parse_response(raw)
|
|
assert decision.action == "HOLD"
|
|
assert decision.confidence == 65
|
|
|
|
def test_low_confidence_sell_becomes_hold(self, settings):
|
|
client = GeminiClient(settings)
|
|
raw = '{"action": "SELL", "confidence": 79, "rationale": "Uncertain"}'
|
|
decision = client.parse_response(raw)
|
|
assert decision.action == "HOLD"
|
|
|
|
def test_exactly_threshold_is_allowed(self, settings):
|
|
client = GeminiClient(settings)
|
|
raw = '{"action": "BUY", "confidence": 80, "rationale": "Just enough"}'
|
|
decision = client.parse_response(raw)
|
|
assert decision.action == "BUY"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Malformed JSON Handling
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestMalformedJsonHandling:
|
|
"""Gemini may return garbage — the parser must not crash."""
|
|
|
|
def test_empty_string_returns_hold(self, settings):
|
|
client = GeminiClient(settings)
|
|
decision = client.parse_response("")
|
|
assert decision.action == "HOLD"
|
|
assert decision.confidence == 0
|
|
|
|
def test_plain_text_returns_hold(self, settings):
|
|
client = GeminiClient(settings)
|
|
decision = client.parse_response("I think you should buy Samsung stock")
|
|
assert decision.action == "HOLD"
|
|
assert decision.confidence == 0
|
|
|
|
def test_partial_json_returns_hold(self, settings):
|
|
client = GeminiClient(settings)
|
|
decision = client.parse_response('{"action": "BUY", "confidence":')
|
|
assert decision.action == "HOLD"
|
|
assert decision.confidence == 0
|
|
|
|
def test_json_with_missing_fields_returns_hold(self, settings):
|
|
client = GeminiClient(settings)
|
|
raw = '{"action": "BUY"}'
|
|
decision = client.parse_response(raw)
|
|
assert decision.action == "HOLD"
|
|
assert decision.confidence == 0
|
|
# rationale preserves raw so prompt_override callers (e.g. pre_market_planner)
|
|
# can extract non-TradeDecision JSON from decision.rationale (#245)
|
|
assert decision.rationale == raw
|
|
|
|
def test_non_trade_decision_json_preserves_raw_in_rationale(self, settings):
|
|
"""Playbook JSON (no action/confidence/rationale) must be preserved for planner."""
|
|
client = GeminiClient(settings)
|
|
playbook_json = '{"market_outlook": "neutral", "stocks": []}'
|
|
decision = client.parse_response(playbook_json)
|
|
assert decision.action == "HOLD"
|
|
assert decision.rationale == playbook_json
|
|
|
|
def test_json_with_invalid_action_returns_hold(self, settings):
|
|
client = GeminiClient(settings)
|
|
decision = client.parse_response(
|
|
'{"action": "YOLO", "confidence": 99, "rationale": "moon"}'
|
|
)
|
|
assert decision.action == "HOLD"
|
|
assert decision.confidence == 0
|
|
|
|
def test_json_wrapped_in_markdown_code_block(self, settings):
|
|
"""Gemini often wraps JSON in ```json ... ``` blocks."""
|
|
client = GeminiClient(settings)
|
|
raw = '```json\n{"action": "BUY", "confidence": 92, "rationale": "Good"}\n```'
|
|
decision = client.parse_response(raw)
|
|
assert decision.action == "BUY"
|
|
assert decision.confidence == 92
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Prompt Construction
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPromptConstruction:
|
|
"""The prompt sent to Gemini must include all required market data."""
|
|
|
|
def test_prompt_contains_stock_code(self, settings):
|
|
client = GeminiClient(settings)
|
|
market_data = {
|
|
"stock_code": "005930",
|
|
"current_price": 72000,
|
|
"orderbook": {"asks": [], "bids": []},
|
|
"foreigner_net": -50000,
|
|
}
|
|
prompt = client.build_prompt_sync(market_data)
|
|
assert "005930" in prompt
|
|
|
|
def test_prompt_contains_price(self, settings):
|
|
client = GeminiClient(settings)
|
|
market_data = {
|
|
"stock_code": "005930",
|
|
"current_price": 72000,
|
|
"orderbook": {"asks": [], "bids": []},
|
|
"foreigner_net": -50000,
|
|
}
|
|
prompt = client.build_prompt_sync(market_data)
|
|
assert "72000" in prompt
|
|
|
|
def test_prompt_enforces_json_output_format(self, settings):
|
|
client = GeminiClient(settings)
|
|
market_data = {
|
|
"stock_code": "005930",
|
|
"current_price": 72000,
|
|
"orderbook": {"asks": [], "bids": []},
|
|
"foreigner_net": 0,
|
|
}
|
|
prompt = client.build_prompt_sync(market_data)
|
|
assert "JSON" in prompt
|
|
assert "action" in prompt
|
|
assert "confidence" in prompt
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Batch Decision Making
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestBatchDecisionParsing:
|
|
"""Batch response parser must handle JSON arrays correctly."""
|
|
|
|
def test_parse_valid_batch_response(self, settings):
|
|
client = GeminiClient(settings)
|
|
stocks_data = [
|
|
{"stock_code": "AAPL", "current_price": 185.5},
|
|
{"stock_code": "MSFT", "current_price": 420.0},
|
|
]
|
|
raw = """[
|
|
{"code": "AAPL", "action": "BUY", "confidence": 85, "rationale": "Strong momentum"},
|
|
{"code": "MSFT", "action": "HOLD", "confidence": 50, "rationale": "Wait for earnings"}
|
|
]"""
|
|
|
|
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
|
|
|
assert len(decisions) == 2
|
|
assert decisions["AAPL"].action == "BUY"
|
|
assert decisions["AAPL"].confidence == 85
|
|
assert decisions["MSFT"].action == "HOLD"
|
|
assert decisions["MSFT"].confidence == 50
|
|
|
|
def test_parse_batch_with_markdown_wrapper(self, settings):
|
|
client = GeminiClient(settings)
|
|
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
|
raw = """```json
|
|
[{"code": "AAPL", "action": "BUY", "confidence": 90, "rationale": "Good"}]
|
|
```"""
|
|
|
|
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
|
|
|
assert decisions["AAPL"].action == "BUY"
|
|
assert decisions["AAPL"].confidence == 90
|
|
|
|
def test_parse_batch_empty_response_returns_hold_for_all(self, settings):
|
|
client = GeminiClient(settings)
|
|
stocks_data = [
|
|
{"stock_code": "AAPL", "current_price": 185.5},
|
|
{"stock_code": "MSFT", "current_price": 420.0},
|
|
]
|
|
|
|
decisions = client._parse_batch_response("", stocks_data, token_count=100)
|
|
|
|
assert len(decisions) == 2
|
|
assert decisions["AAPL"].action == "HOLD"
|
|
assert decisions["AAPL"].confidence == 0
|
|
assert decisions["MSFT"].action == "HOLD"
|
|
|
|
def test_parse_batch_malformed_json_returns_hold_for_all(self, settings):
|
|
client = GeminiClient(settings)
|
|
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
|
raw = "This is not JSON"
|
|
|
|
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
|
|
|
assert decisions["AAPL"].action == "HOLD"
|
|
assert decisions["AAPL"].confidence == 0
|
|
|
|
def test_parse_batch_not_array_returns_hold_for_all(self, settings):
|
|
client = GeminiClient(settings)
|
|
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
|
raw = '{"code": "AAPL", "action": "BUY", "confidence": 90, "rationale": "Good"}'
|
|
|
|
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
|
|
|
assert decisions["AAPL"].action == "HOLD"
|
|
assert decisions["AAPL"].confidence == 0
|
|
|
|
def test_parse_batch_missing_stock_gets_hold(self, settings):
|
|
client = GeminiClient(settings)
|
|
stocks_data = [
|
|
{"stock_code": "AAPL", "current_price": 185.5},
|
|
{"stock_code": "MSFT", "current_price": 420.0},
|
|
]
|
|
# Response only has AAPL, MSFT is missing
|
|
raw = '[{"code": "AAPL", "action": "BUY", "confidence": 85, "rationale": "Good"}]'
|
|
|
|
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
|
|
|
assert decisions["AAPL"].action == "BUY"
|
|
assert decisions["MSFT"].action == "HOLD"
|
|
assert decisions["MSFT"].confidence == 0
|
|
|
|
def test_parse_batch_invalid_action_becomes_hold(self, settings):
|
|
client = GeminiClient(settings)
|
|
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
|
raw = '[{"code": "AAPL", "action": "YOLO", "confidence": 90, "rationale": "Moon"}]'
|
|
|
|
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
|
|
|
assert decisions["AAPL"].action == "HOLD"
|
|
|
|
def test_parse_batch_low_confidence_becomes_hold(self, settings):
|
|
client = GeminiClient(settings)
|
|
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
|
raw = '[{"code": "AAPL", "action": "BUY", "confidence": 65, "rationale": "Weak"}]'
|
|
|
|
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
|
|
|
assert decisions["AAPL"].action == "HOLD"
|
|
assert decisions["AAPL"].confidence == 65
|
|
|
|
def test_parse_batch_missing_fields_gets_hold(self, settings):
|
|
client = GeminiClient(settings)
|
|
stocks_data = [{"stock_code": "AAPL", "current_price": 185.5}]
|
|
raw = '[{"code": "AAPL", "action": "BUY"}]' # Missing confidence and rationale
|
|
|
|
decisions = client._parse_batch_response(raw, stocks_data, token_count=100)
|
|
|
|
assert decisions["AAPL"].action == "HOLD"
|
|
assert decisions["AAPL"].confidence == 0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Prompt Override (used by pre_market_planner)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPromptOverride:
|
|
"""decide() must use prompt_override when present in market_data."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_prompt_override_is_sent_to_gemini(self, settings):
|
|
"""When prompt_override is in market_data, it should be used as the prompt."""
|
|
client = GeminiClient(settings)
|
|
|
|
custom_prompt = "You are a playbook generator. Return JSON with scenarios."
|
|
playbook_json = '{"market_outlook": "neutral", "stocks": []}'
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.text = playbook_json
|
|
|
|
with patch.object(
|
|
client._client.aio.models,
|
|
"generate_content",
|
|
new_callable=AsyncMock,
|
|
return_value=mock_response,
|
|
) as mock_generate:
|
|
market_data = {
|
|
"stock_code": "PLANNER",
|
|
"current_price": 0,
|
|
"prompt_override": custom_prompt,
|
|
}
|
|
decision = await client.decide(market_data)
|
|
|
|
# Verify the custom prompt was sent, not a built prompt
|
|
mock_generate.assert_called_once()
|
|
actual_prompt = mock_generate.call_args[1].get(
|
|
"contents", mock_generate.call_args[0][1] if len(mock_generate.call_args[0]) > 1 else None
|
|
)
|
|
assert actual_prompt == custom_prompt
|
|
# Raw response preserved in rationale without parse_response (#247)
|
|
assert decision.rationale == playbook_json
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_prompt_override_skips_parse_response(self, settings):
|
|
"""prompt_override bypasses parse_response — no Missing fields warning, raw preserved."""
|
|
client = GeminiClient(settings)
|
|
client._enable_optimization = True
|
|
|
|
custom_prompt = "Custom playbook prompt"
|
|
playbook_json = '{"market_outlook": "bullish", "stocks": [{"stock_code": "AAPL"}]}'
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.text = playbook_json
|
|
|
|
with patch.object(
|
|
client._client.aio.models,
|
|
"generate_content",
|
|
new_callable=AsyncMock,
|
|
return_value=mock_response,
|
|
):
|
|
with patch.object(client, "parse_response") as mock_parse:
|
|
market_data = {
|
|
"stock_code": "PLANNER",
|
|
"current_price": 0,
|
|
"prompt_override": custom_prompt,
|
|
}
|
|
decision = await client.decide(market_data)
|
|
|
|
# parse_response must NOT be called for prompt_override
|
|
mock_parse.assert_not_called()
|
|
# Raw playbook JSON preserved in rationale
|
|
assert decision.rationale == playbook_json
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_without_prompt_override_uses_build_prompt(self, settings):
|
|
"""Without prompt_override, decide() should use build_prompt as before."""
|
|
client = GeminiClient(settings)
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.text = '{"action": "HOLD", "confidence": 50, "rationale": "ok"}'
|
|
|
|
with patch.object(
|
|
client._client.aio.models,
|
|
"generate_content",
|
|
new_callable=AsyncMock,
|
|
return_value=mock_response,
|
|
) as mock_generate:
|
|
market_data = {
|
|
"stock_code": "005930",
|
|
"current_price": 72000,
|
|
}
|
|
await client.decide(market_data)
|
|
|
|
actual_prompt = mock_generate.call_args[1].get(
|
|
"contents", mock_generate.call_args[0][1] if len(mock_generate.call_args[0]) > 1 else None
|
|
)
|
|
# Should contain stock code from build_prompt, not be a custom override
|
|
assert "005930" in actual_prompt
|