Browse Source

wasi-nn: use resources (#8873)

* wasi-nn: use resources

Recent discussion in the wasi-nn proposal (see [wasi-nn#59], e.g.) has
concluded that the right approach for representing wasi-nn "things"
(tensors, graph, etc.) is with a component model _resource_. This
sweeping change brings Wasmtime's implementation in line with that
decision.

Initially I had structured this PR to remove all of the WITX-based
implementation (#8530). But, after consulting in a Zulip [thread] on
what other WASI proposals aim to do, this PR pivoted to support _both_`
the WITX-based and WIT-based ABIs (e.g., preview1 era versus preview2,
component model era). What is clear is that the WITX-based specification
will remain "frozen in time" while the WIT-based implementation moves
forward.

What that means for this PR is a "split world" paradigm. In many places,
we have to distinguish between the `wit` and `witx` versions of the same
thing. This change isn't the end state yet: it's a big step forward
towards bringing Wasmtime back in line with the WIT spec but, despite my
best efforts, doesn't fully fix all the TODOs left behind over several
years of development. I have, however, taken the liberty to refactor and
fix various parts as I came across them (e.g., the ONNX backend). I plan
to continue working on this in future PRs to figure out a good error
paradigm (the current one is too wordy) and device residence.

[wasi-nn#59]: https://github.com/WebAssembly/wasi-nn/pull/59
[thread]: https://bytecodealliance.zulipchat.com/#narrow/stream/219900-wasi/topic/wasi-nn's.20preview1.20vs.20preview2.20timeline

prtest:full

* vet: audit `ort`-related crate updates

* Simplify `WasiNnView`

With @alexcrichton's help, this change removes the `trait WasiNnView`
and `struct WasiNnImpl` wrapping that the WIT-based implementation used
for accessing the host context. Instead, `WasiNnView` is now a `struct`
containing the mutable references it needs to make things work. This
unwraps one complex layer of abstraction, though it does have the
downside that it complicates CLI code to split borrows of `Host`.

* Temporarily disable WIT check

* Refactor errors to use `trappable_error_type`

This change simplifies the return types of the host implementations of
the WIT-based wasi-nn. There is more work to be done with errors, e.g.,
to catch up with the upstream decision to return errors as resources.
But this is better than the previous mess.
pull/8886/head
Andrew Brown 4 months ago
committed by GitHub
parent
commit
0f4ae88a7a
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 50
      Cargo.lock
  2. 6
      ci/vendor-wit.sh
  3. 4
      crates/bench-api/src/lib.rs
  4. 5
      crates/test-programs/artifacts/build.rs
  5. 16
      crates/test-programs/src/bin/nn_image_classification_winml.rs
  6. 14
      crates/test-programs/src/bin/nn_wit_image_classification_onnx.rs
  7. 25
      crates/test-programs/src/bin/nn_wit_image_classification_openvino.rs
  8. 17
      crates/test-programs/src/bin/nn_wit_image_classification_openvino_named.rs
  9. 9
      crates/test-programs/src/bin/nn_wit_image_classification_winml_named.rs
  10. 22
      crates/test-programs/src/bin/nn_witx_image_classification_onnx.rs
  11. 12
      crates/test-programs/src/bin/nn_witx_image_classification_openvino.rs
  12. 17
      crates/test-programs/src/bin/nn_witx_image_classification_openvino_named.rs
  13. 18
      crates/test-programs/src/bin/nn_witx_image_classification_winml_named.rs
  14. 180
      crates/test-programs/src/nn.rs
  15. 17
      crates/wasi-nn/Cargo.toml
  16. 31
      crates/wasi-nn/src/backend/mod.rs
  17. 338
      crates/wasi-nn/src/backend/onnx.rs
  18. 149
      crates/wasi-nn/src/backend/onnxruntime.rs
  19. 36
      crates/wasi-nn/src/backend/openvino.rs
  20. 146
      crates/wasi-nn/src/backend/winml.rs
  21. 146
      crates/wasi-nn/src/ctx.rs
  22. 51
      crates/wasi-nn/src/lib.rs
  23. 5
      crates/wasi-nn/src/registry/in_memory.rs
  24. 1
      crates/wasi-nn/src/registry/mod.rs
  25. 330
      crates/wasi-nn/src/wit.rs
  26. 122
      crates/wasi-nn/src/witx.rs
  27. 8
      crates/wasi-nn/tests/check/mod.rs
  28. 54
      crates/wasi-nn/tests/exec/mod.rs
  29. 73
      crates/wasi-nn/tests/exec/wit.rs
  30. 52
      crates/wasi-nn/tests/exec/witx.rs
  31. 0
      crates/wasi-nn/tests/fixtures/README.md
  32. 200
      crates/wasi-nn/tests/test-programs.rs
  33. 57
      crates/wasi-nn/wit/wasi-nn.wit
  34. 63
      src/commands/run.rs
  35. 22
      src/commands/serve.rs
  36. 18
      supply-chain/audits.toml

50
Cargo.lock

@ -371,15 +371,6 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "castaway"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a17ed5635fc8536268e5d4de1e22e81ac34419e5f052d4d51f4e01dcc263fcc"
dependencies = [
"rustversion",
]
[[package]]
name = "cc"
version = "1.0.83"
@ -486,19 +477,6 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7"
[[package]]
name = "compact_str"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f86b9c4c00838774a6d902ef931eff7470720c51d90c2e32cfe15dc304737b3f"
dependencies = [
"castaway",
"cfg-if",
"itoa",
"ryu",
"static_assertions",
]
[[package]]
name = "component-fuzz-util"
version = "0.0.0"
@ -1908,9 +1886,9 @@ dependencies = [
[[package]]
name = "num-traits"
version = "0.2.15"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
]
@ -2028,21 +2006,22 @@ dependencies = [
[[package]]
name = "ort"
version = "2.0.0-rc.0"
version = "2.0.0-rc.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f8e5caf4eb2ead4bc137c3ff4e347940e3e556ceb11a4180627f04b63d7342dd"
checksum = "0bc80894094c6a875bfac64415ed456fa661081a278a035e22be661305c87e14"
dependencies = [
"compact_str",
"js-sys",
"ort-sys",
"thiserror",
"tracing",
"web-sys",
]
[[package]]
name = "ort-sys"
version = "2.0.0-rc.0"
version = "2.0.0-rc.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f48b5623df2187e0db543ecb2032a6a999081086b7ffddd318000c00b23ace46"
checksum = "b3d9c1373fc813d3f024d394f621f4c6dde0734c79b1c17113c3bb5bf0084bbe"
dependencies = [
"flate2",
"sha2",
@ -3936,9 +3915,10 @@ dependencies = [
"test-programs-artifacts",
"thiserror",
"tracing",
"tracing-subscriber",
"walkdir",
"wasi-common",
"wasmtime",
"wasmtime-wasi",
"wiggle",
"windows",
]
@ -4024,6 +4004,16 @@ dependencies = [
"wast 211.0.1",
]
[[package]]
name = "web-sys"
version = "0.3.57"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b17e741662c70c8bd24ac5c5b18de314a2c26c32bf8346ee1e6f53de919c283"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "webpki-roots"
version = "0.26.1"

6
ci/vendor-wit.sh

@ -36,5 +36,9 @@ cp -r $dst crates/wasi-http/wit
# slightly different than above.
repo=https://raw.githubusercontent.com/WebAssembly/wasi-nn
revision=e2310b
curl -L $repo/$revision/wit/wasi-nn.wit -o crates/wasi-nn/wit/wasi-nn.wit
curl -L $repo/$revision/wasi-nn.witx -o crates/wasi-nn/witx/wasi-nn.witx
# TODO: the in-tree `wasi-nn` implementation does not yet fully support the
# latest WIT specification on `main`. To create a baseline for moving forward,
# the in-tree WIT incorporates some but not all of the upstream changes. This
# TODO can be removed once the implementation catches up with the spec.
# curl -L $repo/$revision/wit/wasi-nn.wit -o crates/wasi-nn/wit/wasi-nn.wit

4
crates/bench-api/src/lib.rs

@ -418,7 +418,7 @@ struct BenchState {
struct HostState {
wasi: WasiCtx,
#[cfg(feature = "wasi-nn")]
wasi_nn: wasmtime_wasi_nn::WasiNnCtx,
wasi_nn: wasmtime_wasi_nn::witx::WasiNnCtx,
}
impl BenchState {
@ -509,7 +509,7 @@ impl BenchState {
#[cfg(feature = "wasi-nn")]
wasi_nn: {
let (backends, registry) = wasmtime_wasi_nn::preload(&[])?;
wasmtime_wasi_nn::WasiNnCtx::new(backends, registry)
wasmtime_wasi_nn::witx::WasiNnCtx::new(backends, registry)
},
};

5
crates/test-programs/artifacts/build.rs

@ -90,7 +90,10 @@ fn build_and_generate_tests() {
}
// Generate a component from each test.
if kind == "nn" || target == "dwarf_imported_memory" || target == "dwarf_shared_memory" {
if target == "dwarf_imported_memory"
|| target == "dwarf_shared_memory"
|| target.starts_with("nn_witx")
{
continue;
}
let adapter = match target.as_str() {

16
crates/test-programs/src/bin/nn_image_classification_winml.rs

@ -1,16 +0,0 @@
use anyhow::{Context, Result};
use std::fs;
use test_programs::nn::{classify, sort_results};
use wasi_nn::{ExecutionTarget, GraphBuilder, GraphEncoding};
pub fn main() -> Result<()> {
let graph = GraphBuilder::new(GraphEncoding::Onnx, ExecutionTarget::CPU)
.build_from_cache("mobilenet")?;
let tensor = fs::read("fixture/kitten.rgb")
.context("the tensor file to be mapped to the fixture directory")?;
let results = classify(graph, tensor)?;
let top_five = &sort_results(&results)[..5];
println!("found results, sorted top 5: {:?}", top_five);
assert_eq!(top_five[0].class_id(), 284);
Ok(())
}

14
crates/test-programs/src/bin/nn_image_classification_onnx.rs → crates/test-programs/src/bin/nn_wit_image_classification_onnx.rs

@ -1,18 +1,20 @@
use anyhow::{Context, Result};
use std::fs;
use test_programs::nn::{classify, sort_results};
use wasi_nn::{ExecutionTarget, GraphBuilder, GraphEncoding};
use test_programs::nn::{sort_results, wit};
pub fn main() -> Result<()> {
let model = fs::read("fixture/model.onnx")
.context("the model file to be mapped to the fixture directory")?;
let graph =
GraphBuilder::new(GraphEncoding::Onnx, ExecutionTarget::CPU).build_from_bytes([&model])?;
let graph = wit::load(
&[model],
wit::GraphEncoding::Onnx,
wit::ExecutionTarget::Cpu,
)?;
let tensor = fs::read("fixture/000000062808.rgb")
.context("the tensor file to be mapped to the fixture directory")?;
let results = classify(graph, tensor)?;
let results = wit::classify(graph, ("input", tensor), "output")?;
let top_five = &sort_results(&results)[..5];
// 963 is meat loaf, meatloaf.
// 963 is "meat loaf, meatloaf."
// https://github.com/onnx/models/blob/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/synset.txt#L963
assert_eq!(top_five[0].class_id(), 963);
println!("found results, sorted top 5: {:?}", top_five);

25
crates/test-programs/src/bin/nn_wit_image_classification_openvino.rs

@ -0,0 +1,25 @@
use anyhow::{Context, Result};
use std::fs;
use test_programs::nn::{sort_results, wit};
pub fn main() -> Result<()> {
let xml = fs::read("fixture/model.xml")
.context("the model file to be mapped to the fixture directory")?;
let weights = fs::read("fixture/model.bin")
.context("the weights file to be mapped to the fixture directory")?;
let graph = wit::load(
&[xml, weights],
wit::GraphEncoding::Openvino,
wit::ExecutionTarget::Cpu,
)?;
let tensor = fs::read("fixture/tensor.bgr")
.context("the tensor file to be mapped to the fixture directory")?;
let results = wit::classify(
graph,
("input", tensor),
"MobilenetV2/Predictions/Reshape_1",
)?;
let top_five = &sort_results(&results)[..5];
println!("found results, sorted top 5: {:?}", top_five);
Ok(())
}

17
crates/test-programs/src/bin/nn_wit_image_classification_openvino_named.rs

@ -0,0 +1,17 @@
use anyhow::{Context, Result};
use std::fs;
use test_programs::nn::{sort_results, wit};
pub fn main() -> Result<()> {
let graph = wit::load_by_name("fixtures")?;
let tensor: Vec<u8> = fs::read("fixture/tensor.bgr")
.context("the tensor file to be mapped to the fixture directory")?;
let results = wit::classify(
graph,
("input", tensor),
"MobilenetV2/Predictions/Reshape_1",
)?;
let top_five = &sort_results(&results)[..5];
println!("found results, sorted top 5: {:?}", top_five);
Ok(())
}

9
crates/test-programs/src/bin/nn_image_classification_named.rs → crates/test-programs/src/bin/nn_wit_image_classification_winml_named.rs

@ -1,15 +1,14 @@
use anyhow::{Context, Result};
use std::fs;
use test_programs::nn::{classify, sort_results};
use wasi_nn::{ExecutionTarget, GraphBuilder, GraphEncoding};
use test_programs::nn::{sort_results, wit};
pub fn main() -> Result<()> {
let graph = GraphBuilder::new(GraphEncoding::Openvino, ExecutionTarget::CPU)
.build_from_cache("fixtures")?;
let graph = wit::load_by_name("mobilenet")?;
let tensor = fs::read("fixture/tensor.bgr")
.context("the tensor file to be mapped to the fixture directory")?;
let results = classify(graph, tensor)?;
let results = wit::classify(graph, ("input", tensor), "output")?;
let top_five = &sort_results(&results)[..5];
println!("found results, sorted top 5: {:?}", top_five);
assert_eq!(top_five[0].class_id(), 284);
Ok(())
}

22
crates/test-programs/src/bin/nn_witx_image_classification_onnx.rs

@ -0,0 +1,22 @@
use anyhow::{Context, Result};
use std::fs;
use test_programs::nn::{sort_results, witx};
pub fn main() -> Result<()> {
let model = fs::read("fixture/model.onnx")
.context("the model file to be mapped to the fixture directory")?;
let graph = witx::load(
&[&model],
witx::GraphEncoding::Onnx,
witx::ExecutionTarget::CPU,
)?;
let tensor = fs::read("fixture/000000062808.rgb")
.context("the tensor file to be mapped to the fixture directory")?;
let results = witx::classify(graph, tensor)?;
let top_five = &sort_results(&results)[..5];
// 963 is "meat loaf, meatloaf."
// https://github.com/onnx/models/blob/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/synset.txt#L963
assert_eq!(top_five[0].class_id(), 963);
println!("found results, sorted top 5: {:?}", top_five);
Ok(())
}

12
crates/test-programs/src/bin/nn_image_classification.rs → crates/test-programs/src/bin/nn_witx_image_classification_openvino.rs

@ -1,18 +1,20 @@
use anyhow::{Context, Result};
use std::fs;
use test_programs::nn::{classify, sort_results};
use wasi_nn::{ExecutionTarget, GraphBuilder, GraphEncoding};
use test_programs::nn::{sort_results, witx};
pub fn main() -> Result<()> {
let xml = fs::read("fixture/model.xml")
.context("the model file to be mapped to the fixture directory")?;
let weights = fs::read("fixture/model.bin")
.context("the weights file to be mapped to the fixture directory")?;
let graph = GraphBuilder::new(GraphEncoding::Openvino, ExecutionTarget::CPU)
.build_from_bytes([&xml, &weights])?;
let graph = witx::load(
&[&xml, &weights],
witx::GraphEncoding::Openvino,
witx::ExecutionTarget::CPU,
)?;
let tensor = fs::read("fixture/tensor.bgr")
.context("the tensor file to be mapped to the fixture directory")?;
let results = classify(graph, tensor)?;
let results = witx::classify(graph, tensor)?;
let top_five = &sort_results(&results)[..5];
println!("found results, sorted top 5: {:?}", top_five);
Ok(())

17
crates/test-programs/src/bin/nn_witx_image_classification_openvino_named.rs

@ -0,0 +1,17 @@
use anyhow::{Context, Result};
use std::fs;
use test_programs::nn::{sort_results, witx};
pub fn main() -> Result<()> {
let graph = witx::load_by_name(
"fixtures",
witx::GraphEncoding::Openvino,
witx::ExecutionTarget::CPU,
)?;
let tensor: Vec<u8> = fs::read("fixture/tensor.bgr")
.context("the tensor file to be mapped to the fixture directory")?;
let results = witx::classify(graph, tensor)?;
let top_five = &sort_results(&results)[..5];
println!("found results, sorted top 5: {:?}", top_five);
Ok(())
}

18
crates/test-programs/src/bin/nn_witx_image_classification_winml_named.rs

@ -0,0 +1,18 @@
use anyhow::{Context, Result};
use std::fs;
use test_programs::nn::{sort_results, witx};
pub fn main() -> Result<()> {
let graph = witx::load_by_name(
"mobilenet",
witx::GraphEncoding::Onnx,
witx::ExecutionTarget::CPU,
)?;
let tensor = fs::read("fixture/tensor.bgr")
.context("the tensor file to be mapped to the fixture directory")?;
let results = witx::classify(graph, tensor)?;
let top_five = &sort_results(&results)[..5];
println!("found results, sorted top 5: {:?}", top_five);
assert_eq!(top_five[0].class_id(), 284);
Ok(())
}

180
crates/test-programs/src/nn.rs

@ -1,39 +1,147 @@
use anyhow::Result;
use std::time::Instant;
use wasi_nn::{Graph, TensorType};
/// Run a wasi-nn inference using a simple classifier model (single input,
/// single output).
pub fn classify(graph: Graph, tensor: Vec<u8>) -> Result<Vec<f32>> {
let mut context = graph.init_execution_context()?;
println!(
"[nn] created wasi-nn execution context with ID: {}",
context
);
// Many classifiers have a single input; currently, this test suite also
// uses tensors of the same shape, though this is not usually the case.
context.set_input(0, TensorType::F32, &[1, 3, 224, 224], &tensor)?;
println!("[nn] set input tensor: {} bytes", tensor.len());
let before = Instant::now();
context.compute()?;
println!(
"[nn] executed graph inference in {} ms",
before.elapsed().as_millis()
);
// Many classifiers emit probabilities as floating point values; here we
// convert the raw bytes to `f32` knowing all models used here use that
// type.
let mut output_buffer = vec![0u8; 1001 * std::mem::size_of::<f32>()];
let num_bytes = context.get_output(0, &mut output_buffer)?;
println!("[nn] retrieved output tensor: {} bytes", num_bytes);
let output: Vec<f32> = output_buffer[..num_bytes]
.chunks(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect();
Ok(output)
//! This module attempts to paper over the differences between the two
//! implementations of wasi-nn: the legacy WITX-based version (`mod witx`) and
//! the up-to-date WIT version (`mod wit`). Since the tests are mainly a simple
//! classifier, this exposes a high-level `classify` function to go along with
//! `load`, etc.
//!
//! This module exists solely for convenience--e.g., reduces test duplication.
//! In the future can be safely disposed of or altered as more tests are added.
/// Call `wasi-nn` functions from WebAssembly using the canonical ABI of the
/// component model via WIT-based tooling. Used by `bin/nn_wit_*.rs` tests.
pub mod wit {
use anyhow::{anyhow, Result};
use std::time::Instant;
// Generate the wasi-nn bindings based on the `*.wit` files.
wit_bindgen::generate!({
path: "../wasi-nn/wit",
world: "ml",
default_bindings_module: "test_programs::ml"
});
use self::wasi::nn::errors;
use self::wasi::nn::graph::{self, Graph};
pub use self::wasi::nn::graph::{ExecutionTarget, GraphEncoding}; // Used by tests.
use self::wasi::nn::tensor::{Tensor, TensorType};
/// Load a wasi-nn graph from a set of bytes.
pub fn load(
bytes: &[Vec<u8>],
encoding: GraphEncoding,
target: ExecutionTarget,
) -> Result<Graph> {
graph::load(bytes, encoding, target).map_err(err_as_anyhow)
}
/// Load a wasi-nn graph by name.
pub fn load_by_name(name: &str) -> Result<Graph> {
graph::load_by_name(name).map_err(err_as_anyhow)
}
/// Run a wasi-nn inference using a simple classifier model (single input,
/// single output).
pub fn classify(graph: Graph, input: (&str, Vec<u8>), output: &str) -> Result<Vec<f32>> {
let context = graph.init_execution_context().map_err(err_as_anyhow)?;
println!(
"[nn] created wasi-nn execution context with ID: {:?}",
context
);
// Many classifiers have a single input; currently, this test suite also
// uses tensors of the same shape, though this is not usually the case.
let tensor = Tensor::new(&vec![1, 3, 224, 224], TensorType::Fp32, &input.1);
context.set_input(input.0, tensor).map_err(err_as_anyhow)?;
println!("[nn] set input tensor: {} bytes", input.1.len());
let before = Instant::now();
context.compute().map_err(err_as_anyhow)?;
println!(
"[nn] executed graph inference in {} ms",
before.elapsed().as_millis()
);
// Many classifiers emit probabilities as floating point values; here we
// convert the raw bytes to `f32` knowing all models used here use that
// type.
let output = context.get_output(output).map_err(err_as_anyhow)?;
println!(
"[nn] retrieved output tensor: {} bytes",
output.data().len()
);
let output: Vec<f32> = output
.data()
.chunks(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect();
Ok(output)
}
fn err_as_anyhow(e: errors::Error) -> anyhow::Error {
anyhow!("error: {e:?}")
}
}
/// Call `wasi-nn` functions from WebAssembly using the legacy WITX-based
/// tooling. This older API has been deprecated for the newer WIT-based API but
/// retained for backwards compatibility testing--i.e., `bin/nn_witx_*.rs`
/// tests.
pub mod witx {
use anyhow::Result;
use std::time::Instant;
pub use wasi_nn::{ExecutionTarget, GraphEncoding};
use wasi_nn::{Graph, GraphBuilder, TensorType};
/// Load a wasi-nn graph from a set of bytes.
pub fn load(
bytes: &[&[u8]],
encoding: GraphEncoding,
target: ExecutionTarget,
) -> Result<Graph> {
Ok(GraphBuilder::new(encoding, target).build_from_bytes(bytes)?)
}
/// Load a wasi-nn graph by name.
pub fn load_by_name(
name: &str,
encoding: GraphEncoding,
target: ExecutionTarget,
) -> Result<Graph> {
Ok(GraphBuilder::new(encoding, target).build_from_cache(name)?)
}
/// Run a wasi-nn inference using a simple classifier model (single input,
/// single output).
pub fn classify(graph: Graph, tensor: Vec<u8>) -> Result<Vec<f32>> {
let mut context = graph.init_execution_context()?;
println!(
"[nn] created wasi-nn execution context with ID: {}",
context
);
// Many classifiers have a single input; currently, this test suite also
// uses tensors of the same shape, though this is not usually the case.
context.set_input(0, TensorType::F32, &[1, 3, 224, 224], &tensor)?;
println!("[nn] set input tensor: {} bytes", tensor.len());
let before = Instant::now();
context.compute()?;
println!(
"[nn] executed graph inference in {} ms",
before.elapsed().as_millis()
);
// Many classifiers emit probabilities as floating point values; here we
// convert the raw bytes to `f32` knowing all models used here use that
// type.
let mut output_buffer = vec![0u8; 1001 * std::mem::size_of::<f32>()];
let num_bytes = context.get_output(0, &mut output_buffer)?;
println!("[nn] retrieved output tensor: {} bytes", num_bytes);
let output: Vec<f32> = output_buffer[..num_bytes]
.chunks(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect();
Ok(output)
}
}
/// Sort some classification probabilities.

17
crates/wasi-nn/Cargo.toml

@ -20,7 +20,11 @@ anyhow = { workspace = true, features = ['std'] }
wiggle = { workspace = true, features = ["wasmtime"] }
# This dependency is necessary for the WIT-generation macros to work:
wasmtime = { workspace = true, features = ["component-model", "runtime"] }
wasmtime = { workspace = true, features = [
"component-model",
"runtime",
"std",
] }
# These dependencies are necessary for the wasi-nn implementation:
tracing = { workspace = true }
@ -29,7 +33,7 @@ openvino = { version = "0.6.0", features = [
"runtime-linking",
], optional = true }
ort = { version = "2.0.0-rc.0", default-features = false, features = [
ort = { version = "2.0.0-rc.2", default-features = false, features = [
"copy-dylibs",
"download-binaries",
], optional = true }
@ -46,16 +50,17 @@ walkdir = { workspace = true }
cap-std = { workspace = true }
libtest-mimic = { workspace = true }
test-programs-artifacts = { workspace = true }
wasi-common = { workspace = true, features = ["sync"] }
wasmtime-wasi = { workspace = true, features = ["preview1"] }
wasmtime = { workspace = true, features = ["cranelift"] }
tracing-subscriber = { workspace = true }
[features]
default = ["openvino", "winml"]
# openvino is available on all platforms, it requires openvino installed.
# OpenVINO is available on all platforms; it requires OpenVINO to be installed.
openvino = ["dep:openvino"]
# onnx is available on all platforms.
# ONNX is available on all platforms.
onnx = ["dep:ort"]
# winml is only available on Windows 10 1809 and later.
# WinML is only available on Windows 10 1809 and later.
winml = ["dep:windows"]
[[test]]

31
crates/wasi-nn/src/backend/mod.rs

@ -3,20 +3,20 @@
//! implementations to maintain backend-specific state between calls.
#[cfg(feature = "onnx")]
pub mod onnxruntime;
pub mod onnx;
#[cfg(feature = "openvino")]
pub mod openvino;
#[cfg(all(feature = "winml", target_os = "windows"))]
pub mod winml;
#[cfg(feature = "onnx")]
use self::onnxruntime::OnnxBackend;
use self::onnx::OnnxBackend;
#[cfg(feature = "openvino")]
use self::openvino::OpenvinoBackend;
#[cfg(all(feature = "winml", target_os = "windows"))]
use self::winml::WinMLBackend;
use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor};
use crate::wit::{ExecutionTarget, GraphEncoding, Tensor};
use crate::{Backend, ExecutionContext, Graph};
use std::fs::File;
use std::io::Read;
@ -69,9 +69,30 @@ pub trait BackendGraph: Send + Sync {
/// A [BackendExecutionContext] performs the actual inference; this is the
/// backing implementation for a user-facing execution context.
pub trait BackendExecutionContext: Send + Sync {
fn set_input(&mut self, index: u32, tensor: &Tensor) -> Result<(), BackendError>;
fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError>;
fn compute(&mut self) -> Result<(), BackendError>;
fn get_output(&mut self, index: u32, destination: &mut [u8]) -> Result<u32, BackendError>;
fn get_output(&mut self, id: Id) -> Result<Tensor, BackendError>;
}
/// An identifier for a tensor in a [Graph].
#[derive(Debug)]
pub enum Id {
Index(u32),
Name(String),
}
impl Id {
pub fn index(&self) -> Option<u32> {
match self {
Id::Index(i) => Some(*i),
Id::Name(_) => None,
}
}
pub fn name(&self) -> Option<&str> {
match self {
Id::Index(_) => None,
Id::Name(n) => Some(n),
}
}
}
/// Errors returned by a backend; [BackendError::BackendAccess] is a catch-all

338
crates/wasi-nn/src/backend/onnx.rs

@ -0,0 +1,338 @@
//! Implements a `wasi-nn` [`BackendInner`] using ONNX via the `ort` crate.
use super::{BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner};
use crate::backend::{read, Id};
use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor, TensorType};
use crate::{ExecutionContext, Graph};
use anyhow::Context;
use ort::{inputs, GraphOptimizationLevel, Session};
use std::path::Path;
use std::sync::{Arc, Mutex};
#[derive(Default)]
pub struct OnnxBackend();
unsafe impl Send for OnnxBackend {}
unsafe impl Sync for OnnxBackend {}
impl BackendInner for OnnxBackend {
fn encoding(&self) -> GraphEncoding {
GraphEncoding::Onnx
}
fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError> {
if builders.len() != 1 {
return Err(BackendError::InvalidNumberOfBuilders(1, builders.len()).into());
}
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.commit_from_memory(builders[0])?;
let box_: Box<dyn BackendGraph> =
Box::new(OnnxGraph(Arc::new(Mutex::new(session)), target));
Ok(box_.into())
}
fn as_dir_loadable<'a>(&'a mut self) -> Option<&'a mut dyn BackendFromDir> {
Some(self)
}
}
impl BackendFromDir for OnnxBackend {
fn load_from_dir(
&mut self,
path: &Path,
target: ExecutionTarget,
) -> Result<Graph, BackendError> {
let model = read(&path.join("model.onnx"))?;
self.load(&[&model], target)
}
}
struct OnnxGraph(Arc<Mutex<Session>>, #[allow(dead_code)] ExecutionTarget);
unsafe impl Send for OnnxGraph {}
unsafe impl Sync for OnnxGraph {}
impl BackendGraph for OnnxGraph {
fn init_execution_context(&self) -> Result<ExecutionContext, BackendError> {
let session = self.0.lock().unwrap();
// We need to hold on to the names of the inputs in order for
// `set_input` to work with both indexes and names. Having the
// dimensions and type around is useful for validation but could be
// retrieved from the session.
let mut inputs = vec![];
for input in &session.inputs {
let shape = Shape::from_onnx_input(input)?;
inputs.push(TensorSlot {
shape,
tensor: None,
});
}
// We need to keep track of the output shapes since they are used for
// creating the output tensor.
let mut outputs = vec![];
for output in &session.outputs {
let shape = Shape::from_onnx_output(output)?;
outputs.push(TensorSlot {
shape,
tensor: None,
});
}
let box_: Box<dyn BackendExecutionContext> = Box::new(OnnxExecutionContext {
session: self.0.clone(),
inputs,
outputs,
});
Ok(box_.into())
}
}
struct OnnxExecutionContext {
session: Arc<Mutex<Session>>,
inputs: Vec<TensorSlot>,
outputs: Vec<TensorSlot>,
}
unsafe impl Send for OnnxExecutionContext {}
unsafe impl Sync for OnnxExecutionContext {}
impl OnnxExecutionContext {
/// Helper function for finding the internal index of a tensor by [`Id`].
fn find(&self, id: Id, list: &[TensorSlot]) -> Result<usize, BackendError> {
let index = match id {
Id::Index(i) => {
let i = i as usize;
if i < list.len() {
i
} else {
return Err(BackendError::BackendAccess(anyhow::anyhow!(
"incorrect tensor index: {i} >= {}",
list.len()
)));
}
}
Id::Name(n) => list.iter().position(|s| s.shape.name == n).ok_or_else(|| {
BackendError::BackendAccess(anyhow::anyhow!("unknown tensor name: {n}"))
})?,
};
Ok(index)
}
}
impl BackendExecutionContext for OnnxExecutionContext {
fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> {
let index = self.find(id, &self.inputs)?;
let input = &mut self.inputs[index];
if let Err(e) = input.shape.matches(tensor) {
return Err(e.into());
}
// Hold the tensor data on the context until `compute` is called.
input.tensor.replace(tensor.clone());
Ok(())
}
fn compute(&mut self) -> Result<(), BackendError> {
let mut session_inputs: Vec<ort::SessionInputValue<'_>> = vec![];
for i in &self.inputs {
session_inputs.extend(to_input_value(i)?);
}
let session = self.session.lock().unwrap();
let session_outputs = session.run(session_inputs.as_slice())?;
for i in 0..self.outputs.len() {
// TODO: fix preexisting gap--this only handles f32 tensors.
let raw: (Vec<i64>, &[f32]) = session_outputs[i].try_extract_raw_tensor()?;
let f32s = raw.1.to_vec();
let output = &mut self.outputs[i];
output.tensor.replace(Tensor {
dimensions: output.shape.dimensions_as_u32()?,
ty: output.shape.ty,
data: f32_vec_to_bytes(f32s),
});
}
Ok(())
}
fn get_output(&mut self, id: Id) -> Result<Tensor, BackendError> {
let index = self.find(id, &self.outputs)?;
let output = &self.outputs[index];
if let Some(tensor) = &output.tensor {
Ok(tensor.clone())
} else {
Err(BackendError::BackendAccess(anyhow::anyhow!(
"missing output tensor: {}; has `compute` been called?",
output.shape.name
)))
}
}
}
impl From<ort::Error> for BackendError {
fn from(e: ort::Error) -> Self {
BackendError::BackendAccess(e.into())
}
}
/// Holds a slot for ONNX session inputs and outputs.
///
/// TODO: it seems unfortunate that we have to "hold" some extra data per
/// session but in the input case, this is necessary for name-based indexing.
struct TensorSlot {
shape: Shape,
tensor: Option<Tensor>,
}
/// Describes a tensor in ONNX terms.
struct Shape {
name: String,
dimensions: Vec<i64>,
ty: TensorType,
}
impl Shape {
fn from_onnx_input(input: &ort::Input) -> Result<Self, BackendError> {
let name = input.name.clone();
let (dimensions, ty) = convert_value_type(&input.input_type)?;
Ok(Self {
name,
dimensions,
ty,
})
}
fn from_onnx_output(output: &ort::Output) -> Result<Self, BackendError> {
let name = output.name.clone();
let (dimensions, ty) = convert_value_type(&output.output_type)?;
Ok(Self {
name,
dimensions,
ty,
})
}
fn dimensions_as_u32(&self) -> Result<Vec<u32>, BackendError> {
self.dimensions
.iter()
.map(|d| if *d == -1 { Ok(1) } else { convert_i64(d) })
.collect()
}
fn matches(&self, tensor: &Tensor) -> anyhow::Result<()> {
if self.dimensions.len() != tensor.dimensions.len() {
return Err(anyhow::anyhow!(
"input tensor cardinality does not match model: {:?} != {:?}",
self.dimensions,
tensor.dimensions
));
} else {
for (&shape_dim, &tensor_dim) in self.dimensions.iter().zip(tensor.dimensions.iter()) {
let tensor_dim = tensor_dim as i64;
if !is_dynamic_dimension(shape_dim) && shape_dim != tensor_dim {
return Err(anyhow::anyhow!(
"input tensor dimensions do not match model: {:?} != {:?}",
self.dimensions,
tensor.dimensions
));
}
}
}
if self.ty != tensor.ty {
return Err(anyhow::anyhow!(
"input tensor type does not match model: {:?} != {:?}",
self.ty,
tensor.ty
));
}
Ok(())
}
}
fn convert_value_type(vt: &ort::ValueType) -> Result<(Vec<i64>, TensorType), BackendError> {
match vt {
ort::ValueType::Tensor { ty, dimensions } => {
let dims = dimensions.clone();
let ty = (*ty).try_into()?;
Ok((dims, ty))
}
_ => Err(BackendError::BackendAccess(anyhow::anyhow!(
"unsupported input type: {vt:?}"
))),
}
}
fn convert_i64(i: &i64) -> Result<u32, BackendError> {
u32::try_from(*i).map_err(|d| -> BackendError {
anyhow::anyhow!("unable to convert dimension to u32: {d}").into()
})
}
impl TryFrom<ort::TensorElementType> for TensorType {
type Error = BackendError;
fn try_from(ty: ort::TensorElementType) -> Result<Self, Self::Error> {
match ty {
ort::TensorElementType::Float32 => Ok(TensorType::Fp32),
ort::TensorElementType::Float64 => Ok(TensorType::Fp64),
ort::TensorElementType::Uint8 => Ok(TensorType::U8),
ort::TensorElementType::Int32 => Ok(TensorType::I32),
ort::TensorElementType::Int64 => Ok(TensorType::I64),
_ => Err(BackendError::BackendAccess(anyhow::anyhow!(
"unsupported tensor type: {ty:?}"
))),
}
}
}
fn to_input_value(slot: &TensorSlot) -> Result<[ort::SessionInputValue<'_>; 1], BackendError> {
match &slot.tensor {
Some(tensor) => match tensor.ty {
TensorType::Fp32 => {
let data = bytes_to_f32_vec(tensor.data.to_vec());
let dimensions = tensor
.dimensions
.iter()
.map(|d| *d as i64) // TODO: fewer conversions
.collect::<Vec<i64>>();
Ok(inputs![(dimensions, Arc::new(data.into_boxed_slice()))]
.context("failed to create ONNX session input")?)
}
_ => {
unimplemented!("{:?} not supported by ONNX", tensor.ty);
}
},
None => {
return Err(BackendError::BackendAccess(anyhow::anyhow!(
"missing input tensor: {}",
slot.shape.name
)));
}
}
}
pub fn f32_vec_to_bytes(data: Vec<f32>) -> Vec<u8> {
let chunks: Vec<[u8; 4]> = data.into_iter().map(|f| f.to_le_bytes()).collect();
let result: Vec<u8> = chunks.iter().flatten().copied().collect();
result
}
pub fn bytes_to_f32_vec(data: Vec<u8>) -> Vec<f32> {
let chunks: Vec<&[u8]> = data.chunks(4).collect();
let v: Vec<f32> = chunks
.into_iter()
.map(|c| f32::from_le_bytes(c.try_into().unwrap()))
.collect();
v.into_iter().collect()
}
/// Returns whether the dimension is dynamic.
///
/// ONNX uses [dimensional variables] (i.e., name strings) to indicate that the
/// value of a tensor dimension is user-defined, not fixed by the model. This is
/// useful for batching up several inference requests, e.g. When `ort` returns a
/// dimension of this kind, though, it uses `-1` to indicate that the dimension
/// is dynamic.
///
/// [dimensional variables]:
/// https://onnx.ai/onnx/repo-docs/IR.html#static-tensor-shapes
fn is_dynamic_dimension(d: i64) -> bool {
d == -1
}

149
crates/wasi-nn/src/backend/onnxruntime.rs

@ -1,149 +0,0 @@
//! Implements a `wasi-nn` [`BackendInner`] using ONNX via ort.
use super::{BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner};
use crate::backend::read;
use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor, TensorType};
use crate::{ExecutionContext, Graph};
use ort::{inputs, GraphOptimizationLevel, Session};
use std::path::Path;
use std::sync::{Arc, Mutex};
#[derive(Default)]
pub struct OnnxBackend();
unsafe impl Send for OnnxBackend {}
unsafe impl Sync for OnnxBackend {}
impl BackendInner for OnnxBackend {
fn encoding(&self) -> GraphEncoding {
GraphEncoding::Onnx
}
fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError> {
if builders.len() != 1 {
return Err(BackendError::InvalidNumberOfBuilders(1, builders.len()).into());
}
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_model_from_memory(builders[0])?;
let box_: Box<dyn BackendGraph> =
Box::new(ONNXGraph(Arc::new(Mutex::new(session)), target));
Ok(box_.into())
}
fn as_dir_loadable<'a>(&'a mut self) -> Option<&'a mut dyn BackendFromDir> {
Some(self)
}
}
impl BackendFromDir for OnnxBackend {
fn load_from_dir(
&mut self,
path: &Path,
target: ExecutionTarget,
) -> Result<Graph, BackendError> {
let model = read(&path.join("model.onnx"))?;
self.load(&[&model], target)
}
}
struct ONNXGraph(Arc<Mutex<Session>>, #[allow(dead_code)] ExecutionTarget);
unsafe impl Send for ONNXGraph {}
unsafe impl Sync for ONNXGraph {}
impl BackendGraph for ONNXGraph {
fn init_execution_context(&self) -> Result<ExecutionContext, BackendError> {
let session = self.0.lock().unwrap();
let inputs = session.inputs.iter().map(|_| None).collect::<Vec<_>>();
let outputs = session.outputs.iter().map(|_| None).collect::<Vec<_>>();
let box_: Box<dyn BackendExecutionContext> = Box::new(ONNXExecutionContext {
session: self.0.clone(),
inputs,
outputs,
});
Ok(box_.into())
}
}
struct ONNXExecutionContext {
session: Arc<Mutex<Session>>,
inputs: Vec<Option<Tensor>>,
outputs: Vec<Option<Vec<u8>>>,
}
unsafe impl Send for ONNXExecutionContext {}
unsafe impl Sync for ONNXExecutionContext {}
impl BackendExecutionContext for ONNXExecutionContext {
fn set_input(&mut self, index: u32, tensor: &Tensor) -> Result<(), BackendError> {
self.inputs[index as usize].replace(tensor.clone());
Ok(())
}
fn compute(&mut self) -> Result<(), BackendError> {
let shaped_inputs: Vec<_> = self
.inputs
.iter()
.enumerate()
.map(|(i, _o)| {
let input = self.inputs[i].as_ref().unwrap();
let dims = input
.dimensions
.as_slice()
.iter()
.map(|d| *d as i64)
.collect::<Vec<_>>();
match input.tensor_type {
TensorType::Fp32 => {
let data = bytes_to_f32_vec(input.data.to_vec());
inputs![(dims, Arc::new(data.into_boxed_slice()))].unwrap()
}
_ => {
unimplemented!("{:?} not supported by ONNX", input.tensor_type);
}
}
})
.flatten()
.collect();
let session = self.session.lock().unwrap();
let res = session.run(shaped_inputs.as_slice())?;
for i in 0..self.outputs.len() {
let raw: (Vec<i64>, &[f32]) = res[i].extract_raw_tensor()?;
let f32s = raw.1.to_vec();
self.outputs[i].replace(f32_vec_to_bytes(f32s));
}
Ok(())
}
fn get_output(&mut self, index: u32, destination: &mut [u8]) -> Result<u32, BackendError> {
let output = self.outputs[index as usize].as_ref().unwrap();
destination[..output.len()].copy_from_slice(output);
Ok(output.len() as u32)
}
}
impl From<ort::Error> for BackendError {
fn from(e: ort::Error) -> Self {
BackendError::BackendAccess(e.into())
}
}
pub fn f32_vec_to_bytes(data: Vec<f32>) -> Vec<u8> {
let chunks: Vec<[u8; 4]> = data.into_iter().map(|f| f.to_le_bytes()).collect();
let result: Vec<u8> = chunks.iter().flatten().copied().collect();
result
}
pub fn bytes_to_f32_vec(data: Vec<u8>) -> Vec<f32> {
let chunks: Vec<&[u8]> = data.chunks(4).collect();
let v: Vec<f32> = chunks
.into_iter()
.map(|c| f32::from_le_bytes(c.try_into().unwrap()))
.collect();
v.into_iter().collect()
}

36
crates/wasi-nn/src/backend/openvino.rs

@ -1,9 +1,9 @@
//! Implements a `wasi-nn` [`BackendInner`] using OpenVINO.
use super::{
read, BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner,
read, BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner, Id,
};
use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor, TensorType};
use crate::wit::{self, ExecutionTarget, GraphEncoding, Tensor, TensorType};
use crate::{ExecutionContext, Graph};
use openvino::{InferenceError, Layout, Precision, SetupError, TensorDesc};
use std::path::Path;
@ -99,12 +99,15 @@ impl BackendGraph for OpenvinoGraph {
struct OpenvinoExecutionContext(Arc<openvino::CNNNetwork>, openvino::InferRequest);
impl BackendExecutionContext for OpenvinoExecutionContext {
fn set_input(&mut self, index: u32, tensor: &Tensor) -> Result<(), BackendError> {
let input_name = self.0.get_input_name(index as usize)?;
fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> {
let input_name = match id {
Id::Index(i) => self.0.get_input_name(i as usize)?,
Id::Name(name) => name,
};
// Construct the blob structure. TODO: there must be some good way to
// discover the layout here; `desc` should not have to default to NHWC.
let precision = map_tensor_type_to_precision(tensor.tensor_type);
let precision = map_tensor_type_to_precision(tensor.ty);
let dimensions = tensor
.dimensions
.iter()
@ -123,17 +126,20 @@ impl BackendExecutionContext for OpenvinoExecutionContext {
Ok(())
}
fn get_output(&mut self, index: u32, destination: &mut [u8]) -> Result<u32, BackendError> {
let output_name = self.0.get_output_name(index as usize)?;
fn get_output(&mut self, id: Id) -> Result<Tensor, BackendError> {
let output_name = match id {
Id::Index(i) => self.0.get_output_name(i as usize)?,
Id::Name(name) => name,
};
let dimensions = vec![]; // TODO: get actual shape
let ty = wit::TensorType::Fp32; // TODO: get actual type.
let blob = self.1.get_blob(&output_name)?;
let blob_size = blob.byte_len()?;
if blob_size > destination.len() {
return Err(BackendError::NotEnoughMemory(blob_size));
}
// Copy the tensor data into the destination buffer.
destination[..blob_size].copy_from_slice(blob.buffer()?);
Ok(blob_size as u32)
let data = blob.buffer()?.to_vec();
Ok(Tensor {
dimensions,
ty,
data,
})
}
}

146
crates/wasi-nn/src/backend/winml.rs

@ -1,16 +1,27 @@
//! Implements a `wasi-nn` [`BackendInner`] using WinML.
use super::{BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner};
use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor};
//!
//! Note that the [docs.rs] documentation for the `windows` crate does have the
//! right features turned on to read about the functions used; see Microsoft's
//! private documentation instead: [microsoft.github.io/windows-docs-rs].
//!
//! [docs.rs]: https://docs.rs/windows
//! [microsoft.github.io/windows-docs-rs]: https://microsoft.github.io/windows-docs-rs/doc/windows/AI/MachineLearning
use crate::backend::{
BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner, Id,
};
use crate::wit::{ExecutionTarget, GraphEncoding, Tensor, TensorType};
use crate::{ExecutionContext, Graph};
use std::{fs::File, io::Read, mem::size_of, path::Path};
use windows::core::{ComInterface, HSTRING};
use windows::Foundation::Collections::IVectorView;
use windows::Storage::Streams::{
DataWriter, InMemoryRandomAccessStream, RandomAccessStreamReference,
};
use windows::AI::MachineLearning::{
LearningModel, LearningModelBinding, LearningModelDevice, LearningModelDeviceKind,
LearningModelEvaluationResult, LearningModelSession, TensorFeatureDescriptor, TensorFloat,
ILearningModelFeatureDescriptor, LearningModel, LearningModelBinding, LearningModelDevice,
LearningModelDeviceKind, LearningModelEvaluationResult, LearningModelSession,
TensorFeatureDescriptor, TensorFloat,
};
#[derive(Default)]
@ -94,29 +105,64 @@ impl WinMLExecutionContext {
}
}
impl WinMLExecutionContext {
/// Helper function for finding the internal index of a tensor by [`Id`].
fn find(
&self,
id: Id,
list: &IVectorView<ILearningModelFeatureDescriptor>,
) -> Result<u32, BackendError> {
let index = match id {
Id::Index(i) => {
if i < list.Size()? {
i
} else {
return Err(BackendError::BackendAccess(anyhow::anyhow!(
"incorrect tensor index: {i} >= {}",
list.Size()?
)));
}
}
Id::Name(name) => list
.into_iter()
.position(|d| d.Name().unwrap() == name)
.ok_or_else(|| {
BackendError::BackendAccess(anyhow::anyhow!("unknown tensor name: {name}"))
})? as u32,
};
Ok(index)
}
}
impl BackendExecutionContext for WinMLExecutionContext {
fn set_input(&mut self, index: u32, tensor: &Tensor) -> Result<(), BackendError> {
fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> {
let input_features = self.session.Model()?.InputFeatures()?;
let index = self.find(id, &input_features)?;
let input = input_features.GetAt(index)?;
// TODO: Support other tensor types. Only FP32 is supported right now.
match tensor.tensor_type {
match tensor.ty {
crate::wit::types::TensorType::Fp32 => {}
_ => unimplemented!(),
}
let input = self.session.Model()?.InputFeatures()?.GetAt(index)?;
unsafe {
let data = std::slice::from_raw_parts(
// TODO: this is quite unsafe and probably incorrect--will the slice
// still be around by the time the binding is used?!
let data = unsafe {
std::slice::from_raw_parts(
tensor.data.as_ptr() as *const f32,
tensor.data.len() / 4,
);
self.binding.Bind(
&input.Name()?,
&TensorFloat::CreateFromArray(
&input.cast::<TensorFeatureDescriptor>()?.Shape()?,
data,
)?,
)?;
}
tensor.data.len() / size_of::<f32>(),
)
};
self.binding.Bind(
&input.Name()?,
&TensorFloat::CreateFromArray(
&input.cast::<TensorFeatureDescriptor>()?.Shape()?,
data,
)?,
)?;
Ok(())
}
@ -125,33 +171,32 @@ impl BackendExecutionContext for WinMLExecutionContext {
Ok(())
}
fn get_output(&mut self, index: u32, destination: &mut [u8]) -> Result<u32, BackendError> {
if self.result.is_none() {
fn get_output(&mut self, id: Id) -> Result<Tensor, BackendError> {
if let Some(result) = &self.result {
let output_features = self.session.Model()?.OutputFeatures()?;
let index = self.find(id, &output_features)?;
let output = output_features.GetAt(index)?;
// TODO: this only handles FP32!
let tensor = result
.Outputs()?
.Lookup(&output.Name()?)?
.cast::<TensorFloat>()?;
let dimensions = dimensions_as_u32(&tensor.Shape()?)?;
let view = tensor.GetAsVectorView()?;
let mut data = Vec::with_capacity(view.Size()? as usize * size_of::<f32>());
for f in view.into_iter() {
data.extend(f.to_le_bytes());
}
Ok(Tensor {
ty: TensorType::Fp32,
dimensions,
data,
})
} else {
return Err(BackendError::BackendAccess(anyhow::Error::msg(
"Output is not ready.",
)));
}
let output_name = self.session.Model()?.OutputFeatures()?.GetAt(index)?;
let output_name_hstring = output_name.Name()?;
let vector_view = self
.result
.as_ref()
.unwrap()
.Outputs()?
.Lookup(&output_name_hstring)?
.cast::<TensorFloat>()?
.GetAsVectorView()?;
let output: Vec<f32> = vector_view.into_iter().collect();
let len_to_copy = output.len() * size_of::<f32>();
unsafe {
destination[..len_to_copy].copy_from_slice(std::slice::from_raw_parts(
output.as_ptr() as *const u8,
len_to_copy,
));
}
Ok(len_to_copy as u32)
}
}
@ -168,3 +213,16 @@ impl From<windows::core::Error> for BackendError {
BackendError::BackendAccess(anyhow::Error::new(e))
}
}
fn dimensions_as_u32(dimensions: &IVectorView<i64>) -> Result<Vec<u32>, BackendError> {
dimensions
.into_iter()
.map(|d| if d == -1 { Ok(1) } else { convert_i64(d) })
.collect()
}
fn convert_i64(i: i64) -> Result<u32, BackendError> {
u32::try_from(i).map_err(|d| -> BackendError {
anyhow::anyhow!("unable to convert dimension to u32: {d}").into()
})
}

146
crates/wasi-nn/src/ctx.rs

@ -1,146 +0,0 @@
//! Implements the host state for the `wasi-nn` API: [WasiNnCtx].
use crate::backend::{self, BackendError};
use crate::wit::types::GraphEncoding;
use crate::{Backend, ExecutionContext, Graph, InMemoryRegistry, Registry};
use anyhow::anyhow;
use std::{collections::HashMap, hash::Hash, path::Path};
use thiserror::Error;
use wiggle::GuestError;
type GraphId = u32;
type GraphExecutionContextId = u32;
type BackendName = String;
type GraphDirectory = String;
/// Construct an in-memory registry from the available backends and a list of
/// `(<backend name>, <graph directory>)`. This assumes graphs can be loaded
/// from a local directory, which is a safe assumption currently for the current
/// model types.
pub fn preload(
preload_graphs: &[(BackendName, GraphDirectory)],
) -> anyhow::Result<(impl IntoIterator<Item = Backend>, Registry)> {
let mut backends = backend::list();
let mut registry = InMemoryRegistry::new();
for (kind, path) in preload_graphs {
let kind_ = kind.parse()?;
let backend = backends
.iter_mut()
.find(|b| b.encoding() == kind_)
.ok_or(anyhow!("unsupported backend: {}", kind))?
.as_dir_loadable()
.ok_or(anyhow!("{} does not support directory loading", kind))?;
registry.load(backend, Path::new(path))?;
}
Ok((backends, Registry::from(registry)))
}
/// Capture the state necessary for calling into the backend ML libraries.
pub struct WasiNnCtx {
pub(crate) backends: HashMap<GraphEncoding, Backend>,
pub(crate) registry: Registry,
pub(crate) graphs: Table<GraphId, Graph>,
pub(crate) executions: Table<GraphExecutionContextId, ExecutionContext>,
}
impl WasiNnCtx {
/// Make a new context from the default state.
pub fn new(backends: impl IntoIterator<Item = Backend>, registry: Registry) -> Self {
let backends = backends.into_iter().map(|b| (b.encoding(), b)).collect();
Self {
backends,
registry,
graphs: Table::default(),
executions: Table::default(),
}
}
}
/// Possible errors while interacting with [WasiNnCtx].
#[derive(Debug, Error)]
pub enum WasiNnError {
#[error("backend error")]
BackendError(#[from] BackendError),
#[error("guest error")]
GuestError(#[from] GuestError),
#[error("usage error")]
UsageError(#[from] UsageError),
}
#[derive(Debug, Error)]
pub enum UsageError {
#[error("Invalid context; has the load function been called?")]
InvalidContext,
#[error("Only OpenVINO's IR is currently supported, passed encoding: {0:?}")]
InvalidEncoding(GraphEncoding),
#[error("OpenVINO expects only two buffers (i.e. [ir, weights]), passed: {0}")]
InvalidNumberOfBuilders(u32),
#[error("Invalid graph handle; has it been loaded?")]
InvalidGraphHandle,
#[error("Invalid execution context handle; has it been initialized?")]
InvalidExecutionContextHandle,
#[error("Not enough memory to copy tensor data of size: {0}")]
NotEnoughMemory(u32),
#[error("No graph found with name: {0}")]
NotFound(String),
}
pub(crate) type WasiNnResult<T> = std::result::Result<T, WasiNnError>;
/// Record handle entries in a table.
pub struct Table<K, V> {
entries: HashMap<K, V>,
next_key: u32,
}
impl<K, V> Default for Table<K, V> {
fn default() -> Self {
Self {
entries: HashMap::new(),
next_key: 0,
}
}
}
impl<K, V> Table<K, V>
where
K: Eq + Hash + From<u32> + Copy,
{
pub fn insert(&mut self, value: V) -> K {
let key = self.use_next_key();
self.entries.insert(key, value);
key
}
pub fn get(&self, key: K) -> Option<&V> {
self.entries.get(&key)
}
pub fn get_mut(&mut self, key: K) -> Option<&mut V> {
self.entries.get_mut(&key)
}
fn use_next_key(&mut self) -> K {
let current = self.next_key;
self.next_key += 1;
K::from(current)
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::registry::GraphRegistry;
#[test]
fn example() {
struct FakeRegistry;
impl GraphRegistry for FakeRegistry {
fn get_mut(&mut self, _: &str) -> Option<&mut Graph> {
None
}
}
let _ctx = WasiNnCtx::new([], Registry::from(FakeRegistry));
}
}

51
crates/wasi-nn/src/lib.rs

@ -1,14 +1,34 @@
mod ctx;
mod registry;
pub mod backend;
pub use ctx::{preload, WasiNnCtx};
pub use registry::{GraphRegistry, InMemoryRegistry};
mod registry;
pub mod wit;
pub mod witx;
use anyhow::anyhow;
use core::fmt;
pub use registry::{GraphRegistry, InMemoryRegistry};
use std::path::Path;
use std::sync::Arc;
/// Construct an in-memory registry from the available backends and a list of
/// `(<backend name>, <graph directory>)`. This assumes graphs can be loaded
/// from a local directory, which is a safe assumption currently for the current
/// model types.
pub fn preload(preload_graphs: &[(String, String)]) -> anyhow::Result<(Vec<Backend>, Registry)> {
let mut backends = backend::list();
let mut registry = InMemoryRegistry::new();
for (kind, path) in preload_graphs {
let kind_ = kind.parse()?;
let backend = backends
.iter_mut()
.find(|b| b.encoding() == kind_)
.ok_or(anyhow!("unsupported backend: {}", kind))?
.as_dir_loadable()
.ok_or(anyhow!("{} does not support directory loading", kind))?;
registry.load(backend, Path::new(path))?;
}
Ok((backends, Registry::from(registry)))
}
/// A machine learning backend.
pub struct Backend(Box<dyn backend::BackendInner>);
impl std::ops::Deref for Backend {
@ -43,6 +63,27 @@ impl std::ops::Deref for Graph {
}
}
/// A host-side tensor.
///
/// Eventually, this may be defined in each backend as they gain the ability to
/// hold tensors on various devices (TODO:
/// https://github.com/WebAssembly/wasi-nn/pull/70).
#[derive(Clone)]
pub struct Tensor {
dimensions: Vec<u32>,
ty: wit::TensorType,
data: Vec<u8>,
}
impl fmt::Debug for Tensor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Tensor")
.field("dimensions", &self.dimensions)
.field("ty", &self.ty)
.field("data (bytes)", &self.data.len())
.finish()
}
}
/// A backend-defined execution context.
pub struct ExecutionContext(Box<dyn backend::BackendExecutionContext>);
impl From<Box<dyn backend::BackendExecutionContext>> for ExecutionContext {

5
crates/wasi-nn/src/registry/in_memory.rs

@ -2,7 +2,7 @@
use super::{Graph, GraphRegistry};
use crate::backend::BackendFromDir;
use crate::wit::types::ExecutionTarget;
use crate::wit::ExecutionTarget;
use anyhow::{anyhow, bail};
use std::{collections::HashMap, path::Path};
@ -37,6 +37,9 @@ impl InMemoryRegistry {
}
impl GraphRegistry for InMemoryRegistry {
fn get(&self, name: &str) -> Option<&Graph> {
self.0.get(name)
}
fn get_mut(&mut self, name: &str) -> Option<&mut Graph> {
self.0.get_mut(name)
}

1
crates/wasi-nn/src/registry/mod.rs

@ -12,5 +12,6 @@ use crate::Graph;
pub use in_memory::InMemoryRegistry;
pub trait GraphRegistry: Send + Sync {
fn get(&self, name: &str) -> Option<&Graph>;
fn get_mut(&mut self, name: &str) -> Option<&mut Graph>;
}

330
crates/wasi-nn/src/wit.rs

@ -15,8 +15,69 @@
//! [`Backend`]: crate::Backend
//! [`types`]: crate::wit::types
use crate::{ctx::UsageError, WasiNnCtx};
use std::{error::Error, fmt, hash::Hash, str::FromStr};
use crate::backend::Id;
use crate::{Backend, Registry};
use std::collections::HashMap;
use std::hash::Hash;
use std::{fmt, str::FromStr};
use wasmtime::component::{Resource, ResourceTable};
/// Capture the state necessary for calling into the backend ML libraries.
pub struct WasiNnCtx {
pub(crate) backends: HashMap<GraphEncoding, Backend>,
pub(crate) registry: Registry,
}
impl WasiNnCtx {
/// Make a new context from the default state.
pub fn new(backends: impl IntoIterator<Item = Backend>, registry: Registry) -> Self {
let backends = backends.into_iter().map(|b| (b.encoding(), b)).collect();
Self { backends, registry }
}
}
/// A wrapper capturing the needed internal wasi-nn state.
///
/// Unlike other WASI proposals (see `wasmtime-wasi`, `wasmtime-wasi-http`),
/// this wrapper is not a `trait` but rather holds the references directly. This
/// remove one layer of abstraction for simplicity only, and could be added back
/// in the future if embedders need more control here.
pub struct WasiNnView<'a> {
ctx: &'a mut WasiNnCtx,
table: &'a mut ResourceTable,
}
impl<'a> WasiNnView<'a> {
/// Create a new view into the wasi-nn state.
pub fn new(table: &'a mut ResourceTable, ctx: &'a mut WasiNnCtx) -> Self {
Self { ctx, table }
}
}
pub enum Error {
/// Caller module passed an invalid argument.
InvalidArgument,
/// Invalid encoding.
InvalidEncoding,
/// The operation timed out.
Timeout,
/// Runtime Error.
RuntimeError,
/// Unsupported operation.
UnsupportedOperation,
/// Graph is too large.
TooLarge,
/// Graph not found.
NotFound,
/// A runtime error occurred that we should trap on; see `StreamError`.
Trap(anyhow::Error),
}
impl From<wasmtime::component::ResourceTableError> for Error {
fn from(error: wasmtime::component::ResourceTableError) -> Self {
Self::Trap(error.into())
}
}
/// Generate the traits and types from the `wasi-nn` WIT specification.
mod gen_ {
@ -24,126 +85,241 @@ mod gen_ {
world: "ml",
path: "wit/wasi-nn.wit",
trappable_imports: true,
with: {
// Configure all WIT http resources to be defined types in this
// crate to use the `ResourceTable` helper methods.
"wasi:nn/graph/graph": crate::Graph,
"wasi:nn/tensor/tensor": crate::Tensor,
"wasi:nn/inference/graph-execution-context": crate::ExecutionContext,
},
trappable_error_type: {
"wasi:nn/errors/error" => super::Error,
},
});
}
use gen_::wasi::nn as gen; // Shortcut to the module containing the types we need.
use gen_::wasi::nn::{self as gen}; // Shortcut to the module containing the types we need.
// Export the `types` used in this crate as well as `ML::add_to_linker`.
pub mod types {
use super::gen;
pub use gen::graph::{ExecutionTarget, Graph, GraphEncoding};
pub use gen::errors::Error;
pub use gen::graph::{ExecutionTarget, Graph, GraphBuilder, GraphEncoding};
pub use gen::inference::GraphExecutionContext;
pub use gen::tensor::{Tensor, TensorType};
}
pub use gen::graph::{ExecutionTarget, Graph, GraphBuilder, GraphEncoding};
pub use gen::inference::GraphExecutionContext;
pub use gen::tensor::{Tensor, TensorData, TensorDimensions, TensorType};
pub use gen_::Ml as ML;
impl gen::graph::Host for WasiNnCtx {
/// Load an opaque sequence of bytes to use for inference.
/// Add the WIT-based version of the `wasi-nn` API to a
/// [`wasmtime::component::Linker`].
pub fn add_to_linker<T>(
l: &mut wasmtime::component::Linker<T>,
f: impl Fn(&mut T) -> WasiNnView<'_> + Send + Sync + Copy + 'static,
) -> anyhow::Result<()> {
gen::graph::add_to_linker_get_host(l, f)?;
gen::tensor::add_to_linker_get_host(l, f)?;
gen::inference::add_to_linker_get_host(l, f)?;
gen::errors::add_to_linker_get_host(l, f)?;
Ok(())
}
impl gen::graph::Host for WasiNnView<'_> {
fn load(
&mut self,
builders: Vec<gen::graph::GraphBuilder>,
encoding: gen::graph::GraphEncoding,
target: gen::graph::ExecutionTarget,
) -> wasmtime::Result<Result<gen::graph::Graph, gen::errors::Error>> {
let graph = if let Some(backend) = self.backends.get_mut(&encoding) {
builders: Vec<GraphBuilder>,
encoding: GraphEncoding,
target: ExecutionTarget,
) -> Result<Resource<crate::Graph>, Error> {
tracing::debug!("load {encoding:?} {target:?}");
if let Some(backend) = self.ctx.backends.get_mut(&encoding) {
let slices = builders.iter().map(|s| s.as_slice()).collect::<Vec<_>>();
backend.load(&slices, target.into())?
match backend.load(&slices, target.into()) {
Ok(graph) => {
let graph = self.table.push(graph)?;
Ok(graph)
}
Err(error) => {
tracing::error!("failed to load graph: {error:?}");
Err(Error::RuntimeError)
}
}
} else {
return Err(UsageError::InvalidEncoding(encoding.into()).into());
};
let graph_id = self.graphs.insert(graph);
Ok(Ok(graph_id))
Err(Error::InvalidEncoding)
}
}
fn load_by_name(
&mut self,
name: String,
) -> wasmtime::Result<Result<gen::graph::Graph, gen::errors::Error>> {
if let Some(graph) = self.registry.get_mut(&name) {
let graph_id = self.graphs.insert(graph.clone().into());
Ok(Ok(graph_id))
fn load_by_name(&mut self, name: String) -> Result<Resource<Graph>, Error> {
use core::result::Result::*;
tracing::debug!("load by name {name:?}");
let registry = &self.ctx.registry;
if let Some(graph) = registry.get(&name) {
let graph = graph.clone();
let graph = self.table.push(graph)?;
Ok(graph)
} else {
return Err(UsageError::NotFound(name.to_string()).into());
tracing::error!("failed to find graph with name: {name}");
Err(Error::NotFound)
}
}
}
impl gen::inference::Host for WasiNnCtx {
/// Create an execution instance of a loaded graph.
///
/// TODO: remove completely?
impl gen::graph::HostGraph for WasiNnView<'_> {
fn init_execution_context(
&mut self,
graph_id: gen::graph::Graph,
) -> wasmtime::Result<Result<gen::inference::GraphExecutionContext, gen::errors::Error>> {
let exec_context = if let Some(graph) = self.graphs.get(graph_id) {
graph.init_execution_context()?
} else {
return Err(UsageError::InvalidGraphHandle.into());
};
graph: Resource<Graph>,
) -> Result<Resource<GraphExecutionContext>, Error> {
use core::result::Result::*;
tracing::debug!("initialize execution context");
let graph = self.table.get(&graph)?;
match graph.init_execution_context() {
Ok(exec_context) => {
let exec_context = self.table.push(exec_context)?;
Ok(exec_context)
}
Err(error) => {
tracing::error!("failed to initialize execution context: {error:?}");
Err(Error::RuntimeError)
}
}
}
let exec_context_id = self.executions.insert(exec_context);
Ok(Ok(exec_context_id))
fn drop(&mut self, graph: Resource<Graph>) -> wasmtime::Result<()> {
self.table.delete(graph)?;
Ok(())
}
}
/// Define the inputs to use for inference.
impl gen::inference::HostGraphExecutionContext for WasiNnView<'_> {
fn set_input(
&mut self,
exec_context_id: gen::inference::GraphExecutionContext,
index: u32,
tensor: gen::tensor::Tensor,
) -> wasmtime::Result<Result<(), gen::errors::Error>> {
if let Some(exec_context) = self.executions.get_mut(exec_context_id) {
exec_context.set_input(index, &tensor)?;
Ok(Ok(()))
exec_context: Resource<GraphExecutionContext>,
name: String,
tensor: Resource<Tensor>,
) -> Result<(), Error> {
let tensor = self.table.get(&tensor)?;
tracing::debug!("set input {name:?}: {tensor:?}");
let tensor = tensor.clone(); // TODO: avoid copying the tensor
let exec_context = self.table.get_mut(&exec_context)?;
if let Err(e) = exec_context.set_input(Id::Name(name), &tensor) {
tracing::error!("failed to set input: {e:?}");
Err(Error::InvalidArgument)
} else {
Err(UsageError::InvalidGraphHandle.into())
Ok(())
}
}
/// Compute the inference on the given inputs.
///
/// TODO: refactor to compute(list<tensor>) -> result<list<tensor>, error>
fn compute(
&mut self,
exec_context_id: gen::inference::GraphExecutionContext,
) -> wasmtime::Result<Result<(), gen::errors::Error>> {
if let Some(exec_context) = self.executions.get_mut(exec_context_id) {
exec_context.compute()?;
Ok(Ok(()))
} else {
Err(UsageError::InvalidExecutionContextHandle.into())
fn compute(&mut self, exec_context: Resource<GraphExecutionContext>) -> Result<(), Error> {
let exec_context = &mut self.table.get_mut(&exec_context)?;
tracing::debug!("compute");
match exec_context.compute() {
Ok(()) => Ok(()),
Err(error) => {
tracing::error!("failed to compute: {error:?}");
Err(Error::RuntimeError)
}
}
}
/// Extract the outputs after inference.
#[doc = r" Extract the outputs after inference."]
fn get_output(
&mut self,
exec_context_id: gen::inference::GraphExecutionContext,
index: u32,
) -> wasmtime::Result<Result<gen::tensor::TensorData, gen::errors::Error>> {
if let Some(exec_context) = self.executions.get_mut(exec_context_id) {
// Read the output bytes. TODO: this involves a hard-coded upper
// limit on the tensor size that is necessary because there is no
// way to introspect the graph outputs
// (https://github.com/WebAssembly/wasi-nn/issues/37).
let mut destination = vec![0; 1024 * 1024];
let bytes_read = exec_context.get_output(index, &mut destination)?;
destination.truncate(bytes_read as usize);
Ok(Ok(destination))
} else {
Err(UsageError::InvalidGraphHandle.into())
exec_context: Resource<GraphExecutionContext>,
name: String,
) -> Result<Resource<Tensor>, Error> {
let exec_context = self.table.get_mut(&exec_context)?;
tracing::debug!("get output {name:?}");
match exec_context.get_output(Id::Name(name)) {
Ok(tensor) => {
let tensor = self.table.push(tensor)?;
Ok(tensor)
}
Err(error) => {
tracing::error!("failed to get output: {error:?}");
Err(Error::RuntimeError)
}
}
}
fn drop(&mut self, exec_context: Resource<GraphExecutionContext>) -> wasmtime::Result<()> {
self.table.delete(exec_context)?;
Ok(())
}
}
impl gen::errors::Host for WasiNnCtx {}
impl gen::tensor::HostTensor for WasiNnView<'_> {
fn new(
&mut self,
dimensions: TensorDimensions,
ty: TensorType,
data: TensorData,
) -> wasmtime::Result<Resource<Tensor>> {
let tensor = Tensor {
dimensions,
ty,
data,
};
let tensor = self.table.push(tensor)?;
Ok(tensor)
}
fn dimensions(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<TensorDimensions> {
let tensor = self.table.get(&tensor)?;
Ok(tensor.dimensions.clone())
}
fn ty(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<TensorType> {
let tensor = self.table.get(&tensor)?;
Ok(tensor.ty)
}
fn data(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<TensorData> {
let tensor = self.table.get(&tensor)?;
Ok(tensor.data.clone())
}
impl gen::tensor::Host for WasiNnCtx {}
fn drop(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<()> {
self.table.delete(tensor)?;
Ok(())
}
}
impl gen::tensor::Host for WasiNnView<'_> {}
impl gen::errors::Host for WasiNnView<'_> {
fn convert_error(&mut self, err: Error) -> wasmtime::Result<gen::errors::Error> {
match err {
Error::InvalidArgument => Ok(gen::errors::Error::InvalidArgument),
Error::InvalidEncoding => Ok(gen::errors::Error::InvalidEncoding),
Error::Timeout => Ok(gen::errors::Error::Timeout),
Error::RuntimeError => Ok(gen::errors::Error::RuntimeError),
Error::UnsupportedOperation => Ok(gen::errors::Error::UnsupportedOperation),
Error::TooLarge => Ok(gen::errors::Error::TooLarge),
Error::NotFound => Ok(gen::errors::Error::NotFound),
Error::Trap(e) => Err(e),
}
}
}
impl gen::inference::Host for WasiNnView<'_> {}
impl Hash for gen::graph::GraphEncoding {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
core::mem::discriminant(self).hash(state);
self.to_string().hash(state)
}
}
impl fmt::Display for gen::graph::GraphEncoding {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use gen::graph::GraphEncoding::*;
match self {
Openvino => write!(f, "openvino"),
Onnx => write!(f, "onnx"),
Pytorch => write!(f, "pytorch"),
Tensorflow => write!(f, "tensorflow"),
Tensorflowlite => write!(f, "tensorflowlite"),
Autodetect => write!(f, "autodetect"),
Ggml => write!(f, "ggml"),
}
}
}
@ -168,4 +344,4 @@ impl fmt::Display for GraphEncodingParseError {
write!(f, "unknown graph encoding: {}", self.0)
}
}
impl Error for GraphEncodingParseError {}
impl std::error::Error for GraphEncodingParseError {}

122
crates/wasi-nn/src/witx.rs

@ -13,11 +13,83 @@
//!
//! [`types`]: crate::wit::types
use crate::ctx::{UsageError, WasiNnCtx, WasiNnError, WasiNnResult as Result};
use wiggle::{GuestMemory, GuestPtr};
use crate::backend::BackendError;
use crate::backend::Id;
use crate::wit::GraphEncoding;
use crate::{Backend, ExecutionContext, Graph, Registry};
use std::collections::HashMap;
use std::hash::Hash;
use thiserror::Error;
use wiggle::{GuestError, GuestMemory, GuestPtr};
pub use gen::wasi_ephemeral_nn::add_to_linker;
pub(crate) type WasiNnResult<T> = std::result::Result<T, WasiNnError>;
type Result<T> = WasiNnResult<T>;
type GraphId = u32;
type GraphExecutionContextId = u32;
/// Capture the state necessary for calling into the backend ML libraries.
pub struct WasiNnCtx {
pub(crate) backends: HashMap<GraphEncoding, Backend>,
pub(crate) registry: Registry,
pub(crate) graphs: Table<GraphId, Graph>,
pub(crate) executions: Table<GraphExecutionContextId, ExecutionContext>,
}
impl WasiNnCtx {
/// Make a new context from the default state.
pub fn new(backends: impl IntoIterator<Item = Backend>, registry: Registry) -> Self {
let backends = backends.into_iter().map(|b| (b.encoding(), b)).collect();
Self {
backends,
registry,
graphs: Table::default(),
executions: Table::default(),
}
}
}
/// Record handle entries in a table.
pub struct Table<K, V> {
entries: HashMap<K, V>,
next_key: u32,
}
impl<K, V> Default for Table<K, V> {
fn default() -> Self {
Self {
entries: HashMap::new(),
next_key: 0,
}
}
}
impl<K, V> Table<K, V>
where
K: Eq + Hash + From<u32> + Copy,
{
pub fn insert(&mut self, value: V) -> K {
let key = self.use_next_key();
self.entries.insert(key, value);
key
}
pub fn get(&self, key: K) -> Option<&V> {
self.entries.get(&key)
}
pub fn get_mut(&mut self, key: K) -> Option<&mut V> {
self.entries.get_mut(&key)
}
fn use_next_key(&mut self) -> K {
let current = self.next_key;
self.next_key += 1;
K::from(current)
}
}
/// Generate the traits and types from the `wasi-nn` WITX specification.
mod gen {
use super::*;
@ -42,9 +114,10 @@ mod gen {
) -> anyhow::Result<types::NnErrno> {
tracing::debug!("host error: {:?}", e);
match e {
WasiNnError::BackendError(_) => unimplemented!(),
WasiNnError::GuestError(_) => unimplemented!(),
WasiNnError::UsageError(_) => unimplemented!(),
WasiNnError::BackendError(_) => Ok(types::NnErrno::RuntimeError),
WasiNnError::GuestError(_) => unimplemented!("guest error conversion"),
WasiNnError::UsageError(_) => Ok(types::NnErrno::UnsupportedOperation),
WasiNnError::NotEnoughMemory(_) => Ok(types::NnErrno::TooLarge),
}
}
}
@ -119,10 +192,10 @@ impl gen::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx {
if let Some(exec_context) = self.executions.get_mut(exec_context_id.into()) {
let tensor = crate::wit::types::Tensor {
dimensions: memory.to_vec(tensor.dimensions)?,
tensor_type: tensor.type_.into(),
ty: tensor.type_.into(),
data: memory.to_vec(tensor.data)?,
};
Ok(exec_context.set_input(index, &tensor)?)
Ok(exec_context.set_input(Id::Index(index), &tensor)?)
} else {
Err(UsageError::InvalidGraphHandle.into())
}
@ -149,13 +222,19 @@ impl gen::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx {
out_buffer_max_size: u32,
) -> Result<u32> {
if let Some(exec_context) = self.executions.get_mut(exec_context_id.into()) {
let mut destination = memory
let tensor = exec_context.get_output(Id::Index(index))?;
let destination = memory
.as_slice_mut(out_buffer.as_array(out_buffer_max_size))?
.expect(
"cannot use with shared memories; \
see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)",
);
Ok(exec_context.get_output(index, &mut destination)?)
if tensor.data.len() > destination.len() {
Err(WasiNnError::NotEnoughMemory(tensor.data.len()))
} else {
destination[..tensor.data.len()].copy_from_slice(&tensor.data);
Ok(tensor.data.len() as u32)
}
} else {
Err(UsageError::InvalidGraphHandle.into())
}
@ -199,3 +278,28 @@ impl From<gen::types::TensorType> for crate::wit::types::TensorType {
}
}
}
/// Possible errors while interacting with [WasiNnCtx].
#[derive(Debug, Error)]
pub enum WasiNnError {
#[error("backend error")]
BackendError(#[from] BackendError),
#[error("guest error")]
GuestError(#[from] GuestError),
#[error("usage error")]
UsageError(#[from] UsageError),
#[error("not enough memory: requested {0} bytes")]
NotEnoughMemory(usize),
}
#[derive(Debug, Error)]
pub enum UsageError {
#[error("Only OpenVINO's IR is currently supported, passed encoding: {0:?}")]
InvalidEncoding(GraphEncoding),
#[error("Invalid graph handle; has it been loaded?")]
InvalidGraphHandle,
#[error("Invalid execution context handle; has it been initialized?")]
InvalidExecutionContextHandle,
#[error("No graph found with name: {0}")]
NotFound(String),
}

8
crates/wasi-nn/tests/check/mod.rs

@ -1,10 +1,8 @@
//! This is testing-specific code--it is public only so that it can be
//! accessible both in unit and integration tests.
//! Check that the environment is set up correctly for running tests.
//!
//! This module checks:
//! - that OpenVINO can be found in the environment
//! - that WinML is available
//! - that some ML model artifacts can be downloaded and cached.
//! - that various backends can be located on the system (see sub-modules)
//! - that certain ML model artifacts can be downloaded and cached.
#[allow(unused_imports)]
use anyhow::{anyhow, Context, Result};

54
crates/wasi-nn/tests/exec/mod.rs

@ -1,52 +1,6 @@
use crate::check::artifacts_dir;
use anyhow::Result;
use std::path::Path;
use wasi_common::sync::{Dir, WasiCtxBuilder};
use wasi_common::WasiCtx;
use wasmtime::{Config, Engine, Linker, Module, Store};
use wasmtime_wasi_nn::{Backend, InMemoryRegistry, WasiNnCtx};
//! Provide a Wasmtime embedding for executing wasi-nn test programs.
const PREOPENED_DIR_NAME: &str = "fixture";
pub mod wit;
pub mod witx;
/// Run a wasi-nn test program. This is modeled after
/// `crates/wasi/tests/all/main.rs` but still uses the older preview1 API
/// for file reads.
pub fn run(path: &str, backend: Backend, preload_model: bool) -> Result<()> {
let path = Path::new(path);
let engine = Engine::new(&Config::new())?;
let mut linker = Linker::new(&engine);
wasmtime_wasi_nn::witx::add_to_linker(&mut linker, |s: &mut Ctx| &mut s.wasi_nn)?;
wasi_common::sync::add_to_linker(&mut linker, |s: &mut Ctx| &mut s.wasi)?;
let module = Module::from_file(&engine, path)?;
let mut store = Store::new(&engine, Ctx::new(&artifacts_dir(), preload_model, backend)?);
let instance = linker.instantiate(&mut store, &module)?;
let start = instance.get_typed_func::<(), ()>(&mut store, "_start")?;
start.call(&mut store, ())?;
Ok(())
}
/// The host state for running wasi-nn tests.
struct Ctx {
wasi: WasiCtx,
wasi_nn: WasiNnCtx,
}
impl Ctx {
fn new(preopen_dir: &Path, preload_model: bool, mut backend: Backend) -> Result<Self> {
let preopen_dir = Dir::open_ambient_dir(preopen_dir, cap_std::ambient_authority())?;
let mut builder = WasiCtxBuilder::new();
builder
.inherit_stdio()
.preopened_dir(preopen_dir, PREOPENED_DIR_NAME)?;
let wasi = builder.build();
let mut registry = InMemoryRegistry::new();
let mobilenet_dir = artifacts_dir();
if preload_model {
registry.load((backend).as_dir_loadable().unwrap(), &mobilenet_dir)?;
}
let wasi_nn = WasiNnCtx::new([backend.into()], registry.into());
Ok(Self { wasi, wasi_nn })
}
}
pub const PREOPENED_DIR_NAME: &str = "fixture";

73
crates/wasi-nn/tests/exec/wit.rs

@ -0,0 +1,73 @@
use super::PREOPENED_DIR_NAME;
use crate::check::artifacts_dir;
use anyhow::{anyhow, Result};
use std::path::Path;
use wasmtime::component::{Component, Linker, ResourceTable};
use wasmtime::{Config, Engine, Store};
use wasmtime_wasi::bindings::sync::Command;
use wasmtime_wasi::{DirPerms, FilePerms, WasiCtx, WasiCtxBuilder};
use wasmtime_wasi_nn::wit::WasiNnView;
use wasmtime_wasi_nn::{wit::WasiNnCtx, Backend, InMemoryRegistry};
/// Run a wasi-nn test program. This is modeled after
/// `crates/wasi/tests/all/main.rs` but still uses the older preview1 API for
/// file reads.
pub fn run(path: &str, backend: Backend, preload_model: bool) -> Result<()> {
let path = Path::new(path);
let engine = Engine::new(&Config::new())?;
let mut linker = Linker::new(&engine);
wasmtime_wasi_nn::wit::add_to_linker(&mut linker, |c: &mut Ctx| {
WasiNnView::new(&mut c.table, &mut c.wasi_nn)
})?;
wasmtime_wasi::add_to_linker_sync(&mut linker)?;
let module = Component::from_file(&engine, path)?;
let mut store = Store::new(&engine, Ctx::new(&artifacts_dir(), preload_model, backend)?);
let command = Command::instantiate(&mut store, &module, &linker)?;
let result = command.wasi_cli_run().call_run(&mut store)?;
result.map_err(|_| anyhow!("failed to run command"))
}
/// The host state for running wasi-nn component tests.
struct Ctx {
wasi: WasiCtx,
wasi_nn: WasiNnCtx,
table: ResourceTable,
}
impl Ctx {
fn new(preopen_dir: &Path, preload_model: bool, mut backend: Backend) -> Result<Self> {
let mut builder = WasiCtxBuilder::new();
builder.inherit_stdio().preopened_dir(
preopen_dir,
PREOPENED_DIR_NAME,
DirPerms::READ,
FilePerms::READ,
)?;
let wasi = builder.build();
let mut registry = InMemoryRegistry::new();
let mobilenet_dir = artifacts_dir();
if preload_model {
registry.load((backend).as_dir_loadable().unwrap(), &mobilenet_dir)?;
}
let wasi_nn = WasiNnCtx::new([backend.into()], registry.into());
let table = ResourceTable::new();
Ok(Self {
wasi,
wasi_nn,
table,
})
}
}
impl wasmtime_wasi::WasiView for Ctx {
fn ctx(&mut self) -> &mut WasiCtx {
&mut self.wasi
}
fn table(&mut self) -> &mut ResourceTable {
&mut self.table
}
}

52
crates/wasi-nn/tests/exec/witx.rs

@ -0,0 +1,52 @@
use super::PREOPENED_DIR_NAME;
use crate::check::artifacts_dir;
use anyhow::Result;
use std::path::Path;
use wasmtime::{Config, Engine, Linker, Module, Store};
use wasmtime_wasi::{preview1::WasiP1Ctx, DirPerms, FilePerms, WasiCtxBuilder};
use wasmtime_wasi_nn::{witx::WasiNnCtx, Backend, InMemoryRegistry};
/// Run a wasi-nn test program. This is modeled after
/// `crates/wasi/tests/all/main.rs` but still uses the older preview1 API
/// for file reads.
pub fn run(path: &str, backend: Backend, preload_model: bool) -> Result<()> {
let path = Path::new(path);
let engine = Engine::new(&Config::new())?;
let mut linker = Linker::new(&engine);
wasmtime_wasi_nn::witx::add_to_linker(&mut linker, |s: &mut Ctx| &mut s.wasi_nn)?;
wasmtime_wasi::preview1::add_to_linker_sync(&mut linker, |s: &mut Ctx| &mut s.wasi)?;
let module = Module::from_file(&engine, path)?;
let mut store = Store::new(&engine, Ctx::new(&artifacts_dir(), preload_model, backend)?);
let instance = linker.instantiate(&mut store, &module)?;
let start = instance.get_typed_func::<(), ()>(&mut store, "_start")?;
start.call(&mut store, ())?;
Ok(())
}
/// The host state for running wasi-nn tests.
struct Ctx {
wasi: WasiP1Ctx,
wasi_nn: WasiNnCtx,
}
impl Ctx {
fn new(preopen_dir: &Path, preload_model: bool, mut backend: Backend) -> Result<Self> {
let mut builder = WasiCtxBuilder::new();
builder.inherit_stdio().preopened_dir(
preopen_dir,
PREOPENED_DIR_NAME,
DirPerms::READ,
FilePerms::READ,
)?;
let wasi = builder.build_p1();
let mut registry = InMemoryRegistry::new();
let mobilenet_dir = artifacts_dir();
if preload_model {
registry.load((backend).as_dir_loadable().unwrap(), &mobilenet_dir)?;
}
let wasi_nn = WasiNnCtx::new([backend.into()], registry.into());
Ok(Self { wasi, wasi_nn })
}
}

0
crates/wasi-nn/tests/fixtures/readme.md → crates/wasi-nn/tests/fixtures/README.md

200
crates/wasi-nn/tests/test-programs.rs

@ -23,6 +23,8 @@ use test_programs_artifacts::*;
use wasmtime_wasi_nn::{backend, Backend};
fn main() -> Result<()> {
tracing_subscriber::fmt::init();
if cfg!(miri) {
return Ok(());
}
@ -45,7 +47,7 @@ fn main() -> Result<()> {
let mut trials = Vec::new();
for program in programs {
// Either ignore the test if it cannot run (i.e., downgrade `Fail` to
// `Ignore`) or pre-emptively fail it if `error_on_failed_check` is set.
// `Ignore`) or preemptively fail it if `error_on_failed_check` is set.
let (run_test, mut check) = check_test_program(program);
if !error_on_failed_check {
check = check.downgrade_failure(); // Downgrade `Fail` to `Ignore`.
@ -68,103 +70,122 @@ fn main() -> Result<()> {
/// Return the test program to run and a check that must pass for the test to
/// run.
fn check_test_program(name: &str) -> (fn() -> Result<()>, IgnoreCheck) {
use IgnoreCheck::*;
match name {
"nn_image_classification" => (
nn_image_classification,
if !cfg!(target_arch = "x86_64") {
Fail("requires x86_64".into())
} else if !cfg!(target_os = "linux") && !cfg!(target_os = "windows") {
Fail("requires linux or windows".into())
} else if let Err(e) = check::openvino::is_installed() {
Fail(e.to_string().into())
} else {
Run
},
// Legacy WITX-based tests:
"nn_witx_image_classification_openvino" => (
nn_witx_image_classification_openvino,
IgnoreCheck::for_openvino(),
),
"nn_witx_image_classification_openvino_named" => (
nn_witx_image_classification_openvino_named,
IgnoreCheck::for_openvino(),
),
"nn_witx_image_classification_onnx" => {
(nn_witx_image_classification_onnx, IgnoreCheck::for_onnx())
}
"nn_witx_image_classification_winml_named" => (
nn_witx_image_classification_winml_named,
IgnoreCheck::for_winml(),
),
"nn_image_classification_named" => (
nn_image_classification_named,
if !cfg!(target_arch = "x86_64") {
Fail("requires x86_64".into())
} else if !cfg!(target_os = "linux") && !cfg!(target_os = "windows") {
Fail("requires linux or windows or macos".into())
} else if let Err(e) = check::openvino::is_installed() {
Fail(e.to_string().into())
} else {
Run
},
// WIT-based tests:
"nn_wit_image_classification_openvino" => (
nn_wit_image_classification_openvino,
IgnoreCheck::for_openvino(),
),
"nn_image_classification_onnx" => (
nn_image_classification_onnx,
#[cfg(feature = "onnx")]
if !cfg!(target_arch = "x86_64") && !cfg!(target_arch = "aarch64") {
Fail("requires x86_64 or aarch64".into())
} else if !cfg!(target_os = "linux")
&& !cfg!(target_os = "windows")
&& !cfg!(target_os = "macos")
{
Fail("requires linux, windows, or macos".into())
} else {
Run
},
#[cfg(not(feature = "onnx"))]
Ignore("requires the `onnx` feature".into()),
"nn_wit_image_classification_openvino_named" => (
nn_wit_image_classification_openvino_named,
IgnoreCheck::for_openvino(),
),
"nn_image_classification_winml" => (
nn_image_classification_winml,
#[cfg(all(feature = "winml", target_os = "windows"))]
if !cfg!(target_arch = "x86_64") {
Fail("requires x86_64".into())
} else if cfg!(target_os = "windows") {
Fail("requires windows".into())
} else if let Err(e) = check::winml::is_available() {
Fail(e.to_string().into())
} else {
Run
},
#[cfg(not(all(feature = "winml", target_os = "windows")))]
Ignore("requires the `winml` feature on windows".into()),
"nn_wit_image_classification_onnx" => {
(nn_wit_image_classification_onnx, IgnoreCheck::for_onnx())
}
"nn_wit_image_classification_winml_named" => (
nn_wit_image_classification_winml_named,
IgnoreCheck::for_winml(),
),
_ => panic!("unknown test program: {} (add to this `match`)", name),
}
}
fn nn_image_classification() -> Result<()> {
fn nn_witx_image_classification_openvino() -> Result<()> {
check::openvino::is_installed()?;
check::openvino::are_artifacts_available()?;
let backend = Backend::from(backend::openvino::OpenvinoBackend::default());
exec::run(NN_IMAGE_CLASSIFICATION, backend, false)
exec::witx::run(NN_WITX_IMAGE_CLASSIFICATION_OPENVINO, backend, false)
}
fn nn_image_classification_named() -> Result<()> {
fn nn_witx_image_classification_openvino_named() -> Result<()> {
check::openvino::is_installed()?;
check::openvino::are_artifacts_available()?;
let backend = Backend::from(backend::openvino::OpenvinoBackend::default());
exec::run(NN_IMAGE_CLASSIFICATION_NAMED, backend, true)
exec::witx::run(NN_WITX_IMAGE_CLASSIFICATION_OPENVINO_NAMED, backend, true)
}
#[cfg(feature = "onnx")]
fn nn_image_classification_onnx() -> Result<()> {
fn nn_witx_image_classification_onnx() -> Result<()> {
check::onnx::are_artifacts_available()?;
let backend = Backend::from(backend::onnxruntime::OnnxBackend::default());
exec::run(NN_IMAGE_CLASSIFICATION_ONNX, backend, false)
let backend = Backend::from(backend::onnx::OnnxBackend::default());
exec::witx::run(NN_WITX_IMAGE_CLASSIFICATION_ONNX, backend, false)
}
#[cfg(not(feature = "onnx"))]
fn nn_image_classification_onnx() -> Result<()> {
fn nn_witx_image_classification_onnx() -> Result<()> {
anyhow::bail!("this test requires the `onnx` feature")
}
#[cfg(all(feature = "winml", target_os = "windows"))]
fn nn_image_classification_winml() -> Result<()> {
fn nn_witx_image_classification_winml_named() -> Result<()> {
check::winml::is_available()?;
check::onnx::are_artifacts_available()?;
let backend = Backend::from(backend::winml::WinMLBackend::default());
exec::run(NN_IMAGE_CLASSIFICATION_ONNX, backend, false)
exec::witx::run(NN_WITX_IMAGE_CLASSIFICATION_ONNX, backend, false)
}
#[cfg(not(all(feature = "winml", target_os = "windows")))]
fn nn_witx_image_classification_winml_named() -> Result<()> {
anyhow::bail!("this test requires the `winml` feature and only runs on windows")
}
fn nn_wit_image_classification_openvino() -> Result<()> {
check::openvino::is_installed()?;
check::openvino::are_artifacts_available()?;
let backend = Backend::from(backend::openvino::OpenvinoBackend::default());
exec::wit::run(
NN_WIT_IMAGE_CLASSIFICATION_OPENVINO_COMPONENT,
backend,
false,
)
}
fn nn_wit_image_classification_openvino_named() -> Result<()> {
check::openvino::is_installed()?;
check::openvino::are_artifacts_available()?;
let backend = Backend::from(backend::openvino::OpenvinoBackend::default());
exec::wit::run(
NN_WIT_IMAGE_CLASSIFICATION_OPENVINO_NAMED_COMPONENT,
backend,
true,
)
}
#[cfg(feature = "onnx")]
fn nn_wit_image_classification_onnx() -> Result<()> {
check::onnx::are_artifacts_available()?;
let backend = Backend::from(backend::onnx::OnnxBackend::default());
exec::wit::run(NN_WIT_IMAGE_CLASSIFICATION_ONNX_COMPONENT, backend, false)
}
#[cfg(not(feature = "onnx"))]
fn nn_wit_image_classification_onnx() -> Result<()> {
anyhow::bail!("this test requires the `onnx` feature")
}
#[cfg(all(feature = "winml", target_os = "windows"))]
fn nn_wit_image_classification_winml_named() -> Result<()> {
check::winml::is_available()?;
check::onnx::are_artifacts_available()?;
let backend = Backend::from(backend::winml::WinMLBackend::default());
exec::wit::run(NN_WIT_IMAGE_CLASSIFICATION_ONNX_COMPONENT, backend, false)
}
#[cfg(not(all(feature = "winml", target_os = "windows")))]
fn nn_image_classification_winml() -> Result<()> {
fn nn_wit_image_classification_winml_named() -> Result<()> {
anyhow::bail!("this test requires the `winml` feature and only runs on windows")
}
@ -197,3 +218,52 @@ impl IgnoreCheck {
matches!(self, IgnoreCheck::Ignore(_))
}
}
/// Some pre-test checks for various backends.
impl IgnoreCheck {
fn for_openvino() -> IgnoreCheck {
use IgnoreCheck::*;
if !cfg!(target_arch = "x86_64") {
Fail("requires x86_64".into())
} else if !cfg!(target_os = "linux") && !cfg!(target_os = "windows") {
Fail("requires linux or windows or macos".into())
} else if let Err(e) = check::openvino::is_installed() {
Fail(e.to_string().into())
} else {
Run
}
}
fn for_onnx() -> Self {
use IgnoreCheck::*;
#[cfg(feature = "onnx")]
if !cfg!(target_arch = "x86_64") && !cfg!(target_arch = "aarch64") {
Fail("requires x86_64 or aarch64".into())
} else if !cfg!(target_os = "linux")
&& !cfg!(target_os = "windows")
&& !cfg!(target_os = "macos")
{
Fail("requires linux, windows, or macos".into())
} else {
Run
}
#[cfg(not(feature = "onnx"))]
Ignore("requires the `onnx` feature".into())
}
fn for_winml() -> IgnoreCheck {
use IgnoreCheck::*;
#[cfg(all(feature = "winml", target_os = "windows"))]
if !cfg!(target_arch = "x86_64") {
Fail("requires x86_64".into())
} else if !cfg!(target_os = "windows") {
Fail("requires windows".into())
} else if let Err(e) = check::winml::is_available() {
Fail(e.to_string().into())
} else {
Run
}
#[cfg(not(all(feature = "winml", target_os = "windows")))]
Ignore("requires the `winml` feature on windows".into())
}
}

57
crates/wasi-nn/wit/wasi-nn.wit

@ -43,16 +43,18 @@ interface tensor {
/// memory--e.g., using row-major ordering--and could perhaps be improved.
type tensor-data = list<u8>;
record tensor {
resource tensor {
constructor(dimensions: tensor-dimensions, ty: tensor-type, data: tensor-data);
// Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To represent a tensor
// containing a single value, use `[1]` for the tensor dimensions.
dimensions: tensor-dimensions,
dimensions: func() -> tensor-dimensions;
// Describe the type of element in the tensor (e.g., `f32`).
tensor-type: tensor-type,
ty: func() -> tensor-type;
// Contains the tensor data.
data: tensor-data,
// Return the tensor data.
data: func() -> tensor-data;
}
}
@ -61,11 +63,12 @@ interface tensor {
interface graph {
use errors.{error};
use tensor.{tensor};
use inference.{graph-execution-context};
/// An execution graph for performing inference (i.e., a model).
///
/// TODO: replace with `resource` (https://github.com/WebAssembly/wasi-nn/issues/47).
type graph = u32;
resource graph {
init-execution-context: func() -> result<graph-execution-context, error>;
}
/// Describes the encoding of the graph. This allows the API to be implemented by various
/// backends that encode (i.e., serialize) their graph IR with different formats.
@ -75,6 +78,7 @@ interface graph {
tensorflow,
pytorch,
tensorflowlite,
ggml,
autodetect,
}
@ -107,27 +111,25 @@ interface graph {
interface inference {
use errors.{error};
use tensor.{tensor, tensor-data};
use graph.{graph};
/// Bind a `graph` to the input and output tensors for an inference.
///
/// TODO: this is no longer necessary in WIT (https://github.com/WebAssembly/wasi-nn/issues/43)
type graph-execution-context = u32;
/// Create an execution instance of a loaded graph.
init-execution-context: func(graph: graph) -> result<graph-execution-context, error>;
/// Define the inputs to use for inference.
set-input: func(ctx: graph-execution-context, index: u32, tensor: tensor) -> result<_, error>;
/// Compute the inference on the given inputs.
///
/// Note the expected sequence of calls: `set-input`, `compute`, `get-output`. TODO: this
/// expectation could be removed as a part of https://github.com/WebAssembly/wasi-nn/issues/43.
compute: func(ctx: graph-execution-context) -> result<_, error>;
/// Extract the outputs after inference.
get-output: func(ctx: graph-execution-context, index: u32) -> result<tensor-data, error>;
/// TODO: this may no longer be necessary in WIT
/// (https://github.com/WebAssembly/wasi-nn/issues/43)
resource graph-execution-context {
/// Define the inputs to use for inference.
set-input: func(name: string, tensor: tensor) -> result<_, error>;
/// Compute the inference on the given inputs.
///
/// Note the expected sequence of calls: `set-input`, `compute`, `get-output`. TODO: this
/// expectation could be removed as a part of
/// https://github.com/WebAssembly/wasi-nn/issues/43.
compute: func() -> result<_, error>;
/// Extract the outputs after inference.
get-output: func(name: string) -> result<tensor, error>;
}
}
/// TODO: create function-specific errors (https://github.com/WebAssembly/wasi-nn/issues/42)
@ -137,7 +139,8 @@ interface errors {
invalid-argument,
// Invalid encoding.
invalid-encoding,
busy,
// The operation timed out.
timeout,
// Runtime Error.
runtime-error,
// Unsupported operation.

63
src/commands/run.rs

@ -18,7 +18,7 @@ use wasmtime::{Engine, Func, Module, Store, StoreLimits, Val, ValType};
use wasmtime_wasi::WasiView;
#[cfg(feature = "wasi-nn")]
use wasmtime_wasi_nn::WasiNnCtx;
use wasmtime_wasi_nn::wit::WasiNnView;
#[cfg(feature = "wasi-threads")]
use wasmtime_wasi_threads::WasiThreadsCtx;
@ -624,40 +624,37 @@ impl RunCommand {
{
bail!("Cannot enable wasi-nn when the binary is not compiled with this feature.");
}
#[cfg(feature = "wasi-nn")]
#[cfg(all(feature = "wasi-nn", feature = "component-model"))]
{
let (backends, registry) = self.collect_preloaded_nn_graphs()?;
match linker {
CliLinker::Core(linker) => {
wasmtime_wasi_nn::witx::add_to_linker(linker, |host| {
// This WASI proposal is currently not protected against
// concurrent access--i.e., when wasi-threads is actively
// spawning new threads, we cannot (yet) safely allow access and
// fail if more than one thread has `Arc`-references to the
// context. Once this proposal is updated (as wasi-common has
// been) to allow concurrent access, this `Arc::get_mut`
// limitation can be removed.
Arc::get_mut(host.wasi_nn.as_mut().unwrap())
Arc::get_mut(host.wasi_nn_witx.as_mut().unwrap())
.expect("wasi-nn is not implemented with multi-threading support")
})?;
store.data_mut().wasi_nn_witx = Some(Arc::new(
wasmtime_wasi_nn::witx::WasiNnCtx::new(backends, registry),
));
}
#[cfg(feature = "component-model")]
CliLinker::Component(linker) => {
wasmtime_wasi_nn::wit::ML::add_to_linker(linker, |host| {
Arc::get_mut(host.wasi_nn.as_mut().unwrap())
.expect("wasi-nn is not implemented with multi-threading support")
wasmtime_wasi_nn::wit::add_to_linker(linker, |h: &mut Host| {
let preview2_ctx =
h.preview2_ctx.as_mut().expect("wasip2 is not configured");
let preview2_ctx = Arc::get_mut(preview2_ctx)
.expect("wasmtime_wasi is not compatible with threads")
.get_mut()
.unwrap();
let nn_ctx = Arc::get_mut(h.wasi_nn_wit.as_mut().unwrap())
.expect("wasi-nn is not implemented with multi-threading support");
WasiNnView::new(preview2_ctx.table(), nn_ctx)
})?;
store.data_mut().wasi_nn_wit = Some(Arc::new(
wasmtime_wasi_nn::wit::WasiNnCtx::new(backends, registry),
));
}
}
let graphs = self
.run
.common
.wasi
.nn_graph
.iter()
.map(|g| (g.format.clone(), g.dir.clone()))
.collect::<Vec<_>>();
let (backends, registry) = wasmtime_wasi_nn::preload(&graphs)?;
store.data_mut().wasi_nn = Some(Arc::new(WasiNnCtx::new(backends, registry)));
}
}
@ -767,6 +764,21 @@ impl RunCommand {
store.data_mut().preview2_ctx = Some(Arc::new(Mutex::new(ctx)));
Ok(())
}
#[cfg(feature = "wasi-nn")]
fn collect_preloaded_nn_graphs(
&self,
) -> Result<(Vec<wasmtime_wasi_nn::Backend>, wasmtime_wasi_nn::Registry)> {
let graphs = self
.run
.common
.wasi
.nn_graph
.iter()
.map(|g| (g.format.clone(), g.dir.clone()))
.collect::<Vec<_>>();
wasmtime_wasi_nn::preload(&graphs)
}
}
#[derive(Default, Clone)]
@ -779,7 +791,10 @@ struct Host {
preview2_ctx: Option<Arc<Mutex<wasmtime_wasi::preview1::WasiP1Ctx>>>,
#[cfg(feature = "wasi-nn")]
wasi_nn: Option<Arc<WasiNnCtx>>,
wasi_nn_wit: Option<Arc<wasmtime_wasi_nn::wit::WasiNnCtx>>,
#[cfg(feature = "wasi-nn")]
wasi_nn_witx: Option<Arc<wasmtime_wasi_nn::witx::WasiNnCtx>>,
#[cfg(feature = "wasi-threads")]
wasi_threads: Option<Arc<WasiThreadsCtx<Host>>>,
#[cfg(feature = "wasi-http")]

22
src/commands/serve.rs

@ -17,7 +17,7 @@ use wasmtime_wasi_http::io::TokioIo;
use wasmtime_wasi_http::{body::HyperOutgoingBody, WasiHttpCtx, WasiHttpView};
#[cfg(feature = "wasi-nn")]
use wasmtime_wasi_nn::WasiNnCtx;
use wasmtime_wasi_nn::wit::WasiNnCtx;
struct Host {
table: wasmtime::component::ResourceTable,
@ -75,15 +75,8 @@ impl ServeCommand {
pub fn execute(mut self) -> Result<()> {
self.run.common.init_logging()?;
// We force cli errors before starting to listen for connections so then we don't
// accidentally delay them to the first request.
if self.run.common.wasi.nn == Some(true) {
#[cfg(not(feature = "wasi-nn"))]
{
bail!("Cannot enable wasi-nn when the binary is not compiled with this feature.");
}
}
// We force cli errors before starting to listen for connections so then
// we don't accidentally delay them to the first request.
if let Some(Profile::Guest { .. }) = &self.run.profile {
bail!("Cannot use the guest profiler with components");
}
@ -99,8 +92,8 @@ impl ServeCommand {
bail!("wasi-threads does not support components yet")
}
// The serve command requires both wasi-http and the component model, so we enable those by
// default here.
// The serve command requires both wasi-http and the component model, so
// we enable those by default here.
if self.run.common.wasi.http.replace(true) == Some(false) {
bail!("wasi-http is required for the serve command, and must not be disabled");
}
@ -227,7 +220,10 @@ impl ServeCommand {
}
#[cfg(feature = "wasi-nn")]
{
wasmtime_wasi_nn::wit::ML::add_to_linker(linker, |host| host.nn.as_mut().unwrap())?;
wasmtime_wasi_nn::wit::add_to_linker(linker, |h: &mut Host| {
let ctx = h.nn.as_mut().unwrap();
wasmtime_wasi_nn::wit::WasiNnView::new(&mut h.table, ctx)
})?;
}
}

18
supply-chain/audits.toml

@ -2028,6 +2028,12 @@ criteria = "safe-to-deploy"
version = "0.46.0"
notes = "one use of unsafe to call windows specific api to get console handle."
[[audits.num-traits]]
who = "Andrew Brown <andrew.brown@intel.com>"
criteria = "safe-to-deploy"
version = "0.2.19"
notes = "As advertised: a numeric library. The only `unsafe` is from some float-to-int conversions, which seems expected."
[[audits.num_cpus]]
who = "Alex Crichton <alex@alexcrichton.com>"
criteria = "safe-to-deploy"
@ -2145,12 +2151,24 @@ criteria = "safe-to-deploy"
version = "2.0.0-rc.0"
notes = "As expected, this crate uses `unsafe` to access the `unsafe` `ort-sys` FFI functions; it also includes several `unsafe` implementations of `Send` for several structures. With the `load-dynamic` feature enabled, this crate will be `libloading` external libraries to call FFI functions. With the `fetch-models` feature enabled, this crate can also download arbitrary models to the local filesystem."
[[audits.ort]]
who = "Andrew Brown <andrew.brown@intel.com>"
criteria = "safe-to-deploy"
delta = "2.0.0-rc.0 -> 2.0.0-rc.2"
notes = "Same as previous audit: the crate inherently uses `unsafe` FFI calls for using ONNX through `ort-sys` (e.g., logging C error strings). The changes are relatively uninteresting: a lot of documentation, some `must_use`, and general refactoring due to changes in the underlying API."
[[audits.ort-sys]]
who = "Andrew Brown <andrew.brown@intel.com>"
criteria = "safe-to-deploy"
version = "2.0.0-rc.0"
notes = "As expected, this crate contains a significant number of `unsafe` definitions to expose the FFI surface of the ONNX libraries. Perhaps surprisingly, it also contains some `unsafe` system calls to locate the user's home directory. Another interesting bit is the `build.rs` script: with the `download-binaries` feature enabled, this script will retrieve and link various ONNX libraries from https://parcel.pyke.io. This seems par for the course with this kind of library, though; the alternative--attempting to find the library on an arbitrary system--can be quite complex."
[[audits.ort-sys]]
who = "Andrew Brown <andrew.brown@intel.com>"
criteria = "safe-to-deploy"
delta = "2.0.0-rc.0 -> 2.0.0-rc.2"
notes = "This crate still downloads the ONNX libraries as a part of the `build.rs` script; now with more platform options for pre-built binaries stored in a `dist.txt` file. Otherwise largely unchanged since the previous audit."
[[audits.overload]]
who = "Pat Hickey <phickey@fastly.com>"
criteria = "safe-to-deploy"

Loading…
Cancel
Save