Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update demoBERT input dimensions to match Triton requirement #1051

Merged
merged 5 commits into from
May 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions demo/BERT/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,12 @@ This demo BERT application can be run within the TensorRT OSS build container. I

Download SQuAD v1.1 training and dev dataset.
```bash
sh ./scripts/download_squad.sh
bash ./scripts/download_squad.sh
```

Download Tensorflow checkpoints for BERT large model with sequence length 128, fine-tuned for SQuAD v2.0.
```bash
sh scripts/download_model.sh
bash scripts/download_model.sh
```

**Note:** Since the datasets and checkpoints are stored in the directory mounted from the host, they do *not* need to be downloaded each time the container is launched.
Expand Down
61 changes: 31 additions & 30 deletions demo/BERT/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,39 +528,32 @@ def load_onnx_weights_and_quant(path, config):
return weights_dict

def emb_layernorm(builder, network, config, weights_dict, builder_config, sequence_lengths, batch_sizes):
if len(batch_sizes) > 1 or len(sequence_lengths) > 1:
# int8 only support some of the sequence length, we dynamic on sequence length is not allowed.
input_ids = network.add_input(name="input_ids", dtype=trt.int32, shape=(-1 if len(sequence_lengths) > 1 else sequence_lengths[0], -1 if len(batch_sizes) > 1 else batch_sizes[0]))
segment_ids = network.add_input(name="segment_ids", dtype=trt.int32, shape=(-1 if len(sequence_lengths) > 1 else sequence_lengths[0], -1 if len(batch_sizes) > 1 else batch_sizes[0]))
input_mask = network.add_input(name="input_mask", dtype=trt.int32, shape=(-1 if len(sequence_lengths) > 1 else sequence_lengths[0], -1 if len(batch_sizes) > 1 else batch_sizes[0]))

# Specify profiles for the batch sizes we're interested in.
# Make sure the profile also works for all sizes not covered by the previous profile.
prev_batch_size = 0
for batch_size in sorted(batch_sizes):
if len(sequence_lengths) == 1:
min_shape = (sequence_lengths[0], prev_batch_size + 1)
shape = (sequence_lengths[0], batch_size)
# int8 only support some of the sequence length, we dynamic on sequence length is not allowed.
input_ids = network.add_input(name="input_ids", dtype=trt.int32, shape=(-1, -1 if len(sequence_lengths) > 1 else sequence_lengths[0]))
segment_ids = network.add_input(name="segment_ids", dtype=trt.int32, shape=(-1, -1 if len(sequence_lengths) > 1 else sequence_lengths[0]))
input_mask = network.add_input(name="input_mask", dtype=trt.int32, shape=(-1, -1 if len(sequence_lengths) > 1 else sequence_lengths[0]))

# Specify profiles for the batch sizes we're interested in.
# Make sure the profile also works for all sizes not covered by the previous profile.

for batch_size in sorted(batch_sizes):
if len(sequence_lengths) == 1:
profile = builder.create_optimization_profile()
min_shape = (1, sequence_lengths[0])
shape = (batch_size, sequence_lengths[0])
profile.set_shape("input_ids", min=min_shape, opt=shape, max=shape)
profile.set_shape("segment_ids", min=min_shape, opt=shape, max=shape)
profile.set_shape("input_mask", min=min_shape, opt=shape, max=shape)
builder_config.add_optimization_profile(profile)
else:
for sequence_length in sorted(sequence_lengths):
profile = builder.create_optimization_profile()
min_shape = (1, sequence_length)
shape = (batch_size, sequence_length)
profile.set_shape("input_ids", min=min_shape, opt=shape, max=shape)
profile.set_shape("segment_ids", min=min_shape, opt=shape, max=shape)
profile.set_shape("input_mask", min=min_shape, opt=shape, max=shape)
builder_config.add_optimization_profile(profile)
else:
prev_sequence_length = 0
for sequence_length in sorted(sequence_lengths):
profile = builder.create_optimization_profile()
min_shape = (prev_sequence_length + 1, prev_batch_size + 1)
shape = (sequence_length, batch_size)
profile.set_shape("input_ids", min=min_shape, opt=shape, max=shape)
profile.set_shape("segment_ids", min=min_shape, opt=shape, max=shape)
profile.set_shape("input_mask", min=min_shape, opt=shape, max=shape)
builder_config.add_optimization_profile(profile)
prev_sequence_length = sequence_length
prev_batch_size = batch_size
else:
input_ids = network.add_input(name="input_ids", dtype=trt.int32, shape=(sequence_lengths[0], batch_sizes[0]))
segment_ids = network.add_input(name="segment_ids", dtype=trt.int32, shape=(sequence_lengths[0], batch_sizes[0]))
input_mask = network.add_input(name="input_mask", dtype=trt.int32, shape=(sequence_lengths[0], batch_sizes[0]))

