Skip to content

Commit

Permalink
fixed default attribute values for tracked shapes (#703)
Browse files Browse the repository at this point in the history
  • Loading branch information
azhavoro authored and nmanovic committed Sep 10, 2019
1 parent 0967a74 commit fda7c1a
Showing 1 changed file with 27 additions and 20 deletions.
47 changes: 27 additions & 20 deletions cvat/apps/engine/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@
from .log import slogger
from . import serializers

"""dot.notation access to dictionary attributes"""
class dotdict(OrderedDict):
__getattr__ = OrderedDict.get
__setattr__ = OrderedDict.__setitem__
__delattr__ = OrderedDict.__delitem__
__eq__ = lambda self, other: self.id == other.id
__hash__ = lambda self: self.id

class PatchAction(str, Enum):
CREATE = "create"
UPDATE = "update"
Expand Down Expand Up @@ -142,15 +150,6 @@ def bulk_create(db_model, objects, flt_param):
return []

def _merge_table_rows(rows, keys_for_merge, field_id):
"""dot.notation access to dictionary attributes"""
from collections import OrderedDict
class dotdict(OrderedDict):
__getattr__ = OrderedDict.get
__setattr__ = OrderedDict.__setitem__
__delattr__ = OrderedDict.__delitem__
__eq__ = lambda self, other: self.id == other.id
__hash__ = lambda self: self.id

# It is necessary to keep a stable order of original rows
# (e.g. for tracked boxes). Otherwise prev_box.frame can be bigger
# than next_box.frame.
Expand Down Expand Up @@ -202,12 +201,16 @@ def __init__(self, pk, user):
"all": OrderedDict(),
}
for db_attr in db_label.attributespec_set.all():
default_value = dotdict([
('spec_id', db_attr.id),
('value', db_attr.default_value),
])
if db_attr.mutable:
self.db_attributes[db_label.id]["mutable"][db_attr.id] = db_attr
self.db_attributes[db_label.id]["mutable"][db_attr.id] = default_value
else:
self.db_attributes[db_label.id]["immutable"][db_attr.id] = db_attr
self.db_attributes[db_label.id]["immutable"][db_attr.id] = default_value

self.db_attributes[db_label.id]["all"][db_attr.id] = db_attr
self.db_attributes[db_label.id]["all"][db_attr.id] = default_value

def reset(self):
self.ir_data.reset()
Expand Down Expand Up @@ -458,13 +461,13 @@ def delete(self, data=None):
self._commit()

@staticmethod
def _extend_attributes(attributeval_set, attribute_specs):
def _extend_attributes(attributeval_set, default_attribute_values):
shape_attribute_specs_set = set(attr.spec_id for attr in attributeval_set)
for db_attr_spec in attribute_specs:
if db_attr_spec.id not in shape_attribute_specs_set:
attributeval_set.append(OrderedDict([
('spec_id', db_attr_spec.id),
('value', db_attr_spec.default_value),
for db_attr in default_attribute_values:
if db_attr.spec_id not in shape_attribute_specs_set:
attributeval_set.append(dotdict([
('spec_id', db_attr.spec_id),
('value', db_attr.value),
]))

def _init_tags_from_db(self):
Expand Down Expand Up @@ -600,12 +603,16 @@ def _init_tracks_from_db(self):
self._extend_attributes(db_track.labeledtrackattributeval_set,
self.db_attributes[db_track.label_id]["immutable"].values())

default_attribute_values = self.db_attributes[db_track.label_id]["mutable"].values()
for db_shape in db_track["trackedshape_set"]:
db_shape["trackedshapeattributeval_set"] = list(
set(db_shape["trackedshapeattributeval_set"])
)
self._extend_attributes(db_shape["trackedshapeattributeval_set"],
self.db_attributes[db_track.label_id]["mutable"].values())
# in case of trackedshapes need to interpolate attriute values and extend it
# by previous shape attribute values (not default values)
self._extend_attributes(db_shape["trackedshapeattributeval_set"], default_attribute_values)
default_attribute_values = db_shape["trackedshapeattributeval_set"]


serializer = serializers.LabeledTrackSerializer(db_tracks, many=True)
self.ir_data.tracks = serializer.data
Expand Down

0 comments on commit fda7c1a

Please sign in to comment.