Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A fix for seed attribute in the keras random normal generator #2126

Merged
merged 8 commits into from
Mar 16, 2023
8 changes: 4 additions & 4 deletions tf2onnx/onnx_opset/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def version_1(cls, ctx, node, **kwargs):
# in the rewriter does not trigger. grappler will send the random uniform
# with shape as input so we need to pickup the input here and if the shape is
# const we make it an attribute.
seed = node.get_attr("seed")
node.set_attr("seed", float(seed.f))
seed = node.get_attr("seed2")
node.set_attr("seed", float(seed.i))
utils.make_sure(node.inputs[0].is_const(), "%s node with non-const shape requires opset >= 9", node.type)
shape = node.inputs[0].get_tensor_value()
ctx.remove_input(node, node.input[0], 0)
Expand All @@ -88,8 +88,8 @@ def version_9(cls, ctx, node, **kwargs):
if node.inputs[0].is_const():
cls.version_1(ctx, node, **kwargs)
else:
seed = node.get_attr("seed")
node.set_attr("seed", float(seed.f))
seed = node.get_attr("seed2")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to also update this for version_1(line 63)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @fatcat-z you're quite right, ill update version_1 as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little confused: is this behavior (using seed2 instead of seed) same among different TF versions? Did it change after one of tf version?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I've seen and tested, version_1 behaviour is also relevant for TF version.2
when the random generator dim/shape are constant and not derived from the data
batch size or any other dependent size.

I'm checking with TF V1 whether it was changed.

node.set_attr("seed", float(seed.i))
cast_node = ctx.make_node("Cast", [node.input[0]], attr={'to': onnx_pb.TensorProto.INT64})
const_node = ctx.make_node("ConstantOfShape", cast_node.output)
inputs = node.input.copy()
Expand Down