1 Commits

Author SHA1 Message Date
agentson
47aadcb4e7 fix: include exchange_code in latest BUY matching key (#323)
Some checks failed
Gitea CI / test (push) Has been cancelled
Gitea CI / test (pull_request) Has been cancelled
2026-02-28 14:38:53 +09:00
9 changed files with 38 additions and 281 deletions

View File

@@ -5,9 +5,7 @@ Implements first-touch labeling with upper/lower/time barriers.
from __future__ import annotations from __future__ import annotations
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Literal, Sequence from typing import Literal, Sequence
@@ -18,18 +16,9 @@ TieBreakMode = Literal["stop_first", "take_first"]
class TripleBarrierSpec: class TripleBarrierSpec:
take_profit_pct: float take_profit_pct: float
stop_loss_pct: float stop_loss_pct: float
max_holding_bars: int | None = None max_holding_bars: int
max_holding_minutes: int | None = None
tie_break: TieBreakMode = "stop_first" tie_break: TieBreakMode = "stop_first"
def __post_init__(self) -> None:
if self.max_holding_minutes is None and self.max_holding_bars is None:
raise ValueError("one of max_holding_minutes or max_holding_bars must be set")
if self.max_holding_minutes is not None and self.max_holding_minutes <= 0:
raise ValueError("max_holding_minutes must be positive")
if self.max_holding_bars is not None and self.max_holding_bars <= 0:
raise ValueError("max_holding_bars must be positive")
@dataclass(frozen=True) @dataclass(frozen=True)
class TripleBarrierLabel: class TripleBarrierLabel:
@@ -46,7 +35,6 @@ def label_with_triple_barrier(
highs: Sequence[float], highs: Sequence[float],
lows: Sequence[float], lows: Sequence[float],
closes: Sequence[float], closes: Sequence[float],
timestamps: Sequence[datetime] | None = None,
entry_index: int, entry_index: int,
side: int, side: int,
spec: TripleBarrierSpec, spec: TripleBarrierSpec,
@@ -65,6 +53,8 @@ def label_with_triple_barrier(
raise ValueError("highs, lows, closes lengths must match") raise ValueError("highs, lows, closes lengths must match")
if entry_index < 0 or entry_index >= len(closes): if entry_index < 0 or entry_index >= len(closes):
raise IndexError("entry_index out of range") raise IndexError("entry_index out of range")
if spec.max_holding_bars <= 0:
raise ValueError("max_holding_bars must be positive")
entry_price = float(closes[entry_index]) entry_price = float(closes[entry_index])
if entry_price <= 0: if entry_price <= 0:
@@ -78,31 +68,13 @@ def label_with_triple_barrier(
upper = entry_price * (1.0 + spec.stop_loss_pct) upper = entry_price * (1.0 + spec.stop_loss_pct)
lower = entry_price * (1.0 - spec.take_profit_pct) lower = entry_price * (1.0 - spec.take_profit_pct)
if spec.max_holding_minutes is not None: last_index = min(len(closes) - 1, entry_index + spec.max_holding_bars)
if timestamps is None:
raise ValueError("timestamps are required when max_holding_minutes is set")
if len(timestamps) != len(closes):
raise ValueError("timestamps length must match OHLC lengths")
expiry_timestamp = timestamps[entry_index] + timedelta(minutes=spec.max_holding_minutes)
last_index = entry_index
for idx in range(entry_index + 1, len(closes)):
if timestamps[idx] > expiry_timestamp:
break
last_index = idx
else:
assert spec.max_holding_bars is not None
warnings.warn(
"TripleBarrierSpec.max_holding_bars is deprecated; use max_holding_minutes with timestamps instead.",
DeprecationWarning,
stacklevel=2,
)
last_index = min(len(closes) - 1, entry_index + spec.max_holding_bars)
for idx in range(entry_index + 1, last_index + 1): for idx in range(entry_index + 1, last_index + 1):
high_price = float(highs[idx]) h = float(highs[idx])
low_price = float(lows[idx]) l = float(lows[idx])
up_touch = high_price >= upper up_touch = h >= upper
down_touch = low_price <= lower down_touch = l <= lower
if not up_touch and not down_touch: if not up_touch and not down_touch:
continue continue

View File

@@ -109,7 +109,6 @@ def init_db(db_path: str) -> sqlite3.Connection:
stock_code TEXT NOT NULL, stock_code TEXT NOT NULL,
market TEXT NOT NULL, market TEXT NOT NULL,
exchange_code TEXT NOT NULL, exchange_code TEXT NOT NULL,
session_id TEXT DEFAULT 'UNKNOWN',
action TEXT NOT NULL, action TEXT NOT NULL,
confidence INTEGER NOT NULL, confidence INTEGER NOT NULL,
rationale TEXT NOT NULL, rationale TEXT NOT NULL,
@@ -122,27 +121,6 @@ def init_db(db_path: str) -> sqlite3.Connection:
) )
""" """
) )
decision_columns = {
row[1]
for row in conn.execute("PRAGMA table_info(decision_logs)").fetchall()
}
if "session_id" not in decision_columns:
conn.execute("ALTER TABLE decision_logs ADD COLUMN session_id TEXT DEFAULT 'UNKNOWN'")
conn.execute(
"""
UPDATE decision_logs
SET session_id = 'UNKNOWN'
WHERE session_id IS NULL OR session_id = ''
"""
)
if "outcome_pnl" not in decision_columns:
conn.execute("ALTER TABLE decision_logs ADD COLUMN outcome_pnl REAL")
if "outcome_accuracy" not in decision_columns:
conn.execute("ALTER TABLE decision_logs ADD COLUMN outcome_accuracy INTEGER")
if "reviewed" not in decision_columns:
conn.execute("ALTER TABLE decision_logs ADD COLUMN reviewed INTEGER DEFAULT 0")
if "review_notes" not in decision_columns:
conn.execute("ALTER TABLE decision_logs ADD COLUMN review_notes TEXT")
conn.execute( conn.execute(
""" """

View File

@@ -9,7 +9,6 @@ This module:
from __future__ import annotations from __future__ import annotations
import ast
import json import json
import logging import logging
import sqlite3 import sqlite3
@@ -29,24 +28,24 @@ from src.logging.decision_logger import DecisionLogger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
STRATEGIES_DIR = Path("src/strategies") STRATEGIES_DIR = Path("src/strategies")
STRATEGY_TEMPLATE = """\ STRATEGY_TEMPLATE = textwrap.dedent("""\
\"\"\"Auto-generated strategy: {name} \"\"\"Auto-generated strategy: {name}
Generated at: {timestamp} Generated at: {timestamp}
Rationale: {rationale} Rationale: {rationale}
\"\"\" \"\"\"
from __future__ import annotations from __future__ import annotations
from typing import Any from typing import Any
from src.strategies.base import BaseStrategy from src.strategies.base import BaseStrategy
class {class_name}(BaseStrategy): class {class_name}(BaseStrategy):
\"\"\"Strategy: {name}\"\"\" \"\"\"Strategy: {name}\"\"\"
def evaluate(self, market_data: dict[str, Any]) -> dict[str, Any]: def evaluate(self, market_data: dict[str, Any]) -> dict[str, Any]:
{body} {body}
""" """)
class EvolutionOptimizer: class EvolutionOptimizer:
@@ -236,8 +235,7 @@ class EvolutionOptimizer:
file_path = STRATEGIES_DIR / file_name file_path = STRATEGIES_DIR / file_name
# Indent the body for the class method # Indent the body for the class method
normalized_body = textwrap.dedent(body).strip() indented_body = textwrap.indent(body, " ")
indented_body = textwrap.indent(normalized_body, " ")
# Generate rationale from patterns # Generate rationale from patterns
rationale = f"Auto-evolved from {len(failures)} failures. " rationale = f"Auto-evolved from {len(failures)} failures. "
@@ -249,16 +247,9 @@ class EvolutionOptimizer:
timestamp=datetime.now(UTC).isoformat(), timestamp=datetime.now(UTC).isoformat(),
rationale=rationale, rationale=rationale,
class_name=class_name, class_name=class_name,
body=indented_body.rstrip(), body=indented_body.strip(),
) )
try:
parsed = ast.parse(content, filename=str(file_path))
compile(parsed, filename=str(file_path), mode="exec")
except SyntaxError as exc:
logger.warning("Generated strategy failed syntax validation: %s", exc)
return None
file_path.write_text(content) file_path.write_text(content)
logger.info("Generated strategy file: %s", file_path) logger.info("Generated strategy file: %s", file_path)
return file_path return file_path

