feat: enhance logging and memory management across modules

This commit is contained in:
2025-08-13 23:09:05 +02:00
parent 47e52d130c
commit 9601c5e44d
48 changed files with 1004 additions and 773 deletions
+19 -12
View File
@@ -9,6 +9,7 @@ import plotly.express as px
import streamlit as st
from core.utils import get_data_file_path
from core.utils.data_loader import OPTIMIZED_DTYPES
from research.experiment.experiment_runner import ExperimentRunner
from research.experiment.experiment_tracker import ExperimentTracker
@@ -16,7 +17,9 @@ from research.experiment.experiment_tracker import ExperimentTracker
class Predictions:
"""Handles prediction interface"""
def __init__(self, config, experiment_tracker: ExperimentTracker, experiment_runner: ExperimentRunner):
def __init__(
self, config, experiment_tracker: ExperimentTracker, experiment_runner: ExperimentRunner
):
self.config = config
self.experiment_tracker = experiment_tracker
self.experiment_runner = experiment_runner
@@ -86,7 +89,9 @@ class Predictions:
confidence = self._get_prediction_confidence(model, input_df)
# Display results
self._display_single_prediction_results(prediction, confidence, experiment, name_input)
self._display_single_prediction_results(
prediction, confidence, experiment, name_input
)
except Exception as e:
st.error(f"Error making prediction: {e}")
@@ -114,8 +119,9 @@ class Predictions:
except:
return None
def _display_single_prediction_results(self, prediction: str, confidence: Optional[float],
experiment, name_input: str):
def _display_single_prediction_results(
self, prediction: str, confidence: Optional[float], experiment, name_input: str
):
"""Display single prediction results"""
col1, col2 = st.columns(2)
@@ -129,9 +135,7 @@ class Predictions:
# Additional info
st.info(f"Model used: {experiment.config.name}")
st.info(
f"Features used: {', '.join([f.value for f in experiment.config.features])}"
)
st.info(f"Features used: {', '.join([f.value for f in experiment.config.features])}")
def show_batch_prediction(self, experiment):
"""Show batch prediction interface"""
@@ -141,7 +145,7 @@ class Predictions:
if uploaded_file is not None:
try:
df = pd.read_csv(uploaded_file)
df = pd.read_csv(uploaded_file, dtype=OPTIMIZED_DTYPES)
st.write("**Uploaded Data Preview:**")
st.dataframe(df.head(), use_container_width=True)
@@ -296,13 +300,14 @@ class Predictions:
def _load_dataset(self, file_path: str) -> pd.DataFrame:
"""Load dataset with error handling"""
try:
return pd.read_csv(file_path)
return pd.read_csv(file_path, dtype=OPTIMIZED_DTYPES)
except Exception as e:
st.error(f"Error loading dataset: {e}")
return pd.DataFrame()
def _run_dataset_prediction(self, df: pd.DataFrame, experiment, sample_size: int,
compare_with_actual: bool):
def _run_dataset_prediction(
self, df: pd.DataFrame, experiment, sample_size: int, compare_with_actual: bool
):
"""Run dataset prediction and display results"""
with st.spinner("Running predictions..."):
# Sample data if requested
@@ -353,7 +358,9 @@ class Predictions:
with col2:
st.write("**Sample Incorrect Predictions**")
incorrect_sample = df_sample[~correct_mask][["name", "sex", "predicted_gender"]].head(10)
incorrect_sample = df_sample[~correct_mask][["name", "sex", "predicted_gender"]].head(
10
)
st.dataframe(incorrect_sample, use_container_width=True)
def _display_dataset_predictions(self, df_sample: pd.DataFrame):