Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic label field upgrades #3152

Merged
merged 16 commits into from
Jun 12, 2023
39 changes: 15 additions & 24 deletions fiftyone/core/clips.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,9 @@ def _untag_labels(self, tags, label_field, ids=None, label_ids=None):
)

def _to_source_ids(self, label_field, ids, label_ids):
label_type = self._source_collection._get_label_field_type(label_field)
is_list_field = issubclass(label_type, fol._LABEL_LIST_FIELDS)
_, is_list_field = self._source_collection._get_label_field_root(
label_field
)

if not is_list_field:
return ids, label_ids
Expand Down Expand Up @@ -774,6 +775,7 @@ def _write_temporal_detection_clips(
dataset, src_collection, field, other_fields=None
):
src_dataset = src_collection._dataset
root, is_list_field = src_collection._get_label_field_root(field)
label_type = src_collection._get_label_field_type(field)

supported_types = (fol.TemporalDetection, fol.TemporalDetections)
Expand Down Expand Up @@ -804,28 +806,24 @@ def _write_temporal_detection_clips(
{"$match": {"$expr": {"$gt": ["$" + field, None]}}},
]

if is_list_field:
pipeline.append({"$unwind": "$" + root})

if label_type is fol.TemporalDetections:
list_path = field + "." + label_type._LABEL_LIST_FIELD
pipeline.extend(
[
{"$unwind": "$" + list_path},
{"$addFields": {field: "$" + list_path}},
]
)
pipeline.append({"$addFields": {field: "$" + root}})

support_path = field + ".support"
pipeline.extend(
[
{
"$addFields": {
"_id": "$" + field + "._id",
"support": "$" + support_path,
"support": "$" + field + ".support",
field + "._cls": "Classification",
"_rand": {"$rand": {}},
"_dataset_id": dataset._doc.id,
}
},
{"$project": {support_path: False}},
{"$project": {field + ".support": False}},
{"$out": dataset._sample_collection_name},
]
)
Expand Down Expand Up @@ -975,22 +973,15 @@ def _write_manual_clips(dataset, src_collection, clips, other_fields=None):

def _get_trajectories(sample_collection, frame_field):
path = sample_collection._FRAMES_PREFIX + frame_field
label_type = sample_collection._get_label_field_type(path)
root, is_list_field = sample_collection._get_label_field_root(path)
root, _ = sample_collection._handle_frame_field(root)

if not issubclass(label_type, fol._LABEL_LIST_FIELDS):
raise ValueError(
"Frame field '%s' has type %s, but trajectories can only be "
"extracted for label list fields %s"
% (
frame_field,
label_type,
fol._LABEL_LIST_FIELDS,
)
)
if not is_list_field:
raise ValueError("Trajectories can only be extracted for label lists")

