Skip to content

Commit ec30ee1

Browse files
authored
[Fix] Fix visualization bug in 3d pose (#2594)
1 parent cb48094 commit ec30ee1

File tree

2 files changed

+30
-15
lines changed

2 files changed

+30
-15
lines changed

demo/body3d_pose_lifter_demo.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def parse_args():
9090
'--save-predictions',
9191
action='store_true',
9292
default=False,
93-
help='whether to save predicted results')
93+
help='Whether to save predicted results')
9494
parser.add_argument(
9595
'--device', default='cuda:0', help='Device used for inference')
9696
parser.add_argument(
@@ -124,7 +124,14 @@ def parse_args():
124124
'--use-multi-frames',
125125
action='store_true',
126126
default=False,
127-
help='whether to use multi frames for inference in the 2D pose'
127+
help='Whether to use multi frames for inference in the 2D pose'
128+
'detection stage. Default: False.')
129+
parser.add_argument(
130+
'--online',
131+
action='store_true',
132+
default=False,
133+
help='Inference mode. If set to True, can not use future frame'
134+
'information when using multi frames for inference in the 2D pose'
128135
'detection stage. Default: False.')
129136

130137
args = parser.parse_args()
@@ -405,6 +412,10 @@ def main():
405412
'Only "PoseLifter" model is supported for the 2nd stage ' \
406413
'(2D-to-3D lifting)'
407414

415+
if args.use_multi_frames:
416+
assert 'frame_indices_test' in pose_estimator.cfg.data.test.data_cfg
417+
indices = pose_estimator.cfg.data.test.data_cfg['frame_indices_test']
418+
408419
pose_lifter.cfg.visualizer.radius = args.radius
409420
pose_lifter.cfg.visualizer.line_width = args.thickness
410421
pose_lifter.cfg.visualizer.det_kpt_color = det_kpt_color

mmpose/apis/inference_3d.py

+17-13
Original file line numberDiff line numberDiff line change
@@ -181,16 +181,11 @@ def collate_pose_sequence(pose_results_2d,
181181
pose_sequences = []
182182
for idx in range(N):
183183
pose_seq = PoseDataSample()
184-
gt_instances = InstanceData()
185184
pred_instances = InstanceData()
186185

187-
for k in pose_results_2d[target_frame][idx].gt_instances.keys():
188-
gt_instances.set_field(
189-
pose_results_2d[target_frame][idx].gt_instances[k], k)
190-
for k in pose_results_2d[target_frame][idx].pred_instances.keys():
191-
if k != 'keypoints':
192-
pred_instances.set_field(
193-
pose_results_2d[target_frame][idx].pred_instances[k], k)
186+
gt_instances = pose_results_2d[target_frame][idx].gt_instances.clone()
187+
pred_instances = pose_results_2d[target_frame][
188+
idx].pred_instances.clone()
194189
pose_seq.pred_instances = pred_instances
195190
pose_seq.gt_instances = gt_instances
196191

@@ -228,7 +223,7 @@ def collate_pose_sequence(pose_results_2d,
228223
# replicate the right most frame
229224
keypoints[:, frame_idx + 1:] = keypoints[:, frame_idx]
230225
break
231-
pose_seq.pred_instances.keypoints = keypoints
226+
pose_seq.pred_instances.set_field(keypoints, 'keypoints')
232227
pose_sequences.append(pose_seq)
233228

234229
return pose_sequences
@@ -276,8 +271,15 @@ def inference_pose_lifter_model(model,
276271
bbox_center = None
277272
bbox_scale = None
278273

274+
pose_results_2d_copy = []
279275
for i, pose_res in enumerate(pose_results_2d):
276+
pose_res_copy = []
280277
for j, data_sample in enumerate(pose_res):
278+
data_sample_copy = PoseDataSample()
279+
data_sample_copy.gt_instances = data_sample.gt_instances.clone()
280+
data_sample_copy.pred_instances = data_sample.pred_instances.clone(
281+
)
282+
data_sample_copy.track_id = data_sample.track_id
281283
kpts = data_sample.pred_instances.keypoints
282284
bboxes = data_sample.pred_instances.bboxes
283285
keypoints = []
@@ -292,11 +294,13 @@ def inference_pose_lifter_model(model,
292294
bbox_scale + bbox_center)
293295
else:
294296
keypoints.append(kpt[:, :2])
295-
pose_results_2d[i][j].pred_instances.keypoints = np.array(
296-
keypoints)
297+
data_sample_copy.pred_instances.set_field(
298+
np.array(keypoints), 'keypoints')
299+
pose_res_copy.append(data_sample_copy)
300+
pose_results_2d_copy.append(pose_res_copy)
297301

298-
pose_sequences_2d = collate_pose_sequence(pose_results_2d, with_track_id,
299-
target_idx)
302+
pose_sequences_2d = collate_pose_sequence(pose_results_2d_copy,
303+
with_track_id, target_idx)
300304

301305
if not pose_sequences_2d:
302306
return []

0 commit comments

Comments
 (0)