Skip to content

Commit

Permalink
Update demoBERT input dimensions to match Triton requirement (#1051)
Browse files Browse the repository at this point in the history
Signed-off-by: Rajeev Rao <rajeevrao@nvidia.com>
  • Loading branch information
mengdong authored May 11, 2021
1 parent c9c1327 commit ab20a8a
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 58 deletions.
4 changes: 2 additions & 2 deletions demo/BERT/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,12 @@ This demo BERT application can be run within the TensorRT OSS build container. I

Download SQuAD v1.1 training and dev dataset.
```bash
sh ./scripts/download_squad.sh
bash ./scripts/download_squad.sh
```

Download Tensorflow checkpoints for BERT large model with sequence length 128, fine-tuned for SQuAD v2.0.
```bash
sh scripts/download_model.sh
bash scripts/download_model.sh
```

**Note:** Since the datasets and checkpoints are stored in the directory mounted from the host, they do *not* need to be downloaded each time the container is launched.
Expand Down
61 changes: 31 additions & 30 deletions demo/BERT/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,39 +528,32 @@ def load_onnx_weights_and_quant(path, config):
return weights_dict

def emb_layernorm(builder, network, config, weights_dict, builder_config, sequence_lengths, batch_sizes):
if len(batch_sizes) > 1 or len(sequence_lengths) > 1:
# int8 only support some of the sequence length, we dynamic on sequence length is not allowed.
input_ids = network.add_input(name="input_ids", dtype=trt.int32, shape=(-1 if len(sequence_lengths) > 1 else sequence_lengths[0], -1 if len(batch_sizes) > 1 else batch_sizes[0]))
segment_ids = network.add_input(name="segment_ids", dtype=trt.int32, shape=(-1 if len(sequence_lengths) > 1 else sequence_lengths[0], -1 if len(batch_sizes) > 1 else batch_sizes[0]))
input_mask = network.add_input(name="input_mask", dtype=trt.int32, shape=(-1 if len(sequence_lengths) > 1 else sequence_lengths[0], -1 if len(batch_sizes) > 1 else batch_sizes[0]))

# Specify profiles for the batch sizes we're interested in.
# Make sure the profile also works for all sizes not covered by the previous profile.
prev_batch_size = 0
for batch_size in sorted(batch_sizes):
if len(sequence_lengths) == 1:
min_shape = (sequence_lengths[0], prev_batch_size + 1)
shape = (sequence_lengths[0], batch_size)
# int8 only support some of the sequence length, we dynamic on sequence length is not allowed.
input_ids = network.add_input(name="input_ids", dtype=trt.int32, shape=(-1, -1 if len(sequence_lengths) > 1 else sequence_lengths[0]))
segment_ids = network.add_input(name="segment_ids", dtype=trt.int32, shape=(-1, -1 if len(sequence_lengths) > 1 else sequence_lengths[0]))
input_mask = network.add_input(name="input_mask", dtype=trt.int32, shape=(-1, -1 if len(sequence_lengths) > 1 else sequence_lengths[0]))

# Specify profiles for the batch sizes we're interested in.
# Make sure the profile also works for all sizes not covered by the previous profile.

for batch_size in sorted(batch_sizes):
if len(sequence_lengths) == 1:
profile = builder.create_optimization_profile()
min_shape = (1, sequence_lengths[0])
shape = (batch_size, sequence_lengths[0])
profile.set_shape("input_ids", min=min_shape, opt=shape, max=shape)
profile.set_shape("segment_ids", min=min_shape, opt=shape, max=shape)
profile.set_shape("input_mask", min=min_shape, opt=shape, max=shape)
builder_config.add_optimization_profile(profile)
else:
for sequence_length in sorted(sequence_lengths):
profile = builder.create_optimization_profile()
min_shape = (1, sequence_length)
shape = (batch_size, sequence_length)
profile.set_shape("input_ids", min=min_shape, opt=shape, max=shape)
profile.set_shape("segment_ids", min=min_shape, opt=shape, max=shape)
profile.set_shape("input_mask", min=min_shape, opt=shape, max=shape)
builder_config.add_optimization_profile(profile)
else:
prev_sequence_length = 0
for sequence_length in sorted(sequence_lengths):
profile = builder.create_optimization_profile()
min_shape = (prev_sequence_length + 1, prev_batch_size + 1)
shape = (sequence_length, batch_size)
profile.set_shape("input_ids", min=min_shape, opt=shape, max=shape)
profile.set_shape("segment_ids", min=min_shape, opt=shape, max=shape)
profile.set_shape("input_mask", min=min_shape, opt=shape, max=shape)
builder_config.add_optimization_profile(profile)
prev_sequence_length = sequence_length
prev_batch_size = batch_size
else:
input_ids = network.add_input(name="input_ids", dtype=trt.int32, shape=(sequence_lengths[0], batch_sizes[0]))
segment_ids = network.add_input(name="segment_ids", dtype=trt.int32, shape=(sequence_lengths[0], batch_sizes[0]))
input_mask = network.add_input(name="input_mask", dtype=trt.int32, shape=(sequence_lengths[0], batch_sizes[0]))

wbeta = trt.PluginField("bert_embeddings_layernorm_beta", weights_dict["bert_embeddings_layernorm_beta"].numpy(), trt.PluginFieldType.FLOAT32)
wgamma = trt.PluginField("bert_embeddings_layernorm_gamma", weights_dict["bert_embeddings_layernorm_gamma"].numpy(), trt.PluginFieldType.FLOAT32)
Expand All @@ -574,7 +567,15 @@ def emb_layernorm(builder, network, config, weights_dict, builder_config, sequen
pfc = trt.PluginFieldCollection([wbeta, wgamma, wwordemb, wtokemb, wposemb, output_fp16, mha_type])
fn = emln_plg_creator.create_plugin("embeddings", pfc)

inputs = [input_ids, segment_ids, input_mask]
input_ids = network.add_shuffle(input_ids)
input_ids.second_transpose = (1, 0)
segment_ids = network.add_shuffle(segment_ids)
segment_ids.second_transpose = (1, 0)
input_mask = network.add_shuffle(input_mask)
input_mask.second_transpose = (1, 0)
inputs = [input_ids.get_output(0),
segment_ids.get_output(0),
input_mask.get_output(0)]
emb_layer = network.add_plugin_v2(inputs, fn)

if config.use_qat:
Expand Down
4 changes: 2 additions & 2 deletions demo/BERT/infer_c/bert_infer.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ struct BertInference
exit(-1);
}

mEngine = TrtUniquePtr<ICudaEngine>(runtime->deserializeCudaEngine(bytes.data(), bytes.size(), nullptr));
mEngine = TrtUniquePtr<ICudaEngine>(runtime->deserializeCudaEngine(bytes.data(), bytes.size()));
if (mEngine == nullptr)
{
gLogError << "Error deserializing CUDA engine\n";
Expand Down Expand Up @@ -175,7 +175,7 @@ struct BertInference
{
for (int i = 0; i < kBERT_INPUT_NUM; i++)
{
mContext->setBindingDimensions(i + bindingIdxOffset, Dims2(mSeqLength, batchSize));
mContext->setBindingDimensions(i + bindingIdxOffset, Dims2(batchSize, mSeqLength));
}
}

Expand Down
4 changes: 2 additions & 2 deletions demo/BERT/inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@
" engine.create_execution_context() as context:\n",
"\n",
" # We always use batch size 1.\n",
" input_shape = (max_seq_length, 1)\n",
" input_shape = (1, max_seq_length)\n",
" input_nbytes = trt.volume(input_shape) * trt.int32.itemsize\n",
" \n",
" # Allocate device memory for inputs.\n",
Expand Down Expand Up @@ -349,7 +349,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
"version": "3.8.3"
}
},
"nbformat": 4,
Expand Down
10 changes: 5 additions & 5 deletions demo/BERT/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def question_features(tokens, question):
num_binding_per_profile = engine.num_bindings // engine.num_optimization_profiles
for idx in range(engine.num_optimization_profiles):
profile_shape = engine.get_profile_shape(profile_index = idx, binding = idx * num_binding_per_profile)
if profile_shape[0][1] <= args.batch_size and profile_shape[2][1] >= args.batch_size and profile_shape[0][0] <= max_seq_length and profile_shape[2][0] >= max_seq_length:
if profile_shape[0][0] <= args.batch_size and profile_shape[2][0] >= args.batch_size and profile_shape[0][1] <= max_seq_length and profile_shape[2][1] >= max_seq_length:
selected_profile = idx
break
if selected_profile == -1:
Expand All @@ -141,7 +141,7 @@ def question_features(tokens, question):

