Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't allocate trailing uninit bits in the InitMap of CTFE Allocations #94936

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 67 additions & 37 deletions compiler/rustc_middle/src/mir/interpret/allocation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ impl<Tag> Allocation<Tag> {
Self {
bytes,
relocations: Relocations::new(),
init_mask: InitMask::new(size, true),
init_mask: InitMask::new_init(size),
align,
mutability,
extra: (),
Expand Down Expand Up @@ -180,7 +180,7 @@ impl<Tag> Allocation<Tag> {
Ok(Allocation {
bytes,
relocations: Relocations::new(),
init_mask: InitMask::new(size, false),
init_mask: InitMask::new_uninit(size),
align,
mutability: Mutability::Mut,
extra: (),
Expand Down Expand Up @@ -629,15 +629,19 @@ impl InitMask {
Size::from_bytes(block * InitMask::BLOCK_SIZE + bit)
}

pub fn new(size: Size, state: bool) -> Self {
pub fn new_init(size: Size) -> Self {
let mut m = InitMask { blocks: vec![], len: Size::ZERO };
m.grow(size, state);
m.grow(size, true);
m
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could keep the old API by using if state { ... } else { ... } -- that seems cleaner (not bothering the users with a somewhat cumbersome API)?

}

pub fn new_uninit(size: Size) -> Self {
InitMask { blocks: vec![], len: size }
}

pub fn set_range(&mut self, start: Size, end: Size, new_state: bool) {
let len = self.len;
if end > len {
if end > len && new_state {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit surprised that we allow OOB indices at all here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I think set_range_inbounds is called in situations where the index is in-bounds of the logical size of the InitMap, but OOB of its actual size due to leaving off the tail. So I think the growing logic needs to move to set_range_inbounds.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, you have ensure_blocks for that. But then we have two growing logics? That seems odd.

self.grow(end - len, new_state);
}
self.set_range_inbounds(start, end, new_state);
Expand All @@ -655,14 +659,16 @@ impl InitMask {
(u64::MAX << bita) & (u64::MAX >> (64 - bitb))
};
if new_state {
self.ensure_blocks(blocka);
self.blocks[blocka] |= range;
} else {
self.blocks[blocka] &= !range;
} else if let Some(block) = self.blocks.get_mut(blocka) {
*block &= !range;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment saying why it is okay to do nothing here if get_mut returns None.

}
return;
}
// across block boundaries
if new_state {
self.ensure_blocks(blockb);
// Set `bita..64` to `1`.
self.blocks[blocka] |= u64::MAX << bita;
// Set `0..bitb` to `1`.
Expand All @@ -673,15 +679,17 @@ impl InitMask {
for block in (blocka + 1)..blockb {
self.blocks[block] = u64::MAX;
}
} else {
} else if let Some(blocka_val) = self.blocks.get_mut(blocka) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, please add a comment saying why it is okay to do nothing here if get_mut returns None.

// Set `bita..64` to `0`.
self.blocks[blocka] &= !(u64::MAX << bita);
*blocka_val &= !(u64::MAX << bita);
// Set `0..bitb` to `0`.
if bitb != 0 {
self.blocks[blockb] &= !(u64::MAX >> (64 - bitb));
if let Some(blockb_val) = self.blocks.get_mut(blockb) {
*blockb_val &= !(u64::MAX >> (64 - bitb));
}
}
// Fill in all the other blocks (much faster than one bit at a time).
for block in (blocka + 1)..blockb {
for block in (blocka + 1)..std::cmp::min(blockb, self.blocks.len()) {
self.blocks[block] = 0;
}
}
Expand All @@ -690,7 +698,10 @@ impl InitMask {
#[inline]
pub fn get(&self, i: Size) -> bool {
let (block, bit) = Self::bit_index(i);
(self.blocks[block] & (1 << bit)) != 0
match self.blocks.get(block) {
Some(block) => (*block & (1 << bit)) != 0,
None => false,
}
}

#[inline]
Expand All @@ -702,10 +713,22 @@ impl InitMask {
#[inline]
fn set_bit(&mut self, block: usize, bit: usize, new_state: bool) {
if new_state {
self.ensure_blocks(block);
self.blocks[block] |= 1 << bit;
} else {
self.blocks[block] &= !(1 << bit);
} else if let Some(block) = self.blocks.get_mut(block) {
*block &= !(1 << bit);
}
}

fn ensure_blocks(&mut self, block: usize) {
if block < self.blocks.len() {
return;
}
let additional_blocks = block - self.blocks.len() + 1;
self.blocks.extend(
// FIXME(oli-obk): optimize this by repeating `new_state as Block`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand this FIXME. I see it just got moved around but still, it should be clarified or removed IMO.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, basically instead of filling with uninit and then setting all of them to initialized, we can immediately fill with init.

iter::repeat(0).take(usize::try_from(additional_blocks).unwrap()),
);
}

pub fn grow(&mut self, amount: Size, new_state: bool) {
Expand All @@ -716,10 +739,7 @@ impl InitMask {
u64::try_from(self.blocks.len()).unwrap() * Self::BLOCK_SIZE - self.len.bytes();
if amount.bytes() > unused_trailing_bits {
let additional_blocks = amount.bytes() / Self::BLOCK_SIZE + 1;
self.blocks.extend(
// FIXME(oli-obk): optimize this by repeating `new_state as Block`.
iter::repeat(0).take(usize::try_from(additional_blocks).unwrap()),
);
self.ensure_blocks(self.blocks.len() + additional_blocks as usize - 1);
}
let start = self.len;
self.len += amount;
Expand Down Expand Up @@ -821,25 +841,31 @@ impl InitMask {
// (c) 01000000|00000000|00000001
// ^~~~~~~~~~~~~~~~~~^
// start end
if let Some(i) =
search_block(init_mask.blocks[start_block], start_block, start_bit, is_init)
{
// If the range is less than a block, we may find a matching bit after `end`.
//
// For example, we shouldn't successfully find bit (2), because it's after `end`:
//
// (2)
// -------|
// (d) 00000001|00000000|00000001
// ^~~~~^
// start end
//
// An alternative would be to mask off end bits in the same way as we do for start bits,
// but performing this check afterwards is faster and simpler to implement.
if i < end {
return Some(i);
} else {
if let Some(&bits) = init_mask.blocks.get(start_block) {
if let Some(i) = search_block(bits, start_block, start_bit, is_init) {
// If the range is less than a block, we may find a matching bit after `end`.
//
// For example, we shouldn't successfully find bit (2), because it's after `end`:
//
// (2)
// -------|
// (d) 00000001|00000000|00000001
// ^~~~~^
// start end
//
// An alternative would be to mask off end bits in the same way as we do for start bits,
// but performing this check afterwards is faster and simpler to implement.
if i < end {
return Some(i);
} else {
return None;
}
}
} else {
if is_init {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This deserves a comment explaining the case it's handling--trailing uninit may not be allocated, so if the start block doesn't exist then it's all uninit

return None;
} else {
return Some(start);
}
}

Expand All @@ -861,7 +887,8 @@ impl InitMask {
// because both alternatives result in significantly worse codegen.
// `end_block_inclusive + 1` is guaranteed not to wrap, because `end_block_inclusive <= end / BLOCK_SIZE`,
// and `BLOCK_SIZE` (the number of bits per block) will always be at least 8 (1 byte).
for (&bits, block) in init_mask.blocks[start_block + 1..end_block_inclusive + 1]
for (&bits, block) in init_mask.blocks[start_block + 1
..std::cmp::min(end_block_inclusive + 1, init_mask.blocks.len())]
.iter()
.zip(start_block + 1..)
{
Expand All @@ -886,6 +913,9 @@ impl InitMask {
}
}
}
if !is_init && end_block_inclusive >= init_mask.blocks.len() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also deserves a comment, and perhaps ascii art about the case it's handling

return Some(InitMask::size_from_bit_index(init_mask.blocks.len(), 0));
}
}

None
Expand Down
2 changes: 1 addition & 1 deletion src/test/ui-fulldeps/uninit_mask.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use rustc_middle::mir::interpret::InitMask;
use rustc_target::abi::Size;

fn main() {
let mut mask = InitMask::new(Size::from_bytes(500), false);
let mut mask = InitMask::new_uninit(Size::from_bytes(500));
assert!(!mask.get(Size::from_bytes(499)));
mask.set(Size::from_bytes(499), true);
assert!(mask.get(Size::from_bytes(499)));
Expand Down