refactor: reorganize project structure and enhance model verbosity
This commit is contained in:
+33
-32
@@ -11,6 +11,7 @@ from core.utils.data_loader import DataLoader
|
||||
from research.experiment import FeatureType, ExperimentConfig
|
||||
from research.experiment.experiment_runner import ExperimentRunner
|
||||
from research.experiment.experiment_tracker import ExperimentTracker
|
||||
from research.model_registry import MODEL_REGISTRY
|
||||
|
||||
|
||||
class ModelTrainer:
|
||||
@@ -21,25 +22,24 @@ class ModelTrainer:
|
||||
self.data_loader = DataLoader(self.config)
|
||||
self.experiment_runner = ExperimentRunner(self.config)
|
||||
self.experiment_tracker = ExperimentTracker(self.config)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Setup model artifacts directory
|
||||
self.models_dir = self.config.paths.models_dir
|
||||
self.models_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def train_single_model(
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: str = "logistic_regression",
|
||||
features: List[str] = None,
|
||||
model_params: Dict[str, Any] = None,
|
||||
save_artifacts: bool = True,
|
||||
self,
|
||||
model_name: str,
|
||||
model_type: str = "logistic_regression",
|
||||
features: List[str] = None,
|
||||
model_params: Dict[str, Any] = None,
|
||||
save_artifacts: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Train a single model and save its artifacts.
|
||||
Returns the experiment ID.
|
||||
"""
|
||||
self.logger.info(f"Training {model_type} model: {model_name}")
|
||||
logging.info(f"Training {model_type} model: {model_name}")
|
||||
|
||||
if features is None:
|
||||
features = ["full_name"]
|
||||
@@ -60,10 +60,10 @@ class ModelTrainer:
|
||||
experiment = self.experiment_tracker.get_experiment(experiment_id)
|
||||
|
||||
if experiment and experiment.test_metrics:
|
||||
self.logger.info("Training completed successfully!")
|
||||
self.logger.info(f" Experiment ID: {experiment_id}")
|
||||
self.logger.info(f" Test Accuracy: {experiment.test_metrics.get('accuracy', 0):.4f}")
|
||||
self.logger.info(f" Test F1-Score: {experiment.test_metrics.get('f1', 0):.4f}")
|
||||
logging.info("Training completed successfully!")
|
||||
logging.info(f"Experiment ID: {experiment_id}")
|
||||
logging.info(f"Test Accuracy: {experiment.test_metrics.get('accuracy', 0):.4f}")
|
||||
logging.info(f"Test F1-Score: {experiment.test_metrics.get('f1', 0):.4f}")
|
||||
|
||||
if save_artifacts:
|
||||
self.save_model_artifacts(experiment_id)
|
||||
@@ -71,12 +71,15 @@ class ModelTrainer:
|
||||
return experiment_id
|
||||
|
||||
def train_multiple_models(
|
||||
self, base_name: str, model_configs: List[Dict[str, Any]], save_all: bool = True
|
||||
self,
|
||||
base_name: str,
|
||||
model_configs: List[Dict[str, Any]],
|
||||
save_all: bool = True
|
||||
) -> List[str]:
|
||||
"""
|
||||
Train multiple models with different configurations.
|
||||
"""
|
||||
self.logger.info(f"Training {len(model_configs)} models...")
|
||||
logging.info(f"Training {len(model_configs)} models...")
|
||||
|
||||
experiment_ids = []
|
||||
|
||||
@@ -94,10 +97,10 @@ class ModelTrainer:
|
||||
experiment_ids.append(exp_id)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to train {model_name}: {e}")
|
||||
logging.error(f"Failed to train {model_name}: {e}")
|
||||
continue
|
||||
|
||||
self.logger.info(f"Completed training {len(experiment_ids)} models successfully")
|
||||
logging.info(f"Completed training {len(experiment_ids)} models successfully")
|
||||
return experiment_ids
|
||||
|
||||
def save_model_artifacts(self, experiment_id: str) -> Dict[str, str]:
|
||||
@@ -145,7 +148,7 @@ class ModelTrainer:
|
||||
df = self.data_loader.load_csv_complete(data_path)
|
||||
|
||||
# Generate learning curve
|
||||
self.logger.info("Generating learning curve...")
|
||||
logging.info("Generating learning curve...")
|
||||
trained_model.generate_learning_curve(df, df[experiment.config.target_column])
|
||||
|
||||
# Plot and save learning curve
|
||||
@@ -169,7 +172,7 @@ class ModelTrainer:
|
||||
json.dump(trained_model.training_history, f, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Could not generate learning curves: {e}")
|
||||
logging.warning(f"Could not generate learning curves: {e}")
|
||||
|
||||
# Save artifacts metadata
|
||||
metadata = {
|
||||
@@ -193,17 +196,17 @@ class ModelTrainer:
|
||||
with open(metadata_path, "w") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
self.logger.info(f"Model artifacts saved to: {model_dir}")
|
||||
self.logger.info(f" - Complete model: {model_path.name}")
|
||||
self.logger.info(f" - Configuration: {config_path.name}")
|
||||
self.logger.info(f" - Results: {results_path.name}")
|
||||
self.logger.info(f" - Metadata: {metadata_path.name}")
|
||||
logging.info(f"Model artifacts saved to: {model_dir}")
|
||||
logging.info(f" - Complete model: {model_path.name}")
|
||||
logging.info(f" - Configuration: {config_path.name}")
|
||||
logging.info(f" - Results: {results_path.name}")
|
||||
logging.info(f" - Metadata: {metadata_path.name}")
|
||||
|
||||
if learning_curve_path and learning_curve_path.exists():
|
||||
self.logger.info(f" - Learning curve: {learning_curve_path.name}")
|
||||
logging.info(f" - Learning curve: {learning_curve_path.name}")
|
||||
|
||||
if training_history_path and training_history_path.exists():
|
||||
self.logger.info(f" - Training history: {training_history_path.name}")
|
||||
logging.info(f" - Training history: {training_history_path.name}")
|
||||
|
||||
return {
|
||||
"model_dir": str(model_dir),
|
||||
@@ -231,16 +234,14 @@ class ModelTrainer:
|
||||
metadata = json.load(f)
|
||||
|
||||
model_type = metadata["model_type"]
|
||||
from research.model_registry import MODEL_REGISTRY
|
||||
|
||||
model_class = MODEL_REGISTRY[model_type]
|
||||
|
||||
# Load the complete model
|
||||
loaded_model = model_class.load(str(model_path))
|
||||
|
||||
self.logger.info(f"Loaded model: {metadata['model_name']}")
|
||||
self.logger.info(f" Type: {model_type}")
|
||||
self.logger.info(f" Accuracy: {metadata['test_accuracy']:.4f}")
|
||||
logging.info(f"Loaded model: {metadata['model_name']}")
|
||||
logging.info(f" Type: {model_type}")
|
||||
logging.info(f" Accuracy: {metadata['test_accuracy']:.4f}")
|
||||
|
||||
return loaded_model
|
||||
|
||||
@@ -259,10 +260,10 @@ class ModelTrainer:
|
||||
metadata = json.load(f)
|
||||
models_data.append(metadata)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Could not read metadata for {model_dir.name}: {e}")
|
||||
logging.warning(f"Could not read metadata for {model_dir.name}: {e}")
|
||||
|
||||
if not models_data:
|
||||
self.logger.info("No saved models found.")
|
||||
logging.info("No saved models found.")
|
||||
return pd.DataFrame()
|
||||
|
||||
df = pd.DataFrame(models_data)
|
||||
|
||||
Reference in New Issue
Block a user