Skip to content

Commit

Permalink
[review] use set instead of list
Browse files Browse the repository at this point in the history
  • Loading branch information
tenzen-y committed Aug 18, 2022
1 parent 5b190a4 commit 5bc97d0
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
5 changes: 2 additions & 3 deletions pkg/suggestion/v1beta1/nas/common/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
def validate_operations(operations: list[api_pb2.Operation]) -> (bool, str):

# Validate each operation
for operation in operations:
for operation in set(operations):

# Check OperationType
if not operation.operation_type:
Expand All @@ -29,8 +29,7 @@ def validate_operations(operations: list[api_pb2.Operation]) -> (bool, str):
return False, "Missing ParameterConfigs in Operation:\n{}".format(operation)

# Validate each ParameterConfig in Operation
parameters_list = list(operation.parameter_specs.parameters)
for parameter in parameters_list:
for parameter in set(operation.parameter_specs.parameters):

# Check Name
if not parameter.name:
Expand Down
8 changes: 4 additions & 4 deletions pkg/suggestion/v1beta1/nas/darts/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,17 +167,17 @@ def validate_algorithm_settings(algorithm_settings: list[api_pb2.AlgorithmSettin
return False, "{} should be greater than zero".format(s.name)

# Validate learning rate
if s.name in ["w_lr", "w_lr_min", "alpha_lr"]:
if s.name in {"w_lr", "w_lr_min", "alpha_lr"}:
if not float(s.value) >= 0.0:
return False, "{} should be greater than or equal to zero".format(s.name)

# Validate weight decay
if s.name in ["w_weight_decay", "alpha_weight_decay"]:
if s.name in {"w_weight_decay", "alpha_weight_decay"}:
if not float(s.value) >= 0.0:
return False, "{} should be greater than or equal to zero".format(s.name)

# Validate w_momentum and w_grad_clip
if s.name in ["w_momentum", "w_grad_clip"]:
if s.name in {"w_momentum", "w_grad_clip"}:
if not float(s.value) >= 0.0:
return False, "{} should be greater than or equal to zero".format(s.name)

Expand All @@ -190,7 +190,7 @@ def validate_algorithm_settings(algorithm_settings: list[api_pb2.AlgorithmSettin
return False, "num_workers should be greater than or equal to zero"

# Validate "init_channels", "print_step", "num_nodes" and "stem_multiplier"
if s.name in ["init_channels", "print_step", "num_nodes", "stem_multiplier"]:
if s.name in {"init_channels", "print_step", "num_nodes", "stem_multiplier"}:
if not int(s.value) >= 1:
return False, "{} should be greater than or equal to one".format(s.name)

Expand Down

0 comments on commit 5bc97d0

Please sign in to comment.