feat: add walk-forward splitter with purge and embargo controls (TASK-CODE-005)
This commit is contained in:
76
tests/test_walk_forward_split.py
Normal file
76
tests/test_walk_forward_split.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from src.analysis.walk_forward_split import generate_walk_forward_splits
|
||||
|
||||
|
||||
def test_generates_sequential_folds() -> None:
|
||||
folds = generate_walk_forward_splits(
|
||||
n_samples=30,
|
||||
train_size=10,
|
||||
test_size=5,
|
||||
)
|
||||
assert len(folds) == 4
|
||||
assert folds[0].train_indices == list(range(0, 10))
|
||||
assert folds[0].test_indices == list(range(10, 15))
|
||||
assert folds[1].train_indices == list(range(5, 15))
|
||||
assert folds[1].test_indices == list(range(15, 20))
|
||||
|
||||
|
||||
def test_purge_removes_boundary_samples_before_test() -> None:
|
||||
folds = generate_walk_forward_splits(
|
||||
n_samples=25,
|
||||
train_size=8,
|
||||
test_size=4,
|
||||
purge_size=2,
|
||||
)
|
||||
first = folds[0]
|
||||
# test starts at 10, purge=2 => train end must be 7
|
||||
assert first.train_indices == list(range(0, 8))
|
||||
assert first.test_indices == list(range(10, 14))
|
||||
|
||||
|
||||
def test_embargo_excludes_post_test_samples_from_next_train() -> None:
|
||||
folds = generate_walk_forward_splits(
|
||||
n_samples=45,
|
||||
train_size=15,
|
||||
test_size=5,
|
||||
step_size=10,
|
||||
embargo_size=3,
|
||||
)
|
||||
assert len(folds) >= 2
|
||||
# Fold1 test: 15..19, next fold train window: 10..24.
|
||||
# embargo_size=3 should remove 20,21,22 from fold2 train.
|
||||
second_train = folds[1].train_indices
|
||||
assert 20 not in second_train
|
||||
assert 21 not in second_train
|
||||
assert 22 not in second_train
|
||||
assert 23 in second_train
|
||||
|
||||
|
||||
def test_respects_min_train_size_and_returns_empty_when_impossible() -> None:
|
||||
folds = generate_walk_forward_splits(
|
||||
n_samples=15,
|
||||
train_size=5,
|
||||
test_size=5,
|
||||
min_train_size=6,
|
||||
)
|
||||
assert folds == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("n_samples", "train_size", "test_size"),
|
||||
[
|
||||
(0, 10, 2),
|
||||
(10, 0, 2),
|
||||
(10, 5, 0),
|
||||
],
|
||||
)
|
||||
def test_invalid_args_raise(n_samples: int, train_size: int, test_size: int) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
generate_walk_forward_splits(
|
||||
n_samples=n_samples,
|
||||
train_size=train_size,
|
||||
test_size=test_size,
|
||||
)
|
||||
Reference in New Issue
Block a user