wbeta = trt.PluginField("bert_embeddings_layernorm_beta", weights_dict["bert_embeddings_layernorm_beta"].numpy(), trt.PluginFieldType.FLOAT32)
wgamma = trt.PluginField("bert_embeddings_layernorm_gamma", weights_dict["bert_embeddings_layernorm_gamma"].numpy(), trt.PluginFieldType.FLOAT32)
Expand All @@ -574,7 +567,15 @@ def emb_layernorm(builder, network, config, weights_dict, builder_config, sequen
pfc = trt.PluginFieldCollection([wbeta, wgamma, wwordemb, wtokemb, wposemb, output_fp16, mha_type])
fn = emln_plg_creator.create_plugin("embeddings", pfc)

inputs = [input_ids, segment_ids, input_mask]
input_ids = network.add_shuffle(input_ids)
input_ids.second_transpose = (1, 0)
segment_ids = network.add_shuffle(segment_ids)
segment_ids.second_transpose = (1, 0)
input_mask = network.add_shuffle(input_mask)
input_mask.second_transpose = (1, 0)
inputs = [input_ids.get_output(0),
segment_ids.get_output(0),
input_mask.get_output(0)]
emb_layer = network.add_plugin_v2(inputs, fn)

if config.use_qat:
Expand Down
4 changes: 2 additions & 2 deletions demo/BERT/infer_c/bert_infer.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ struct BertInference
exit(-1);
}

mEngine = TrtUniquePtr<ICudaEngine>(runtime->deserializeCudaEngine(bytes.data(), bytes.size(), nullptr));
mEngine = TrtUniquePtr<ICudaEngine>(runtime->deserializeCudaEngine(bytes.data(), bytes.size()));
if (mEngine == nullptr)
{
gLogError << "Error deserializing CUDA engine\n";
Expand Down Expand Up @@ -175,7 +175,7 @@ struct BertInference
{
for (int i = 0; i < kBERT_INPUT_NUM; i++)
{
mContext->setBindingDimensions(i + bindingIdxOffset, Dims2(mSeqLength, batchSize));
mContext->setBindingDimensions(i + bindingIdxOffset, Dims2(batchSize, mSeqLength));
}
}

Expand Down
4 changes: 2 additions & 2 deletions demo/BERT/inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@
" engine.create_execution_context() as context:\n",
"\n",
" # We always use batch size 1.\n",
" input_shape = (max_seq_length, 1)\n",
" input_shape = (1, max_seq_length)\n",
" input_nbytes = trt.volume(input_shape) * trt.int32.itemsize\n",
" \n",
" # Allocate device memory for inputs.\n",
Expand Down Expand Up @@ -349,7 +349,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
"version": "3.8.3"
}
},
"nbformat": 4,
Expand Down
10 changes: 5 additions & 5 deletions demo/BERT/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def question_features(tokens, question):
num_binding_per_profile = engine.num_bindings // engine.num_optimization_profiles
for idx in range(engine.num_optimization_profiles):
profile_shape = engine.get_profile_shape(profile_index = idx, binding = idx * num_binding_per_profile)
if profile_shape[0][1] <= args.batch_size and profile_shape[2][1] >= args.batch_size and profile_shape[0][0] <= max_seq_length and profile_shape[2][0] >= max_seq_length:
if profile_shape[0][0] <= args.batch_size and profile_shape[2][0] >= args.batch_size and profile_shape[0][1] <= max_seq_length and profile_shape[2][1] >= max_seq_length:
selected_profile = idx
break
if selected_profile == -1:
Expand All @@ -141,7 +141,7 @@ def question_features(tokens, question):

