61 lines
2.6 KiB
Python
61 lines
2.6 KiB
Python
import numpy as np
|
|
import pandas as pd
|
|
|
|
from core.config.pipeline_config import PipelineConfig
|
|
from processing.steps.feature_extraction_step import Gender
|
|
from core.utils.data_loader import DataLoader
|
|
|
|
from processing.batch.batch_config import BatchConfig
|
|
from processing.steps import PipelineStep
|
|
|
|
|
|
class DataSplittingStep(PipelineStep):
|
|
"""Configuration-driven data splitting step"""
|
|
|
|
def __init__(self, pipeline_config: PipelineConfig):
|
|
batch_config = BatchConfig(
|
|
batch_size=pipeline_config.processing.batch_size,
|
|
max_workers=1, # No need for parallelism in splitting
|
|
checkpoint_interval=pipeline_config.processing.checkpoint_interval,
|
|
use_multiprocessing=False,
|
|
)
|
|
super().__init__("data_splitting", pipeline_config, batch_config)
|
|
self.data_loader = DataLoader(pipeline_config)
|
|
self.eval_indices = None
|
|
|
|
def determine_eval_indices(self, total_size: int) -> set:
|
|
"""Determine evaluation indices consistently across batches"""
|
|
if self.eval_indices is None:
|
|
np.random.seed(self.pipeline_config.data.random_seed)
|
|
eval_size = int(total_size * self.pipeline_config.data.evaluation_fraction)
|
|
self.eval_indices = set(np.random.choice(total_size, size=eval_size, replace=False))
|
|
return self.eval_indices
|
|
|
|
def process_batch(self, batch: pd.DataFrame, batch_id: int) -> pd.DataFrame:
|
|
"""Process batch for data splitting - no modification needed"""
|
|
return batch.copy()
|
|
|
|
def save_splits(self, df: pd.DataFrame) -> None:
|
|
"""Save the split datasets based on configuration"""
|
|
output_files = self.pipeline_config.data.output_files
|
|
data_dir = self.pipeline_config.paths.data_dir
|
|
|
|
if self.pipeline_config.data.split_evaluation:
|
|
eval_indices = self.determine_eval_indices(len(df))
|
|
eval_mask = df.index.isin(eval_indices)
|
|
|
|
df_evaluation = df[eval_mask]
|
|
df_featured = df[~eval_mask]
|
|
|
|
self.data_loader.save_csv(df_evaluation, data_dir / output_files["evaluation"])
|
|
self.data_loader.save_csv(df_featured, data_dir / output_files["featured"])
|
|
else:
|
|
self.data_loader.save_csv(df, data_dir / output_files["featured"])
|
|
|
|
if self.pipeline_config.data.split_by_gender:
|
|
df_males = df[df["sex"] == Gender.MALE.value]
|
|
df_females = df[df["sex"] == Gender.FEMALE.value]
|
|
|
|
self.data_loader.save_csv(df_males, data_dir / output_files["males"])
|
|
self.data_loader.save_csv(df_females, data_dir / output_files["females"])
|