From 44eb695b730d6abcea38e661053cd06e872b78c7 Mon Sep 17 00:00:00 2001 From: Jordan Graesser Date: Wed, 31 Jul 2024 18:50:28 +1000 Subject: [PATCH] use cpu device --- tests/test_train.py | 96 +++++++++++++++++++++++---------------------- 1 file changed, 50 insertions(+), 46 deletions(-) diff --git a/tests/test_train.py b/tests/test_train.py index 558716fa..dce02235 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -50,51 +50,51 @@ def create_data(group: int) -> Data: return batch_data -# def test_train(): -# num_data = 10 -# with tempfile.TemporaryDirectory() as tmp_path: -# ppaths = setup_paths(tmp_path) -# for i in range(num_data): -# data_path = ( -# ppaths.process_path / f'data_{i:06d}_2021_{i:06d}_none.pt' -# ) -# batch_data = create_data(i) -# batch_data.to_file(data_path) - -# dataset = EdgeDataset( -# ppaths.train_path, -# processes=0, -# threads_per_worker=1, -# random_seed=100, -# ) - -# cultionet_params = CultionetParams( -# ckpt_file=ppaths.ckpt_file, -# model_name="cultionet", -# dataset=dataset, -# val_frac=0.2, -# batch_size=2, -# load_batch_workers=0, -# hidden_channels=16, -# num_classes=2, -# edge_class=2, -# model_type=ModelTypes.TOWERUNET, -# res_block_type=ResBlockTypes.RESA, -# attention_weights=AttentionTypes.SPATIAL_CHANNEL, -# activation_type="SiLU", -# dilations=[1, 2], -# dropout=0.2, -# deep_supervision=True, -# pool_attention=False, -# pool_by_max=True, -# repeat_resa_kernel=False, -# batchnorm_first=True, -# epochs=1, -# device="cpu", -# devices=1, -# precision="16-mixed", -# ) -# cultionet.fit(cultionet_params) +def test_train(): + num_data = 10 + with tempfile.TemporaryDirectory() as tmp_path: + ppaths = setup_paths(tmp_path) + for i in range(num_data): + data_path = ( + ppaths.process_path / f'data_{i:06d}_2021_{i:06d}_none.pt' + ) + batch_data = create_data(i) + batch_data.to_file(data_path) + + dataset = EdgeDataset( + ppaths.train_path, + processes=0, + threads_per_worker=1, + random_seed=100, + ) + + cultionet_params = CultionetParams( + ckpt_file=ppaths.ckpt_file, + model_name="cultionet", + dataset=dataset, + val_frac=0.2, + batch_size=2, + load_batch_workers=0, + hidden_channels=16, + num_classes=2, + edge_class=2, + model_type=ModelTypes.TOWERUNET, + res_block_type=ResBlockTypes.RESA, + attention_weights=AttentionTypes.SPATIAL_CHANNEL, + activation_type="SiLU", + dilations=[1, 2], + dropout=0.2, + deep_supervision=True, + pool_attention=False, + pool_by_max=True, + repeat_resa_kernel=False, + batchnorm_first=True, + epochs=1, + device="cpu", + devices=1, + precision="16-mixed", + ) + cultionet.fit(cultionet_params) def test_train_cli(): @@ -112,7 +112,11 @@ def test_train_cli(): with open(tmp_path / "data/classes.info", "w") as f: json.dump({"max_crop_class": 1, "edge_class": 2}, f) - command = f"cultionet train -p {str(tmp_path.absolute())} --val-frac 0.2 --augment-prob 0.5 --epochs 2 --hidden-channels 16 --processes 1 --load-batch-workers 0 --batch-size 2 --dropout 0.2 --deep-sup --dilations 1 2 --pool-by-max --learning-rate 0.01 --weight-decay 1e-4 --attention-weights natten" + command = f"cultionet train -p {str(tmp_path.absolute())} " + "--val-frac 0.2 --augment-prob 0.5 --epochs 1 --hidden-channels 16 " + "--processes 1 --load-batch-workers 0 --batch-size 2 --dropout 0.2 " + "--deep-sup --dilations 1 2 --pool-by-max --learning-rate 0.01 " + "--weight-decay 1e-4 --attention-weights natten --device cpu" try: subprocess.run(