Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Investigate Resident Memory Increase during Inference #18640

Open
ZekunZh opened this issue Sep 26, 2023 · 9 comments
Open

Investigate Resident Memory Increase during Inference #18640

ZekunZh opened this issue Sep 26, 2023 · 9 comments
Labels
bug Something isn't working help wanted Open to be worked on performance ver: 2.0.x

Comments

@ZekunZh
Copy link

ZekunZh commented Sep 26, 2023

Bug description

The memory consumption (RSS memory) continues to grow when Trainer is instantiated multiple times during the inference.

In our production environment, currently we need to instantiate a Trainer for each request which contains 1 image. That's why we observed the OOM issue.

We understand that it's might not be the best practice to use Lighting in production, any suggestions / comments are welcome ! 😃

The following curve can be reproduced with the provided python script, running 1000 iterations.

2023-09-26T14h23m39s_memory_usage_originalStrategy

What version are you seeing the problem on?

v2.0

How to reproduce the bug

import gc
import os
import re
from datetime import datetime
from pathlib import Path

import numpy as np
import psutil
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning import LightningModule
from lightning.fabric.utilities.optimizer import _optimizers_to_device
from lightning.pytorch import Trainer
from lightning.pytorch.strategies.single_device import SingleDeviceStrategy
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader, Dataset

TIME_FORMAT = "%Y-%m-%dT%Hh%Mm%Ss"

def get_time() -> str:
    """get current time and convert to specific format"""
    return datetime.utcnow().strftime(TIME_FORMAT)


class SimpleDataset(Dataset):
    def __len__(self):
        return 1000

    def __getitem__(self, idx):
        return torch.randn((1, 28, 28))


class SimpleModel(LightningModule):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.layer1 = nn.Linear(28 * 28, 512)
        self.layer2 = nn.Linear(512, 512)
        self.layer3 = nn.Linear(512, 512)
        self.layer4 = nn.Linear(512, 512)
        self.layer5 = nn.Linear(512, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = F.relu(self.layer3(x))
        x = F.relu(self.layer4(x))
        x = self.layer5(x)
        return x


    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        return self(batch)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())

class SingleDeviceProdStrategy(SingleDeviceStrategy):
    def teardown(self) -> None:
        _optimizers_to_device(self.optimizers, torch.device("cpu"))
        if self.lightning_module is not None:
            self.lightning_module.cpu()
        self.precision_plugin.teardown()
        assert self.accelerator is not None
        self.accelerator.teardown()
        self.checkpoint_io.teardown()
        gc.collect()


def convert_bytes_to_megabytes(memory_bytes):
    return memory_bytes / 1024 ** 2

def run_inference_and_monitor_memory(tag: str):
    dataset = SimpleDataset()
    dataloader = DataLoader(dataset, batch_size=32)
    model = SimpleModel()

    process = psutil.Process(os.getpid())
    initial_memory = process.memory_info().rss

    memory_usages = []

    N_ITERATIONS = 1000

    for i in range(N_ITERATIONS):
        strategy = SingleDeviceStrategy(device=torch.device("cuda:0"))
        # strategy = SingleDeviceProdStrategy(device=torch.device("cuda:0"))
        trainer = Trainer(strategy=strategy)
        trainer.predict(model, dataloader)
        current_memory = process.memory_info().rss
        memory_usage = convert_bytes_to_megabytes(current_memory - initial_memory)
        print(f"Iteration {i + 1}: Resident Memory used: {memory_usage:.3f} MB")
        memory_usages.append(memory_usage)

    plt.plot(range(1, N_ITERATIONS+1), memory_usages)
    plt.xlabel('Iteration')
    plt.ylabel('Resident Memory used (MB)')
    plt.title('Resident Memory Usage over Iterations')

    # Specify the y-ticks
    min_memory = min(memory_usages)
    max_memory = max(memory_usages)
    yticks = np.linspace(min_memory, max_memory, num=20)  # Increase num to increase density
    plt.yticks(yticks)

    fig_path = Path(__file__).parent / 'oom_minimal_example' / f'{get_time()}_memory_usage_{tag}.png'
    fig_path.parent.mkdir(exist_ok=True, parents=True)
    plt.savefig(fig_path)
    print(f"Saved figure to {fig_path.resolve()}")


