feat: enhance logging and memory management across modules
This commit is contained in:
@@ -48,7 +48,7 @@ class NERNameModel:
|
||||
|
||||
logging.info(f"Loading training data from {data_path}")
|
||||
|
||||
with open(data_path, 'r', encoding='utf-8') as f:
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
raw_data = json.load(f)
|
||||
|
||||
# Validate and clean training data
|
||||
@@ -58,7 +58,9 @@ class NERNameModel:
|
||||
for i, item in enumerate(raw_data):
|
||||
try:
|
||||
if not isinstance(item, (list, tuple)) or len(item) != 2:
|
||||
logging.warning(f"Skipping invalid training example format at index {i}: {item}")
|
||||
logging.warning(
|
||||
f"Skipping invalid training example format at index {i}: {item}"
|
||||
)
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
@@ -83,20 +85,27 @@ class NERNameModel:
|
||||
# String format from tagger: "[(0, 6, 'NATIVE'), ...]"
|
||||
try:
|
||||
import ast
|
||||
|
||||
entities = ast.literal_eval(entities_raw)
|
||||
if not isinstance(entities, list):
|
||||
logging.warning(f"Parsed entities is not a list at index {i}: {entities}")
|
||||
logging.warning(
|
||||
f"Parsed entities is not a list at index {i}: {entities}"
|
||||
)
|
||||
skipped_count += 1
|
||||
continue
|
||||
except (ValueError, SyntaxError) as e:
|
||||
logging.warning(f"Failed to parse entity string at index {i}: {entities_raw} ({e})")
|
||||
logging.warning(
|
||||
f"Failed to parse entity string at index {i}: {entities_raw} ({e})"
|
||||
)
|
||||
skipped_count += 1
|
||||
continue
|
||||
elif isinstance(entities_raw, list):
|
||||
# Already in list format
|
||||
entities = entities_raw
|
||||
else:
|
||||
logging.warning(f"Skipping invalid entities format at index {i}: {entities_raw}")
|
||||
logging.warning(
|
||||
f"Skipping invalid entities format at index {i}: {entities_raw}"
|
||||
)
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
@@ -110,16 +119,20 @@ class NERNameModel:
|
||||
start, end, label = entity
|
||||
|
||||
# Validate entity components
|
||||
if (not isinstance(start, int) or not isinstance(end, int) or
|
||||
not isinstance(label, str) or start >= end or
|
||||
start < 0 or end > len(text)):
|
||||
if (
|
||||
not isinstance(start, int)
|
||||
or not isinstance(end, int)
|
||||
or not isinstance(label, str)
|
||||
or start >= end
|
||||
or start < 0
|
||||
or end > len(text)
|
||||
):
|
||||
logging.warning(f"Skipping invalid entity bounds in '{text}': {entity}")
|
||||
continue
|
||||
|
||||
# Check for overlaps with already validated entities
|
||||
has_overlap = any(
|
||||
start < v_end and end > v_start
|
||||
for v_start, v_end, _ in valid_entities
|
||||
start < v_end and end > v_start for v_start, v_end, _ in valid_entities
|
||||
)
|
||||
|
||||
if has_overlap:
|
||||
@@ -128,8 +141,10 @@ class NERNameModel:
|
||||
|
||||
# Validate that the span doesn't contain spaces (matching tagger validation)
|
||||
span_text = text[start:end]
|
||||
if not span_text or span_text != span_text.strip() or ' ' in span_text:
|
||||
logging.warning(f"Skipping entity with spaces in '{text}': {entity} -> '{span_text}'")
|
||||
if not span_text or span_text != span_text.strip() or " " in span_text:
|
||||
logging.warning(
|
||||
f"Skipping entity with spaces in '{text}': {entity} -> '{span_text}'"
|
||||
)
|
||||
continue
|
||||
|
||||
valid_entities.append((start, end, label))
|
||||
@@ -148,7 +163,9 @@ class NERNameModel:
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
logging.info(f"Loaded {len(valid_data)} valid training examples, skipped {skipped_count} invalid ones")
|
||||
logging.info(
|
||||
f"Loaded {len(valid_data)} valid training examples, skipped {skipped_count} invalid ones"
|
||||
)
|
||||
|
||||
if not valid_data:
|
||||
raise ValueError("No valid training examples found in the data")
|
||||
@@ -156,15 +173,17 @@ class NERNameModel:
|
||||
return valid_data
|
||||
|
||||
def train(
|
||||
self,
|
||||
data: List[Tuple[str, Dict]],
|
||||
epochs: int = 5,
|
||||
batch_size: int = 16,
|
||||
dropout_rate: float = 0.2,
|
||||
self,
|
||||
data: List[Tuple[str, Dict]],
|
||||
epochs: int = 5,
|
||||
batch_size: int = 16,
|
||||
dropout_rate: float = 0.2,
|
||||
) -> None:
|
||||
"""Train the NER model"""
|
||||
logging.info(f"Starting NER training with {len(data)} examples")
|
||||
logging.info(f"Training parameters: epochs={epochs}, batch_size={batch_size}, dropout={dropout_rate}")
|
||||
logging.info(
|
||||
f"Training parameters: epochs={epochs}, batch_size={batch_size}, dropout={dropout_rate}"
|
||||
)
|
||||
|
||||
if self.nlp is None:
|
||||
raise ValueError("Model not initialized. Call create_blank_model() first.")
|
||||
@@ -184,16 +203,15 @@ class NERNameModel:
|
||||
doc = self.nlp.make_doc(text)
|
||||
example = Example.from_dict(doc, annotations)
|
||||
examples.append(example)
|
||||
logging.info(f"Training example: {text[:30]}... with entities {annotations.get('entities', [])}")
|
||||
logging.info(
|
||||
f"Training example: {text[:30]}... with entities {annotations.get('entities', [])}"
|
||||
)
|
||||
|
||||
# Train in batches
|
||||
batches = minibatch(examples, size=batch_size)
|
||||
for batch in batches:
|
||||
self.nlp.update(
|
||||
batch,
|
||||
losses=losses,
|
||||
drop=dropout_rate,
|
||||
sgd=self.nlp.create_optimizer()
|
||||
batch, losses=losses, drop=dropout_rate, sgd=self.nlp.create_optimizer()
|
||||
)
|
||||
logging.info(f"Training batch with {len(batch)} examples, current losses: {losses}")
|
||||
|
||||
@@ -208,7 +226,7 @@ class NERNameModel:
|
||||
"training_examples": len(data),
|
||||
"loss_history": losses_history,
|
||||
"batch_size": batch_size,
|
||||
"dropout_rate": dropout_rate
|
||||
"dropout_rate": dropout_rate,
|
||||
}
|
||||
|
||||
logging.info(f"Training completed. Final loss: {self.training_stats['final_loss']:.4f}")
|
||||
@@ -225,7 +243,10 @@ class NERNameModel:
|
||||
predicted_entities = 0
|
||||
actual_entities = 0
|
||||
|
||||
entity_stats = {"NATIVE": {"tp": 0, "fp": 0, "fn": 0}, "SURNAME": {"tp": 0, "fp": 0, "fn": 0}}
|
||||
entity_stats = {
|
||||
"NATIVE": {"tp": 0, "fp": 0, "fn": 0},
|
||||
"SURNAME": {"tp": 0, "fp": 0, "fn": 0},
|
||||
}
|
||||
|
||||
for text, annotations in test_data:
|
||||
# Get actual entities
|
||||
@@ -259,7 +280,9 @@ class NERNameModel:
|
||||
# Calculate overall metrics
|
||||
precision = correct_entities / predicted_entities if predicted_entities > 0 else 0
|
||||
recall = correct_entities / actual_entities if actual_entities > 0 else 0
|
||||
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
||||
f1_score = (
|
||||
2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
||||
)
|
||||
|
||||
# Calculate per-label metrics
|
||||
label_metrics = {}
|
||||
@@ -268,14 +291,16 @@ class NERNameModel:
|
||||
label_precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
||||
label_recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
||||
label_f1 = (
|
||||
2 * (label_precision * label_recall) / (label_precision + label_recall)) \
|
||||
if (label_precision + label_recall) > 0 else 0
|
||||
(2 * (label_precision * label_recall) / (label_precision + label_recall))
|
||||
if (label_precision + label_recall) > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
label_metrics[label] = {
|
||||
"precision": label_precision,
|
||||
"recall": label_recall,
|
||||
"f1_score": label_f1,
|
||||
"support": tp + fn
|
||||
"support": tp + fn,
|
||||
}
|
||||
|
||||
evaluation_results = {
|
||||
@@ -286,9 +311,9 @@ class NERNameModel:
|
||||
"total_examples": total_examples,
|
||||
"correct_entities": correct_entities,
|
||||
"predicted_entities": predicted_entities,
|
||||
"actual_entities": actual_entities
|
||||
"actual_entities": actual_entities,
|
||||
},
|
||||
"by_label": label_metrics
|
||||
"by_label": label_metrics,
|
||||
}
|
||||
|
||||
logging.info(f"NER Evaluation completed. Overall F1: {f1_score:.4f}")
|
||||
@@ -309,7 +334,7 @@ class NERNameModel:
|
||||
|
||||
# Save training statistics
|
||||
stats_path = model_dir / "training_stats.json"
|
||||
with open(stats_path, 'w', encoding='utf-8') as f:
|
||||
with open(stats_path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.training_stats, f, indent=2)
|
||||
|
||||
logging.info(f"NER Model saved to {model_dir}")
|
||||
@@ -328,7 +353,7 @@ class NERNameModel:
|
||||
# Load training statistics if available
|
||||
stats_path = Path(model_path) / "training_stats.json"
|
||||
if stats_path.exists():
|
||||
with open(stats_path, 'r', encoding='utf-8') as f:
|
||||
with open(stats_path, "r", encoding="utf-8") as f:
|
||||
self.training_stats = json.load(f)
|
||||
|
||||
logging.info("NER Model loaded successfully")
|
||||
@@ -342,15 +367,14 @@ class NERNameModel:
|
||||
entities = []
|
||||
|
||||
for ent in doc.ents:
|
||||
entities.append({
|
||||
"text": ent.text,
|
||||
"label": ent.label_,
|
||||
"start": ent.start_char,
|
||||
"end": ent.end_char,
|
||||
"confidence": getattr(ent, 'score', None) # If confidence scores are available
|
||||
})
|
||||
entities.append(
|
||||
{
|
||||
"text": ent.text,
|
||||
"label": ent.label_,
|
||||
"start": ent.start_char,
|
||||
"end": ent.end_char,
|
||||
"confidence": getattr(ent, "score", None), # If confidence scores are available
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"text": text,
|
||||
"entities": entities
|
||||
}
|
||||
return {"text": text, "entities": entities}
|
||||
|
||||
Reference in New Issue
Block a user