From 912d518106e5f5fd77d436df7e874f23bd8b37cd Mon Sep 17 00:00:00 2001 From: bernard-ng Date: Mon, 29 Sep 2025 22:52:08 +0200 Subject: [PATCH] feat: support gpu --- research/neural_network_model.py | 42 ++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/research/neural_network_model.py b/research/neural_network_model.py index 2f28afd..64f6445 100644 --- a/research/neural_network_model.py +++ b/research/neural_network_model.py @@ -71,6 +71,8 @@ class NeuralNetworkModel(BaseModel): # Extract and prepare features (this will also initialize tokenizer) features_df = self.feature_extractor.extract_features(X) X_prepared = self.prepare_features(features_df) + # Sanitize any out-of-range indices to avoid embedding scatter errors + X_prepared = self._sanitize_sequences(X_prepared) # Encode labels if self.label_encoder is None: @@ -113,6 +115,44 @@ class NeuralNetworkModel(BaseModel): self.is_fitted = True return self + def _sanitize_sequences(self, sequences: np.ndarray) -> np.ndarray: + """Clamp invalid token indices to OOV and ensure int32 dtype. + + This prevents rare cases where malformed inputs or dtype issues introduce + large or negative indices which can trigger TensorScatterUpdate errors + during embedding updates on GPU. + """ + try: + if sequences is None: + return sequences + arr = np.asarray(sequences) + # Ensure integer dtype for embedding lookups + if not np.issubdtype(arr.dtype, np.integer): + arr = arr.astype(np.int64, copy=False) + + if self.tokenizer is not None and hasattr(self.tokenizer, "word_index"): + # Use the actual max index present in the tokenizer mapping + if self.tokenizer.word_index: + max_idx = max(self.tokenizer.word_index.values()) + else: + max_idx = 0 + # OOV token index if available, else fall back to 1 + oov_index = self.tokenizer.word_index.get( + getattr(self.tokenizer, "oov_token", ""), 1 + ) + # Keep zeros (padding) untouched; clamp negatives and > max_idx to OOV + invalid_mask = (arr < 0) | (arr > max_idx) + # Avoid turning zeros into OOV + invalid_mask &= (arr != 0) + if invalid_mask.any(): + arr[invalid_mask] = oov_index + + # Use int32 for TF embedding ops compatibility + return arr.astype(np.int32, copy=False) + except Exception as e: + logging.debug(f"Sequence sanitization skipped due to: {e}") + return sequences + def _collect_text_corpus(self, X: pd.DataFrame) -> List[str]: """Combine configured textual features into one string per record.""" @@ -165,6 +205,7 @@ class NeuralNetworkModel(BaseModel): pass features_df = self.feature_extractor.extract_features(X) X_prepared = self.prepare_features(features_df) + X_prepared = self._sanitize_sequences(X_prepared) y_encoded = self.label_encoder.transform(y) cv = StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=self.config.random_seed) @@ -264,6 +305,7 @@ class NeuralNetworkModel(BaseModel): # Prepare features and get vocabulary size features_df = self.feature_extractor.extract_features(X) X_prepared = self.prepare_features(features_df) + X_prepared = self._sanitize_sequences(X_prepared) y_encoded = self.label_encoder.transform(y) vocab_size = len(self.tokenizer.word_index) + 1 if self.tokenizer else 1000