169 lines
6.5 KiB
Python
169 lines
6.5 KiB
Python
import logging
|
|
import time
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
from typing import Dict, Optional
|
|
|
|
import ollama
|
|
import pandas as pd
|
|
from pydantic import ValidationError, BaseModel
|
|
|
|
from core.config.pipeline_config import PipelineConfig
|
|
from core.utils.prompt_manager import PromptManager
|
|
from core.utils.rate_limiter import RateLimiter
|
|
from core.utils.rate_limiter import RateLimitConfig
|
|
from processing.batch.batch_config import BatchConfig
|
|
from processing.steps import PipelineStep
|
|
|
|
|
|
class NameAnnotation(BaseModel):
|
|
"""Model for name annotation results"""
|
|
|
|
identified_name: Optional[str]
|
|
identified_surname: Optional[str]
|
|
|
|
|
|
class LLMAnnotationStep(PipelineStep):
|
|
"""Configuration-driven LLM annotation step"""
|
|
|
|
def __init__(self, pipeline_config: PipelineConfig):
|
|
# Create custom batch config for LLM processing
|
|
batch_config = BatchConfig(
|
|
batch_size=pipeline_config.processing.batch_size,
|
|
max_workers=min(
|
|
pipeline_config.llm.max_concurrent_requests, pipeline_config.processing.max_workers
|
|
),
|
|
checkpoint_interval=pipeline_config.processing.checkpoint_interval,
|
|
use_multiprocessing=pipeline_config.processing.use_multiprocessing,
|
|
)
|
|
super().__init__("llm_annotation", pipeline_config, batch_config)
|
|
|
|
self.prompt = PromptManager(pipeline_config).load_prompt()
|
|
self.rate_limiter = (
|
|
self._create_rate_limiter() if pipeline_config.llm.enable_rate_limiting else None
|
|
)
|
|
|
|
# Statistics
|
|
self.successful_requests = 0
|
|
self.failed_requests = 0
|
|
self.total_retry_attempts = 0
|
|
|
|
# Setup logging
|
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
|
|
|
def _create_rate_limiter(self):
|
|
"""Create rate limiter based on configuration"""
|
|
rate_config = RateLimitConfig(
|
|
requests_per_minute=self.pipeline_config.llm.requests_per_minute,
|
|
requests_per_second=self.pipeline_config.llm.requests_per_second,
|
|
)
|
|
return RateLimiter(rate_config)
|
|
|
|
def analyze_name_with_retry(self, client: ollama.Client, name: str, row_id: int) -> Dict:
|
|
"""Analyze a name with retry logic and rate limiting"""
|
|
for attempt in range(self.pipeline_config.llm.retry_attempts):
|
|
try:
|
|
# Apply rate limiting if enabled
|
|
if self.rate_limiter:
|
|
self.rate_limiter.wait_if_needed()
|
|
|
|
start_time = time.time()
|
|
response = client.chat(
|
|
model=self.pipeline_config.llm.model_name,
|
|
messages=[
|
|
{"role": "system", "content": self.prompt},
|
|
{"role": "user", "content": name},
|
|
],
|
|
format=NameAnnotation.model_json_schema(),
|
|
)
|
|
elapsed_time = time.time() - start_time
|
|
|
|
if elapsed_time > self.pipeline_config.llm.timeout_seconds:
|
|
raise TimeoutError(
|
|
f"Request took {elapsed_time:.2f}s, exceeding {self.pipeline_config.llm.timeout_seconds}s timeout"
|
|
)
|
|
|
|
annotation = NameAnnotation.model_validate_json(response.message.content)
|
|
result = {
|
|
**annotation.model_dump(),
|
|
"annotated": 1,
|
|
"processing_time": elapsed_time,
|
|
"attempts": attempt + 1,
|
|
}
|
|
|
|
self.successful_requests += 1
|
|
if attempt > 0:
|
|
self.total_retry_attempts += attempt
|
|
|
|
return result
|
|
|
|
except (ValidationError, TimeoutError, Exception) as e:
|
|
logging.warning(
|
|
f"Error analyzing '{name}' (attempt {attempt + 1}/{self.pipeline_config.llm.retry_attempts}): {e}"
|
|
)
|
|
|
|
# Exponential backoff with jitter
|
|
if attempt < self.pipeline_config.llm.retry_attempts - 1:
|
|
wait_time = (2**attempt) + (time.time() % 1)
|
|
time.sleep(min(wait_time, 10))
|
|
|
|
self.failed_requests += 1
|
|
return {
|
|
"identified_name": None,
|
|
"identified_surname": None,
|
|
"annotated": 0,
|
|
"processing_time": 0,
|
|
"attempts": self.pipeline_config.llm.retry_attempts,
|
|
"failed": True,
|
|
}
|
|
|
|
def process_batch(self, batch: pd.DataFrame, batch_id: int) -> pd.DataFrame:
|
|
"""Process batch with LLM annotation"""
|
|
unannotated_mask = batch.get("annotated", 0) == 0
|
|
unannotated_entries = batch[unannotated_mask]
|
|
|
|
if unannotated_entries.empty:
|
|
logging.info(f"Batch {batch_id}: No entries to annotate")
|
|
return batch
|
|
|
|
logging.info(f"Batch {batch_id}: Annotating {len(unannotated_entries)} entries")
|
|
|
|
batch = batch.copy()
|
|
client = ollama.Client()
|
|
|
|
# Process with controlled concurrency
|
|
max_workers = self.pipeline_config.llm.max_concurrent_requests
|
|
|
|
if len(unannotated_entries) == 1 or max_workers == 1:
|
|
# Sequential processing
|
|
for idx, row in unannotated_entries.iterrows():
|
|
result = self.analyze_name_with_retry(client, row["name"], idx)
|
|
for field, value in result.items():
|
|
if field not in ["failed"]:
|
|
batch.loc[idx, field] = value
|
|
else:
|
|
# Concurrent processing
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
future_to_idx = {}
|
|
|
|
for idx, row in unannotated_entries.iterrows():
|
|
future = executor.submit(self.analyze_name_with_retry, client, row["name"], idx)
|
|
future_to_idx[future] = idx
|
|
|
|
for future in as_completed(future_to_idx):
|
|
idx = future_to_idx[future]
|
|
try:
|
|
result = future.result()
|
|
for field, value in result.items():
|
|
if field not in ["failed"]:
|
|
batch.loc[idx, field] = value
|
|
except Exception as e:
|
|
logging.error(f"Failed to process row {idx}: {e}")
|
|
batch.loc[idx, "annotated"] = 0
|
|
|
|
# Ensure proper data types
|
|
batch["annotated"] = (
|
|
pd.to_numeric(batch["annotated"], errors="coerce").fillna(0).astype("Int8")
|
|
)
|
|
|
|
return batch
|