fix: normalize hyper params
This commit is contained in:
@@ -52,22 +52,21 @@ class NeuralNetworkModel(BaseModel):
|
||||
logging.info(f"Vocabulary size: {vocab_size}")
|
||||
|
||||
# Get additional model parameters
|
||||
max_len = self.config.model_params.get("max_len", 6)
|
||||
|
||||
self.model = self.build_model_with_vocab(
|
||||
vocab_size=vocab_size, max_len=max_len, **self.config.model_params
|
||||
)
|
||||
self.model = self.build_model_with_vocab(vocab_size=vocab_size, **self.config.model_params)
|
||||
|
||||
# Train the neural network
|
||||
logging.info(
|
||||
f"Fitting model with {X_prepared.shape[0]} samples and {X_prepared.shape[1]} features"
|
||||
)
|
||||
logging.info(X_prepared[0])
|
||||
logging.info(f"Model parameters: {self.config.model_params}")
|
||||
|
||||
history = self.model.fit(
|
||||
X_prepared,
|
||||
y_encoded,
|
||||
epochs=self.config.model_params.get("epochs", 10),
|
||||
batch_size=self.config.model_params.get("batch_size", 64),
|
||||
validation_split=0.1,
|
||||
validation_split=self.config.model_params.get("validation_split", 0.1),
|
||||
verbose=2,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user