feat: support gpu
This commit is contained in:
@@ -48,7 +48,7 @@ class BiGRUModel(NeuralNetworkModel):
|
||||
Dense(64, activation="relu"),
|
||||
Dropout(params.get("dropout", 0.5)),
|
||||
# Two-way softmax for binary gender classification.
|
||||
Dense(2, activation="softmax"),
|
||||
Dense(2, activation="softmax", dtype="float32"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ class CNNModel(NeuralNetworkModel):
|
||||
Dense(64, activation="relu"),
|
||||
Dropout(params.get("dropout", 0.5)),
|
||||
# Two-way softmax for binary classification.
|
||||
Dense(2, activation="softmax"),
|
||||
Dense(2, activation="softmax", dtype="float32"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -20,6 +20,12 @@ class LightGBMModel(TraditionalModel):
|
||||
def build_model(self) -> BaseEstimator:
|
||||
params = self.config.model_params
|
||||
|
||||
# Optional GPU acceleration
|
||||
use_gpu = bool(params.get("use_gpu", False))
|
||||
device = params.get("device", "gpu" if use_gpu else "cpu")
|
||||
gpu_platform_id = params.get("gpu_platform_id", None)
|
||||
gpu_device_id = params.get("gpu_device_id", None)
|
||||
|
||||
# Leaf-wise boosted trees excel on sparse/categorical mixes; binary objective
|
||||
# and parallelism improve training speed for this task.
|
||||
return lgb.LGBMClassifier(
|
||||
@@ -33,6 +39,9 @@ class LightGBMModel(TraditionalModel):
|
||||
objective=params.get("objective", "binary"),
|
||||
n_jobs=params.get("n_jobs", -1),
|
||||
verbose=2,
|
||||
device=device,
|
||||
gpu_platform_id=gpu_platform_id,
|
||||
gpu_device_id=gpu_device_id,
|
||||
)
|
||||
|
||||
def prepare_features(self, X: pd.DataFrame) -> np.ndarray:
|
||||
|
||||
@@ -45,7 +45,7 @@ class LSTMModel(NeuralNetworkModel):
|
||||
Dense(64, activation="relu"),
|
||||
Dropout(params.get("dropout", 0.5)),
|
||||
# Two-way softmax for binary classification.
|
||||
Dense(2, activation="softmax"),
|
||||
Dense(2, activation="softmax", dtype="float32"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ class TransformerModel(NeuralNetworkModel):
|
||||
x = GlobalAveragePooling1D()(x)
|
||||
x = Dense(32, activation="relu")(x)
|
||||
x = Dropout(params.get("dropout", 0.1))(x)
|
||||
outputs = Dense(2, activation="softmax")(x)
|
||||
outputs = Dense(2, activation="softmax", dtype="float32")(x)
|
||||
|
||||
model = Model(inputs, outputs)
|
||||
model.compile(
|
||||
|
||||
@@ -20,6 +20,14 @@ class XGBoostModel(TraditionalModel):
|
||||
def build_model(self) -> BaseEstimator:
|
||||
params = self.config.model_params
|
||||
|
||||
# Optional GPU acceleration
|
||||
use_gpu = bool(params.get("use_gpu", False))
|
||||
default_tree_method = "gpu_hist" if use_gpu else "hist"
|
||||
tree_method = params.get("tree_method", default_tree_method)
|
||||
predictor = params.get(
|
||||
"predictor", "gpu_predictor" if tree_method.startswith("gpu") else "auto"
|
||||
)
|
||||
|
||||
# Histogram-based trees and parallelism provide fast training; default
|
||||
# logloss metric suits binary classification of gender.
|
||||
return xgb.XGBClassifier(
|
||||
@@ -31,7 +39,8 @@ class XGBoostModel(TraditionalModel):
|
||||
random_state=self.config.random_seed,
|
||||
eval_metric="logloss",
|
||||
n_jobs=params.get("n_jobs", -1),
|
||||
tree_method=params.get("tree_method", "hist"),
|
||||
tree_method=tree_method,
|
||||
predictor=predictor,
|
||||
verbosity=2,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user