feat: add NER testing interface and evaluation statistics handling
This commit is contained in:
@@ -44,10 +44,11 @@ def train(config: PipelineConfig):
|
||||
batch_size=config.processing.batch_size,
|
||||
dropout_rate=0.3,
|
||||
)
|
||||
name_model.evaluate(eval_data)
|
||||
evaluation_results = name_model.evaluate(eval_data)
|
||||
|
||||
model_path = name_model.save()
|
||||
logging.info(f"Model saved to: {model_path}")
|
||||
print(f"Evaluation results: {evaluation_results}")
|
||||
|
||||
|
||||
def run_pipeline(config: PipelineConfig, reset: bool = False):
|
||||
|
||||
@@ -23,6 +23,7 @@ class NameModel:
|
||||
self.ner = None
|
||||
self.model_path = None
|
||||
self.training_stats = {}
|
||||
self.evaluation_stats = {}
|
||||
|
||||
def create_blank_model(self, language: str = "fr") -> None:
|
||||
"""Create a blank spaCy model with NER pipeline"""
|
||||
@@ -304,7 +305,7 @@ class NameModel:
|
||||
"support": tp + fn,
|
||||
}
|
||||
|
||||
evaluation_results = {
|
||||
self.evaluation_stats = {
|
||||
"overall": {
|
||||
"precision": precision,
|
||||
"recall": recall,
|
||||
@@ -317,8 +318,7 @@ class NameModel:
|
||||
"by_label": label_metrics,
|
||||
}
|
||||
|
||||
logging.info(f"NER Evaluation completed. Overall F1: {f1_score:.4f}")
|
||||
return evaluation_results
|
||||
return self.evaluation_stats
|
||||
|
||||
def save(self, model_name: str = "drc_ner_model") -> str:
|
||||
"""Save the trained model"""
|
||||
@@ -333,11 +333,15 @@ class NameModel:
|
||||
self.nlp.to_disk(model_dir)
|
||||
self.model_path = str(model_dir)
|
||||
|
||||
# Save training statistics
|
||||
stats_path = model_dir / "training_stats.json"
|
||||
with open(stats_path, "w", encoding="utf-8") as f:
|
||||
# Save training and evaluation statistics
|
||||
training_stats_path = model_dir / "training_stats.json"
|
||||
with open(training_stats_path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.training_stats, f, indent=2)
|
||||
|
||||
evaluation_stats_path = model_dir / "evaluation_stats.json"
|
||||
with open(evaluation_stats_path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.evaluation_stats, f, indent=2)
|
||||
|
||||
logging.info(f"NER Model saved to {model_dir}")
|
||||
return self.model_path
|
||||
|
||||
@@ -352,11 +356,16 @@ class NameModel:
|
||||
self.model_path = model_path
|
||||
|
||||
# Load training statistics if available
|
||||
stats_path = Path(model_path) / "training_stats.json"
|
||||
if stats_path.exists():
|
||||
with open(stats_path, "r", encoding="utf-8") as f:
|
||||
training_stats_path = Path(model_path) / "training_stats.json"
|
||||
if training_stats_path.exists():
|
||||
with open(training_stats_path, "r", encoding="utf-8") as f:
|
||||
self.training_stats = json.load(f)
|
||||
|
||||
evaluation_stats_path = Path(model_path) / "evaluation_stats.json"
|
||||
if evaluation_stats_path.exists():
|
||||
with open(evaluation_stats_path, "r", encoding="utf-8") as f:
|
||||
self.evaluation_stats = json.load(f)
|
||||
|
||||
logging.info("NER Model loaded successfully")
|
||||
|
||||
def predict(self, text: str) -> Dict[str, Any]:
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
from .ner_testing import NERTesting
|
||||
|
||||
|
||||
@@ -0,0 +1,166 @@
|
||||
from pathlib import Path
|
||||
|
||||
import streamlit as st
|
||||
from spacy import displacy
|
||||
|
||||
from core.config import PipelineConfig
|
||||
from processing.ner.name_model import NameModel
|
||||
|
||||
|
||||
class NERTesting:
|
||||
def __init__(self, config: PipelineConfig):
|
||||
self.config = config
|
||||
self.model_path = config.paths.models_dir / "drc_ner_model"
|
||||
self.ner_model = None
|
||||
self.training_stats = None
|
||||
self.evaluation_stats = None
|
||||
|
||||
def load_ner_model(self) -> bool:
|
||||
"""Load the trained NER model"""
|
||||
try:
|
||||
if self.ner_model is None:
|
||||
self.ner_model = NameModel(self.config)
|
||||
self.ner_model.load(str(self.model_path))
|
||||
self.training_stats = self.ner_model.training_stats
|
||||
self.evaluation_stats = {}
|
||||
return True
|
||||
except Exception as e:
|
||||
st.error(f"Error loading NER model: {e}")
|
||||
return False
|
||||
|
||||
def index(self):
|
||||
st.title("Named Entity Recognition")
|
||||
|
||||
# Load model
|
||||
if not self.load_ner_model():
|
||||
st.warning("NER model could not be loaded. Please ensure the model is trained and available.")
|
||||
return
|
||||
|
||||
# Display model information
|
||||
self.show_model_training_info()
|
||||
self.show_model_evaluation_info()
|
||||
|
||||
st.markdown("---")
|
||||
st.subheader("Test the NER Model")
|
||||
input_method = st.radio("Input Method", ["Single Name", "Multiple Names"])
|
||||
if input_method == "Single Name":
|
||||
self.test_single_name()
|
||||
elif input_method == "Multiple Names":
|
||||
self.test_multiple_names()
|
||||
|
||||
def show_model_training_info(self):
|
||||
if self.training_stats:
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
|
||||
with col1:
|
||||
st.metric("Training Examples", f"{self.training_stats.get('training_examples', 0):,}")
|
||||
with col2:
|
||||
st.metric("Epochs", self.training_stats.get('epochs', 0))
|
||||
with col3:
|
||||
st.metric("Final Loss", f"{self.training_stats.get('final_loss', 0):.2f}")
|
||||
with col4:
|
||||
st.metric("Batch Size", f"{self.training_stats.get('batch_size', 0):,}")
|
||||
|
||||
def show_model_evaluation_info(self):
|
||||
if self.evaluation_stats:
|
||||
col1, col2, col3 = st.columns(4)
|
||||
overall = self.evaluation_stats.get('overall', {})
|
||||
|
||||
with col1:
|
||||
st.metric("Overall Precision", f"{overall['precision']:.2f}")
|
||||
with col2:
|
||||
st.metric("Overall Recall", f"{overall['recall']:.2f}")
|
||||
with col3:
|
||||
st.metric("Overall F1 Score", f"{overall['f1_score']:.2f}")
|
||||
|
||||
st.json(self.evaluation_stats.get("by_label", {}))
|
||||
|
||||
def test_single_name(self):
|
||||
name_input = st.text_input(
|
||||
"Name:",
|
||||
placeholder="e.g., Jean Baptiste Mukendi, Marie Kabamba Tshiala, Joseph Kasongo",
|
||||
help="Enter a full name or multiple names separated by spaces"
|
||||
)
|
||||
if name_input.strip():
|
||||
if st.button("Analyze Name", type="primary"):
|
||||
self.analyze_and_display(name_input)
|
||||
|
||||
def test_multiple_names(self):
|
||||
names_input = st.text_area(
|
||||
"Names:",
|
||||
placeholder="Jean Baptiste Mukendi\nMarie Kabamba Tshiala\nJoseph Kasongo\nGrace Mbuyi Kalala",
|
||||
height=150,
|
||||
help="Enter each name on a new line"
|
||||
)
|
||||
|
||||
if names_input.strip():
|
||||
if st.button("Analyze All Names", type="primary"):
|
||||
names = [name.strip() for name in names_input.split('\n') if name.strip()]
|
||||
for i, name in enumerate(names):
|
||||
st.markdown(f"**Name {i+1}: {name}**")
|
||||
self.analyze_and_display(name)
|
||||
if i < len(names) - 1:
|
||||
st.markdown("---")
|
||||
|
||||
def analyze_and_display(self, text: str):
|
||||
try:
|
||||
result = self.ner_model.predict(text)
|
||||
st.subheader("Analysis Results")
|
||||
entities = result.get('entities', [])
|
||||
|
||||
if entities:
|
||||
self.show_visual_entities(text, entities)
|
||||
native_count = sum(1 for e in entities if e['label'] == 'NATIVE')
|
||||
surname_count = sum(1 for e in entities if e['label'] == 'SURNAME')
|
||||
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.metric("Total Entities", len(entities))
|
||||
with col2:
|
||||
st.metric("Native Names", native_count)
|
||||
with col3:
|
||||
st.metric("Surnames", surname_count)
|
||||
|
||||
else:
|
||||
st.warning("No entities detected in the input text.")
|
||||
st.info("Try using traditional Congolese names or ensure the spelling is correct.")
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error analyzing text: {e}")
|
||||
|
||||
@classmethod
|
||||
def show_visual_entities(cls, text: str, entities: list):
|
||||
try:
|
||||
# Convert our entities format to spaCy format for displacy
|
||||
ents = []
|
||||
for entity in entities:
|
||||
ents.append({
|
||||
"start": entity['start'],
|
||||
"end": entity['end'],
|
||||
"label": entity['label']
|
||||
})
|
||||
|
||||
# Create doc-like structure for displacy
|
||||
doc_data = {
|
||||
"text": text,
|
||||
"ents": ents,
|
||||
"title": None
|
||||
}
|
||||
|
||||
# Custom colors for our labels
|
||||
colors = {
|
||||
"NATIVE": "#74C0FC", # Light blue
|
||||
"SURNAME": "#69DB7C" # Light green
|
||||
}
|
||||
|
||||
options = {
|
||||
"colors": colors,
|
||||
"distance": 90
|
||||
}
|
||||
|
||||
# Generate HTML visualization
|
||||
html = displacy.render(doc_data, style="ent", manual=True, options=options)
|
||||
st.markdown(html, unsafe_allow_html=True)
|
||||
|
||||
except Exception as e:
|
||||
st.warning(f"Could not generate visual representation: {e}")
|
||||
@@ -60,10 +60,7 @@ class Predictions:
|
||||
|
||||
def show_single_prediction(self, experiment):
|
||||
"""Show single name prediction interface"""
|
||||
st.subheader("Single Name Prediction")
|
||||
|
||||
name_input = st.text_input("Enter a name:", placeholder="e.g., Jean Baptiste Mukendi")
|
||||
|
||||
if name_input and st.button("Predict Gender"):
|
||||
try:
|
||||
# Load the model
|
||||
@@ -132,11 +129,7 @@ class Predictions:
|
||||
st.info(f"Features used: {', '.join([f.value for f in experiment.config.features])}")
|
||||
|
||||
def show_batch_prediction(self, experiment):
|
||||
"""Show batch prediction interface"""
|
||||
st.subheader("Batch Prediction")
|
||||
|
||||
uploaded_file = st.file_uploader("Upload CSV file with names", type="csv")
|
||||
|
||||
if uploaded_file is not None:
|
||||
try:
|
||||
df = pd.read_csv(uploaded_file, dtype=OPTIMIZED_DTYPES)
|
||||
@@ -251,11 +244,6 @@ class Predictions:
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
def show_dataset_prediction(self, experiment):
|
||||
"""Show dataset prediction interface"""
|
||||
st.subheader("Dataset Prediction")
|
||||
st.write("Apply the model to existing datasets")
|
||||
|
||||
# Dataset selection
|
||||
dataset_options = {
|
||||
"Featured Dataset": self.config.data.output_files["featured"],
|
||||
"Evaluation Dataset": self.config.data.output_files["evaluation"],
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import streamlit as st
|
||||
|
||||
# Add parent directory to Python path to access core modules
|
||||
parent_dir = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(parent_dir))
|
||||
|
||||
from web.interfaces.ner_testing import NERTesting
|
||||
|
||||
st.set_page_config(page_title="NER Testing", page_icon="🏷️", layout="wide")
|
||||
|
||||
if "config" in st.session_state:
|
||||
ner_testing = NERTesting(st.session_state.config)
|
||||
ner_testing.index()
|
||||
else:
|
||||
st.error("Please run the main app first to initialize the configuration.")
|
||||
st.markdown("Go back to the [main page](/) to start the application.")
|
||||
Reference in New Issue
Block a user