feat: implement unified configuration loading and logging setup across entry points
This commit is contained in:
@@ -21,6 +21,41 @@ def load_config(config_path: Optional[Union[str, Path]] = None) -> PipelineConfi
|
||||
return config_manager.get_config()
|
||||
|
||||
|
||||
def setup_config_and_logging(
|
||||
config_path: Optional[Path] = None,
|
||||
env: str = "development"
|
||||
) -> PipelineConfig:
|
||||
"""
|
||||
Unified configuration loading and logging setup for all entrypoint scripts.
|
||||
|
||||
Args:
|
||||
config_path: Direct path to config file (takes precedence over env)
|
||||
env: Environment name (defaults to "development")
|
||||
|
||||
Returns:
|
||||
Loaded configuration object
|
||||
"""
|
||||
# Determine config path
|
||||
if config_path is None:
|
||||
config_path = Path("config") / f"pipeline.{env}.yaml"
|
||||
|
||||
# Load configuration
|
||||
config = ConfigManager(config_path).load_config()
|
||||
|
||||
# Setup logging
|
||||
setup_logging(config)
|
||||
|
||||
# Ensure required directories exist
|
||||
from core.utils import ensure_directories
|
||||
ensure_directories(config)
|
||||
|
||||
logging.info(f"Loaded configuration: {config.name} v{config.version}")
|
||||
logging.info(f"Environment: {config.environment}")
|
||||
logging.info(f"Config file: {config_path}")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def setup_logging(config: PipelineConfig):
|
||||
"""Setup logging based on configuration"""
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import field
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -20,3 +20,7 @@ class DataConfig(BaseModel):
|
||||
split_by_gender: bool = True
|
||||
evaluation_fraction: float = 0.2
|
||||
random_seed: int = 42
|
||||
|
||||
# Dataset size limiting options
|
||||
max_dataset_size: Optional[int] = None
|
||||
balance_by_sex: bool = False
|
||||
|
||||
@@ -44,9 +44,71 @@ class DataLoader:
|
||||
raise ValueError(f"Unable to decode {filepath} with any encoding: {encodings}")
|
||||
|
||||
def load_csv_complete(self, filepath: Union[str, Path]) -> pd.DataFrame:
|
||||
"""Load complete CSV file into memory"""
|
||||
"""Load complete CSV file into memory with size limiting and balancing"""
|
||||
chunks = list(self.load_csv_chunked(filepath))
|
||||
return pd.concat(chunks, ignore_index=True) if chunks else pd.DataFrame()
|
||||
if not chunks:
|
||||
return pd.DataFrame()
|
||||
|
||||
df = pd.concat(chunks, ignore_index=True)
|
||||
|
||||
# Apply dataset size limiting if configured
|
||||
if self.config.data.max_dataset_size is not None:
|
||||
df = self._limit_dataset_size(df)
|
||||
|
||||
return df
|
||||
|
||||
def _limit_dataset_size(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Limit dataset size with optional sex balancing"""
|
||||
max_size = self.config.data.max_dataset_size
|
||||
|
||||
if max_size is None or len(df) <= max_size:
|
||||
return df
|
||||
|
||||
if self.config.data.balance_by_sex and "sex" in df.columns:
|
||||
return self._balanced_sample(df, max_size)
|
||||
else:
|
||||
# Simple random sampling
|
||||
return df.sample(n=max_size, random_state=self.config.data.random_seed)
|
||||
|
||||
def _balanced_sample(self, df: pd.DataFrame, max_size: int) -> pd.DataFrame:
|
||||
"""Sample data with balanced sex distribution"""
|
||||
|
||||
# Get unique sex values
|
||||
sex_values = df["sex"].dropna().unique()
|
||||
|
||||
if len(sex_values) == 0:
|
||||
logging.warning(f"No valid values found in sex column 'sex', using random sampling")
|
||||
return df.sample(n=max_size, random_state=self.config.data.random_seed)
|
||||
|
||||
# Calculate samples per sex category
|
||||
samples_per_sex = max_size // len(sex_values)
|
||||
remaining_samples = max_size % len(sex_values)
|
||||
|
||||
balanced_samples = []
|
||||
|
||||
for i, sex in enumerate(sex_values):
|
||||
sex_df = df[df["sex"] == sex]
|
||||
|
||||
# Distribute remaining samples to first categories
|
||||
current_samples = samples_per_sex + (1 if i < remaining_samples else 0)
|
||||
current_samples = min(current_samples, len(sex_df))
|
||||
|
||||
if current_samples > 0:
|
||||
sample = sex_df.sample(n=current_samples, random_state=self.config.data.random_seed + i)
|
||||
balanced_samples.append(sample)
|
||||
logging.info(f"Sampled {current_samples} records for sex '{sex}'")
|
||||
|
||||
if not balanced_samples:
|
||||
logging.warning("No balanced samples could be created, using random sampling")
|
||||
return df.sample(n=max_size, random_state=self.config.data.random_seed)
|
||||
|
||||
result = pd.concat(balanced_samples, ignore_index=True)
|
||||
|
||||
# Shuffle the final result
|
||||
result = result.sample(frac=1, random_state=self.config.data.random_seed).reset_index(drop=True)
|
||||
|
||||
logging.info(f"Created balanced dataset with {len(result)} records from {len(df)} total records")
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def save_csv(
|
||||
|
||||
Reference in New Issue
Block a user