Skip to content

Commit 823fd7f

Browse files
committed
xgboost trinaing set to gpu if available
1 parent 33adda5 commit 823fd7f

1 file changed

Lines changed: 12 additions & 0 deletions

File tree

src/model/train_model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,20 @@ def train_pdf_classifier(texts, labels, output_dir="src/model/models"):
9292
dtrain = xgb.DMatrix(X_train_vec, label=y_train)
9393
dtest = xgb.DMatrix(X_test_vec, label=y_test)
9494

95+
# Use GPU if available (e.g. HPC gpu nodes), fall back to CPU
96+
try:
97+
_probe = xgb.DMatrix(X_train_vec[:1], label=y_train[:1])
98+
xgb.train({"device": "cuda", "tree_method": "hist"}, _probe, num_boost_round=1)
99+
device, tree_method = "cuda", "hist"
100+
print("[INFO] GPU detected — training with CUDA.")
101+
except xgb.core.XGBoostError:
102+
device, tree_method = "cpu", "hist"
103+
print("[INFO] No GPU available — training on CPU.")
104+
95105
# XGBoost parameters
96106
params = {
107+
"device": device,
108+
"tree_method": tree_method,
97109
"objective": "binary:logistic", # binary classification
98110
"eval_metric": "logloss", # log loss metric
99111
"eta": 0.05, # learning rate

0 commit comments

Comments
 (0)