feat: add walk-forward splitter with purge and embargo controls (TASK-CODE-005)
This commit is contained in:
75
src/analysis/walk_forward_split.py
Normal file
75
src/analysis/walk_forward_split.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
"""Walk-forward splitter with purge/embargo controls."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class WalkForwardFold:
|
||||||
|
train_indices: list[int]
|
||||||
|
test_indices: list[int]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def train_size(self) -> int:
|
||||||
|
return len(self.train_indices)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def test_size(self) -> int:
|
||||||
|
return len(self.test_indices)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_walk_forward_splits(
|
||||||
|
*,
|
||||||
|
n_samples: int,
|
||||||
|
train_size: int,
|
||||||
|
test_size: int,
|
||||||
|
step_size: int | None = None,
|
||||||
|
purge_size: int = 0,
|
||||||
|
embargo_size: int = 0,
|
||||||
|
min_train_size: int = 1,
|
||||||
|
) -> list[WalkForwardFold]:
|
||||||
|
"""Generate chronological folds with purge/embargo leakage controls."""
|
||||||
|
if n_samples <= 0:
|
||||||
|
raise ValueError("n_samples must be positive")
|
||||||
|
if train_size <= 0 or test_size <= 0:
|
||||||
|
raise ValueError("train_size and test_size must be positive")
|
||||||
|
if purge_size < 0 or embargo_size < 0:
|
||||||
|
raise ValueError("purge_size and embargo_size must be >= 0")
|
||||||
|
if min_train_size <= 0:
|
||||||
|
raise ValueError("min_train_size must be positive")
|
||||||
|
|
||||||
|
step = step_size if step_size is not None else test_size
|
||||||
|
if step <= 0:
|
||||||
|
raise ValueError("step_size must be positive")
|
||||||
|
|
||||||
|
folds: list[WalkForwardFold] = []
|
||||||
|
prev_test_end: int | None = None
|
||||||
|
test_start = train_size + purge_size
|
||||||
|
|
||||||
|
while test_start + test_size <= n_samples:
|
||||||
|
test_end = test_start + test_size - 1
|
||||||
|
train_end = test_start - purge_size - 1
|
||||||
|
if train_end < 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
train_start = max(0, train_end - train_size + 1)
|
||||||
|
train_indices = list(range(train_start, train_end + 1))
|
||||||
|
|
||||||
|
if prev_test_end is not None and embargo_size > 0:
|
||||||
|
emb_from = prev_test_end + 1
|
||||||
|
emb_to = prev_test_end + embargo_size
|
||||||
|
train_indices = [i for i in train_indices if i < emb_from or i > emb_to]
|
||||||
|
|
||||||
|
if len(train_indices) >= min_train_size:
|
||||||
|
folds.append(
|
||||||
|
WalkForwardFold(
|
||||||
|
train_indices=train_indices,
|
||||||
|
test_indices=list(range(test_start, test_end + 1)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
prev_test_end = test_end
|
||||||
|
test_start += step
|
||||||
|
|
||||||
|
return folds
|
||||||
76
tests/test_walk_forward_split.py
Normal file
76
tests/test_walk_forward_split.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.analysis.walk_forward_split import generate_walk_forward_splits
|
||||||
|
|
||||||
|
|
||||||
|
def test_generates_sequential_folds() -> None:
|
||||||
|
folds = generate_walk_forward_splits(
|
||||||
|
n_samples=30,
|
||||||
|
train_size=10,
|
||||||
|
test_size=5,
|
||||||
|
)
|
||||||
|
assert len(folds) == 4
|
||||||
|
assert folds[0].train_indices == list(range(0, 10))
|
||||||
|
assert folds[0].test_indices == list(range(10, 15))
|
||||||
|
assert folds[1].train_indices == list(range(5, 15))
|
||||||
|
assert folds[1].test_indices == list(range(15, 20))
|
||||||
|
|
||||||
|
|
||||||
|
def test_purge_removes_boundary_samples_before_test() -> None:
|
||||||
|
folds = generate_walk_forward_splits(
|
||||||
|
n_samples=25,
|
||||||
|
train_size=8,
|
||||||
|
test_size=4,
|
||||||
|
purge_size=2,
|
||||||
|
)
|
||||||
|
first = folds[0]
|
||||||
|
# test starts at 10, purge=2 => train end must be 7
|
||||||
|
assert first.train_indices == list(range(0, 8))
|
||||||
|
assert first.test_indices == list(range(10, 14))
|
||||||
|
|
||||||
|
|
||||||
|
def test_embargo_excludes_post_test_samples_from_next_train() -> None:
|
||||||
|
folds = generate_walk_forward_splits(
|
||||||
|
n_samples=45,
|
||||||
|
train_size=15,
|
||||||
|
test_size=5,
|
||||||
|
step_size=10,
|
||||||
|
embargo_size=3,
|
||||||
|
)
|
||||||
|
assert len(folds) >= 2
|
||||||
|
# Fold1 test: 15..19, next fold train window: 10..24.
|
||||||
|
# embargo_size=3 should remove 20,21,22 from fold2 train.
|
||||||
|
second_train = folds[1].train_indices
|
||||||
|
assert 20 not in second_train
|
||||||
|
assert 21 not in second_train
|
||||||
|
assert 22 not in second_train
|
||||||
|
assert 23 in second_train
|
||||||
|
|
||||||
|
|
||||||
|
def test_respects_min_train_size_and_returns_empty_when_impossible() -> None:
|
||||||
|
folds = generate_walk_forward_splits(
|
||||||
|
n_samples=15,
|
||||||
|
train_size=5,
|
||||||
|
test_size=5,
|
||||||
|
min_train_size=6,
|
||||||
|
)
|
||||||
|
assert folds == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("n_samples", "train_size", "test_size"),
|
||||||
|
[
|
||||||
|
(0, 10, 2),
|
||||||
|
(10, 0, 2),
|
||||||
|
(10, 5, 0),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_invalid_args_raise(n_samples: int, train_size: int, test_size: int) -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
generate_walk_forward_splits(
|
||||||
|
n_samples=n_samples,
|
||||||
|
train_size=train_size,
|
||||||
|
test_size=test_size,
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user