Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pytorch] Can't get byte data from torch.ScalarType.Byte tensor #1321

Closed
jxtps opened this issue Feb 9, 2023 · 23 comments
Closed

[pytorch] Can't get byte data from torch.ScalarType.Byte tensor #1321

jxtps opened this issue Feb 9, 2023 · 23 comments

Comments

@jxtps
Copy link
Contributor

jxtps commented Feb 9, 2023

I'm using fairly large images and want to avoid the overhead of using floats. However, when I try to read byte data from a torch.ScalarType.Byte tensor I get a weird exception.

Minimal code example to reproduce:

package misc;

import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.pytorch.Tensor;
import org.bytedeco.pytorch.global.torch;

import java.nio.FloatBuffer;

public class PytorchUint8TestBench {
    public static void main(String[] args) {
        long[] shape = new long[]{1, 3, 20, 20};
        float[] floats = new float[3 * 20 * 20];
        FloatPointer p = new FloatPointer(FloatBuffer.wrap(floats));
        Tensor floatT = org.bytedeco.pytorch.global.torch.from_blob(p, shape);
        Tensor byteT = floatT.to(torch.ScalarType.Byte);
        System.out.println("byteT.dtype().isScalarType(torch.ScalarType.Byte): " + byteT.dtype().isScalarType(torch.ScalarType.Byte));
        BytePointer byteP = byteT.data_ptr_byte(); // <-- Exception here
        byte[] bytes = new byte[3 * 20 * 20];
        byteP.get(bytes);
    }
}

This produces:

byteT.dtype().isScalarType(torch.ScalarType.Byte): true
Exception in thread "main" java.lang.RuntimeException: expected scalar type Char but found Byte
Exception raised from data_ptr at C:\build\aten\src\ATen\core\TensorMethods.cpp:18 (most recent call first):
00007FFAAF449CD200007FFAAF449C70 c10.dll!c10::Error::Error [<unknown file> @ <unknown line number>]
00007FFAAF44975E00007FFAAF449710 c10.dll!c10::detail::torchCheckFail [<unknown file> @ <unknown line number>]
00007FF8804BB95D00007FF8804BB8E0 torch_cpu.dll!at::TensorBase::data_ptr<signed char> [<unknown file> @ <unknown line number>]
00007FFA6E330B5700007FFA6E330AE0 jnitorch.dll!Java_org_bytedeco_pytorch_TensorBase_data_1ptr_1byte [<unknown file> @ <unknown line number>]
0000000003A217B0 <unknown symbol address> !<unknown symbol> [<unknown file> @ <unknown line number>]

	at org.bytedeco.pytorch.TensorBase.data_ptr_byte(Native Method)
	at misc.PytorchUint8TestBench.main(PytorchUint8TestBench.java:18)

Reading the exception it looks like Java_org_bytedeco_pytorch_TensorBase_data_1ptr_1byte is for some reason trying to access TensorBase::data_ptr<signed char> when it should probably be TensorBase::data_ptr<byte>?

Version: 1.10.2 & 1.12.1 on windows. Haven't tried 1.13.1 as snapshots are tricky in my setup.

@jxtps
Copy link
Contributor Author

jxtps commented Feb 9, 2023

Workaround:

        Pointer pp = byteT.data_ptr(); 
        BytePointer byteP = new BytePointer(pp);

I checked that I got reasonable data back and it seems to work. That also allows viewing the data as e.g. ints, which can be useful when working with images.

@saudet
Copy link
Member

saudet commented Feb 10, 2023

There are no types named "byte" in C++. I wonder what it's expecting here. @HGuillemet Would you know?

@jxtps
Copy link
Contributor Author

jxtps commented Feb 10, 2023

https://github.com/bytedeco/javacpp-presets/blob/master/pytorch/src/gen/java/org/bytedeco/pytorch/TensorBase.java#L297 reads:

public native @Name("data_ptr<int8_t>") BytePointer data_ptr_byte();

What is int8_t defined as?

@HGuillemet
Copy link
Collaborator

Apparently, for pytorch scalar type BYTE is unsigned byte and CHAR is signed byte, so we need to change the Info in the presets to:

               .put(new Info("at::TensorBase::data_ptr<uint8_t>").javaNames("data_ptr_byte"))
               .put(new Info("at::TensorBase::data_ptr<int8_t>").javaNames("data_ptr_char"))
...
               .put(new Info("at::Tensor::item<uint8_t>").javaNames("item_byte"))
               .put(new Info("at::Tensor::item<int8_t>").javaNames("item_char"))

BTW, we could add something for half floats, too, but since there is no java type for them and since we would need to return a ShortPointer, we could as well leave as is and have the users call the raw data_ptr as @jxtps did above.

@HGuillemet
Copy link
Collaborator

HGuillemet commented Feb 10, 2023

Or we could use the terms of dtype: int8 and uint8.

