feat: support gpu

This commit is contained in:
2025-09-29 21:07:23 +02:00
parent 9e35f95107
commit a1d500830b
15 changed files with 661 additions and 85 deletions
+1 -1
View File
@@ -45,7 +45,7 @@ class TransformerModel(NeuralNetworkModel):
x = GlobalAveragePooling1D()(x)
x = Dense(32, activation="relu")(x)
x = Dropout(params.get("dropout", 0.1))(x)
outputs = Dense(2, activation="softmax")(x)
outputs = Dense(2, activation="softmax", dtype="float32")(x)
model = Model(inputs, outputs)
model.compile(