From e56819e9e238b996dbaca0f42405604804952438 Mon Sep 17 00:00:00 2001 From: agentson Date: Fri, 27 Feb 2026 08:28:11 +0900 Subject: [PATCH] feat: add walk-forward splitter with purge and embargo controls (TASK-CODE-005) --- src/analysis/walk_forward_split.py | 75 +++++++++++++++++++++++++++++ tests/test_walk_forward_split.py | 76 ++++++++++++++++++++++++++++++ 2 files changed, 151 insertions(+) create mode 100644 src/analysis/walk_forward_split.py create mode 100644 tests/test_walk_forward_split.py diff --git a/src/analysis/walk_forward_split.py b/src/analysis/walk_forward_split.py new file mode 100644 index 0000000..2ff7837 --- /dev/null +++ b/src/analysis/walk_forward_split.py @@ -0,0 +1,75 @@ +"""Walk-forward splitter with purge/embargo controls.""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class WalkForwardFold: + train_indices: list[int] + test_indices: list[int] + + @property + def train_size(self) -> int: + return len(self.train_indices) + + @property + def test_size(self) -> int: + return len(self.test_indices) + + +def generate_walk_forward_splits( + *, + n_samples: int, + train_size: int, + test_size: int, + step_size: int | None = None, + purge_size: int = 0, + embargo_size: int = 0, + min_train_size: int = 1, +) -> list[WalkForwardFold]: + """Generate chronological folds with purge/embargo leakage controls.""" + if n_samples <= 0: + raise ValueError("n_samples must be positive") + if train_size <= 0 or test_size <= 0: + raise ValueError("train_size and test_size must be positive") + if purge_size < 0 or embargo_size < 0: + raise ValueError("purge_size and embargo_size must be >= 0") + if min_train_size <= 0: + raise ValueError("min_train_size must be positive") + + step = step_size if step_size is not None else test_size + if step <= 0: + raise ValueError("step_size must be positive") + + folds: list[WalkForwardFold] = [] + prev_test_end: int | None = None + test_start = train_size + purge_size + + while test_start + test_size <= n_samples: + test_end = test_start + test_size - 1 + train_end = test_start - purge_size - 1 + if train_end < 0: + break + + train_start = max(0, train_end - train_size + 1) + train_indices = list(range(train_start, train_end + 1)) + + if prev_test_end is not None and embargo_size > 0: + emb_from = prev_test_end + 1 + emb_to = prev_test_end + embargo_size + train_indices = [i for i in train_indices if i < emb_from or i > emb_to] + + if len(train_indices) >= min_train_size: + folds.append( + WalkForwardFold( + train_indices=train_indices, + test_indices=list(range(test_start, test_end + 1)), + ) + ) + + prev_test_end = test_end + test_start += step + + return folds diff --git a/tests/test_walk_forward_split.py b/tests/test_walk_forward_split.py new file mode 100644 index 0000000..c5003b8 --- /dev/null +++ b/tests/test_walk_forward_split.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import pytest + +from src.analysis.walk_forward_split import generate_walk_forward_splits + + +def test_generates_sequential_folds() -> None: + folds = generate_walk_forward_splits( + n_samples=30, + train_size=10, + test_size=5, + ) + assert len(folds) == 4 + assert folds[0].train_indices == list(range(0, 10)) + assert folds[0].test_indices == list(range(10, 15)) + assert folds[1].train_indices == list(range(5, 15)) + assert folds[1].test_indices == list(range(15, 20)) + + +def test_purge_removes_boundary_samples_before_test() -> None: + folds = generate_walk_forward_splits( + n_samples=25, + train_size=8, + test_size=4, + purge_size=2, + ) + first = folds[0] + # test starts at 10, purge=2 => train end must be 7 + assert first.train_indices == list(range(0, 8)) + assert first.test_indices == list(range(10, 14)) + + +def test_embargo_excludes_post_test_samples_from_next_train() -> None: + folds = generate_walk_forward_splits( + n_samples=45, + train_size=15, + test_size=5, + step_size=10, + embargo_size=3, + ) + assert len(folds) >= 2 + # Fold1 test: 15..19, next fold train window: 10..24. + # embargo_size=3 should remove 20,21,22 from fold2 train. + second_train = folds[1].train_indices + assert 20 not in second_train + assert 21 not in second_train + assert 22 not in second_train + assert 23 in second_train + + +def test_respects_min_train_size_and_returns_empty_when_impossible() -> None: + folds = generate_walk_forward_splits( + n_samples=15, + train_size=5, + test_size=5, + min_train_size=6, + ) + assert folds == [] + + +@pytest.mark.parametrize( + ("n_samples", "train_size", "test_size"), + [ + (0, 10, 2), + (10, 0, 2), + (10, 5, 0), + ], +) +def test_invalid_args_raise(n_samples: int, train_size: int, test_size: int) -> None: + with pytest.raises(ValueError): + generate_walk_forward_splits( + n_samples=n_samples, + train_size=train_size, + test_size=test_size, + )