Skip to content

Commit

Permalink
chore: Add test for torch.Tensor.uniform_ and remove unnecessary comm…
Browse files Browse the repository at this point in the history
…ents
  • Loading branch information
hmahmood24 committed Nov 17, 2023
1 parent e73a282 commit 79d36a1
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 3 deletions.
1 change: 0 additions & 1 deletion ivy/functional/frontends/torch/nn/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def __init__(self, data=None, device=None, requires_grad=True):
self.grad_fn = None

def __deepcopy__(self, memo):
# TODO: Need to add test. Adding for KLA demo (priority)
if id(self) in memo:
return memo[id(self)]
else:
Expand Down
14 changes: 12 additions & 2 deletions ivy/functional/frontends/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def __repr__(self):
)

def __hash__(self):
# TODO: Need to add test. Adding for KLA demo (priority)
return id(self)

# Properties #
Expand Down Expand Up @@ -1995,8 +1994,19 @@ def random_(
)
return self.ivy_array

@with_unsupported_dtypes(
{
"2.1.1 and below": (
"integer",
"unsigned",
"bfloat16",
"bool",
"complex",
)
},
"torch",
)
def uniform_(self, from_=0, to=1, *, generator=None):
# TODO: Need to add test. Adding for KLA demo (priority)
ret = ivy.random_uniform(
low=from_, high=to, shape=self.shape, dtype=self.dtype, seed=generator
)
Expand Down
49 changes: 49 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13500,6 +13500,55 @@ def test_torch_unfold(
)


# uniform_
@handle_frontend_method(
class_tree=CLASS_TREE,
init_tree="torch.tensor",
method_name="uniform_",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid"),
min_value=1,
max_value=5,
min_num_dims=1,
max_num_dims=5,
),
from_=helpers.floats(min_value=-1000, max_value=0),
to=helpers.floats(min_value=1, max_value=1000),
test_inplace=st.just(True),
)
def test_torch_uniform_(
dtype_and_x,
from_,
to,
frontend_method_data,
init_flags,
method_flags,
frontend,
on_device,
backend_fw,
):
input_dtype, x = dtype_and_x
method_flags.num_positional_args = 3
helpers.test_frontend_method(
init_input_dtypes=input_dtype,
backend_to_test=backend_fw,
init_all_as_kwargs_np={
"data": x[0],
},
method_input_dtypes=input_dtype,
method_all_as_kwargs_np={
"from_": from_,
"to": to,
},
frontend_method_data=frontend_method_data,
init_flags=init_flags,
method_flags=method_flags,
frontend=frontend,
on_device=on_device,
test_values=False,
)


# unique
@handle_frontend_method(
class_tree=CLASS_TREE,
Expand Down

0 comments on commit 79d36a1

Please sign in to comment.