refactor: include province and annotation pipeline
This commit is contained in:
@@ -0,0 +1,86 @@
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import ollama
|
||||
import pandas as pd
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from tqdm import tqdm
|
||||
from typing import Optional
|
||||
|
||||
from misc import load_prompt, load_csv_dataset, DATA_DIR, logging
|
||||
|
||||
|
||||
class NameAnalysis(BaseModel):
|
||||
identified_name: Optional[str]
|
||||
identified_surname: Optional[str]
|
||||
|
||||
|
||||
def analyze_name(client: ollama.Client, model: str, prompt: str, name: str) -> dict:
|
||||
"""
|
||||
Analyze a name using the specified model and prompt.
|
||||
Returns a dictionary with identified name, surname, and category.
|
||||
"""
|
||||
try:
|
||||
response = client.chat(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": name}
|
||||
],
|
||||
format=NameAnalysis.model_json_schema()
|
||||
)
|
||||
analysis = NameAnalysis.model_validate_json(response.message.content)
|
||||
return analysis.model_dump()
|
||||
except ValidationError as ve:
|
||||
logging.warning(f"Validation error: {ve}")
|
||||
except Exception as e:
|
||||
logging.error(f"Unexpected error: {e}")
|
||||
return {
|
||||
"identified_name": None,
|
||||
"identified_surname": None
|
||||
}
|
||||
|
||||
|
||||
def build_updates(client: ollama.Client, prompt: str, llm_model: str, rows: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Build updates for the DataFrame by analyzing names.
|
||||
Iterates through the DataFrame rows, analyzes each name, and returns a DataFrame with updates.
|
||||
"""
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
updates = []
|
||||
|
||||
for idx, row in rows.iterrows():
|
||||
entry = analyze_name(client, llm_model, prompt, row['name'])
|
||||
entry["annotated"] = 1
|
||||
updates.append((idx, entry))
|
||||
logging.info(f"Analyzed name: {row['name']} - {entry}")
|
||||
|
||||
return pd.DataFrame.from_dict(dict(updates), orient='index')
|
||||
|
||||
|
||||
def main(llm_model: str = "llama3.2:3b"):
|
||||
df = pd.DataFrame(load_csv_dataset('names_featured.csv'))
|
||||
prompt = load_prompt()
|
||||
|
||||
entries = df[df['annotated'].astype("Int8") == 0]
|
||||
if entries.empty:
|
||||
logging.info("No names to analyze.")
|
||||
return
|
||||
|
||||
logging.info(f"Found {len(entries)} names to analyze.")
|
||||
client = ollama.Client()
|
||||
|
||||
df.update(build_updates(client, prompt, llm_model, entries))
|
||||
df.to_csv(os.path.join(DATA_DIR, 'names_featured.csv'), index=False)
|
||||
logging.info("Done.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="Analyze names using an LLM model.")
|
||||
parser.add_argument('--llm_model', type=str, default="llama3.2:3b", help="Ollama model name to use (default: llama3.2:3b)")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
main(llm_model=args.llm_model)
|
||||
except Exception as e:
|
||||
logging.error(f"Fatal error: {e}", exc_info=True)
|
||||
@@ -1,72 +0,0 @@
|
||||
import os
|
||||
|
||||
import ollama
|
||||
import pandas as pd
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from tqdm import tqdm
|
||||
|
||||
from misc import load_prompt, load_csv_dataset, DATA_DIR
|
||||
|
||||
|
||||
class NameAnalysis(BaseModel):
|
||||
identified_name: str | None
|
||||
identified_surname: str | None
|
||||
identified_category: str | None
|
||||
|
||||
|
||||
def main():
|
||||
dataset = pd.DataFrame(load_csv_dataset('names_featured.csv'))
|
||||
prompt = load_prompt()
|
||||
|
||||
print(">> Filtering dataset for names that need analysis...")
|
||||
to_analyze = dataset[dataset['llm_annotated'] == 0].copy()
|
||||
if to_analyze.empty:
|
||||
print(">> No names to analyze.")
|
||||
return
|
||||
|
||||
client = ollama.Client()
|
||||
updates = []
|
||||
|
||||
print(">> Starting name analysis with LLM...")
|
||||
for row in tqdm(to_analyze.itertuples(index=True), total=len(to_analyze)):
|
||||
name = row.name
|
||||
try:
|
||||
response = client.chat(
|
||||
model="llama3.2:3b",
|
||||
messages=[
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": name}
|
||||
],
|
||||
format=NameAnalysis.model_json_schema()
|
||||
)
|
||||
analysis = NameAnalysis.model_validate_json(response.message.content)
|
||||
result = analysis.model_dump()
|
||||
except (ValidationError, Exception):
|
||||
result = {
|
||||
"identified_name": None,
|
||||
"identified_surname": None,
|
||||
"identified_category": None
|
||||
}
|
||||
|
||||
updates.append({
|
||||
"index": row.Index,
|
||||
"identified_name": result["identified_name"],
|
||||
"identified_surname": result["identified_surname"],
|
||||
"identified_category": result["identified_category"],
|
||||
"llm_annotated": 1
|
||||
})
|
||||
|
||||
print(">> Updating dataset with results...")
|
||||
updates_df = pd.DataFrame(updates).set_index("index")
|
||||
dataset.update(updates_df)
|
||||
|
||||
print(">> Saving updated dataset...")
|
||||
dataset.to_csv(os.path.join(DATA_DIR, 'names_featured.csv'), index=False)
|
||||
print(">> Done.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
try:
|
||||
main()
|
||||
except Exception as e:
|
||||
print(f">> Fatal error: {e}")
|
||||
@@ -1,78 +0,0 @@
|
||||
import os
|
||||
import pandas as pd
|
||||
from misc import DATA_DIR
|
||||
|
||||
|
||||
def clean(filepath):
|
||||
encodings = ['utf-8', 'utf-16', 'latin1']
|
||||
for enc in encodings:
|
||||
try:
|
||||
print(f">> Trying to read {filepath} with encoding: {enc}")
|
||||
# Use chunked reading to handle large files
|
||||
chunks = pd.read_csv(filepath, encoding=enc, chunksize=100_000, on_bad_lines='skip')
|
||||
cleaned_chunks = []
|
||||
|
||||
for chunk in chunks:
|
||||
# Drop rows with essential missing values early
|
||||
chunk = chunk.dropna(subset=['name', 'sex', 'region'])
|
||||
|
||||
# Clean string columns in-place
|
||||
for col in chunk.select_dtypes(include='object').columns:
|
||||
chunk[col] = (
|
||||
chunk[col]
|
||||
.astype(str)
|
||||
.str.replace('\x00', ' ', regex=False)
|
||||
.str.replace('\u00a0', ' ', regex=False)
|
||||
.str.replace(' +', ' ', regex=True)
|
||||
)
|
||||
|
||||
cleaned_chunks.append(chunk)
|
||||
|
||||
df = pd.concat(cleaned_chunks, ignore_index=True)
|
||||
df.to_csv(filepath, index=False, encoding='utf-8')
|
||||
print(f">> Successfully read with encoding: {enc}")
|
||||
return df
|
||||
except Exception:
|
||||
continue
|
||||
raise UnicodeDecodeError(f"Unable to decode {filepath} with common encodings.")
|
||||
|
||||
|
||||
def process(df: pd.DataFrame):
|
||||
print(">> Preprocessing names")
|
||||
df['name'] = df['name'].str.strip().str.lower()
|
||||
|
||||
df['words'] = df['name'].str.count(' ') + 1
|
||||
df['length'] = df['name'].str.replace(' ', '', regex=False).str.len()
|
||||
|
||||
name_split = df['name'].str.split()
|
||||
df['probable_native'] = name_split.apply(lambda x: ' '.join(x[:-1]) if len(x) > 1 else '')
|
||||
df['probable_surname'] = name_split.apply(lambda x: x[-1] if x else '')
|
||||
df['llm_annotated'] = 0
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def split_and_save(df: pd.DataFrame):
|
||||
print(">> Saving evaluation and featured datasets")
|
||||
eval_idx = df.sample(frac=0.2, random_state=42).index
|
||||
|
||||
df_evaluation = df.loc[eval_idx]
|
||||
df_featured = df.drop(index=eval_idx)
|
||||
|
||||
df_evaluation.to_csv(os.path.join(DATA_DIR, 'names_evaluation.csv'), index=False)
|
||||
df_featured.to_csv(os.path.join(DATA_DIR, 'names_featured.csv'), index=False)
|
||||
|
||||
print(">> Saving by sex")
|
||||
df[df['sex'].str.lower() == 'm'].to_csv(os.path.join(DATA_DIR, 'names_males.csv'), index=False)
|
||||
df[df['sex'].str.lower() == 'f'].to_csv(os.path.join(DATA_DIR, 'names_females.csv'), index=False)
|
||||
|
||||
|
||||
def main():
|
||||
filepath = os.path.join(DATA_DIR, 'names.csv')
|
||||
df = clean(filepath)
|
||||
df = process(df)
|
||||
split_and_save(df)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,110 @@
|
||||
import os
|
||||
import argparse
|
||||
import pandas as pd
|
||||
from misc import DATA_DIR, REGION_MAPPING, logging
|
||||
|
||||
|
||||
def clean(filepath) -> pd.DataFrame:
|
||||
"""
|
||||
Clean the CSV file by removing null bytes, non-breaking spaces, and extra spaces.
|
||||
Also, it attempts to read the file with different encodings to handle potential encoding issues.
|
||||
"""
|
||||
|
||||
encodings = ['utf-8', 'utf-16', 'latin1']
|
||||
for enc in encodings:
|
||||
try:
|
||||
logging.info(f"Trying to read {filepath} with encoding: {enc}")
|
||||
# Use chunked reading to handle large files
|
||||
chunks = pd.read_csv(filepath, encoding=enc, chunksize=100_000, on_bad_lines='skip')
|
||||
cleaned_chunks = []
|
||||
|
||||
for chunk in chunks:
|
||||
# Drop rows with essential missing values early
|
||||
chunk = chunk.dropna(subset=['name', 'sex', 'region'])
|
||||
|
||||
# Clean string columns in-place
|
||||
for col in chunk.select_dtypes(include='object').columns:
|
||||
chunk[col] = (
|
||||
chunk[col]
|
||||
.astype(str)
|
||||
.str.replace('\x00', ' ', regex=False)
|
||||
.str.replace('\u00a0', ' ', regex=False)
|
||||
.str.replace(' +', ' ', regex=True)
|
||||
.str.strip()
|
||||
.str.lower()
|
||||
)
|
||||
|
||||
cleaned_chunks.append(chunk)
|
||||
|
||||
df = pd.concat(cleaned_chunks, ignore_index=True)
|
||||
df.to_csv(filepath, index=False, encoding='utf-8')
|
||||
logging.info(f"Successfully read with encoding: {enc}")
|
||||
return df
|
||||
except Exception:
|
||||
continue
|
||||
raise UnicodeDecodeError(f"Unable to decode {filepath} with common encodings.")
|
||||
|
||||
|
||||
def process(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Process the DataFrame to extract features and clean data.
|
||||
This includes counting words, calculating name length, and extracting probable native names and surnames.
|
||||
Also maps regions to provinces based on REGION_MAPPING.
|
||||
"""
|
||||
|
||||
logging.info("Preprocessing names")
|
||||
df['words'] = df['name'].str.count(' ') + 1
|
||||
df['length'] = df['name'].str.replace(' ', '', regex=False).str.len()
|
||||
|
||||
name_split = df['name'].str.split()
|
||||
df['probable_native'] = name_split.apply(lambda x: ' '.join(x[:-1]) if len(x) > 1 else '')
|
||||
df['probable_surname'] = name_split.apply(lambda x: x[-1] if x else '')
|
||||
df['identified_category'] = df['words'].apply(lambda x: 'compose' if x > 3 else 'simple')
|
||||
df['identified_name'] = None
|
||||
df['identified_surname'] = None
|
||||
|
||||
logging.info("Mapping regions to provinces")
|
||||
df['province'] = df['region'].map(lambda r: REGION_MAPPING.get(r, ('AUTRES', 'AUTRES'))[1])
|
||||
df['province'] = df['province'].str.lower()
|
||||
df['annotated'] = 0
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def save_artifacts(df: pd.DataFrame, split_eval: bool = True, split_by_sex: bool = True) -> None:
|
||||
"""
|
||||
Splits the input DataFrame into evaluation and featured datasets, saves them as CSV files,
|
||||
and additionally saves separate CSV files for male and female entries if requested.
|
||||
"""
|
||||
|
||||
if split_eval:
|
||||
logging.info("Saving evaluation and featured datasets")
|
||||
eval_idx = df.sample(frac=0.2, random_state=42).index
|
||||
df_evaluation = df.loc[eval_idx]
|
||||
df_featured = df.drop(index=eval_idx)
|
||||
df_evaluation.to_csv(os.path.join(DATA_DIR, 'names_evaluation.csv'), index=False)
|
||||
df_featured.to_csv(os.path.join(DATA_DIR, 'names_featured.csv'), index=False)
|
||||
else:
|
||||
df.to_csv(os.path.join(DATA_DIR, 'names_featured.csv'), index=False)
|
||||
|
||||
if split_by_sex:
|
||||
logging.info("Saving by sex")
|
||||
df[df['sex'] == 'm'].to_csv(os.path.join(DATA_DIR, 'names_males.csv'), index=False)
|
||||
df[df['sex'] == 'f'].to_csv(os.path.join(DATA_DIR, 'names_females.csv'), index=False)
|
||||
|
||||
|
||||
def main(split_eval: bool = True, split_by_sex: bool = True):
|
||||
df = process(clean(os.path.join(DATA_DIR, 'names.csv')))
|
||||
save_artifacts(df, split_eval=split_eval, split_by_sex=split_by_sex)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="Prepare name datasets with optional splits.")
|
||||
|
||||
parser.add_argument('--split_eval', action='store_true', default=True, help="Split into evaluation and featured datasets (default: True)")
|
||||
parser.add_argument('--no-split_eval', action='store_false', dest='split_eval', help="Do not split into evaluation and featured datasets")
|
||||
parser.add_argument('--split_by_sex', action='store_true', default=True, help="Split by sex into male/female datasets (default: True)")
|
||||
parser.add_argument('--no-split_by_sex', action='store_false', dest='split_by_sex', help="Do not split by sex into male/female datasets")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(split_eval=args.split_eval, split_by_sex=args.split_by_sex)
|
||||
@@ -7,7 +7,6 @@ from misc import load_prompt
|
||||
class NameAnalysis(BaseModel):
|
||||
identified_name: str | None
|
||||
identified_surname: str | None
|
||||
identified_category: str | None
|
||||
|
||||
|
||||
name = input("Enter name: ")
|
||||
Reference in New Issue
Block a user