feat: add NER annotation step and integrate into pipeline
This commit is contained in:
@@ -0,0 +1,198 @@
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import spacy
|
||||
from spacy.tokens import DocBin
|
||||
from spacy.util import filter_spans
|
||||
|
||||
from core.config import PipelineConfig
|
||||
from core.utils import get_data_file_path
|
||||
|
||||
|
||||
class NERDataBuilder:
|
||||
def __init__(self, config: PipelineConfig):
|
||||
self.config = config
|
||||
|
||||
@classmethod
|
||||
def parse_entities(cls, entities_str):
|
||||
"""Parse entity string (tuple format or JSON) into spaCy-style tuples."""
|
||||
if not entities_str or entities_str in ["[]", "", "nan"]:
|
||||
return []
|
||||
|
||||
entities_str = str(entities_str).strip()
|
||||
|
||||
# Handle different formats
|
||||
try:
|
||||
# Try to parse as Python literal (tuples or lists)
|
||||
if entities_str.startswith("[(") and entities_str.endswith(")]"):
|
||||
# Standard tuple format: [(0, 6, 'NATIVE'), ...]
|
||||
return ast.literal_eval(entities_str)
|
||||
elif entities_str.startswith("[[") and entities_str.endswith("]]"):
|
||||
# Nested list format: [[0, 6, 'NATIVE'], ...]
|
||||
nested_list = ast.literal_eval(entities_str)
|
||||
return [(start, end, label) for start, end, label in nested_list]
|
||||
elif entities_str.startswith("[{") and entities_str.endswith("}]"):
|
||||
# JSON format: [{"start": 0, "end": 6, "label": "NATIVE"}, ...]
|
||||
json_entities = json.loads(entities_str)
|
||||
return [(e["start"], e["end"], e["label"]) for e in json_entities]
|
||||
else:
|
||||
# Try general ast.literal_eval for other formats
|
||||
parsed = ast.literal_eval(entities_str)
|
||||
if isinstance(parsed, list):
|
||||
# Convert any list format to tuples
|
||||
result = []
|
||||
for item in parsed:
|
||||
if isinstance(item, (list, tuple)) and len(item) == 3:
|
||||
result.append((item[0], item[1], item[2]))
|
||||
return result
|
||||
|
||||
except (ValueError, SyntaxError, json.JSONDecodeError) as e:
|
||||
logging.warning(f"Failed to parse entities: {entities_str} ({e})")
|
||||
return []
|
||||
|
||||
logging.warning(f"Unknown entity format: {entities_str}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def validate_entities(cls, entities, text):
|
||||
"""Validate and sort entity tuples, removing overlaps and invalid spans."""
|
||||
if not entities or not text:
|
||||
return []
|
||||
|
||||
text = str(text).strip()
|
||||
if not text:
|
||||
return []
|
||||
|
||||
# Filter out invalid entities
|
||||
valid_entities = []
|
||||
for entity in entities:
|
||||
if not isinstance(entity, (list, tuple)) or len(entity) != 3:
|
||||
logging.warning(f"Invalid entity format: {entity}")
|
||||
continue
|
||||
|
||||
start, end, label = entity
|
||||
|
||||
# Ensure start/end are integers
|
||||
try:
|
||||
start = int(start)
|
||||
end = int(end)
|
||||
except (ValueError, TypeError):
|
||||
logging.warning(f"Invalid start/end positions: {entity}")
|
||||
continue
|
||||
|
||||
# Ensure label is string
|
||||
if not isinstance(label, str):
|
||||
logging.warning(f"Invalid label type: {entity}")
|
||||
continue
|
||||
|
||||
# Check bounds
|
||||
if not (0 <= start < end <= len(text)):
|
||||
logging.warning(f"Entity span out of bounds: {entity} for text '{text}' (length {len(text)})")
|
||||
continue
|
||||
|
||||
# Check that span contains actual text
|
||||
span_text = text[start:end].strip()
|
||||
if not span_text:
|
||||
logging.warning(f"Empty span: {entity} in text '{text}'")
|
||||
continue
|
||||
|
||||
valid_entities.append((start, end, label))
|
||||
|
||||
if not valid_entities:
|
||||
return []
|
||||
|
||||
# Sort by start position
|
||||
valid_entities.sort(key=lambda x: (x[0], x[1]))
|
||||
|
||||
# Remove overlapping entities (keep the first one)
|
||||
filtered = []
|
||||
for start, end, label in valid_entities:
|
||||
# Check for overlap with already added entities
|
||||
has_overlap = False
|
||||
for e_start, e_end, _ in filtered:
|
||||
if not (end <= e_start or start >= e_end):
|
||||
has_overlap = True
|
||||
logging.warning(
|
||||
f"Removing overlapping entity ({start}, {end}, '{label}') "
|
||||
f"conflicts with ({e_start}, {e_end}) in '{text}'"
|
||||
)
|
||||
break
|
||||
|
||||
if not has_overlap:
|
||||
filtered.append((start, end, label))
|
||||
|
||||
return filtered
|
||||
|
||||
@classmethod
|
||||
def create_doc(cls, text, entities, nlp):
|
||||
"""Create a spaCy Doc object with entities added."""
|
||||
doc = nlp(text)
|
||||
ents = []
|
||||
|
||||
for start, end, label in entities:
|
||||
span = doc.char_span(start, end, label=label, alignment_mode="contract") \
|
||||
or doc.char_span(start, end, label=label, alignment_mode="strict")
|
||||
if span:
|
||||
ents.append(span)
|
||||
else:
|
||||
logging.warning(f"Could not create span ({start}, {end}, '{label}') in '{text}'")
|
||||
|
||||
doc.ents = filter_spans(ents) if ents else []
|
||||
return doc
|
||||
|
||||
def build(self, data: pd.DataFrame = None) -> int:
|
||||
"""Build the dataset for NER training."""
|
||||
logging.info("Building dataset for NER training")
|
||||
try:
|
||||
df = pd.read_csv(get_data_file_path("names_featured.csv", self.config)) \
|
||||
if data is None \
|
||||
else data
|
||||
|
||||
ner_df = df[df["ner_tagged"] == 1].copy()
|
||||
if ner_df.empty:
|
||||
logging.error("No NER tagged data found in the CSV")
|
||||
return 1
|
||||
|
||||
logging.info(f"Found {len(ner_df)} NER tagged entries")
|
||||
nlp = spacy.blank("fr")
|
||||
doc_bin, training_data = DocBin(), []
|
||||
processed_count, skipped_count = 0, 0
|
||||
|
||||
for _, row in ner_df.iterrows():
|
||||
text = str(row.get("name", "")).strip()
|
||||
if not text:
|
||||
continue
|
||||
|
||||
entities = self.parse_entities(row.get("ner_entities", "[]"))
|
||||
entities = self.validate_entities(entities, text)
|
||||
|
||||
training_data.append((text, {"entities": entities}))
|
||||
try:
|
||||
doc_bin.add(self.create_doc(text, entities, nlp))
|
||||
processed_count += 1
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing '{text}': {e}")
|
||||
skipped_count += 1
|
||||
|
||||
if not training_data:
|
||||
logging.error("No valid training examples generated")
|
||||
return 1
|
||||
|
||||
json_path = Path(self.config.paths.data_dir) / self.config.data.output_files["ner_data"]
|
||||
spacy_path = Path(self.config.paths.data_dir) / self.config.data.output_files["ner_spacy"]
|
||||
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(training_data, f, ensure_ascii=False, indent=None)
|
||||
doc_bin.to_disk(spacy_path)
|
||||
|
||||
logging.info(f"Processed: {processed_count}, Skipped: {skipped_count}")
|
||||
logging.info(f"Saved NER data in json format to {json_path}")
|
||||
logging.info(f"Saved NER data in spaCy format to {spacy_path}")
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to build NER dataset: {e}", exc_info=True)
|
||||
return 1
|
||||
@@ -0,0 +1,356 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Tuple
|
||||
|
||||
import spacy
|
||||
from spacy.training import Example
|
||||
from spacy.util import minibatch
|
||||
|
||||
from core.config.pipeline_config import PipelineConfig
|
||||
|
||||
|
||||
class NERNameModel:
|
||||
"""NER model trainer using spaCy for DRC names entity recognition"""
|
||||
|
||||
def __init__(self, config: PipelineConfig):
|
||||
self.config = config
|
||||
self.nlp = None
|
||||
self.ner = None
|
||||
self.model_path = None
|
||||
self.training_stats = {}
|
||||
|
||||
def create_blank_model(self, language: str = "fr") -> None:
|
||||
"""Create a blank spaCy model with NER pipeline"""
|
||||
logging.info(f"Creating blank {language} model for NER training")
|
||||
|
||||
# Create blank model - French tokenizer works well for DRC names
|
||||
self.nlp = spacy.blank(language)
|
||||
|
||||
# Add NER pipeline component
|
||||
if "ner" not in self.nlp.pipe_names:
|
||||
self.ner = self.nlp.add_pipe("ner")
|
||||
else:
|
||||
self.ner = self.nlp.get_pipe("ner")
|
||||
|
||||
# Add our custom labels
|
||||
self.ner.add_label("NATIVE")
|
||||
self.ner.add_label("SURNAME")
|
||||
|
||||
logging.info("Blank model created with NATIVE and SURNAME labels")
|
||||
|
||||
@classmethod
|
||||
def load_data(cls, data_path: str) -> List[Tuple[str, Dict]]:
|
||||
"""Load training data from JSON file - compatible with NERNameTagger output format"""
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError(f"Training data not found at {data_path}")
|
||||
|
||||
logging.info(f"Loading training data from {data_path}")
|
||||
|
||||
with open(data_path, 'r', encoding='utf-8') as f:
|
||||
raw_data = json.load(f)
|
||||
|
||||
# Validate and clean training data
|
||||
valid_data = []
|
||||
skipped_count = 0
|
||||
|
||||
for i, item in enumerate(raw_data):
|
||||
try:
|
||||
if not isinstance(item, (list, tuple)) or len(item) != 2:
|
||||
logging.warning(f"Skipping invalid training example format at index {i}: {item}")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
text, annotations = item
|
||||
|
||||
# Validate text
|
||||
if not isinstance(text, str) or not text.strip():
|
||||
logging.warning(f"Skipping invalid text at index {i}: {repr(text)}")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
# Handle different annotation formats from NERNameTagger
|
||||
if not isinstance(annotations, dict) or "entities" not in annotations:
|
||||
logging.warning(f"Skipping invalid annotations at index {i}: {annotations}")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
entities_raw = annotations["entities"]
|
||||
|
||||
# Parse entities - handle both string and list formats from tagger
|
||||
if isinstance(entities_raw, str):
|
||||
# String format from tagger: "[(0, 6, 'NATIVE'), ...]"
|
||||
try:
|
||||
import ast
|
||||
entities = ast.literal_eval(entities_raw)
|
||||
if not isinstance(entities, list):
|
||||
logging.warning(f"Parsed entities is not a list at index {i}: {entities}")
|
||||
skipped_count += 1
|
||||
continue
|
||||
except (ValueError, SyntaxError) as e:
|
||||
logging.warning(f"Failed to parse entity string at index {i}: {entities_raw} ({e})")
|
||||
skipped_count += 1
|
||||
continue
|
||||
elif isinstance(entities_raw, list):
|
||||
# Already in list format
|
||||
entities = entities_raw
|
||||
else:
|
||||
logging.warning(f"Skipping invalid entities format at index {i}: {entities_raw}")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
# Validate each entity
|
||||
valid_entities = []
|
||||
for entity in entities:
|
||||
if not isinstance(entity, (list, tuple)) or len(entity) != 3:
|
||||
logging.warning(f"Skipping invalid entity format in '{text}': {entity}")
|
||||
continue
|
||||
|
||||
start, end, label = entity
|
||||
|
||||
# Validate entity components
|
||||
if (not isinstance(start, int) or not isinstance(end, int) or
|
||||
not isinstance(label, str) or start >= end or
|
||||
start < 0 or end > len(text)):
|
||||
logging.warning(f"Skipping invalid entity bounds in '{text}': {entity}")
|
||||
continue
|
||||
|
||||
# Check for overlaps with already validated entities
|
||||
has_overlap = any(
|
||||
start < v_end and end > v_start
|
||||
for v_start, v_end, _ in valid_entities
|
||||
)
|
||||
|
||||
if has_overlap:
|
||||
logging.warning(f"Skipping overlapping entity in '{text}': {entity}")
|
||||
continue
|
||||
|
||||
# Validate that the span doesn't contain spaces (matching tagger validation)
|
||||
span_text = text[start:end]
|
||||
if not span_text or span_text != span_text.strip() or ' ' in span_text:
|
||||
logging.warning(f"Skipping entity with spaces in '{text}': {entity} -> '{span_text}'")
|
||||
continue
|
||||
|
||||
valid_entities.append((start, end, label))
|
||||
|
||||
if not valid_entities:
|
||||
logging.warning(f"Skipping training example with no valid entities: '{text}'")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
# Sort entities by start position
|
||||
valid_entities.sort(key=lambda x: x[0])
|
||||
valid_data.append((text.strip(), {"entities": valid_entities}))
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing training example at index {i}: {e}")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
logging.info(f"Loaded {len(valid_data)} valid training examples, skipped {skipped_count} invalid ones")
|
||||
|
||||
if not valid_data:
|
||||
raise ValueError("No valid training examples found in the data")
|
||||
|
||||
return valid_data
|
||||
|
||||
def train(
|
||||
self,
|
||||
data: List[Tuple[str, Dict]],
|
||||
epochs: int = 5,
|
||||
batch_size: int = 16,
|
||||
dropout_rate: float = 0.2,
|
||||
) -> None:
|
||||
"""Train the NER model"""
|
||||
logging.info(f"Starting NER training with {len(data)} examples")
|
||||
logging.info(f"Training parameters: epochs={epochs}, batch_size={batch_size}, dropout={dropout_rate}")
|
||||
|
||||
if self.nlp is None:
|
||||
raise ValueError("Model not initialized. Call create_blank_model() first.")
|
||||
|
||||
# Initialize the model
|
||||
self.nlp.initialize()
|
||||
|
||||
# Training loop
|
||||
losses_history = []
|
||||
|
||||
for epoch in range(epochs):
|
||||
losses = {}
|
||||
|
||||
# Create training examples
|
||||
examples = []
|
||||
for text, annotations in data:
|
||||
doc = self.nlp.make_doc(text)
|
||||
example = Example.from_dict(doc, annotations)
|
||||
examples.append(example)
|
||||
logging.info(f"Training example: {text[:30]}... with entities {annotations.get('entities', [])}")
|
||||
|
||||
# Train in batches
|
||||
batches = minibatch(examples, size=batch_size)
|
||||
for batch in batches:
|
||||
self.nlp.update(
|
||||
batch,
|
||||
losses=losses,
|
||||
drop=dropout_rate,
|
||||
sgd=self.nlp.create_optimizer()
|
||||
)
|
||||
logging.info(f"Training batch with {len(batch)} examples, current losses: {losses}")
|
||||
|
||||
epoch_loss = losses.get("ner", 0)
|
||||
losses_history.append(epoch_loss)
|
||||
logging.info(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss:.4f}")
|
||||
|
||||
# Store training statistics
|
||||
self.training_stats = {
|
||||
"epochs": epochs,
|
||||
"final_loss": losses_history[-1] if losses_history else 0,
|
||||
"training_examples": len(data),
|
||||
"loss_history": losses_history,
|
||||
"batch_size": batch_size,
|
||||
"dropout_rate": dropout_rate
|
||||
}
|
||||
|
||||
logging.info(f"Training completed. Final loss: {self.training_stats['final_loss']:.4f}")
|
||||
|
||||
def evaluate(self, test_data: List[Tuple[str, Dict]]) -> Dict[str, Any]:
|
||||
"""Evaluate the trained model on test data"""
|
||||
if self.nlp is None:
|
||||
raise ValueError("Model not trained. Call train_model() first.")
|
||||
|
||||
logging.info(f"Evaluating model on {len(test_data)} test examples")
|
||||
|
||||
total_examples = len(test_data)
|
||||
correct_entities = 0
|
||||
predicted_entities = 0
|
||||
actual_entities = 0
|
||||
|
||||
entity_stats = {"NATIVE": {"tp": 0, "fp": 0, "fn": 0}, "SURNAME": {"tp": 0, "fp": 0, "fn": 0}}
|
||||
|
||||
for text, annotations in test_data:
|
||||
# Get actual entities
|
||||
actual_ents = set()
|
||||
for start, end, label in annotations.get("entities", []):
|
||||
actual_ents.add((start, end, label))
|
||||
actual_entities += 1
|
||||
|
||||
# Get predicted entities
|
||||
doc = self.nlp(text)
|
||||
predicted_ents = set()
|
||||
for ent in doc.ents:
|
||||
predicted_ents.add((ent.start_char, ent.end_char, ent.label_))
|
||||
predicted_entities += 1
|
||||
|
||||
# Calculate matches
|
||||
matches = actual_ents.intersection(predicted_ents)
|
||||
correct_entities += len(matches)
|
||||
|
||||
# Update per-label statistics
|
||||
for start, end, label in actual_ents:
|
||||
if (start, end, label) in predicted_ents:
|
||||
entity_stats[label]["tp"] += 1
|
||||
else:
|
||||
entity_stats[label]["fn"] += 1
|
||||
|
||||
for start, end, label in predicted_ents:
|
||||
if (start, end, label) not in actual_ents:
|
||||
entity_stats[label]["fp"] += 1
|
||||
|
||||
# Calculate overall metrics
|
||||
precision = correct_entities / predicted_entities if predicted_entities > 0 else 0
|
||||
recall = correct_entities / actual_entities if actual_entities > 0 else 0
|
||||
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
||||
|
||||
# Calculate per-label metrics
|
||||
label_metrics = {}
|
||||
for label, stats in entity_stats.items():
|
||||
tp, fp, fn = stats["tp"], stats["fp"], stats["fn"]
|
||||
label_precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
||||
label_recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
||||
label_f1 = (
|
||||
2 * (label_precision * label_recall) / (label_precision + label_recall)) \
|
||||
if (label_precision + label_recall) > 0 else 0
|
||||
|
||||
label_metrics[label] = {
|
||||
"precision": label_precision,
|
||||
"recall": label_recall,
|
||||
"f1_score": label_f1,
|
||||
"support": tp + fn
|
||||
}
|
||||
|
||||
evaluation_results = {
|
||||
"overall": {
|
||||
"precision": precision,
|
||||
"recall": recall,
|
||||
"f1_score": f1_score,
|
||||
"total_examples": total_examples,
|
||||
"correct_entities": correct_entities,
|
||||
"predicted_entities": predicted_entities,
|
||||
"actual_entities": actual_entities
|
||||
},
|
||||
"by_label": label_metrics
|
||||
}
|
||||
|
||||
logging.info(f"NER Evaluation completed. Overall F1: {f1_score:.4f}")
|
||||
return evaluation_results
|
||||
|
||||
def save(self, model_name: str = "drc_ner_model") -> str:
|
||||
"""Save the trained model"""
|
||||
if self.nlp is None:
|
||||
raise ValueError("No model to save. Train a model first.")
|
||||
|
||||
# Create model directory
|
||||
model_dir = self.config.paths.models_dir / model_name
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save the model
|
||||
self.nlp.to_disk(model_dir)
|
||||
self.model_path = str(model_dir)
|
||||
|
||||
# Save training statistics
|
||||
stats_path = model_dir / "training_stats.json"
|
||||
with open(stats_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(self.training_stats, f, indent=2)
|
||||
|
||||
logging.info(f"NER Model saved to {model_dir}")
|
||||
return self.model_path
|
||||
|
||||
def load(self, model_path: str) -> None:
|
||||
"""Load a trained model"""
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"Model not found at {model_path}")
|
||||
|
||||
logging.info(f"Loading model from {model_path}")
|
||||
self.nlp = spacy.load(model_path)
|
||||
self.ner = self.nlp.get_pipe("ner")
|
||||
self.model_path = model_path
|
||||
|
||||
# Load training statistics if available
|
||||
stats_path = Path(model_path) / "training_stats.json"
|
||||
if stats_path.exists():
|
||||
with open(stats_path, 'r', encoding='utf-8') as f:
|
||||
self.training_stats = json.load(f)
|
||||
|
||||
logging.info("NER Model loaded successfully")
|
||||
|
||||
def predict(self, text: str) -> Dict[str, Any]:
|
||||
"""Make predictions on a single text"""
|
||||
if self.nlp is None:
|
||||
raise ValueError("No model loaded. Load or train a model first.")
|
||||
|
||||
doc = self.nlp(text)
|
||||
entities = []
|
||||
|
||||
for ent in doc.ents:
|
||||
entities.append({
|
||||
"text": ent.text,
|
||||
"label": ent.label_,
|
||||
"start": ent.start_char,
|
||||
"end": ent.end_char,
|
||||
"confidence": getattr(ent, 'score', None) # If confidence scores are available
|
||||
})
|
||||
|
||||
return {
|
||||
"text": text,
|
||||
"entities": entities
|
||||
}
|
||||
@@ -0,0 +1,200 @@
|
||||
from typing import Union, Dict, Any, List
|
||||
import logging
|
||||
|
||||
|
||||
class NERNameTagger:
|
||||
def tag_name(self, name: str, probable_native: str, probable_surname: str) -> Union[Dict[str, Any], None]:
|
||||
"""Create a single NER training example using probable_native and probable_surname"""
|
||||
if not name or not probable_native or not probable_surname:
|
||||
return None
|
||||
|
||||
name = name.strip()
|
||||
probable_native = probable_native.strip()
|
||||
probable_surname = probable_surname.strip()
|
||||
|
||||
entities = []
|
||||
used_spans = [] # Track used character spans to prevent overlaps
|
||||
|
||||
# Helper function to check if a span overlaps with any existing span
|
||||
def has_overlap(start, end):
|
||||
for used_start, used_end in used_spans:
|
||||
if not (end <= used_start or start >= used_end):
|
||||
return True
|
||||
return False
|
||||
|
||||
# Find positions of native names in the full name
|
||||
native_words = probable_native.split()
|
||||
name_lower = name.lower() # Use lowercase for consistent searching
|
||||
processed_native_words = set()
|
||||
|
||||
for native_word in native_words:
|
||||
native_word = native_word.strip()
|
||||
if len(native_word) < 2: # Skip very short words
|
||||
continue
|
||||
|
||||
native_word_lower = native_word.lower()
|
||||
|
||||
# Skip if we've already processed this exact word
|
||||
if native_word_lower in processed_native_words:
|
||||
continue
|
||||
processed_native_words.add(native_word_lower)
|
||||
|
||||
# Find the first occurrence of this native word that doesn't overlap
|
||||
start_pos = 0
|
||||
while True:
|
||||
pos = name_lower.find(native_word_lower, start_pos) # Case-insensitive search
|
||||
if pos == -1:
|
||||
break
|
||||
|
||||
# Calculate end position - make sure we only include the word itself
|
||||
end_pos = pos + len(native_word_lower)
|
||||
|
||||
# Double-check that the extracted span matches exactly what we expect
|
||||
extracted_text = name[pos:end_pos] # Get original case text
|
||||
if extracted_text.lower() != native_word_lower:
|
||||
start_pos = pos + 1
|
||||
continue
|
||||
|
||||
# Check if this is a word boundary match and doesn't overlap
|
||||
if (self._is_word_boundary_match(name, pos, end_pos) and
|
||||
not has_overlap(pos, end_pos)):
|
||||
entities.append((pos, end_pos, 'NATIVE'))
|
||||
used_spans.append((pos, end_pos))
|
||||
break # Only take the first non-overlapping occurrence
|
||||
|
||||
start_pos = pos + 1
|
||||
|
||||
# Find position of surname in the full name
|
||||
if probable_surname and len(probable_surname.strip()) >= 2:
|
||||
surname_lower = probable_surname.lower()
|
||||
|
||||
# Find the first occurrence that doesn't overlap
|
||||
start_pos = 0
|
||||
while True:
|
||||
pos = name_lower.find(surname_lower, start_pos) # Case-insensitive search
|
||||
if pos == -1:
|
||||
break
|
||||
|
||||
# Calculate end position correctly - exact match only
|
||||
end_pos = pos + len(surname_lower)
|
||||
|
||||
# Double-check that the extracted span matches exactly what we expect
|
||||
extracted_text = name[pos:end_pos] # Get original case text
|
||||
if extracted_text.lower() != surname_lower:
|
||||
start_pos = pos + 1
|
||||
continue
|
||||
|
||||
if (self._is_word_boundary_match(name, pos, end_pos) and
|
||||
not has_overlap(pos, end_pos)):
|
||||
entities.append((pos, end_pos, 'SURNAME'))
|
||||
used_spans.append((pos, end_pos))
|
||||
break
|
||||
|
||||
start_pos = pos + 1
|
||||
|
||||
if not entities:
|
||||
logging.warning(f"No valid entities found for name: '{name}' with native: '{probable_native}' and surname: '{probable_surname}'")
|
||||
return None
|
||||
|
||||
# Sort entities by position and validate
|
||||
entities.sort(key=lambda x: x[0])
|
||||
|
||||
# Final validation - ensure no overlaps and valid spans
|
||||
validated_entities = []
|
||||
for start, end, label in entities:
|
||||
# Check bounds
|
||||
if not (0 <= start < end <= len(name)):
|
||||
logging.warning(f"Invalid span bounds ({start}, {end}) for text length {len(name)}: '{name}'")
|
||||
continue
|
||||
|
||||
# Check for overlaps with already validated entities
|
||||
if any(start < v_end and end > v_start for v_start, v_end, _ in validated_entities):
|
||||
logging.warning(f"Overlapping span ({start}, {end}, '{label}') in '{name}'")
|
||||
continue
|
||||
|
||||
# CRITICAL VALIDATION: Check that the span contains only the expected word (no spaces)
|
||||
span_text = name[start:end]
|
||||
if not span_text or span_text != span_text.strip() or ' ' in span_text:
|
||||
logging.warning(f"Span contains spaces or is empty ({start}, {end}) in '{name}': '{span_text}'")
|
||||
continue
|
||||
|
||||
validated_entities.append((start, end, label))
|
||||
|
||||
if not validated_entities:
|
||||
logging.warning(f"No valid entities after validation for: '{name}'")
|
||||
return None
|
||||
|
||||
# Convert to string format that matches the dataset
|
||||
entities_str = str(validated_entities)
|
||||
|
||||
return {
|
||||
"entities": entities_str,
|
||||
"spans": validated_entities # Keep the original tuples for internal use
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _is_word_boundary_match(cls, text: str, start: int, end: int) -> bool:
|
||||
"""Check if the match is at word boundaries"""
|
||||
# Check character before start position
|
||||
if start > 0:
|
||||
prev_char = text[start - 1]
|
||||
if prev_char.isalnum():
|
||||
return False
|
||||
|
||||
# Check character after end position
|
||||
if end < len(text):
|
||||
next_char = text[end]
|
||||
if next_char.isalnum():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def validate_entities(cls, name: str, entities_str: str) -> bool:
|
||||
"""Validate that entity annotations are correct for a given name"""
|
||||
try:
|
||||
import ast
|
||||
entities = ast.literal_eval(entities_str)
|
||||
|
||||
# Check for overlaps and valid bounds
|
||||
sorted_entities = sorted(entities, key=lambda x: x[0])
|
||||
|
||||
for i, (start, end, label) in enumerate(sorted_entities):
|
||||
# Check bounds
|
||||
if not (0 <= start < end <= len(name)):
|
||||
return False
|
||||
|
||||
# Check for overlaps with next entity
|
||||
if i < len(sorted_entities) - 1:
|
||||
next_start = sorted_entities[i + 1][0]
|
||||
if end > next_start:
|
||||
return False
|
||||
|
||||
# Extract the text span and validate it's not empty
|
||||
span_text = name[start:end]
|
||||
if not span_text.strip():
|
||||
return False
|
||||
|
||||
return True
|
||||
except (ValueError, SyntaxError, TypeError):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def extract_entity_text(cls, name: str, entities_str: str) -> Dict[str, List[str]]:
|
||||
"""Extract the actual text for each entity type"""
|
||||
result = {'NATIVE': [], 'SURNAME': []}
|
||||
|
||||
try:
|
||||
import ast
|
||||
entities = ast.literal_eval(entities_str)
|
||||
|
||||
for start, end, label in entities:
|
||||
if 0 <= start < end <= len(name):
|
||||
span_text = name[start:end]
|
||||
if label in result:
|
||||
result[label].append(span_text)
|
||||
|
||||
except (ValueError, SyntaxError, TypeError):
|
||||
pass
|
||||
|
||||
return result
|
||||
Reference in New Issue
Block a user