Skip to content

Commit

Permalink
Merge branch 'master' into jeffra/inject_v2
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffra authored Jan 6, 2021
2 parents 1b42798 + 5ab1279 commit a5a34c6
Show file tree
Hide file tree
Showing 50 changed files with 1,335 additions and 253 deletions.
10 changes: 4 additions & 6 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@ name: Build

# Controls when the action will run.
on:
# Triggers the workflow on push or pull request events but only for the master branch
push:
branches: [ master ]
paths-ignore:
- 'docs/**'
pull_request:
branches: [ master ]

# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:
paths-ignore:
- 'docs/**'

# A workflow run is made up of one or more jobs that can run sequentially or in parallel
jobs:
Expand Down
47 changes: 47 additions & 0 deletions .github/workflows/pre-compile-ops.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# This is a basic workflow to help you get started with Actions

name: Tests-w-precompiled-ops

# Controls when the action will run.
on:
# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:

# A workflow run is made up of one or more jobs that can run sequentially or in parallel
jobs:
# This workflow contains a single job called "build"
build:
# The type of runner that the job will run on
runs-on: self-hosted

# Steps represent a sequence of tasks that will be executed as part of the job
steps:
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v2

# Runs a single command using the runners shell
- name: environment
run: |
nvidia-smi
which python
python --version
which nvcc
nvcc --version
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
# Runs a set of commands using the runners shell
- name: Install deepspeed
run: |
DS_BUILD_OPS=1 pip install .[dev]
ds_report
- name: Formatting checks
run: |
pre-commit run --all-files
# Runs a set of commands using the runners shell
- name: Unit tests
run: |
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --durations=0 --forked --verbose -x tests/unit/
39 changes: 39 additions & 0 deletions bin/ds_elastic
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/usr/bin/env python

import argparse
import json

import deepspeed
from deepspeed.elasticity import compute_elastic_config


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, help="DeepSpeed config json")
parser.add_argument('-w', '--world-size', type=int, default=0, help="Intended/current world size")
args = parser.parse_args()
ds_config = json.load(open(args.config, 'r'))

ds_version = deepspeed.__version__

elastic_config = ds_config['elasticity']
print('------------------------------------------')
print("Elasticity config:")
print('------------------------------------------')
print(json.dumps(elastic_config, indent=4, sort_keys=True))

if args.world_size > 0:
final_batch_size, valid_gpus, micro_batch_size = compute_elastic_config(ds_config=ds_config, target_deepspeed_version=ds_version, world_size=args.world_size)
print('------------------------------------------')
print(f"Calculated results for world size {args.world_size}:")
print('------------------------------------------')
print(f'final_batch_size .... {final_batch_size}')
print(f'valid_gpus .......... {valid_gpus}')
print(f'micro_batch_size .... {micro_batch_size}')
else:
final_batch_size, valid_gpus = compute_elastic_config(ds_config=ds_config, target_deepspeed_version=ds_version)
print('------------------------------------------')
print("Calculated results:")
print('------------------------------------------')
print(f'final_batch_size .... {final_batch_size}')
print(f'valid_gpus .......... {valid_gpus}')
13 changes: 10 additions & 3 deletions csrc/transformer/ds_transformer_cuda.cpp
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

static std::unordered_map<int, std::shared_ptr<void>> s_transformer_layers;

const int init_seq_length = 128;

// C++ interface

