diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/MemoryUsage.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/MemoryUsage.cpp index 5ceb123459..7c8de9694c 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/MemoryUsage.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/Concrete/Analysis/MemoryUsage.cpp @@ -107,7 +107,7 @@ static std::optional on_exit(scf::ForOp &op, int64_t getElementTypeSize(mlir::Type elementType) { if (auto integerType = mlir::dyn_cast(elementType)) { auto width = integerType.getWidth(); - return std::ceil(width / 8); + return std::ceil((double)width / 8); } if (mlir::dyn_cast(elementType)) { return 8; @@ -199,7 +199,7 @@ static std::optional on_enter(mlir::Operation *op, // the search would be faster if we use an unsorted_set, but we need a hash // function for mlir::Value - if (std::find(visited.begin(), visited.end(), lastVisitedBuffer) != + if (std::find(visited.begin(), visited.end(), lastVisitedBuffer) == visited.end()) { visited.push_back(lastVisitedBuffer); diff --git a/compilers/concrete-compiler/compiler/tests/python/test_compilation.py b/compilers/concrete-compiler/compiler/tests/python/test_compilation.py index 9533877359..079b129724 100644 --- a/compilers/concrete-compiler/compiler/tests/python/test_compilation.py +++ b/compilers/concrete-compiler/compiler/tests/python/test_compilation.py @@ -453,7 +453,8 @@ def test_crt_decomposition_feedback(): return %result: tensor<4x2x!FHE.eint<6>> loc("some/random/location.py":10:2) } """, - {'loc("some/random/location.py":10:2)': 1187392}, + # 4*4*4097*8 (input1) + 4*2 (input2) + 4*2*4097*8 + 4097*3*8 + 4096*8 + 869*8 (temporary buffers) + 4*2*4097*8 (output buffer) + 64*8 (constant TLU) + {'loc("some/random/location.py":10:2)': 1187400}, id="single location", ), pytest.param( @@ -466,8 +467,12 @@ def test_crt_decomposition_feedback(): } """, { - 'loc("@matmul some/random/location.py":10:2)': 852176, + # 4*4*4097*8 (input1) + 4*2 (input2) + 4*2*4097*8 (matmul result buffer) + 4097*2*8 (temporary buffers) + 'loc("@matmul some/random/location.py":10:2)': 852184, + # 4*2*4097*8 + 4*2 (input2) + *2*4097*8 + 4097*3*8 + 4096*8 + 869*8 (temporary buffers) + 4*2*4097*8 (output buffer) + 64*8 (constant TLU) + # 4*2*4097*8 (matmul result buffer) + 4*2*4097*8 (result buffer) + 4097*8 + 4096*8 + 869*8 (temporary buffers) + 64*8 (constant TLU) 'loc("@lut some/random/location.py":11:2)': 597424, + # 4*2*4097*8 (result buffer) 'loc("@return some/random/location.py":12:2)': 262208, }, id="multiple location",