refactor: reorganize project structure and enhance model verbosity
This commit is contained in:
@@ -0,0 +1,332 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import plotly.express as px
|
||||
import plotly.graph_objects as go
|
||||
import streamlit as st
|
||||
|
||||
from research.experiment.experiment_runner import ExperimentRunner
|
||||
from research.experiment.experiment_tracker import ExperimentTracker
|
||||
|
||||
|
||||
class ResultsAnalysis:
|
||||
"""Handles experiment results and analysis interface"""
|
||||
|
||||
def __init__(self, config, experiment_tracker: ExperimentTracker, experiment_runner: ExperimentRunner):
|
||||
self.config = config
|
||||
self.experiment_tracker = experiment_tracker
|
||||
self.experiment_runner = experiment_runner
|
||||
|
||||
def index(self):
|
||||
"""Main results analysis page"""
|
||||
st.header("Results & Analysis")
|
||||
tab1, tab2, tab3 = st.tabs(["Experiment Comparison", "Performance Analysis", "Model Analysis"])
|
||||
|
||||
with tab1:
|
||||
self.show_experiment_comparison()
|
||||
|
||||
with tab2:
|
||||
self.show_performance_analysis()
|
||||
|
||||
with tab3:
|
||||
self.show_model_analysis()
|
||||
|
||||
def show_experiment_comparison(self):
|
||||
"""Show experiment comparison interface"""
|
||||
st.subheader("Compare Experiments")
|
||||
|
||||
experiments = self.experiment_tracker.list_experiments()
|
||||
completed_experiments = [e for e in experiments if e.status.value == "completed"]
|
||||
|
||||
if not completed_experiments:
|
||||
st.warning("No completed experiments found.")
|
||||
return
|
||||
|
||||
# Experiment selection
|
||||
exp_options = {
|
||||
f"{exp.config.name} ({exp.experiment_id[:8]})": exp.experiment_id
|
||||
for exp in completed_experiments
|
||||
}
|
||||
|
||||
selected_exp_names = st.multiselect(
|
||||
"Select Experiments to Compare",
|
||||
list(exp_options.keys()),
|
||||
default=list(exp_options.keys())[: min(5, len(exp_options))],
|
||||
)
|
||||
|
||||
if not selected_exp_names:
|
||||
st.info("Please select experiments to compare.")
|
||||
return
|
||||
|
||||
selected_exp_ids = [exp_options[name] for name in selected_exp_names]
|
||||
|
||||
# Generate comparison
|
||||
comparison_df = self.experiment_runner.compare_experiments(selected_exp_ids)
|
||||
|
||||
if comparison_df.empty:
|
||||
st.error("No data available for comparison.")
|
||||
return
|
||||
|
||||
self._display_comparison_table(comparison_df)
|
||||
self._display_comparison_charts(comparison_df)
|
||||
|
||||
def _display_comparison_table(self, comparison_df: pd.DataFrame):
|
||||
"""Display comparison table"""
|
||||
st.write("**Experiment Comparison Table**")
|
||||
|
||||
# Select columns to display
|
||||
metric_columns = [
|
||||
col for col in comparison_df.columns if col.startswith("test_") or col.startswith("cv_")
|
||||
]
|
||||
display_columns = ["name", "model_type", "features"] + metric_columns
|
||||
available_columns = [col for col in display_columns if col in comparison_df.columns]
|
||||
|
||||
st.dataframe(comparison_df[available_columns], use_container_width=True)
|
||||
|
||||
def _display_comparison_charts(self, comparison_df: pd.DataFrame):
|
||||
"""Display comparison charts"""
|
||||
st.write("**Performance Comparison**")
|
||||
|
||||
if "test_accuracy" in comparison_df.columns:
|
||||
fig = px.bar(
|
||||
comparison_df,
|
||||
x="name",
|
||||
y="test_accuracy",
|
||||
color="model_type",
|
||||
title="Test Accuracy Comparison",
|
||||
)
|
||||
fig.update_layout(xaxis_tickangle=-45)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
# Metric comparison across multiple metrics
|
||||
metric_columns = [
|
||||
col for col in comparison_df.columns if col.startswith("test_") or col.startswith("cv_")
|
||||
]
|
||||
|
||||
if len(metric_columns) > 1:
|
||||
metric_to_plot = st.selectbox("Select Metric for Detailed Comparison", metric_columns)
|
||||
|
||||
if metric_to_plot in comparison_df.columns:
|
||||
fig = px.bar(
|
||||
comparison_df,
|
||||
x="name",
|
||||
y=metric_to_plot,
|
||||
color="model_type",
|
||||
title=f"{metric_to_plot.replace('_', ' ').title()} Comparison",
|
||||
)
|
||||
fig.update_layout(xaxis_tickangle=-45)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
def show_performance_analysis(self):
|
||||
"""Show performance analysis across experiments"""
|
||||
st.subheader("Performance Analysis")
|
||||
|
||||
experiments = self.experiment_tracker.list_experiments()
|
||||
completed_experiments = [
|
||||
e for e in experiments if e.status.value == "completed" and e.test_metrics
|
||||
]
|
||||
|
||||
if not completed_experiments:
|
||||
st.warning("No completed experiments with metrics found.")
|
||||
return
|
||||
|
||||
# Prepare data for analysis
|
||||
analysis_data = self._prepare_analysis_data(completed_experiments)
|
||||
analysis_df = pd.DataFrame(analysis_data)
|
||||
|
||||
self._display_performance_trends(analysis_df)
|
||||
self._display_model_comparison(analysis_df)
|
||||
self._display_top_experiments(analysis_df)
|
||||
|
||||
def _prepare_analysis_data(self, completed_experiments: List) -> List[dict]:
|
||||
"""Prepare data for performance analysis"""
|
||||
analysis_data = []
|
||||
for exp in completed_experiments:
|
||||
row = {
|
||||
"experiment_id": exp.experiment_id,
|
||||
"name": exp.config.name,
|
||||
"model_type": exp.config.model_type,
|
||||
"feature_count": len(exp.config.features),
|
||||
"features": ", ".join([f.value for f in exp.config.features]),
|
||||
"train_size": exp.train_size,
|
||||
"test_size": exp.test_size,
|
||||
**exp.test_metrics,
|
||||
}
|
||||
analysis_data.append(row)
|
||||
return analysis_data
|
||||
|
||||
def _display_performance_trends(self, analysis_df: pd.DataFrame):
|
||||
"""Display performance trend charts"""
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
# Accuracy vs Training Size
|
||||
if "accuracy" in analysis_df.columns and "train_size" in analysis_df.columns:
|
||||
fig = px.scatter(
|
||||
analysis_df,
|
||||
x="train_size",
|
||||
y="accuracy",
|
||||
color="model_type",
|
||||
hover_data=["name"],
|
||||
title="Accuracy vs Training Size",
|
||||
)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
with col2:
|
||||
# Feature Count vs Performance
|
||||
if "accuracy" in analysis_df.columns and "feature_count" in analysis_df.columns:
|
||||
fig = px.scatter(
|
||||
analysis_df,
|
||||
x="feature_count",
|
||||
y="accuracy",
|
||||
color="model_type",
|
||||
hover_data=["name"],
|
||||
title="Accuracy vs Number of Features",
|
||||
)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
def _display_model_comparison(self, analysis_df: pd.DataFrame):
|
||||
"""Display model type comparison"""
|
||||
if "accuracy" in analysis_df.columns:
|
||||
model_performance = (
|
||||
analysis_df.groupby("model_type")["accuracy"]
|
||||
.agg(["mean", "std", "count"])
|
||||
.reset_index()
|
||||
)
|
||||
|
||||
fig = go.Figure()
|
||||
fig.add_trace(
|
||||
go.Bar(
|
||||
x=model_performance["model_type"],
|
||||
y=model_performance["mean"],
|
||||
error_y=dict(type="data", array=model_performance["std"]),
|
||||
name="Average Accuracy",
|
||||
)
|
||||
)
|
||||
fig.update_layout(title="Average Accuracy by Model Type", yaxis_title="Accuracy")
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
def _display_top_experiments(self, analysis_df: pd.DataFrame):
|
||||
"""Display top performing experiments"""
|
||||
st.subheader("Top Performing Experiments")
|
||||
|
||||
if "accuracy" in analysis_df.columns:
|
||||
display_columns = ["name", "model_type", "features", "accuracy"]
|
||||
|
||||
# Add other metrics if available
|
||||
for metric in ["precision", "recall", "f1"]:
|
||||
if metric in analysis_df.columns:
|
||||
display_columns.append(metric)
|
||||
|
||||
top_experiments = analysis_df.nlargest(5, "accuracy")[display_columns]
|
||||
st.dataframe(top_experiments, use_container_width=True)
|
||||
|
||||
def show_model_analysis(self):
|
||||
"""Show detailed model analysis"""
|
||||
st.subheader("Model Analysis")
|
||||
|
||||
experiments = self.experiment_tracker.list_experiments()
|
||||
completed_experiments = [e for e in experiments if e.status.value == "completed"]
|
||||
|
||||
if not completed_experiments:
|
||||
st.warning("No completed experiments found.")
|
||||
return
|
||||
|
||||
# Select experiment for detailed analysis
|
||||
exp_options = {
|
||||
f"{exp.config.name} ({exp.experiment_id[:8]})": exp for exp in completed_experiments
|
||||
}
|
||||
|
||||
selected_exp_name = st.selectbox(
|
||||
"Select Experiment for Detailed Analysis", list(exp_options.keys())
|
||||
)
|
||||
|
||||
if not selected_exp_name:
|
||||
return
|
||||
|
||||
selected_exp = exp_options[selected_exp_name]
|
||||
|
||||
self._display_experiment_details(selected_exp)
|
||||
self._display_confusion_matrix(selected_exp)
|
||||
self._display_feature_importance(selected_exp)
|
||||
self._display_prediction_examples(selected_exp)
|
||||
|
||||
def _display_experiment_details(self, experiment):
|
||||
"""Display experiment configuration and metrics"""
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
st.write("**Experiment Configuration**")
|
||||
st.json(
|
||||
{
|
||||
"name": experiment.config.name,
|
||||
"model_type": experiment.config.model_type,
|
||||
"features": [f.value for f in experiment.config.features],
|
||||
"model_params": experiment.config.model_params,
|
||||
}
|
||||
)
|
||||
|
||||
with col2:
|
||||
st.write("**Performance Metrics**")
|
||||
if experiment.test_metrics:
|
||||
for metric, value in experiment.test_metrics.items():
|
||||
st.metric(metric.title(), f"{value:.4f}")
|
||||
|
||||
def _display_confusion_matrix(self, experiment):
|
||||
"""Display confusion matrix if available"""
|
||||
if experiment.confusion_matrix:
|
||||
st.write("**Confusion Matrix**")
|
||||
cm = np.array(experiment.confusion_matrix)
|
||||
|
||||
fig = px.imshow(cm, text_auto=True, aspect="auto", title="Confusion Matrix")
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
def _display_feature_importance(self, experiment):
|
||||
"""Display feature importance if available"""
|
||||
if experiment.feature_importance:
|
||||
st.write("**Feature Importance**")
|
||||
|
||||
importance_data = sorted(
|
||||
experiment.feature_importance.items(), key=lambda x: x[1], reverse=True
|
||||
)[:20]
|
||||
|
||||
features, importances = zip(*importance_data)
|
||||
|
||||
fig = px.bar(
|
||||
x=list(importances),
|
||||
y=list(features),
|
||||
orientation="h",
|
||||
title="Top 20 Feature Importances",
|
||||
)
|
||||
fig.update_layout(height=600)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
def _display_prediction_examples(self, experiment):
|
||||
"""Display prediction examples if available"""
|
||||
if experiment.prediction_examples:
|
||||
st.write("**Prediction Examples**")
|
||||
|
||||
examples_df = pd.DataFrame(experiment.prediction_examples)
|
||||
|
||||
# Separate correct and incorrect predictions
|
||||
correct_examples = examples_df[examples_df["correct"] == True]
|
||||
incorrect_examples = examples_df[examples_df["correct"] == False]
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
st.write("**Correct Predictions**")
|
||||
if not correct_examples.empty:
|
||||
st.dataframe(
|
||||
correct_examples[["name", "true_label", "predicted_label"]],
|
||||
use_container_width=True,
|
||||
)
|
||||
|
||||
with col2:
|
||||
st.write("**Incorrect Predictions**")
|
||||
if not incorrect_examples.empty:
|
||||
st.dataframe(
|
||||
incorrect_examples[["name", "true_label", "predicted_label"]],
|
||||
use_container_width=True,
|
||||
)
|
||||
Reference in New Issue
Block a user