From 5e5e07c60163d7bb8c5d49ab53a7dc08693f5a9a Mon Sep 17 00:00:00 2001 From: bernard-ng Date: Thu, 24 Jul 2025 14:14:03 +0200 Subject: [PATCH] refactor: prompt engineering --- README.md | 2 +- processing/annotate.py | 76 ++++++++++++++++++++++++++---------------- 2 files changed, 48 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index 362bc1f..ec09215 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ python -m processing.prepare --split_eval --split_by_sex ### Annotation | Name | Description | Default | |-------------|-----------------------------------------------------|----------------| -| --llm_model | Ollama model name to use | llama3.2:3b | +| --llm_model | Ollama model name to use | mistral:7b | Example: diff --git a/processing/annotate.py b/processing/annotate.py index dab69af..c539908 100644 --- a/processing/annotate.py +++ b/processing/annotate.py @@ -25,9 +25,9 @@ def analyze_name(client: ollama.Client, model: str, prompt: str, name: str) -> d model=model, messages=[ {"role": "system", "content": prompt}, - {"role": "user", "content": name} + {"role": "user", "content": name}, ], - format=NameAnalysis.model_json_schema() + format=NameAnalysis.model_json_schema(), ) analysis = NameAnalysis.model_validate_json(response.message.content) return analysis.model_dump() @@ -35,51 +35,69 @@ def analyze_name(client: ollama.Client, model: str, prompt: str, name: str) -> d logging.warning(f"Validation error: {ve}") except Exception as e: logging.error(f"Unexpected error: {e}") - return { - "identified_name": None, - "identified_surname": None - } + return {"identified_name": None, "identified_surname": None} -def build_updates(client: ollama.Client, prompt: str, llm_model: str, rows: pd.DataFrame) -> pd.DataFrame: - """ - Build updates for the DataFrame by analyzing names. - Iterates through the DataFrame rows, analyzes each name, and returns a DataFrame with updates. - """ - logging.getLogger("httpx").setLevel(logging.WARNING) +def save_checkpoint(df: pd.DataFrame): + df.to_csv(os.path.join(DATA_DIR, "names_featured.csv"), index=False) + logging.cri(f"Checkpoint saved") + + +def build_updates(llm_model: str, df: pd.DataFrame, entries: pd.DataFrame) -> pd.DataFrame: + BATCH_SIZE = 10 + + client = ollama.Client() + prompt = load_prompt() updates = [] + + # Set logging level for HTTP client to reduce noise + # This is useful to avoid excessive logging from the HTTP client used by Ollama + logging.getLogger("httpx").setLevel(logging.WARNING) - for idx, row in rows.iterrows(): - entry = analyze_name(client, llm_model, prompt, row['name']) - entry["annotated"] = 1 - updates.append((idx, entry)) - logging.info(f"Analyzed name: {row['name']} - {entry}") + + for idx, (row_idx, row) in enumerate(entries.iterrows(), 1): + try: + entry = analyze_name(client, llm_model, prompt, row["name"]) + entry["annotated"] = 1 + updates.append((row_idx, entry)) + logging.info(f"Analyzed : {row['name']} - {entry}") + except Exception as e: + logging.warning(f"Failed to analyze '{row['name']}': {e}") + continue - return pd.DataFrame.from_dict(dict(updates), orient='index') + + if idx % BATCH_SIZE == 0 or idx == len(entries): + df.update(pd.DataFrame.from_dict(dict(updates), orient="index")) + save_checkpoint(df) + updates.clear() # avoid re-applying same updates + + return df def main(llm_model: str = "llama3.2:3b"): - df = pd.DataFrame(load_csv_dataset('names_featured.csv')) - prompt = load_prompt() + df = pd.DataFrame(load_csv_dataset(os.path.join(DATA_DIR, "names_featured.csv"))) - entries = df[df['annotated'].astype("Int8") == 0] + entries = df[df["annotated"].astype("Int8") == 0] if entries.empty: logging.info("No names to analyze.") return logging.info(f"Found {len(entries)} names to analyze.") - client = ollama.Client() - - df.update(build_updates(client, prompt, llm_model, entries)) - df.to_csv(os.path.join(DATA_DIR, 'names_featured.csv'), index=False) - logging.info("Done.") + df = build_updates(llm_model, df, entries) + save_checkpoint(df) + logging.info("Analysis complete.") -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser(description="Analyze names using an LLM model.") - parser.add_argument('--llm_model', type=str, default="llama3.2:3b", help="Ollama model name to use (default: llama3.2:3b)") + parser.add_argument( + "--llm_model", + type=str, + default="mistral:7b", + help="Ollama model name to use (default: mistral:7b)", + ) args = parser.parse_args() - + try: main(llm_model=args.llm_model) except Exception as e: