feat: document models
This commit is contained in:
@@ -27,7 +27,12 @@ class TransformerModel(NeuralNetworkModel):
|
||||
|
||||
# Build Transformer model
|
||||
inputs = Input(shape=(max_len,))
|
||||
x = Embedding(input_dim=vocab_size, output_dim=params.get("embedding_dim", 64))(inputs)
|
||||
x = Embedding(
|
||||
input_dim=vocab_size,
|
||||
output_dim=params.get("embedding_dim", 64),
|
||||
input_length=max_len,
|
||||
mask_zero=True,
|
||||
)(inputs)
|
||||
|
||||
# Add positional encoding
|
||||
positions = tf.range(start=0, limit=max_len, delta=1)
|
||||
@@ -39,6 +44,7 @@ class TransformerModel(NeuralNetworkModel):
|
||||
x = self._transformer_encoder(x, params)
|
||||
x = GlobalAveragePooling1D()(x)
|
||||
x = Dense(32, activation="relu")(x)
|
||||
x = Dropout(params.get("dropout", 0.1))(x)
|
||||
outputs = Dense(2, activation="softmax")(x)
|
||||
|
||||
model = Model(inputs, outputs)
|
||||
@@ -54,6 +60,7 @@ class TransformerModel(NeuralNetworkModel):
|
||||
attn = MultiHeadAttention(
|
||||
num_heads=cfg_params.get("transformer_num_heads", 2),
|
||||
key_dim=cfg_params.get("transformer_head_size", 64),
|
||||
dropout=cfg_params.get("attn_dropout", 0.1),
|
||||
)(x, x)
|
||||
x = LayerNormalization(epsilon=1e-6)(x + Dropout(cfg_params.get("dropout", 0.1))(attn))
|
||||
|
||||
@@ -62,13 +69,7 @@ class TransformerModel(NeuralNetworkModel):
|
||||
return LayerNormalization(epsilon=1e-6)(x + Dropout(cfg_params.get("dropout", 0.1))(ff))
|
||||
|
||||
def prepare_features(self, X: pd.DataFrame) -> np.ndarray:
|
||||
text_data = []
|
||||
for feature_type in self.config.features:
|
||||
if feature_type.value in X.columns:
|
||||
text_data.extend(X[feature_type.value].astype(str).tolist())
|
||||
|
||||
if not text_data:
|
||||
raise ValueError("No text data found in the provided DataFrame.")
|
||||
text_data = self._collect_text_corpus(X)
|
||||
|
||||
# Initialize tokenizer if needed
|
||||
if self.tokenizer is None:
|
||||
@@ -76,7 +77,7 @@ class TransformerModel(NeuralNetworkModel):
|
||||
self.tokenizer.fit_on_texts(text_data)
|
||||
|
||||
# Convert to sequences
|
||||
sequences = self.tokenizer.texts_to_sequences(text_data[: len(X)])
|
||||
sequences = self.tokenizer.texts_to_sequences(text_data)
|
||||
max_len = self.config.model_params.get("max_len", 6)
|
||||
|
||||
return pad_sequences(sequences, maxlen=max_len, padding="post")
|
||||
|
||||
Reference in New Issue
Block a user