Skip to content

Commit

Permalink
fix ruff lint
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Sep 13, 2024
1 parent da6a595 commit 0a79fd3
Show file tree
Hide file tree
Showing 8 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/gluonts/core/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def equals_default_impl(this: Any, that: Any) -> bool:
A boolean value indicating whether ``this`` and ``that`` are
structurally equal.
"""
if type(this) != type(that):
if type(this) is not type(that):
return False

if hasattr(this, "__init_args__") and hasattr(that, "__init_args__"):
Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/dataset/arrow/dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def decode_batch(self, batch):
)

for name, value in row.items():
if type(value) == np.ndarray and value.dtype == object:
if type(value) is np.ndarray and value.dtype == object:
row[name] = np.stack(value)

yield row
2 changes: 1 addition & 1 deletion src/gluonts/mx/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def equals_parameter_dict(
equals
Dispatching function.
"""
if type(this) != type(that):
if type(this) is not type(that):
return False

def strip_prefix_enumeration(key, prefix):
Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/mx/model/n_beats/_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def __eq__(self, that):
if this returns false if for some reason the order of the predictors
list has been altered.
"""
if type(self) != type(that):
if type(self) is not type(that):
return False

try:
Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/mx/model/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def predict(
)

def __eq__(self, that):
if type(self) != type(that):
if type(self) is not type(that):
return False

if not equals(self.prediction_length, that.prediction_length):
Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/testutil/equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def assert_recursively_close(


def _assert_recursively_close(obj_a, obj_b, location, *args, **kwargs):
assert type(obj_a) == type(
assert type(obj_a) is type(
obj_b
), f"types don't match (at {location}) {type(obj_a)} != {type(obj_b)}"
if isinstance(obj_a, (str, int)):
Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/torch/model/deep_npts/_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@


def init_weights(module: nn.Module, scale: float = 1.0):
if type(module) == nn.Linear:
if type(module) is nn.Linear:
nn.init.uniform_(module.weight, -scale, scale)
nn.init.zeros_(module.bias)

Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/zebras/_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def pad(self, value, left: int = 0, right: int = 0) -> TimeSeries:
@staticmethod
def _batch(xs: List[TimeSeries]) -> BatchTimeSeries:
for series in xs:
assert type(series) == TimeSeries
assert type(series) is TimeSeries

pluck = pluck_attr(xs)

Expand Down

0 comments on commit 0a79fd3

Please sign in to comment.