feat: document models

This commit is contained in:
2025-09-20 23:35:54 +02:00
parent dd2a9f2711
commit e41b15a863
13 changed files with 256 additions and 47 deletions
+15 -9
View File
@@ -9,6 +9,7 @@ from tensorflow.keras.layers import (
GlobalMaxPooling1D,
Dense,
Dropout,
SpatialDropout1D,
)
from tensorflow.keras.models import Sequential
@@ -24,21 +25,33 @@ class CNNModel(NeuralNetworkModel):
params = kwargs
model = Sequential(
[
# Learn char/subword embeddings; spatial dropout regularizes across channels
# to make the model robust to noisy characters and transliteration.
Embedding(input_dim=vocab_size, output_dim=params.get("embedding_dim", 64)),
SpatialDropout1D(rate=params.get("embedding_dropout", 0.1)),
# Small kernels capture short n-gram like patterns; padding='same' keeps
# sequence length stable for simpler pooling behavior.
Conv1D(
filters=params.get("filters", 64),
kernel_size=params.get("kernel_size", 3),
activation="relu",
padding="same",
),
# Downsample to gain some position invariance and reduce computation.
MaxPooling1D(pool_size=2),
# Second conv layer to compose higher-level motifs (e.g., suffix+vowel).
Conv1D(
filters=params.get("filters", 64),
kernel_size=params.get("kernel_size", 3),
activation="relu",
padding="same",
),
# Global max pooling picks strongest motif evidence anywhere in the name.
GlobalMaxPooling1D(),
# Compact dense head with dropout to control overfitting.
Dense(64, activation="relu"),
Dropout(params.get("dropout", 0.5)),
# Two-way softmax for binary classification.
Dense(2, activation="softmax"),
]
)
@@ -55,21 +68,14 @@ class CNNModel(NeuralNetworkModel):
from tensorflow.keras.preprocessing.sequence import pad_sequences
# Get text data from extracted features - use character level for CNN
text_data = []
for feature_type in self.config.features:
if feature_type.value in X.columns:
text_data.extend(X[feature_type.value].astype(str).tolist())
if not text_data:
# Fallback - should not happen if FeatureExtractor is properly configured
text_data = [""] * len(X)
text_data = self._collect_text_corpus(X)
# Initialize character-level tokenizer
if self.tokenizer is None:
self.tokenizer = Tokenizer(char_level=True, lower=True, oov_token="<OOV>")
self.tokenizer.fit_on_texts(text_data)
sequences = self.tokenizer.texts_to_sequences(text_data[: len(X)])
sequences = self.tokenizer.texts_to_sequences(text_data)
max_len = self.config.model_params.get("max_len", 20) # Longer for character level
return pad_sequences(sequences, maxlen=max_len, padding="post")