Skip to content

Commit

Permalink
test out jupyter notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
bw4sz committed Oct 28, 2024
1 parent 12ce8bb commit dac708e
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 10 deletions.
15 changes: 9 additions & 6 deletions DeepForest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@ def parse_args():
return known_args, unknown_args

def convert_unknown_args_to_dict(unknown_args):
def set_nested_dict(d, keys, value):
for key in keys[:-1]:
d = d.setdefault(key, {})
d[keys[-1]] = value

kwargs = {}
key = None
for arg in unknown_args:
if arg.startswith('--'):
key = arg.lstrip('--')
kwargs[key] = None
key = arg.lstrip('--').split('.')
set_nested_dict(kwargs, key, None)
else:
if kwargs[key] is None:
kwargs[key] = arg
else:
kwargs[key] += f" {arg}"
set_nested_dict(kwargs, key, arg)
return kwargs


Expand Down
2 changes: 1 addition & 1 deletion deepforest_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ train:
eps: 1e-08

# Print loss every n epochs
epochs: 1
epochs: 10
# Useful debugging flag in pytorch lightning, set to True to get a single batch of training to test settings.
fast_dev_run: False
# pin images to GPU memory for fast training. This depends on GPU size and number of images.
Expand Down
29 changes: 29 additions & 0 deletions docs/examples/Datasets.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,35 @@
"m.config[\"train\"][\"csv_file\"] =\"<dummy file, existing dataloader>\"\n",
"m.trainer.fit(m, train_loader) "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluation loader"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from milliontrees.common.data_loaders import get_eval_loader\n",
"\n",
"test_dataset = dataset.get_subset(\"train\")\n",
"test_loader = get_eval_loader(\"standard\", test_dataset, batch_size=16)\n",
"\n",
"# Show one batch of the loader\n",
"for metadata, image, targets in test_loader:\n",
" print(\"Targets is a list of dictionaries with the following keys: \", targets[0].keys())\n",
" print(f\"Image shape: {image.shape}, Image type: {type(image)}\")\n",
" print(f\"Annotation shape of the first image: {targets[0]['boxes'].shape}\")\n",
" break # Just show the first batch\n",
"\n",
"# Evaluate\n",
"dataset.eval(all_y_pred, all_y_true, all_metadata)"
]
}
],
"metadata": {
Expand Down
30 changes: 27 additions & 3 deletions tests/test_TreeBoxes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from milliontrees.datasets.TreeBoxes import TreeBoxesDataset
from milliontrees.common.data_loaders import get_train_loader
from milliontrees.common.data_loaders import get_train_loader, get_eval_loader

import torch
import pytest
Expand Down Expand Up @@ -54,6 +54,31 @@ def test_get_train_dataloader(dataset, batch_size):
assert len(metadata) == batch_size
break

def test_get_test_dataloader(dataset, batch_size):
dataset = TreeBoxesDataset(download=False, root_dir=dataset)
test_dataset = dataset.get_subset("test")

for metadata, image, targets in test_dataset:
boxes, labels = targets["boxes"], targets["labels"]
assert image.shape == (100, 100, 3)
assert image.dtype == np.float32
assert image.min() >= 0.0 and image.max() <= 1.0
assert boxes.shape == (2, 4)
assert labels.shape == (2,)
assert metadata.shape == (2,2)
break

test_loader = get_eval_loader('standard', test_dataset, batch_size=batch_size)
for metadata, x, targets in test_loader:
y = targets[0]["boxes"]
assert torch.is_tensor(targets[0]["boxes"])
assert x.shape == (batch_size, 3, 448, 448)
assert x.dtype == torch.float32
assert x.min() >= 0.0 and x.max() <= 1.0
assert y.shape[1] == 4
assert len(metadata) == batch_size
break

# Test structure with real annotation data to ensure format is correct
# Do not run on github actions, long running.
@pytest.mark.skipif(not on_hipergator, reason="Do not run on github actions")
Expand Down Expand Up @@ -93,5 +118,4 @@ def test_TreeBoxes_download(tmpdir):
assert image.min() >= 0.0 and image.max() <= 1.0
assert boxes.shape[1] == 4
assert metadata.shape[1] == 1
break

break

0 comments on commit dac708e

Please sign in to comment.