def parse_args():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--tag", type=str, required=True)
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    run_inference_and_monitor_memory(tag=args.tag)

Error messages and logs

# Error messages and logs here please

Environment

Current environment
  • CUDA:
    • GPU:
      • Tesla T4
    • available: True
    • version: 11.7
  • Lightning:
    • lightning: 2.0.0
    • lightning-cloud: 0.5.36
    • lightning-utilities: 0.8.0
    • pytorch-lightning: 2.0.3
    • torch: 2.0.0
    • torchinfo: 1.5.3
    • torchmetrics: 0.11.4
    • torchvision: 0.15.1
  • Packages:
    • absl-py: 1.4.0
    • adal: 1.2.7
    • addict: 2.4.0
    • aiofiles: 23.1.0
    • aiohttp: 3.8.4
    • aiohttp-retry: 2.8.3
    • aiosignal: 1.3.1
    • albumentations: 1.1.0
    • amqp: 5.1.1
    • antlr4-python3-runtime: 4.9.3
    • anyio: 3.7.0
    • appdirs: 1.4.4
    • argcomplete: 2.1.2
    • arrow: 1.2.3
    • async-timeout: 4.0.2
    • asyncssh: 2.13.1
    • atpublic: 4.0
    • attrs: 23.1.0
    • autoflake: 2.2.0
    • azure-common: 1.1.28
    • azure-core: 1.27.0
    • azure-graphrbac: 0.61.1
    • azure-mgmt-authorization: 3.0.0
    • azure-mgmt-containerregistry: 10.1.0
    • azure-mgmt-core: 1.4.0
    • azure-mgmt-keyvault: 10.2.2
    • azure-mgmt-resource: 22.0.0
    • azure-mgmt-storage: 21.0.0
    • azure-nspkg: 3.0.2
    • azure-storage: 0.36.0
    • azure-storage-blob: 1.1.0
    • azure-storage-common: 1.1.0
    • azure-storage-nspkg: 3.1.0
    • azureml-core: 1.50.0
    • backports.tempfile: 1.0
    • backports.weakref: 1.0.post1
    • bcrypt: 4.0.1
    • beautifulsoup4: 4.12.2
    • billiard: 3.6.4.0
    • black: 23.3.0
    • blessed: 1.20.0
    • blindspin: 2.0.1
    • boto3: 1.26.149
    • botocore: 1.29.149
    • cachetools: 5.3.1
    • celery: 5.2.2
    • certifi: 2023.5.7
    • cffi: 1.15.1
    • cfgv: 3.3.1
    • charset-normalizer: 3.1.0
    • chumpy: 0.71
    • clearml: 1.3.2
    • click: 8.0.2
    • click-didyoumean: 0.3.0
    • click-plugins: 1.1.1
    • click-repl: 0.2.0
    • clickclick: 20.10.2
    • cloudpickle: 2.2.1
    • cmake: 3.26.4
    • colorama: 0.4.6
    • configobj: 5.0.8
    • connexion: 2.14.2
    • contextlib2: 0.5.5
    • contourpy: 1.0.7
    • coverage: 7.2.5
    • crayons: 0.4.0
    • croniter: 1.3.15
    • cryptography: 3.4.8
    • cycler: 0.11.0
    • cython: 0.29.33
    • dacite: 1.7.0
    • dateutils: 0.6.12
    • decorator: 5.1.1
    • deepdiff: 6.3.0
    • deprecated: 1.2.14
    • detectron2: 0.7+cu118
    • detrex: 0.3.0
    • dictdiffer: 0.9.0
    • dill: 0.3.6
    • diskcache: 5.6.1
    • distlib: 0.3.6
    • distro: 1.8.0
    • dnspython: 2.3.0
    • docker: 6.1.3
    • docker-pycreds: 0.4.0
    • dpath: 2.1.6
    • dulwich: 0.21.5
    • dvc: 2.46.0
    • dvc-data: 0.42.3
    • dvc-gs: 2.22.0
    • dvc-http: 2.30.2
    • dvc-objects: 0.22.0
    • dvc-render: 0.5.3
    • dvc-studio-client: 0.10.0
    • dvc-task: 0.2.1
    • einops: 0.6.1
    • et-xmlfile: 1.1.0
    • eventlet: 0.33.3
    • fairscale: 0.4.13
    • fastapi: 0.86.0
    • fiftyone: 0.20.0
    • fiftyone-brain: 0.11.0
    • fiftyone-db: 0.4.0
    • filelock: 3.12.0
    • flake8: 6.0.0
    • flask: 2.2.5
    • flask-testing: 0.8.1
    • flatten-dict: 0.4.2
    • flufl.lock: 7.1.1
    • fonttools: 4.39.4
    • frozenlist: 1.3.3
    • fsspec: 2023.5.0
    • ftfy: 6.1.1
    • funcy: 2.0
    • furl: 2.1.3
    • future: 0.18.3
    • fvcore: 0.1.5.post20220506
    • gcsfs: 2023.5.0
    • gitdb: 4.0.10
    • gitdb2: 2.0.6
    • gitpython: 3.1.31
    • glmlib: 1.0.0
    • glob2: 0.7
    • google-api-core: 1.34.0
    • google-auth: 2.19.1
    • google-auth-oauthlib: 1.0.0
    • google-cloud-core: 2.3.2
    • google-cloud-pubsub: 1.0.2
    • google-cloud-storage: 1.43.0
    • google-crc32c: 1.5.0
    • google-resumable-media: 1.3.0
    • googleapis-common-protos: 1.59.0
    • gputil: 1.4.0
    • grandalf: 0.8
    • graphql-core: 3.2.3
    • greenlet: 2.0.2
    • grpc-google-iam-v1: 0.12.6
    • grpcio: 1.54.2
    • grpcio-status: 1.48.2
    • h11: 0.14.0
    • h2: 4.1.0
    • hpack: 4.0.0
    • httpcore: 0.17.2
    • httpx: 0.24.1
    • huggingface-hub: 0.15.1
    • humanfriendly: 10.0
    • hydra-core: 1.3.2
    • hydra-zen: 0.10.0
    • hypercorn: 0.14.3
    • hyperframe: 6.0.1
    • identify: 2.5.24
    • idna: 3.4
    • imageio: 2.31.0
    • imgaug: 0.4.0
    • inflection: 0.5.1
    • iniconfig: 2.0.0
    • inquirer: 3.1.3
    • iopath: 0.1.9
    • isodate: 0.6.1
    • isort: 5.12.0
    • iterative-telemetry: 0.0.8
    • itsdangerous: 2.1.2
    • jaraco.classes: 3.3.0
    • jeepney: 0.8.0
    • jinja2: 3.1.2
    • jmespath: 1.0.1
    • joblib: 1.2.0
    • json-tricks: 3.17.0
    • jsonpickle: 3.0.1
    • jsonschema: 4.10.0
    • kaleido: 0.2.1
    • keyring: 24.2.0
    • keyrings.google-artifactregistry-auth: 1.1.2
    • kili: 2.120.0
    • kiwisolver: 1.4.4
    • knack: 0.10.1
    • kombu: 5.3.0
    • lazy-loader: 0.2
    • lightning: 2.0.0
    • lightning-cloud: 0.5.36
    • lightning-utilities: 0.8.0
    • lit: 16.0.5.post0
    • markdown: 3.4.3
    • markdown-it-py: 2.2.0
    • markupsafe: 2.1.3
    • matplotlib: 3.7.1
    • mccabe: 0.7.0
    • mdurl: 0.1.2
    • mmcv: 1.4.2
    • mmpose: 0.21.0
    • monai: 0.9.1
    • mongoengine: 0.24.2
    • more-itertools: 8.8.0
    • motor: 3.1.2
    • mpmath: 1.3.0
    • msal: 1.22.0
    • msal-extensions: 1.0.0
    • msrest: 0.7.1
    • msrestazure: 0.6.4
    • multidict: 6.0.4
    • munkres: 1.1.4
    • mypy-extensions: 1.0.0
    • nanotime: 0.5.2
    • ndg-httpsclient: 0.5.1
    • ndjson: 0.3.1
    • networkx: 3.1
    • nibabel: 3.2.1
    • nodeenv: 1.8.0
    • numpy: 1.24.2
    • nvidia-cublas-cu11: 11.10.3.66
    • nvidia-cuda-cupti-cu11: 11.7.101
    • nvidia-cuda-nvrtc-cu11: 11.7.99
    • nvidia-cuda-runtime-cu11: 11.7.99
    • nvidia-cudnn-cu11: 8.5.0.96
    • nvidia-cufft-cu11: 10.9.0.58
    • nvidia-curand-cu11: 10.2.10.91
    • nvidia-cusolver-cu11: 11.4.0.1
    • nvidia-cusparse-cu11: 11.7.4.91
    • nvidia-nccl-cu11: 2.14.3
    • nvidia-nvtx-cu11: 11.7.91
    • oauthlib: 3.2.2
    • omegaconf: 2.2.1
    • opencv-python: 4.7.0.72
    • opencv-python-headless: 4.7.0.72
    • openpyxl: 3.0.7
    • ordered-set: 4.1.0
    • orderedmultidict: 1.0.1
    • orjson: 3.9.0
    • packaging: 23.0
    • pandas: 2.0.2
    • paramiko: 3.2.0
    • pathlib2: 2.3.7.post1
    • pathspec: 0.11.1
    • pathtools: 0.1.2
    • patool: 1.12
    • pika: 1.1.0
    • pillow: 9.5.0
    • pip: 23.2.1
    • pkginfo: 1.9.6
    • platformdirs: 3.5.1
    • plotly: 5.14.1
    • pluggy: 1.0.0
    • portalocker: 2.7.0
    • pprintpp: 0.4.0
    • pre-commit: 3.2.2
    • priority: 2.0.0
    • prompt-toolkit: 3.0.38
    • protobuf: 3.20.3
    • psutil: 5.9.5
    • pyaescrypt: 0.4.3
    • pyasn1: 0.5.0
    • pyasn1-modules: 0.3.0
    • pybind11: 2.11.1
    • pycocotools: 2.0.6
    • pycodestyle: 2.10.0
    • pycparser: 2.21
    • pydantic: 1.10.9
    • pydicom: 2.0.0
    • pydot: 1.4.2
    • pyelftools: 0.27
    • pyflakes: 3.0.1
    • pygit2: 1.12.1
    • pygments: 2.15.1
    • pygtrie: 2.5.0
    • pyjwt: 2.1.0
    • pymongo: 4.3.3
    • pympler: 1.0.1
    • pynacl: 1.5.0
    • pyopenssl: 21.0.0
    • pyparsing: 3.0.9
    • pyrsistent: 0.19.3
    • pysocks: 1.7.1
    • pytest: 7.2.2
    • pytest-mock: 3.10.0
    • python-dateutil: 2.8.2
    • python-editor: 1.0.4
    • python-gdcm: 3.0.21
    • python-multipart: 0.0.6
    • pytorch-lightning: 2.0.3
    • pytz: 2023.3
    • pywavelets: 1.4.1
    • pyyaml: 6.0
    • qudida: 0.0.4
    • readchar: 4.0.5
    • regex: 2023.6.3
    • requests: 2.30.0
    • requests-oauthlib: 1.3.1
    • retrying: 1.3.4
    • rich: 13.4.1
    • rsa: 4.9
    • ruamel.yaml: 0.17.21
    • ruff: 0.0.270
    • s3transfer: 0.6.1
    • schema: 0.7.0
    • scikit-image: 0.20.0
    • scikit-learn: 1.2.2
    • scipy: 1.10.1
    • scmrepo: 0.2.1
    • secretstorage: 3.3.3
    • sentry-sdk: 1.25.1
    • setproctitle: 1.3.2
    • setuptools: 67.2.0
    • shapely: 2.0.1
    • shortuuid: 1.0.11
    • shtab: 1.6.1
    • six: 1.16.0
    • smmap: 5.0.0
    • smmap2: 3.0.1
    • sniffio: 1.3.0
    • sortedcontainers: 2.4.0
    • soupsieve: 2.4.1
    • sqltrie: 0.4.0
    • sse-starlette: 0.10.3
    • sseclient-py: 1.7.2
    • starlette: 0.20.4
    • starsessions: 1.3.0
    • strawberry-graphql: 0.138.1
    • submitit: 1.4.5
    • sympy: 1.12
    • tabulate: 0.9.0
    • tenacity: 8.2.2
    • tensorboard: 2.13.0
    • tensorboard-data-server: 0.7.0
    • termcolor: 2.3.0
    • testcontainers: 3.0.0
    • threadpoolctl: 3.1.0
    • tifffile: 2023.4.12
    • timm: 0.6.13
    • toml: 0.10.2
    • tomli: 2.0.1
    • tomlkit: 0.11.8
    • torch: 2.0.0
    • torchinfo: 1.5.3
    • torchmetrics: 0.11.4
    • torchvision: 0.15.1
    • tqdm: 4.64.0
    • traitlets: 5.9.0
    • triton: 2.0.0
    • typeguard: 4.0.0
    • typing-extensions: 4.6.3
    • tzdata: 2023.3
    • tzlocal: 5.0.1
    • universal-analytics-python3: 1.1.1
    • urllib3: 1.26.16
    • uvicorn: 0.22.0
    • vine: 5.0.0
    • virtualenv: 20.23.0
    • voluptuous: 0.13.1
    • voxel51-eta: 0.8.4
    • wandb: 0.15.0
    • wcwidth: 0.2.6
    • websocket-client: 1.5.2
    • websockets: 11.0.3
    • werkzeug: 2.2.3
    • wheel: 0.40.0
    • wrapt: 1.15.0
    • wsproto: 1.2.0
    • xmltodict: 0.13.0
    • xtcocotools: 1.13
    • yacs: 0.1.8
    • yapf: 0.33.0
    • yarl: 1.9.2
    • zc.lockfile: 3.0.post1
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.11.5
    • release: 5.15.0-1042-gcp
    • version: self-balancing architecture #50~20.04.1-Ubuntu SMP Mon Sep 11 03:30:57 UTC 2023

