fix: normalize hyper params

This commit is contained in:
2025-09-21 13:10:07 +02:00
parent 83d21c640b
commit 63e23d6600
8 changed files with 26 additions and 19 deletions
+5 -5
View File
@@ -22,21 +22,21 @@ from research.neural_network_model import NeuralNetworkModel
class TransformerModel(NeuralNetworkModel):
"""Transformer-based model"""
def build_model_with_vocab(self, vocab_size: int, max_len: int = 6, **kwargs) -> Any:
def build_model_with_vocab(self, vocab_size: int, **kwargs) -> Any:
params = kwargs
# Build Transformer model
inputs = Input(shape=(max_len,))
inputs = Input(shape=(params.get("max_len", 8),))
x = Embedding(
input_dim=vocab_size,
output_dim=params.get("embedding_dim", 64),
input_length=max_len,
input_length=params.get("max_len", 8),
mask_zero=True,
)(inputs)
# Add positional encoding
positions = tf.range(start=0, limit=max_len, delta=1)
pos_embedding = Embedding(input_dim=max_len, output_dim=params.get("embedding_dim", 64))(
positions = tf.range(start=0, limit=params.get("max_len", 8), delta=1)
pos_embedding = Embedding(input_dim=params.get("max_len", 8), output_dim=params.get("embedding_dim", 64))(
positions
)
x = x + pos_embedding