fn_expr = F("frames").map(F("frame_number"))
uuid_expr = F("frames").map(
F(frame_field + "." + label_type._LABEL_LIST_FIELD).map(
F(root).map(
F("label").concat(
".", (F("index") != None).if_else(F("index").to_string(), "")
)
Expand Down
124 changes: 74 additions & 50 deletions fiftyone/core/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -1547,9 +1547,8 @@ def _untag_labels(self, tags, label_field, ids=None, label_ids=None):
def _edit_label_tags(
self, update_fcn, label_field, ids=None, label_ids=None
):
label_type, root = self._get_label_field_path(label_field)
root, is_list_field = self._get_label_field_root(label_field)
_root, is_frame_field = self._handle_frame_field(root)
is_list_field = issubclass(label_type, fol._LABEL_LIST_FIELDS)

ops = []

Expand Down Expand Up @@ -1616,11 +1615,10 @@ def _get_selected_labels(self, ids=None, tags=None, fields=None):
is_list_fields = []
is_frame_fields = []
for label_field in label_fields:
label_type, id_path = view._get_label_field_path(label_field, "id")
is_list_field = issubclass(label_type, fol._LABEL_LIST_FIELDS)
root, is_list_field = view._get_label_field_root(label_field)
is_frame_field = view._is_frame_field(label_field)

paths.append(id_path)
paths.append(root + ".id")
is_list_fields.append(is_list_field)
is_frame_fields.append(is_frame_field)

Expand Down Expand Up @@ -2128,10 +2126,9 @@ def set_label_values(
if is_frame_field:
label_field = self._FRAMES_PREFIX + label_field

label_type, root = self._get_label_field_path(label_field)
root, is_list_field = self._get_label_field_root(label_field)
label_id_path = root + ".id"
_root, _ = self._handle_frame_field(root)
is_list_field = issubclass(label_type, fol._LABEL_LIST_FIELDS)
_, label_id_path = self._get_label_field_path(label_field, "id")

id_map = {}

Expand Down Expand Up @@ -2572,12 +2569,11 @@ def _set_labels(self, field_name, sample_ids, label_docs):
"(found: '%s')" % field_name
)

label_type = self._get_label_field_type(field_name)
root, is_list_field = self._get_label_field_root(field_name)
field_name, is_frame_field = self._handle_frame_field(field_name)

ops = []
if issubclass(label_type, fol._LABEL_LIST_FIELDS):
root = field_name + "." + label_type._LABEL_LIST_FIELD
if is_list_field:
elem_id = root + "._id"
set_path = root + ".$"

Expand Down Expand Up @@ -9549,32 +9545,18 @@ def _get_media_fields(
return media_fields

def _get_label_fields(self):
fields = self._get_sample_label_fields()
return [path for path, _ in _iter_label_fields(self)]

if self._has_frame_fields():
fields.extend(self._get_frame_label_fields())

return fields

def _get_sample_label_fields(self):
return list(
self.get_field_schema(
ftype=fof.EmbeddedDocumentField,
embedded_doc_type=fol.Label,
).keys()
)
def _get_label_field_schema(self):
schema = self.get_field_schema()
return dict(_iter_schema_label_fields(schema))

def _get_frame_label_fields(self):
def _get_frame_label_field_schema(self):
if not self._has_frame_fields():
return None

return [
self._FRAMES_PREFIX + field
for field in self.get_frame_field_schema(
ftype=fof.EmbeddedDocumentField,
embedded_doc_type=fol.Label,
).keys()
]
schema = self.get_frame_field_schema()
return dict(_iter_schema_label_fields(schema))

def _get_root_fields(self, fields):
root_fields = set()
Expand Down Expand Up @@ -9622,14 +9604,12 @@ def _get_label_field_type(self, field_name):
field_name, _ = self._handle_group_field(field_name)
field_name, is_frame_field = self._handle_frame_field(field_name)

if field_name.startswith(
"___"
): # for fiftyone.server.view hidden results
# for fiftyone.server.view hidden results
if field_name.startswith("___"):
field_name = field_name[3:]

if field_name.startswith(
"__"
): # for fiftyone.core.stages hidden results
# for fiftyone.core.stages hidden results
if field_name.startswith("__"):
field_name = field_name[2:]

if is_frame_field:
Expand Down Expand Up @@ -9659,6 +9639,18 @@ def _get_label_field_type(self, field_name):

return field.document_type

def _get_label_field_root(self, field_name):
label_type = self._get_label_field_type(field_name)

if issubclass(label_type, fol._LABEL_LIST_FIELDS):
root = field_name + "." + label_type._LABEL_LIST_FIELD
is_list_field = True
else:
root = field_name
is_list_field = isinstance(self.get_field(root), fof.ListField)

return root, is_list_field

def _get_label_field_path(self, field_name, subfield=None):
label_type = self._get_label_field_type(field_name)

Expand Down Expand Up @@ -9773,6 +9765,37 @@ def _get_values_by_id(self, path_or_expr, ids, link_field=None):
return [values_map.get(i, None) for i in ids]


def _iter_label_fields(sample_collection):
schema = sample_collection.get_field_schema()
for path, field in _iter_schema_label_fields(schema):
yield path, field

if not sample_collection._has_frame_fields():
return

prefix = sample_collection._FRAMES_PREFIX
schema = sample_collection.get_frame_field_schema()
for path, field in _iter_schema_label_fields(schema):
yield prefix + path, field


def _iter_schema_label_fields(schema, recursive=True):
for path, field in schema.items():
if isinstance(field, fof.ListField):
field = field.field

if isinstance(field, fof.EmbeddedDocumentField):
if issubclass(field.document_type, fol.Label):
# Do not recurse into Label fields
yield path, field
else:
# Only recurse one level deep into embedded documents
for _path, _field in _iter_schema_label_fields(
field.get_field_schema(), recursive=False
):
yield path + "." + _path, _field


def _serialize_value(field_name, field, value, validate=True):
if value is None:
return None
Expand Down Expand Up @@ -9812,7 +9835,9 @@ def _parse_label_field(
return label_field

if _is_glob_pattern(label_field):
label_field = _get_matching_fields(sample_collection, label_field)
label_field = _get_matching_label_fields(
sample_collection, label_field
)

if etau.is_container(label_field):
return {f: f for f in label_field}
Expand Down Expand Up @@ -9852,7 +9877,7 @@ def _parse_frame_labels_field(
return frame_labels_field

if _is_glob_pattern(frame_labels_field):
frame_labels_field = _get_matching_fields(
frame_labels_field = _get_matching_label_fields(
sample_collection, frame_labels_field, frames=True
)

Expand Down Expand Up @@ -9890,13 +9915,16 @@ def _is_glob_pattern(s):
return "*" in s or "?" in s or "[" in s


def _get_matching_fields(sample_collection, patt, frames=False):
def _get_matching_label_fields(sample_collection, patt, frames=False):
if frames:
schema = sample_collection.get_frame_field_schema()
label_schema = sample_collection._get_frame_label_field_schema()
else:
schema = sample_collection.get_field_schema()
label_schema = sample_collection._get_label_field_schema()

if label_schema is None:
return label_schema

return fnmatch.filter(list(schema.keys()), patt)
return fnmatch.filter(list(label_schema.keys()), patt)


def _get_default_label_fields_for_exporter(
Expand All @@ -9915,9 +9943,7 @@ def _get_default_label_fields_for_exporter(
return None

media_type = sample_collection.media_type
label_schema = sample_collection.get_field_schema(
ftype=fof.EmbeddedDocumentField, embedded_doc_type=fol.Label
)
label_schema = sample_collection._get_label_field_schema()

label_field_or_dict = _get_fields_with_types(
media_type,
Expand Down Expand Up @@ -9954,9 +9980,7 @@ def _get_default_frame_label_fields_for_exporter(
return None

media_type = sample_collection.media_type
frame_label_schema = sample_collection.get_frame_field_schema(
ftype=fof.EmbeddedDocumentField, embedded_doc_type=fol.Label
)
frame_label_schema = sample_collection._get_frame_label_field_schema()

frame_labels_field_or_dict = _get_fields_with_types(
media_type,
Expand Down
Loading