Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In model configuration support defining list of tensors as input #2593

Closed
arunsu opened this issue Mar 4, 2021 · 5 comments
Closed

In model configuration support defining list of tensors as input #2593

arunsu opened this issue Mar 4, 2021 · 5 comments

Comments

@arunsu
Copy link

arunsu commented Mar 4, 2021

Is your feature request related to a problem? Please describe.
Some of the models take a list of images to inference. In the model configuration file "config.pbtxt" the tensor datatypes supported by Triton doesn't have a list. Is there a plan to support list or is there a work around to pass a list of tensors to model.

Example:
Retinanet model takes list of images
import torch
import torchvision.models as models
retina50 = models.detection.retinanet_resnet50_fpn(pretrained=True)
retina50.eval()
dummy_input = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
output = retina50(dummy_input)

Describe the solution you'd like
Support List of tensors or alternative to provide the list.

Describe alternatives you've considered
A clear and concise description of any alternative solutions or features you've considered.
In the config.pbtxt have the following input which doesn't work with the model.
input [
{
name: "input__0"
data_type: TYPE_FP32
dims: [3, 480, 640]
}
]

Additional context
Add any other context or screenshots about the feature request here.

@GuanLuo
Copy link
Contributor

GuanLuo commented Mar 4, 2021

So your model can accept variable number of inputs? Or what you are showing here is batching the images into one input? Usually the image models accepts a batch of images as input to process a group of images in the same inference, and Triton can serve such model and you will need to change the max_batch_size in your model config to comply with the max batch size of your model.

@CoderHam
Copy link
Contributor

CoderHam commented Mar 5, 2021

While libtorch does support passing lists of tensors, Tritonserver does not. You can build a simple wrapper model that coverts a tensor into a list of tensors and passes it on to your model. Once you have this wrapper model, simply trace the same and you should be able to use this model inside Triton.

@CoderHam
Copy link
Contributor

Closing this since Triton does not intend on adding support for handling list of tensors as inputs. Please use the work around shared above.

@ketan-b
Copy link

ketan-b commented Oct 8, 2021

@CoderHam What do you exactly mean by the wrapper model, could you please explain it to me a bit?

@wscjxky
Copy link

wscjxky commented Oct 25, 2021

  • I got the same problem,and I sovled . just change my model outputs to tuple or namedtuple ! and use jit.trace to export pt model

I guess tritonserver:21.09-py3 could support multiple inputs for pytorch

    def forward(self,image,points):
        coord_features = self.dist_maps(image, points)
        .........
        instance_out = feature_extractor_out[0]
        outputs = {'instances': instance_out}

        #### add following code
        if isinstance(outputs, dict):
            data_named_tuple = namedtuple("outputs", sorted(outputs.keys()))  # type: ignore
            outputs = data_named_tuple(**outputs)  # type: ignore

        elif isinstance(outputs, list):
            outputs = tuple(outputs)

        return outputs
        #### add following code
name: "infer_model"
platform: "pytorch_libtorch"

version_policy {
  latest {
    num_versions: 1
  }
}
input {
  name: "inputs__0"
  data_type: TYPE_FP32
  dims: [1, 3, 28, 28]
}
input {
  name: "inputs__1"
  data_type: TYPE_FP32
  dims: [1, 28, 28]
}

output {
  name: "classes__0"
  data_type: TYPE_FP32
  dims: [1, 1, 28, 28]
}
  • Triton Inference code
model = np.ones((1,28,28),dtype=np.float32)
check = np.ones((1,3,28,28),dtype=np.float32)

with httpclient.InferenceServerClient("localhost:8000") as client:
    inputs = [
        httpclient.InferInput("inputs__0", check.shape,
                              np_to_triton_dtype(check.dtype)),
        httpclient.InferInput("inputs__1", model.shape,
                              np_to_triton_dtype(check.dtype))
    ]
    inputs[0].set_data_from_numpy(check)
    inputs[1].set_data_from_numpy(model)
    result = client.infer("infer_model", inputs)
    print(result.as_numpy("classes__0"))
  • Otherwise I also try to create a wrapper code around the my model . just use split、cat、squeeze、narrow function to make multiple inputs combined into one

like this

    class model():
    ...
    def forward(self,inputs):
        image, points = torch.split(inputs, 1)
        points = points.squeeze(1).narrow(1, 0, 1).squeeze(1)

  • inference model
    x0=torch.ones(1, 3, 28, 28).to(device)
    x1=torch.ones(1, 28, 28).to(device)
    x1 = x1.unsqueeze(1).expand(1, 3, 2, 2)
    inputs = torch.cat([x0, x1])
    
    model(inputs)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

No branches or pull requests

5 participants