import json import logging import os from pathlib import Path from typing import Dict, Any, List, Tuple import spacy from spacy.training import Example from spacy.util import minibatch from core.config.pipeline_config import PipelineConfig class NERNameModel: """NER model trainer using spaCy for DRC names entity recognition""" def __init__(self, config: PipelineConfig): self.config = config self.nlp = None self.ner = None self.model_path = None self.training_stats = {} def create_blank_model(self, language: str = "fr") -> None: """Create a blank spaCy model with NER pipeline""" logging.info(f"Creating blank {language} model for NER training") # Create blank model - French tokenizer works well for DRC names self.nlp = spacy.blank(language) # Add NER pipeline component if "ner" not in self.nlp.pipe_names: self.ner = self.nlp.add_pipe("ner") else: self.ner = self.nlp.get_pipe("ner") # Add our custom labels self.ner.add_label("NATIVE") self.ner.add_label("SURNAME") logging.info("Blank model created with NATIVE and SURNAME labels") @classmethod def load_data(cls, data_path: str) -> List[Tuple[str, Dict]]: """Load training data from JSON file - compatible with NERNameTagger output format""" if not os.path.exists(data_path): raise FileNotFoundError(f"Training data not found at {data_path}") logging.info(f"Loading training data from {data_path}") with open(data_path, "r", encoding="utf-8") as f: raw_data = json.load(f) # Validate and clean training data valid_data = [] skipped_count = 0 for i, item in enumerate(raw_data): try: if not isinstance(item, (list, tuple)) or len(item) != 2: logging.warning( f"Skipping invalid training example format at index {i}: {item}" ) skipped_count += 1 continue text, annotations = item # Validate text if not isinstance(text, str) or not text.strip(): logging.warning(f"Skipping invalid text at index {i}: {repr(text)}") skipped_count += 1 continue # Handle different annotation formats from NERNameTagger if not isinstance(annotations, dict) or "entities" not in annotations: logging.warning(f"Skipping invalid annotations at index {i}: {annotations}") skipped_count += 1 continue entities_raw = annotations["entities"] # Parse entities - handle both string and list formats from tagger if isinstance(entities_raw, str): # String format from tagger: "[(0, 6, 'NATIVE'), ...]" try: import ast entities = ast.literal_eval(entities_raw) if not isinstance(entities, list): logging.warning( f"Parsed entities is not a list at index {i}: {entities}" ) skipped_count += 1 continue except (ValueError, SyntaxError) as e: logging.warning( f"Failed to parse entity string at index {i}: {entities_raw} ({e})" ) skipped_count += 1 continue elif isinstance(entities_raw, list): # Already in list format entities = entities_raw else: logging.warning( f"Skipping invalid entities format at index {i}: {entities_raw}" ) skipped_count += 1 continue # Validate each entity valid_entities = [] for entity in entities: if not isinstance(entity, (list, tuple)) or len(entity) != 3: logging.warning(f"Skipping invalid entity format in '{text}': {entity}") continue start, end, label = entity # Validate entity components if ( not isinstance(start, int) or not isinstance(end, int) or not isinstance(label, str) or start >= end or start < 0 or end > len(text) ): logging.warning(f"Skipping invalid entity bounds in '{text}': {entity}") continue # Check for overlaps with already validated entities has_overlap = any( start < v_end and end > v_start for v_start, v_end, _ in valid_entities ) if has_overlap: logging.warning(f"Skipping overlapping entity in '{text}': {entity}") continue # Validate that the span doesn't contain spaces (matching tagger validation) span_text = text[start:end] if not span_text or span_text != span_text.strip() or " " in span_text: logging.warning( f"Skipping entity with spaces in '{text}': {entity} -> '{span_text}'" ) continue valid_entities.append((start, end, label)) if not valid_entities: logging.warning(f"Skipping training example with no valid entities: '{text}'") skipped_count += 1 continue # Sort entities by start position valid_entities.sort(key=lambda x: x[0]) valid_data.append((text.strip(), {"entities": valid_entities})) except Exception as e: logging.error(f"Error processing training example at index {i}: {e}") skipped_count += 1 continue logging.info( f"Loaded {len(valid_data)} valid training examples, skipped {skipped_count} invalid ones" ) if not valid_data: raise ValueError("No valid training examples found in the data") return valid_data def train( self, data: List[Tuple[str, Dict]], epochs: int = 5, batch_size: int = 16, dropout_rate: float = 0.2, ) -> None: """Train the NER model""" logging.info(f"Starting NER training with {len(data)} examples") logging.info( f"Training parameters: epochs={epochs}, batch_size={batch_size}, dropout={dropout_rate}" ) if self.nlp is None: raise ValueError("Model not initialized. Call create_blank_model() first.") # Initialize the model self.nlp.initialize() # Training loop losses_history = [] for epoch in range(epochs): losses = {} # Create training examples examples = [] for text, annotations in data: doc = self.nlp.make_doc(text) example = Example.from_dict(doc, annotations) examples.append(example) logging.info( f"Training example: {text[:30]}... with entities {annotations.get('entities', [])}" ) # Train in batches batches = minibatch(examples, size=batch_size) for batch in batches: self.nlp.update( batch, losses=losses, drop=dropout_rate, sgd=self.nlp.create_optimizer() ) logging.info(f"Training batch with {len(batch)} examples, current losses: {losses}") epoch_loss = losses.get("ner", 0) losses_history.append(epoch_loss) logging.info(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss:.4f}") # Store training statistics self.training_stats = { "epochs": epochs, "final_loss": losses_history[-1] if losses_history else 0, "training_examples": len(data), "loss_history": losses_history, "batch_size": batch_size, "dropout_rate": dropout_rate, } logging.info(f"Training completed. Final loss: {self.training_stats['final_loss']:.4f}") def evaluate(self, test_data: List[Tuple[str, Dict]]) -> Dict[str, Any]: """Evaluate the trained model on test data""" if self.nlp is None: raise ValueError("Model not trained. Call train_model() first.") logging.info(f"Evaluating model on {len(test_data)} test examples") total_examples = len(test_data) correct_entities = 0 predicted_entities = 0 actual_entities = 0 entity_stats = { "NATIVE": {"tp": 0, "fp": 0, "fn": 0}, "SURNAME": {"tp": 0, "fp": 0, "fn": 0}, } for text, annotations in test_data: # Get actual entities actual_ents = set() for start, end, label in annotations.get("entities", []): actual_ents.add((start, end, label)) actual_entities += 1 # Get predicted entities doc = self.nlp(text) predicted_ents = set() for ent in doc.ents: predicted_ents.add((ent.start_char, ent.end_char, ent.label_)) predicted_entities += 1 # Calculate matches matches = actual_ents.intersection(predicted_ents) correct_entities += len(matches) # Update per-label statistics for start, end, label in actual_ents: if (start, end, label) in predicted_ents: entity_stats[label]["tp"] += 1 else: entity_stats[label]["fn"] += 1 for start, end, label in predicted_ents: if (start, end, label) not in actual_ents: entity_stats[label]["fp"] += 1 # Calculate overall metrics precision = correct_entities / predicted_entities if predicted_entities > 0 else 0 recall = correct_entities / actual_entities if actual_entities > 0 else 0 f1_score = ( 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 ) # Calculate per-label metrics label_metrics = {} for label, stats in entity_stats.items(): tp, fp, fn = stats["tp"], stats["fp"], stats["fn"] label_precision = tp / (tp + fp) if (tp + fp) > 0 else 0 label_recall = tp / (tp + fn) if (tp + fn) > 0 else 0 label_f1 = ( (2 * (label_precision * label_recall) / (label_precision + label_recall)) if (label_precision + label_recall) > 0 else 0 ) label_metrics[label] = { "precision": label_precision, "recall": label_recall, "f1_score": label_f1, "support": tp + fn, } evaluation_results = { "overall": { "precision": precision, "recall": recall, "f1_score": f1_score, "total_examples": total_examples, "correct_entities": correct_entities, "predicted_entities": predicted_entities, "actual_entities": actual_entities, }, "by_label": label_metrics, } logging.info(f"NER Evaluation completed. Overall F1: {f1_score:.4f}") return evaluation_results def save(self, model_name: str = "drc_ner_model") -> str: """Save the trained model""" if self.nlp is None: raise ValueError("No model to save. Train a model first.") # Create model directory model_dir = self.config.paths.models_dir / model_name model_dir.mkdir(parents=True, exist_ok=True) # Save the model self.nlp.to_disk(model_dir) self.model_path = str(model_dir) # Save training statistics stats_path = model_dir / "training_stats.json" with open(stats_path, "w", encoding="utf-8") as f: json.dump(self.training_stats, f, indent=2) logging.info(f"NER Model saved to {model_dir}") return self.model_path def load(self, model_path: str) -> None: """Load a trained model""" if not os.path.exists(model_path): raise FileNotFoundError(f"Model not found at {model_path}") logging.info(f"Loading model from {model_path}") self.nlp = spacy.load(model_path) self.ner = self.nlp.get_pipe("ner") self.model_path = model_path # Load training statistics if available stats_path = Path(model_path) / "training_stats.json" if stats_path.exists(): with open(stats_path, "r", encoding="utf-8") as f: self.training_stats = json.load(f) logging.info("NER Model loaded successfully") def predict(self, text: str) -> Dict[str, Any]: """Make predictions on a single text""" if self.nlp is None: raise ValueError("No model loaded. Load or train a model first.") doc = self.nlp(text) entities = [] for ent in doc.ents: entities.append( { "text": ent.text, "label": ent.label_, "start": ent.start_char, "end": ent.end_char, "confidence": getattr(ent, "score", None), # If confidence scores are available } ) return {"text": text, "entities": entities}