146 lines
5.2 KiB
Python
146 lines
5.2 KiB
Python
import json
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Optional, Union, Dict, Any
|
|
|
|
import yaml
|
|
|
|
from core.config.pipeline_config import PipelineConfig
|
|
from core.config.project_paths import ProjectPaths
|
|
|
|
|
|
class ConfigManager:
|
|
"""Centralized configuration management"""
|
|
|
|
def __init__(self, config_path: Optional[Union[str, Path]] = None):
|
|
self.config_path = config_path or self._find_config_file()
|
|
self._config: Optional[PipelineConfig] = None
|
|
self._setup_default_paths()
|
|
|
|
@classmethod
|
|
def _find_config_file(cls) -> Path:
|
|
"""Find configuration file in standard locations"""
|
|
possible_paths = [
|
|
Path.cwd() / "config" / "pipeline.yaml",
|
|
Path.cwd() / "config" / "pipeline.yml",
|
|
Path.cwd() / "pipeline.yaml",
|
|
Path(__file__).parent.parent.parent / "config" / "pipeline.yaml",
|
|
]
|
|
|
|
for path in possible_paths:
|
|
if path.exists():
|
|
return path
|
|
|
|
# Return default path if none found
|
|
return Path.cwd() / "config" / "pipeline.yaml"
|
|
|
|
def _setup_default_paths(self):
|
|
"""Setup default project paths"""
|
|
root_dir = Path(__file__).parent.parent.parent
|
|
self.default_paths = ProjectPaths(
|
|
root_dir=root_dir,
|
|
configs_dir=root_dir / "config",
|
|
data_dir=root_dir / "data" / "dataset",
|
|
models_dir=root_dir / "data" / "models",
|
|
outputs_dir=root_dir / "data" / "outputs",
|
|
logs_dir=root_dir / "data" / "logs",
|
|
checkpoints_dir=root_dir / "data" / "checkpoints",
|
|
)
|
|
|
|
def load_config(self, config_path: Optional[Path] = None) -> PipelineConfig:
|
|
"""Load configuration from file"""
|
|
if config_path:
|
|
self.config_path = config_path
|
|
|
|
if not self.config_path.exists():
|
|
logging.warning(f"Config file not found: {self.config_path}. Using defaults.")
|
|
return self._create_default_config()
|
|
|
|
try:
|
|
with open(self.config_path, "r") as f:
|
|
if self.config_path.suffix.lower() in [".yaml", ".yml"]:
|
|
config_data = yaml.safe_load(f)
|
|
else:
|
|
config_data = json.load(f)
|
|
|
|
# Ensure paths are properly set
|
|
if "paths" not in config_data:
|
|
config_data["paths"] = self.default_paths.dict()
|
|
|
|
self._config = PipelineConfig(**config_data)
|
|
return self._config
|
|
|
|
except Exception as e:
|
|
logging.error(f"Failed to load config from {self.config_path}: {e}")
|
|
return self._create_default_config()
|
|
|
|
def _create_default_config(self) -> PipelineConfig:
|
|
"""Create default configuration"""
|
|
return PipelineConfig(paths=self.default_paths)
|
|
|
|
def save_config(self, config: PipelineConfig, path: Optional[Path] = None):
|
|
"""Save configuration to file"""
|
|
save_path = path or self.config_path
|
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
config_dict = config.model_dump()
|
|
|
|
# Convert Path objects to strings for serialization
|
|
if "paths" in config_dict:
|
|
for key, value in config_dict["paths"].items():
|
|
if isinstance(value, Path):
|
|
config_dict["paths"][key] = str(value)
|
|
|
|
try:
|
|
with open(save_path, "w") as f:
|
|
if save_path.suffix.lower() in [".yaml", ".yml"]:
|
|
yaml.dump(config_dict, f, default_flow_style=False, indent=2)
|
|
else:
|
|
json.dump(config_dict, f, indent=2)
|
|
|
|
logging.info(f"Configuration saved to {save_path}")
|
|
|
|
except Exception as e:
|
|
logging.error(f"Failed to save config to {save_path}: {e}")
|
|
|
|
def get_config(self) -> PipelineConfig:
|
|
"""Get current configuration, loading if necessary"""
|
|
if self._config is None:
|
|
self._config = self.load_config()
|
|
return self._config
|
|
|
|
def update_config(self, updates: Dict[str, Any]):
|
|
"""Update configuration with new values"""
|
|
config = self.get_config()
|
|
|
|
# Deep update configuration
|
|
config_dict = config.model_dump()
|
|
self._deep_update(config_dict, updates)
|
|
|
|
self._config = PipelineConfig(**config_dict)
|
|
|
|
def _deep_update(self, base_dict: Dict, update_dict: Dict):
|
|
"""Recursively update nested dictionaries"""
|
|
for key, value in update_dict.items():
|
|
if key in base_dict and isinstance(base_dict[key], dict) and isinstance(value, dict):
|
|
self._deep_update(base_dict[key], value)
|
|
else:
|
|
base_dict[key] = value
|
|
|
|
def get_environment_config(self, env: str) -> PipelineConfig:
|
|
"""Load environment-specific configuration"""
|
|
env_config_path = self.config_path.parent / f"pipeline.{env}.yaml"
|
|
|
|
if env_config_path.exists():
|
|
base_config = self.load_config()
|
|
env_config = self.load_config(env_config_path)
|
|
|
|
# Merge configurations
|
|
base_dict = base_config.dict()
|
|
env_dict = env_config.dict()
|
|
self._deep_update(base_dict, env_dict)
|
|
|
|
return PipelineConfig(**base_dict)
|
|
|
|
return self.get_config()
|