Browse Source

wasi-nn: remove Git submodule (#8519)

* wasi-nn: remove Git submodule

To more closely align with the conventions in the `wasmtime-wasi` and
`wasmtime-wasi-http` crates, this change removes the Git submodule that
previously provided the WIT and WITX files for `wasmtime-wasi-nn`. Like
those other crates, the syncing of wasi-nn WIT and WITX files will
happen manually for the time being. This is the first PR towards
upgrading the wasi-nn implementation to match recent spec changes and
better preview2-ABI compatibility.

prtest:full

* ci: auto-vendor the wasi-nn WIT files
pull/8532/head
Andrew Brown 6 months ago
committed by GitHub
parent
commit
71d576e325
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 3
      .gitmodules
  2. 15
      ci/vendor-wit.sh
  3. 8
      crates/wasi-nn/build.rs
  4. 1
      crates/wasi-nn/spec
  5. 2
      crates/wasi-nn/src/wit.rs
  6. 150
      crates/wasi-nn/wit/wasi-nn.wit
  7. 92
      crates/wasi-nn/witx/wasi-nn.witx

3
.gitmodules

@ -1,9 +1,6 @@
[submodule "tests/spec_testsuite"]
path = tests/spec_testsuite
url = https://github.com/WebAssembly/testsuite
[submodule "crates/wasi-nn/spec"]
path = crates/wasi-nn/spec
url = https://github.com/WebAssembly/wasi-nn
[submodule "tests/wasi_testsuite/wasi-threads"]
path = tests/wasi_testsuite/wasi-threads
url = https://github.com/WebAssembly/wasi-threads

15
ci/vendor-wit.sh

@ -4,6 +4,7 @@
# particular tag in upstream repositories.
#
# This script is executed on CI to ensure that everything is up-to-date.
set -ex
# Space-separated list of wasi proposals that are vendored here along with the
# tag that they're all vendored at.
@ -15,13 +16,10 @@
repos="cli clocks filesystem http io random sockets"
tag=0.2.0
set -ex
# First, replace the existing vendored WIT files in the `wasi` crate.
dst=crates/wasi/wit/deps
rm -rf $dst
mkdir -p $dst
for repo in $repos; do
mkdir $dst/$repo
curl -L https://github.com/WebAssembly/wasi-$repo/archive/refs/tags/v$tag.tar.gz | \
@ -29,5 +27,14 @@ for repo in $repos; do
rm -rf $dst/$repo/deps*
done
# Also replace the `wasi-http` WIT files since they match those in the `wasi`
# crate.
rm -rf crates/wasi-http/wit/deps
cp -r $dst crates/wasi-http/wit
# Separately (for now), vendor the `wasi-nn` WIT files since their retrieval is
# slightly different than above.
repo=https://raw.githubusercontent.com/WebAssembly/wasi-nn
revision=e2310b
curl -L $repo/$revision/wit/wasi-nn.wit -o crates/wasi-nn/wit/wasi-nn.wit
curl -L $repo/$revision/wasi-nn.witx -o crates/wasi-nn/witx/wasi-nn.witx

8
crates/wasi-nn/build.rs

@ -1,12 +1,12 @@
//! This build script:
//! - has the configuration necessary for the wiggle and witx macros.
//! - has the configuration necessary for the Wiggle and WITX macros.
fn main() {
// This is necessary for Wiggle/Witx macros.
// This is necessary for Wiggle/WITX macros.
let cwd = std::env::current_dir().unwrap();
let wasi_root = cwd.join("spec");
let wasi_root = cwd.join("witx");
println!("cargo:rustc-env=WASI_ROOT={}", wasi_root.display());
// Also automatically rebuild if the Witx files change
// Also automatically rebuild if the WITX files change
for entry in walkdir::WalkDir::new(wasi_root) {
println!("cargo:rerun-if-changed={}", entry.unwrap().path().display());
}

1
crates/wasi-nn/spec

@ -1 +0,0 @@
Subproject commit e2310b860db2ff1719c9d69816099b87e85fabdb

2
crates/wasi-nn/src/wit.rs

@ -22,7 +22,7 @@ use std::{error::Error, fmt, hash::Hash, str::FromStr};
mod gen_ {
wasmtime::component::bindgen!({
world: "ml",
path: "spec/wit/wasi-nn.wit",
path: "wit/wasi-nn.wit",
trappable_imports: true,
});
}

150
crates/wasi-nn/wit/wasi-nn.wit

@ -0,0 +1,150 @@
package wasi:nn;
/// `wasi-nn` is a WASI API for performing machine learning (ML) inference. The API is not (yet)
/// capable of performing ML training. WebAssembly programs that want to use a host's ML
/// capabilities can access these capabilities through `wasi-nn`'s core abstractions: _graphs_ and
/// _tensors_. A user `load`s an ML model -- instantiated as a _graph_ -- to use in an ML _backend_.
/// Then, the user passes _tensor_ inputs to the _graph_, computes the inference, and retrieves the
/// _tensor_ outputs.
///
/// This example world shows how to use these primitives together.
world ml {
import tensor;
import graph;
import inference;
import errors;
}
/// All inputs and outputs to an ML inference are represented as `tensor`s.
interface tensor {
/// The dimensions of a tensor.
///
/// The array length matches the tensor rank and each element in the array describes the size of
/// each dimension
type tensor-dimensions = list<u32>;
/// The type of the elements in a tensor.
enum tensor-type {
FP16,
FP32,
FP64,
BF16,
U8,
I32,
I64
}
/// The tensor data.
///
/// Initially conceived as a sparse representation, each empty cell would be filled with zeros
/// and the array length must match the product of all of the dimensions and the number of bytes
/// in the type (e.g., a 2x2 tensor with 4-byte f32 elements would have a data array of length
/// 16). Naturally, this representation requires some knowledge of how to lay out data in
/// memory--e.g., using row-major ordering--and could perhaps be improved.
type tensor-data = list<u8>;
record tensor {
// Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To represent a tensor
// containing a single value, use `[1]` for the tensor dimensions.
dimensions: tensor-dimensions,
// Describe the type of element in the tensor (e.g., `f32`).
tensor-type: tensor-type,
// Contains the tensor data.
data: tensor-data,
}
}
/// A `graph` is a loaded instance of a specific ML model (e.g., MobileNet) for a specific ML
/// framework (e.g., TensorFlow):
interface graph {
use errors.{error};
use tensor.{tensor};
/// An execution graph for performing inference (i.e., a model).
///
/// TODO: replace with `resource` (https://github.com/WebAssembly/wasi-nn/issues/47).
type graph = u32;
/// Describes the encoding of the graph. This allows the API to be implemented by various
/// backends that encode (i.e., serialize) their graph IR with different formats.
enum graph-encoding {
openvino,
onnx,
tensorflow,
pytorch,
tensorflowlite,
autodetect,
}
/// Define where the graph should be executed.
enum execution-target {
cpu,
gpu,
tpu
}
/// The graph initialization data.
///
/// This gets bundled up into an array of buffers because implementing backends may encode their
/// graph IR in parts (e.g., OpenVINO stores its IR and weights separately).
type graph-builder = list<u8>;
/// Load a `graph` from an opaque sequence of bytes to use for inference.
load: func(builder: list<graph-builder>, encoding: graph-encoding, target: execution-target) -> result<graph, error>;
/// Load a `graph` by name.
///
/// How the host expects the names to be passed and how it stores the graphs for retrieval via
/// this function is **implementation-specific**. This allows hosts to choose name schemes that
/// range from simple to complex (e.g., URLs?) and caching mechanisms of various kinds.
load-by-name: func(name: string) -> result<graph, error>;
}
/// An inference "session" is encapsulated by a `graph-execution-context`. This structure binds a
/// `graph` to input tensors before `compute`-ing an inference:
interface inference {
use errors.{error};
use tensor.{tensor, tensor-data};
use graph.{graph};
/// Bind a `graph` to the input and output tensors for an inference.
///
/// TODO: this is no longer necessary in WIT (https://github.com/WebAssembly/wasi-nn/issues/43)
type graph-execution-context = u32;
/// Create an execution instance of a loaded graph.
init-execution-context: func(graph: graph) -> result<graph-execution-context, error>;
/// Define the inputs to use for inference.
set-input: func(ctx: graph-execution-context, index: u32, tensor: tensor) -> result<_, error>;
/// Compute the inference on the given inputs.
///
/// Note the expected sequence of calls: `set-input`, `compute`, `get-output`. TODO: this
/// expectation could be removed as a part of https://github.com/WebAssembly/wasi-nn/issues/43.
compute: func(ctx: graph-execution-context) -> result<_, error>;
/// Extract the outputs after inference.
get-output: func(ctx: graph-execution-context, index: u32) -> result<tensor-data, error>;
}
/// TODO: create function-specific errors (https://github.com/WebAssembly/wasi-nn/issues/42)
interface errors {
enum error {
// Caller module passed an invalid argument.
invalid-argument,
// Invalid encoding.
invalid-encoding,
busy,
// Runtime Error.
runtime-error,
// Unsupported operation.
unsupported-operation,
// Graph is too large.
too-large,
// Graph not found.
not-found
}
}

92
crates/wasi-nn/witx/wasi-nn.witx

@ -0,0 +1,92 @@
;; This WITX version of the wasi-nn API is retained for consistency only. See the `wit/wasi-nn.wit`
;; version for the official specification and documentation.
(typename $buffer_size u32)
(typename $nn_errno
(enum (@witx tag u16)
$success
$invalid_argument
$invalid_encoding
$missing_memory
$busy
$runtime_error
$unsupported_operation
$too_large
$not_found
)
)
(typename $tensor_dimensions (list u32))
(typename $tensor_type
(enum (@witx tag u8)
$f16
$f32
$f64
$u8
$i32
$i64
)
)
(typename $tensor_data (list u8))
(typename $tensor
(record
(field $dimensions $tensor_dimensions)
(field $type $tensor_type)
(field $data $tensor_data)
)
)
(typename $graph_builder (list u8))
(typename $graph_builder_array (list $graph_builder))
(typename $graph (handle))
(typename $graph_encoding
(enum (@witx tag u8)
$openvino
$onnx
$tensorflow
$pytorch
$tensorflowlite
$autodetect
)
)
(typename $execution_target
(enum (@witx tag u8)
$cpu
$gpu
$tpu
)
)
(typename $graph_execution_context (handle))
(module $wasi_ephemeral_nn
(import "memory" (memory))
(@interface func (export "load")
(param $builder $graph_builder_array)
(param $encoding $graph_encoding)
(param $target $execution_target)
(result $error (expected $graph (error $nn_errno)))
)
(@interface func (export "load_by_name")
(param $name string)
(result $error (expected $graph (error $nn_errno)))
)
(@interface func (export "init_execution_context")
(param $graph $graph)
(result $error (expected $graph_execution_context (error $nn_errno)))
)
(@interface func (export "set_input")
(param $context $graph_execution_context)
(param $index u32)
(param $tensor $tensor)
(result $error (expected (error $nn_errno)))
)
(@interface func (export "get_output")
(param $context $graph_execution_context)
(param $index u32)
(param $out_buffer (@witx pointer u8))
(param $out_buffer_max_size $buffer_size)
(result $error (expected $buffer_size (error $nn_errno)))
)
(@interface func (export "compute")
(param $context $graph_execution_context)
(result $error (expected (error $nn_errno)))
)
)
Loading…
Cancel
Save