"""Predictions interface for the Streamlit app""" from datetime import datetime from typing import Optional import numpy as np import pandas as pd import plotly.express as px import streamlit as st from core.utils import get_data_file_path from research.experiment.experiment_runner import ExperimentRunner from research.experiment.experiment_tracker import ExperimentTracker class Predictions: """Handles prediction interface""" def __init__(self, config, experiment_tracker: ExperimentTracker, experiment_runner: ExperimentRunner): self.config = config self.experiment_tracker = experiment_tracker self.experiment_runner = experiment_runner def index(self): """Main predictions page""" st.header("Make Predictions") # Load available models experiments = self.experiment_tracker.list_experiments() completed_experiments = [ e for e in experiments if e.status.value == "completed" and e.model_path ] if not completed_experiments: st.warning("No trained models available. Please run some experiments first.") return # Model selection model_options = { f"{exp.config.name} (Acc: {exp.test_metrics.get('accuracy', 0):.3f})": exp for exp in completed_experiments if exp.test_metrics } selected_model_name = st.selectbox("Select Model", list(model_options.keys())) if not selected_model_name: return selected_experiment = model_options[selected_model_name] # Prediction modes prediction_mode = st.radio( "Prediction Mode", ["Single Name", "Batch Upload", "Dataset Prediction"] ) if prediction_mode == "Single Name": self.show_single_prediction(selected_experiment) elif prediction_mode == "Batch Upload": self.show_batch_prediction(selected_experiment) elif prediction_mode == "Dataset Prediction": self.show_dataset_prediction(selected_experiment) def show_single_prediction(self, experiment): """Show single name prediction interface""" st.subheader("Single Name Prediction") name_input = st.text_input("Enter a name:", placeholder="e.g., Jean Baptiste Mukendi") if name_input and st.button("Predict Gender"): try: # Load the model model = self.experiment_runner.load_experiment_model(experiment.experiment_id) if model is None: st.error("Failed to load model") return # Create a DataFrame with the input input_df = self._prepare_single_input(name_input) # Make prediction prediction = model.predict(input_df)[0] # Get prediction probability if available confidence = self._get_prediction_confidence(model, input_df) # Display results self._display_single_prediction_results(prediction, confidence, experiment, name_input) except Exception as e: st.error(f"Error making prediction: {e}") def _prepare_single_input(self, name_input: str) -> pd.DataFrame: """Prepare single name input for prediction""" return pd.DataFrame( { "name": [name_input], "words": [len(name_input.split())], "length": [len(name_input.replace(" ", ""))], "province": ["unknown"], # Default values "identified_name": [None], "identified_surname": [None], "probable_native": [None], "probable_surname": [None], } ) def _get_prediction_confidence(self, model, input_df: pd.DataFrame) -> Optional[float]: """Get prediction confidence if available""" try: probabilities = model.predict_proba(input_df)[0] return max(probabilities) except: return None def _display_single_prediction_results(self, prediction: str, confidence: Optional[float], experiment, name_input: str): """Display single prediction results""" col1, col2 = st.columns(2) with col1: gender_label = "Female" if prediction == "f" else "Male" st.success(f"**Predicted Gender:** {gender_label}") with col2: if confidence: st.metric("Confidence", f"{confidence:.2%}") # Additional info st.info(f"Model used: {experiment.config.name}") st.info( f"Features used: {', '.join([f.value for f in experiment.config.features])}" ) def show_batch_prediction(self, experiment): """Show batch prediction interface""" st.subheader("Batch Prediction") uploaded_file = st.file_uploader("Upload CSV file with names", type="csv") if uploaded_file is not None: try: df = pd.read_csv(uploaded_file) st.write("**Uploaded Data Preview:**") st.dataframe(df.head(), use_container_width=True) # Column selection df = self._prepare_batch_data(df) if st.button("Run Batch Prediction"): self._run_batch_prediction(df, experiment) except Exception as e: st.error(f"Error processing file: {e}") def _prepare_batch_data(self, df: pd.DataFrame) -> pd.DataFrame: """Prepare batch data for prediction""" # Column selection if "name" not in df.columns: name_column = st.selectbox("Select the name column:", df.columns) df = df.rename(columns={name_column: "name"}) # Add missing columns with defaults required_columns = [ "words", "length", "province", "identified_name", "identified_surname", "probable_native", "probable_surname", ] for col in required_columns: if col not in df.columns: if col == "words": df[col] = df["name"].str.split().str.len() elif col == "length": df[col] = df["name"].str.replace(" ", "").str.len() else: df[col] = None return df def _run_batch_prediction(self, df: pd.DataFrame, experiment): """Run batch prediction and display results""" with st.spinner("Making predictions..."): # Load model model = self.experiment_runner.load_experiment_model(experiment.experiment_id) if model is None: st.error("Failed to load model") return # Make predictions predictions = model.predict(df) df["predicted_gender"] = predictions df["gender_label"] = df["predicted_gender"].map({"f": "Female", "m": "Male"}) # Try to get probabilities try: probabilities = model.predict_proba(df) df["confidence"] = np.max(probabilities, axis=1) except: df["confidence"] = None st.success("Predictions completed!") # Show results self._display_batch_results(df) def _display_batch_results(self, df: pd.DataFrame): """Display batch prediction results""" result_columns = ["name", "gender_label", "predicted_gender"] if "confidence" in df.columns: result_columns.append("confidence") st.dataframe(df[result_columns], use_container_width=True) # Download results csv = df.to_csv(index=False) st.download_button( label="Download Predictions", data=csv, file_name=f"predictions_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv", mime="text/csv", ) # Summary statistics self._display_batch_summary(df) def _display_batch_summary(self, df: pd.DataFrame): """Display batch prediction summary""" st.subheader("Prediction Summary") gender_counts = df["gender_label"].value_counts() col1, col2, col3 = st.columns(3) with col1: st.metric("Total Predictions", len(df)) with col2: st.metric("Female", gender_counts.get("Female", 0)) with col3: st.metric("Male", gender_counts.get("Male", 0)) # Gender distribution chart fig = px.pie( values=gender_counts.values, names=gender_counts.index, title="Predicted Gender Distribution", ) st.plotly_chart(fig, use_container_width=True) def show_dataset_prediction(self, experiment): """Show dataset prediction interface""" st.subheader("Dataset Prediction") st.write("Apply the model to existing datasets") # Dataset selection dataset_options = { "Featured Dataset": self.config.data.output_files["featured"], "Evaluation Dataset": self.config.data.output_files["evaluation"], } selected_dataset = st.selectbox("Select Dataset", list(dataset_options.keys())) file_path = get_data_file_path(dataset_options[selected_dataset], self.config) if not file_path.exists(): st.warning(f"Dataset not found: {file_path}") return # Load and show dataset info df = self._load_dataset(str(file_path)) if df.empty: return st.write(f"Dataset contains {len(df):,} records") # Prediction options col1, col2 = st.columns(2) with col1: sample_size = st.number_input( "Sample size (0 = all data)", 0, len(df), min(1000, len(df)) ) with col2: compare_with_actual = False if "sex" in df.columns: compare_with_actual = st.checkbox("Compare with actual labels", value=True) if st.button("Run Dataset Prediction"): self._run_dataset_prediction(df, experiment, sample_size, compare_with_actual) def _load_dataset(self, file_path: str) -> pd.DataFrame: """Load dataset with error handling""" try: return pd.read_csv(file_path) except Exception as e: st.error(f"Error loading dataset: {e}") return pd.DataFrame() def _run_dataset_prediction(self, df: pd.DataFrame, experiment, sample_size: int, compare_with_actual: bool): """Run dataset prediction and display results""" with st.spinner("Running predictions..."): # Sample data if requested if sample_size > 0: df_sample = df.sample(n=sample_size, random_state=42) else: df_sample = df # Load model and make predictions model = self.experiment_runner.load_experiment_model(experiment.experiment_id) if model is None: st.error("Failed to load model") return predictions = model.predict(df_sample) df_sample["predicted_gender"] = predictions # Show results if compare_with_actual and "sex" in df_sample.columns: self._display_dataset_comparison(df_sample) else: self._display_dataset_predictions(df_sample) def _display_dataset_comparison(self, df_sample: pd.DataFrame): """Display dataset predictions with actual comparison""" # Calculate accuracy accuracy = (df_sample["sex"] == df_sample["predicted_gender"]).mean() st.metric("Accuracy on Selected Data", f"{accuracy:.4f}") # Confusion matrix from sklearn.metrics import confusion_matrix cm = confusion_matrix(df_sample["sex"], df_sample["predicted_gender"]) fig = px.imshow(cm, text_auto=True, aspect="auto", title="Confusion Matrix") st.plotly_chart(fig, use_container_width=True) # Sample of correct and incorrect predictions correct_mask = df_sample["sex"] == df_sample["predicted_gender"] col1, col2 = st.columns(2) with col1: st.write("**Sample Correct Predictions**") correct_sample = df_sample[correct_mask][["name", "sex", "predicted_gender"]].head(10) st.dataframe(correct_sample, use_container_width=True) with col2: st.write("**Sample Incorrect Predictions**") incorrect_sample = df_sample[~correct_mask][["name", "sex", "predicted_gender"]].head(10) st.dataframe(incorrect_sample, use_container_width=True) def _display_dataset_predictions(self, df_sample: pd.DataFrame): """Display dataset predictions without comparison""" # Just show predictions st.write("**Sample Predictions**") sample_results = df_sample[["name", "predicted_gender"]].head(20) st.dataframe(sample_results, use_container_width=True) # Gender distribution gender_counts = df_sample["predicted_gender"].value_counts() fig = px.pie( values=gender_counts.values, names=gender_counts.index, title="Predicted Gender Distribution", ) st.plotly_chart(fig, use_container_width=True)