feat: web application multipage support
This commit is contained in:
@@ -0,0 +1,68 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
import spacy
|
||||
from spacy.tokens import DocBin
|
||||
|
||||
from core.config import PipelineConfig
|
||||
from core.utils.data_loader import DataLoader
|
||||
from .name_tagger import NameTagger
|
||||
|
||||
|
||||
class NameBuilder:
|
||||
def __init__(self, config: PipelineConfig):
|
||||
self.config = config
|
||||
self.data_loader = DataLoader(config)
|
||||
self.tagger = NameTagger()
|
||||
|
||||
def build(self) -> int:
|
||||
filepath = self.config.paths.get_data_path(self.config.data.output_files["engineered"])
|
||||
df = self.data_loader.load_csv_complete(filepath)
|
||||
df = df[["name", "ner_tagged", "ner_entities"]]
|
||||
|
||||
# Filter early
|
||||
ner_df = df.loc[df["ner_tagged"] == 1, ["name", "ner_entities"]]
|
||||
if ner_df.empty:
|
||||
logging.error("No NER tagged data found")
|
||||
return 1
|
||||
|
||||
total_rows = len(df)
|
||||
del df # No need to keep in memory
|
||||
|
||||
logging.info(f"Found {len(ner_df)} NER tagged entries")
|
||||
nlp = spacy.blank("fr")
|
||||
|
||||
# Use NERNameTagger for parsing and validation
|
||||
parsed_entities = NameTagger.parse_entities(ner_df["ner_entities"])
|
||||
validated_entities = NameTagger.validate_entities(ner_df["name"], parsed_entities)
|
||||
|
||||
# Drop rows with no valid entities
|
||||
mask = validated_entities.map(bool)
|
||||
ner_df = ner_df.loc[mask]
|
||||
validated_entities = validated_entities.loc[mask]
|
||||
|
||||
if ner_df.empty:
|
||||
logging.error("No valid training examples after validation")
|
||||
return 1
|
||||
|
||||
# Prepare training data
|
||||
training_data = list(
|
||||
zip(ner_df["name"].tolist(), [{"entities": ents} for ents in validated_entities])
|
||||
)
|
||||
|
||||
# Use NERNameTagger to create spaCy DocBin
|
||||
docs = NameTagger.create_docs(nlp, ner_df["name"].tolist(), validated_entities.tolist())
|
||||
doc_bin = DocBin(docs=docs)
|
||||
|
||||
# Save
|
||||
json_path = self.config.paths.get_data_path(self.config.data.output_files["ner_data"])
|
||||
spacy_path = self.config.paths.get_data_path(self.config.data.output_files["ner_spacy"])
|
||||
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(training_data, f, ensure_ascii=False, separators=(",", ":"))
|
||||
doc_bin.to_disk(spacy_path)
|
||||
|
||||
logging.info(f"Processed: {len(training_data)}, Skipped: {total_rows - len(training_data)}")
|
||||
logging.info(f"Saved NER JSON to {json_path}")
|
||||
logging.info(f"Saved NER spacy to {spacy_path}")
|
||||
return 0
|
||||
@@ -1,5 +1,5 @@
|
||||
import gc
|
||||
import random
|
||||
from typing import List
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
@@ -7,7 +7,7 @@ import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
from core.config import PipelineConfig
|
||||
from core.utils.data_loader import OPTIMIZED_DTYPES, DataLoader
|
||||
from core.utils.data_loader import DataLoader
|
||||
from processing.ner.formats.connectors_format import ConnectorFormatter
|
||||
from processing.ner.formats.extended_surname_format import ExtendedSurnameFormatter
|
||||
from processing.ner.formats.native_only_format import NativeOnlyFormatter
|
||||
@@ -16,7 +16,7 @@ from processing.ner.formats.position_flipped_format import PositionFlippedFormat
|
||||
from processing.ner.formats.reduced_native_format import ReducedNativeFormatter
|
||||
|
||||
|
||||
class NEREngineering:
|
||||
class NameEngineering:
|
||||
"""
|
||||
Feature engineering for NER dataset to prevent position-based learning
|
||||
and encourage sequence characteristic learning.
|
||||
@@ -66,13 +66,16 @@ class NEREngineering:
|
||||
def compute(self) -> None:
|
||||
logging.info("Applying feature engineering transformations...")
|
||||
input_filepath = self.config.paths.get_data_path(self.config.data.output_files["featured"])
|
||||
output_filepath = self.config.paths.get_data_path(self.config.data.output_files["engineered"])
|
||||
output_filepath = self.config.paths.get_data_path(
|
||||
self.config.data.output_files["engineered"]
|
||||
)
|
||||
|
||||
df = self.data_loader.load_csv_complete(input_filepath)
|
||||
ner_df = df[df["ner_tagged"] == 1].copy()
|
||||
logging.info(f"Loaded {len(ner_df)} NER-tagged records from {len(df)} total records")
|
||||
|
||||
del df # No need to keep in memory
|
||||
gc.collect()
|
||||
|
||||
ner_df = ner_df.sample(frac=1, random_state=self.config.data.random_seed).reset_index(
|
||||
drop=True
|
||||
@@ -1,3 +1,4 @@
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -11,7 +12,7 @@ from spacy.util import minibatch
|
||||
from core.config.pipeline_config import PipelineConfig
|
||||
|
||||
|
||||
class NERNameModel:
|
||||
class NameModel:
|
||||
"""NER model trainer using spaCy for DRC names entity recognition"""
|
||||
|
||||
def __init__(self, config: PipelineConfig):
|
||||
@@ -84,8 +85,6 @@ class NERNameModel:
|
||||
if isinstance(entities_raw, str):
|
||||
# String format from tagger: "[(0, 6, 'NATIVE'), ...]"
|
||||
try:
|
||||
import ast
|
||||
|
||||
entities = ast.literal_eval(entities_raw)
|
||||
if not isinstance(entities, list):
|
||||
logging.warning(
|
||||
@@ -175,9 +174,9 @@ class NERNameModel:
|
||||
def train(
|
||||
self,
|
||||
data: List[Tuple[str, Dict]],
|
||||
epochs: int = 5,
|
||||
batch_size: int = 16,
|
||||
dropout_rate: float = 0.2,
|
||||
epochs: int = 1,
|
||||
batch_size: int = 10_000,
|
||||
dropout_rate: float = 0.3,
|
||||
) -> None:
|
||||
"""Train the NER model"""
|
||||
logging.info(f"Starting NER training with {len(data)} examples")
|
||||
@@ -204,7 +203,7 @@ class NERNameModel:
|
||||
example = Example.from_dict(doc, annotations)
|
||||
examples.append(example)
|
||||
logging.info(
|
||||
f"Training example: {text[:30]}... with entities {annotations.get('entities', [])}"
|
||||
f"Training example: {text[:30]} with entities {annotations.get('entities', [])}"
|
||||
)
|
||||
|
||||
# Train in batches
|
||||
@@ -215,6 +214,7 @@ class NERNameModel:
|
||||
)
|
||||
logging.info(f"Training batch with {len(batch)} examples, current losses: {losses}")
|
||||
|
||||
del batches # free memory
|
||||
epoch_loss = losses.get("ner", 0)
|
||||
losses_history.append(epoch_loss)
|
||||
logging.info(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss:.4f}")
|
||||
@@ -0,0 +1,273 @@
|
||||
from typing import Union, Dict, Any, List
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
import pandas as pd
|
||||
from spacy.util import filter_spans
|
||||
|
||||
|
||||
class NameTagger:
|
||||
def tag_name(
|
||||
self, name: str, probable_native: str, probable_surname: str
|
||||
) -> Union[Dict[str, Any], None]:
|
||||
"""Create a single NER training example using probable_native and probable_surname"""
|
||||
if not name or not probable_native or not probable_surname:
|
||||
return None
|
||||
|
||||
name = name.strip()
|
||||
probable_native = probable_native.strip()
|
||||
probable_surname = probable_surname.strip()
|
||||
|
||||
entities = []
|
||||
used_spans = [] # Track used character spans to prevent overlaps
|
||||
|
||||
# Helper function to check if a span overlaps with any existing span
|
||||
def has_overlap(start, end):
|
||||
for used_start, used_end in used_spans:
|
||||
if not (end <= used_start or start >= used_end):
|
||||
return True
|
||||
return False
|
||||
|
||||
# Find positions of native names in the full name
|
||||
native_words = probable_native.split()
|
||||
name_lower = name.lower() # Use lowercase for consistent searching
|
||||
processed_native_words = set()
|
||||
|
||||
for native_word in native_words:
|
||||
native_word = native_word.strip()
|
||||
if len(native_word) < 2: # Skip very short words
|
||||
continue
|
||||
|
||||
native_word_lower = native_word.lower()
|
||||
|
||||
# Skip if we've already processed this exact word
|
||||
if native_word_lower in processed_native_words:
|
||||
continue
|
||||
processed_native_words.add(native_word_lower)
|
||||
|
||||
# Find the first occurrence of this native word that doesn't overlap
|
||||
start_pos = 0
|
||||
while True:
|
||||
pos = name_lower.find(native_word_lower, start_pos) # Case-insensitive search
|
||||
if pos == -1:
|
||||
break
|
||||
|
||||
# Calculate end position - make sure we only include the word itself
|
||||
end_pos = pos + len(native_word_lower)
|
||||
|
||||
# Double-check that the extracted span matches exactly what we expect
|
||||
extracted_text = name[pos:end_pos] # Get original case text
|
||||
if extracted_text.lower() != native_word_lower:
|
||||
start_pos = pos + 1
|
||||
continue
|
||||
|
||||
# Check if this is a word boundary match and doesn't overlap
|
||||
if self._is_word_boundary_match(name, pos, end_pos) and not has_overlap(
|
||||
pos, end_pos
|
||||
):
|
||||
entities.append((pos, end_pos, "NATIVE"))
|
||||
used_spans.append((pos, end_pos))
|
||||
break # Only take the first non-overlapping occurrence
|
||||
|
||||
start_pos = pos + 1
|
||||
|
||||
# Find position of surname in the full name
|
||||
if probable_surname and len(probable_surname.strip()) >= 2:
|
||||
surname_lower = probable_surname.lower()
|
||||
|
||||
# Find the first occurrence that doesn't overlap
|
||||
start_pos = 0
|
||||
while True:
|
||||
pos = name_lower.find(surname_lower, start_pos) # Case-insensitive search
|
||||
if pos == -1:
|
||||
break
|
||||
|
||||
# Calculate end position correctly - exact match only
|
||||
end_pos = pos + len(surname_lower)
|
||||
|
||||
# Double-check that the extracted span matches exactly what we expect
|
||||
extracted_text = name[pos:end_pos] # Get original case text
|
||||
if extracted_text.lower() != surname_lower:
|
||||
start_pos = pos + 1
|
||||
continue
|
||||
|
||||
if self._is_word_boundary_match(name, pos, end_pos) and not has_overlap(
|
||||
pos, end_pos
|
||||
):
|
||||
entities.append((pos, end_pos, "SURNAME"))
|
||||
used_spans.append((pos, end_pos))
|
||||
break
|
||||
|
||||
start_pos = pos + 1
|
||||
|
||||
if not entities:
|
||||
logging.warning(
|
||||
f"No valid entities found for name: '{name}' with native: '{probable_native}' and surname: '{probable_surname}'"
|
||||
)
|
||||
return None
|
||||
|
||||
# Sort entities by position and validate
|
||||
entities.sort(key=lambda x: x[0])
|
||||
|
||||
# Final validation - ensure no overlaps and valid spans
|
||||
validated_entities = []
|
||||
for start, end, label in entities:
|
||||
# Check bounds
|
||||
if not (0 <= start < end <= len(name)):
|
||||
logging.warning(
|
||||
f"Invalid span bounds ({start}, {end}) for text length {len(name)}: '{name}'"
|
||||
)
|
||||
continue
|
||||
|
||||
# Check for overlaps with already validated entities
|
||||
if any(start < v_end and end > v_start for v_start, v_end, _ in validated_entities):
|
||||
logging.warning(f"Overlapping span ({start}, {end}, '{label}') in '{name}'")
|
||||
continue
|
||||
|
||||
# CRITICAL VALIDATION: Check that the span contains only the expected word (no spaces)
|
||||
span_text = name[start:end]
|
||||
if not span_text or span_text != span_text.strip() or " " in span_text:
|
||||
logging.warning(
|
||||
f"Span contains spaces or is empty ({start}, {end}) in '{name}': '{span_text}'"
|
||||
)
|
||||
continue
|
||||
|
||||
validated_entities.append((start, end, label))
|
||||
|
||||
if not validated_entities:
|
||||
logging.warning(f"No valid entities after validation for: '{name}'")
|
||||
return None
|
||||
|
||||
# Convert to string format that matches the dataset
|
||||
entities_str = str(validated_entities)
|
||||
|
||||
return {
|
||||
"entities": entities_str,
|
||||
"spans": validated_entities, # Keep the original tuples for internal use
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _is_word_boundary_match(cls, text: str, start: int, end: int) -> bool:
|
||||
"""Check if the match is at word boundaries"""
|
||||
# Check character before start position
|
||||
if start > 0:
|
||||
prev_char = text[start - 1]
|
||||
if prev_char.isalnum():
|
||||
return False
|
||||
|
||||
# Check character after end position
|
||||
if end < len(text):
|
||||
next_char = text[end]
|
||||
if next_char.isalnum():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def extract_entity_text(cls, name: str, entities_str: str) -> Dict[str, List[str]]:
|
||||
"""Extract the actual text for each entity type"""
|
||||
result = {"NATIVE": [], "SURNAME": []}
|
||||
|
||||
try:
|
||||
entities = ast.literal_eval(entities_str)
|
||||
|
||||
for start, end, label in entities:
|
||||
if 0 <= start < end <= len(name):
|
||||
span_text = name[start:end]
|
||||
if label in result:
|
||||
result[label].append(span_text)
|
||||
|
||||
except (ValueError, SyntaxError, TypeError):
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def parse(cls, entities_str: str) -> List[tuple]:
|
||||
"""Parse entity strings from various formats.
|
||||
|
||||
Supports formats:
|
||||
- [(start, end, label), ...]
|
||||
- [[start, end, label], ...]
|
||||
- [{"start": start, "end": end, "label": label}, ...]
|
||||
"""
|
||||
if not entities_str or entities_str in ["[]", "", "nan"]:
|
||||
return []
|
||||
entities_str = str(entities_str).strip()
|
||||
try:
|
||||
if entities_str.startswith("[(") and entities_str.endswith(")]"):
|
||||
return ast.literal_eval(entities_str)
|
||||
elif entities_str.startswith("[[") and entities_str.endswith("]]"):
|
||||
return [tuple(e) for e in ast.literal_eval(entities_str)]
|
||||
elif entities_str.startswith("[{") and entities_str.endswith("}]"):
|
||||
return [(e["start"], e["end"], e["label"]) for e in json.loads(entities_str)]
|
||||
else:
|
||||
parsed = ast.literal_eval(entities_str)
|
||||
return [tuple(e) for e in parsed if isinstance(e, (list, tuple)) and len(e) == 3]
|
||||
except (ValueError, SyntaxError, json.JSONDecodeError):
|
||||
return []
|
||||
|
||||
def parse_entities(self, series: pd.Series) -> pd.Series:
|
||||
"""Vectorized parse of entity strings."""
|
||||
return series.map(self.parse)
|
||||
|
||||
@classmethod
|
||||
def validate(cls, text: str, entities: List[tuple]) -> List[tuple]:
|
||||
"""Advanced entity validation with overlap removal.
|
||||
|
||||
This is more comprehensive than the basic validate_entities method.
|
||||
"""
|
||||
if not entities or not text:
|
||||
return []
|
||||
text = str(text).strip()
|
||||
valid = []
|
||||
|
||||
for ent in entities:
|
||||
if not isinstance(ent, (list, tuple)) or len(ent) != 3:
|
||||
continue
|
||||
start, end, label = ent
|
||||
try:
|
||||
start, end = int(start), int(end)
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
if not isinstance(label, str):
|
||||
continue
|
||||
if not (0 <= start < end <= len(text)):
|
||||
continue
|
||||
if not text[start:end].strip():
|
||||
continue
|
||||
valid.append((start, end, label))
|
||||
|
||||
if not valid:
|
||||
return []
|
||||
|
||||
valid.sort(key=lambda x: (x[0], x[1]))
|
||||
|
||||
# Remove overlaps
|
||||
filtered, last_end = [], -1
|
||||
for s, e, l in valid:
|
||||
if s >= last_end:
|
||||
filtered.append((s, e, l))
|
||||
last_end = e
|
||||
return filtered
|
||||
|
||||
def validate_entities(self, texts: pd.Series, entities_series: pd.Series) -> pd.Series:
|
||||
"""Vectorized entity validation."""
|
||||
return pd.Series(map(self.validate, texts, entities_series), index=texts.index)
|
||||
|
||||
@classmethod
|
||||
def create_docs(cls, nlp, texts: List[str], entities: List[List[tuple]]) -> List:
|
||||
"""Batch create spaCy Docs from texts and entities."""
|
||||
docs = []
|
||||
for text, ents in zip(texts, entities):
|
||||
doc = nlp(text)
|
||||
spans = []
|
||||
for start, end, label in ents:
|
||||
span = doc.char_span(
|
||||
start, end, label=label, alignment_mode="contract"
|
||||
) or doc.char_span(start, end, label=label, alignment_mode="strict")
|
||||
if span:
|
||||
spans.append(span)
|
||||
doc.ents = filter_spans(spans)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
@@ -1,149 +0,0 @@
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import spacy
|
||||
from spacy.tokens import DocBin
|
||||
from spacy.util import filter_spans
|
||||
|
||||
from core.config import PipelineConfig
|
||||
from core.utils.data_loader import DataLoader
|
||||
|
||||
|
||||
class NERDataBuilder:
|
||||
def __init__(self, config: PipelineConfig):
|
||||
self.config = config
|
||||
self.data_loader = DataLoader(config)
|
||||
|
||||
@staticmethod
|
||||
def _parse_entities(series: pd.Series) -> pd.Series:
|
||||
"""Vectorized parse of entity strings."""
|
||||
|
||||
def _parse(entities_str):
|
||||
if not entities_str or entities_str in ["[]", "", "nan"]:
|
||||
return []
|
||||
entities_str = str(entities_str).strip()
|
||||
try:
|
||||
if entities_str.startswith("[(") and entities_str.endswith(")]"):
|
||||
return ast.literal_eval(entities_str)
|
||||
elif entities_str.startswith("[[") and entities_str.endswith("]]"):
|
||||
return [tuple(e) for e in ast.literal_eval(entities_str)]
|
||||
elif entities_str.startswith("[{") and entities_str.endswith("}]"):
|
||||
return [(e["start"], e["end"], e["label"]) for e in json.loads(entities_str)]
|
||||
else:
|
||||
parsed = ast.literal_eval(entities_str)
|
||||
return [
|
||||
tuple(e) for e in parsed if isinstance(e, (list, tuple)) and len(e) == 3
|
||||
]
|
||||
except (ValueError, SyntaxError, json.JSONDecodeError):
|
||||
return []
|
||||
|
||||
return series.map(_parse)
|
||||
|
||||
@staticmethod
|
||||
def _validate_entities(texts: pd.Series, entities_series: pd.Series) -> pd.Series:
|
||||
"""Vectorized entity validation."""
|
||||
|
||||
def _validate(text, entities):
|
||||
if not entities or not text:
|
||||
return []
|
||||
text = str(text).strip()
|
||||
valid = []
|
||||
for ent in entities:
|
||||
if not isinstance(ent, (list, tuple)) or len(ent) != 3:
|
||||
continue
|
||||
start, end, label = ent
|
||||
try:
|
||||
start, end = int(start), int(end)
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
if not isinstance(label, str):
|
||||
continue
|
||||
if not (0 <= start < end <= len(text)):
|
||||
continue
|
||||
if not text[start:end].strip():
|
||||
continue
|
||||
valid.append((start, end, label))
|
||||
if not valid:
|
||||
return []
|
||||
valid.sort(key=lambda x: (x[0], x[1]))
|
||||
# remove overlaps
|
||||
filtered, last_end = [], -1
|
||||
for s, e, l in valid:
|
||||
if s >= last_end:
|
||||
filtered.append((s, e, l))
|
||||
last_end = e
|
||||
return filtered
|
||||
|
||||
return pd.Series(map(_validate, texts, entities_series), index=texts.index)
|
||||
|
||||
@staticmethod
|
||||
def _create_docs(nlp, texts, entities):
|
||||
"""Batch create spaCy Docs."""
|
||||
docs = []
|
||||
for text, ents in zip(texts, entities):
|
||||
doc = nlp(text)
|
||||
spans = []
|
||||
for start, end, label in ents:
|
||||
span = doc.char_span(
|
||||
start, end, label=label, alignment_mode="contract"
|
||||
) or doc.char_span(start, end, label=label, alignment_mode="strict")
|
||||
if span:
|
||||
spans.append(span)
|
||||
doc.ents = filter_spans(spans)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def build(self) -> int:
|
||||
filepath = self.config.paths.get_data_path(self.config.data.output_files["engineered"])
|
||||
df = self.data_loader.load_csv_complete(filepath)
|
||||
df = df[["name", "ner_tagged", "ner_entities"]]
|
||||
|
||||
# Filter early
|
||||
ner_df = df.loc[df["ner_tagged"] == 1, ["name", "ner_entities"]]
|
||||
if ner_df.empty:
|
||||
logging.error("No NER tagged data found")
|
||||
return 1
|
||||
|
||||
total_rows = len(df)
|
||||
del df # No need to keep in memory
|
||||
|
||||
logging.info(f"Found {len(ner_df)} NER tagged entries")
|
||||
nlp = spacy.blank("fr")
|
||||
|
||||
# Vectorized parsing + validation
|
||||
parsed_entities = self._parse_entities(ner_df["ner_entities"])
|
||||
validated_entities = self._validate_entities(ner_df["name"], parsed_entities)
|
||||
|
||||
# Drop rows with no valid entities
|
||||
mask = validated_entities.map(bool)
|
||||
ner_df = ner_df.loc[mask]
|
||||
validated_entities = validated_entities.loc[mask]
|
||||
|
||||
if ner_df.empty:
|
||||
logging.error("No valid training examples after validation")
|
||||
return 1
|
||||
|
||||
# Prepare training data
|
||||
training_data = list(
|
||||
zip(ner_df["name"].tolist(), [{"entities": ents} for ents in validated_entities])
|
||||
)
|
||||
|
||||
# Create spaCy DocBin in batch
|
||||
docs = self._create_docs(nlp, ner_df["name"].tolist(), validated_entities.tolist())
|
||||
doc_bin = DocBin(docs=docs)
|
||||
|
||||
# Save
|
||||
json_path = self.config.paths.get_data_path(self.config.data.output_files["ner_data"])
|
||||
spacy_path = self.config.paths.get_data_path(self.config.data.output_files["ner_spacy"])
|
||||
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(training_data, f, ensure_ascii=False, separators=(",", ":"))
|
||||
doc_bin.to_disk(spacy_path)
|
||||
|
||||
logging.info(f"Processed: {len(training_data)}, Skipped: {total_rows - len(training_data)}")
|
||||
logging.info(f"Saved NER JSON to {json_path}")
|
||||
logging.info(f"Saved NER spacy to {spacy_path}")
|
||||
return 0
|
||||
|
||||
@@ -1,212 +0,0 @@
|
||||
from typing import Union, Dict, Any, List
|
||||
import logging
|
||||
|
||||
|
||||
class NERNameTagger:
|
||||
def tag_name(
|
||||
self, name: str, probable_native: str, probable_surname: str
|
||||
) -> Union[Dict[str, Any], None]:
|
||||
"""Create a single NER training example using probable_native and probable_surname"""
|
||||
if not name or not probable_native or not probable_surname:
|
||||
return None
|
||||
|
||||
name = name.strip()
|
||||
probable_native = probable_native.strip()
|
||||
probable_surname = probable_surname.strip()
|
||||
|
||||
entities = []
|
||||
used_spans = [] # Track used character spans to prevent overlaps
|
||||
|
||||
# Helper function to check if a span overlaps with any existing span
|
||||
def has_overlap(start, end):
|
||||
for used_start, used_end in used_spans:
|
||||
if not (end <= used_start or start >= used_end):
|
||||
return True
|
||||
return False
|
||||
|
||||
# Find positions of native names in the full name
|
||||
native_words = probable_native.split()
|
||||
name_lower = name.lower() # Use lowercase for consistent searching
|
||||
processed_native_words = set()
|
||||
|
||||
for native_word in native_words:
|
||||
native_word = native_word.strip()
|
||||
if len(native_word) < 2: # Skip very short words
|
||||
continue
|
||||
|
||||
native_word_lower = native_word.lower()
|
||||
|
||||
# Skip if we've already processed this exact word
|
||||
if native_word_lower in processed_native_words:
|
||||
continue
|
||||
processed_native_words.add(native_word_lower)
|
||||
|
||||
# Find the first occurrence of this native word that doesn't overlap
|
||||
start_pos = 0
|
||||
while True:
|
||||
pos = name_lower.find(native_word_lower, start_pos) # Case-insensitive search
|
||||
if pos == -1:
|
||||
break
|
||||
|
||||
# Calculate end position - make sure we only include the word itself
|
||||
end_pos = pos + len(native_word_lower)
|
||||
|
||||
# Double-check that the extracted span matches exactly what we expect
|
||||
extracted_text = name[pos:end_pos] # Get original case text
|
||||
if extracted_text.lower() != native_word_lower:
|
||||
start_pos = pos + 1
|
||||
continue
|
||||
|
||||
# Check if this is a word boundary match and doesn't overlap
|
||||
if self._is_word_boundary_match(name, pos, end_pos) and not has_overlap(
|
||||
pos, end_pos
|
||||
):
|
||||
entities.append((pos, end_pos, "NATIVE"))
|
||||
used_spans.append((pos, end_pos))
|
||||
break # Only take the first non-overlapping occurrence
|
||||
|
||||
start_pos = pos + 1
|
||||
|
||||
# Find position of surname in the full name
|
||||
if probable_surname and len(probable_surname.strip()) >= 2:
|
||||
surname_lower = probable_surname.lower()
|
||||
|
||||
# Find the first occurrence that doesn't overlap
|
||||
start_pos = 0
|
||||
while True:
|
||||
pos = name_lower.find(surname_lower, start_pos) # Case-insensitive search
|
||||
if pos == -1:
|
||||
break
|
||||
|
||||
# Calculate end position correctly - exact match only
|
||||
end_pos = pos + len(surname_lower)
|
||||
|
||||
# Double-check that the extracted span matches exactly what we expect
|
||||
extracted_text = name[pos:end_pos] # Get original case text
|
||||
if extracted_text.lower() != surname_lower:
|
||||
start_pos = pos + 1
|
||||
continue
|
||||
|
||||
if self._is_word_boundary_match(name, pos, end_pos) and not has_overlap(
|
||||
pos, end_pos
|
||||
):
|
||||
entities.append((pos, end_pos, "SURNAME"))
|
||||
used_spans.append((pos, end_pos))
|
||||
break
|
||||
|
||||
start_pos = pos + 1
|
||||
|
||||
if not entities:
|
||||
logging.warning(
|
||||
f"No valid entities found for name: '{name}' with native: '{probable_native}' and surname: '{probable_surname}'"
|
||||
)
|
||||
return None
|
||||
|
||||
# Sort entities by position and validate
|
||||
entities.sort(key=lambda x: x[0])
|
||||
|
||||
# Final validation - ensure no overlaps and valid spans
|
||||
validated_entities = []
|
||||
for start, end, label in entities:
|
||||
# Check bounds
|
||||
if not (0 <= start < end <= len(name)):
|
||||
logging.warning(
|
||||
f"Invalid span bounds ({start}, {end}) for text length {len(name)}: '{name}'"
|
||||
)
|
||||
continue
|
||||
|
||||
# Check for overlaps with already validated entities
|
||||
if any(start < v_end and end > v_start for v_start, v_end, _ in validated_entities):
|
||||
logging.warning(f"Overlapping span ({start}, {end}, '{label}') in '{name}'")
|
||||
continue
|
||||
|
||||
# CRITICAL VALIDATION: Check that the span contains only the expected word (no spaces)
|
||||
span_text = name[start:end]
|
||||
if not span_text or span_text != span_text.strip() or " " in span_text:
|
||||
logging.warning(
|
||||
f"Span contains spaces or is empty ({start}, {end}) in '{name}': '{span_text}'"
|
||||
)
|
||||
continue
|
||||
|
||||
validated_entities.append((start, end, label))
|
||||
|
||||
if not validated_entities:
|
||||
logging.warning(f"No valid entities after validation for: '{name}'")
|
||||
return None
|
||||
|
||||
# Convert to string format that matches the dataset
|
||||
entities_str = str(validated_entities)
|
||||
|
||||
return {
|
||||
"entities": entities_str,
|
||||
"spans": validated_entities, # Keep the original tuples for internal use
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _is_word_boundary_match(cls, text: str, start: int, end: int) -> bool:
|
||||
"""Check if the match is at word boundaries"""
|
||||
# Check character before start position
|
||||
if start > 0:
|
||||
prev_char = text[start - 1]
|
||||
if prev_char.isalnum():
|
||||
return False
|
||||
|
||||
# Check character after end position
|
||||
if end < len(text):
|
||||
next_char = text[end]
|
||||
if next_char.isalnum():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def validate_entities(cls, name: str, entities_str: str) -> bool:
|
||||
"""Validate that entity annotations are correct for a given name"""
|
||||
try:
|
||||
import ast
|
||||
|
||||
entities = ast.literal_eval(entities_str)
|
||||
|
||||
# Check for overlaps and valid bounds
|
||||
sorted_entities = sorted(entities, key=lambda x: x[0])
|
||||
|
||||
for i, (start, end, label) in enumerate(sorted_entities):
|
||||
# Check bounds
|
||||
if not (0 <= start < end <= len(name)):
|
||||
return False
|
||||
|
||||
# Check for overlaps with next entity
|
||||
if i < len(sorted_entities) - 1:
|
||||
next_start = sorted_entities[i + 1][0]
|
||||
if end > next_start:
|
||||
return False
|
||||
|
||||
# Extract the text span and validate it's not empty
|
||||
span_text = name[start:end]
|
||||
if not span_text.strip():
|
||||
return False
|
||||
|
||||
return True
|
||||
except (ValueError, SyntaxError, TypeError):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def extract_entity_text(cls, name: str, entities_str: str) -> Dict[str, List[str]]:
|
||||
"""Extract the actual text for each entity type"""
|
||||
result = {"NATIVE": [], "SURNAME": []}
|
||||
|
||||
try:
|
||||
import ast
|
||||
|
||||
entities = ast.literal_eval(entities_str)
|
||||
|
||||
for start, end, label in entities:
|
||||
if 0 <= start < end <= len(name):
|
||||
span_text = name[start:end]
|
||||
if label in result:
|
||||
result[label].append(span_text)
|
||||
|
||||
except (ValueError, SyntaxError, TypeError):
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user