fix: lstm model
This commit is contained in:
@@ -65,9 +65,9 @@ def predict_lstm(names: List[str], threshold: float, max_len=6):
|
||||
categories, and probabilities are the prediction scores for each input name.
|
||||
:rtype: Tuple[numpy.ndarray, numpy.ndarray]
|
||||
"""
|
||||
model_path = os.path.join(GENDER_MODELS_DIR, "BiLSTM_model.h5")
|
||||
tokenizer_path = os.path.join(GENDER_MODELS_DIR, "BiLSTM_tokenizer.pkl")
|
||||
encoder_path = os.path.join(GENDER_MODELS_DIR, "BiLSTM_label_encoder.pkl")
|
||||
model_path = os.path.join(GENDER_MODELS_DIR, "lstm_model.keras")
|
||||
tokenizer_path = os.path.join(GENDER_MODELS_DIR, "lstm_tokenizer.pkl")
|
||||
encoder_path = os.path.join(GENDER_MODELS_DIR, "lstm_label_encoder.pkl")
|
||||
|
||||
model = tf.keras.models.load_model(model_path)
|
||||
tokenizer: Tokenizer = load_pickle(tokenizer_path)
|
||||
@@ -104,7 +104,7 @@ def predict_transformer(names: List[str], threshold: float, max_len=6):
|
||||
corresponds to one class, and the second index corresponds to another).
|
||||
:rtype: Tuple[List[str], numpy.ndarray]
|
||||
"""
|
||||
model_path = os.path.join(GENDER_MODELS_DIR, "transformer.h5")
|
||||
model_path = os.path.join(GENDER_MODELS_DIR, "transformer.keras")
|
||||
tokenizer_path = os.path.join(GENDER_MODELS_DIR, "transformer_tokenizer.pkl")
|
||||
encoder_path = os.path.join(GENDER_MODELS_DIR, "transformer_label_encoder.pkl")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user