refactoring: add initial pipeline configuration and model classes
This commit is contained in:
@@ -1,109 +0,0 @@
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import ollama
|
||||
import pandas as pd
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from tqdm import tqdm
|
||||
from typing import Optional
|
||||
|
||||
from misc import load_prompt, load_csv_dataset, DATA_DIR, logging
|
||||
|
||||
|
||||
class NameAnalysis(BaseModel):
|
||||
identified_name: Optional[str]
|
||||
identified_surname: Optional[str]
|
||||
|
||||
|
||||
def analyze_name(client: ollama.Client, model: str, prompt: str, name: str) -> dict:
|
||||
"""
|
||||
Analyze a name using the specified model and prompt.
|
||||
Returns a dictionary with identified name, surname, and category.
|
||||
"""
|
||||
try:
|
||||
response = client.chat(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": name},
|
||||
],
|
||||
format=NameAnalysis.model_json_schema(),
|
||||
)
|
||||
analysis = NameAnalysis.model_validate_json(response.message.content)
|
||||
return analysis.model_dump()
|
||||
except ValidationError as ve:
|
||||
logging.warning(f"Validation error: {ve}")
|
||||
except Exception as e:
|
||||
logging.error(f"Unexpected error: {e}")
|
||||
return {"identified_name": None, "identified_surname": None}
|
||||
|
||||
|
||||
def save_checkpoint(df: pd.DataFrame):
|
||||
df.to_csv(os.path.join(DATA_DIR, "names_featured.csv"), index=False)
|
||||
logging.critical(f"Checkpoint saved")
|
||||
|
||||
|
||||
def build_updates(llm_model: str, df: pd.DataFrame, entries: pd.DataFrame) -> pd.DataFrame:
|
||||
BATCH_SIZE = 10
|
||||
|
||||
client = ollama.Client()
|
||||
prompt = load_prompt()
|
||||
updates = []
|
||||
|
||||
# Set logging level for HTTP client to reduce noise
|
||||
# This is useful to avoid excessive logging from the HTTP client used by Ollama
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
for idx, (row_idx, row) in enumerate(entries.iterrows(), 1):
|
||||
try:
|
||||
entry = analyze_name(client, llm_model, prompt, row["name"])
|
||||
entry["annotated"] = 1
|
||||
updates.append((row_idx, entry))
|
||||
logging.info(f"Analyzed: {row['name']} - {entry}")
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to analyze '{row['name']}': {e}")
|
||||
continue
|
||||
|
||||
if idx % BATCH_SIZE == 0 or idx == len(entries):
|
||||
update_df = pd.DataFrame.from_dict(dict(updates), orient="index")
|
||||
update_df["annotated"] = pd.to_numeric(update_df["annotated"], errors="coerce").fillna(0).astype("Int8")
|
||||
|
||||
df.update(update_df)
|
||||
save_checkpoint(df)
|
||||
updates.clear() # avoid re-applying same updates
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def main(llm_model: str = "llama3.2:3b"):
|
||||
df = pd.DataFrame(load_csv_dataset(os.path.join(DATA_DIR, "names_featured.csv")))
|
||||
|
||||
# Safely cast 'annotated' column to Int8, handling float-like strings (e.g., '1.0')
|
||||
df["annotated"] = pd.to_numeric(df["annotated"], errors="coerce").fillna(0).astype(float).astype("Int8")
|
||||
|
||||
entries = df[df["annotated"] == 0]
|
||||
if entries.empty:
|
||||
logging.info("No names to analyze.")
|
||||
return
|
||||
|
||||
logging.info(f"Found {len(entries)} names to analyze.")
|
||||
df = build_updates(llm_model, df, entries)
|
||||
save_checkpoint(df)
|
||||
logging.info("Analysis complete.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Analyze names using an LLM model.")
|
||||
parser.add_argument(
|
||||
"--llm_model",
|
||||
type=str,
|
||||
default="mistral:7b",
|
||||
help="Ollama model name to use (default: mistral:7b)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
main(llm_model=args.llm_model)
|
||||
except Exception as e:
|
||||
logging.error(f"Fatal error: {e}", exc_info=True)
|
||||
Reference in New Issue
Block a user