More info

The temporary solution to fix this issue is to add gc.collect() at the end of teardown method, while commenting self.lightning_module.cpu().

Things that I've tried:

  • Only comment self.lightning_module.cpu() -> not work 🛑
    2023-09-26T14h41m04s_memory_usage_noModToCPU

  • Only comment _optimizers_to_device(self.optimizers, torch.device("cpu")) -> not work 🛑
    2023-09-26T15h05m44s_memory_usage_noOptToCPU

  • Comment both module to cpu and optimiser to cpu -> not work 🛑
    2023-09-26T14h44m17s_memory_usage_noModToCPU-noOptToCPU

  • Only add gc.collect() -> partially work 🟡
    2023-09-26T14h51m30s_memory_usage_originalTdWithGC

  • Comment _optimizers_to_device(self.optimizers, torch.device("cpu")) + add gc.collect() -> partially work 🟡
    2023-09-26T15h42m45s_memory_usage_noOptToCPU-withGC

  • Comment self.lightning_module.cpu() + add gc.collect() -> work better 🟢
    2023-09-26T15h01m03s_memory_usage_noModToCPUWithGC

  • Comment self.lightning_module.cpu() and _optimizers_to_device(self.optimizers, torch.device("cpu")) + add gc.collect() -> Similar to the previous one 🟢
    2023-09-26T15h17m50s_memory_usage_noModToCPU-noOptToCPU-withGC