# Specify input shapes. These must be within the min/max bounds of the active profile
# Note that input shapes can be specified on a per-inference basis, but in this case, we only have a single shape.
input_shape = (max_seq_length, args.batch_size)
input_shape = (args.batch_size, max_seq_length)
input_nbytes = trt.volume(input_shape) * trt.int32.itemsize
for binding in range(3):
context.set_binding_shape(binding_idx_offset + binding, input_shape)
Expand All @@ -168,9 +168,9 @@ def inference(features, tokens):
eval_time_elapsed = 0
for feature_index, feature in enumerate(features):
# Copy inputs
input_ids_batch = np.dstack([feature.input_ids] * args.batch_size).squeeze()
segment_ids_batch = np.dstack([feature.segment_ids] * args.batch_size).squeeze()
input_mask_batch = np.dstack([feature.input_mask] * args.batch_size).squeeze()
input_ids_batch = np.repeat(np.expand_dims(feature.input_ids, 0), args.batch_size, axis=0)
segment_ids_batch = np.repeat(np.expand_dims(feature.segment_ids, 0), args.batch_size, axis=0)
input_mask_batch = np.repeat(np.expand_dims(feature.input_mask, 0), args.batch_size, axis=0)

input_ids = cuda.register_host_memory(np.ascontiguousarray(input_ids_batch.ravel()))
segment_ids = cuda.register_host_memory(np.ascontiguousarray(segment_ids_batch.ravel()))
Expand Down
16 changes: 8 additions & 8 deletions demo/BERT/perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def main():

