feat: support gpu

This commit is contained in:
2025-09-29 21:07:23 +02:00
parent 9e35f95107
commit a1d500830b
15 changed files with 661 additions and 85 deletions
+10 -1
View File
@@ -20,6 +20,14 @@ class XGBoostModel(TraditionalModel):
def build_model(self) -> BaseEstimator:
params = self.config.model_params
# Optional GPU acceleration
use_gpu = bool(params.get("use_gpu", False))
default_tree_method = "gpu_hist" if use_gpu else "hist"
tree_method = params.get("tree_method", default_tree_method)
predictor = params.get(
"predictor", "gpu_predictor" if tree_method.startswith("gpu") else "auto"
)
# Histogram-based trees and parallelism provide fast training; default
# logloss metric suits binary classification of gender.
return xgb.XGBClassifier(
@@ -31,7 +39,8 @@ class XGBoostModel(TraditionalModel):
random_state=self.config.random_seed,
eval_metric="logloss",
n_jobs=params.get("n_jobs", -1),
tree_method=params.get("tree_method", "hist"),
tree_method=tree_method,
predictor=predictor,
verbosity=2,
)