Compare commits
7 Commits
47aadcb4e7
...
92261da414
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
92261da414 | ||
| ea7260d574 | |||
| a2855e286e | |||
| 28ded34441 | |||
|
|
11b9ad126f | ||
|
|
c641097fe7 | ||
|
|
2f3b2149d5 |
@@ -5,7 +5,9 @@ Implements first-touch labeling with upper/lower/time barriers.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timedelta
|
||||||
from typing import Literal, Sequence
|
from typing import Literal, Sequence
|
||||||
|
|
||||||
|
|
||||||
@@ -16,9 +18,18 @@ TieBreakMode = Literal["stop_first", "take_first"]
|
|||||||
class TripleBarrierSpec:
|
class TripleBarrierSpec:
|
||||||
take_profit_pct: float
|
take_profit_pct: float
|
||||||
stop_loss_pct: float
|
stop_loss_pct: float
|
||||||
max_holding_bars: int
|
max_holding_bars: int | None = None
|
||||||
|
max_holding_minutes: int | None = None
|
||||||
tie_break: TieBreakMode = "stop_first"
|
tie_break: TieBreakMode = "stop_first"
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
if self.max_holding_minutes is None and self.max_holding_bars is None:
|
||||||
|
raise ValueError("one of max_holding_minutes or max_holding_bars must be set")
|
||||||
|
if self.max_holding_minutes is not None and self.max_holding_minutes <= 0:
|
||||||
|
raise ValueError("max_holding_minutes must be positive")
|
||||||
|
if self.max_holding_bars is not None and self.max_holding_bars <= 0:
|
||||||
|
raise ValueError("max_holding_bars must be positive")
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class TripleBarrierLabel:
|
class TripleBarrierLabel:
|
||||||
@@ -35,6 +46,7 @@ def label_with_triple_barrier(
|
|||||||
highs: Sequence[float],
|
highs: Sequence[float],
|
||||||
lows: Sequence[float],
|
lows: Sequence[float],
|
||||||
closes: Sequence[float],
|
closes: Sequence[float],
|
||||||
|
timestamps: Sequence[datetime] | None = None,
|
||||||
entry_index: int,
|
entry_index: int,
|
||||||
side: int,
|
side: int,
|
||||||
spec: TripleBarrierSpec,
|
spec: TripleBarrierSpec,
|
||||||
@@ -53,8 +65,6 @@ def label_with_triple_barrier(
|
|||||||
raise ValueError("highs, lows, closes lengths must match")
|
raise ValueError("highs, lows, closes lengths must match")
|
||||||
if entry_index < 0 or entry_index >= len(closes):
|
if entry_index < 0 or entry_index >= len(closes):
|
||||||
raise IndexError("entry_index out of range")
|
raise IndexError("entry_index out of range")
|
||||||
if spec.max_holding_bars <= 0:
|
|
||||||
raise ValueError("max_holding_bars must be positive")
|
|
||||||
|
|
||||||
entry_price = float(closes[entry_index])
|
entry_price = float(closes[entry_index])
|
||||||
if entry_price <= 0:
|
if entry_price <= 0:
|
||||||
@@ -68,13 +78,31 @@ def label_with_triple_barrier(
|
|||||||
upper = entry_price * (1.0 + spec.stop_loss_pct)
|
upper = entry_price * (1.0 + spec.stop_loss_pct)
|
||||||
lower = entry_price * (1.0 - spec.take_profit_pct)
|
lower = entry_price * (1.0 - spec.take_profit_pct)
|
||||||
|
|
||||||
last_index = min(len(closes) - 1, entry_index + spec.max_holding_bars)
|
if spec.max_holding_minutes is not None:
|
||||||
|
if timestamps is None:
|
||||||
|
raise ValueError("timestamps are required when max_holding_minutes is set")
|
||||||
|
if len(timestamps) != len(closes):
|
||||||
|
raise ValueError("timestamps length must match OHLC lengths")
|
||||||
|
expiry_timestamp = timestamps[entry_index] + timedelta(minutes=spec.max_holding_minutes)
|
||||||
|
last_index = entry_index
|
||||||
|
for idx in range(entry_index + 1, len(closes)):
|
||||||
|
if timestamps[idx] > expiry_timestamp:
|
||||||
|
break
|
||||||
|
last_index = idx
|
||||||
|
else:
|
||||||
|
assert spec.max_holding_bars is not None
|
||||||
|
warnings.warn(
|
||||||
|
"TripleBarrierSpec.max_holding_bars is deprecated; use max_holding_minutes with timestamps instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
last_index = min(len(closes) - 1, entry_index + spec.max_holding_bars)
|
||||||
for idx in range(entry_index + 1, last_index + 1):
|
for idx in range(entry_index + 1, last_index + 1):
|
||||||
h = float(highs[idx])
|
high_price = float(highs[idx])
|
||||||
l = float(lows[idx])
|
low_price = float(lows[idx])
|
||||||
|
|
||||||
up_touch = h >= upper
|
up_touch = high_price >= upper
|
||||||
down_touch = l <= lower
|
down_touch = low_price <= lower
|
||||||
if not up_touch and not down_touch:
|
if not up_touch and not down_touch:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
75
src/db.py
75
src/db.py
@@ -109,6 +109,7 @@ def init_db(db_path: str) -> sqlite3.Connection:
|
|||||||
stock_code TEXT NOT NULL,
|
stock_code TEXT NOT NULL,
|
||||||
market TEXT NOT NULL,
|
market TEXT NOT NULL,
|
||||||
exchange_code TEXT NOT NULL,
|
exchange_code TEXT NOT NULL,
|
||||||
|
session_id TEXT DEFAULT 'UNKNOWN',
|
||||||
action TEXT NOT NULL,
|
action TEXT NOT NULL,
|
||||||
confidence INTEGER NOT NULL,
|
confidence INTEGER NOT NULL,
|
||||||
rationale TEXT NOT NULL,
|
rationale TEXT NOT NULL,
|
||||||
@@ -121,6 +122,27 @@ def init_db(db_path: str) -> sqlite3.Connection:
|
|||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
decision_columns = {
|
||||||
|
row[1]
|
||||||
|
for row in conn.execute("PRAGMA table_info(decision_logs)").fetchall()
|
||||||
|
}
|
||||||
|
if "session_id" not in decision_columns:
|
||||||
|
conn.execute("ALTER TABLE decision_logs ADD COLUMN session_id TEXT DEFAULT 'UNKNOWN'")
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
UPDATE decision_logs
|
||||||
|
SET session_id = 'UNKNOWN'
|
||||||
|
WHERE session_id IS NULL OR session_id = ''
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
if "outcome_pnl" not in decision_columns:
|
||||||
|
conn.execute("ALTER TABLE decision_logs ADD COLUMN outcome_pnl REAL")
|
||||||
|
if "outcome_accuracy" not in decision_columns:
|
||||||
|
conn.execute("ALTER TABLE decision_logs ADD COLUMN outcome_accuracy INTEGER")
|
||||||
|
if "reviewed" not in decision_columns:
|
||||||
|
conn.execute("ALTER TABLE decision_logs ADD COLUMN reviewed INTEGER DEFAULT 0")
|
||||||
|
if "review_notes" not in decision_columns:
|
||||||
|
conn.execute("ALTER TABLE decision_logs ADD COLUMN review_notes TEXT")
|
||||||
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"""
|
"""
|
||||||
@@ -290,22 +312,47 @@ def _resolve_session_id(*, market: str, session_id: str | None) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def get_latest_buy_trade(
|
def get_latest_buy_trade(
|
||||||
conn: sqlite3.Connection, stock_code: str, market: str
|
conn: sqlite3.Connection,
|
||||||
|
stock_code: str,
|
||||||
|
market: str,
|
||||||
|
exchange_code: str | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Fetch the most recent BUY trade for a stock and market."""
|
"""Fetch the most recent BUY trade for a stock and market."""
|
||||||
cursor = conn.execute(
|
if exchange_code:
|
||||||
"""
|
cursor = conn.execute(
|
||||||
SELECT decision_id, price, quantity
|
"""
|
||||||
FROM trades
|
SELECT decision_id, price, quantity
|
||||||
WHERE stock_code = ?
|
FROM trades
|
||||||
AND market = ?
|
WHERE stock_code = ?
|
||||||
AND action = 'BUY'
|
AND market = ?
|
||||||
AND decision_id IS NOT NULL
|
AND action = 'BUY'
|
||||||
ORDER BY timestamp DESC
|
AND decision_id IS NOT NULL
|
||||||
LIMIT 1
|
AND (
|
||||||
""",
|
exchange_code = ?
|
||||||
(stock_code, market),
|
OR exchange_code IS NULL
|
||||||
)
|
OR exchange_code = ''
|
||||||
|
)
|
||||||
|
ORDER BY
|
||||||
|
CASE WHEN exchange_code = ? THEN 0 ELSE 1 END,
|
||||||
|
timestamp DESC
|
||||||
|
LIMIT 1
|
||||||
|
""",
|
||||||
|
(stock_code, market, exchange_code, exchange_code),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cursor = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT decision_id, price, quantity
|
||||||
|
FROM trades
|
||||||
|
WHERE stock_code = ?
|
||||||
|
AND market = ?
|
||||||
|
AND action = 'BUY'
|
||||||
|
AND decision_id IS NOT NULL
|
||||||
|
ORDER BY timestamp DESC
|
||||||
|
LIMIT 1
|
||||||
|
""",
|
||||||
|
(stock_code, market),
|
||||||
|
)
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
if not row:
|
if not row:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ This module:
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ast
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import sqlite3
|
import sqlite3
|
||||||
@@ -28,24 +29,24 @@ from src.logging.decision_logger import DecisionLogger
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
STRATEGIES_DIR = Path("src/strategies")
|
STRATEGIES_DIR = Path("src/strategies")
|
||||||
STRATEGY_TEMPLATE = textwrap.dedent("""\
|
STRATEGY_TEMPLATE = """\
|
||||||
\"\"\"Auto-generated strategy: {name}
|
\"\"\"Auto-generated strategy: {name}
|
||||||
|
|
||||||
Generated at: {timestamp}
|
Generated at: {timestamp}
|
||||||
Rationale: {rationale}
|
Rationale: {rationale}
|
||||||
\"\"\"
|
\"\"\"
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from src.strategies.base import BaseStrategy
|
from src.strategies.base import BaseStrategy
|
||||||
|
|
||||||
|
|
||||||
class {class_name}(BaseStrategy):
|
class {class_name}(BaseStrategy):
|
||||||
\"\"\"Strategy: {name}\"\"\"
|
\"\"\"Strategy: {name}\"\"\"
|
||||||
|
|
||||||
def evaluate(self, market_data: dict[str, Any]) -> dict[str, Any]:
|
def evaluate(self, market_data: dict[str, Any]) -> dict[str, Any]:
|
||||||
{body}
|
{body}
|
||||||
""")
|
"""
|
||||||
|
|
||||||
|
|
||||||
class EvolutionOptimizer:
|
class EvolutionOptimizer:
|
||||||
@@ -235,7 +236,8 @@ class EvolutionOptimizer:
|
|||||||
file_path = STRATEGIES_DIR / file_name
|
file_path = STRATEGIES_DIR / file_name
|
||||||
|
|
||||||
# Indent the body for the class method
|
# Indent the body for the class method
|
||||||
indented_body = textwrap.indent(body, " ")
|
normalized_body = textwrap.dedent(body).strip()
|
||||||
|
indented_body = textwrap.indent(normalized_body, " ")
|
||||||
|
|
||||||
# Generate rationale from patterns
|
# Generate rationale from patterns
|
||||||
rationale = f"Auto-evolved from {len(failures)} failures. "
|
rationale = f"Auto-evolved from {len(failures)} failures. "
|
||||||
@@ -247,9 +249,16 @@ class EvolutionOptimizer:
|
|||||||
timestamp=datetime.now(UTC).isoformat(),
|
timestamp=datetime.now(UTC).isoformat(),
|
||||||
rationale=rationale,
|
rationale=rationale,
|
||||||
class_name=class_name,
|
class_name=class_name,
|
||||||
body=indented_body.strip(),
|
body=indented_body.rstrip(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = ast.parse(content, filename=str(file_path))
|
||||||
|
compile(parsed, filename=str(file_path), mode="exec")
|
||||||
|
except SyntaxError as exc:
|
||||||
|
logger.warning("Generated strategy failed syntax validation: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
file_path.write_text(content)
|
file_path.write_text(content)
|
||||||
logger.info("Generated strategy file: %s", file_path)
|
logger.info("Generated strategy file: %s", file_path)
|
||||||
return file_path
|
return file_path
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ class DecisionLog:
|
|||||||
stock_code: str
|
stock_code: str
|
||||||
market: str
|
market: str
|
||||||
exchange_code: str
|
exchange_code: str
|
||||||
|
session_id: str
|
||||||
action: str
|
action: str
|
||||||
confidence: int
|
confidence: int
|
||||||
rationale: str
|
rationale: str
|
||||||
@@ -47,6 +48,7 @@ class DecisionLogger:
|
|||||||
rationale: str,
|
rationale: str,
|
||||||
context_snapshot: dict[str, Any],
|
context_snapshot: dict[str, Any],
|
||||||
input_data: dict[str, Any],
|
input_data: dict[str, Any],
|
||||||
|
session_id: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Log a trading decision with full context.
|
"""Log a trading decision with full context.
|
||||||
|
|
||||||
@@ -59,20 +61,22 @@ class DecisionLogger:
|
|||||||
rationale: Reasoning for the decision
|
rationale: Reasoning for the decision
|
||||||
context_snapshot: L1-L7 context snapshot at decision time
|
context_snapshot: L1-L7 context snapshot at decision time
|
||||||
input_data: Market data inputs (price, volume, orderbook, etc.)
|
input_data: Market data inputs (price, volume, orderbook, etc.)
|
||||||
|
session_id: Runtime session identifier
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
decision_id: Unique identifier for this decision
|
decision_id: Unique identifier for this decision
|
||||||
"""
|
"""
|
||||||
decision_id = str(uuid.uuid4())
|
decision_id = str(uuid.uuid4())
|
||||||
timestamp = datetime.now(UTC).isoformat()
|
timestamp = datetime.now(UTC).isoformat()
|
||||||
|
resolved_session = session_id or "UNKNOWN"
|
||||||
|
|
||||||
self.conn.execute(
|
self.conn.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO decision_logs (
|
INSERT INTO decision_logs (
|
||||||
decision_id, timestamp, stock_code, market, exchange_code,
|
decision_id, timestamp, stock_code, market, exchange_code,
|
||||||
action, confidence, rationale, context_snapshot, input_data
|
session_id, action, confidence, rationale, context_snapshot, input_data
|
||||||
)
|
)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
decision_id,
|
decision_id,
|
||||||
@@ -80,6 +84,7 @@ class DecisionLogger:
|
|||||||
stock_code,
|
stock_code,
|
||||||
market,
|
market,
|
||||||
exchange_code,
|
exchange_code,
|
||||||
|
resolved_session,
|
||||||
action,
|
action,
|
||||||
confidence,
|
confidence,
|
||||||
rationale,
|
rationale,
|
||||||
@@ -106,7 +111,7 @@ class DecisionLogger:
|
|||||||
query = """
|
query = """
|
||||||
SELECT
|
SELECT
|
||||||
decision_id, timestamp, stock_code, market, exchange_code,
|
decision_id, timestamp, stock_code, market, exchange_code,
|
||||||
action, confidence, rationale, context_snapshot, input_data,
|
session_id, action, confidence, rationale, context_snapshot, input_data,
|
||||||
outcome_pnl, outcome_accuracy, reviewed, review_notes
|
outcome_pnl, outcome_accuracy, reviewed, review_notes
|
||||||
FROM decision_logs
|
FROM decision_logs
|
||||||
WHERE reviewed = 0 AND confidence >= ?
|
WHERE reviewed = 0 AND confidence >= ?
|
||||||
@@ -168,7 +173,7 @@ class DecisionLogger:
|
|||||||
"""
|
"""
|
||||||
SELECT
|
SELECT
|
||||||
decision_id, timestamp, stock_code, market, exchange_code,
|
decision_id, timestamp, stock_code, market, exchange_code,
|
||||||
action, confidence, rationale, context_snapshot, input_data,
|
session_id, action, confidence, rationale, context_snapshot, input_data,
|
||||||
outcome_pnl, outcome_accuracy, reviewed, review_notes
|
outcome_pnl, outcome_accuracy, reviewed, review_notes
|
||||||
FROM decision_logs
|
FROM decision_logs
|
||||||
WHERE decision_id = ?
|
WHERE decision_id = ?
|
||||||
@@ -196,7 +201,7 @@ class DecisionLogger:
|
|||||||
"""
|
"""
|
||||||
SELECT
|
SELECT
|
||||||
decision_id, timestamp, stock_code, market, exchange_code,
|
decision_id, timestamp, stock_code, market, exchange_code,
|
||||||
action, confidence, rationale, context_snapshot, input_data,
|
session_id, action, confidence, rationale, context_snapshot, input_data,
|
||||||
outcome_pnl, outcome_accuracy, reviewed, review_notes
|
outcome_pnl, outcome_accuracy, reviewed, review_notes
|
||||||
FROM decision_logs
|
FROM decision_logs
|
||||||
WHERE confidence >= ?
|
WHERE confidence >= ?
|
||||||
@@ -223,13 +228,14 @@ class DecisionLogger:
|
|||||||
stock_code=row[2],
|
stock_code=row[2],
|
||||||
market=row[3],
|
market=row[3],
|
||||||
exchange_code=row[4],
|
exchange_code=row[4],
|
||||||
action=row[5],
|
session_id=row[5] or "UNKNOWN",
|
||||||
confidence=row[6],
|
action=row[6],
|
||||||
rationale=row[7],
|
confidence=row[7],
|
||||||
context_snapshot=json.loads(row[8]),
|
rationale=row[8],
|
||||||
input_data=json.loads(row[9]),
|
context_snapshot=json.loads(row[9]),
|
||||||
outcome_pnl=row[10],
|
input_data=json.loads(row[10]),
|
||||||
outcome_accuracy=row[11],
|
outcome_pnl=row[11],
|
||||||
reviewed=bool(row[12]),
|
outcome_accuracy=row[12],
|
||||||
review_notes=row[13],
|
reviewed=bool(row[13]),
|
||||||
|
review_notes=row[14],
|
||||||
)
|
)
|
||||||
|
|||||||
22
src/main.py
22
src/main.py
@@ -217,6 +217,7 @@ async def sync_positions_from_broker(
|
|||||||
price=avg_price,
|
price=avg_price,
|
||||||
market=log_market,
|
market=log_market,
|
||||||
exchange_code=market.exchange_code,
|
exchange_code=market.exchange_code,
|
||||||
|
session_id=get_session_info(market).session_id,
|
||||||
mode=settings.MODE,
|
mode=settings.MODE,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -1368,10 +1369,12 @@ async def trading_cycle(
|
|||||||
"pnl_pct": pnl_pct,
|
"pnl_pct": pnl_pct,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
runtime_session_id = get_session_info(market).session_id
|
||||||
decision_id = decision_logger.log_decision(
|
decision_id = decision_logger.log_decision(
|
||||||
stock_code=stock_code,
|
stock_code=stock_code,
|
||||||
market=market.code,
|
market=market.code,
|
||||||
exchange_code=market.exchange_code,
|
exchange_code=market.exchange_code,
|
||||||
|
session_id=runtime_session_id,
|
||||||
action=decision.action,
|
action=decision.action,
|
||||||
confidence=decision.confidence,
|
confidence=decision.confidence,
|
||||||
rationale=decision.rationale,
|
rationale=decision.rationale,
|
||||||
@@ -1636,6 +1639,7 @@ async def trading_cycle(
|
|||||||
pnl=0.0,
|
pnl=0.0,
|
||||||
market=market.code,
|
market=market.code,
|
||||||
exchange_code=market.exchange_code,
|
exchange_code=market.exchange_code,
|
||||||
|
session_id=runtime_session_id,
|
||||||
mode=settings.MODE if settings else "paper",
|
mode=settings.MODE if settings else "paper",
|
||||||
)
|
)
|
||||||
logger.info("Order result: %s", result.get("msg1", "OK"))
|
logger.info("Order result: %s", result.get("msg1", "OK"))
|
||||||
@@ -1655,7 +1659,12 @@ async def trading_cycle(
|
|||||||
logger.warning("Telegram notification failed: %s", exc)
|
logger.warning("Telegram notification failed: %s", exc)
|
||||||
|
|
||||||
if decision.action == "SELL" and order_succeeded:
|
if decision.action == "SELL" and order_succeeded:
|
||||||
buy_trade = get_latest_buy_trade(db_conn, stock_code, market.code)
|
buy_trade = get_latest_buy_trade(
|
||||||
|
db_conn,
|
||||||
|
stock_code,
|
||||||
|
market.code,
|
||||||
|
exchange_code=market.exchange_code,
|
||||||
|
)
|
||||||
if buy_trade and buy_trade.get("price") is not None:
|
if buy_trade and buy_trade.get("price") is not None:
|
||||||
buy_price = float(buy_trade["price"])
|
buy_price = float(buy_trade["price"])
|
||||||
buy_qty = int(buy_trade.get("quantity") or 1)
|
buy_qty = int(buy_trade.get("quantity") or 1)
|
||||||
@@ -1690,6 +1699,7 @@ async def trading_cycle(
|
|||||||
pnl=trade_pnl,
|
pnl=trade_pnl,
|
||||||
market=market.code,
|
market=market.code,
|
||||||
exchange_code=market.exchange_code,
|
exchange_code=market.exchange_code,
|
||||||
|
session_id=runtime_session_id,
|
||||||
selection_context=selection_context,
|
selection_context=selection_context,
|
||||||
decision_id=decision_id,
|
decision_id=decision_id,
|
||||||
mode=settings.MODE if settings else "paper",
|
mode=settings.MODE if settings else "paper",
|
||||||
@@ -2497,10 +2507,12 @@ async def run_daily_session(
|
|||||||
"pnl_pct": pnl_pct,
|
"pnl_pct": pnl_pct,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
runtime_session_id = get_session_info(market).session_id
|
||||||
decision_id = decision_logger.log_decision(
|
decision_id = decision_logger.log_decision(
|
||||||
stock_code=stock_code,
|
stock_code=stock_code,
|
||||||
market=market.code,
|
market=market.code,
|
||||||
exchange_code=market.exchange_code,
|
exchange_code=market.exchange_code,
|
||||||
|
session_id=runtime_session_id,
|
||||||
action=decision.action,
|
action=decision.action,
|
||||||
confidence=decision.confidence,
|
confidence=decision.confidence,
|
||||||
rationale=decision.rationale,
|
rationale=decision.rationale,
|
||||||
@@ -2752,7 +2764,12 @@ async def run_daily_session(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if decision.action == "SELL" and order_succeeded:
|
if decision.action == "SELL" and order_succeeded:
|
||||||
buy_trade = get_latest_buy_trade(db_conn, stock_code, market.code)
|
buy_trade = get_latest_buy_trade(
|
||||||
|
db_conn,
|
||||||
|
stock_code,
|
||||||
|
market.code,
|
||||||
|
exchange_code=market.exchange_code,
|
||||||
|
)
|
||||||
if buy_trade and buy_trade.get("price") is not None:
|
if buy_trade and buy_trade.get("price") is not None:
|
||||||
buy_price = float(buy_trade["price"])
|
buy_price = float(buy_trade["price"])
|
||||||
buy_qty = int(buy_trade.get("quantity") or 1)
|
buy_qty = int(buy_trade.get("quantity") or 1)
|
||||||
@@ -2777,6 +2794,7 @@ async def run_daily_session(
|
|||||||
pnl=trade_pnl,
|
pnl=trade_pnl,
|
||||||
market=market.code,
|
market=market.code,
|
||||||
exchange_code=market.exchange_code,
|
exchange_code=market.exchange_code,
|
||||||
|
session_id=runtime_session_id,
|
||||||
decision_id=decision_id,
|
decision_id=decision_id,
|
||||||
mode=settings.MODE,
|
mode=settings.MODE,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
import tempfile
|
import tempfile
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from src.db import get_open_position, init_db, log_trade
|
from src.db import get_latest_buy_trade, get_open_position, init_db, log_trade
|
||||||
|
|
||||||
|
|
||||||
def test_get_open_position_returns_latest_buy() -> None:
|
def test_get_open_position_returns_latest_buy() -> None:
|
||||||
@@ -329,3 +329,89 @@ def test_log_trade_unknown_market_falls_back_to_unknown_session() -> None:
|
|||||||
row = conn.execute("SELECT session_id FROM trades ORDER BY id DESC LIMIT 1").fetchone()
|
row = conn.execute("SELECT session_id FROM trades ORDER BY id DESC LIMIT 1").fetchone()
|
||||||
assert row is not None
|
assert row is not None
|
||||||
assert row[0] == "UNKNOWN"
|
assert row[0] == "UNKNOWN"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_latest_buy_trade_prefers_exchange_code_match() -> None:
|
||||||
|
conn = init_db(":memory:")
|
||||||
|
log_trade(
|
||||||
|
conn=conn,
|
||||||
|
stock_code="AAPL",
|
||||||
|
action="BUY",
|
||||||
|
confidence=80,
|
||||||
|
rationale="legacy",
|
||||||
|
quantity=10,
|
||||||
|
price=120.0,
|
||||||
|
market="US_NASDAQ",
|
||||||
|
exchange_code="",
|
||||||
|
decision_id="legacy-buy",
|
||||||
|
)
|
||||||
|
log_trade(
|
||||||
|
conn=conn,
|
||||||
|
stock_code="AAPL",
|
||||||
|
action="BUY",
|
||||||
|
confidence=85,
|
||||||
|
rationale="matched",
|
||||||
|
quantity=5,
|
||||||
|
price=125.0,
|
||||||
|
market="US_NASDAQ",
|
||||||
|
exchange_code="NASD",
|
||||||
|
decision_id="matched-buy",
|
||||||
|
)
|
||||||
|
matched = get_latest_buy_trade(
|
||||||
|
conn,
|
||||||
|
stock_code="AAPL",
|
||||||
|
market="US_NASDAQ",
|
||||||
|
exchange_code="NASD",
|
||||||
|
)
|
||||||
|
assert matched is not None
|
||||||
|
assert matched["decision_id"] == "matched-buy"
|
||||||
|
|
||||||
|
|
||||||
|
def test_decision_logs_session_id_migration_backfills_unknown() -> None:
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
|
db_path = f.name
|
||||||
|
try:
|
||||||
|
old_conn = sqlite3.connect(db_path)
|
||||||
|
old_conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE decision_logs (
|
||||||
|
decision_id TEXT PRIMARY KEY,
|
||||||
|
timestamp TEXT NOT NULL,
|
||||||
|
stock_code TEXT NOT NULL,
|
||||||
|
market TEXT NOT NULL,
|
||||||
|
exchange_code TEXT NOT NULL,
|
||||||
|
action TEXT NOT NULL,
|
||||||
|
confidence INTEGER NOT NULL,
|
||||||
|
rationale TEXT NOT NULL,
|
||||||
|
context_snapshot TEXT NOT NULL,
|
||||||
|
input_data TEXT NOT NULL
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
old_conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO decision_logs (
|
||||||
|
decision_id, timestamp, stock_code, market, exchange_code,
|
||||||
|
action, confidence, rationale, context_snapshot, input_data
|
||||||
|
) VALUES (
|
||||||
|
'd1', '2026-01-01T00:00:00+00:00', 'AAPL', 'US_NASDAQ', 'NASD',
|
||||||
|
'BUY', 80, 'legacy row', '{}', '{}'
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
old_conn.commit()
|
||||||
|
old_conn.close()
|
||||||
|
|
||||||
|
conn = init_db(db_path)
|
||||||
|
columns = {row[1] for row in conn.execute("PRAGMA table_info(decision_logs)").fetchall()}
|
||||||
|
assert "session_id" in columns
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT session_id FROM decision_logs WHERE decision_id='d1'"
|
||||||
|
).fetchone()
|
||||||
|
assert row is not None
|
||||||
|
assert row[0] == "UNKNOWN"
|
||||||
|
conn.close()
|
||||||
|
finally:
|
||||||
|
os.unlink(db_path)
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ def test_log_decision_creates_record(logger: DecisionLogger, db_conn: sqlite3.Co
|
|||||||
|
|
||||||
# Verify record exists in database
|
# Verify record exists in database
|
||||||
cursor = db_conn.execute(
|
cursor = db_conn.execute(
|
||||||
"SELECT decision_id, action, confidence FROM decision_logs WHERE decision_id = ?",
|
"SELECT decision_id, action, confidence, session_id FROM decision_logs WHERE decision_id = ?",
|
||||||
(decision_id,),
|
(decision_id,),
|
||||||
)
|
)
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
@@ -57,6 +57,7 @@ def test_log_decision_creates_record(logger: DecisionLogger, db_conn: sqlite3.Co
|
|||||||
assert row[0] == decision_id
|
assert row[0] == decision_id
|
||||||
assert row[1] == "BUY"
|
assert row[1] == "BUY"
|
||||||
assert row[2] == 85
|
assert row[2] == 85
|
||||||
|
assert row[3] == "UNKNOWN"
|
||||||
|
|
||||||
|
|
||||||
def test_log_decision_stores_context_snapshot(logger: DecisionLogger) -> None:
|
def test_log_decision_stores_context_snapshot(logger: DecisionLogger) -> None:
|
||||||
@@ -84,6 +85,24 @@ def test_log_decision_stores_context_snapshot(logger: DecisionLogger) -> None:
|
|||||||
assert decision is not None
|
assert decision is not None
|
||||||
assert decision.context_snapshot == context_snapshot
|
assert decision.context_snapshot == context_snapshot
|
||||||
assert decision.input_data == input_data
|
assert decision.input_data == input_data
|
||||||
|
assert decision.session_id == "UNKNOWN"
|
||||||
|
|
||||||
|
|
||||||
|
def test_log_decision_stores_explicit_session_id(logger: DecisionLogger) -> None:
|
||||||
|
decision_id = logger.log_decision(
|
||||||
|
stock_code="AAPL",
|
||||||
|
market="US_NASDAQ",
|
||||||
|
exchange_code="NASD",
|
||||||
|
action="BUY",
|
||||||
|
confidence=88,
|
||||||
|
rationale="session check",
|
||||||
|
context_snapshot={},
|
||||||
|
input_data={},
|
||||||
|
session_id="US_PRE",
|
||||||
|
)
|
||||||
|
decision = logger.get_decision_by_id(decision_id)
|
||||||
|
assert decision is not None
|
||||||
|
assert decision.session_id == "US_PRE"
|
||||||
|
|
||||||
|
|
||||||
def test_get_unreviewed_decisions(logger: DecisionLogger) -> None:
|
def test_get_unreviewed_decisions(logger: DecisionLogger) -> None:
|
||||||
@@ -278,6 +297,7 @@ def test_decision_log_dataclass() -> None:
|
|||||||
stock_code="005930",
|
stock_code="005930",
|
||||||
market="KR",
|
market="KR",
|
||||||
exchange_code="KRX",
|
exchange_code="KRX",
|
||||||
|
session_id="KRX_REG",
|
||||||
action="BUY",
|
action="BUY",
|
||||||
confidence=85,
|
confidence=85,
|
||||||
rationale="Test",
|
rationale="Test",
|
||||||
@@ -286,6 +306,7 @@ def test_decision_log_dataclass() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert log.decision_id == "test-uuid"
|
assert log.decision_id == "test-uuid"
|
||||||
|
assert log.session_id == "KRX_REG"
|
||||||
assert log.action == "BUY"
|
assert log.action == "BUY"
|
||||||
assert log.confidence == 85
|
assert log.confidence == 85
|
||||||
assert log.reviewed is False
|
assert log.reviewed is False
|
||||||
|
|||||||
@@ -245,6 +245,52 @@ async def test_generate_strategy_creates_file(optimizer: EvolutionOptimizer, tmp
|
|||||||
assert "def evaluate" in strategy_path.read_text()
|
assert "def evaluate" in strategy_path.read_text()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_strategy_saves_valid_python_code(
|
||||||
|
optimizer: EvolutionOptimizer, tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test that syntactically valid generated code is saved."""
|
||||||
|
failures = [{"decision_id": "1", "timestamp": "2024-01-15T09:30:00+00:00"}]
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.text = (
|
||||||
|
'price = market_data.get("current_price", 0)\n'
|
||||||
|
'if price > 0:\n'
|
||||||
|
' return {"action": "BUY", "confidence": 80, "rationale": "Positive price"}\n'
|
||||||
|
'return {"action": "HOLD", "confidence": 50, "rationale": "No signal"}\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(optimizer._client.aio.models, "generate_content", new=AsyncMock(return_value=mock_response)):
|
||||||
|
with patch("src.evolution.optimizer.STRATEGIES_DIR", tmp_path):
|
||||||
|
strategy_path = await optimizer.generate_strategy(failures)
|
||||||
|
|
||||||
|
assert strategy_path is not None
|
||||||
|
assert strategy_path.exists()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_strategy_blocks_invalid_python_code(
|
||||||
|
optimizer: EvolutionOptimizer, tmp_path: Path, caplog: pytest.LogCaptureFixture,
|
||||||
|
) -> None:
|
||||||
|
"""Test that syntactically invalid generated code is not saved."""
|
||||||
|
failures = [{"decision_id": "1", "timestamp": "2024-01-15T09:30:00+00:00"}]
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.text = (
|
||||||
|
'if market_data.get("current_price", 0) > 0\n'
|
||||||
|
' return {"action": "BUY", "confidence": 80, "rationale": "broken"}\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(optimizer._client.aio.models, "generate_content", new=AsyncMock(return_value=mock_response)):
|
||||||
|
with patch("src.evolution.optimizer.STRATEGIES_DIR", tmp_path):
|
||||||
|
with caplog.at_level("WARNING"):
|
||||||
|
strategy_path = await optimizer.generate_strategy(failures)
|
||||||
|
|
||||||
|
assert strategy_path is None
|
||||||
|
assert list(tmp_path.glob("*.py")) == []
|
||||||
|
assert "failed syntax validation" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_generate_strategy_handles_api_error(optimizer: EvolutionOptimizer) -> None:
|
async def test_generate_strategy_handles_api_error(optimizer: EvolutionOptimizer) -> None:
|
||||||
"""Test that generate_strategy handles Gemini API errors gracefully."""
|
"""Test that generate_strategy handles Gemini API errors gracefully."""
|
||||||
|
|||||||
@@ -1,5 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from src.analysis.triple_barrier import TripleBarrierSpec, label_with_triple_barrier
|
from src.analysis.triple_barrier import TripleBarrierSpec, label_with_triple_barrier
|
||||||
|
|
||||||
|
|
||||||
@@ -129,3 +133,52 @@ def test_short_tie_break_modes() -> None:
|
|||||||
)
|
)
|
||||||
assert out_take.label == 1
|
assert out_take.label == 1
|
||||||
assert out_take.touched == "take_profit"
|
assert out_take.touched == "take_profit"
|
||||||
|
|
||||||
|
|
||||||
|
def test_minutes_time_barrier_consistent_across_sampling() -> None:
|
||||||
|
base = datetime(2026, 2, 28, 9, 0, tzinfo=UTC)
|
||||||
|
highs = [100.0, 100.5, 100.6, 100.4]
|
||||||
|
lows = [100.0, 99.6, 99.4, 99.5]
|
||||||
|
closes = [100.0, 100.1, 100.0, 100.0]
|
||||||
|
spec = TripleBarrierSpec(
|
||||||
|
take_profit_pct=0.02,
|
||||||
|
stop_loss_pct=0.02,
|
||||||
|
max_holding_minutes=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
out_1m = label_with_triple_barrier(
|
||||||
|
highs=highs,
|
||||||
|
lows=lows,
|
||||||
|
closes=closes,
|
||||||
|
timestamps=[base + timedelta(minutes=i) for i in range(4)],
|
||||||
|
entry_index=0,
|
||||||
|
side=1,
|
||||||
|
spec=spec,
|
||||||
|
)
|
||||||
|
out_5m = label_with_triple_barrier(
|
||||||
|
highs=highs,
|
||||||
|
lows=lows,
|
||||||
|
closes=closes,
|
||||||
|
timestamps=[base + timedelta(minutes=5 * i) for i in range(4)],
|
||||||
|
entry_index=0,
|
||||||
|
side=1,
|
||||||
|
spec=spec,
|
||||||
|
)
|
||||||
|
assert out_1m.touch_bar == 3
|
||||||
|
assert out_5m.touch_bar == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_bars_mode_emits_deprecation_warning() -> None:
|
||||||
|
highs = [100, 101, 103]
|
||||||
|
lows = [100, 99.6, 100]
|
||||||
|
closes = [100, 100, 102]
|
||||||
|
spec = TripleBarrierSpec(take_profit_pct=0.02, stop_loss_pct=0.01, max_holding_bars=3)
|
||||||
|
with pytest.deprecated_call(match="max_holding_bars is deprecated"):
|
||||||
|
label_with_triple_barrier(
|
||||||
|
highs=highs,
|
||||||
|
lows=lows,
|
||||||
|
closes=closes,
|
||||||
|
entry_index=0,
|
||||||
|
side=1,
|
||||||
|
spec=spec,
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user