Browse Source

support enums with more than 256 variants in derive macro (#4370)

* support enums with more than 256 variants in derive macro

This addresses #4361.  Technically, we now support up to 2^32 variants, which is
the maximum for the canonical ABI.  In practice, though, the derived code for
enums with even just 2^16 variants takes a prohibitively long time to compile.

Signed-off-by: Joel Dice <joel.dice@fermyon.com>

* simplify `LowerExpander::expand_variant` code

Signed-off-by: Joel Dice <joel.dice@fermyon.com>
pull/4382/head
Joel Dice 2 years ago
committed by GitHub
parent
commit
5542c4ef26
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 10
      Cargo.lock
  2. 1
      Cargo.toml
  3. 121
      crates/component-macro/src/lib.rs
  4. 15
      crates/misc/component-macro-test/Cargo.toml
  5. 34
      crates/misc/component-macro-test/src/lib.rs
  6. 37
      tests/all/component_model/macros.rs

10
Cargo.lock

@ -450,6 +450,15 @@ dependencies = [
"cc",
]
[[package]]
name = "component-macro-test"
version = "0.1.0"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "console"
version = "0.15.0"
@ -3413,6 +3422,7 @@ dependencies = [
"anyhow",
"async-trait",
"clap 3.1.15",
"component-macro-test",
"criterion",
"env_logger 0.9.0",
"filecheck",

1
Cargo.toml

@ -60,6 +60,7 @@ async-trait = "0.1"
wat = "1.0.43"
once_cell = "1.9.0"
rayon = "1.5.0"
component-macro-test = { path = "crates/misc/component-macro-test" }
[target.'cfg(windows)'.dev-dependencies]
windows-sys = { version = "0.36.0", features = ["Win32_System_Memory"] }

121
crates/component-macro/src/lib.rs

@ -145,6 +145,54 @@ fn add_trait_bounds(generics: &syn::Generics, bound: syn::TypeParamBound) -> syn
generics
}
#[derive(Debug, Copy, Clone)]
enum DiscriminantSize {
Size1,
Size2,
Size4,
}
impl DiscriminantSize {
fn quote(self, discriminant: usize) -> TokenStream {
match self {
Self::Size1 => {
let discriminant = u8::try_from(discriminant).unwrap();
quote!(#discriminant)
}
Self::Size2 => {
let discriminant = u16::try_from(discriminant).unwrap();
quote!(#discriminant)
}
Self::Size4 => {
let discriminant = u32::try_from(discriminant).unwrap();
quote!(#discriminant)
}
}
}
}
impl From<DiscriminantSize> for u32 {
fn from(size: DiscriminantSize) -> u32 {
match size {
DiscriminantSize::Size1 => 1,
DiscriminantSize::Size2 => 2,
DiscriminantSize::Size4 => 4,
}
}
}
fn discriminant_size(case_count: usize) -> Option<DiscriminantSize> {
if case_count <= 0xFF {
Some(DiscriminantSize::Size1)
} else if case_count <= 0xFFFF {
Some(DiscriminantSize::Size2)
} else if case_count <= 0xFFFF_FFFF {
Some(DiscriminantSize::Size4)
} else {
None
}
}
struct VariantCase<'a> {
attrs: &'a [syn::Attribute],
ident: &'a syn::Ident,
@ -157,6 +205,7 @@ trait Expander {
fn expand_variant(
&self,
input: &DeriveInput,
discriminant_size: DiscriminantSize,
cases: &[VariantCase],
style: VariantStyle,
) -> Result<TokenStream>;
@ -217,6 +266,13 @@ fn expand_variant(
));
}
let discriminant_size = discriminant_size(body.variants.len()).ok_or_else(|| {
Error::new(
input.ident.span(),
"`enum`s with more than 2^32 variants are not supported",
)
})?;
let cases = body
.variants
.iter()
@ -240,8 +296,13 @@ fn expand_variant(
name.span(),
format!(
"`{}` component types can only be derived for Rust `enum`s \
containing variants with at most one unnamed field each",
style
containing variants with {}",
style,
match style {
VariantStyle::Variant => "at most one unnamed field each",
VariantStyle::Enum => "no fields",
VariantStyle::Union => "exactly one unnamed field each",
}
),
))
}
@ -251,7 +312,7 @@ fn expand_variant(
)
.collect::<Result<Vec<_>>>()?;
expander.expand_variant(input, &cases, style)
expander.expand_variant(input, discriminant_size, &cases, style)
}
#[proc_macro_derive(Lift, attributes(component))]
@ -321,6 +382,7 @@ impl Expander for LiftExpander {
fn expand_variant(
&self,
input: &DeriveInput,
discriminant_size: DiscriminantSize,
cases: &[VariantCase],
_style: VariantStyle,
) -> Result<TokenStream> {
@ -330,31 +392,26 @@ impl Expander for LiftExpander {
let mut loads = TokenStream::new();
for (index, VariantCase { ident, ty, .. }) in cases.iter().enumerate() {
let index_u8 = u8::try_from(index).map_err(|_| {
Error::new(
input.ident.span(),
"`enum`s with more than 256 variants not yet supported",
)
})?;
let index_u32 = u32::try_from(index).unwrap();
let index_i32 = index_u8 as i32;
let index_quoted = discriminant_size.quote(index);
if let Some(ty) = ty {
lifts.extend(
quote!(#index_i32 => Self::#ident(<#ty as wasmtime::component::Lift>::lift(
quote!(#index_u32 => Self::#ident(<#ty as wasmtime::component::Lift>::lift(
store, options, unsafe { &src.payload.#ident }
)?),),
);
loads.extend(
quote!(#index_u8 => Self::#ident(<#ty as wasmtime::component::Lift>::load(
quote!(#index_quoted => Self::#ident(<#ty as wasmtime::component::Lift>::load(
memory, &payload[..<#ty as wasmtime::component::ComponentType>::SIZE32]
)?),),
);
} else {
lifts.extend(quote!(#index_i32 => Self::#ident,));
lifts.extend(quote!(#index_u32 => Self::#ident,));
loads.extend(quote!(#index_u8 => Self::#ident,));
loads.extend(quote!(#index_quoted => Self::#ident,));
}
}
@ -362,6 +419,14 @@ impl Expander for LiftExpander {
let generics = add_trait_bounds(&input.generics, parse_quote!(wasmtime::component::Lift));
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let from_bytes = match discriminant_size {
DiscriminantSize::Size1 => quote!(bytes[0]),
DiscriminantSize::Size2 => quote!(u16::from_le_bytes(bytes[0..2].try_into()?)),
DiscriminantSize::Size4 => quote!(u32::from_le_bytes(bytes[0..4].try_into()?)),
};
let payload_offset = u32::from(discriminant_size) as usize;
let expanded = quote! {
unsafe impl #impl_generics wasmtime::component::Lift for #name #ty_generics #where_clause {
#[inline]
@ -370,7 +435,7 @@ impl Expander for LiftExpander {
options: &#internal::Options,
src: &Self::Lower,
) -> #internal::anyhow::Result<Self> {
Ok(match src.tag.get_i32() {
Ok(match src.tag.get_u32() {
#lifts
discrim => #internal::anyhow::bail!("unexpected discriminant: {}", discrim),
})
@ -380,8 +445,8 @@ impl Expander for LiftExpander {
fn load(memory: &#internal::Memory, bytes: &[u8]) -> #internal::anyhow::Result<Self> {
let align = <Self as wasmtime::component::ComponentType>::ALIGN32;
debug_assert!((bytes.as_ptr() as usize) % (align as usize) == 0);
let discrim = bytes[0];
let payload = &bytes[#internal::align_to(1, align)..];
let discrim = #from_bytes;
let payload = &bytes[#internal::align_to(#payload_offset, align)..];
Ok(match discrim {
#loads
discrim => #internal::anyhow::bail!("unexpected discriminant: {}", discrim),
@ -456,6 +521,7 @@ impl Expander for LowerExpander {
fn expand_variant(
&self,
input: &DeriveInput,
discriminant_size: DiscriminantSize,
cases: &[VariantCase],
_style: VariantStyle,
) -> Result<TokenStream> {
@ -465,14 +531,9 @@ impl Expander for LowerExpander {
let mut stores = TokenStream::new();
for (index, VariantCase { ident, ty, .. }) in cases.iter().enumerate() {
let index_u8 = u8::try_from(index).map_err(|_| {
Error::new(
input.ident.span(),
"`enum`s with more than 256 variants not yet supported",
)
})?;
let index_u32 = u32::try_from(index).unwrap();
let index_i32 = index_u8 as i32;
let index_quoted = discriminant_size.quote(index);
let pattern;
let lower;
@ -492,12 +553,14 @@ impl Expander for LowerExpander {
}
lowers.extend(quote!(#pattern => {
#internal::map_maybe_uninit!(dst.tag).write(wasmtime::ValRaw::i32(#index_i32));
#internal::map_maybe_uninit!(dst.tag).write(wasmtime::ValRaw::i32(#index_u32 as i32));
#lower
}));
let discriminant_size = u32::from(discriminant_size) as usize;
stores.extend(quote!(#pattern => {
memory.get::<1>(offset)[0] = #index_u8;
*memory.get::<#discriminant_size>(offset) = #index_quoted.to_le_bytes();
#store
}));
}
@ -668,6 +731,7 @@ impl Expander for ComponentTypeExpander {
fn expand_variant(
&self,
input: &DeriveInput,
discriminant_size: DiscriminantSize,
cases: &[VariantCase],
style: VariantStyle,
) -> Result<TokenStream> {
@ -766,6 +830,7 @@ impl Expander for ComponentTypeExpander {
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let lower = format_ident!("Lower{}", name);
let lower_payload = format_ident!("LowerPayload{}", name);
let discriminant_size = u32::from(discriminant_size);
// You may wonder why we make the types of all the fields of the #lower struct and #lower_payload union
// generic. This is to work around a [normalization bug in
@ -806,11 +871,11 @@ impl Expander for ComponentTypeExpander {
const SIZE32: usize = {
let mut size = 0;
#sizes
#internal::align_to(1, Self::ALIGN32) + size
#internal::align_to(#discriminant_size as usize, Self::ALIGN32) + size
};
const ALIGN32: u32 = {
let mut align = 1;
let mut align = #discriminant_size;
#alignments
align
};

15
crates/misc/component-macro-test/Cargo.toml

@ -0,0 +1,15 @@
[package]
name = "component-macro-test"
authors = ["The Wasmtime Project Developers"]
license = "Apache-2.0 WITH LLVM-exception"
version = "0.1.0"
edition = "2021"
publish = false
[lib]
proc-macro = true
[dependencies]
proc-macro2 = "1.0"
quote = "1.0"
syn = { version = "1.0", features = ["full"] }

34
crates/misc/component-macro-test/src/lib.rs

@ -0,0 +1,34 @@
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::parse_macro_input;
#[proc_macro_attribute]
pub fn add_variants(
attr: proc_macro::TokenStream,
item: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
expand_variants(
&parse_macro_input!(attr as syn::LitInt),
parse_macro_input!(item as syn::ItemEnum),
)
.unwrap_or_else(syn::Error::into_compile_error)
.into()
}
fn expand_variants(count: &syn::LitInt, mut ty: syn::ItemEnum) -> syn::Result<TokenStream> {
let count = count
.base10_digits()
.parse::<usize>()
.map_err(|_| syn::Error::new(count.span(), "expected unsigned integer"))?;
ty.variants = (0..count)
.map(|index| syn::Variant {
attrs: Vec::new(),
ident: syn::Ident::new(&format!("V{}", index), Span::call_site()),
fields: syn::Fields::Unit,
discriminant: None,
})
.collect();
Ok(quote!(#ty))
}

37
tests/all/component_model/macros.rs

@ -1,5 +1,6 @@
use super::TypedFuncExt;
use anyhow::Result;
use component_macro_test::add_variants;
use std::fmt::Write;
use wasmtime::component::{Component, ComponentType, Lift, Linker, Lower};
use wasmtime::Store;
@ -475,5 +476,41 @@ fn enum_derive() -> Result<()> {
.get_typed_func::<(Foo,), Foo, _>(&mut store, "echo")
.is_err());
#[add_variants(257)]
#[derive(ComponentType, Lift, Lower, PartialEq, Eq, Debug, Copy, Clone)]
#[component(enum)]
enum Many {}
let component = Component::new(
&engine,
make_echo_component(
&format!(
r#"(type $Foo (enum {}))"#,
(0..257)
.map(|index| format!(r#""V{}""#, index))
.collect::<Vec<_>>()
.join(" ")
),
4,
),
)?;
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
let func = instance.get_typed_func::<(Many,), Many, _>(&mut store, "echo")?;
for &input in &[Many::V0, Many::V1, Many::V254, Many::V255, Many::V256] {
let output = func.call_and_post_return(&mut store, (input,))?;
assert_eq!(input, output);
}
// TODO: The following case takes forever (i.e. I gave up after 30 minutes) to compile; we'll need to profile
// the compiler to find out why, which may point the way to a more efficient option. On the other hand, this
// may not be worth spending time on. Enums with over 2^16 variants are rare enough.
// #[add_variants(65537)]
// #[derive(ComponentType, Lift, Lower, PartialEq, Eq, Debug, Copy, Clone)]
// #[component(enum)]
// enum ManyMore {}
Ok(())
}

Loading…
Cancel
Save