Merge pull request 'feat: minute-based triple barrier horizon (#329)' (#334) from feature/issue-329-triple-barrier-minutes into feature/v3-session-policy-stream
Some checks failed
Gitea CI / test (push) Has been cancelled
Some checks failed
Gitea CI / test (push) Has been cancelled
Reviewed-on: #334
This commit was merged in pull request #334.
This commit is contained in:
@@ -5,7 +5,9 @@ Implements first-touch labeling with upper/lower/time barriers.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timedelta
|
||||||
from typing import Literal, Sequence
|
from typing import Literal, Sequence
|
||||||
|
|
||||||
|
|
||||||
@@ -16,9 +18,18 @@ TieBreakMode = Literal["stop_first", "take_first"]
|
|||||||
class TripleBarrierSpec:
|
class TripleBarrierSpec:
|
||||||
take_profit_pct: float
|
take_profit_pct: float
|
||||||
stop_loss_pct: float
|
stop_loss_pct: float
|
||||||
max_holding_bars: int
|
max_holding_bars: int | None = None
|
||||||
|
max_holding_minutes: int | None = None
|
||||||
tie_break: TieBreakMode = "stop_first"
|
tie_break: TieBreakMode = "stop_first"
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
if self.max_holding_minutes is None and self.max_holding_bars is None:
|
||||||
|
raise ValueError("one of max_holding_minutes or max_holding_bars must be set")
|
||||||
|
if self.max_holding_minutes is not None and self.max_holding_minutes <= 0:
|
||||||
|
raise ValueError("max_holding_minutes must be positive")
|
||||||
|
if self.max_holding_bars is not None and self.max_holding_bars <= 0:
|
||||||
|
raise ValueError("max_holding_bars must be positive")
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class TripleBarrierLabel:
|
class TripleBarrierLabel:
|
||||||
@@ -35,6 +46,7 @@ def label_with_triple_barrier(
|
|||||||
highs: Sequence[float],
|
highs: Sequence[float],
|
||||||
lows: Sequence[float],
|
lows: Sequence[float],
|
||||||
closes: Sequence[float],
|
closes: Sequence[float],
|
||||||
|
timestamps: Sequence[datetime] | None = None,
|
||||||
entry_index: int,
|
entry_index: int,
|
||||||
side: int,
|
side: int,
|
||||||
spec: TripleBarrierSpec,
|
spec: TripleBarrierSpec,
|
||||||
@@ -53,8 +65,6 @@ def label_with_triple_barrier(
|
|||||||
raise ValueError("highs, lows, closes lengths must match")
|
raise ValueError("highs, lows, closes lengths must match")
|
||||||
if entry_index < 0 or entry_index >= len(closes):
|
if entry_index < 0 or entry_index >= len(closes):
|
||||||
raise IndexError("entry_index out of range")
|
raise IndexError("entry_index out of range")
|
||||||
if spec.max_holding_bars <= 0:
|
|
||||||
raise ValueError("max_holding_bars must be positive")
|
|
||||||
|
|
||||||
entry_price = float(closes[entry_index])
|
entry_price = float(closes[entry_index])
|
||||||
if entry_price <= 0:
|
if entry_price <= 0:
|
||||||
@@ -68,13 +78,31 @@ def label_with_triple_barrier(
|
|||||||
upper = entry_price * (1.0 + spec.stop_loss_pct)
|
upper = entry_price * (1.0 + spec.stop_loss_pct)
|
||||||
lower = entry_price * (1.0 - spec.take_profit_pct)
|
lower = entry_price * (1.0 - spec.take_profit_pct)
|
||||||
|
|
||||||
last_index = min(len(closes) - 1, entry_index + spec.max_holding_bars)
|
if spec.max_holding_minutes is not None:
|
||||||
|
if timestamps is None:
|
||||||
|
raise ValueError("timestamps are required when max_holding_minutes is set")
|
||||||
|
if len(timestamps) != len(closes):
|
||||||
|
raise ValueError("timestamps length must match OHLC lengths")
|
||||||
|
expiry_timestamp = timestamps[entry_index] + timedelta(minutes=spec.max_holding_minutes)
|
||||||
|
last_index = entry_index
|
||||||
|
for idx in range(entry_index + 1, len(closes)):
|
||||||
|
if timestamps[idx] > expiry_timestamp:
|
||||||
|
break
|
||||||
|
last_index = idx
|
||||||
|
else:
|
||||||
|
assert spec.max_holding_bars is not None
|
||||||
|
warnings.warn(
|
||||||
|
"TripleBarrierSpec.max_holding_bars is deprecated; use max_holding_minutes with timestamps instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
last_index = min(len(closes) - 1, entry_index + spec.max_holding_bars)
|
||||||
for idx in range(entry_index + 1, last_index + 1):
|
for idx in range(entry_index + 1, last_index + 1):
|
||||||
h = float(highs[idx])
|
high_price = float(highs[idx])
|
||||||
l = float(lows[idx])
|
low_price = float(lows[idx])
|
||||||
|
|
||||||
up_touch = h >= upper
|
up_touch = high_price >= upper
|
||||||
down_touch = l <= lower
|
down_touch = low_price <= lower
|
||||||
if not up_touch and not down_touch:
|
if not up_touch and not down_touch:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from src.analysis.triple_barrier import TripleBarrierSpec, label_with_triple_barrier
|
from src.analysis.triple_barrier import TripleBarrierSpec, label_with_triple_barrier
|
||||||
|
|
||||||
|
|
||||||
@@ -129,3 +133,52 @@ def test_short_tie_break_modes() -> None:
|
|||||||
)
|
)
|
||||||
assert out_take.label == 1
|
assert out_take.label == 1
|
||||||
assert out_take.touched == "take_profit"
|
assert out_take.touched == "take_profit"
|
||||||
|
|
||||||
|
|
||||||
|
def test_minutes_time_barrier_consistent_across_sampling() -> None:
|
||||||
|
base = datetime(2026, 2, 28, 9, 0, tzinfo=UTC)
|
||||||
|
highs = [100.0, 100.5, 100.6, 100.4]
|
||||||
|
lows = [100.0, 99.6, 99.4, 99.5]
|
||||||
|
closes = [100.0, 100.1, 100.0, 100.0]
|
||||||
|
spec = TripleBarrierSpec(
|
||||||
|
take_profit_pct=0.02,
|
||||||
|
stop_loss_pct=0.02,
|
||||||
|
max_holding_minutes=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
out_1m = label_with_triple_barrier(
|
||||||
|
highs=highs,
|
||||||
|
lows=lows,
|
||||||
|
closes=closes,
|
||||||
|
timestamps=[base + timedelta(minutes=i) for i in range(4)],
|
||||||
|
entry_index=0,
|
||||||
|
side=1,
|
||||||
|
spec=spec,
|
||||||
|
)
|
||||||
|
out_5m = label_with_triple_barrier(
|
||||||
|
highs=highs,
|
||||||
|
lows=lows,
|
||||||
|
closes=closes,
|
||||||
|
timestamps=[base + timedelta(minutes=5 * i) for i in range(4)],
|
||||||
|
entry_index=0,
|
||||||
|
side=1,
|
||||||
|
spec=spec,
|
||||||
|
)
|
||||||
|
assert out_1m.touch_bar == 3
|
||||||
|
assert out_5m.touch_bar == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_bars_mode_emits_deprecation_warning() -> None:
|
||||||
|
highs = [100, 101, 103]
|
||||||
|
lows = [100, 99.6, 100]
|
||||||
|
closes = [100, 100, 102]
|
||||||
|
spec = TripleBarrierSpec(take_profit_pct=0.02, stop_loss_pct=0.01, max_holding_bars=3)
|
||||||
|
with pytest.deprecated_call(match="max_holding_bars is deprecated"):
|
||||||
|
label_with_triple_barrier(
|
||||||
|
highs=highs,
|
||||||
|
lows=lows,
|
||||||
|
closes=closes,
|
||||||
|
entry_index=0,
|
||||||
|
side=1,
|
||||||
|
spec=spec,
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user