77 lines
2.1 KiB
Python
77 lines
2.1 KiB
Python
from __future__ import annotations
|
|
|
|
import pytest
|
|
|
|
from src.analysis.walk_forward_split import generate_walk_forward_splits
|
|
|
|
|
|
def test_generates_sequential_folds() -> None:
|
|
folds = generate_walk_forward_splits(
|
|
n_samples=30,
|
|
train_size=10,
|
|
test_size=5,
|
|
)
|
|
assert len(folds) == 4
|
|
assert folds[0].train_indices == list(range(0, 10))
|
|
assert folds[0].test_indices == list(range(10, 15))
|
|
assert folds[1].train_indices == list(range(5, 15))
|
|
assert folds[1].test_indices == list(range(15, 20))
|
|
|
|
|
|
def test_purge_removes_boundary_samples_before_test() -> None:
|
|
folds = generate_walk_forward_splits(
|
|
n_samples=25,
|
|
train_size=8,
|
|
test_size=4,
|
|
purge_size=2,
|
|
)
|
|
first = folds[0]
|
|
# test starts at 10, purge=2 => train end must be 7
|
|
assert first.train_indices == list(range(0, 8))
|
|
assert first.test_indices == list(range(10, 14))
|
|
|
|
|
|
def test_embargo_excludes_post_test_samples_from_next_train() -> None:
|
|
folds = generate_walk_forward_splits(
|
|
n_samples=45,
|
|
train_size=15,
|
|
test_size=5,
|
|
step_size=10,
|
|
embargo_size=3,
|
|
)
|
|
assert len(folds) >= 2
|
|
# Fold1 test: 15..19, next fold train window: 10..24.
|
|
# embargo_size=3 should remove 20,21,22 from fold2 train.
|
|
second_train = folds[1].train_indices
|
|
assert 20 not in second_train
|
|
assert 21 not in second_train
|
|
assert 22 not in second_train
|
|
assert 23 in second_train
|
|
|
|
|
|
def test_respects_min_train_size_and_returns_empty_when_impossible() -> None:
|
|
folds = generate_walk_forward_splits(
|
|
n_samples=15,
|
|
train_size=5,
|
|
test_size=5,
|
|
min_train_size=6,
|
|
)
|
|
assert folds == []
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("n_samples", "train_size", "test_size"),
|
|
[
|
|
(0, 10, 2),
|
|
(10, 0, 2),
|
|
(10, 5, 0),
|
|
],
|
|
)
|
|
def test_invalid_args_raise(n_samples: int, train_size: int, test_size: int) -> None:
|
|
with pytest.raises(ValueError):
|
|
generate_walk_forward_splits(
|
|
n_samples=n_samples,
|
|
train_size=train_size,
|
|
test_size=test_size,
|
|
)
|