112 lines
4.2 KiB
Python
112 lines
4.2 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
from typing import List, Optional
|
|
|
|
import pandas as pd
|
|
|
|
from processing.batch.batch_config import BatchConfig
|
|
from core.config.pipeline_config import PipelineConfig
|
|
|
|
|
|
@dataclass
|
|
class PipelineState:
|
|
"""Tracks the state of pipeline execution"""
|
|
|
|
processed_batches: int = 0
|
|
total_batches: int = 0
|
|
failed_batches: List[int] = None
|
|
last_checkpoint: Optional[str] = None
|
|
|
|
def __post_init__(self):
|
|
if self.failed_batches is None:
|
|
self.failed_batches = []
|
|
|
|
|
|
class PipelineStep(ABC):
|
|
"""Abstract base class for pipeline steps"""
|
|
|
|
def __init__(
|
|
self, name: str, pipeline_config: PipelineConfig, batch_config: Optional[BatchConfig] = None
|
|
):
|
|
self.name = name
|
|
self.pipeline_config = pipeline_config
|
|
|
|
# Use provided batch_config or create default from pipeline config
|
|
if batch_config is None:
|
|
batch_config = BatchConfig(
|
|
batch_size=pipeline_config.processing.batch_size,
|
|
max_workers=pipeline_config.processing.max_workers,
|
|
checkpoint_interval=pipeline_config.processing.checkpoint_interval,
|
|
use_multiprocessing=pipeline_config.processing.use_multiprocessing,
|
|
)
|
|
self.batch_config = batch_config
|
|
self.state = PipelineState()
|
|
|
|
@abstractmethod
|
|
def process_batch(self, batch: pd.DataFrame, batch_id: int) -> pd.DataFrame:
|
|
"""Process a single batch of data"""
|
|
pass
|
|
|
|
def get_checkpoint_path(self, batch_id: int) -> str:
|
|
"""Get the checkpoint file path for a batch"""
|
|
checkpoint_dir = self.pipeline_config.paths.checkpoints_dir / self.name
|
|
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
return str(checkpoint_dir / f"batch_{batch_id:06d}.csv")
|
|
|
|
def get_state_path(self) -> str:
|
|
"""Get the state file path"""
|
|
state_dir = self.pipeline_config.paths.checkpoints_dir / self.name
|
|
state_dir.mkdir(parents=True, exist_ok=True)
|
|
return str(state_dir / "pipeline_state.json")
|
|
|
|
def save_state(self):
|
|
"""Save pipeline state to disk"""
|
|
state_file = self.get_state_path()
|
|
with open(state_file, "w") as f:
|
|
json.dump(
|
|
{
|
|
"processed_batches": self.state.processed_batches,
|
|
"total_batches": self.state.total_batches,
|
|
"failed_batches": self.state.failed_batches,
|
|
"last_checkpoint": self.state.last_checkpoint,
|
|
},
|
|
f,
|
|
)
|
|
|
|
def load_state(self) -> bool:
|
|
"""Load pipeline state from disk. Returns True if state was loaded."""
|
|
state_file = self.get_state_path()
|
|
if os.path.exists(state_file):
|
|
try:
|
|
with open(state_file, "r") as f:
|
|
state_data = json.load(f)
|
|
self.state.processed_batches = state_data.get("processed_batches", 0)
|
|
self.state.total_batches = state_data.get("total_batches", 0)
|
|
self.state.failed_batches = state_data.get("failed_batches", [])
|
|
self.state.last_checkpoint = state_data.get("last_checkpoint")
|
|
return True
|
|
except Exception as e:
|
|
logging.warning(f"Failed to load state: {e}")
|
|
return False
|
|
|
|
def batch_exists(self, batch_id: int) -> bool:
|
|
"""Check if a batch has already been processed (idempotency)"""
|
|
checkpoint_path = self.get_checkpoint_path(batch_id)
|
|
return os.path.exists(checkpoint_path)
|
|
|
|
def save_batch(self, batch: pd.DataFrame, batch_id: int):
|
|
"""Save processed batch to checkpoint"""
|
|
checkpoint_path = self.get_checkpoint_path(batch_id)
|
|
batch.to_csv(checkpoint_path, index=False)
|
|
logging.info(f"Saved batch {batch_id} to {checkpoint_path}")
|
|
|
|
def load_batch(self, batch_id: int) -> Optional[pd.DataFrame]:
|
|
"""Load processed batch from checkpoint"""
|
|
checkpoint_path = self.get_checkpoint_path(batch_id)
|
|
if os.path.exists(checkpoint_path):
|
|
return pd.read_csv(checkpoint_path)
|
|
return None
|