feat: support gpu
This commit is contained in:
@@ -71,6 +71,8 @@ class NeuralNetworkModel(BaseModel):
|
||||
# Extract and prepare features (this will also initialize tokenizer)
|
||||
features_df = self.feature_extractor.extract_features(X)
|
||||
X_prepared = self.prepare_features(features_df)
|
||||
# Sanitize any out-of-range indices to avoid embedding scatter errors
|
||||
X_prepared = self._sanitize_sequences(X_prepared)
|
||||
|
||||
# Encode labels
|
||||
if self.label_encoder is None:
|
||||
@@ -113,6 +115,44 @@ class NeuralNetworkModel(BaseModel):
|
||||
self.is_fitted = True
|
||||
return self
|
||||
|
||||
def _sanitize_sequences(self, sequences: np.ndarray) -> np.ndarray:
|
||||
"""Clamp invalid token indices to OOV and ensure int32 dtype.
|
||||
|
||||
This prevents rare cases where malformed inputs or dtype issues introduce
|
||||
large or negative indices which can trigger TensorScatterUpdate errors
|
||||
during embedding updates on GPU.
|
||||
"""
|
||||
try:
|
||||
if sequences is None:
|
||||
return sequences
|
||||
arr = np.asarray(sequences)
|
||||
# Ensure integer dtype for embedding lookups
|
||||
if not np.issubdtype(arr.dtype, np.integer):
|
||||
arr = arr.astype(np.int64, copy=False)
|
||||
|
||||
if self.tokenizer is not None and hasattr(self.tokenizer, "word_index"):
|
||||
# Use the actual max index present in the tokenizer mapping
|
||||
if self.tokenizer.word_index:
|
||||
max_idx = max(self.tokenizer.word_index.values())
|
||||
else:
|
||||
max_idx = 0
|
||||
# OOV token index if available, else fall back to 1
|
||||
oov_index = self.tokenizer.word_index.get(
|
||||
getattr(self.tokenizer, "oov_token", "<OOV>"), 1
|
||||
)
|
||||
# Keep zeros (padding) untouched; clamp negatives and > max_idx to OOV
|
||||
invalid_mask = (arr < 0) | (arr > max_idx)
|
||||
# Avoid turning zeros into OOV
|
||||
invalid_mask &= (arr != 0)
|
||||
if invalid_mask.any():
|
||||
arr[invalid_mask] = oov_index
|
||||
|
||||
# Use int32 for TF embedding ops compatibility
|
||||
return arr.astype(np.int32, copy=False)
|
||||
except Exception as e:
|
||||
logging.debug(f"Sequence sanitization skipped due to: {e}")
|
||||
return sequences
|
||||
|
||||
def _collect_text_corpus(self, X: pd.DataFrame) -> List[str]:
|
||||
"""Combine configured textual features into one string per record."""
|
||||
|
||||
@@ -165,6 +205,7 @@ class NeuralNetworkModel(BaseModel):
|
||||
pass
|
||||
features_df = self.feature_extractor.extract_features(X)
|
||||
X_prepared = self.prepare_features(features_df)
|
||||
X_prepared = self._sanitize_sequences(X_prepared)
|
||||
y_encoded = self.label_encoder.transform(y)
|
||||
|
||||
cv = StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=self.config.random_seed)
|
||||
@@ -264,6 +305,7 @@ class NeuralNetworkModel(BaseModel):
|
||||
# Prepare features and get vocabulary size
|
||||
features_df = self.feature_extractor.extract_features(X)
|
||||
X_prepared = self.prepare_features(features_df)
|
||||
X_prepared = self._sanitize_sequences(X_prepared)
|
||||
y_encoded = self.label_encoder.transform(y)
|
||||
|
||||
vocab_size = len(self.tokenizer.word_index) + 1 if self.tokenizer else 1000
|
||||
|
||||
Reference in New Issue
Block a user