Browse Source

Rework the switch module in cranelift-frontend in terms of brif (#5644)

Rework the compilation strategy for switch to:
* use brif instead of brz and brnz
* generate tables inline, rather than delyaing them to after the decision tree has been generated
* avoid allocating new vectors by using slices into the sorted contiguous ranges
* avoid generating some unconditional jumps
* output differences in test output using the similar crate for easier debugging
pull/5651/head
Trevor Elliott 2 years ago
committed by GitHub
parent
commit
b47006d432
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      Cargo.lock
  2. 3
      cranelift/frontend/Cargo.toml
  3. 334
      cranelift/frontend/src/switch.rs

1
Cargo.lock

@ -622,6 +622,7 @@ dependencies = [
"cranelift-codegen",
"hashbrown",
"log",
"similar",
"smallvec",
"target-lexicon",
]

3
cranelift/frontend/Cargo.toml

@ -17,6 +17,9 @@ log = { workspace = true }
hashbrown = { workspace = true, optional = true }
smallvec = { workspace = true }
[dev-dependencies]
similar = { workspace = true }
[features]
default = ["std"]
std = ["cranelift-codegen/std"]

334
cranelift/frontend/src/switch.rs

@ -108,27 +108,19 @@ impl Switch {
}
/// Binary search for the right `ContiguousCaseRange`.
fn build_search_tree(
fn build_search_tree<'a>(
bx: &mut FunctionBuilder,
val: Value,
otherwise: Block,
contiguous_case_ranges: Vec<ContiguousCaseRange>,
) -> Vec<(EntryIndex, Block, Vec<Block>)> {
let mut cases_and_jt_blocks = Vec::new();
contiguous_case_ranges: &'a [ContiguousCaseRange],
) {
// Avoid allocation in the common case
if contiguous_case_ranges.len() <= 3 {
Self::build_search_branches(
bx,
val,
otherwise,
contiguous_case_ranges,
&mut cases_and_jt_blocks,
);
return cases_and_jt_blocks;
Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges);
return;
}
let mut stack: Vec<(Option<Block>, Vec<ContiguousCaseRange>)> = Vec::new();
let mut stack = Vec::new();
stack.push((None, contiguous_case_ranges));
while let Some((block, contiguous_case_ranges)) = stack.pop() {
@ -137,17 +129,10 @@ impl Switch {
}
if contiguous_case_ranges.len() <= 3 {
Self::build_search_branches(
bx,
val,
otherwise,
contiguous_case_ranges,
&mut cases_and_jt_blocks,
);
Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges);
} else {
let split_point = contiguous_case_ranges.len() / 2;
let mut left = contiguous_case_ranges;
let right = left.split_off(split_point);
let (left, right) = contiguous_case_ranges.split_at(split_point);
let left_block = bx.create_block();
let right_block = bx.create_block();
@ -155,8 +140,8 @@ impl Switch {
let first_index = right[0].first_index;
let should_take_right_side =
icmp_imm_u128(bx, IntCC::UnsignedGreaterThanOrEqual, val, first_index);
bx.ins().brnz(should_take_right_side, right_block, &[]);
bx.ins().jump(left_block, &[]);
bx.ins()
.brif(should_take_right_side, right_block, &[], left_block, &[]);
bx.seal_block(left_block);
bx.seal_block(right_block);
@ -165,126 +150,108 @@ impl Switch {
stack.push((Some(right_block), right));
}
}
cases_and_jt_blocks
}
/// Linear search for the right `ContiguousCaseRange`.
fn build_search_branches(
fn build_search_branches<'a>(
bx: &mut FunctionBuilder,
val: Value,
otherwise: Block,
contiguous_case_ranges: Vec<ContiguousCaseRange>,
cases_and_jt_blocks: &mut Vec<(EntryIndex, Block, Vec<Block>)>,
contiguous_case_ranges: &'a [ContiguousCaseRange],
) {
let mut was_branch = false;
let ins_fallthrough_jump = |was_branch: bool, bx: &mut FunctionBuilder| {
if was_branch {
let block = bx.create_block();
bx.ins().jump(block, &[]);
bx.seal_block(block);
bx.switch_to_block(block);
}
};
for ContiguousCaseRange {
first_index,
blocks,
} in contiguous_case_ranges.into_iter().rev()
{
match (blocks.len(), first_index) {
(1, 0) => {
ins_fallthrough_jump(was_branch, bx);
bx.ins().brz(val, blocks[0], &[]);
}
(1, _) => {
ins_fallthrough_jump(was_branch, bx);
let is_good_val = icmp_imm_u128(bx, IntCC::Equal, val, first_index);
bx.ins().brnz(is_good_val, blocks[0], &[]);
}
(_, 0) => {
// if `first_index` is 0, then `icmp_imm uge val, first_index` is trivially true
let jt_block = bx.create_block();
bx.ins().jump(jt_block, &[]);
bx.seal_block(jt_block);
cases_and_jt_blocks.push((first_index, jt_block, blocks));
// `jump otherwise` below must not be hit, because the current block has been
// filled above. This is the last iteration anyway, as 0 is the smallest
// unsigned int, so just return here.
return;
let last_ix = contiguous_case_ranges.len() - 1;
for (ix, range) in contiguous_case_ranges.iter().rev().enumerate() {
let alternate = if ix == last_ix {
otherwise
} else {
bx.create_block()
};
if range.first_index == 0 {
assert_eq!(alternate, otherwise);
if let Some(block) = range.single_block() {
bx.ins().brif(val, otherwise, &[], block, &[]);
} else {
Self::build_jump_table(bx, val, otherwise, 0, &range.blocks);
}
(_, _) => {
ins_fallthrough_jump(was_branch, bx);
} else {
if let Some(block) = range.single_block() {
let is_good_val = icmp_imm_u128(bx, IntCC::Equal, val, range.first_index);
bx.ins().brif(is_good_val, block, &[], alternate, &[]);
} else {
let is_good_val = icmp_imm_u128(
bx,
IntCC::UnsignedGreaterThanOrEqual,
val,
range.first_index,
);
let jt_block = bx.create_block();
let is_good_val =
icmp_imm_u128(bx, IntCC::UnsignedGreaterThanOrEqual, val, first_index);
bx.ins().brnz(is_good_val, jt_block, &[]);
bx.ins().brif(is_good_val, jt_block, &[], alternate, &[]);
bx.seal_block(jt_block);
cases_and_jt_blocks.push((first_index, jt_block, blocks));
bx.switch_to_block(jt_block);
Self::build_jump_table(bx, val, otherwise, range.first_index, &range.blocks);
}
}
was_branch = true;
}
bx.ins().jump(otherwise, &[]);
if alternate != otherwise {
bx.seal_block(alternate);
bx.switch_to_block(alternate);
}
}
}
/// For every item in `cases_and_jt_blocks` this will create a jump table in the specified block.
fn build_jump_tables(
fn build_jump_table(
bx: &mut FunctionBuilder,
val: Value,
otherwise: Block,
cases_and_jt_blocks: Vec<(EntryIndex, Block, Vec<Block>)>,
first_index: EntryIndex,
blocks: &[Block],
) {
for (first_index, jt_block, blocks) in cases_and_jt_blocks.into_iter().rev() {
// There are currently no 128bit systems supported by rustc, but once we do ensure that
// we don't silently ignore a part of the jump table for 128bit integers on 128bit systems.
assert!(
u32::try_from(blocks.len()).is_ok(),
"Jump tables bigger than 2^32-1 are not yet supported"
);
// There are currently no 128bit systems supported by rustc, but once we do ensure that
// we don't silently ignore a part of the jump table for 128bit integers on 128bit systems.
assert!(
u32::try_from(blocks.len()).is_ok(),
"Jump tables bigger than 2^32-1 are not yet supported"
);
let mut jt_data = JumpTableData::new();
for block in blocks {
jt_data.push_entry(block);
}
let jump_table = bx.create_jump_table(jt_data);
let jt_data = JumpTableData::with_blocks(Vec::from(blocks));
let jump_table = bx.create_jump_table(jt_data);
bx.switch_to_block(jt_block);
let discr = if first_index == 0 {
val
let discr = if first_index == 0 {
val
} else {
if let Ok(first_index) = u64::try_from(first_index) {
bx.ins().iadd_imm(val, (first_index as i64).wrapping_neg())
} else {
if let Ok(first_index) = u64::try_from(first_index) {
bx.ins().iadd_imm(val, (first_index as i64).wrapping_neg())
} else {
let (lsb, msb) = (first_index as u64, (first_index >> 64) as u64);
let lsb = bx.ins().iconst(types::I64, lsb as i64);
let msb = bx.ins().iconst(types::I64, msb as i64);
let index = bx.ins().iconcat(lsb, msb);
bx.ins().isub(val, index)
}
};
let (lsb, msb) = (first_index as u64, (first_index >> 64) as u64);
let lsb = bx.ins().iconst(types::I64, lsb as i64);
let msb = bx.ins().iconst(types::I64, msb as i64);
let index = bx.ins().iconcat(lsb, msb);
bx.ins().isub(val, index)
}
};
let discr = match bx.func.dfg.value_type(discr).bits() {
bits if bits > 32 => {
// Check for overflow of cast to u32. This is the max supported jump table entries.
let new_block = bx.create_block();
let bigger_than_u32 =
bx.ins()
.icmp_imm(IntCC::UnsignedGreaterThan, discr, u32::MAX as i64);
bx.ins().brnz(bigger_than_u32, otherwise, &[]);
bx.ins().jump(new_block, &[]);
bx.seal_block(new_block);
bx.switch_to_block(new_block);
// Cast to i32, as br_table is not implemented for i64/i128
bx.ins().ireduce(types::I32, discr)
}
bits if bits < 32 => bx.ins().uextend(types::I32, discr),
_ => discr,
};
let discr = match bx.func.dfg.value_type(discr).bits() {
bits if bits > 32 => {
// Check for overflow of cast to u32. This is the max supported jump table entries.
let new_block = bx.create_block();
let bigger_than_u32 =
bx.ins()
.icmp_imm(IntCC::UnsignedGreaterThan, discr, u32::MAX as i64);
bx.ins()
.brif(bigger_than_u32, otherwise, &[], new_block, &[]);
bx.seal_block(new_block);
bx.switch_to_block(new_block);
// Cast to i32, as br_table is not implemented for i64/i128
bx.ins().ireduce(types::I32, discr)
}
bits if bits < 32 => bx.ins().uextend(types::I32, discr),
_ => discr,
};
bx.ins().br_table(discr, otherwise, jump_table);
}
bx.ins().br_table(discr, otherwise, jump_table);
}
/// Build the switch
@ -307,9 +274,7 @@ impl Switch {
}
let contiguous_case_ranges = self.collect_contiguous_case_ranges();
let cases_and_jt_blocks =
Self::build_search_tree(bx, val, otherwise, contiguous_case_ranges);
Self::build_jump_tables(bx, val, otherwise, cases_and_jt_blocks);
Self::build_search_tree(bx, val, otherwise, &contiguous_case_ranges);
}
}
@ -351,6 +316,15 @@ impl ContiguousCaseRange {
blocks: Vec::new(),
}
}
/// Returns `Some` block when there is only a single block in this range.
fn single_block(&self) -> Option<Block> {
if self.blocks.len() == 1 {
Some(self.blocks[0])
} else {
None
}
}
}
#[cfg(test)]
@ -384,43 +358,52 @@ mod tests {
}};
}
macro_rules! assert_eq_output {
($actual:ident, $expected:literal) => {
if $actual != $expected {
assert!(
false,
"\n{}",
similar::TextDiff::from_lines($expected, &$actual)
.unified_diff()
.header("expected", "actual")
);
}
};
}
#[test]
fn switch_zero() {
let func = setup!(0, [0,]);
assert_eq!(
assert_eq_output!(
func,
"block0:
v0 = iconst.i8 0
brz v0, block1 ; v0 = 0
jump block0"
brif v0, block0, block1 ; v0 = 0"
);
}
#[test]
fn switch_single() {
let func = setup!(0, [1,]);
assert_eq!(
assert_eq_output!(
func,
"block0:
v0 = iconst.i8 0
v1 = icmp_imm eq v0, 1 ; v0 = 0
brnz v1, block1
jump block0"
brif v1, block1, block0"
);
}
#[test]
fn switch_bool() {
let func = setup!(0, [0, 1,]);
assert_eq!(
assert_eq_output!(
func,
" jt0 = jump_table [block1, block2]
block0:
v0 = iconst.i8 0
jump block3
block3:
v1 = uextend.i32 v0 ; v0 = 0
br_table v1, block0, jt0"
);
@ -429,56 +412,50 @@ block3:
#[test]
fn switch_two_gap() {
let func = setup!(0, [0, 2,]);
assert_eq!(
assert_eq_output!(
func,
"block0:
v0 = iconst.i8 0
v1 = icmp_imm eq v0, 2 ; v0 = 0
brnz v1, block2
jump block3
brif v1, block2, block3
block3:
brz.i8 v0, block1 ; v0 = 0
jump block0"
brif.i8 v0, block0, block1 ; v0 = 0"
);
}
#[test]
fn switch_many() {
let func = setup!(0, [0, 1, 5, 7, 10, 11, 12,]);
assert_eq!(
assert_eq_output!(
func,
" jt0 = jump_table [block1, block2]
jt1 = jump_table [block5, block6, block7]
" jt0 = jump_table [block5, block6, block7]
jt1 = jump_table [block1, block2]
block0:
v0 = iconst.i8 0
v1 = icmp_imm uge v0, 7 ; v0 = 0
brnz v1, block9
jump block8
brif v1, block9, block8
block9:
v2 = icmp_imm.i8 uge v0, 10 ; v0 = 0
brnz v2, block10
jump block11
brif v2, block11, block10
block11:
v3 = icmp_imm.i8 eq v0, 7 ; v0 = 0
brnz v3, block4
jump block0
v3 = iadd_imm.i8 v0, -10 ; v0 = 0
v4 = uextend.i32 v3
br_table v4, block0, jt0
block10:
v5 = icmp_imm.i8 eq v0, 7 ; v0 = 0
brif v5, block4, block0
block8:
v4 = icmp_imm.i8 eq v0, 5 ; v0 = 0
brnz v4, block3
jump block12
v6 = icmp_imm.i8 eq v0, 5 ; v0 = 0
brif v6, block3, block12
block12:
v5 = uextend.i32 v0 ; v0 = 0
br_table v5, block0, jt0
block10:
v6 = iadd_imm.i8 v0, -10 ; v0 = 0
v7 = uextend.i32 v6
v7 = uextend.i32 v0 ; v0 = 0
br_table v7, block0, jt1"
);
}
@ -486,51 +463,46 @@ block10:
#[test]
fn switch_min_index_value() {
let func = setup!(0, [i8::MIN as u8 as u128, 1,]);
assert_eq!(
assert_eq_output!(
func,
"block0:
v0 = iconst.i8 0
v1 = icmp_imm eq v0, 128 ; v0 = 0
brnz v1, block1
jump block3
brif v1, block1, block3
block3:
v2 = icmp_imm.i8 eq v0, 1 ; v0 = 0
brnz v2, block2
jump block0"
brif v2, block2, block0"
);
}
#[test]
fn switch_max_index_value() {
let func = setup!(0, [i8::MAX as u8 as u128, 1,]);
assert_eq!(
assert_eq_output!(
func,
"block0:
v0 = iconst.i8 0
v1 = icmp_imm eq v0, 127 ; v0 = 0
brnz v1, block1
jump block3
brif v1, block1, block3
block3:
v2 = icmp_imm.i8 eq v0, 1 ; v0 = 0
brnz v2, block2
jump block0"
brif v2, block2, block0"
)
}
#[test]
fn switch_optimal_codegen() {
let func = setup!(0, [-1i8 as u8 as u128, 0, 1,]);
assert_eq!(
assert_eq_output!(
func,
" jt0 = jump_table [block2, block3]
block0:
v0 = iconst.i8 0
v1 = icmp_imm eq v0, 255 ; v0 = 0
brnz v1, block1
jump block4
brif v1, block1, block4
block4:
v2 = uextend.i32 v0 ; v0 = 0
@ -617,20 +589,16 @@ block4:
.trim_start_matches("function u0:0() fast {\n")
.trim_end_matches("\n}\n")
.to_string();
assert_eq!(
assert_eq_output!(
func,
" jt0 = jump_table [block2, block1]
block0:
v0 = iconst.i64 0
jump block4
v1 = icmp_imm ugt v0, 0xffff_ffff ; v0 = 0
brif v1, block3, block4
block4:
v1 = icmp_imm.i64 ugt v0, 0xffff_ffff ; v0 = 0
brnz v1, block3
jump block5
block5:
v2 = ireduce.i32 v0 ; v0 = 0
br_table v2, block3, jt0"
);
@ -659,21 +627,17 @@ block5:
.trim_start_matches("function u0:0() fast {\n")
.trim_end_matches("\n}\n")
.to_string();
assert_eq!(
assert_eq_output!(
func,
" jt0 = jump_table [block2, block1]
block0:
v0 = iconst.i64 0
v1 = uextend.i128 v0 ; v0 = 0
jump block4
v2 = icmp_imm ugt v1, 0xffff_ffff
brif v2, block3, block4
block4:
v2 = icmp_imm.i128 ugt v1, 0xffff_ffff
brnz v2, block3
jump block5
block5:
v3 = ireduce.i32 v1
br_table v3, block3, jt0"
);

Loading…
Cancel
Save