feat: support gpu
This commit is contained in:
@@ -45,7 +45,7 @@ class TransformerModel(NeuralNetworkModel):
|
||||
x = GlobalAveragePooling1D()(x)
|
||||
x = Dense(32, activation="relu")(x)
|
||||
x = Dropout(params.get("dropout", 0.1))(x)
|
||||
outputs = Dense(2, activation="softmax")(x)
|
||||
outputs = Dense(2, activation="softmax", dtype="float32")(x)
|
||||
|
||||
model = Model(inputs, outputs)
|
||||
model.compile(
|
||||
|
||||
Reference in New Issue
Block a user