Files
drc-ners-nlp/research/experiment/__init__.py
T

92 lines
2.7 KiB
Python

from dataclasses import dataclass, field, asdict
from enum import Enum
from typing import List, Dict, Any, Optional
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from .feature_extractor import FeatureType
@dataclass
class ExperimentConfig:
"""Configuration for a single experiment"""
# Experiment metadata
name: str
description: str = ""
tags: List[str] = field(default_factory=list)
# Model configuration
model_type: str = "logistic_regression" # logistic_regression, lstm, transformer, etc.
model_params: Dict[str, Any] = field(default_factory=dict)
# Feature configuration
features: List[FeatureType] = field(default_factory=lambda: [FeatureType.FULL_NAME])
feature_params: Dict[str, Any] = field(default_factory=dict)
# Data configuration
train_data_filter: Optional[Dict[str, Any]] = None # Filter criteria for training data
test_data_filter: Optional[Dict[str, Any]] = None
target_column: str = "sex"
# Training configuration
test_size: float = 0.2
random_seed: int = 42
cross_validation_folds: int = 5
# Evaluation configuration
metrics: List[str] = field(default_factory=lambda: ["accuracy", "precision", "recall", "f1"])
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for serialization"""
result = asdict(self)
# Convert enums to strings
result["features"] = [f.value for f in self.features]
return result
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ExperimentConfig":
"""Create from dictionary"""
if "features" in data:
data["features"] = [FeatureType(f) for f in data["features"]]
return cls(**data)
class ExperimentStatus(Enum):
"""Experiment execution status"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
def calculate_metrics(
y_true: np.ndarray, y_pred: np.ndarray, metrics: List[str] = None
) -> Dict[str, float]:
"""Calculate specified metrics"""
if metrics is None:
metrics = ["accuracy", "precision", "recall", "f1"]
results = {}
if "accuracy" in metrics:
results["accuracy"] = accuracy_score(y_true, y_pred)
if any(m in metrics for m in ["precision", "recall", "f1"]):
precision, recall, f1, _ = precision_recall_fscore_support(
y_true, y_pred, average="weighted"
)
if "precision" in metrics:
results["precision"] = precision
if "recall" in metrics:
results["recall"] = recall
if "f1" in metrics:
results["f1"] = f1
return results