Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about singa-auto graphic memory overflow while training for several trials #68

Open
SeanCho1996 opened this issue Oct 22, 2020 · 2 comments

Comments

@SeanCho1996
Copy link

I was trying to train a VGG model with singa-auto on a local environment, the training python script is PyPandaVgg.py:

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#

from __future__ import division
from __future__ import print_function
import os
import argparse
from typing import Union, Dict, Any

# Singa-auto Dependency
from singa_auto.model import CategoricalKnob, FixedKnob, utils
from singa_auto.model.knob import BaseKnob
from singa_auto.constants import ModelDependency
from singa_auto.model.dev import test_model_class

# PyTorch Dependency
import torch.nn as nn
from torchvision.models.vgg import vgg11_bn

# Misc Third-party Machine-Learning Dependency
import numpy as np

# singa easy Modules Dependency
from singa_easy.models.TorchModel import TorchModel

KnobConfig = Dict[str, BaseKnob]
Knobs = Dict[str, Any]
Params = Dict[str, Union[str, int, float, np.ndarray]]


class PyPandaVgg(TorchModel):
    """
    Implementation of PyTorch DenseNet
    """

    def __init__(self, **knobs):
        super().__init__(**knobs)

    def _create_model(self, scratch: bool, num_classes: int):
        model = vgg11_bn(pretrained=not scratch)
        num_features = 4096
        model.classifier[6] = nn.Linear(num_features, num_classes)
        print("create model {}".format(model))
        return model

    @staticmethod
    def get_knob_config():
        return {
            # Learning parameters
            'lr': FixedKnob(0.0001),  ### learning_rate
            'weight_decay': FixedKnob(0.0),
            'drop_rate': FixedKnob(0.0),
            'max_epochs': FixedKnob(5),
            'batch_size': CategoricalKnob([256]),
            'max_iter': FixedKnob(20),
            'optimizer': CategoricalKnob(['adam']),
            'scratch': FixedKnob(True),

            # Data augmentation
            'max_image_size': FixedKnob(32),
            'share_params': CategoricalKnob(['SHARE_PARAMS']),
            'tag': CategoricalKnob(['relabeled']),
            'workers': FixedKnob(8),
            'seed': FixedKnob(123456),
            'scale': FixedKnob(512),
            'horizontal_flip': FixedKnob(True),

            # Self-paced Learning and Loss Revision
            'enable_spl': FixedKnob(True),
            'spl_threshold_init': FixedKnob(16.0),
            'spl_mu': FixedKnob(1.3),
            'enable_lossrevise': FixedKnob(False),
            'lossrevise_slop': FixedKnob(2.0),

            # Label Adaptation
            'enable_label_adaptation': FixedKnob(False),

            # GM Prior Regularization
            'enable_gm_prior_regularization': FixedKnob(False),
            'gm_prior_regularization_a': FixedKnob(0.001),
            'gm_prior_regularization_b': FixedKnob(0.0001),
            'gm_prior_regularization_alpha': FixedKnob(0.5),
            'gm_prior_regularization_num': FixedKnob(4),
            'gm_prior_regularization_lambda': FixedKnob(0.0001),
            'gm_prior_regularization_upt_freq': FixedKnob(100),
            'gm_prior_regularization_param_upt_freq': FixedKnob(50),

            # Explanation
            'enable_explanation': FixedKnob(True),
            'explanation_gradcam': FixedKnob(True),
            'explanation_lime': FixedKnob(True),

            # Model Slicing
            'enable_model_slicing': FixedKnob(False),
            'model_slicing_groups': FixedKnob(0),
            'model_slicing_rate': FixedKnob(1.0),
            'model_slicing_scheduler_type': FixedKnob('randomminmax'),
            'model_slicing_randnum': FixedKnob(1),

            # MC Dropout
            'enable_mc_dropout': FixedKnob(True),
            'mc_trials_n': FixedKnob(10)
        }


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--train_path',
                        type=str,
                        default='./dataset.zip',
                        help='Path to train dataset')
    parser.add_argument('--val_path',
                        type=str,
                        default='./dataset.zip',
                        help='Path to validation dataset')
    parser.add_argument('--test_path',
                        type=str,
                        default='./dataset.zip',
                        help='Path to test dataset')
    print(os.getcwd())
    parser.add_argument(
        '--query_path',
        type=str,
        default=
        # 'examples/data/image_classification/1463729893_339.jpg,examples/data/image_classification/1463729893_326.jpg,examples/data/image_classification/eed35e9d04814071.jpg',
        'examples/data/image_classification/1463729893_339.jpg',
        help='Path(s) to query image(s), delimited by commas')
    (args, _) = parser.parse_known_args()

    # queries = utils.dataset.load_images(args.query_path.split(',')).tolist()

    test_model_class(model_file_path=__file__,
                     model_class='PyPandaVgg',
                     task='IMAGE_CLASSIFICATION',
                     dependencies={
                         ModelDependency.TORCH: '1.0.1',
                         ModelDependency.TORCHVISION: '0.2.2',
                     },
                     train_dataset_path=args.train_path,
                     val_dataset_path=args.val_path,
                     test_dataset_path=args.test_path,)
                    #  queries=queries)

    # Test without singa-auto frame

    # parser = argparse.ArgumentParser()
    # parser.add_argument('--path',
    #                     type=str,
    #                     help='Path root of the model file')
    # parser.add_argument('--fname',
    #                     type=str,
    #                     help='Model file name')
    # parser.add_argument('--img',
    #                     type=str,
    #                     help='Path to test img')
    # (args, _) = parser.parse_known_args()
    # path = args.path
    # fname = args.fname
    # img = args.img
    # from singa_auto.param_store import FileParamStore
    # from singa_auto.advisor.advisor import RandomAdvisor
    # knobs = PyPandaVgg.get_knob_config()
    # adviser = RandomAdvisor(knobs, {})
    # knobs = {
    #     name: adviser._propose_knob(knob)
    #     for (name, knob) in adviser.knob_config.items()
    # }
    # model = PyPandaVgg(**knobs)
    # params = FileParamStore(path).load(fname)
    # model.load_parameters(params)
    #
    # with open(img, "rb") as f:
    #     img_bytes = [f.read()]
    # queries = utils.dataset.load_images_from_bytes(img_bytes).tolist()
    #
    # print(model.predict(queries))

