feat: add NER annotation step and integrate into pipeline
This commit is contained in:
@@ -56,6 +56,7 @@ the `drc-ners-nlp/config/pipeline.yaml` file.
|
|||||||
stages:
|
stages:
|
||||||
- "data_cleaning"
|
- "data_cleaning"
|
||||||
- "feature_extraction"
|
- "feature_extraction"
|
||||||
|
- "ner_annotation"
|
||||||
- "llm_annotation"
|
- "llm_annotation"
|
||||||
- "data_splitting"
|
- "data_splitting"
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ processing:
|
|||||||
stages:
|
stages:
|
||||||
- "data_cleaning"
|
- "data_cleaning"
|
||||||
- "feature_extraction"
|
- "feature_extraction"
|
||||||
|
#- "ner_annotation"
|
||||||
#- "llm_annotation"
|
#- "llm_annotation"
|
||||||
- "data_splitting"
|
- "data_splitting"
|
||||||
|
|
||||||
@@ -27,7 +28,8 @@ llm:
|
|||||||
|
|
||||||
# Data handling configuration
|
# Data handling configuration
|
||||||
data:
|
data:
|
||||||
max_dataset_size: 100_000
|
split_evaluation: false
|
||||||
|
max_dataset_size: null
|
||||||
balance_by_sex: true
|
balance_by_sex: true
|
||||||
|
|
||||||
# Enhanced logging for development
|
# Enhanced logging for development
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ processing:
|
|||||||
stages:
|
stages:
|
||||||
- "data_cleaning"
|
- "data_cleaning"
|
||||||
- "feature_extraction"
|
- "feature_extraction"
|
||||||
|
- "ner_annotation"
|
||||||
- "llm_annotation"
|
- "llm_annotation"
|
||||||
- "data_splitting"
|
- "data_splitting"
|
||||||
|
|
||||||
|
|||||||
+20
-11
@@ -18,9 +18,10 @@ paths:
|
|||||||
checkpoints_dir: "./data/checkpoints" # Directory for model checkpoints
|
checkpoints_dir: "./data/checkpoints" # Directory for model checkpoints
|
||||||
|
|
||||||
# Pipeline stages
|
# Pipeline stages
|
||||||
stages: # List of stages in the processing pipeline
|
stages: # List of stages in the processing pipeline
|
||||||
- "data_cleaning" # Data cleaning stage
|
- "data_cleaning" # Data cleaning stage
|
||||||
- "feature_extraction" # Feature extraction stage
|
- "feature_extraction" # Feature extraction stage
|
||||||
|
- "ner_annotation" # NER-based annotation stage
|
||||||
- "llm_annotation" # LLM annotation stage (computational intensive)
|
- "llm_annotation" # LLM annotation stage (computational intensive)
|
||||||
- "data_splitting" # Data splitting stage
|
- "data_splitting" # Data splitting stage
|
||||||
|
|
||||||
@@ -36,15 +37,20 @@ processing:
|
|||||||
- "latin1"
|
- "latin1"
|
||||||
chunk_size: 100_000 # Size of data chunks to process in parallel
|
chunk_size: 100_000 # Size of data chunks to process in parallel
|
||||||
|
|
||||||
# LLM annotation settings
|
# Annotation settings
|
||||||
llm:
|
annotation:
|
||||||
model_name: "mistral:7b" # Name of the LLM model to use
|
llm:
|
||||||
requests_per_minute: 60 # Requests per minute to the LLM service
|
model_name: "mistral:7b" # Name of the LLM model to use
|
||||||
requests_per_second: 2 # Requests per second to the LLM service
|
requests_per_minute: 60 # Requests per minute to the LLM service
|
||||||
retry_attempts: 3 # Number of retry attempts for LLM requests
|
requests_per_second: 2 # Requests per second to the LLM service
|
||||||
timeout_seconds: 600 # Timeout for LLM requests
|
retry_attempts: 3 # Number of retry attempts for LLM requests
|
||||||
max_concurrent_requests: 2 # Maximum concurrent requests to the LLM service
|
timeout_seconds: 600 # Timeout for LLM requests
|
||||||
enable_rate_limiting: true # Enable rate limiting to avoid overloading the LLM service
|
max_concurrent_requests: 2 # Maximum concurrent requests to the LLM service
|
||||||
|
enable_rate_limiting: true # Enable rate limiting to avoid overloading the LLM service
|
||||||
|
|
||||||
|
ner:
|
||||||
|
model_name: "drc_names_ner" # Name of the NER model to use
|
||||||
|
retry_attempts: 3 # Number of retry attempts for NER requests
|
||||||
|
|
||||||
# Data handling configuration
|
# Data handling configuration
|
||||||
data:
|
data:
|
||||||
@@ -54,8 +60,11 @@ data:
|
|||||||
evaluation: "names_evaluation.csv" # Output file for evaluation set
|
evaluation: "names_evaluation.csv" # Output file for evaluation set
|
||||||
males: "names_males.csv" # Output files for male names
|
males: "names_males.csv" # Output files for male names
|
||||||
females: "names_females.csv" # Output files for female names
|
females: "names_females.csv" # Output files for female names
|
||||||
split_evaluation: true # Should the dataset be split into training and evaluation sets ?
|
ner_data: "names_ner.json" # Output file for NER annotated data
|
||||||
|
ner_spacy: "names_ner.spacy" # Output file for NER annotated data using spaCy format
|
||||||
|
split_evaluation: false # Should the dataset be split into training and evaluation sets ?
|
||||||
split_by_gender: true # Should the dataset be split by gender ?
|
split_by_gender: true # Should the dataset be split by gender ?
|
||||||
|
split_ner_data: true # Should the NER data be extracted and saved?
|
||||||
evaluation_fraction: 0.2 # Fraction of data to use for evaluation
|
evaluation_fraction: 0.2 # Fraction of data to use for evaluation
|
||||||
random_seed: 42 # Random seed for reproducibility
|
random_seed: 42 # Random seed for reproducibility
|
||||||
max_dataset_size: null # Maximum size of the dataset to process, set to null for no
|
max_dataset_size: null # Maximum size of the dataset to process, set to null for no
|
||||||
|
|||||||
@@ -0,0 +1,29 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class NERConfig(BaseModel):
|
||||||
|
"""NER annotation configuration"""
|
||||||
|
|
||||||
|
model_name: str = "drc_names_ner"
|
||||||
|
retry_attempts: int = 3
|
||||||
|
|
||||||
|
|
||||||
|
class LLMConfig(BaseModel):
|
||||||
|
"""LLM annotation configuration"""
|
||||||
|
|
||||||
|
model_name: str = "mistral:7b"
|
||||||
|
requests_per_minute: int = 60
|
||||||
|
requests_per_second: int = 2
|
||||||
|
retry_attempts: int = 3
|
||||||
|
timeout_seconds: int = 30
|
||||||
|
max_concurrent_requests: int = 2
|
||||||
|
enable_rate_limiting: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotationConfig(BaseModel):
|
||||||
|
"""Base class for annotation configurations"""
|
||||||
|
|
||||||
|
llm: LLMConfig = LLMConfig()
|
||||||
|
ner: NERConfig = NERConfig()
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
@@ -65,7 +65,7 @@ class ConfigManager:
|
|||||||
|
|
||||||
# Ensure paths are properly set
|
# Ensure paths are properly set
|
||||||
if "paths" not in config_data:
|
if "paths" not in config_data:
|
||||||
config_data["paths"] = self.default_paths.dict()
|
config_data["paths"] = self.default_paths.model_dump()
|
||||||
|
|
||||||
self._config = PipelineConfig(**config_data)
|
self._config = PipelineConfig(**config_data)
|
||||||
return self._config
|
return self._config
|
||||||
|
|||||||
@@ -14,10 +14,13 @@ class DataConfig(BaseModel):
|
|||||||
"evaluation": "names_evaluation.csv",
|
"evaluation": "names_evaluation.csv",
|
||||||
"males": "names_males.csv",
|
"males": "names_males.csv",
|
||||||
"females": "names_females.csv",
|
"females": "names_females.csv",
|
||||||
|
"ner_data": "names_ner.json",
|
||||||
|
"ner_spacy": "names_ner.spacy"
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
split_evaluation: bool = True
|
split_evaluation: bool = False
|
||||||
split_by_gender: bool = True
|
split_by_gender: bool = True
|
||||||
|
split_ner_data: bool = True
|
||||||
evaluation_fraction: float = 0.2
|
evaluation_fraction: float = 0.2
|
||||||
random_seed: int = 42
|
random_seed: int = 42
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class LLMConfig(BaseModel):
|
|
||||||
"""LLM annotation configuration"""
|
|
||||||
|
|
||||||
model_name: str = "mistral:7b"
|
|
||||||
requests_per_minute: int = 60
|
|
||||||
requests_per_second: int = 2
|
|
||||||
retry_attempts: int = 3
|
|
||||||
timeout_seconds: int = 30
|
|
||||||
max_concurrent_requests: int = 2
|
|
||||||
enable_rate_limiting: bool = False
|
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.config.logging_config import LoggingConfig
|
from core.config.annotation_config import AnnotationConfig
|
||||||
from core.config.data_config import DataConfig
|
from core.config.data_config import DataConfig
|
||||||
from core.config.llm_config import LLMConfig
|
from core.config.logging_config import LoggingConfig
|
||||||
from core.config.processing_config import ProcessingConfig
|
from core.config.processing_config import ProcessingConfig
|
||||||
from core.config.project_paths import ProjectPaths
|
from core.config.project_paths import ProjectPaths
|
||||||
|
|
||||||
@@ -17,7 +17,7 @@ class PipelineConfig(BaseModel):
|
|||||||
paths: ProjectPaths
|
paths: ProjectPaths
|
||||||
stages: list[str] = []
|
stages: list[str] = []
|
||||||
processing: ProcessingConfig = ProcessingConfig()
|
processing: ProcessingConfig = ProcessingConfig()
|
||||||
llm: LLMConfig = LLMConfig()
|
annotation: AnnotationConfig = AnnotationConfig()
|
||||||
data: DataConfig = DataConfig()
|
data: DataConfig = DataConfig()
|
||||||
logging: LoggingConfig = LoggingConfig()
|
logging: LoggingConfig = LoggingConfig()
|
||||||
|
|
||||||
|
|||||||
@@ -8,11 +8,13 @@ from core.config import setup_config
|
|||||||
from core.utils import get_data_file_path
|
from core.utils import get_data_file_path
|
||||||
from core.utils.data_loader import DataLoader
|
from core.utils.data_loader import DataLoader
|
||||||
from processing.batch.batch_config import BatchConfig
|
from processing.batch.batch_config import BatchConfig
|
||||||
|
from processing.ner.ner_data_builder import NERDataBuilder
|
||||||
from processing.pipeline import Pipeline
|
from processing.pipeline import Pipeline
|
||||||
from processing.steps.data_cleaning_step import DataCleaningStep
|
from processing.steps.data_cleaning_step import DataCleaningStep
|
||||||
from processing.steps.data_splitting_step import DataSplittingStep
|
from processing.steps.data_splitting_step import DataSplittingStep
|
||||||
from processing.steps.feature_extraction_step import FeatureExtractionStep
|
from processing.steps.feature_extraction_step import FeatureExtractionStep
|
||||||
from processing.steps.llm_annotation_step import LLMAnnotationStep
|
from processing.steps.llm_annotation_step import LLMAnnotationStep
|
||||||
|
from processing.steps.ner_annotation_step import NERAnnotationStep
|
||||||
|
|
||||||
|
|
||||||
def create_pipeline(config) -> Pipeline:
|
def create_pipeline(config) -> Pipeline:
|
||||||
@@ -29,6 +31,7 @@ def create_pipeline(config) -> Pipeline:
|
|||||||
steps = [
|
steps = [
|
||||||
DataCleaningStep(config),
|
DataCleaningStep(config),
|
||||||
FeatureExtractionStep(config),
|
FeatureExtractionStep(config),
|
||||||
|
NERAnnotationStep(config),
|
||||||
LLMAnnotationStep(config),
|
LLMAnnotationStep(config),
|
||||||
DataSplittingStep(config),
|
DataSplittingStep(config),
|
||||||
]
|
]
|
||||||
@@ -67,6 +70,7 @@ def run_pipeline(config) -> int:
|
|||||||
splitting_step = pipeline.steps[-1]
|
splitting_step = pipeline.steps[-1]
|
||||||
if isinstance(splitting_step, DataSplittingStep):
|
if isinstance(splitting_step, DataSplittingStep):
|
||||||
splitting_step.save_splits(result_df)
|
splitting_step.save_splits(result_df)
|
||||||
|
NERDataBuilder(config).build(result_df)
|
||||||
|
|
||||||
# Show completion statistics
|
# Show completion statistics
|
||||||
progress = pipeline.get_progress()
|
progress = pipeline.get_progress()
|
||||||
|
|||||||
+17
-89
@@ -5,65 +5,31 @@ import traceback
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from core.config import setup_config
|
from core.config import setup_config
|
||||||
from processing.monitoring.data_analyzer import DatasetAnalyzer
|
|
||||||
from processing.monitoring.pipeline_monitor import PipelineMonitor
|
from processing.monitoring.pipeline_monitor import PipelineMonitor
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
choices = ["data_cleaning", "feature_extraction", "ner_annotation", "llm_annotation", "data_splitting"]
|
||||||
description="Monitor and manage the DRC names processing pipeline"
|
|
||||||
)
|
|
||||||
parser.add_argument("--config", type=Path, help="Path to configuration file")
|
|
||||||
parser.add_argument(
|
|
||||||
"--env", type=str, default="development",
|
|
||||||
help="Environment name (default: development)"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Monitor and manage the DRC names processing pipeline")
|
||||||
|
parser.add_argument("--config", type=Path, help="Path to configuration file")
|
||||||
|
parser.add_argument("--env", type=str, default="development", help="Environment")
|
||||||
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
||||||
|
|
||||||
# Status command
|
# Status command
|
||||||
status_parser = subparsers.add_parser("status", help="Show pipeline status")
|
subparsers.add_parser("status", help="Show pipeline status")
|
||||||
status_parser.add_argument(
|
|
||||||
"--detailed",
|
|
||||||
action="store_true",
|
|
||||||
help="Show detailed information including failed batch IDs",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Clean command
|
# Clean command
|
||||||
clean_parser = subparsers.add_parser("clean", help="Clean checkpoint files")
|
clean_parser = subparsers.add_parser("clean", help="Clean checkpoint files")
|
||||||
clean_parser.add_argument(
|
clean_parser.add_argument("--step", type=str, choices=choices, help="Specific step (default: all)")
|
||||||
"--step",
|
clean_parser.add_argument("--keep-last", type=int, default=1, help="Checkpoints to keep (default: 1)")
|
||||||
type=str,
|
|
||||||
choices=["data_cleaning", "feature_extraction", "llm_annotation", "data_splitting"],
|
|
||||||
help="Clean specific step (default: all)",
|
|
||||||
)
|
|
||||||
clean_parser.add_argument(
|
|
||||||
"--keep-last", type=int, default=1, help="Number of recent checkpoints to keep (default: 1)"
|
|
||||||
)
|
|
||||||
clean_parser.add_argument("--force", action="store_true", help="Clean without confirmation")
|
clean_parser.add_argument("--force", action="store_true", help="Clean without confirmation")
|
||||||
|
|
||||||
# Reset command
|
# Reset command
|
||||||
reset_parser = subparsers.add_parser("reset", help="Reset pipeline step")
|
reset_parser = subparsers.add_parser("reset", help="Reset pipeline step")
|
||||||
reset_parser.add_argument(
|
reset_parser.add_argument("--step", type=str, choices=choices, help="Specific step (default: all)")
|
||||||
"step",
|
reset_parser.add_argument("--all", action="store_true", help="Reset all steps")
|
||||||
type=str,
|
|
||||||
choices=["data_cleaning", "feature_extraction", "llm_annotation", "data_splitting"],
|
|
||||||
help="Step to reset",
|
|
||||||
)
|
|
||||||
reset_parser.add_argument("--force", action="store_true", help="Reset without confirmation")
|
reset_parser.add_argument("--force", action="store_true", help="Reset without confirmation")
|
||||||
|
|
||||||
# Analyze command
|
|
||||||
analyze_parser = subparsers.add_parser("analyze", help="Analyze dataset")
|
|
||||||
analyze_parser.add_argument(
|
|
||||||
"--file",
|
|
||||||
type=str,
|
|
||||||
default="names_featured.csv",
|
|
||||||
help="Dataset file to analyze (default: names_featured.csv)",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Checkpoint info command
|
|
||||||
info_parser = subparsers.add_parser("info", help="Show checkpoint information")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not args.command:
|
if not args.command:
|
||||||
@@ -71,13 +37,11 @@ def main():
|
|||||||
return 1
|
return 1
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load configuration and setup logging
|
setup_config(config_path=args.config, env=args.env)
|
||||||
config = setup_config(config_path=args.config, env=args.env)
|
|
||||||
|
|
||||||
monitor = PipelineMonitor()
|
monitor = PipelineMonitor()
|
||||||
|
|
||||||
if args.command == "status":
|
if args.command == "status":
|
||||||
monitor.print_status(detailed=args.detailed)
|
monitor.print_status(detailed=True)
|
||||||
|
|
||||||
elif args.command == "clean":
|
elif args.command == "clean":
|
||||||
checkpoint_info = monitor.count_checkpoint_files()
|
checkpoint_info = monitor.count_checkpoint_files()
|
||||||
@@ -106,49 +70,13 @@ def main():
|
|||||||
print("Cancelled")
|
print("Cancelled")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
monitor.reset_step(args.step)
|
if args.step:
|
||||||
print(f"Reset completed for {args.step}")
|
monitor.reset_step(args.step)
|
||||||
|
else:
|
||||||
|
for step in monitor.steps:
|
||||||
|
monitor.reset_step(step)
|
||||||
|
|
||||||
elif args.command == "analyze":
|
print(f"Reset completed")
|
||||||
# Use configured data directory
|
|
||||||
data_dir = config.paths.data_dir
|
|
||||||
filepath = data_dir / args.file
|
|
||||||
|
|
||||||
if not filepath.exists():
|
|
||||||
print(f"File not found: {filepath}")
|
|
||||||
return 1
|
|
||||||
|
|
||||||
analyzer = DatasetAnalyzer(str(filepath))
|
|
||||||
|
|
||||||
if not analyzer.load_data():
|
|
||||||
return 1
|
|
||||||
|
|
||||||
completion_stats = analyzer.analyze_completion()
|
|
||||||
|
|
||||||
print(f"\n=== Dataset Analysis: {args.file} ===")
|
|
||||||
print(f"Total rows: {completion_stats['total_rows']:,}")
|
|
||||||
print(
|
|
||||||
f"Annotated: {completion_stats['annotated_rows']:,} ({completion_stats['annotation_percentage']:.1f}%)")
|
|
||||||
print(f"Unannotated: {completion_stats['unannotated_rows']:,}")
|
|
||||||
print(
|
|
||||||
f"Complete names: {completion_stats['complete_names']:,} ({completion_stats['completeness_percentage']:.1f}%)"
|
|
||||||
)
|
|
||||||
|
|
||||||
elif args.command == "info":
|
|
||||||
checkpoint_info = monitor.count_checkpoint_files()
|
|
||||||
|
|
||||||
print(f"\n=== Checkpoint Information ===")
|
|
||||||
print(f"Total storage: {checkpoint_info['total_size_mb']:.1f} MB")
|
|
||||||
print()
|
|
||||||
|
|
||||||
for step in monitor.steps:
|
|
||||||
step_info = checkpoint_info[step]
|
|
||||||
print(f"{step.replace('_', ' ').title()}:")
|
|
||||||
print(f" Files: {step_info['files']}")
|
|
||||||
print(f" Size: {step_info['size_mb']:.1f} MB")
|
|
||||||
print()
|
|
||||||
|
|
||||||
return 0
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Monitoring failed: {e}")
|
print(f"Monitoring failed: {e}")
|
||||||
|
|||||||
@@ -0,0 +1,52 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from core.config import setup_config
|
||||||
|
from processing.ner.ner_data_builder import NERDataBuilder
|
||||||
|
from processing.ner.ner_name_model import NERNameModel
|
||||||
|
|
||||||
|
|
||||||
|
def train(config_path=None, env="development"):
|
||||||
|
"""Train the NER model."""
|
||||||
|
try:
|
||||||
|
config = setup_config(config_path=config_path, env=env)
|
||||||
|
trainer = NERNameModel(config)
|
||||||
|
builder = NERDataBuilder(config)
|
||||||
|
|
||||||
|
data_path = Path(config.paths.data_dir) / config.data.output_files["ner_data"]
|
||||||
|
if not data_path.exists():
|
||||||
|
builder.build()
|
||||||
|
|
||||||
|
trainer.create_blank_model("fr")
|
||||||
|
data = trainer.load_data(str(data_path))
|
||||||
|
|
||||||
|
split_idx = int(len(data) * 0.8)
|
||||||
|
train_data, eval_data = data[:split_idx], data[split_idx:]
|
||||||
|
|
||||||
|
logging.info(f"Training with {len(train_data)} examples, evaluating on {len(eval_data)}")
|
||||||
|
trainer.train(train_data, epochs=1, batch_size=config.processing.batch_size, dropout_rate=0.3)
|
||||||
|
trainer.evaluate(eval_data)
|
||||||
|
|
||||||
|
model_path = trainer.save()
|
||||||
|
logging.info(f"Model saved to: {model_path}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"NER Training failed: {e}", exc_info=True)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Train NER model for DRC names")
|
||||||
|
parser.add_argument("--config", type=str, help="Path to configuration file")
|
||||||
|
parser.add_argument("--env", type=str, default="development", help="Environment name")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
sys.exit(train(config_path=args.config, env=args.env))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -19,7 +19,7 @@ class PipelineMonitor:
|
|||||||
|
|
||||||
self.paths = paths
|
self.paths = paths
|
||||||
self.checkpoint_dir = paths.checkpoints_dir
|
self.checkpoint_dir = paths.checkpoints_dir
|
||||||
self.steps = ["data_cleaning", "feature_extraction", "llm_annotation", "data_splitting"]
|
self.steps = ["data_cleaning", "feature_extraction", "ner_annotation", "llm_annotation", "data_splitting"]
|
||||||
|
|
||||||
def get_step_status(self, step_name: str) -> Dict:
|
def get_step_status(self, step_name: str) -> Dict:
|
||||||
"""Get status of a specific pipeline step"""
|
"""Get status of a specific pipeline step"""
|
||||||
|
|||||||
@@ -0,0 +1,198 @@
|
|||||||
|
import ast
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import spacy
|
||||||
|
from spacy.tokens import DocBin
|
||||||
|
from spacy.util import filter_spans
|
||||||
|
|
||||||
|
from core.config import PipelineConfig
|
||||||
|
from core.utils import get_data_file_path
|
||||||
|
|
||||||
|
|
||||||
|
class NERDataBuilder:
|
||||||
|
def __init__(self, config: PipelineConfig):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parse_entities(cls, entities_str):
|
||||||
|
"""Parse entity string (tuple format or JSON) into spaCy-style tuples."""
|
||||||
|
if not entities_str or entities_str in ["[]", "", "nan"]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
entities_str = str(entities_str).strip()
|
||||||
|
|
||||||
|
# Handle different formats
|
||||||
|
try:
|
||||||
|
# Try to parse as Python literal (tuples or lists)
|
||||||
|
if entities_str.startswith("[(") and entities_str.endswith(")]"):
|
||||||
|
# Standard tuple format: [(0, 6, 'NATIVE'), ...]
|
||||||
|
return ast.literal_eval(entities_str)
|
||||||
|
elif entities_str.startswith("[[") and entities_str.endswith("]]"):
|
||||||
|
# Nested list format: [[0, 6, 'NATIVE'], ...]
|
||||||
|
nested_list = ast.literal_eval(entities_str)
|
||||||
|
return [(start, end, label) for start, end, label in nested_list]
|
||||||
|
elif entities_str.startswith("[{") and entities_str.endswith("}]"):
|
||||||
|
# JSON format: [{"start": 0, "end": 6, "label": "NATIVE"}, ...]
|
||||||
|
json_entities = json.loads(entities_str)
|
||||||
|
return [(e["start"], e["end"], e["label"]) for e in json_entities]
|
||||||
|
else:
|
||||||
|
# Try general ast.literal_eval for other formats
|
||||||
|
parsed = ast.literal_eval(entities_str)
|
||||||
|
if isinstance(parsed, list):
|
||||||
|
# Convert any list format to tuples
|
||||||
|
result = []
|
||||||
|
for item in parsed:
|
||||||
|
if isinstance(item, (list, tuple)) and len(item) == 3:
|
||||||
|
result.append((item[0], item[1], item[2]))
|
||||||
|
return result
|
||||||
|
|
||||||
|
except (ValueError, SyntaxError, json.JSONDecodeError) as e:
|
||||||
|
logging.warning(f"Failed to parse entities: {entities_str} ({e})")
|
||||||
|
return []
|
||||||
|
|
||||||
|
logging.warning(f"Unknown entity format: {entities_str}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_entities(cls, entities, text):
|
||||||
|
"""Validate and sort entity tuples, removing overlaps and invalid spans."""
|
||||||
|
if not entities or not text:
|
||||||
|
return []
|
||||||
|
|
||||||
|
text = str(text).strip()
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Filter out invalid entities
|
||||||
|
valid_entities = []
|
||||||
|
for entity in entities:
|
||||||
|
if not isinstance(entity, (list, tuple)) or len(entity) != 3:
|
||||||
|
logging.warning(f"Invalid entity format: {entity}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
start, end, label = entity
|
||||||
|
|
||||||
|
# Ensure start/end are integers
|
||||||
|
try:
|
||||||
|
start = int(start)
|
||||||
|
end = int(end)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
logging.warning(f"Invalid start/end positions: {entity}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Ensure label is string
|
||||||
|
if not isinstance(label, str):
|
||||||
|
logging.warning(f"Invalid label type: {entity}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check bounds
|
||||||
|
if not (0 <= start < end <= len(text)):
|
||||||
|
logging.warning(f"Entity span out of bounds: {entity} for text '{text}' (length {len(text)})")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check that span contains actual text
|
||||||
|
span_text = text[start:end].strip()
|
||||||
|
if not span_text:
|
||||||
|
logging.warning(f"Empty span: {entity} in text '{text}'")
|
||||||
|
continue
|
||||||
|
|
||||||
|
valid_entities.append((start, end, label))
|
||||||
|
|
||||||
|
if not valid_entities:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Sort by start position
|
||||||
|
valid_entities.sort(key=lambda x: (x[0], x[1]))
|
||||||
|
|
||||||
|
# Remove overlapping entities (keep the first one)
|
||||||
|
filtered = []
|
||||||
|
for start, end, label in valid_entities:
|
||||||
|
# Check for overlap with already added entities
|
||||||
|
has_overlap = False
|
||||||
|
for e_start, e_end, _ in filtered:
|
||||||
|
if not (end <= e_start or start >= e_end):
|
||||||
|
has_overlap = True
|
||||||
|
logging.warning(
|
||||||
|
f"Removing overlapping entity ({start}, {end}, '{label}') "
|
||||||
|
f"conflicts with ({e_start}, {e_end}) in '{text}'"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
if not has_overlap:
|
||||||
|
filtered.append((start, end, label))
|
||||||
|
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_doc(cls, text, entities, nlp):
|
||||||
|
"""Create a spaCy Doc object with entities added."""
|
||||||
|
doc = nlp(text)
|
||||||
|
ents = []
|
||||||
|
|
||||||
|
for start, end, label in entities:
|
||||||
|
span = doc.char_span(start, end, label=label, alignment_mode="contract") \
|
||||||
|
or doc.char_span(start, end, label=label, alignment_mode="strict")
|
||||||
|
if span:
|
||||||
|
ents.append(span)
|
||||||
|
else:
|
||||||
|
logging.warning(f"Could not create span ({start}, {end}, '{label}') in '{text}'")
|
||||||
|
|
||||||
|
doc.ents = filter_spans(ents) if ents else []
|
||||||
|
return doc
|
||||||
|
|
||||||
|
def build(self, data: pd.DataFrame = None) -> int:
|
||||||
|
"""Build the dataset for NER training."""
|
||||||
|
logging.info("Building dataset for NER training")
|
||||||
|
try:
|
||||||
|
df = pd.read_csv(get_data_file_path("names_featured.csv", self.config)) \
|
||||||
|
if data is None \
|
||||||
|
else data
|
||||||
|
|
||||||
|
ner_df = df[df["ner_tagged"] == 1].copy()
|
||||||
|
if ner_df.empty:
|
||||||
|
logging.error("No NER tagged data found in the CSV")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
logging.info(f"Found {len(ner_df)} NER tagged entries")
|
||||||
|
nlp = spacy.blank("fr")
|
||||||
|
doc_bin, training_data = DocBin(), []
|
||||||
|
processed_count, skipped_count = 0, 0
|
||||||
|
|
||||||
|
for _, row in ner_df.iterrows():
|
||||||
|
text = str(row.get("name", "")).strip()
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
entities = self.parse_entities(row.get("ner_entities", "[]"))
|
||||||
|
entities = self.validate_entities(entities, text)
|
||||||
|
|
||||||
|
training_data.append((text, {"entities": entities}))
|
||||||
|
try:
|
||||||
|
doc_bin.add(self.create_doc(text, entities, nlp))
|
||||||
|
processed_count += 1
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error processing '{text}': {e}")
|
||||||
|
skipped_count += 1
|
||||||
|
|
||||||
|
if not training_data:
|
||||||
|
logging.error("No valid training examples generated")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
json_path = Path(self.config.paths.data_dir) / self.config.data.output_files["ner_data"]
|
||||||
|
spacy_path = Path(self.config.paths.data_dir) / self.config.data.output_files["ner_spacy"]
|
||||||
|
|
||||||
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(training_data, f, ensure_ascii=False, indent=None)
|
||||||
|
doc_bin.to_disk(spacy_path)
|
||||||
|
|
||||||
|
logging.info(f"Processed: {processed_count}, Skipped: {skipped_count}")
|
||||||
|
logging.info(f"Saved NER data in json format to {json_path}")
|
||||||
|
logging.info(f"Saved NER data in spaCy format to {spacy_path}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to build NER dataset: {e}", exc_info=True)
|
||||||
|
return 1
|
||||||
@@ -0,0 +1,356 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any, List, Tuple
|
||||||
|
|
||||||
|
import spacy
|
||||||
|
from spacy.training import Example
|
||||||
|
from spacy.util import minibatch
|
||||||
|
|
||||||
|
from core.config.pipeline_config import PipelineConfig
|
||||||
|
|
||||||
|
|
||||||
|
class NERNameModel:
|
||||||
|
"""NER model trainer using spaCy for DRC names entity recognition"""
|
||||||
|
|
||||||
|
def __init__(self, config: PipelineConfig):
|
||||||
|
self.config = config
|
||||||
|
self.nlp = None
|
||||||
|
self.ner = None
|
||||||
|
self.model_path = None
|
||||||
|
self.training_stats = {}
|
||||||
|
|
||||||
|
def create_blank_model(self, language: str = "fr") -> None:
|
||||||
|
"""Create a blank spaCy model with NER pipeline"""
|
||||||
|
logging.info(f"Creating blank {language} model for NER training")
|
||||||
|
|
||||||
|
# Create blank model - French tokenizer works well for DRC names
|
||||||
|
self.nlp = spacy.blank(language)
|
||||||
|
|
||||||
|
# Add NER pipeline component
|
||||||
|
if "ner" not in self.nlp.pipe_names:
|
||||||
|
self.ner = self.nlp.add_pipe("ner")
|
||||||
|
else:
|
||||||
|
self.ner = self.nlp.get_pipe("ner")
|
||||||
|
|
||||||
|
# Add our custom labels
|
||||||
|
self.ner.add_label("NATIVE")
|
||||||
|
self.ner.add_label("SURNAME")
|
||||||
|
|
||||||
|
logging.info("Blank model created with NATIVE and SURNAME labels")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_data(cls, data_path: str) -> List[Tuple[str, Dict]]:
|
||||||
|
"""Load training data from JSON file - compatible with NERNameTagger output format"""
|
||||||
|
if not os.path.exists(data_path):
|
||||||
|
raise FileNotFoundError(f"Training data not found at {data_path}")
|
||||||
|
|
||||||
|
logging.info(f"Loading training data from {data_path}")
|
||||||
|
|
||||||
|
with open(data_path, 'r', encoding='utf-8') as f:
|
||||||
|
raw_data = json.load(f)
|
||||||
|
|
||||||
|
# Validate and clean training data
|
||||||
|
valid_data = []
|
||||||
|
skipped_count = 0
|
||||||
|
|
||||||
|
for i, item in enumerate(raw_data):
|
||||||
|
try:
|
||||||
|
if not isinstance(item, (list, tuple)) or len(item) != 2:
|
||||||
|
logging.warning(f"Skipping invalid training example format at index {i}: {item}")
|
||||||
|
skipped_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
text, annotations = item
|
||||||
|
|
||||||
|
# Validate text
|
||||||
|
if not isinstance(text, str) or not text.strip():
|
||||||
|
logging.warning(f"Skipping invalid text at index {i}: {repr(text)}")
|
||||||
|
skipped_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle different annotation formats from NERNameTagger
|
||||||
|
if not isinstance(annotations, dict) or "entities" not in annotations:
|
||||||
|
logging.warning(f"Skipping invalid annotations at index {i}: {annotations}")
|
||||||
|
skipped_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
entities_raw = annotations["entities"]
|
||||||
|
|
||||||
|
# Parse entities - handle both string and list formats from tagger
|
||||||
|
if isinstance(entities_raw, str):
|
||||||
|
# String format from tagger: "[(0, 6, 'NATIVE'), ...]"
|
||||||
|
try:
|
||||||
|
import ast
|
||||||
|
entities = ast.literal_eval(entities_raw)
|
||||||
|
if not isinstance(entities, list):
|
||||||
|
logging.warning(f"Parsed entities is not a list at index {i}: {entities}")
|
||||||
|
skipped_count += 1
|
||||||
|
continue
|
||||||
|
except (ValueError, SyntaxError) as e:
|
||||||
|
logging.warning(f"Failed to parse entity string at index {i}: {entities_raw} ({e})")
|
||||||
|
skipped_count += 1
|
||||||
|
continue
|
||||||
|
elif isinstance(entities_raw, list):
|
||||||
|
# Already in list format
|
||||||
|
entities = entities_raw
|
||||||
|
else:
|
||||||
|
logging.warning(f"Skipping invalid entities format at index {i}: {entities_raw}")
|
||||||
|
skipped_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Validate each entity
|
||||||
|
valid_entities = []
|
||||||
|
for entity in entities:
|
||||||
|
if not isinstance(entity, (list, tuple)) or len(entity) != 3:
|
||||||
|
logging.warning(f"Skipping invalid entity format in '{text}': {entity}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
start, end, label = entity
|
||||||
|
|
||||||
|
# Validate entity components
|
||||||
|
if (not isinstance(start, int) or not isinstance(end, int) or
|
||||||
|
not isinstance(label, str) or start >= end or
|
||||||
|
start < 0 or end > len(text)):
|
||||||
|
logging.warning(f"Skipping invalid entity bounds in '{text}': {entity}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check for overlaps with already validated entities
|
||||||
|
has_overlap = any(
|
||||||
|
start < v_end and end > v_start
|
||||||
|
for v_start, v_end, _ in valid_entities
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_overlap:
|
||||||
|
logging.warning(f"Skipping overlapping entity in '{text}': {entity}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Validate that the span doesn't contain spaces (matching tagger validation)
|
||||||
|
span_text = text[start:end]
|
||||||
|
if not span_text or span_text != span_text.strip() or ' ' in span_text:
|
||||||
|
logging.warning(f"Skipping entity with spaces in '{text}': {entity} -> '{span_text}'")
|
||||||
|
continue
|
||||||
|
|
||||||
|
valid_entities.append((start, end, label))
|
||||||
|
|
||||||
|
if not valid_entities:
|
||||||
|
logging.warning(f"Skipping training example with no valid entities: '{text}'")
|
||||||
|
skipped_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Sort entities by start position
|
||||||
|
valid_entities.sort(key=lambda x: x[0])
|
||||||
|
valid_data.append((text.strip(), {"entities": valid_entities}))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error processing training example at index {i}: {e}")
|
||||||
|
skipped_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
logging.info(f"Loaded {len(valid_data)} valid training examples, skipped {skipped_count} invalid ones")
|
||||||
|
|
||||||
|
if not valid_data:
|
||||||
|
raise ValueError("No valid training examples found in the data")
|
||||||
|
|
||||||
|
return valid_data
|
||||||
|
|
||||||
|
def train(
|
||||||
|
self,
|
||||||
|
data: List[Tuple[str, Dict]],
|
||||||
|
epochs: int = 5,
|
||||||
|
batch_size: int = 16,
|
||||||
|
dropout_rate: float = 0.2,
|
||||||
|
) -> None:
|
||||||
|
"""Train the NER model"""
|
||||||
|
logging.info(f"Starting NER training with {len(data)} examples")
|
||||||
|
logging.info(f"Training parameters: epochs={epochs}, batch_size={batch_size}, dropout={dropout_rate}")
|
||||||
|
|
||||||
|
if self.nlp is None:
|
||||||
|
raise ValueError("Model not initialized. Call create_blank_model() first.")
|
||||||
|
|
||||||
|
# Initialize the model
|
||||||
|
self.nlp.initialize()
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
losses_history = []
|
||||||
|
|
||||||
|
for epoch in range(epochs):
|
||||||
|
losses = {}
|
||||||
|
|
||||||
|
# Create training examples
|
||||||
|
examples = []
|
||||||
|
for text, annotations in data:
|
||||||
|
doc = self.nlp.make_doc(text)
|
||||||
|
example = Example.from_dict(doc, annotations)
|
||||||
|
examples.append(example)
|
||||||
|
logging.info(f"Training example: {text[:30]}... with entities {annotations.get('entities', [])}")
|
||||||
|
|
||||||
|
# Train in batches
|
||||||
|
batches = minibatch(examples, size=batch_size)
|
||||||
|
for batch in batches:
|
||||||
|
self.nlp.update(
|
||||||
|
batch,
|
||||||
|
losses=losses,
|
||||||
|
drop=dropout_rate,
|
||||||
|
sgd=self.nlp.create_optimizer()
|
||||||
|
)
|
||||||
|
logging.info(f"Training batch with {len(batch)} examples, current losses: {losses}")
|
||||||
|
|
||||||
|
epoch_loss = losses.get("ner", 0)
|
||||||
|
losses_history.append(epoch_loss)
|
||||||
|
logging.info(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss:.4f}")
|
||||||
|
|
||||||
|
# Store training statistics
|
||||||
|
self.training_stats = {
|
||||||
|
"epochs": epochs,
|
||||||
|
"final_loss": losses_history[-1] if losses_history else 0,
|
||||||
|
"training_examples": len(data),
|
||||||
|
"loss_history": losses_history,
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"dropout_rate": dropout_rate
|
||||||
|
}
|
||||||
|
|
||||||
|
logging.info(f"Training completed. Final loss: {self.training_stats['final_loss']:.4f}")
|
||||||
|
|
||||||
|
def evaluate(self, test_data: List[Tuple[str, Dict]]) -> Dict[str, Any]:
|
||||||
|
"""Evaluate the trained model on test data"""
|
||||||
|
if self.nlp is None:
|
||||||
|
raise ValueError("Model not trained. Call train_model() first.")
|
||||||
|
|
||||||
|
logging.info(f"Evaluating model on {len(test_data)} test examples")
|
||||||
|
|
||||||
|
total_examples = len(test_data)
|
||||||
|
correct_entities = 0
|
||||||
|
predicted_entities = 0
|
||||||
|
actual_entities = 0
|
||||||
|
|
||||||
|
entity_stats = {"NATIVE": {"tp": 0, "fp": 0, "fn": 0}, "SURNAME": {"tp": 0, "fp": 0, "fn": 0}}
|
||||||
|
|
||||||
|
for text, annotations in test_data:
|
||||||
|
# Get actual entities
|
||||||
|
actual_ents = set()
|
||||||
|
for start, end, label in annotations.get("entities", []):
|
||||||
|
actual_ents.add((start, end, label))
|
||||||
|
actual_entities += 1
|
||||||
|
|
||||||
|
# Get predicted entities
|
||||||
|
doc = self.nlp(text)
|
||||||
|
predicted_ents = set()
|
||||||
|
for ent in doc.ents:
|
||||||
|
predicted_ents.add((ent.start_char, ent.end_char, ent.label_))
|
||||||
|
predicted_entities += 1
|
||||||
|
|
||||||
|
# Calculate matches
|
||||||
|
matches = actual_ents.intersection(predicted_ents)
|
||||||
|
correct_entities += len(matches)
|
||||||
|
|
||||||
|
# Update per-label statistics
|
||||||
|
for start, end, label in actual_ents:
|
||||||
|
if (start, end, label) in predicted_ents:
|
||||||
|
entity_stats[label]["tp"] += 1
|
||||||
|
else:
|
||||||
|
entity_stats[label]["fn"] += 1
|
||||||
|
|
||||||
|
for start, end, label in predicted_ents:
|
||||||
|
if (start, end, label) not in actual_ents:
|
||||||
|
entity_stats[label]["fp"] += 1
|
||||||
|
|
||||||
|
# Calculate overall metrics
|
||||||
|
precision = correct_entities / predicted_entities if predicted_entities > 0 else 0
|
||||||
|
recall = correct_entities / actual_entities if actual_entities > 0 else 0
|
||||||
|
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
||||||
|
|
||||||
|
# Calculate per-label metrics
|
||||||
|
label_metrics = {}
|
||||||
|
for label, stats in entity_stats.items():
|
||||||
|
tp, fp, fn = stats["tp"], stats["fp"], stats["fn"]
|
||||||
|
label_precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
||||||
|
label_recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
||||||
|
label_f1 = (
|
||||||
|
2 * (label_precision * label_recall) / (label_precision + label_recall)) \
|
||||||
|
if (label_precision + label_recall) > 0 else 0
|
||||||
|
|
||||||
|
label_metrics[label] = {
|
||||||
|
"precision": label_precision,
|
||||||
|
"recall": label_recall,
|
||||||
|
"f1_score": label_f1,
|
||||||
|
"support": tp + fn
|
||||||
|
}
|
||||||
|
|
||||||
|
evaluation_results = {
|
||||||
|
"overall": {
|
||||||
|
"precision": precision,
|
||||||
|
"recall": recall,
|
||||||
|
"f1_score": f1_score,
|
||||||
|
"total_examples": total_examples,
|
||||||
|
"correct_entities": correct_entities,
|
||||||
|
"predicted_entities": predicted_entities,
|
||||||
|
"actual_entities": actual_entities
|
||||||
|
},
|
||||||
|
"by_label": label_metrics
|
||||||
|
}
|
||||||
|
|
||||||
|
logging.info(f"NER Evaluation completed. Overall F1: {f1_score:.4f}")
|
||||||
|
return evaluation_results
|
||||||
|
|
||||||
|
def save(self, model_name: str = "drc_ner_model") -> str:
|
||||||
|
"""Save the trained model"""
|
||||||
|
if self.nlp is None:
|
||||||
|
raise ValueError("No model to save. Train a model first.")
|
||||||
|
|
||||||
|
# Create model directory
|
||||||
|
model_dir = self.config.paths.models_dir / model_name
|
||||||
|
model_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Save the model
|
||||||
|
self.nlp.to_disk(model_dir)
|
||||||
|
self.model_path = str(model_dir)
|
||||||
|
|
||||||
|
# Save training statistics
|
||||||
|
stats_path = model_dir / "training_stats.json"
|
||||||
|
with open(stats_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(self.training_stats, f, indent=2)
|
||||||
|
|
||||||
|
logging.info(f"NER Model saved to {model_dir}")
|
||||||
|
return self.model_path
|
||||||
|
|
||||||
|
def load(self, model_path: str) -> None:
|
||||||
|
"""Load a trained model"""
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
raise FileNotFoundError(f"Model not found at {model_path}")
|
||||||
|
|
||||||
|
logging.info(f"Loading model from {model_path}")
|
||||||
|
self.nlp = spacy.load(model_path)
|
||||||
|
self.ner = self.nlp.get_pipe("ner")
|
||||||
|
self.model_path = model_path
|
||||||
|
|
||||||
|
# Load training statistics if available
|
||||||
|
stats_path = Path(model_path) / "training_stats.json"
|
||||||
|
if stats_path.exists():
|
||||||
|
with open(stats_path, 'r', encoding='utf-8') as f:
|
||||||
|
self.training_stats = json.load(f)
|
||||||
|
|
||||||
|
logging.info("NER Model loaded successfully")
|
||||||
|
|
||||||
|
def predict(self, text: str) -> Dict[str, Any]:
|
||||||
|
"""Make predictions on a single text"""
|
||||||
|
if self.nlp is None:
|
||||||
|
raise ValueError("No model loaded. Load or train a model first.")
|
||||||
|
|
||||||
|
doc = self.nlp(text)
|
||||||
|
entities = []
|
||||||
|
|
||||||
|
for ent in doc.ents:
|
||||||
|
entities.append({
|
||||||
|
"text": ent.text,
|
||||||
|
"label": ent.label_,
|
||||||
|
"start": ent.start_char,
|
||||||
|
"end": ent.end_char,
|
||||||
|
"confidence": getattr(ent, 'score', None) # If confidence scores are available
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"text": text,
|
||||||
|
"entities": entities
|
||||||
|
}
|
||||||
@@ -0,0 +1,200 @@
|
|||||||
|
from typing import Union, Dict, Any, List
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
class NERNameTagger:
|
||||||
|
def tag_name(self, name: str, probable_native: str, probable_surname: str) -> Union[Dict[str, Any], None]:
|
||||||
|
"""Create a single NER training example using probable_native and probable_surname"""
|
||||||
|
if not name or not probable_native or not probable_surname:
|
||||||
|
return None
|
||||||
|
|
||||||
|
name = name.strip()
|
||||||
|
probable_native = probable_native.strip()
|
||||||
|
probable_surname = probable_surname.strip()
|
||||||
|
|
||||||
|
entities = []
|
||||||
|
used_spans = [] # Track used character spans to prevent overlaps
|
||||||
|
|
||||||
|
# Helper function to check if a span overlaps with any existing span
|
||||||
|
def has_overlap(start, end):
|
||||||
|
for used_start, used_end in used_spans:
|
||||||
|
if not (end <= used_start or start >= used_end):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Find positions of native names in the full name
|
||||||
|
native_words = probable_native.split()
|
||||||
|
name_lower = name.lower() # Use lowercase for consistent searching
|
||||||
|
processed_native_words = set()
|
||||||
|
|
||||||
|
for native_word in native_words:
|
||||||
|
native_word = native_word.strip()
|
||||||
|
if len(native_word) < 2: # Skip very short words
|
||||||
|
continue
|
||||||
|
|
||||||
|
native_word_lower = native_word.lower()
|
||||||
|
|
||||||
|
# Skip if we've already processed this exact word
|
||||||
|
if native_word_lower in processed_native_words:
|
||||||
|
continue
|
||||||
|
processed_native_words.add(native_word_lower)
|
||||||
|
|
||||||
|
# Find the first occurrence of this native word that doesn't overlap
|
||||||
|
start_pos = 0
|
||||||
|
while True:
|
||||||
|
pos = name_lower.find(native_word_lower, start_pos) # Case-insensitive search
|
||||||
|
if pos == -1:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Calculate end position - make sure we only include the word itself
|
||||||
|
end_pos = pos + len(native_word_lower)
|
||||||
|
|
||||||
|
# Double-check that the extracted span matches exactly what we expect
|
||||||
|
extracted_text = name[pos:end_pos] # Get original case text
|
||||||
|
if extracted_text.lower() != native_word_lower:
|
||||||
|
start_pos = pos + 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if this is a word boundary match and doesn't overlap
|
||||||
|
if (self._is_word_boundary_match(name, pos, end_pos) and
|
||||||
|
not has_overlap(pos, end_pos)):
|
||||||
|
entities.append((pos, end_pos, 'NATIVE'))
|
||||||
|
used_spans.append((pos, end_pos))
|
||||||
|
break # Only take the first non-overlapping occurrence
|
||||||
|
|
||||||
|
start_pos = pos + 1
|
||||||
|
|
||||||
|
# Find position of surname in the full name
|
||||||
|
if probable_surname and len(probable_surname.strip()) >= 2:
|
||||||
|
surname_lower = probable_surname.lower()
|
||||||
|
|
||||||
|
# Find the first occurrence that doesn't overlap
|
||||||
|
start_pos = 0
|
||||||
|
while True:
|
||||||
|
pos = name_lower.find(surname_lower, start_pos) # Case-insensitive search
|
||||||
|
if pos == -1:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Calculate end position correctly - exact match only
|
||||||
|
end_pos = pos + len(surname_lower)
|
||||||
|
|
||||||
|
# Double-check that the extracted span matches exactly what we expect
|
||||||
|
extracted_text = name[pos:end_pos] # Get original case text
|
||||||
|
if extracted_text.lower() != surname_lower:
|
||||||
|
start_pos = pos + 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (self._is_word_boundary_match(name, pos, end_pos) and
|
||||||
|
not has_overlap(pos, end_pos)):
|
||||||
|
entities.append((pos, end_pos, 'SURNAME'))
|
||||||
|
used_spans.append((pos, end_pos))
|
||||||
|
break
|
||||||
|
|
||||||
|
start_pos = pos + 1
|
||||||
|
|
||||||
|
if not entities:
|
||||||
|
logging.warning(f"No valid entities found for name: '{name}' with native: '{probable_native}' and surname: '{probable_surname}'")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Sort entities by position and validate
|
||||||
|
entities.sort(key=lambda x: x[0])
|
||||||
|
|
||||||
|
# Final validation - ensure no overlaps and valid spans
|
||||||
|
validated_entities = []
|
||||||
|
for start, end, label in entities:
|
||||||
|
# Check bounds
|
||||||
|
if not (0 <= start < end <= len(name)):
|
||||||
|
logging.warning(f"Invalid span bounds ({start}, {end}) for text length {len(name)}: '{name}'")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check for overlaps with already validated entities
|
||||||
|
if any(start < v_end and end > v_start for v_start, v_end, _ in validated_entities):
|
||||||
|
logging.warning(f"Overlapping span ({start}, {end}, '{label}') in '{name}'")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# CRITICAL VALIDATION: Check that the span contains only the expected word (no spaces)
|
||||||
|
span_text = name[start:end]
|
||||||
|
if not span_text or span_text != span_text.strip() or ' ' in span_text:
|
||||||
|
logging.warning(f"Span contains spaces or is empty ({start}, {end}) in '{name}': '{span_text}'")
|
||||||
|
continue
|
||||||
|
|
||||||
|
validated_entities.append((start, end, label))
|
||||||
|
|
||||||
|
if not validated_entities:
|
||||||
|
logging.warning(f"No valid entities after validation for: '{name}'")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Convert to string format that matches the dataset
|
||||||
|
entities_str = str(validated_entities)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"entities": entities_str,
|
||||||
|
"spans": validated_entities # Keep the original tuples for internal use
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _is_word_boundary_match(cls, text: str, start: int, end: int) -> bool:
|
||||||
|
"""Check if the match is at word boundaries"""
|
||||||
|
# Check character before start position
|
||||||
|
if start > 0:
|
||||||
|
prev_char = text[start - 1]
|
||||||
|
if prev_char.isalnum():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check character after end position
|
||||||
|
if end < len(text):
|
||||||
|
next_char = text[end]
|
||||||
|
if next_char.isalnum():
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_entities(cls, name: str, entities_str: str) -> bool:
|
||||||
|
"""Validate that entity annotations are correct for a given name"""
|
||||||
|
try:
|
||||||
|
import ast
|
||||||
|
entities = ast.literal_eval(entities_str)
|
||||||
|
|
||||||
|
# Check for overlaps and valid bounds
|
||||||
|
sorted_entities = sorted(entities, key=lambda x: x[0])
|
||||||
|
|
||||||
|
for i, (start, end, label) in enumerate(sorted_entities):
|
||||||
|
# Check bounds
|
||||||
|
if not (0 <= start < end <= len(name)):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check for overlaps with next entity
|
||||||
|
if i < len(sorted_entities) - 1:
|
||||||
|
next_start = sorted_entities[i + 1][0]
|
||||||
|
if end > next_start:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Extract the text span and validate it's not empty
|
||||||
|
span_text = name[start:end]
|
||||||
|
if not span_text.strip():
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
except (ValueError, SyntaxError, TypeError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def extract_entity_text(cls, name: str, entities_str: str) -> Dict[str, List[str]]:
|
||||||
|
"""Extract the actual text for each entity type"""
|
||||||
|
result = {'NATIVE': [], 'SURNAME': []}
|
||||||
|
|
||||||
|
try:
|
||||||
|
import ast
|
||||||
|
entities = ast.literal_eval(entities_str)
|
||||||
|
|
||||||
|
for start, end, label in entities:
|
||||||
|
if 0 <= start < end <= len(name):
|
||||||
|
span_text = name[start:end]
|
||||||
|
if label in result:
|
||||||
|
result[label].append(span_text)
|
||||||
|
|
||||||
|
except (ValueError, SyntaxError, TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return result
|
||||||
@@ -6,9 +6,10 @@ from dataclasses import dataclass
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from processing.batch.batch_config import BatchConfig
|
|
||||||
from core.config.pipeline_config import PipelineConfig
|
from core.config.pipeline_config import PipelineConfig
|
||||||
|
from processing.batch.batch_config import BatchConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -25,11 +26,18 @@ class PipelineState:
|
|||||||
self.failed_batches = []
|
self.failed_batches = []
|
||||||
|
|
||||||
|
|
||||||
|
class NameAnnotation(BaseModel):
|
||||||
|
"""Model for name annotation results"""
|
||||||
|
|
||||||
|
identified_name: Optional[str]
|
||||||
|
identified_surname: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
class PipelineStep(ABC):
|
class PipelineStep(ABC):
|
||||||
"""Abstract base class for pipeline steps"""
|
"""Abstract base class for pipeline steps"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, name: str, pipeline_config: PipelineConfig, batch_config: Optional[BatchConfig] = None
|
self, name: str, pipeline_config: PipelineConfig, batch_config: Optional[BatchConfig] = None
|
||||||
):
|
):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.pipeline_config = pipeline_config
|
self.pipeline_config = pipeline_config
|
||||||
|
|||||||
@@ -25,4 +25,7 @@ class DataCleaningStep(PipelineStep):
|
|||||||
# Apply text cleaning
|
# Apply text cleaning
|
||||||
batch = self.text_cleaner.clean_dataframe_text_columns(batch)
|
batch = self.text_cleaner.clean_dataframe_text_columns(batch)
|
||||||
|
|
||||||
|
# Remove duplicates
|
||||||
|
batch = batch.drop_duplicates(subset=self.required_columns)
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import pandas as pd
|
|||||||
|
|
||||||
from core.config.pipeline_config import PipelineConfig
|
from core.config.pipeline_config import PipelineConfig
|
||||||
from core.utils.region_mapper import RegionMapper
|
from core.utils.region_mapper import RegionMapper
|
||||||
|
from processing.ner.ner_name_tagger import NERNameTagger
|
||||||
from processing.steps import PipelineStep
|
from processing.steps import PipelineStep
|
||||||
|
|
||||||
|
|
||||||
@@ -24,6 +25,7 @@ class FeatureExtractionStep(PipelineStep):
|
|||||||
def __init__(self, pipeline_config: PipelineConfig):
|
def __init__(self, pipeline_config: PipelineConfig):
|
||||||
super().__init__("feature_extraction", pipeline_config)
|
super().__init__("feature_extraction", pipeline_config)
|
||||||
self.region_mapper = RegionMapper()
|
self.region_mapper = RegionMapper()
|
||||||
|
self.name_tagger = NERNameTagger()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_gender(cls, gender: str) -> Gender:
|
def validate_gender(cls, gender: str) -> Gender:
|
||||||
@@ -52,7 +54,7 @@ class FeatureExtractionStep(PipelineStep):
|
|||||||
|
|
||||||
# Basic features
|
# Basic features
|
||||||
batch["words"] = batch["name"].str.count(" ") + 1
|
batch["words"] = batch["name"].str.count(" ") + 1
|
||||||
batch["length"] = batch["name"].str.replace(" ", "", regex=False).str.len()
|
batch["length"] = batch["name"].str.len()
|
||||||
|
|
||||||
# Handle year column
|
# Handle year column
|
||||||
if "year" in batch.columns:
|
if "year" in batch.columns:
|
||||||
@@ -63,6 +65,8 @@ class FeatureExtractionStep(PipelineStep):
|
|||||||
batch["probable_surname"] = None
|
batch["probable_surname"] = None
|
||||||
batch["identified_name"] = None
|
batch["identified_name"] = None
|
||||||
batch["identified_surname"] = None
|
batch["identified_surname"] = None
|
||||||
|
batch["ner_entities"] = None
|
||||||
|
batch["ner_tagged"] = 0
|
||||||
batch["annotated"] = 0
|
batch["annotated"] = 0
|
||||||
|
|
||||||
# Vectorized category assignment
|
# Vectorized category assignment
|
||||||
@@ -81,14 +85,19 @@ class FeatureExtractionStep(PipelineStep):
|
|||||||
|
|
||||||
# Auto-assign for 3-word names
|
# Auto-assign for 3-word names
|
||||||
three_word_mask = batch["words"] == 3
|
three_word_mask = batch["words"] == 3
|
||||||
batch.loc[three_word_mask, "identified_name"] = batch.loc[
|
batch.loc[three_word_mask, "identified_name"] = batch.loc[three_word_mask, "probable_native"]
|
||||||
three_word_mask, "probable_native"
|
batch.loc[three_word_mask, "identified_surname"] = batch.loc[three_word_mask, "probable_surname"]
|
||||||
]
|
|
||||||
batch.loc[three_word_mask, "identified_surname"] = batch.loc[
|
|
||||||
three_word_mask, "probable_surname"
|
|
||||||
]
|
|
||||||
batch.loc[three_word_mask, "annotated"] = 1
|
batch.loc[three_word_mask, "annotated"] = 1
|
||||||
|
|
||||||
|
# Tag names with NER entities
|
||||||
|
three_word_rows = batch[three_word_mask]
|
||||||
|
for idx, row in three_word_rows.iterrows():
|
||||||
|
entity = self.name_tagger.tag_name(row['name'], row['identified_name'], row['identified_surname'])
|
||||||
|
|
||||||
|
if entity:
|
||||||
|
batch.at[idx, "ner_entities"] = entity["entities"]
|
||||||
|
batch.at[idx, "ner_tagged"] = 1
|
||||||
|
|
||||||
# Map regions to provinces
|
# Map regions to provinces
|
||||||
batch["province"] = self.region_mapper.map_regions_vectorized(batch["region"])
|
batch["province"] = self.region_mapper.map_regions_vectorized(batch["region"])
|
||||||
|
|
||||||
|
|||||||
@@ -1,25 +1,18 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from typing import Dict, Optional
|
from typing import Dict
|
||||||
|
|
||||||
import ollama
|
import ollama
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pydantic import ValidationError, BaseModel
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from core.config.pipeline_config import PipelineConfig
|
from core.config.pipeline_config import PipelineConfig
|
||||||
from core.utils.prompt_manager import PromptManager
|
from core.utils.prompt_manager import PromptManager
|
||||||
from core.utils.rate_limiter import RateLimiter
|
|
||||||
from core.utils.rate_limiter import RateLimitConfig
|
from core.utils.rate_limiter import RateLimitConfig
|
||||||
|
from core.utils.rate_limiter import RateLimiter
|
||||||
from processing.batch.batch_config import BatchConfig
|
from processing.batch.batch_config import BatchConfig
|
||||||
from processing.steps import PipelineStep
|
from processing.steps import PipelineStep, NameAnnotation
|
||||||
|
|
||||||
|
|
||||||
class NameAnnotation(BaseModel):
|
|
||||||
"""Model for name annotation results"""
|
|
||||||
|
|
||||||
identified_name: Optional[str]
|
|
||||||
identified_surname: Optional[str]
|
|
||||||
|
|
||||||
|
|
||||||
class LLMAnnotationStep(PipelineStep):
|
class LLMAnnotationStep(PipelineStep):
|
||||||
@@ -27,10 +20,12 @@ class LLMAnnotationStep(PipelineStep):
|
|||||||
|
|
||||||
def __init__(self, pipeline_config: PipelineConfig):
|
def __init__(self, pipeline_config: PipelineConfig):
|
||||||
# Create custom batch config for LLM processing
|
# Create custom batch config for LLM processing
|
||||||
|
self.llm_config = pipeline_config.annotation.llm
|
||||||
batch_config = BatchConfig(
|
batch_config = BatchConfig(
|
||||||
batch_size=pipeline_config.processing.batch_size,
|
batch_size=pipeline_config.processing.batch_size,
|
||||||
max_workers=min(
|
max_workers=min(
|
||||||
pipeline_config.llm.max_concurrent_requests, pipeline_config.processing.max_workers
|
self.llm_config.max_concurrent_requests,
|
||||||
|
pipeline_config.processing.max_workers
|
||||||
),
|
),
|
||||||
checkpoint_interval=pipeline_config.processing.checkpoint_interval,
|
checkpoint_interval=pipeline_config.processing.checkpoint_interval,
|
||||||
use_multiprocessing=pipeline_config.processing.use_multiprocessing,
|
use_multiprocessing=pipeline_config.processing.use_multiprocessing,
|
||||||
@@ -39,7 +34,7 @@ class LLMAnnotationStep(PipelineStep):
|
|||||||
|
|
||||||
self.prompt = PromptManager(pipeline_config).load_prompt()
|
self.prompt = PromptManager(pipeline_config).load_prompt()
|
||||||
self.rate_limiter = (
|
self.rate_limiter = (
|
||||||
self._create_rate_limiter() if pipeline_config.llm.enable_rate_limiting else None
|
self._create_rate_limiter() if self.llm_config.enable_rate_limiting else None
|
||||||
)
|
)
|
||||||
|
|
||||||
# Statistics
|
# Statistics
|
||||||
@@ -53,14 +48,14 @@ class LLMAnnotationStep(PipelineStep):
|
|||||||
def _create_rate_limiter(self):
|
def _create_rate_limiter(self):
|
||||||
"""Create rate limiter based on configuration"""
|
"""Create rate limiter based on configuration"""
|
||||||
rate_config = RateLimitConfig(
|
rate_config = RateLimitConfig(
|
||||||
requests_per_minute=self.pipeline_config.llm.requests_per_minute,
|
requests_per_minute=self.llm_config.requests_per_minute,
|
||||||
requests_per_second=self.pipeline_config.llm.requests_per_second,
|
requests_per_second=self.llm_config.requests_per_second,
|
||||||
)
|
)
|
||||||
return RateLimiter(rate_config)
|
return RateLimiter(rate_config)
|
||||||
|
|
||||||
def analyze_name_with_retry(self, client: ollama.Client, name: str, row_id: int) -> Dict:
|
def analyze_name(self, client: ollama.Client, name: str) -> Dict:
|
||||||
"""Analyze a name with retry logic and rate limiting"""
|
"""Analyze a name with retry logic and rate limiting"""
|
||||||
for attempt in range(self.pipeline_config.llm.retry_attempts):
|
for attempt in range(self.llm_config.retry_attempts):
|
||||||
try:
|
try:
|
||||||
# Apply rate limiting if enabled
|
# Apply rate limiting if enabled
|
||||||
if self.rate_limiter:
|
if self.rate_limiter:
|
||||||
@@ -68,7 +63,7 @@ class LLMAnnotationStep(PipelineStep):
|
|||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
response = client.chat(
|
response = client.chat(
|
||||||
model=self.pipeline_config.llm.model_name,
|
model=self.llm_config.model_name,
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": self.prompt},
|
{"role": "system", "content": self.prompt},
|
||||||
{"role": "user", "content": name},
|
{"role": "user", "content": name},
|
||||||
@@ -77,9 +72,9 @@ class LLMAnnotationStep(PipelineStep):
|
|||||||
)
|
)
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
if elapsed_time > self.pipeline_config.llm.timeout_seconds:
|
if elapsed_time > self.llm_config.timeout_seconds:
|
||||||
raise TimeoutError(
|
raise TimeoutError(
|
||||||
f"Request took {elapsed_time:.2f}s, exceeding {self.pipeline_config.llm.timeout_seconds}s timeout"
|
f"Request took {elapsed_time:.2f}s, exceeding {self.llm_config.timeout_seconds}s timeout"
|
||||||
)
|
)
|
||||||
|
|
||||||
annotation = NameAnnotation.model_validate_json(response.message.content)
|
annotation = NameAnnotation.model_validate_json(response.message.content)
|
||||||
@@ -98,12 +93,12 @@ class LLMAnnotationStep(PipelineStep):
|
|||||||
|
|
||||||
except (ValidationError, TimeoutError, Exception) as e:
|
except (ValidationError, TimeoutError, Exception) as e:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Error analyzing '{name}' (attempt {attempt + 1}/{self.pipeline_config.llm.retry_attempts}): {e}"
|
f"Error analyzing '{name}' (attempt {attempt + 1}/{self.llm_config.retry_attempts}): {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Exponential backoff with jitter
|
# Exponential backoff with jitter
|
||||||
if attempt < self.pipeline_config.llm.retry_attempts - 1:
|
if attempt < self.llm_config.retry_attempts - 1:
|
||||||
wait_time = (2**attempt) + (time.time() % 1)
|
wait_time = (2 ** attempt) + (time.time() % 1)
|
||||||
time.sleep(min(wait_time, 10))
|
time.sleep(min(wait_time, 10))
|
||||||
|
|
||||||
self.failed_requests += 1
|
self.failed_requests += 1
|
||||||
@@ -112,7 +107,7 @@ class LLMAnnotationStep(PipelineStep):
|
|||||||
"identified_surname": None,
|
"identified_surname": None,
|
||||||
"annotated": 0,
|
"annotated": 0,
|
||||||
"processing_time": 0,
|
"processing_time": 0,
|
||||||
"attempts": self.pipeline_config.llm.retry_attempts,
|
"attempts": self.llm_config.retry_attempts,
|
||||||
"failed": True,
|
"failed": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -125,18 +120,18 @@ class LLMAnnotationStep(PipelineStep):
|
|||||||
logging.info(f"Batch {batch_id}: No entries to annotate")
|
logging.info(f"Batch {batch_id}: No entries to annotate")
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
logging.info(f"Batch {batch_id}: Annotating {len(unannotated_entries)} entries")
|
logging.info(f"Batch {batch_id}: Annotating {len(unannotated_entries)} entries with LLM")
|
||||||
|
|
||||||
batch = batch.copy()
|
batch = batch.copy()
|
||||||
client = ollama.Client()
|
client = ollama.Client()
|
||||||
|
|
||||||
# Process with controlled concurrency
|
# Process with controlled concurrency
|
||||||
max_workers = self.pipeline_config.llm.max_concurrent_requests
|
max_workers = self.llm_config.max_concurrent_requests
|
||||||
|
|
||||||
if len(unannotated_entries) == 1 or max_workers == 1:
|
if len(unannotated_entries) == 1 or max_workers == 1:
|
||||||
# Sequential processing
|
# Sequential processing
|
||||||
for idx, row in unannotated_entries.iterrows():
|
for idx, row in unannotated_entries.iterrows():
|
||||||
result = self.analyze_name_with_retry(client, row["name"], idx)
|
result = self.analyze_name(client, row["name"])
|
||||||
for field, value in result.items():
|
for field, value in result.items():
|
||||||
if field not in ["failed"]:
|
if field not in ["failed"]:
|
||||||
batch.loc[idx, field] = value
|
batch.loc[idx, field] = value
|
||||||
@@ -146,7 +141,7 @@ class LLMAnnotationStep(PipelineStep):
|
|||||||
future_to_idx = {}
|
future_to_idx = {}
|
||||||
|
|
||||||
for idx, row in unannotated_entries.iterrows():
|
for idx, row in unannotated_entries.iterrows():
|
||||||
future = executor.submit(self.analyze_name_with_retry, client, row["name"], idx)
|
future = executor.submit(self.analyze_name, client, row["name"])
|
||||||
future_to_idx[future] = idx
|
future_to_idx[future] = idx
|
||||||
|
|
||||||
for future in as_completed(future_to_idx):
|
for future in as_completed(future_to_idx):
|
||||||
@@ -161,8 +156,6 @@ class LLMAnnotationStep(PipelineStep):
|
|||||||
batch.loc[idx, "annotated"] = 0
|
batch.loc[idx, "annotated"] = 0
|
||||||
|
|
||||||
# Ensure proper data types
|
# Ensure proper data types
|
||||||
batch["annotated"] = (
|
batch["annotated"] = pd.to_numeric(batch["annotated"], errors="coerce").fillna(0).astype("Int8")
|
||||||
pd.to_numeric(batch["annotated"], errors="coerce").fillna(0).astype("Int8")
|
|
||||||
)
|
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|||||||
@@ -0,0 +1,164 @@
|
|||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from core.config.pipeline_config import PipelineConfig
|
||||||
|
from processing.steps import PipelineStep, NameAnnotation
|
||||||
|
from processing.ner.ner_name_model import NERNameModel
|
||||||
|
|
||||||
|
|
||||||
|
class NERAnnotationStep(PipelineStep):
|
||||||
|
"""NER annotation step using trained spaCy model for entity recognition"""
|
||||||
|
|
||||||
|
def __init__(self, pipeline_config: PipelineConfig):
|
||||||
|
# Create custom batch config for NER processing
|
||||||
|
super().__init__("ner_annotation", pipeline_config)
|
||||||
|
|
||||||
|
self.model_name = "drc_ner_model"
|
||||||
|
self.model_path = pipeline_config.paths.models_dir / "drc_ner_model"
|
||||||
|
self.ner_trainer = NERNameModel(pipeline_config)
|
||||||
|
self.ner_config = pipeline_config.annotation.ner
|
||||||
|
|
||||||
|
# Statistics
|
||||||
|
self.successful_requests = 0
|
||||||
|
self.failed_requests = 0
|
||||||
|
self.total_retry_attempts = 0
|
||||||
|
|
||||||
|
# Load the model
|
||||||
|
self._load_ner_model()
|
||||||
|
|
||||||
|
def _load_ner_model(self) -> None:
|
||||||
|
"""Load the trained NER model"""
|
||||||
|
try:
|
||||||
|
if self.model_path.exists():
|
||||||
|
logging.info(f"Loading NER model from {self.model_path}")
|
||||||
|
self.ner_trainer.load(str(self.model_path))
|
||||||
|
logging.info("NER model loaded successfully")
|
||||||
|
else:
|
||||||
|
logging.warning(f"NER model not found at {self.model_path}")
|
||||||
|
logging.warning("NER annotation will be skipped. Train the model first.")
|
||||||
|
self.ner_trainer.nlp = None
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to load NER model: {e}")
|
||||||
|
self.ner_trainer.nlp = None
|
||||||
|
|
||||||
|
def analyze_name(self, name: str) -> Dict:
|
||||||
|
"""Analyze a name with retry logic"""
|
||||||
|
if self.ner_trainer.nlp is None:
|
||||||
|
return {
|
||||||
|
"identified_name": None,
|
||||||
|
"identified_surname": None,
|
||||||
|
"annotated": 0,
|
||||||
|
"processing_time": 0,
|
||||||
|
"attempts": 0,
|
||||||
|
"failed": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
for attempt in range(self.ner_config.retry_attempts):
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Get NER predictions
|
||||||
|
prediction = self.ner_trainer.predict(name.lower())
|
||||||
|
entities = prediction.get('entities', [])
|
||||||
|
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Extract native names and surnames from entities
|
||||||
|
native_parts = []
|
||||||
|
surname_parts = []
|
||||||
|
|
||||||
|
for entity in entities:
|
||||||
|
if entity['label'] == 'NATIVE':
|
||||||
|
native_parts.append(entity['text'])
|
||||||
|
elif entity['label'] == 'SURNAME':
|
||||||
|
surname_parts.append(entity['text'])
|
||||||
|
|
||||||
|
# Create annotation result in same format as LLM step
|
||||||
|
annotation = NameAnnotation(
|
||||||
|
identified_name=" ".join(native_parts) if native_parts else None,
|
||||||
|
identified_surname=" ".join(surname_parts) if surname_parts else None
|
||||||
|
)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
**annotation.model_dump(),
|
||||||
|
"annotated": 1,
|
||||||
|
"processing_time": elapsed_time,
|
||||||
|
"attempts": attempt + 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.successful_requests += 1
|
||||||
|
if attempt > 0:
|
||||||
|
self.total_retry_attempts += attempt
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(
|
||||||
|
f"Error analyzing '{name}' with NER (attempt {attempt + 1}/{self.ner_config.retry_attempts}): {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Small delay between retries
|
||||||
|
if attempt < self.ner_config.retry_attempts - 1:
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
self.failed_requests += 1
|
||||||
|
return {
|
||||||
|
"identified_name": None,
|
||||||
|
"identified_surname": None,
|
||||||
|
"annotated": 0,
|
||||||
|
"processing_time": 0,
|
||||||
|
"attempts": self.ner_config.retry_attempts,
|
||||||
|
"failed": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
def process_batch(self, batch: pd.DataFrame, batch_id: int) -> pd.DataFrame:
|
||||||
|
"""Process batch with NER annotation using same logic as LLM step"""
|
||||||
|
unannotated_mask = batch.get("annotated", 0) == 0
|
||||||
|
unannotated_entries = batch[unannotated_mask]
|
||||||
|
|
||||||
|
if unannotated_entries.empty:
|
||||||
|
logging.info(f"Batch {batch_id}: No entries to annotate")
|
||||||
|
return batch
|
||||||
|
|
||||||
|
logging.info(f"Batch {batch_id}: Annotating {len(unannotated_entries)} entries with NER")
|
||||||
|
|
||||||
|
batch = batch.copy()
|
||||||
|
|
||||||
|
# Process with controlled concurrency
|
||||||
|
max_workers = self.batch_config.max_workers
|
||||||
|
|
||||||
|
if len(unannotated_entries) == 1 or max_workers == 1:
|
||||||
|
# Sequential processing
|
||||||
|
for idx, row in unannotated_entries.iterrows():
|
||||||
|
result = self.analyze_name(row["name"])
|
||||||
|
for field, value in result.items():
|
||||||
|
if field not in ["failed"]:
|
||||||
|
batch.loc[idx, field] = value
|
||||||
|
else:
|
||||||
|
# Concurrent processing
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
future_to_idx = {}
|
||||||
|
|
||||||
|
for idx, row in unannotated_entries.iterrows():
|
||||||
|
future = executor.submit(self.analyze_name, row["name"])
|
||||||
|
future_to_idx[future] = idx
|
||||||
|
|
||||||
|
for future in as_completed(future_to_idx):
|
||||||
|
idx = future_to_idx[future]
|
||||||
|
try:
|
||||||
|
result = future.result()
|
||||||
|
for field, value in result.items():
|
||||||
|
if field not in ["failed"]:
|
||||||
|
batch.loc[idx, field] = value
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to process row {idx}: {e}")
|
||||||
|
batch.loc[idx, "annotated"] = 0
|
||||||
|
|
||||||
|
# Ensure proper data types
|
||||||
|
batch["annotated"] = pd.to_numeric(batch["annotated"], errors="coerce").fillna(0).astype("Int8")
|
||||||
|
|
||||||
|
return batch
|
||||||
@@ -176,3 +176,4 @@ altair==5.1.2
|
|||||||
PyYAML~=6.0.2
|
PyYAML~=6.0.2
|
||||||
xgboost~=3.0.3
|
xgboost~=3.0.3
|
||||||
lightgbm~=4.6.0
|
lightgbm~=4.6.0
|
||||||
|
spacy~=3.8.7
|
||||||
Reference in New Issue
Block a user