feat: document models

This commit is contained in:
2025-09-20 23:35:54 +02:00
parent dd2a9f2711
commit e41b15a863
13 changed files with 256 additions and 47 deletions
+25 -10
View File
@@ -17,17 +17,38 @@ class BiGRUModel(NeuralNetworkModel):
params = kwargs
model = Sequential(
[
Embedding(input_dim=vocab_size, output_dim=params.get("embedding_dim", 64)),
# Mask padding tokens so recurrent layers ignore them; fix input length
# for better shape inference and to support masking through the stack.
Embedding(
input_dim=vocab_size,
output_dim=params.get("embedding_dim", 64),
input_length=max_len,
mask_zero=True,
),
# First recurrent block returns full sequences to allow stacking.
# Moderate dropout + optional recurrent_dropout to reduce overfitting
# on short names while retaining temporal signal.
Bidirectional(
GRU(
params.get("gru_units", 32),
return_sequences=True,
dropout=params.get("dropout", 0.2),
recurrent_dropout=params.get("recurrent_dropout", 0.0),
)
),
Bidirectional(GRU(params.get("gru_units", 32), dropout=params.get("dropout", 0.2))),
# Second GRU summarizes to the last hidden state (no return_sequences),
# capturing bidirectional context efficiently for classification.
Bidirectional(
GRU(
params.get("gru_units", 32),
dropout=params.get("dropout", 0.2),
recurrent_dropout=params.get("recurrent_dropout", 0.0),
)
),
# Small dense head; ReLU + dropout for capacity and regularization.
Dense(64, activation="relu"),
Dropout(params.get("dropout", 0.5)),
# Two-way softmax for binary gender classification.
Dense(2, activation="softmax"),
]
)
@@ -38,19 +59,13 @@ class BiGRUModel(NeuralNetworkModel):
return model
def prepare_features(self, X: pd.DataFrame) -> np.ndarray:
text_data = []
for feature_type in self.config.features:
if feature_type.value in X.columns:
text_data.extend(X[feature_type.value].astype(str).tolist())
if not text_data:
raise ValueError("No text data found in the provided DataFrame.")
text_data = self._collect_text_corpus(X)
if self.tokenizer is None:
self.tokenizer = Tokenizer(char_level=False, lower=True, oov_token="<OOV>")
self.tokenizer.fit_on_texts(text_data)
sequences = self.tokenizer.texts_to_sequences(text_data[: len(X)])
sequences = self.tokenizer.texts_to_sequences(text_data)
max_len = self.config.model_params.get("max_len", 6)
return pad_sequences(sequences, maxlen=max_len, padding="post")