From dac708e5f36f2c400d4039129ad8652ad30b79f8 Mon Sep 17 00:00:00 2001 From: bw4sz Date: Mon, 28 Oct 2024 14:25:07 -0400 Subject: [PATCH] test out jupyter notebook --- DeepForest.py | 15 +++++++++------ deepforest_config.yml | 2 +- docs/examples/Datasets.ipynb | 29 +++++++++++++++++++++++++++++ tests/test_TreeBoxes.py | 30 +++++++++++++++++++++++++++--- 4 files changed, 66 insertions(+), 10 deletions(-) diff --git a/DeepForest.py b/DeepForest.py index fda171c..515caa3 100644 --- a/DeepForest.py +++ b/DeepForest.py @@ -16,16 +16,19 @@ def parse_args(): return known_args, unknown_args def convert_unknown_args_to_dict(unknown_args): + def set_nested_dict(d, keys, value): + for key in keys[:-1]: + d = d.setdefault(key, {}) + d[keys[-1]] = value + kwargs = {} + key = None for arg in unknown_args: if arg.startswith('--'): - key = arg.lstrip('--') - kwargs[key] = None + key = arg.lstrip('--').split('.') + set_nested_dict(kwargs, key, None) else: - if kwargs[key] is None: - kwargs[key] = arg - else: - kwargs[key] += f" {arg}" + set_nested_dict(kwargs, key, arg) return kwargs diff --git a/deepforest_config.yml b/deepforest_config.yml index 3db43fa..ab93f40 100644 --- a/deepforest_config.yml +++ b/deepforest_config.yml @@ -45,7 +45,7 @@ train: eps: 1e-08 # Print loss every n epochs - epochs: 1 + epochs: 10 # Useful debugging flag in pytorch lightning, set to True to get a single batch of training to test settings. fast_dev_run: False # pin images to GPU memory for fast training. This depends on GPU size and number of images. diff --git a/docs/examples/Datasets.ipynb b/docs/examples/Datasets.ipynb index 6748640..12e9318 100644 --- a/docs/examples/Datasets.ipynb +++ b/docs/examples/Datasets.ipynb @@ -251,6 +251,35 @@ "m.config[\"train\"][\"csv_file\"] =\"\"\n", "m.trainer.fit(m, train_loader) " ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluation loader" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from milliontrees.common.data_loaders import get_eval_loader\n", + "\n", + "test_dataset = dataset.get_subset(\"train\")\n", + "test_loader = get_eval_loader(\"standard\", test_dataset, batch_size=16)\n", + "\n", + "# Show one batch of the loader\n", + "for metadata, image, targets in test_loader:\n", + " print(\"Targets is a list of dictionaries with the following keys: \", targets[0].keys())\n", + " print(f\"Image shape: {image.shape}, Image type: {type(image)}\")\n", + " print(f\"Annotation shape of the first image: {targets[0]['boxes'].shape}\")\n", + " break # Just show the first batch\n", + "\n", + "# Evaluate\n", + "dataset.eval(all_y_pred, all_y_true, all_metadata)" + ] } ], "metadata": { diff --git a/tests/test_TreeBoxes.py b/tests/test_TreeBoxes.py index 2974c65..942cb18 100644 --- a/tests/test_TreeBoxes.py +++ b/tests/test_TreeBoxes.py @@ -1,5 +1,5 @@ from milliontrees.datasets.TreeBoxes import TreeBoxesDataset -from milliontrees.common.data_loaders import get_train_loader +from milliontrees.common.data_loaders import get_train_loader, get_eval_loader import torch import pytest @@ -54,6 +54,31 @@ def test_get_train_dataloader(dataset, batch_size): assert len(metadata) == batch_size break +def test_get_test_dataloader(dataset, batch_size): + dataset = TreeBoxesDataset(download=False, root_dir=dataset) + test_dataset = dataset.get_subset("test") + + for metadata, image, targets in test_dataset: + boxes, labels = targets["boxes"], targets["labels"] + assert image.shape == (100, 100, 3) + assert image.dtype == np.float32 + assert image.min() >= 0.0 and image.max() <= 1.0 + assert boxes.shape == (2, 4) + assert labels.shape == (2,) + assert metadata.shape == (2,2) + break + + test_loader = get_eval_loader('standard', test_dataset, batch_size=batch_size) + for metadata, x, targets in test_loader: + y = targets[0]["boxes"] + assert torch.is_tensor(targets[0]["boxes"]) + assert x.shape == (batch_size, 3, 448, 448) + assert x.dtype == torch.float32 + assert x.min() >= 0.0 and x.max() <= 1.0 + assert y.shape[1] == 4 + assert len(metadata) == batch_size + break + # Test structure with real annotation data to ensure format is correct # Do not run on github actions, long running. @pytest.mark.skipif(not on_hipergator, reason="Do not run on github actions") @@ -93,5 +118,4 @@ def test_TreeBoxes_download(tmpdir): assert image.min() >= 0.0 and image.max() <= 1.0 assert boxes.shape[1] == 4 assert metadata.shape[1] == 1 - break - \ No newline at end of file + break \ No newline at end of file