fix: NER training loop

This commit is contained in:
2025-08-17 14:15:12 +02:00
parent 3122c92f5e
commit 6faf9f355e
+17 -14
View File
@@ -2,6 +2,7 @@ import ast
import json import json
import logging import logging
import os import os
import random
from pathlib import Path from pathlib import Path
from typing import Dict, Any, List, Tuple from typing import Dict, Any, List, Tuple
@@ -190,32 +191,34 @@ class NameModel:
# Initialize the model # Initialize the model
self.nlp.initialize() self.nlp.initialize()
optimizer = self.nlp.resume_training()
# Training loop
losses_history = [] losses_history = []
for epoch in range(epochs): for epoch in range(epochs):
losses = {} losses = {}
# Create training examples
examples = [] examples = []
for text, annotations in tqdm(data, description="Create training examples"):
for text, annotations in tqdm(data, desc="Create training examples"):
doc = self.nlp.make_doc(text) doc = self.nlp.make_doc(text)
example = Example.from_dict(doc, annotations) examples.append(Example.from_dict(doc, annotations))
examples.append(example)
# Shuffle examples each epoch (important!)
random.shuffle(examples)
# Train in batches # Train in batches
batches = minibatch(examples, size=batch_size) batches = minibatch(examples, size=batch_size)
for batch in batches: for batch in batches:
self.nlp.update( batch_losses = {}
batch, losses=losses, drop=dropout_rate, sgd=self.nlp.create_optimizer() self.nlp.update(batch, losses=batch_losses, drop=dropout_rate, sgd=optimizer)
) logging.info(f"Training batch with {len(batch)} examples, current losses: {batch_losses}")
logging.info(f"Training batch with {len(batch)} examples, current losses: {losses}")
# Accumulate into total losses dict
for k, v in batch_losses.items():
losses[k] = losses.get(k, 0.0) + v
del batches # free memory del batches # free memory
epoch_loss = losses.get("ner", 0) losses_history.append(losses.get("ner", 0))
losses_history.append(epoch_loss) logging.info(f"Epoch {epoch+1}/{epochs}, Total Loss: {losses['ner']:.4f}")
logging.info(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss:.4f}")
# Store training statistics # Store training statistics
self.training_stats = { self.training_stats = {