fix: lstm model

This commit is contained in:
2025-06-24 09:40:42 +02:00
parent d8980ec328
commit fb95c72ab7
3 changed files with 14 additions and 13 deletions
+6 -5
View File
@@ -111,8 +111,9 @@ def build_model(cfg: Config, vocab_size: int) -> Sequential:
logging.info("Building LSTM model")
model = Sequential([
Embedding(input_dim=vocab_size, output_dim=cfg.embedding_dim),
Bidirectional(LSTM(cfg.lstm_units, return_sequences=True)),
Bidirectional(LSTM(cfg.lstm_units)),
Dense(32, activation="relu"),
Dense(64, activation="relu"),
Dense(2, activation="softmax")
])
model.compile(loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
@@ -138,7 +139,7 @@ def evaluate_proba(y_true, y_proba, threshold, class_names):
classification report.
:return: None
"""
y_pred = 1 if y_proba[:, 1] >= threshold else 0
y_pred = (y_proba[:, 1] >= threshold).astype(int)
acc = accuracy_score(y_true, y_pred)
pr, rc, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')
cm = confusion_matrix(y_true, y_pred)
@@ -213,9 +214,9 @@ def save_artifacts(model, tokenizer, encoder):
:return: None
"""
model_path = os.path.join(GENDER_MODELS_DIR, "BiLSTM_model.h5")
tokenizer_path = os.path.join(GENDER_MODELS_DIR, "BiLSTM_tokenizer.pkl")
encoder_path = os.path.join(GENDER_MODELS_DIR, "BiLSTM_label_encoder.pkl")
model_path = os.path.join(GENDER_MODELS_DIR, "lstm_model.keras")
tokenizer_path = os.path.join(GENDER_MODELS_DIR, "lstm_tokenizer.pkl")
encoder_path = os.path.join(GENDER_MODELS_DIR, "lstm_label_encoder.pkl")
model.save(model_path)
with open(tokenizer_path, "wb") as f: