feat: support gpu
This commit is contained in:
@@ -71,6 +71,8 @@ class NeuralNetworkModel(BaseModel):
|
|||||||
# Extract and prepare features (this will also initialize tokenizer)
|
# Extract and prepare features (this will also initialize tokenizer)
|
||||||
features_df = self.feature_extractor.extract_features(X)
|
features_df = self.feature_extractor.extract_features(X)
|
||||||
X_prepared = self.prepare_features(features_df)
|
X_prepared = self.prepare_features(features_df)
|
||||||
|
# Sanitize any out-of-range indices to avoid embedding scatter errors
|
||||||
|
X_prepared = self._sanitize_sequences(X_prepared)
|
||||||
|
|
||||||
# Encode labels
|
# Encode labels
|
||||||
if self.label_encoder is None:
|
if self.label_encoder is None:
|
||||||
@@ -113,6 +115,44 @@ class NeuralNetworkModel(BaseModel):
|
|||||||
self.is_fitted = True
|
self.is_fitted = True
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def _sanitize_sequences(self, sequences: np.ndarray) -> np.ndarray:
|
||||||
|
"""Clamp invalid token indices to OOV and ensure int32 dtype.
|
||||||
|
|
||||||
|
This prevents rare cases where malformed inputs or dtype issues introduce
|
||||||
|
large or negative indices which can trigger TensorScatterUpdate errors
|
||||||
|
during embedding updates on GPU.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if sequences is None:
|
||||||
|
return sequences
|
||||||
|
arr = np.asarray(sequences)
|
||||||
|
# Ensure integer dtype for embedding lookups
|
||||||
|
if not np.issubdtype(arr.dtype, np.integer):
|
||||||
|
arr = arr.astype(np.int64, copy=False)
|
||||||
|
|
||||||
|
if self.tokenizer is not None and hasattr(self.tokenizer, "word_index"):
|
||||||
|
# Use the actual max index present in the tokenizer mapping
|
||||||
|
if self.tokenizer.word_index:
|
||||||
|
max_idx = max(self.tokenizer.word_index.values())
|
||||||
|
else:
|
||||||
|
max_idx = 0
|
||||||
|
# OOV token index if available, else fall back to 1
|
||||||
|
oov_index = self.tokenizer.word_index.get(
|
||||||
|
getattr(self.tokenizer, "oov_token", "<OOV>"), 1
|
||||||
|
)
|
||||||
|
# Keep zeros (padding) untouched; clamp negatives and > max_idx to OOV
|
||||||
|
invalid_mask = (arr < 0) | (arr > max_idx)
|
||||||
|
# Avoid turning zeros into OOV
|
||||||
|
invalid_mask &= (arr != 0)
|
||||||
|
if invalid_mask.any():
|
||||||
|
arr[invalid_mask] = oov_index
|
||||||
|
|
||||||
|
# Use int32 for TF embedding ops compatibility
|
||||||
|
return arr.astype(np.int32, copy=False)
|
||||||
|
except Exception as e:
|
||||||
|
logging.debug(f"Sequence sanitization skipped due to: {e}")
|
||||||
|
return sequences
|
||||||
|
|
||||||
def _collect_text_corpus(self, X: pd.DataFrame) -> List[str]:
|
def _collect_text_corpus(self, X: pd.DataFrame) -> List[str]:
|
||||||
"""Combine configured textual features into one string per record."""
|
"""Combine configured textual features into one string per record."""
|
||||||
|
|
||||||
@@ -165,6 +205,7 @@ class NeuralNetworkModel(BaseModel):
|
|||||||
pass
|
pass
|
||||||
features_df = self.feature_extractor.extract_features(X)
|
features_df = self.feature_extractor.extract_features(X)
|
||||||
X_prepared = self.prepare_features(features_df)
|
X_prepared = self.prepare_features(features_df)
|
||||||
|
X_prepared = self._sanitize_sequences(X_prepared)
|
||||||
y_encoded = self.label_encoder.transform(y)
|
y_encoded = self.label_encoder.transform(y)
|
||||||
|
|
||||||
cv = StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=self.config.random_seed)
|
cv = StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=self.config.random_seed)
|
||||||
@@ -264,6 +305,7 @@ class NeuralNetworkModel(BaseModel):
|
|||||||
# Prepare features and get vocabulary size
|
# Prepare features and get vocabulary size
|
||||||
features_df = self.feature_extractor.extract_features(X)
|
features_df = self.feature_extractor.extract_features(X)
|
||||||
X_prepared = self.prepare_features(features_df)
|
X_prepared = self.prepare_features(features_df)
|
||||||
|
X_prepared = self._sanitize_sequences(X_prepared)
|
||||||
y_encoded = self.label_encoder.transform(y)
|
y_encoded = self.label_encoder.transform(y)
|
||||||
|
|
||||||
vocab_size = len(self.tokenizer.word_index) + 1 if self.tokenizer else 1000
|
vocab_size = len(self.tokenizer.word_index) + 1 if self.tokenizer else 1000
|
||||||
|
|||||||
Reference in New Issue
Block a user