feat: enforce session_id persistence in trade ledger (TASK-CODE-007)

This commit is contained in:
agentson
2026-02-27 08:49:04 +09:00
parent 2dbe98615d
commit b2b02b6f57
2 changed files with 43 additions and 3 deletions

View File

@@ -8,6 +8,9 @@ from datetime import UTC, datetime
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from src.core.order_policy import classify_session_id
from src.markets.schedule import MARKETS
def init_db(db_path: str) -> sqlite3.Connection: def init_db(db_path: str) -> sqlite3.Connection:
"""Initialize the trade logs database and return a connection.""" """Initialize the trade logs database and return a connection."""
@@ -35,6 +38,7 @@ def init_db(db_path: str) -> sqlite3.Connection:
fx_pnl REAL DEFAULT 0.0, fx_pnl REAL DEFAULT 0.0,
market TEXT DEFAULT 'KR', market TEXT DEFAULT 'KR',
exchange_code TEXT DEFAULT 'KRX', exchange_code TEXT DEFAULT 'KRX',
session_id TEXT DEFAULT 'UNKNOWN',
selection_context TEXT, selection_context TEXT,
decision_id TEXT, decision_id TEXT,
mode TEXT DEFAULT 'paper' mode TEXT DEFAULT 'paper'
@@ -56,6 +60,8 @@ def init_db(db_path: str) -> sqlite3.Connection:
conn.execute("ALTER TABLE trades ADD COLUMN decision_id TEXT") conn.execute("ALTER TABLE trades ADD COLUMN decision_id TEXT")
if "mode" not in columns: if "mode" not in columns:
conn.execute("ALTER TABLE trades ADD COLUMN mode TEXT DEFAULT 'paper'") conn.execute("ALTER TABLE trades ADD COLUMN mode TEXT DEFAULT 'paper'")
if "session_id" not in columns:
conn.execute("ALTER TABLE trades ADD COLUMN session_id TEXT DEFAULT 'UNKNOWN'")
if "strategy_pnl" not in columns: if "strategy_pnl" not in columns:
conn.execute("ALTER TABLE trades ADD COLUMN strategy_pnl REAL DEFAULT 0.0") conn.execute("ALTER TABLE trades ADD COLUMN strategy_pnl REAL DEFAULT 0.0")
if "fx_pnl" not in columns: if "fx_pnl" not in columns:
@@ -70,6 +76,13 @@ def init_db(db_path: str) -> sqlite3.Connection:
AND fx_pnl = 0.0 AND fx_pnl = 0.0
""" """
) )
conn.execute(
"""
UPDATE trades
SET session_id = 'UNKNOWN'
WHERE session_id IS NULL OR session_id = ''
"""
)
# Context tree tables for multi-layered memory management # Context tree tables for multi-layered memory management
conn.execute( conn.execute(
@@ -192,6 +205,7 @@ def log_trade(
fx_pnl: float | None = None, fx_pnl: float | None = None,
market: str = "KR", market: str = "KR",
exchange_code: str = "KRX", exchange_code: str = "KRX",
session_id: str | None = None,
selection_context: dict[str, any] | None = None, selection_context: dict[str, any] | None = None,
decision_id: str | None = None, decision_id: str | None = None,
mode: str = "paper", mode: str = "paper",
@@ -211,12 +225,17 @@ def log_trade(
fx_pnl: FX PnL component fx_pnl: FX PnL component
market: Market code market: Market code
exchange_code: Exchange code exchange_code: Exchange code
session_id: Session identifier (if omitted, auto-derived from market)
selection_context: Scanner selection data (RSI, volume_ratio, signal, score) selection_context: Scanner selection data (RSI, volume_ratio, signal, score)
decision_id: Unique decision identifier for audit linking decision_id: Unique decision identifier for audit linking
mode: Trading mode ('paper' or 'live') for data separation mode: Trading mode ('paper' or 'live') for data separation
""" """
# Serialize selection context to JSON # Serialize selection context to JSON
context_json = json.dumps(selection_context) if selection_context else None context_json = json.dumps(selection_context) if selection_context else None
resolved_session_id = session_id or "UNKNOWN"
market_info = MARKETS.get(market)
if session_id is None and market_info is not None:
resolved_session_id = classify_session_id(market_info)
if strategy_pnl is None and fx_pnl is None: if strategy_pnl is None and fx_pnl is None:
strategy_pnl = pnl strategy_pnl = pnl
fx_pnl = 0.0 fx_pnl = 0.0
@@ -232,9 +251,9 @@ def log_trade(
INSERT INTO trades ( INSERT INTO trades (
timestamp, stock_code, action, confidence, rationale, timestamp, stock_code, action, confidence, rationale,
quantity, price, pnl, strategy_pnl, fx_pnl, quantity, price, pnl, strategy_pnl, fx_pnl,
market, exchange_code, selection_context, decision_id, mode market, exchange_code, session_id, selection_context, decision_id, mode
) )
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", """,
( (
datetime.now(UTC).isoformat(), datetime.now(UTC).isoformat(),
@@ -249,6 +268,7 @@ def log_trade(
fx_pnl, fx_pnl,
market, market,
exchange_code, exchange_code,
resolved_session_id,
context_json, context_json,
decision_id, decision_id,
mode, mode,

View File

@@ -155,6 +155,7 @@ def test_mode_column_exists_in_schema() -> None:
cursor = conn.execute("PRAGMA table_info(trades)") cursor = conn.execute("PRAGMA table_info(trades)")
columns = {row[1] for row in cursor.fetchall()} columns = {row[1] for row in cursor.fetchall()}
assert "mode" in columns assert "mode" in columns
assert "session_id" in columns
assert "strategy_pnl" in columns assert "strategy_pnl" in columns
assert "fx_pnl" in columns assert "fx_pnl" in columns
@@ -199,15 +200,17 @@ def test_mode_migration_adds_column_to_existing_db() -> None:
cursor = conn.execute("PRAGMA table_info(trades)") cursor = conn.execute("PRAGMA table_info(trades)")
columns = {row[1] for row in cursor.fetchall()} columns = {row[1] for row in cursor.fetchall()}
assert "mode" in columns assert "mode" in columns
assert "session_id" in columns
assert "strategy_pnl" in columns assert "strategy_pnl" in columns
assert "fx_pnl" in columns assert "fx_pnl" in columns
migrated = conn.execute( migrated = conn.execute(
"SELECT pnl, strategy_pnl, fx_pnl FROM trades WHERE stock_code='AAPL' LIMIT 1" "SELECT pnl, strategy_pnl, fx_pnl, session_id FROM trades WHERE stock_code='AAPL' LIMIT 1"
).fetchone() ).fetchone()
assert migrated is not None assert migrated is not None
assert migrated[0] == 123.45 assert migrated[0] == 123.45
assert migrated[1] == 123.45 assert migrated[1] == 123.45
assert migrated[2] == 0.0 assert migrated[2] == 0.0
assert migrated[3] == "UNKNOWN"
conn.close() conn.close()
finally: finally:
os.unlink(db_path) os.unlink(db_path)
@@ -277,3 +280,20 @@ def test_log_trade_partial_fx_input_does_not_infer_negative_strategy_pnl() -> No
assert row[0] == 10.0 assert row[0] == 10.0
assert row[1] == 0.0 assert row[1] == 0.0
assert row[2] == 10.0 assert row[2] == 10.0
def test_log_trade_persists_explicit_session_id() -> None:
conn = init_db(":memory:")
log_trade(
conn=conn,
stock_code="AAPL",
action="BUY",
confidence=70,
rationale="session test",
market="US_NASDAQ",
exchange_code="NASD",
session_id="US_PRE",
)
row = conn.execute("SELECT session_id FROM trades ORDER BY id DESC LIMIT 1").fetchone()
assert row is not None
assert row[0] == "US_PRE"