From a14f944fccec03916c84e7678badab4cdd3d0f98 Mon Sep 17 00:00:00 2001 From: agentson Date: Sat, 14 Feb 2026 21:36:57 +0900 Subject: [PATCH] feat: link decision outcomes to trades via decision_id (issue #92) Add decision_id column to trades table, capture log_decision() return value, and update original BUY decision outcome on SELL execution. Co-Authored-By: Claude Opus 4.6 --- src/db.py | 35 +++++++++++++-- src/main.py | 44 ++++++++++++++++-- tests/test_main.py | 108 ++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 180 insertions(+), 7 deletions(-) diff --git a/src/db.py b/src/db.py index 7190699..ea57f79 100644 --- a/src/db.py +++ b/src/db.py @@ -6,6 +6,7 @@ import json import sqlite3 from datetime import UTC, datetime from pathlib import Path +from typing import Any def init_db(db_path: str) -> sqlite3.Connection: @@ -26,7 +27,8 @@ def init_db(db_path: str) -> sqlite3.Connection: price REAL, pnl REAL DEFAULT 0.0, market TEXT DEFAULT 'KR', - exchange_code TEXT DEFAULT 'KRX' + exchange_code TEXT DEFAULT 'KRX', + decision_id TEXT ) """ ) @@ -41,6 +43,8 @@ def init_db(db_path: str) -> sqlite3.Connection: conn.execute("ALTER TABLE trades ADD COLUMN exchange_code TEXT DEFAULT 'KRX'") if "selection_context" not in columns: conn.execute("ALTER TABLE trades ADD COLUMN selection_context TEXT") + if "decision_id" not in columns: + conn.execute("ALTER TABLE trades ADD COLUMN decision_id TEXT") # Context tree tables for multi-layered memory management conn.execute( @@ -143,6 +147,7 @@ def log_trade( market: str = "KR", exchange_code: str = "KRX", selection_context: dict[str, any] | None = None, + decision_id: str | None = None, ) -> None: """Insert a trade record into the database. @@ -166,9 +171,9 @@ def log_trade( """ INSERT INTO trades ( timestamp, stock_code, action, confidence, rationale, - quantity, price, pnl, market, exchange_code, selection_context + quantity, price, pnl, market, exchange_code, selection_context, decision_id ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( datetime.now(UTC).isoformat(), @@ -182,6 +187,30 @@ def log_trade( market, exchange_code, context_json, + decision_id, ), ) conn.commit() + + +def get_latest_buy_trade( + conn: sqlite3.Connection, stock_code: str, market: str +) -> dict[str, Any] | None: + """Fetch the most recent BUY trade for a stock and market.""" + cursor = conn.execute( + """ + SELECT decision_id, price, quantity + FROM trades + WHERE stock_code = ? + AND market = ? + AND action = 'BUY' + AND decision_id IS NOT NULL + ORDER BY timestamp DESC + LIMIT 1 + """, + (stock_code, market), + ) + row = cursor.fetchone() + if not row: + return None + return {"decision_id": row[0], "price": row[1], "quantity": row[2]} diff --git a/src/main.py b/src/main.py index 97c9afd..e9b484f 100644 --- a/src/main.py +++ b/src/main.py @@ -26,7 +26,7 @@ from src.context.store import ContextStore from src.core.criticality import CriticalityAssessor from src.core.priority_queue import PriorityTaskQueue from src.core.risk_manager import CircuitBreakerTripped, FatFingerRejected, RiskManager -from src.db import init_db, log_trade +from src.db import get_latest_buy_trade, init_db, log_trade from src.logging.decision_logger import DecisionLogger from src.logging_config import setup_logging from src.markets.schedule import MarketInfo, get_next_market_open, get_open_markets @@ -279,7 +279,7 @@ async def trading_cycle( "pnl_pct": pnl_pct, } - decision_logger.log_decision( + decision_id = decision_logger.log_decision( stock_code=stock_code, market=market.code, exchange_code=market.exchange_code, @@ -291,6 +291,9 @@ async def trading_cycle( ) # 3. Execute if actionable + quantity = 0 + trade_price = current_price + trade_pnl = 0.0 if decision.action in ("BUY", "SELL"): # Determine order size (simplified: 1 lot) quantity = 1 @@ -346,6 +349,18 @@ async def trading_cycle( except Exception as exc: logger.warning("Telegram notification failed: %s", exc) + if decision.action == "SELL": + buy_trade = get_latest_buy_trade(db_conn, stock_code, market.code) + if buy_trade and buy_trade.get("price") is not None: + buy_price = float(buy_trade["price"]) + buy_qty = int(buy_trade.get("quantity") or 1) + trade_pnl = (trade_price - buy_price) * buy_qty + decision_logger.update_outcome( + decision_id=buy_trade["decision_id"], + pnl=trade_pnl, + accuracy=1 if trade_pnl > 0 else 0, + ) + # 6. Log trade with selection context selection_context = None if stock_code in market_candidates: @@ -363,9 +378,13 @@ async def trading_cycle( action=decision.action, confidence=decision.confidence, rationale=decision.rationale, + quantity=quantity, + price=trade_price, + pnl=trade_pnl, market=market.code, exchange_code=market.exchange_code, selection_context=selection_context, + decision_id=decision_id, ) # 7. Latency monitoring @@ -600,7 +619,7 @@ async def run_daily_session( "pnl_pct": pnl_pct, } - decision_logger.log_decision( + decision_id = decision_logger.log_decision( stock_code=stock_code, market=market.code, exchange_code=market.exchange_code, @@ -612,6 +631,9 @@ async def run_daily_session( ) # Execute if actionable + quantity = 0 + trade_price = stock_data["current_price"] + trade_pnl = 0.0 if decision.action in ("BUY", "SELL"): quantity = 1 order_amount = stock_data["current_price"] * quantity @@ -684,6 +706,18 @@ async def run_daily_session( ) continue + if decision.action == "SELL": + buy_trade = get_latest_buy_trade(db_conn, stock_code, market.code) + if buy_trade and buy_trade.get("price") is not None: + buy_price = float(buy_trade["price"]) + buy_qty = int(buy_trade.get("quantity") or 1) + trade_pnl = (trade_price - buy_price) * buy_qty + decision_logger.update_outcome( + decision_id=buy_trade["decision_id"], + pnl=trade_pnl, + accuracy=1 if trade_pnl > 0 else 0, + ) + # Log trade log_trade( conn=db_conn, @@ -691,8 +725,12 @@ async def run_daily_session( action=decision.action, confidence=decision.confidence, rationale=decision.rationale, + quantity=quantity, + price=trade_price, + pnl=trade_pnl, market=market.code, exchange_code=market.exchange_code, + decision_id=decision_id, ) logger.info("Daily trading session completed") diff --git a/tests/test_main.py b/tests/test_main.py index 417a74c..4952297 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -5,8 +5,10 @@ from unittest.mock import ANY, AsyncMock, MagicMock, patch import pytest -from src.core.risk_manager import CircuitBreakerTripped, FatFingerRejected from src.context.layer import ContextLayer +from src.core.risk_manager import CircuitBreakerTripped, FatFingerRejected +from src.db import init_db, log_trade +from src.logging.decision_logger import DecisionLogger from src.main import safe_float, trading_cycle from src.strategy.models import ( DayPlaybook, @@ -44,6 +46,17 @@ def _make_hold_match(stock_code: str = "005930") -> ScenarioMatch: ) +def _make_sell_match(stock_code: str = "005930") -> ScenarioMatch: + """Create a ScenarioMatch that returns SELL.""" + return ScenarioMatch( + stock_code=stock_code, + matched_scenario=None, + action=ScenarioAction.SELL, + confidence=90, + rationale="Test sell", + ) + + class TestSafeFloat: """Test safe_float() helper function.""" @@ -1113,3 +1126,96 @@ class TestScenarioEngineIntegration: # REDUCE_ALL is not BUY or SELL — no order sent mock_broker.send_order.assert_not_called() mock_telegram.notify_trade_execution.assert_not_called() + + +@pytest.mark.asyncio +async def test_sell_updates_original_buy_decision_outcome() -> None: + """SELL should update the original BUY decision outcome in decision_logs.""" + db_conn = init_db(":memory:") + decision_logger = DecisionLogger(db_conn) + + buy_decision_id = decision_logger.log_decision( + stock_code="005930", + market="KR", + exchange_code="KRX", + action="BUY", + confidence=85, + rationale="Initial buy", + context_snapshot={}, + input_data={}, + ) + log_trade( + conn=db_conn, + stock_code="005930", + action="BUY", + confidence=85, + rationale="Initial buy", + quantity=1, + price=100.0, + pnl=0.0, + market="KR", + exchange_code="KRX", + decision_id=buy_decision_id, + ) + + broker = MagicMock() + broker.get_orderbook = AsyncMock( + return_value={"output1": {"stck_prpr": "120", "frgn_ntby_qty": "0"}} + ) + broker.get_balance = AsyncMock( + return_value={ + "output2": [ + { + "tot_evlu_amt": "100000", + "dnca_tot_amt": "10000", + "pchs_amt_smtl_amt": "90000", + } + ] + } + ) + broker.send_order = AsyncMock(return_value={"msg1": "OK"}) + + overseas_broker = MagicMock() + engine = MagicMock(spec=ScenarioEngine) + engine.evaluate = MagicMock(return_value=_make_sell_match()) + risk = MagicMock() + context_store = MagicMock( + get_latest_timeframe=MagicMock(return_value=None), + set_context=MagicMock(), + ) + criticality_assessor = MagicMock( + assess_market_conditions=MagicMock(return_value=MagicMock(value="NORMAL")), + get_timeout=MagicMock(return_value=5.0), + ) + telegram = MagicMock() + telegram.notify_trade_execution = AsyncMock() + telegram.notify_fat_finger = AsyncMock() + telegram.notify_circuit_breaker = AsyncMock() + telegram.notify_scenario_matched = AsyncMock() + + market = MagicMock() + market.name = "Korea" + market.code = "KR" + market.exchange_code = "KRX" + market.is_domestic = True + + await trading_cycle( + broker=broker, + overseas_broker=overseas_broker, + scenario_engine=engine, + playbook=_make_playbook(), + risk=risk, + db_conn=db_conn, + decision_logger=decision_logger, + context_store=context_store, + criticality_assessor=criticality_assessor, + telegram=telegram, + market=market, + stock_code="005930", + scan_candidates={}, + ) + + updated_buy = decision_logger.get_decision_by_id(buy_decision_id) + assert updated_buy is not None + assert updated_buy.outcome_pnl == 20.0 + assert updated_buy.outcome_accuracy == 1