feat: support gpu
This commit is contained in:
@@ -20,6 +20,12 @@ class LightGBMModel(TraditionalModel):
|
||||
def build_model(self) -> BaseEstimator:
|
||||
params = self.config.model_params
|
||||
|
||||
# Optional GPU acceleration
|
||||
use_gpu = bool(params.get("use_gpu", False))
|
||||
device = params.get("device", "gpu" if use_gpu else "cpu")
|
||||
gpu_platform_id = params.get("gpu_platform_id", None)
|
||||
gpu_device_id = params.get("gpu_device_id", None)
|
||||
|
||||
# Leaf-wise boosted trees excel on sparse/categorical mixes; binary objective
|
||||
# and parallelism improve training speed for this task.
|
||||
return lgb.LGBMClassifier(
|
||||
@@ -33,6 +39,9 @@ class LightGBMModel(TraditionalModel):
|
||||
objective=params.get("objective", "binary"),
|
||||
n_jobs=params.get("n_jobs", -1),
|
||||
verbose=2,
|
||||
device=device,
|
||||
gpu_platform_id=gpu_platform_id,
|
||||
gpu_device_id=gpu_device_id,
|
||||
)
|
||||
|
||||
def prepare_features(self, X: pd.DataFrame) -> np.ndarray:
|
||||
|
||||
Reference in New Issue
Block a user