333 lines
12 KiB
Python
333 lines
12 KiB
Python
from typing import List
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import plotly.express as px
|
|
import plotly.graph_objects as go
|
|
import streamlit as st
|
|
|
|
from research.experiment.experiment_runner import ExperimentRunner
|
|
from research.experiment.experiment_tracker import ExperimentTracker
|
|
|
|
|
|
class ResultsAnalysis:
|
|
"""Handles experiment results and analysis interface"""
|
|
|
|
def __init__(self, config, experiment_tracker: ExperimentTracker, experiment_runner: ExperimentRunner):
|
|
self.config = config
|
|
self.experiment_tracker = experiment_tracker
|
|
self.experiment_runner = experiment_runner
|
|
|
|
def index(self):
|
|
"""Main results analysis page"""
|
|
st.header("Results & Analysis")
|
|
tab1, tab2, tab3 = st.tabs(["Experiment Comparison", "Performance Analysis", "Model Analysis"])
|
|
|
|
with tab1:
|
|
self.show_experiment_comparison()
|
|
|
|
with tab2:
|
|
self.show_performance_analysis()
|
|
|
|
with tab3:
|
|
self.show_model_analysis()
|
|
|
|
def show_experiment_comparison(self):
|
|
"""Show experiment comparison interface"""
|
|
st.subheader("Compare Experiments")
|
|
|
|
experiments = self.experiment_tracker.list_experiments()
|
|
completed_experiments = [e for e in experiments if e.status.value == "completed"]
|
|
|
|
if not completed_experiments:
|
|
st.warning("No completed experiments found.")
|
|
return
|
|
|
|
# Experiment selection
|
|
exp_options = {
|
|
f"{exp.config.name} ({exp.experiment_id[:8]})": exp.experiment_id
|
|
for exp in completed_experiments
|
|
}
|
|
|
|
selected_exp_names = st.multiselect(
|
|
"Select Experiments to Compare",
|
|
list(exp_options.keys()),
|
|
default=list(exp_options.keys())[: min(5, len(exp_options))],
|
|
)
|
|
|
|
if not selected_exp_names:
|
|
st.info("Please select experiments to compare.")
|
|
return
|
|
|
|
selected_exp_ids = [exp_options[name] for name in selected_exp_names]
|
|
|
|
# Generate comparison
|
|
comparison_df = self.experiment_runner.compare_experiments(selected_exp_ids)
|
|
|
|
if comparison_df.empty:
|
|
st.error("No data available for comparison.")
|
|
return
|
|
|
|
self._display_comparison_table(comparison_df)
|
|
self._display_comparison_charts(comparison_df)
|
|
|
|
def _display_comparison_table(self, comparison_df: pd.DataFrame):
|
|
"""Display comparison table"""
|
|
st.write("**Experiment Comparison Table**")
|
|
|
|
# Select columns to display
|
|
metric_columns = [
|
|
col for col in comparison_df.columns if col.startswith("test_") or col.startswith("cv_")
|
|
]
|
|
display_columns = ["name", "model_type", "features"] + metric_columns
|
|
available_columns = [col for col in display_columns if col in comparison_df.columns]
|
|
|
|
st.dataframe(comparison_df[available_columns], use_container_width=True)
|
|
|
|
def _display_comparison_charts(self, comparison_df: pd.DataFrame):
|
|
"""Display comparison charts"""
|
|
st.write("**Performance Comparison**")
|
|
|
|
if "test_accuracy" in comparison_df.columns:
|
|
fig = px.bar(
|
|
comparison_df,
|
|
x="name",
|
|
y="test_accuracy",
|
|
color="model_type",
|
|
title="Test Accuracy Comparison",
|
|
)
|
|
fig.update_layout(xaxis_tickangle=-45)
|
|
st.plotly_chart(fig, use_container_width=True)
|
|
|
|
# Metric comparison across multiple metrics
|
|
metric_columns = [
|
|
col for col in comparison_df.columns if col.startswith("test_") or col.startswith("cv_")
|
|
]
|
|
|
|
if len(metric_columns) > 1:
|
|
metric_to_plot = st.selectbox("Select Metric for Detailed Comparison", metric_columns)
|
|
|
|
if metric_to_plot in comparison_df.columns:
|
|
fig = px.bar(
|
|
comparison_df,
|
|
x="name",
|
|
y=metric_to_plot,
|
|
color="model_type",
|
|
title=f"{metric_to_plot.replace('_', ' ').title()} Comparison",
|
|
)
|
|
fig.update_layout(xaxis_tickangle=-45)
|
|
st.plotly_chart(fig, use_container_width=True)
|
|
|
|
def show_performance_analysis(self):
|
|
"""Show performance analysis across experiments"""
|
|
st.subheader("Performance Analysis")
|
|
|
|
experiments = self.experiment_tracker.list_experiments()
|
|
completed_experiments = [
|
|
e for e in experiments if e.status.value == "completed" and e.test_metrics
|
|
]
|
|
|
|
if not completed_experiments:
|
|
st.warning("No completed experiments with metrics found.")
|
|
return
|
|
|
|
# Prepare data for analysis
|
|
analysis_data = self._prepare_analysis_data(completed_experiments)
|
|
analysis_df = pd.DataFrame(analysis_data)
|
|
|
|
self._display_performance_trends(analysis_df)
|
|
self._display_model_comparison(analysis_df)
|
|
self._display_top_experiments(analysis_df)
|
|
|
|
def _prepare_analysis_data(self, completed_experiments: List) -> List[dict]:
|
|
"""Prepare data for performance analysis"""
|
|
analysis_data = []
|
|
for exp in completed_experiments:
|
|
row = {
|
|
"experiment_id": exp.experiment_id,
|
|
"name": exp.config.name,
|
|
"model_type": exp.config.model_type,
|
|
"feature_count": len(exp.config.features),
|
|
"features": ", ".join([f.value for f in exp.config.features]),
|
|
"train_size": exp.train_size,
|
|
"test_size": exp.test_size,
|
|
**exp.test_metrics,
|
|
}
|
|
analysis_data.append(row)
|
|
return analysis_data
|
|
|
|
def _display_performance_trends(self, analysis_df: pd.DataFrame):
|
|
"""Display performance trend charts"""
|
|
col1, col2 = st.columns(2)
|
|
|
|
with col1:
|
|
# Accuracy vs Training Size
|
|
if "accuracy" in analysis_df.columns and "train_size" in analysis_df.columns:
|
|
fig = px.scatter(
|
|
analysis_df,
|
|
x="train_size",
|
|
y="accuracy",
|
|
color="model_type",
|
|
hover_data=["name"],
|
|
title="Accuracy vs Training Size",
|
|
)
|
|
st.plotly_chart(fig, use_container_width=True)
|
|
|
|
with col2:
|
|
# Feature Count vs Performance
|
|
if "accuracy" in analysis_df.columns and "feature_count" in analysis_df.columns:
|
|
fig = px.scatter(
|
|
analysis_df,
|
|
x="feature_count",
|
|
y="accuracy",
|
|
color="model_type",
|
|
hover_data=["name"],
|
|
title="Accuracy vs Number of Features",
|
|
)
|
|
st.plotly_chart(fig, use_container_width=True)
|
|
|
|
def _display_model_comparison(self, analysis_df: pd.DataFrame):
|
|
"""Display model type comparison"""
|
|
if "accuracy" in analysis_df.columns:
|
|
model_performance = (
|
|
analysis_df.groupby("model_type")["accuracy"]
|
|
.agg(["mean", "std", "count"])
|
|
.reset_index()
|
|
)
|
|
|
|
fig = go.Figure()
|
|
fig.add_trace(
|
|
go.Bar(
|
|
x=model_performance["model_type"],
|
|
y=model_performance["mean"],
|
|
error_y=dict(type="data", array=model_performance["std"]),
|
|
name="Average Accuracy",
|
|
)
|
|
)
|
|
fig.update_layout(title="Average Accuracy by Model Type", yaxis_title="Accuracy")
|
|
st.plotly_chart(fig, use_container_width=True)
|
|
|
|
def _display_top_experiments(self, analysis_df: pd.DataFrame):
|
|
"""Display top performing experiments"""
|
|
st.subheader("Top Performing Experiments")
|
|
|
|
if "accuracy" in analysis_df.columns:
|
|
display_columns = ["name", "model_type", "features", "accuracy"]
|
|
|
|
# Add other metrics if available
|
|
for metric in ["precision", "recall", "f1"]:
|
|
if metric in analysis_df.columns:
|
|
display_columns.append(metric)
|
|
|
|
top_experiments = analysis_df.nlargest(5, "accuracy")[display_columns]
|
|
st.dataframe(top_experiments, use_container_width=True)
|
|
|
|
def show_model_analysis(self):
|
|
"""Show detailed model analysis"""
|
|
st.subheader("Model Analysis")
|
|
|
|
experiments = self.experiment_tracker.list_experiments()
|
|
completed_experiments = [e for e in experiments if e.status.value == "completed"]
|
|
|
|
if not completed_experiments:
|
|
st.warning("No completed experiments found.")
|
|
return
|
|
|
|
# Select experiment for detailed analysis
|
|
exp_options = {
|
|
f"{exp.config.name} ({exp.experiment_id[:8]})": exp for exp in completed_experiments
|
|
}
|
|
|
|
selected_exp_name = st.selectbox(
|
|
"Select Experiment for Detailed Analysis", list(exp_options.keys())
|
|
)
|
|
|
|
if not selected_exp_name:
|
|
return
|
|
|
|
selected_exp = exp_options[selected_exp_name]
|
|
|
|
self._display_experiment_details(selected_exp)
|
|
self._display_confusion_matrix(selected_exp)
|
|
self._display_feature_importance(selected_exp)
|
|
self._display_prediction_examples(selected_exp)
|
|
|
|
def _display_experiment_details(self, experiment):
|
|
"""Display experiment configuration and metrics"""
|
|
col1, col2 = st.columns(2)
|
|
|
|
with col1:
|
|
st.write("**Experiment Configuration**")
|
|
st.json(
|
|
{
|
|
"name": experiment.config.name,
|
|
"model_type": experiment.config.model_type,
|
|
"features": [f.value for f in experiment.config.features],
|
|
"model_params": experiment.config.model_params,
|
|
}
|
|
)
|
|
|
|
with col2:
|
|
st.write("**Performance Metrics**")
|
|
if experiment.test_metrics:
|
|
for metric, value in experiment.test_metrics.items():
|
|
st.metric(metric.title(), f"{value:.4f}")
|
|
|
|
def _display_confusion_matrix(self, experiment):
|
|
"""Display confusion matrix if available"""
|
|
if experiment.confusion_matrix:
|
|
st.write("**Confusion Matrix**")
|
|
cm = np.array(experiment.confusion_matrix)
|
|
|
|
fig = px.imshow(cm, text_auto=True, aspect="auto", title="Confusion Matrix")
|
|
st.plotly_chart(fig, use_container_width=True)
|
|
|
|
def _display_feature_importance(self, experiment):
|
|
"""Display feature importance if available"""
|
|
if experiment.feature_importance:
|
|
st.write("**Feature Importance**")
|
|
|
|
importance_data = sorted(
|
|
experiment.feature_importance.items(), key=lambda x: x[1], reverse=True
|
|
)[:20]
|
|
|
|
features, importances = zip(*importance_data)
|
|
|
|
fig = px.bar(
|
|
x=list(importances),
|
|
y=list(features),
|
|
orientation="h",
|
|
title="Top 20 Feature Importances",
|
|
)
|
|
fig.update_layout(height=600)
|
|
st.plotly_chart(fig, use_container_width=True)
|
|
|
|
def _display_prediction_examples(self, experiment):
|
|
"""Display prediction examples if available"""
|
|
if experiment.prediction_examples:
|
|
st.write("**Prediction Examples**")
|
|
|
|
examples_df = pd.DataFrame(experiment.prediction_examples)
|
|
|
|
# Separate correct and incorrect predictions
|
|
correct_examples = examples_df[examples_df["correct"] == True]
|
|
incorrect_examples = examples_df[examples_df["correct"] == False]
|
|
|
|
col1, col2 = st.columns(2)
|
|
|
|
with col1:
|
|
st.write("**Correct Predictions**")
|
|
if not correct_examples.empty:
|
|
st.dataframe(
|
|
correct_examples[["name", "true_label", "predicted_label"]],
|
|
use_container_width=True,
|
|
)
|
|
|
|
with col2:
|
|
st.write("**Incorrect Predictions**")
|
|
if not incorrect_examples.empty:
|
|
st.dataframe(
|
|
incorrect_examples[["name", "true_label", "predicted_label"]],
|
|
use_container_width=True,
|
|
)
|