From eb139ee09a9c149ef9a91bd23dd12651a50bae90 Mon Sep 17 00:00:00 2001 From: bernard-ng Date: Tue, 24 Jun 2025 21:49:03 +0200 Subject: [PATCH] fix: artifacts saving and dataset loading --- misc/__init__.py | 33 ++++++++++++++++++++++--------- ners/gender/models/logreg.py | 15 +++++--------- ners/gender/models/lstm.py | 16 ++++++--------- ners/gender/models/transformer.py | 20 ++++++++----------- 4 files changed, 43 insertions(+), 41 deletions(-) diff --git a/misc/__init__.py b/misc/__init__.py index b669c35..0f038e7 100644 --- a/misc/__init__.py +++ b/misc/__init__.py @@ -1,8 +1,8 @@ import csv +import io import json import os import pickle -from datetime import datetime from typing import Optional # Paths @@ -16,6 +16,7 @@ GENDER_RESULT_DIR = os.path.join(ROOT_DIR, 'gender', 'results') NER_MODELS_DIR = os.path.join(MODELS_DIR, 'ner') NER_RESULT_DIR = os.path.join(ROOT_DIR, 'ner', 'results') + def clean_spacing(filename: str) -> Optional[str]: try: with open(os.path.join(DATA_DIR, filename), 'r', encoding='utf8') as f: @@ -42,14 +43,27 @@ def save_csv_dataset(data: list, path: str) -> None: def load_csv_dataset(path: str, limit: int = None) -> list: print(f">> Loading CSV dataset from {path}") data = [] - with open(os.path.join(DATA_DIR, path), "r", encoding="utf-8") as f: - reader = csv.DictReader(f) - for row in reader: - data.append(row) - if limit and len(data) >= limit: - break + encodings = ['utf-8', 'utf-16', 'latin1'] - return data + for enc in encodings: + try: + with open(os.path.join(DATA_DIR, path), "r", encoding=enc, errors="replace") as f: + raw_text = f.read().replace('\x00', '') + + csv_buffer = io.StringIO(raw_text) + reader = csv.DictReader(csv_buffer) + print(f">> Detected fieldnames: {reader.fieldnames}") + + for row in reader: + data.append(row) + if limit and len(data) >= limit: + break + print(f">> Successfully loaded with encoding: {enc}") + return data + except Exception as e: + print(f">> Failed with encoding: {enc}, error: {e}") + + raise UnicodeDecodeError("load_csv_dataset", path, 0, 0, "Unable to decode file with common encodings.") def save_json_dataset(data: list, path: str) -> None: @@ -63,6 +77,7 @@ def save_pickle(obj, path): with open(path, "wb") as f: pickle.dump(obj, f) + def load_pickle(path: str): with open(path, "rb") as f: - return pickle.load(f) \ No newline at end of file + return pickle.load(f) diff --git a/ners/gender/models/logreg.py b/ners/gender/models/logreg.py index 9f13631..5aed7dc 100644 --- a/ners/gender/models/logreg.py +++ b/ners/gender/models/logreg.py @@ -1,7 +1,6 @@ import argparse import logging import os -import pickle from dataclasses import dataclass from typing import Tuple, Optional @@ -16,10 +15,11 @@ from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val from sklearn.pipeline import make_pipeline, Pipeline from sklearn.preprocessing import LabelEncoder -from misc import GENDER_MODELS_DIR, load_csv_dataset +from misc import GENDER_MODELS_DIR, load_csv_dataset, save_pickle logging.basicConfig(level=logging.INFO, format=">> %(message)s") + @dataclass class Config: dataset_path: str @@ -169,15 +169,10 @@ def save_artifacts(model, encoder, cfg: Config): :type cfg: Config :return: None """ - model_path = os.path.join(GENDER_MODELS_DIR, "regression_model.pkl") - encoder_path = os.path.join(GENDER_MODELS_DIR, "regression_label_encoder.pkl") + save_pickle(model, os.path.join(GENDER_MODELS_DIR, "regression_model.pkl")) + save_pickle(encoder, os.path.join(GENDER_MODELS_DIR, "regression_label_encoder.pkl")) - with open(model_path, "wb") as f: - pickle.dump(model, f) - with open(encoder_path, "wb") as f: - pickle.dump(encoder, f) - logging.info(f"Saved model to: {model_path}") - logging.info(f"Saved label encoder to: {encoder_path}") + logging.info(f"Model and artifacts saved to {GENDER_MODELS_DIR}") def main(): diff --git a/ners/gender/models/lstm.py b/ners/gender/models/lstm.py index e3cc10d..48750ea 100644 --- a/ners/gender/models/lstm.py +++ b/ners/gender/models/lstm.py @@ -1,7 +1,6 @@ import argparse import logging import os -import pickle from dataclasses import dataclass from typing import Tuple, Optional @@ -18,7 +17,7 @@ from tensorflow.keras.models import Sequential from tensorflow.keras.preprocessing.sequence import pad_sequences from tensorflow.keras.preprocessing.text import Tokenizer -from misc import GENDER_MODELS_DIR, load_csv_dataset +from misc import GENDER_MODELS_DIR, load_csv_dataset, save_pickle logging.basicConfig(level=logging.INFO, format=">> %(message)s") @@ -214,15 +213,12 @@ def save_artifacts(model, tokenizer, encoder): :return: None """ - model_path = os.path.join(GENDER_MODELS_DIR, "lstm_model.keras") - tokenizer_path = os.path.join(GENDER_MODELS_DIR, "lstm_tokenizer.pkl") - encoder_path = os.path.join(GENDER_MODELS_DIR, "lstm_label_encoder.pkl") + os.makedirs(GENDER_MODELS_DIR, exist_ok=True) + model.save(os.path.join(GENDER_MODELS_DIR, "lstm_model.keras")) + + save_pickle(tokenizer, os.path.join(GENDER_MODELS_DIR, "lstm_tokenizer.pkl")) + save_pickle(encoder, os.path.join(GENDER_MODELS_DIR, "lstm_label_encoder.pkl")) - model.save(model_path) - with open(tokenizer_path, "wb") as f: - pickle.dump(tokenizer, f) - with open(encoder_path, "wb") as f: - pickle.dump(encoder, f) logging.info(f"Model and artifacts saved to {GENDER_MODELS_DIR}") diff --git a/ners/gender/models/transformer.py b/ners/gender/models/transformer.py index d6e94bf..558feac 100644 --- a/ners/gender/models/transformer.py +++ b/ners/gender/models/transformer.py @@ -1,7 +1,6 @@ import argparse import logging import os -import pickle from dataclasses import dataclass from typing import Tuple, Optional @@ -23,7 +22,7 @@ from tensorflow.keras.models import Model from tensorflow.keras.preprocessing.sequence import pad_sequences from tensorflow.keras.preprocessing.text import Tokenizer -from misc import GENDER_MODELS_DIR, load_csv_dataset +from misc import GENDER_MODELS_DIR, load_csv_dataset, save_pickle logging.basicConfig(level=logging.INFO, format=">> %(message)s") @@ -198,7 +197,7 @@ def evaluate_proba(y_true, y_proba, threshold, class_names): :return: None. Outputs performance metrics and confusion matrix to the logging system and the console. """ - y_pred = 1 if y_proba[:, 1] >= threshold else 0 + y_pred = (y_proba[:, 1] >= threshold).astype(int) acc = accuracy_score(y_true, y_pred) pr, rc, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary") cm = confusion_matrix(y_true, y_pred) @@ -257,16 +256,13 @@ def save_artifacts(model, tokenizer, encoder): :param encoder: The label encoder used for encoding target labels. :return: None """ - model_path = os.path.join(GENDER_MODELS_DIR, "transformer.h5") - tokenizer_path = os.path.join(GENDER_MODELS_DIR, "transformer_tokenizer.pkl") - encoder_path = os.path.join(GENDER_MODELS_DIR, "transformer_label_encoder.pkl") + os.makedirs(GENDER_MODELS_DIR, exist_ok=True) + model.save(os.path.join(GENDER_MODELS_DIR, "transformer.keras")) - model.save(model_path) - with open(tokenizer_path, "wb") as f: - pickle.dump(tokenizer, f) - with open(encoder_path, "wb") as f: - pickle.dump(encoder, f) - logging.info("Model and artifacts saved.") + save_pickle(tokenizer, os.path.join(GENDER_MODELS_DIR, "transformer_tokenizer.pkl")) + save_pickle(encoder, os.path.join(GENDER_MODELS_DIR, "transformer_label_encoder.pkl")) + + logging.info(f"Model and artifacts saved to {GENDER_MODELS_DIR}") def main():