View File

@@ -19,7 +19,6 @@ class DecisionLog:
stock_code: str stock_code: str
market: str market: str
exchange_code: str exchange_code: str
session_id: str
action: str action: str
confidence: int confidence: int
rationale: str rationale: str
@@ -48,7 +47,6 @@ class DecisionLogger:
rationale: str, rationale: str,
context_snapshot: dict[str, Any], context_snapshot: dict[str, Any],
input_data: dict[str, Any], input_data: dict[str, Any],
session_id: str | None = None,
) -> str: ) -> str:
"""Log a trading decision with full context. """Log a trading decision with full context.
@@ -61,22 +59,20 @@ class DecisionLogger:
rationale: Reasoning for the decision rationale: Reasoning for the decision
context_snapshot: L1-L7 context snapshot at decision time context_snapshot: L1-L7 context snapshot at decision time
input_data: Market data inputs (price, volume, orderbook, etc.) input_data: Market data inputs (price, volume, orderbook, etc.)
session_id: Runtime session identifier
Returns: Returns:
decision_id: Unique identifier for this decision decision_id: Unique identifier for this decision
""" """
decision_id = str(uuid.uuid4()) decision_id = str(uuid.uuid4())
timestamp = datetime.now(UTC).isoformat() timestamp = datetime.now(UTC).isoformat()
resolved_session = session_id or "UNKNOWN"
self.conn.execute( self.conn.execute(
""" """
INSERT INTO decision_logs ( INSERT INTO decision_logs (
decision_id, timestamp, stock_code, market, exchange_code, decision_id, timestamp, stock_code, market, exchange_code,
session_id, action, confidence, rationale, context_snapshot, input_data action, confidence, rationale, context_snapshot, input_data
) )
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", """,
( (
decision_id, decision_id,
@@ -84,7 +80,6 @@ class DecisionLogger:
stock_code, stock_code,
market, market,
exchange_code, exchange_code,
resolved_session,
action, action,
confidence, confidence,
rationale, rationale,
@@ -111,7 +106,7 @@ class DecisionLogger:
query = """ query = """
SELECT SELECT
decision_id, timestamp, stock_code, market, exchange_code, decision_id, timestamp, stock_code, market, exchange_code,
session_id, action, confidence, rationale, context_snapshot, input_data, action, confidence, rationale, context_snapshot, input_data,
outcome_pnl, outcome_accuracy, reviewed, review_notes outcome_pnl, outcome_accuracy, reviewed, review_notes
FROM decision_logs FROM decision_logs
WHERE reviewed = 0 AND confidence >= ? WHERE reviewed = 0 AND confidence >= ?
@@ -173,7 +168,7 @@ class DecisionLogger:
""" """
SELECT SELECT
decision_id, timestamp, stock_code, market, exchange_code, decision_id, timestamp, stock_code, market, exchange_code,
session_id, action, confidence, rationale, context_snapshot, input_data, action, confidence, rationale, context_snapshot, input_data,
outcome_pnl, outcome_accuracy, reviewed, review_notes outcome_pnl, outcome_accuracy, reviewed, review_notes
FROM decision_logs FROM decision_logs
WHERE decision_id = ? WHERE decision_id = ?
@@ -201,7 +196,7 @@ class DecisionLogger:
""" """
SELECT SELECT
decision_id, timestamp, stock_code, market, exchange_code, decision_id, timestamp, stock_code, market, exchange_code,
session_id, action, confidence, rationale, context_snapshot, input_data, action, confidence, rationale, context_snapshot, input_data,
outcome_pnl, outcome_accuracy, reviewed, review_notes outcome_pnl, outcome_accuracy, reviewed, review_notes
FROM decision_logs FROM decision_logs
WHERE confidence >= ? WHERE confidence >= ?
@@ -228,14 +223,13 @@ class DecisionLogger:
stock_code=row[2], stock_code=row[2],
market=row[3], market=row[3],
exchange_code=row[4], exchange_code=row[4],
session_id=row[5] or "UNKNOWN", action=row[5],
action=row[6], confidence=row[6],
confidence=row[7], rationale=row[7],
rationale=row[8], context_snapshot=json.loads(row[8]),
context_snapshot=json.loads(row[9]), input_data=json.loads(row[9]),
input_data=json.loads(row[10]), outcome_pnl=row[10],
outcome_pnl=row[11], outcome_accuracy=row[11],
outcome_accuracy=row[12], reviewed=bool(row[12]),
reviewed=bool(row[13]), review_notes=row[13],
review_notes=row[14],
) )

View File

@@ -217,7 +217,6 @@ async def sync_positions_from_broker(
price=avg_price, price=avg_price,
market=log_market, market=log_market,
exchange_code=market.exchange_code, exchange_code=market.exchange_code,
session_id=get_session_info(market).session_id,
mode=settings.MODE, mode=settings.MODE,
) )
logger.info( logger.info(
@@ -1369,12 +1368,10 @@ async def trading_cycle(
"pnl_pct": pnl_pct, "pnl_pct": pnl_pct,
} }
runtime_session_id = get_session_info(market).session_id
decision_id = decision_logger.log_decision( decision_id = decision_logger.log_decision(
stock_code=stock_code, stock_code=stock_code,
market=market.code, market=market.code,
exchange_code=market.exchange_code, exchange_code=market.exchange_code,
session_id=runtime_session_id,
action=decision.action, action=decision.action,
confidence=decision.confidence, confidence=decision.confidence,
rationale=decision.rationale, rationale=decision.rationale,
@@ -1639,7 +1636,6 @@ async def trading_cycle(
pnl=0.0, pnl=0.0,
market=market.code, market=market.code,
exchange_code=market.exchange_code, exchange_code=market.exchange_code,
session_id=runtime_session_id,
mode=settings.MODE if settings else "paper", mode=settings.MODE if settings else "paper",
) )
logger.info("Order result: %s", result.get("msg1", "OK")) logger.info("Order result: %s", result.get("msg1", "OK"))
@@ -1699,7 +1695,6 @@ async def trading_cycle(
pnl=trade_pnl, pnl=trade_pnl,
market=market.code, market=market.code,
exchange_code=market.exchange_code, exchange_code=market.exchange_code,
session_id=runtime_session_id,
selection_context=selection_context, selection_context=selection_context,
decision_id=decision_id, decision_id=decision_id,
mode=settings.MODE if settings else "paper", mode=settings.MODE if settings else "paper",
@@ -2507,12 +2502,10 @@ async def run_daily_session(
"pnl_pct": pnl_pct, "pnl_pct": pnl_pct,
} }
runtime_session_id = get_session_info(market).session_id
decision_id = decision_logger.log_decision( decision_id = decision_logger.log_decision(
stock_code=stock_code, stock_code=stock_code,
market=market.code, market=market.code,
exchange_code=market.exchange_code, exchange_code=market.exchange_code,
session_id=runtime_session_id,
action=decision.action, action=decision.action,
confidence=decision.confidence, confidence=decision.confidence,
rationale=decision.rationale, rationale=decision.rationale,
@@ -2794,7 +2787,6 @@ async def run_daily_session(
pnl=trade_pnl, pnl=trade_pnl,
market=market.code, market=market.code,
exchange_code=market.exchange_code, exchange_code=market.exchange_code,
session_id=runtime_session_id,
decision_id=decision_id, decision_id=decision_id,
mode=settings.MODE, mode=settings.MODE,
) )

View File

@@ -365,53 +365,3 @@ def test_get_latest_buy_trade_prefers_exchange_code_match() -> None:
) )
assert matched is not None assert matched is not None
assert matched["decision_id"] == "matched-buy" assert matched["decision_id"] == "matched-buy"
def test_decision_logs_session_id_migration_backfills_unknown() -> None:
import sqlite3
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
try:
old_conn = sqlite3.connect(db_path)
old_conn.execute(
"""
CREATE TABLE decision_logs (
decision_id TEXT PRIMARY KEY,
timestamp TEXT NOT NULL,
stock_code TEXT NOT NULL,
market TEXT NOT NULL,
exchange_code TEXT NOT NULL,
action TEXT NOT NULL,
confidence INTEGER NOT NULL,
rationale TEXT NOT NULL,
context_snapshot TEXT NOT NULL,
input_data TEXT NOT NULL
)
"""
)
old_conn.execute(
"""
INSERT INTO decision_logs (
decision_id, timestamp, stock_code, market, exchange_code,
action, confidence, rationale, context_snapshot, input_data
) VALUES (
'd1', '2026-01-01T00:00:00+00:00', 'AAPL', 'US_NASDAQ', 'NASD',
'BUY', 80, 'legacy row', '{}', '{}'
)
"""
)
old_conn.commit()
old_conn.close()
conn = init_db(db_path)
columns = {row[1] for row in conn.execute("PRAGMA table_info(decision_logs)").fetchall()}
assert "session_id" in columns
row = conn.execute(
"SELECT session_id FROM decision_logs WHERE decision_id='d1'"
).fetchone()
assert row is not None
assert row[0] == "UNKNOWN"
conn.close()
finally:
os.unlink(db_path)

View File

@@ -49,7 +49,7 @@ def test_log_decision_creates_record(logger: DecisionLogger, db_conn: sqlite3.Co
# Verify record exists in database # Verify record exists in database
cursor = db_conn.execute( cursor = db_conn.execute(
"SELECT decision_id, action, confidence, session_id FROM decision_logs WHERE decision_id = ?", "SELECT decision_id, action, confidence FROM decision_logs WHERE decision_id = ?",
(decision_id,), (decision_id,),
) )
row = cursor.fetchone() row = cursor.fetchone()
@@ -57,7 +57,6 @@ def test_log_decision_creates_record(logger: DecisionLogger, db_conn: sqlite3.Co
assert row[0] == decision_id assert row[0] == decision_id
assert row[1] == "BUY" assert row[1] == "BUY"
assert row[2] == 85 assert row[2] == 85
assert row[3] == "UNKNOWN"
def test_log_decision_stores_context_snapshot(logger: DecisionLogger) -> None: def test_log_decision_stores_context_snapshot(logger: DecisionLogger) -> None:
@@ -85,24 +84,6 @@ def test_log_decision_stores_context_snapshot(logger: DecisionLogger) -> None:
assert decision is not None assert decision is not None
assert decision.context_snapshot == context_snapshot assert decision.context_snapshot == context_snapshot
assert decision.input_data == input_data assert decision.input_data == input_data
assert decision.session_id == "UNKNOWN"
def test_log_decision_stores_explicit_session_id(logger: DecisionLogger) -> None:
decision_id = logger.log_decision(
stock_code="AAPL",
market="US_NASDAQ",
exchange_code="NASD",
action="BUY",
confidence=88,
rationale="session check",
context_snapshot={},
input_data={},
session_id="US_PRE",
)
decision = logger.get_decision_by_id(decision_id)
assert decision is not None
assert decision.session_id == "US_PRE"
def test_get_unreviewed_decisions(logger: DecisionLogger) -> None: def test_get_unreviewed_decisions(logger: DecisionLogger) -> None:
@@ -297,7 +278,6 @@ def test_decision_log_dataclass() -> None:
stock_code="005930", stock_code="005930",
market="KR", market="KR",
exchange_code="KRX", exchange_code="KRX",
session_id="KRX_REG",
action="BUY", action="BUY",
confidence=85, confidence=85,
rationale="Test", rationale="Test",
@@ -306,7 +286,6 @@ def test_decision_log_dataclass() -> None:
) )
assert log.decision_id == "test-uuid" assert log.decision_id == "test-uuid"
assert log.session_id == "KRX_REG"
assert log.action == "BUY" assert log.action == "BUY"
assert log.confidence == 85 assert log.confidence == 85
assert log.reviewed is False assert log.reviewed is False

View File

@@ -245,52 +245,6 @@ async def test_generate_strategy_creates_file(optimizer: EvolutionOptimizer, tmp
assert "def evaluate" in strategy_path.read_text() assert "def evaluate" in strategy_path.read_text()
@pytest.mark.asyncio
async def test_generate_strategy_saves_valid_python_code(
optimizer: EvolutionOptimizer, tmp_path: Path,
) -> None:
"""Test that syntactically valid generated code is saved."""
failures = [{"decision_id": "1", "timestamp": "2024-01-15T09:30:00+00:00"}]
mock_response = Mock()
mock_response.text = (
'price = market_data.get("current_price", 0)\n'
'if price > 0:\n'
' return {"action": "BUY", "confidence": 80, "rationale": "Positive price"}\n'
'return {"action": "HOLD", "confidence": 50, "rationale": "No signal"}\n'
)
with patch.object(optimizer._client.aio.models, "generate_content", new=AsyncMock(return_value=mock_response)):
with patch("src.evolution.optimizer.STRATEGIES_DIR", tmp_path):
strategy_path = await optimizer.generate_strategy(failures)
assert strategy_path is not None
assert strategy_path.exists()
@pytest.mark.asyncio
async def test_generate_strategy_blocks_invalid_python_code(
optimizer: EvolutionOptimizer, tmp_path: Path, caplog: pytest.LogCaptureFixture,
) -> None:
"""Test that syntactically invalid generated code is not saved."""
failures = [{"decision_id": "1", "timestamp": "2024-01-15T09:30:00+00:00"}]
mock_response = Mock()
mock_response.text = (
'if market_data.get("current_price", 0) > 0\n'
' return {"action": "BUY", "confidence": 80, "rationale": "broken"}\n'
)
with patch.object(optimizer._client.aio.models, "generate_content", new=AsyncMock(return_value=mock_response)):
with patch("src.evolution.optimizer.STRATEGIES_DIR", tmp_path):
with caplog.at_level("WARNING"):
strategy_path = await optimizer.generate_strategy(failures)
assert strategy_path is None
assert list(tmp_path.glob("*.py")) == []
assert "failed syntax validation" in caplog.text
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_strategy_handles_api_error(optimizer: EvolutionOptimizer) -> None: async def test_generate_strategy_handles_api_error(optimizer: EvolutionOptimizer) -> None:
"""Test that generate_strategy handles Gemini API errors gracefully.""" """Test that generate_strategy handles Gemini API errors gracefully."""

View File

@@ -1,9 +1,5 @@
from __future__ import annotations from __future__ import annotations
from datetime import UTC, datetime, timedelta
import pytest
from src.analysis.triple_barrier import TripleBarrierSpec, label_with_triple_barrier from src.analysis.triple_barrier import TripleBarrierSpec, label_with_triple_barrier
@@ -133,52 +129,3 @@ def test_short_tie_break_modes() -> None:
) )
assert out_take.label == 1 assert out_take.label == 1
assert out_take.touched == "take_profit" assert out_take.touched == "take_profit"
def test_minutes_time_barrier_consistent_across_sampling() -> None:
base = datetime(2026, 2, 28, 9, 0, tzinfo=UTC)
highs = [100.0, 100.5, 100.6, 100.4]
lows = [100.0, 99.6, 99.4, 99.5]
closes = [100.0, 100.1, 100.0, 100.0]
spec = TripleBarrierSpec(
take_profit_pct=0.02,
stop_loss_pct=0.02,
max_holding_minutes=5,
)
out_1m = label_with_triple_barrier(
highs=highs,
lows=lows,
closes=closes,
timestamps=[base + timedelta(minutes=i) for i in range(4)],
entry_index=0,
side=1,
spec=spec,
)
out_5m = label_with_triple_barrier(
highs=highs,
lows=lows,
closes=closes,
timestamps=[base + timedelta(minutes=5 * i) for i in range(4)],
entry_index=0,
side=1,
spec=spec,
)
assert out_1m.touch_bar == 3
assert out_5m.touch_bar == 1
def test_bars_mode_emits_deprecation_warning() -> None:
highs = [100, 101, 103]
lows = [100, 99.6, 100]
closes = [100, 100, 102]
spec = TripleBarrierSpec(take_profit_pct=0.02, stop_loss_pct=0.01, max_holding_bars=3)
with pytest.deprecated_call(match="max_holding_bars is deprecated"):
label_with_triple_barrier(
highs=highs,
lows=lows,
closes=closes,
entry_index=0,
side=1,
spec=spec,
)