-
Notifications
You must be signed in to change notification settings - Fork 750
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[pytorch] Can't get byte data from torch.ScalarType.Byte tensor #1321
Comments
Workaround: Pointer pp = byteT.data_ptr();
BytePointer byteP = new BytePointer(pp); I checked that I got reasonable data back and it seems to work. That also allows viewing the data as e.g. |
There are no types named "byte" in C++. I wonder what it's expecting here. @HGuillemet Would you know? |
public native @Name("data_ptr<int8_t>") BytePointer data_ptr_byte(); What is |
Apparently, for pytorch scalar type BYTE is unsigned byte and CHAR is signed byte, so we need to change the Info in the presets to:
BTW, we could add something for half floats, too, but since there is no java type for them and since we would need to return a |
Or we could use the terms of dtype: Also it would be nicer if |
Sounds all good! Please open a pull request :)
|
But if you want to do something like that, please be consistent with the C++ API, not the Python API. |
No, no, please don't! That explodes the memory consumption, which would (at least for my use-case) completely undo the rationale for using bytes in the first place. A ShortPointer for FP16 sounds good, I guess it'll be a while until JEP 401 lands. |
|
Or maybe it's OK the way it is now. Java doesn't have types for unsigned bytes, half floats, and what not, so in those cases it makes sense to leave users deal with the raw Pointer themselves. There are a few indexers to deal with those in JavaCPP though, and I've already added something for them in AbstractTensor: https://github.com/bytedeco/javacpp-presets/blob/master/pytorch/src/main/java/org/bytedeco/pytorch/AbstractTensor.java#L98 |
byte[] bytes = new byte[128];
bytes[0] = 127;
Tensor t = Tensor.create(bytes);
UByteIndexer i = t.createIndexer();
int a = i.get(0);
System.out.println("a: " + a); produces the expected |
Is this issue solved and can be closed ? |
The original exception still happens on
We should probably either fix it, or remove the |
It has already been renamed to |
Why did you rename that? Please don't break backward compatibility just for fun |
I assumed the plan was to add getters for uint8_t as well, that hasn't been done. |
Right, we can add back |
Yes, at least that would kind of make sense. |
A detail to be decided: should |
|
Sure, as long as your tensor is |
|
It would be the same as |
I'm using fairly large images and want to avoid the overhead of using floats. However, when I try to read byte data from a
torch.ScalarType.Byte
tensor I get a weird exception.Minimal code example to reproduce:
This produces:
Reading the exception it looks like
Java_org_bytedeco_pytorch_TensorBase_data_1ptr_1byte
is for some reason trying to accessTensorBase::data_ptr<signed char>
when it should probably beTensorBase::data_ptr<byte>
?Version: 1.10.2 & 1.12.1 on windows. Haven't tried 1.13.1 as snapshots are tricky in my setup.
The text was updated successfully, but these errors were encountered: