Files
drc-ners-nlp/research/model_registry.py
T

45 lines
1.5 KiB
Python

from typing import List
from research.base_model import BaseModel
from research.experiment import ExperimentConfig
from research.models.bigru_model import BiGRUModel
from research.models.cnn_model import CNNModel
from research.models.ensemble_model import EnsembleModel
from research.models.lightgbm_model import LightGBMModel
from research.models.logistic_regression_model import LogisticRegressionModel
from research.models.lstm_model import LSTMModel
from research.models.naive_bayes_model import NaiveBayesModel
from research.models.random_forest_model import RandomForestModel
from research.models.svm_model import SVMModel
from research.models.transformer_model import TransformerModel
from research.models.xgboost_model import XGBoostModel
MODEL_REGISTRY = {
"bigru": BiGRUModel,
"cnn": CNNModel,
"ensemble": EnsembleModel,
"lightgbm": LightGBMModel,
"logistic_regression": LogisticRegressionModel,
"lstm": LSTMModel,
"naive_bayes": NaiveBayesModel,
"random_forest": RandomForestModel,
"svm": SVMModel,
"transformer": TransformerModel,
"xgboost": XGBoostModel,
}
def create_model(config: ExperimentConfig) -> BaseModel:
"""Factory function to create models"""
model_class = MODEL_REGISTRY.get(config.model_type)
if model_class is None:
raise ValueError(f"Unknown model type: {config.model_type}")
return model_class(config)
def list_available_models() -> List[str]:
"""List all available model types"""
return list(MODEL_REGISTRY.keys())