diff --git a/csrc/README.md b/csrc/README.md new file mode 100644 index 0000000000..cc1db6b9f1 --- /dev/null +++ b/csrc/README.md @@ -0,0 +1,15 @@ +# PaddleClas 自定义 OP + +此文档介绍如何编译安装 PaddleClas 自定义 OP。 + +## 安装 pip 依赖 + +```shell +pip install -r requirements.txt +``` + +## 编译 Cuda 算子 + +```shell +python setup_cuda.py install +``` \ No newline at end of file diff --git a/csrc/generation/helper.h b/csrc/generation/helper.h new file mode 100644 index 0000000000..4a74709aec --- /dev/null +++ b/csrc/generation/helper.h @@ -0,0 +1,103 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/extension.h" +#include +#include + +constexpr int kBlockSize = 256; +constexpr int kNumWaves = 16; + +inline cudaError_t GetNumBlocks(int64_t n, int* num_blocks) { + int dev; + { + cudaError_t err = cudaGetDevice(&dev); + if (err != cudaSuccess) { return err; } + } + int sm_count; + { + cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev); + if (err != cudaSuccess) { return err; } + } + int tpm; + { + cudaError_t err = cudaDeviceGetAttribute(&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev); + if (err != cudaSuccess) { return err; } + } + *num_blocks = std::max(1, std::min((n + kBlockSize - 1) / kBlockSize, + sm_count * tpm / kBlockSize * kNumWaves)); + return cudaSuccess; +} + +template +__device__ T max_func(const T a, const T b) { + return a > b ? a : b; +} + +template +struct MaxOp { + __device__ __forceinline__ T operator()(const T& a, const T& b) const { + return max_func(a, b); + } +}; + +template +class PDTraits; + +template <> +class PDTraits { +public: + typedef float DataType; + typedef float data_t; +}; + +template <> +class PDTraits { +public: + typedef half DataType; + typedef paddle::float16 data_t; +}; + +template <> +class PDTraits { +public: + typedef __nv_bfloat16 DataType; + typedef paddle::bfloat16 data_t; +}; + +template +struct alignas(sizeof(T) * Size) AlignedVector { + T val[Size]; + + HOSTDEVICE inline const T& operator[](int i) const { return val[i]; } + HOSTDEVICE inline T& operator[](int i) { return val[i]; } +}; + +template +HOSTDEVICE inline void Load(const T* addr, AlignedVector* vec) { + const AlignedVector* addr_vec = + reinterpret_cast*>(addr); + *vec = *addr_vec; +} + +template +HOSTDEVICE inline void Store(const AlignedVector& vec, T* addr) { + AlignedVector* addr_vec = + reinterpret_cast*>(addr); + *addr_vec = vec; +} + +constexpr int VEC_16B = 16; \ No newline at end of file diff --git a/csrc/generation/qkv_transpose_split.cu b/csrc/generation/qkv_transpose_split.cu new file mode 100644 index 0000000000..ba9ee1f8ce --- /dev/null +++ b/csrc/generation/qkv_transpose_split.cu @@ -0,0 +1,193 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" + +template +__global__ void fusedQKV_transpose_split_kernel( + T *q_buf, + T *k_buf, + T *v_buf, + const T *qkv, + const int *padding_offset, + const int *seq_lens, + const int32_t elem_cnt, + const int batch_size, + const int max_len_this_time, + const int seq_len, + const int token_num, + const int head_num, + const int size_per_head) { + const int32_t offset = batch_size * max_len_this_time * head_num * size_per_head; + const int32_t hidden_size = head_num * size_per_head; + const int32_t fused_hidden_size = 3 * hidden_size; + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + using LoadT = AlignedVector; + LoadT src_vec; + LoadT bias_vec; + + for (int32_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + Load(&qkv[linear_index], &src_vec); + int32_t bias_idx = linear_index % fused_hidden_size; + const int32_t token_idx = linear_index / fused_hidden_size; + const int32_t ori_token_idx = + token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]); + const int32_t target_batch_id = ori_token_idx / seq_len; + if (seq_lens[target_batch_id] == 0) continue; + const int32_t seq_id = ori_token_idx % seq_len; + + // equal to: + // const int qkv_id = (linear_index % fused_hidden_size) / hidden_size; + const int32_t qkv_id = bias_idx / hidden_size; + const int32_t head_id = (linear_index % hidden_size) / size_per_head; + const int32_t size_id = linear_index % size_per_head; + + if (qkv_id == 0) { + Store( + src_vec, + &q_buf[target_batch_id * head_num * max_len_this_time * size_per_head + + head_id * max_len_this_time * size_per_head + seq_id * size_per_head + + size_id]); + } else if (qkv_id == 1) { + Store( + src_vec, + &k_buf[target_batch_id * head_num * max_len_this_time * size_per_head + + head_id * max_len_this_time * size_per_head + seq_id * size_per_head + + size_id]); + } else { + Store( + src_vec, + &v_buf[target_batch_id * head_num * max_len_this_time * size_per_head + + head_id * max_len_this_time * size_per_head + seq_id * size_per_head + + size_id]); + } + } +} + +template +std::vector qkv_transpose_split(const paddle::Tensor& qkv, // [token_num, dim_embed] + const paddle::Tensor& padding_offset, // [bsz, 1] + const paddle::Tensor& seq_lens, + const paddle::Tensor& input_ids, + int num_head, + int head_size) { + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + + auto cu_stream = qkv.stream(); + std::vector qkv_shape = qkv.shape(); + const int token_num = qkv_shape[0]; + const int bsz = seq_lens.shape()[0]; + const int max_seq_len = input_ids.shape()[1]; //max_seq_len_tensor.copy_to(paddle::CPUPlace(), false).data()[0]; + auto q_out = paddle::full({bsz, num_head, max_seq_len, head_size}, 0, qkv.dtype(), qkv.place()); + auto k_out = paddle::full({bsz, num_head, max_seq_len, head_size}, 0, qkv.dtype(), qkv.place()); + auto v_out = paddle::full({bsz, num_head, max_seq_len, head_size}, 0, qkv.dtype(), qkv.place()); + constexpr int PackSize = VEC_16B / sizeof(DataType_); + const int elem_cnt = token_num * num_head * head_size * 3; + const int pack_num = elem_cnt / PackSize; + const int blocksize = 128; + const int grid_size = (pack_num + blocksize - 1) / blocksize; + fusedQKV_transpose_split_kernel + <<>>( + reinterpret_cast(q_out.data()), + reinterpret_cast(k_out.data()), + reinterpret_cast(v_out.data()), + reinterpret_cast(const_cast(qkv.data())), + padding_offset.data(), + seq_lens.data(), + elem_cnt, + bsz, + max_seq_len, + max_seq_len, + token_num, + num_head, + head_size); + return {q_out, k_out, v_out}; +} + +std::vector QKVTransposeSplit(const paddle::Tensor& qkv, + const paddle::Tensor& padding_offset, + const paddle::Tensor& seq_lens, + const paddle::Tensor& input_ids, + int num_head, + int head_size) { + switch (qkv.type()) { + case paddle::DataType::BFLOAT16: { + return qkv_transpose_split( + qkv, + padding_offset, + seq_lens, + input_ids, + num_head, + head_size + ); + } + case paddle::DataType::FLOAT16: { + return qkv_transpose_split( + qkv, + padding_offset, + seq_lens, + input_ids, + num_head, + head_size + ); + } + case paddle::DataType::FLOAT32: { + return qkv_transpose_split( + qkv, + padding_offset, + seq_lens, + input_ids, + num_head, + head_size + ); + } + default: { + PD_THROW( + "NOT supported data type. " + "Only float16, bfloat16 and float32 are supported. "); + break; + } + } +} + +std::vector> QKVTransposeSplitInferShape(const std::vector& qkv_shape, + const std::vector& padding_offset_shape, + const std::vector& seq_lens_shape, + const std::vector& input_ids_shape, + int num_head, + int head_size) { + int64_t bsz = seq_lens_shape[0]; + return {{bsz, num_head, -1, head_size}, {bsz, num_head, -1, head_size}, {bsz, num_head, -1, head_size}}; +} + +std::vector QKVTransposeSplitInferDtype(const paddle::DataType& qkv_dtype, + const paddle::DataType& padding_offset_dtype, + const paddle::DataType& seq_lens_dtype, + const paddle::DataType& input_ids_dtype) { + return {qkv_dtype, qkv_dtype, qkv_dtype}; +} + +PD_BUILD_OP(qkv_transpose_split) + .Inputs({"qkv", "padding_offset", "seq_lens", "input_ids"}) + .Outputs({"q_out", "k_out", "v_out"}) + .Attrs({"num_head: int", + "head_size: int"}) + .SetKernelFn(PD_KERNEL(QKVTransposeSplit)) + .SetInferShapeFn(PD_INFER_SHAPE(QKVTransposeSplitInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(QKVTransposeSplitInferDtype)); \ No newline at end of file diff --git a/csrc/generation/transpose_remove_padding.cu b/csrc/generation/transpose_remove_padding.cu new file mode 100644 index 0000000000..5b6b16a7fa --- /dev/null +++ b/csrc/generation/transpose_remove_padding.cu @@ -0,0 +1,177 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" + +template +__global__ void TransposeRemovingPadding(const T* input_data, + const int* seq_lens, + T* output_data, + const int batch_size, + const int num_head, + const int max_len_this_time, + const int seq_len, + const int head_dim, + const int token_num, + const int elem_cnt, + const int* padding_offset) { + // transpose and remove padding + // [batch_size, num_head, max_len_this_time, head_dim] -> [token_num, num_head, + // head_dim] + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; + const int dim_embed = num_head * head_dim; + using LoadT = AlignedVector; + LoadT src_vec; + + for (int32_t linear_index = idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + const int token_idx = linear_index / dim_embed; + const int ori_token_idx = + token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]); + const int ori_batch_id = ori_token_idx / seq_len; + if (seq_lens && seq_lens[ori_batch_id] == 0) continue; + const int ori_seq_id = ori_token_idx % seq_len; + const int ori_head_id = (linear_index % dim_embed) / head_dim; + const int ori_head_lane = (linear_index % dim_embed) % head_dim; + const int ori_idx = ori_batch_id * num_head * max_len_this_time * head_dim + + ori_head_id * max_len_this_time * head_dim + + ori_seq_id * head_dim + ori_head_lane; + Load(&input_data[ori_idx], &src_vec); + Store(src_vec, &output_data[linear_index]); + } +} + +template +void InvokeTransposeRemovePadding(const T* input_data, + const int* seq_lens, + T* output_data, + const int batch_size, + const int num_head, + const int max_len_this_time, + const int seq_len, + const int head_dim, + const int token_num, + const int* padding_offset, + cudaStream_t cu_stream) { + // [batch_size, num_head, max_len_this_time, head_dim] -> [token_num, num_head, + // head_dim] + constexpr int VEC_16B = 16; + const int elem_cnt = token_num * num_head * head_dim; + constexpr int PackSize = VEC_16B / sizeof(T); + const int32_t pack_num = elem_cnt / PackSize; + const int32_t block_size = 128; + int32_t grid_size = (pack_num + block_size - 1) / block_size; + TransposeRemovingPadding + <<>>(input_data, + seq_lens, + output_data, + batch_size, + num_head, + max_len_this_time, + seq_len, + head_dim, + token_num, + elem_cnt, + padding_offset); +} + +template +std::vector apply_transpose_remove_padding(const paddle::Tensor& input, + const paddle::Tensor& seq_lens, + const paddle::Tensor& padding_offset) { + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + + auto cu_stream = input.stream(); + std::vector input_shape = input.shape(); + const int bsz = input_shape[0]; + const int num_head = input_shape[1]; + const int seq_len = input_shape[2]; + const int dim_head = input_shape[3]; + const int token_num = padding_offset.shape()[0]; + + auto out = paddle::full({token_num, num_head * dim_head}, 0, input.dtype(), input.place()); + InvokeTransposeRemovePadding( + reinterpret_cast(const_cast(input.data())), + seq_lens.data(), + reinterpret_cast(out.data()), + bsz, + num_head, + seq_len, + seq_len, + dim_head, + token_num, + padding_offset.data(), + cu_stream + ); + return {out}; +} + +std::vector ApplyTransposeRemovingPadding(const paddle::Tensor& input, + const paddle::Tensor& seq_lens, + const paddle::Tensor& padding_offset) { + switch (input.type()) { + case paddle::DataType::BFLOAT16: { + return apply_transpose_remove_padding( + input, + seq_lens, + padding_offset + ); + } + case paddle::DataType::FLOAT16: { + return apply_transpose_remove_padding( + input, + seq_lens, + padding_offset + ); + } + case paddle::DataType::FLOAT32: { + return apply_transpose_remove_padding( + input, + seq_lens, + padding_offset + ); + } + default: { + PD_THROW( + "NOT supported data type. " + "Only float16, bfloat16 and float32 are supported. "); + break; + } + } +} + +std::vector> ApplyTransposeRemovingPaddingInferShape( + const std::vector& input_shape, + const std::vector& seq_lens_shape, + const std::vector& padding_offset_shape) { + return {{padding_offset_shape[0], input_shape[1] * input_shape[3]}}; +} + +std::vector ApplyTransposeRemovingPaddingInferDtype( + const paddle::DataType& input_dtype, + const paddle::DataType& seq_lens_dtype, + const paddle::DataType& padding_offset_dtype) { + return {input_dtype}; +} + +PD_BUILD_OP(transpose_remove_padding) + .Inputs({"input", "seq_lens", "padding_offset"}) + .Outputs({"fmha_out"}) + .SetKernelFn(PD_KERNEL(ApplyTransposeRemovingPadding)) + .SetInferShapeFn(PD_INFER_SHAPE(ApplyTransposeRemovingPaddingInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(ApplyTransposeRemovingPaddingInferDtype)); \ No newline at end of file diff --git a/csrc/requirements.txt b/csrc/requirements.txt new file mode 100644 index 0000000000..0bf0625387 --- /dev/null +++ b/csrc/requirements.txt @@ -0,0 +1,2 @@ +cupy-cuda116 +pybind11 \ No newline at end of file diff --git a/csrc/setup_cuda.py b/csrc/setup_cuda.py new file mode 100644 index 0000000000..3bdabe7a21 --- /dev/null +++ b/csrc/setup_cuda.py @@ -0,0 +1,25 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.utils.cpp_extension import CUDAExtension, setup + +setup( + name="paddleclas_ops", + ext_modules=CUDAExtension( + sources=[ + "./generation/transpose_remove_padding.cu", + "./generation/qkv_transpose_split.cu", + ] + ), +) \ No newline at end of file diff --git a/docs/zh_CN/fused_vit/README.md b/docs/zh_CN/fused_vit/README.md new file mode 100644 index 0000000000..edfdac0ce1 --- /dev/null +++ b/docs/zh_CN/fused_vit/README.md @@ -0,0 +1,331 @@ +# Fused Vision Transformer 高性能推理使用 + +PaddleClas 中已经添加高性能推理模型相关实现,支持: + +| Model | FP16 | Wint8 | Wint4 | PTQ | +|-------------------------------------------------------------------------------------------------|------|-------|-------|-----| +| [Fused Vision Transformer](../../../ppcls/arch/backbone/model_zoo/fused_vision_transformer.py) | ✅ | ✅ | ✅ | ❌ | + +* 支持以下`fused_vit`类型 + * `Fused_ViT_small_patch16_224` + * `Fused_ViT_base_patch16_224` + * `Fused_ViT_base_patch16_384` + * `Fused_ViT_base_patch32_384` + * `Fused_ViT_large_patch16_224` + * `Fused_ViT_large_patch16_384` + * `Fused_ViT_large_patch32_384` +* 预训练权重来自Vision Transformer对应权重 + +## 安装自定义算子库 + +PaddleClas 针对于 Fused Vision Transformer 系列编写了高性能自定义算子,提升模型在推理和解码过程中的性能。 + +```shell +cd ./PaddleClas/csrc +pip install -r requirements.txt +python setup_cuda.py install +``` + +## 静态图推理 + +* 模型导出 + +```python +from paddleclas import ( + Fused_ViT_large_patch16_224, + Fused_ViT_large_patch32_384 +) +import paddle + +if __name__ == "__main__": + dtype = "float16" + paddle.set_default_dtype(dtype) + path = "/your/path/fused_384_fp16/static_model" + model = Fused_ViT_large_patch32_384(pretrained=True, class_num=1000) + model.eval() + model = paddle.jit.to_static( + model, + input_spec=[ + paddle.static.InputSpec( + shape=[None] + [3, 384, 384], + dtype=dtype + ) + ] + ) + paddle.jit.save(model, path) +``` + +* 模型推理 + +```python +from paddle.inference import create_predictor +from paddle.inference import PrecisionType +from paddle.inference import Config +from paddleclas_ops import ( + qkv_transpose_split, + transpose_remove_padding +) +import paddle +import numpy as np + +from paddleclas import ( + Fused_ViT_large_patch32_384, +) + +def run(predictor, img): + # copy img data to input tensor + input_names = predictor.get_input_names() + for i, name in enumerate(input_names): + input_tensor = predictor.get_input_handle(name) + input_tensor.reshape(img[i].shape) + input_tensor.copy_from_cpu(img[i]) + + # do the inference + predictor.run() + + results = [] + # get out data from output tensor + output_names = predictor.get_output_names() + for i, name in enumerate(output_names): + output_tensor = predictor.get_output_handle(name) + output_data = output_tensor.copy_to_cpu() + results.append(output_data) + return results + +def static_infer(model_file, params_file, images): + config = Config(model_file, params_file) + config.enable_memory_optim() + config.enable_use_gpu(1000, 0) + + predictor = create_predictor(config) + + output = run(predictor, [images]) + + return output + +def main_fp16(): + dtype = "float16" + N, C, H, W = (1, 3, 384, 384) + images = np.random.rand(N, C, H, W).astype(dtype) + + # fp32 static infer + model_file = "/your/path/fused_384_fp16/static_model.pdmodel" + params_file = "/your/path/fused_384_fp16/static_model.pdiparams" + static_fp16_output = static_infer(model_file, params_file, images) + +if __name__ == "__main__": + main_fp16() +``` + +## 动态图推理 + +### FP16 + +* `fused_vit`通过`paddle.set_default_dtype`来设置`weight`的数据类型 + +```python +import paddle + +from paddleclas import ( + Fused_ViT_large_patch32_384, +) + +if __name__ == '__main__': + dtype = "float16" + N, C, H, W = (1, 3, 384, 384) + images = paddle.randn([N, C, H, W]).cast(dtype) + paddle.set_default_dtype(dtype) + + # ----- Fused Model ----- + fused_model = Fused_ViT_large_patch32_384(pretrained=True, class_num=1000) + fused_output = fused_model(images) + print(fused_output) +``` + +### Weight Only Int8/Int4 推理 + +> weight only int4 存在精度问题 + +* 参数介绍: + * `use_weight_only`:使用 weight only 推理,默认为 False + * `quant_type`:weight only 类型,默认为`weight_only_int8`,可选`weight_only_int4` + +```python +import paddle + +from paddleclas import ( + Fused_ViT_large_patch32_384, +) + +if __name__ == '__main__': + dtype = "float16" + N, C, H, W = (1, 3, 384, 384) + images = paddle.randn([N, C, H, W]).cast(dtype) + paddle.set_default_dtype(dtype) + + # ----- 8 bits Quanted Model ----- + quanted_model_8 = Fused_ViT_large_patch32_384(pretrained=True, class_num=1000, use_weight_only=True) + quanted_output_8 = quanted_model_8(images) + print(quanted_output_8) + + # ----- 4 bits Quanted Model ----- + quanted_model_4 = Fused_ViT_large_patch32_384(pretrained=True, class_num=1000, use_weight_only=True, quant_type="weight_only_int4") + quanted_output_4 = quanted_model_4(images) + print(quanted_output_4) +``` + +## 性能数据 +### 测试代码 + +```python +from paddle.inference import create_predictor +from paddle.inference import PrecisionType +from paddle.inference import Config +from paddleclas_ops import ( + qkv_transpose_split, + transpose_remove_padding +) +import paddle +import numpy as np +import time + +from paddleclas import ( + Fused_ViT_large_patch16_224, + Fused_ViT_large_patch32_384, + ViT_large_patch16_224, + ViT_large_patch32_384, +) + +paddle.seed(42) +np.random.seed(42) + +warmup_time = 10 +test_time = 100 + +def run(predictor, img): + # copy img data to input tensor + input_names = predictor.get_input_names() + for i, name in enumerate(input_names): + input_tensor = predictor.get_input_handle(name) + input_tensor.reshape(img[i].shape) + input_tensor.copy_from_cpu(img[i]) + + # do the inference + predictor.run() + + results = [] + # get out data from output tensor + output_names = predictor.get_output_names() + for i, name in enumerate(output_names): + output_tensor = predictor.get_output_handle(name) + output_data = output_tensor.copy_to_cpu() + results.append(output_data) + return results + +def static_infer(model_file, params_file, images): + config = Config(model_file, params_file) + config.enable_memory_optim() + config.enable_use_gpu(1000, 0) + + predictor = create_predictor(config) + + # warmup + for i in range(warmup_time): + result = run(predictor, [images]) + + # test + paddle.device.cuda.synchronize() + time_begin = time.time() + for i in range(test_time): + output = run(predictor, [images]) + paddle.device.cuda.synchronize() + time_end = time.time() + print(f"input size: {images.shape}, dtype: {images.dtype}, Description: static model, Avg Time: {(time_end - time_begin) / test_time * 1000} ms") + return output + +def dynamic_infer(model, images, description): + # warmup + for i in range(warmup_time): + output = model(images) + + # test + paddle.device.cuda.synchronize() + time_begin = time.time() + for i in range(test_time): + output = model(images) + paddle.device.cuda.synchronize() + time_end = time.time() + print(f"input size: {images.shape}, dtype: {images.dtype}, Description: {description}, Avg Time: {(time_end - time_begin) / test_time * 1000} ms") + return output + +def main_fp32(): + N, C, H, W = (1, 3, 384, 384) + # fp32 + dtype = "float32" + paddle.set_default_dtype(dtype) + images = np.random.rand(N, C, H, W).astype(dtype) + images_tensor = paddle.to_tensor(images, dtype=dtype) + + # fp32 origin + origin_model = ViT_large_patch32_384(pretrained=True, class_num=1000) + origin_output = dynamic_infer(origin_model, images_tensor, "Origin") + # print(origin_output) + + # fp32 fused + fused_fp32_model = Fused_ViT_large_patch32_384(pretrained=True, class_num=1000) + fused_fp32_output = dynamic_infer(fused_fp32_model, images_tensor, "Fused fp32") + # print(fused_fp32_output) + + # fp32 static infer + model_file = "/your/path/fused_384_fp32/static_model.pdmodel" + params_file = "/your/path/fused_384_fp32/static_model.pdiparams" + static_fp32_output = static_infer(model_file, params_file, images) + # print(static_fp32_output) + +def main_fp16(): + N, C, H, W = (1, 3, 384, 384) + # fp16 + dtype = "float16" + paddle.set_default_dtype(dtype) + images = np.random.rand(N, C, H, W).astype(dtype) + images_tensor = paddle.to_tensor(images, dtype=dtype) + + # fp16 origin + # need change code in /paddleclas/ppcls/utils/save_load.py load_dygraph_pretrain + # origin_model = ViT_large_patch32_384(pretrained=True, class_num=1000) + # origin_output = dynamic_infer(origin_model, images_tensor, "Origin") + # print(origin_output) + + # fp16 fused + fused_fp16_model = Fused_ViT_large_patch32_384(pretrained=True, class_num=1000) + fused_fp16_output = dynamic_infer(fused_fp16_model, images_tensor, "Fused fp16") + # print(fused_fp16_output) + + # fp16 static infer + model_file = "/your/path/fused_384_fp16/static_model.pdmodel" + params_file = "/your/path/fused_384_fp16/static_model.pdiparams" + static_fp16_output = static_infer(model_file, params_file, images) + # print(static_fp16_output) + + # wint8 + quanted_8_model = Fused_ViT_large_patch32_384(pretrained=True, class_num=1000, use_weight_only=True) + quanted_8_output = dynamic_infer(quanted_8_model, images_tensor, "8bits Fused Quanted") + # print(quanted_8_output) + +if __name__ == "__main__": + main_fp32() + main_fp16() +``` + +### 性能数据—动态图 + +performance_dynamic + +* 此处的提升是与`naive vit`对应精度实现的对比 + * `int8`实现的对比基准为`fp16` + +### 性能数据—静态图 + +performance_static + +* 此处的提升是与`fused vit fp32`的对比 \ No newline at end of file diff --git a/docs/zh_CN/fused_vit/imgs/performance_dynamic.jpg b/docs/zh_CN/fused_vit/imgs/performance_dynamic.jpg new file mode 100644 index 0000000000..9a2af91ce2 Binary files /dev/null and b/docs/zh_CN/fused_vit/imgs/performance_dynamic.jpg differ diff --git a/docs/zh_CN/fused_vit/imgs/performance_static.jpg b/docs/zh_CN/fused_vit/imgs/performance_static.jpg new file mode 100644 index 0000000000..5e4d97dbdc Binary files /dev/null and b/docs/zh_CN/fused_vit/imgs/performance_static.jpg differ diff --git a/ppcls/arch/backbone/__init__.py b/ppcls/arch/backbone/__init__.py index bbab980624..587e3c8f7e 100644 --- a/ppcls/arch/backbone/__init__.py +++ b/ppcls/arch/backbone/__init__.py @@ -55,6 +55,7 @@ from .model_zoo.regnet import RegNetX_200MF, RegNetX_400MF, RegNetX_600MF, RegNetX_800MF, RegNetX_1600MF, RegNetX_3200MF, RegNetX_4GF, RegNetX_6400MF, RegNetX_8GF, RegNetX_12GF, RegNetX_16GF, RegNetX_32GF from .model_zoo.vision_transformer import ViT_small_patch16_224, ViT_base_patch16_224, ViT_base_patch16_384, ViT_base_patch32_384, ViT_large_patch16_224, ViT_large_patch16_384, ViT_large_patch32_384 from .model_zoo.distilled_vision_transformer import DeiT_tiny_patch16_224, DeiT_small_patch16_224, DeiT_base_patch16_224, DeiT_tiny_distilled_patch16_224, DeiT_small_distilled_patch16_224, DeiT_base_distilled_patch16_224, DeiT_base_patch16_384, DeiT_base_distilled_patch16_384 +from .model_zoo.fused_vision_transformer import Fused_ViT_small_patch16_224, Fused_ViT_base_patch16_224, Fused_ViT_base_patch16_384, Fused_ViT_base_patch32_384, Fused_ViT_large_patch16_224, Fused_ViT_large_patch16_384, Fused_ViT_large_patch32_384 from .legendary_models.swin_transformer import SwinTransformer_tiny_patch4_window7_224, SwinTransformer_small_patch4_window7_224, SwinTransformer_base_patch4_window7_224, SwinTransformer_base_patch4_window12_384, SwinTransformer_large_patch4_window7_224, SwinTransformer_large_patch4_window12_384 from .model_zoo.swin_transformer_v2 import SwinTransformerV2_tiny_patch4_window8_256, SwinTransformerV2_small_patch4_window8_256, SwinTransformerV2_base_patch4_window8_256, SwinTransformerV2_tiny_patch4_window16_256, SwinTransformerV2_small_patch4_window16_256, SwinTransformerV2_base_patch4_window16_256, SwinTransformerV2_base_patch4_window24_384, SwinTransformerV2_large_patch4_window16_256, SwinTransformerV2_large_patch4_window24_384 from .model_zoo.cswin_transformer import CSWinTransformer_tiny_224, CSWinTransformer_small_224, CSWinTransformer_base_224, CSWinTransformer_large_224, CSWinTransformer_base_384, CSWinTransformer_large_384 diff --git a/ppcls/arch/backbone/model_zoo/fused_vision_transformer.py b/ppcls/arch/backbone/model_zoo/fused_vision_transformer.py new file mode 100644 index 0000000000..af936b9859 --- /dev/null +++ b/ppcls/arch/backbone/model_zoo/fused_vision_transformer.py @@ -0,0 +1,802 @@ +# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Code was based on https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +# reference: https://arxiv.org/abs/2010.11929 + +import paddle +import paddle.nn as nn +from paddle.framework import LayerHelper, in_dynamic_mode +from paddle.nn.initializer import TruncatedNormal, Constant, Normal +from paddle.incubate.nn.functional import ( + fused_layer_norm, + fused_linear, + variable_length_memory_efficient_attention +) +from paddle.nn.quant import weight_quantize, weight_only_linear +from ....utils.save_load import get_pretrain_state_dict, get_pretrain_state_dict_from_url +from ....utils.import_utils import is_paddleclas_ops_available + +if is_paddleclas_ops_available(): + from paddleclas_ops import ( + qkv_transpose_split, + transpose_remove_padding + ) +else: + raise RuntimeError( + "The paddleclas_ops is not installed. You can read the docs and install it by hand," + "you can refer to: csrc/README.md" + ) + + +MODEL_URLS = { + "Fused_ViT_small_patch16_224": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_small_patch16_224_pretrained.pdparams", + "Fused_ViT_base_patch16_224": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_base_patch16_224_pretrained.pdparams", + "Fused_ViT_base_patch16_384": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_base_patch16_384_pretrained.pdparams", + "Fused_ViT_base_patch32_384": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_base_patch32_384_pretrained.pdparams", + "Fused_ViT_large_patch16_224": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_large_patch16_224_pretrained.pdparams", + "Fused_ViT_large_patch16_384": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_large_patch16_384_pretrained.pdparams", + "Fused_ViT_large_patch32_384": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_large_patch32_384_pretrained.pdparams", +} + +__all__ = list(MODEL_URLS.keys()) + +trunc_normal_ = TruncatedNormal(std=.02) +normal_ = Normal +zeros_ = Constant(value=0.) +ones_ = Constant(value=1.) + +def to_2tuple(x): + return tuple([x] * 2) + +def fused_act_bias_wrapper( + x, + bias=None, + dequant_scales=None, + shift=None, + smooth=None, + act_method="gelu", + compute_dtype="default", + quant_scale=-1, + quant_round_type=0, + quant_max_bound=0, + quant_min_bound=0, +): + if in_dynamic_mode(): + return paddle._C_ops.fused_bias_act( + x, + bias, + dequant_scales, + shift, + smooth, + act_method, + compute_dtype, + quant_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, + ) + helper = LayerHelper("fused_bias_act") + if x.dtype == "int32": + if compute_dtype == "bf16": + dtype = "uint16" + elif compute_dtype == "fp16": + dtype = "float16" + elif compute_dtype == "fp32": + dtype = "float32" + out = helper.create_variable_for_type_inference(dtype=dtype) + else: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + inputs = {} + inputs["x"] = x + if bias is not None: + inputs["bias"] = bias + if dequant_scales is not None: + inputs["bias"] = dequant_scales + + if shift is not None: + inputs["shift"] = shift + + if smooth is not None: + inputs["smooth"] = smooth + + attrs = { + "act_method": act_method, + "compute_dtype": compute_dtype, + "quant_scale": quant_scale, + "quant_round_type": quant_round_type, + "quant_max_bound": quant_max_bound, + "quant_min_bound": quant_min_bound, + } + + helper.append_op( + type="fused_bias_act", + inputs=inputs, + outputs={"out": out}, + attrs=attrs, + ) + return out + + +class FusedVisionTransformer(nn.Layer): + """ Fused Vision Transformer with support for patch input + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_chans=3, + class_num=1000, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=False, + qk_scale=None, + norm_layer='nn.LayerNorm', + epsilon=1e-5, + use_weight_only=False, + quant_type="weight_only_int8", + **kwargs): + super().__init__() + self.dtype = self._helper.get_default_dtype() + + self.class_num = class_num + self.num_features = self.embed_dim = embed_dim + self.epsilon = epsilon + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.depth = depth + self.scale = qk_scale or self.head_dim**-0.5 + self.norm_func = fused_layer_norm + self.linear = fused_linear + + self.use_weight_only = use_weight_only + self.quant_type = quant_type + self.create_params_type = self.get_weight_create_dtype() + self._norm_weight_dtype = "float32" + + if self.use_weight_only: + assert ( + self.quant_type == "weight_only_int8" or self.quant_type == "weight_only_int4" + ), "Expected quant_type equal to 'weight_only_int8' or 'weight_only_int4' \ + but received quant_type: {}".format( + self.quant_type + ) + self.quant_bits = int(self.quant_type[-1]) + self.weight_dtype = "int" + str(self.quant_bits) + + self.img_size = to_2tuple(img_size) + self.patch_size = to_2tuple(patch_size) + self.patch_embed_proj = nn.Conv2D( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + num_patches = (self.img_size[1] // self.patch_size[1]) * \ + (self.img_size[0] // self.patch_size[0]) + + self.pos_embed = self.create_parameter( + shape=(1, num_patches + 1, embed_dim), default_initializer=trunc_normal_ + ) + self.add_parameter("pos_embed", self.pos_embed) + self.cls_token = self.create_parameter( + shape=(1, 1, embed_dim), default_initializer=trunc_normal_ + ) + self.add_parameter("cls_token", self.cls_token) + + self.norm1_weights, self.norm1_biases = [], [] + self.attn_qkv_weights, self.attn_qkv_biases = [], [] + self.attn_proj_weights, self.attn_proj_biases = [], [] + self.norm2_weights, self.norm2_biases = [], [] + self.mlp_fc1_weights, self.mlp_fc1_biases = [], [] + self.mlp_fc2_weights, self.mlp_fc2_biases = [], [] + + if self.use_weight_only: + self.attn_qkv_weights_scale = [] + self.attn_proj_weights_scale = [] + self.mlp_fc1_weights_scale = [] + self.mlp_fc2_weights_scale = [] + + mlp_hidden_dim = int(embed_dim * mlp_ratio) + self._init_weight_shape(mlp_hidden_dim) + + for i in range(self.depth): + norm1_weight = self.create_parameter( + shape=self.norm1_weight_shape, + default_initializer=ones_, + dtype=self._norm_weight_dtype + ) + norm1_bias = self.create_parameter( + shape=self.norm1_bias_shape, + default_initializer=zeros_, + is_bias=True, + dtype=self._norm_weight_dtype + ) + + attn_qkv_weight = self.create_parameter( + shape=self.attn_qkv_weight_shape, + default_initializer=ones_, + dtype=self.create_params_type + ) + attn_qkv_bias = self.create_parameter( + shape=self.attn_qkv_bias_shape, + default_initializer=zeros_, + is_bias=True, + dtype=self.dtype + ) + + attn_proj_weight = self.create_parameter( + shape=self.attn_proj_weight_shape, + default_initializer=ones_, + dtype=self.create_params_type + ) + attn_proj_bias = self.create_parameter( + shape=self.attn_proj_bias_shape, + default_initializer=zeros_, + is_bias=True, + dtype=self.dtype + ) + + norm2_weight = self.create_parameter( + shape=self.norm2_weight_shape, + default_initializer=ones_, + dtype=self._norm_weight_dtype + ) + norm2_bias = self.create_parameter( + shape=self.norm2_bias_shape, + default_initializer=zeros_, + is_bias=True, + dtype=self._norm_weight_dtype + ) + + mlp_fc1_weight = self.create_parameter( + shape=self.mlp_fc1_weight_shape, + default_initializer=ones_, + dtype=self.create_params_type + ) + mlp_fc1_bias = self.create_parameter( + shape=self.mlp_fc1_bias_shape, + default_initializer=zeros_, + is_bias=True, + dtype=self.dtype + ) + + mlp_fc2_weight = self.create_parameter( + shape=self.mlp_fc2_weight_shape, + default_initializer=ones_, + dtype=self.create_params_type + ) + mlp_fc2_bias = self.create_parameter( + shape=self.mlp_fc2_bias_shape, + default_initializer=zeros_, + is_bias=True, + dtype=self.dtype + ) + + self.norm1_weights.append(norm1_weight) + self.norm1_biases.append(norm1_bias) + self.attn_qkv_weights.append(attn_qkv_weight) + self.attn_qkv_biases.append(attn_qkv_bias) + self.attn_proj_weights.append(attn_proj_weight) + self.attn_proj_biases.append(attn_proj_bias) + self.norm2_weights.append(norm2_weight) + self.norm2_biases.append(norm2_bias) + self.mlp_fc1_weights.append(mlp_fc1_weight) + self.mlp_fc1_biases.append(mlp_fc1_bias) + self.mlp_fc2_weights.append(mlp_fc2_weight) + self.mlp_fc2_biases.append(mlp_fc2_bias) + + self.add_parameter("blocks_{}_norm1_weight".format(i), norm1_weight) + self.add_parameter("blocks_{}_norm1_bias".format(i), norm1_bias) + self.add_parameter("blocks_{}_attn_qkv_weight".format(i), attn_qkv_weight) + self.add_parameter("blocks_{}_attn_qkv_bias".format(i), attn_qkv_bias) + self.add_parameter("blocks_{}_attn_proj_weight".format(i), attn_proj_weight) + self.add_parameter("blocks_{}_attn_proj_bias".format(i), attn_proj_bias) + self.add_parameter("blocks_{}_norm2_weight".format(i), norm2_weight) + self.add_parameter("blocks_{}_norm2_bias".format(i), norm2_bias) + self.add_parameter("blocks_{}_mlp_fc1_weight".format(i), mlp_fc1_weight) + self.add_parameter("blocks_{}_mlp_fc1_bias".format(i), mlp_fc1_bias) + self.add_parameter("blocks_{}_mlp_fc2_weight".format(i), mlp_fc2_weight) + self.add_parameter("blocks_{}_mlp_fc2_bias".format(i), mlp_fc2_bias) + + if self.use_weight_only: + attn_qkv_weight_scale = self.create_parameter( + shape=[3 * self.num_heads * self.head_dim], + default_initializer=zeros_, + dtype=self.dtype, + is_bias=False + ) + attn_proj_weight_scale = self.create_parameter( + shape=[self.embed_dim], + default_initializer=zeros_, + dtype=self.dtype, + is_bias=False + ) + mlp_fc1_weight_scale = self.create_parameter( + shape=[mlp_hidden_dim], + default_initializer=zeros_, + dtype=self.dtype, + is_bias=False + ) + mlp_fc2_weight_scale = self.create_parameter( + shape=[self.embed_dim], + default_initializer=zeros_, + dtype=self.dtype, + is_bias=False + ) + + self.attn_qkv_weights_scale.append(attn_qkv_weight_scale) + self.attn_proj_weights_scale.append(attn_proj_weight_scale) + self.mlp_fc1_weights_scale.append(mlp_fc1_weight_scale) + self.mlp_fc2_weights_scale.append(mlp_fc2_weight_scale) + + self.add_parameter("blocks_{}_attn_qkv_weight_scale".format(i), attn_qkv_weight_scale) + self.add_parameter("blocks_{}_attn_proj_weight_scale".format(i), attn_proj_weight_scale) + self.add_parameter("blocks_{}_mlp_fc1_weight_scale".format(i), mlp_fc1_weight_scale) + self.add_parameter("blocks_{}_mlp_fc2_weight_scale".format(i), mlp_fc2_weight_scale) + + self.norm_weight = self.create_parameter( + shape=[embed_dim], + default_initializer=ones_, + dtype=self._norm_weight_dtype + ) + self.norm_bias = self.create_parameter( + shape=[embed_dim], + is_bias=True, + default_initializer=zeros_, + dtype=self._norm_weight_dtype + ) + self.head_weight = self.create_parameter( + shape=[embed_dim, class_num], + default_initializer=ones_, + dtype=self.dtype + ) + self.head_bias = self.create_parameter( + shape=[class_num], + is_bias=True, + default_initializer=zeros_, + dtype=self.dtype + ) + + def _init_weight_shape(self, mlp_hidden_dim): + self.norm1_weight_shape = [self.embed_dim] + self.norm1_bias_shape = [self.embed_dim] + self.attn_qkv_weight_shape = ( + [3 * self.num_heads * self.head_dim, self.embed_dim] + if self.use_weight_only + else [self.embed_dim, 3 * self.num_heads * self.head_dim, ] + ) + self.attn_qkv_bias_shape = [3 * self.num_heads * self.head_dim] + self.attn_proj_weight_shape = ( + [self.embed_dim, self.num_heads * self.head_dim] + if self.use_weight_only + else [self.num_heads * self.head_dim, self.embed_dim] + ) + self.attn_proj_bias_shape = [self.num_heads * self.head_dim] + self.norm2_weight_shape = [self.embed_dim] + self.norm2_bias_shape = [self.embed_dim] + self.mlp_fc1_weight_shape = ( + [mlp_hidden_dim, self.embed_dim] + if self.use_weight_only + else [self.embed_dim, mlp_hidden_dim] + ) + self.mlp_fc1_bias_shape = [mlp_hidden_dim] + self.mlp_fc2_weight_shape = ( + [self.embed_dim, mlp_hidden_dim] + if self.use_weight_only + else [mlp_hidden_dim, self.embed_dim] + ) + self.mlp_fc2_bias_shape = [self.embed_dim] + + if self.use_weight_only and self.quant_bits == 4: + self.attn_qkv_weight_shape[0] //= 2 + self.attn_proj_weight_shape[0] //= 2 + self.mlp_fc1_weight_shape[0] //= 2 + self.mlp_fc2_weight_shape[0] //= 2 + + def get_weight_create_dtype(self): + if self.use_weight_only: + return "int8" + else: + return self.dtype + + @paddle.no_grad() + def set_state_dict(self, state_dict): + self.pos_embed.set_value(state_dict["pos_embed"].astype(self.dtype)) + self.cls_token.set_value(state_dict["cls_token"].astype(self.dtype)) + self.patch_embed_proj.weight.set_value(state_dict["patch_embed.proj.weight"].astype(self.dtype)) + self.patch_embed_proj.bias.set_value(state_dict["patch_embed.proj.bias"].astype(self.dtype)) + for i in range(self.depth): + self.norm1_weights[i].set_value(state_dict["blocks.{}.norm1.weight".format(i)].astype(self._norm_weight_dtype)) + self.norm1_biases[i].set_value(state_dict["blocks.{}.norm1.bias".format(i)].astype(self._norm_weight_dtype)) + + if self.use_weight_only: + attn_qkv_weight_tensor = paddle.to_tensor(state_dict["blocks.{}.attn.qkv.weight".format(i)].astype(self.dtype)) + attn_qkv_quanted_weight_tensor, attn_qkv_weight_scale_tensor = weight_quantize( + attn_qkv_weight_tensor, algo=self.quant_type + ) + self.attn_qkv_weights[i].set_value(attn_qkv_quanted_weight_tensor) + self.attn_qkv_weights_scale[i].set_value(attn_qkv_weight_scale_tensor) + else: + self.attn_qkv_weights[i].set_value(state_dict["blocks.{}.attn.qkv.weight".format(i)].astype(self.dtype)) + self.attn_qkv_biases[i].set_value(state_dict["blocks.{}.attn.qkv.bias".format(i)].astype(self.dtype)) + + if self.use_weight_only: + attn_proj_weight_tensor = paddle.to_tensor(state_dict["blocks.{}.attn.proj.weight".format(i)].astype(self.dtype)) + attn_proj_quanted_weight_tensor, attn_proj_weight_scale_tensor = weight_quantize( + attn_proj_weight_tensor, algo=self.quant_type + ) + self.attn_proj_weights[i].set_value(attn_proj_quanted_weight_tensor) + self.attn_proj_weights_scale[i].set_value(attn_proj_weight_scale_tensor) + else: + self.attn_proj_weights[i].set_value(state_dict["blocks.{}.attn.proj.weight".format(i)].astype(self.dtype)) + self.attn_proj_biases[i].set_value(state_dict["blocks.{}.attn.proj.bias".format(i)].astype(self.dtype)) + + self.norm2_weights[i].set_value(state_dict["blocks.{}.norm2.weight".format(i)].astype(self._norm_weight_dtype)) + self.norm2_biases[i].set_value(state_dict["blocks.{}.norm2.bias".format(i)].astype(self._norm_weight_dtype)) + + if self.use_weight_only: + mlp_fc1_weight_tensor = paddle.to_tensor(state_dict["blocks.{}.mlp.fc1.weight".format(i)].astype(self.dtype)) + mlp_fc1_quanted_weight_tensor, mlp_fc1_weight_scale_tensor = weight_quantize( + mlp_fc1_weight_tensor, algo=self.quant_type + ) + self.mlp_fc1_weights[i].set_value(mlp_fc1_quanted_weight_tensor) + self.mlp_fc1_weights_scale[i].set_value(mlp_fc1_weight_scale_tensor) + else: + self.mlp_fc1_weights[i].set_value(state_dict["blocks.{}.mlp.fc1.weight".format(i)].astype(self.dtype)) + self.mlp_fc1_biases[i].set_value(state_dict["blocks.{}.mlp.fc1.bias".format(i)].astype(self.dtype)) + + if self.use_weight_only: + mlp_fc2_weight_tensor = paddle.to_tensor(state_dict["blocks.{}.mlp.fc2.weight".format(i)].astype(self.dtype)) + mlp_fc2_quanted_weight_tensor, mlp_fc2_weight_scale_tensor = weight_quantize( + mlp_fc2_weight_tensor, algo=self.quant_type + ) + self.mlp_fc2_weights[i].set_value(mlp_fc2_quanted_weight_tensor) + self.mlp_fc2_weights_scale[i].set_value(mlp_fc2_weight_scale_tensor) + else: + self.mlp_fc2_weights[i].set_value(state_dict["blocks.{}.mlp.fc2.weight".format(i)].astype(self.dtype)) + self.mlp_fc2_biases[i].set_value(state_dict["blocks.{}.mlp.fc2.bias".format(i)].astype(self.dtype)) + + self.norm_weight.set_value(state_dict["norm.weight"].astype(self._norm_weight_dtype)) + self.norm_bias.set_value(state_dict["norm.bias"].astype(self._norm_weight_dtype)) + self.head_weight.set_value(state_dict["head.weight"].astype(self.dtype)) + self.head_bias.set_value(state_dict["head.bias"].astype(self.dtype)) + + def compute_layernorm_before_qkv(self, src, i): + if i == 0: + ln_out = self.norm_func(src, self.norm1_weights[i], self.norm1_biases[i], self.epsilon) + else: + ln_out = src + + return ln_out + + def compute_qkv_linear(self, ln_out, i): + if self.use_weight_only: + return weight_only_linear( + ln_out, + weight=self.attn_qkv_weights[i], + bias=self.attn_qkv_biases[i], + weight_scale=self.attn_qkv_weights_scale[i], + weight_dtype=self.weight_dtype + ) + + if float(paddle.version.cuda()) < 11.6: + qkv_out = paddle.matmul(ln_out, self.attn_qkv_weights[i]) + if self.attn_qkv_biases[i] is not None: + qkv_out = paddle.add(qkv_out, self.attn_qkv_biases[i]) + return qkv_out + else: + return self.linear(ln_out, self.attn_qkv_weights[i], self.attn_qkv_biases[i]) + + def compute_qkv(self, src, residual_input, i): + ln_out = self.compute_layernorm_before_qkv(src, i) + qkv_out = self.compute_qkv_linear(ln_out, i) + return qkv_out, residual_input + + def compute_fmha(self, qkv_out, padding_offset, seq_lens, input_ids, i): + q_out, k_out, v_out = qkv_transpose_split( + qkv_out, padding_offset, seq_lens, input_ids, self.num_heads, self.head_dim + ) + # cutlass fmha + qktv_out = variable_length_memory_efficient_attention( + q_out, + k_out, + v_out, + seq_lens, + seq_lens, + None, + scale=self.scale + ) + return transpose_remove_padding(qktv_out, seq_lens, padding_offset) + + def compute_out_linear(self, fmha_out, i): + if self.use_weight_only: + return weight_only_linear( + fmha_out, + weight=self.attn_proj_weights[i], + weight_scale=self.attn_proj_weights_scale[i], + weight_dtype=self.weight_dtype + ) + + return paddle.matmul(fmha_out, self.attn_proj_weights[i]) + + def compute_attn(self, qkv_out, padding_offset, seq_lens, input_ids, i): + fmha_out = self.compute_fmha(qkv_out, padding_offset, seq_lens, input_ids, i) + out_linear_out = self.compute_out_linear(fmha_out, i) + return out_linear_out + + def compute_ffn_layernorm(self, out_linear_out, residual_input, i): + """ + tmp_out = layernorm(out_linear_out + attn_proj_biases[i] + residual_input) + """ + norm_out = self.norm_func( + out_linear_out, + norm_weight=self.norm2_weights[i], + norm_bias=self.norm2_biases[i], + epsilon=self.epsilon, + bias=self.attn_proj_biases[i], + residual=residual_input, + ) + tmp_out, residual_input = norm_out[0], norm_out[1] + return tmp_out, residual_input + + def compute_ffn1(self, tmp_out, i): + if self.use_weight_only: + return weight_only_linear( + tmp_out, + weight=self.mlp_fc1_weights[i], + weight_scale=self.mlp_fc1_weights_scale[i], + weight_dtype=self.weight_dtype, + ) + + return paddle.matmul(tmp_out, self.mlp_fc1_weights[i]) + + def compute_ffn2(self, ffn1_out, i): + if self.use_weight_only: + return weight_only_linear( + ffn1_out, + weight=self.mlp_fc2_weights[i], + weight_scale=self.mlp_fc2_weights_scale[i], + weight_dtype=self.weight_dtype, + ) + + return paddle.matmul(ffn1_out, self.mlp_fc2_weights[i]) + + def compute_bias_residual_layernorm(self, ffn2_out, residual_input, i, num_layers): + if i != num_layers - 1: + norm_out = self.norm_func( + ffn2_out, + norm_weight=self.norm1_weights[i + 1], + norm_bias=self.norm1_biases[i + 1], + epsilon=self.epsilon, + bias=self.mlp_fc2_biases[i], + residual=residual_input + ) + tmp_out, residual_input = norm_out[0], norm_out[1] + else: + tmp_out = self.norm_func( + ffn2_out, + norm_weight=self.norm_weight, + norm_bias=self.norm_bias, + epsilon=self.epsilon, + bias=self.mlp_fc2_biases[i], + residual=residual_input + )[0] + return tmp_out, residual_input + + def compute_head_linear(self, ln_out): + if float(paddle.version.cuda()) < 11.6: + qkv_out = paddle.matmul(ln_out, self.head_weight) + if self.head_bias is not None: + qkv_out = paddle.add(qkv_out, self.head_bias) + return qkv_out + else: + return self.linear(ln_out, self.head_weight, self.head_bias) + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.patch_embed_proj(x).flatten(2).transpose((0, 2, 1)) + + cls_tokens = self.cls_token.expand((B, -1, -1)) + x = paddle.concat((cls_tokens, x), axis=1) + x = x + self.pos_embed + + batch, seq_len, _ = x.shape + padding_offset = paddle.zeros([seq_len * batch], dtype='int32') + seq_lens = paddle.full([batch], seq_len, dtype='int32') + input_ids = paddle.full([batch, seq_len], 0, dtype='int32') + + x = x.reshape([-1, x.shape[-1]]) + residual_input = x + for i in range(self.depth): + qkv_out, residual_input = self.compute_qkv(x, residual_input, i) + out_linear_out = self.compute_attn( + qkv_out, + padding_offset, + seq_lens, + input_ids, + i + ) + + # qkv proj linear + layernorm2 + tmp_out, residual_input = self.compute_ffn_layernorm(out_linear_out, residual_input, i) + + # mlp ffn1 matmul + ffn1_out = self.compute_ffn1(tmp_out, i) + ffn1_out = fused_act_bias_wrapper(ffn1_out, self.mlp_fc1_biases[i]) + + # mlp ffn2 matmul + ffn2_out = self.compute_ffn2(ffn1_out, i) + + # layernorm1 + residual_add_bias + tmp_out, residual_input = self.compute_bias_residual_layernorm(ffn2_out, residual_input, i, self.depth) + x = tmp_out + x = x.reshape((batch, seq_len, -1)) + index = paddle.zeros([1], dtype="int32") + x = paddle.index_select(x, index, axis=1).reshape((batch, self.embed_dim)) + x = self.compute_head_linear(x) + + return x + + +def _load_pretrained(pretrained, model, model_url, use_ssld=False): + if pretrained is False: + pass + elif pretrained is True: + weight_state_dict = get_pretrain_state_dict_from_url(model_url, use_ssld=use_ssld) + model.set_state_dict(weight_state_dict) + elif isinstance(pretrained, str): + weight_state_dict = get_pretrain_state_dict(pretrained) + model.set_state_dict(weight_state_dict) + else: + raise RuntimeError( + "pretrained type is not available. Please use `string` or `boolean` type." + ) + + +def Fused_ViT_small_patch16_224(pretrained=False, use_ssld=False, **kwargs): + model = FusedVisionTransformer( + patch_size=16, + embed_dim=768, + depth=8, + num_heads=8, + mlp_ratio=3, + qk_scale=768**-0.5, + **kwargs) + _load_pretrained( + pretrained, + model, + MODEL_URLS["Fused_ViT_small_patch16_224"], + use_ssld=use_ssld) + return model + + +def Fused_ViT_base_patch16_224(pretrained=False, use_ssld=False, **kwargs): + model = FusedVisionTransformer( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + epsilon=1e-6, + **kwargs) + _load_pretrained( + pretrained, + model, + MODEL_URLS["Fused_ViT_base_patch16_224"], + use_ssld=use_ssld) + return model + + +def Fused_ViT_base_patch16_384(pretrained=False, use_ssld=False, **kwargs): + model = FusedVisionTransformer( + img_size=384, + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + epsilon=1e-6, + **kwargs) + _load_pretrained( + pretrained, + model, + MODEL_URLS["Fused_ViT_base_patch16_384"], + use_ssld=use_ssld) + return model + + +def Fused_ViT_base_patch32_384(pretrained=False, use_ssld=False, **kwargs): + model = FusedVisionTransformer( + img_size=384, + patch_size=32, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + epsilon=1e-6, + **kwargs) + _load_pretrained( + pretrained, + model, + MODEL_URLS["Fused_ViT_base_patch32_384"], + use_ssld=use_ssld) + return model + + +def Fused_ViT_large_patch16_224(pretrained=False, use_ssld=False, **kwargs): + model = FusedVisionTransformer( + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + epsilon=1e-6, + **kwargs) + _load_pretrained( + pretrained, + model, + MODEL_URLS["Fused_ViT_large_patch16_224"], + use_ssld=use_ssld) + return model + + +def Fused_ViT_large_patch16_384(pretrained=False, use_ssld=False, **kwargs): + model = FusedVisionTransformer( + img_size=384, + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + epsilon=1e-6, + **kwargs) + _load_pretrained( + pretrained, + model, + MODEL_URLS["Fused_ViT_large_patch16_384"], + use_ssld=use_ssld) + return model + + +def Fused_ViT_large_patch32_384(pretrained=False, use_ssld=False, **kwargs): + model = FusedVisionTransformer( + img_size=384, + patch_size=32, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + epsilon=1e-6, + **kwargs) + _load_pretrained( + pretrained, + model, + MODEL_URLS["Fused_ViT_large_patch32_384"], + use_ssld=use_ssld) + return model \ No newline at end of file diff --git a/ppcls/utils/import_utils.py b/ppcls/utils/import_utils.py new file mode 100644 index 0000000000..bc1f772460 --- /dev/null +++ b/ppcls/utils/import_utils.py @@ -0,0 +1,33 @@ +# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util + +def is_package_available(package_name: str) -> bool: + """check if the package is avaliable + Args: + package_name (str): the installed package name + Returns: + bool: the existence of installed package + """ + package_spec = importlib.util.find_spec(package_name) + return package_spec is not None and package_spec.has_location + + +def is_paddleclas_ops_available() -> bool: + """check if `paddleclas_ops` ia avaliable + Returns: + bool: if `paddleclas_ops` is avaliable + """ + return is_package_available("paddleclas_ops") \ No newline at end of file diff --git a/ppcls/utils/save_load.py b/ppcls/utils/save_load.py index a40f235f87..94412ab188 100644 --- a/ppcls/utils/save_load.py +++ b/ppcls/utils/save_load.py @@ -117,6 +117,23 @@ def load_distillation_model(model, pretrained_model): pretrained_model)) +def get_pretrain_state_dict(path=None): + if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): + raise ValueError("Model pretrain path {}.pdparams does not " + "exists.".format(path)) + param_state_dict = paddle.load(path + ".pdparams") + return param_state_dict + + +def get_pretrain_state_dict_from_url(pretrained_url, use_ssld=False): + if use_ssld: + pretrained_url = pretrained_url.replace("_pretrained", + "_ssld_pretrained") + local_weight_path = get_weights_path_from_url(pretrained_url).replace( + ".pdparams", "") + return get_pretrain_state_dict(path=local_weight_path) + + def init_model(config, net, optimizer=None,