Files
drc-ners-nlp/interface/results_analysis.py
T

333 lines
12 KiB
Python

from typing import List
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import streamlit as st
from research.experiment.experiment_runner import ExperimentRunner
from research.experiment.experiment_tracker import ExperimentTracker
class ResultsAnalysis:
"""Handles experiment results and analysis interface"""
def __init__(self, config, experiment_tracker: ExperimentTracker, experiment_runner: ExperimentRunner):
self.config = config
self.experiment_tracker = experiment_tracker
self.experiment_runner = experiment_runner
def index(self):
"""Main results analysis page"""
st.header("Results & Analysis")
tab1, tab2, tab3 = st.tabs(["Experiment Comparison", "Performance Analysis", "Model Analysis"])
with tab1:
self.show_experiment_comparison()
with tab2:
self.show_performance_analysis()
with tab3:
self.show_model_analysis()
def show_experiment_comparison(self):
"""Show experiment comparison interface"""
st.subheader("Compare Experiments")
experiments = self.experiment_tracker.list_experiments()
completed_experiments = [e for e in experiments if e.status.value == "completed"]
if not completed_experiments:
st.warning("No completed experiments found.")
return
# Experiment selection
exp_options = {
f"{exp.config.name} ({exp.experiment_id[:8]})": exp.experiment_id
for exp in completed_experiments
}
selected_exp_names = st.multiselect(
"Select Experiments to Compare",
list(exp_options.keys()),
default=list(exp_options.keys())[: min(5, len(exp_options))],
)
if not selected_exp_names:
st.info("Please select experiments to compare.")
return
selected_exp_ids = [exp_options[name] for name in selected_exp_names]
# Generate comparison
comparison_df = self.experiment_runner.compare_experiments(selected_exp_ids)
if comparison_df.empty:
st.error("No data available for comparison.")
return
self._display_comparison_table(comparison_df)
self._display_comparison_charts(comparison_df)
def _display_comparison_table(self, comparison_df: pd.DataFrame):
"""Display comparison table"""
st.write("**Experiment Comparison Table**")
# Select columns to display
metric_columns = [
col for col in comparison_df.columns if col.startswith("test_") or col.startswith("cv_")
]
display_columns = ["name", "model_type", "features"] + metric_columns
available_columns = [col for col in display_columns if col in comparison_df.columns]
st.dataframe(comparison_df[available_columns], use_container_width=True)
def _display_comparison_charts(self, comparison_df: pd.DataFrame):
"""Display comparison charts"""
st.write("**Performance Comparison**")
if "test_accuracy" in comparison_df.columns:
fig = px.bar(
comparison_df,
x="name",
y="test_accuracy",
color="model_type",
title="Test Accuracy Comparison",
)
fig.update_layout(xaxis_tickangle=-45)
st.plotly_chart(fig, use_container_width=True)
# Metric comparison across multiple metrics
metric_columns = [
col for col in comparison_df.columns if col.startswith("test_") or col.startswith("cv_")
]
if len(metric_columns) > 1:
metric_to_plot = st.selectbox("Select Metric for Detailed Comparison", metric_columns)
if metric_to_plot in comparison_df.columns:
fig = px.bar(
comparison_df,
x="name",
y=metric_to_plot,
color="model_type",
title=f"{metric_to_plot.replace('_', ' ').title()} Comparison",
)
fig.update_layout(xaxis_tickangle=-45)
st.plotly_chart(fig, use_container_width=True)
def show_performance_analysis(self):
"""Show performance analysis across experiments"""
st.subheader("Performance Analysis")
experiments = self.experiment_tracker.list_experiments()
completed_experiments = [
e for e in experiments if e.status.value == "completed" and e.test_metrics
]
if not completed_experiments:
st.warning("No completed experiments with metrics found.")
return
# Prepare data for analysis
analysis_data = self._prepare_analysis_data(completed_experiments)
analysis_df = pd.DataFrame(analysis_data)
self._display_performance_trends(analysis_df)
self._display_model_comparison(analysis_df)
self._display_top_experiments(analysis_df)
def _prepare_analysis_data(self, completed_experiments: List) -> List[dict]:
"""Prepare data for performance analysis"""
analysis_data = []
for exp in completed_experiments:
row = {
"experiment_id": exp.experiment_id,
"name": exp.config.name,
"model_type": exp.config.model_type,
"feature_count": len(exp.config.features),
"features": ", ".join([f.value for f in exp.config.features]),
"train_size": exp.train_size,
"test_size": exp.test_size,
**exp.test_metrics,
}
analysis_data.append(row)
return analysis_data
def _display_performance_trends(self, analysis_df: pd.DataFrame):
"""Display performance trend charts"""
col1, col2 = st.columns(2)
with col1:
# Accuracy vs Training Size
if "accuracy" in analysis_df.columns and "train_size" in analysis_df.columns:
fig = px.scatter(
analysis_df,
x="train_size",
y="accuracy",
color="model_type",
hover_data=["name"],
title="Accuracy vs Training Size",
)
st.plotly_chart(fig, use_container_width=True)
with col2:
# Feature Count vs Performance
if "accuracy" in analysis_df.columns and "feature_count" in analysis_df.columns:
fig = px.scatter(
analysis_df,
x="feature_count",
y="accuracy",
color="model_type",
hover_data=["name"],
title="Accuracy vs Number of Features",
)
st.plotly_chart(fig, use_container_width=True)
def _display_model_comparison(self, analysis_df: pd.DataFrame):
"""Display model type comparison"""
if "accuracy" in analysis_df.columns:
model_performance = (
analysis_df.groupby("model_type")["accuracy"]
.agg(["mean", "std", "count"])
.reset_index()
)
fig = go.Figure()
fig.add_trace(
go.Bar(
x=model_performance["model_type"],
y=model_performance["mean"],
error_y=dict(type="data", array=model_performance["std"]),
name="Average Accuracy",
)
)
fig.update_layout(title="Average Accuracy by Model Type", yaxis_title="Accuracy")
st.plotly_chart(fig, use_container_width=True)
def _display_top_experiments(self, analysis_df: pd.DataFrame):
"""Display top performing experiments"""
st.subheader("Top Performing Experiments")
if "accuracy" in analysis_df.columns:
display_columns = ["name", "model_type", "features", "accuracy"]
# Add other metrics if available
for metric in ["precision", "recall", "f1"]:
if metric in analysis_df.columns:
display_columns.append(metric)
top_experiments = analysis_df.nlargest(5, "accuracy")[display_columns]
st.dataframe(top_experiments, use_container_width=True)
def show_model_analysis(self):
"""Show detailed model analysis"""
st.subheader("Model Analysis")
experiments = self.experiment_tracker.list_experiments()
completed_experiments = [e for e in experiments if e.status.value == "completed"]
if not completed_experiments:
st.warning("No completed experiments found.")
return
# Select experiment for detailed analysis
exp_options = {
f"{exp.config.name} ({exp.experiment_id[:8]})": exp for exp in completed_experiments
}
selected_exp_name = st.selectbox(
"Select Experiment for Detailed Analysis", list(exp_options.keys())
)
if not selected_exp_name:
return
selected_exp = exp_options[selected_exp_name]
self._display_experiment_details(selected_exp)
self._display_confusion_matrix(selected_exp)
self._display_feature_importance(selected_exp)
self._display_prediction_examples(selected_exp)
def _display_experiment_details(self, experiment):
"""Display experiment configuration and metrics"""
col1, col2 = st.columns(2)
with col1:
st.write("**Experiment Configuration**")
st.json(
{
"name": experiment.config.name,
"model_type": experiment.config.model_type,
"features": [f.value for f in experiment.config.features],
"model_params": experiment.config.model_params,
}
)
with col2:
st.write("**Performance Metrics**")
if experiment.test_metrics:
for metric, value in experiment.test_metrics.items():
st.metric(metric.title(), f"{value:.4f}")
def _display_confusion_matrix(self, experiment):
"""Display confusion matrix if available"""
if experiment.confusion_matrix:
st.write("**Confusion Matrix**")
cm = np.array(experiment.confusion_matrix)
fig = px.imshow(cm, text_auto=True, aspect="auto", title="Confusion Matrix")
st.plotly_chart(fig, use_container_width=True)
def _display_feature_importance(self, experiment):
"""Display feature importance if available"""
if experiment.feature_importance:
st.write("**Feature Importance**")
importance_data = sorted(
experiment.feature_importance.items(), key=lambda x: x[1], reverse=True
)[:20]
features, importances = zip(*importance_data)
fig = px.bar(
x=list(importances),
y=list(features),
orientation="h",
title="Top 20 Feature Importances",
)
fig.update_layout(height=600)
st.plotly_chart(fig, use_container_width=True)
def _display_prediction_examples(self, experiment):
"""Display prediction examples if available"""
if experiment.prediction_examples:
st.write("**Prediction Examples**")
examples_df = pd.DataFrame(experiment.prediction_examples)
# Separate correct and incorrect predictions
correct_examples = examples_df[examples_df["correct"] == True]
incorrect_examples = examples_df[examples_df["correct"] == False]
col1, col2 = st.columns(2)
with col1:
st.write("**Correct Predictions**")
if not correct_examples.empty:
st.dataframe(
correct_examples[["name", "true_label", "predicted_label"]],
use_container_width=True,
)
with col2:
st.write("**Incorrect Predictions**")
if not incorrect_examples.empty:
st.dataframe(
incorrect_examples[["name", "true_label", "predicted_label"]],
use_container_width=True,
)