-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-664] Support integer type in ImageIter #11864
Conversation
tests/python/unittest/test_image.py
Outdated
path_root='') | ||
for batch in test_iter: | ||
pass | ||
for dtype in ['int32', 'float32']: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this only support int32 & float32? If possible please also check float64 and int64.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I can check float64 and int64
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If that's the case then please also update the doc string above correspondingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall. Left two small comments.
python/mxnet/image/image.py
Outdated
@@ -1091,7 +1093,7 @@ def __init__(self, batch_size, data_shape, label_width=1, | |||
imgkeys = [] | |||
for line in iter(fin.readline, ''): | |||
line = line.strip().split('\t') | |||
label = nd.array([float(i) for i in line[1:-1]]) | |||
label = nd.array([i for i in line[1:-1]], dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[i for i in line[1:-1]] is just line[1:-1], right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct. made the change
tests/python/unittest/test_image.py
Outdated
test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=im_list, | ||
path_root='') | ||
for _ in range(3): | ||
def check_ImageIter(dtype='float32'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use lower case imageiter to be consistent with test_imageiter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
44023ad
to
156d793
Compare
@@ -1091,7 +1093,7 @@ def __init__(self, batch_size, data_shape, label_width=1, | |||
imgkeys = [] | |||
for line in iter(fin.readline, ''): | |||
line = line.strip().split('\t') | |||
label = nd.array([float(i) for i in line[1:-1]]) | |||
label = nd.array(line[1:-1], dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if I pass dtype='int8' or something that is not supported?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
int8 specifically doesn't cause any issue. but yes, generally, a check is required. Adding an assert for the next iteration. Adding a test for the default case too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your contributions Vandana. I have 2 comments:
- What happens if user pass unsupported dtype?
- Please add tests for default case, i.e., I do not pass a dtype fir the Iter
@sandeep-krishnamurthy I've updated the commit - addressed your review comments. Please have a look at the updated changes and let me know your inputs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks.
Thanks for working on this. LGTM. |
Thanks @haojin2 @apeforest @sandeep-krishnamurthy @Roshrini @szha Can this change be merged? |
Description
Support int32 type for labels in ImageIter. By default, labels are of type float and therefore lead to precision issues.
Checklist
Essentials
Changes
Comments