Skip to content

Commit

Permalink
Best hyperparameters set
Browse files Browse the repository at this point in the history
  • Loading branch information
Junhui Yang authored and Junhui Yang committed Apr 2, 2024
1 parent 3e57919 commit 31dddcf
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 105 deletions.
4 changes: 2 additions & 2 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ modeling:
stratify_by: "neighbourhood_group"
# Maximum number of features to consider for the TFIDF applied to the title of the
# insertion (the column called "name")
max_tfidf_features: 5
max_tfidf_features: 30
# NOTE: you can put here any parameter that is accepted by the constructor of
# RandomForestRegressor. This is a subsample, but more could be added:
random_forest:
Expand All @@ -33,6 +33,6 @@ modeling:
# Here -1 means all available cores
n_jobs: -1
criterion: squared_error
max_features: 0.5
max_features: 0.33
# DO not change the following
oob_score: true
109 changes: 78 additions & 31 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
from omegaconf import DictConfig

_steps = [
"download",
"basic_cleaning",
"data_check",
"data_split",
"train_random_forest",
'download',
'basic_cleaning',
'data_check',
'data_split',
'train_random_forest',
# NOTE: We do not include this in the steps so it is not run by mistake.
# You first need to promote a model export to "prod" before you can run this,
# You first need to promote a model export to 'prod' before you can run this,
# then you need to run this step explicitly
# "test_regression_model"
# 'test_regression_model'
]


Expand All @@ -25,47 +25,77 @@
def go(config: DictConfig):

# Setup the wandb experiment. All runs will be grouped under this name
os.environ["WANDB_PROJECT"] = config["main"]["project_name"]
os.environ["WANDB_RUN_GROUP"] = config["main"]["experiment_name"]
os.environ['WANDB_PROJECT'] = config['main']['project_name']
os.environ['WANDB_RUN_GROUP'] = config['main']['experiment_name']

# Steps to execute
steps_par = config['main']['steps']
active_steps = steps_par.split(",") if steps_par != "all" else _steps
active_steps = steps_par.split(',') if steps_par != 'all' else _steps

# Move to a temporary directory
with tempfile.TemporaryDirectory() as tmp_dir:

if "download" in active_steps:
if 'download' in active_steps:
# Download file and load in W&B
_ = mlflow.run(
f"{config['main']['components_repository']}/get_data",
"main",
'main',
version='main',
parameters={
"sample": config["etl"]["sample"],
"artifact_name": "sample.csv",
"artifact_type": "raw_data",
"artifact_description": "Raw file as downloaded"
'sample': config['etl']['sample'],
'artifact_name': 'sample.csv',
'artifact_type': 'raw_data',
'artifact_description': 'Raw file as downloaded'
},
)

if "basic_cleaning" in active_steps:
if 'basic_cleaning' in active_steps:
##################
# Implement here #
##################
pass

if "data_check" in active_steps:
_ = mlflow.run(
os.path.join(hydra.utils.get_original_cwd(), 'src', 'basic_cleaning'),
'main',
parameters={
'input_artifact': 'sample.csv:latest',
'output_artifact': 'clean_sample.csv',
'output_type': 'clean_sample',
'output_description': 'Data with outliers and null values removed',
'min_price': config['etl']['min_price'],
'max_price': config['etl']['max_price']
},
)

if 'data_check' in active_steps:
##################
# Implement here #
##################
pass

if "data_split" in active_steps:
_ = mlflow.run(
os.path.join(hydra.utils.get_original_cwd(), 'src', 'data_check'),
'main',
parameters={
'csv': 'clean_sample.csv:latest',
'ref': 'clean_sample.csv:reference',
'kl_threshold': config['data_check']['kl_threshold'],
'min_price': config['etl']['min_price'],
'max_price': config['etl']['max_price']
},
)

if 'data_split' in active_steps:
##################
# Implement here #
##################
pass
_ = mlflow.run(
os.path.join(hydra.utils.get_original_cwd(), 'components', 'train_val_test_split'),
'main',
parameters={
'input': 'clean_sample.csv:latest',
'test_size': config['modeling']['test_size'],
'random_seed': config['modeling']['random_seed'],
'stratify_by': config['modeling']['stratify_by']
},
)

if "train_random_forest" in active_steps:

Expand All @@ -80,17 +110,34 @@ def go(config: DictConfig):
##################
# Implement here #
##################

pass

if "test_regression_model" in active_steps:
_ = mlflow.run(
os.path.join(hydra.utils.get_original_cwd(), 'src', 'train_random_forest'),
'main',
parameters={
'trainval_artifact': 'trainval_data.csv:latest',
'val_size': config['modeling']['val_size'],
'random_seed': config['modeling']['random_seed'],
'stratify_by': config['modeling']['stratify_by'],
'rf_config': rf_config,
'max_tfidf_features': config['modeling']['max_tfidf_features'],
'output_artifact': 'random_forest_export'
},
)

if 'test_regression_model' in active_steps:

##################
# Implement here #
##################

pass
_ = mlflow.run(
os.path.join(hydra.utils.get_original_cwd(), 'components', 'test_regression_model'),
'main',
parameters={
'mlflow_model': 'random_forest_export:prod',
'test_dataset': 'test_data.csv:latest'
},
)


if __name__ == "__main__":
if __name__ == '__main__':
go()
4 changes: 2 additions & 2 deletions src/data_check/conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ channels:
dependencies:
- python=3.10.0
- pandas=2.1.3
- pytest=6.2.2
- scipy=1.5.2
- pytest=6.2.5
- scipy=1.7.3
- pip=23.3.1
- pip:
- mlflow==2.8.1
Expand Down
5 changes: 5 additions & 0 deletions src/data_check/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,8 @@ def test_similar_neigh_distrib(data: pd.DataFrame, ref_data: pd.DataFrame, kl_th
########################################################
# Implement here test_row_count and test_price_range #
########################################################
def test_row_count(data):
assert 15000 < data.shape[0] < 1000000

def test_price_range(data, min_price, max_price):
assert data['price'].between(min_price, max_price).all()
Loading

0 comments on commit 31dddcf

Please sign in to comment.