Skip to content

Commit

Permalink
not the most robust but should work
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Mar 13, 2024
1 parent c24053f commit 9ffce2a
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions transformer_nuggets/utils/shape_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,18 +133,17 @@ def construct_input(
input_dict = {}
for pair in input_str:
k, v = pair.split(":")
# Check if dtype is present
assert (
"[" in v
), f"Shape is not present in {v}, you are calling a function that accepts non tensor kwargs and I have yet to create string serialization for these"
maybe_dtype_tuple = v.split("[")
maybe_dtype = maybe_dtype_tuple[0]
shape = "[" + maybe_dtype_tuple[1]
shape = ast.literal_eval(shape)
dtype = abbr_to_dtype[maybe_dtype] if maybe_dtype else default_dtype
# Need to convert the input name to self if it is input
input_dict[input_to_self(k)] = torch.rand(
shape, dtype=dtype, device=device, requires_grad=requires_grad
)
# If '[' is in the value, then it is a tensor
if "[" in v:
maybe_dtype_tuple = v.split("[")
maybe_dtype = maybe_dtype_tuple[0]
shape = "[" + maybe_dtype_tuple[1]
shape = ast.literal_eval(shape)
dtype = abbr_to_dtype[maybe_dtype] if maybe_dtype else default_dtype
input_dict[input_to_self(k)] = torch.rand(
shape, dtype=dtype, device=device, requires_grad=requires_grad
)
else:
input_dict[k] = ast.literal_eval(v)
op_inpts.append(input_dict)
return op_inpts

0 comments on commit 9ffce2a

Please sign in to comment.