Skip to content

Commit

Permalink
Merge pull request #5256 from voxel51/on-disk-instances-updates
Browse files Browse the repository at this point in the history
Support on-disk instance segmentations in SDK
  • Loading branch information
brimoor authored Dec 13, 2024
2 parents 73da3ae + 8331661 commit 64cf79b
Show file tree
Hide file tree
Showing 7 changed files with 335 additions and 106 deletions.
8 changes: 5 additions & 3 deletions docs/source/user_guide/using_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2542,7 +2542,7 @@ Object detections stored in |Detections| may also have instance segmentation
masks.

These masks can be stored in one of two ways: either directly in the database
via the :attr:`mask<fiftyone.core.labels.Detection.mask>` attribute, or on
via the :attr:`mask <fiftyone.core.labels.Detection.mask>` attribute, or on
disk referenced by the
:attr:`mask_path <fiftyone.core.labels.Detection.mask_path>` attribute.

Expand Down Expand Up @@ -2605,8 +2605,10 @@ object's bounding box when visualizing in the App.
<Detection: {
'id': '5f8709282018186b6ef6682b',
'attributes': {},
'tags': [],
'label': 'cat',
'bounding_box': [0.48, 0.513, 0.397, 0.288],
'mask': None,
'mask_path': '/path/to/mask.png',
'confidence': 0.96,
'index': None,
Expand All @@ -2615,8 +2617,8 @@ object's bounding box when visualizing in the App.
}>,
}>
Like all |Label| types, you can also add custom attributes to your detections
by dynamically adding new fields to each |Detection| instance:
Like all |Label| types, you can also add custom attributes to your instance
segmentations by dynamically adding new fields to each |Detection| instance:

.. code-block:: python
:linenos:
Expand Down
46 changes: 33 additions & 13 deletions fiftyone/core/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -10662,9 +10662,7 @@ def _handle_db_fields(self, paths, frames=False):
db_fields_map = self._get_db_fields_map(frames=frames)
return [db_fields_map.get(p, p) for p in paths]

def _get_media_fields(
self, include_filepath=True, whitelist=None, frames=False
):
def _get_media_fields(self, whitelist=None, blacklist=None, frames=False):
media_fields = {}

