Files
drc-ners-nlp/web/interfaces/ner_testing.py
T
2025-09-21 16:23:44 +02:00

159 lines
5.8 KiB
Python

from pathlib import Path
import streamlit as st
from spacy import displacy
from core.config import PipelineConfig
from processing.ner.name_model import NameModel
class NERTesting:
def __init__(self, config: PipelineConfig):
self.config = config
self.model_path = config.paths.models_dir / "drc_ner_model"
self.ner_model = None
self.training_stats = None
self.evaluation_stats = None
def load_ner_model(self) -> bool:
"""Load the trained NER model"""
try:
if self.ner_model is None:
self.ner_model = NameModel(self.config)
self.ner_model.load(str(self.model_path))
self.training_stats = self.ner_model.training_stats
self.evaluation_stats = {}
return True
except Exception as e:
st.error(f"Error loading NER model: {e}")
return False
def index(self):
st.title("Named Entity Recognition")
# Load model
if not self.load_ner_model():
st.warning(
"NER model could not be loaded. Please ensure the model is trained and available."
)
return
# Display model information
self.show_model_training_info()
self.show_model_evaluation_info()
st.markdown("---")
st.subheader("Test the NER Model")
input_method = st.radio("Input Method", ["Single Name", "Multiple Names"])
if input_method == "Single Name":
self.test_single_name()
elif input_method == "Multiple Names":
self.test_multiple_names()
def show_model_training_info(self):
if self.training_stats:
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric(
"Training Examples", f"{self.training_stats.get('training_examples', 0):,}"
)
with col2:
st.metric("Epochs", self.training_stats.get("epochs", 0))
with col3:
st.metric("Final Loss", f"{self.training_stats.get('final_loss', 0):.2f}")
with col4:
st.metric("Batch Size", f"{self.training_stats.get('batch_size', 0):,}")
def show_model_evaluation_info(self):
if self.evaluation_stats:
col1, col2, col3 = st.columns(4)
overall = self.evaluation_stats.get("overall", {})
with col1:
st.metric("Overall Precision", f"{overall['precision']:.2f}")
with col2:
st.metric("Overall Recall", f"{overall['recall']:.2f}")
with col3:
st.metric("Overall F1 Score", f"{overall['f1_score']:.2f}")
st.json(self.evaluation_stats.get("by_label", {}))
def test_single_name(self):
name_input = st.text_input(
"Name:",
placeholder="e.g., Jean Baptiste Mukendi, Marie Kabamba Tshiala, Joseph Kasongo",
help="Enter a full name or multiple names separated by spaces",
)
if name_input.strip():
if st.button("Analyze Name", type="primary"):
self.analyze_and_display(name_input)
def test_multiple_names(self):
names_input = st.text_area(
"Names:",
placeholder="Jean Baptiste Mukendi\nMarie Kabamba Tshiala\nJoseph Kasongo\nGrace Mbuyi Kalala",
height=150,
help="Enter each name on a new line",
)
if names_input.strip():
if st.button("Analyze All Names", type="primary"):
names = [name.strip() for name in names_input.split("\n") if name.strip()]
for i, name in enumerate(names):
st.markdown(f"**Name {i+1}: {name}**")
self.analyze_and_display(name)
if i < len(names) - 1:
st.markdown("---")
def analyze_and_display(self, text: str):
try:
result = self.ner_model.predict(text)
st.subheader("Analysis Results")
entities = result.get("entities", [])
if entities:
self.show_visual_entities(text, entities)
native_count = sum(1 for e in entities if e["label"] == "NATIVE")
surname_count = sum(1 for e in entities if e["label"] == "SURNAME")
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Total Entities", len(entities))
with col2:
st.metric("Native Names", native_count)
with col3:
st.metric("Surnames", surname_count)
else:
st.warning("No entities detected in the input text.")
st.info("Try using traditional Congolese names or ensure the spelling is correct.")
except Exception as e:
st.error(f"Error analyzing text: {e}")
@classmethod
def show_visual_entities(cls, text: str, entities: list):
try:
# Convert our entities format to spaCy format for displacy
ents = []
for entity in entities:
ents.append(
{"start": entity["start"], "end": entity["end"], "label": entity["label"]}
)
# Create doc-like structure for displacy
doc_data = {"text": text, "ents": ents, "title": None}
# Custom colors for our labels
colors = {"NATIVE": "#74C0FC", "SURNAME": "#69DB7C"} # Light blue # Light green
options = {"colors": colors, "distance": 90}
# Generate HTML visualization
html = displacy.render(doc_data, style="ent", manual=True, options=options)
st.markdown(html, unsafe_allow_html=True)
except Exception as e:
st.warning(f"Could not generate visual representation: {e}")