diff --git a/processing/annotate.py b/processing/annotate.py index 2d32949..731fa78 100644 --- a/processing/annotate.py +++ b/processing/annotate.py @@ -67,7 +67,10 @@ def build_updates(llm_model: str, df: pd.DataFrame, entries: pd.DataFrame) -> pd if idx % BATCH_SIZE == 0 or idx == len(entries): - df.update(pd.DataFrame.from_dict(dict(updates), orient="index")) + update_df = pd.DataFrame.from_dict(dict(updates), orient="index") + update_df = update_df['annotated'].astype('Int8').fillna(0) + + df.update(update_df) save_checkpoint(df) updates.clear() # avoid re-applying same updates