From d5a4aaaf4aa82fb838da37ecf9f41aba9eb24e46 Mon Sep 17 00:00:00 2001 From: bernard-ng Date: Mon, 11 Aug 2025 07:13:09 +0200 Subject: [PATCH] feat: add NER annotation step and integrate into pipeline --- README.md | 1 + config/pipeline.development.yaml | 4 +- config/pipeline.production.yaml | 1 + config/pipeline.yaml | 31 +- core/config/annotation_config.py | 29 ++ core/config/config_manager.py | 2 +- core/config/data_config.py | 5 +- core/config/llm_config.py | 13 - core/config/pipeline_config.py | 6 +- main.py | 4 + monitor.py | 106 +----- ner.py | 52 +++ processing/monitoring/pipeline_monitor.py | 2 +- processing/ner/__init__.py | 0 processing/ner/ner_data_builder.py | 198 +++++++++++ processing/ner/ner_name_model.py | 356 ++++++++++++++++++++ processing/ner/ner_name_tagger.py | 200 +++++++++++ processing/steps/__init__.py | 12 +- processing/steps/data_cleaning_step.py | 3 + processing/steps/feature_extraction_step.py | 23 +- processing/steps/llm_annotation_step.py | 55 ++- processing/steps/ner_annotation_step.py | 164 +++++++++ requirements.txt | 1 + 23 files changed, 1108 insertions(+), 160 deletions(-) create mode 100644 core/config/annotation_config.py delete mode 100644 core/config/llm_config.py create mode 100755 ner.py create mode 100644 processing/ner/__init__.py create mode 100644 processing/ner/ner_data_builder.py create mode 100644 processing/ner/ner_name_model.py create mode 100644 processing/ner/ner_name_tagger.py create mode 100644 processing/steps/ner_annotation_step.py diff --git a/README.md b/README.md index 60bc12b..a338236 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,7 @@ the `drc-ners-nlp/config/pipeline.yaml` file. stages: - "data_cleaning" - "feature_extraction" + - "ner_annotation" - "llm_annotation" - "data_splitting" ``` diff --git a/config/pipeline.development.yaml b/config/pipeline.development.yaml index bccc0ae..d437458 100644 --- a/config/pipeline.development.yaml +++ b/config/pipeline.development.yaml @@ -12,6 +12,7 @@ processing: stages: - "data_cleaning" - "feature_extraction" + #- "ner_annotation" #- "llm_annotation" - "data_splitting" @@ -27,7 +28,8 @@ llm: # Data handling configuration data: - max_dataset_size: 100_000 + split_evaluation: false + max_dataset_size: null balance_by_sex: true # Enhanced logging for development diff --git a/config/pipeline.production.yaml b/config/pipeline.production.yaml index ace19e0..7cdcc97 100644 --- a/config/pipeline.production.yaml +++ b/config/pipeline.production.yaml @@ -12,6 +12,7 @@ processing: stages: - "data_cleaning" - "feature_extraction" + - "ner_annotation" - "llm_annotation" - "data_splitting" diff --git a/config/pipeline.yaml b/config/pipeline.yaml index c316842..79e2fa8 100644 --- a/config/pipeline.yaml +++ b/config/pipeline.yaml @@ -18,9 +18,10 @@ paths: checkpoints_dir: "./data/checkpoints" # Directory for model checkpoints # Pipeline stages -stages: # List of stages in the processing pipeline +stages: # List of stages in the processing pipeline - "data_cleaning" # Data cleaning stage - "feature_extraction" # Feature extraction stage + - "ner_annotation" # NER-based annotation stage - "llm_annotation" # LLM annotation stage (computational intensive) - "data_splitting" # Data splitting stage @@ -36,15 +37,20 @@ processing: - "latin1" chunk_size: 100_000 # Size of data chunks to process in parallel -# LLM annotation settings -llm: - model_name: "mistral:7b" # Name of the LLM model to use - requests_per_minute: 60 # Requests per minute to the LLM service - requests_per_second: 2 # Requests per second to the LLM service - retry_attempts: 3 # Number of retry attempts for LLM requests - timeout_seconds: 600 # Timeout for LLM requests - max_concurrent_requests: 2 # Maximum concurrent requests to the LLM service - enable_rate_limiting: true # Enable rate limiting to avoid overloading the LLM service +# Annotation settings +annotation: + llm: + model_name: "mistral:7b" # Name of the LLM model to use + requests_per_minute: 60 # Requests per minute to the LLM service + requests_per_second: 2 # Requests per second to the LLM service + retry_attempts: 3 # Number of retry attempts for LLM requests + timeout_seconds: 600 # Timeout for LLM requests + max_concurrent_requests: 2 # Maximum concurrent requests to the LLM service + enable_rate_limiting: true # Enable rate limiting to avoid overloading the LLM service + + ner: + model_name: "drc_names_ner" # Name of the NER model to use + retry_attempts: 3 # Number of retry attempts for NER requests # Data handling configuration data: @@ -54,8 +60,11 @@ data: evaluation: "names_evaluation.csv" # Output file for evaluation set males: "names_males.csv" # Output files for male names females: "names_females.csv" # Output files for female names - split_evaluation: true # Should the dataset be split into training and evaluation sets ? + ner_data: "names_ner.json" # Output file for NER annotated data + ner_spacy: "names_ner.spacy" # Output file for NER annotated data using spaCy format + split_evaluation: false # Should the dataset be split into training and evaluation sets ? split_by_gender: true # Should the dataset be split by gender ? + split_ner_data: true # Should the NER data be extracted and saved? evaluation_fraction: 0.2 # Fraction of data to use for evaluation random_seed: 42 # Random seed for reproducibility max_dataset_size: null # Maximum size of the dataset to process, set to null for no diff --git a/core/config/annotation_config.py b/core/config/annotation_config.py new file mode 100644 index 0000000..59c1c21 --- /dev/null +++ b/core/config/annotation_config.py @@ -0,0 +1,29 @@ +from pydantic import BaseModel + +class NERConfig(BaseModel): + """NER annotation configuration""" + + model_name: str = "drc_names_ner" + retry_attempts: int = 3 + + +class LLMConfig(BaseModel): + """LLM annotation configuration""" + + model_name: str = "mistral:7b" + requests_per_minute: int = 60 + requests_per_second: int = 2 + retry_attempts: int = 3 + timeout_seconds: int = 30 + max_concurrent_requests: int = 2 + enable_rate_limiting: bool = False + + +class AnnotationConfig(BaseModel): + """Base class for annotation configurations""" + + llm: LLMConfig = LLMConfig() + ner: NERConfig = NERConfig() + + class Config: + arbitrary_types_allowed = True diff --git a/core/config/config_manager.py b/core/config/config_manager.py index a55aecb..3939405 100644 --- a/core/config/config_manager.py +++ b/core/config/config_manager.py @@ -65,7 +65,7 @@ class ConfigManager: # Ensure paths are properly set if "paths" not in config_data: - config_data["paths"] = self.default_paths.dict() + config_data["paths"] = self.default_paths.model_dump() self._config = PipelineConfig(**config_data) return self._config diff --git a/core/config/data_config.py b/core/config/data_config.py index 639f68f..cb5c150 100644 --- a/core/config/data_config.py +++ b/core/config/data_config.py @@ -14,10 +14,13 @@ class DataConfig(BaseModel): "evaluation": "names_evaluation.csv", "males": "names_males.csv", "females": "names_females.csv", + "ner_data": "names_ner.json", + "ner_spacy": "names_ner.spacy" } ) - split_evaluation: bool = True + split_evaluation: bool = False split_by_gender: bool = True + split_ner_data: bool = True evaluation_fraction: float = 0.2 random_seed: int = 42 diff --git a/core/config/llm_config.py b/core/config/llm_config.py deleted file mode 100644 index 5f05967..0000000 --- a/core/config/llm_config.py +++ /dev/null @@ -1,13 +0,0 @@ -from pydantic import BaseModel - - -class LLMConfig(BaseModel): - """LLM annotation configuration""" - - model_name: str = "mistral:7b" - requests_per_minute: int = 60 - requests_per_second: int = 2 - retry_attempts: int = 3 - timeout_seconds: int = 30 - max_concurrent_requests: int = 2 - enable_rate_limiting: bool = False diff --git a/core/config/pipeline_config.py b/core/config/pipeline_config.py index 87e412b..43bf1a8 100644 --- a/core/config/pipeline_config.py +++ b/core/config/pipeline_config.py @@ -1,8 +1,8 @@ from pydantic import BaseModel -from core.config.logging_config import LoggingConfig +from core.config.annotation_config import AnnotationConfig from core.config.data_config import DataConfig -from core.config.llm_config import LLMConfig +from core.config.logging_config import LoggingConfig from core.config.processing_config import ProcessingConfig from core.config.project_paths import ProjectPaths @@ -17,7 +17,7 @@ class PipelineConfig(BaseModel): paths: ProjectPaths stages: list[str] = [] processing: ProcessingConfig = ProcessingConfig() - llm: LLMConfig = LLMConfig() + annotation: AnnotationConfig = AnnotationConfig() data: DataConfig = DataConfig() logging: LoggingConfig = LoggingConfig() diff --git a/main.py b/main.py index 1d2c7e4..4e13e0d 100755 --- a/main.py +++ b/main.py @@ -8,11 +8,13 @@ from core.config import setup_config from core.utils import get_data_file_path from core.utils.data_loader import DataLoader from processing.batch.batch_config import BatchConfig +from processing.ner.ner_data_builder import NERDataBuilder from processing.pipeline import Pipeline from processing.steps.data_cleaning_step import DataCleaningStep from processing.steps.data_splitting_step import DataSplittingStep from processing.steps.feature_extraction_step import FeatureExtractionStep from processing.steps.llm_annotation_step import LLMAnnotationStep +from processing.steps.ner_annotation_step import NERAnnotationStep def create_pipeline(config) -> Pipeline: @@ -29,6 +31,7 @@ def create_pipeline(config) -> Pipeline: steps = [ DataCleaningStep(config), FeatureExtractionStep(config), + NERAnnotationStep(config), LLMAnnotationStep(config), DataSplittingStep(config), ] @@ -67,6 +70,7 @@ def run_pipeline(config) -> int: splitting_step = pipeline.steps[-1] if isinstance(splitting_step, DataSplittingStep): splitting_step.save_splits(result_df) + NERDataBuilder(config).build(result_df) # Show completion statistics progress = pipeline.get_progress() diff --git a/monitor.py b/monitor.py index 4e37049..d76a9f3 100755 --- a/monitor.py +++ b/monitor.py @@ -5,65 +5,31 @@ import traceback from pathlib import Path from core.config import setup_config -from processing.monitoring.data_analyzer import DatasetAnalyzer from processing.monitoring.pipeline_monitor import PipelineMonitor def main(): - parser = argparse.ArgumentParser( - description="Monitor and manage the DRC names processing pipeline" - ) - parser.add_argument("--config", type=Path, help="Path to configuration file") - parser.add_argument( - "--env", type=str, default="development", - help="Environment name (default: development)" - ) + choices = ["data_cleaning", "feature_extraction", "ner_annotation", "llm_annotation", "data_splitting"] + parser = argparse.ArgumentParser(description="Monitor and manage the DRC names processing pipeline") + parser.add_argument("--config", type=Path, help="Path to configuration file") + parser.add_argument("--env", type=str, default="development", help="Environment") subparsers = parser.add_subparsers(dest="command", help="Available commands") # Status command - status_parser = subparsers.add_parser("status", help="Show pipeline status") - status_parser.add_argument( - "--detailed", - action="store_true", - help="Show detailed information including failed batch IDs", - ) + subparsers.add_parser("status", help="Show pipeline status") # Clean command clean_parser = subparsers.add_parser("clean", help="Clean checkpoint files") - clean_parser.add_argument( - "--step", - type=str, - choices=["data_cleaning", "feature_extraction", "llm_annotation", "data_splitting"], - help="Clean specific step (default: all)", - ) - clean_parser.add_argument( - "--keep-last", type=int, default=1, help="Number of recent checkpoints to keep (default: 1)" - ) + clean_parser.add_argument("--step", type=str, choices=choices, help="Specific step (default: all)") + clean_parser.add_argument("--keep-last", type=int, default=1, help="Checkpoints to keep (default: 1)") clean_parser.add_argument("--force", action="store_true", help="Clean without confirmation") # Reset command reset_parser = subparsers.add_parser("reset", help="Reset pipeline step") - reset_parser.add_argument( - "step", - type=str, - choices=["data_cleaning", "feature_extraction", "llm_annotation", "data_splitting"], - help="Step to reset", - ) + reset_parser.add_argument("--step", type=str, choices=choices, help="Specific step (default: all)") + reset_parser.add_argument("--all", action="store_true", help="Reset all steps") reset_parser.add_argument("--force", action="store_true", help="Reset without confirmation") - - # Analyze command - analyze_parser = subparsers.add_parser("analyze", help="Analyze dataset") - analyze_parser.add_argument( - "--file", - type=str, - default="names_featured.csv", - help="Dataset file to analyze (default: names_featured.csv)", - ) - - # Checkpoint info command - info_parser = subparsers.add_parser("info", help="Show checkpoint information") - args = parser.parse_args() if not args.command: @@ -71,13 +37,11 @@ def main(): return 1 try: - # Load configuration and setup logging - config = setup_config(config_path=args.config, env=args.env) - + setup_config(config_path=args.config, env=args.env) monitor = PipelineMonitor() if args.command == "status": - monitor.print_status(detailed=args.detailed) + monitor.print_status(detailed=True) elif args.command == "clean": checkpoint_info = monitor.count_checkpoint_files() @@ -106,49 +70,13 @@ def main(): print("Cancelled") return 0 - monitor.reset_step(args.step) - print(f"Reset completed for {args.step}") + if args.step: + monitor.reset_step(args.step) + else: + for step in monitor.steps: + monitor.reset_step(step) - elif args.command == "analyze": - # Use configured data directory - data_dir = config.paths.data_dir - filepath = data_dir / args.file - - if not filepath.exists(): - print(f"File not found: {filepath}") - return 1 - - analyzer = DatasetAnalyzer(str(filepath)) - - if not analyzer.load_data(): - return 1 - - completion_stats = analyzer.analyze_completion() - - print(f"\n=== Dataset Analysis: {args.file} ===") - print(f"Total rows: {completion_stats['total_rows']:,}") - print( - f"Annotated: {completion_stats['annotated_rows']:,} ({completion_stats['annotation_percentage']:.1f}%)") - print(f"Unannotated: {completion_stats['unannotated_rows']:,}") - print( - f"Complete names: {completion_stats['complete_names']:,} ({completion_stats['completeness_percentage']:.1f}%)" - ) - - elif args.command == "info": - checkpoint_info = monitor.count_checkpoint_files() - - print(f"\n=== Checkpoint Information ===") - print(f"Total storage: {checkpoint_info['total_size_mb']:.1f} MB") - print() - - for step in monitor.steps: - step_info = checkpoint_info[step] - print(f"{step.replace('_', ' ').title()}:") - print(f" Files: {step_info['files']}") - print(f" Size: {step_info['size_mb']:.1f} MB") - print() - - return 0 + print(f"Reset completed") except Exception as e: print(f"Monitoring failed: {e}") diff --git a/ner.py b/ner.py new file mode 100755 index 0000000..d2a8834 --- /dev/null +++ b/ner.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +import argparse +import logging +import sys +from pathlib import Path + +from core.config import setup_config +from processing.ner.ner_data_builder import NERDataBuilder +from processing.ner.ner_name_model import NERNameModel + + +def train(config_path=None, env="development"): + """Train the NER model.""" + try: + config = setup_config(config_path=config_path, env=env) + trainer = NERNameModel(config) + builder = NERDataBuilder(config) + + data_path = Path(config.paths.data_dir) / config.data.output_files["ner_data"] + if not data_path.exists(): + builder.build() + + trainer.create_blank_model("fr") + data = trainer.load_data(str(data_path)) + + split_idx = int(len(data) * 0.8) + train_data, eval_data = data[:split_idx], data[split_idx:] + + logging.info(f"Training with {len(train_data)} examples, evaluating on {len(eval_data)}") + trainer.train(train_data, epochs=1, batch_size=config.processing.batch_size, dropout_rate=0.3) + trainer.evaluate(eval_data) + + model_path = trainer.save() + logging.info(f"Model saved to: {model_path}") + return 0 + + except Exception as e: + logging.error(f"NER Training failed: {e}", exc_info=True) + return 1 + + +def main(): + parser = argparse.ArgumentParser(description="Train NER model for DRC names") + parser.add_argument("--config", type=str, help="Path to configuration file") + parser.add_argument("--env", type=str, default="development", help="Environment name") + args = parser.parse_args() + + sys.exit(train(config_path=args.config, env=args.env)) + + +if __name__ == "__main__": + main() diff --git a/processing/monitoring/pipeline_monitor.py b/processing/monitoring/pipeline_monitor.py index 7b131f7..92a7488 100644 --- a/processing/monitoring/pipeline_monitor.py +++ b/processing/monitoring/pipeline_monitor.py @@ -19,7 +19,7 @@ class PipelineMonitor: self.paths = paths self.checkpoint_dir = paths.checkpoints_dir - self.steps = ["data_cleaning", "feature_extraction", "llm_annotation", "data_splitting"] + self.steps = ["data_cleaning", "feature_extraction", "ner_annotation", "llm_annotation", "data_splitting"] def get_step_status(self, step_name: str) -> Dict: """Get status of a specific pipeline step""" diff --git a/processing/ner/__init__.py b/processing/ner/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/processing/ner/ner_data_builder.py b/processing/ner/ner_data_builder.py new file mode 100644 index 0000000..5ce23b3 --- /dev/null +++ b/processing/ner/ner_data_builder.py @@ -0,0 +1,198 @@ +import ast +import json +import logging +from pathlib import Path + +import pandas as pd +import spacy +from spacy.tokens import DocBin +from spacy.util import filter_spans + +from core.config import PipelineConfig +from core.utils import get_data_file_path + + +class NERDataBuilder: + def __init__(self, config: PipelineConfig): + self.config = config + + @classmethod + def parse_entities(cls, entities_str): + """Parse entity string (tuple format or JSON) into spaCy-style tuples.""" + if not entities_str or entities_str in ["[]", "", "nan"]: + return [] + + entities_str = str(entities_str).strip() + + # Handle different formats + try: + # Try to parse as Python literal (tuples or lists) + if entities_str.startswith("[(") and entities_str.endswith(")]"): + # Standard tuple format: [(0, 6, 'NATIVE'), ...] + return ast.literal_eval(entities_str) + elif entities_str.startswith("[[") and entities_str.endswith("]]"): + # Nested list format: [[0, 6, 'NATIVE'], ...] + nested_list = ast.literal_eval(entities_str) + return [(start, end, label) for start, end, label in nested_list] + elif entities_str.startswith("[{") and entities_str.endswith("}]"): + # JSON format: [{"start": 0, "end": 6, "label": "NATIVE"}, ...] + json_entities = json.loads(entities_str) + return [(e["start"], e["end"], e["label"]) for e in json_entities] + else: + # Try general ast.literal_eval for other formats + parsed = ast.literal_eval(entities_str) + if isinstance(parsed, list): + # Convert any list format to tuples + result = [] + for item in parsed: + if isinstance(item, (list, tuple)) and len(item) == 3: + result.append((item[0], item[1], item[2])) + return result + + except (ValueError, SyntaxError, json.JSONDecodeError) as e: + logging.warning(f"Failed to parse entities: {entities_str} ({e})") + return [] + + logging.warning(f"Unknown entity format: {entities_str}") + return [] + + @classmethod + def validate_entities(cls, entities, text): + """Validate and sort entity tuples, removing overlaps and invalid spans.""" + if not entities or not text: + return [] + + text = str(text).strip() + if not text: + return [] + + # Filter out invalid entities + valid_entities = [] + for entity in entities: + if not isinstance(entity, (list, tuple)) or len(entity) != 3: + logging.warning(f"Invalid entity format: {entity}") + continue + + start, end, label = entity + + # Ensure start/end are integers + try: + start = int(start) + end = int(end) + except (ValueError, TypeError): + logging.warning(f"Invalid start/end positions: {entity}") + continue + + # Ensure label is string + if not isinstance(label, str): + logging.warning(f"Invalid label type: {entity}") + continue + + # Check bounds + if not (0 <= start < end <= len(text)): + logging.warning(f"Entity span out of bounds: {entity} for text '{text}' (length {len(text)})") + continue + + # Check that span contains actual text + span_text = text[start:end].strip() + if not span_text: + logging.warning(f"Empty span: {entity} in text '{text}'") + continue + + valid_entities.append((start, end, label)) + + if not valid_entities: + return [] + + # Sort by start position + valid_entities.sort(key=lambda x: (x[0], x[1])) + + # Remove overlapping entities (keep the first one) + filtered = [] + for start, end, label in valid_entities: + # Check for overlap with already added entities + has_overlap = False + for e_start, e_end, _ in filtered: + if not (end <= e_start or start >= e_end): + has_overlap = True + logging.warning( + f"Removing overlapping entity ({start}, {end}, '{label}') " + f"conflicts with ({e_start}, {e_end}) in '{text}'" + ) + break + + if not has_overlap: + filtered.append((start, end, label)) + + return filtered + + @classmethod + def create_doc(cls, text, entities, nlp): + """Create a spaCy Doc object with entities added.""" + doc = nlp(text) + ents = [] + + for start, end, label in entities: + span = doc.char_span(start, end, label=label, alignment_mode="contract") \ + or doc.char_span(start, end, label=label, alignment_mode="strict") + if span: + ents.append(span) + else: + logging.warning(f"Could not create span ({start}, {end}, '{label}') in '{text}'") + + doc.ents = filter_spans(ents) if ents else [] + return doc + + def build(self, data: pd.DataFrame = None) -> int: + """Build the dataset for NER training.""" + logging.info("Building dataset for NER training") + try: + df = pd.read_csv(get_data_file_path("names_featured.csv", self.config)) \ + if data is None \ + else data + + ner_df = df[df["ner_tagged"] == 1].copy() + if ner_df.empty: + logging.error("No NER tagged data found in the CSV") + return 1 + + logging.info(f"Found {len(ner_df)} NER tagged entries") + nlp = spacy.blank("fr") + doc_bin, training_data = DocBin(), [] + processed_count, skipped_count = 0, 0 + + for _, row in ner_df.iterrows(): + text = str(row.get("name", "")).strip() + if not text: + continue + + entities = self.parse_entities(row.get("ner_entities", "[]")) + entities = self.validate_entities(entities, text) + + training_data.append((text, {"entities": entities})) + try: + doc_bin.add(self.create_doc(text, entities, nlp)) + processed_count += 1 + except Exception as e: + logging.error(f"Error processing '{text}': {e}") + skipped_count += 1 + + if not training_data: + logging.error("No valid training examples generated") + return 1 + + json_path = Path(self.config.paths.data_dir) / self.config.data.output_files["ner_data"] + spacy_path = Path(self.config.paths.data_dir) / self.config.data.output_files["ner_spacy"] + + with open(json_path, "w", encoding="utf-8") as f: + json.dump(training_data, f, ensure_ascii=False, indent=None) + doc_bin.to_disk(spacy_path) + + logging.info(f"Processed: {processed_count}, Skipped: {skipped_count}") + logging.info(f"Saved NER data in json format to {json_path}") + logging.info(f"Saved NER data in spaCy format to {spacy_path}") + return 0 + + except Exception as e: + logging.error(f"Failed to build NER dataset: {e}", exc_info=True) + return 1 diff --git a/processing/ner/ner_name_model.py b/processing/ner/ner_name_model.py new file mode 100644 index 0000000..9537910 --- /dev/null +++ b/processing/ner/ner_name_model.py @@ -0,0 +1,356 @@ +import json +import logging +import os +from pathlib import Path +from typing import Dict, Any, List, Tuple + +import spacy +from spacy.training import Example +from spacy.util import minibatch + +from core.config.pipeline_config import PipelineConfig + + +class NERNameModel: + """NER model trainer using spaCy for DRC names entity recognition""" + + def __init__(self, config: PipelineConfig): + self.config = config + self.nlp = None + self.ner = None + self.model_path = None + self.training_stats = {} + + def create_blank_model(self, language: str = "fr") -> None: + """Create a blank spaCy model with NER pipeline""" + logging.info(f"Creating blank {language} model for NER training") + + # Create blank model - French tokenizer works well for DRC names + self.nlp = spacy.blank(language) + + # Add NER pipeline component + if "ner" not in self.nlp.pipe_names: + self.ner = self.nlp.add_pipe("ner") + else: + self.ner = self.nlp.get_pipe("ner") + + # Add our custom labels + self.ner.add_label("NATIVE") + self.ner.add_label("SURNAME") + + logging.info("Blank model created with NATIVE and SURNAME labels") + + @classmethod + def load_data(cls, data_path: str) -> List[Tuple[str, Dict]]: + """Load training data from JSON file - compatible with NERNameTagger output format""" + if not os.path.exists(data_path): + raise FileNotFoundError(f"Training data not found at {data_path}") + + logging.info(f"Loading training data from {data_path}") + + with open(data_path, 'r', encoding='utf-8') as f: + raw_data = json.load(f) + + # Validate and clean training data + valid_data = [] + skipped_count = 0 + + for i, item in enumerate(raw_data): + try: + if not isinstance(item, (list, tuple)) or len(item) != 2: + logging.warning(f"Skipping invalid training example format at index {i}: {item}") + skipped_count += 1 + continue + + text, annotations = item + + # Validate text + if not isinstance(text, str) or not text.strip(): + logging.warning(f"Skipping invalid text at index {i}: {repr(text)}") + skipped_count += 1 + continue + + # Handle different annotation formats from NERNameTagger + if not isinstance(annotations, dict) or "entities" not in annotations: + logging.warning(f"Skipping invalid annotations at index {i}: {annotations}") + skipped_count += 1 + continue + + entities_raw = annotations["entities"] + + # Parse entities - handle both string and list formats from tagger + if isinstance(entities_raw, str): + # String format from tagger: "[(0, 6, 'NATIVE'), ...]" + try: + import ast + entities = ast.literal_eval(entities_raw) + if not isinstance(entities, list): + logging.warning(f"Parsed entities is not a list at index {i}: {entities}") + skipped_count += 1 + continue + except (ValueError, SyntaxError) as e: + logging.warning(f"Failed to parse entity string at index {i}: {entities_raw} ({e})") + skipped_count += 1 + continue + elif isinstance(entities_raw, list): + # Already in list format + entities = entities_raw + else: + logging.warning(f"Skipping invalid entities format at index {i}: {entities_raw}") + skipped_count += 1 + continue + + # Validate each entity + valid_entities = [] + for entity in entities: + if not isinstance(entity, (list, tuple)) or len(entity) != 3: + logging.warning(f"Skipping invalid entity format in '{text}': {entity}") + continue + + start, end, label = entity + + # Validate entity components + if (not isinstance(start, int) or not isinstance(end, int) or + not isinstance(label, str) or start >= end or + start < 0 or end > len(text)): + logging.warning(f"Skipping invalid entity bounds in '{text}': {entity}") + continue + + # Check for overlaps with already validated entities + has_overlap = any( + start < v_end and end > v_start + for v_start, v_end, _ in valid_entities + ) + + if has_overlap: + logging.warning(f"Skipping overlapping entity in '{text}': {entity}") + continue + + # Validate that the span doesn't contain spaces (matching tagger validation) + span_text = text[start:end] + if not span_text or span_text != span_text.strip() or ' ' in span_text: + logging.warning(f"Skipping entity with spaces in '{text}': {entity} -> '{span_text}'") + continue + + valid_entities.append((start, end, label)) + + if not valid_entities: + logging.warning(f"Skipping training example with no valid entities: '{text}'") + skipped_count += 1 + continue + + # Sort entities by start position + valid_entities.sort(key=lambda x: x[0]) + valid_data.append((text.strip(), {"entities": valid_entities})) + + except Exception as e: + logging.error(f"Error processing training example at index {i}: {e}") + skipped_count += 1 + continue + + logging.info(f"Loaded {len(valid_data)} valid training examples, skipped {skipped_count} invalid ones") + + if not valid_data: + raise ValueError("No valid training examples found in the data") + + return valid_data + + def train( + self, + data: List[Tuple[str, Dict]], + epochs: int = 5, + batch_size: int = 16, + dropout_rate: float = 0.2, + ) -> None: + """Train the NER model""" + logging.info(f"Starting NER training with {len(data)} examples") + logging.info(f"Training parameters: epochs={epochs}, batch_size={batch_size}, dropout={dropout_rate}") + + if self.nlp is None: + raise ValueError("Model not initialized. Call create_blank_model() first.") + + # Initialize the model + self.nlp.initialize() + + # Training loop + losses_history = [] + + for epoch in range(epochs): + losses = {} + + # Create training examples + examples = [] + for text, annotations in data: + doc = self.nlp.make_doc(text) + example = Example.from_dict(doc, annotations) + examples.append(example) + logging.info(f"Training example: {text[:30]}... with entities {annotations.get('entities', [])}") + + # Train in batches + batches = minibatch(examples, size=batch_size) + for batch in batches: + self.nlp.update( + batch, + losses=losses, + drop=dropout_rate, + sgd=self.nlp.create_optimizer() + ) + logging.info(f"Training batch with {len(batch)} examples, current losses: {losses}") + + epoch_loss = losses.get("ner", 0) + losses_history.append(epoch_loss) + logging.info(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss:.4f}") + + # Store training statistics + self.training_stats = { + "epochs": epochs, + "final_loss": losses_history[-1] if losses_history else 0, + "training_examples": len(data), + "loss_history": losses_history, + "batch_size": batch_size, + "dropout_rate": dropout_rate + } + + logging.info(f"Training completed. Final loss: {self.training_stats['final_loss']:.4f}") + + def evaluate(self, test_data: List[Tuple[str, Dict]]) -> Dict[str, Any]: + """Evaluate the trained model on test data""" + if self.nlp is None: + raise ValueError("Model not trained. Call train_model() first.") + + logging.info(f"Evaluating model on {len(test_data)} test examples") + + total_examples = len(test_data) + correct_entities = 0 + predicted_entities = 0 + actual_entities = 0 + + entity_stats = {"NATIVE": {"tp": 0, "fp": 0, "fn": 0}, "SURNAME": {"tp": 0, "fp": 0, "fn": 0}} + + for text, annotations in test_data: + # Get actual entities + actual_ents = set() + for start, end, label in annotations.get("entities", []): + actual_ents.add((start, end, label)) + actual_entities += 1 + + # Get predicted entities + doc = self.nlp(text) + predicted_ents = set() + for ent in doc.ents: + predicted_ents.add((ent.start_char, ent.end_char, ent.label_)) + predicted_entities += 1 + + # Calculate matches + matches = actual_ents.intersection(predicted_ents) + correct_entities += len(matches) + + # Update per-label statistics + for start, end, label in actual_ents: + if (start, end, label) in predicted_ents: + entity_stats[label]["tp"] += 1 + else: + entity_stats[label]["fn"] += 1 + + for start, end, label in predicted_ents: + if (start, end, label) not in actual_ents: + entity_stats[label]["fp"] += 1 + + # Calculate overall metrics + precision = correct_entities / predicted_entities if predicted_entities > 0 else 0 + recall = correct_entities / actual_entities if actual_entities > 0 else 0 + f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 + + # Calculate per-label metrics + label_metrics = {} + for label, stats in entity_stats.items(): + tp, fp, fn = stats["tp"], stats["fp"], stats["fn"] + label_precision = tp / (tp + fp) if (tp + fp) > 0 else 0 + label_recall = tp / (tp + fn) if (tp + fn) > 0 else 0 + label_f1 = ( + 2 * (label_precision * label_recall) / (label_precision + label_recall)) \ + if (label_precision + label_recall) > 0 else 0 + + label_metrics[label] = { + "precision": label_precision, + "recall": label_recall, + "f1_score": label_f1, + "support": tp + fn + } + + evaluation_results = { + "overall": { + "precision": precision, + "recall": recall, + "f1_score": f1_score, + "total_examples": total_examples, + "correct_entities": correct_entities, + "predicted_entities": predicted_entities, + "actual_entities": actual_entities + }, + "by_label": label_metrics + } + + logging.info(f"NER Evaluation completed. Overall F1: {f1_score:.4f}") + return evaluation_results + + def save(self, model_name: str = "drc_ner_model") -> str: + """Save the trained model""" + if self.nlp is None: + raise ValueError("No model to save. Train a model first.") + + # Create model directory + model_dir = self.config.paths.models_dir / model_name + model_dir.mkdir(parents=True, exist_ok=True) + + # Save the model + self.nlp.to_disk(model_dir) + self.model_path = str(model_dir) + + # Save training statistics + stats_path = model_dir / "training_stats.json" + with open(stats_path, 'w', encoding='utf-8') as f: + json.dump(self.training_stats, f, indent=2) + + logging.info(f"NER Model saved to {model_dir}") + return self.model_path + + def load(self, model_path: str) -> None: + """Load a trained model""" + if not os.path.exists(model_path): + raise FileNotFoundError(f"Model not found at {model_path}") + + logging.info(f"Loading model from {model_path}") + self.nlp = spacy.load(model_path) + self.ner = self.nlp.get_pipe("ner") + self.model_path = model_path + + # Load training statistics if available + stats_path = Path(model_path) / "training_stats.json" + if stats_path.exists(): + with open(stats_path, 'r', encoding='utf-8') as f: + self.training_stats = json.load(f) + + logging.info("NER Model loaded successfully") + + def predict(self, text: str) -> Dict[str, Any]: + """Make predictions on a single text""" + if self.nlp is None: + raise ValueError("No model loaded. Load or train a model first.") + + doc = self.nlp(text) + entities = [] + + for ent in doc.ents: + entities.append({ + "text": ent.text, + "label": ent.label_, + "start": ent.start_char, + "end": ent.end_char, + "confidence": getattr(ent, 'score', None) # If confidence scores are available + }) + + return { + "text": text, + "entities": entities + } diff --git a/processing/ner/ner_name_tagger.py b/processing/ner/ner_name_tagger.py new file mode 100644 index 0000000..2460e9d --- /dev/null +++ b/processing/ner/ner_name_tagger.py @@ -0,0 +1,200 @@ +from typing import Union, Dict, Any, List +import logging + + +class NERNameTagger: + def tag_name(self, name: str, probable_native: str, probable_surname: str) -> Union[Dict[str, Any], None]: + """Create a single NER training example using probable_native and probable_surname""" + if not name or not probable_native or not probable_surname: + return None + + name = name.strip() + probable_native = probable_native.strip() + probable_surname = probable_surname.strip() + + entities = [] + used_spans = [] # Track used character spans to prevent overlaps + + # Helper function to check if a span overlaps with any existing span + def has_overlap(start, end): + for used_start, used_end in used_spans: + if not (end <= used_start or start >= used_end): + return True + return False + + # Find positions of native names in the full name + native_words = probable_native.split() + name_lower = name.lower() # Use lowercase for consistent searching + processed_native_words = set() + + for native_word in native_words: + native_word = native_word.strip() + if len(native_word) < 2: # Skip very short words + continue + + native_word_lower = native_word.lower() + + # Skip if we've already processed this exact word + if native_word_lower in processed_native_words: + continue + processed_native_words.add(native_word_lower) + + # Find the first occurrence of this native word that doesn't overlap + start_pos = 0 + while True: + pos = name_lower.find(native_word_lower, start_pos) # Case-insensitive search + if pos == -1: + break + + # Calculate end position - make sure we only include the word itself + end_pos = pos + len(native_word_lower) + + # Double-check that the extracted span matches exactly what we expect + extracted_text = name[pos:end_pos] # Get original case text + if extracted_text.lower() != native_word_lower: + start_pos = pos + 1 + continue + + # Check if this is a word boundary match and doesn't overlap + if (self._is_word_boundary_match(name, pos, end_pos) and + not has_overlap(pos, end_pos)): + entities.append((pos, end_pos, 'NATIVE')) + used_spans.append((pos, end_pos)) + break # Only take the first non-overlapping occurrence + + start_pos = pos + 1 + + # Find position of surname in the full name + if probable_surname and len(probable_surname.strip()) >= 2: + surname_lower = probable_surname.lower() + + # Find the first occurrence that doesn't overlap + start_pos = 0 + while True: + pos = name_lower.find(surname_lower, start_pos) # Case-insensitive search + if pos == -1: + break + + # Calculate end position correctly - exact match only + end_pos = pos + len(surname_lower) + + # Double-check that the extracted span matches exactly what we expect + extracted_text = name[pos:end_pos] # Get original case text + if extracted_text.lower() != surname_lower: + start_pos = pos + 1 + continue + + if (self._is_word_boundary_match(name, pos, end_pos) and + not has_overlap(pos, end_pos)): + entities.append((pos, end_pos, 'SURNAME')) + used_spans.append((pos, end_pos)) + break + + start_pos = pos + 1 + + if not entities: + logging.warning(f"No valid entities found for name: '{name}' with native: '{probable_native}' and surname: '{probable_surname}'") + return None + + # Sort entities by position and validate + entities.sort(key=lambda x: x[0]) + + # Final validation - ensure no overlaps and valid spans + validated_entities = [] + for start, end, label in entities: + # Check bounds + if not (0 <= start < end <= len(name)): + logging.warning(f"Invalid span bounds ({start}, {end}) for text length {len(name)}: '{name}'") + continue + + # Check for overlaps with already validated entities + if any(start < v_end and end > v_start for v_start, v_end, _ in validated_entities): + logging.warning(f"Overlapping span ({start}, {end}, '{label}') in '{name}'") + continue + + # CRITICAL VALIDATION: Check that the span contains only the expected word (no spaces) + span_text = name[start:end] + if not span_text or span_text != span_text.strip() or ' ' in span_text: + logging.warning(f"Span contains spaces or is empty ({start}, {end}) in '{name}': '{span_text}'") + continue + + validated_entities.append((start, end, label)) + + if not validated_entities: + logging.warning(f"No valid entities after validation for: '{name}'") + return None + + # Convert to string format that matches the dataset + entities_str = str(validated_entities) + + return { + "entities": entities_str, + "spans": validated_entities # Keep the original tuples for internal use + } + + @classmethod + def _is_word_boundary_match(cls, text: str, start: int, end: int) -> bool: + """Check if the match is at word boundaries""" + # Check character before start position + if start > 0: + prev_char = text[start - 1] + if prev_char.isalnum(): + return False + + # Check character after end position + if end < len(text): + next_char = text[end] + if next_char.isalnum(): + return False + + return True + + @classmethod + def validate_entities(cls, name: str, entities_str: str) -> bool: + """Validate that entity annotations are correct for a given name""" + try: + import ast + entities = ast.literal_eval(entities_str) + + # Check for overlaps and valid bounds + sorted_entities = sorted(entities, key=lambda x: x[0]) + + for i, (start, end, label) in enumerate(sorted_entities): + # Check bounds + if not (0 <= start < end <= len(name)): + return False + + # Check for overlaps with next entity + if i < len(sorted_entities) - 1: + next_start = sorted_entities[i + 1][0] + if end > next_start: + return False + + # Extract the text span and validate it's not empty + span_text = name[start:end] + if not span_text.strip(): + return False + + return True + except (ValueError, SyntaxError, TypeError): + return False + + @classmethod + def extract_entity_text(cls, name: str, entities_str: str) -> Dict[str, List[str]]: + """Extract the actual text for each entity type""" + result = {'NATIVE': [], 'SURNAME': []} + + try: + import ast + entities = ast.literal_eval(entities_str) + + for start, end, label in entities: + if 0 <= start < end <= len(name): + span_text = name[start:end] + if label in result: + result[label].append(span_text) + + except (ValueError, SyntaxError, TypeError): + pass + + return result diff --git a/processing/steps/__init__.py b/processing/steps/__init__.py index 065a622..b4f960d 100644 --- a/processing/steps/__init__.py +++ b/processing/steps/__init__.py @@ -6,9 +6,10 @@ from dataclasses import dataclass from typing import List, Optional import pandas as pd +from pydantic import BaseModel -from processing.batch.batch_config import BatchConfig from core.config.pipeline_config import PipelineConfig +from processing.batch.batch_config import BatchConfig @dataclass @@ -25,11 +26,18 @@ class PipelineState: self.failed_batches = [] +class NameAnnotation(BaseModel): + """Model for name annotation results""" + + identified_name: Optional[str] + identified_surname: Optional[str] + + class PipelineStep(ABC): """Abstract base class for pipeline steps""" def __init__( - self, name: str, pipeline_config: PipelineConfig, batch_config: Optional[BatchConfig] = None + self, name: str, pipeline_config: PipelineConfig, batch_config: Optional[BatchConfig] = None ): self.name = name self.pipeline_config = pipeline_config diff --git a/processing/steps/data_cleaning_step.py b/processing/steps/data_cleaning_step.py index 394f61c..1443984 100644 --- a/processing/steps/data_cleaning_step.py +++ b/processing/steps/data_cleaning_step.py @@ -25,4 +25,7 @@ class DataCleaningStep(PipelineStep): # Apply text cleaning batch = self.text_cleaner.clean_dataframe_text_columns(batch) + # Remove duplicates + batch = batch.drop_duplicates(subset=self.required_columns) + return batch diff --git a/processing/steps/feature_extraction_step.py b/processing/steps/feature_extraction_step.py index 3a4d520..3c23f91 100644 --- a/processing/steps/feature_extraction_step.py +++ b/processing/steps/feature_extraction_step.py @@ -5,6 +5,7 @@ import pandas as pd from core.config.pipeline_config import PipelineConfig from core.utils.region_mapper import RegionMapper +from processing.ner.ner_name_tagger import NERNameTagger from processing.steps import PipelineStep @@ -24,6 +25,7 @@ class FeatureExtractionStep(PipelineStep): def __init__(self, pipeline_config: PipelineConfig): super().__init__("feature_extraction", pipeline_config) self.region_mapper = RegionMapper() + self.name_tagger = NERNameTagger() @classmethod def validate_gender(cls, gender: str) -> Gender: @@ -52,7 +54,7 @@ class FeatureExtractionStep(PipelineStep): # Basic features batch["words"] = batch["name"].str.count(" ") + 1 - batch["length"] = batch["name"].str.replace(" ", "", regex=False).str.len() + batch["length"] = batch["name"].str.len() # Handle year column if "year" in batch.columns: @@ -63,6 +65,8 @@ class FeatureExtractionStep(PipelineStep): batch["probable_surname"] = None batch["identified_name"] = None batch["identified_surname"] = None + batch["ner_entities"] = None + batch["ner_tagged"] = 0 batch["annotated"] = 0 # Vectorized category assignment @@ -81,14 +85,19 @@ class FeatureExtractionStep(PipelineStep): # Auto-assign for 3-word names three_word_mask = batch["words"] == 3 - batch.loc[three_word_mask, "identified_name"] = batch.loc[ - three_word_mask, "probable_native" - ] - batch.loc[three_word_mask, "identified_surname"] = batch.loc[ - three_word_mask, "probable_surname" - ] + batch.loc[three_word_mask, "identified_name"] = batch.loc[three_word_mask, "probable_native"] + batch.loc[three_word_mask, "identified_surname"] = batch.loc[three_word_mask, "probable_surname"] batch.loc[three_word_mask, "annotated"] = 1 + # Tag names with NER entities + three_word_rows = batch[three_word_mask] + for idx, row in three_word_rows.iterrows(): + entity = self.name_tagger.tag_name(row['name'], row['identified_name'], row['identified_surname']) + + if entity: + batch.at[idx, "ner_entities"] = entity["entities"] + batch.at[idx, "ner_tagged"] = 1 + # Map regions to provinces batch["province"] = self.region_mapper.map_regions_vectorized(batch["region"]) diff --git a/processing/steps/llm_annotation_step.py b/processing/steps/llm_annotation_step.py index b8a3343..8aa80fb 100644 --- a/processing/steps/llm_annotation_step.py +++ b/processing/steps/llm_annotation_step.py @@ -1,25 +1,18 @@ import logging import time from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Dict, Optional +from typing import Dict import ollama import pandas as pd -from pydantic import ValidationError, BaseModel +from pydantic import ValidationError from core.config.pipeline_config import PipelineConfig from core.utils.prompt_manager import PromptManager -from core.utils.rate_limiter import RateLimiter from core.utils.rate_limiter import RateLimitConfig +from core.utils.rate_limiter import RateLimiter from processing.batch.batch_config import BatchConfig -from processing.steps import PipelineStep - - -class NameAnnotation(BaseModel): - """Model for name annotation results""" - - identified_name: Optional[str] - identified_surname: Optional[str] +from processing.steps import PipelineStep, NameAnnotation class LLMAnnotationStep(PipelineStep): @@ -27,10 +20,12 @@ class LLMAnnotationStep(PipelineStep): def __init__(self, pipeline_config: PipelineConfig): # Create custom batch config for LLM processing + self.llm_config = pipeline_config.annotation.llm batch_config = BatchConfig( batch_size=pipeline_config.processing.batch_size, max_workers=min( - pipeline_config.llm.max_concurrent_requests, pipeline_config.processing.max_workers + self.llm_config.max_concurrent_requests, + pipeline_config.processing.max_workers ), checkpoint_interval=pipeline_config.processing.checkpoint_interval, use_multiprocessing=pipeline_config.processing.use_multiprocessing, @@ -39,7 +34,7 @@ class LLMAnnotationStep(PipelineStep): self.prompt = PromptManager(pipeline_config).load_prompt() self.rate_limiter = ( - self._create_rate_limiter() if pipeline_config.llm.enable_rate_limiting else None + self._create_rate_limiter() if self.llm_config.enable_rate_limiting else None ) # Statistics @@ -53,14 +48,14 @@ class LLMAnnotationStep(PipelineStep): def _create_rate_limiter(self): """Create rate limiter based on configuration""" rate_config = RateLimitConfig( - requests_per_minute=self.pipeline_config.llm.requests_per_minute, - requests_per_second=self.pipeline_config.llm.requests_per_second, + requests_per_minute=self.llm_config.requests_per_minute, + requests_per_second=self.llm_config.requests_per_second, ) return RateLimiter(rate_config) - def analyze_name_with_retry(self, client: ollama.Client, name: str, row_id: int) -> Dict: + def analyze_name(self, client: ollama.Client, name: str) -> Dict: """Analyze a name with retry logic and rate limiting""" - for attempt in range(self.pipeline_config.llm.retry_attempts): + for attempt in range(self.llm_config.retry_attempts): try: # Apply rate limiting if enabled if self.rate_limiter: @@ -68,7 +63,7 @@ class LLMAnnotationStep(PipelineStep): start_time = time.time() response = client.chat( - model=self.pipeline_config.llm.model_name, + model=self.llm_config.model_name, messages=[ {"role": "system", "content": self.prompt}, {"role": "user", "content": name}, @@ -77,9 +72,9 @@ class LLMAnnotationStep(PipelineStep): ) elapsed_time = time.time() - start_time - if elapsed_time > self.pipeline_config.llm.timeout_seconds: + if elapsed_time > self.llm_config.timeout_seconds: raise TimeoutError( - f"Request took {elapsed_time:.2f}s, exceeding {self.pipeline_config.llm.timeout_seconds}s timeout" + f"Request took {elapsed_time:.2f}s, exceeding {self.llm_config.timeout_seconds}s timeout" ) annotation = NameAnnotation.model_validate_json(response.message.content) @@ -98,12 +93,12 @@ class LLMAnnotationStep(PipelineStep): except (ValidationError, TimeoutError, Exception) as e: logging.warning( - f"Error analyzing '{name}' (attempt {attempt + 1}/{self.pipeline_config.llm.retry_attempts}): {e}" + f"Error analyzing '{name}' (attempt {attempt + 1}/{self.llm_config.retry_attempts}): {e}" ) # Exponential backoff with jitter - if attempt < self.pipeline_config.llm.retry_attempts - 1: - wait_time = (2**attempt) + (time.time() % 1) + if attempt < self.llm_config.retry_attempts - 1: + wait_time = (2 ** attempt) + (time.time() % 1) time.sleep(min(wait_time, 10)) self.failed_requests += 1 @@ -112,7 +107,7 @@ class LLMAnnotationStep(PipelineStep): "identified_surname": None, "annotated": 0, "processing_time": 0, - "attempts": self.pipeline_config.llm.retry_attempts, + "attempts": self.llm_config.retry_attempts, "failed": True, } @@ -125,18 +120,18 @@ class LLMAnnotationStep(PipelineStep): logging.info(f"Batch {batch_id}: No entries to annotate") return batch - logging.info(f"Batch {batch_id}: Annotating {len(unannotated_entries)} entries") + logging.info(f"Batch {batch_id}: Annotating {len(unannotated_entries)} entries with LLM") batch = batch.copy() client = ollama.Client() # Process with controlled concurrency - max_workers = self.pipeline_config.llm.max_concurrent_requests + max_workers = self.llm_config.max_concurrent_requests if len(unannotated_entries) == 1 or max_workers == 1: # Sequential processing for idx, row in unannotated_entries.iterrows(): - result = self.analyze_name_with_retry(client, row["name"], idx) + result = self.analyze_name(client, row["name"]) for field, value in result.items(): if field not in ["failed"]: batch.loc[idx, field] = value @@ -146,7 +141,7 @@ class LLMAnnotationStep(PipelineStep): future_to_idx = {} for idx, row in unannotated_entries.iterrows(): - future = executor.submit(self.analyze_name_with_retry, client, row["name"], idx) + future = executor.submit(self.analyze_name, client, row["name"]) future_to_idx[future] = idx for future in as_completed(future_to_idx): @@ -161,8 +156,6 @@ class LLMAnnotationStep(PipelineStep): batch.loc[idx, "annotated"] = 0 # Ensure proper data types - batch["annotated"] = ( - pd.to_numeric(batch["annotated"], errors="coerce").fillna(0).astype("Int8") - ) + batch["annotated"] = pd.to_numeric(batch["annotated"], errors="coerce").fillna(0).astype("Int8") return batch diff --git a/processing/steps/ner_annotation_step.py b/processing/steps/ner_annotation_step.py new file mode 100644 index 0000000..1280ec5 --- /dev/null +++ b/processing/steps/ner_annotation_step.py @@ -0,0 +1,164 @@ +import logging +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Dict + +import pandas as pd + +from core.config.pipeline_config import PipelineConfig +from processing.steps import PipelineStep, NameAnnotation +from processing.ner.ner_name_model import NERNameModel + + +class NERAnnotationStep(PipelineStep): + """NER annotation step using trained spaCy model for entity recognition""" + + def __init__(self, pipeline_config: PipelineConfig): + # Create custom batch config for NER processing + super().__init__("ner_annotation", pipeline_config) + + self.model_name = "drc_ner_model" + self.model_path = pipeline_config.paths.models_dir / "drc_ner_model" + self.ner_trainer = NERNameModel(pipeline_config) + self.ner_config = pipeline_config.annotation.ner + + # Statistics + self.successful_requests = 0 + self.failed_requests = 0 + self.total_retry_attempts = 0 + + # Load the model + self._load_ner_model() + + def _load_ner_model(self) -> None: + """Load the trained NER model""" + try: + if self.model_path.exists(): + logging.info(f"Loading NER model from {self.model_path}") + self.ner_trainer.load(str(self.model_path)) + logging.info("NER model loaded successfully") + else: + logging.warning(f"NER model not found at {self.model_path}") + logging.warning("NER annotation will be skipped. Train the model first.") + self.ner_trainer.nlp = None + except Exception as e: + logging.error(f"Failed to load NER model: {e}") + self.ner_trainer.nlp = None + + def analyze_name(self, name: str) -> Dict: + """Analyze a name with retry logic""" + if self.ner_trainer.nlp is None: + return { + "identified_name": None, + "identified_surname": None, + "annotated": 0, + "processing_time": 0, + "attempts": 0, + "failed": True, + } + + for attempt in range(self.ner_config.retry_attempts): + try: + start_time = time.time() + + # Get NER predictions + prediction = self.ner_trainer.predict(name.lower()) + entities = prediction.get('entities', []) + + elapsed_time = time.time() - start_time + + # Extract native names and surnames from entities + native_parts = [] + surname_parts = [] + + for entity in entities: + if entity['label'] == 'NATIVE': + native_parts.append(entity['text']) + elif entity['label'] == 'SURNAME': + surname_parts.append(entity['text']) + + # Create annotation result in same format as LLM step + annotation = NameAnnotation( + identified_name=" ".join(native_parts) if native_parts else None, + identified_surname=" ".join(surname_parts) if surname_parts else None + ) + + result = { + **annotation.model_dump(), + "annotated": 1, + "processing_time": elapsed_time, + "attempts": attempt + 1, + } + + self.successful_requests += 1 + if attempt > 0: + self.total_retry_attempts += attempt + + return result + + except Exception as e: + logging.warning( + f"Error analyzing '{name}' with NER (attempt {attempt + 1}/{self.ner_config.retry_attempts}): {e}" + ) + + # Small delay between retries + if attempt < self.ner_config.retry_attempts - 1: + time.sleep(0.1) + + self.failed_requests += 1 + return { + "identified_name": None, + "identified_surname": None, + "annotated": 0, + "processing_time": 0, + "attempts": self.ner_config.retry_attempts, + "failed": True, + } + + def process_batch(self, batch: pd.DataFrame, batch_id: int) -> pd.DataFrame: + """Process batch with NER annotation using same logic as LLM step""" + unannotated_mask = batch.get("annotated", 0) == 0 + unannotated_entries = batch[unannotated_mask] + + if unannotated_entries.empty: + logging.info(f"Batch {batch_id}: No entries to annotate") + return batch + + logging.info(f"Batch {batch_id}: Annotating {len(unannotated_entries)} entries with NER") + + batch = batch.copy() + + # Process with controlled concurrency + max_workers = self.batch_config.max_workers + + if len(unannotated_entries) == 1 or max_workers == 1: + # Sequential processing + for idx, row in unannotated_entries.iterrows(): + result = self.analyze_name(row["name"]) + for field, value in result.items(): + if field not in ["failed"]: + batch.loc[idx, field] = value + else: + # Concurrent processing + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_idx = {} + + for idx, row in unannotated_entries.iterrows(): + future = executor.submit(self.analyze_name, row["name"]) + future_to_idx[future] = idx + + for future in as_completed(future_to_idx): + idx = future_to_idx[future] + try: + result = future.result() + for field, value in result.items(): + if field not in ["failed"]: + batch.loc[idx, field] = value + except Exception as e: + logging.error(f"Failed to process row {idx}: {e}") + batch.loc[idx, "annotated"] = 0 + + # Ensure proper data types + batch["annotated"] = pd.to_numeric(batch["annotated"], errors="coerce").fillna(0).astype("Int8") + + return batch diff --git a/requirements.txt b/requirements.txt index 49d26e4..8e2f2d8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -176,3 +176,4 @@ altair==5.1.2 PyYAML~=6.0.2 xgboost~=3.0.3 lightgbm~=4.6.0 +spacy~=3.8.7 \ No newline at end of file