"""Triple barrier labeler utilities. Implements first-touch labeling with upper/lower/time barriers. """ from __future__ import annotations import warnings from collections.abc import Sequence from dataclasses import dataclass from datetime import datetime, timedelta from typing import Literal TieBreakMode = Literal["stop_first", "take_first"] @dataclass(frozen=True) class TripleBarrierSpec: take_profit_pct: float stop_loss_pct: float max_holding_bars: int | None = None max_holding_minutes: int | None = None tie_break: TieBreakMode = "stop_first" def __post_init__(self) -> None: if self.max_holding_minutes is None and self.max_holding_bars is None: raise ValueError("one of max_holding_minutes or max_holding_bars must be set") if self.max_holding_minutes is not None and self.max_holding_minutes <= 0: raise ValueError("max_holding_minutes must be positive") if self.max_holding_bars is not None and self.max_holding_bars <= 0: raise ValueError("max_holding_bars must be positive") @dataclass(frozen=True) class TripleBarrierLabel: label: int # +1 take-profit first, -1 stop-loss first, 0 timeout touched: Literal["take_profit", "stop_loss", "time"] touch_bar: int entry_price: float upper_barrier: float lower_barrier: float def label_with_triple_barrier( *, highs: Sequence[float], lows: Sequence[float], closes: Sequence[float], timestamps: Sequence[datetime] | None = None, entry_index: int, side: int, spec: TripleBarrierSpec, ) -> TripleBarrierLabel: """Label one entry using triple-barrier first-touch rule. Args: highs/lows/closes: OHLC components with identical length. entry_index: Entry bar index in the sequences. side: +1 for long, -1 for short. spec: Barrier specification. """ if side not in {1, -1}: raise ValueError("side must be +1 or -1") if len(highs) != len(lows) or len(highs) != len(closes): raise ValueError("highs, lows, closes lengths must match") if entry_index < 0 or entry_index >= len(closes): raise IndexError("entry_index out of range") entry_price = float(closes[entry_index]) if entry_price <= 0: raise ValueError("entry price must be positive") if side == 1: upper = entry_price * (1.0 + spec.take_profit_pct) lower = entry_price * (1.0 - spec.stop_loss_pct) else: # For short side, favorable move is down. upper = entry_price * (1.0 + spec.stop_loss_pct) lower = entry_price * (1.0 - spec.take_profit_pct) if spec.max_holding_minutes is not None: if timestamps is None: raise ValueError("timestamps are required when max_holding_minutes is set") if len(timestamps) != len(closes): raise ValueError("timestamps length must match OHLC lengths") expiry_timestamp = timestamps[entry_index] + timedelta(minutes=spec.max_holding_minutes) last_index = entry_index for idx in range(entry_index + 1, len(closes)): if timestamps[idx] > expiry_timestamp: break last_index = idx else: assert spec.max_holding_bars is not None warnings.warn( ( "TripleBarrierSpec.max_holding_bars is deprecated; " "use max_holding_minutes with timestamps instead." ), DeprecationWarning, stacklevel=2, ) last_index = min(len(closes) - 1, entry_index + spec.max_holding_bars) for idx in range(entry_index + 1, last_index + 1): high_price = float(highs[idx]) low_price = float(lows[idx]) up_touch = high_price >= upper down_touch = low_price <= lower if not up_touch and not down_touch: continue if up_touch and down_touch: if spec.tie_break == "stop_first": touched = "stop_loss" label = -1 else: touched = "take_profit" label = 1 elif up_touch: touched = "take_profit" if side == 1 else "stop_loss" label = 1 if side == 1 else -1 else: touched = "stop_loss" if side == 1 else "take_profit" label = -1 if side == 1 else 1 return TripleBarrierLabel( label=label, touched=touched, touch_bar=idx, entry_price=entry_price, upper_barrier=upper, lower_barrier=lower, ) return TripleBarrierLabel( label=0, touched="time", touch_bar=last_index, entry_price=entry_price, upper_barrier=upper, lower_barrier=lower, )