From bb4db9b1d883d374a330699a85fef61b23539ff5 Mon Sep 17 00:00:00 2001 From: zkh2016 Date: Mon, 9 May 2022 07:52:18 +0000 Subject: [PATCH 1/3] test sparse model --- .../unittests/test_sparse_middle_extractor.py | 324 ++++++++++++++++++ .../tests/unittests/test_sparse_mnist.py | 126 +++++++ 2 files changed, 450 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_sparse_middle_extractor.py create mode 100644 python/paddle/fluid/tests/unittests/test_sparse_mnist.py diff --git a/python/paddle/fluid/tests/unittests/test_sparse_middle_extractor.py b/python/paddle/fluid/tests/unittests/test_sparse_middle_extractor.py new file mode 100644 index 0000000000000..ae52b4a413336 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sparse_middle_extractor.py @@ -0,0 +1,324 @@ +import paddle +import paddle.nn as nn +import paddle.sparse as sparse +from paddle.fluid.framework import _test_eager_guard +import time +import numpy as np +import torch +import spconv.pytorch as spconv +import inspect + +class MiddleExtractor(paddle.nn.Layer): + def __init__(self, + #output_shape, + use_norm=True, + num_input_features=128, + num_filters_down1=[64], + num_filters_down2=[64, 64], + name='MiddleExtractor'): + super(MiddleExtractor, self).__init__() + self.name = name + if not use_norm: + self.middle_conv = paddle.nn.Sequential( + #nn.Pad3D(1), + nn.Conv3D(num_input_features, 64, 3, stride=(2, 1, 1), data_format='NDHWC'), + #nn.BatchNorm3D(64, epsilon=1e-3, momentum=0.001, data_format='NDHWC'), + nn.ReLU(), + #nn.Pad3D([1, 1, 1, 1, 0, 0]), + nn.Conv3D(64, 64, 3, stride=(1, 1, 1), data_format='NDHWC'), + #nn.BatchNorm3D(64, epsilon=1e-3, momentum=0.001, data_format='NDHWC'), + nn.ReLU(), + #nn.Pad3D(1), + nn.Conv3D(64, 64, 3, stride=(2, 1, 1), data_format='NDHWC'), + #nn.BatchNorm3D(64, epsilon=1e-3, momentum=0.001, data_format='NDHWC'), + nn.ReLU(), + ) + else: + self.middle_conv = paddle.nn.Sequential( + #nn.Pad3D(1), + nn.Conv3D(num_input_features, 64, 3, stride=(2, 1, 1), data_format='NDHWC'), + nn.BatchNorm3D(64, epsilon=1e-3, momentum=0.001, data_format='NDHWC'), + nn.ReLU(), + #nn.Pad3D([1, 1, 1, 1, 0, 0]), + nn.Conv3D(64, 64, 3, stride=(1, 1, 1), data_format='NDHWC'), + nn.BatchNorm3D(64, epsilon=1e-3, momentum=0.001, data_format='NDHWC'), + nn.ReLU(), + #nn.Pad3D(1), + nn.Conv3D(64, 64, 3, stride=(2, 1, 1), data_format='NDHWC'), + nn.BatchNorm3D(64, epsilon=1e-3, momentum=0.001, data_format='NDHWC'), + nn.ReLU(), + ) + def forward(self, x): + return self.middle_conv(x) + + +def get_pos_to_kw_map(func): + pos_to_kw = {} + fsig = inspect.signature(func) + pos = 0 + for name, info in fsig.parameters.items(): + if info.kind is info.POSITIONAL_OR_KEYWORD: + pos_to_kw[pos] = name + pos += 1 + return pos_to_kw + +def change_default_args(**kwargs): + def layer_wrapper(layer_class): + class DefaultArgLayer(layer_class): + def __init__(self, *args, **kw): + pos_to_kw = get_pos_to_kw_map(layer_class.__init__) + kw_to_pos = {kw: pos for pos, kw in pos_to_kw.items()} + for key, val in kwargs.items(): + if key not in kw and kw_to_pos[key] > len(args): + kw[key] = val + super().__init__(*args, **kw) + + return DefaultArgLayer + + return layer_wrapper + +class Empty(torch.nn.Module): + def __init__(self, *args, **kwargs): + super(Empty, self).__init__() + + def forward(self, *args, **kwargs): + if len(args) == 1: + return args[0] + elif len(args) == 0: + return None + return args + +class SpconvMiddleExtractor(torch.nn.Module): + def __init__(self, + #output_shape, + use_norm=True, + num_input_features=128, + num_filters_down1=[64], + num_filters_down2=[64, 64], + name='SpconvMiddleExtractor'): + super(SpconvMiddleExtractor, self).__init__() + if use_norm: + BatchNorm1d = change_default_args( + eps=1e-3, momentum=0.01)(torch.nn.BatchNorm1d) + Linear = change_default_args(bias=False)(nn.Linear) + else: + BatchNorm1d = Empty + Linear = change_default_args(bias=True)(nn.Linear) + + middle_layers = [] + + num_filters = [num_input_features] + num_filters_down1 + filters_pairs_d1 = [[num_filters[i], num_filters[i + 1]] + for i in range(len(num_filters) - 1)] + + for i, o in filters_pairs_d1: + middle_layers.append(spconv.SubMConv3d(i, o, 3, bias=False)) + if use_norm: + #middle_layers.append(BatchNorm1d(o)) + middle_layers.append(torch.nn.BatchNorm1d(o, eps=1e-3, momentum=0.01)) + middle_layers.append(torch.nn.ReLU()) + + middle_layers.append( + spconv.SparseConv3d( + num_filters[-1], + num_filters[-1], (3, 1, 1), (2, 1, 1), + bias=False)) + + if use_norm: + #middle_layers.append( + # BatchNorm1d(num_filters[-1])) + middle_layers.append(torch.nn.BatchNorm1d(o, eps=1e-3, momentum=0.01)) + middle_layers.append(torch.nn.ReLU()) + + + # assert len(num_filters_down2) > 0 + if len(num_filters_down1) == 0: + num_filters = [num_filters[-1]] + num_filters_down2 + else: + num_filters = [num_filters_down1[-1]] + num_filters_down2 + filters_pairs_d2 = [[num_filters[i], num_filters[i + 1]] + for i in range(len(num_filters) - 1)] + for i, o in filters_pairs_d2: + middle_layers.append(spconv.SubMConv3d(i, o, 3, bias=False)) + if use_norm: + #middle_layers.append(BatchNorm1d(o)) + middle_layers.append(torch.nn.BatchNorm1d(o, eps=1e-3, momentum=0.01)) + middle_layers.append(torch.nn.ReLU()) + middle_layers.append( + spconv.SparseConv3d( + num_filters[-1], + num_filters[-1], (3, 1, 1), (2, 1, 1), + bias=False)) + if use_norm: + #middle_layers.append( + #BatchNorm1d(num_filters[-1])) + middle_layers.append(torch.nn.BatchNorm1d(o, eps=1e-3, momentum=0.01)) + middle_layers.append(torch.nn.ReLU()) + #middle_layers.append(scn.SparseToDense(3, num_filters[-1])) + middle_layers.append(spconv.ToDense()) + self.middle_conv = spconv.SparseSequential(*middle_layers) + + def forward(self, x): + out = self.middle_conv(x) + return out + +class SparseMiddleExtractor(paddle.nn.Layer): + def __init__(self, + #output_shape, + use_norm=True, + num_input_features=128, + num_filters_down1=[64], + num_filters_down2=[64, 64], + name='SparseMiddleExtractor'): + super(SparseMiddleExtractor, self).__init__() + self.name = name + + middle_layers = [] + num_filters = [num_input_features] + num_filters_down1 + filters_pairs_d1 = [[num_filters[i], num_filters[i + 1]] for i in range(len(num_filters) - 1)] + for i, o in filters_pairs_d1: + middle_layers.append(sparse.SubmConv3D(i, o, 3, bias_attr=False)) + if use_norm: + middle_layers.append(sparse.BatchNorm(o, epsilon=1e-3, momentum=0.01)) + middle_layers.append(sparse.ReLU()) + + middle_layers.append(sparse.Conv3D(num_filters[-1], num_filters[-1], (3, 1, 1), (2, 1, 1), bias_attr=False)) + + if use_norm: + middle_layers.append(sparse.BatchNorm(num_filters[-1], epsilon=1e-3, momentum=0.01)) + middle_layers.append(sparse.ReLU()) + + + if len(num_filters_down1) == 0: + num_filters = [num_filters[-1]] + num_filters_down2 + else: + num_filters = [num_filters_down1[-1]] + num_filters_down2 + + filters_pairs_d2 = [[num_filters[i], num_filters[i + 1]] for i in range(len(num_filters) - 1)] + + for i, o in filters_pairs_d2: + middle_layers.append(sparse.SubmConv3D(i, o, 3, bias_attr=False)) + if use_norm: + middle_layers.append(sparse.BatchNorm(o, epsilon=1e-3, momentum=0.01)) + middle_layers.append(sparse.ReLU()) + + middle_layers.append(sparse.Conv3D(num_filters[-1], num_filters[-1], (3, 1, 1), (2, 1, 1), bias_attr=False)) + if use_norm: + middle_layers.append(sparse.BatchNorm(num_filters[-1], epsilon=1e-3, momentum=0.01)) + middle_layers.append(sparse.ReLU()) + + self.middle_conv = nn.Sequential(*middle_layers) + + def forward(self, x): + sparse_out = self.middle_conv(x) + #return sparse_out + return sparse_out.to_dense() + + +def test(): + paddle.seed(0) + with _test_eager_guard(): + in_channels = 128 + # Note: 1. paddle的BatchNorm1D的输入shape不能太大,否则报CUDNN_STATUS_NOT_SUPPORTED. + shape = [20, 40, 100] + batch_size = 1 + sparsity = 0.95 + + full_shape = [batch_size] + shape + [in_channels] + print(full_shape) + + total_elements = np.prod(shape) + nnz = int(total_elements * (1-sparsity)) + print("nnz=", nnz) + + #product indices + indices = [] + for i in range(4): + indices.append(paddle.randint(0, full_shape[i], [1, nnz])) + + indices = paddle.concat(indices) + #product values + values = paddle.randn((nnz, in_channels)) + + sparse_x = paddle.sparse.sparse_coo_tensor(indices, values, shape=full_shape) + + dense_x = sparse_x.to_dense() + + #spconv + device = torch.device("cuda") + torch_x = torch.tensor(dense_x.numpy(), device=device) + + spconv_x = spconv.SparseConvTensor.from_dense(torch_x) + + #whether to use batch_norm + use_norm = True + + dense_model = MiddleExtractor(use_norm=use_norm, num_input_features=in_channels) + spconv_model = SpconvMiddleExtractor(use_norm=use_norm, num_input_features=in_channels).to(device) + sparse_model = SparseMiddleExtractor(use_norm=use_norm, num_input_features=in_channels) + layer_nums = len(sparse_model.middle_conv) + block_size = 3 if use_norm else 2 + layer_nums = int(layer_nums / block_size) + + for i in range(0, layer_nums): + weight = paddle.to_tensor(spconv_model.middle_conv[i * block_size].weight.detach().cpu().numpy()) + sparse_model.middle_conv[i * block_size].weight.set_value(paddle.transpose(paddle.to_tensor(weight), [1,2,3,4,0])) + if use_norm: + bn_weight = paddle.to_tensor(spconv_model.middle_conv[i*block_size + 1].weight.detach().cpu().numpy()) + sparse_model.middle_conv[i * block_size + 1].weight.set_value(bn_weight) + + print(dense_model) + print(sparse_model) + print(spconv_model) + paddle.device.cuda.synchronize() + + #warm up + dense_x.stop_gradient=True + out1 = dense_model(dense_x) + paddle.device.cuda.synchronize() + sparse_x.stop_gradient=True + out2 = sparse_model(sparse_x) + paddle.device.cuda.synchronize() + spconv_x.features.required_grad=False + out3 = spconv_model(spconv_x) + torch.cuda.synchronize(device) + #warm up + + t0 = time.time() + #padde dense + dense_x.stop_gradient=False + out1 = dense_model(dense_x) + out1.backward(out1) + paddle.device.cuda.synchronize() + t1 = time.time() + + #padde sparse + sparse_x.stop_gradient=False + out2 = sparse_model(sparse_x) + out2.backward(out2) + paddle.device.cuda.synchronize() + t2 = time.time() + + #spconv + spconv_x.features.required_grad=True + spconv_x.features.requires_grad_() + out3 = spconv_model(spconv_x) + out3.backward(out3) + torch.cuda.synchronize(device) + t3 = time.time() + + # Note 2. sparse的BatchNorm底层是使用paddle.nn.BatchNorm1D对values进行bn计算,测试发现BatchNorm1D的性能比BatchNorm3D差,因此use_norm=True的情况,需要更高的稀疏度才能比dense的快 + # Note 3. 只跑前向,sparse的耗时和spconv接近,稀疏度越高sparse的性能越好,当前方式测试前向+反向,spconv的耗时很高, 原因未知 + print("dense time: ", t1 - t0) + print("sparse time: ", t2 - t1) + print("spconv time: ", t3 - t2) + + # Note 4. paddle和torch的BN存在误差,测试shape=(4000, 64)的随机输入,单层BN前向误差在1e-6, 反向误差在1e-4 + #verify the forward calculation result + assert np.allclose(paddle.transpose(out2, [0, 4, 1, 2, 3]).numpy(), out3.detach().cpu().numpy(), atol=1e-4, rtol=1e-4) + + #verify the backward calculation result + assert np.allclose(spconv_x.features.grad.cpu().numpy(), + sparse_x.grad.values().numpy(), atol=1e-3, rtol=1e-3) + +test() diff --git a/python/paddle/fluid/tests/unittests/test_sparse_mnist.py b/python/paddle/fluid/tests/unittests/test_sparse_mnist.py new file mode 100644 index 0000000000000..3589dc83090f3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sparse_mnist.py @@ -0,0 +1,126 @@ +import paddle +from paddle.vision.transforms import Compose, Normalize, ToTensor +from paddle.fluid.framework import _test_eager_guard +import time + +paddle.disable_static() +#transform = Compose([Normalize(mean=[127.5], +# std=[127.5], +# data_format='CHW')]) +transform = Compose([ToTensor()]) +# 使用transform对数据集做归一化 +print('download training data and load training data') +train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform) +test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform) +print('load finished') + +import numpy as np +#import matplotlib.pyplot as plt +train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1] +train_data0 = train_data0.reshape([28,28]) +#plt.figure(figsize=(2,2)) +#plt.imshow(train_data0, cmap=plt.cm.binary) +print('train_data0 label is: ' + str(train_label_0)) + + +import paddle +import paddle.nn.functional as F +class SparseLeNet(paddle.nn.Layer): + def __init__(self): + super(SparseLeNet, self).__init__() + #self.bn = paddle.sparse.BatchNorm(1) + self.conv1 = paddle.sparse.Conv3D(in_channels=1, out_channels=6, kernel_size=[1, 5, 5], stride=[1, 1, 1], padding=[0, 2, 2]) + self.relu1 = paddle.sparse.ReLU() + self.pool1 = paddle.sparse.MaxPool3D(kernel_size=[1, 2, 2], stride=[1, 2, 2]) + self.conv2 = paddle.sparse.Conv3D(in_channels=6, out_channels=16, kernel_size=[1, 5, 5], stride=[1, 1, 1]) + self.relu2 = paddle.sparse.ReLU() + self.pool2 = paddle.sparse.MaxPool3D(kernel_size=[1, 2, 2], stride=[1, 2, 2]) + + self.fc1 = paddle.nn.Linear(16*5*5, 120) + self.fc2 = paddle.nn.Linear(120, 84) + self.fc3 = paddle.nn.Linear(84, 10) + + def forward(self, x): + #x = self.bn(x) + x = self.conv1(x) + x = self.relu1(x) + x = self.pool1(x) + x = self.conv2(x) + x = self.relu2(x) + x = self.pool2(x) + x = x.to_dense() + + x = paddle.flatten(x, start_axis=1, stop_axis=-1) + x = self.fc1(x) + x = paddle.nn.functional.relu(x) + x = self.fc2(x) + x = paddle.nn.functional.relu(x) + x = self.fc3(x) + return x + +import paddle.nn.functional as F +train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True) +# 加载训练集 batch_size 设为 64 +# sparse 训练 + +def prepare_data(x_data): + x_data = paddle.transpose(x_data, perm=[0, 2, 3, 1]) + x_data = paddle.reshape(x_data, [x_data.shape[0], 1, x_data.shape[1], x_data.shape[2], x_data.shape[3]]) + return x_data + +def sparse_train(model): + model.train() + epochs = 2 + optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters()) + # 用Adam作为优化函数 + for epoch in range(epochs): + for batch_id, data in enumerate(train_loader()): + x_data = data[0] + y_data = data[1] + x_data = prepare_data(x_data) + x_data = x_data.to_sparse_coo(4) + x_data.stop_gradient=False + predicts = model(x_data) + loss = F.cross_entropy(predicts, y_data) + # 计算损失 + acc = paddle.metric.accuracy(predicts, y_data) + loss.backward() + if batch_id % 300 == 0: + print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id, loss.numpy(), acc.numpy())) + optim.step() + optim.clear_grad() + +test_loader = paddle.io.DataLoader(test_dataset, places=paddle.CPUPlace(), batch_size=64) +# 加载测试数据集 +def test(model): + model.eval() + batch_size = 64 + for batch_id, data in enumerate(test_loader()): + x_data = data[0] + y_data = data[1] + x_data = prepare_data(x_data) + x_data = x_data.to_sparse_coo(4) + predicts = model(x_data) + # 获取预测结果 + loss = F.cross_entropy(predicts, y_data) + acc = paddle.metric.accuracy(predicts, y_data) + if batch_id % 20 == 0: + print("batch_id: {}, loss is: {}, acc is: {}".format(batch_id, loss.numpy(), acc.numpy())) + +with _test_eager_guard(): + sparse_model = SparseLeNet() + print(sparse_model) + + t0 = time.time() + sparse_train(sparse_model) + t1 = time.time() + print("spare time:", t1-t0) + test(sparse_model) + #x = paddle.randn((1, 1,28,28,1)) + #x.stop_gradient=False + #sparse_x = x.to_sparse_coo(4) + #print("sparse_x values shape:", sparse_x.values().shape) + #out = sparse_model(sparse_x) + #out.backward(out) + #print("end") + From f5f94413899c1d203f517ed7680ad31d3d17e6b0 Mon Sep 17 00:00:00 2001 From: zkh2016 Date: Wed, 13 Jul 2022 07:49:43 +0000 Subject: [PATCH 2/3] opt sparse_mask --- .../{sparse_mask_kernel.cc => mask_kernel.cc} | 2 +- .../{sparse_mask_kernel.cu => mask_kernel.cu} | 143 ++++++++++-------- .../{sparse_mask_kernel.h => mask_kernel.h} | 0 .../sparse/sparse_utils_grad_kernel.cc | 1 - .../kernels/sparse/sparse_utils_grad_kernel.h | 2 +- 5 files changed, 84 insertions(+), 64 deletions(-) rename paddle/phi/kernels/sparse/cpu/{sparse_mask_kernel.cc => mask_kernel.cc} (99%) rename paddle/phi/kernels/sparse/gpu/{sparse_mask_kernel.cu => mask_kernel.cu} (72%) rename paddle/phi/kernels/sparse/{sparse_mask_kernel.h => mask_kernel.h} (100%) diff --git a/paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc b/paddle/phi/kernels/sparse/cpu/mask_kernel.cc similarity index 99% rename from paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc rename to paddle/phi/kernels/sparse/cpu/mask_kernel.cc index cf2acd8557333..92c015101264c 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/mask_kernel.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/phi/kernels/sparse/sparse_mask_kernel.h" +#include "paddle/phi/kernels/sparse/mask_kernel.h" #include "paddle/phi/api/ext/dispatch.h" #include "paddle/phi/core/ddim.h" diff --git a/paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu b/paddle/phi/kernels/sparse/gpu/mask_kernel.cu similarity index 72% rename from paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu rename to paddle/phi/kernels/sparse/gpu/mask_kernel.cu index 21d6850bdc4aa..39fa89c0379b7 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/mask_kernel.cu @@ -12,9 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/phi/kernels/sparse/sparse_mask_kernel.h" - -#include +#include "paddle/phi/kernels/sparse/mask_kernel.h" #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" @@ -24,6 +22,7 @@ limitations under the License. */ #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h" @@ -72,11 +71,7 @@ void SparseMaskGPUKernel(const GPUContext& dev_ctx, phi::backends::gpu::GpuMemcpyAsync(sparse_offsets.data(), &h_sparse_offsets[0], sizeof(int64_t) * sparse_dim, -#ifdef PADDLE_WITH_HIP - hipMemcpyHostToDevice, -#else - cudaMemcpyHostToDevice, -#endif + gpuMemcpyHostToDevice, dev_ctx.stream()); DenseTensor out_indices = phi::EmptyLike(dev_ctx, indices); @@ -93,14 +88,15 @@ void SparseMaskGPUKernel(const GPUContext& dev_ctx, auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num * cols, 1); - MaskKernel<<>>( - x_ptr, - indices_ptr, - sparse_offsets.data(), - non_zero_num, - cols, - sparse_dim, - out_values_ptr); + MaskKernel + <<>>( + x_ptr, + indices_ptr, + sparse_offsets.data(), + non_zero_num, + cols, + sparse_dim, + out_values_ptr); out->SetMember(out_indices, out_values, dims, true); } @@ -121,19 +117,31 @@ void SparseMaskKernel(const Context& dev_ctx, })); } -template -__global__ void SparseMaskCopyKernel(const IntT* x_indexs, - const IntT* mask_indexs, - const IntT* bound_out, - const T* x_values, - const int64_t n, - const int64_t stride, - T* out_values) { +template +__global__ void MaskTable(const IntT* x_indexs, const int n, int* table) { + CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) { + int index = x_indexs[i]; + table[index] = i == 0 ? -1 : i; + } +} + +template +__global__ void MaskCopy(const IntT* mask_indexs, + const int* table, + const int n, + const int stride, + const T* x_values, + T* out_values) { + using LoadT = phi::AlignedVector; + using StoreT = phi::AlignedVector; CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) { - const IntT j = bound_out[i]; - if (j >= 0 && j < n && mask_indexs[i] == x_indexs[j]) { - for (int k = 0; k < stride; k++) { - out_values[i * stride + k] = x_values[j * stride + k]; + int j = table[mask_indexs[i]]; + if (j != 0) { + if (j == -1) j = 0; + for (int k = 0; k < stride; k += VecSize) { + LoadT vec_x; + phi::Load(x_values + j * stride + k, &vec_x); + phi::Store(vec_x, out_values + i * stride + k); } } } @@ -179,11 +187,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, phi::backends::gpu::GpuMemcpyAsync(d_sparse_offsets.data(), sparse_offsets.data(), sizeof(IntT) * sparse_dim, -#ifdef PADDLE_WITH_HIP - hipMemcpyHostToDevice, -#else - cudaMemcpyHostToDevice, -#endif + gpuMemcpyHostToDevice, dev_ctx.stream()); // 3. flatten x indices and mask indices @@ -210,37 +214,54 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, mask_indexs.numel(), sparse_dim, mask_indexs_ptr); -// 4. call thrust::lower_bound -#ifdef PADDLE_WITH_HIP - thrust::lower_bound(thrust::hip::par.on(dev_ctx.stream()), -#else - thrust::lower_bound(thrust::cuda::par.on(dev_ctx.stream()), -#endif - x_indexs_ptr, - x_indexs_ptr + x_indexs.numel(), - mask_indexs_ptr, - mask_indexs_ptr + mask_indexs.numel(), - bound_out_ptr); - // 5. copy value to out + int table_size = 1; + auto x_dims = x.dims(); + for (int i = 0; i < x_dims.size() - 1; i++) { + table_size *= x_dims[i]; + } + DenseTensor table = phi::Empty(dev_ctx, {table_size}); + phi::backends::gpu::GpuMemsetAsync( + table.data(), 0, table_size * sizeof(int), dev_ctx.stream()); + const int64_t stride = + x.dims().size() == sparse_dim ? 1 : x.non_zero_elements().dims()[1]; *out = phi::EmptyLike(dev_ctx, x.non_zero_elements()); phi::funcs::SetConstant set_zero; set_zero(dev_ctx, out, static_cast(0)); T* out_ptr = out->data(); - - const int64_t stride = - x.dims().size() == sparse_dim ? 1 : x.non_zero_elements().dims()[1]; - - SparseMaskCopyKernel<<>>(x_indexs_ptr, - mask_indexs_ptr, - bound_out_ptr, - x.non_zero_elements().data(), - mask_indexs.numel(), - stride, - out_ptr); + config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_indexs.numel(), 1); + MaskTable<<>>( + x_indexs_ptr, x_indexs.numel(), table.data()); + config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, mask_indexs.numel(), 1); + const int VecBytes = 16; + const int VecSize = VecBytes / sizeof(T); + if (stride % VecSize == 0) { + MaskCopy + <<>>(mask_indexs_ptr, + table.data(), + mask_indexs.numel(), + stride, + x.non_zero_elements().data(), + out_ptr); + } else { + MaskCopy<<>>(mask_indexs_ptr, + table.data(), + mask_indexs.numel(), + stride, + x.non_zero_elements().data(), + out_ptr); + } } template @@ -257,7 +278,7 @@ void SparseMaskHelperKernel(const Context& dev_ctx, } // namespace sparse } // namespace phi -PD_REGISTER_KERNEL(sparse_mask, +PD_REGISTER_KERNEL(mask, GPU, ALL_LAYOUT, phi::sparse::SparseMaskKernel, @@ -272,7 +293,7 @@ PD_REGISTER_KERNEL(sparse_mask, kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); } -PD_REGISTER_KERNEL(sparse_mask_helper, +PD_REGISTER_KERNEL(mask_helper, GPU, ALL_LAYOUT, phi::sparse::SparseMaskHelperKernel, diff --git a/paddle/phi/kernels/sparse/sparse_mask_kernel.h b/paddle/phi/kernels/sparse/mask_kernel.h similarity index 100% rename from paddle/phi/kernels/sparse/sparse_mask_kernel.h rename to paddle/phi/kernels/sparse/mask_kernel.h diff --git a/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.cc b/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.cc index 69677be34b231..9425c14b79b36 100644 --- a/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.cc @@ -15,7 +15,6 @@ limitations under the License. */ #include "paddle/phi/kernels/sparse/sparse_utils_grad_kernel.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/sparse/sparse_mask_kernel.h" namespace phi { namespace sparse { diff --git a/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.h b/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.h index a00b9c275c292..7cf97c3f48ece 100644 --- a/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.h +++ b/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.h @@ -16,7 +16,7 @@ limitations under the License. */ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h" -#include "paddle/phi/kernels/sparse/sparse_mask_kernel.h" +#include "paddle/phi/kernels/sparse/mask_kernel.h" namespace phi { namespace sparse { From 9457848aa1d78b52a0cff1bc3c7a7e0963e81515 Mon Sep 17 00:00:00 2001 From: zkh2016 Date: Wed, 13 Jul 2022 07:52:31 +0000 Subject: [PATCH 3/3] rm unused file --- .../unittests/test_sparse_middle_extractor.py | 324 ------------------ .../tests/unittests/test_sparse_mnist.py | 126 ------- 2 files changed, 450 deletions(-) delete mode 100644 python/paddle/fluid/tests/unittests/test_sparse_middle_extractor.py delete mode 100644 python/paddle/fluid/tests/unittests/test_sparse_mnist.py diff --git a/python/paddle/fluid/tests/unittests/test_sparse_middle_extractor.py b/python/paddle/fluid/tests/unittests/test_sparse_middle_extractor.py deleted file mode 100644 index ae52b4a413336..0000000000000 --- a/python/paddle/fluid/tests/unittests/test_sparse_middle_extractor.py +++ /dev/null @@ -1,324 +0,0 @@ -import paddle -import paddle.nn as nn -import paddle.sparse as sparse -from paddle.fluid.framework import _test_eager_guard -import time -import numpy as np -import torch -import spconv.pytorch as spconv -import inspect - -class MiddleExtractor(paddle.nn.Layer): - def __init__(self, - #output_shape, - use_norm=True, - num_input_features=128, - num_filters_down1=[64], - num_filters_down2=[64, 64], - name='MiddleExtractor'): - super(MiddleExtractor, self).__init__() - self.name = name - if not use_norm: - self.middle_conv = paddle.nn.Sequential( - #nn.Pad3D(1), - nn.Conv3D(num_input_features, 64, 3, stride=(2, 1, 1), data_format='NDHWC'), - #nn.BatchNorm3D(64, epsilon=1e-3, momentum=0.001, data_format='NDHWC'), - nn.ReLU(), - #nn.Pad3D([1, 1, 1, 1, 0, 0]), - nn.Conv3D(64, 64, 3, stride=(1, 1, 1), data_format='NDHWC'), - #nn.BatchNorm3D(64, epsilon=1e-3, momentum=0.001, data_format='NDHWC'), - nn.ReLU(), - #nn.Pad3D(1), - nn.Conv3D(64, 64, 3, stride=(2, 1, 1), data_format='NDHWC'), - #nn.BatchNorm3D(64, epsilon=1e-3, momentum=0.001, data_format='NDHWC'), - nn.ReLU(), - ) - else: - self.middle_conv = paddle.nn.Sequential( - #nn.Pad3D(1), - nn.Conv3D(num_input_features, 64, 3, stride=(2, 1, 1), data_format='NDHWC'), - nn.BatchNorm3D(64, epsilon=1e-3, momentum=0.001, data_format='NDHWC'), - nn.ReLU(), - #nn.Pad3D([1, 1, 1, 1, 0, 0]), - nn.Conv3D(64, 64, 3, stride=(1, 1, 1), data_format='NDHWC'), - nn.BatchNorm3D(64, epsilon=1e-3, momentum=0.001, data_format='NDHWC'), - nn.ReLU(), - #nn.Pad3D(1), - nn.Conv3D(64, 64, 3, stride=(2, 1, 1), data_format='NDHWC'), - nn.BatchNorm3D(64, epsilon=1e-3, momentum=0.001, data_format='NDHWC'), - nn.ReLU(), - ) - def forward(self, x): - return self.middle_conv(x) - - -def get_pos_to_kw_map(func): - pos_to_kw = {} - fsig = inspect.signature(func) - pos = 0 - for name, info in fsig.parameters.items(): - if info.kind is info.POSITIONAL_OR_KEYWORD: - pos_to_kw[pos] = name - pos += 1 - return pos_to_kw - -def change_default_args(**kwargs): - def layer_wrapper(layer_class): - class DefaultArgLayer(layer_class): - def __init__(self, *args, **kw): - pos_to_kw = get_pos_to_kw_map(layer_class.__init__) - kw_to_pos = {kw: pos for pos, kw in pos_to_kw.items()} - for key, val in kwargs.items(): - if key not in kw and kw_to_pos[key] > len(args): - kw[key] = val - super().__init__(*args, **kw) - - return DefaultArgLayer - - return layer_wrapper - -class Empty(torch.nn.Module): - def __init__(self, *args, **kwargs): - super(Empty, self).__init__() - - def forward(self, *args, **kwargs): - if len(args) == 1: - return args[0] - elif len(args) == 0: - return None - return args - -class SpconvMiddleExtractor(torch.nn.Module): - def __init__(self, - #output_shape, - use_norm=True, - num_input_features=128, - num_filters_down1=[64], - num_filters_down2=[64, 64], - name='SpconvMiddleExtractor'): - super(SpconvMiddleExtractor, self).__init__() - if use_norm: - BatchNorm1d = change_default_args( - eps=1e-3, momentum=0.01)(torch.nn.BatchNorm1d) - Linear = change_default_args(bias=False)(nn.Linear) - else: - BatchNorm1d = Empty - Linear = change_default_args(bias=True)(nn.Linear) - - middle_layers = [] - - num_filters = [num_input_features] + num_filters_down1 - filters_pairs_d1 = [[num_filters[i], num_filters[i + 1]] - for i in range(len(num_filters) - 1)] - - for i, o in filters_pairs_d1: - middle_layers.append(spconv.SubMConv3d(i, o, 3, bias=False)) - if use_norm: - #middle_layers.append(BatchNorm1d(o)) - middle_layers.append(torch.nn.BatchNorm1d(o, eps=1e-3, momentum=0.01)) - middle_layers.append(torch.nn.ReLU()) - - middle_layers.append( - spconv.SparseConv3d( - num_filters[-1], - num_filters[-1], (3, 1, 1), (2, 1, 1), - bias=False)) - - if use_norm: - #middle_layers.append( - # BatchNorm1d(num_filters[-1])) - middle_layers.append(torch.nn.BatchNorm1d(o, eps=1e-3, momentum=0.01)) - middle_layers.append(torch.nn.ReLU()) - - - # assert len(num_filters_down2) > 0 - if len(num_filters_down1) == 0: - num_filters = [num_filters[-1]] + num_filters_down2 - else: - num_filters = [num_filters_down1[-1]] + num_filters_down2 - filters_pairs_d2 = [[num_filters[i], num_filters[i + 1]] - for i in range(len(num_filters) - 1)] - for i, o in filters_pairs_d2: - middle_layers.append(spconv.SubMConv3d(i, o, 3, bias=False)) - if use_norm: - #middle_layers.append(BatchNorm1d(o)) - middle_layers.append(torch.nn.BatchNorm1d(o, eps=1e-3, momentum=0.01)) - middle_layers.append(torch.nn.ReLU()) - middle_layers.append( - spconv.SparseConv3d( - num_filters[-1], - num_filters[-1], (3, 1, 1), (2, 1, 1), - bias=False)) - if use_norm: - #middle_layers.append( - #BatchNorm1d(num_filters[-1])) - middle_layers.append(torch.nn.BatchNorm1d(o, eps=1e-3, momentum=0.01)) - middle_layers.append(torch.nn.ReLU()) - #middle_layers.append(scn.SparseToDense(3, num_filters[-1])) - middle_layers.append(spconv.ToDense()) - self.middle_conv = spconv.SparseSequential(*middle_layers) - - def forward(self, x): - out = self.middle_conv(x) - return out - -class SparseMiddleExtractor(paddle.nn.Layer): - def __init__(self, - #output_shape, - use_norm=True, - num_input_features=128, - num_filters_down1=[64], - num_filters_down2=[64, 64], - name='SparseMiddleExtractor'): - super(SparseMiddleExtractor, self).__init__() - self.name = name - - middle_layers = [] - num_filters = [num_input_features] + num_filters_down1 - filters_pairs_d1 = [[num_filters[i], num_filters[i + 1]] for i in range(len(num_filters) - 1)] - for i, o in filters_pairs_d1: - middle_layers.append(sparse.SubmConv3D(i, o, 3, bias_attr=False)) - if use_norm: - middle_layers.append(sparse.BatchNorm(o, epsilon=1e-3, momentum=0.01)) - middle_layers.append(sparse.ReLU()) - - middle_layers.append(sparse.Conv3D(num_filters[-1], num_filters[-1], (3, 1, 1), (2, 1, 1), bias_attr=False)) - - if use_norm: - middle_layers.append(sparse.BatchNorm(num_filters[-1], epsilon=1e-3, momentum=0.01)) - middle_layers.append(sparse.ReLU()) - - - if len(num_filters_down1) == 0: - num_filters = [num_filters[-1]] + num_filters_down2 - else: - num_filters = [num_filters_down1[-1]] + num_filters_down2 - - filters_pairs_d2 = [[num_filters[i], num_filters[i + 1]] for i in range(len(num_filters) - 1)] - - for i, o in filters_pairs_d2: - middle_layers.append(sparse.SubmConv3D(i, o, 3, bias_attr=False)) - if use_norm: - middle_layers.append(sparse.BatchNorm(o, epsilon=1e-3, momentum=0.01)) - middle_layers.append(sparse.ReLU()) - - middle_layers.append(sparse.Conv3D(num_filters[-1], num_filters[-1], (3, 1, 1), (2, 1, 1), bias_attr=False)) - if use_norm: - middle_layers.append(sparse.BatchNorm(num_filters[-1], epsilon=1e-3, momentum=0.01)) - middle_layers.append(sparse.ReLU()) - - self.middle_conv = nn.Sequential(*middle_layers) - - def forward(self, x): - sparse_out = self.middle_conv(x) - #return sparse_out - return sparse_out.to_dense() - - -def test(): - paddle.seed(0) - with _test_eager_guard(): - in_channels = 128 - # Note: 1. paddle的BatchNorm1D的输入shape不能太大,否则报CUDNN_STATUS_NOT_SUPPORTED. - shape = [20, 40, 100] - batch_size = 1 - sparsity = 0.95 - - full_shape = [batch_size] + shape + [in_channels] - print(full_shape) - - total_elements = np.prod(shape) - nnz = int(total_elements * (1-sparsity)) - print("nnz=", nnz) - - #product indices - indices = [] - for i in range(4): - indices.append(paddle.randint(0, full_shape[i], [1, nnz])) - - indices = paddle.concat(indices) - #product values - values = paddle.randn((nnz, in_channels)) - - sparse_x = paddle.sparse.sparse_coo_tensor(indices, values, shape=full_shape) - - dense_x = sparse_x.to_dense() - - #spconv - device = torch.device("cuda") - torch_x = torch.tensor(dense_x.numpy(), device=device) - - spconv_x = spconv.SparseConvTensor.from_dense(torch_x) - - #whether to use batch_norm - use_norm = True - - dense_model = MiddleExtractor(use_norm=use_norm, num_input_features=in_channels) - spconv_model = SpconvMiddleExtractor(use_norm=use_norm, num_input_features=in_channels).to(device) - sparse_model = SparseMiddleExtractor(use_norm=use_norm, num_input_features=in_channels) - layer_nums = len(sparse_model.middle_conv) - block_size = 3 if use_norm else 2 - layer_nums = int(layer_nums / block_size) - - for i in range(0, layer_nums): - weight = paddle.to_tensor(spconv_model.middle_conv[i * block_size].weight.detach().cpu().numpy()) - sparse_model.middle_conv[i * block_size].weight.set_value(paddle.transpose(paddle.to_tensor(weight), [1,2,3,4,0])) - if use_norm: - bn_weight = paddle.to_tensor(spconv_model.middle_conv[i*block_size + 1].weight.detach().cpu().numpy()) - sparse_model.middle_conv[i * block_size + 1].weight.set_value(bn_weight) - - print(dense_model) - print(sparse_model) - print(spconv_model) - paddle.device.cuda.synchronize() - - #warm up - dense_x.stop_gradient=True - out1 = dense_model(dense_x) - paddle.device.cuda.synchronize() - sparse_x.stop_gradient=True - out2 = sparse_model(sparse_x) - paddle.device.cuda.synchronize() - spconv_x.features.required_grad=False - out3 = spconv_model(spconv_x) - torch.cuda.synchronize(device) - #warm up - - t0 = time.time() - #padde dense - dense_x.stop_gradient=False - out1 = dense_model(dense_x) - out1.backward(out1) - paddle.device.cuda.synchronize() - t1 = time.time() - - #padde sparse - sparse_x.stop_gradient=False - out2 = sparse_model(sparse_x) - out2.backward(out2) - paddle.device.cuda.synchronize() - t2 = time.time() - - #spconv - spconv_x.features.required_grad=True - spconv_x.features.requires_grad_() - out3 = spconv_model(spconv_x) - out3.backward(out3) - torch.cuda.synchronize(device) - t3 = time.time() - - # Note 2. sparse的BatchNorm底层是使用paddle.nn.BatchNorm1D对values进行bn计算,测试发现BatchNorm1D的性能比BatchNorm3D差,因此use_norm=True的情况,需要更高的稀疏度才能比dense的快 - # Note 3. 只跑前向,sparse的耗时和spconv接近,稀疏度越高sparse的性能越好,当前方式测试前向+反向,spconv的耗时很高, 原因未知 - print("dense time: ", t1 - t0) - print("sparse time: ", t2 - t1) - print("spconv time: ", t3 - t2) - - # Note 4. paddle和torch的BN存在误差,测试shape=(4000, 64)的随机输入,单层BN前向误差在1e-6, 反向误差在1e-4 - #verify the forward calculation result - assert np.allclose(paddle.transpose(out2, [0, 4, 1, 2, 3]).numpy(), out3.detach().cpu().numpy(), atol=1e-4, rtol=1e-4) - - #verify the backward calculation result - assert np.allclose(spconv_x.features.grad.cpu().numpy(), - sparse_x.grad.values().numpy(), atol=1e-3, rtol=1e-3) - -test() diff --git a/python/paddle/fluid/tests/unittests/test_sparse_mnist.py b/python/paddle/fluid/tests/unittests/test_sparse_mnist.py deleted file mode 100644 index 3589dc83090f3..0000000000000 --- a/python/paddle/fluid/tests/unittests/test_sparse_mnist.py +++ /dev/null @@ -1,126 +0,0 @@ -import paddle -from paddle.vision.transforms import Compose, Normalize, ToTensor -from paddle.fluid.framework import _test_eager_guard -import time - -paddle.disable_static() -#transform = Compose([Normalize(mean=[127.5], -# std=[127.5], -# data_format='CHW')]) -transform = Compose([ToTensor()]) -# 使用transform对数据集做归一化 -print('download training data and load training data') -train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform) -test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform) -print('load finished') - -import numpy as np -#import matplotlib.pyplot as plt -train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1] -train_data0 = train_data0.reshape([28,28]) -#plt.figure(figsize=(2,2)) -#plt.imshow(train_data0, cmap=plt.cm.binary) -print('train_data0 label is: ' + str(train_label_0)) - - -import paddle -import paddle.nn.functional as F -class SparseLeNet(paddle.nn.Layer): - def __init__(self): - super(SparseLeNet, self).__init__() - #self.bn = paddle.sparse.BatchNorm(1) - self.conv1 = paddle.sparse.Conv3D(in_channels=1, out_channels=6, kernel_size=[1, 5, 5], stride=[1, 1, 1], padding=[0, 2, 2]) - self.relu1 = paddle.sparse.ReLU() - self.pool1 = paddle.sparse.MaxPool3D(kernel_size=[1, 2, 2], stride=[1, 2, 2]) - self.conv2 = paddle.sparse.Conv3D(in_channels=6, out_channels=16, kernel_size=[1, 5, 5], stride=[1, 1, 1]) - self.relu2 = paddle.sparse.ReLU() - self.pool2 = paddle.sparse.MaxPool3D(kernel_size=[1, 2, 2], stride=[1, 2, 2]) - - self.fc1 = paddle.nn.Linear(16*5*5, 120) - self.fc2 = paddle.nn.Linear(120, 84) - self.fc3 = paddle.nn.Linear(84, 10) - - def forward(self, x): - #x = self.bn(x) - x = self.conv1(x) - x = self.relu1(x) - x = self.pool1(x) - x = self.conv2(x) - x = self.relu2(x) - x = self.pool2(x) - x = x.to_dense() - - x = paddle.flatten(x, start_axis=1, stop_axis=-1) - x = self.fc1(x) - x = paddle.nn.functional.relu(x) - x = self.fc2(x) - x = paddle.nn.functional.relu(x) - x = self.fc3(x) - return x - -import paddle.nn.functional as F -train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True) -# 加载训练集 batch_size 设为 64 -# sparse 训练 - -def prepare_data(x_data): - x_data = paddle.transpose(x_data, perm=[0, 2, 3, 1]) - x_data = paddle.reshape(x_data, [x_data.shape[0], 1, x_data.shape[1], x_data.shape[2], x_data.shape[3]]) - return x_data - -def sparse_train(model): - model.train() - epochs = 2 - optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters()) - # 用Adam作为优化函数 - for epoch in range(epochs): - for batch_id, data in enumerate(train_loader()): - x_data = data[0] - y_data = data[1] - x_data = prepare_data(x_data) - x_data = x_data.to_sparse_coo(4) - x_data.stop_gradient=False - predicts = model(x_data) - loss = F.cross_entropy(predicts, y_data) - # 计算损失 - acc = paddle.metric.accuracy(predicts, y_data) - loss.backward() - if batch_id % 300 == 0: - print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id, loss.numpy(), acc.numpy())) - optim.step() - optim.clear_grad() - -test_loader = paddle.io.DataLoader(test_dataset, places=paddle.CPUPlace(), batch_size=64) -# 加载测试数据集 -def test(model): - model.eval() - batch_size = 64 - for batch_id, data in enumerate(test_loader()): - x_data = data[0] - y_data = data[1] - x_data = prepare_data(x_data) - x_data = x_data.to_sparse_coo(4) - predicts = model(x_data) - # 获取预测结果 - loss = F.cross_entropy(predicts, y_data) - acc = paddle.metric.accuracy(predicts, y_data) - if batch_id % 20 == 0: - print("batch_id: {}, loss is: {}, acc is: {}".format(batch_id, loss.numpy(), acc.numpy())) - -with _test_eager_guard(): - sparse_model = SparseLeNet() - print(sparse_model) - - t0 = time.time() - sparse_train(sparse_model) - t1 = time.time() - print("spare time:", t1-t0) - test(sparse_model) - #x = paddle.randn((1, 1,28,28,1)) - #x.stop_gradient=False - #sparse_x = x.to_sparse_coo(4) - #print("sparse_x values shape:", sparse_x.values().shape) - #out = sparse_model(sparse_x) - #out.backward(out) - #print("end") -