feat: document models

This commit is contained in:
2025-09-20 23:35:54 +02:00
parent dd2a9f2711
commit e41b15a863
13 changed files with 256 additions and 47 deletions
+10 -9
View File
@@ -27,7 +27,12 @@ class TransformerModel(NeuralNetworkModel):
# Build Transformer model
inputs = Input(shape=(max_len,))
x = Embedding(input_dim=vocab_size, output_dim=params.get("embedding_dim", 64))(inputs)
x = Embedding(
input_dim=vocab_size,
output_dim=params.get("embedding_dim", 64),
input_length=max_len,
mask_zero=True,
)(inputs)
# Add positional encoding
positions = tf.range(start=0, limit=max_len, delta=1)
@@ -39,6 +44,7 @@ class TransformerModel(NeuralNetworkModel):
x = self._transformer_encoder(x, params)
x = GlobalAveragePooling1D()(x)
x = Dense(32, activation="relu")(x)
x = Dropout(params.get("dropout", 0.1))(x)
outputs = Dense(2, activation="softmax")(x)
model = Model(inputs, outputs)
@@ -54,6 +60,7 @@ class TransformerModel(NeuralNetworkModel):
attn = MultiHeadAttention(
num_heads=cfg_params.get("transformer_num_heads", 2),
key_dim=cfg_params.get("transformer_head_size", 64),
dropout=cfg_params.get("attn_dropout", 0.1),
)(x, x)
x = LayerNormalization(epsilon=1e-6)(x + Dropout(cfg_params.get("dropout", 0.1))(attn))
@@ -62,13 +69,7 @@ class TransformerModel(NeuralNetworkModel):
return LayerNormalization(epsilon=1e-6)(x + Dropout(cfg_params.get("dropout", 0.1))(ff))
def prepare_features(self, X: pd.DataFrame) -> np.ndarray:
text_data = []
for feature_type in self.config.features:
if feature_type.value in X.columns:
text_data.extend(X[feature_type.value].astype(str).tolist())
if not text_data:
raise ValueError("No text data found in the provided DataFrame.")
text_data = self._collect_text_corpus(X)
# Initialize tokenizer if needed
if self.tokenizer is None:
@@ -76,7 +77,7 @@ class TransformerModel(NeuralNetworkModel):
self.tokenizer.fit_on_texts(text_data)
# Convert to sequences
sequences = self.tokenizer.texts_to_sequences(text_data[: len(X)])
sequences = self.tokenizer.texts_to_sequences(text_data)
max_len = self.config.model_params.get("max_len", 6)
return pad_sequences(sequences, maxlen=max_len, padding="post")