Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Op] Multinomial #12284

Merged
merged 8 commits into from
Aug 8, 2022
Merged

[Relay][Op] Multinomial #12284

merged 8 commits into from
Aug 8, 2022

Conversation

jwfromm
Copy link
Contributor

@jwfromm jwfromm commented Aug 3, 2022

This PR introduces the multinomial random operator. It's a neat adaptation of random.uniform that allows weighted selection of indices from a probability tensor. This op is used in new Dalle-like architectures to generate random images. The PR provides a topi implemenation and tests, relay integration, and an initial pytorch integration. I did not implement sampling without replacement at this time as it seems complicated to do as a tensor operation.

@jwfromm
Copy link
Contributor Author

jwfromm commented Aug 3, 2022

@sfvaroglu can you take a look at this PR?

@jwfromm jwfromm requested a review from tkonolige August 3, 2022 05:56
@sfvaroglu
Copy link
Contributor

LGTM, thanks @jwfromm! Would be nice to have this in the onnx importer, too :)

Copy link
Contributor

@tkonolige tkonolige left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good @jwfromm. I would just like a little testing on the output to make sure it is actually a multinomial distribution. Let me know if you think that is too complicated.

assert not (
replacement is False and num_samples > 1
), "Multinomial without replacement is not yet supported."
seed = np.random.randint(1e6)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally there would be one seed that we pass through the entire graph that is set or initialized at runtime. But I don't think we have the infrastructure for that yet. This is fine for now but maybe you could add a comment about how to improve this in the future?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah thats a good point, we'd have to use a global dictionary or something for that. I'll add a note. For now, this approach matches how we handle other rng importer functions.

@@ -157,8 +185,27 @@ def test_uniform(target, dev):
assert np.max(rands) <= 10.0


@tvm.testing.parametrize_targets
def test_multinomial(target, dev):
def _verify_multinomial(size, num_samples):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you do some rough checking of expected value and variance of the distribution. It's always hard to tell if these random things are implemented correctly, but I think this would help.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a good way to do this without potentially introducing flakiness? I guess we could use a fixed seed. Would that be satisfactory?

Copy link
Contributor

@octoJon octoJon Aug 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could generate a "large" sample of at least 10,000 values and then use a chi-squared test (scipy.stats.chisquare). You'd look at the p-value from that chi-squared test and compare it to an acceptably low threshold for flakiness -- for example, have this unit test fail if the p-value is smaller than 1e-6, which should only happen by chance in one run per million.

Copy link
Contributor Author

@jwfromm jwfromm Aug 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this tip. I added a chisquared test which confirms that the behavior of this function is expected.

@jwfromm
Copy link
Contributor Author

jwfromm commented Aug 6, 2022

@tkonolige can you give this another look. I think its all set to merge.

Copy link
Member

@junrushao junrushao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Contributor

@tkonolige tkonolige left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jwfromm

@tkonolige tkonolige merged commit b79f950 into apache:main Aug 8, 2022
xinetzone pushed a commit to daobook/tvm that referenced this pull request Nov 25, 2022
* Add multinomial operator.
* Implemented Pytorch integration with multinomial.
* Fixed test paramatrization and added onnx integration.
* Add statistical testing.
* Make get_type more flexible.
@jwfromm jwfromm deleted the torch_multinomial branch April 12, 2023 15:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants