Skip to content
This repository has been archived by the owner on Nov 14, 2023. It is now read-only.

Commit

Permalink
Fix overly broad exception conversion in LambdaFunction.invoke (cva…
Browse files Browse the repository at this point in the history
…t-ai#6394)

The intent of the `try`/`except` statement is to catch accesses to
missing members of the `data` dictionary, but due to the large amount of
code in the `try` block, it may end up catching entirely unrelated
`KeyError`s. Those unrelated `KeyError`s should not be converted to
`ValidationError`s, since they might not have anything to do with input
validation, and the conversion will make it harder to debug these
exceptions.

An example of these misapplied conversions is a recent bug where a
`KeyError` was coming from inside `_get_image` (fixed by f6420eb).

To fix this, make sure to only catch `KeyError`s emitted by accesses to
`data`.
  • Loading branch information
SpecLad authored and mikhail-treskin committed Oct 25, 2023
1 parent ed0acad commit 6f76f8e
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 108 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added missed auto_add argument to Issue model (<https://github.com/opencv/cvat/pull/6364>)
- \[API\] Performance of several API endpoints (<https://github.com/opencv/cvat/pull/6340>)
- \[API\] Invalid schema for the owner field in several endpoints (<https://github.com/opencv/cvat/pull/6343>)
- Some internal errors occurring during lambda function invocations
could be mistakenly reported as invalid requests
(<https://github.com/opencv/cvat/pull/6394>)
- \[SDK\] Loading tasks that have been cached with the PyTorch adapter
(<https://github.com/opencv/cvat/issues/6047>)
- The problem with importing annotations if dataset has extra dots in filenames
Expand Down
219 changes: 111 additions & 108 deletions cvat/apps/lambda_manager/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,126 +210,129 @@ def invoke(
is_interactive: Optional[bool] = False,
request: Optional[Request] = None
):
try:
if db_job is not None and db_job.get_task_id() != db_task.id:
raise ValidationError("Job task id does not match task id",
code=status.HTTP_400_BAD_REQUEST
)

payload = {}
data = {k: v for k,v in data.items() if v is not None}
threshold = data.get("threshold")
if threshold:
payload.update({ "threshold": threshold })
quality = data.get("quality")
mapping = data.get("mapping", {})

task_attributes = {}
mapping_by_default = {}
for db_label in (db_task.project.label_set if db_task.project_id else db_task.label_set).prefetch_related("attributespec_set").all():
mapping_by_default[db_label.name] = {
'name': db_label.name,
'attributes': {}
if db_job is not None and db_job.get_task_id() != db_task.id:
raise ValidationError("Job task id does not match task id",
code=status.HTTP_400_BAD_REQUEST
)

payload = {}
data = {k: v for k,v in data.items() if v is not None}

def mandatory_arg(name: str) -> Any:
try:
return data[name]
except KeyError:
raise ValidationError(
"`{}` lambda function was called without mandatory argument: {}"
.format(self.id, name),
code=status.HTTP_400_BAD_REQUEST)

threshold = data.get("threshold")
if threshold:
payload.update({ "threshold": threshold })
quality = data.get("quality")
mapping = data.get("mapping", {})

task_attributes = {}
mapping_by_default = {}
for db_label in (db_task.project.label_set if db_task.project_id else db_task.label_set).prefetch_related("attributespec_set").all():
mapping_by_default[db_label.name] = {
'name': db_label.name,
'attributes': {}
}
task_attributes[db_label.name] = {}
for attribute in db_label.attributespec_set.all():
task_attributes[db_label.name][attribute.name] = {
'input_type': attribute.input_type,
'values': attribute.values.split('\n')
}
task_attributes[db_label.name] = {}
for attribute in db_label.attributespec_set.all():
task_attributes[db_label.name][attribute.name] = {
'input_type': attribute.input_type,
'values': attribute.values.split('\n')
}
if not mapping:
# use mapping by default to avoid labels in mapping which
# don't exist in the task
mapping = mapping_by_default
else:
# filter labels in mapping which don't exist in the task
mapping = {k:v for k,v in mapping.items() if v['name'] in mapping_by_default}
if not mapping:
# use mapping by default to avoid labels in mapping which
# don't exist in the task
mapping = mapping_by_default
else:
# filter labels in mapping which don't exist in the task
mapping = {k:v for k,v in mapping.items() if v['name'] in mapping_by_default}

attr_mapping = { label: mapping[label]['attributes'] if 'attributes' in mapping[label] else {} for label in mapping }
mapping = { modelLabel: mapping[modelLabel]['name'] for modelLabel in mapping }
attr_mapping = { label: mapping[label]['attributes'] if 'attributes' in mapping[label] else {} for label in mapping }
mapping = { modelLabel: mapping[modelLabel]['name'] for modelLabel in mapping }

supported_attrs = {}
for func_label, func_attrs in self.func_attributes.items():
if func_label not in mapping:
continue
supported_attrs = {}
for func_label, func_attrs in self.func_attributes.items():
if func_label not in mapping:
continue

mapped_label = mapping[func_label]
mapped_attributes = attr_mapping.get(func_label, {})
supported_attrs[func_label] = {}
mapped_label = mapping[func_label]
mapped_attributes = attr_mapping.get(func_label, {})
supported_attrs[func_label] = {}

if mapped_attributes:
task_attr_names = [task_attr for task_attr in task_attributes[mapped_label]]
for attr in func_attrs:
mapped_attr = mapped_attributes.get(attr["name"])
if mapped_attr in task_attr_names:
supported_attrs[func_label].update({ attr["name"]: task_attributes[mapped_label][mapped_attr] })
if mapped_attributes:
task_attr_names = [task_attr for task_attr in task_attributes[mapped_label]]
for attr in func_attrs:
mapped_attr = mapped_attributes.get(attr["name"])
if mapped_attr in task_attr_names:
supported_attrs[func_label].update({ attr["name"]: task_attributes[mapped_label][mapped_attr] })

# Check job frame boundaries
if db_job:
task_data = db_task.data
data_start_frame = task_data.start_frame
step = task_data.get_frame_step()

for key, desc in (
('frame', 'frame'),
('frame0', 'start frame'),
('frame1', 'end frame'),
):
if key not in data:
continue
# Check job frame boundaries
if db_job:
task_data = db_task.data
data_start_frame = task_data.start_frame
step = task_data.get_frame_step()

abs_frame_id = data_start_frame + data[key] * step
if not db_job.segment.contains_frame(abs_frame_id):
raise ValidationError(f"The {desc} is outside the job range",
code=status.HTTP_400_BAD_REQUEST)
for key, desc in (
('frame', 'frame'),
('frame0', 'start frame'),
('frame1', 'end frame'),
):
if key not in data:
continue

abs_frame_id = data_start_frame + data[key] * step
if not db_job.segment.contains_frame(abs_frame_id):
raise ValidationError(f"The {desc} is outside the job range",
code=status.HTTP_400_BAD_REQUEST)

if self.kind == LambdaType.DETECTOR:
if db_task.data.original_chunk_type == DataChoice.VIDEO:
data_path = db_task.data.video.path
elif db_task.data.original_chunk_type == DataChoice.IMAGESET:
data_path = db_task.data.images.get(frame=data["frame"]).path
else:
data_path = ""
payload.update({
"image": self._get_image(db_task, data["frame"], quality),
"data_path": data_path
})
elif self.kind == LambdaType.INTERACTOR:
payload.update({
"image": self._get_image(db_task, data["frame"], quality),
"pos_points": data["pos_points"][2:] if self.startswith_box else data["pos_points"],
"neg_points": data["neg_points"],
"obj_bbox": data["pos_points"][0:2] if self.startswith_box else None
})
elif self.kind == LambdaType.REID:
payload.update({
"image0": self._get_image(db_task, data["frame0"], quality),
"image1": self._get_image(db_task, data["frame1"], quality),
"boxes0": data["boxes0"],
"boxes1": data["boxes1"]
})
max_distance = data.get("max_distance")
if max_distance:
payload.update({
"max_distance": max_distance
})
elif self.kind == LambdaType.TRACKER:

if self.kind == LambdaType.DETECTOR:
if db_task.data.original_chunk_type == DataChoice.VIDEO:
data_path = db_task.data.video.path
elif db_task.data.original_chunk_type == DataChoice.IMAGESET:
data_path = db_task.data.images.get(frame=data["frame"]).path
else:
data_path = ""
payload.update({
"image": self._get_image(db_task, data["frame"], quality),
"data_path": data_path
elif self.kind == LambdaType.INTERACTOR:
payload.update({
"image": self._get_image(db_task, mandatory_arg("frame"), quality),
"pos_points": mandatory_arg("pos_points")[2:] if self.startswith_box else mandatory_arg("pos_points"),
"neg_points": mandatory_arg("neg_points"),
"obj_bbox": mandatory_arg("pos_points")[0:2] if self.startswith_box else None
})
elif self.kind == LambdaType.REID:
payload.update({
"image0": self._get_image(db_task, mandatory_arg("frame0"), quality),
"image1": self._get_image(db_task, mandatory_arg("frame1"), quality),
"boxes0": mandatory_arg("boxes0"),
"boxes1": mandatory_arg("boxes1")
})
max_distance = data.get("max_distance")
if max_distance:
payload.update({
"image": self._get_image(db_task, data["frame"], quality),
"shapes": data.get("shapes", []),
"states": data.get("states", [])
"max_distance": max_distance
})
else:
raise ValidationError(
'`{}` lambda function has incorrect type: {}'
.format(self.id, self.kind),
code=status.HTTP_500_INTERNAL_SERVER_ERROR)
except KeyError as err:
elif self.kind == LambdaType.TRACKER:
payload.update({
"image": self._get_image(db_task, mandatory_arg("frame"), quality),
"shapes": data.get("shapes", []),
"states": data.get("states", [])
})
else:
raise ValidationError(
"`{}` lambda function was called without mandatory argument: {}"
.format(self.id, str(err)),
code=status.HTTP_400_BAD_REQUEST)
'`{}` lambda function has incorrect type: {}'
.format(self.id, self.kind),
code=status.HTTP_500_INTERNAL_SERVER_ERROR)

if is_interactive and request:
interactive_function_call_signal.send(sender=self, request=request)
Expand Down

0 comments on commit 6f76f8e

Please sign in to comment.