feat: Experiment Builder
This commit is contained in:
@@ -1,5 +1,9 @@
|
||||
from typing import List
|
||||
import logging
|
||||
from typing import List, Dict
|
||||
|
||||
import yaml
|
||||
|
||||
from core.config.pipeline_config import PipelineConfig
|
||||
from research.experiment import ExperimentConfig
|
||||
from research.experiment.feature_extractor import FeatureType
|
||||
|
||||
@@ -7,117 +11,98 @@ from research.experiment.feature_extractor import FeatureType
|
||||
class ExperimentBuilder:
|
||||
"""Helper class to build experiment configurations"""
|
||||
|
||||
@staticmethod
|
||||
def create_baseline_experiments() -> List[ExperimentConfig]:
|
||||
"""Create a set of baseline experiments for comparison"""
|
||||
def __init__(self, config: PipelineConfig):
|
||||
self.config = config
|
||||
|
||||
return [
|
||||
# Full name experiments
|
||||
ExperimentConfig(
|
||||
name="baseline_logistic_regression_fullname",
|
||||
description="Logistic regression with full name",
|
||||
model_type="logistic_regression",
|
||||
features=[FeatureType.FULL_NAME],
|
||||
tags=["baseline", "fullname"],
|
||||
),
|
||||
# Native name only
|
||||
ExperimentConfig(
|
||||
name="baseline_logistic_regression_native",
|
||||
description="Logistic regression with native name only",
|
||||
model_type="logistic_regression",
|
||||
features=[FeatureType.NATIVE_NAME],
|
||||
tags=["baseline", "native"],
|
||||
),
|
||||
# Surname only
|
||||
ExperimentConfig(
|
||||
name="baseline_logistic_regression_surname",
|
||||
description="Logistic regression with surname only",
|
||||
model_type="logistic_regression",
|
||||
features=[FeatureType.SURNAME],
|
||||
tags=["baseline", "surname"],
|
||||
),
|
||||
# Random Forest with engineered features
|
||||
ExperimentConfig(
|
||||
name="baseline_rf_engineered",
|
||||
description="Random Forest with engineered features",
|
||||
model_type="random_forest",
|
||||
features=[FeatureType.NAME_LENGTH, FeatureType.WORD_COUNT, FeatureType.PROVINCE],
|
||||
tags=["baseline", "engineered"],
|
||||
),
|
||||
]
|
||||
def load_templates(self, templates: str = "research_templates.yaml") -> dict:
|
||||
"""Load research templates from YAML file"""
|
||||
try:
|
||||
with open(self.config.paths.configs_dir / templates, "r") as file:
|
||||
return yaml.safe_load(file)
|
||||
except FileNotFoundError:
|
||||
logging.error(f"Templates file not found: {templates}")
|
||||
raise
|
||||
except yaml.YAMLError as e:
|
||||
logging.error(f"Error parsing templates file: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def create_feature_ablation_study() -> List[ExperimentConfig]:
|
||||
"""Create experiments for feature ablation study"""
|
||||
base_features = [
|
||||
FeatureType.FULL_NAME,
|
||||
FeatureType.NAME_LENGTH,
|
||||
FeatureType.WORD_COUNT,
|
||||
FeatureType.PROVINCE,
|
||||
]
|
||||
@classmethod
|
||||
def find_template(cls, templates: dict, name: str, experiment_type: str = "baseline") -> dict:
|
||||
"""Find experiment configuration by name and type"""
|
||||
|
||||
experiments = []
|
||||
# Map type to section in templates
|
||||
type_mapping = {
|
||||
"baseline": "baseline_experiments",
|
||||
"advanced": "advanced_experiments",
|
||||
"feature_study": "feature_studies",
|
||||
"tuning": "hyperparameter_tuning",
|
||||
}
|
||||
|
||||
# Test removing each feature one by one
|
||||
for i, feature_to_remove in enumerate(base_features):
|
||||
remaining_features = [f for f in base_features if f != feature_to_remove]
|
||||
|
||||
experiments.append(
|
||||
ExperimentConfig(
|
||||
name=f"ablation_remove_{feature_to_remove.value}",
|
||||
description=f"Ablation study: removed {feature_to_remove.value}",
|
||||
model_type="logistic_regression",
|
||||
features=remaining_features,
|
||||
tags=["ablation", feature_to_remove.value],
|
||||
)
|
||||
section_name = type_mapping.get(experiment_type)
|
||||
if not section_name:
|
||||
available_types = list(type_mapping.keys())
|
||||
raise ValueError(
|
||||
f"Unknown experiment type '{experiment_type}'. Available types: {available_types}"
|
||||
)
|
||||
|
||||
return experiments
|
||||
if section_name not in templates:
|
||||
raise ValueError(f"Section '{section_name}' not found in templates")
|
||||
|
||||
@staticmethod
|
||||
def create_name_component_study() -> List[ExperimentConfig]:
|
||||
"""Create experiments to study different name components"""
|
||||
experiments = []
|
||||
experiments = templates[section_name]
|
||||
|
||||
name_components = [
|
||||
(FeatureType.FIRST_WORD, "first_word"),
|
||||
(FeatureType.LAST_WORD, "last_word"),
|
||||
(FeatureType.NATIVE_NAME, "native_name"),
|
||||
(FeatureType.SURNAME, "surname"),
|
||||
(FeatureType.NAME_BEGINNINGS, "name_beginnings"),
|
||||
(FeatureType.NAME_ENDINGS, "name_endings"),
|
||||
# Search for experiment by model name
|
||||
for experiment in experiments:
|
||||
# Check if this is the experiment we're looking for
|
||||
# Look for experiments that match the model type or contain the name
|
||||
if (
|
||||
experiment.get("model_type") == name
|
||||
or name.lower() in experiment.get("name", "").lower()
|
||||
or experiment.get("name") == name
|
||||
or f"baseline_{name}" == experiment.get("name")
|
||||
or f"advanced_{name}" == experiment.get("name")
|
||||
):
|
||||
return experiment
|
||||
|
||||
# If not found, list available experiments
|
||||
available_experiments = [
|
||||
exp.get("name", exp.get("model_type", "unknown")) for exp in experiments
|
||||
]
|
||||
raise ValueError(
|
||||
f"Experiment '{name}' not found in '{experiment_type}' section. "
|
||||
f"Available experiments: {available_experiments}"
|
||||
)
|
||||
|
||||
for feature, name in name_components:
|
||||
experiments.append(
|
||||
ExperimentConfig(
|
||||
name=f"component_study_{name}",
|
||||
description=f"Study of {name} for gender prediction",
|
||||
model_type="logistic_regression",
|
||||
features=[feature],
|
||||
tags=["component_study", name],
|
||||
)
|
||||
)
|
||||
def get_templates(self, templates_path: str = "research_templates.yaml") -> Dict[str, List[Dict]]:
|
||||
"""Get all available experiments from templates organized by type"""
|
||||
templates = self.load_templates(templates_path)
|
||||
|
||||
return experiments
|
||||
return {
|
||||
"baseline": templates.get("baseline_experiments", []),
|
||||
"advanced": templates.get("advanced_experiments", []),
|
||||
"feature_study": templates.get("feature_studies", []),
|
||||
"tuning": templates.get("hyperparameter_tuning", [])
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_province_specific_study() -> List[ExperimentConfig]:
|
||||
"""Create experiments for province-specific analysis"""
|
||||
provinces = ["kinshasa", "bas-congo", "bandundu", "katanga"] # Add more as needed
|
||||
@classmethod
|
||||
def from_template(cls, template_config: dict) -> ExperimentConfig:
|
||||
"""Create an ExperimentConfig from a template configuration"""
|
||||
# Convert feature strings to FeatureType objects
|
||||
features = []
|
||||
for feature_str in template_config.get("features", []):
|
||||
try:
|
||||
features.append(FeatureType(feature_str))
|
||||
except ValueError:
|
||||
logging.warning(f"Unknown feature type: {feature_str}")
|
||||
continue
|
||||
|
||||
experiments = []
|
||||
|
||||
for province in provinces:
|
||||
experiments.append(
|
||||
ExperimentConfig(
|
||||
name=f"province_study_{province}",
|
||||
description=f"Gender prediction for {province} province only",
|
||||
model_type="logistic_regression",
|
||||
features=[FeatureType.FULL_NAME],
|
||||
train_data_filter={"province": province},
|
||||
tags=["province_study", province],
|
||||
)
|
||||
)
|
||||
|
||||
return experiments
|
||||
return ExperimentConfig(
|
||||
name=template_config.get("name"),
|
||||
description=template_config.get("description"),
|
||||
model_type=template_config.get("model_type"),
|
||||
features=features,
|
||||
model_params=template_config.get("model_params", {}),
|
||||
tags=template_config.get("tags", []),
|
||||
test_size=template_config.get("test_size", 0.2),
|
||||
cross_validation_folds=template_config.get("cross_validation_folds", 5),
|
||||
train_data_filter=template_config.get("train_data_filter")
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user