cc @Borda

@ZekunZh ZekunZh added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Sep 26, 2023
@ZekunZh
Copy link
Author

ZekunZh commented Sep 27, 2023

cc @awaelchli @four4fish @carmocca

@awaelchli awaelchli added performance help wanted Open to be worked on and removed needs triage Waiting to be triaged by maintainers labels Sep 28, 2023
@awaelchli
Copy link
Contributor

Thanks for investigating this and providing an example code @ZekunZh!

I'm not sure we can remove self.lightning_module.cpu(), we would need to investigate the implications. This could be an unexpected breaking change for users. We should check whether there is a different solution to this first.

@awaelchli
Copy link
Contributor

awaelchli commented Sep 29, 2023

I ran a couple of tests with your script, removing Lightning and only running with the raw PyTorch model:

...
for i in range(N_ITERATIONS):
        torch.cuda.empty_cache()
        gc.collect()
        model = model.to("cuda:0")
        with torch.inference_mode():
            for batch in dataloader:
                model(batch.to("cuda:0"))
        # model.cpu()
        
        current_memory = process.memory_info().rss
        memory_usage = convert_bytes_to_megabytes(current_memory - initial_memory)
        print(f"Iteration {i + 1}: Resident Memory used: {memory_usage:.3f} MB")
        memory_usages.append(memory_usage)
