From c641097fe7718b280c5742ef6d45bb7249ea2751 Mon Sep 17 00:00:00 2001 From: agentson Date: Sat, 28 Feb 2026 14:35:55 +0900 Subject: [PATCH] feat: support minute-based triple barrier horizon (#329) --- src/analysis/triple_barrier.py | 44 +++++++++++++++++++++++----- tests/test_triple_barrier.py | 53 ++++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 8 deletions(-) diff --git a/src/analysis/triple_barrier.py b/src/analysis/triple_barrier.py index f609496..793250d 100644 --- a/src/analysis/triple_barrier.py +++ b/src/analysis/triple_barrier.py @@ -5,7 +5,9 @@ Implements first-touch labeling with upper/lower/time barriers. from __future__ import annotations +import warnings from dataclasses import dataclass +from datetime import datetime, timedelta from typing import Literal, Sequence @@ -16,9 +18,18 @@ TieBreakMode = Literal["stop_first", "take_first"] class TripleBarrierSpec: take_profit_pct: float stop_loss_pct: float - max_holding_bars: int + max_holding_bars: int | None = None + max_holding_minutes: int | None = None tie_break: TieBreakMode = "stop_first" + def __post_init__(self) -> None: + if self.max_holding_minutes is None and self.max_holding_bars is None: + raise ValueError("one of max_holding_minutes or max_holding_bars must be set") + if self.max_holding_minutes is not None and self.max_holding_minutes <= 0: + raise ValueError("max_holding_minutes must be positive") + if self.max_holding_bars is not None and self.max_holding_bars <= 0: + raise ValueError("max_holding_bars must be positive") + @dataclass(frozen=True) class TripleBarrierLabel: @@ -35,6 +46,7 @@ def label_with_triple_barrier( highs: Sequence[float], lows: Sequence[float], closes: Sequence[float], + timestamps: Sequence[datetime] | None = None, entry_index: int, side: int, spec: TripleBarrierSpec, @@ -53,8 +65,6 @@ def label_with_triple_barrier( raise ValueError("highs, lows, closes lengths must match") if entry_index < 0 or entry_index >= len(closes): raise IndexError("entry_index out of range") - if spec.max_holding_bars <= 0: - raise ValueError("max_holding_bars must be positive") entry_price = float(closes[entry_index]) if entry_price <= 0: @@ -68,13 +78,31 @@ def label_with_triple_barrier( upper = entry_price * (1.0 + spec.stop_loss_pct) lower = entry_price * (1.0 - spec.take_profit_pct) - last_index = min(len(closes) - 1, entry_index + spec.max_holding_bars) + if spec.max_holding_minutes is not None: + if timestamps is None: + raise ValueError("timestamps are required when max_holding_minutes is set") + if len(timestamps) != len(closes): + raise ValueError("timestamps length must match OHLC lengths") + expiry_timestamp = timestamps[entry_index] + timedelta(minutes=spec.max_holding_minutes) + last_index = entry_index + for idx in range(entry_index + 1, len(closes)): + if timestamps[idx] > expiry_timestamp: + break + last_index = idx + else: + assert spec.max_holding_bars is not None + warnings.warn( + "TripleBarrierSpec.max_holding_bars is deprecated; use max_holding_minutes with timestamps instead.", + DeprecationWarning, + stacklevel=2, + ) + last_index = min(len(closes) - 1, entry_index + spec.max_holding_bars) for idx in range(entry_index + 1, last_index + 1): - h = float(highs[idx]) - l = float(lows[idx]) + high_price = float(highs[idx]) + low_price = float(lows[idx]) - up_touch = h >= upper - down_touch = l <= lower + up_touch = high_price >= upper + down_touch = low_price <= lower if not up_touch and not down_touch: continue diff --git a/tests/test_triple_barrier.py b/tests/test_triple_barrier.py index 1fff8e3..ba82a5e 100644 --- a/tests/test_triple_barrier.py +++ b/tests/test_triple_barrier.py @@ -1,5 +1,9 @@ from __future__ import annotations +from datetime import UTC, datetime, timedelta + +import pytest + from src.analysis.triple_barrier import TripleBarrierSpec, label_with_triple_barrier @@ -129,3 +133,52 @@ def test_short_tie_break_modes() -> None: ) assert out_take.label == 1 assert out_take.touched == "take_profit" + + +def test_minutes_time_barrier_consistent_across_sampling() -> None: + base = datetime(2026, 2, 28, 9, 0, tzinfo=UTC) + highs = [100.0, 100.5, 100.6, 100.4] + lows = [100.0, 99.6, 99.4, 99.5] + closes = [100.0, 100.1, 100.0, 100.0] + spec = TripleBarrierSpec( + take_profit_pct=0.02, + stop_loss_pct=0.02, + max_holding_minutes=5, + ) + + out_1m = label_with_triple_barrier( + highs=highs, + lows=lows, + closes=closes, + timestamps=[base + timedelta(minutes=i) for i in range(4)], + entry_index=0, + side=1, + spec=spec, + ) + out_5m = label_with_triple_barrier( + highs=highs, + lows=lows, + closes=closes, + timestamps=[base + timedelta(minutes=5 * i) for i in range(4)], + entry_index=0, + side=1, + spec=spec, + ) + assert out_1m.touch_bar == 3 + assert out_5m.touch_bar == 1 + + +def test_bars_mode_emits_deprecation_warning() -> None: + highs = [100, 101, 103] + lows = [100, 99.6, 100] + closes = [100, 100, 102] + spec = TripleBarrierSpec(take_profit_pct=0.02, stop_loss_pct=0.01, max_holding_bars=3) + with pytest.deprecated_call(match="max_holding_bars is deprecated"): + label_with_triple_barrier( + highs=highs, + lows=lows, + closes=closes, + entry_index=0, + side=1, + spec=spec, + )