feat: Triple Barrier 시간장벽을 캘린더 분 기반으로 전환 (#329) #346
@@ -8,8 +8,10 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
from statistics import mean
|
from statistics import mean
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
from src.analysis.backtest_cost_guard import BacktestCostModel, validate_backtest_cost_model
|
from src.analysis.backtest_cost_guard import BacktestCostModel, validate_backtest_cost_model
|
||||||
from src.analysis.triple_barrier import TripleBarrierSpec, label_with_triple_barrier
|
from src.analysis.triple_barrier import TripleBarrierSpec, label_with_triple_barrier
|
||||||
@@ -22,6 +24,7 @@ class BacktestBar:
|
|||||||
low: float
|
low: float
|
||||||
close: float
|
close: float
|
||||||
session_id: str
|
session_id: str
|
||||||
|
timestamp: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -86,16 +89,27 @@ def run_v2_backtest_pipeline(
|
|||||||
highs = [float(bar.high) for bar in bars]
|
highs = [float(bar.high) for bar in bars]
|
||||||
lows = [float(bar.low) for bar in bars]
|
lows = [float(bar.low) for bar in bars]
|
||||||
closes = [float(bar.close) for bar in bars]
|
closes = [float(bar.close) for bar in bars]
|
||||||
|
timestamps = [bar.timestamp for bar in bars]
|
||||||
normalized_entries = sorted(set(int(i) for i in entry_indices))
|
normalized_entries = sorted(set(int(i) for i in entry_indices))
|
||||||
if normalized_entries[0] < 0 or normalized_entries[-1] >= len(bars):
|
if normalized_entries[0] < 0 or normalized_entries[-1] >= len(bars):
|
||||||
raise IndexError("entry index out of range")
|
raise IndexError("entry index out of range")
|
||||||
|
|
||||||
|
resolved_timestamps: list[datetime] | None = None
|
||||||
|
if triple_barrier_spec.max_holding_minutes is not None:
|
||||||
|
if any(ts is None for ts in timestamps):
|
||||||
|
raise ValueError(
|
||||||
|
"BacktestBar.timestamp is required for all bars when "
|
||||||
|
"triple_barrier_spec.max_holding_minutes is set"
|
||||||
|
)
|
||||||
|
resolved_timestamps = cast(list[datetime], timestamps)
|
||||||
|
|
||||||
labels_by_bar_index: dict[int, int] = {}
|
labels_by_bar_index: dict[int, int] = {}
|
||||||
for idx in normalized_entries:
|
for idx in normalized_entries:
|
||||||
labels_by_bar_index[idx] = label_with_triple_barrier(
|
labels_by_bar_index[idx] = label_with_triple_barrier(
|
||||||
highs=highs,
|
highs=highs,
|
||||||
lows=lows,
|
lows=lows,
|
||||||
closes=closes,
|
closes=closes,
|
||||||
|
timestamps=resolved_timestamps,
|
||||||
entry_index=idx,
|
entry_index=idx,
|
||||||
side=side,
|
side=side,
|
||||||
spec=triple_barrier_spec,
|
spec=triple_barrier_spec,
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
from src.analysis.backtest_cost_guard import BacktestCostModel
|
from src.analysis.backtest_cost_guard import BacktestCostModel
|
||||||
from src.analysis.backtest_pipeline import (
|
from src.analysis.backtest_pipeline import (
|
||||||
BacktestBar,
|
BacktestBar,
|
||||||
@@ -12,6 +14,7 @@ from src.analysis.walk_forward_split import generate_walk_forward_splits
|
|||||||
|
|
||||||
|
|
||||||
def _bars() -> list[BacktestBar]:
|
def _bars() -> list[BacktestBar]:
|
||||||
|
base_ts = datetime(2026, 2, 28, 0, 0, tzinfo=UTC)
|
||||||
closes = [100.0, 101.0, 102.0, 101.5, 103.0, 102.5, 104.0, 103.5, 105.0, 104.5, 106.0, 105.5]
|
closes = [100.0, 101.0, 102.0, 101.5, 103.0, 102.5, 104.0, 103.5, 105.0, 104.5, 106.0, 105.5]
|
||||||
bars: list[BacktestBar] = []
|
bars: list[BacktestBar] = []
|
||||||
for i, close in enumerate(closes):
|
for i, close in enumerate(closes):
|
||||||
@@ -21,6 +24,7 @@ def _bars() -> list[BacktestBar]:
|
|||||||
low=close - 1.0,
|
low=close - 1.0,
|
||||||
close=close,
|
close=close,
|
||||||
session_id="KRX_REG" if i % 2 == 0 else "US_PRE",
|
session_id="KRX_REG" if i % 2 == 0 else "US_PRE",
|
||||||
|
timestamp=base_ts + timedelta(minutes=i),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return bars
|
return bars
|
||||||
@@ -43,7 +47,7 @@ def test_pipeline_happy_path_returns_fold_and_artifact_contract() -> None:
|
|||||||
triple_barrier_spec=TripleBarrierSpec(
|
triple_barrier_spec=TripleBarrierSpec(
|
||||||
take_profit_pct=0.02,
|
take_profit_pct=0.02,
|
||||||
stop_loss_pct=0.01,
|
stop_loss_pct=0.01,
|
||||||
max_holding_bars=3,
|
max_holding_minutes=3,
|
||||||
),
|
),
|
||||||
walk_forward=WalkForwardConfig(
|
walk_forward=WalkForwardConfig(
|
||||||
train_size=4,
|
train_size=4,
|
||||||
@@ -84,7 +88,7 @@ def test_pipeline_cost_guard_fail_fast() -> None:
|
|||||||
triple_barrier_spec=TripleBarrierSpec(
|
triple_barrier_spec=TripleBarrierSpec(
|
||||||
take_profit_pct=0.02,
|
take_profit_pct=0.02,
|
||||||
stop_loss_pct=0.01,
|
stop_loss_pct=0.01,
|
||||||
max_holding_bars=3,
|
max_holding_minutes=3,
|
||||||
),
|
),
|
||||||
walk_forward=WalkForwardConfig(train_size=2, test_size=1),
|
walk_forward=WalkForwardConfig(train_size=2, test_size=1),
|
||||||
cost_model=bad,
|
cost_model=bad,
|
||||||
@@ -119,7 +123,7 @@ def test_pipeline_deterministic_seed_free_deterministic_result() -> None:
|
|||||||
triple_barrier_spec=TripleBarrierSpec(
|
triple_barrier_spec=TripleBarrierSpec(
|
||||||
take_profit_pct=0.02,
|
take_profit_pct=0.02,
|
||||||
stop_loss_pct=0.01,
|
stop_loss_pct=0.01,
|
||||||
max_holding_bars=3,
|
max_holding_minutes=3,
|
||||||
),
|
),
|
||||||
walk_forward=WalkForwardConfig(
|
walk_forward=WalkForwardConfig(
|
||||||
train_size=4,
|
train_size=4,
|
||||||
@@ -134,3 +138,31 @@ def test_pipeline_deterministic_seed_free_deterministic_result() -> None:
|
|||||||
out1 = run_v2_backtest_pipeline(**cfg)
|
out1 = run_v2_backtest_pipeline(**cfg)
|
||||||
out2 = run_v2_backtest_pipeline(**cfg)
|
out2 = run_v2_backtest_pipeline(**cfg)
|
||||||
assert out1 == out2
|
assert out1 == out2
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_rejects_minutes_spec_when_timestamp_missing() -> None:
|
||||||
|
bars = _bars()
|
||||||
|
bars[2] = BacktestBar(
|
||||||
|
high=bars[2].high,
|
||||||
|
low=bars[2].low,
|
||||||
|
close=bars[2].close,
|
||||||
|
session_id=bars[2].session_id,
|
||||||
|
timestamp=None,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
run_v2_backtest_pipeline(
|
||||||
|
bars=bars,
|
||||||
|
entry_indices=[0, 1, 2, 3],
|
||||||
|
side=1,
|
||||||
|
triple_barrier_spec=TripleBarrierSpec(
|
||||||
|
take_profit_pct=0.02,
|
||||||
|
stop_loss_pct=0.01,
|
||||||
|
max_holding_minutes=3,
|
||||||
|
),
|
||||||
|
walk_forward=WalkForwardConfig(train_size=2, test_size=1),
|
||||||
|
cost_model=_cost_model(),
|
||||||
|
)
|
||||||
|
except ValueError as exc:
|
||||||
|
assert "BacktestBar.timestamp is required" in str(exc)
|
||||||
|
else:
|
||||||
|
raise AssertionError("expected timestamp validation error")
|
||||||
|
|||||||
Reference in New Issue
Block a user