with open(args.engine, 'rb') as f, trt.Runtime(TRT_LOGGER) as runtime, runtime.deserialize_cuda_engine(f.read()) as engine, engine.create_execution_context() as context:
# Allocate buffers large enough to store the largest batch size
max_input_shape = (args.sequence_length, max(args.batch_size))
max_output_shape = (args.sequence_length, max(args.batch_size), 2, 1, 1)
max_input_shape = (max(args.batch_size), args.sequence_length)
max_output_shape = (max(args.batch_size), args.sequence_length, 2, 1, 1)
buffers = [
DeviceBuffer(max_input_shape),
DeviceBuffer(max_input_shape),
Expand All @@ -65,9 +65,9 @@ def main():
pseudo_vocab_size = 30522
pseudo_type_vocab_size = 2
np.random.seed(args.random_seed)
test_word_ids = np.random.randint(0, pseudo_vocab_size, (args.sequence_length, max(args.batch_size)), dtype=np.int32)
test_segment_ids = np.random.randint(0, pseudo_type_vocab_size, (args.sequence_length, max(args.batch_size)), dtype=np.int32)
test_input_mask = np.ones((args.sequence_length, max(args.batch_size)), dtype=np.int32)
test_word_ids = np.random.randint(0, pseudo_vocab_size, (max(args.batch_size), args.sequence_length), dtype=np.int32)
test_segment_ids = np.random.randint(0, pseudo_type_vocab_size, (max(args.batch_size), args.sequence_length), dtype=np.int32)
test_input_mask = np.ones((max(args.batch_size), args.sequence_length), dtype=np.int32)

# Copy input h2d
cuda.memcpy_htod(buffers[0].buf, test_word_ids.ravel())
Expand All @@ -86,9 +86,9 @@ def main():
bindings = [0] * binding_idx_offset + [buf.binding() for buf in buffers]

shapes = {
"input_ids": (args.sequence_length, batch_size),
"segment_ids": (args.sequence_length, batch_size),
"input_mask": (args.sequence_length, batch_size),
"input_ids": (batch_size, args.sequence_length),
"segment_ids": (batch_size, args.sequence_length),
"input_mask": (batch_size, args.sequence_length),
}

for binding, shape in shapes.items():
Expand Down
5 changes: 3 additions & 2 deletions demo/BERT/scripts/download_model.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ do
done

# Prepare the download directory
mkdir -p /workspace/TensorRT/demo/BERT/models/fine-tuned
cd /workspace/TensorRT/demo/BERT/models/fine-tuned
mkdir -p models/fine-tuned
pushd models/fine-tuned

# Download the BERT fine-tuned model
echo "Downloading BERT-${FW} ${MODEL} checkpoints for sequence length ${SEQ_LEN} and fine-tuned for SQuAD ${SQUAD}."
Expand All @@ -78,3 +78,4 @@ if [ -n "$CKPT" ]; then
ngc registry model download-version nvidia/${CKPT}:${CKPT_VERSION}
fi
fi
popd
5 changes: 3 additions & 2 deletions demo/BERT/scripts/download_squad.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ done

# Download the SQuAD training and dev datasets
echo "Downloading SQuAD-${VERSION} training and dev datasets"
mkdir -p /workspace/TensorRT/demo/BERT/squad
cd /workspace/TensorRT/demo/BERT/squad
mkdir -p squad
pushd squad
wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-${VERSION}.json
wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-${VERSION}.json
popd
10 changes: 5 additions & 5 deletions demo/BERT/scripts/inference_benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ SEQUENCE_LENGTH="${4}"
MAX_BATCH="${5}"
GPU_ARCH="${6}"

CHECKPOINTS_DIR="/workspace/TensorRT/demo/BERT/models/fine-tuned/bert_tf_ckpt_${MODEL_VARIANT}_qa_squad2_amp_${SEQUENCE_LENGTH}_v19.03.1"
SQUAD_DIR="/workspace/TensorRT/demo/BERT/squad"
ENGINE_NAME="/workspace/TensorRT/demo/BERT/engines/bert_${MODEL_VARIANT}_${PRECISION}_bs${MAX_BATCH}_seqlen${SEQUENCE_LENGTH}_benchmark.engine"
CHECKPOINTS_DIR="models/fine-tuned/bert_tf_ckpt_${MODEL_VARIANT}_qa_squad2_amp_${SEQUENCE_LENGTH}_v19.03.1"
SQUAD_DIR="BERT/squad"
ENGINE_NAME="engines/bert_${MODEL_VARIANT}_${PRECISION}_bs${MAX_BATCH}_seqlen${SEQUENCE_LENGTH}_benchmark.engine"
# QAT Checkpoint - available only for BERT-Large
QAT_CHECKPOINT="/workspace/TensorRT/demo/BERT/models/fine-tuned/bert_pyt_onnx_large_qa_squad11_amp_fake_quant_v1/bert_large_v1_1_fake_quant.onnx"
CUDAGRAPH_PERFBIN="/workspace/TensorRT/demo/BERT/build/perf"
QAT_CHECKPOINT="models/fine-tuned/bert_pyt_onnx_large_qa_squad11_amp_fake_quant_v1/bert_large_v1_1_fake_quant.onnx"
CUDAGRAPH_PERFBIN="build/perf"

echo "==== Benchmarking BERT ${MODEL_VARIANT} ${PRECISION} SEQLEN ${SEQUENCE_LENGTH} on ${GPU_ARCH} ===="
if [ ! -f ${ENGINE_NAME} ]; then
Expand Down