fix: artifacts saving and dataset loading
This commit is contained in:
+24
-9
@@ -1,8 +1,8 @@
|
|||||||
import csv
|
import csv
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from datetime import datetime
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
# Paths
|
# Paths
|
||||||
@@ -16,6 +16,7 @@ GENDER_RESULT_DIR = os.path.join(ROOT_DIR, 'gender', 'results')
|
|||||||
NER_MODELS_DIR = os.path.join(MODELS_DIR, 'ner')
|
NER_MODELS_DIR = os.path.join(MODELS_DIR, 'ner')
|
||||||
NER_RESULT_DIR = os.path.join(ROOT_DIR, 'ner', 'results')
|
NER_RESULT_DIR = os.path.join(ROOT_DIR, 'ner', 'results')
|
||||||
|
|
||||||
|
|
||||||
def clean_spacing(filename: str) -> Optional[str]:
|
def clean_spacing(filename: str) -> Optional[str]:
|
||||||
try:
|
try:
|
||||||
with open(os.path.join(DATA_DIR, filename), 'r', encoding='utf8') as f:
|
with open(os.path.join(DATA_DIR, filename), 'r', encoding='utf8') as f:
|
||||||
@@ -42,14 +43,27 @@ def save_csv_dataset(data: list, path: str) -> None:
|
|||||||
def load_csv_dataset(path: str, limit: int = None) -> list:
|
def load_csv_dataset(path: str, limit: int = None) -> list:
|
||||||
print(f">> Loading CSV dataset from {path}")
|
print(f">> Loading CSV dataset from {path}")
|
||||||
data = []
|
data = []
|
||||||
with open(os.path.join(DATA_DIR, path), "r", encoding="utf-8") as f:
|
encodings = ['utf-8', 'utf-16', 'latin1']
|
||||||
reader = csv.DictReader(f)
|
|
||||||
for row in reader:
|
|
||||||
data.append(row)
|
|
||||||
if limit and len(data) >= limit:
|
|
||||||
break
|
|
||||||
|
|
||||||
return data
|
for enc in encodings:
|
||||||
|
try:
|
||||||
|
with open(os.path.join(DATA_DIR, path), "r", encoding=enc, errors="replace") as f:
|
||||||
|
raw_text = f.read().replace('\x00', '')
|
||||||
|
|
||||||
|
csv_buffer = io.StringIO(raw_text)
|
||||||
|
reader = csv.DictReader(csv_buffer)
|
||||||
|
print(f">> Detected fieldnames: {reader.fieldnames}")
|
||||||
|
|
||||||
|
for row in reader:
|
||||||
|
data.append(row)
|
||||||
|
if limit and len(data) >= limit:
|
||||||
|
break
|
||||||
|
print(f">> Successfully loaded with encoding: {enc}")
|
||||||
|
return data
|
||||||
|
except Exception as e:
|
||||||
|
print(f">> Failed with encoding: {enc}, error: {e}")
|
||||||
|
|
||||||
|
raise UnicodeDecodeError("load_csv_dataset", path, 0, 0, "Unable to decode file with common encodings.")
|
||||||
|
|
||||||
|
|
||||||
def save_json_dataset(data: list, path: str) -> None:
|
def save_json_dataset(data: list, path: str) -> None:
|
||||||
@@ -63,6 +77,7 @@ def save_pickle(obj, path):
|
|||||||
with open(path, "wb") as f:
|
with open(path, "wb") as f:
|
||||||
pickle.dump(obj, f)
|
pickle.dump(obj, f)
|
||||||
|
|
||||||
|
|
||||||
def load_pickle(path: str):
|
def load_pickle(path: str):
|
||||||
with open(path, "rb") as f:
|
with open(path, "rb") as f:
|
||||||
return pickle.load(f)
|
return pickle.load(f)
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import pickle
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Tuple, Optional
|
from typing import Tuple, Optional
|
||||||
|
|
||||||
@@ -16,10 +15,11 @@ from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val
|
|||||||
from sklearn.pipeline import make_pipeline, Pipeline
|
from sklearn.pipeline import make_pipeline, Pipeline
|
||||||
from sklearn.preprocessing import LabelEncoder
|
from sklearn.preprocessing import LabelEncoder
|
||||||
|
|
||||||
from misc import GENDER_MODELS_DIR, load_csv_dataset
|
from misc import GENDER_MODELS_DIR, load_csv_dataset, save_pickle
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format=">> %(message)s")
|
logging.basicConfig(level=logging.INFO, format=">> %(message)s")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Config:
|
class Config:
|
||||||
dataset_path: str
|
dataset_path: str
|
||||||
@@ -169,15 +169,10 @@ def save_artifacts(model, encoder, cfg: Config):
|
|||||||
:type cfg: Config
|
:type cfg: Config
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
model_path = os.path.join(GENDER_MODELS_DIR, "regression_model.pkl")
|
save_pickle(model, os.path.join(GENDER_MODELS_DIR, "regression_model.pkl"))
|
||||||
encoder_path = os.path.join(GENDER_MODELS_DIR, "regression_label_encoder.pkl")
|
save_pickle(encoder, os.path.join(GENDER_MODELS_DIR, "regression_label_encoder.pkl"))
|
||||||
|
|
||||||
with open(model_path, "wb") as f:
|
logging.info(f"Model and artifacts saved to {GENDER_MODELS_DIR}")
|
||||||
pickle.dump(model, f)
|
|
||||||
with open(encoder_path, "wb") as f:
|
|
||||||
pickle.dump(encoder, f)
|
|
||||||
logging.info(f"Saved model to: {model_path}")
|
|
||||||
logging.info(f"Saved label encoder to: {encoder_path}")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import pickle
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Tuple, Optional
|
from typing import Tuple, Optional
|
||||||
|
|
||||||
@@ -18,7 +17,7 @@ from tensorflow.keras.models import Sequential
|
|||||||
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
||||||
from tensorflow.keras.preprocessing.text import Tokenizer
|
from tensorflow.keras.preprocessing.text import Tokenizer
|
||||||
|
|
||||||
from misc import GENDER_MODELS_DIR, load_csv_dataset
|
from misc import GENDER_MODELS_DIR, load_csv_dataset, save_pickle
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format=">> %(message)s")
|
logging.basicConfig(level=logging.INFO, format=">> %(message)s")
|
||||||
|
|
||||||
@@ -214,15 +213,12 @@ def save_artifacts(model, tokenizer, encoder):
|
|||||||
|
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
model_path = os.path.join(GENDER_MODELS_DIR, "lstm_model.keras")
|
os.makedirs(GENDER_MODELS_DIR, exist_ok=True)
|
||||||
tokenizer_path = os.path.join(GENDER_MODELS_DIR, "lstm_tokenizer.pkl")
|
model.save(os.path.join(GENDER_MODELS_DIR, "lstm_model.keras"))
|
||||||
encoder_path = os.path.join(GENDER_MODELS_DIR, "lstm_label_encoder.pkl")
|
|
||||||
|
save_pickle(tokenizer, os.path.join(GENDER_MODELS_DIR, "lstm_tokenizer.pkl"))
|
||||||
|
save_pickle(encoder, os.path.join(GENDER_MODELS_DIR, "lstm_label_encoder.pkl"))
|
||||||
|
|
||||||
model.save(model_path)
|
|
||||||
with open(tokenizer_path, "wb") as f:
|
|
||||||
pickle.dump(tokenizer, f)
|
|
||||||
with open(encoder_path, "wb") as f:
|
|
||||||
pickle.dump(encoder, f)
|
|
||||||
logging.info(f"Model and artifacts saved to {GENDER_MODELS_DIR}")
|
logging.info(f"Model and artifacts saved to {GENDER_MODELS_DIR}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import pickle
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Tuple, Optional
|
from typing import Tuple, Optional
|
||||||
|
|
||||||
@@ -23,7 +22,7 @@ from tensorflow.keras.models import Model
|
|||||||
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
||||||
from tensorflow.keras.preprocessing.text import Tokenizer
|
from tensorflow.keras.preprocessing.text import Tokenizer
|
||||||
|
|
||||||
from misc import GENDER_MODELS_DIR, load_csv_dataset
|
from misc import GENDER_MODELS_DIR, load_csv_dataset, save_pickle
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format=">> %(message)s")
|
logging.basicConfig(level=logging.INFO, format=">> %(message)s")
|
||||||
|
|
||||||
@@ -198,7 +197,7 @@ def evaluate_proba(y_true, y_proba, threshold, class_names):
|
|||||||
:return: None. Outputs performance metrics and confusion matrix to the logging
|
:return: None. Outputs performance metrics and confusion matrix to the logging
|
||||||
system and the console.
|
system and the console.
|
||||||
"""
|
"""
|
||||||
y_pred = 1 if y_proba[:, 1] >= threshold else 0
|
y_pred = (y_proba[:, 1] >= threshold).astype(int)
|
||||||
acc = accuracy_score(y_true, y_pred)
|
acc = accuracy_score(y_true, y_pred)
|
||||||
pr, rc, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary")
|
pr, rc, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary")
|
||||||
cm = confusion_matrix(y_true, y_pred)
|
cm = confusion_matrix(y_true, y_pred)
|
||||||
@@ -257,16 +256,13 @@ def save_artifacts(model, tokenizer, encoder):
|
|||||||
:param encoder: The label encoder used for encoding target labels.
|
:param encoder: The label encoder used for encoding target labels.
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
model_path = os.path.join(GENDER_MODELS_DIR, "transformer.h5")
|
os.makedirs(GENDER_MODELS_DIR, exist_ok=True)
|
||||||
tokenizer_path = os.path.join(GENDER_MODELS_DIR, "transformer_tokenizer.pkl")
|
model.save(os.path.join(GENDER_MODELS_DIR, "transformer.keras"))
|
||||||
encoder_path = os.path.join(GENDER_MODELS_DIR, "transformer_label_encoder.pkl")
|
|
||||||
|
|
||||||
model.save(model_path)
|
save_pickle(tokenizer, os.path.join(GENDER_MODELS_DIR, "transformer_tokenizer.pkl"))
|
||||||
with open(tokenizer_path, "wb") as f:
|
save_pickle(encoder, os.path.join(GENDER_MODELS_DIR, "transformer_label_encoder.pkl"))
|
||||||
pickle.dump(tokenizer, f)
|
|
||||||
with open(encoder_path, "wb") as f:
|
logging.info(f"Model and artifacts saved to {GENDER_MODELS_DIR}")
|
||||||
pickle.dump(encoder, f)
|
|
||||||
logging.info("Model and artifacts saved.")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
Reference in New Issue
Block a user