45 lines
1.5 KiB
Python
45 lines
1.5 KiB
Python
from typing import List
|
|
|
|
from research.base_model import BaseModel
|
|
from research.experiment import ExperimentConfig
|
|
from research.models.bigru_model import BiGRUModel
|
|
from research.models.cnn_model import CNNModel
|
|
from research.models.ensemble_model import EnsembleModel
|
|
from research.models.lightgbm_model import LightGBMModel
|
|
from research.models.logistic_regression_model import LogisticRegressionModel
|
|
from research.models.lstm_model import LSTMModel
|
|
from research.models.naive_bayes_model import NaiveBayesModel
|
|
from research.models.random_forest_model import RandomForestModel
|
|
from research.models.svm_model import SVMModel
|
|
from research.models.transformer_model import TransformerModel
|
|
from research.models.xgboost_model import XGBoostModel
|
|
|
|
MODEL_REGISTRY = {
|
|
"bigru": BiGRUModel,
|
|
"cnn": CNNModel,
|
|
"ensemble": EnsembleModel,
|
|
"lightgbm": LightGBMModel,
|
|
"logistic_regression": LogisticRegressionModel,
|
|
"lstm": LSTMModel,
|
|
"naive_bayes": NaiveBayesModel,
|
|
"random_forest": RandomForestModel,
|
|
"svm": SVMModel,
|
|
"transformer": TransformerModel,
|
|
"xgboost": XGBoostModel,
|
|
}
|
|
|
|
|
|
def create_model(config: ExperimentConfig) -> BaseModel:
|
|
"""Factory function to create models"""
|
|
model_class = MODEL_REGISTRY.get(config.model_type)
|
|
|
|
if model_class is None:
|
|
raise ValueError(f"Unknown model type: {config.model_type}")
|
|
|
|
return model_class(config)
|
|
|
|
|
|
def list_available_models() -> List[str]:
|
|
"""List all available model types"""
|
|
return list(MODEL_REGISTRY.keys())
|