@ -145,6 +145,54 @@ fn add_trait_bounds(generics: &syn::Generics, bound: syn::TypeParamBound) -> syn
generics
}
#[ derive(Debug, Copy, Clone) ]
enum DiscriminantSize {
Size1 ,
Size2 ,
Size4 ,
}
impl DiscriminantSize {
fn quote ( self , discriminant : usize ) -> TokenStream {
match self {
Self ::Size1 = > {
let discriminant = u8 ::try_from ( discriminant ) . unwrap ( ) ;
quote ! ( # discriminant )
}
Self ::Size2 = > {
let discriminant = u16 ::try_from ( discriminant ) . unwrap ( ) ;
quote ! ( # discriminant )
}
Self ::Size4 = > {
let discriminant = u32 ::try_from ( discriminant ) . unwrap ( ) ;
quote ! ( # discriminant )
}
}
}
}
impl From < DiscriminantSize > for u32 {
fn from ( size : DiscriminantSize ) -> u32 {
match size {
DiscriminantSize ::Size1 = > 1 ,
DiscriminantSize ::Size2 = > 2 ,
DiscriminantSize ::Size4 = > 4 ,
}
}
}
fn discriminant_size ( case_count : usize ) -> Option < DiscriminantSize > {
if case_count < = 0xFF {
Some ( DiscriminantSize ::Size1 )
} else if case_count < = 0xFFFF {
Some ( DiscriminantSize ::Size2 )
} else if case_count < = 0xFFFF_FFFF {
Some ( DiscriminantSize ::Size4 )
} else {
None
}
}
struct VariantCase < 'a > {
attrs : & 'a [ syn ::Attribute ] ,
ident : & 'a syn ::Ident ,
@ -157,6 +205,7 @@ trait Expander {
fn expand_variant (
& self ,
input : & DeriveInput ,
discriminant_size : DiscriminantSize ,
cases : & [ VariantCase ] ,
style : VariantStyle ,
) -> Result < TokenStream > ;
@ -217,6 +266,13 @@ fn expand_variant(
) ) ;
}
let discriminant_size = discriminant_size ( body . variants . len ( ) ) . ok_or_else ( | | {
Error ::new (
input . ident . span ( ) ,
"`enum`s with more than 2^32 variants are not supported" ,
)
} ) ? ;
let cases = body
. variants
. iter ( )
@ -240,8 +296,13 @@ fn expand_variant(
name . span ( ) ,
format ! (
" ` { } ` component types can only be derived for Rust ` enum ` s \
containing variants with at most one unnamed field each " ,
style
containing variants with { } " ,
style ,
match style {
VariantStyle ::Variant = > "at most one unnamed field each" ,
VariantStyle ::Enum = > "no fields" ,
VariantStyle ::Union = > "exactly one unnamed field each" ,
}
) ,
) )
}
@ -251,7 +312,7 @@ fn expand_variant(
)
. collect ::< Result < Vec < _ > > > ( ) ? ;
expander . expand_variant ( input , & cases , style )
expander . expand_variant ( input , discriminant_size , & cases , style )
}
#[ proc_macro_derive(Lift, attributes(component)) ]
@ -321,6 +382,7 @@ impl Expander for LiftExpander {
fn expand_variant (
& self ,
input : & DeriveInput ,
discriminant_size : DiscriminantSize ,
cases : & [ VariantCase ] ,
_style : VariantStyle ,
) -> Result < TokenStream > {
@ -330,31 +392,26 @@ impl Expander for LiftExpander {
let mut loads = TokenStream ::new ( ) ;
for ( index , VariantCase { ident , ty , . . } ) in cases . iter ( ) . enumerate ( ) {
let index_u8 = u8 ::try_from ( index ) . map_err ( | _ | {
Error ::new (
input . ident . span ( ) ,
"`enum`s with more than 256 variants not yet supported" ,
)
} ) ? ;
let index_u32 = u32 ::try_from ( index ) . unwrap ( ) ;
let index_i32 = index_u8 as i32 ;
let index_quoted = discriminant_size . quote ( index ) ;
if let Some ( ty ) = ty {
lifts . extend (
quote ! ( # index_i 32 = > Self ::# ident ( < # ty as wasmtime ::component ::Lift > ::lift (
quote ! ( # index_u 32 = > Self ::# ident ( < # ty as wasmtime ::component ::Lift > ::lift (
store , options , unsafe { & src . payload . # ident }
) ? ) , ) ,
) ;
loads . extend (
quote ! ( # index_u8 = > Self ::# ident ( < # ty as wasmtime ::component ::Lift > ::load (
quote ! ( # index_quoted = > Self ::# ident ( < # ty as wasmtime ::component ::Lift > ::load (
memory , & payload [ . . < # ty as wasmtime ::component ::ComponentType > ::SIZE32 ]
) ? ) , ) ,
) ;
} else {
lifts . extend ( quote ! ( # index_i 32 = > Self ::# ident , ) ) ;
lifts . extend ( quote ! ( # index_u 32 = > Self ::# ident , ) ) ;
loads . extend ( quote ! ( # index_u8 = > Self ::# ident , ) ) ;
loads . extend ( quote ! ( # index_quoted = > Self ::# ident , ) ) ;
}
}
@ -362,6 +419,14 @@ impl Expander for LiftExpander {
let generics = add_trait_bounds ( & input . generics , parse_quote ! ( wasmtime ::component ::Lift ) ) ;
let ( impl_generics , ty_generics , where_clause ) = generics . split_for_impl ( ) ;
let from_bytes = match discriminant_size {
DiscriminantSize ::Size1 = > quote ! ( bytes [ 0 ] ) ,
DiscriminantSize ::Size2 = > quote ! ( u16 ::from_le_bytes ( bytes [ 0 . . 2 ] . try_into ( ) ? ) ) ,
DiscriminantSize ::Size4 = > quote ! ( u32 ::from_le_bytes ( bytes [ 0 . . 4 ] . try_into ( ) ? ) ) ,
} ;
let payload_offset = u32 ::from ( discriminant_size ) as usize ;
let expanded = quote ! {
unsafe impl # impl_generics wasmtime ::component ::Lift for # name # ty_generics # where_clause {
#[ inline ]
@ -370,7 +435,7 @@ impl Expander for LiftExpander {
options : & # internal ::Options ,
src : & Self ::Lower ,
) -> # internal ::anyhow ::Result < Self > {
Ok ( match src . tag . get_i 32 ( ) {
Ok ( match src . tag . get_u 32 ( ) {
# lifts
discrim = > # internal ::anyhow ::bail ! ( "unexpected discriminant: {}" , discrim ) ,
} )
@ -380,8 +445,8 @@ impl Expander for LiftExpander {
fn load ( memory : & # internal ::Memory , bytes : & [ u8 ] ) -> # internal ::anyhow ::Result < Self > {
let align = < Self as wasmtime ::component ::ComponentType > ::ALIGN32 ;
debug_assert ! ( ( bytes . as_ptr ( ) as usize ) % ( align as usize ) = = 0 ) ;
let discrim = bytes [ 0 ] ;
let payload = & bytes [ # internal ::align_to ( 1 , align ) . . ] ;
let discrim = # from_bytes ;
let payload = & bytes [ # internal ::align_to ( # payload_offset , align ) . . ] ;
Ok ( match discrim {
# loads
discrim = > # internal ::anyhow ::bail ! ( "unexpected discriminant: {}" , discrim ) ,
@ -456,6 +521,7 @@ impl Expander for LowerExpander {
fn expand_variant (
& self ,
input : & DeriveInput ,
discriminant_size : DiscriminantSize ,
cases : & [ VariantCase ] ,
_style : VariantStyle ,
) -> Result < TokenStream > {
@ -465,14 +531,9 @@ impl Expander for LowerExpander {
let mut stores = TokenStream ::new ( ) ;
for ( index , VariantCase { ident , ty , . . } ) in cases . iter ( ) . enumerate ( ) {
let index_u8 = u8 ::try_from ( index ) . map_err ( | _ | {
Error ::new (
input . ident . span ( ) ,
"`enum`s with more than 256 variants not yet supported" ,
)
} ) ? ;
let index_u32 = u32 ::try_from ( index ) . unwrap ( ) ;
let index_i32 = index_u8 as i32 ;
let index_quoted = discriminant_size . quote ( index ) ;
let pattern ;
let lower ;
@ -492,12 +553,14 @@ impl Expander for LowerExpander {
}
lowers . extend ( quote ! ( # pattern = > {
# internal ::map_maybe_uninit ! ( dst . tag ) . write ( wasmtime ::ValRaw ::i32 ( # index_i32 ) ) ;
# internal ::map_maybe_uninit ! ( dst . tag ) . write ( wasmtime ::ValRaw ::i32 ( # index_u32 as i32 ) ) ;
# lower
} ) ) ;
let discriminant_size = u32 ::from ( discriminant_size ) as usize ;
stores . extend ( quote ! ( # pattern = > {
memory . get ::< 1 > ( offset ) [ 0 ] = # index_u8 ;
* memory . get ::< # discriminant_size > ( offset ) = # index_quoted . to_le_bytes ( ) ;
# store
} ) ) ;
}
@ -668,6 +731,7 @@ impl Expander for ComponentTypeExpander {
fn expand_variant (
& self ,
input : & DeriveInput ,
discriminant_size : DiscriminantSize ,
cases : & [ VariantCase ] ,
style : VariantStyle ,
) -> Result < TokenStream > {
@ -766,6 +830,7 @@ impl Expander for ComponentTypeExpander {
let ( impl_generics , ty_generics , where_clause ) = generics . split_for_impl ( ) ;
let lower = format_ident ! ( "Lower{}" , name ) ;
let lower_payload = format_ident ! ( "LowerPayload{}" , name ) ;
let discriminant_size = u32 ::from ( discriminant_size ) ;
// You may wonder why we make the types of all the fields of the #lower struct and #lower_payload union
// generic. This is to work around a [normalization bug in
@ -806,11 +871,11 @@ impl Expander for ComponentTypeExpander {
const SIZE32 : usize = {
let mut size = 0 ;
# sizes
# internal ::align_to ( 1 , Self ::ALIGN32 ) + size
# internal ::align_to ( # discriminant_size as usize , Self ::ALIGN32 ) + size
} ;
const ALIGN32 : u32 = {
let mut align = 1 ;
let mut align = # discriminant_size ;
# alignments
align
} ;