feat: add walk-forward splitter with purge and embargo controls (TASK-CODE-005)

This commit is contained in:
agentson
2026-02-27 08:28:11 +09:00
parent cfd5351b58
commit e56819e9e2
2 changed files with 151 additions and 0 deletions

View File

@@ -0,0 +1,75 @@
"""Walk-forward splitter with purge/embargo controls."""
from __future__ import annotations
from dataclasses import dataclass
@dataclass(frozen=True)
class WalkForwardFold:
train_indices: list[int]
test_indices: list[int]
@property
def train_size(self) -> int:
return len(self.train_indices)
@property
def test_size(self) -> int:
return len(self.test_indices)
def generate_walk_forward_splits(
*,
n_samples: int,
train_size: int,
test_size: int,
step_size: int | None = None,
purge_size: int = 0,
embargo_size: int = 0,
min_train_size: int = 1,
) -> list[WalkForwardFold]:
"""Generate chronological folds with purge/embargo leakage controls."""
if n_samples <= 0:
raise ValueError("n_samples must be positive")
if train_size <= 0 or test_size <= 0:
raise ValueError("train_size and test_size must be positive")
if purge_size < 0 or embargo_size < 0:
raise ValueError("purge_size and embargo_size must be >= 0")
if min_train_size <= 0:
raise ValueError("min_train_size must be positive")
step = step_size if step_size is not None else test_size
if step <= 0:
raise ValueError("step_size must be positive")
folds: list[WalkForwardFold] = []
prev_test_end: int | None = None
test_start = train_size + purge_size
while test_start + test_size <= n_samples:
test_end = test_start + test_size - 1
train_end = test_start - purge_size - 1
if train_end < 0:
break
train_start = max(0, train_end - train_size + 1)
train_indices = list(range(train_start, train_end + 1))
if prev_test_end is not None and embargo_size > 0:
emb_from = prev_test_end + 1
emb_to = prev_test_end + embargo_size
train_indices = [i for i in train_indices if i < emb_from or i > emb_to]
if len(train_indices) >= min_train_size:
folds.append(
WalkForwardFold(
train_indices=train_indices,
test_indices=list(range(test_start, test_end + 1)),
)
)
prev_test_end = test_end
test_start += step
return folds