template <typename T>
Expand Down Expand Up @@ -591,7 +593,6 @@ int create_transformer_layer(int layer_id,
int hidden_dim,
int num_heads,
int intermediate_size,
int seq_length,
float attn_dropout_ratio,
float hidden_dropout_ratio,
int seed,
Expand All @@ -604,14 +605,14 @@ int create_transformer_layer(int layer_id,
{
Context::Instance().SetSeed(seed);
Context::Instance().TestGemmFP16(
test_gemm, batch_size, seq_length, num_heads, hidden_dim / num_heads);
test_gemm, batch_size, init_seq_length, num_heads, hidden_dim / num_heads);

auto layer = std::make_shared<BertTransformerLayer<T>>(layer_id,
batch_size,
hidden_dim,
num_heads,
intermediate_size,
seq_length,
init_seq_length,
attn_dropout_ratio,
hidden_dropout_ratio,
pre_or_postLayerNorm,
Expand Down Expand Up @@ -873,6 +874,12 @@ std::vector<torch::Tensor> ds_transformer_backward(int layer_id,
std::shared_ptr<BertTransformerLayer<T>> layer =
std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);

int seq_len = layer->GetSeqLength();
if (g_output.size(1) != seq_len) {
seq_len = g_output.size(1);
layer->SetSeqLength(seq_len, bsz);
}

auto grad_input = torch::empty_like(input);
auto grad_attn_qkvw = torch::empty_like(attn_qkvw);
auto grad_attn_qkvb = torch::empty_like(attn_qkvb);
Expand Down
20 changes: 14 additions & 6 deletions csrc/transformer/softmax_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ __global__ void attn_softmax(float* vals,
#endif

int iters = warp_num;
if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length);
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);

for (int i = 1; i < iters; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
Expand Down Expand Up @@ -113,7 +114,8 @@ __global__ void attn_softmax(float* vals,
#endif

int iters = warp_num;
if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length);
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);

for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); }

Expand Down Expand Up @@ -216,7 +218,8 @@ __global__ void attn_softmax(__half* vals,
#endif

int iters = warp_num;
if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length);
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);

for (int i = 1; i < iters; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
Expand Down Expand Up @@ -252,7 +255,8 @@ __global__ void attn_softmax(__half* vals,
#endif

int iters = warp_num;
if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length);
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);

for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); }

Expand Down Expand Up @@ -339,7 +343,9 @@ void launch_attn_softmax<float>(float* vals,
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);

iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 512)
attn_softmax<32, (threads / 128), 128><<<grid_dim, block_dim, 0, stream>>>(
vals, attn_mask, heads, seq_length4, iterations);
Expand Down Expand Up @@ -408,7 +414,9 @@ void launch_attn_softmax<__half>(__half* vals,
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);

iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 512)
attn_softmax<32, (threads / 128), 128><<<grid_dim, block_dim, 0, stream>>>(
vals, attn_mask, heads, seq_length4, iterations);
Expand Down
1 change: 1 addition & 0 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .runtime.activation_checkpointing import checkpointing
from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .utils import log_dist
from .utils.distributed import init_distributed

from .pipe import PipelineModule

Expand Down
8 changes: 8 additions & 0 deletions deepspeed/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''

#############################################
# Torch distributed constants
#############################################
TORCH_DISTRIBUTED_DEFAULT_PORT = 29500
1 change: 1 addition & 0 deletions deepspeed/elasticity/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .elasticity import compute_elastic_config, elasticity_enabled, ensure_immutable_elastic_config
80 changes: 80 additions & 0 deletions deepspeed/elasticity/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""
Copyright 2020 The Microsoft DeepSpeed Team
"""

import json
from .constants import *


class ElasticityError(Exception):
"""
Base exception for all elasticity related errors
"""
pass


class ElasticityConfigError(ElasticityError):
"""
Elasticity configuration error
"""
pass


class ElasticityIncompatibleWorldSize(ElasticityError):
"""
Attempting to run a world size that is incompatible with a given elastic config
"""
pass


class ElasticityConfig:
"""
Elastic config object, constructed from a param dictionary that only contains elastic
config parameters, example below:
If elasticity is enabled, user must specify (at least) max_train_batch_size
and micro_batch_sizes.
{
"enabled": true,
"max_train_batch_size": 2000,
"micro_batch_sizes": [2,4,6],
"min_gpus": 1,
"max_gpus" : 10000
"min_time": 20
"ignore_non_elastic_batch_info": false
"version": 0.1
}
"""
def __init__(self, param_dict):
self.enabled = param_dict.get(ENABLED, ENABLED_DEFAULT)
if self.enabled:
if MAX_ACCEPTABLE_BATCH_SIZE in param_dict:
self.max_acceptable_batch_size = param_dict[MAX_ACCEPTABLE_BATCH_SIZE]
else:
raise ElasticityConfigError(
f"Elasticity config missing {MAX_ACCEPTABLE_BATCH_SIZE}")
if MICRO_BATCHES in param_dict:
self.micro_batches = param_dict[MICRO_BATCHES]
else:
raise ElasticityConfigError(f"Elasticity config missing {MICRO_BATCHES}")
else:
self.max_acceptable_batch_size = param_dict.get(
MAX_ACCEPTABLE_BATCH_SIZE,
MAX_ACCEPTABLE_BATCH_SIZE_DEFAULT)
self.micro_batches = param_dict.get(MICRO_BATCHES, MICRO_BATCHES_DEFAULT)
self.min_gpus = param_dict.get(MIN_GPUS, MIN_GPUS_DEFAULT)
self.max_gpus = param_dict.get(MAX_GPUS, MAX_GPUS_DEFAULT)
self.min_time = param_dict.get(MIN_TIME, MIN_TIME_DEFAULT)
self.version = param_dict.get(VERSION, VERSION_DEFAULT)
self.prefer_larger_batch_size = param_dict.get(PREFER_LARGER_BATCH,
PREFER_LARGER_BATCH_DEFAULT)
self.ignore_non_elastic_batch_info = param_dict.get(
IGNORE_NON_ELASTIC_BATCH_INFO,
IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT)

def repr(self):
return self.__dict__

def __repr__(self):
return json.dumps(self.__dict__, sort_keys=True, indent=4)
74 changes: 74 additions & 0 deletions deepspeed/elasticity/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""
Copyright 2020 The Microsoft DeepSpeed Team
"""

