Skip to content

[Feature] Speed up the resume process of IterBased loop #1520

@yinaoxiong

Description

@yinaoxiong
Contributor

What is the feature?

next(self.dataloader_iterator)

现有的恢复方式会对dataloader 迭代 n 个step,当n较大时,速度会很慢,因为执行了实际的数据加载和处理逻辑。 是否有比较好的方式只迭代index,不执行实际的数据加载流程。

  1. 一种可能的方式是和用户约定一个返回虚拟数据的数据集接口,在恢复时返回虚拟数据,
class Dataset:

    def __getitem__(self, index):
        if self._skip_flag:
            return # Fake data
        # 处理数据
        return Real data

    def skip(self):
        self._skip_flag = True

    def resume(self):
        self._skip_flag = False



# loop中的处理逻辑
            if (
                hasattr(self.dataloader.dataset, "skip")
                and callable(self.dataloader.dataset.skip)
                and hasattr(self.dataloader.dataset, "resume")
                and callable(self.dataloader.dataset.resume)
            ):
                self.dataloader.dataset.skip()
                for _ in range(self._iter):
                    next(self.dataloader_iterator)
                self.dataloader.dataset.resume()
            else:
                for _ in range(self._iter):
                    next(self.dataloader_iterator)
  1. 方式一还是需要用户进行配合,是否可以对dataloader进行操作从而无感知的快速跳过?
                iter_batch_sampler = iter(self.dataloader.batch_sampler)
                for _ in range(self._iter):
                    next(iter_batch_sampler)

尝试直接迭代batch_sampler 在worker=0的时候是正常的,在多worker的时候恢复数据顺序出现错误。 像知道有没有什么比较好的解决方案

Any other context?

https://discuss.pytorch.org/t/is-there-any-way-to-skip-steps-in-a-dataloader/123201
https://pytorch.org/data/main/dataloader2.html

Snapshot the state of data-preprocessing pipeline (WIP)

Activity

zhouzaida

zhouzaida commented on May 17, 2024

@zhouzaida
Collaborator

一个最小改动的方案是在迭代前 mock dataset 的__getitem__方法:

    def run(self) -> None:
        """Launch training."""
        self.runner.call_hook('before_train')
        # In iteration-based training loop, we treat the whole training process
        # as a big epoch and execute the corresponding hook.
        self.runner.call_hook('before_train_epoch')
        if self._iter > 0:
            print_log(
                f'Advance dataloader {self._iter} steps to skip data '
                'that has already been trained',
                logger='current',
                level=logging.WARNING)
            # mock
            old_getitem = self.dataloader_iterator.dataset.__getitem__
            self.dataloader_iterator.dataset.__getitem__ = a_new_getitem_method
            for _ in range(self._iter):
                next(self.dataloader_iterator)
            self.dataloader_iterator.dataset.__getitem__ = old_getitem
linked a pull request that will close this issue on May 23, 2024
linked a pull request that will close this issue[Fix] Speed up "--resume" #1548on May 23, 2024
chtzs

chtzs commented on May 24, 2024

@chtzs

I believe this PR is the cause of the issue: #1471.
While it fixed the resume iteration problem, it also led to slow resume speed. A suitable solution would be to call the _next_index() method of the DataLoader's built-in iterator to skip a batch without reading the data.

hujh1994

hujh1994 commented on Apr 7, 2025

@hujh1994

一个最小改动的方案是在迭代前 mock dataset 的__getitem__方法:

def run(self) -> None:
    """Launch training."""
    self.runner.call_hook('before_train')
    # In iteration-based training loop, we treat the whole training process
    # as a big epoch and execute the corresponding hook.
    self.runner.call_hook('before_train_epoch')
    if self._iter > 0:
        print_log(
            f'Advance dataloader {self._iter} steps to skip data '
            'that has already been trained',
            logger='current',
            level=logging.WARNING)
        # mock
        old_getitem = self.dataloader_iterator.dataset.__getitem__
        self.dataloader_iterator.dataset.__getitem__ = a_new_getitem_method
        for _ in range(self._iter):
            next(self.dataloader_iterator)
        self.dataloader_iterator.dataset.__getitem__ = old_getitem

在多进程下没用

wanghao9610

wanghao9610 commented on Apr 12, 2025

@wanghao9610

I have had the same resume issue, and PR #1548 has solved this issue successfully. Thanks, @chtzs. If anyone gets stuck on this issue, please consider this PR solution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      Participants

      @yinaoxiong@chtzs@hujh1994@wanghao9610@zhouzaida

      Issue actions

        [Feature] Speed up the resume process of IterBased loop · Issue #1520 · open-mmlab/mmengine