fix: normalize hyper params

This commit is contained in:
2025-09-21 13:10:07 +02:00
parent 83d21c640b
commit 63e23d6600
8 changed files with 26 additions and 19 deletions
+5 -6
View File
@@ -52,22 +52,21 @@ class NeuralNetworkModel(BaseModel):
logging.info(f"Vocabulary size: {vocab_size}")
# Get additional model parameters
max_len = self.config.model_params.get("max_len", 6)
self.model = self.build_model_with_vocab(
vocab_size=vocab_size, max_len=max_len, **self.config.model_params
)
self.model = self.build_model_with_vocab(vocab_size=vocab_size, **self.config.model_params)
# Train the neural network
logging.info(
f"Fitting model with {X_prepared.shape[0]} samples and {X_prepared.shape[1]} features"
)
logging.info(X_prepared[0])
logging.info(f"Model parameters: {self.config.model_params}")
history = self.model.fit(
X_prepared,
y_encoded,
epochs=self.config.model_params.get("epochs", 10),
batch_size=self.config.model_params.get("batch_size", 64),
validation_split=0.1,
validation_split=self.config.model_params.get("validation_split", 0.1),
verbose=2,
)