My initial trial number was set to 7, and the Time_hour set to 0.5h, but as the process went on, when the comes to the 4th or 5th trial, the graphic memory occupation raised to 10169Mb and keep rising, then the process will crash because of graphic memory overflow.

Traceback (most recent call last):
  File "PyPandaVgg.py", line 158, in <module>
    test_dataset_path=args.test_path,)
  File "/home/zhaozixiao/projects/singa_vgg/singa-auto/singa_auto/model/dev.py", line 316, in test_model_class
    train_args=train_args)
  File "/home/zhaozixiao/projects/singa_vgg/singa-auto/singa_auto/model/dev.py", line 126, in tune_model
    **(train_args or {}))
  File "/home/zhaozixiao/projects/singa_vgg/singa-auto/singa_easy/models/TorchModel.py", line 340, in train
    trainloss.backward()
  File "/home/zhaozixiao/miniconda3/envs/singa/lib/python3.6/site-packages/torch/tensor.py", line 102, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/zhaozixiao/miniconda3/envs/singa/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: CUDA out of memory. Tried to allocate 392.00 MiB (GPU 0; 10.76 GiB total capacity; 9.56 GiB already allocated; 325.12 MiB free; 38.61 MiB cached)

When I attempted to debug th whole procedure, I found that the potential problem was in the dev.py file, more precisely, in the tune_model function, trial loop:

while True:
        trial_no += 1

        # Advisor checks free workers
        worker_ids = train_cache.get_workers()
        assert worker_id in worker_ids

        # Advisor checks worker doesn't already have a proposal
        proposal = train_cache.get_proposal(worker_id)
        assert proposal is None

        # Advisor sends a proposal to worker
        # Overriding knobs from args
        proposal = advisor.propose(worker_id, trial_no)
        if proposal is None:
            print('No more proposals from advisor - to stop training')
            break
        proposal.knobs = {**proposal.knobs, **knobs_from_args}
        train_cache.create_proposal(worker_id, proposal)

        # Worker receives proposal
        proposal = train_cache.get_proposal(worker_id)
        assert proposal is not None

        # Worker starts trial
        _print_header(f'Trial #{trial_no}')
        print('Proposal from advisor:', proposal)

        # Worker loads model
        model_inst = py_model_class(**proposal.knobs)

        # Worker pulls shared params
        shared_params = _pull_shared_params(proposal, param_cache)

        # Worker trains model
        print('Training model...')
        model_inst.train(train_dataset_path,
                         annotation_dataset_path=annotation_dataset_path,
                         shared_params=shared_params,
                         **(train_args or {}))

        # Worker evaluates model
        result = _evaluate_model(model_inst, proposal, val_dataset_path, annotation_dataset_path)

        # Worker caches/saves model parameters
        store_params_id = _save_model(model_inst, proposal, result, param_cache,
                                      param_store)

        # Update best saved model
        if result.score is not None and store_params_id is not None and result.score > best_model_score:
            inform_user(
                'Best saved model so far! Beats previous best of score {}!'.
                format(best_model_score))
            best_store_params_id = store_params_id
            best_proposal = proposal
            best_model_score = result.score
            best_trial_no = trial_no

            # Test best model, if test dataset provided
            if test_dataset_path is not None:
                print('Evaluating new best model on test dataset...')
                best_model_test_score = model_inst.evaluate(test_dataset_path,
                                                            annotation_dataset_path=annotation_dataset_path)
                inform_user(
                    'Score on test dataset: {}'.format(best_model_test_score))

        # Worker sends result to advisor
        print('Giving feedback to advisor...')
        train_cache.create_result(worker_id, result)
        train_cache.delete_proposal(worker_id)

        # Advisor receives result
        # Advisor ingests feedback
        result = train_cache.take_result(worker_id)
        assert result is not None
        advisor.feedback(worker_id, result)

        # Destroy model
        model_inst.destroy()

At the end of each loop, the destroy function is called to delete the temporarily mounted model in the graphic memory, but as I digged into this function in model.py,

def destroy(self):
        '''
        Destroy this model instance, freeing any resources held by this model instance.
        No other instance methods will be called subsequently.
        '''
        pass

I found that this function is empty, and apparently the model is not deleted, so my question is whether this function is not finished yet or there was something wrong with my comprehension? Thank you.

@SeanCho1996
Copy link
Author

I tried to update Pytorch version to 1.4.0, while torchvision is 0.5.0, and this problem is solved. Maybe need to modify the script

dependencies={
                         ModelDependency.TORCH: '1.0.1',
                         ModelDependency.TORCHVISION: '0.2.2',
                     },

to the newest version of torch?

@chrishkchris
Copy link
Contributor

chrishkchris commented Oct 26, 2020

Thanks for the information.
I see your script is using batch size of 256 that cannot fit into a single GPU.
Normally a VGG (VGG11BN) is very large so we use a small batch size (32 in this case)

If it is for CPU only, we can use batch size of 256 for VGG. In other words, the batch size of 256 is for CPU

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants