feat: support minute-based triple barrier horizon (#329)
Some checks failed
Gitea CI / test (push) Has been cancelled
Gitea CI / test (pull_request) Has been cancelled

This commit is contained in:
agentson
2026-02-28 14:35:55 +09:00
parent 13a6d6612a
commit c641097fe7
2 changed files with 89 additions and 8 deletions

View File

@@ -5,7 +5,9 @@ Implements first-touch labeling with upper/lower/time barriers.
from __future__ import annotations
import warnings
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Literal, Sequence
@@ -16,9 +18,18 @@ TieBreakMode = Literal["stop_first", "take_first"]
class TripleBarrierSpec:
take_profit_pct: float
stop_loss_pct: float
max_holding_bars: int
max_holding_bars: int | None = None
max_holding_minutes: int | None = None
tie_break: TieBreakMode = "stop_first"
def __post_init__(self) -> None:
if self.max_holding_minutes is None and self.max_holding_bars is None:
raise ValueError("one of max_holding_minutes or max_holding_bars must be set")
if self.max_holding_minutes is not None and self.max_holding_minutes <= 0:
raise ValueError("max_holding_minutes must be positive")
if self.max_holding_bars is not None and self.max_holding_bars <= 0:
raise ValueError("max_holding_bars must be positive")
@dataclass(frozen=True)
class TripleBarrierLabel:
@@ -35,6 +46,7 @@ def label_with_triple_barrier(
highs: Sequence[float],
lows: Sequence[float],
closes: Sequence[float],
timestamps: Sequence[datetime] | None = None,
entry_index: int,
side: int,
spec: TripleBarrierSpec,
@@ -53,8 +65,6 @@ def label_with_triple_barrier(
raise ValueError("highs, lows, closes lengths must match")
if entry_index < 0 or entry_index >= len(closes):
raise IndexError("entry_index out of range")
if spec.max_holding_bars <= 0:
raise ValueError("max_holding_bars must be positive")
entry_price = float(closes[entry_index])
if entry_price <= 0:
@@ -68,13 +78,31 @@ def label_with_triple_barrier(
upper = entry_price * (1.0 + spec.stop_loss_pct)
lower = entry_price * (1.0 - spec.take_profit_pct)
last_index = min(len(closes) - 1, entry_index + spec.max_holding_bars)
if spec.max_holding_minutes is not None:
if timestamps is None:
raise ValueError("timestamps are required when max_holding_minutes is set")
if len(timestamps) != len(closes):
raise ValueError("timestamps length must match OHLC lengths")
expiry_timestamp = timestamps[entry_index] + timedelta(minutes=spec.max_holding_minutes)
last_index = entry_index
for idx in range(entry_index + 1, len(closes)):
if timestamps[idx] > expiry_timestamp:
break
last_index = idx
else:
assert spec.max_holding_bars is not None
warnings.warn(
"TripleBarrierSpec.max_holding_bars is deprecated; use max_holding_minutes with timestamps instead.",
DeprecationWarning,
stacklevel=2,
)
last_index = min(len(closes) - 1, entry_index + spec.max_holding_bars)
for idx in range(entry_index + 1, last_index + 1):
h = float(highs[idx])
l = float(lows[idx])
high_price = float(highs[idx])
low_price = float(lows[idx])
up_touch = h >= upper
down_touch = l <= lower
up_touch = high_price >= upper
down_touch = low_price <= lower
if not up_touch and not down_touch:
continue