From d3b3840278e7de674203ffb9204ae844563cf3ce Mon Sep 17 00:00:00 2001 From: bernard-ng Date: Mon, 6 Oct 2025 00:37:29 +0200 Subject: [PATCH] fix: nn models pad_sequences --- src/ners/research/models/bigru_model.py | 7 +- src/ners/research/models/cnn_model.py | 5 +- src/ners/research/models/lstm_model.py | 5 +- src/ners/research/models/transformer_model.py | 5 +- src/ners/research/neural_network_model.py | 23 ++ src/notebooks/experiments.ipynb | 254 ++++++++++++------ src/notebooks/names.ipynb | 4 +- 7 files changed, 211 insertions(+), 92 deletions(-) diff --git a/src/ners/research/models/bigru_model.py b/src/ners/research/models/bigru_model.py index daabfcf..b6c1bbf 100644 --- a/src/ners/research/models/bigru_model.py +++ b/src/ners/research/models/bigru_model.py @@ -23,6 +23,7 @@ class BiGRUModel(NeuralNetworkModel): input_dim=vocab_size, output_dim=params.get("embedding_dim", 64), mask_zero=True, + input_length=params.get("max_len", 6), ), # First recurrent block returns full sequences to allow stacking. # Moderate dropout + optional recurrent_dropout to reduce overfitting @@ -69,4 +70,8 @@ class BiGRUModel(NeuralNetworkModel): sequences = self.tokenizer.texts_to_sequences(text_data) max_len = self.config.model_params.get("max_len", 6) - return pad_sequences(sequences, maxlen=max_len, padding="post") + # Ensure padding and truncation are applied on the right to keep + # contiguous non-zero tokens on the left, matching RNN mask expectations. + return pad_sequences( + sequences, maxlen=max_len, padding="post", truncating="post" + ) diff --git a/src/ners/research/models/cnn_model.py b/src/ners/research/models/cnn_model.py index c7097a2..011011d 100644 --- a/src/ners/research/models/cnn_model.py +++ b/src/ners/research/models/cnn_model.py @@ -83,4 +83,7 @@ class CNNModel(NeuralNetworkModel): "max_len", 20 ) # Longer for character level - return pad_sequences(sequences, maxlen=max_len, padding="post") + # Right-side padding and truncation ensure contiguous non-zero tokens on the left + return pad_sequences( + sequences, maxlen=max_len, padding="post", truncating="post" + ) diff --git a/src/ners/research/models/lstm_model.py b/src/ners/research/models/lstm_model.py index 792c7a4..78b878f 100644 --- a/src/ners/research/models/lstm_model.py +++ b/src/ners/research/models/lstm_model.py @@ -68,4 +68,7 @@ class LSTMModel(NeuralNetworkModel): sequences = self.tokenizer.texts_to_sequences(text_data) max_len = self.config.model_params.get("max_len", 6) - return pad_sequences(sequences, maxlen=max_len, padding="post") + # Right-side padding and truncation to preserve contiguous non-zero tokens + return pad_sequences( + sequences, maxlen=max_len, padding="post", truncating="post" + ) diff --git a/src/ners/research/models/transformer_model.py b/src/ners/research/models/transformer_model.py index 1a9876f..2a581d7 100644 --- a/src/ners/research/models/transformer_model.py +++ b/src/ners/research/models/transformer_model.py @@ -88,4 +88,7 @@ class TransformerModel(NeuralNetworkModel): sequences = self.tokenizer.texts_to_sequences(text_data) max_len = int(self.config.model_params.get("max_len", 6)) - return pad_sequences(sequences, maxlen=max_len, padding="post") + # Right-side padding and truncation for consistent masking/shape + return pad_sequences( + sequences, maxlen=max_len, padding="post", truncating="post" + ) diff --git a/src/ners/research/neural_network_model.py b/src/ners/research/neural_network_model.py index 6181367..db68419 100644 --- a/src/ners/research/neural_network_model.py +++ b/src/ners/research/neural_network_model.py @@ -149,6 +149,29 @@ class NeuralNetworkModel(BaseModel): if invalid_mask.any(): arr[invalid_mask] = oov_index + # Enforce strictly right-padded masks for RNN/cuDNN compatibility. + # Any zero appearing before the last non-zero in a sequence will be + # replaced with the OOV index so the mask remains contiguous True->False. + try: + nz = arr != 0 # non-padding tokens + if nz.ndim == 2 and arr.shape[1] > 0: + # Identify rows that have at least one non-zero + has_nz = nz.any(axis=1) + # Compute last non-zero position per row; if none, set to -1 + indices = np.arange(arr.shape[1], dtype=np.int64) + # Max of indices where nz is True gives last non-zero + last_pos = (nz * indices).max(axis=1) + last_pos = np.where(has_nz, last_pos, -1) + # Broadcast to mark the left region up to last non-zero (inclusive) + left_region = indices <= last_pos[:, None] + # Zeros inside the left region are invalid padding -> set to OOV + zero_inside = (~nz) & left_region + if zero_inside.any(): + arr[zero_inside] = oov_index + except Exception: + # Best-effort; skip if any unexpected broadcasting issue occurs + pass + # Use int32 for TF embedding ops compatibility return arr.astype(np.int32, copy=False) except Exception as e: diff --git a/src/notebooks/experiments.ipynb b/src/notebooks/experiments.ipynb index 73348c5..e884ffa 100644 --- a/src/notebooks/experiments.ipynb +++ b/src/notebooks/experiments.ipynb @@ -73,60 +73,82 @@ " cv = exp.get(\"cv_metrics\", {}) or {}\n", "\n", " cm = exp.get(\"confusion_matrix\")\n", - " tn=fp=fn=tp=np.nan\n", - " if isinstance(cm, list) and len(cm)==2 and all(isinstance(r, list) and len(r)==2 for r in cm):\n", + " tn = fp = fn = tp = np.nan\n", + " if (\n", + " isinstance(cm, list)\n", + " and len(cm) == 2\n", + " and all(isinstance(r, list) and len(r) == 2 for r in cm)\n", + " ):\n", " # By inspection of the provided metrics, mapping is:\n", " # rows = true [f, m]; cols = pred [f, m]\n", - " tn, fp = cm[0][0], cm[0][1] # true negatives and false positives for positive class 'm'\n", + " tn, fp = (\n", + " cm[0][0],\n", + " cm[0][1],\n", + " ) # true negatives and false positives for positive class 'm'\n", " fn, tp = cm[1][0], cm[1][1]\n", "\n", " # Derived metrics from confusion matrix (where present)\n", - " def safe_div(a,b): \n", - " return float(a)/float(b) if (b not in (0, None) and not pd.isna(b)) else np.nan\n", + " def safe_div(a, b):\n", + " return (\n", + " float(a) / float(b) if (b not in (0, None) and not pd.isna(b)) else np.nan\n", + " )\n", "\n", - " sensitivity = safe_div(tp, tp+fn) # TPR for 'm'\n", - " specificity = safe_div(tn, tn+fp) # TNR for 'm'\n", + " sensitivity = safe_div(tp, tp + fn) # TPR for 'm'\n", + " specificity = safe_div(tn, tn + fp) # TNR for 'm'\n", " balanced_acc = np.nanmean([sensitivity, specificity])\n", - " mcc_num = (tp*tn - fp*fn)\n", - " mcc_den = sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)) if all(x==x for x in [tp+fp, tp+fn, tn+fp, tn+fn]) else np.nan\n", + " mcc_num = tp * tn - fp * fn\n", + " mcc_den = (\n", + " sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))\n", + " if all(x == x for x in [tp + fp, tp + fn, tn + fp, tn + fn])\n", + " else np.nan\n", + " )\n", " mcc = safe_div(mcc_num, mcc_den)\n", "\n", " n_test = exp.get(\"test_size\") or np.nansum([tn, fp, fn, tp])\n", " test_acc = te.get(\"accuracy\", np.nan)\n", " # 95% CI for accuracy via normal approximation (ok for n=2000)\n", - " if pd.notna(test_acc) and pd.notna(n_test) and n_test>0:\n", - " se = np.sqrt(test_acc*(1-test_acc)/n_test)\n", - " acc_ci_lo = test_acc - 1.96*se\n", - " acc_ci_hi = test_acc + 1.96*se\n", + " if pd.notna(test_acc) and pd.notna(n_test) and n_test > 0:\n", + " se = np.sqrt(test_acc * (1 - test_acc) / n_test)\n", + " acc_ci_lo = test_acc - 1.96 * se\n", + " acc_ci_hi = test_acc + 1.96 * se\n", " else:\n", " acc_ci_lo = acc_ci_hi = np.nan\n", "\n", - " rows.append({\n", - " \"experiment_id\": exp_id,\n", - " \"model\": name or model_type,\n", - " \"model_family\": (model_type or \"\").upper(),\n", - " \"feature_set\": features,\n", - " \"train_accuracy\": tr.get(\"accuracy\", np.nan),\n", - " \"test_accuracy\": test_acc,\n", - " \"cv_accuracy_mean\": cv.get(\"accuracy\", np.nan),\n", - " \"cv_accuracy_std\": cv.get(\"accuracy_std\", np.nan),\n", - " \"train_f1\": tr.get(\"f1\", np.nan),\n", - " \"test_f1\": te.get(\"f1\", np.nan),\n", - " \"cv_f1_mean\": cv.get(\"f1\", np.nan),\n", - " \"cv_f1_std\": cv.get(\"f1_std\", np.nan),\n", - " \"TP\": tp, \"FP\": fp, \"TN\": tn, \"FN\": fn,\n", - " \"sensitivity_TPR_m\": sensitivity,\n", - " \"specificity_TNR_m\": specificity,\n", - " \"balanced_accuracy\": balanced_acc,\n", - " \"MCC\": mcc,\n", - " \"n_test\": n_test,\n", - " \"acc_95ci_lo\": acc_ci_lo,\n", - " \"acc_95ci_hi\": acc_ci_hi,\n", - " \"train_minus_test_gap\": (tr.get(\"accuracy\", np.nan) - test_acc) if pd.notna(tr.get(\"accuracy\", np.nan)) and pd.notna(test_acc) else np.nan,\n", - " \"test_minus_cv_gap\": (test_acc - cv.get(\"accuracy\", np.nan)) if pd.notna(test_acc) and pd.notna(cv.get(\"accuracy\", np.nan)) else np.nan,\n", - " \"start_time\": exp.get(\"start_time\"),\n", - " \"end_time\": exp.get(\"end_time\")\n", - " })\n", + " rows.append(\n", + " {\n", + " \"experiment_id\": exp_id,\n", + " \"model\": name or model_type,\n", + " \"model_family\": (model_type or \"\").upper(),\n", + " \"feature_set\": features,\n", + " \"train_accuracy\": tr.get(\"accuracy\", np.nan),\n", + " \"test_accuracy\": test_acc,\n", + " \"cv_accuracy_mean\": cv.get(\"accuracy\", np.nan),\n", + " \"cv_accuracy_std\": cv.get(\"accuracy_std\", np.nan),\n", + " \"train_f1\": tr.get(\"f1\", np.nan),\n", + " \"test_f1\": te.get(\"f1\", np.nan),\n", + " \"cv_f1_mean\": cv.get(\"f1\", np.nan),\n", + " \"cv_f1_std\": cv.get(\"f1_std\", np.nan),\n", + " \"TP\": tp,\n", + " \"FP\": fp,\n", + " \"TN\": tn,\n", + " \"FN\": fn,\n", + " \"sensitivity_TPR_m\": sensitivity,\n", + " \"specificity_TNR_m\": specificity,\n", + " \"balanced_accuracy\": balanced_acc,\n", + " \"MCC\": mcc,\n", + " \"n_test\": n_test,\n", + " \"acc_95ci_lo\": acc_ci_lo,\n", + " \"acc_95ci_hi\": acc_ci_hi,\n", + " \"train_minus_test_gap\": (tr.get(\"accuracy\", np.nan) - test_acc)\n", + " if pd.notna(tr.get(\"accuracy\", np.nan)) and pd.notna(test_acc)\n", + " else np.nan,\n", + " \"test_minus_cv_gap\": (test_acc - cv.get(\"accuracy\", np.nan))\n", + " if pd.notna(test_acc) and pd.notna(cv.get(\"accuracy\", np.nan))\n", + " else np.nan,\n", + " \"start_time\": exp.get(\"start_time\"),\n", + " \"end_time\": exp.get(\"end_time\"),\n", + " }\n", + " )\n", "\n", "df = pd.DataFrame(rows)" ] @@ -139,23 +161,53 @@ "outputs": [], "source": [ "# Clean and order categorical fields\n", - "df[\"feature_set\"] = df[\"feature_set\"].replace({\"full_name\":\"Full name\",\"native_name\":\"Native\",\"surname\":\"Surname\"})\n", - "order_features = [\"Full name\",\"Surname\",\"Native\"]\n", - "df[\"feature_set\"] = pd.Categorical(df[\"feature_set\"], categories=order_features, ordered=True)\n", + "df[\"feature_set\"] = df[\"feature_set\"].replace(\n", + " {\"full_name\": \"Full name\", \"native_name\": \"Native\", \"surname\": \"Surname\"}\n", + ")\n", + "order_features = [\"Full name\", \"Surname\", \"Native\"]\n", + "df[\"feature_set\"] = pd.Categorical(\n", + " df[\"feature_set\"], categories=order_features, ordered=True\n", + ")\n", "\n", - "order_family = [\"LOGISTIC_REGRESSION\",\"LIGHTGBM\",\"LSTM\",\"CNN\",\"BIGRU\", \"RANDOM_FOREST\", \"TRANSFORMER\", \"NAIVE_BAYES\", \"XGBOOST\"]\n", - "df[\"model_family\"] = pd.Categorical(df[\"model_family\"], categories=order_family, ordered=True)\n", + "order_family = [\n", + " \"LOGISTIC_REGRESSION\",\n", + " \"LIGHTGBM\",\n", + " \"LSTM\",\n", + " \"CNN\",\n", + " \"BIGRU\",\n", + " \"RANDOM_FOREST\",\n", + " \"TRANSFORMER\",\n", + " \"NAIVE_BAYES\",\n", + " \"XGBOOST\",\n", + "]\n", + "df[\"model_family\"] = pd.Categorical(\n", + " df[\"model_family\"], categories=order_family, ordered=True\n", + ")\n", "\n", "# Summary table (subset of most relevant columns)\n", "summary_cols = [\n", - " \"experiment_id\",\"model_family\",\"feature_set\",\n", - " \"train_accuracy\",\"test_accuracy\",\"cv_accuracy_mean\",\"cv_accuracy_std\",\n", - " \"acc_95ci_lo\",\"acc_95ci_hi\",\n", - " \"balanced_accuracy\",\"MCC\",\n", - " \"train_minus_test_gap\",\"test_minus_cv_gap\",\n", - " \"n_test\"\n", + " \"experiment_id\",\n", + " \"model_family\",\n", + " \"feature_set\",\n", + " \"train_accuracy\",\n", + " \"test_accuracy\",\n", + " \"cv_accuracy_mean\",\n", + " \"cv_accuracy_std\",\n", + " \"acc_95ci_lo\",\n", + " \"acc_95ci_hi\",\n", + " \"balanced_accuracy\",\n", + " \"MCC\",\n", + " \"train_minus_test_gap\",\n", + " \"test_minus_cv_gap\",\n", + " \"n_test\",\n", "]\n", - "summary = df[summary_cols].sort_values([\"model_family\",\"feature_set\",\"test_accuracy\"], ascending=[True, True, False]).reset_index(drop=True)\n", + "summary = (\n", + " df[summary_cols]\n", + " .sort_values(\n", + " [\"model_family\", \"feature_set\", \"test_accuracy\"], ascending=[True, True, False]\n", + " )\n", + " .reset_index(drop=True)\n", + ")\n", "\n", "# Display the master summary table\n", "display(summary)" @@ -171,25 +223,37 @@ "# Build a pivot for plotting\n", "plot_df = df.dropna(subset=[\"test_accuracy\"]).copy()\n", "# Prepare positions\n", - "families = [f for f in order_family if f in plot_df[\"model_family\"].astype(str).unique()]\n", - "features = [f for f in order_features if f in plot_df[\"feature_set\"].astype(str).unique()]\n", + "families = [\n", + " f for f in order_family if f in plot_df[\"model_family\"].astype(str).unique()\n", + "]\n", + "features = [\n", + " f for f in order_features if f in plot_df[\"feature_set\"].astype(str).unique()\n", + "]\n", "\n", "# Bar positions\n", "x = np.arange(len(families))\n", - "width = 0.8 / max(1,len(features)) # total width split by features\n", + "width = 0.8 / max(1, len(features)) # total width split by features\n", "\n", - "fig1 = plt.figure(figsize=(10,6))\n", + "fig1 = plt.figure(figsize=(10, 6))\n", "for i, feat in enumerate(features):\n", - " sub = plot_df[plot_df[\"feature_set\"].astype(str)==feat]\n", + " sub = plot_df[plot_df[\"feature_set\"].astype(str) == feat]\n", " # Align to families\n", " y = []\n", " yerr = [[], []] # lower and upper errors for asymmetric CI\n", " for fam in families:\n", - " row = sub[sub[\"model_family\"].astype(str)==fam]\n", + " row = sub[sub[\"model_family\"].astype(str) == fam]\n", " if len(row):\n", " val = float(row.iloc[0][\"test_accuracy\"])\n", - " lo = float(row.iloc[0][\"acc_95ci_lo\"]) if pd.notna(row.iloc[0][\"acc_95ci_lo\"]) else np.nan\n", - " hi = float(row.iloc[0][\"acc_95ci_hi\"]) if pd.notna(row.iloc[0][\"acc_95ci_hi\"]) else np.nan\n", + " lo = (\n", + " float(row.iloc[0][\"acc_95ci_lo\"])\n", + " if pd.notna(row.iloc[0][\"acc_95ci_lo\"])\n", + " else np.nan\n", + " )\n", + " hi = (\n", + " float(row.iloc[0][\"acc_95ci_hi\"])\n", + " if pd.notna(row.iloc[0][\"acc_95ci_hi\"])\n", + " else np.nan\n", + " )\n", " else:\n", " val, lo, hi = np.nan, np.nan, np.nan\n", " y.append(val)\n", @@ -201,7 +265,14 @@ " yerr[0].append(np.nan)\n", " yerr[1].append(np.nan)\n", "\n", - " plt.bar(x + i*width - (len(features)-1)*width/2, y, width, label=feat, yerr=yerr, capsize=4)\n", + " plt.bar(\n", + " x + i * width - (len(features) - 1) * width / 2,\n", + " y,\n", + " width,\n", + " label=feat,\n", + " yerr=yerr,\n", + " capsize=4,\n", + " )\n", "\n", "plt.xticks(x, families, rotation=0)\n", "plt.ylabel(\"Test accuracy\")\n", @@ -219,15 +290,15 @@ "metadata": {}, "outputs": [], "source": [ - "fig2 = plt.figure(figsize=(10,6))\n", + "fig2 = plt.figure(figsize=(10, 6))\n", "for i, feat in enumerate(features):\n", - " sub = plot_df[plot_df[\"feature_set\"].astype(str)==feat]\n", + " sub = plot_df[plot_df[\"feature_set\"].astype(str) == feat]\n", " y = []\n", " for fam in families:\n", - " row = sub[sub[\"model_family\"].astype(str)==fam]\n", + " row = sub[sub[\"model_family\"].astype(str) == fam]\n", " val = float(row.iloc[0][\"test_f1\"]) if len(row) else np.nan\n", " y.append(val)\n", - " plt.bar(x + i*width - (len(features)-1)*width/2, y, width, label=feat)\n", + " plt.bar(x + i * width - (len(features) - 1) * width / 2, y, width, label=feat)\n", "\n", "plt.xticks(x, families, rotation=0)\n", "plt.ylabel(\"Test F1\")\n", @@ -245,14 +316,18 @@ "metadata": {}, "outputs": [], "source": [ - "fig3 = plt.figure(figsize=(7,7))\n", + "fig3 = plt.figure(figsize=(7, 7))\n", "for feat in features:\n", - " sub = df[df[\"feature_set\"].astype(str)==feat]\n", + " sub = df[df[\"feature_set\"].astype(str) == feat]\n", " plt.scatter(sub[\"train_accuracy\"], sub[\"test_accuracy\"], label=feat)\n", "# y=x reference\n", - "lims = [min(df[\"train_accuracy\"].min(), df[\"test_accuracy\"].min())-0.02, max(df[\"train_accuracy\"].max(), df[\"test_accuracy\"].max())+0.02]\n", + "lims = [\n", + " min(df[\"train_accuracy\"].min(), df[\"test_accuracy\"].min()) - 0.02,\n", + " max(df[\"train_accuracy\"].max(), df[\"test_accuracy\"].max()) + 0.02,\n", + "]\n", "plt.plot(lims, lims, linestyle=\"--\")\n", - "plt.xlim(lims); plt.ylim(lims)\n", + "plt.xlim(lims)\n", + "plt.ylim(lims)\n", "plt.xlabel(\"Train accuracy\")\n", "plt.ylabel(\"Test accuracy\")\n", "plt.title(\"Overfitting analysis: Train vs Test accuracy\")\n", @@ -268,22 +343,24 @@ "metadata": {}, "outputs": [], "source": [ - "best_rows = df.sort_values(\"test_accuracy\", ascending=False).groupby(\"feature_set\").head(1)\n", + "best_rows = (\n", + " df.sort_values(\"test_accuracy\", ascending=False).groupby(\"feature_set\").head(1)\n", + ")\n", "for _, row in best_rows.iterrows():\n", " cm = np.array([[row[\"TN\"], row[\"FP\"]], [row[\"FN\"], row[\"TP\"]]], dtype=float)\n", " if np.isnan(cm).any():\n", " continue\n", - " fig = plt.figure(figsize=(5,5))\n", + " fig = plt.figure(figsize=(5, 5))\n", " im = plt.imshow(cm, interpolation=\"nearest\")\n", " plt.title(f\"Confusion Matrix — {row['model_family']} ({row['feature_set']})\")\n", - " plt.xticks([0,1], [\"Pred: f\",\"Pred: m\"])\n", - " plt.yticks([0,1], [\"True: f\",\"True: m\"])\n", + " plt.xticks([0, 1], [\"Pred: f\", \"Pred: m\"])\n", + " plt.yticks([0, 1], [\"True: f\", \"True: m\"])\n", " # Annotate counts and rates\n", " total = cm.sum()\n", " for i in range(2):\n", " for j in range(2):\n", - " val = cm[i,j]\n", - " plt.text(j, i, f\"{int(val)}\\n({val/total:.2%})\", ha=\"center\", va=\"center\")\n", + " val = cm[i, j]\n", + " plt.text(j, i, f\"{int(val)}\\n({val / total:.2%})\", ha=\"center\", va=\"center\")\n", " plt.colorbar(im, fraction=0.046, pad=0.04)\n", " plt.tight_layout()\n", " plt.show()" @@ -298,34 +375,37 @@ "source": [ "deltas = []\n", "for fam in families:\n", - " fam_rows = df[df[\"model_family\"].astype(str)==fam]\n", - " base = fam_rows[fam_rows[\"feature_set\"]==\"Native\"]\n", + " fam_rows = df[df[\"model_family\"].astype(str) == fam]\n", + " base = fam_rows[fam_rows[\"feature_set\"] == \"Native\"]\n", " if len(base):\n", " base_acc = float(base.iloc[0][\"test_accuracy\"])\n", - " for feat in [\"Full name\",\"Surname\"]:\n", - " tgt = fam_rows[fam_rows[\"feature_set\"]==feat]\n", + " for feat in [\"Full name\", \"Surname\"]:\n", + " tgt = fam_rows[fam_rows[\"feature_set\"] == feat]\n", " if len(tgt):\n", - " deltas.append({\n", - " \"model_family\": fam,\n", - " \"comparison\": f\"{feat} minus Native\",\n", - " \"delta_accuracy\": float(tgt.iloc[0][\"test_accuracy\"]) - base_acc\n", - " })\n", + " deltas.append(\n", + " {\n", + " \"model_family\": fam,\n", + " \"comparison\": f\"{feat} minus Native\",\n", + " \"delta_accuracy\": float(tgt.iloc[0][\"test_accuracy\"])\n", + " - base_acc,\n", + " }\n", + " )\n", "\n", "deltas_df = pd.DataFrame(deltas)\n", "display(deltas_df)\n", "\n", - "fig5 = plt.figure(figsize=(10,6))\n", + "fig5 = plt.figure(figsize=(10, 6))\n", "# Make bars grouped by model_family\n", "comp_types = deltas_df[\"comparison\"].unique().tolist() if not deltas_df.empty else []\n", "x2 = np.arange(len(families))\n", "width2 = 0.8 / max(1, len(comp_types))\n", "for i, comp in enumerate(comp_types):\n", - " sub = deltas_df[deltas_df[\"comparison\"]==comp]\n", + " sub = deltas_df[deltas_df[\"comparison\"] == comp]\n", " y = []\n", " for fam in families:\n", - " row = sub[sub[\"model_family\"]==fam]\n", + " row = sub[sub[\"model_family\"] == fam]\n", " y.append(float(row.iloc[0][\"delta_accuracy\"]) if len(row) else np.nan)\n", - " plt.bar(x2 + i*width2 - (len(comp_types)-1)*width2/2, y, width2, label=comp)\n", + " plt.bar(x2 + i * width2 - (len(comp_types) - 1) * width2 / 2, y, width2, label=comp)\n", "\n", "plt.xticks(x2, families)\n", "plt.axhline(0, linestyle=\"--\")\n", diff --git a/src/notebooks/names.ipynb b/src/notebooks/names.ipynb index 398f670..066005c 100644 --- a/src/notebooks/names.ipynb +++ b/src/notebooks/names.ipynb @@ -113,7 +113,9 @@ "df_name_categories.head(12)\n", "\n", "# save data\n", - "df_name_categories.to_csv(\"../../assets/identified_category_distribution.csv\", index=False)" + "df_name_categories.to_csv(\n", + " \"../../assets/identified_category_distribution.csv\", index=False\n", + ")" ] }, {