...

2023-09-29T01h57m56s_memory_usage_raw-torch-no-move-to-cpu

(results produced with torch nightly 2.2.0.dev20230920+cu121)

While the memory increase is definitely smaller, it is still a steady slope. I suppose on a production system with thousands of requests these few MB could add up. I'm definitely not familiar with memory management in Python and PyTorch, but there seems to be some hidden state somewhere that's not just in Lightning. Perhaps the impact is just amplified with Lightning and the root cause something else.

@MushroomMaula
Copy link

MushroomMaula commented Feb 22, 2024

Hey, I have the same problem but during training. I am currently using lightning in an active learning loop, in which I recreate the trainer in each loop.
Calling gc.collect() after each iteration successfully fixes this issue.

I am using lightning version 2.2.0.post0 and torch 2.2.0

import gc
import os
from copy import deepcopy
from typing import Any

import psutil
import torch
import torch.nn.functional as F
from lightning import LightningModule, Trainer
from matplotlib import pyplot as plt
from torch import nn, optim
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm


class MLP(LightningModule):

    def __init__(
        self, 
        n_hidden=10240  # Quite a large value to amplify the effect, my actual model has roughly the same size
    ):  
        super().__init__()
        torch.manual_seed(0)
        self.model = nn.Sequential(
            nn.Linear(1024, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, 2)
        )

        self._initial_state = deepcopy(self.state_dict())

    def forward(self, x) -> Any:
        return F.log_softmax(self.model(x), dim=1)

    def training_step(self, batch):
        x, y = batch
        pred = self.forward(x)
        loss = F.nll_loss(pred, y.view(-1))
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters())

    def reset(self):
        self.load_state_dict(self._initial_state)


