Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] GradientCell Relay Pass #5039

Merged
merged 30 commits into from
Mar 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
3a929be
save
MarisaKirisame Feb 14, 2020
f68955e
gradient.rly
hypercubestart Feb 28, 2020
dfb00ee
fix
MarisaKirisame Feb 28, 2020
75ca326
NOT WORKING: gradient cell pass
hypercubestart Mar 3, 2020
7113720
test gradient pass
hypercubestart Mar 3, 2020
16d63d7
fixed basic call ops
hypercubestart Mar 6, 2020
a083450
more tests
hypercubestart Mar 11, 2020
67a6b01
fix bug
MarisaKirisame Mar 11, 2020
949aa2b
transform calls to one ones_like zero zero_like
hypercubestart Mar 11, 2020
ca15729
maintenance stuff
hypercubestart Mar 11, 2020
182fbdc
fix linting
hypercubestart Mar 11, 2020
109b288
linting
hypercubestart Mar 11, 2020
9115710
linting
hypercubestart Mar 11, 2020
8c7f4e8
throw default
hypercubestart Mar 11, 2020
bbd0a45
remove unrelated changes
hypercubestart Mar 12, 2020
921a03c
import gradent.rly in pass
hypercubestart Mar 12, 2020
02563fb
comment
hypercubestart Mar 12, 2020
e81d0bd
linting
hypercubestart Mar 12, 2020
2a7968c
remove changes to test files
hypercubestart Mar 12, 2020
4f504c1
move gradient_cell.cc to transforms
hypercubestart Mar 12, 2020
0f14861
revert change
hypercubestart Mar 21, 2020
fda4fdf
update files with new commits
hypercubestart Mar 21, 2020
88e6744
type
hypercubestart Mar 21, 2020
c614857
wrapper function to main outermost function type
hypercubestart Mar 22, 2020
0955681
fix linting
hypercubestart Mar 22, 2020
40e629c
fix unsigned and signed int comparison
hypercubestart Mar 22, 2020
4d4b350
review
hypercubestart Mar 23, 2020
953791b
GetConstructor definition in module and change op comparison
hypercubestart Mar 23, 2020
3e18cde
update node instantiations
hypercubestart Mar 23, 2020
4c15056
increase code readability
hypercubestart Mar 24, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,14 @@ class IRModuleNode : public Object {
*/
TVM_DLL Array<GlobalTypeVar> GetGlobalTypeVars() const;

/*!
* \brief Find constructor of ADT using name
* \param adt name of the ADT the constructor belongs to
* \param cons name of the constructor
* \returns Constructor of ADT, error if not found
*/
TVM_DLL Constructor GetConstructor(const std::string& adt, const std::string& cons) const;

/*!
* \brief Look up a global function by its variable.
* \param var The global var to lookup.
Expand Down
14 changes: 14 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,20 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
*/
TVM_DLL Pass DeadCodeElimination(bool inline_once = false);

/*!
* \brief Convert all expressions of TensorType into GradCell,
* an algebraic data type defined in gradient.rly.
*
* This will delay or decrease memory usage. All calls to
* ones, ones_like, zeros, zeros_like will not immediately instantiate a tensor in memory,
* rather only instantiate if needed. It also defines + and * operation
* between GradCell types which can increase performance when using
* zero-filled or one-filled tensors, which is the case in reverse mode ad.
*
* \return the pass
*/
TVM_DLL Pass LazyGradientInit();

/*!
* \brief Fold constant expressions.
*
Expand Down
55 changes: 55 additions & 0 deletions python/tvm/relay/std/gradient.rly
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
v0.0.4

/*
* Store the Gradient Value of a Tensor of type T.
* Note that Gradient of T is stored inside a Ref(GradCell[T]) instead of GradCell[T].
*/
type GradCell[T] {
Raw(T),
One(fn() -> T),
Zero(fn() -> T)
}

def @FromGradCell[T](%g: GradCell[T]) -> T {
match (%g) {
Raw(%x) => %x,
One(%x) => %x(),
Zero(%x) => %x()
}
}

def @MultiplyGradCell[T](%multiply: fn(T, T) -> T, %l: GradCell[T], %r: GradCell[T]) -> GradCell[T] {
match((%l, %r)) {
(Zero(_), _) => %l,
(_, Zero(_)) => %r,
(One(_), _) => %r,
(_, One(_)) => %l,
_ => Raw(%multiply(@FromGradCell(%l), @FromGradCell(%r)))
}
}

def @AddGradCell[T](%add: fn(T, T) -> T, %l: GradCell[T], %r: GradCell[T]) -> GradCell[T] {
match ((%l, %r)) {
(Zero(_), _) => %r,
(_, Zero(_)) => %l,
_ => Raw(%add(@FromGradCell(%l), @FromGradCell(%r)))
}
}
13 changes: 13 additions & 0 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,19 @@ def DeadCodeElimination(inline_once=False):
"""
return _ffi_api.DeadCodeElimination(inline_once)

def LazyGradientInit():
"""Reduces memory usage of gradient tensors

Parameters
----------

Returns
-------
ret: tvm.relay.Pass
A pass which delays and/or reduces memory allocation,
by lazily allocating 0 or one filled tensors.
"""
return _ffi_api.LazyGradientInit()

def FoldConstant():
"""Fold the constant expressions in a Relay program.
Expand Down
12 changes: 12 additions & 0 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,18 @@ GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const std::string& name) const {
return (*it).second;
}

Constructor IRModuleNode::GetConstructor(const std::string& adt, const std::string& cons) const {
TypeData typeDef = this->LookupTypeDef(adt);
for (Constructor c : typeDef->constructors) {
if (cons.compare(c->name_hint) == 0) {
return c;
}
}

LOG(FATAL) << adt << " does not contain constructor " << cons;
throw std::runtime_error("Constructor Not Found.");
}

tvm::Array<GlobalTypeVar> IRModuleNode::GetGlobalTypeVars() const {
std::vector<GlobalTypeVar> global_type_vars;
for (const auto& pair : global_type_var_map_) {
Expand Down
Loading