if frames:
Expand All @@ -10674,13 +10672,13 @@ def _get_media_fields(
schema = self.get_field_schema(flat=True)
app_media_fields = set(self._dataset.app_config.media_fields)

if include_filepath:
# 'filepath' should already be in set, but add it just in case
app_media_fields.add("filepath")
else:
app_media_fields.discard("filepath")
# 'filepath' should already be in set, but add it just in case
app_media_fields.add("filepath")

for field_name, field in schema.items():
while isinstance(field, fof.ListField):
field = field.field

if field_name in app_media_fields:
media_fields[field_name] = None
elif isinstance(field, fof.EmbeddedDocumentField) and issubclass(
Expand All @@ -10695,14 +10693,28 @@ def _get_media_fields(
whitelist = {whitelist}

media_fields = {
k: v for k, v in media_fields.items() if k in whitelist
k: v
for k, v in media_fields.items()
if any(w == k or k.startswith(w + ".") for w in whitelist)
}

if blacklist is not None:
if etau.is_container(blacklist):
blacklist = set(blacklist)
else:
blacklist = {blacklist}

media_fields = {
k: v
for k, v in media_fields.items()
if not any(w == k or k.startswith(w + ".") for w in blacklist)
}

return media_fields

def _resolve_media_field(self, media_field):
def _parse_media_field(self, media_field):
if media_field in self._dataset.app_config.media_fields:
return media_field
return media_field, None

_media_field, is_frame_field = self._handle_frame_field(media_field)

Expand All @@ -10711,12 +10723,20 @@ def _resolve_media_field(self, media_field):
if leaf is not None:
leaf = root + "." + leaf

if _media_field in (root, leaf):
if _media_field in (root, leaf) or root.startswith(
_media_field + "."
):
_resolved_field = leaf if leaf is not None else root
if is_frame_field:
_resolved_field = self._FRAMES_PREFIX + _resolved_field

return _resolved_field
_list_fields = self._parse_field_name(
_resolved_field, auto_unwind=False
)[-2]
if _list_fields:
return _resolved_field, _list_fields[0]

return _resolved_field, None

raise ValueError("'%s' is not a valid media field" % media_field)

Expand Down
7 changes: 4 additions & 3 deletions fiftyone/core/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,8 @@ class Detection(_HasAttributesDict, _HasID, _HasMedia, Label):
its bounding box, which should be a 2D binary or 0/1 integer numpy
array
mask_path (None): the absolute path to the instance segmentation image
on disk
on disk, which should be a single-channel PNG image where any
non-zero values represent the instance's extent
confidence (None): a confidence in ``[0, 1]`` for the detection
index (None): an index for the object
attributes ({}): a dict mapping attribute names to :class:`Attribute`
Expand Down Expand Up @@ -532,8 +533,8 @@ def to_segmentation(self, mask=None, frame_size=None, target=255):
"""
if not self.has_mask:
raise ValueError(
"Only detections with their `mask` attributes populated can "
"be converted to segmentations"
"Only detections with their `mask` or `mask_path` attribute "
"populated can be converted to segmentations"
)

mask, target = _parse_segmentation_target(mask, frame_size, target)
Expand Down
106 changes: 61 additions & 45 deletions fiftyone/utils/data/exporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
import warnings
from collections import defaultdict

from bson import json_util
import pydash

import eta.core.datasets as etad
import eta.core.frameutils as etaf
import eta.core.serial as etas
import eta.core.utils as etau
from bson import json_util

import fiftyone as fo
import fiftyone.core.collections as foc
Expand Down Expand Up @@ -1892,7 +1894,7 @@ def log_collection(self, sample_collection):
self._metadata["frame_fields"] = schema

self._media_fields = sample_collection._get_media_fields(
include_filepath=False
blacklist="filepath",
)

info = dict(sample_collection.info)
Expand Down Expand Up @@ -2029,34 +2031,38 @@ def _export_frame_labels(self, sample, uuid):

def _export_media_fields(self, sd):
for field_name, key in self._media_fields.items():
value = sd.get(field_name, None)
if value is None:
continue

if key is not None:
self._export_media_field(value, field_name, key=key)
else:
self._export_media_field(sd, field_name)
self._export_media_field(sd, field_name, key=key)

def _export_media_field(self, d, field_name, key=None):
if key is not None:
value = d.get(key, None)
else:
key = field_name
value = d.get(field_name, None)

value = pydash.get(d, field_name, None)
if value is None:
return

media_exporter = self._get_media_field_exporter(field_name)
outpath, _ = media_exporter.export(value)

if self.abs_paths:
d[key] = outpath
else:
d[key] = fou.safe_relpath(
outpath, self.export_dir, default=outpath
)
if not isinstance(value, (list, tuple)):
value = [value]

for _d in value:
if key is not None:
_value = _d.get(key, None)
else:
_value = _d

if _value is None:
continue

outpath, _ = media_exporter.export(_value)

if not self.abs_paths:
outpath = fou.safe_relpath(
outpath, self.export_dir, default=outpath
)

if key is not None:
_d[key] = outpath
else:
pydash.set_(d, field_name, outpath)

def _get_media_field_exporter(self, field_name):
media_exporter = self._media_field_exporters.get(field_name, None)
Expand Down Expand Up @@ -2196,7 +2202,7 @@ def export_samples(self, sample_collection, progress=None):
_sample_collection = sample_collection

self._media_fields = sample_collection._get_media_fields(
include_filepath=False
blacklist="filepath"
)

logger.info("Exporting samples...")
Expand Down Expand Up @@ -2333,33 +2339,43 @@ def _prep_sample(sd):

def _export_media_fields(self, sd):
for field_name, key in self._media_fields.items():
value = sd.get(field_name, None)
if value is None:
continue
self._export_media_field(sd, field_name, key=key)

def _export_media_field(self, d, field_name, key=None):
value = pydash.get(d, field_name, None)
if value is None:
return

media_exporter = self._get_media_field_exporter(field_name)

if not isinstance(value, (list, tuple)):
value = [value]

for _d in value:
if key is not None:
self._export_media_field(value, field_name, key=key)
_value = _d.get(key, None)
else:
self._export_media_field(sd, field_name)
_value = _d

def _export_media_field(self, d, field_name, key=None):
if key is not None:
value = d.get(key, None)
else:
key = field_name
value = d.get(field_name, None)
if _value is None:
continue

if value is None:
return
if self.export_media is not False:
# Store relative path
_, uuid = media_exporter.export(_value)
outpath = os.path.join("fields", field_name, uuid)
elif self.rel_dir is not None:
# Remove `rel_dir` prefix from path
outpath = fou.safe_relpath(
_value, self.rel_dir, default=_value
)
else:
continue

if self.export_media is not False:
# Store relative path
media_exporter = self._get_media_field_exporter(field_name)
_, uuid = media_exporter.export(value)
d[key] = os.path.join("fields", field_name, uuid)
elif self.rel_dir is not None:
# Remove `rel_dir` prefix from path
d[key] = fou.safe_relpath(value, self.rel_dir, default=value)
if key is not None:
_d[key] = outpath
else:
pydash.set_(d, field_name, outpath)

def _get_media_field_exporter(self, field_name):
media_exporter = self._media_field_exporters.get(field_name, None)
Expand Down
52 changes: 32 additions & 20 deletions fiftyone/utils/data/importers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from bson import json_util
from mongoengine.base import get_document
import pydash

import eta.core.datasets as etad
import eta.core.image as etai
Expand Down Expand Up @@ -2151,32 +2152,43 @@ def _import_runs(dataset, runs, results_dir, run_cls):

def _parse_media_fields(sd, media_fields, rel_dir):
for field_name, key in media_fields.items():
value = sd.get(field_name, None)
value = pydash.get(sd, field_name, None)
if value is None:
continue

if isinstance(value, dict):
if key is False:
try:
_cls = value.get("_cls", None)
key = get_document(_cls)._MEDIA_FIELD
except Exception as e:
logger.warning(
"Failed to infer media field for '%s'. Reason: %s",
field_name,
e,
)
key = None

media_fields[field_name] = key

if key is not None:
path = value.get(key, None)
if path is not None and not os.path.isabs(path):
value[key] = os.path.join(rel_dir, path)
_parse_nested_media_field(
value, media_fields, rel_dir, field_name, key
)
elif isinstance(value, list):
for d in value:
_parse_nested_media_field(
d, media_fields, rel_dir, field_name, key
)
elif etau.is_str(value):
if not os.path.isabs(value):
sd[field_name] = os.path.join(rel_dir, value)
pydash.set_(sd, field_name, os.path.join(rel_dir, value))


def _parse_nested_media_field(d, media_fields, rel_dir, field_name, key):
if key is False:
try:
_cls = d.get("_cls", None)
key = get_document(_cls)._MEDIA_FIELD
except Exception as e:
logger.warning(
"Failed to infer media field for '%s'. Reason: %s",
field_name,
e,
)
key = None

media_fields[field_name] = key

if key is not None:
path = d.get(key, None)
if path is not None and not os.path.isabs(path):
d[key] = os.path.join(rel_dir, path)


class ImageDirectoryImporter(UnlabeledImageDatasetImporter):
Expand Down
Loading

0 comments on commit 64cf79b

Please sign in to comment.