Skip to content

Commit

Permalink
Bugfix for point extraction with random crop disabled, plus allow con…
Browse files Browse the repository at this point in the history
…figuration of point extraction parallelism with a saner default. (#266)
  • Loading branch information
cdoersch authored Jan 10, 2023
1 parent 7dfdad0 commit f0cec8f
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions challenges/point_tracking/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def add_tracks(data,
area_range=(min_area, max_area),
max_attempts=20)
else:
crop_window = tf.constant([0, 0, shp[2], shp[3]],
crop_window = tf.constant([0, 0, shp[1], shp[2]],
dtype=tf.int32,
shape=[4])

Expand Down Expand Up @@ -692,6 +692,7 @@ def create_point_tracking_dataset(
sampling_stride=4,
max_seg_id=25,
max_sampled_frac=0.1,
num_parallel_point_extraction_calls=16,
**kwargs):
"""Construct a dataset for point tracking using Kubric: go/kubric.
Expand All @@ -712,6 +713,8 @@ def create_point_tracking_dataset(
the to graph is proportional to this number, so prefer small values.
max_sampled_frac: Float. The maximum fraction of points to sample from each
object, out of all points that lie on the sampling grid.
num_parallel_point_extraction_calls: Int. The num_parallel_calls for the
map function for point extraction.
**kwargs: additional args to pass to tfds.load.
Returns:
Expand All @@ -736,7 +739,7 @@ def create_point_tracking_dataset(
sampling_stride=sampling_stride,
max_seg_id=max_seg_id,
max_sampled_frac=max_sampled_frac),
num_parallel_calls=2)
num_parallel_calls=num_parallel_point_extraction_calls)
if shuffle_buffer_size is not None:
ds = ds.shuffle(shuffle_buffer_size)

Expand Down

0 comments on commit f0cec8f

Please sign in to comment.