fix: lstm model
This commit is contained in:
+4
-4
@@ -86,9 +86,9 @@ def evaluate_lstm(df, threshold, max_len=6):
|
||||
- encoder.classes_: An array of class names corresponding to the label encoding.
|
||||
:rtype: Tuple
|
||||
"""
|
||||
model = tf.keras.models.load_model(os.path.join(GENDER_MODELS_DIR, "BiLSTM_model.h5"))
|
||||
tokenizer = load_pickle(os.path.join(GENDER_MODELS_DIR, "BiLSTM_tokenizer.pkl"))
|
||||
encoder = load_pickle(os.path.join(GENDER_MODELS_DIR, "BiLSTM_label_encoder.pkl"))
|
||||
model = tf.keras.models.load_model(os.path.join(GENDER_MODELS_DIR, "lstm_model.keras"))
|
||||
tokenizer = load_pickle(os.path.join(GENDER_MODELS_DIR, "lstm_tokenizer.pkl"))
|
||||
encoder = load_pickle(os.path.join(GENDER_MODELS_DIR, "lstm_label_encoder.pkl"))
|
||||
|
||||
sequences = tokenizer.texts_to_sequences(df["name"])
|
||||
X = pad_sequences(sequences, maxlen=max_len, padding="post")
|
||||
@@ -118,7 +118,7 @@ def evaluate_transformer(df, threshold, max_len=6):
|
||||
probabilities for the positive class, and a list of the label classes.
|
||||
:rtype: Tuple
|
||||
"""
|
||||
model = tf.keras.models.load_model(os.path.join(GENDER_MODELS_DIR, "transformer.h5"))
|
||||
model = tf.keras.models.load_model(os.path.join(GENDER_MODELS_DIR, "transformer.keras"))
|
||||
tokenizer = load_pickle(os.path.join(GENDER_MODELS_DIR, "transformer_tokenizer.pkl"))
|
||||
encoder = load_pickle(os.path.join(GENDER_MODELS_DIR, "transformer_label_encoder.pkl"))
|
||||
|
||||
|
||||
@@ -111,8 +111,9 @@ def build_model(cfg: Config, vocab_size: int) -> Sequential:
|
||||
logging.info("Building LSTM model")
|
||||
model = Sequential([
|
||||
Embedding(input_dim=vocab_size, output_dim=cfg.embedding_dim),
|
||||
Bidirectional(LSTM(cfg.lstm_units, return_sequences=True)),
|
||||
Bidirectional(LSTM(cfg.lstm_units)),
|
||||
Dense(32, activation="relu"),
|
||||
Dense(64, activation="relu"),
|
||||
Dense(2, activation="softmax")
|
||||
])
|
||||
model.compile(loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
|
||||
@@ -138,7 +139,7 @@ def evaluate_proba(y_true, y_proba, threshold, class_names):
|
||||
classification report.
|
||||
:return: None
|
||||
"""
|
||||
y_pred = 1 if y_proba[:, 1] >= threshold else 0
|
||||
y_pred = (y_proba[:, 1] >= threshold).astype(int)
|
||||
acc = accuracy_score(y_true, y_pred)
|
||||
pr, rc, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')
|
||||
cm = confusion_matrix(y_true, y_pred)
|
||||
@@ -213,9 +214,9 @@ def save_artifacts(model, tokenizer, encoder):
|
||||
|
||||
:return: None
|
||||
"""
|
||||
model_path = os.path.join(GENDER_MODELS_DIR, "BiLSTM_model.h5")
|
||||
tokenizer_path = os.path.join(GENDER_MODELS_DIR, "BiLSTM_tokenizer.pkl")
|
||||
encoder_path = os.path.join(GENDER_MODELS_DIR, "BiLSTM_label_encoder.pkl")
|
||||
model_path = os.path.join(GENDER_MODELS_DIR, "lstm_model.keras")
|
||||
tokenizer_path = os.path.join(GENDER_MODELS_DIR, "lstm_tokenizer.pkl")
|
||||
encoder_path = os.path.join(GENDER_MODELS_DIR, "lstm_label_encoder.pkl")
|
||||
|
||||
model.save(model_path)
|
||||
with open(tokenizer_path, "wb") as f:
|
||||
|
||||
@@ -65,9 +65,9 @@ def predict_lstm(names: List[str], threshold: float, max_len=6):
|
||||
categories, and probabilities are the prediction scores for each input name.
|
||||
:rtype: Tuple[numpy.ndarray, numpy.ndarray]
|
||||
"""
|
||||
model_path = os.path.join(GENDER_MODELS_DIR, "BiLSTM_model.h5")
|
||||
tokenizer_path = os.path.join(GENDER_MODELS_DIR, "BiLSTM_tokenizer.pkl")
|
||||
encoder_path = os.path.join(GENDER_MODELS_DIR, "BiLSTM_label_encoder.pkl")
|
||||
model_path = os.path.join(GENDER_MODELS_DIR, "lstm_model.keras")
|
||||
tokenizer_path = os.path.join(GENDER_MODELS_DIR, "lstm_tokenizer.pkl")
|
||||
encoder_path = os.path.join(GENDER_MODELS_DIR, "lstm_label_encoder.pkl")
|
||||
|
||||
model = tf.keras.models.load_model(model_path)
|
||||
tokenizer: Tokenizer = load_pickle(tokenizer_path)
|
||||
@@ -104,7 +104,7 @@ def predict_transformer(names: List[str], threshold: float, max_len=6):
|
||||
corresponds to one class, and the second index corresponds to another).
|
||||
:rtype: Tuple[List[str], numpy.ndarray]
|
||||
"""
|
||||
model_path = os.path.join(GENDER_MODELS_DIR, "transformer.h5")
|
||||
model_path = os.path.join(GENDER_MODELS_DIR, "transformer.keras")
|
||||
tokenizer_path = os.path.join(GENDER_MODELS_DIR, "transformer_tokenizer.pkl")
|
||||
encoder_path = os.path.join(GENDER_MODELS_DIR, "transformer_label_encoder.pkl")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user