fix: artifacts saving and dataset loading

This commit is contained in:
2025-06-24 21:49:03 +02:00
parent fb95c72ab7
commit eb139ee09a
4 changed files with 43 additions and 41 deletions
+24 -9
View File
@@ -1,8 +1,8 @@
import csv
import io
import json
import os
import pickle
from datetime import datetime
from typing import Optional
# Paths
@@ -16,6 +16,7 @@ GENDER_RESULT_DIR = os.path.join(ROOT_DIR, 'gender', 'results')
NER_MODELS_DIR = os.path.join(MODELS_DIR, 'ner')
NER_RESULT_DIR = os.path.join(ROOT_DIR, 'ner', 'results')
def clean_spacing(filename: str) -> Optional[str]:
try:
with open(os.path.join(DATA_DIR, filename), 'r', encoding='utf8') as f:
@@ -42,14 +43,27 @@ def save_csv_dataset(data: list, path: str) -> None:
def load_csv_dataset(path: str, limit: int = None) -> list:
print(f">> Loading CSV dataset from {path}")
data = []
with open(os.path.join(DATA_DIR, path), "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
data.append(row)
if limit and len(data) >= limit:
break
encodings = ['utf-8', 'utf-16', 'latin1']
return data
for enc in encodings:
try:
with open(os.path.join(DATA_DIR, path), "r", encoding=enc, errors="replace") as f:
raw_text = f.read().replace('\x00', '')
csv_buffer = io.StringIO(raw_text)
reader = csv.DictReader(csv_buffer)
print(f">> Detected fieldnames: {reader.fieldnames}")
for row in reader:
data.append(row)
if limit and len(data) >= limit:
break
print(f">> Successfully loaded with encoding: {enc}")
return data
except Exception as e:
print(f">> Failed with encoding: {enc}, error: {e}")
raise UnicodeDecodeError("load_csv_dataset", path, 0, 0, "Unable to decode file with common encodings.")
def save_json_dataset(data: list, path: str) -> None:
@@ -63,6 +77,7 @@ def save_pickle(obj, path):
with open(path, "wb") as f:
pickle.dump(obj, f)
def load_pickle(path: str):
with open(path, "rb") as f:
return pickle.load(f)
return pickle.load(f)