def convert_bytes_to_megabytes(memory_bytes):
    return memory_bytes / 1024 ** 2


def run_pytorch(model, dataloader, iterations=10):
    model.train()

    process = psutil.Process(os.getpid())
    initial_memory = process.memory_info().rss
    memory_usages = []

    for _ in range(iterations):
        model.reset()
        optimizer = optim.Adam(model.parameters())
        for x, y in tqdm(dataloader):
            x = x.to("cuda")
            y = y.to("cuda")
            pred = model(x)
            loss = F.nll_loss(pred, y)
            # loss = model.training_step((x, y))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        current_memory = process.memory_info().rss
        diff = convert_bytes_to_megabytes(current_memory - initial_memory)
        memory_usages.append(diff)

    return memory_usages


def run_lightning(model, dataloader, iterations=10, fix=False):
    process = psutil.Process(os.getpid())
    initial_memory = process.memory_info().rss
    memory_usages = []

    for _ in range(iterations):
        model.reset()
        trainer = Trainer(max_epochs=1, enable_checkpointing=False, logger=False)
        trainer.fit(model, train_dataloaders=dataloader)
        if fix:
            gc.collect()

        current_memory = process.memory_info().rss
        diff = convert_bytes_to_megabytes(current_memory - initial_memory)
        memory_usages.append(diff)

    return memory_usages


