feat: support gpu

This commit is contained in:
2025-09-29 21:07:23 +02:00
parent 9e35f95107
commit a1d500830b
15 changed files with 661 additions and 85 deletions
+9
View File
@@ -20,6 +20,12 @@ class LightGBMModel(TraditionalModel):
def build_model(self) -> BaseEstimator:
params = self.config.model_params
# Optional GPU acceleration
use_gpu = bool(params.get("use_gpu", False))
device = params.get("device", "gpu" if use_gpu else "cpu")
gpu_platform_id = params.get("gpu_platform_id", None)
gpu_device_id = params.get("gpu_device_id", None)
# Leaf-wise boosted trees excel on sparse/categorical mixes; binary objective
# and parallelism improve training speed for this task.
return lgb.LGBMClassifier(
@@ -33,6 +39,9 @@ class LightGBMModel(TraditionalModel):
objective=params.get("objective", "binary"),
n_jobs=params.get("n_jobs", -1),
verbose=2,
device=device,
gpu_platform_id=gpu_platform_id,
gpu_device_id=gpu_device_id,
)
def prepare_features(self, X: pd.DataFrame) -> np.ndarray: