fix: artifacts saving and dataset loading
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, Optional
|
||||
|
||||
@@ -23,7 +22,7 @@ from tensorflow.keras.models import Model
|
||||
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
||||
from tensorflow.keras.preprocessing.text import Tokenizer
|
||||
|
||||
from misc import GENDER_MODELS_DIR, load_csv_dataset
|
||||
from misc import GENDER_MODELS_DIR, load_csv_dataset, save_pickle
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format=">> %(message)s")
|
||||
|
||||
@@ -198,7 +197,7 @@ def evaluate_proba(y_true, y_proba, threshold, class_names):
|
||||
:return: None. Outputs performance metrics and confusion matrix to the logging
|
||||
system and the console.
|
||||
"""
|
||||
y_pred = 1 if y_proba[:, 1] >= threshold else 0
|
||||
y_pred = (y_proba[:, 1] >= threshold).astype(int)
|
||||
acc = accuracy_score(y_true, y_pred)
|
||||
pr, rc, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary")
|
||||
cm = confusion_matrix(y_true, y_pred)
|
||||
@@ -257,16 +256,13 @@ def save_artifacts(model, tokenizer, encoder):
|
||||
:param encoder: The label encoder used for encoding target labels.
|
||||
:return: None
|
||||
"""
|
||||
model_path = os.path.join(GENDER_MODELS_DIR, "transformer.h5")
|
||||
tokenizer_path = os.path.join(GENDER_MODELS_DIR, "transformer_tokenizer.pkl")
|
||||
encoder_path = os.path.join(GENDER_MODELS_DIR, "transformer_label_encoder.pkl")
|
||||
os.makedirs(GENDER_MODELS_DIR, exist_ok=True)
|
||||
model.save(os.path.join(GENDER_MODELS_DIR, "transformer.keras"))
|
||||
|
||||
model.save(model_path)
|
||||
with open(tokenizer_path, "wb") as f:
|
||||
pickle.dump(tokenizer, f)
|
||||
with open(encoder_path, "wb") as f:
|
||||
pickle.dump(encoder, f)
|
||||
logging.info("Model and artifacts saved.")
|
||||
save_pickle(tokenizer, os.path.join(GENDER_MODELS_DIR, "transformer_tokenizer.pkl"))
|
||||
save_pickle(encoder, os.path.join(GENDER_MODELS_DIR, "transformer_label_encoder.pkl"))
|
||||
|
||||
logging.info(f"Model and artifacts saved to {GENDER_MODELS_DIR}")
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
Reference in New Issue
Block a user