# Specify input shapes. These must be within the min/max bounds of the active profile
# Note that input shapes can be specified on a per-inference basis, but in this case, we only have a single shape.
input_shape = (max_seq_length, args.batch_size)
input_shape = (args.batch_size, max_seq_length)
input_nbytes = trt.volume(input_shape) * trt.int32.itemsize
for binding in range(3):
context.set_binding_shape(binding_idx_offset + binding, input_shape)
Expand All @@ -168,9 +168,9 @@ def inference(features, tokens):
eval_time_elapsed = 0
for feature_index, feature in enumerate(features):
# Copy inputs
input_ids_batch = np.dstack([feature.input_ids] * args.batch_size).squeeze()
segment_ids_batch = np.dstack([feature.segment_ids] * args.batch_size).squeeze()
input_mask_batch = np.dstack([feature.input_mask] * args.batch_size).squeeze()
input_ids_batch = np.repeat(np.expand_dims(feature.input_ids, 0), args.batch_size, axis=0)
segment_ids_batch = np.repeat(np.expand_dims(feature.segment_ids, 0), args.batch_size, axis=0)
input_mask_batch = np.repeat(np.expand_dims(feature.input_mask, 0), args.batch_size, axis=0)

input_ids = cuda.register_host_memory(np.ascontiguousarray(input_ids_batch.ravel()))
segment_ids = cuda.register_host_memory(np.ascontiguousarray(segment_ids_batch.ravel()))
Expand Down
16 changes: 8 additions & 8 deletions demo/BERT/perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def main():

