refactoring: add initial pipeline configuration and model classes
This commit is contained in:
@@ -0,0 +1,60 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from core.config.pipeline_config import PipelineConfig
|
||||
from processing.steps.feature_extraction_step import Gender
|
||||
from core.utils.data_loader import DataLoader
|
||||
|
||||
from processing.batch.batch_config import BatchConfig
|
||||
from processing.steps import PipelineStep
|
||||
|
||||
|
||||
class DataSplittingStep(PipelineStep):
|
||||
"""Configuration-driven data splitting step"""
|
||||
|
||||
def __init__(self, pipeline_config: PipelineConfig):
|
||||
batch_config = BatchConfig(
|
||||
batch_size=pipeline_config.processing.batch_size,
|
||||
max_workers=1, # No need for parallelism in splitting
|
||||
checkpoint_interval=pipeline_config.processing.checkpoint_interval,
|
||||
use_multiprocessing=False,
|
||||
)
|
||||
super().__init__("data_splitting", pipeline_config, batch_config)
|
||||
self.data_loader = DataLoader(pipeline_config)
|
||||
self.eval_indices = None
|
||||
|
||||
def determine_eval_indices(self, total_size: int) -> set:
|
||||
"""Determine evaluation indices consistently across batches"""
|
||||
if self.eval_indices is None:
|
||||
np.random.seed(self.pipeline_config.data.random_seed)
|
||||
eval_size = int(total_size * self.pipeline_config.data.evaluation_fraction)
|
||||
self.eval_indices = set(np.random.choice(total_size, size=eval_size, replace=False))
|
||||
return self.eval_indices
|
||||
|
||||
def process_batch(self, batch: pd.DataFrame, batch_id: int) -> pd.DataFrame:
|
||||
"""Process batch for data splitting - no modification needed"""
|
||||
return batch.copy()
|
||||
|
||||
def save_splits(self, df: pd.DataFrame) -> None:
|
||||
"""Save the split datasets based on configuration"""
|
||||
output_files = self.pipeline_config.data.output_files
|
||||
data_dir = self.pipeline_config.paths.data_dir
|
||||
|
||||
if self.pipeline_config.data.split_evaluation:
|
||||
eval_indices = self.determine_eval_indices(len(df))
|
||||
eval_mask = df.index.isin(eval_indices)
|
||||
|
||||
df_evaluation = df[eval_mask]
|
||||
df_featured = df[~eval_mask]
|
||||
|
||||
self.data_loader.save_csv(df_evaluation, data_dir / output_files["evaluation"])
|
||||
self.data_loader.save_csv(df_featured, data_dir / output_files["featured"])
|
||||
else:
|
||||
self.data_loader.save_csv(df, data_dir / output_files["featured"])
|
||||
|
||||
if self.pipeline_config.data.split_by_gender:
|
||||
df_males = df[df["sex"] == Gender.MALE.value]
|
||||
df_females = df[df["sex"] == Gender.FEMALE.value]
|
||||
|
||||
self.data_loader.save_csv(df_males, data_dir / output_files["males"])
|
||||
self.data_loader.save_csv(df_females, data_dir / output_files["females"])
|
||||
Reference in New Issue
Block a user