refactor: prompt engineering
This commit is contained in:
@@ -40,7 +40,7 @@ python -m processing.prepare --split_eval --split_by_sex
|
|||||||
### Annotation
|
### Annotation
|
||||||
| Name | Description | Default |
|
| Name | Description | Default |
|
||||||
|-------------|-----------------------------------------------------|----------------|
|
|-------------|-----------------------------------------------------|----------------|
|
||||||
| --llm_model | Ollama model name to use | llama3.2:3b |
|
| --llm_model | Ollama model name to use | mistral:7b |
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
|
|||||||
+47
-29
@@ -25,9 +25,9 @@ def analyze_name(client: ollama.Client, model: str, prompt: str, name: str) -> d
|
|||||||
model=model,
|
model=model,
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": prompt},
|
{"role": "system", "content": prompt},
|
||||||
{"role": "user", "content": name}
|
{"role": "user", "content": name},
|
||||||
],
|
],
|
||||||
format=NameAnalysis.model_json_schema()
|
format=NameAnalysis.model_json_schema(),
|
||||||
)
|
)
|
||||||
analysis = NameAnalysis.model_validate_json(response.message.content)
|
analysis = NameAnalysis.model_validate_json(response.message.content)
|
||||||
return analysis.model_dump()
|
return analysis.model_dump()
|
||||||
@@ -35,51 +35,69 @@ def analyze_name(client: ollama.Client, model: str, prompt: str, name: str) -> d
|
|||||||
logging.warning(f"Validation error: {ve}")
|
logging.warning(f"Validation error: {ve}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Unexpected error: {e}")
|
logging.error(f"Unexpected error: {e}")
|
||||||
return {
|
return {"identified_name": None, "identified_surname": None}
|
||||||
"identified_name": None,
|
|
||||||
"identified_surname": None
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def build_updates(client: ollama.Client, prompt: str, llm_model: str, rows: pd.DataFrame) -> pd.DataFrame:
|
def save_checkpoint(df: pd.DataFrame):
|
||||||
"""
|
df.to_csv(os.path.join(DATA_DIR, "names_featured.csv"), index=False)
|
||||||
Build updates for the DataFrame by analyzing names.
|
logging.cri(f"Checkpoint saved")
|
||||||
Iterates through the DataFrame rows, analyzes each name, and returns a DataFrame with updates.
|
|
||||||
"""
|
|
||||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
def build_updates(llm_model: str, df: pd.DataFrame, entries: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
BATCH_SIZE = 10
|
||||||
|
|
||||||
|
client = ollama.Client()
|
||||||
|
prompt = load_prompt()
|
||||||
updates = []
|
updates = []
|
||||||
|
|
||||||
|
# Set logging level for HTTP client to reduce noise
|
||||||
|
# This is useful to avoid excessive logging from the HTTP client used by Ollama
|
||||||
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||||
|
|
||||||
for idx, row in rows.iterrows():
|
|
||||||
entry = analyze_name(client, llm_model, prompt, row['name'])
|
for idx, (row_idx, row) in enumerate(entries.iterrows(), 1):
|
||||||
entry["annotated"] = 1
|
try:
|
||||||
updates.append((idx, entry))
|
entry = analyze_name(client, llm_model, prompt, row["name"])
|
||||||
logging.info(f"Analyzed name: {row['name']} - {entry}")
|
entry["annotated"] = 1
|
||||||
|
updates.append((row_idx, entry))
|
||||||
|
logging.info(f"Analyzed : {row['name']} - {entry}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to analyze '{row['name']}': {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
return pd.DataFrame.from_dict(dict(updates), orient='index')
|
|
||||||
|
if idx % BATCH_SIZE == 0 or idx == len(entries):
|
||||||
|
df.update(pd.DataFrame.from_dict(dict(updates), orient="index"))
|
||||||
|
save_checkpoint(df)
|
||||||
|
updates.clear() # avoid re-applying same updates
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
def main(llm_model: str = "llama3.2:3b"):
|
def main(llm_model: str = "llama3.2:3b"):
|
||||||
df = pd.DataFrame(load_csv_dataset('names_featured.csv'))
|
df = pd.DataFrame(load_csv_dataset(os.path.join(DATA_DIR, "names_featured.csv")))
|
||||||
prompt = load_prompt()
|
|
||||||
|
|
||||||
entries = df[df['annotated'].astype("Int8") == 0]
|
entries = df[df["annotated"].astype("Int8") == 0]
|
||||||
if entries.empty:
|
if entries.empty:
|
||||||
logging.info("No names to analyze.")
|
logging.info("No names to analyze.")
|
||||||
return
|
return
|
||||||
|
|
||||||
logging.info(f"Found {len(entries)} names to analyze.")
|
logging.info(f"Found {len(entries)} names to analyze.")
|
||||||
client = ollama.Client()
|
df = build_updates(llm_model, df, entries)
|
||||||
|
save_checkpoint(df)
|
||||||
df.update(build_updates(client, prompt, llm_model, entries))
|
logging.info("Analysis complete.")
|
||||||
df.to_csv(os.path.join(DATA_DIR, 'names_featured.csv'), index=False)
|
|
||||||
logging.info("Done.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Analyze names using an LLM model.")
|
parser = argparse.ArgumentParser(description="Analyze names using an LLM model.")
|
||||||
parser.add_argument('--llm_model', type=str, default="llama3.2:3b", help="Ollama model name to use (default: llama3.2:3b)")
|
parser.add_argument(
|
||||||
|
"--llm_model",
|
||||||
|
type=str,
|
||||||
|
default="mistral:7b",
|
||||||
|
help="Ollama model name to use (default: mistral:7b)",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
main(llm_model=args.llm_model)
|
main(llm_model=args.llm_model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
Reference in New Issue
Block a user