Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question: unsqueeze: axes in not an []int64 #200

Open
loeffel-io opened this issue Oct 4, 2022 · 3 comments
Open

Question: unsqueeze: axes in not an []int64 #200

loeffel-io opened this issue Oct 4, 2022 · 3 comments

Comments

@loeffel-io
Copy link
Contributor

loeffel-io commented Oct 4, 2022

Hello @owulveryck,

i try to run this onnx model:

model

with this code:

package main

import (
	"fmt"
	"github.com/owulveryck/onnx-go"
	"github.com/owulveryck/onnx-go/backend/x/gorgonnx"
	"gorgonia.org/tensor"
	"io/ioutil"
	"log"
)

func main() {
	backend := gorgonnx.NewGraph()
	model := onnx.NewModel(backend)

	var err error
	var b []byte
	var output []tensor.Tensor

	if b, err = ioutil.ReadFile("model.onnx"); err != nil {
		log.Fatal(err)
	}

	if err = model.UnmarshalBinary(b); err != nil {
		log.Fatal(err)
	}

	var acosGroupTensor tensor.Tensor
	if acosGroupTensor, err = tensor.Argmax(
		tensor.New(tensor.WithShape(1, 5), tensor.Of(tensor.Float32), tensor.WithBacking([]float32{0, 1, 0, 0, 0})),
		1,
	); err != nil {
		log.Fatal(err)
	}

	if err = model.SetInput(0, acosGroupTensor); err != nil {
		log.Fatal(err)
	}

	var acosRatioTensor = tensor.New(tensor.WithShape(1, 4), tensor.Of(tensor.Float32), tensor.WithBacking([]float32{0, 0, -1, 0}))

	if err = model.SetInput(1, acosRatioTensor); err != nil {
		log.Fatal(err)
	}

	var salesRatioTensor = tensor.New(tensor.WithShape(1, 4), tensor.Of(tensor.Float32), tensor.WithBacking([]float32{0, 0, -1, 0}))

	if err = model.SetInput(2, salesRatioTensor); err != nil {
		log.Fatal(err)
	}

	if err = backend.Run(); err != nil {
		log.Fatal(err)
	}

	if output, err = model.GetOutputTensors(); err != nil {
		log.Fatal(err)
	}
	// write the first output to stdout
	fmt.Println(output[0])
}

It looks like i miss something pretty small - would love to get some help ❤️ thank you

@owulveryck
Copy link
Contributor

My guess is that the axes is a int64 and not a []int64 which causes the error.

I don't know why, but you can try to change the code in unsqueeze.go like this:

	a, ok := o.Attributes["axes"].(int64)
	if !ok {
		return errors.New("unsqueeze: axes in not an []int64")
	}
        a.Axes = []int64{a}

@loeffel-io
Copy link
Contributor Author

Thanks @owulveryck, i created a workaround like this too but then i got stuck because the Cast operator is unimplemented.

I moved to https://github.com/triton-inference-server/server

I would leave this open to inform that the Unsqueeze operator seams unstable - feel free to close if not

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants