fix: artifacts saving and dataset loading
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, Optional
|
||||
|
||||
@@ -16,10 +15,11 @@ from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val
|
||||
from sklearn.pipeline import make_pipeline, Pipeline
|
||||
from sklearn.preprocessing import LabelEncoder
|
||||
|
||||
from misc import GENDER_MODELS_DIR, load_csv_dataset
|
||||
from misc import GENDER_MODELS_DIR, load_csv_dataset, save_pickle
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format=">> %(message)s")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
dataset_path: str
|
||||
@@ -169,15 +169,10 @@ def save_artifacts(model, encoder, cfg: Config):
|
||||
:type cfg: Config
|
||||
:return: None
|
||||
"""
|
||||
model_path = os.path.join(GENDER_MODELS_DIR, "regression_model.pkl")
|
||||
encoder_path = os.path.join(GENDER_MODELS_DIR, "regression_label_encoder.pkl")
|
||||
save_pickle(model, os.path.join(GENDER_MODELS_DIR, "regression_model.pkl"))
|
||||
save_pickle(encoder, os.path.join(GENDER_MODELS_DIR, "regression_label_encoder.pkl"))
|
||||
|
||||
with open(model_path, "wb") as f:
|
||||
pickle.dump(model, f)
|
||||
with open(encoder_path, "wb") as f:
|
||||
pickle.dump(encoder, f)
|
||||
logging.info(f"Saved model to: {model_path}")
|
||||
logging.info(f"Saved label encoder to: {encoder_path}")
|
||||
logging.info(f"Model and artifacts saved to {GENDER_MODELS_DIR}")
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
Reference in New Issue
Block a user