with open(args.engine, 'rb') as f, trt.Runtime(TRT_LOGGER) as runtime, runtime.deserialize_cuda_engine(f.read()) as engine, engine.create_execution_context() as context:
# Allocate buffers large enough to store the largest batch size
max_input_shape = (args.sequence_length, max(args.batch_size))
max_output_shape = (args.sequence_length, max(args.batch_size), 2, 1, 1)
max_input_shape = (max(args.batch_size), args.sequence_length)
max_output_shape = (max(args.batch_size), args.sequence_length, 2, 1, 1)
buffers = [
DeviceBuffer(max_input_shape),
DeviceBuffer(max_input_shape),
Expand All @@ -65,9 +65,9 @@ def main():
pseudo_vocab_size = 30522
pseudo_type_vocab_size = 2
np.random.seed(args.random_seed)
test_word_ids = np.random.randint(0, pseudo_vocab_size, (args.sequence_length, max(args.batch_size)), dtype=np.int32)
test_segment_ids = np.random.randint(0, pseudo_type_vocab_size, (args.sequence_length, max(args.batch_size)), dtype=np.int32)
test_input_mask = np.ones((args.sequence_length, max(args.batch_size)), dtype=np.int32)
test_word_ids = np.random.randint(0, pseudo_vocab_size, (max(args.batch_size), args.sequence_length), dtype=np.int32)
test_segment_ids = np.random.randint(0, pseudo_type_vocab_size, (max(args.batch_size), args.sequence_length), dtype=np.int32)
test_input_mask = np.ones((max(args.batch_size), args.sequence_length), dtype=np.int32)

# Copy input h2d
cuda.memcpy_htod(buffers[0].buf, test_word_ids.ravel())
Expand All @@ -86,9 +86,9 @@ def main():
bindings = [0] * binding_idx_offset + [buf.binding() for buf in buffers]

shapes = {
"input_ids": (args.sequence_length, batch_size),
"segment_ids": (args.sequence_length, batch_size),
"input_mask": (args.sequence_length, batch_size),
"input_ids": (batch_size, args.sequence_length),
"segment_ids": (batch_size, args.sequence_length),
"input_mask": (batch_size, args.sequence_length),
}

for binding, shape in shapes.items():
Expand Down
5 changes: 3 additions & 2 deletions demo/BERT/scripts/download_model.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ do
done

# Prepare the download directory
mkdir -p /workspace/TensorRT/demo/BERT/models/fine-tuned
cd /workspace/TensorRT/demo/BERT/models/fine-tuned
mkdir -p models/fine-tuned
pushd models/fine-tuned

# Download the BERT fine-tuned model
echo "Downloading BERT-${FW} ${MODEL} checkpoints for sequence length ${SEQ_LEN} and fine-tuned for SQuAD ${SQUAD}."
Expand All @@ -78,3 +78,4 @@ if [ -n "$CKPT" ]; then
ngc registry model download-version nvidia/${CKPT}:${CKPT_VERSION}
fi
fi
popd
5 changes: 3 additions & 2 deletions demo/BERT/scripts/download_squad.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ done

# Download the SQuAD training and dev datasets
echo "Downloading SQuAD-${VERSION} training and dev datasets"
mkdir -p /workspace/TensorRT/demo/BERT/squad
cd /workspace/TensorRT/demo/BERT/squad
mkdir -p squad
pushd squad
wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-${VERSION}.json
wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-${VERSION}.json
popd
10 changes: 5 additions & 5 deletions demo/BERT/scripts/inference_benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ SEQUENCE_LENGTH="${4}"
MAX_BATCH="${5}"
GPU_ARCH="${6}"

CHECKPOINTS_DIR="/workspace/TensorRT/demo/BERT/models/fine-tuned/bert_tf_ckpt_${MODEL_VARIANT}_qa_squad2_amp_${SEQUENCE_LENGTH}_v19.03.1"
SQUAD_DIR="/workspace/TensorRT/demo/BERT/squad"
ENGINE_NAME="/workspace/TensorRT/demo/BERT/engines/bert_${MODEL_VARIANT}_${PRECISION}_bs${MAX_BATCH}_seqlen${SEQUENCE_LENGTH}_benchmark.engine"
CHECKPOINTS_DIR="models/fine-tuned/bert_tf_ckpt_${MODEL_VARIANT}_qa_squad2_amp_${SEQUENCE_LENGTH}_v19.03.1"
SQUAD_DIR="BERT/squad"
ENGINE_NAME="engines/bert_${MODEL_VARIANT}_${PRECISION}_bs${MAX_BATCH}_seqlen${SEQUENCE_LENGTH}_benchmark.engine"
# QAT Checkpoint - available only for BERT-Large
QAT_CHECKPOINT="/workspace/TensorRT/demo/BERT/models/fine-tuned/bert_pyt_onnx_large_qa_squad11_amp_fake_quant_v1/bert_large_v1_1_fake_quant.onnx"
CUDAGRAPH_PERFBIN="/workspace/TensorRT/demo/BERT/build/perf"
QAT_CHECKPOINT="models/fine-tuned/bert_pyt_onnx_large_qa_squad11_amp_fake_quant_v1/bert_large_v1_1_fake_quant.onnx"
CUDAGRAPH_PERFBIN="build/perf"

echo "==== Benchmarking BERT ${MODEL_VARIANT} ${PRECISION} SEQLEN ${SEQUENCE_LENGTH} on ${GPU_ARCH} ===="
if [ ! -f ${ENGINE_NAME} ]; then
Expand Down

0 comments on commit ab20a8a

Please sign in to comment.