feat: support gpu
This commit is contained in:
@@ -20,6 +20,14 @@ class XGBoostModel(TraditionalModel):
|
||||
def build_model(self) -> BaseEstimator:
|
||||
params = self.config.model_params
|
||||
|
||||
# Optional GPU acceleration
|
||||
use_gpu = bool(params.get("use_gpu", False))
|
||||
default_tree_method = "gpu_hist" if use_gpu else "hist"
|
||||
tree_method = params.get("tree_method", default_tree_method)
|
||||
predictor = params.get(
|
||||
"predictor", "gpu_predictor" if tree_method.startswith("gpu") else "auto"
|
||||
)
|
||||
|
||||
# Histogram-based trees and parallelism provide fast training; default
|
||||
# logloss metric suits binary classification of gender.
|
||||
return xgb.XGBClassifier(
|
||||
@@ -31,7 +39,8 @@ class XGBoostModel(TraditionalModel):
|
||||
random_state=self.config.random_seed,
|
||||
eval_metric="logloss",
|
||||
n_jobs=params.get("n_jobs", -1),
|
||||
tree_method=params.get("tree_method", "hist"),
|
||||
tree_method=tree_method,
|
||||
predictor=predictor,
|
||||
verbosity=2,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user