92 lines
2.7 KiB
Python
92 lines
2.7 KiB
Python
from dataclasses import dataclass, field, asdict
|
|
from enum import Enum
|
|
from typing import List, Dict, Any, Optional
|
|
|
|
import numpy as np
|
|
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
|
|
|
from .feature_extractor import FeatureType
|
|
|
|
|
|
@dataclass
|
|
class ExperimentConfig:
|
|
"""Configuration for a single experiment"""
|
|
|
|
# Experiment metadata
|
|
name: str
|
|
description: str = ""
|
|
tags: List[str] = field(default_factory=list)
|
|
|
|
# Model configuration
|
|
model_type: str = "logistic_regression" # logistic_regression, lstm, transformer, etc.
|
|
model_params: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
# Feature configuration
|
|
features: List[FeatureType] = field(default_factory=lambda: [FeatureType.FULL_NAME])
|
|
feature_params: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
# Data configuration
|
|
train_data_filter: Optional[Dict[str, Any]] = None # Filter criteria for training data
|
|
test_data_filter: Optional[Dict[str, Any]] = None
|
|
target_column: str = "sex"
|
|
|
|
# Training configuration
|
|
test_size: float = 0.2
|
|
random_seed: int = 42
|
|
cross_validation_folds: int = 5
|
|
|
|
# Evaluation configuration
|
|
metrics: List[str] = field(default_factory=lambda: ["accuracy", "precision", "recall", "f1"])
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to dictionary for serialization"""
|
|
result = asdict(self)
|
|
# Convert enums to strings
|
|
result["features"] = [f.value for f in self.features]
|
|
return result
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "ExperimentConfig":
|
|
"""Create from dictionary"""
|
|
if "features" in data:
|
|
data["features"] = [FeatureType(f) for f in data["features"]]
|
|
return cls(**data)
|
|
|
|
|
|
class ExperimentStatus(Enum):
|
|
"""Experiment execution status"""
|
|
|
|
PENDING = "pending"
|
|
RUNNING = "running"
|
|
COMPLETED = "completed"
|
|
FAILED = "failed"
|
|
CANCELLED = "cancelled"
|
|
|
|
|
|
def calculate_metrics(
|
|
y_true: np.ndarray, y_pred: np.ndarray, metrics: List[str] = None
|
|
) -> Dict[str, float]:
|
|
"""Calculate specified metrics"""
|
|
|
|
if metrics is None:
|
|
metrics = ["accuracy", "precision", "recall", "f1"]
|
|
|
|
results = {}
|
|
|
|
if "accuracy" in metrics:
|
|
results["accuracy"] = accuracy_score(y_true, y_pred)
|
|
|
|
if any(m in metrics for m in ["precision", "recall", "f1"]):
|
|
precision, recall, f1, _ = precision_recall_fscore_support(
|
|
y_true, y_pred, average="weighted"
|
|
)
|
|
|
|
if "precision" in metrics:
|
|
results["precision"] = precision
|
|
if "recall" in metrics:
|
|
results["recall"] = recall
|
|
if "f1" in metrics:
|
|
results["f1"] = f1
|
|
|
|
return results
|