feat: web application multipage support
This commit is contained in:
@@ -0,0 +1,10 @@
|
||||
import streamlit as st
|
||||
|
||||
|
||||
class Configuration:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def index(self):
|
||||
st.title("Configuration")
|
||||
st.json(self.config.model_dump())
|
||||
@@ -0,0 +1,77 @@
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
|
||||
from core.utils.data_loader import OPTIMIZED_DTYPES
|
||||
|
||||
|
||||
@st.cache_data
|
||||
def load_dataset(file_path: str) -> pd.DataFrame:
|
||||
try:
|
||||
return pd.read_csv(file_path, dtype=OPTIMIZED_DTYPES)
|
||||
except Exception as e:
|
||||
st.error(f"Error loading dataset: {e}")
|
||||
return pd.DataFrame()
|
||||
|
||||
|
||||
class Dashboard:
|
||||
def __init__(self, config, experiment_tracker, experiment_runner):
|
||||
self.config = config
|
||||
self.experiment_tracker = experiment_tracker
|
||||
self.experiment_runner = experiment_runner
|
||||
|
||||
def index(self):
|
||||
st.title("Dashboard")
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
|
||||
# Load basic statistics
|
||||
try:
|
||||
data_path = self.config.paths.get_data_path(self.config.data.output_files["featured"])
|
||||
if data_path.exists():
|
||||
df = load_dataset(str(data_path))
|
||||
|
||||
with col1:
|
||||
st.metric("Total Names", f"{len(df):,}")
|
||||
|
||||
with col2:
|
||||
annotated = (df.get("annotated", 0) == 1).sum()
|
||||
st.metric("Annotated Names", f"{annotated:,}")
|
||||
|
||||
with col3:
|
||||
provinces = df["province"].nunique() if "province" in df.columns else 0
|
||||
st.metric("Provinces", provinces)
|
||||
|
||||
with col4:
|
||||
if "sex" in df.columns:
|
||||
gender_dist = df["sex"].value_counts()
|
||||
ratio = gender_dist.get("f", 0) / max(gender_dist.get("m", 1), 1)
|
||||
st.metric("F/M Ratio", f"{ratio:.2f}")
|
||||
else:
|
||||
st.warning("No processed data found. Please run data processing first.")
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error loading dashboard data: {e}")
|
||||
|
||||
# Recent experiments
|
||||
st.subheader("Recent Experiments")
|
||||
experiments = self.experiment_tracker.list_experiments()[:5]
|
||||
|
||||
if experiments:
|
||||
exp_data = []
|
||||
for exp in experiments:
|
||||
exp_data.append(
|
||||
{
|
||||
"Name": exp.config.name,
|
||||
"Model": exp.config.model_type,
|
||||
"Status": exp.status.value,
|
||||
"Accuracy": (
|
||||
f"{exp.test_metrics.get('accuracy', 0):.3f}"
|
||||
if exp.test_metrics
|
||||
else "N/A"
|
||||
),
|
||||
"Date": exp.start_time.strftime("%Y-%m-%d %H:%M"),
|
||||
}
|
||||
)
|
||||
|
||||
st.dataframe(pd.DataFrame(exp_data), use_container_width=True)
|
||||
else:
|
||||
st.info("No experiments found. Create your first experiment in the Experiments tab!")
|
||||
@@ -0,0 +1,155 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pandas as pd
|
||||
import plotly.express as px
|
||||
import streamlit as st
|
||||
|
||||
from core.utils.data_loader import OPTIMIZED_DTYPES
|
||||
|
||||
|
||||
@st.cache_data
|
||||
def load_dataset(file_path: str) -> pd.DataFrame:
|
||||
try:
|
||||
return pd.read_csv(file_path, dtype=OPTIMIZED_DTYPES)
|
||||
except Exception as e:
|
||||
st.error(f"Error loading dataset: {e}")
|
||||
return pd.DataFrame()
|
||||
|
||||
|
||||
class DataOverview:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def index(self):
|
||||
st.title("Data Overview")
|
||||
data_files = {
|
||||
"Names": self.config.data.input_file,
|
||||
"Featured Dataset": self.config.data.output_files["featured"],
|
||||
"Evaluation Dataset": self.config.data.output_files["evaluation"],
|
||||
"Male Names": self.config.data.output_files["males"],
|
||||
"Female Names": self.config.data.output_files["females"],
|
||||
}
|
||||
|
||||
selected_file = st.selectbox("Select Dataset", list(data_files.keys()))
|
||||
file_path = self.config.paths.get_data_path(data_files[selected_file])
|
||||
|
||||
if not file_path.exists():
|
||||
st.warning(f"Dataset not found: {file_path}")
|
||||
st.warning("Please run data processing first to generate datasets.")
|
||||
return
|
||||
|
||||
# Load and display data
|
||||
df = load_dataset(str(file_path))
|
||||
|
||||
if df.empty:
|
||||
st.error("Failed to load dataset")
|
||||
return
|
||||
|
||||
# Basic statistics
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
|
||||
with col1:
|
||||
st.metric("Total Records", f"{len(df):,}")
|
||||
|
||||
with col2:
|
||||
if "annotated" in df.columns:
|
||||
annotated_pct = (df["annotated"] == 1).mean() * 100
|
||||
st.metric("Annotated", f"{annotated_pct:.1f}%")
|
||||
|
||||
with col3:
|
||||
if "words" in df.columns:
|
||||
avg_words = df["words"].mean()
|
||||
st.metric("Avg Words", f"{avg_words:.1f}")
|
||||
|
||||
with col4:
|
||||
if "length" in df.columns:
|
||||
avg_length = df["length"].mean()
|
||||
st.metric("Avg Length", f"{avg_length:.0f}")
|
||||
|
||||
# Data quality analysis
|
||||
st.subheader("Data Quality Analysis")
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
# Missing values
|
||||
missing_data = df.isnull().sum()
|
||||
if missing_data.sum() > 0:
|
||||
fig = px.bar(
|
||||
x=missing_data.index, y=missing_data.values, title="Missing Values by Column"
|
||||
)
|
||||
fig.update_layout(height=400)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
else:
|
||||
st.success("No missing values found")
|
||||
|
||||
with col2:
|
||||
# Gender distribution
|
||||
if "sex" in df.columns:
|
||||
gender_counts = df["sex"].value_counts()
|
||||
fig = px.pie(
|
||||
values=gender_counts.values,
|
||||
names=gender_counts.index,
|
||||
title="Gender Distribution",
|
||||
)
|
||||
fig.update_layout(height=400)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
# Word count distribution
|
||||
if "words" in df.columns:
|
||||
st.subheader("Name Structure Analysis")
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
word_dist = df["words"].value_counts().sort_index()
|
||||
fig = px.bar(
|
||||
x=word_dist.index,
|
||||
y=word_dist.values,
|
||||
title="Distribution of Word Count in Names",
|
||||
)
|
||||
fig.update_layout(height=400)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
with col2:
|
||||
# Province distribution
|
||||
if "province" in df.columns:
|
||||
province_counts = df["province"].value_counts().head(10)
|
||||
fig = px.bar(
|
||||
x=province_counts.values,
|
||||
y=province_counts.index,
|
||||
orientation="h",
|
||||
title="Top 10 Provinces by Name Count",
|
||||
)
|
||||
fig.update_layout(height=400)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
# Sample data
|
||||
st.subheader("Sample Data")
|
||||
|
||||
# Display columns selector
|
||||
if not df.empty:
|
||||
columns_to_show = st.multiselect(
|
||||
"Select columns to display",
|
||||
df.columns.tolist(),
|
||||
default=(
|
||||
["name", "sex", "province", "words"]
|
||||
if all(col in df.columns for col in ["name", "sex", "province", "words"])
|
||||
else df.columns[:5].tolist()
|
||||
),
|
||||
)
|
||||
|
||||
if columns_to_show:
|
||||
sample_size = st.slider("Number of rows to display", 10, min(1000, len(df)), 50)
|
||||
st.dataframe(df[columns_to_show].head(sample_size), use_container_width=True)
|
||||
|
||||
# Data export
|
||||
st.subheader("Export Data")
|
||||
if st.button("Download as CSV"):
|
||||
csv = df.to_csv(index=False)
|
||||
st.download_button(
|
||||
label="Download CSV",
|
||||
data=csv,
|
||||
file_name=f"{selected_file.lower().replace(' ', '_')}_{datetime.now().strftime('%Y%m%d')}.csv",
|
||||
mime="text/csv",
|
||||
)
|
||||
@@ -0,0 +1,133 @@
|
||||
import pandas as pd
|
||||
import plotly.express as px
|
||||
import streamlit as st
|
||||
|
||||
from core.utils.data_loader import OPTIMIZED_DTYPES
|
||||
from web.interfaces.log_reader import LogReader
|
||||
|
||||
|
||||
@st.cache_data
|
||||
def load_dataset(file_path: str) -> pd.DataFrame:
|
||||
try:
|
||||
return pd.read_csv(file_path, dtype=OPTIMIZED_DTYPES)
|
||||
except Exception as e:
|
||||
st.error(f"Error loading dataset: {e}")
|
||||
return pd.DataFrame()
|
||||
|
||||
|
||||
class DataProcessing:
|
||||
def __init__(self, config, pipeline_monitor):
|
||||
self.config = config
|
||||
self.pipeline_monitor = pipeline_monitor
|
||||
|
||||
def index(self):
|
||||
st.title("Data Processing")
|
||||
status = self.pipeline_monitor.get_pipeline_status()
|
||||
|
||||
# Overall progress
|
||||
overall_progress = status["overall_completion"] / 100
|
||||
st.progress(overall_progress)
|
||||
st.write(f"Overall Progress: {status['overall_completion']:.1f}%")
|
||||
|
||||
# Step details
|
||||
for step_name, step_status in status["steps"].items():
|
||||
with st.expander(f"{step_name.replace('_', ' ').title()} - {step_status['status']}"):
|
||||
col1, col2, col3 = st.columns(3)
|
||||
|
||||
with col1:
|
||||
st.metric("Processed Batches", step_status["processed_batches"])
|
||||
|
||||
with col2:
|
||||
st.metric("Total Batches", step_status["total_batches"])
|
||||
|
||||
with col3:
|
||||
st.metric("Failed Batches", step_status["failed_batches"])
|
||||
|
||||
if step_status["completion_percentage"] > 0:
|
||||
st.progress(step_status["completion_percentage"] / 100)
|
||||
|
||||
# Read actual log entries from the log file
|
||||
st.subheader("Recent Processing Logs")
|
||||
try:
|
||||
log_file_path = self.config.paths.logs_dir / "pipeline.development.log"
|
||||
log_reader = LogReader(log_file_path)
|
||||
|
||||
# Options for filtering logs
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
log_level_filter = st.selectbox(
|
||||
"Filter by Level",
|
||||
["All", "INFO", "WARNING", "ERROR", "DEBUG", "CRITICAL"],
|
||||
key="log_level_filter",
|
||||
)
|
||||
|
||||
with col2:
|
||||
num_entries = st.number_input(
|
||||
"Number of entries", min_value=5, max_value=50, value=10, key="num_log_entries"
|
||||
)
|
||||
|
||||
# Get log entries based on filter
|
||||
if log_level_filter == "All":
|
||||
log_entries = log_reader.read_last_entries(num_entries)
|
||||
else:
|
||||
log_entries = log_reader.read_entries_by_level(log_level_filter, num_entries)
|
||||
|
||||
if log_entries:
|
||||
for entry in log_entries:
|
||||
if entry.level == "ERROR":
|
||||
st.error(
|
||||
f"[{entry.timestamp.strftime('%Y-%m-%d %H:%M:%S')}] {entry.level}: {entry.message}"
|
||||
)
|
||||
elif entry.level == "WARNING":
|
||||
st.warning(
|
||||
f"[{entry.timestamp.strftime('%Y-%m-%d %H:%M:%S')}] {entry.level}: {entry.message}"
|
||||
)
|
||||
elif entry.level == "INFO":
|
||||
st.info(
|
||||
f"[{entry.timestamp.strftime('%Y-%m-%d %H:%M:%S')}] {entry.level}: {entry.message}"
|
||||
)
|
||||
else:
|
||||
st.text(
|
||||
f"[{entry.timestamp.strftime('%Y-%m-%d %H:%M:%S')}] {entry.level}: {entry.message}"
|
||||
)
|
||||
|
||||
# Show log statistics
|
||||
st.subheader("Log Statistics")
|
||||
log_stats = log_reader.get_log_stats()
|
||||
|
||||
if log_stats:
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
|
||||
with col1:
|
||||
st.metric("Total Lines", log_stats.get("total_lines", 0))
|
||||
with col2:
|
||||
st.metric("INFO", log_stats.get("INFO", 0))
|
||||
with col3:
|
||||
st.metric("WARNING", log_stats.get("WARNING", 0))
|
||||
with col4:
|
||||
st.metric("ERROR", log_stats.get("ERROR", 0))
|
||||
|
||||
# Log level distribution chart
|
||||
levels = ["INFO", "WARNING", "ERROR", "DEBUG", "CRITICAL"]
|
||||
counts = [log_stats.get(level, 0) for level in levels]
|
||||
|
||||
if sum(counts) > 0:
|
||||
fig = px.bar(
|
||||
x=levels,
|
||||
y=counts,
|
||||
title="Log Entries by Level",
|
||||
color=levels,
|
||||
color_discrete_map={
|
||||
"INFO": "blue",
|
||||
"WARNING": "orange",
|
||||
"ERROR": "red",
|
||||
"DEBUG": "gray",
|
||||
"CRITICAL": "darkred",
|
||||
},
|
||||
)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
else:
|
||||
st.info("No log entries found or log file is empty.")
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error reading log file: {e}")
|
||||
@@ -0,0 +1,431 @@
|
||||
from typing import List, Dict, Any
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from core.utils.region_mapper import RegionMapper
|
||||
from research.experiment import ExperimentConfig, ExperimentStatus
|
||||
from research.experiment.experiment_builder import ExperimentBuilder
|
||||
from research.experiment.experiment_runner import ExperimentRunner
|
||||
from research.experiment.experiment_tracker import ExperimentTracker
|
||||
from research.experiment.feature_extractor import FeatureType
|
||||
from research.model_registry import list_available_models
|
||||
|
||||
|
||||
class Experiments:
|
||||
def __init__(
|
||||
self, config, experiment_tracker: ExperimentTracker, experiment_runner: ExperimentRunner
|
||||
):
|
||||
self.config = config
|
||||
self.experiment_tracker = experiment_tracker
|
||||
self.experiment_runner = experiment_runner
|
||||
|
||||
def index(self):
|
||||
st.title("Experiments")
|
||||
tab1, tab2, tab3 = st.tabs(["New Experiment", "Experiment List", "Batch Experiments"])
|
||||
|
||||
with tab1:
|
||||
self.show_experiment_creation()
|
||||
|
||||
with tab2:
|
||||
self.show_experiment_list()
|
||||
|
||||
with tab3:
|
||||
self.show_batch_experiments()
|
||||
|
||||
def show_experiment_creation(self):
|
||||
"""Show interface for creating new experiments"""
|
||||
st.subheader("Create New Experiment")
|
||||
|
||||
with st.form("new_experiment"):
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
exp_name = st.text_input(
|
||||
"Experiment Name", placeholder="e.g., native_name_gender_prediction"
|
||||
)
|
||||
description = st.text_area(
|
||||
"Description", placeholder="Brief description of the experiment"
|
||||
)
|
||||
model_type = st.selectbox("Model Type", list_available_models())
|
||||
|
||||
# Feature selection
|
||||
feature_options = [f.value for f in FeatureType]
|
||||
selected_features = st.multiselect(
|
||||
"Features to Use", feature_options, default=["full_name"]
|
||||
)
|
||||
|
||||
with col2:
|
||||
# Model parameters
|
||||
st.write("**Model Parameters**")
|
||||
model_params = {}
|
||||
if model_type == "logistic_regression":
|
||||
ngram_min = st.number_input("N-gram Min", 1, 5, 2)
|
||||
ngram_max = st.number_input("N-gram Max", 2, 8, 5)
|
||||
max_features = st.number_input("Max Features", 1000, 50000, 10000)
|
||||
model_params = {
|
||||
"ngram_range": [ngram_min, ngram_max],
|
||||
"max_features": max_features,
|
||||
}
|
||||
elif model_type == "random_forest":
|
||||
n_estimators = st.number_input("Number of Trees", 10, 500, 100)
|
||||
max_depth = st.number_input("Max Depth", 1, 20, 10)
|
||||
model_params = {
|
||||
"n_estimators": n_estimators,
|
||||
"max_depth": max_depth if max_depth > 0 else None,
|
||||
}
|
||||
|
||||
# Training parameters
|
||||
st.write("**Training Parameters**")
|
||||
test_size = st.slider("Test Set Size", 0.1, 0.5, 0.2)
|
||||
cv_folds = st.number_input("Cross-Validation Folds", 3, 10, 5)
|
||||
|
||||
tags = st.text_input(
|
||||
"Tags (comma-separated)", placeholder="e.g., baseline, feature_study"
|
||||
)
|
||||
|
||||
# Advanced options
|
||||
with st.expander("Advanced Options"):
|
||||
# Data filters
|
||||
st.write("**Data Filters**")
|
||||
filter_province = st.selectbox(
|
||||
"Filter by Province (optional)",
|
||||
["None"] + RegionMapper().get_provinces(),
|
||||
)
|
||||
|
||||
min_words = st.number_input("Minimum Word Count", 0, 10, 0)
|
||||
max_words = st.number_input("Maximum Word Count (0 = no limit)", 0, 20, 0)
|
||||
|
||||
submitted = st.form_submit_button("Create and Run Experiment", type="primary")
|
||||
|
||||
if submitted:
|
||||
self._handle_experiment_submission(
|
||||
exp_name,
|
||||
description,
|
||||
model_type,
|
||||
selected_features,
|
||||
model_params,
|
||||
test_size,
|
||||
cv_folds,
|
||||
tags,
|
||||
filter_province,
|
||||
min_words,
|
||||
max_words,
|
||||
)
|
||||
|
||||
def _handle_experiment_submission(
|
||||
self,
|
||||
exp_name: str,
|
||||
description: str,
|
||||
model_type: str,
|
||||
selected_features: List[str],
|
||||
model_params: Dict[str, Any],
|
||||
test_size: float,
|
||||
cv_folds: int,
|
||||
tags: str,
|
||||
filter_province: str,
|
||||
min_words: int,
|
||||
max_words: int,
|
||||
):
|
||||
"""Handle experiment form submission"""
|
||||
if not exp_name:
|
||||
st.error("Please provide an experiment name")
|
||||
return
|
||||
|
||||
if not selected_features:
|
||||
st.error("Please select at least one feature")
|
||||
return
|
||||
|
||||
try:
|
||||
# Prepare data filters
|
||||
train_filter = {}
|
||||
if filter_province != "None":
|
||||
train_filter["province"] = filter_province
|
||||
if min_words > 0:
|
||||
train_filter["words"] = {"min": min_words}
|
||||
if max_words > 0:
|
||||
if "words" in train_filter:
|
||||
train_filter["words"]["max"] = max_words
|
||||
else:
|
||||
train_filter["words"] = {"max": max_words}
|
||||
|
||||
# Create experiment config
|
||||
features = [FeatureType(f) for f in selected_features]
|
||||
tag_list = [tag.strip() for tag in tags.split(",") if tag.strip()]
|
||||
|
||||
config = ExperimentConfig(
|
||||
name=exp_name,
|
||||
description=description,
|
||||
tags=tag_list,
|
||||
model_type=model_type,
|
||||
model_params=model_params,
|
||||
features=features,
|
||||
train_data_filter=train_filter if train_filter else None,
|
||||
test_size=test_size,
|
||||
cross_validation_folds=cv_folds,
|
||||
)
|
||||
|
||||
# Run experiment
|
||||
with st.spinner("Running experiment..."):
|
||||
experiment_id = self.experiment_runner.run_experiment(config)
|
||||
|
||||
st.success(f"Experiment completed successfully!")
|
||||
st.info(f"Experiment ID: `{experiment_id}`")
|
||||
|
||||
# Show results
|
||||
experiment = self.experiment_tracker.get_experiment(experiment_id)
|
||||
if experiment and experiment.test_metrics:
|
||||
st.write("**Results:**")
|
||||
for metric, value in experiment.test_metrics.items():
|
||||
st.metric(metric.title(), f"{value:.4f}")
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error running experiment: {e}")
|
||||
|
||||
def show_experiment_list(self):
|
||||
"""Show list of all experiments with filtering"""
|
||||
st.subheader("All Experiments")
|
||||
|
||||
# Filters
|
||||
col1, col2, col3 = st.columns(3)
|
||||
|
||||
with col1:
|
||||
status_filter = st.selectbox(
|
||||
"Filter by Status", ["All", "completed", "running", "failed", "pending"]
|
||||
)
|
||||
|
||||
with col2:
|
||||
model_filter = st.selectbox("Filter by Model", ["All"] + list_available_models())
|
||||
|
||||
with col3:
|
||||
tag_filter = st.text_input("Filter by Tags (comma-separated)")
|
||||
|
||||
# Get and filter experiments
|
||||
experiments = self._get_filtered_experiments(status_filter, model_filter, tag_filter)
|
||||
|
||||
if not experiments:
|
||||
st.info("No experiments found matching the filters.")
|
||||
return
|
||||
|
||||
# Display experiments
|
||||
for i, exp in enumerate(experiments):
|
||||
with st.expander(
|
||||
f"{exp.config.name} - {exp.status.value} - {exp.start_time.strftime('%Y-%m-%d %H:%M')}"
|
||||
):
|
||||
self._display_experiment_details(exp, i)
|
||||
|
||||
def _get_filtered_experiments(self, status_filter: str, model_filter: str, tag_filter: str):
|
||||
"""Get experiments with applied filters"""
|
||||
experiments = self.experiment_tracker.list_experiments()
|
||||
|
||||
# Apply filters
|
||||
if status_filter != "All":
|
||||
experiments = [e for e in experiments if e.status == ExperimentStatus(status_filter)]
|
||||
|
||||
if model_filter != "All":
|
||||
experiments = [e for e in experiments if e.config.model_type == model_filter]
|
||||
|
||||
if tag_filter:
|
||||
tags = [tag.strip() for tag in tag_filter.split(",")]
|
||||
experiments = [e for e in experiments if any(tag in e.config.tags for tag in tags)]
|
||||
|
||||
return experiments
|
||||
|
||||
def _display_experiment_details(self, exp, index: int):
|
||||
"""Display details for a single experiment"""
|
||||
col1, col2, col3 = st.columns(3)
|
||||
|
||||
with col1:
|
||||
st.write(f"**Model:** {exp.config.model_type}")
|
||||
st.write(f"**Features:** {', '.join([f.value for f in exp.config.features])}")
|
||||
st.write(f"**Tags:** {', '.join(exp.config.tags)}")
|
||||
|
||||
with col2:
|
||||
if exp.test_metrics:
|
||||
for metric, value in exp.test_metrics.items():
|
||||
st.metric(metric.title(), f"{value:.4f}")
|
||||
|
||||
with col3:
|
||||
st.write(f"**Train Size:** {exp.train_size:,}")
|
||||
st.write(f"**Test Size:** {exp.test_size:,}")
|
||||
|
||||
if st.button(f"View Details", key=f"details_{index}"):
|
||||
st.session_state.selected_experiment = exp.experiment_id
|
||||
st.rerun()
|
||||
|
||||
if exp.config.description:
|
||||
st.write(f"**Description:** {exp.config.description}")
|
||||
|
||||
def show_batch_experiments(self):
|
||||
"""Show interface for running batch experiments"""
|
||||
st.subheader("Batch Experiments")
|
||||
st.write("Run multiple experiments with different parameter combinations.")
|
||||
|
||||
# Parameter sweep configuration
|
||||
with st.form("batch_experiments"):
|
||||
st.write("**Parameter Sweep Configuration**")
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
base_name = st.text_input("Base Experiment Name", "parameter_sweep")
|
||||
model_types = st.multiselect(
|
||||
"Model Types", list_available_models(), default=["logistic_regression"]
|
||||
)
|
||||
|
||||
# N-gram ranges for logistic regression
|
||||
st.write("**Logistic Regression Parameters**")
|
||||
ngram_ranges = st.text_area(
|
||||
"N-gram Ranges (one per line, format: min,max)", "2,4\n2,5\n3,6"
|
||||
)
|
||||
|
||||
with col2:
|
||||
feature_combinations = st.multiselect(
|
||||
"Feature Combinations",
|
||||
[f.value for f in FeatureType],
|
||||
default=["full_name", "native_name", "surname"],
|
||||
)
|
||||
|
||||
test_sizes = st.text_input("Test Sizes (comma-separated)", "0.15,0.2,0.25")
|
||||
|
||||
tags = st.text_input("Common Tags", "parameter_sweep,batch")
|
||||
|
||||
if st.form_submit_button("🚀 Run Batch Experiments"):
|
||||
self.run_batch_experiments(
|
||||
base_name, model_types, ngram_ranges, feature_combinations, test_sizes, tags
|
||||
)
|
||||
|
||||
def run_batch_experiments(
|
||||
self,
|
||||
base_name: str,
|
||||
model_types: List[str],
|
||||
ngram_ranges: str,
|
||||
feature_combinations: List[str],
|
||||
test_sizes: str,
|
||||
tags: str,
|
||||
):
|
||||
"""Run batch experiments with parameter combinations"""
|
||||
with st.spinner("Running batch experiments..."):
|
||||
try:
|
||||
experiments = []
|
||||
|
||||
# Parse parameters
|
||||
ngram_list = []
|
||||
for line in ngram_ranges.strip().split("\n"):
|
||||
if "," in line:
|
||||
min_val, max_val = map(int, line.split(","))
|
||||
ngram_list.append([min_val, max_val])
|
||||
|
||||
test_size_list = [float(x.strip()) for x in test_sizes.split(",")]
|
||||
tag_list = [tag.strip() for tag in tags.split(",") if tag.strip()]
|
||||
|
||||
# Generate experiment combinations
|
||||
exp_count = 0
|
||||
for model_type in model_types:
|
||||
for feature_combo in feature_combinations:
|
||||
for test_size in test_size_list:
|
||||
if model_type == "logistic_regression":
|
||||
for ngram_range in ngram_list:
|
||||
exp_name = f"{base_name}_{model_type}_{feature_combo}_{ngram_range[0]}_{ngram_range[1]}_{test_size}"
|
||||
|
||||
config = ExperimentConfig(
|
||||
name=exp_name,
|
||||
description=f"Batch experiment: {model_type} with {feature_combo}",
|
||||
model_type=model_type,
|
||||
features=[FeatureType(feature_combo)],
|
||||
model_params={"ngram_range": ngram_range},
|
||||
test_size=test_size,
|
||||
tags=tag_list,
|
||||
)
|
||||
experiments.append(config)
|
||||
exp_count += 1
|
||||
else:
|
||||
exp_name = f"{base_name}_{model_type}_{feature_combo}_{test_size}"
|
||||
|
||||
config = ExperimentConfig(
|
||||
name=exp_name,
|
||||
description=f"Batch experiment: {model_type} with {feature_combo}",
|
||||
model_type=model_type,
|
||||
features=[FeatureType(feature_combo)],
|
||||
test_size=test_size,
|
||||
tags=tag_list,
|
||||
)
|
||||
experiments.append(config)
|
||||
exp_count += 1
|
||||
|
||||
# Run experiments
|
||||
experiment_ids = self.experiment_runner.run_experiment_batch(experiments)
|
||||
|
||||
st.success(f"Completed {len(experiment_ids)} batch experiments")
|
||||
|
||||
# Show summary
|
||||
if experiment_ids:
|
||||
comparison = self.experiment_runner.compare_experiments(experiment_ids)
|
||||
st.write("**Batch Results Summary:**")
|
||||
st.dataframe(
|
||||
comparison[["name", "model_type", "test_accuracy"]],
|
||||
use_container_width=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error running batch experiments: {e}")
|
||||
|
||||
def run_baseline_experiments(self):
|
||||
"""Run baseline experiments"""
|
||||
with st.spinner("Running baseline experiments..."):
|
||||
try:
|
||||
builder = ExperimentBuilder()
|
||||
experiments = builder.create_baseline_experiments()
|
||||
experiment_ids = self.experiment_runner.run_experiment_batch(experiments)
|
||||
|
||||
st.success(f"Completed {len(experiment_ids)} baseline experiments")
|
||||
|
||||
# Show quick comparison
|
||||
if experiment_ids:
|
||||
comparison = self.experiment_runner.compare_experiments(experiment_ids)
|
||||
st.write("**Results Summary:**")
|
||||
st.dataframe(
|
||||
comparison[["name", "model_type", "test_accuracy"]],
|
||||
use_container_width=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error running baseline experiments: {e}")
|
||||
|
||||
def run_ablation_study(self):
|
||||
"""Run feature ablation study"""
|
||||
with st.spinner("Running ablation study..."):
|
||||
try:
|
||||
builder = ExperimentBuilder()
|
||||
experiments = builder.create_feature_ablation_study()
|
||||
experiment_ids = self.experiment_runner.run_experiment_batch(experiments)
|
||||
|
||||
st.success(f"Completed {len(experiment_ids)} ablation experiments")
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error running ablation study: {e}")
|
||||
|
||||
def run_component_study(self):
|
||||
"""Run name component study"""
|
||||
with st.spinner("Running component study..."):
|
||||
try:
|
||||
builder = ExperimentBuilder()
|
||||
experiments = builder.create_name_component_study()
|
||||
experiment_ids = self.experiment_runner.run_experiment_batch(experiments)
|
||||
|
||||
st.success(f"Completed {len(experiment_ids)} component experiments")
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error running component study: {e}")
|
||||
|
||||
def run_province_study(self):
|
||||
"""Run province-specific study"""
|
||||
with st.spinner("Running province study..."):
|
||||
try:
|
||||
builder = ExperimentBuilder()
|
||||
experiments = builder.create_province_specific_study()
|
||||
experiment_ids = self.experiment_runner.run_experiment_batch(experiments)
|
||||
|
||||
st.success(f"Completed {len(experiment_ids)} province experiments")
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error running province study: {e}")
|
||||
@@ -0,0 +1,182 @@
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogEntry:
|
||||
"""Represents a single log entry."""
|
||||
|
||||
timestamp: datetime
|
||||
logger: str
|
||||
level: str
|
||||
message: str
|
||||
raw_line: str
|
||||
|
||||
|
||||
class LogReader:
|
||||
"""Utility class for reading and parsing log files."""
|
||||
|
||||
def __init__(self, log_file_path: Path):
|
||||
"""Initialize the log reader with a log file path."""
|
||||
self.log_file_path = Path(log_file_path)
|
||||
# Pattern to match Python logging format: timestamp - logger - level - message
|
||||
self.log_pattern = re.compile(
|
||||
r"(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3}) - (.+?) - (\w+) - (.+)"
|
||||
)
|
||||
|
||||
def read_last_entries(self, count: int = 10) -> List[LogEntry]:
|
||||
"""Read the last N entries from the log file."""
|
||||
if not self.log_file_path.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(self.log_file_path, "r", encoding="utf-8") as file:
|
||||
lines = file.readlines()
|
||||
|
||||
# Parse log entries from the end
|
||||
entries = []
|
||||
for line in reversed(lines[-count * 2 :]): # Read more lines in case some don't match
|
||||
entry = self._parse_log_line(line.strip())
|
||||
if entry:
|
||||
entries.append(entry)
|
||||
if len(entries) >= count:
|
||||
break
|
||||
|
||||
# Return entries in chronological order (oldest first of the last N)
|
||||
return list(reversed(entries))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading log file: {e}")
|
||||
return []
|
||||
|
||||
def read_entries_by_level(self, level: str, count: int = 50) -> List[LogEntry]:
|
||||
"""Read entries filtered by log level."""
|
||||
if not self.log_file_path.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(self.log_file_path, "r", encoding="utf-8") as file:
|
||||
lines = file.readlines()
|
||||
|
||||
entries = []
|
||||
for line in reversed(lines):
|
||||
entry = self._parse_log_line(line.strip())
|
||||
if entry and entry.level.upper() == level.upper():
|
||||
entries.append(entry)
|
||||
if len(entries) >= count:
|
||||
break
|
||||
|
||||
return list(reversed(entries))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading log file: {e}")
|
||||
return []
|
||||
|
||||
def read_entries_since(self, since: datetime, count: int = 100) -> List[LogEntry]:
|
||||
"""Read entries since a specific datetime."""
|
||||
if not self.log_file_path.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(self.log_file_path, "r", encoding="utf-8") as file:
|
||||
lines = file.readlines()
|
||||
|
||||
entries = []
|
||||
for line in reversed(lines):
|
||||
entry = self._parse_log_line(line.strip())
|
||||
if entry:
|
||||
if entry.timestamp >= since:
|
||||
entries.append(entry)
|
||||
else:
|
||||
# Stop reading if we've gone past the since time
|
||||
break
|
||||
if len(entries) >= count:
|
||||
break
|
||||
|
||||
return list(reversed(entries))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading log file: {e}")
|
||||
return []
|
||||
|
||||
def get_log_stats(self) -> Dict[str, int]:
|
||||
"""Get statistics about the log file."""
|
||||
if not self.log_file_path.exists():
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(self.log_file_path, "r", encoding="utf-8") as file:
|
||||
lines = file.readlines()
|
||||
|
||||
stats = {
|
||||
"total_lines": len(lines),
|
||||
"INFO": 0,
|
||||
"WARNING": 0,
|
||||
"ERROR": 0,
|
||||
"DEBUG": 0,
|
||||
"CRITICAL": 0,
|
||||
}
|
||||
|
||||
for line in lines:
|
||||
entry = self._parse_log_line(line.strip())
|
||||
if entry:
|
||||
level = entry.level.upper()
|
||||
if level in stats:
|
||||
stats[level] += 1
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading log file: {e}")
|
||||
return {}
|
||||
|
||||
def _parse_log_line(self, line: str) -> Optional[LogEntry]:
|
||||
"""Parse a single log line into a LogEntry object."""
|
||||
if not line:
|
||||
return None
|
||||
|
||||
match = self.log_pattern.match(line)
|
||||
if not match:
|
||||
return None
|
||||
|
||||
try:
|
||||
timestamp_str, logger, level, message = match.groups()
|
||||
timestamp = datetime.strptime(timestamp_str, "%Y-%m-%d %H:%M:%S,%f")
|
||||
|
||||
return LogEntry(
|
||||
timestamp=timestamp, logger=logger, level=level, message=message, raw_line=line
|
||||
)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
class MultiLogReader:
|
||||
"""Reader for multiple log files."""
|
||||
|
||||
def __init__(self, log_directory: Path):
|
||||
"""Initialize with a directory containing log files."""
|
||||
self.log_directory = Path(log_directory)
|
||||
|
||||
def get_available_log_files(self) -> List[Path]:
|
||||
"""Get list of available log files."""
|
||||
if not self.log_directory.exists():
|
||||
return []
|
||||
|
||||
return list(self.log_directory.glob("*.log"))
|
||||
|
||||
def read_from_all_files(self, count: int = 10) -> List[LogEntry]:
|
||||
"""Read entries from all log files and merge them chronologically."""
|
||||
all_entries = []
|
||||
|
||||
for log_file in self.get_available_log_files():
|
||||
reader = LogReader(log_file)
|
||||
entries = reader.read_last_entries(count)
|
||||
all_entries.extend(entries)
|
||||
|
||||
# Sort by timestamp
|
||||
all_entries.sort(key=lambda x: x.timestamp, reverse=True)
|
||||
|
||||
return all_entries[:count]
|
||||
@@ -0,0 +1,374 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import plotly.express as px
|
||||
import streamlit as st
|
||||
|
||||
from core.utils.data_loader import OPTIMIZED_DTYPES
|
||||
from research.experiment.experiment_runner import ExperimentRunner
|
||||
from research.experiment.experiment_tracker import ExperimentTracker
|
||||
|
||||
|
||||
class Predictions:
|
||||
def __init__(
|
||||
self, config, experiment_tracker: ExperimentTracker, experiment_runner: ExperimentRunner
|
||||
):
|
||||
self.config = config
|
||||
self.experiment_tracker = experiment_tracker
|
||||
self.experiment_runner = experiment_runner
|
||||
|
||||
def index(self):
|
||||
st.title("Predictions")
|
||||
|
||||
# Load available models
|
||||
experiments = self.experiment_tracker.list_experiments()
|
||||
completed_experiments = [
|
||||
e for e in experiments if e.status.value == "completed" and e.model_path
|
||||
]
|
||||
|
||||
if not completed_experiments:
|
||||
st.warning("No trained models available. Please run some experiments first.")
|
||||
return
|
||||
|
||||
# Model selection
|
||||
model_options = {
|
||||
f"{exp.config.name} (Acc: {exp.test_metrics.get('accuracy', 0):.3f})": exp
|
||||
for exp in completed_experiments
|
||||
if exp.test_metrics
|
||||
}
|
||||
|
||||
selected_model_name = st.selectbox("Select Model", list(model_options.keys()))
|
||||
|
||||
if not selected_model_name:
|
||||
return
|
||||
|
||||
selected_experiment = model_options[selected_model_name]
|
||||
|
||||
# Prediction modes
|
||||
prediction_mode = st.radio(
|
||||
"Prediction Mode", ["Single Name", "Batch Upload", "Dataset Prediction"]
|
||||
)
|
||||
|
||||
if prediction_mode == "Single Name":
|
||||
self.show_single_prediction(selected_experiment)
|
||||
elif prediction_mode == "Batch Upload":
|
||||
self.show_batch_prediction(selected_experiment)
|
||||
elif prediction_mode == "Dataset Prediction":
|
||||
self.show_dataset_prediction(selected_experiment)
|
||||
|
||||
def show_single_prediction(self, experiment):
|
||||
"""Show single name prediction interface"""
|
||||
st.subheader("Single Name Prediction")
|
||||
|
||||
name_input = st.text_input("Enter a name:", placeholder="e.g., Jean Baptiste Mukendi")
|
||||
|
||||
if name_input and st.button("Predict Gender"):
|
||||
try:
|
||||
# Load the model
|
||||
model = self.experiment_runner.load_experiment_model(experiment.experiment_id)
|
||||
|
||||
if model is None:
|
||||
st.error("Failed to load model")
|
||||
return
|
||||
|
||||
# Create a DataFrame with the input
|
||||
input_df = self._prepare_single_input(name_input)
|
||||
|
||||
# Make prediction
|
||||
prediction = model.predict(input_df)[0]
|
||||
|
||||
# Get prediction probability if available
|
||||
confidence = self._get_prediction_confidence(model, input_df)
|
||||
|
||||
# Display results
|
||||
self._display_single_prediction_results(
|
||||
prediction, confidence, experiment, name_input
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error making prediction: {e}")
|
||||
|
||||
def _prepare_single_input(self, name_input: str) -> pd.DataFrame:
|
||||
"""Prepare single name input for prediction"""
|
||||
return pd.DataFrame(
|
||||
{
|
||||
"name": [name_input],
|
||||
"words": [len(name_input.split())],
|
||||
"length": [len(name_input.replace(" ", ""))],
|
||||
"province": ["unknown"], # Default values
|
||||
"identified_name": [None],
|
||||
"identified_surname": [None],
|
||||
"probable_native": [None],
|
||||
"probable_surname": [None],
|
||||
}
|
||||
)
|
||||
|
||||
def _get_prediction_confidence(self, model, input_df: pd.DataFrame) -> Optional[float]:
|
||||
"""Get prediction confidence if available"""
|
||||
try:
|
||||
probabilities = model.predict_proba(input_df)[0]
|
||||
return max(probabilities)
|
||||
except:
|
||||
return None
|
||||
|
||||
def _display_single_prediction_results(
|
||||
self, prediction: str, confidence: Optional[float], experiment, name_input: str
|
||||
):
|
||||
"""Display single prediction results"""
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
gender_label = "Female" if prediction == "f" else "Male"
|
||||
st.success(f"**Predicted Gender:** {gender_label}")
|
||||
|
||||
with col2:
|
||||
if confidence:
|
||||
st.metric("Confidence", f"{confidence:.2%}")
|
||||
|
||||
# Additional info
|
||||
st.info(f"Model used: {experiment.config.name}")
|
||||
st.info(f"Features used: {', '.join([f.value for f in experiment.config.features])}")
|
||||
|
||||
def show_batch_prediction(self, experiment):
|
||||
"""Show batch prediction interface"""
|
||||
st.subheader("Batch Prediction")
|
||||
|
||||
uploaded_file = st.file_uploader("Upload CSV file with names", type="csv")
|
||||
|
||||
if uploaded_file is not None:
|
||||
try:
|
||||
df = pd.read_csv(uploaded_file, dtype=OPTIMIZED_DTYPES)
|
||||
|
||||
st.write("**Uploaded Data Preview:**")
|
||||
st.dataframe(df.head(), use_container_width=True)
|
||||
|
||||
# Column selection
|
||||
df = self._prepare_batch_data(df)
|
||||
|
||||
if st.button("Run Batch Prediction"):
|
||||
self._run_batch_prediction(df, experiment)
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error processing file: {e}")
|
||||
|
||||
def _prepare_batch_data(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Prepare batch data for prediction"""
|
||||
# Column selection
|
||||
if "name" not in df.columns:
|
||||
name_column = st.selectbox("Select the name column:", df.columns)
|
||||
df = df.rename(columns={name_column: "name"})
|
||||
|
||||
# Add missing columns with defaults
|
||||
required_columns = [
|
||||
"words",
|
||||
"length",
|
||||
"province",
|
||||
"identified_name",
|
||||
"identified_surname",
|
||||
"probable_native",
|
||||
"probable_surname",
|
||||
]
|
||||
|
||||
for col in required_columns:
|
||||
if col not in df.columns:
|
||||
if col == "words":
|
||||
df[col] = df["name"].str.split().str.len()
|
||||
elif col == "length":
|
||||
df[col] = df["name"].str.replace(" ", "").str.len()
|
||||
else:
|
||||
df[col] = None
|
||||
|
||||
return df
|
||||
|
||||
def _run_batch_prediction(self, df: pd.DataFrame, experiment):
|
||||
"""Run batch prediction and display results"""
|
||||
with st.spinner("Making predictions..."):
|
||||
# Load model
|
||||
model = self.experiment_runner.load_experiment_model(experiment.experiment_id)
|
||||
|
||||
if model is None:
|
||||
st.error("Failed to load model")
|
||||
return
|
||||
|
||||
# Make predictions
|
||||
predictions = model.predict(df)
|
||||
df["predicted_gender"] = predictions
|
||||
df["gender_label"] = df["predicted_gender"].map({"f": "Female", "m": "Male"})
|
||||
|
||||
# Try to get probabilities
|
||||
try:
|
||||
probabilities = model.predict_proba(df)
|
||||
df["confidence"] = np.max(probabilities, axis=1)
|
||||
except:
|
||||
df["confidence"] = None
|
||||
|
||||
st.success("Predictions completed!")
|
||||
|
||||
# Show results
|
||||
self._display_batch_results(df)
|
||||
|
||||
def _display_batch_results(self, df: pd.DataFrame):
|
||||
"""Display batch prediction results"""
|
||||
result_columns = ["name", "gender_label", "predicted_gender"]
|
||||
if "confidence" in df.columns:
|
||||
result_columns.append("confidence")
|
||||
|
||||
st.dataframe(df[result_columns], use_container_width=True)
|
||||
|
||||
# Download results
|
||||
csv = df.to_csv(index=False)
|
||||
st.download_button(
|
||||
label="Download Predictions",
|
||||
data=csv,
|
||||
file_name=f"predictions_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
|
||||
mime="text/csv",
|
||||
)
|
||||
|
||||
# Summary statistics
|
||||
self._display_batch_summary(df)
|
||||
|
||||
def _display_batch_summary(self, df: pd.DataFrame):
|
||||
"""Display batch prediction summary"""
|
||||
st.subheader("Prediction Summary")
|
||||
gender_counts = df["gender_label"].value_counts()
|
||||
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.metric("Total Predictions", len(df))
|
||||
with col2:
|
||||
st.metric("Female", gender_counts.get("Female", 0))
|
||||
with col3:
|
||||
st.metric("Male", gender_counts.get("Male", 0))
|
||||
|
||||
# Gender distribution chart
|
||||
fig = px.pie(
|
||||
values=gender_counts.values,
|
||||
names=gender_counts.index,
|
||||
title="Predicted Gender Distribution",
|
||||
)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
def show_dataset_prediction(self, experiment):
|
||||
"""Show dataset prediction interface"""
|
||||
st.subheader("Dataset Prediction")
|
||||
st.write("Apply the model to existing datasets")
|
||||
|
||||
# Dataset selection
|
||||
dataset_options = {
|
||||
"Featured Dataset": self.config.data.output_files["featured"],
|
||||
"Evaluation Dataset": self.config.data.output_files["evaluation"],
|
||||
}
|
||||
|
||||
selected_dataset = st.selectbox("Select Dataset", list(dataset_options.keys()))
|
||||
file_path = self.config.paths.get_data_path(dataset_options[selected_dataset])
|
||||
|
||||
if not file_path.exists():
|
||||
st.warning(f"Dataset not found: {file_path}")
|
||||
return
|
||||
|
||||
# Load and show dataset info
|
||||
df = self._load_dataset(str(file_path))
|
||||
if df.empty:
|
||||
return
|
||||
|
||||
st.write(f"Dataset contains {len(df):,} records")
|
||||
|
||||
# Prediction options
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
sample_size = st.number_input(
|
||||
"Sample size (0 = all data)", 0, len(df), min(1000, len(df))
|
||||
)
|
||||
|
||||
with col2:
|
||||
compare_with_actual = False
|
||||
if "sex" in df.columns:
|
||||
compare_with_actual = st.checkbox("Compare with actual labels", value=True)
|
||||
|
||||
if st.button("Run Dataset Prediction"):
|
||||
self._run_dataset_prediction(df, experiment, sample_size, compare_with_actual)
|
||||
|
||||
def _load_dataset(self, file_path: str) -> pd.DataFrame:
|
||||
"""Load dataset with error handling"""
|
||||
try:
|
||||
return pd.read_csv(file_path, dtype=OPTIMIZED_DTYPES)
|
||||
except Exception as e:
|
||||
st.error(f"Error loading dataset: {e}")
|
||||
return pd.DataFrame()
|
||||
|
||||
def _run_dataset_prediction(
|
||||
self, df: pd.DataFrame, experiment, sample_size: int, compare_with_actual: bool
|
||||
):
|
||||
"""Run dataset prediction and display results"""
|
||||
with st.spinner("Running predictions..."):
|
||||
# Sample data if requested
|
||||
if sample_size > 0:
|
||||
df_sample = df.sample(n=sample_size, random_state=42)
|
||||
else:
|
||||
df_sample = df
|
||||
|
||||
# Load model and make predictions
|
||||
model = self.experiment_runner.load_experiment_model(experiment.experiment_id)
|
||||
|
||||
if model is None:
|
||||
st.error("Failed to load model")
|
||||
return
|
||||
|
||||
predictions = model.predict(df_sample)
|
||||
df_sample["predicted_gender"] = predictions
|
||||
|
||||
# Show results
|
||||
if compare_with_actual and "sex" in df_sample.columns:
|
||||
self._display_dataset_comparison(df_sample)
|
||||
else:
|
||||
self._display_dataset_predictions(df_sample)
|
||||
|
||||
def _display_dataset_comparison(self, df_sample: pd.DataFrame):
|
||||
"""Display dataset predictions with actual comparison"""
|
||||
# Calculate accuracy
|
||||
accuracy = (df_sample["sex"] == df_sample["predicted_gender"]).mean()
|
||||
st.metric("Accuracy on Selected Data", f"{accuracy:.4f}")
|
||||
|
||||
# Confusion matrix
|
||||
from sklearn.metrics import confusion_matrix
|
||||
|
||||
cm = confusion_matrix(df_sample["sex"], df_sample["predicted_gender"])
|
||||
|
||||
fig = px.imshow(cm, text_auto=True, aspect="auto", title="Confusion Matrix")
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
# Sample of correct and incorrect predictions
|
||||
correct_mask = df_sample["sex"] == df_sample["predicted_gender"]
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
st.write("**Sample Correct Predictions**")
|
||||
correct_sample = df_sample[correct_mask][["name", "sex", "predicted_gender"]].head(10)
|
||||
st.dataframe(correct_sample, use_container_width=True)
|
||||
|
||||
with col2:
|
||||
st.write("**Sample Incorrect Predictions**")
|
||||
incorrect_sample = df_sample[~correct_mask][["name", "sex", "predicted_gender"]].head(
|
||||
10
|
||||
)
|
||||
st.dataframe(incorrect_sample, use_container_width=True)
|
||||
|
||||
def _display_dataset_predictions(self, df_sample: pd.DataFrame):
|
||||
"""Display dataset predictions without comparison"""
|
||||
# Just show predictions
|
||||
st.write("**Sample Predictions**")
|
||||
sample_results = df_sample[["name", "predicted_gender"]].head(20)
|
||||
st.dataframe(sample_results, use_container_width=True)
|
||||
|
||||
# Gender distribution
|
||||
gender_counts = df_sample["predicted_gender"].value_counts()
|
||||
fig = px.pie(
|
||||
values=gender_counts.values,
|
||||
names=gender_counts.index,
|
||||
title="Predicted Gender Distribution",
|
||||
)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
@@ -0,0 +1,333 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import plotly.express as px
|
||||
import plotly.graph_objects as go
|
||||
import streamlit as st
|
||||
|
||||
from research.experiment.experiment_runner import ExperimentRunner
|
||||
from research.experiment.experiment_tracker import ExperimentTracker
|
||||
|
||||
|
||||
class ResultsAnalysis:
|
||||
def __init__(
|
||||
self, config, experiment_tracker: ExperimentTracker, experiment_runner: ExperimentRunner
|
||||
):
|
||||
self.config = config
|
||||
self.experiment_tracker = experiment_tracker
|
||||
self.experiment_runner = experiment_runner
|
||||
|
||||
def index(self):
|
||||
st.title("Results & Analysis")
|
||||
tab1, tab2, tab3 = st.tabs(
|
||||
["Experiment Comparison", "Performance Analysis", "Model Analysis"]
|
||||
)
|
||||
|
||||
with tab1:
|
||||
self.show_experiment_comparison()
|
||||
|
||||
with tab2:
|
||||
self.show_performance_analysis()
|
||||
|
||||
with tab3:
|
||||
self.show_model_analysis()
|
||||
|
||||
def show_experiment_comparison(self):
|
||||
"""Show experiment comparison interface"""
|
||||
st.subheader("Compare Experiments")
|
||||
|
||||
experiments = self.experiment_tracker.list_experiments()
|
||||
completed_experiments = [e for e in experiments if e.status.value == "completed"]
|
||||
|
||||
if not completed_experiments:
|
||||
st.warning("No completed experiments found.")
|
||||
return
|
||||
|
||||
# Experiment selection
|
||||
exp_options = {
|
||||
f"{exp.config.name} ({exp.experiment_id[:8]})": exp.experiment_id
|
||||
for exp in completed_experiments
|
||||
}
|
||||
|
||||
selected_exp_names = st.multiselect(
|
||||
"Select Experiments to Compare",
|
||||
list(exp_options.keys()),
|
||||
default=list(exp_options.keys())[: min(5, len(exp_options))],
|
||||
)
|
||||
|
||||
if not selected_exp_names:
|
||||
st.info("Please select experiments to compare.")
|
||||
return
|
||||
|
||||
selected_exp_ids = [exp_options[name] for name in selected_exp_names]
|
||||
|
||||
# Generate comparison
|
||||
comparison_df = self.experiment_runner.compare_experiments(selected_exp_ids)
|
||||
|
||||
if comparison_df.empty:
|
||||
st.error("No data available for comparison.")
|
||||
return
|
||||
|
||||
self._display_comparison_table(comparison_df)
|
||||
self._display_comparison_charts(comparison_df)
|
||||
|
||||
def _display_comparison_table(self, comparison_df: pd.DataFrame):
|
||||
"""Display comparison table"""
|
||||
st.write("**Experiment Comparison Table**")
|
||||
|
||||
# Select columns to display
|
||||
metric_columns = [
|
||||
col for col in comparison_df.columns if col.startswith("test_") or col.startswith("cv_")
|
||||
]
|
||||
display_columns = ["name", "model_type", "features"] + metric_columns
|
||||
available_columns = [col for col in display_columns if col in comparison_df.columns]
|
||||
|
||||
st.dataframe(comparison_df[available_columns], use_container_width=True)
|
||||
|
||||
def _display_comparison_charts(self, comparison_df: pd.DataFrame):
|
||||
"""Display comparison charts"""
|
||||
st.write("**Performance Comparison**")
|
||||
|
||||
if "test_accuracy" in comparison_df.columns:
|
||||
fig = px.bar(
|
||||
comparison_df,
|
||||
x="name",
|
||||
y="test_accuracy",
|
||||
color="model_type",
|
||||
title="Test Accuracy Comparison",
|
||||
)
|
||||
fig.update_layout(xaxis_tickangle=-45)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
# Metric comparison across multiple metrics
|
||||
metric_columns = [
|
||||
col for col in comparison_df.columns if col.startswith("test_") or col.startswith("cv_")
|
||||
]
|
||||
|
||||
if len(metric_columns) > 1:
|
||||
metric_to_plot = st.selectbox("Select Metric for Detailed Comparison", metric_columns)
|
||||
|
||||
if metric_to_plot in comparison_df.columns:
|
||||
fig = px.bar(
|
||||
comparison_df,
|
||||
x="name",
|
||||
y=metric_to_plot,
|
||||
color="model_type",
|
||||
title=f"{metric_to_plot.replace('_', ' ').title()} Comparison",
|
||||
)
|
||||
fig.update_layout(xaxis_tickangle=-45)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
def show_performance_analysis(self):
|
||||
"""Show performance analysis across experiments"""
|
||||
st.subheader("Performance Analysis")
|
||||
|
||||
experiments = self.experiment_tracker.list_experiments()
|
||||
completed_experiments = [
|
||||
e for e in experiments if e.status.value == "completed" and e.test_metrics
|
||||
]
|
||||
|
||||
if not completed_experiments:
|
||||
st.warning("No completed experiments with metrics found.")
|
||||
return
|
||||
|
||||
# Prepare data for analysis
|
||||
analysis_data = self._prepare_analysis_data(completed_experiments)
|
||||
analysis_df = pd.DataFrame(analysis_data)
|
||||
|
||||
self._display_performance_trends(analysis_df)
|
||||
self._display_model_comparison(analysis_df)
|
||||
self._display_top_experiments(analysis_df)
|
||||
|
||||
def _prepare_analysis_data(self, completed_experiments: List) -> List[dict]:
|
||||
"""Prepare data for performance analysis"""
|
||||
analysis_data = []
|
||||
for exp in completed_experiments:
|
||||
row = {
|
||||
"experiment_id": exp.experiment_id,
|
||||
"name": exp.config.name,
|
||||
"model_type": exp.config.model_type,
|
||||
"feature_count": len(exp.config.features),
|
||||
"features": ", ".join([f.value for f in exp.config.features]),
|
||||
"train_size": exp.train_size,
|
||||
"test_size": exp.test_size,
|
||||
**exp.test_metrics,
|
||||
}
|
||||
analysis_data.append(row)
|
||||
return analysis_data
|
||||
|
||||
def _display_performance_trends(self, analysis_df: pd.DataFrame):
|
||||
"""Display performance trend charts"""
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
# Accuracy vs Training Size
|
||||
if "accuracy" in analysis_df.columns and "train_size" in analysis_df.columns:
|
||||
fig = px.scatter(
|
||||
analysis_df,
|
||||
x="train_size",
|
||||
y="accuracy",
|
||||
color="model_type",
|
||||
hover_data=["name"],
|
||||
title="Accuracy vs Training Size",
|
||||
)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
with col2:
|
||||
# Feature Count vs Performance
|
||||
if "accuracy" in analysis_df.columns and "feature_count" in analysis_df.columns:
|
||||
fig = px.scatter(
|
||||
analysis_df,
|
||||
x="feature_count",
|
||||
y="accuracy",
|
||||
color="model_type",
|
||||
hover_data=["name"],
|
||||
title="Accuracy vs Number of Features",
|
||||
)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
def _display_model_comparison(self, analysis_df: pd.DataFrame):
|
||||
"""Display model type comparison"""
|
||||
if "accuracy" in analysis_df.columns:
|
||||
model_performance = (
|
||||
analysis_df.groupby("model_type")["accuracy"]
|
||||
.agg(["mean", "std", "count"])
|
||||
.reset_index()
|
||||
)
|
||||
|
||||
fig = go.Figure()
|
||||
fig.add_trace(
|
||||
go.Bar(
|
||||
x=model_performance["model_type"],
|
||||
y=model_performance["mean"],
|
||||
error_y=dict(type="data", array=model_performance["std"]),
|
||||
name="Average Accuracy",
|
||||
)
|
||||
)
|
||||
fig.update_layout(title="Average Accuracy by Model Type", yaxis_title="Accuracy")
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
def _display_top_experiments(self, analysis_df: pd.DataFrame):
|
||||
"""Display top performing experiments"""
|
||||
st.subheader("Top Performing Experiments")
|
||||
|
||||
if "accuracy" in analysis_df.columns:
|
||||
display_columns = ["name", "model_type", "features", "accuracy"]
|
||||
|
||||
# Add other metrics if available
|
||||
for metric in ["precision", "recall", "f1"]:
|
||||
if metric in analysis_df.columns:
|
||||
display_columns.append(metric)
|
||||
|
||||
top_experiments = analysis_df.nlargest(5, "accuracy")[display_columns]
|
||||
st.dataframe(top_experiments, use_container_width=True)
|
||||
|
||||
def show_model_analysis(self):
|
||||
"""Show detailed model analysis"""
|
||||
st.subheader("Model Analysis")
|
||||
|
||||
experiments = self.experiment_tracker.list_experiments()
|
||||
completed_experiments = [e for e in experiments if e.status.value == "completed"]
|
||||
|
||||
if not completed_experiments:
|
||||
st.warning("No completed experiments found.")
|
||||
return
|
||||
|
||||
# Select experiment for detailed analysis
|
||||
exp_options = {
|
||||
f"{exp.config.name} ({exp.experiment_id[:8]})": exp for exp in completed_experiments
|
||||
}
|
||||
|
||||
selected_exp_name = st.selectbox(
|
||||
"Select Experiment for Detailed Analysis", list(exp_options.keys())
|
||||
)
|
||||
|
||||
if not selected_exp_name:
|
||||
return
|
||||
|
||||
selected_exp = exp_options[selected_exp_name]
|
||||
|
||||
self._display_experiment_details(selected_exp)
|
||||
self._display_confusion_matrix(selected_exp)
|
||||
self._display_feature_importance(selected_exp)
|
||||
self._display_prediction_examples(selected_exp)
|
||||
|
||||
def _display_experiment_details(self, experiment):
|
||||
"""Display experiment configuration and metrics"""
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
st.write("**Experiment Configuration**")
|
||||
st.json(
|
||||
{
|
||||
"name": experiment.config.name,
|
||||
"model_type": experiment.config.model_type,
|
||||
"features": [f.value for f in experiment.config.features],
|
||||
"model_params": experiment.config.model_params,
|
||||
}
|
||||
)
|
||||
|
||||
with col2:
|
||||
st.write("**Performance Metrics**")
|
||||
if experiment.test_metrics:
|
||||
for metric, value in experiment.test_metrics.items():
|
||||
st.metric(metric.title(), f"{value:.4f}")
|
||||
|
||||
def _display_confusion_matrix(self, experiment):
|
||||
"""Display confusion matrix if available"""
|
||||
if experiment.confusion_matrix:
|
||||
st.write("**Confusion Matrix**")
|
||||
cm = np.array(experiment.confusion_matrix)
|
||||
|
||||
fig = px.imshow(cm, text_auto=True, aspect="auto", title="Confusion Matrix")
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
def _display_feature_importance(self, experiment):
|
||||
"""Display feature importance if available"""
|
||||
if experiment.feature_importance:
|
||||
st.write("**Feature Importance**")
|
||||
|
||||
importance_data = sorted(
|
||||
experiment.feature_importance.items(), key=lambda x: x[1], reverse=True
|
||||
)[:20]
|
||||
|
||||
features, importances = zip(*importance_data)
|
||||
|
||||
fig = px.bar(
|
||||
x=list(importances),
|
||||
y=list(features),
|
||||
orientation="h",
|
||||
title="Top 20 Feature Importances",
|
||||
)
|
||||
fig.update_layout(height=600)
|
||||
st.plotly_chart(fig, use_container_width=True)
|
||||
|
||||
def _display_prediction_examples(self, experiment):
|
||||
"""Display prediction examples if available"""
|
||||
if experiment.prediction_examples:
|
||||
st.write("**Prediction Examples**")
|
||||
|
||||
examples_df = pd.DataFrame(experiment.prediction_examples)
|
||||
|
||||
# Separate correct and incorrect predictions
|
||||
correct_examples = examples_df[examples_df["correct"] == True]
|
||||
incorrect_examples = examples_df[examples_df["correct"] == False]
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
st.write("**Correct Predictions**")
|
||||
if not correct_examples.empty:
|
||||
st.dataframe(
|
||||
correct_examples[["name", "true_label", "predicted_label"]],
|
||||
use_container_width=True,
|
||||
)
|
||||
|
||||
with col2:
|
||||
st.write("**Incorrect Predictions**")
|
||||
if not incorrect_examples.empty:
|
||||
st.dataframe(
|
||||
incorrect_examples[["name", "true_label", "predicted_label"]],
|
||||
use_container_width=True,
|
||||
)
|
||||
Reference in New Issue
Block a user