@ -30,9 +30,52 @@ use crate::ir::Opcode;
/// `finite()` method.) An infinite cost is used to represent a value
/// that cannot be computed, or otherwise serve as a sentinel when
/// performing search for the lowest-cost representation of a value.
#[ derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord ) ]
#[ derive(Clone, Copy, PartialEq, Eq) ]
pub ( crate ) struct Cost ( u32 ) ;
impl core ::fmt ::Debug for Cost {
fn fmt ( & self , f : & mut core ::fmt ::Formatter < '_ > ) -> core ::fmt ::Result {
if * self = = Cost ::infinity ( ) {
write ! ( f , "Cost::Infinite" )
} else {
f . debug_struct ( "Cost::Finite" )
. field ( "op_cost" , & self . op_cost ( ) )
. field ( "depth" , & self . depth ( ) )
. finish ( )
}
}
}
impl Ord for Cost {
#[ inline ]
fn cmp ( & self , other : & Self ) -> std ::cmp ::Ordering {
// We make sure that the high bits are the op cost and the low bits are
// the depth. This means that we can use normal integer comparison to
// order by op cost and then depth.
//
// We want to break op cost ties with depth (rather than the other way
// around). When the op cost is the same, we prefer shallow and wide
// expressions to narrow and deep expressions and breaking ties with
// `depth` gives us that. For example, `(a + b) + (c + d)` is preferred
// to `((a + b) + c) + d`. This is beneficial because it exposes more
// instruction-level parallelism and shortens live ranges.
self . 0. cmp ( & other . 0 )
}
}
impl PartialOrd for Cost {
#[ inline ]
fn partial_cmp ( & self , other : & Self ) -> Option < std ::cmp ::Ordering > {
Some ( self . cmp ( other ) )
}
}
impl Cost {
const DEPTH_BITS : u8 = 8 ;
const DEPTH_MASK : u32 = ( 1 < < Self ::DEPTH_BITS ) - 1 ;
const OP_COST_MASK : u32 = ! Self ::DEPTH_MASK ;
const MAX_OP_COST : u32 = ( Self ::OP_COST_MASK > > Self ::DEPTH_BITS ) - 1 ;
pub ( crate ) fn infinity ( ) -> Cost {
// 2^32 - 1 is, uh, pretty close to infinite... (we use `Cost`
// only for heuristics and always saturate so this suffices!)
@ -43,11 +86,38 @@ impl Cost {
Cost ( 0 )
}
/// Clamp this cost at a "finite" value. Can be used in
/// conjunction with saturating ops to avoid saturating into
/// `infinity()`.
fn finite ( self ) -> Cost {
Cost ( std ::cmp ::min ( u32 ::MAX - 1 , self . 0 ) )
/// Construct a new finite cost from the given parts.
///
/// The opcode cost is clamped to the maximum value representable.
fn new_finite ( opcode_cost : u32 , depth : u8 ) -> Cost {
let opcode_cost = std ::cmp ::min ( opcode_cost , Self ::MAX_OP_COST ) ;
let cost = Cost ( ( opcode_cost < < Self ::DEPTH_BITS ) | u32 ::from ( depth ) ) ;
debug_assert_ne ! ( cost , Cost ::infinity ( ) ) ;
cost
}
fn depth ( & self ) -> u8 {
let depth = self . 0 & Self ::DEPTH_MASK ;
u8 ::try_from ( depth ) . unwrap ( )
}
fn op_cost ( & self ) -> u32 {
( self . 0 & Self ::OP_COST_MASK ) > > Self ::DEPTH_BITS
}
/// Compute the cost of the operation and its given operands.
///
/// Caller is responsible for checking that the opcode came from an instruction
/// that satisfies `inst_predicates::is_pure_for_egraph()`.
pub ( crate ) fn of_pure_op ( op : Opcode , operand_costs : impl IntoIterator < Item = Self > ) -> Self {
let c = pure_op_cost ( op ) + operand_costs . into_iter ( ) . sum ( ) ;
Cost ::new_finite ( c . op_cost ( ) , c . depth ( ) . saturating_add ( 1 ) )
}
}
impl std ::iter ::Sum < Cost > for Cost {
fn sum < I : Iterator < Item = Cost > > ( iter : I ) -> Self {
iter . fold ( Self ::zero ( ) , | a , b | a + b )
}
}
@ -59,22 +129,29 @@ impl std::default::Default for Cost {
impl std ::ops ::Add < Cost > for Cost {
type Output = Cost ;
fn add ( self , other : Cost ) -> Cost {
Cost ( self . 0. saturating_add ( other . 0 ) ) . finite ( )
let op_cost = std ::cmp ::min (
self . op_cost ( ) . saturating_add ( other . op_cost ( ) ) ,
Self ::MAX_OP_COST ,
) ;
let depth = std ::cmp ::max ( self . depth ( ) , other . depth ( ) ) ;
Cost ::new_finite ( op_cost , depth )
}
}
/// Return the cost of a *pure* opcode. Caller is responsible for
/// checking that the opcode came from an instruction that satisfies
/// `inst_predicates::is_pure_for_egraph()`.
pub ( crate ) fn pure_op_cost ( op : Opcode ) -> Cost {
/// Return the cost of a *pure* opcode.
///
/// Caller is responsible for checking that the opcode came from an instruction
/// that satisfies `inst_predicates::is_pure_for_egraph()`.
fn pure_op_cost ( op : Opcode ) -> Cost {
match op {
// Constants.
Opcode ::Iconst | Opcode ::F32const | Opcode ::F64const = > Cost ( 1 ) ,
Opcode ::Iconst | Opcode ::F32const | Opcode ::F64const = > Cost ::new_finite ( 1 , 0 ) ,
// Extends/reduces.
Opcode ::Uextend | Opcode ::Sextend | Opcode ::Ireduce | Opcode ::Iconcat | Opcode ::Isplit = > {
Cost ( 2 )
Cost ::new_finite ( 2 , 0 )
}
// "Simple" arithmetic.
@ -86,9 +163,9 @@ pub(crate) fn pure_op_cost(op: Opcode) -> Cost {
| Opcode ::Bnot
| Opcode ::Ishl
| Opcode ::Ushr
| Opcode ::Sshr = > Cost ( 3 ) ,
| Opcode ::Sshr = > Cost ::new_finite ( 3 , 0 ) ,
// Everything else (pure.)
_ = > Cost ( 4 ) ,
_ = > Cost ::new_finite ( 4 , 0 ) ,
}
}