Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Aug 28, 2023
1 parent 86233ae commit 8dd6f9c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ static std::optional<StringError> on_exit(scf::ForOp &op,
int64_t getElementTypeSize(mlir::Type elementType) {
if (auto integerType = mlir::dyn_cast<mlir::IntegerType>(elementType)) {
auto width = integerType.getWidth();
return std::ceil(width / 8);
return std::ceil((double)width / 8);
}
if (mlir::dyn_cast<mlir::IndexType>(elementType)) {
return 8;
Expand Down Expand Up @@ -199,7 +199,7 @@ static std::optional<StringError> on_enter(mlir::Operation *op,

// the search would be faster if we use an unsorted_set, but we need a hash
// function for mlir::Value
if (std::find(visited.begin(), visited.end(), lastVisitedBuffer) !=
if (std::find(visited.begin(), visited.end(), lastVisitedBuffer) ==
visited.end()) {
visited.push_back(lastVisitedBuffer);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,8 @@ def test_crt_decomposition_feedback():
return %result: tensor<4x2x!FHE.eint<6>> loc("some/random/location.py":10:2)
}
""",
{'loc("some/random/location.py":10:2)': 1187392},
# 4*4*4097*8 (input1) + 4*2 (input2) + 4*2*4097*8 + 4097*3*8 + 4096*8 + 869*8 (temporary buffers) + 4*2*4097*8 (output buffer) + 64*8 (constant TLU)
{'loc("some/random/location.py":10:2)': 1187400},
id="single location",
),
pytest.param(
Expand All @@ -466,8 +467,12 @@ def test_crt_decomposition_feedback():
}
""",
{
'loc("@matmul some/random/location.py":10:2)': 852176,
# 4*4*4097*8 (input1) + 4*2 (input2) + 4*2*4097*8 (matmul result buffer) + 4097*2*8 (temporary buffers)
'loc("@matmul some/random/location.py":10:2)': 852184,
# 4*2*4097*8 + 4*2 (input2) + *2*4097*8 + 4097*3*8 + 4096*8 + 869*8 (temporary buffers) + 4*2*4097*8 (output buffer) + 64*8 (constant TLU)
# 4*2*4097*8 (matmul result buffer) + 4*2*4097*8 (result buffer) + 4097*8 + 4096*8 + 869*8 (temporary buffers) + 64*8 (constant TLU)
'loc("@lut some/random/location.py":11:2)': 597424,
# 4*2*4097*8 (result buffer)
'loc("@return some/random/location.py":12:2)': 262208,
},
id="multiple location",
Expand Down

0 comments on commit 8dd6f9c

Please sign in to comment.