#########################################
# Elasticity
#########################################
''' Elasticity Utility in DeepSpeed can be used to create highly elastic jobs compatible
with a large number of GPUs. For elastic jobs, DeepSpeed will provide a batch size that
can support a large number of GPUs based on the user specified parameters
'''
FORMAT = '''
Elasticity should be enabled as:
"elasticity": {
"enabled": true,
"max_train_batch_size": 2000,
"micro_batch_sizes": [2,4,6],
"min_gpus": 1,
"max_gpus" : 10000
"min_time": 20,
"prefer_larger_batch": true,
"ignore_non_elastic_batch_info": false,
"version": 0.1
}
'''

ELASTICITY = 'elasticity'

# Current elasticity version
LATEST_ELASTICITY_VERSION = 0.1

ENABLED = 'enabled'
ENABLED_DEFAULT = False

# Max acceptable train_batch_size
MAX_ACCEPTABLE_BATCH_SIZE = 'max_train_batch_size'
MAX_ACCEPTABLE_BATCH_SIZE_DEFAULT = 2000

# Acceptable micro batch sizes, same as train_micro_batch_size_per_gpu
MICRO_BATCHES = 'micro_batch_sizes'
MICRO_BATCHES_DEFAULT = [2, 4, 6]

# Min/max of GPUs to search over
MIN_GPUS = 'min_gpus'
MIN_GPUS_DEFAULT = 1
MAX_GPUS = 'max_gpus'
MAX_GPUS_DEFAULT = 10000

# Minimum running time (minutes) before the scheduler will scale us
MIN_TIME = "min_time"
MIN_TIME_DEFAULT = "20"

# When finding a suitable batch size, attempt to find one that is closest
# to the max train batch size given.
PREFER_LARGER_BATCH = 'prefer_larger_batch'
PREFER_LARGER_BATCH_DEFAULT = True

# In order to reduce confusion, if elastic mode is enabled we
# require (via assert) that no batch info is set outside of the
# elastic config. You can turn off this assert via this config
# but keep in mind that all batch info defined outside the
# elastic mode *will be ignored*.
IGNORE_NON_ELASTIC_BATCH_INFO = 'ignore_non_elastic_batch_info'
IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT = False

# Version of elastic logic to use
VERSION = "version"
VERSION_DEFAULT = LATEST_ELASTICITY_VERSION

# Minimum deepspeed version to use elasticity
MINIMUM_DEEPSPEED_VERSION = "0.3.8"

# Environment variable storing elastic config from resource scheduler
DEEPSPEED_ELASTICITY_CONFIG = "DEEPSPEED_ELASTICITY_CONFIG"
Loading

0 comments on commit a5a34c6

Please sign in to comment.