@ -108,27 +108,19 @@ impl Switch {
}
/// Binary search for the right `ContiguousCaseRange`.
fn build_search_tree (
fn build_search_tree < 'a > (
bx : & mut FunctionBuilder ,
val : Value ,
otherwise : Block ,
contiguous_case_ranges : Vec < ContiguousCaseRange > ,
) -> Vec < ( EntryIndex , Block , Vec < Block > ) > {
let mut cases_and_jt_blocks = Vec ::new ( ) ;
contiguous_case_ranges : & 'a [ ContiguousCaseRange ] ,
) {
// Avoid allocation in the common case
if contiguous_case_ranges . len ( ) < = 3 {
Self ::build_search_branches (
bx ,
val ,
otherwise ,
contiguous_case_ranges ,
& mut cases_and_jt_blocks ,
) ;
return cases_and_jt_blocks ;
Self ::build_search_branches ( bx , val , otherwise , contiguous_case_ranges ) ;
return ;
}
let mut stack : Vec < ( Option < Block > , Vec < ContiguousCaseRange > ) > = Vec ::new ( ) ;
let mut stack = Vec ::new ( ) ;
stack . push ( ( None , contiguous_case_ranges ) ) ;
while let Some ( ( block , contiguous_case_ranges ) ) = stack . pop ( ) {
@ -137,17 +129,10 @@ impl Switch {
}
if contiguous_case_ranges . len ( ) < = 3 {
Self ::build_search_branches (
bx ,
val ,
otherwise ,
contiguous_case_ranges ,
& mut cases_and_jt_blocks ,
) ;
Self ::build_search_branches ( bx , val , otherwise , contiguous_case_ranges ) ;
} else {
let split_point = contiguous_case_ranges . len ( ) / 2 ;
let mut left = contiguous_case_ranges ;
let right = left . split_off ( split_point ) ;
let ( left , right ) = contiguous_case_ranges . split_at ( split_point ) ;
let left_block = bx . create_block ( ) ;
let right_block = bx . create_block ( ) ;
@ -155,8 +140,8 @@ impl Switch {
let first_index = right [ 0 ] . first_index ;
let should_take_right_side =
icmp_imm_u128 ( bx , IntCC ::UnsignedGreaterThanOrEqual , val , first_index ) ;
bx . ins ( ) . brnz ( should_take_right_side , right_block , & [ ] ) ;
bx . ins ( ) . jump ( left_block , & [ ] ) ;
bx . ins ( )
. brif ( should_take_right_side , right_block , & [ ] , left_block , & [ ] ) ;
bx . seal_block ( left_block ) ;
bx . seal_block ( right_block ) ;
@ -165,126 +150,108 @@ impl Switch {
stack . push ( ( Some ( right_block ) , right ) ) ;
}
}
cases_and_jt_blocks
}
/// Linear search for the right `ContiguousCaseRange`.
fn build_search_branches (
fn build_search_branches < 'a > (
bx : & mut FunctionBuilder ,
val : Value ,
otherwise : Block ,
contiguous_case_ranges : Vec < ContiguousCaseRange > ,
cases_and_jt_blocks : & mut Vec < ( EntryIndex , Block , Vec < Block > ) > ,
contiguous_case_ranges : & 'a [ ContiguousCaseRange ] ,
) {
let mut was_branch = false ;
let ins_fallthrough_jump = | was_branch : bool , bx : & mut FunctionBuilder | {
if was_branch {
let block = bx . create_block ( ) ;
bx . ins ( ) . jump ( block , & [ ] ) ;
bx . seal_block ( block ) ;
bx . switch_to_block ( block ) ;
}
} ;
for ContiguousCaseRange {
first_index ,
blocks ,
} in contiguous_case_ranges . into_iter ( ) . rev ( )
{
match ( blocks . len ( ) , first_index ) {
( 1 , 0 ) = > {
ins_fallthrough_jump ( was_branch , bx ) ;
bx . ins ( ) . brz ( val , blocks [ 0 ] , & [ ] ) ;
}
( 1 , _ ) = > {
ins_fallthrough_jump ( was_branch , bx ) ;
let is_good_val = icmp_imm_u128 ( bx , IntCC ::Equal , val , first_index ) ;
bx . ins ( ) . brnz ( is_good_val , blocks [ 0 ] , & [ ] ) ;
}
( _ , 0 ) = > {
// if `first_index` is 0, then `icmp_imm uge val, first_index` is trivially true
let jt_block = bx . create_block ( ) ;
bx . ins ( ) . jump ( jt_block , & [ ] ) ;
bx . seal_block ( jt_block ) ;
cases_and_jt_blocks . push ( ( first_index , jt_block , blocks ) ) ;
// `jump otherwise` below must not be hit, because the current block has been
// filled above. This is the last iteration anyway, as 0 is the smallest
// unsigned int, so just return here.
return ;
let last_ix = contiguous_case_ranges . len ( ) - 1 ;
for ( ix , range ) in contiguous_case_ranges . iter ( ) . rev ( ) . enumerate ( ) {
let alternate = if ix = = last_ix {
otherwise
} else {
bx . create_block ( )
} ;
if range . first_index = = 0 {
assert_eq ! ( alternate , otherwise ) ;
if let Some ( block ) = range . single_block ( ) {
bx . ins ( ) . brif ( val , otherwise , & [ ] , block , & [ ] ) ;
} else {
Self ::build_jump_table ( bx , val , otherwise , 0 , & range . blocks ) ;
}
( _ , _ ) = > {
ins_fallthrough_jump ( was_branch , bx ) ;
} else {
if let Some ( block ) = range . single_block ( ) {
let is_good_val = icmp_imm_u128 ( bx , IntCC ::Equal , val , range . first_index ) ;
bx . ins ( ) . brif ( is_good_val , block , & [ ] , alternate , & [ ] ) ;
} else {
let is_good_val = icmp_imm_u128 (
bx ,
IntCC ::UnsignedGreaterThanOrEqual ,
val ,
range . first_index ,
) ;
let jt_block = bx . create_block ( ) ;
let is_good_val =
icmp_imm_u128 ( bx , IntCC ::UnsignedGreaterThanOrEqual , val , first_index ) ;
bx . ins ( ) . brnz ( is_good_val , jt_block , & [ ] ) ;
bx . ins ( ) . brif ( is_good_val , jt_block , & [ ] , alternate , & [ ] ) ;
bx . seal_block ( jt_block ) ;
cases_and_jt_blocks . push ( ( first_index , jt_block , blocks ) ) ;
bx . switch_to_block ( jt_block ) ;
Self ::build_jump_table ( bx , val , otherwise , range . first_index , & range . blocks ) ;
}
}
was_branch = true ;
}
bx . ins ( ) . jump ( otherwise , & [ ] ) ;
if alternate ! = otherwise {
bx . seal_block ( alternate ) ;
bx . switch_to_block ( alternate ) ;
}
}
}
/// For every item in `cases_and_jt_blocks` this will create a jump table in the specified block.
fn build_jump_tables (
fn build_jump_table (
bx : & mut FunctionBuilder ,
val : Value ,
otherwise : Block ,
cases_and_jt_blocks : Vec < ( EntryIndex , Block , Vec < Block > ) > ,
first_index : EntryIndex ,
blocks : & [ Block ] ,
) {
for ( first_index , jt_block , blocks ) in cases_and_jt_blocks . into_iter ( ) . rev ( ) {
// There are currently no 128bit systems supported by rustc, but once we do ensure that
// we don't silently ignore a part of the jump table for 128bit integers on 128bit systems.
assert ! (
u32 ::try_from ( blocks . len ( ) ) . is_ok ( ) ,
"Jump tables bigger than 2^32-1 are not yet supported"
) ;
// There are currently no 128bit systems supported by rustc, but once we do ensure that
// we don't silently ignore a part of the jump table for 128bit integers on 128bit systems.
assert ! (
u32 ::try_from ( blocks . len ( ) ) . is_ok ( ) ,
"Jump tables bigger than 2^32-1 are not yet supported"
) ;
let mut jt_data = JumpTableData ::new ( ) ;
for block in blocks {
jt_data . push_entry ( block ) ;
}
let jump_table = bx . create_jump_table ( jt_data ) ;
let jt_data = JumpTableData ::with_blocks ( Vec ::from ( blocks ) ) ;
let jump_table = bx . create_jump_table ( jt_data ) ;
bx . switch_to_block ( jt_block ) ;
let discr = if first_index = = 0 {
val
let discr = if first_index = = 0 {
val
} else {
if let Ok ( first_index ) = u64 ::try_from ( first_index ) {
bx . ins ( ) . iadd_imm ( val , ( first_index as i64 ) . wrapping_neg ( ) )
} else {
if let Ok ( first_index ) = u64 ::try_from ( first_index ) {
bx . ins ( ) . iadd_imm ( val , ( first_index as i64 ) . wrapping_neg ( ) )
} else {
let ( lsb , msb ) = ( first_index as u64 , ( first_index > > 64 ) as u64 ) ;
let lsb = bx . ins ( ) . iconst ( types ::I64 , lsb as i64 ) ;
let msb = bx . ins ( ) . iconst ( types ::I64 , msb as i64 ) ;
let index = bx . ins ( ) . iconcat ( lsb , msb ) ;
bx . ins ( ) . isub ( val , index )
}
} ;
let ( lsb , msb ) = ( first_index as u64 , ( first_index > > 64 ) as u64 ) ;
let lsb = bx . ins ( ) . iconst ( types ::I64 , lsb as i64 ) ;
let msb = bx . ins ( ) . iconst ( types ::I64 , msb as i64 ) ;
let index = bx . ins ( ) . iconcat ( lsb , msb ) ;
bx . ins ( ) . isub ( val , index )
}
} ;
let discr = match bx . func . dfg . value_type ( discr ) . bits ( ) {
bits if bits > 32 = > {
// Check for overflow of cast to u32. This is the max supported jump table entries.
let new_block = bx . create_block ( ) ;
let bigger_than_u32 =
bx . ins ( )
. icmp_imm ( IntCC ::UnsignedGreaterThan , discr , u32 ::MAX as i64 ) ;
bx . ins ( ) . brnz ( bigger_than_u32 , otherwise , & [ ] ) ;
bx . ins ( ) . jump ( new_block , & [ ] ) ;
bx . seal_block ( new_block ) ;
bx . switch_to_block ( new_block ) ;
// Cast to i32, as br_table is not implemented for i64/i128
bx . ins ( ) . ireduce ( types ::I32 , discr )
}
bits if bits < 32 = > bx . ins ( ) . uextend ( types ::I32 , discr ) ,
_ = > discr ,
} ;
let discr = match bx . func . dfg . value_type ( discr ) . bits ( ) {
bits if bits > 32 = > {
// Check for overflow of cast to u32. This is the max supported jump table entries.
let new_block = bx . create_block ( ) ;
let bigger_than_u32 =
bx . ins ( )
. icmp_imm ( IntCC ::UnsignedGreaterThan , discr , u32 ::MAX as i64 ) ;
bx . ins ( )
. brif ( bigger_than_u32 , otherwise , & [ ] , new_block , & [ ] ) ;
bx . seal_block ( new_block ) ;
bx . switch_to_block ( new_block ) ;
// Cast to i32, as br_table is not implemented for i64/i128
bx . ins ( ) . ireduce ( types ::I32 , discr )
}
bits if bits < 32 = > bx . ins ( ) . uextend ( types ::I32 , discr ) ,
_ = > discr ,
} ;
bx . ins ( ) . br_table ( discr , otherwise , jump_table ) ;
}
bx . ins ( ) . br_table ( discr , otherwise , jump_table ) ;
}
/// Build the switch
@ -307,9 +274,7 @@ impl Switch {
}
let contiguous_case_ranges = self . collect_contiguous_case_ranges ( ) ;
let cases_and_jt_blocks =
Self ::build_search_tree ( bx , val , otherwise , contiguous_case_ranges ) ;
Self ::build_jump_tables ( bx , val , otherwise , cases_and_jt_blocks ) ;
Self ::build_search_tree ( bx , val , otherwise , & contiguous_case_ranges ) ;
}
}
@ -351,6 +316,15 @@ impl ContiguousCaseRange {
blocks : Vec ::new ( ) ,
}
}
/// Returns `Some` block when there is only a single block in this range.
fn single_block ( & self ) -> Option < Block > {
if self . blocks . len ( ) = = 1 {
Some ( self . blocks [ 0 ] )
} else {
None
}
}
}
#[ cfg(test) ]
@ -384,43 +358,52 @@ mod tests {
} } ;
}
macro_rules ! assert_eq_output {
( $actual :ident , $expected :literal ) = > {
if $actual ! = $expected {
assert ! (
false ,
"\n{}" ,
similar ::TextDiff ::from_lines ( $expected , & $actual )
. unified_diff ( )
. header ( "expected" , "actual" )
) ;
}
} ;
}
#[ test ]
fn switch_zero ( ) {
let func = setup ! ( 0 , [ 0 , ] ) ;
assert_eq ! (
assert_eq_output ! (
func ,
" block0 :
v0 = iconst . i8 0
brz v0 , block1 ; v0 = 0
jump block0 "
brif v0 , block0 , block1 ; v0 = 0 "
) ;
}
#[ test ]
fn switch_single ( ) {
let func = setup ! ( 0 , [ 1 , ] ) ;
assert_eq ! (
assert_eq_output ! (
func ,
" block0 :
v0 = iconst . i8 0
v1 = icmp_imm eq v0 , 1 ; v0 = 0
brnz v1 , block1
jump block0 "
brif v1 , block1 , block0 "
) ;
}
#[ test ]
fn switch_bool ( ) {
let func = setup ! ( 0 , [ 0 , 1 , ] ) ;
assert_eq ! (
assert_eq_output ! (
func ,
" jt0 = jump_table [ block1 , block2 ]
block0 :
v0 = iconst . i8 0
jump block3
block3 :
v1 = uextend . i32 v0 ; v0 = 0
br_table v1 , block0 , jt0 "
) ;
@ -429,56 +412,50 @@ block3:
#[ test ]
fn switch_two_gap ( ) {
let func = setup ! ( 0 , [ 0 , 2 , ] ) ;
assert_eq ! (
assert_eq_output ! (
func ,
" block0 :
v0 = iconst . i8 0
v1 = icmp_imm eq v0 , 2 ; v0 = 0
brnz v1 , block2
jump block3
brif v1 , block2 , block3
block3 :
brz . i8 v0 , block1 ; v0 = 0
jump block0 "
brif . i8 v0 , block0 , block1 ; v0 = 0 "
) ;
}
#[ test ]
fn switch_many ( ) {
let func = setup ! ( 0 , [ 0 , 1 , 5 , 7 , 10 , 11 , 12 , ] ) ;
assert_eq ! (
assert_eq_output ! (
func ,
" jt0 = jump_table [ block1 , block2 ]
jt1 = jump_table [ block5 , block6 , block7 ]
" jt0 = jump_table [ block5 , block6 , block7 ]
jt1 = jump_table [ block1 , block2 ]
block0 :
v0 = iconst . i8 0
v1 = icmp_imm uge v0 , 7 ; v0 = 0
brnz v1 , block9
jump block8
brif v1 , block9 , block8
block9 :
v2 = icmp_imm . i8 uge v0 , 10 ; v0 = 0
brnz v2 , block10
jump block11
brif v2 , block11 , block10
block11 :
v3 = icmp_imm . i8 eq v0 , 7 ; v0 = 0
brnz v3 , block4
jump block0
v3 = iadd_imm . i8 v0 , - 10 ; v0 = 0
v4 = uextend . i32 v3
br_table v4 , block0 , jt0
block10 :
v5 = icmp_imm . i8 eq v0 , 7 ; v0 = 0
brif v5 , block4 , block0
block8 :
v4 = icmp_imm . i8 eq v0 , 5 ; v0 = 0
brnz v4 , block3
jump block12
v6 = icmp_imm . i8 eq v0 , 5 ; v0 = 0
brif v6 , block3 , block12
block12 :
v5 = uextend . i32 v0 ; v0 = 0
br_table v5 , block0 , jt0
block10 :
v6 = iadd_imm . i8 v0 , - 10 ; v0 = 0
v7 = uextend . i32 v6
v7 = uextend . i32 v0 ; v0 = 0
br_table v7 , block0 , jt1 "
) ;
}
@ -486,51 +463,46 @@ block10:
#[ test ]
fn switch_min_index_value ( ) {
let func = setup ! ( 0 , [ i8 ::MIN as u8 as u128 , 1 , ] ) ;
assert_eq ! (
assert_eq_output ! (
func ,
" block0 :
v0 = iconst . i8 0
v1 = icmp_imm eq v0 , 128 ; v0 = 0
brnz v1 , block1
jump block3
brif v1 , block1 , block3
block3 :
v2 = icmp_imm . i8 eq v0 , 1 ; v0 = 0
brnz v2 , block2
jump block0 "
brif v2 , block2 , block0 "
) ;
}
#[ test ]
fn switch_max_index_value ( ) {
let func = setup ! ( 0 , [ i8 ::MAX as u8 as u128 , 1 , ] ) ;
assert_eq ! (
assert_eq_output ! (
func ,
" block0 :
v0 = iconst . i8 0
v1 = icmp_imm eq v0 , 127 ; v0 = 0
brnz v1 , block1
jump block3
brif v1 , block1 , block3
block3 :
v2 = icmp_imm . i8 eq v0 , 1 ; v0 = 0
brnz v2 , block2
jump block0 "
brif v2 , block2 , block0 "
)
}
#[ test ]
fn switch_optimal_codegen ( ) {
let func = setup ! ( 0 , [ - 1 i8 as u8 as u128 , 0 , 1 , ] ) ;
assert_eq ! (
assert_eq_output ! (
func ,
" jt0 = jump_table [ block2 , block3 ]
block0 :
v0 = iconst . i8 0
v1 = icmp_imm eq v0 , 255 ; v0 = 0
brnz v1 , block1
jump block4
brif v1 , block1 , block4
block4 :
v2 = uextend . i32 v0 ; v0 = 0
@ -617,20 +589,16 @@ block4:
. trim_start_matches ( "function u0:0() fast {\n" )
. trim_end_matches ( "\n}\n" )
. to_string ( ) ;
assert_eq ! (
assert_eq_output ! (
func ,
" jt0 = jump_table [ block2 , block1 ]
block0 :
v0 = iconst . i64 0
jump block4
v1 = icmp_imm ugt v0 , 0xffff_ffff ; v0 = 0
brif v1 , block3 , block4
block4 :
v1 = icmp_imm . i64 ugt v0 , 0xffff_ffff ; v0 = 0
brnz v1 , block3
jump block5
block5 :
v2 = ireduce . i32 v0 ; v0 = 0
br_table v2 , block3 , jt0 "
) ;
@ -659,21 +627,17 @@ block5:
. trim_start_matches ( "function u0:0() fast {\n" )
. trim_end_matches ( "\n}\n" )
. to_string ( ) ;
assert_eq ! (
assert_eq_output ! (
func ,
" jt0 = jump_table [ block2 , block1 ]
block0 :
v0 = iconst . i64 0
v1 = uextend . i128 v0 ; v0 = 0
jump block4
v2 = icmp_imm ugt v1 , 0xffff_ffff
brif v2 , block3 , block4
block4 :
v2 = icmp_imm . i128 ugt v1 , 0xffff_ffff
brnz v2 , block3
jump block5
block5 :
v3 = ireduce . i32 v1
br_table v3 , block3 , jt0 "
) ;