diff --git a/.gitmodules b/.gitmodules index dc9dfa50f5..e1fa8b72c6 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,9 +1,6 @@ [submodule "tests/spec_testsuite"] path = tests/spec_testsuite url = https://github.com/WebAssembly/testsuite -[submodule "crates/wasi-nn/spec"] - path = crates/wasi-nn/spec - url = https://github.com/WebAssembly/wasi-nn [submodule "tests/wasi_testsuite/wasi-threads"] path = tests/wasi_testsuite/wasi-threads url = https://github.com/WebAssembly/wasi-threads diff --git a/ci/vendor-wit.sh b/ci/vendor-wit.sh index d594de918d..30c1cbb928 100755 --- a/ci/vendor-wit.sh +++ b/ci/vendor-wit.sh @@ -4,6 +4,7 @@ # particular tag in upstream repositories. # # This script is executed on CI to ensure that everything is up-to-date. +set -ex # Space-separated list of wasi proposals that are vendored here along with the # tag that they're all vendored at. @@ -15,13 +16,10 @@ repos="cli clocks filesystem http io random sockets" tag=0.2.0 -set -ex - +# First, replace the existing vendored WIT files in the `wasi` crate. dst=crates/wasi/wit/deps - rm -rf $dst mkdir -p $dst - for repo in $repos; do mkdir $dst/$repo curl -L https://github.com/WebAssembly/wasi-$repo/archive/refs/tags/v$tag.tar.gz | \ @@ -29,5 +27,14 @@ for repo in $repos; do rm -rf $dst/$repo/deps* done +# Also replace the `wasi-http` WIT files since they match those in the `wasi` +# crate. rm -rf crates/wasi-http/wit/deps cp -r $dst crates/wasi-http/wit + +# Separately (for now), vendor the `wasi-nn` WIT files since their retrieval is +# slightly different than above. +repo=https://raw.githubusercontent.com/WebAssembly/wasi-nn +revision=e2310b +curl -L $repo/$revision/wit/wasi-nn.wit -o crates/wasi-nn/wit/wasi-nn.wit +curl -L $repo/$revision/wasi-nn.witx -o crates/wasi-nn/witx/wasi-nn.witx diff --git a/crates/wasi-nn/build.rs b/crates/wasi-nn/build.rs index 535d0e0f80..f3c018366c 100644 --- a/crates/wasi-nn/build.rs +++ b/crates/wasi-nn/build.rs @@ -1,12 +1,12 @@ //! This build script: -//! - has the configuration necessary for the wiggle and witx macros. +//! - has the configuration necessary for the Wiggle and WITX macros. fn main() { - // This is necessary for Wiggle/Witx macros. + // This is necessary for Wiggle/WITX macros. let cwd = std::env::current_dir().unwrap(); - let wasi_root = cwd.join("spec"); + let wasi_root = cwd.join("witx"); println!("cargo:rustc-env=WASI_ROOT={}", wasi_root.display()); - // Also automatically rebuild if the Witx files change + // Also automatically rebuild if the WITX files change for entry in walkdir::WalkDir::new(wasi_root) { println!("cargo:rerun-if-changed={}", entry.unwrap().path().display()); } diff --git a/crates/wasi-nn/spec b/crates/wasi-nn/spec deleted file mode 160000 index e2310b860d..0000000000 --- a/crates/wasi-nn/spec +++ /dev/null @@ -1 +0,0 @@ -Subproject commit e2310b860db2ff1719c9d69816099b87e85fabdb diff --git a/crates/wasi-nn/src/wit.rs b/crates/wasi-nn/src/wit.rs index dd9c9cc085..dbe894357c 100644 --- a/crates/wasi-nn/src/wit.rs +++ b/crates/wasi-nn/src/wit.rs @@ -22,7 +22,7 @@ use std::{error::Error, fmt, hash::Hash, str::FromStr}; mod gen_ { wasmtime::component::bindgen!({ world: "ml", - path: "spec/wit/wasi-nn.wit", + path: "wit/wasi-nn.wit", trappable_imports: true, }); } diff --git a/crates/wasi-nn/wit/wasi-nn.wit b/crates/wasi-nn/wit/wasi-nn.wit new file mode 100644 index 0000000000..19e3de875d --- /dev/null +++ b/crates/wasi-nn/wit/wasi-nn.wit @@ -0,0 +1,150 @@ +package wasi:nn; + +/// `wasi-nn` is a WASI API for performing machine learning (ML) inference. The API is not (yet) +/// capable of performing ML training. WebAssembly programs that want to use a host's ML +/// capabilities can access these capabilities through `wasi-nn`'s core abstractions: _graphs_ and +/// _tensors_. A user `load`s an ML model -- instantiated as a _graph_ -- to use in an ML _backend_. +/// Then, the user passes _tensor_ inputs to the _graph_, computes the inference, and retrieves the +/// _tensor_ outputs. +/// +/// This example world shows how to use these primitives together. +world ml { + import tensor; + import graph; + import inference; + import errors; +} + +/// All inputs and outputs to an ML inference are represented as `tensor`s. +interface tensor { + /// The dimensions of a tensor. + /// + /// The array length matches the tensor rank and each element in the array describes the size of + /// each dimension + type tensor-dimensions = list; + + /// The type of the elements in a tensor. + enum tensor-type { + FP16, + FP32, + FP64, + BF16, + U8, + I32, + I64 + } + + /// The tensor data. + /// + /// Initially conceived as a sparse representation, each empty cell would be filled with zeros + /// and the array length must match the product of all of the dimensions and the number of bytes + /// in the type (e.g., a 2x2 tensor with 4-byte f32 elements would have a data array of length + /// 16). Naturally, this representation requires some knowledge of how to lay out data in + /// memory--e.g., using row-major ordering--and could perhaps be improved. + type tensor-data = list; + + record tensor { + // Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To represent a tensor + // containing a single value, use `[1]` for the tensor dimensions. + dimensions: tensor-dimensions, + + // Describe the type of element in the tensor (e.g., `f32`). + tensor-type: tensor-type, + + // Contains the tensor data. + data: tensor-data, + } +} + +/// A `graph` is a loaded instance of a specific ML model (e.g., MobileNet) for a specific ML +/// framework (e.g., TensorFlow): +interface graph { + use errors.{error}; + use tensor.{tensor}; + + /// An execution graph for performing inference (i.e., a model). + /// + /// TODO: replace with `resource` (https://github.com/WebAssembly/wasi-nn/issues/47). + type graph = u32; + + /// Describes the encoding of the graph. This allows the API to be implemented by various + /// backends that encode (i.e., serialize) their graph IR with different formats. + enum graph-encoding { + openvino, + onnx, + tensorflow, + pytorch, + tensorflowlite, + autodetect, + } + + /// Define where the graph should be executed. + enum execution-target { + cpu, + gpu, + tpu + } + + /// The graph initialization data. + /// + /// This gets bundled up into an array of buffers because implementing backends may encode their + /// graph IR in parts (e.g., OpenVINO stores its IR and weights separately). + type graph-builder = list; + + /// Load a `graph` from an opaque sequence of bytes to use for inference. + load: func(builder: list, encoding: graph-encoding, target: execution-target) -> result; + + /// Load a `graph` by name. + /// + /// How the host expects the names to be passed and how it stores the graphs for retrieval via + /// this function is **implementation-specific**. This allows hosts to choose name schemes that + /// range from simple to complex (e.g., URLs?) and caching mechanisms of various kinds. + load-by-name: func(name: string) -> result; +} + +/// An inference "session" is encapsulated by a `graph-execution-context`. This structure binds a +/// `graph` to input tensors before `compute`-ing an inference: +interface inference { + use errors.{error}; + use tensor.{tensor, tensor-data}; + use graph.{graph}; + + /// Bind a `graph` to the input and output tensors for an inference. + /// + /// TODO: this is no longer necessary in WIT (https://github.com/WebAssembly/wasi-nn/issues/43) + type graph-execution-context = u32; + + /// Create an execution instance of a loaded graph. + init-execution-context: func(graph: graph) -> result; + + /// Define the inputs to use for inference. + set-input: func(ctx: graph-execution-context, index: u32, tensor: tensor) -> result<_, error>; + + /// Compute the inference on the given inputs. + /// + /// Note the expected sequence of calls: `set-input`, `compute`, `get-output`. TODO: this + /// expectation could be removed as a part of https://github.com/WebAssembly/wasi-nn/issues/43. + compute: func(ctx: graph-execution-context) -> result<_, error>; + + /// Extract the outputs after inference. + get-output: func(ctx: graph-execution-context, index: u32) -> result; +} + +/// TODO: create function-specific errors (https://github.com/WebAssembly/wasi-nn/issues/42) +interface errors { + enum error { + // Caller module passed an invalid argument. + invalid-argument, + // Invalid encoding. + invalid-encoding, + busy, + // Runtime Error. + runtime-error, + // Unsupported operation. + unsupported-operation, + // Graph is too large. + too-large, + // Graph not found. + not-found + } +} diff --git a/crates/wasi-nn/witx/wasi-nn.witx b/crates/wasi-nn/witx/wasi-nn.witx new file mode 100644 index 0000000000..e413c2e03d --- /dev/null +++ b/crates/wasi-nn/witx/wasi-nn.witx @@ -0,0 +1,92 @@ +;; This WITX version of the wasi-nn API is retained for consistency only. See the `wit/wasi-nn.wit` +;; version for the official specification and documentation. + +(typename $buffer_size u32) +(typename $nn_errno + (enum (@witx tag u16) + $success + $invalid_argument + $invalid_encoding + $missing_memory + $busy + $runtime_error + $unsupported_operation + $too_large + $not_found + ) +) +(typename $tensor_dimensions (list u32)) +(typename $tensor_type + (enum (@witx tag u8) + $f16 + $f32 + $f64 + $u8 + $i32 + $i64 + ) +) +(typename $tensor_data (list u8)) +(typename $tensor + (record + (field $dimensions $tensor_dimensions) + (field $type $tensor_type) + (field $data $tensor_data) + ) +) +(typename $graph_builder (list u8)) +(typename $graph_builder_array (list $graph_builder)) +(typename $graph (handle)) +(typename $graph_encoding + (enum (@witx tag u8) + $openvino + $onnx + $tensorflow + $pytorch + $tensorflowlite + $autodetect + ) +) +(typename $execution_target + (enum (@witx tag u8) + $cpu + $gpu + $tpu + ) +) +(typename $graph_execution_context (handle)) + +(module $wasi_ephemeral_nn + (import "memory" (memory)) + (@interface func (export "load") + (param $builder $graph_builder_array) + (param $encoding $graph_encoding) + (param $target $execution_target) + (result $error (expected $graph (error $nn_errno))) + ) + (@interface func (export "load_by_name") + (param $name string) + (result $error (expected $graph (error $nn_errno))) + ) + (@interface func (export "init_execution_context") + (param $graph $graph) + (result $error (expected $graph_execution_context (error $nn_errno))) + ) + (@interface func (export "set_input") + (param $context $graph_execution_context) + (param $index u32) + (param $tensor $tensor) + (result $error (expected (error $nn_errno))) + ) + (@interface func (export "get_output") + (param $context $graph_execution_context) + (param $index u32) + (param $out_buffer (@witx pointer u8)) + (param $out_buffer_max_size $buffer_size) + (result $error (expected $buffer_size (error $nn_errno))) + ) + (@interface func (export "compute") + (param $context $graph_execution_context) + (result $error (expected (error $nn_errno))) + ) +)