Skip to content

Commit

Permalink
Fix remaining Ruff UP complaints
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Feb 7, 2024
1 parent 348be3e commit ac0d782
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

class ListHandler(logging.Handler):
def __init__(self, *args, **kwargs):
super(ListHandler, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
self.logs = []

def emit(self, record):
Expand Down Expand Up @@ -161,8 +161,7 @@ def __len__(self):
return len(self.data)

def __iter__(self):
for element in self.data:
yield element
yield from self.data

iterable_dataset = DummyIterableDataset([n for n in range(30)])
dataloader = DataLoader(iterable_dataset, batch_size=4)
Expand Down Expand Up @@ -194,8 +193,7 @@ def __len__(self):
return len(self.data)

def __iter__(self):
for element in self.data:
yield element
yield from self.data

iterable_dataset = DummyIterableDataset(torch.as_tensor(range(30)))
dataloader = DataLoader(iterable_dataset, batch_size=4)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ def __init__(self, data):
self.data = data

def __iter__(self):
for element in self.data:
yield element
yield from self.data


def create_accelerator(even_batches=True):
Expand Down
5 changes: 2 additions & 3 deletions src/accelerate/utils/megatron_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,7 @@ def build_pretraining_data_loader(self, dataset, consumed_samples):
def build_train_valid_test_data_iterators(self):
def cyclic_iter(iter):
while True:
for x in iter:
yield x
yield from iter

args = get_args()

Expand Down Expand Up @@ -926,7 +925,7 @@ class MegatronEngine(torch.nn.Module):
"""

def __init__(self, accelerator, model, optimizer, scheduler):
super(MegatronEngine, self).__init__()
super().__init__()
self.module = model
self.base_model = model[0]
self.optimizer = optimizer
Expand Down
3 changes: 1 addition & 2 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,7 @@ def named_module_tensors(
Whether or not to remove the non persistent buffer from the buffers. Useful only when include_buffers =
True
"""
for named_parameter in module.named_parameters(recurse=recurse):
yield named_parameter
yield from module.named_parameters(recurse=recurse)

if include_buffers:
non_persistent_buffers = set()
Expand Down

0 comments on commit ac0d782

Please sign in to comment.