Merge pull request 'feat: Triple Barrier 시간장벽을 캘린더 분 기반으로 전환 (#329)' (#346) from feature/issue-329-triple-barrier-calendar-minutes into feature/v3-session-policy-stream
Some checks failed
Gitea CI / test (push) Has been cancelled

Reviewed-on: #346
This commit was merged in pull request #346.
This commit is contained in:
2026-03-01 09:57:01 +09:00
2 changed files with 49 additions and 3 deletions

View File

@@ -8,8 +8,10 @@ from __future__ import annotations
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime
from statistics import mean from statistics import mean
from typing import Literal from typing import Literal
from typing import cast
from src.analysis.backtest_cost_guard import BacktestCostModel, validate_backtest_cost_model from src.analysis.backtest_cost_guard import BacktestCostModel, validate_backtest_cost_model
from src.analysis.triple_barrier import TripleBarrierSpec, label_with_triple_barrier from src.analysis.triple_barrier import TripleBarrierSpec, label_with_triple_barrier
@@ -22,6 +24,7 @@ class BacktestBar:
low: float low: float
close: float close: float
session_id: str session_id: str
timestamp: datetime | None = None
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -86,16 +89,27 @@ def run_v2_backtest_pipeline(
highs = [float(bar.high) for bar in bars] highs = [float(bar.high) for bar in bars]
lows = [float(bar.low) for bar in bars] lows = [float(bar.low) for bar in bars]
closes = [float(bar.close) for bar in bars] closes = [float(bar.close) for bar in bars]
timestamps = [bar.timestamp for bar in bars]
normalized_entries = sorted(set(int(i) for i in entry_indices)) normalized_entries = sorted(set(int(i) for i in entry_indices))
if normalized_entries[0] < 0 or normalized_entries[-1] >= len(bars): if normalized_entries[0] < 0 or normalized_entries[-1] >= len(bars):
raise IndexError("entry index out of range") raise IndexError("entry index out of range")
resolved_timestamps: list[datetime] | None = None
if triple_barrier_spec.max_holding_minutes is not None:
if any(ts is None for ts in timestamps):
raise ValueError(
"BacktestBar.timestamp is required for all bars when "
"triple_barrier_spec.max_holding_minutes is set"
)
resolved_timestamps = cast(list[datetime], timestamps)
labels_by_bar_index: dict[int, int] = {} labels_by_bar_index: dict[int, int] = {}
for idx in normalized_entries: for idx in normalized_entries:
labels_by_bar_index[idx] = label_with_triple_barrier( labels_by_bar_index[idx] = label_with_triple_barrier(
highs=highs, highs=highs,
lows=lows, lows=lows,
closes=closes, closes=closes,
timestamps=resolved_timestamps,
entry_index=idx, entry_index=idx,
side=side, side=side,
spec=triple_barrier_spec, spec=triple_barrier_spec,

View File

@@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
from datetime import UTC, datetime, timedelta
from src.analysis.backtest_cost_guard import BacktestCostModel from src.analysis.backtest_cost_guard import BacktestCostModel
from src.analysis.backtest_pipeline import ( from src.analysis.backtest_pipeline import (
BacktestBar, BacktestBar,
@@ -12,6 +14,7 @@ from src.analysis.walk_forward_split import generate_walk_forward_splits
def _bars() -> list[BacktestBar]: def _bars() -> list[BacktestBar]:
base_ts = datetime(2026, 2, 28, 0, 0, tzinfo=UTC)
closes = [100.0, 101.0, 102.0, 101.5, 103.0, 102.5, 104.0, 103.5, 105.0, 104.5, 106.0, 105.5] closes = [100.0, 101.0, 102.0, 101.5, 103.0, 102.5, 104.0, 103.5, 105.0, 104.5, 106.0, 105.5]
bars: list[BacktestBar] = [] bars: list[BacktestBar] = []
for i, close in enumerate(closes): for i, close in enumerate(closes):
@@ -21,6 +24,7 @@ def _bars() -> list[BacktestBar]:
low=close - 1.0, low=close - 1.0,
close=close, close=close,
session_id="KRX_REG" if i % 2 == 0 else "US_PRE", session_id="KRX_REG" if i % 2 == 0 else "US_PRE",
timestamp=base_ts + timedelta(minutes=i),
) )
) )
return bars return bars
@@ -43,7 +47,7 @@ def test_pipeline_happy_path_returns_fold_and_artifact_contract() -> None:
triple_barrier_spec=TripleBarrierSpec( triple_barrier_spec=TripleBarrierSpec(
take_profit_pct=0.02, take_profit_pct=0.02,
stop_loss_pct=0.01, stop_loss_pct=0.01,
max_holding_bars=3, max_holding_minutes=3,
), ),
walk_forward=WalkForwardConfig( walk_forward=WalkForwardConfig(
train_size=4, train_size=4,
@@ -84,7 +88,7 @@ def test_pipeline_cost_guard_fail_fast() -> None:
triple_barrier_spec=TripleBarrierSpec( triple_barrier_spec=TripleBarrierSpec(
take_profit_pct=0.02, take_profit_pct=0.02,
stop_loss_pct=0.01, stop_loss_pct=0.01,
max_holding_bars=3, max_holding_minutes=3,
), ),
walk_forward=WalkForwardConfig(train_size=2, test_size=1), walk_forward=WalkForwardConfig(train_size=2, test_size=1),
cost_model=bad, cost_model=bad,
@@ -119,7 +123,7 @@ def test_pipeline_deterministic_seed_free_deterministic_result() -> None:
triple_barrier_spec=TripleBarrierSpec( triple_barrier_spec=TripleBarrierSpec(
take_profit_pct=0.02, take_profit_pct=0.02,
stop_loss_pct=0.01, stop_loss_pct=0.01,
max_holding_bars=3, max_holding_minutes=3,
), ),
walk_forward=WalkForwardConfig( walk_forward=WalkForwardConfig(
train_size=4, train_size=4,
@@ -134,3 +138,31 @@ def test_pipeline_deterministic_seed_free_deterministic_result() -> None:
out1 = run_v2_backtest_pipeline(**cfg) out1 = run_v2_backtest_pipeline(**cfg)
out2 = run_v2_backtest_pipeline(**cfg) out2 = run_v2_backtest_pipeline(**cfg)
assert out1 == out2 assert out1 == out2
def test_pipeline_rejects_minutes_spec_when_timestamp_missing() -> None:
bars = _bars()
bars[2] = BacktestBar(
high=bars[2].high,
low=bars[2].low,
close=bars[2].close,
session_id=bars[2].session_id,
timestamp=None,
)
try:
run_v2_backtest_pipeline(
bars=bars,
entry_indices=[0, 1, 2, 3],
side=1,
triple_barrier_spec=TripleBarrierSpec(
take_profit_pct=0.02,
stop_loss_pct=0.01,
max_holding_minutes=3,
),
walk_forward=WalkForwardConfig(train_size=2, test_size=1),
cost_model=_cost_model(),
)
except ValueError as exc:
assert "BacktestBar.timestamp is required" in str(exc)
else:
raise AssertionError("expected timestamp validation error")