Skip to content

Commit

Permalink
sdk2.5 restore code (PaddlePaddle#534)
Browse files Browse the repository at this point in the history
* restore code

* rm ipu_strategy.check()
  • Loading branch information
gglin001 authored Mar 21, 2022
1 parent b9955d4 commit 5d12eaf
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 94 deletions.
3 changes: 1 addition & 2 deletions cmake/inference_lib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ function(copy TARGET)
add_custom_command(TARGET ${TARGET} POST_BUILD
COMMAND mkdir -p "${dst}"
COMMAND cp -r "${src}" "${dst}"
# COMMENT "copying ${src} -> ${dst}"
)
COMMENT "copying ${src} -> ${dst}")
endif (WIN32) # not windows
endforeach ()
endfunction()
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/inference/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ set(SHARED_INFERENCE_SRCS
# shared inference library deps
set(SHARED_INFERENCE_DEPS ${fluid_modules} ${phi_modules} analysis_predictor)

if (WITH_CRYPTO)
if (WITH_CRYPTO)
set(SHARED_INFERENCE_DEPS ${SHARED_INFERENCE_DEPS} paddle_crypto)
endif (WITH_CRYPTO)

Expand Down
70 changes: 35 additions & 35 deletions paddle/fluid/inference/tests/api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ function(download_quant_data_without_verify install_dir data_file)
endfunction()

function(download_model_and_data install_dir model_name model_check_sum data_name data_check_sum)
download_data(${install_dir} ${model_name} ${model_check_sum})
download_data(${install_dir} ${model_name} ${model_check_sum})
download_data(${install_dir} ${data_name} ${data_check_sum})
endfunction()

function(download_model_and_data_without_verify install_dir model_name data_name)
download_data_without_verify(${install_dir} ${model_name})
download_data_without_verify(${install_dir} ${model_name})
download_data_without_verify(${install_dir} ${data_name})
endfunction()

Expand Down Expand Up @@ -165,7 +165,7 @@ function(inference_analysis_api_test_with_fake_data_run TARGET_NAME test_binary
inference_analysis_test_run(${TARGET_NAME}
COMMAND ${test_binary}
ARGS --infer_model=${model_dir}/model
--disable_mkldnn_fc=${disable_fc})
--disable_mkldnn_fc=${disable_fc})
endfunction()

function(inference_analysis_api_quant_test_run TARGET_NAME test_binary fp32_model_dir int8_model_dir data_path)
Expand Down Expand Up @@ -233,7 +233,7 @@ if(NOT APPLE AND WITH_MKLML)
set(RNN1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/rnn1")
download_model_and_data_without_verify(${RNN1_INSTALL_DIR} "rnn1/model.tar.gz" "rnn1/data.txt.tar.gz")
inference_analysis_api_test(test_analyzer_rnn1 ${RNN1_INSTALL_DIR} analyzer_rnn1_tester.cc)

# seq_pool1
set(SEQ_POOL1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/seq_pool")
download_model_and_data_without_verify(${SEQ_POOL1_INSTALL_DIR} "seq_pool1_model_.tar.gz" "seq_pool1_data.txt.tar.gz")
Expand Down Expand Up @@ -277,7 +277,7 @@ inference_analysis_test(test_analyzer_small_dam SRCS analyzer_dam_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${DAM_SMALL_INSTALL_DIR}/model --infer_data=${DAM_SMALL_INSTALL_DIR}/data.txt)

#save model
#save model
inference_analysis_api_test(test_analyzer_save_model ${DAM_SMALL_INSTALL_DIR} analyzer_save_model_tester.cc)

# chinese_ner
Expand Down Expand Up @@ -331,17 +331,17 @@ inference_analysis_api_test(test_analyzer_seq_conv1 ${SEQ_CONV1_INSTALL_DIR} ana
# transformer, the dataset only works on batch_size=8 now
set(TRANSFORMER_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/transformer")
download_model_and_data_without_verify(${TRANSFORMER_INSTALL_DIR} "temp/transformer_model.tar.gz" "temp/transformer_data.txt.tar.gz")
inference_analysis_test(test_analyzer_transformer SRCS analyzer_transformer_compare_tester.cc
inference_analysis_test(test_analyzer_transformer SRCS analyzer_transformer_compare_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRANSFORMER_INSTALL_DIR}/model --infer_data=${TRANSFORMER_INSTALL_DIR}/data.txt --batch_size=8
ARGS --infer_model=${TRANSFORMER_INSTALL_DIR}/model --infer_data=${TRANSFORMER_INSTALL_DIR}/data.txt --batch_size=8
--cpu_num_threads=${CPU_NUM_THREADS_ON_CI})
inference_analysis_test(test_analyzer_transformer_fuse SRCS analyzer_transformer_fuse_tester.cc
inference_analysis_test(test_analyzer_transformer_fuse SRCS analyzer_transformer_fuse_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRANSFORMER_INSTALL_DIR}/model --infer_data=${TRANSFORMER_INSTALL_DIR}/data.txt --batch_size=8
ARGS --infer_model=${TRANSFORMER_INSTALL_DIR}/model --infer_data=${TRANSFORMER_INSTALL_DIR}/data.txt --batch_size=8
--cpu_num_threads=${CPU_NUM_THREADS_ON_CI})
inference_analysis_test(test_analyzer_transformer_profile SRCS analyzer_transformer_profile_tester.cc
inference_analysis_test(test_analyzer_transformer_profile SRCS analyzer_transformer_profile_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRANSFORMER_INSTALL_DIR}/model --infer_data=${TRANSFORMER_INSTALL_DIR}/data.txt --batch_size=8
ARGS --infer_model=${TRANSFORMER_INSTALL_DIR}/model --infer_data=${TRANSFORMER_INSTALL_DIR}/data.txt --batch_size=8
--cpu_num_threads=${CPU_NUM_THREADS_ON_CI})

# ocr
Expand All @@ -354,9 +354,9 @@ inference_analysis_api_test(test_analyzer_ocr ${OCR_INSTALL_DIR} analyzer_vis_te
# densebox
set(DENSEBOX_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/densebox")
download_data_without_verify(${DENSEBOX_INSTALL_DIR} "densebox.tar.gz")
inference_analysis_test(test_analyzer_detect_functional_mkldnn SRCS analyzer_detect_functional_mkldnn_tester.cc
inference_analysis_test(test_analyzer_detect_functional_mkldnn SRCS analyzer_detect_functional_mkldnn_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${DENSEBOX_INSTALL_DIR}/model --infer_data=${DENSEBOX_INSTALL_DIR}/detect_input_50.txt
ARGS --infer_model=${DENSEBOX_INSTALL_DIR}/model --infer_data=${DENSEBOX_INSTALL_DIR}/detect_input_50.txt
--infer_shape=${DENSEBOX_INSTALL_DIR}/shape_50.txt)

# mobilenet with transpose op
Expand Down Expand Up @@ -424,24 +424,24 @@ if(WITH_MKLDNN)
set(INT8_MOBILENETV1_MODEL_DIR "${INT8_DATA_DIR}/mobilenetv1")
download_int8_data_without_verify(${INT8_MOBILENETV1_MODEL_DIR} "mobilenetv1_int8_model.tar.gz" )
inference_analysis_api_int8_test_run(test_analyzer_int8_mobilenetv1 ${INT8_IMG_CLASS_TEST_APP} ${INT8_MOBILENETV1_MODEL_DIR} ${IMAGENET_DATA_PATH})

# mobilenetv2 int8
set(INT8_MOBILENETV2_MODEL_DIR "${INT8_DATA_DIR}/mobilenetv2")
download_int8_data_without_verify(${INT8_MOBILENETV2_MODEL_DIR} "mobilenet_v2_int8_model.tar.gz" )
inference_analysis_api_int8_test_run(test_analyzer_int8_mobilenetv2 ${INT8_IMG_CLASS_TEST_APP} ${INT8_MOBILENETV2_MODEL_DIR} ${IMAGENET_DATA_PATH})

# resnet101 int8
# TODO(grygielski) Enable after MKL-DNN 1.0 merge
set(INT8_RESNET101_MODEL_DIR "${INT8_DATA_DIR}/resnet101")
download_int8_data_without_verify(${INT8_RESNET101_MODEL_DIR} "Res101_int8_model.tar.gz" )
# inference_analysis_api_int8_test_run(test_analyzer_int8_resnet101 ${INT8_IMG_CLASS_TEST_APP} ${INT8_RESNET101_MODEL_DIR} ${IMAGENET_DATA_PATH})

# vgg16 int8
# TODO(grygielski) Enable after MKL-DNN 1.0 merge
set(INT8_VGG16_MODEL_DIR "${INT8_DATA_DIR}/vgg16")
download_int8_data_without_verify(${INT8_VGG16_MODEL_DIR} "VGG16_int8_model.tar.gz" )
# inference_analysis_api_int8_test_run(test_analyzer_int8_vgg16 ${INT8_IMG_CLASS_TEST_APP} ${INT8_VGG16_MODEL_DIR} ${IMAGENET_DATA_PATH})

# vgg19 int8
# TODO(grygielski) Enable after MKL-DNN 1.0 merge
set(INT8_VGG19_MODEL_DIR "${INT8_DATA_DIR}/vgg19")
Expand Down Expand Up @@ -479,7 +479,7 @@ if(WITH_MKLDNN)

# resnet50 bfloat16
inference_analysis_api_bfloat16_test_run(test_analyzer_bfloat16_resnet50 ${BF16_IMG_CLASS_TEST_APP} ${INT8_RESNET50_MODEL_DIR} ${IMAGENET_DATA_PATH})

# googlenet bfloat16
inference_analysis_api_bfloat16_test_run(test_analyzer_bfloat16_googlenet ${BF16_IMG_CLASS_TEST_APP} ${INT8_GOOGLENET_MODEL_DIR} ${IMAGENET_DATA_PATH})

Expand Down Expand Up @@ -534,11 +534,11 @@ if(WITH_MKLDNN)
inference_analysis_api_lexical_bfloat16_test_run(test_analyzer_lexical_gru_bfloat16 ${LEXICAL_TEST_APP} ${GRU_MODEL_PATH} ${GRU_DATA_PATH})
# run post-training quantization lexical analysis test
inference_analysis_api_lexical_int8_test_run(test_analyzer_lexical_gru_int8 ${LEXICAL_TEST_APP} ${GRU_MODEL_PATH} ${GRU_DATA_PATH} false)
# run post-training quantization lexical analysis test with multi_gru fuse
# run post-training quantization lexical analysis test with multi_gru fuse
inference_analysis_api_lexical_int8_test_run(test_analyzer_lexical_gru_int8_multi_gru ${LEXICAL_TEST_APP} ${GRU_MODEL_PATH} ${GRU_DATA_PATH} true)

### optimized FP32 vs. Quant INT8 tests

set(QUANT_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/quant")
set(QUANT_IMG_CLASS_TEST_APP "test_analyzer_quant_image_classification")
set(QUANT_IMG_CLASS_TEST_APP_SRC "analyzer_quant_image_classification_tester.cc")
Expand All @@ -557,7 +557,7 @@ if(WITH_MKLDNN)
inference_analysis_api_quant_test_run(test_analyzer_quant_performance_benchmark ${QUANT_IMG_CLASS_TEST_APP} ${QUANT2_MobileNetV1_MODEL_DIR}/MobileNet_qat_perf/float ${QUANT2_INT8_MobileNetV1_MODEL_DIR}/MobileNet_qat_perf_int8 ${IMAGENET_DATA_PATH})

### Other tests

# MKLDNN quantizer config
set(MKLDNN_QUANTIZER_CONFIG_TEST_APP "test_mkldnn_quantizer_config")
set(MKLDNN_QUANTIZER_CONFIG_TEST_APP_SRC "mkldnn_quantizer_config_tester.cc")
Expand All @@ -569,7 +569,7 @@ if(WITH_MKLDNN)
set(IMAGENET_SMALL_DATA_DIR "${INT8_DATA_DIR}/imagenet_small")
set(IMAGENET_SMALL_OUTPUT_FILE "imagenet_small.bin")
preprocess_data2bin_test_run(preprocess_local_imagenet "full_ILSVRC2012_val_preprocess.py" ${IMAGENET_SMALL_DATA_DIR} ${IMAGENET_SMALL_OUTPUT_FILE})

# preprocess data2bin pascalvoc
download_int8_data_without_verify(${INT8_DATA_DIR} "pascalvoc_small.tar.gz")
set(PASCALVOC_SMALL_DATA_DIR "${INT8_DATA_DIR}/pascalvoc_small")
Expand All @@ -587,7 +587,7 @@ endif()

# multiple models prediction
set(MMP_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/multi_model_prediction")
download_data_without_verify(${MMP_INSTALL_DIR} PaddleInference/mobilenet_v2_models.tar.gz)
download_data_without_verify(${MMP_INSTALL_DIR} PaddleInference/mobilenet_v2_models.tar.gz)
inference_multiple_models_analysis_api_test(test_analyzer_multi_model_prediction ${MMP_INSTALL_DIR} analyzer_mmp_tester.cc)

if(WITH_GPU AND TENSORRT_FOUND)
Expand Down Expand Up @@ -615,15 +615,15 @@ if(WITH_GPU AND TENSORRT_FOUND)
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models)
inference_analysis_test(trt_split_converter_test SRCS trt_split_converter_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TEST_SPLIT_CONVERTER_MODEL}/)
inference_analysis_test(test_analyzer_capi_exp_gpu SRCS analyzer_capi_exp_gpu_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} paddle_inference_c
ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models)
inference_analysis_test(test_analyzer_capi_exp_xpu SRCS analyzer_capi_exp_xpu_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} paddle_inference_c
ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models)

set(TRT_MODEL_QUANT_RESNET_DIR "${INFERENCE_DEMO_INSTALL_DIR}/small_quant_model")
if (NOT EXISTS ${INFERENCE_DEMO_INSTALL_DIR}/small_quant_model.tgz)
inference_download_and_uncompress_without_verify(${INFERENCE_DEMO_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "small_quant_model.tgz")
Expand Down Expand Up @@ -659,7 +659,7 @@ if(WITH_GPU AND TENSORRT_FOUND)
endif()

inference_analysis_test(test_trt_dynamic_shape_ernie SRCS trt_dynamic_shape_ernie_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TEST_TRT_ERNIE_MODEL}/ernie_model_4)

set(TEST_TRT_TRANSFORMER_PRUNE_MODEL "${TRT_MODEL_INSTALL_DIR}/transformer_prune")
Expand All @@ -668,23 +668,23 @@ if(WITH_GPU AND TENSORRT_FOUND)
endif()

inference_analysis_test(test_trt_dynamic_shape_transformer_prune SRCS trt_dynamic_shape_transformer_prune_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TEST_TRT_TRANSFORMER_PRUNE_MODEL}/transformer_prune)

if (NOT EXISTS ${TEST_TRT_ERNIE_MODEL}/ernie_model_4_unserialized.tgz)
inference_download_and_uncompress(${TEST_TRT_ERNIE_MODEL} ${INFERENCE_URL}/tensorrt_test "ernie_model_4_unserialized.tgz" 833d73fc6a7f7e1ee4a1fd6419209e55)
endif()

inference_analysis_test(test_trt_dynamic_shape_ernie_ser_deser SRCS trt_dynamic_shape_ernie_serialize_deserialize_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TEST_TRT_ERNIE_MODEL}/ernie_model_4_unserialized)

if (NOT EXISTS ${TEST_TRT_ERNIE_MODEL}/ernie_model_4_fp16_unserialized.tgz)
inference_download_and_uncompress(${TEST_TRT_ERNIE_MODEL} ${INFERENCE_URL}/tensorrt_test "ernie_model_4_fp16_unserialized.tgz" c5ff2d0cad79953ffbf2b8b9e2fae6e4)
endif()

inference_analysis_test(test_trt_dynamic_shape_ernie_fp16_ser_deser SRCS trt_dynamic_shape_ernie_fp16_serialize_deserialize_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TEST_TRT_ERNIE_MODEL}/ernie_model_4_fp16_unserialized)

endif()
Expand Down Expand Up @@ -717,8 +717,8 @@ if (NOT APPLE AND NOT WIN32)
ARGS --infer_model=${MOBILENET_INSTALL_DIR}/model)
endif()
inference_analysis_test(test_analyzer_zerocopytensor_tensor SRCS analyzer_zerocopy_tensor_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${OCR_INSTALL_DIR}/model)
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${OCR_INSTALL_DIR}/model)

if(WITH_DISTRIBUTE AND WITH_PSCORE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL))
inference_analysis_test(test_analyzer_dist_model SRCS analyzer_dist_model_tester.cc
Expand All @@ -727,16 +727,16 @@ if(WITH_DISTRIBUTE AND WITH_PSCORE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL))
endif()

