feat: balanced dataset loading
This commit is contained in:
+25
-28
@@ -4,6 +4,7 @@ import json
|
||||
import os
|
||||
import pickle
|
||||
from typing import Optional
|
||||
from typing import List, Dict
|
||||
|
||||
# Paths
|
||||
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
@@ -17,15 +18,6 @@ NER_MODELS_DIR = os.path.join(MODELS_DIR, 'ner')
|
||||
NER_RESULT_DIR = os.path.join(ROOT_DIR, 'ner', 'results')
|
||||
|
||||
|
||||
def clean_spacing(filename: str) -> Optional[str]:
|
||||
try:
|
||||
with open(os.path.join(DATA_DIR, filename), 'r', encoding='utf8') as f:
|
||||
content = f.read()
|
||||
return content.translate(str.maketrans({'\00': ' ', ' ': ' '}))
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
|
||||
def load_json_dataset(path: str) -> list:
|
||||
print(f">> Loading JSON dataset from {path}")
|
||||
with open(os.path.join(DATA_DIR, path), "r", encoding="utf-8") as f:
|
||||
@@ -40,30 +32,35 @@ def save_csv_dataset(data: list, path: str) -> None:
|
||||
writer.writerows(data)
|
||||
|
||||
|
||||
def load_csv_dataset(path: str, limit: int = None) -> list:
|
||||
def load_csv_dataset(path: str, limit: int = None, balanced: bool = False) -> List[Dict[str, str]]:
|
||||
print(f">> Loading CSV dataset from {path}")
|
||||
data = []
|
||||
encodings = ['utf-8', 'utf-16', 'latin1']
|
||||
|
||||
for enc in encodings:
|
||||
try:
|
||||
with open(os.path.join(DATA_DIR, path), "r", encoding=enc, errors="replace") as f:
|
||||
raw_text = f.read().replace('\x00', '')
|
||||
file_path = os.path.join(DATA_DIR, path)
|
||||
with open(file_path, "r", encoding="utf-8", errors="replace", newline="") as f:
|
||||
raw_text = f.read().replace('\x00', '')
|
||||
|
||||
csv_buffer = io.StringIO(raw_text)
|
||||
reader = csv.DictReader(csv_buffer)
|
||||
print(f">> Detected fieldnames: {reader.fieldnames}")
|
||||
reader = csv.DictReader(io.StringIO(raw_text))
|
||||
print(f">> Detected fieldnames: {reader.fieldnames}")
|
||||
|
||||
for row in reader:
|
||||
data.append(row)
|
||||
if limit and len(data) >= limit:
|
||||
break
|
||||
print(f">> Successfully loaded with encoding: {enc}")
|
||||
return data
|
||||
except Exception as e:
|
||||
print(f">> Failed with encoding: {enc}, error: {e}")
|
||||
if balanced:
|
||||
by_sex = {'m': [], 'f': []}
|
||||
for row in reader:
|
||||
sex = row.get("sex", "").lower()
|
||||
if sex in by_sex:
|
||||
by_sex[sex].append(row)
|
||||
min_len = min(len(by_sex['m']), len(by_sex['f']))
|
||||
if limit:
|
||||
min_len = min(min_len, limit // 2)
|
||||
data = by_sex['m'][:min_len] + by_sex['f'][:min_len]
|
||||
else:
|
||||
data = []
|
||||
for i, row in enumerate(reader):
|
||||
data.append(row)
|
||||
if limit and i + 1 >= limit:
|
||||
break
|
||||
|
||||
raise UnicodeDecodeError("load_csv_dataset", path, 0, 0, "Unable to decode file with common encodings.")
|
||||
print(">> Successfully loaded with UTF-8 encoding")
|
||||
return data
|
||||
|
||||
|
||||
def save_json_dataset(data: list, path: str) -> None:
|
||||
|
||||
Reference in New Issue
Block a user