def main():
    # Create some example data
    torch.manual_seed(0)
    X = torch.randn(size=(64, 1024))
    Y = (torch.rand(size=(64,)) < 0.2).long()
    dataset = TensorDataset(X, Y)
    dataloader = DataLoader(
        dataset,
        batch_size=64,
        shuffle=True,
        num_workers=1
    )

    model = MLP()
    # Remove model from cpu memory
    model.to("cuda")

    torch_usage = run_pytorch(model, dataloader)
    lightning_usage = run_lightning(model, dataloader)
    lightning_usage_fix = run_lightning(model, dataloader, fix=True)

    plt.plot(torch_usage, label="PyTorch")
    plt.plot(lightning_usage, label="Lightning")
    plt.plot(lightning_usage_fix, label="Lightning - Fix")
    plt.legend()
    plt.xlabel("Iterations")
    plt.ylabel("Memory usage")
    plt.show()


main()

image
PS: I think the memory for the fixed version is negative, because the trainer moves some stuff directly to the GPU.

@awaelchli
Copy link
Contributor

Thanks for collecting more data here. So then if gc.collect() "fixes" this, it must mean that there is nothing seriously wrong with the code in Lightning/PyTorch because references have been released. It's just Python not collecting the garbage fast enough? Is that right?

Hypothetically, if we were to insert a gc.collect() at the beginning of Trainer.fit() (cleaning up memory in case there was a trainer instance deleted before), would this be equivalent to your "fix"?

@MushroomMaula
Copy link

Yes, it seems that the result is the same.
Using the following implementation with fix=True solves the issue.

class GCTrainer(Trainer):

    def fit(self, fix: bool = False, *args, **kwargs):
        if fix:
            gc.collect()
        super().fit(*args, **kwargs)

@awaelchli
Copy link
Contributor

@carmocca What are your thoughts on adding a gc.collect() call at the beginning of Trainer's _run() function?

@carmocca
Copy link
Contributor

I'm learning towards not adding it. Instantiating trainers like this in a loop is very unconventional and there is a cost to triggering gc for everybody else. We also don't understand why these are not getting freed periodically as you'd expect. Perhaps this is python version-dependent or platform-dependent.

If somebody can explain the cause of this, we would be better informed to create a fix: either by improving the reference counts or by adding this collect() call

@awaelchli
Copy link
Contributor

If somebody can explain the cause of this, we would be better informed to create a fix: either by improving the reference counts or by adding this collect() call

@carmocca Just to clarify. Above we've determined that the Trainer releases these objects. So their refcount is actually 0. It's just that the GC does not collect them from memory quick enough. By adding gc.collect() and seeing the memory drop means the refcounts were 0, so there isn't any fix we could possibly do there. The GC is making the decisions here.

Instantiating trainers like this in a loop is very unconventional

I agree. In light of this I am also ok closing this issue. But for the same argumentation, I am also ok adding the gc.collect(). For the users who do this looping of Trainers, there is already overhead in setup and teardown of the trainer alone, so a gc.collect() shouldn't be noticeable IMO.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on performance ver: 2.0.x
Projects
None yet
Development

No branches or pull requests

4 participants