refactoring: add initial pipeline configuration and model classes
This commit is contained in:
@@ -1,109 +0,0 @@
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import ollama
|
||||
import pandas as pd
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from tqdm import tqdm
|
||||
from typing import Optional
|
||||
|
||||
from misc import load_prompt, load_csv_dataset, DATA_DIR, logging
|
||||
|
||||
|
||||
class NameAnalysis(BaseModel):
|
||||
identified_name: Optional[str]
|
||||
identified_surname: Optional[str]
|
||||
|
||||
|
||||
def analyze_name(client: ollama.Client, model: str, prompt: str, name: str) -> dict:
|
||||
"""
|
||||
Analyze a name using the specified model and prompt.
|
||||
Returns a dictionary with identified name, surname, and category.
|
||||
"""
|
||||
try:
|
||||
response = client.chat(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": name},
|
||||
],
|
||||
format=NameAnalysis.model_json_schema(),
|
||||
)
|
||||
analysis = NameAnalysis.model_validate_json(response.message.content)
|
||||
return analysis.model_dump()
|
||||
except ValidationError as ve:
|
||||
logging.warning(f"Validation error: {ve}")
|
||||
except Exception as e:
|
||||
logging.error(f"Unexpected error: {e}")
|
||||
return {"identified_name": None, "identified_surname": None}
|
||||
|
||||
|
||||
def save_checkpoint(df: pd.DataFrame):
|
||||
df.to_csv(os.path.join(DATA_DIR, "names_featured.csv"), index=False)
|
||||
logging.critical(f"Checkpoint saved")
|
||||
|
||||
|
||||
def build_updates(llm_model: str, df: pd.DataFrame, entries: pd.DataFrame) -> pd.DataFrame:
|
||||
BATCH_SIZE = 10
|
||||
|
||||
client = ollama.Client()
|
||||
prompt = load_prompt()
|
||||
updates = []
|
||||
|
||||
# Set logging level for HTTP client to reduce noise
|
||||
# This is useful to avoid excessive logging from the HTTP client used by Ollama
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
for idx, (row_idx, row) in enumerate(entries.iterrows(), 1):
|
||||
try:
|
||||
entry = analyze_name(client, llm_model, prompt, row["name"])
|
||||
entry["annotated"] = 1
|
||||
updates.append((row_idx, entry))
|
||||
logging.info(f"Analyzed: {row['name']} - {entry}")
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to analyze '{row['name']}': {e}")
|
||||
continue
|
||||
|
||||
if idx % BATCH_SIZE == 0 or idx == len(entries):
|
||||
update_df = pd.DataFrame.from_dict(dict(updates), orient="index")
|
||||
update_df["annotated"] = pd.to_numeric(update_df["annotated"], errors="coerce").fillna(0).astype("Int8")
|
||||
|
||||
df.update(update_df)
|
||||
save_checkpoint(df)
|
||||
updates.clear() # avoid re-applying same updates
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def main(llm_model: str = "llama3.2:3b"):
|
||||
df = pd.DataFrame(load_csv_dataset(os.path.join(DATA_DIR, "names_featured.csv")))
|
||||
|
||||
# Safely cast 'annotated' column to Int8, handling float-like strings (e.g., '1.0')
|
||||
df["annotated"] = pd.to_numeric(df["annotated"], errors="coerce").fillna(0).astype(float).astype("Int8")
|
||||
|
||||
entries = df[df["annotated"] == 0]
|
||||
if entries.empty:
|
||||
logging.info("No names to analyze.")
|
||||
return
|
||||
|
||||
logging.info(f"Found {len(entries)} names to analyze.")
|
||||
df = build_updates(llm_model, df, entries)
|
||||
save_checkpoint(df)
|
||||
logging.info("Analysis complete.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Analyze names using an LLM model.")
|
||||
parser.add_argument(
|
||||
"--llm_model",
|
||||
type=str,
|
||||
default="mistral:7b",
|
||||
help="Ollama model name to use (default: mistral:7b)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
main(llm_model=args.llm_model)
|
||||
except Exception as e:
|
||||
logging.error(f"Fatal error: {e}", exc_info=True)
|
||||
@@ -0,0 +1,11 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchConfig:
|
||||
"""Configuration for batch processing"""
|
||||
|
||||
batch_size: int = 1000
|
||||
max_workers: int = 4
|
||||
checkpoint_interval: int = 5 # Save checkpoint every N batches
|
||||
use_multiprocessing: bool = False # Use ProcessPoolExecutor instead of ThreadPoolExecutor
|
||||
@@ -0,0 +1,102 @@
|
||||
import logging
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
||||
from typing import Iterator
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from processing.batch.batch_config import BatchConfig
|
||||
from processing.steps import PipelineStep
|
||||
|
||||
|
||||
class BatchProcessor:
|
||||
"""Handles batch processing with concurrency and checkpointing"""
|
||||
|
||||
def __init__(self, config: BatchConfig):
|
||||
self.config = config
|
||||
|
||||
def create_batches(self, df: pd.DataFrame) -> Iterator[tuple[pd.DataFrame, int]]:
|
||||
"""Create batches from DataFrame"""
|
||||
total_rows = len(df)
|
||||
batch_size = self.config.batch_size
|
||||
|
||||
for i in range(0, total_rows, batch_size):
|
||||
batch = df.iloc[i : i + batch_size].copy()
|
||||
batch_id = i // batch_size
|
||||
yield batch, batch_id
|
||||
|
||||
def process_sequential(self, step: PipelineStep, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Process batches sequentially"""
|
||||
results = []
|
||||
|
||||
for batch, batch_id in self.create_batches(df):
|
||||
if step.batch_exists(batch_id):
|
||||
logging.info(f"Batch {batch_id} already processed, loading from checkpoint")
|
||||
processed_batch = step.load_batch(batch_id)
|
||||
else:
|
||||
try:
|
||||
processed_batch = step.process_batch(batch, batch_id)
|
||||
step.save_batch(processed_batch, batch_id)
|
||||
step.state.processed_batches += 1
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to process batch {batch_id}: {e}")
|
||||
step.state.failed_batches.append(batch_id)
|
||||
continue
|
||||
|
||||
results.append(processed_batch)
|
||||
|
||||
# Save state periodically
|
||||
if batch_id % self.config.checkpoint_interval == 0:
|
||||
step.save_state()
|
||||
|
||||
return pd.concat(results, ignore_index=True) if results else pd.DataFrame()
|
||||
|
||||
def process_concurrent(self, step: PipelineStep, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Process batches concurrently"""
|
||||
executor_class = (
|
||||
ProcessPoolExecutor if self.config.use_multiprocessing else ThreadPoolExecutor
|
||||
)
|
||||
results = {}
|
||||
|
||||
with executor_class(max_workers=self.config.max_workers) as executor:
|
||||
# Submit all batches
|
||||
future_to_batch = {}
|
||||
for batch, batch_id in self.create_batches(df):
|
||||
if step.batch_exists(batch_id):
|
||||
logging.info(f"Batch {batch_id} already processed, loading from checkpoint")
|
||||
results[batch_id] = step.load_batch(batch_id)
|
||||
else:
|
||||
future = executor.submit(step.process_batch, batch, batch_id)
|
||||
future_to_batch[future] = (batch_id, batch)
|
||||
|
||||
# Collect results as they complete
|
||||
for future in as_completed(future_to_batch):
|
||||
batch_id, batch = future_to_batch[future]
|
||||
try:
|
||||
processed_batch = future.result()
|
||||
step.save_batch(processed_batch, batch_id)
|
||||
results[batch_id] = processed_batch
|
||||
step.state.processed_batches += 1
|
||||
logging.info(f"Completed batch {batch_id}")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to process batch {batch_id}: {e}")
|
||||
step.state.failed_batches.append(batch_id)
|
||||
|
||||
# Reassemble results in order
|
||||
ordered_results = []
|
||||
for batch_id in sorted(results.keys()):
|
||||
ordered_results.append(results[batch_id])
|
||||
|
||||
step.save_state()
|
||||
return pd.concat(ordered_results, ignore_index=True) if ordered_results else pd.DataFrame()
|
||||
|
||||
def process(self, step: PipelineStep, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Process data using the configured strategy"""
|
||||
step.state.total_batches = (len(df) + self.config.batch_size - 1) // self.config.batch_size
|
||||
step.load_state()
|
||||
|
||||
logging.info(f"Starting {step.name} with {step.state.total_batches} batches")
|
||||
|
||||
if self.config.max_workers == 1:
|
||||
return self.process_sequential(step, df)
|
||||
else:
|
||||
return self.process_concurrent(step, df)
|
||||
@@ -0,0 +1,80 @@
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class DatasetAnalyzer:
|
||||
"""Analyze dataset statistics and quality"""
|
||||
|
||||
def __init__(self, filepath: str):
|
||||
self.filepath = filepath
|
||||
self.df = None
|
||||
|
||||
def load_data(self) -> bool:
|
||||
"""Load dataset for analysis"""
|
||||
try:
|
||||
self.df = pd.read_csv(self.filepath)
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to load {self.filepath}: {e}")
|
||||
return False
|
||||
|
||||
def analyze_completion(self) -> Dict:
|
||||
"""Analyze annotation completion status"""
|
||||
if self.df is None:
|
||||
return {}
|
||||
|
||||
total_rows = len(self.df)
|
||||
|
||||
# Check annotation status
|
||||
if "annotated" in self.df.columns:
|
||||
annotated_count = (self.df["annotated"] == 1).sum()
|
||||
unannotated_count = (self.df["annotated"] == 0).sum()
|
||||
else:
|
||||
annotated_count = 0
|
||||
unannotated_count = total_rows
|
||||
|
||||
# Analyze name completeness
|
||||
complete_names = 0
|
||||
if "identified_name" in self.df.columns and "identified_surname" in self.df.columns:
|
||||
complete_names = (
|
||||
(self.df["identified_name"].notna()) & (self.df["identified_surname"].notna())
|
||||
).sum()
|
||||
|
||||
return {
|
||||
"total_rows": total_rows,
|
||||
"annotated_rows": annotated_count,
|
||||
"unannotated_rows": unannotated_count,
|
||||
"annotation_percentage": (annotated_count / total_rows * 100) if total_rows > 0 else 0,
|
||||
"complete_names": complete_names,
|
||||
"completeness_percentage": (complete_names / total_rows * 100) if total_rows > 0 else 0,
|
||||
}
|
||||
|
||||
def analyze_quality(self) -> Dict:
|
||||
"""Analyze data quality metrics"""
|
||||
if self.df is None:
|
||||
return {}
|
||||
|
||||
quality_metrics = {}
|
||||
|
||||
# Missing values
|
||||
missing_data = self.df.isnull().sum()
|
||||
quality_metrics["missing_values"] = missing_data.to_dict()
|
||||
|
||||
# Name length distribution
|
||||
if "name" in self.df.columns:
|
||||
name_lengths = self.df["name"].str.len()
|
||||
quality_metrics["name_length"] = {
|
||||
"mean": name_lengths.mean(),
|
||||
"median": name_lengths.median(),
|
||||
"min": name_lengths.min(),
|
||||
"max": name_lengths.max(),
|
||||
}
|
||||
|
||||
# Word count distribution
|
||||
if "words" in self.df.columns:
|
||||
word_counts = self.df["words"].value_counts().sort_index()
|
||||
quality_metrics["word_distribution"] = word_counts.to_dict()
|
||||
|
||||
return quality_metrics
|
||||
@@ -0,0 +1,179 @@
|
||||
import json
|
||||
import logging
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict
|
||||
|
||||
from core.config.config_manager import ConfigManager
|
||||
from core.config.project_paths import ProjectPaths
|
||||
|
||||
|
||||
class PipelineMonitor:
|
||||
"""Monitor and manage pipeline execution"""
|
||||
|
||||
def __init__(self, paths: Optional[ProjectPaths] = None):
|
||||
if paths is None:
|
||||
# Use default configuration if none provided
|
||||
config_manager = ConfigManager()
|
||||
paths = config_manager.default_paths
|
||||
|
||||
self.paths = paths
|
||||
self.checkpoint_dir = paths.checkpoints_dir
|
||||
self.steps = ["data_cleaning", "feature_extraction", "llm_annotation", "data_splitting"]
|
||||
|
||||
def get_step_status(self, step_name: str) -> Dict:
|
||||
"""Get status of a specific pipeline step"""
|
||||
step_dir = self.checkpoint_dir / step_name
|
||||
state_file = step_dir / "pipeline_state.json"
|
||||
|
||||
if not state_file.exists():
|
||||
return {
|
||||
"step": step_name,
|
||||
"status": "not_started",
|
||||
"processed_batches": 0,
|
||||
"total_batches": 0,
|
||||
"failed_batches": 0,
|
||||
"completion_percentage": 0.0,
|
||||
}
|
||||
|
||||
try:
|
||||
with open(state_file, "r") as f:
|
||||
state = json.load(f)
|
||||
|
||||
processed = state.get("processed_batches", 0)
|
||||
total = state.get("total_batches", 0)
|
||||
failed = len(state.get("failed_batches", []))
|
||||
|
||||
if total == 0:
|
||||
completion = 0.0
|
||||
status = "not_started"
|
||||
elif processed >= total:
|
||||
completion = 100.0
|
||||
status = "completed" if failed == 0 else "completed_with_errors"
|
||||
else:
|
||||
completion = (processed / total) * 100
|
||||
status = "in_progress"
|
||||
|
||||
return {
|
||||
"step": step_name,
|
||||
"status": status,
|
||||
"processed_batches": processed,
|
||||
"total_batches": total,
|
||||
"failed_batches": failed,
|
||||
"completion_percentage": completion,
|
||||
"last_checkpoint": state.get("last_checkpoint"),
|
||||
"failed_batch_ids": state.get("failed_batches", []),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error reading state for {step_name}: {e}")
|
||||
return {"step": step_name, "status": "error", "error": str(e)}
|
||||
|
||||
def get_pipeline_status(self) -> Dict:
|
||||
"""Get overall pipeline status"""
|
||||
step_statuses = {}
|
||||
overall_status = "not_started"
|
||||
total_completion = 0.0
|
||||
|
||||
for step in self.steps:
|
||||
status = self.get_step_status(step)
|
||||
step_statuses[step] = status
|
||||
|
||||
if status["status"] == "error":
|
||||
overall_status = "error"
|
||||
elif status["status"] in ["in_progress"]:
|
||||
overall_status = "in_progress"
|
||||
elif status["status"] == "completed_with_errors":
|
||||
overall_status = "completed_with_errors"
|
||||
|
||||
total_completion += status.get("completion_percentage", 0)
|
||||
|
||||
avg_completion = total_completion / len(self.steps)
|
||||
|
||||
if avg_completion >= 100 and overall_status not in ["error", "completed_with_errors"]:
|
||||
overall_status = "completed"
|
||||
|
||||
return {
|
||||
"overall_status": overall_status,
|
||||
"overall_completion": avg_completion,
|
||||
"steps": step_statuses,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
def print_status(self, detailed: bool = False):
|
||||
"""Print pipeline status in a human-readable format"""
|
||||
status = self.get_pipeline_status()
|
||||
|
||||
print("\n=== Pipeline Status ===")
|
||||
print(f"Overall Status: {status['overall_status'].upper()}")
|
||||
print(f"Overall Completion: {status['overall_completion']:.1f}%")
|
||||
print(f"Last Updated: {status['timestamp']}")
|
||||
print()
|
||||
|
||||
for step_name, step_status in status["steps"].items():
|
||||
print(f"{step_name.replace('_', ' ').title()}:")
|
||||
print(f" Status: {step_status['status']}")
|
||||
print(f" Progress: {step_status['completion_percentage']:.1f}%")
|
||||
print(f" Batches: {step_status['processed_batches']}/{step_status['total_batches']}")
|
||||
|
||||
if step_status["failed_batches"] > 0:
|
||||
print(f" Failed Batches: {step_status['failed_batches']}")
|
||||
|
||||
if detailed and "failed_batch_ids" in step_status:
|
||||
print(f" Failed Batch IDs: {step_status['failed_batch_ids']}")
|
||||
|
||||
print()
|
||||
|
||||
def count_checkpoint_files(self) -> Dict:
|
||||
"""Count checkpoint files for each step"""
|
||||
counts = {}
|
||||
total_size = 0
|
||||
|
||||
for step in self.steps:
|
||||
step_dir = self.checkpoint_dir / step
|
||||
if step_dir.exists():
|
||||
csv_files = list(step_dir.glob("*.csv"))
|
||||
step_size = sum(f.stat().st_size for f in csv_files)
|
||||
counts[step] = {"files": len(csv_files), "size_mb": step_size / (1024 * 1024)}
|
||||
total_size += step_size
|
||||
else:
|
||||
counts[step] = {"files": 0, "size_mb": 0}
|
||||
|
||||
counts["total_size_mb"] = total_size / (1024 * 1024)
|
||||
return counts
|
||||
|
||||
def clean_step_checkpoints(self, step_name: str, keep_last: int = 1):
|
||||
"""Clean checkpoint files for a specific step"""
|
||||
step_dir = self.checkpoint_dir / step_name
|
||||
|
||||
if not step_dir.exists():
|
||||
logging.info(f"No checkpoints found for {step_name}")
|
||||
return
|
||||
|
||||
csv_files = sorted(step_dir.glob("batch_*.csv"))
|
||||
|
||||
if len(csv_files) <= keep_last:
|
||||
logging.info(f"Only {len(csv_files)} checkpoint files for {step_name}, keeping all")
|
||||
return
|
||||
|
||||
files_to_delete = csv_files[:-keep_last] if keep_last > 0 else csv_files
|
||||
|
||||
for file_path in files_to_delete:
|
||||
try:
|
||||
file_path.unlink()
|
||||
logging.info(f"Deleted {file_path}")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to delete {file_path}: {e}")
|
||||
|
||||
def reset_step(self, step_name: str):
|
||||
"""Reset a pipeline step by removing its checkpoints and state"""
|
||||
step_dir = self.checkpoint_dir / step_name
|
||||
|
||||
if step_dir.exists():
|
||||
try:
|
||||
shutil.rmtree(step_dir)
|
||||
logging.info(f"Reset step: {step_name}")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to reset {step_name}: {e}")
|
||||
else:
|
||||
logging.info(f"Step {step_name} has no checkpoints to reset")
|
||||
@@ -0,0 +1,57 @@
|
||||
import logging
|
||||
|
||||
import pandas as pd
|
||||
from typing import Dict, Any
|
||||
import time
|
||||
|
||||
from processing.batch.batch_config import BatchConfig
|
||||
from processing.batch.batch_processor import BatchProcessor
|
||||
from processing.steps import PipelineStep
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""Main pipeline orchestrator"""
|
||||
|
||||
def __init__(self, config: BatchConfig):
|
||||
self.config = config
|
||||
self.processor = BatchProcessor(config)
|
||||
self.steps = []
|
||||
|
||||
def add_step(self, step: PipelineStep):
|
||||
"""Add a processing step to the pipeline"""
|
||||
self.steps.append(step)
|
||||
|
||||
def run(self, input_data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Run the complete pipeline"""
|
||||
current_data = input_data.copy()
|
||||
|
||||
for step in self.steps:
|
||||
logging.info(f"Running pipeline step: {step.name}")
|
||||
start_time = time.time()
|
||||
|
||||
current_data = self.processor.process(step, current_data)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
logging.info(f"Completed {step.name} in {elapsed_time:.2f} seconds")
|
||||
|
||||
if step.state.failed_batches:
|
||||
logging.warning(
|
||||
f"Step {step.name} had {len(step.state.failed_batches)} failed batches"
|
||||
)
|
||||
|
||||
return current_data
|
||||
|
||||
def get_progress(self) -> Dict[str, Any]:
|
||||
"""Get progress information for all steps"""
|
||||
progress = {}
|
||||
for step in self.steps:
|
||||
progress[step.name] = {
|
||||
"processed_batches": step.state.processed_batches,
|
||||
"total_batches": step.state.total_batches,
|
||||
"failed_batches": len(step.state.failed_batches),
|
||||
"completion_percentage": (
|
||||
step.state.processed_batches / max(1, step.state.total_batches)
|
||||
)
|
||||
* 100,
|
||||
}
|
||||
return progress
|
||||
@@ -1,119 +0,0 @@
|
||||
import os
|
||||
import argparse
|
||||
import pandas as pd
|
||||
from misc import DATA_DIR, REGION_MAPPING, logging
|
||||
|
||||
|
||||
def clean(filepath) -> pd.DataFrame:
|
||||
"""
|
||||
Clean the CSV file by removing null bytes, non-breaking spaces, and extra spaces.
|
||||
Also, it attempts to read the file with different encodings to handle potential encoding issues.
|
||||
"""
|
||||
|
||||
encodings = ['utf-8', 'utf-16', 'latin1']
|
||||
for enc in encodings:
|
||||
try:
|
||||
logging.info(f"Trying to read {filepath} with encoding: {enc}")
|
||||
# Use chunked reading to handle large files
|
||||
chunks = pd.read_csv(filepath, encoding=enc, chunksize=100_000, on_bad_lines='skip')
|
||||
cleaned_chunks = []
|
||||
|
||||
for chunk in chunks:
|
||||
# Drop rows with essential missing values early
|
||||
chunk = chunk.dropna(subset=['name', 'sex', 'region'])
|
||||
|
||||
# Clean string columns in-place
|
||||
for col in chunk.select_dtypes(include='object').columns:
|
||||
chunk[col] = (
|
||||
chunk[col]
|
||||
.astype(str)
|
||||
.str.replace('\x00', ' ', regex=False)
|
||||
.str.replace('\u00a0', ' ', regex=False)
|
||||
.str.replace(' +', ' ', regex=True)
|
||||
.str.strip()
|
||||
.str.lower()
|
||||
)
|
||||
|
||||
cleaned_chunks.append(chunk)
|
||||
|
||||
df = pd.concat(cleaned_chunks, ignore_index=True)
|
||||
df.to_csv(filepath, index=False, encoding='utf-8')
|
||||
logging.info(f"Successfully read with encoding: {enc}")
|
||||
return df
|
||||
except Exception:
|
||||
continue
|
||||
raise UnicodeDecodeError(f"Unable to decode {filepath} with common encodings.")
|
||||
|
||||
|
||||
def process(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Process the DataFrame to extract features and clean data.
|
||||
This includes counting words, calculating name length, and extracting probable native names and surnames.
|
||||
Also maps regions to provinces based on REGION_MAPPING.
|
||||
"""
|
||||
|
||||
logging.info("Preprocessing names")
|
||||
df['words'] = df['name'].str.count(' ') + 1
|
||||
df['length'] = df['name'].str.replace(' ', '', regex=False).str.len()
|
||||
df['year'] = df['year'].astype(int)
|
||||
|
||||
# Calculate probable_native and probable_surname
|
||||
name_split = df['name'].str.split()
|
||||
df['probable_native'] = name_split.apply(lambda x: ' '.join(x[:-1]) if len(x) > 1 else '')
|
||||
df['probable_surname'] = name_split.apply(lambda x: x[-1] if x else '')
|
||||
df['identified_category'] = df['words'].apply(lambda x: 'compose' if x > 3 else 'simple')
|
||||
df['identified_name'] = None
|
||||
df['identified_surname'] = None
|
||||
df['annotated'] = 0
|
||||
|
||||
# We can assume that if a name has exactly 3 words, the first two are the native name and the last is the surname
|
||||
# This is a common pattern in Congolese names
|
||||
three_word_mask = df['words'] == 3
|
||||
df.loc[three_word_mask, 'identified_name'] = df.loc[three_word_mask, 'probable_native']
|
||||
df.loc[three_word_mask, 'identified_surname'] = df.loc[three_word_mask, 'probable_surname']
|
||||
df.loc[three_word_mask, 'annotated'] = 1
|
||||
|
||||
logging.info("Mapping regions to provinces")
|
||||
df['province'] = df['region'].map(lambda r: REGION_MAPPING.get(r, ('AUTRES', 'AUTRES'))[1])
|
||||
df['province'] = df['province'].str.lower()
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def save_artifacts(df: pd.DataFrame, split_eval: bool = True, split_by_sex: bool = True) -> None:
|
||||
"""
|
||||
Splits the input DataFrame into evaluation and featured datasets, saves them as CSV files,
|
||||
and additionally saves separate CSV files for male and female entries if requested.
|
||||
"""
|
||||
|
||||
if split_eval:
|
||||
logging.info("Saving evaluation and featured datasets")
|
||||
eval_idx = df.sample(frac=0.2, random_state=42).index
|
||||
df_evaluation = df.loc[eval_idx]
|
||||
df_featured = df.drop(index=eval_idx)
|
||||
df_evaluation.to_csv(os.path.join(DATA_DIR, 'names_evaluation.csv'), index=False)
|
||||
df_featured.to_csv(os.path.join(DATA_DIR, 'names_featured.csv'), index=False)
|
||||
else:
|
||||
df.to_csv(os.path.join(DATA_DIR, 'names_featured.csv'), index=False)
|
||||
|
||||
if split_by_sex:
|
||||
logging.info("Saving by sex")
|
||||
df[df['sex'] == 'm'].to_csv(os.path.join(DATA_DIR, 'names_males.csv'), index=False)
|
||||
df[df['sex'] == 'f'].to_csv(os.path.join(DATA_DIR, 'names_females.csv'), index=False)
|
||||
|
||||
|
||||
def main(split_eval: bool = True, split_by_sex: bool = True):
|
||||
df = process(clean(os.path.join(DATA_DIR, 'names.csv')))
|
||||
save_artifacts(df, split_eval=split_eval, split_by_sex=split_by_sex)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="Prepare name datasets with optional splits.")
|
||||
|
||||
parser.add_argument('--split_eval', action='store_true', default=True, help="Split into evaluation and featured datasets (default: True)")
|
||||
parser.add_argument('--no-split_eval', action='store_false', dest='split_eval', help="Do not split into evaluation and featured datasets")
|
||||
parser.add_argument('--split_by_sex', action='store_true', default=True, help="Split by sex into male/female datasets (default: True)")
|
||||
parser.add_argument('--no-split_by_sex', action='store_false', dest='split_by_sex', help="Do not split by sex into male/female datasets")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(split_eval=args.split_eval, split_by_sex=args.split_by_sex)
|
||||
@@ -0,0 +1,111 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from processing.batch.batch_config import BatchConfig
|
||||
from core.config.pipeline_config import PipelineConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineState:
|
||||
"""Tracks the state of pipeline execution"""
|
||||
|
||||
processed_batches: int = 0
|
||||
total_batches: int = 0
|
||||
failed_batches: List[int] = None
|
||||
last_checkpoint: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.failed_batches is None:
|
||||
self.failed_batches = []
|
||||
|
||||
|
||||
class PipelineStep(ABC):
|
||||
"""Abstract base class for pipeline steps"""
|
||||
|
||||
def __init__(
|
||||
self, name: str, pipeline_config: PipelineConfig, batch_config: Optional[BatchConfig] = None
|
||||
):
|
||||
self.name = name
|
||||
self.pipeline_config = pipeline_config
|
||||
|
||||
# Use provided batch_config or create default from pipeline config
|
||||
if batch_config is None:
|
||||
batch_config = BatchConfig(
|
||||
batch_size=pipeline_config.processing.batch_size,
|
||||
max_workers=pipeline_config.processing.max_workers,
|
||||
checkpoint_interval=pipeline_config.processing.checkpoint_interval,
|
||||
use_multiprocessing=pipeline_config.processing.use_multiprocessing,
|
||||
)
|
||||
self.batch_config = batch_config
|
||||
self.state = PipelineState()
|
||||
|
||||
@abstractmethod
|
||||
def process_batch(self, batch: pd.DataFrame, batch_id: int) -> pd.DataFrame:
|
||||
"""Process a single batch of data"""
|
||||
pass
|
||||
|
||||
def get_checkpoint_path(self, batch_id: int) -> str:
|
||||
"""Get the checkpoint file path for a batch"""
|
||||
checkpoint_dir = self.pipeline_config.paths.checkpoints_dir / self.name
|
||||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
return str(checkpoint_dir / f"batch_{batch_id:06d}.csv")
|
||||
|
||||
def get_state_path(self) -> str:
|
||||
"""Get the state file path"""
|
||||
state_dir = self.pipeline_config.paths.checkpoints_dir / self.name
|
||||
state_dir.mkdir(parents=True, exist_ok=True)
|
||||
return str(state_dir / "pipeline_state.json")
|
||||
|
||||
def save_state(self):
|
||||
"""Save pipeline state to disk"""
|
||||
state_file = self.get_state_path()
|
||||
with open(state_file, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"processed_batches": self.state.processed_batches,
|
||||
"total_batches": self.state.total_batches,
|
||||
"failed_batches": self.state.failed_batches,
|
||||
"last_checkpoint": self.state.last_checkpoint,
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
def load_state(self) -> bool:
|
||||
"""Load pipeline state from disk. Returns True if state was loaded."""
|
||||
state_file = self.get_state_path()
|
||||
if os.path.exists(state_file):
|
||||
try:
|
||||
with open(state_file, "r") as f:
|
||||
state_data = json.load(f)
|
||||
self.state.processed_batches = state_data.get("processed_batches", 0)
|
||||
self.state.total_batches = state_data.get("total_batches", 0)
|
||||
self.state.failed_batches = state_data.get("failed_batches", [])
|
||||
self.state.last_checkpoint = state_data.get("last_checkpoint")
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to load state: {e}")
|
||||
return False
|
||||
|
||||
def batch_exists(self, batch_id: int) -> bool:
|
||||
"""Check if a batch has already been processed (idempotency)"""
|
||||
checkpoint_path = self.get_checkpoint_path(batch_id)
|
||||
return os.path.exists(checkpoint_path)
|
||||
|
||||
def save_batch(self, batch: pd.DataFrame, batch_id: int):
|
||||
"""Save processed batch to checkpoint"""
|
||||
checkpoint_path = self.get_checkpoint_path(batch_id)
|
||||
batch.to_csv(checkpoint_path, index=False)
|
||||
logging.info(f"Saved batch {batch_id} to {checkpoint_path}")
|
||||
|
||||
def load_batch(self, batch_id: int) -> Optional[pd.DataFrame]:
|
||||
"""Load processed batch from checkpoint"""
|
||||
checkpoint_path = self.get_checkpoint_path(batch_id)
|
||||
if os.path.exists(checkpoint_path):
|
||||
return pd.read_csv(checkpoint_path)
|
||||
return None
|
||||
@@ -0,0 +1,28 @@
|
||||
import logging
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from core.config.pipeline_config import PipelineConfig
|
||||
from core.utils.text_cleaner import TextCleaner
|
||||
from processing.steps import PipelineStep
|
||||
|
||||
|
||||
class DataCleaningStep(PipelineStep):
|
||||
"""Configuration-driven data cleaning step"""
|
||||
|
||||
def __init__(self, pipeline_config: PipelineConfig):
|
||||
super().__init__("data_cleaning", pipeline_config)
|
||||
self.text_cleaner = TextCleaner()
|
||||
self.required_columns = ["name", "sex", "region"]
|
||||
|
||||
def process_batch(self, batch: pd.DataFrame, batch_id: int) -> pd.DataFrame:
|
||||
"""Process a single batch for data cleaning"""
|
||||
logging.info(f"Cleaning batch {batch_id} with {len(batch)} rows")
|
||||
|
||||
# Drop rows with essential missing values
|
||||
batch = batch.dropna(subset=self.required_columns)
|
||||
|
||||
# Apply text cleaning
|
||||
batch = self.text_cleaner.clean_dataframe_text_columns(batch)
|
||||
|
||||
return batch
|
||||
@@ -0,0 +1,60 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from core.config.pipeline_config import PipelineConfig
|
||||
from processing.steps.feature_extraction_step import Gender
|
||||
from core.utils.data_loader import DataLoader
|
||||
|
||||
from processing.batch.batch_config import BatchConfig
|
||||
from processing.steps import PipelineStep
|
||||
|
||||
|
||||
class DataSplittingStep(PipelineStep):
|
||||
"""Configuration-driven data splitting step"""
|
||||
|
||||
def __init__(self, pipeline_config: PipelineConfig):
|
||||
batch_config = BatchConfig(
|
||||
batch_size=pipeline_config.processing.batch_size,
|
||||
max_workers=1, # No need for parallelism in splitting
|
||||
checkpoint_interval=pipeline_config.processing.checkpoint_interval,
|
||||
use_multiprocessing=False,
|
||||
)
|
||||
super().__init__("data_splitting", pipeline_config, batch_config)
|
||||
self.data_loader = DataLoader(pipeline_config)
|
||||
self.eval_indices = None
|
||||
|
||||
def determine_eval_indices(self, total_size: int) -> set:
|
||||
"""Determine evaluation indices consistently across batches"""
|
||||
if self.eval_indices is None:
|
||||
np.random.seed(self.pipeline_config.data.random_seed)
|
||||
eval_size = int(total_size * self.pipeline_config.data.evaluation_fraction)
|
||||
self.eval_indices = set(np.random.choice(total_size, size=eval_size, replace=False))
|
||||
return self.eval_indices
|
||||
|
||||
def process_batch(self, batch: pd.DataFrame, batch_id: int) -> pd.DataFrame:
|
||||
"""Process batch for data splitting - no modification needed"""
|
||||
return batch.copy()
|
||||
|
||||
def save_splits(self, df: pd.DataFrame) -> None:
|
||||
"""Save the split datasets based on configuration"""
|
||||
output_files = self.pipeline_config.data.output_files
|
||||
data_dir = self.pipeline_config.paths.data_dir
|
||||
|
||||
if self.pipeline_config.data.split_evaluation:
|
||||
eval_indices = self.determine_eval_indices(len(df))
|
||||
eval_mask = df.index.isin(eval_indices)
|
||||
|
||||
df_evaluation = df[eval_mask]
|
||||
df_featured = df[~eval_mask]
|
||||
|
||||
self.data_loader.save_csv(df_evaluation, data_dir / output_files["evaluation"])
|
||||
self.data_loader.save_csv(df_featured, data_dir / output_files["featured"])
|
||||
else:
|
||||
self.data_loader.save_csv(df, data_dir / output_files["featured"])
|
||||
|
||||
if self.pipeline_config.data.split_by_gender:
|
||||
df_males = df[df["sex"] == Gender.MALE.value]
|
||||
df_females = df[df["sex"] == Gender.FEMALE.value]
|
||||
|
||||
self.data_loader.save_csv(df_males, data_dir / output_files["males"])
|
||||
self.data_loader.save_csv(df_females, data_dir / output_files["females"])
|
||||
@@ -0,0 +1,99 @@
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from core.config.pipeline_config import PipelineConfig
|
||||
from core.utils.region_mapper import RegionMapper
|
||||
from processing.steps import PipelineStep
|
||||
|
||||
|
||||
class Gender(Enum):
|
||||
MALE = "m"
|
||||
FEMALE = "f"
|
||||
|
||||
|
||||
class NameCategory(Enum):
|
||||
SIMPLE = "simple"
|
||||
COMPOSE = "compose"
|
||||
|
||||
|
||||
class FeatureExtractionStep(PipelineStep):
|
||||
"""Configuration-driven feature extraction step"""
|
||||
|
||||
def __init__(self, pipeline_config: PipelineConfig):
|
||||
super().__init__("feature_extraction", pipeline_config)
|
||||
self.region_mapper = RegionMapper()
|
||||
|
||||
@classmethod
|
||||
def validate_gender(cls, gender: str) -> Gender:
|
||||
"""Validate and normalize gender value"""
|
||||
gender_lower = gender.lower().strip()
|
||||
if gender_lower in ["m", "male", "homme", "masculin"]:
|
||||
return Gender.MALE
|
||||
elif gender_lower in ["f", "female", "femme", "féminin"]:
|
||||
return Gender.FEMALE
|
||||
else:
|
||||
raise ValueError(f"Unknown gender: {gender}")
|
||||
|
||||
@classmethod
|
||||
def get_name_category(cls, word_count: int) -> NameCategory:
|
||||
"""Determine name category based on word count"""
|
||||
if word_count <= 3:
|
||||
return NameCategory.SIMPLE
|
||||
else:
|
||||
return NameCategory.COMPOSE
|
||||
|
||||
def process_batch(self, batch: pd.DataFrame, batch_id: int) -> pd.DataFrame:
|
||||
"""Extract features from names in batch"""
|
||||
logging.info(f"Extracting features for batch {batch_id} with {len(batch)} rows")
|
||||
|
||||
batch = batch.copy()
|
||||
|
||||
# Basic features
|
||||
batch["words"] = batch["name"].str.count(" ") + 1
|
||||
batch["length"] = batch["name"].str.replace(" ", "", regex=False).str.len()
|
||||
|
||||
# Handle year column
|
||||
if "year" in batch.columns:
|
||||
batch["year"] = pd.to_numeric(batch["year"], errors="coerce").astype("Int64")
|
||||
|
||||
# Initialize new columns
|
||||
batch["probable_native"] = None
|
||||
batch["probable_surname"] = None
|
||||
batch["identified_name"] = None
|
||||
batch["identified_surname"] = None
|
||||
batch["annotated"] = 0
|
||||
|
||||
# Vectorized category assignment
|
||||
batch["identified_category"] = batch["words"].apply(
|
||||
lambda x: self.get_name_category(x).value
|
||||
)
|
||||
|
||||
# Assign probable_native and probable_surname for all names
|
||||
name_splits = batch["name"].str.split()
|
||||
batch["probable_native"] = name_splits.apply(
|
||||
lambda x: " ".join(x[:-1]) if isinstance(x, list) and len(x) >= 2 else None
|
||||
)
|
||||
batch["probable_surname"] = name_splits.apply(
|
||||
lambda x: x[-1] if isinstance(x, list) and len(x) >= 2 else None
|
||||
)
|
||||
|
||||
# Auto-assign for 3-word names
|
||||
three_word_mask = batch["words"] == 3
|
||||
batch.loc[three_word_mask, "identified_name"] = batch.loc[
|
||||
three_word_mask, "probable_native"
|
||||
]
|
||||
batch.loc[three_word_mask, "identified_surname"] = batch.loc[
|
||||
three_word_mask, "probable_surname"
|
||||
]
|
||||
batch.loc[three_word_mask, "annotated"] = 1
|
||||
|
||||
# Map regions to provinces
|
||||
batch["province"] = self.region_mapper.map_regions_vectorized(batch["region"])
|
||||
|
||||
# Normalize gender
|
||||
if "sex" in batch.columns:
|
||||
batch["sex"] = batch["sex"].apply(lambda x: self.validate_gender(str(x)).value)
|
||||
|
||||
return batch
|
||||
@@ -0,0 +1,168 @@
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Dict, Optional
|
||||
|
||||
import ollama
|
||||
import pandas as pd
|
||||
from pydantic import ValidationError, BaseModel
|
||||
|
||||
from core.config.pipeline_config import PipelineConfig
|
||||
from core.utils.prompt_manager import PromptManager
|
||||
from core.utils.rate_limiter import RateLimiter
|
||||
from core.utils.rate_limiter import RateLimitConfig
|
||||
from processing.batch.batch_config import BatchConfig
|
||||
from processing.steps import PipelineStep
|
||||
|
||||
|
||||
class NameAnnotation(BaseModel):
|
||||
"""Model for name annotation results"""
|
||||
|
||||
identified_name: Optional[str]
|
||||
identified_surname: Optional[str]
|
||||
|
||||
|
||||
class LLMAnnotationStep(PipelineStep):
|
||||
"""Configuration-driven LLM annotation step"""
|
||||
|
||||
def __init__(self, pipeline_config: PipelineConfig):
|
||||
# Create custom batch config for LLM processing
|
||||
batch_config = BatchConfig(
|
||||
batch_size=pipeline_config.processing.batch_size,
|
||||
max_workers=min(
|
||||
pipeline_config.llm.max_concurrent_requests, pipeline_config.processing.max_workers
|
||||
),
|
||||
checkpoint_interval=pipeline_config.processing.checkpoint_interval,
|
||||
use_multiprocessing=pipeline_config.processing.use_multiprocessing,
|
||||
)
|
||||
super().__init__("llm_annotation", pipeline_config, batch_config)
|
||||
|
||||
self.prompt = PromptManager(pipeline_config).load_prompt()
|
||||
self.rate_limiter = (
|
||||
self._create_rate_limiter() if pipeline_config.llm.enable_rate_limiting else None
|
||||
)
|
||||
|
||||
# Statistics
|
||||
self.successful_requests = 0
|
||||
self.failed_requests = 0
|
||||
self.total_retry_attempts = 0
|
||||
|
||||
# Setup logging
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
|
||||
def _create_rate_limiter(self):
|
||||
"""Create rate limiter based on configuration"""
|
||||
rate_config = RateLimitConfig(
|
||||
requests_per_minute=self.pipeline_config.llm.requests_per_minute,
|
||||
requests_per_second=self.pipeline_config.llm.requests_per_second,
|
||||
)
|
||||
return RateLimiter(rate_config)
|
||||
|
||||
def analyze_name_with_retry(self, client: ollama.Client, name: str, row_id: int) -> Dict:
|
||||
"""Analyze a name with retry logic and rate limiting"""
|
||||
for attempt in range(self.pipeline_config.llm.retry_attempts):
|
||||
try:
|
||||
# Apply rate limiting if enabled
|
||||
if self.rate_limiter:
|
||||
self.rate_limiter.wait_if_needed()
|
||||
|
||||
start_time = time.time()
|
||||
response = client.chat(
|
||||
model=self.pipeline_config.llm.model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": self.prompt},
|
||||
{"role": "user", "content": name},
|
||||
],
|
||||
format=NameAnnotation.model_json_schema(),
|
||||
)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
if elapsed_time > self.pipeline_config.llm.timeout_seconds:
|
||||
raise TimeoutError(
|
||||
f"Request took {elapsed_time:.2f}s, exceeding {self.pipeline_config.llm.timeout_seconds}s timeout"
|
||||
)
|
||||
|
||||
annotation = NameAnnotation.model_validate_json(response.message.content)
|
||||
result = {
|
||||
**annotation.model_dump(),
|
||||
"annotated": 1,
|
||||
"processing_time": elapsed_time,
|
||||
"attempts": attempt + 1,
|
||||
}
|
||||
|
||||
self.successful_requests += 1
|
||||
if attempt > 0:
|
||||
self.total_retry_attempts += attempt
|
||||
|
||||
return result
|
||||
|
||||
except (ValidationError, TimeoutError, Exception) as e:
|
||||
logging.warning(
|
||||
f"Error analyzing '{name}' (attempt {attempt + 1}/{self.pipeline_config.llm.retry_attempts}): {e}"
|
||||
)
|
||||
|
||||
# Exponential backoff with jitter
|
||||
if attempt < self.pipeline_config.llm.retry_attempts - 1:
|
||||
wait_time = (2**attempt) + (time.time() % 1)
|
||||
time.sleep(min(wait_time, 10))
|
||||
|
||||
self.failed_requests += 1
|
||||
return {
|
||||
"identified_name": None,
|
||||
"identified_surname": None,
|
||||
"annotated": 0,
|
||||
"processing_time": 0,
|
||||
"attempts": self.pipeline_config.llm.retry_attempts,
|
||||
"failed": True,
|
||||
}
|
||||
|
||||
def process_batch(self, batch: pd.DataFrame, batch_id: int) -> pd.DataFrame:
|
||||
"""Process batch with LLM annotation"""
|
||||
unannotated_mask = batch.get("annotated", 0) == 0
|
||||
unannotated_entries = batch[unannotated_mask]
|
||||
|
||||
if unannotated_entries.empty:
|
||||
logging.info(f"Batch {batch_id}: No entries to annotate")
|
||||
return batch
|
||||
|
||||
logging.info(f"Batch {batch_id}: Annotating {len(unannotated_entries)} entries")
|
||||
|
||||
batch = batch.copy()
|
||||
client = ollama.Client()
|
||||
|
||||
# Process with controlled concurrency
|
||||
max_workers = self.pipeline_config.llm.max_concurrent_requests
|
||||
|
||||
if len(unannotated_entries) == 1 or max_workers == 1:
|
||||
# Sequential processing
|
||||
for idx, row in unannotated_entries.iterrows():
|
||||
result = self.analyze_name_with_retry(client, row["name"], idx)
|
||||
for field, value in result.items():
|
||||
if field not in ["failed"]:
|
||||
batch.loc[idx, field] = value
|
||||
else:
|
||||
# Concurrent processing
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_idx = {}
|
||||
|
||||
for idx, row in unannotated_entries.iterrows():
|
||||
future = executor.submit(self.analyze_name_with_retry, client, row["name"], idx)
|
||||
future_to_idx[future] = idx
|
||||
|
||||
for future in as_completed(future_to_idx):
|
||||
idx = future_to_idx[future]
|
||||
try:
|
||||
result = future.result()
|
||||
for field, value in result.items():
|
||||
if field not in ["failed"]:
|
||||
batch.loc[idx, field] = value
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to process row {idx}: {e}")
|
||||
batch.loc[idx, "annotated"] = 0
|
||||
|
||||
# Ensure proper data types
|
||||
batch["annotated"] = (
|
||||
pd.to_numeric(batch["annotated"], errors="coerce").fillna(0).astype("Int8")
|
||||
)
|
||||
|
||||
return batch
|
||||
@@ -1,26 +0,0 @@
|
||||
import ollama
|
||||
from pydantic import BaseModel
|
||||
|
||||
from misc import load_prompt
|
||||
|
||||
|
||||
class NameAnalysis(BaseModel):
|
||||
identified_name: str | None
|
||||
identified_surname: str | None
|
||||
|
||||
|
||||
name = input("Enter name: ")
|
||||
|
||||
client = ollama.Client()
|
||||
response = client.chat(
|
||||
model="mistral:7b",
|
||||
messages=[
|
||||
{"role": "system", "content": load_prompt()},
|
||||
{"role": "user", "content": name}
|
||||
],
|
||||
format=NameAnalysis.model_json_schema()
|
||||
)
|
||||
analysis = NameAnalysis.model_validate_json(response.message.content)
|
||||
result = analysis.model_dump()
|
||||
|
||||
print(result)
|
||||
Reference in New Issue
Block a user