Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev cc clean tensor name scope #7082

Merged
merged 3 commits into from
Dec 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions oneflow/core/framework/nn_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include "oneflow/core/framework/instructions_builder.h"
#include "oneflow/core/framework/multi_client_session_context.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/framework/tensor_name_scope.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/job/compiler.h"
Expand Down Expand Up @@ -254,6 +255,9 @@ Maybe<void> NNGraph::CompileAndInitRuntime() {
// TODO(chengcheng): CHECK job valid for each rank.
JUST(CreateAndRegisterNewVariableOpInJobPass());

// NOTE(chengcheng): TensorNameScope need to be cleared after current graph is built.
one::TensorNameScope::Global()->Clear();

// NOTE(chengcheng): Global<JobDesc> need be clear before GlobalJobDescScope construct.
if (Global<JobDesc>::Get() != nullptr) { Global<JobDesc>::Delete(); }

Expand Down
5 changes: 5 additions & 0 deletions oneflow/core/framework/tensor_name_scope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,10 @@ void TensorNameScope::Record(const std::shared_ptr<Tensor>& tensor, const std::s
tensor_names_[key] = name;
}

void TensorNameScope::Clear() {
std::lock_guard<std::mutex> lock(mutex_);
tensor_names_.clear();
}

} // namespace one
} // namespace oneflow
2 changes: 2 additions & 0 deletions oneflow/core/framework/tensor_name_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class TensorNameScope {

void Record(const std::shared_ptr<Tensor>& tensor, const std::string& name);

void Clear();

private:
TensorNameScope() : default_tensor_name_("") {}
virtual ~TensorNameScope() = default;
Expand Down
55 changes: 43 additions & 12 deletions python/oneflow/test/graph/test_graph_free_eager_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,21 @@
import oneflow.unittest


class MyModuleWithEagerTensorForward(flow.nn.Module):
def __init__(self):
super().__init__()
self.linear = flow.nn.Linear(3, 8, False)

def forward(self, x):
y0 = self.linear(x)
eager_t = flow.tensor([1.0], dtype=y0.dtype, device=y0.device)
out = y0 + eager_t
return out


@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
@flow.unittest.skip_unless_1n1d()
class TestGraphWithEagerTensorCaught(oneflow.unittest.TestCase):
def test_eager_tensor_forward_graph(test_case):
class MyModuleWithEagerTensorForward(flow.nn.Module):
def __init__(self):
super().__init__()
self.linear = flow.nn.Linear(3, 8, False)

def forward(self, x):
y0 = self.linear(x)
eager_t = flow.tensor([1.0], dtype=y0.dtype, device=y0.device)
out = y0 + eager_t
return out

my_net_module = MyModuleWithEagerTensorForward()
flow.nn.init.constant_(my_net_module.linear.weight, 2.3)
x = np.random.randn(5, 3)
Expand Down Expand Up @@ -84,6 +83,38 @@ def build(self):
np.allclose(graph_out.numpy(), eager_out.numpy(), atol=1e-4, rtol=1e-4)
)

def test_two_graph_caught_same_free_eager_tensor(test_case):
np_x = np.random.randn(5, 3)
np_y = np.random.randn(5, 3)
x = flow.tensor(np_x, dtype=flow.float32)
y = flow.tensor(np_y, dtype=flow.float32)

class GraphAdd(flow.nn.Graph):
def __init__(self):
super().__init__()

def build(self):
return x + y

class GraphMul(flow.nn.Graph):
def __init__(self):
super().__init__()

def build(self):
return x * y

g_add = GraphAdd()
g_mul = GraphMul()

add_out = g_add()
mul_out = g_mul()
test_case.assertTrue(
np.allclose(add_out.numpy(), np_x + np_y, atol=1e-4, rtol=1e-4)
)
test_case.assertTrue(
np.allclose(mul_out.numpy(), np_x * np_y, atol=1e-4, rtol=1e-4)
)


@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
@flow.unittest.skip_unless_1n2d()
Expand Down