inference_analysis_test(test_analyzer_paddletensor_tensor SRCS analyzer_paddle_tensor_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${OCR_INSTALL_DIR}/model --infer_data=${OCR_INSTALL_DIR}/data.txt --refer_result=${OCR_INSTALL_DIR}/result.txt)

EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${OCR_INSTALL_DIR}/model --infer_data=${OCR_INSTALL_DIR}/data.txt --refer_result=${OCR_INSTALL_DIR}/result.txt)
if(WITH_MKLDNN)
inference_analysis_test(test_analyzer_capi_exp_int SRCS analyzer_capi_exp_int_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} paddle_inference_c
ARGS --infer_model=${INT8_DATA_DIR}/resnet50/model)
endif()

inference_analysis_test(test_analyzer_capi_exp_ner SRCS analyzer_capi_exp_ner_tester.cc
inference_analysis_test(test_analyzer_capi_exp_ner SRCS analyzer_capi_exp_ner_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} paddle_inference_c
ARGS --infer_model=${CHINESE_NER_INSTALL_DIR}/model)

Expand Down
2 changes: 1 addition & 1 deletion python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ add_custom_target(copy_paddle_pybind ALL DEPENDS ${FLUID_CORE_DEPS})
IF(WIN32)
add_custom_command(OUTPUT ${PADDLE_PYTHON_BUILD_DIR}/.timestamp
COMMAND ${CMAKE_COMMAND} -E copy_directory ${PADDLE_SOURCE_DIR}/python/paddle ${PADDLE_BINARY_DIR}/python/paddle/
COMMAND ${CMAKE_COMMAND} -E env ${py_env} ${PYTHON_EXECUTABLE} setup.py --quiet bdist_wheel
COMMAND ${CMAKE_COMMAND} -E env ${py_env} ${PYTHON_EXECUTABLE} setup.py bdist_wheel
COMMENT "Packing whl packages------>>>"
DEPENDS copy_paddle_pybind ${FLUID_CORE} framework_py_proto profiler_py_proto pass_desc_py_proto ${PY_FILES})
ELSE(WIN32)
Expand Down
45 changes: 0 additions & 45 deletions python/paddle/fluid/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,48 +809,6 @@ def disable_pattern(self, pattern):
"""
self._ipu_strategy.disable_pattern(pattern)

def check(self):
"""
This function is going to check if the ipu_strategy is valid.
"""
if self.get_option("enable_distribution"):
if 'POPDIST_NUM_TOTAL_REPLICAS' not in os.environ:
raise RuntimeError(
"Please use poprun to run the program with POD128 and POD256"
)
required_local_replicas = int(
os.environ.get('POPDIST_NUM_LOCAL_REPLICAS', default='1'))
required_total_replicas = int(
os.environ.get('POPDIST_NUM_TOTAL_REPLICAS', default='1'))
required_ipus_per_replica = int(
os.environ.get('POPDIST_NUM_IPUS_PER_REPLICA', default='1'))

local_replicas = self.get_option("replicated_graph_count")
total_replicas = self.get_option("global_replication_factor")
local_num_ipus = self.get_option("num_ipus")

if required_local_replicas != local_replicas:
raise RuntimeError(
"Please set valid replicated_graph_count for distribution. Expect %d, but received %d."
% (required_local_replicas, local_replicas))
if required_total_replicas != total_replicas:
raise RuntimeError(
"Please set valid global_replication_factor for distribution. Expect %d, but received %d."
% (required_total_replicas, total_replicas))
if required_ipus_per_replica * local_replicas != local_num_ipus:
raise RuntimeError(
"Please set valid num_ipus for distribution. Expect %d, but received %d."
% (required_ipus_per_replica * local_replicas,
local_num_ipus))

if local_replicas != total_replicas:
replica_index = int(
os.environ.get('POPDIST_REPLICA_INDEX_OFFSET', default='0'))
self.set_options({
"enable_distributed_replicated_graphs": True,
"global_replica_offset": replica_index
})

@property
def num_ipus(self):
"""
Expand Down Expand Up @@ -952,9 +910,6 @@ def __init__(self, program=None, scope=None, ipu_strategy=None):
else:
self._ipu_strategy = IpuStrategy()

