87 lines
2.7 KiB
Python
87 lines
2.7 KiB
Python
import os
|
|
import argparse
|
|
|
|
import ollama
|
|
import pandas as pd
|
|
from pydantic import BaseModel, ValidationError
|
|
from tqdm import tqdm
|
|
from typing import Optional
|
|
|
|
from misc import load_prompt, load_csv_dataset, DATA_DIR, logging
|
|
|
|
|
|
class NameAnalysis(BaseModel):
|
|
identified_name: Optional[str]
|
|
identified_surname: Optional[str]
|
|
|
|
|
|
def analyze_name(client: ollama.Client, model: str, prompt: str, name: str) -> dict:
|
|
"""
|
|
Analyze a name using the specified model and prompt.
|
|
Returns a dictionary with identified name, surname, and category.
|
|
"""
|
|
try:
|
|
response = client.chat(
|
|
model=model,
|
|
messages=[
|
|
{"role": "system", "content": prompt},
|
|
{"role": "user", "content": name}
|
|
],
|
|
format=NameAnalysis.model_json_schema()
|
|
)
|
|
analysis = NameAnalysis.model_validate_json(response.message.content)
|
|
return analysis.model_dump()
|
|
except ValidationError as ve:
|
|
logging.warning(f"Validation error: {ve}")
|
|
except Exception as e:
|
|
logging.error(f"Unexpected error: {e}")
|
|
return {
|
|
"identified_name": None,
|
|
"identified_surname": None
|
|
}
|
|
|
|
|
|
def build_updates(client: ollama.Client, prompt: str, llm_model: str, rows: pd.DataFrame) -> pd.DataFrame:
|
|
"""
|
|
Build updates for the DataFrame by analyzing names.
|
|
Iterates through the DataFrame rows, analyzes each name, and returns a DataFrame with updates.
|
|
"""
|
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
|
updates = []
|
|
|
|
for idx, row in rows.iterrows():
|
|
entry = analyze_name(client, llm_model, prompt, row['name'])
|
|
entry["annotated"] = 1
|
|
updates.append((idx, entry))
|
|
logging.info(f"Analyzed name: {row['name']} - {entry}")
|
|
|
|
return pd.DataFrame.from_dict(dict(updates), orient='index')
|
|
|
|
|
|
def main(llm_model: str = "llama3.2:3b"):
|
|
df = pd.DataFrame(load_csv_dataset('names_featured.csv'))
|
|
prompt = load_prompt()
|
|
|
|
entries = df[df['annotated'].astype("Int8") == 0]
|
|
if entries.empty:
|
|
logging.info("No names to analyze.")
|
|
return
|
|
|
|
logging.info(f"Found {len(entries)} names to analyze.")
|
|
client = ollama.Client()
|
|
|
|
df.update(build_updates(client, prompt, llm_model, entries))
|
|
df.to_csv(os.path.join(DATA_DIR, 'names_featured.csv'), index=False)
|
|
logging.info("Done.")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description="Analyze names using an LLM model.")
|
|
parser.add_argument('--llm_model', type=str, default="llama3.2:3b", help="Ollama model name to use (default: llama3.2:3b)")
|
|
args = parser.parse_args()
|
|
|
|
try:
|
|
main(llm_model=args.llm_model)
|
|
except Exception as e:
|
|
logging.error(f"Fatal error: {e}", exc_info=True)
|