diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index d8e79d40c23eb..21c87681f6877 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -822,6 +822,29 @@ void TensorToStream(std::ostream& os, const Tensor& tensor, #else PADDLE_THROW(platform::errors::Unimplemented( "XPUPlace is not supported when not compiled with XPU")); +#endif + } else if (platform::is_npu_place(tensor.place())) { +#ifdef PADDLE_WITH_ASCEND_CL + constexpr size_t kBufSize = 1024 * 1024 * 64; // 64MB + std::unique_ptr buf(new char[kBufSize]); + auto& npu_dev_ctx = + static_cast(dev_ctx); + platform::CPUPlace cpu; + uintptr_t data = reinterpret_cast(data_ptr); + while (size != 0) { + size_t size_to_write = std::min(kBufSize, static_cast(size)); + memory::Copy(cpu, buf.get(), + BOOST_GET_CONST(platform::NPUPlace, tensor.place()), + reinterpret_cast(data), size_to_write, + npu_dev_ctx.stream()); + npu_dev_ctx.Wait(); + os.write(buf.get(), size_to_write); + data += size_to_write; + size -= size_to_write; + } +#else + PADDLE_THROW(platform::errors::Unimplemented( + "NPUPlace is not supported when not compiled with NPU")); #endif } else { os.write(static_cast(data_ptr), @@ -877,8 +900,10 @@ void TensorFromStream(std::istream& is, Tensor* tensor, auto ctx = platform::CPUDeviceContext(); size_t size = tensor->numel() * framework::SizeOfType(desc.data_type()); if (platform::is_gpu_place(dev_ctx.GetPlace()) || - platform::is_xpu_place(dev_ctx.GetPlace())) { -#if defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU + platform::is_xpu_place(dev_ctx.GetPlace()) || + platform::is_npu_place(dev_ctx.GetPlace())) { +#if defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU || \ + defined PADDLE_WITH_ASCEND_CL Tensor cpu_tensor; cpu_tensor.Resize(framework::make_ddim(shape)); framework::VisitDataType( @@ -887,13 +912,19 @@ void TensorFromStream(std::istream& is, Tensor* tensor, is.read(static_cast(buf), size); auto dst_place = dev_ctx.GetPlace(); framework::TensorCopy(cpu_tensor, dst_place, dev_ctx, tensor); + if (platform::is_npu_place(dev_ctx.GetPlace())) { + dev_ctx.Wait(); + } #else if (platform::is_gpu_place(dev_ctx.GetPlace())) { PADDLE_THROW(platform::errors::Unimplemented( "CUDAPlace is not supported when not compiled with CUDA")); - } else { + } else if (platform::is_xpu_place(dev_ctx.GetPlace())) { PADDLE_THROW(platform::errors::Unimplemented( "XPUPlace is not supported when not compiled with XPU")); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "NPUPlace is not supported when not compiled with NPU")); } #endif } else { @@ -934,8 +965,10 @@ void TensorFromStream(std::istream& is, Tensor* tensor, auto ctx = platform::CPUDeviceContext(); size_t size = tensor->numel() * framework::SizeOfType(desc.data_type()); if (platform::is_gpu_place(dev_ctx.GetPlace()) || - platform::is_xpu_place(dev_ctx.GetPlace())) { -#if defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU + platform::is_xpu_place(dev_ctx.GetPlace()) || + platform::is_npu_place(dev_ctx.GetPlace())) { +#if defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU || \ + defined PADDLE_WITH_ASCEND_CL Tensor cpu_tensor; cpu_tensor.Resize(framework::make_ddim(dims)); framework::VisitDataType( @@ -944,13 +977,19 @@ void TensorFromStream(std::istream& is, Tensor* tensor, is.read(static_cast(buf), size); auto dst_place = dev_ctx.GetPlace(); framework::TensorCopy(cpu_tensor, dst_place, dev_ctx, tensor); + if (platform::is_npu_place(dev_ctx.GetPlace())) { + dev_ctx.Wait(); + } #else if (platform::is_gpu_place(dev_ctx.GetPlace())) { PADDLE_THROW(platform::errors::Unimplemented( "CUDAPlace is not supported when not compiled with CUDA")); - } else { + } else if (platform::is_xpu_place(dev_ctx.GetPlace())) { PADDLE_THROW(platform::errors::Unimplemented( "XPUPlace is not supported when not compiled with XPU")); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "NPUPlace is not supported when not compiled with NPU")); } #endif } else { diff --git a/paddle/fluid/operators/load_combine_op_npu.cc b/paddle/fluid/operators/load_combine_op_npu.cc new file mode 100644 index 0000000000000..4b9b96c23b0b7 --- /dev/null +++ b/paddle/fluid/operators/load_combine_op_npu.cc @@ -0,0 +1,25 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/load_combine_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_NPU_KERNEL( + load_combine, + ops::LoadCombineOpKernel, + ops::LoadCombineOpKernel, + ops::LoadCombineOpKernel, + ops::LoadCombineOpKernel, + ops::LoadCombineOpKernel); diff --git a/paddle/fluid/operators/load_op_npu.cc b/paddle/fluid/operators/load_op_npu.cc new file mode 100644 index 0000000000000..1f53280345831 --- /dev/null +++ b/paddle/fluid/operators/load_op_npu.cc @@ -0,0 +1,24 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/load_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_NPU_KERNEL( + load, ops::LoadOpKernel, + ops::LoadOpKernel, + ops::LoadOpKernel, + ops::LoadOpKernel, + ops::LoadOpKernel); diff --git a/paddle/fluid/operators/save_combine_op_npu.cc b/paddle/fluid/operators/save_combine_op_npu.cc new file mode 100644 index 0000000000000..1fb136a5110db --- /dev/null +++ b/paddle/fluid/operators/save_combine_op_npu.cc @@ -0,0 +1,24 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/save_combine_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_NPU_KERNEL( + save_combine, + ops::SaveCombineOpKernel, + ops::SaveCombineOpKernel, + ops::SaveCombineOpKernel, + ops::SaveCombineOpKernel); diff --git a/paddle/fluid/operators/save_op_npu.cc b/paddle/fluid/operators/save_op_npu.cc new file mode 100644 index 0000000000000..90db1a0bb85d6 --- /dev/null +++ b/paddle/fluid/operators/save_op_npu.cc @@ -0,0 +1,28 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/save_op.h" +#include "paddle/fluid/platform/float16.h" + +namespace ops = paddle::operators; + +REGISTER_OP_NPU_KERNEL( + save, ops::SaveOpKernel, + ops::SaveOpKernel, + ops::SaveOpKernel, + ops::SaveOpKernel, + ops::SaveOpKernel, + ops::SaveOpKernel, + ops::SaveOpKernel); diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index 51fc3439c9a59..f1ec8c3ea67c7 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -644,6 +644,7 @@ inline py::array TensorToPyArray(const framework::Tensor &tensor, } bool is_gpu_tensor = platform::is_gpu_place(tensor.place()); bool is_xpu_tensor = platform::is_xpu_place(tensor.place()); + bool is_npu_tensor = platform::is_npu_place(tensor.place()); const auto &tensor_dims = tensor.dims(); auto tensor_dtype = tensor.type(); size_t sizeof_dtype = framework::SizeOfType(tensor_dtype); @@ -662,7 +663,7 @@ inline py::array TensorToPyArray(const framework::Tensor &tensor, std::string py_dtype_str = details::TensorDTypeToPyDTypeStr(tensor.type()); - if (!is_gpu_tensor && !is_xpu_tensor) { + if (!is_gpu_tensor && !is_xpu_tensor && !is_npu_tensor) { if (!need_deep_copy) { auto base = py::cast(std::move(tensor)); return py::array(py::dtype(py_dtype_str.c_str()), py_dims, py_strides, @@ -729,6 +730,34 @@ inline py::array TensorToPyArray(const framework::Tensor &tensor, PADDLE_THROW(platform::errors::PermissionDenied( "Cannot use CUDAPlace in CPU only version, " "Please recompile or reinstall Paddle with CUDA support.")); +#endif + } else if (is_npu_tensor) { +#ifdef PADDLE_WITH_ASCEND_CL + py::array py_arr(py::dtype(py_dtype_str.c_str()), py_dims, py_strides); + PADDLE_ENFORCE_EQ(py_arr.writeable(), true, + platform::errors::InvalidArgument( + "PyArray is not writable, in which case memory leak " + "or double free would occur")); + PADDLE_ENFORCE_EQ( + py_arr.owndata(), true, + platform::errors::InvalidArgument( + "PyArray does not own data, in which case memory leak " + "or double free would occur")); + + size_t copy_bytes = sizeof_dtype * numel; + auto p = BOOST_GET_CONST(platform::NPUPlace, tensor.place()); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &ctx = *pool.Get(tensor.place()); + paddle::memory::Copy( + platform::CPUPlace(), py_arr.mutable_data(), p, tensor_buf_ptr, + copy_bytes, + reinterpret_cast(ctx).stream()); + ctx.Wait(); + return py_arr; +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "Cannot use NPUPlace in CPU/GPU/XPU version, " + "Please recompile or reinstall Paddle with NPU support.")); #endif } PADDLE_THROW(platform::errors::Unimplemented("Place is not supported")); diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index d5963675a82a0..560abad626405 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -1973,6 +1973,10 @@ def set_var(var, ndarray): p = paddle.fluid.core.Place() p.set_place(t._place()) place = paddle.fluid.XPUPlace(p.xpu_device_id()) + elif p.is_npu_place(): + p = paddle.fluid.core.Place() + p.set_place(t._place()) + place = paddle.fluid.NPUPlace(p.npu_device_id()) else: p = paddle.fluid.core.Place() p.set_place(t._place()) @@ -2115,8 +2119,8 @@ def _load_vars_with_try_catch(exe, error_str = "Failed to load model/variables `%s`, please make sure " \ "model/variables file is saved with the following APIs: " \ "save_params, save_persistables, save_vars." - filenames = [var.name for var in vars - ] if filename is None else filename + filenames = [var.name for var in + vars] if filename is None else filename if raise_error: raise RuntimeError(error_str % filenames) else: @@ -2256,6 +2260,10 @@ def set_program_state(program, state_dict): p = paddle.fluid.core.Place() p.set_place(ten_place) py_place = paddle.fluid.XPUPlace(p.xpu_device_id()) + elif ten_place.is_npu_place(): + p = paddle.fluid.core.Place() + p.set_place(ten_place) + py_place = paddle.fluid.NPUPlace(p.npu_device_id()) ten.set(new_para_np, py_place) diff --git a/python/paddle/fluid/tests/unittests/npu/test_save_load_npu.py b/python/paddle/fluid/tests/unittests/npu/test_save_load_npu.py new file mode 100644 index 0000000000000..e7e7fb39c913b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_save_load_npu.py @@ -0,0 +1,108 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import sys +sys.path.append("..") +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.nn import Embedding +import paddle.fluid.framework as framework +from paddle.fluid.optimizer import Adam +from paddle.fluid.dygraph.base import to_variable +from test_imperative_base import new_program_scope +from paddle.fluid.executor import global_scope +import numpy as np +import six +import pickle +import os +import errno +from test_static_save_load import * + +paddle.enable_static() + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestNPUSaveLoadBase(TestSaveLoadBase): + def set_place(self): + return fluid.CPUPlace() if not core.is_compiled_with_npu( + ) else paddle.NPUPlace(0) + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestNPUSaveLoadPartial(TestSaveLoadPartial): + def set_place(self): + return fluid.CPUPlace() if not core.is_compiled_with_npu( + ) else paddle.NPUPlace(0) + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestNPUSaveLoadSetStateDict(TestSaveLoadSetStateDict): + def set_place(self): + return fluid.CPUPlace() if not core.is_compiled_with_npu( + ) else paddle.NPUPlace(0) + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestNPUProgramStatePartial(TestProgramStatePartial): + def set_place(self): + return fluid.CPUPlace() if not core.is_compiled_with_npu( + ) else paddle.NPUPlace(0) + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestNPULoadFromOldInterface(TestLoadFromOldInterface): + def set_place(self): + return fluid.CPUPlace() if not core.is_compiled_with_npu( + ) else paddle.NPUPlace(0) + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestNPULoadFromOldInterfaceSingleFile(TestLoadFromOldInterfaceSingleFile): + def set_place(self): + return fluid.CPUPlace() if not core.is_compiled_with_npu( + ) else paddle.NPUPlace(0) + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestNPUProgramStateOldSave(TestProgramStateOldSave): + def setUp(self): + self.test_dygraph = False + + def set_place(self): + return fluid.CPUPlace() if not core.is_compiled_with_npu( + ) else paddle.NPUPlace(0) + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestNPUProgramStateOldSaveSingleModel(TestProgramStateOldSaveSingleModel): + def set_place(self): + return fluid.CPUPlace() if not core.is_compiled_with_npu( + ) else paddle.NPUPlace(0) + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_static_save_load.py b/python/paddle/fluid/tests/unittests/test_static_save_load.py index 0f4fca6d7f848..200e6fd35fdd3 100644 --- a/python/paddle/fluid/tests/unittests/test_static_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_static_save_load.py @@ -18,7 +18,7 @@ import paddle import paddle.fluid as fluid import paddle.fluid.core as core -from paddle.fluid.dygraph.nn import Embedding +from paddle.nn import Embedding import paddle.fluid.framework as framework from paddle.fluid.optimizer import Adam from paddle.fluid.dygraph.base import to_variable @@ -30,6 +30,8 @@ import os import errno +paddle.enable_static() + class SimpleLSTMRNN(fluid.Layer): def __init__(self, @@ -158,11 +160,10 @@ def __init__(self, num_layers=num_layers, init_scale=init_scale, dropout=dropout) - self.embedding = Embedding( - size=[vocab_size, hidden_size], - dtype='float32', - is_sparse=False, - param_attr=fluid.ParamAttr( + self.embedding = paddle.nn.Embedding( + num_embeddings=vocab_size, + embedding_dim=hidden_size, + weight_attr=fluid.ParamAttr( name='embedding_para', initializer=fluid.initializer.UniformInitializer( low=-init_scale, high=init_scale))) @@ -186,6 +187,8 @@ def forward(self, input, label, init_hidden, init_cell): init_c = fluid.layers.reshape( init_cell, shape=[self.num_layers, -1, self.hidden_size]) + # NPU 'tok_k' kernel only support `int32` dtype, so cast `input` from `int64` to `int32`. + input = fluid.layers.cast(input, "int32") x_emb = self.embedding(input) x_emb = fluid.layers.reshape( x_emb, shape=[-1, self.num_steps, self.hidden_size]) @@ -213,6 +216,10 @@ def forward(self, input, label, init_hidden, init_cell): class TestSaveLoadBase(unittest.TestCase): + def set_place(self): + return fluid.CPUPlace() if not core.is_compiled_with_cuda( + ) else fluid.CUDAPlace(0) + def test_ptb_rnn_cpu_float32(self): seed = 90 hidden_size = 10 @@ -234,8 +241,7 @@ def test_ptb_rnn_cpu_float32(self): num_steps=num_steps, init_scale=init_scale) - place = fluid.CPUPlace() if not core.is_compiled_with_cuda( - ) else fluid.CUDAPlace(0) + place = self.set_place() exe = fluid.Executor(place) sgd = Adam(learning_rate=1e-3) x = fluid.layers.data( @@ -314,6 +320,10 @@ def test_ptb_rnn_cpu_float32(self): class TestSaveLoadPartial(unittest.TestCase): + def set_place(self): + return fluid.CPUPlace() if not core.is_compiled_with_cuda( + ) else fluid.CUDAPlace(0) + def test_ptb_rnn_cpu_float32(self): seed = 90 hidden_size = 10 @@ -335,8 +345,7 @@ def test_ptb_rnn_cpu_float32(self): num_steps=num_steps, init_scale=init_scale) - place = fluid.CPUPlace() if not core.is_compiled_with_cuda( - ) else fluid.CUDAPlace(0) + place = self.set_place() exe = fluid.Executor(place) sgd = Adam(learning_rate=1e-3) x = fluid.layers.data( @@ -424,6 +433,10 @@ def test_ptb_rnn_cpu_float32(self): class TestSaveLoadSetStateDict(unittest.TestCase): + def set_place(self): + return fluid.CPUPlace() if not core.is_compiled_with_cuda( + ) else fluid.CUDAPlace(0) + def test_ptb_rnn_cpu_float32(self): seed = 90 hidden_size = 10 @@ -445,8 +458,7 @@ def test_ptb_rnn_cpu_float32(self): num_steps=num_steps, init_scale=init_scale) - place = fluid.CPUPlace() if not core.is_compiled_with_cuda( - ) else fluid.CUDAPlace(0) + place = self.set_place() exe = fluid.Executor(place) sgd = Adam(learning_rate=1e-3) x = fluid.layers.data( @@ -525,6 +537,10 @@ def test_ptb_rnn_cpu_float32(self): class TestProgramStatePartial(unittest.TestCase): + def set_place(self): + return fluid.CPUPlace() if not core.is_compiled_with_cuda( + ) else fluid.CUDAPlace(0) + def test_ptb_rnn_cpu_float32(self): seed = 90 hidden_size = 10 @@ -546,8 +562,7 @@ def test_ptb_rnn_cpu_float32(self): num_steps=num_steps, init_scale=init_scale) - place = fluid.CPUPlace() if not core.is_compiled_with_cuda( - ) else fluid.CUDAPlace(0) + place = self.set_place() exe = fluid.Executor(place) sgd = Adam(learning_rate=1e-3) x = fluid.layers.data( @@ -707,14 +722,17 @@ def test_ptb_rnn_cpu_float32(self): class TestVariableInit(unittest.TestCase): + def set_place(self): + return fluid.CPUPlace() if not core.is_compiled_with_cuda( + ) else fluid.CUDAPlace(0) + def test_variable_init(self): x = fluid.data(name="x", shape=[10, 10], dtype='float32') y = fluid.layers.fc(x, 10) z = fluid.layers.fc(y, 10) - place = fluid.CPUPlace() if not core.is_compiled_with_cuda( - ) else fluid.CUDAPlace(0) + place = self.set_place() exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) @@ -737,8 +755,7 @@ def set_var(var, ndarray): program = fluid.default_main_program() new_scope = fluid.core.Scope() - place = fluid.CPUPlace() if not core.is_compiled_with_cuda( - ) else fluid.CUDAPlace(0) + place = self.set_place() exe = fluid.Executor(place) parameter_list = list( filter(fluid.io.is_parameter, program.list_vars())) @@ -797,6 +814,10 @@ def setUp(self): if os.path.exists("test_static_load_var_list.pdparams"): os.remove("test_static_load_var_list.pdparams") + def set_place(self): + return fluid.CPUPlace() if not core.is_compiled_with_cuda( + ) else fluid.CUDAPlace(0) + def test_load_from_old_interface(self): seed = 90 hidden_size = 10 @@ -818,8 +839,7 @@ def test_load_from_old_interface(self): num_steps=num_steps, init_scale=init_scale) - place = fluid.CPUPlace() if not core.is_compiled_with_cuda( - ) else fluid.CUDAPlace(0) + place = self.set_place() exe = fluid.Executor(place) sgd = Adam(learning_rate=1e-3) x = fluid.layers.data( @@ -934,8 +954,7 @@ def test_load_from_old_interface_var_list(self): num_steps=num_steps, init_scale=init_scale) - place = fluid.CPUPlace() if not core.is_compiled_with_cuda( - ) else fluid.CUDAPlace(0) + place = self.set_place() exe = fluid.Executor(place) sgd = Adam(learning_rate=1e-3) x = fluid.layers.data( @@ -1026,6 +1045,10 @@ def test_load_from_old_interface_var_list(self): class TestLoadFromOldInterfaceSingleFile(unittest.TestCase): + def set_place(self): + return fluid.CPUPlace() if not core.is_compiled_with_cuda( + ) else fluid.CUDAPlace(0) + def test_load_from_old_interface(self): seed = 90 hidden_size = 10 @@ -1047,8 +1070,7 @@ def test_load_from_old_interface(self): num_steps=num_steps, init_scale=init_scale) - place = fluid.CPUPlace() if not core.is_compiled_with_cuda( - ) else fluid.CUDAPlace(0) + place = self.set_place() exe = fluid.Executor(place) sgd = Adam(learning_rate=1e-3) x = fluid.layers.data( @@ -1170,6 +1192,13 @@ def test_load_from_old_interface(self): class TestProgramStateOldSave(unittest.TestCase): + def setUp(self): + self.test_dygraph = True + + def set_place(self): + return fluid.CPUPlace() if not core.is_compiled_with_cuda( + ) else fluid.CUDAPlace(0) + def test_ptb_rnn_cpu_float32(self): seed = 90 hidden_size = 10 @@ -1191,8 +1220,7 @@ def test_ptb_rnn_cpu_float32(self): num_steps=num_steps, init_scale=init_scale) - place = fluid.CPUPlace() if not core.is_compiled_with_cuda( - ) else fluid.CUDAPlace(0) + place = self.set_place() exe = fluid.Executor(place) sgd = Adam(learning_rate=1e-3) x = fluid.layers.data( @@ -1298,11 +1326,12 @@ def symlink_force(target, link_name): fluid.set_program_state(main_program, program_state) self.check_in_static(main_program, base_map) - # make sure `load_program_state` can be used in dynamic graph mode - with fluid.dygraph.guard(place): - load_state = fluid.load_program_state("test_program_1") - for k, v in load_state.items(): - self.assertTrue(np.array_equal(base_map[k], v)) + if self.test_dygraph: + # make sure `load_program_state` can be used in dynamic graph mode + with fluid.dygraph.guard(place): + load_state = fluid.load_program_state("test_program_1") + for k, v in load_state.items(): + self.assertTrue(np.array_equal(base_map[k], v)) def check_in_static(self, main_program, base_map): for var in main_program.list_vars(): @@ -1313,40 +1342,11 @@ def check_in_static(self, main_program, base_map): self.assertTrue(np.array_equal(new_t, base_t)) -class TestStaticSaveLoadLargeParameters(unittest.TestCase): - def test_large_parameters_static_save(self): - # enable static mode - paddle.enable_static() - LARGE_PARAM = 2**26 - with new_program_scope(): - # create network - x = paddle.static.data( - name="static_save_load_large_x", - shape=[None, 10], - dtype='float32') - z = paddle.static.nn.fc(x, LARGE_PARAM) - place = paddle.CPUPlace() - exe = paddle.static.Executor(place) - exe.run(paddle.static.default_startup_program()) - prog = paddle.static.default_main_program() - - inputs = np.random.randn(1, 10).astype("float32") - result_z = exe.run(program=prog, - feed={"static_save_load_large_x": inputs}, - fetch_list=[z.name]) - path = "test_static_save_load_large_param/static_save" - paddle.fluid.save(prog, path) - - paddle.fluid.load(prog, path) - result_load = exe.run(program=prog, - feed={"static_save_load_large_x": inputs}, - fetch_list=[z.name]) - # compare results before and after saving - self.assertTrue( - np.sum(np.abs(result_z[0] - result_load[0])) < 1e-15) - - class TestProgramStateOldSaveSingleModel(unittest.TestCase): + def set_place(self): + return fluid.CPUPlace() if not core.is_compiled_with_cuda( + ) else fluid.CUDAPlace(0) + def test_ptb_rnn_cpu_float32(self): seed = 90 hidden_size = 10 @@ -1368,8 +1368,7 @@ def test_ptb_rnn_cpu_float32(self): num_steps=num_steps, init_scale=init_scale) - place = fluid.CPUPlace() if not core.is_compiled_with_cuda( - ) else fluid.CUDAPlace(0) + place = self.set_place() exe = fluid.Executor(place) sgd = Adam(learning_rate=1e-3) x = fluid.layers.data(