Files
drc-ners-nlp/processing/annotate.py
T
2025-07-25 10:42:02 +02:00

110 lines
3.5 KiB
Python

import os
import argparse
import ollama
import pandas as pd
from pydantic import BaseModel, ValidationError
from tqdm import tqdm
from typing import Optional
from misc import load_prompt, load_csv_dataset, DATA_DIR, logging
class NameAnalysis(BaseModel):
identified_name: Optional[str]
identified_surname: Optional[str]
def analyze_name(client: ollama.Client, model: str, prompt: str, name: str) -> dict:
"""
Analyze a name using the specified model and prompt.
Returns a dictionary with identified name, surname, and category.
"""
try:
response = client.chat(
model=model,
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": name},
],
format=NameAnalysis.model_json_schema(),
)
analysis = NameAnalysis.model_validate_json(response.message.content)
return analysis.model_dump()
except ValidationError as ve:
logging.warning(f"Validation error: {ve}")
except Exception as e:
logging.error(f"Unexpected error: {e}")
return {"identified_name": None, "identified_surname": None}
def save_checkpoint(df: pd.DataFrame):
df.to_csv(os.path.join(DATA_DIR, "names_featured.csv"), index=False)
logging.critical(f"Checkpoint saved")
def build_updates(llm_model: str, df: pd.DataFrame, entries: pd.DataFrame) -> pd.DataFrame:
BATCH_SIZE = 10
client = ollama.Client()
prompt = load_prompt()
updates = []
# Set logging level for HTTP client to reduce noise
# This is useful to avoid excessive logging from the HTTP client used by Ollama
logging.getLogger("httpx").setLevel(logging.WARNING)
for idx, (row_idx, row) in enumerate(entries.iterrows(), 1):
try:
entry = analyze_name(client, llm_model, prompt, row["name"])
entry["annotated"] = 1
updates.append((row_idx, entry))
logging.info(f"Analyzed: {row['name']} - {entry}")
except Exception as e:
logging.warning(f"Failed to analyze '{row['name']}': {e}")
continue
if idx % BATCH_SIZE == 0 or idx == len(entries):
update_df = pd.DataFrame.from_dict(dict(updates), orient="index")
update_df["annotated"] = pd.to_numeric(update_df["annotated"], errors="coerce").fillna(0).astype("Int8")
df.update(update_df)
save_checkpoint(df)
updates.clear() # avoid re-applying same updates
return df
def main(llm_model: str = "llama3.2:3b"):
df = pd.DataFrame(load_csv_dataset(os.path.join(DATA_DIR, "names_featured.csv")))
# Safely cast 'annotated' column to Int8, handling float-like strings (e.g., '1.0')
df["annotated"] = pd.to_numeric(df["annotated"], errors="coerce").fillna(0).astype(float).astype("Int8")
entries = df[df["annotated"] == 0]
if entries.empty:
logging.info("No names to analyze.")
return
logging.info(f"Found {len(entries)} names to analyze.")
df = build_updates(llm_model, df, entries)
save_checkpoint(df)
logging.info("Analysis complete.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Analyze names using an LLM model.")
parser.add_argument(
"--llm_model",
type=str,
default="mistral:7b",
help="Ollama model name to use (default: mistral:7b)",
)
args = parser.parse_args()
try:
main(llm_model=args.llm_model)
except Exception as e:
logging.error(f"Fatal error: {e}", exc_info=True)