feat: switch backtest triple barrier to calendar-minute horizon (#329)
Some checks failed
Gitea CI / test (push) Has been cancelled
Gitea CI / test (pull_request) Failing after 3s

This commit is contained in:
agentson
2026-03-01 09:44:24 +09:00
parent 35d81fb73d
commit 701350fb65
2 changed files with 48 additions and 3 deletions

View File

@@ -8,6 +8,7 @@ from __future__ import annotations
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime
from statistics import mean from statistics import mean
from typing import Literal from typing import Literal
@@ -22,6 +23,7 @@ class BacktestBar:
low: float low: float
close: float close: float
session_id: str session_id: str
timestamp: datetime | None = None
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -86,16 +88,27 @@ def run_v2_backtest_pipeline(
highs = [float(bar.high) for bar in bars] highs = [float(bar.high) for bar in bars]
lows = [float(bar.low) for bar in bars] lows = [float(bar.low) for bar in bars]
closes = [float(bar.close) for bar in bars] closes = [float(bar.close) for bar in bars]
timestamps = [bar.timestamp for bar in bars]
normalized_entries = sorted(set(int(i) for i in entry_indices)) normalized_entries = sorted(set(int(i) for i in entry_indices))
if normalized_entries[0] < 0 or normalized_entries[-1] >= len(bars): if normalized_entries[0] < 0 or normalized_entries[-1] >= len(bars):
raise IndexError("entry index out of range") raise IndexError("entry index out of range")
resolved_timestamps: list[datetime] | None = None
if triple_barrier_spec.max_holding_minutes is not None:
if any(ts is None for ts in timestamps):
raise ValueError(
"BacktestBar.timestamp is required for all bars when "
"triple_barrier_spec.max_holding_minutes is set"
)
resolved_timestamps = [ts for ts in timestamps if ts is not None]
labels_by_bar_index: dict[int, int] = {} labels_by_bar_index: dict[int, int] = {}
for idx in normalized_entries: for idx in normalized_entries:
labels_by_bar_index[idx] = label_with_triple_barrier( labels_by_bar_index[idx] = label_with_triple_barrier(
highs=highs, highs=highs,
lows=lows, lows=lows,
closes=closes, closes=closes,
timestamps=resolved_timestamps,
entry_index=idx, entry_index=idx,
side=side, side=side,
spec=triple_barrier_spec, spec=triple_barrier_spec,

View File

@@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
from datetime import UTC, datetime, timedelta
from src.analysis.backtest_cost_guard import BacktestCostModel from src.analysis.backtest_cost_guard import BacktestCostModel
from src.analysis.backtest_pipeline import ( from src.analysis.backtest_pipeline import (
BacktestBar, BacktestBar,
@@ -12,6 +14,7 @@ from src.analysis.walk_forward_split import generate_walk_forward_splits
def _bars() -> list[BacktestBar]: def _bars() -> list[BacktestBar]:
base_ts = datetime(2026, 2, 28, 0, 0, tzinfo=UTC)
closes = [100.0, 101.0, 102.0, 101.5, 103.0, 102.5, 104.0, 103.5, 105.0, 104.5, 106.0, 105.5] closes = [100.0, 101.0, 102.0, 101.5, 103.0, 102.5, 104.0, 103.5, 105.0, 104.5, 106.0, 105.5]
bars: list[BacktestBar] = [] bars: list[BacktestBar] = []
for i, close in enumerate(closes): for i, close in enumerate(closes):
@@ -21,6 +24,7 @@ def _bars() -> list[BacktestBar]:
low=close - 1.0, low=close - 1.0,
close=close, close=close,
session_id="KRX_REG" if i % 2 == 0 else "US_PRE", session_id="KRX_REG" if i % 2 == 0 else "US_PRE",
timestamp=base_ts + timedelta(minutes=i),
) )
) )
return bars return bars
@@ -43,7 +47,7 @@ def test_pipeline_happy_path_returns_fold_and_artifact_contract() -> None:
triple_barrier_spec=TripleBarrierSpec( triple_barrier_spec=TripleBarrierSpec(
take_profit_pct=0.02, take_profit_pct=0.02,
stop_loss_pct=0.01, stop_loss_pct=0.01,
max_holding_bars=3, max_holding_minutes=3,
), ),
walk_forward=WalkForwardConfig( walk_forward=WalkForwardConfig(
train_size=4, train_size=4,
@@ -84,7 +88,7 @@ def test_pipeline_cost_guard_fail_fast() -> None:
triple_barrier_spec=TripleBarrierSpec( triple_barrier_spec=TripleBarrierSpec(
take_profit_pct=0.02, take_profit_pct=0.02,
stop_loss_pct=0.01, stop_loss_pct=0.01,
max_holding_bars=3, max_holding_minutes=3,
), ),
walk_forward=WalkForwardConfig(train_size=2, test_size=1), walk_forward=WalkForwardConfig(train_size=2, test_size=1),
cost_model=bad, cost_model=bad,
@@ -119,7 +123,7 @@ def test_pipeline_deterministic_seed_free_deterministic_result() -> None:
triple_barrier_spec=TripleBarrierSpec( triple_barrier_spec=TripleBarrierSpec(
take_profit_pct=0.02, take_profit_pct=0.02,
stop_loss_pct=0.01, stop_loss_pct=0.01,
max_holding_bars=3, max_holding_minutes=3,
), ),
walk_forward=WalkForwardConfig( walk_forward=WalkForwardConfig(
train_size=4, train_size=4,
@@ -134,3 +138,31 @@ def test_pipeline_deterministic_seed_free_deterministic_result() -> None:
out1 = run_v2_backtest_pipeline(**cfg) out1 = run_v2_backtest_pipeline(**cfg)
out2 = run_v2_backtest_pipeline(**cfg) out2 = run_v2_backtest_pipeline(**cfg)
assert out1 == out2 assert out1 == out2
def test_pipeline_rejects_minutes_spec_when_timestamp_missing() -> None:
bars = _bars()
bars[2] = BacktestBar(
high=bars[2].high,
low=bars[2].low,
close=bars[2].close,
session_id=bars[2].session_id,
timestamp=None,
)
try:
run_v2_backtest_pipeline(
bars=bars,
entry_indices=[0, 1, 2, 3],
side=1,
triple_barrier_spec=TripleBarrierSpec(
take_profit_pct=0.02,
stop_loss_pct=0.01,
max_holding_minutes=3,
),
walk_forward=WalkForwardConfig(train_size=2, test_size=1),
cost_model=_cost_model(),
)
except ValueError as exc:
assert "BacktestBar.timestamp is required" in str(exc)
else:
raise AssertionError("expected timestamp validation error")