Skip to content

Commit

Permalink
Save and load hashkey for explorer (#981)
Browse files Browse the repository at this point in the history
<!-- Contributing guide:
https://github.com/openvinotoolkit/datumaro/blob/develop/CONTRIBUTING.md
-->

### Summary

<!--
Resolves #111 and #222.
Depends on #1000 (for series of dependent commits).

This PR introduces this capability to make the project better in this
and that.

- Added this feature
- Removed that feature
- Fixed the problem #1234
-->
- Ticket no.107264
- Save and load `HashKey` for dataset after `explore` command
- Get list of dataset in explorer
- Export `HashKey` annotation in `datumaro` format
- Usecase for explorer
  - w/w.o target for explore command
  - explore -> add -> explore
  - explore -> merge -> explore
  - Support versioning

### How to test
<!-- Describe the testing procedure for reviewers, if changes are
not fully covered by unit tests or manual testing can be complicated.
-->

### Checklist
<!-- Put an 'x' in all the boxes that apply -->
- [X] I have added unit tests to cover my changes.​
- [X] I have added integration tests to cover my changes.​
- [X] I have added the description of my changes into
[CHANGELOG](https://github.com/openvinotoolkit/datumaro/blob/develop/CHANGELOG.md).​
- [X] I have updated the
[documentation](https://github.com/openvinotoolkit/datumaro/tree/develop/docs)
accordingly

### License

- [ ] I submit _my code changes_ under the same [MIT
License](https://github.com/openvinotoolkit/datumaro/blob/develop/LICENSE)
that covers the project.
  Feel free to contact the maintainers if that's a concern.
- [ ] I have updated the license header for each file (see an example
below).

```python
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT
```
  • Loading branch information
sooahleex authored May 11, 2023
1 parent 38bbf0c commit 9ab0954
Show file tree
Hide file tree
Showing 13 changed files with 640 additions and 105 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
(<https://github.com/openvinotoolkit/datumaro/pull/973>)
- Enrich stack trace for better user experience when importing
(<https://github.com/openvinotoolkit/datumaro/pull/992>)
- Save and load hashkey for explorer
(<https://github.com/openvinotoolkit/datumaro/pull/981>)

### Bug fixes
- Fix Mapillary Vistas data format
Expand Down
91 changes: 64 additions & 27 deletions datumaro/cli/commands/explore.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@

import argparse
import logging as log
import os
import os.path as osp

import numpy as np
import shutil

from datumaro.components.errors import ProjectNotFoundError
from datumaro.components.explorer import Explorer
from datumaro.components.visualizer import Visualizer
from datumaro.util.image import save_image
from datumaro.util import str_to_bool
from datumaro.util.scope import scope_add, scoped

from ..util import MultilineFormatter
Expand Down Expand Up @@ -42,10 +41,7 @@ def build_parser(parser_ctor=argparse.ArgumentParser):
formatter_class=MultilineFormatter,
)

parser.add_argument(
"_positionals", nargs=argparse.REMAINDER, help=argparse.SUPPRESS
) # workaround for -- eaten by positionals
parser.add_argument("target", nargs="+", default="project", help="Target dataset")
parser.add_argument("target", nargs="?", help="Target dataset")
parser.add_argument(
"-q",
"--query",
Expand All @@ -61,9 +57,25 @@ def build_parser(parser_ctor=argparse.ArgumentParser):
help="Directory of the project to operate on (default: current dir)",
)
parser.add_argument(
"-s", "--save", dest="save", default=True, help="Save explorer result as png"
"-s",
"--save",
action="store_true",
default=False,
help="Save explorer result files on explore_result folder",
)
parser.add_argument(
"--stage",
type=str_to_bool,
default=True,
help="""
Include this action as a project build step.
If true, this operation will be saved in the project
build tree, allowing to reproduce the resulting dataset later.
Applicable only to main project targets (i.e. data sources
and the 'project' target, but not intermediate stages)
(default: %(default)s)
""",
)

parser.set_defaults(command=explore_command)

return parser
Expand All @@ -75,7 +87,7 @@ def get_sensitive_args():
"target",
"query",
"topk",
"save",
"project_dir",
]
}

Expand All @@ -89,36 +101,61 @@ def explore_command(args):
if args.project_dir:
raise

dataset, _ = parse_full_revpath(args.target[0], project)
if args.target:
targets = [args.target]
else:
targets = list(project.working_tree.sources)

source_datasets = []
for target in targets:
target_dataset, _ = parse_full_revpath(target, project)
source_datasets.append(target_dataset)

explorer_args = {"save_hashkey": True}
build_tree = project.working_tree.clone()
for target in targets:
build_tree.build_targets.add_explore_stage(target, params=explorer_args)

explorer = Explorer(dataset)
explorer = Explorer(*source_datasets)
for dataset in source_datasets:
dst_dir = dataset.data_path
dataset.save(dst_dir, save_media=True)

if args.stage:
project.working_tree.config.update(build_tree.config)
project.working_tree.save()

# Get query datasetitem through query path
if osp.exists(args.query):
query_datasetitem = dataset.get_datasetitem_by_path(args.query)
query_datasetitem = None
for dataset in source_datasets:
try:
query_datasetitem = dataset.get_datasetitem_by_path(args.query)
except Exception:
continue
if not query_datasetitem:
break
else:
query_datasetitem = args.query

results = explorer.explore_topk(query_datasetitem, args.topk)

subset_list = []
id_list = []
result_path_list = []
log.info("Most similar {} results of query in dataset".format(args.topk))
log.info(f"Most similar {args.topk} results of query in dataset")
for result in results:
subset_list.append(result.subset)
id_list.append(result.id)
path = getattr(result.media, "path", None)
result_path_list.append(path)
log.info("id: {} | subset: {} | path : {}".format(result.id, result.subset, path))

visualizer = Visualizer(dataset, figsize=(20, 20), alpha=0)
fig = visualizer.vis_gallery(id_list, subset_list)
log.info(f"id: {result.id} | subset: {result.subset} | path : {path}")

if args.save:
fig.canvas.draw()
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
save_image(osp.join("./explorer.png"), data, create_dir=True)
saved_result_path = osp.join(args.project_dir, "explore_result")
if osp.exists(saved_result_path):
shutil.rmtree(saved_result_path)
os.makedirs(saved_result_path)
for result in results:
saved_subset_path = osp.join(saved_result_path, result.subset)
if not osp.exists(saved_subset_path):
os.makedirs(saved_subset_path)
shutil.copyfile(path, osp.join(saved_subset_path, result.id + ".jpg"))

return 0
70 changes: 50 additions & 20 deletions datumaro/components/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
#
# SPDX-License-Identifier: MIT

from typing import List, Optional, Union
from typing import List, Optional, Sequence, Union

import numpy as np

from datumaro.components.annotation import HashKey
from datumaro.components.dataset import IDataset
from datumaro.components.dataset import Dataset
from datumaro.components.dataset_base import DatasetItem
from datumaro.components.errors import MediaTypeError
from datumaro.components.media import Image
from datumaro.components.media import Image, MediaElement
from datumaro.plugins.explorer import ExplorerLauncher


Expand All @@ -25,10 +25,18 @@ def calculate_hamming(B1, B2):
return distH


def select_uninferenced_dataset(dataset):
uninferenced_dataset = Dataset(media_type=MediaElement)
for item in dataset:
if not any(isinstance(annotation, HashKey) for annotation in item.annotations):
uninferenced_dataset.put(item)
return uninferenced_dataset


class Explorer:
def __init__(
self,
dataset: IDataset,
*datasets: Sequence[Dataset],
topk: int = 10,
) -> None:
"""
Expand All @@ -41,28 +49,50 @@ def __init__(
topk:
Number of images.
"""
self._model = ExplorerLauncher(model_name="clip_visual_ViT-B_32")
self._text_model = ExplorerLauncher(model_name="clip_text_ViT-B_32")
inference = dataset.run_model(self._model, append_annotation=True)
self._model = None
self._text_model = None
self._topk = topk

database_keys = []
item_list = []

for item in inference:
for annotation in item.annotations:
if isinstance(annotation, HashKey):
try:
hash_key = annotation.hash_key[0]
hash_key = np.unpackbits(hash_key, axis=-1)
database_keys.append(hash_key)
item_list.append(item)
except Exception:
hash_key = None
datasets_to_infer = [select_uninferenced_dataset(dataset) for dataset in datasets]
datasets = self.compute_hash_key(datasets, datasets_to_infer)

for dataset in datasets:
for item in dataset:
for annotation in item.annotations:
if isinstance(annotation, HashKey):
try:
hash_key = annotation.hash_key[0]
hash_key = np.unpackbits(hash_key, axis=-1)
database_keys.append(hash_key)
item_list.append(item)
except Exception:
continue

self._database_keys = database_keys
self._item_list = item_list

@property
def model(self):
if self._model is None:
self._model = ExplorerLauncher(model_name="clip_visual_ViT-B_32")
return self._model

@property
def text_model(self):
if self._text_model is None:
self._text_model = ExplorerLauncher(model_name="clip_text_ViT-B_32")
return self._text_model

def compute_hash_key(self, datasets, datasets_to_infer):
for dataset in datasets_to_infer:
if len(dataset) > 0:
dataset.run_model(self.model, append_annotation=True)
for dataset, dataset_to_infer in zip(datasets, datasets_to_infer):
dataset.update(dataset_to_infer)
return datasets

def explore_topk(
self,
query: Union[DatasetItem, str, List[DatasetItem], List[str]],
Expand Down Expand Up @@ -91,7 +121,7 @@ def explore_topk(
break
query_hash_key_list.append(q_hash_key)
elif isinstance(q, str):
q_hash_key = self._text_model.launch(q)[0][0].hash_key
q_hash_key = self.text_model.launch(q)[0][0].hash_key
query_hash_key_list.append(q_hash_key)

sims = np.zeros(shape=database_keys.shape[0] * len(query_hash_key_list))
Expand Down Expand Up @@ -131,7 +161,7 @@ def cal_ind(x):
pass

elif isinstance(query, str):
query_key = self._text_model.launch(query)[0][0].hash_key
query_key = self.text_model.launch(query)[0][0].hash_key
else:
raise MediaTypeError(
"Unexpected media type of query '%s'. "
Expand Down
9 changes: 8 additions & 1 deletion datumaro/components/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)
from datumaro.components.media import Image, PointCloud
from datumaro.components.progress_reporting import NullProgressReporter, ProgressReporter
from datumaro.util.meta_file_util import save_meta_file
from datumaro.util.meta_file_util import save_hashkey_file, save_meta_file
from datumaro.util.os_util import rmtree
from datumaro.util.scope import on_error_do, scoped

Expand Down Expand Up @@ -174,6 +174,7 @@ def __init__(
image_ext: Optional[str] = None,
default_image_ext: Optional[str] = None,
save_dataset_meta: bool = False,
save_hashkey_meta: bool = False,
ctx: Optional[ExportContext] = None,
):
default_image_ext = default_image_ext or self.DEFAULT_IMAGE_EXT
Expand Down Expand Up @@ -202,6 +203,7 @@ def __init__(
self._save_dir = save_dir

self._save_dataset_meta = save_dataset_meta
self._save_hashkey_meta = save_hashkey_meta

# TODO: refactor this variable.
# Can be used by a subclass to store the current patch info
Expand Down Expand Up @@ -278,9 +280,14 @@ def _save_point_cloud(self, item=None, path=None, *, name=None, subdir=None, bas
def _save_meta_file(self, path):
save_meta_file(path, self._extractor.categories())

def _save_hashkey_file(self, path):
save_hashkey_file(path, self._extractor)


# TODO: Currently, ExportContextComponent is introduced only for Datumaro and DatumaroBinary format
# for multi-processing. We need to propagate this to everywhere in Datumaro 1.2.0


class ExportContextComponent:
def __init__(
self,
Expand Down
13 changes: 13 additions & 0 deletions datumaro/components/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ class BuildStageType(Enum):
filter = auto()
convert = auto()
inference = auto()
explore = auto()


class Pipeline:
Expand Down Expand Up @@ -915,6 +916,18 @@ def add_convert_stage(
name=name,
)

def add_explore_stage(
self, target: str, params: Optional[Dict] = None, name: Optional[str] = None
):
return self.add_stage(
target,
{
"type": BuildStageType.explore.name,
"params": params or {},
},
name=name,
)

@staticmethod
def make_target_name(target: str, stage: Optional[str] = None) -> str:
if stage:
Expand Down
Loading

0 comments on commit 9ab0954

Please sign in to comment.