Skip to content

Commit

Permalink
i1 fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 24, 2025
1 parent d399df6 commit 8b1bee1
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1320,13 +1320,16 @@ DenseElementsAttr fromTensor(stablehlo::Tensor inp) {
auto elemType = type.getElementType();

int64_t bitWidth = -1;
bool isI1 = false;
if (auto inType = dyn_cast<IntegerType>(elemType)) {
// For bitwidth = 1: Packed into 8bit.
bitWidth = inType.getWidth();
if (bitWidth == 1) {
bitWidth = 8;
isI1 = true;
SmallVector<bool, 1> data;
data.reserve(inp.getNumElements());
auto v = inp.getData();
for (size_t i = 0; i < inp.getNumElements(); ++i)
data.push_back(v[i] ? 1 : 0);
return DenseElementsAttr::get(type, data);
}
}

Expand All @@ -1336,8 +1339,6 @@ DenseElementsAttr fromTensor(stablehlo::Tensor inp) {
bitWidth = bitWidth / 8;

auto size = inp.getNumElements() * bitWidth;
if (isI1)
size = llvm::alignTo<8>(size);
auto floatValues = ArrayRef(inp.getData(), size);
return DenseElementsAttr::getFromRawBuffer(type, floatValues);
}
Expand Down

0 comments on commit 8b1bee1

Please sign in to comment.