Also it would be nicer if item_byte returned a java int (or short) equal to the unsigned byte.

@saudet
Copy link
Member

saudet commented Feb 10, 2023 via email

@saudet
Copy link
Member

saudet commented Feb 10, 2023

Or we could use the terms of dtype: int8 and uint8.

But if you want to do something like that, please be consistent with the C++ API, not the Python API.

@jxtps
Copy link
Contributor Author

jxtps commented Feb 10, 2023

Also it would be nicer if item_byte returned a java int (or short) equal to the unsigned byte.

No, no, please don't! That explodes the memory consumption, which would (at least for my use-case) completely undo the rationale for using bytes in the first place.

A ShortPointer for FP16 sounds good, I guess it'll be a while until JEP 401 lands.

@HGuillemet
Copy link
Collaborator

item_byte is the function returning a single value, when the tensor has 0 dimension.
data_ptr_int8 and data_ptr_uint8 will both return byte pointers of course.

@saudet
Copy link
Member

saudet commented Feb 10, 2023

Or maybe it's OK the way it is now. Java doesn't have types for unsigned bytes, half floats, and what not, so in those cases it makes sense to leave users deal with the raw Pointer themselves. There are a few indexers to deal with those in JavaCPP though, and I've already added something for them in AbstractTensor: https://github.com/bytedeco/javacpp-presets/blob/master/pytorch/src/main/java/org/bytedeco/pytorch/AbstractTensor.java#L98
@jxtps Are you having any problems when calling Tensor.createIndexer()?

@saudet saudet added enhancement and removed bug labels Feb 11, 2023
@jxtps
Copy link
Contributor Author

jxtps commented Feb 11, 2023

    byte[] bytes = new byte[128];
    bytes[0] = 127;
    Tensor t = Tensor.create(bytes);
    UByteIndexer i = t.createIndexer();
    int a = i.get(0);
    System.out.println("a: " + a);

produces the expected a: 127, so looks ok?

@HGuillemet
Copy link
Collaborator

Is this issue solved and can be closed ?

@jxtps
Copy link
Contributor Author

jxtps commented Oct 17, 2023

The original exception still happens on "org.bytedeco" % "pytorch-platform" % "2.0.1-1.5.9":

java.lang.RuntimeException: expected scalar type Char but found Byte
Exception raised from data_ptr at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\build\build\aten\src\ATen\core\TensorMethods.cpp:20 (most recent call first):
00007FFF39AAD24200007FFF39AAD1E0 c10.dll!c10::Error::Error [<unknown file> @ <unknown line number>]
00007FFF39AACE8A00007FFF39AACE30 c10.dll!c10::detail::torchCheckFail [<unknown file> @ <unknown line number>]
00007FFE0C4B80E300007FFE0C4B8000 torch_cpu.dll!at::TensorBase::data_ptr<signed char> [<unknown file> @ <unknown line number>]
00007FFF17D3DC1700007FFF17D3DBA0 jnitorch.dll!Java_org_bytedeco_pytorch_TensorBase_data_1ptr_1byte [<unknown file> @ <unknown line number>]
0000000002BD17B0 <unknown symbol address> !<unknown symbol> [<unknown file> @ <unknown line number>]

We should probably either fix it, or remove the TensorBase.data_ptr_byte()?

@HGuillemet
Copy link
Collaborator

It has already been renamed to data_ptr_char() since PR #1360

@saudet
Copy link
Member

saudet commented Oct 17, 2023

It has already been renamed to data_ptr_char() since PR #1360

Why did you rename that? Please don't break backward compatibility just for fun

@saudet
Copy link
Member

saudet commented Oct 18, 2023

I assumed the plan was to add getters for uint8_t as well, that hasn't been done.

@HGuillemet
Copy link
Collaborator

Right, we can add back data_ptr_byte for data_ptr<uint8_t> and item_byte for item<uint8_t>

@saudet
Copy link
Member

saudet commented Oct 18, 2023

Yes, at least that would kind of make sense.

@HGuillemet
Copy link
Collaborator

A detail to be decided: should item_byte (used when the tensor type is uint8_t) return a byte (signed), short or int ?
I'd go for int

@jxtps
Copy link
Contributor Author

jxtps commented Oct 18, 2023

int sounds great. I trust the reintroduced data_ptr_byte won't crash?

@HGuillemet
Copy link
Collaborator

Sure, as long as your tensor is ScalarType.Byte.
Concerning item<*>, in fact * is the type we want the result cast into, while in data_ptr<*> * must match the type of the tensor. We can do t.item_int() even if t contains floats or bytes. So no need for a item_byte.

@saudet
Copy link
Member

saudet commented Oct 19, 2023

int item_byte() sounds fine though.

@HGuillemet
Copy link
Collaborator

It would be the same as int item_int().
item<X> is a macro to item().toX where item() returns a Scalar of the same type as the tensor.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants