Skip to content

Commit

Permalink
fix: use cpu device (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
jgrss authored Jul 31, 2024
1 parent 9853319 commit c685fa5
Showing 1 changed file with 50 additions and 46 deletions.
96 changes: 50 additions & 46 deletions tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,51 +50,51 @@ def create_data(group: int) -> Data:
return batch_data


# def test_train():
# num_data = 10
# with tempfile.TemporaryDirectory() as tmp_path:
# ppaths = setup_paths(tmp_path)
# for i in range(num_data):
# data_path = (
# ppaths.process_path / f'data_{i:06d}_2021_{i:06d}_none.pt'
# )
# batch_data = create_data(i)
# batch_data.to_file(data_path)

# dataset = EdgeDataset(
# ppaths.train_path,
# processes=0,
# threads_per_worker=1,
# random_seed=100,
# )

# cultionet_params = CultionetParams(
# ckpt_file=ppaths.ckpt_file,
# model_name="cultionet",
# dataset=dataset,
# val_frac=0.2,
# batch_size=2,
# load_batch_workers=0,
# hidden_channels=16,
# num_classes=2,
# edge_class=2,
# model_type=ModelTypes.TOWERUNET,
# res_block_type=ResBlockTypes.RESA,
# attention_weights=AttentionTypes.SPATIAL_CHANNEL,
# activation_type="SiLU",
# dilations=[1, 2],
# dropout=0.2,
# deep_supervision=True,
# pool_attention=False,
# pool_by_max=True,
# repeat_resa_kernel=False,
# batchnorm_first=True,
# epochs=1,
# device="cpu",
# devices=1,
# precision="16-mixed",
# )
# cultionet.fit(cultionet_params)
def test_train():
num_data = 10
with tempfile.TemporaryDirectory() as tmp_path:
ppaths = setup_paths(tmp_path)
for i in range(num_data):
data_path = (
ppaths.process_path / f'data_{i:06d}_2021_{i:06d}_none.pt'
)
batch_data = create_data(i)
batch_data.to_file(data_path)

dataset = EdgeDataset(
ppaths.train_path,
processes=0,
threads_per_worker=1,
random_seed=100,
)

cultionet_params = CultionetParams(
ckpt_file=ppaths.ckpt_file,
model_name="cultionet",
dataset=dataset,
val_frac=0.2,
batch_size=2,
load_batch_workers=0,
hidden_channels=16,
num_classes=2,
edge_class=2,
model_type=ModelTypes.TOWERUNET,
res_block_type=ResBlockTypes.RESA,
attention_weights=AttentionTypes.SPATIAL_CHANNEL,
activation_type="SiLU",
dilations=[1, 2],
dropout=0.2,
deep_supervision=True,
pool_attention=False,
pool_by_max=True,
repeat_resa_kernel=False,
batchnorm_first=True,
epochs=1,
device="cpu",
devices=1,
precision="16-mixed",
)
cultionet.fit(cultionet_params)


def test_train_cli():
Expand All @@ -112,7 +112,11 @@ def test_train_cli():
with open(tmp_path / "data/classes.info", "w") as f:
json.dump({"max_crop_class": 1, "edge_class": 2}, f)

command = f"cultionet train -p {str(tmp_path.absolute())} --val-frac 0.2 --augment-prob 0.5 --epochs 2 --hidden-channels 16 --processes 1 --load-batch-workers 0 --batch-size 2 --dropout 0.2 --deep-sup --dilations 1 2 --pool-by-max --learning-rate 0.01 --weight-decay 1e-4 --attention-weights natten"
command = f"cultionet train -p {str(tmp_path.absolute())} "
"--val-frac 0.2 --augment-prob 0.5 --epochs 1 --hidden-channels 16 "
"--processes 1 --load-batch-workers 0 --batch-size 2 --dropout 0.2 "
"--deep-sup --dilations 1 2 --pool-by-max --learning-rate 0.01 "
"--weight-decay 1e-4 --attention-weights natten --device cpu"

try:
subprocess.run(
Expand Down

0 comments on commit c685fa5

Please sign in to comment.