Skip to content
This repository has been archived by the owner on Mar 12, 2024. It is now read-only.

Remove NestedTensor #233

Closed
whatdhack opened this issue Sep 16, 2020 · 7 comments
Closed

Remove NestedTensor #233

whatdhack opened this issue Sep 16, 2020 · 7 comments
Labels
question Further information is requested

Comments

@whatdhack
Copy link

Instructions To Reproduce the 🐛 Bug:

Is there a way to replace the NestedTensor with 2 tensors - one for images and other for masks , without having to retrain the network ? The use of this odd class as input cuases lot of headache .

Expected behavior:

replace NestedTensor with 2 tensors.

Environment:

1.6

@zhiqwang
Copy link
Contributor

zhiqwang commented Sep 16, 2020

I think the question of the first importance is what the scenario you are facing. For examples, if you are processing the object detection problem, you can just drop the NestedTensor away, which corresponds the code in

detr/models/detr.py

Lines 59 to 60 in 5e66b4c

if isinstance(samples, (list, torch.Tensor)):
samples = nested_tensor_from_tensor_list(samples)

and

detr/models/backbone.py

Lines 72 to 73 in 5e66b4c

def forward(self, tensor_list: NestedTensor):
xs = self.body(tensor_list.tensors)

Because it only uses the Tensor information of the inputs. So it's easy to remove NestedTensor out in this conditions.

Otherwise in the segmentation tasks, it seems that we should do a lot to detach the object NestedTensor from the architecture (In my opinion, the NestedTensor here was just designed to handling the complicated segmentation tasks, Ref to issues #116 , @fmassa has given some expansion the intention for introducing NestedTensor here).

@whatdhack
Copy link
Author

Sounds like a good idea, however on further look, "mask" is used later on in "out", which in turn is used later.

    def forward(self, tensor_list: NestedTensor):
        xs = self.body(tensor_list.tensors)
        out: Dict[str, NestedTensor] = {}
        for name, x in xs.items():
            m = tensor_list.mask
            assert m is not None
            mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
            out[name] = NestedTensor(x, mask)
        return out
        if isinstance(samples, (list, torch.Tensor)):
            samples = nested_tensor_from_tensor_list(samples)
        features, pos = self.backbone(samples)

        src, mask = features[-1].decompose()
        assert mask is not None
        hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]

@whatdhack
Copy link
Author

This code looks more like C++ than python - are all these typing really ncecessary ?

@zhiqwang
Copy link
Contributor

zhiqwang commented Sep 17, 2020

This code looks more like C++ than python - are all these typing really ncecessary ?

Some of these typing is used for converting to torchscript, it's essential in this situation.

@zhiqwang
Copy link
Contributor

zhiqwang commented Sep 17, 2020

..., however on further look, "mask" is used later on in "out", which in turn is used later.

Sorry, you are right. I have confused the usage of NestedTensor with my own repo :( In my own understanding, DETR packages the NestedTensor using the function nested_tensor_from_tensor_list as below.

detr/util/misc.py

Lines 283 to 305 in 5e66b4c

def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
# TODO make this more general
if tensor_list[0].ndim == 3:
if torchvision._is_tracing():
# nested_tensor_from_tensor_list() does not export well to ONNX
# call _onnx_nested_tensor_from_tensor_list() instead
return _onnx_nested_tensor_from_tensor_list(tensor_list)
# TODO make it support different-sized images
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = [len(tensor_list)] + max_size
b, c, h, w = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], :img.shape[2]] = False
else:
raise ValueError('not supported')
return NestedTensor(tensor, mask)

So if you wanna drop NestedTensor out, it means that you should reimplement this function in another ways.

@fmassa
Copy link
Contributor

fmassa commented Sep 18, 2020

It is possible to remove NestedTensor without having to retrain the model, but will require a couple of changes to the code (which shouldn't be too difficult and most has already been pointed out).

For more background on NestedTensor, see #116

I believe I've answered your question, and as such I'm closing this issue, but let us know if you have further questions

@fmassa fmassa closed this as completed Sep 18, 2020
@fmassa fmassa added the question Further information is requested label Sep 18, 2020
@whatdhack
Copy link
Author

whatdhack commented Sep 18, 2020

Who knows when the Python NestedTensor going to finalized, whether every vendor decides to support that is probably years away, Would request figuring out a better way than just relying on uncertain future work.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants