refactoring: add initial pipeline configuration and model classes
This commit is contained in:
@@ -0,0 +1,44 @@
|
||||
from typing import List
|
||||
|
||||
from research.base_model import BaseModel
|
||||
from research.experiment import ExperimentConfig
|
||||
from research.models.bigru_model import BiGRUModel
|
||||
from research.models.cnn_model import CNNModel
|
||||
from research.models.ensemble_model import EnsembleModel
|
||||
from research.models.lightgbm_model import LightGBMModel
|
||||
from research.models.logistic_regression_model import LogisticRegressionModel
|
||||
from research.models.lstm_model import LSTMModel
|
||||
from research.models.naive_bayes_model import NaiveBayesModel
|
||||
from research.models.random_forest_model import RandomForestModel
|
||||
from research.models.svm_model import SVMModel
|
||||
from research.models.transformer_model import TransformerModel
|
||||
from research.models.xgboost_model import XGBoostModel
|
||||
|
||||
MODEL_REGISTRY = {
|
||||
"bigru": BiGRUModel,
|
||||
"cnn": CNNModel,
|
||||
"ensemble": EnsembleModel,
|
||||
"lightgbm": LightGBMModel,
|
||||
"logistic_regression": LogisticRegressionModel,
|
||||
"lstm": LSTMModel,
|
||||
"naive_bayes": NaiveBayesModel,
|
||||
"random_forest": RandomForestModel,
|
||||
"svm": SVMModel,
|
||||
"transformer": TransformerModel,
|
||||
"xgboost": XGBoostModel,
|
||||
}
|
||||
|
||||
|
||||
def create_model(config: ExperimentConfig) -> BaseModel:
|
||||
"""Factory function to create models"""
|
||||
model_class = MODEL_REGISTRY.get(config.model_type)
|
||||
|
||||
if model_class is None:
|
||||
raise ValueError(f"Unknown model type: {config.model_type}")
|
||||
|
||||
return model_class(config)
|
||||
|
||||
|
||||
def list_available_models() -> List[str]:
|
||||
"""List all available model types"""
|
||||
return list(MODEL_REGISTRY.keys())
|
||||
Reference in New Issue
Block a user