import numpy as np import pandas as pd from core.config.pipeline_config import PipelineConfig from processing.steps.feature_extraction_step import Gender from core.utils.data_loader import DataLoader from processing.batch.batch_config import BatchConfig from processing.steps import PipelineStep class DataSplittingStep(PipelineStep): """Configuration-driven data splitting step""" def __init__(self, pipeline_config: PipelineConfig): batch_config = BatchConfig( batch_size=pipeline_config.processing.batch_size, max_workers=1, # No need for parallelism in splitting checkpoint_interval=pipeline_config.processing.checkpoint_interval, use_multiprocessing=False, ) super().__init__("data_splitting", pipeline_config, batch_config) self.data_loader = DataLoader(pipeline_config) self.eval_indices = None def determine_eval_indices(self, total_size: int) -> set: """Determine evaluation indices consistently across batches""" if self.eval_indices is None: np.random.seed(self.pipeline_config.data.random_seed) eval_size = int(total_size * self.pipeline_config.data.evaluation_fraction) self.eval_indices = set(np.random.choice(total_size, size=eval_size, replace=False)) return self.eval_indices def process_batch(self, batch: pd.DataFrame, batch_id: int) -> pd.DataFrame: """Process batch for data splitting - no modification needed""" return batch.copy() def save_splits(self, df: pd.DataFrame) -> None: """Save the split datasets based on configuration""" output_files = self.pipeline_config.data.output_files data_dir = self.pipeline_config.paths.data_dir if self.pipeline_config.data.split_evaluation: eval_indices = self.determine_eval_indices(len(df)) eval_mask = df.index.isin(eval_indices) df_evaluation = df[eval_mask] df_featured = df[~eval_mask] self.data_loader.save_csv(df_evaluation, data_dir / output_files["evaluation"]) self.data_loader.save_csv(df_featured, data_dir / output_files["featured"]) else: self.data_loader.save_csv(df, data_dir / output_files["featured"]) if self.pipeline_config.data.split_by_gender: df_males = df[df["sex"] == Gender.MALE.value] df_females = df[df["sex"] == Gender.FEMALE.value] self.data_loader.save_csv(df_males, data_dir / output_files["males"]) self.data_loader.save_csv(df_females, data_dir / output_files["females"])