# check if the ipu_strategy is valid
self._ipu_strategy.check()

if ipu_strategy.has_custom_ops:
self._custom_op_names = set(ipu_strategy.custom_op_names)
else:
Expand Down
11 changes: 1 addition & 10 deletions python/paddle/nn/layer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,6 @@ def __init__(self,
dim_feedforward,
dropout=0.1,
activation="relu",
approximate=False,
attn_dropout=None,
act_dropout=None,
normalize_before=False,
Expand Down Expand Up @@ -539,8 +538,6 @@ def __init__(self,
self.dropout1 = Dropout(dropout, mode="upscale_in_train")
self.dropout2 = Dropout(dropout, mode="upscale_in_train")
self.activation = getattr(F, activation)
self.act_name = activation
self.act_approximate = approximate

def forward(self, src, src_mask=None, cache=None):
r"""
Expand Down Expand Up @@ -593,13 +590,7 @@ def forward(self, src, src_mask=None, cache=None):
residual = src
if self.normalize_before:
src = self.norm2(src)
if self.act_name == "gelu":
src = self.linear2(
self.dropout(
self.activation(
self.linear1(src), approximate=self.act_approximate)))
else:
src = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = residual + self.dropout2(src)
if not self.normalize_before:
src = self.norm2(src)
Expand Down

0 comments on commit 5d12eaf

Please sign in to comment.