fix: lazy session resolver and one-time session_id backfill
This commit is contained in:
26
src/db.py
26
src/db.py
@@ -8,9 +8,6 @@ from datetime import UTC, datetime
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from src.core.order_policy import classify_session_id
|
|
||||||
from src.markets.schedule import MARKETS
|
|
||||||
|
|
||||||
|
|
||||||
def init_db(db_path: str) -> sqlite3.Connection:
|
def init_db(db_path: str) -> sqlite3.Connection:
|
||||||
"""Initialize the trade logs database and return a connection."""
|
"""Initialize the trade logs database and return a connection."""
|
||||||
@@ -60,8 +57,10 @@ def init_db(db_path: str) -> sqlite3.Connection:
|
|||||||
conn.execute("ALTER TABLE trades ADD COLUMN decision_id TEXT")
|
conn.execute("ALTER TABLE trades ADD COLUMN decision_id TEXT")
|
||||||
if "mode" not in columns:
|
if "mode" not in columns:
|
||||||
conn.execute("ALTER TABLE trades ADD COLUMN mode TEXT DEFAULT 'paper'")
|
conn.execute("ALTER TABLE trades ADD COLUMN mode TEXT DEFAULT 'paper'")
|
||||||
|
session_id_added = False
|
||||||
if "session_id" not in columns:
|
if "session_id" not in columns:
|
||||||
conn.execute("ALTER TABLE trades ADD COLUMN session_id TEXT DEFAULT 'UNKNOWN'")
|
conn.execute("ALTER TABLE trades ADD COLUMN session_id TEXT DEFAULT 'UNKNOWN'")
|
||||||
|
session_id_added = True
|
||||||
if "strategy_pnl" not in columns:
|
if "strategy_pnl" not in columns:
|
||||||
conn.execute("ALTER TABLE trades ADD COLUMN strategy_pnl REAL DEFAULT 0.0")
|
conn.execute("ALTER TABLE trades ADD COLUMN strategy_pnl REAL DEFAULT 0.0")
|
||||||
if "fx_pnl" not in columns:
|
if "fx_pnl" not in columns:
|
||||||
@@ -76,6 +75,7 @@ def init_db(db_path: str) -> sqlite3.Connection:
|
|||||||
AND fx_pnl = 0.0
|
AND fx_pnl = 0.0
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
if session_id_added:
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"""
|
"""
|
||||||
UPDATE trades
|
UPDATE trades
|
||||||
@@ -232,10 +232,7 @@ def log_trade(
|
|||||||
"""
|
"""
|
||||||
# Serialize selection context to JSON
|
# Serialize selection context to JSON
|
||||||
context_json = json.dumps(selection_context) if selection_context else None
|
context_json = json.dumps(selection_context) if selection_context else None
|
||||||
resolved_session_id = session_id or "UNKNOWN"
|
resolved_session_id = _resolve_session_id(market=market, session_id=session_id)
|
||||||
market_info = MARKETS.get(market)
|
|
||||||
if session_id is None and market_info is not None:
|
|
||||||
resolved_session_id = classify_session_id(market_info)
|
|
||||||
if strategy_pnl is None and fx_pnl is None:
|
if strategy_pnl is None and fx_pnl is None:
|
||||||
strategy_pnl = pnl
|
strategy_pnl = pnl
|
||||||
fx_pnl = 0.0
|
fx_pnl = 0.0
|
||||||
@@ -277,6 +274,21 @@ def log_trade(
|
|||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_session_id(*, market: str, session_id: str | None) -> str:
|
||||||
|
if session_id:
|
||||||
|
return session_id
|
||||||
|
try:
|
||||||
|
from src.core.order_policy import classify_session_id
|
||||||
|
from src.markets.schedule import MARKETS
|
||||||
|
|
||||||
|
market_info = MARKETS.get(market)
|
||||||
|
if market_info is not None:
|
||||||
|
return classify_session_id(market_info)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return "UNKNOWN"
|
||||||
|
|
||||||
|
|
||||||
def get_latest_buy_trade(
|
def get_latest_buy_trade(
|
||||||
conn: sqlite3.Connection, stock_code: str, market: str
|
conn: sqlite3.Connection, stock_code: str, market: str
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
|
|||||||
@@ -297,3 +297,35 @@ def test_log_trade_persists_explicit_session_id() -> None:
|
|||||||
row = conn.execute("SELECT session_id FROM trades ORDER BY id DESC LIMIT 1").fetchone()
|
row = conn.execute("SELECT session_id FROM trades ORDER BY id DESC LIMIT 1").fetchone()
|
||||||
assert row is not None
|
assert row is not None
|
||||||
assert row[0] == "US_PRE"
|
assert row[0] == "US_PRE"
|
||||||
|
|
||||||
|
|
||||||
|
def test_log_trade_auto_derives_session_id_when_not_provided() -> None:
|
||||||
|
conn = init_db(":memory:")
|
||||||
|
log_trade(
|
||||||
|
conn=conn,
|
||||||
|
stock_code="005930",
|
||||||
|
action="BUY",
|
||||||
|
confidence=70,
|
||||||
|
rationale="auto session",
|
||||||
|
market="KR",
|
||||||
|
exchange_code="KRX",
|
||||||
|
)
|
||||||
|
row = conn.execute("SELECT session_id FROM trades ORDER BY id DESC LIMIT 1").fetchone()
|
||||||
|
assert row is not None
|
||||||
|
assert row[0] != "UNKNOWN"
|
||||||
|
|
||||||
|
|
||||||
|
def test_log_trade_unknown_market_falls_back_to_unknown_session() -> None:
|
||||||
|
conn = init_db(":memory:")
|
||||||
|
log_trade(
|
||||||
|
conn=conn,
|
||||||
|
stock_code="X",
|
||||||
|
action="BUY",
|
||||||
|
confidence=70,
|
||||||
|
rationale="unknown market",
|
||||||
|
market="MARS",
|
||||||
|
exchange_code="MARS",
|
||||||
|
)
|
||||||
|
row = conn.execute("SELECT session_id FROM trades ORDER BY id DESC LIMIT 1").fetchone()
|
||||||
|
assert row is not None
|
||||||
|
assert row[0] == "UNKNOWN"
|
||||||
|
|||||||
Reference in New Issue
Block a user