Skip to content

Commit

Permalink
Addressing comments
Browse files Browse the repository at this point in the history
Change-Id: If9f1ee190690f9a810fe41eb1933d736f1eb4ec3
  • Loading branch information
Giuseppe Rossini committed Jun 7, 2021
1 parent 6075b68 commit f816b93
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 12 deletions.
19 changes: 9 additions & 10 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,17 +137,16 @@ class AOTOnDemandAllocator : public ExprVisitor {

void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "if is not supported."; }

void VisitExpr_(const LetNode* op) final { LOG(FATAL) << "if is not supported."; }
void VisitExpr_(const LetNode* op) final { LOG(FATAL) << "let is not supported."; }

private:
void AssignReturnSid(Expr e) {
if (storage_device_map_.find(e) != storage_device_map_.end()) {
auto buffers = storage_device_map_[e];
std::vector<int> return_ids;
return_ids_.clear();
for (auto buffer : buffers) {
return_ids.push_back(buffer.sid);
return_ids_.push_back(buffer.sid);
}
return_ids_ = return_ids;
}
}
/*!
Expand All @@ -163,7 +162,7 @@ class AOTOnDemandAllocator : public ExprVisitor {
* \param prototype The prototype token.
* \return The required memory size.
*/
size_t GetMemorySize(const TensorTypeNode* ttype) {
size_t GetMemorySizeBytes(const TensorTypeNode* ttype) {
ICHECK(ttype != nullptr);
size_t size = 1;
for (IndexExpr dim : ttype->shape) {
Expand Down Expand Up @@ -200,17 +199,17 @@ class AOTOnDemandAllocator : public ExprVisitor {
const auto* ttype = t.as<TensorTypeNode>();
ICHECK(ttype);
StorageInfo buffer;
buffer.sid = sid_++;
buffer.size_bytes = GetMemorySize(ttype);
buffer.sid = next_available_sid_++;
buffer.size_bytes = GetMemorySizeBytes(ttype);
buffer.dev_type = device_type;
buffers.push_back(buffer);
}
} else {
const auto* ttype = op->checked_type().as<TensorTypeNode>();
ICHECK(ttype);
StorageInfo buffer;
buffer.sid = sid_++;
buffer.size_bytes = GetMemorySize(ttype);
buffer.sid = next_available_sid_++;
buffer.size_bytes = GetMemorySizeBytes(ttype);
buffer.dev_type = device_type;
buffers.push_back(buffer);
}
Expand All @@ -221,7 +220,7 @@ class AOTOnDemandAllocator : public ExprVisitor {
/*! \brief mapping of expression -> device type*/
Map<Expr, Integer> node_device_map_;
/*! \brief current id of the temporary allocated*/
int sid_{0};
int next_available_sid_{0};
/*! \brief the set of intermediate tensors that are return variables */
std::vector<int> return_ids_;
};
Expand Down
10 changes: 8 additions & 2 deletions src/tir/transforms/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
// Recall that the arguments of a tvm_call_cpacked are passed as
// TVMValues. But a TVMValue is only a container, that points to
// a real buffer previously allocated. We need to signal that those
// buffers need to be live at the same time (i.e., cannot be overridden)
// buffers need to be live at the same time (i.e., cannot be overwritten during the function
// call)
Array<PrimExpr> args = op->args;
for (auto arg : args) {
const VarNode* var = arg.as<VarNode>();
Expand Down Expand Up @@ -234,7 +235,12 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
bool in_thread_env_{false};
// The scope stack.
std::vector<StmtEntry> scope_;
// This is a map to connect TVMValues to real allocations
// This is a map to connect TVMValues to real allocations. When we pass parameters
// to a tvm_call_cpacked, the data needs to be wrapped in a TVMValue. The wrapping
// happens through the tvm_struct_set built-in. This map is mapping the variable
// representing the TVMValue to the variable representing the real buffer. The live
// analysis needs to happen on the latter and not on the TVMValue which only acts as
// a container.
std::unordered_map<const VarNode*, std::vector<const VarNode*>> value_to_alloc_;
};

Expand Down
5 changes: 5 additions & 0 deletions tests/python/relay/aot/test_crt_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,9 @@ def test_byoc_utvm(use_calculated_workspaces, target_options):


def test_quant_mobilenet_tfl():
"""Since in AOT we pass directly the output buffer from the user, in quantized networks sharing the output buffers is not possible.
This is because the output data type is int8 and the intermediate buffer are int32 or int16. We use mobilenet quantized to stress this
situation and verify that the output buffer sharing is disabled in AOT."""
pytest.importorskip("tflite")

import tvm.relay.testing.tf as tf_testing
Expand All @@ -410,6 +413,8 @@ def test_quant_mobilenet_tfl():

@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"])
def test_transpose(target_options):
"""Test that non-inpleaceable operations (e.g., transpose) do not happen in-place."""

dtype = "float32"
x = relay.var("x", shape=(10, 5), dtype=dtype)
y = relay.var("y", shape=(10, 5), dtype=dtype)
Expand Down

0 comments on commit f816b93

Please sign in to comment.