Browse Source

wasi-nn: use resources (#8873)

* wasi-nn: use resources

Recent discussion in the wasi-nn proposal (see [wasi-nn#59], e.g.) has
concluded that the right approach for representing wasi-nn "things"
(tensors, graph, etc.) is with a component model _resource_. This
sweeping change brings Wasmtime's implementation in line with that
decision.

Initially I had structured this PR to remove all of the WITX-based
implementation (#8530). But, after consulting in a Zulip [thread] on
what other WASI proposals aim to do, this PR pivoted to support _both_`
the WITX-based and WIT-based ABIs (e.g., preview1 era versus preview2,
component model era). What is clear is that the WITX-based specification
will remain "frozen in time" while the WIT-based implementation moves
forward.

What that means for this PR is a "split world" paradigm. In many places,
we have to distinguish between the `wit` and `witx` versions of the same
thing. This change isn't the end state yet: it's a big step forward
towards bringing Wasmtime back in line with the WIT spec but, despite my
best efforts, doesn't fully fix all the TODOs left behind over several
years of development. I have, however, taken the liberty to refactor and
fix various parts as I came across them (e.g., the ONNX backend). I plan
to continue working on this in future PRs to figure out a good error
paradigm (the current one is too wordy) and device residence.

[wasi-nn#59]: https://github.com/WebAssembly/wasi-nn/pull/59
[thread]: https://bytecodealliance.zulipchat.com/#narrow/stream/219900-wasi/topic/wasi-nn's.20preview1.20vs.20preview2.20timeline

prtest:full

* vet: audit `ort`-related crate updates

* Simplify `WasiNnView`

With @alexcrichton's help, this change removes the `trait WasiNnView`
and `struct WasiNnImpl` wrapping that the WIT-based implementation used
for accessing the host context. Instead, `WasiNnView` is now a `struct`
containing the mutable references it needs to make things work. This
unwraps one complex layer of abstraction, though it does have the
downside that it complicates CLI code to split borrows of `Host`.

* Temporarily disable WIT check

* Refactor errors to use `trappable_error_type`

This change simplifies the return types of the host implementations of
the WIT-based wasi-nn. There is more work to be done with errors, e.g.,
to catch up with the upstream decision to return errors as resources.
But this is better than the previous mess.
pull/8886/head
Andrew Brown 4 months ago
committed by GitHub
parent
commit
0f4ae88a7a
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 50
      Cargo.lock
  2. 6
      ci/vendor-wit.sh
  3. 4
      crates/bench-api/src/lib.rs
  4. 5
      crates/test-programs/artifacts/build.rs
  5. 16
      crates/test-programs/src/bin/nn_image_classification_winml.rs
  6. 14
      crates/test-programs/src/bin/nn_wit_image_classification_onnx.rs
  7. 25
      crates/test-programs/src/bin/nn_wit_image_classification_openvino.rs
  8. 17
      crates/test-programs/src/bin/nn_wit_image_classification_openvino_named.rs
  9. 9
      crates/test-programs/src/bin/nn_wit_image_classification_winml_named.rs
  10. 22
      crates/test-programs/src/bin/nn_witx_image_classification_onnx.rs
  11. 12
      crates/test-programs/src/bin/nn_witx_image_classification_openvino.rs
  12. 17
      crates/test-programs/src/bin/nn_witx_image_classification_openvino_named.rs
  13. 18
      crates/test-programs/src/bin/nn_witx_image_classification_winml_named.rs
  14. 180
      crates/test-programs/src/nn.rs
  15. 17
      crates/wasi-nn/Cargo.toml
  16. 31
      crates/wasi-nn/src/backend/mod.rs
  17. 338
      crates/wasi-nn/src/backend/onnx.rs
  18. 149
      crates/wasi-nn/src/backend/onnxruntime.rs
  19. 36
      crates/wasi-nn/src/backend/openvino.rs
  20. 146
      crates/wasi-nn/src/backend/winml.rs
  21. 146
      crates/wasi-nn/src/ctx.rs
  22. 51
      crates/wasi-nn/src/lib.rs
  23. 5
      crates/wasi-nn/src/registry/in_memory.rs
  24. 1
      crates/wasi-nn/src/registry/mod.rs
  25. 330
      crates/wasi-nn/src/wit.rs
  26. 122
      crates/wasi-nn/src/witx.rs
  27. 8
      crates/wasi-nn/tests/check/mod.rs
  28. 54
      crates/wasi-nn/tests/exec/mod.rs
  29. 73
      crates/wasi-nn/tests/exec/wit.rs
  30. 52
      crates/wasi-nn/tests/exec/witx.rs
  31. 0
      crates/wasi-nn/tests/fixtures/README.md
  32. 200
      crates/wasi-nn/tests/test-programs.rs
  33. 57
      crates/wasi-nn/wit/wasi-nn.wit
  34. 63
      src/commands/run.rs
  35. 22
      src/commands/serve.rs
  36. 18
      supply-chain/audits.toml

50
Cargo.lock

@ -371,15 +371,6 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "castaway"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a17ed5635fc8536268e5d4de1e22e81ac34419e5f052d4d51f4e01dcc263fcc"
dependencies = [
"rustversion",
]
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.0.83" version = "1.0.83"
@ -486,19 +477,6 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7"
[[package]]
name = "compact_str"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f86b9c4c00838774a6d902ef931eff7470720c51d90c2e32cfe15dc304737b3f"
dependencies = [
"castaway",
"cfg-if",
"itoa",
"ryu",
"static_assertions",
]
[[package]] [[package]]
name = "component-fuzz-util" name = "component-fuzz-util"
version = "0.0.0" version = "0.0.0"
@ -1908,9 +1886,9 @@ dependencies = [
[[package]] [[package]]
name = "num-traits" name = "num-traits"
version = "0.2.15" version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [ dependencies = [
"autocfg", "autocfg",
] ]
@ -2028,21 +2006,22 @@ dependencies = [
[[package]] [[package]]
name = "ort" name = "ort"
version = "2.0.0-rc.0" version = "2.0.0-rc.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f8e5caf4eb2ead4bc137c3ff4e347940e3e556ceb11a4180627f04b63d7342dd" checksum = "0bc80894094c6a875bfac64415ed456fa661081a278a035e22be661305c87e14"
dependencies = [ dependencies = [
"compact_str", "js-sys",
"ort-sys", "ort-sys",
"thiserror", "thiserror",
"tracing", "tracing",
"web-sys",
] ]
[[package]] [[package]]
name = "ort-sys" name = "ort-sys"
version = "2.0.0-rc.0" version = "2.0.0-rc.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f48b5623df2187e0db543ecb2032a6a999081086b7ffddd318000c00b23ace46" checksum = "b3d9c1373fc813d3f024d394f621f4c6dde0734c79b1c17113c3bb5bf0084bbe"
dependencies = [ dependencies = [
"flate2", "flate2",
"sha2", "sha2",
@ -3936,9 +3915,10 @@ dependencies = [
"test-programs-artifacts", "test-programs-artifacts",
"thiserror", "thiserror",
"tracing", "tracing",
"tracing-subscriber",
"walkdir", "walkdir",
"wasi-common",
"wasmtime", "wasmtime",
"wasmtime-wasi",
"wiggle", "wiggle",
"windows", "windows",
] ]
@ -4024,6 +4004,16 @@ dependencies = [
"wast 211.0.1", "wast 211.0.1",
] ]
[[package]]
name = "web-sys"
version = "0.3.57"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b17e741662c70c8bd24ac5c5b18de314a2c26c32bf8346ee1e6f53de919c283"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]] [[package]]
name = "webpki-roots" name = "webpki-roots"
version = "0.26.1" version = "0.26.1"

6
ci/vendor-wit.sh

@ -36,5 +36,9 @@ cp -r $dst crates/wasi-http/wit
# slightly different than above. # slightly different than above.
repo=https://raw.githubusercontent.com/WebAssembly/wasi-nn repo=https://raw.githubusercontent.com/WebAssembly/wasi-nn
revision=e2310b revision=e2310b
curl -L $repo/$revision/wit/wasi-nn.wit -o crates/wasi-nn/wit/wasi-nn.wit
curl -L $repo/$revision/wasi-nn.witx -o crates/wasi-nn/witx/wasi-nn.witx curl -L $repo/$revision/wasi-nn.witx -o crates/wasi-nn/witx/wasi-nn.witx
# TODO: the in-tree `wasi-nn` implementation does not yet fully support the
# latest WIT specification on `main`. To create a baseline for moving forward,
# the in-tree WIT incorporates some but not all of the upstream changes. This
# TODO can be removed once the implementation catches up with the spec.
# curl -L $repo/$revision/wit/wasi-nn.wit -o crates/wasi-nn/wit/wasi-nn.wit

4
crates/bench-api/src/lib.rs

@ -418,7 +418,7 @@ struct BenchState {
struct HostState { struct HostState {
wasi: WasiCtx, wasi: WasiCtx,
#[cfg(feature = "wasi-nn")] #[cfg(feature = "wasi-nn")]
wasi_nn: wasmtime_wasi_nn::WasiNnCtx, wasi_nn: wasmtime_wasi_nn::witx::WasiNnCtx,
} }
impl BenchState { impl BenchState {
@ -509,7 +509,7 @@ impl BenchState {
#[cfg(feature = "wasi-nn")] #[cfg(feature = "wasi-nn")]
wasi_nn: { wasi_nn: {
let (backends, registry) = wasmtime_wasi_nn::preload(&[])?; let (backends, registry) = wasmtime_wasi_nn::preload(&[])?;
wasmtime_wasi_nn::WasiNnCtx::new(backends, registry) wasmtime_wasi_nn::witx::WasiNnCtx::new(backends, registry)
}, },
}; };

5
crates/test-programs/artifacts/build.rs

@ -90,7 +90,10 @@ fn build_and_generate_tests() {
} }
// Generate a component from each test. // Generate a component from each test.
if kind == "nn" || target == "dwarf_imported_memory" || target == "dwarf_shared_memory" { if target == "dwarf_imported_memory"
|| target == "dwarf_shared_memory"
|| target.starts_with("nn_witx")
{
continue; continue;
} }
let adapter = match target.as_str() { let adapter = match target.as_str() {

16
crates/test-programs/src/bin/nn_image_classification_winml.rs

@ -1,16 +0,0 @@
use anyhow::{Context, Result};
use std::fs;
use test_programs::nn::{classify, sort_results};
use wasi_nn::{ExecutionTarget, GraphBuilder, GraphEncoding};
pub fn main() -> Result<()> {
let graph = GraphBuilder::new(GraphEncoding::Onnx, ExecutionTarget::CPU)
.build_from_cache("mobilenet")?;
let tensor = fs::read("fixture/kitten.rgb")
.context("the tensor file to be mapped to the fixture directory")?;
let results = classify(graph, tensor)?;
let top_five = &sort_results(&results)[..5];
println!("found results, sorted top 5: {:?}", top_five);
assert_eq!(top_five[0].class_id(), 284);
Ok(())
}

14
crates/test-programs/src/bin/nn_image_classification_onnx.rs → crates/test-programs/src/bin/nn_wit_image_classification_onnx.rs

@ -1,18 +1,20 @@
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use std::fs; use std::fs;
use test_programs::nn::{classify, sort_results}; use test_programs::nn::{sort_results, wit};
use wasi_nn::{ExecutionTarget, GraphBuilder, GraphEncoding};
pub fn main() -> Result<()> { pub fn main() -> Result<()> {
let model = fs::read("fixture/model.onnx") let model = fs::read("fixture/model.onnx")
.context("the model file to be mapped to the fixture directory")?; .context("the model file to be mapped to the fixture directory")?;
let graph = let graph = wit::load(
GraphBuilder::new(GraphEncoding::Onnx, ExecutionTarget::CPU).build_from_bytes([&model])?; &[model],
wit::GraphEncoding::Onnx,
wit::ExecutionTarget::Cpu,
)?;
let tensor = fs::read("fixture/000000062808.rgb") let tensor = fs::read("fixture/000000062808.rgb")
.context("the tensor file to be mapped to the fixture directory")?; .context("the tensor file to be mapped to the fixture directory")?;
let results = classify(graph, tensor)?; let results = wit::classify(graph, ("input", tensor), "output")?;
let top_five = &sort_results(&results)[..5]; let top_five = &sort_results(&results)[..5];
// 963 is meat loaf, meatloaf. // 963 is "meat loaf, meatloaf."
// https://github.com/onnx/models/blob/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/synset.txt#L963 // https://github.com/onnx/models/blob/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/synset.txt#L963
assert_eq!(top_five[0].class_id(), 963); assert_eq!(top_five[0].class_id(), 963);
println!("found results, sorted top 5: {:?}", top_five); println!("found results, sorted top 5: {:?}", top_five);

25
crates/test-programs/src/bin/nn_wit_image_classification_openvino.rs

@ -0,0 +1,25 @@
use anyhow::{Context, Result};
use std::fs;
use test_programs::nn::{sort_results, wit};
pub fn main() -> Result<()> {
let xml = fs::read("fixture/model.xml")
.context("the model file to be mapped to the fixture directory")?;
let weights = fs::read("fixture/model.bin")
.context("the weights file to be mapped to the fixture directory")?;
let graph = wit::load(
&[xml, weights],
wit::GraphEncoding::Openvino,
wit::ExecutionTarget::Cpu,
)?;
let tensor = fs::read("fixture/tensor.bgr")
.context("the tensor file to be mapped to the fixture directory")?;
let results = wit::classify(
graph,
("input", tensor),
"MobilenetV2/Predictions/Reshape_1",
)?;
let top_five = &sort_results(&results)[..5];
println!("found results, sorted top 5: {:?}", top_five);
Ok(())
}

17
crates/test-programs/src/bin/nn_wit_image_classification_openvino_named.rs

@ -0,0 +1,17 @@
use anyhow::{Context, Result};
use std::fs;
use test_programs::nn::{sort_results, wit};
pub fn main() -> Result<()> {
let graph = wit::load_by_name("fixtures")?;
let tensor: Vec<u8> = fs::read("fixture/tensor.bgr")
.context("the tensor file to be mapped to the fixture directory")?;
let results = wit::classify(
graph,
("input", tensor),
"MobilenetV2/Predictions/Reshape_1",
)?;
let top_five = &sort_results(&results)[..5];
println!("found results, sorted top 5: {:?}", top_five);
Ok(())
}

9
crates/test-programs/src/bin/nn_image_classification_named.rs → crates/test-programs/src/bin/nn_wit_image_classification_winml_named.rs

@ -1,15 +1,14 @@
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use std::fs; use std::fs;
use test_programs::nn::{classify, sort_results}; use test_programs::nn::{sort_results, wit};
use wasi_nn::{ExecutionTarget, GraphBuilder, GraphEncoding};
pub fn main() -> Result<()> { pub fn main() -> Result<()> {
let graph = GraphBuilder::new(GraphEncoding::Openvino, ExecutionTarget::CPU) let graph = wit::load_by_name("mobilenet")?;
.build_from_cache("fixtures")?;
let tensor = fs::read("fixture/tensor.bgr") let tensor = fs::read("fixture/tensor.bgr")
.context("the tensor file to be mapped to the fixture directory")?; .context("the tensor file to be mapped to the fixture directory")?;
let results = classify(graph, tensor)?; let results = wit::classify(graph, ("input", tensor), "output")?;
let top_five = &sort_results(&results)[..5]; let top_five = &sort_results(&results)[..5];
println!("found results, sorted top 5: {:?}", top_five); println!("found results, sorted top 5: {:?}", top_five);
assert_eq!(top_five[0].class_id(), 284);
Ok(()) Ok(())
} }

22
crates/test-programs/src/bin/nn_witx_image_classification_onnx.rs

@ -0,0 +1,22 @@
use anyhow::{Context, Result};
use std::fs;
use test_programs::nn::{sort_results, witx};
pub fn main() -> Result<()> {
let model = fs::read("fixture/model.onnx")
.context("the model file to be mapped to the fixture directory")?;
let graph = witx::load(
&[&model],
witx::GraphEncoding::Onnx,
witx::ExecutionTarget::CPU,
)?;
let tensor = fs::read("fixture/000000062808.rgb")
.context("the tensor file to be mapped to the fixture directory")?;
let results = witx::classify(graph, tensor)?;
let top_five = &sort_results(&results)[..5];
// 963 is "meat loaf, meatloaf."
// https://github.com/onnx/models/blob/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/synset.txt#L963
assert_eq!(top_five[0].class_id(), 963);
println!("found results, sorted top 5: {:?}", top_five);
Ok(())
}

12
crates/test-programs/src/bin/nn_image_classification.rs → crates/test-programs/src/bin/nn_witx_image_classification_openvino.rs

@ -1,18 +1,20 @@
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use std::fs; use std::fs;
use test_programs::nn::{classify, sort_results}; use test_programs::nn::{sort_results, witx};
use wasi_nn::{ExecutionTarget, GraphBuilder, GraphEncoding};
pub fn main() -> Result<()> { pub fn main() -> Result<()> {
let xml = fs::read("fixture/model.xml") let xml = fs::read("fixture/model.xml")
.context("the model file to be mapped to the fixture directory")?; .context("the model file to be mapped to the fixture directory")?;
let weights = fs::read("fixture/model.bin") let weights = fs::read("fixture/model.bin")
.context("the weights file to be mapped to the fixture directory")?; .context("the weights file to be mapped to the fixture directory")?;
let graph = GraphBuilder::new(GraphEncoding::Openvino, ExecutionTarget::CPU) let graph = witx::load(
.build_from_bytes([&xml, &weights])?; &[&xml, &weights],
witx::GraphEncoding::Openvino,
witx::ExecutionTarget::CPU,
)?;
let tensor = fs::read("fixture/tensor.bgr") let tensor = fs::read("fixture/tensor.bgr")
.context("the tensor file to be mapped to the fixture directory")?; .context("the tensor file to be mapped to the fixture directory")?;
let results = classify(graph, tensor)?; let results = witx::classify(graph, tensor)?;
let top_five = &sort_results(&results)[..5]; let top_five = &sort_results(&results)[..5];
println!("found results, sorted top 5: {:?}", top_five); println!("found results, sorted top 5: {:?}", top_five);
Ok(()) Ok(())

17
crates/test-programs/src/bin/nn_witx_image_classification_openvino_named.rs

@ -0,0 +1,17 @@
use anyhow::{Context, Result};
use std::fs;
use test_programs::nn::{sort_results, witx};
pub fn main() -> Result<()> {
let graph = witx::load_by_name(
"fixtures",
witx::GraphEncoding::Openvino,
witx::ExecutionTarget::CPU,
)?;
let tensor: Vec<u8> = fs::read("fixture/tensor.bgr")
.context("the tensor file to be mapped to the fixture directory")?;
let results = witx::classify(graph, tensor)?;
let top_five = &sort_results(&results)[..5];
println!("found results, sorted top 5: {:?}", top_five);
Ok(())
}

18
crates/test-programs/src/bin/nn_witx_image_classification_winml_named.rs

@ -0,0 +1,18 @@
use anyhow::{Context, Result};
use std::fs;
use test_programs::nn::{sort_results, witx};
pub fn main() -> Result<()> {
let graph = witx::load_by_name(
"mobilenet",
witx::GraphEncoding::Onnx,
witx::ExecutionTarget::CPU,
)?;
let tensor = fs::read("fixture/tensor.bgr")
.context("the tensor file to be mapped to the fixture directory")?;
let results = witx::classify(graph, tensor)?;
let top_five = &sort_results(&results)[..5];
println!("found results, sorted top 5: {:?}", top_five);
assert_eq!(top_five[0].class_id(), 284);
Ok(())
}

180
crates/test-programs/src/nn.rs

@ -1,39 +1,147 @@
use anyhow::Result; //! This module attempts to paper over the differences between the two
use std::time::Instant; //! implementations of wasi-nn: the legacy WITX-based version (`mod witx`) and
use wasi_nn::{Graph, TensorType}; //! the up-to-date WIT version (`mod wit`). Since the tests are mainly a simple
//! classifier, this exposes a high-level `classify` function to go along with
/// Run a wasi-nn inference using a simple classifier model (single input, //! `load`, etc.
/// single output). //!
pub fn classify(graph: Graph, tensor: Vec<u8>) -> Result<Vec<f32>> { //! This module exists solely for convenience--e.g., reduces test duplication.
let mut context = graph.init_execution_context()?; //! In the future can be safely disposed of or altered as more tests are added.
println!(
"[nn] created wasi-nn execution context with ID: {}", /// Call `wasi-nn` functions from WebAssembly using the canonical ABI of the
context /// component model via WIT-based tooling. Used by `bin/nn_wit_*.rs` tests.
); pub mod wit {
use anyhow::{anyhow, Result};
// Many classifiers have a single input; currently, this test suite also use std::time::Instant;
// uses tensors of the same shape, though this is not usually the case.
context.set_input(0, TensorType::F32, &[1, 3, 224, 224], &tensor)?; // Generate the wasi-nn bindings based on the `*.wit` files.
println!("[nn] set input tensor: {} bytes", tensor.len()); wit_bindgen::generate!({
path: "../wasi-nn/wit",
let before = Instant::now(); world: "ml",
context.compute()?; default_bindings_module: "test_programs::ml"
println!( });
"[nn] executed graph inference in {} ms", use self::wasi::nn::errors;
before.elapsed().as_millis() use self::wasi::nn::graph::{self, Graph};
); pub use self::wasi::nn::graph::{ExecutionTarget, GraphEncoding}; // Used by tests.
use self::wasi::nn::tensor::{Tensor, TensorType};
// Many classifiers emit probabilities as floating point values; here we
// convert the raw bytes to `f32` knowing all models used here use that /// Load a wasi-nn graph from a set of bytes.
// type. pub fn load(
let mut output_buffer = vec![0u8; 1001 * std::mem::size_of::<f32>()]; bytes: &[Vec<u8>],
let num_bytes = context.get_output(0, &mut output_buffer)?; encoding: GraphEncoding,
println!("[nn] retrieved output tensor: {} bytes", num_bytes); target: ExecutionTarget,
let output: Vec<f32> = output_buffer[..num_bytes] ) -> Result<Graph> {
.chunks(4) graph::load(bytes, encoding, target).map_err(err_as_anyhow)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])) }
.collect();
Ok(output) /// Load a wasi-nn graph by name.
pub fn load_by_name(name: &str) -> Result<Graph> {
graph::load_by_name(name).map_err(err_as_anyhow)
}
/// Run a wasi-nn inference using a simple classifier model (single input,
/// single output).
pub fn classify(graph: Graph, input: (&str, Vec<u8>), output: &str) -> Result<Vec<f32>> {
let context = graph.init_execution_context().map_err(err_as_anyhow)?;
println!(
"[nn] created wasi-nn execution context with ID: {:?}",
context
);
// Many classifiers have a single input; currently, this test suite also
// uses tensors of the same shape, though this is not usually the case.
let tensor = Tensor::new(&vec![1, 3, 224, 224], TensorType::Fp32, &input.1);
context.set_input(input.0, tensor).map_err(err_as_anyhow)?;
println!("[nn] set input tensor: {} bytes", input.1.len());
let before = Instant::now();
context.compute().map_err(err_as_anyhow)?;
println!(
"[nn] executed graph inference in {} ms",
before.elapsed().as_millis()
);
// Many classifiers emit probabilities as floating point values; here we
// convert the raw bytes to `f32` knowing all models used here use that
// type.
let output = context.get_output(output).map_err(err_as_anyhow)?;
println!(
"[nn] retrieved output tensor: {} bytes",
output.data().len()
);
let output: Vec<f32> = output
.data()
.chunks(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect();
Ok(output)
}
fn err_as_anyhow(e: errors::Error) -> anyhow::Error {
anyhow!("error: {e:?}")
}
}
/// Call `wasi-nn` functions from WebAssembly using the legacy WITX-based
/// tooling. This older API has been deprecated for the newer WIT-based API but
/// retained for backwards compatibility testing--i.e., `bin/nn_witx_*.rs`
/// tests.
pub mod witx {
use anyhow::Result;
use std::time::Instant;
pub use wasi_nn::{ExecutionTarget, GraphEncoding};
use wasi_nn::{Graph, GraphBuilder, TensorType};
/// Load a wasi-nn graph from a set of bytes.
pub fn load(
bytes: &[&[u8]],
encoding: GraphEncoding,
target: ExecutionTarget,
) -> Result<Graph> {
Ok(GraphBuilder::new(encoding, target).build_from_bytes(bytes)?)
}
/// Load a wasi-nn graph by name.
pub fn load_by_name(
name: &str,
encoding: GraphEncoding,
target: ExecutionTarget,
) -> Result<Graph> {
Ok(GraphBuilder::new(encoding, target).build_from_cache(name)?)
}
/// Run a wasi-nn inference using a simple classifier model (single input,
/// single output).
pub fn classify(graph: Graph, tensor: Vec<u8>) -> Result<Vec<f32>> {
let mut context = graph.init_execution_context()?;
println!(
"[nn] created wasi-nn execution context with ID: {}",
context
);
// Many classifiers have a single input; currently, this test suite also
// uses tensors of the same shape, though this is not usually the case.
context.set_input(0, TensorType::F32, &[1, 3, 224, 224], &tensor)?;
println!("[nn] set input tensor: {} bytes", tensor.len());
let before = Instant::now();
context.compute()?;
println!(
"[nn] executed graph inference in {} ms",
before.elapsed().as_millis()
);
// Many classifiers emit probabilities as floating point values; here we
// convert the raw bytes to `f32` knowing all models used here use that
// type.
let mut output_buffer = vec![0u8; 1001 * std::mem::size_of::<f32>()];
let num_bytes = context.get_output(0, &mut output_buffer)?;
println!("[nn] retrieved output tensor: {} bytes", num_bytes);
let output: Vec<f32> = output_buffer[..num_bytes]
.chunks(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect();
Ok(output)
}
} }
/// Sort some classification probabilities. /// Sort some classification probabilities.

17
crates/wasi-nn/Cargo.toml

@ -20,7 +20,11 @@ anyhow = { workspace = true, features = ['std'] }
wiggle = { workspace = true, features = ["wasmtime"] } wiggle = { workspace = true, features = ["wasmtime"] }
# This dependency is necessary for the WIT-generation macros to work: # This dependency is necessary for the WIT-generation macros to work:
wasmtime = { workspace = true, features = ["component-model", "runtime"] } wasmtime = { workspace = true, features = [
"component-model",
"runtime",
"std",
] }
# These dependencies are necessary for the wasi-nn implementation: # These dependencies are necessary for the wasi-nn implementation:
tracing = { workspace = true } tracing = { workspace = true }
@ -29,7 +33,7 @@ openvino = { version = "0.6.0", features = [
"runtime-linking", "runtime-linking",
], optional = true } ], optional = true }
ort = { version = "2.0.0-rc.0", default-features = false, features = [ ort = { version = "2.0.0-rc.2", default-features = false, features = [
"copy-dylibs", "copy-dylibs",
"download-binaries", "download-binaries",
], optional = true } ], optional = true }
@ -46,16 +50,17 @@ walkdir = { workspace = true }
cap-std = { workspace = true } cap-std = { workspace = true }
libtest-mimic = { workspace = true } libtest-mimic = { workspace = true }
test-programs-artifacts = { workspace = true } test-programs-artifacts = { workspace = true }
wasi-common = { workspace = true, features = ["sync"] } wasmtime-wasi = { workspace = true, features = ["preview1"] }
wasmtime = { workspace = true, features = ["cranelift"] } wasmtime = { workspace = true, features = ["cranelift"] }
tracing-subscriber = { workspace = true }
[features] [features]
default = ["openvino", "winml"] default = ["openvino", "winml"]
# openvino is available on all platforms, it requires openvino installed. # OpenVINO is available on all platforms; it requires OpenVINO to be installed.
openvino = ["dep:openvino"] openvino = ["dep:openvino"]
# onnx is available on all platforms. # ONNX is available on all platforms.
onnx = ["dep:ort"] onnx = ["dep:ort"]
# winml is only available on Windows 10 1809 and later. # WinML is only available on Windows 10 1809 and later.
winml = ["dep:windows"] winml = ["dep:windows"]
[[test]] [[test]]

31
crates/wasi-nn/src/backend/mod.rs

@ -3,20 +3,20 @@
//! implementations to maintain backend-specific state between calls. //! implementations to maintain backend-specific state between calls.
#[cfg(feature = "onnx")] #[cfg(feature = "onnx")]
pub mod onnxruntime; pub mod onnx;
#[cfg(feature = "openvino")] #[cfg(feature = "openvino")]
pub mod openvino; pub mod openvino;
#[cfg(all(feature = "winml", target_os = "windows"))] #[cfg(all(feature = "winml", target_os = "windows"))]
pub mod winml; pub mod winml;
#[cfg(feature = "onnx")] #[cfg(feature = "onnx")]
use self::onnxruntime::OnnxBackend; use self::onnx::OnnxBackend;
#[cfg(feature = "openvino")] #[cfg(feature = "openvino")]
use self::openvino::OpenvinoBackend; use self::openvino::OpenvinoBackend;
#[cfg(all(feature = "winml", target_os = "windows"))] #[cfg(all(feature = "winml", target_os = "windows"))]
use self::winml::WinMLBackend; use self::winml::WinMLBackend;
use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor}; use crate::wit::{ExecutionTarget, GraphEncoding, Tensor};
use crate::{Backend, ExecutionContext, Graph}; use crate::{Backend, ExecutionContext, Graph};
use std::fs::File; use std::fs::File;
use std::io::Read; use std::io::Read;
@ -69,9 +69,30 @@ pub trait BackendGraph: Send + Sync {
/// A [BackendExecutionContext] performs the actual inference; this is the /// A [BackendExecutionContext] performs the actual inference; this is the
/// backing implementation for a user-facing execution context. /// backing implementation for a user-facing execution context.
pub trait BackendExecutionContext: Send + Sync { pub trait BackendExecutionContext: Send + Sync {
fn set_input(&mut self, index: u32, tensor: &Tensor) -> Result<(), BackendError>; fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError>;
fn compute(&mut self) -> Result<(), BackendError>; fn compute(&mut self) -> Result<(), BackendError>;
fn get_output(&mut self, index: u32, destination: &mut [u8]) -> Result<u32, BackendError>; fn get_output(&mut self, id: Id) -> Result<Tensor, BackendError>;
}
/// An identifier for a tensor in a [Graph].
#[derive(Debug)]
pub enum Id {
Index(u32),
Name(String),
}
impl Id {
pub fn index(&self) -> Option<u32> {
match self {
Id::Index(i) => Some(*i),
Id::Name(_) => None,
}
}
pub fn name(&self) -> Option<&str> {
match self {
Id::Index(_) => None,
Id::Name(n) => Some(n),
}
}
} }
/// Errors returned by a backend; [BackendError::BackendAccess] is a catch-all /// Errors returned by a backend; [BackendError::BackendAccess] is a catch-all

338
crates/wasi-nn/src/backend/onnx.rs

@ -0,0 +1,338 @@
//! Implements a `wasi-nn` [`BackendInner`] using ONNX via the `ort` crate.
use super::{BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner};
use crate::backend::{read, Id};
use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor, TensorType};
use crate::{ExecutionContext, Graph};
use anyhow::Context;
use ort::{inputs, GraphOptimizationLevel, Session};
use std::path::Path;
use std::sync::{Arc, Mutex};
#[derive(Default)]
pub struct OnnxBackend();
unsafe impl Send for OnnxBackend {}
unsafe impl Sync for OnnxBackend {}
impl BackendInner for OnnxBackend {
fn encoding(&self) -> GraphEncoding {
GraphEncoding::Onnx
}
fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError> {
if builders.len() != 1 {
return Err(BackendError::InvalidNumberOfBuilders(1, builders.len()).into());
}
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.commit_from_memory(builders[0])?;
let box_: Box<dyn BackendGraph> =
Box::new(OnnxGraph(Arc::new(Mutex::new(session)), target));
Ok(box_.into())
}
fn as_dir_loadable<'a>(&'a mut self) -> Option<&'a mut dyn BackendFromDir> {
Some(self)
}
}
impl BackendFromDir for OnnxBackend {
fn load_from_dir(
&mut self,
path: &Path,
target: ExecutionTarget,
) -> Result<Graph, BackendError> {
let model = read(&path.join("model.onnx"))?;
self.load(&[&model], target)
}
}
struct OnnxGraph(Arc<Mutex<Session>>, #[allow(dead_code)] ExecutionTarget);
unsafe impl Send for OnnxGraph {}
unsafe impl Sync for OnnxGraph {}
impl BackendGraph for OnnxGraph {
fn init_execution_context(&self) -> Result<ExecutionContext, BackendError> {
let session = self.0.lock().unwrap();
// We need to hold on to the names of the inputs in order for
// `set_input` to work with both indexes and names. Having the
// dimensions and type around is useful for validation but could be
// retrieved from the session.
let mut inputs = vec![];
for input in &session.inputs {
let shape = Shape::from_onnx_input(input)?;
inputs.push(TensorSlot {
shape,
tensor: None,
});
}
// We need to keep track of the output shapes since they are used for
// creating the output tensor.
let mut outputs = vec![];
for output in &session.outputs {
let shape = Shape::from_onnx_output(output)?;
outputs.push(TensorSlot {
shape,
tensor: None,
});
}
let box_: Box<dyn BackendExecutionContext> = Box::new(OnnxExecutionContext {
session: self.0.clone(),
inputs,
outputs,
});
Ok(box_.into())
}
}
struct OnnxExecutionContext {
session: Arc<Mutex<Session>>,
inputs: Vec<TensorSlot>,
outputs: Vec<TensorSlot>,
}
unsafe impl Send for OnnxExecutionContext {}
unsafe impl Sync for OnnxExecutionContext {}
impl OnnxExecutionContext {
/// Helper function for finding the internal index of a tensor by [`Id`].
fn find(&self, id: Id, list: &[TensorSlot]) -> Result<usize, BackendError> {
let index = match id {
Id::Index(i) => {
let i = i as usize;
if i < list.len() {
i
} else {
return Err(BackendError::BackendAccess(anyhow::anyhow!(
"incorrect tensor index: {i} >= {}",
list.len()
)));
}
}
Id::Name(n) => list.iter().position(|s| s.shape.name == n).ok_or_else(|| {
BackendError::BackendAccess(anyhow::anyhow!("unknown tensor name: {n}"))
})?,
};
Ok(index)
}
}
impl BackendExecutionContext for OnnxExecutionContext {
fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> {
let index = self.find(id, &self.inputs)?;
let input = &mut self.inputs[index];
if let Err(e) = input.shape.matches(tensor) {
return Err(e.into());
}
// Hold the tensor data on the context until `compute` is called.
input.tensor.replace(tensor.clone());
Ok(())
}
fn compute(&mut self) -> Result<(), BackendError> {
let mut session_inputs: Vec<ort::SessionInputValue<'_>> = vec![];
for i in &self.inputs {
session_inputs.extend(to_input_value(i)?);
}
let session = self.session.lock().unwrap();
let session_outputs = session.run(session_inputs.as_slice())?;
for i in 0..self.outputs.len() {
// TODO: fix preexisting gap--this only handles f32 tensors.
let raw: (Vec<i64>, &[f32]) = session_outputs[i].try_extract_raw_tensor()?;
let f32s = raw.1.to_vec();
let output = &mut self.outputs[i];
output.tensor.replace(Tensor {
dimensions: output.shape.dimensions_as_u32()?,
ty: output.shape.ty,
data: f32_vec_to_bytes(f32s),
});
}
Ok(())
}
fn get_output(&mut self, id: Id) -> Result<Tensor, BackendError> {
let index = self.find(id, &self.outputs)?;
let output = &self.outputs[index];
if let Some(tensor) = &output.tensor {
Ok(tensor.clone())
} else {
Err(BackendError::BackendAccess(anyhow::anyhow!(
"missing output tensor: {}; has `compute` been called?",
output.shape.name
)))
}
}
}
impl From<ort::Error> for BackendError {
fn from(e: ort::Error) -> Self {
BackendError::BackendAccess(e.into())
}
}
/// Holds a slot for ONNX session inputs and outputs.
///
/// TODO: it seems unfortunate that we have to "hold" some extra data per
/// session but in the input case, this is necessary for name-based indexing.
struct TensorSlot {
shape: Shape,
tensor: Option<Tensor>,
}
/// Describes a tensor in ONNX terms.
struct Shape {
name: String,
dimensions: Vec<i64>,
ty: TensorType,
}
impl Shape {
fn from_onnx_input(input: &ort::Input) -> Result<Self, BackendError> {
let name = input.name.clone();
let (dimensions, ty) = convert_value_type(&input.input_type)?;
Ok(Self {
name,
dimensions,
ty,
})
}
fn from_onnx_output(output: &ort::Output) -> Result<Self, BackendError> {
let name = output.name.clone();
let (dimensions, ty) = convert_value_type(&output.output_type)?;
Ok(Self {
name,
dimensions,
ty,
})
}
fn dimensions_as_u32(&self) -> Result<Vec<u32>, BackendError> {
self.dimensions
.iter()
.map(|d| if *d == -1 { Ok(1) } else { convert_i64(d) })
.collect()
}
fn matches(&self, tensor: &Tensor) -> anyhow::Result<()> {
if self.dimensions.len() != tensor.dimensions.len() {
return Err(anyhow::anyhow!(
"input tensor cardinality does not match model: {:?} != {:?}",
self.dimensions,
tensor.dimensions
));
} else {
for (&shape_dim, &tensor_dim) in self.dimensions.iter().zip(tensor.dimensions.iter()) {
let tensor_dim = tensor_dim as i64;
if !is_dynamic_dimension(shape_dim) && shape_dim != tensor_dim {
return Err(anyhow::anyhow!(
"input tensor dimensions do not match model: {:?} != {:?}",
self.dimensions,
tensor.dimensions
));
}
}
}
if self.ty != tensor.ty {
return Err(anyhow::anyhow!(
"input tensor type does not match model: {:?} != {:?}",
self.ty,
tensor.ty
));
}
Ok(())
}
}
fn convert_value_type(vt: &ort::ValueType) -> Result<(Vec<i64>, TensorType), BackendError> {
match vt {
ort::ValueType::Tensor { ty, dimensions } => {
let dims = dimensions.clone();
let ty = (*ty).try_into()?;
Ok((dims, ty))
}
_ => Err(BackendError::BackendAccess(anyhow::anyhow!(
"unsupported input type: {vt:?}"
))),
}
}
fn convert_i64(i: &i64) -> Result<u32, BackendError> {
u32::try_from(*i).map_err(|d| -> BackendError {
anyhow::anyhow!("unable to convert dimension to u32: {d}").into()
})
}
impl TryFrom<ort::TensorElementType> for TensorType {
type Error = BackendError;
fn try_from(ty: ort::TensorElementType) -> Result<Self, Self::Error> {
match ty {
ort::TensorElementType::Float32 => Ok(TensorType::Fp32),
ort::TensorElementType::Float64 => Ok(TensorType::Fp64),
ort::TensorElementType::Uint8 => Ok(TensorType::U8),
ort::TensorElementType::Int32 => Ok(TensorType::I32),
ort::TensorElementType::Int64 => Ok(TensorType::I64),
_ => Err(BackendError::BackendAccess(anyhow::anyhow!(
"unsupported tensor type: {ty:?}"
))),
}
}
}
fn to_input_value(slot: &TensorSlot) -> Result<[ort::SessionInputValue<'_>; 1], BackendError> {
match &slot.tensor {
Some(tensor) => match tensor.ty {
TensorType::Fp32 => {
let data = bytes_to_f32_vec(tensor.data.to_vec());
let dimensions = tensor
.dimensions
.iter()
.map(|d| *d as i64) // TODO: fewer conversions
.collect::<Vec<i64>>();
Ok(inputs![(dimensions, Arc::new(data.into_boxed_slice()))]
.context("failed to create ONNX session input")?)
}
_ => {
unimplemented!("{:?} not supported by ONNX", tensor.ty);
}
},
None => {
return Err(BackendError::BackendAccess(anyhow::anyhow!(
"missing input tensor: {}",
slot.shape.name
)));
}
}
}
pub fn f32_vec_to_bytes(data: Vec<f32>) -> Vec<u8> {
let chunks: Vec<[u8; 4]> = data.into_iter().map(|f| f.to_le_bytes()).collect();
let result: Vec<u8> = chunks.iter().flatten().copied().collect();
result
}
pub fn bytes_to_f32_vec(data: Vec<u8>) -> Vec<f32> {
let chunks: Vec<&[u8]> = data.chunks(4).collect();
let v: Vec<f32> = chunks
.into_iter()
.map(|c| f32::from_le_bytes(c.try_into().unwrap()))
.collect();
v.into_iter().collect()
}
/// Returns whether the dimension is dynamic.
///
/// ONNX uses [dimensional variables] (i.e., name strings) to indicate that the
/// value of a tensor dimension is user-defined, not fixed by the model. This is
/// useful for batching up several inference requests, e.g. When `ort` returns a
/// dimension of this kind, though, it uses `-1` to indicate that the dimension
/// is dynamic.
///
/// [dimensional variables]:
/// https://onnx.ai/onnx/repo-docs/IR.html#static-tensor-shapes
fn is_dynamic_dimension(d: i64) -> bool {
d == -1
}

149
crates/wasi-nn/src/backend/onnxruntime.rs

@ -1,149 +0,0 @@
//! Implements a `wasi-nn` [`BackendInner`] using ONNX via ort.
use super::{BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner};
use crate::backend::read;
use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor, TensorType};
use crate::{ExecutionContext, Graph};
use ort::{inputs, GraphOptimizationLevel, Session};
use std::path::Path;
use std::sync::{Arc, Mutex};
#[derive(Default)]
pub struct OnnxBackend();
unsafe impl Send for OnnxBackend {}
unsafe impl Sync for OnnxBackend {}
impl BackendInner for OnnxBackend {
fn encoding(&self) -> GraphEncoding {
GraphEncoding::Onnx
}
fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError> {
if builders.len() != 1 {
return Err(BackendError::InvalidNumberOfBuilders(1, builders.len()).into());
}
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_model_from_memory(builders[0])?;
let box_: Box<dyn BackendGraph> =
Box::new(ONNXGraph(Arc::new(Mutex::new(session)), target));
Ok(box_.into())
}
fn as_dir_loadable<'a>(&'a mut self) -> Option<&'a mut dyn BackendFromDir> {
Some(self)
}
}
impl BackendFromDir for OnnxBackend {
fn load_from_dir(
&mut self,
path: &Path,
target: ExecutionTarget,
) -> Result<Graph, BackendError> {
let model = read(&path.join("model.onnx"))?;
self.load(&[&model], target)
}
}
struct ONNXGraph(Arc<Mutex<Session>>, #[allow(dead_code)] ExecutionTarget);
unsafe impl Send for ONNXGraph {}
unsafe impl Sync for ONNXGraph {}
impl BackendGraph for ONNXGraph {
fn init_execution_context(&self) -> Result<ExecutionContext, BackendError> {
let session = self.0.lock().unwrap();
let inputs = session.inputs.iter().map(|_| None).collect::<Vec<_>>();
let outputs = session.outputs.iter().map(|_| None).collect::<Vec<_>>();
let box_: Box<dyn BackendExecutionContext> = Box::new(ONNXExecutionContext {
session: self.0.clone(),
inputs,
outputs,
});
Ok(box_.into())
}
}
struct ONNXExecutionContext {
session: Arc<Mutex<Session>>,
inputs: Vec<Option<Tensor>>,
outputs: Vec<Option<Vec<u8>>>,
}
unsafe impl Send for ONNXExecutionContext {}
unsafe impl Sync for ONNXExecutionContext {}
impl BackendExecutionContext for ONNXExecutionContext {
fn set_input(&mut self, index: u32, tensor: &Tensor) -> Result<(), BackendError> {
self.inputs[index as usize].replace(tensor.clone());
Ok(())
}
fn compute(&mut self) -> Result<(), BackendError> {
let shaped_inputs: Vec<_> = self
.inputs
.iter()
.enumerate()
.map(|(i, _o)| {
let input = self.inputs[i].as_ref().unwrap();
let dims = input
.dimensions
.as_slice()
.iter()
.map(|d| *d as i64)
.collect::<Vec<_>>();
match input.tensor_type {
TensorType::Fp32 => {
let data = bytes_to_f32_vec(input.data.to_vec());
inputs![(dims, Arc::new(data.into_boxed_slice()))].unwrap()
}
_ => {
unimplemented!("{:?} not supported by ONNX", input.tensor_type);
}
}
})
.flatten()
.collect();
let session = self.session.lock().unwrap();
let res = session.run(shaped_inputs.as_slice())?;
for i in 0..self.outputs.len() {
let raw: (Vec<i64>, &[f32]) = res[i].extract_raw_tensor()?;
let f32s = raw.1.to_vec();
self.outputs[i].replace(f32_vec_to_bytes(f32s));
}
Ok(())
}
fn get_output(&mut self, index: u32, destination: &mut [u8]) -> Result<u32, BackendError> {
let output = self.outputs[index as usize].as_ref().unwrap();
destination[..output.len()].copy_from_slice(output);
Ok(output.len() as u32)
}
}
impl From<ort::Error> for BackendError {
fn from(e: ort::Error) -> Self {
BackendError::BackendAccess(e.into())
}
}
pub fn f32_vec_to_bytes(data: Vec<f32>) -> Vec<u8> {
let chunks: Vec<[u8; 4]> = data.into_iter().map(|f| f.to_le_bytes()).collect();
let result: Vec<u8> = chunks.iter().flatten().copied().collect();
result
}
pub fn bytes_to_f32_vec(data: Vec<u8>) -> Vec<f32> {
let chunks: Vec<&[u8]> = data.chunks(4).collect();
let v: Vec<f32> = chunks
.into_iter()
.map(|c| f32::from_le_bytes(c.try_into().unwrap()))
.collect();
v.into_iter().collect()
}

36
crates/wasi-nn/src/backend/openvino.rs

@ -1,9 +1,9 @@
//! Implements a `wasi-nn` [`BackendInner`] using OpenVINO. //! Implements a `wasi-nn` [`BackendInner`] using OpenVINO.
use super::{ use super::{
read, BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner, read, BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner, Id,
}; };
use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor, TensorType}; use crate::wit::{self, ExecutionTarget, GraphEncoding, Tensor, TensorType};
use crate::{ExecutionContext, Graph}; use crate::{ExecutionContext, Graph};
use openvino::{InferenceError, Layout, Precision, SetupError, TensorDesc}; use openvino::{InferenceError, Layout, Precision, SetupError, TensorDesc};
use std::path::Path; use std::path::Path;
@ -99,12 +99,15 @@ impl BackendGraph for OpenvinoGraph {
struct OpenvinoExecutionContext(Arc<openvino::CNNNetwork>, openvino::InferRequest); struct OpenvinoExecutionContext(Arc<openvino::CNNNetwork>, openvino::InferRequest);
impl BackendExecutionContext for OpenvinoExecutionContext { impl BackendExecutionContext for OpenvinoExecutionContext {
fn set_input(&mut self, index: u32, tensor: &Tensor) -> Result<(), BackendError> { fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> {
let input_name = self.0.get_input_name(index as usize)?; let input_name = match id {
Id::Index(i) => self.0.get_input_name(i as usize)?,
Id::Name(name) => name,
};
// Construct the blob structure. TODO: there must be some good way to // Construct the blob structure. TODO: there must be some good way to
// discover the layout here; `desc` should not have to default to NHWC. // discover the layout here; `desc` should not have to default to NHWC.
let precision = map_tensor_type_to_precision(tensor.tensor_type); let precision = map_tensor_type_to_precision(tensor.ty);
let dimensions = tensor let dimensions = tensor
.dimensions .dimensions
.iter() .iter()
@ -123,17 +126,20 @@ impl BackendExecutionContext for OpenvinoExecutionContext {
Ok(()) Ok(())
} }
fn get_output(&mut self, index: u32, destination: &mut [u8]) -> Result<u32, BackendError> { fn get_output(&mut self, id: Id) -> Result<Tensor, BackendError> {
let output_name = self.0.get_output_name(index as usize)?; let output_name = match id {
Id::Index(i) => self.0.get_output_name(i as usize)?,
Id::Name(name) => name,
};
let dimensions = vec![]; // TODO: get actual shape
let ty = wit::TensorType::Fp32; // TODO: get actual type.
let blob = self.1.get_blob(&output_name)?; let blob = self.1.get_blob(&output_name)?;
let blob_size = blob.byte_len()?; let data = blob.buffer()?.to_vec();
if blob_size > destination.len() { Ok(Tensor {
return Err(BackendError::NotEnoughMemory(blob_size)); dimensions,
} ty,
data,
// Copy the tensor data into the destination buffer. })
destination[..blob_size].copy_from_slice(blob.buffer()?);
Ok(blob_size as u32)
} }
} }

146
crates/wasi-nn/src/backend/winml.rs

@ -1,16 +1,27 @@
//! Implements a `wasi-nn` [`BackendInner`] using WinML. //! Implements a `wasi-nn` [`BackendInner`] using WinML.
//!
use super::{BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner}; //! Note that the [docs.rs] documentation for the `windows` crate does have the
use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor}; //! right features turned on to read about the functions used; see Microsoft's
//! private documentation instead: [microsoft.github.io/windows-docs-rs].
//!
//! [docs.rs]: https://docs.rs/windows
//! [microsoft.github.io/windows-docs-rs]: https://microsoft.github.io/windows-docs-rs/doc/windows/AI/MachineLearning
use crate::backend::{
BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner, Id,
};
use crate::wit::{ExecutionTarget, GraphEncoding, Tensor, TensorType};
use crate::{ExecutionContext, Graph}; use crate::{ExecutionContext, Graph};
use std::{fs::File, io::Read, mem::size_of, path::Path}; use std::{fs::File, io::Read, mem::size_of, path::Path};
use windows::core::{ComInterface, HSTRING}; use windows::core::{ComInterface, HSTRING};
use windows::Foundation::Collections::IVectorView;
use windows::Storage::Streams::{ use windows::Storage::Streams::{
DataWriter, InMemoryRandomAccessStream, RandomAccessStreamReference, DataWriter, InMemoryRandomAccessStream, RandomAccessStreamReference,
}; };
use windows::AI::MachineLearning::{ use windows::AI::MachineLearning::{
LearningModel, LearningModelBinding, LearningModelDevice, LearningModelDeviceKind, ILearningModelFeatureDescriptor, LearningModel, LearningModelBinding, LearningModelDevice,
LearningModelEvaluationResult, LearningModelSession, TensorFeatureDescriptor, TensorFloat, LearningModelDeviceKind, LearningModelEvaluationResult, LearningModelSession,
TensorFeatureDescriptor, TensorFloat,
}; };
#[derive(Default)] #[derive(Default)]
@ -94,29 +105,64 @@ impl WinMLExecutionContext {
} }
} }
impl WinMLExecutionContext {
/// Helper function for finding the internal index of a tensor by [`Id`].
fn find(
&self,
id: Id,
list: &IVectorView<ILearningModelFeatureDescriptor>,
) -> Result<u32, BackendError> {
let index = match id {
Id::Index(i) => {
if i < list.Size()? {
i
} else {
return Err(BackendError::BackendAccess(anyhow::anyhow!(
"incorrect tensor index: {i} >= {}",
list.Size()?
)));
}
}
Id::Name(name) => list
.into_iter()
.position(|d| d.Name().unwrap() == name)
.ok_or_else(|| {
BackendError::BackendAccess(anyhow::anyhow!("unknown tensor name: {name}"))
})? as u32,
};
Ok(index)
}
}
impl BackendExecutionContext for WinMLExecutionContext { impl BackendExecutionContext for WinMLExecutionContext {
fn set_input(&mut self, index: u32, tensor: &Tensor) -> Result<(), BackendError> { fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> {
let input_features = self.session.Model()?.InputFeatures()?;
let index = self.find(id, &input_features)?;
let input = input_features.GetAt(index)?;
// TODO: Support other tensor types. Only FP32 is supported right now. // TODO: Support other tensor types. Only FP32 is supported right now.
match tensor.tensor_type { match tensor.ty {
crate::wit::types::TensorType::Fp32 => {} crate::wit::types::TensorType::Fp32 => {}
_ => unimplemented!(), _ => unimplemented!(),
} }
let input = self.session.Model()?.InputFeatures()?.GetAt(index)?; // TODO: this is quite unsafe and probably incorrect--will the slice
unsafe { // still be around by the time the binding is used?!
let data = std::slice::from_raw_parts( let data = unsafe {
std::slice::from_raw_parts(
tensor.data.as_ptr() as *const f32, tensor.data.as_ptr() as *const f32,
tensor.data.len() / 4, tensor.data.len() / size_of::<f32>(),
); )
};
self.binding.Bind(
&input.Name()?, self.binding.Bind(
&TensorFloat::CreateFromArray( &input.Name()?,
&input.cast::<TensorFeatureDescriptor>()?.Shape()?, &TensorFloat::CreateFromArray(
data, &input.cast::<TensorFeatureDescriptor>()?.Shape()?,
)?, data,
)?; )?,
} )?;
Ok(()) Ok(())
} }
@ -125,33 +171,32 @@ impl BackendExecutionContext for WinMLExecutionContext {
Ok(()) Ok(())
} }
fn get_output(&mut self, index: u32, destination: &mut [u8]) -> Result<u32, BackendError> { fn get_output(&mut self, id: Id) -> Result<Tensor, BackendError> {
if self.result.is_none() { if let Some(result) = &self.result {
let output_features = self.session.Model()?.OutputFeatures()?;
let index = self.find(id, &output_features)?;
let output = output_features.GetAt(index)?;
// TODO: this only handles FP32!
let tensor = result
.Outputs()?
.Lookup(&output.Name()?)?
.cast::<TensorFloat>()?;
let dimensions = dimensions_as_u32(&tensor.Shape()?)?;
let view = tensor.GetAsVectorView()?;
let mut data = Vec::with_capacity(view.Size()? as usize * size_of::<f32>());
for f in view.into_iter() {
data.extend(f.to_le_bytes());
}
Ok(Tensor {
ty: TensorType::Fp32,
dimensions,
data,
})
} else {
return Err(BackendError::BackendAccess(anyhow::Error::msg( return Err(BackendError::BackendAccess(anyhow::Error::msg(
"Output is not ready.", "Output is not ready.",
))); )));
} }
let output_name = self.session.Model()?.OutputFeatures()?.GetAt(index)?;
let output_name_hstring = output_name.Name()?;
let vector_view = self
.result
.as_ref()
.unwrap()
.Outputs()?
.Lookup(&output_name_hstring)?
.cast::<TensorFloat>()?
.GetAsVectorView()?;
let output: Vec<f32> = vector_view.into_iter().collect();
let len_to_copy = output.len() * size_of::<f32>();
unsafe {
destination[..len_to_copy].copy_from_slice(std::slice::from_raw_parts(
output.as_ptr() as *const u8,
len_to_copy,
));
}
Ok(len_to_copy as u32)
} }
} }
@ -168,3 +213,16 @@ impl From<windows::core::Error> for BackendError {
BackendError::BackendAccess(anyhow::Error::new(e)) BackendError::BackendAccess(anyhow::Error::new(e))
} }
} }
fn dimensions_as_u32(dimensions: &IVectorView<i64>) -> Result<Vec<u32>, BackendError> {
dimensions
.into_iter()
.map(|d| if d == -1 { Ok(1) } else { convert_i64(d) })
.collect()
}
fn convert_i64(i: i64) -> Result<u32, BackendError> {
u32::try_from(i).map_err(|d| -> BackendError {
anyhow::anyhow!("unable to convert dimension to u32: {d}").into()
})
}

146
crates/wasi-nn/src/ctx.rs

@ -1,146 +0,0 @@
//! Implements the host state for the `wasi-nn` API: [WasiNnCtx].
use crate::backend::{self, BackendError};
use crate::wit::types::GraphEncoding;
use crate::{Backend, ExecutionContext, Graph, InMemoryRegistry, Registry};
use anyhow::anyhow;
use std::{collections::HashMap, hash::Hash, path::Path};
use thiserror::Error;
use wiggle::GuestError;
type GraphId = u32;
type GraphExecutionContextId = u32;
type BackendName = String;
type GraphDirectory = String;
/// Construct an in-memory registry from the available backends and a list of
/// `(<backend name>, <graph directory>)`. This assumes graphs can be loaded
/// from a local directory, which is a safe assumption currently for the current
/// model types.
pub fn preload(
preload_graphs: &[(BackendName, GraphDirectory)],
) -> anyhow::Result<(impl IntoIterator<Item = Backend>, Registry)> {
let mut backends = backend::list();
let mut registry = InMemoryRegistry::new();
for (kind, path) in preload_graphs {
let kind_ = kind.parse()?;
let backend = backends
.iter_mut()
.find(|b| b.encoding() == kind_)
.ok_or(anyhow!("unsupported backend: {}", kind))?
.as_dir_loadable()
.ok_or(anyhow!("{} does not support directory loading", kind))?;
registry.load(backend, Path::new(path))?;
}
Ok((backends, Registry::from(registry)))
}
/// Capture the state necessary for calling into the backend ML libraries.
pub struct WasiNnCtx {
pub(crate) backends: HashMap<GraphEncoding, Backend>,
pub(crate) registry: Registry,
pub(crate) graphs: Table<GraphId, Graph>,
pub(crate) executions: Table<GraphExecutionContextId, ExecutionContext>,
}
impl WasiNnCtx {
/// Make a new context from the default state.
pub fn new(backends: impl IntoIterator<Item = Backend>, registry: Registry) -> Self {
let backends = backends.into_iter().map(|b| (b.encoding(), b)).collect();
Self {
backends,
registry,
graphs: Table::default(),
executions: Table::default(),
}
}
}
/// Possible errors while interacting with [WasiNnCtx].
#[derive(Debug, Error)]
pub enum WasiNnError {
#[error("backend error")]
BackendError(#[from] BackendError),
#[error("guest error")]
GuestError(#[from] GuestError),
#[error("usage error")]
UsageError(#[from] UsageError),
}
#[derive(Debug, Error)]
pub enum UsageError {
#[error("Invalid context; has the load function been called?")]
InvalidContext,
#[error("Only OpenVINO's IR is currently supported, passed encoding: {0:?}")]
InvalidEncoding(GraphEncoding),
#[error("OpenVINO expects only two buffers (i.e. [ir, weights]), passed: {0}")]
InvalidNumberOfBuilders(u32),
#[error("Invalid graph handle; has it been loaded?")]
InvalidGraphHandle,
#[error("Invalid execution context handle; has it been initialized?")]
InvalidExecutionContextHandle,
#[error("Not enough memory to copy tensor data of size: {0}")]
NotEnoughMemory(u32),
#[error("No graph found with name: {0}")]
NotFound(String),
}
pub(crate) type WasiNnResult<T> = std::result::Result<T, WasiNnError>;
/// Record handle entries in a table.
pub struct Table<K, V> {
entries: HashMap<K, V>,
next_key: u32,
}
impl<K, V> Default for Table<K, V> {
fn default() -> Self {
Self {
entries: HashMap::new(),
next_key: 0,
}
}
}
impl<K, V> Table<K, V>
where
K: Eq + Hash + From<u32> + Copy,
{
pub fn insert(&mut self, value: V) -> K {
let key = self.use_next_key();
self.entries.insert(key, value);
key
}
pub fn get(&self, key: K) -> Option<&V> {
self.entries.get(&key)
}
pub fn get_mut(&mut self, key: K) -> Option<&mut V> {
self.entries.get_mut(&key)
}
fn use_next_key(&mut self) -> K {
let current = self.next_key;
self.next_key += 1;
K::from(current)
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::registry::GraphRegistry;
#[test]
fn example() {
struct FakeRegistry;
impl GraphRegistry for FakeRegistry {
fn get_mut(&mut self, _: &str) -> Option<&mut Graph> {
None
}
}
let _ctx = WasiNnCtx::new([], Registry::from(FakeRegistry));
}
}

51
crates/wasi-nn/src/lib.rs

@ -1,14 +1,34 @@
mod ctx;
mod registry;
pub mod backend; pub mod backend;
pub use ctx::{preload, WasiNnCtx}; mod registry;
pub use registry::{GraphRegistry, InMemoryRegistry};
pub mod wit; pub mod wit;
pub mod witx; pub mod witx;
use anyhow::anyhow;
use core::fmt;
pub use registry::{GraphRegistry, InMemoryRegistry};
use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
/// Construct an in-memory registry from the available backends and a list of
/// `(<backend name>, <graph directory>)`. This assumes graphs can be loaded
/// from a local directory, which is a safe assumption currently for the current
/// model types.
pub fn preload(preload_graphs: &[(String, String)]) -> anyhow::Result<(Vec<Backend>, Registry)> {
let mut backends = backend::list();
let mut registry = InMemoryRegistry::new();
for (kind, path) in preload_graphs {
let kind_ = kind.parse()?;
let backend = backends
.iter_mut()
.find(|b| b.encoding() == kind_)
.ok_or(anyhow!("unsupported backend: {}", kind))?
.as_dir_loadable()
.ok_or(anyhow!("{} does not support directory loading", kind))?;
registry.load(backend, Path::new(path))?;
}
Ok((backends, Registry::from(registry)))
}
/// A machine learning backend. /// A machine learning backend.
pub struct Backend(Box<dyn backend::BackendInner>); pub struct Backend(Box<dyn backend::BackendInner>);
impl std::ops::Deref for Backend { impl std::ops::Deref for Backend {
@ -43,6 +63,27 @@ impl std::ops::Deref for Graph {
} }
} }
/// A host-side tensor.
///
/// Eventually, this may be defined in each backend as they gain the ability to
/// hold tensors on various devices (TODO:
/// https://github.com/WebAssembly/wasi-nn/pull/70).
#[derive(Clone)]
pub struct Tensor {
dimensions: Vec<u32>,
ty: wit::TensorType,
data: Vec<u8>,
}
impl fmt::Debug for Tensor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Tensor")
.field("dimensions", &self.dimensions)
.field("ty", &self.ty)
.field("data (bytes)", &self.data.len())
.finish()
}
}
/// A backend-defined execution context. /// A backend-defined execution context.
pub struct ExecutionContext(Box<dyn backend::BackendExecutionContext>); pub struct ExecutionContext(Box<dyn backend::BackendExecutionContext>);
impl From<Box<dyn backend::BackendExecutionContext>> for ExecutionContext { impl From<Box<dyn backend::BackendExecutionContext>> for ExecutionContext {

5
crates/wasi-nn/src/registry/in_memory.rs

@ -2,7 +2,7 @@
use super::{Graph, GraphRegistry}; use super::{Graph, GraphRegistry};
use crate::backend::BackendFromDir; use crate::backend::BackendFromDir;
use crate::wit::types::ExecutionTarget; use crate::wit::ExecutionTarget;
use anyhow::{anyhow, bail}; use anyhow::{anyhow, bail};
use std::{collections::HashMap, path::Path}; use std::{collections::HashMap, path::Path};
@ -37,6 +37,9 @@ impl InMemoryRegistry {
} }
impl GraphRegistry for InMemoryRegistry { impl GraphRegistry for InMemoryRegistry {
fn get(&self, name: &str) -> Option<&Graph> {
self.0.get(name)
}
fn get_mut(&mut self, name: &str) -> Option<&mut Graph> { fn get_mut(&mut self, name: &str) -> Option<&mut Graph> {
self.0.get_mut(name) self.0.get_mut(name)
} }

1
crates/wasi-nn/src/registry/mod.rs

@ -12,5 +12,6 @@ use crate::Graph;
pub use in_memory::InMemoryRegistry; pub use in_memory::InMemoryRegistry;
pub trait GraphRegistry: Send + Sync { pub trait GraphRegistry: Send + Sync {
fn get(&self, name: &str) -> Option<&Graph>;
fn get_mut(&mut self, name: &str) -> Option<&mut Graph>; fn get_mut(&mut self, name: &str) -> Option<&mut Graph>;
} }

330
crates/wasi-nn/src/wit.rs

@ -15,8 +15,69 @@
//! [`Backend`]: crate::Backend //! [`Backend`]: crate::Backend
//! [`types`]: crate::wit::types //! [`types`]: crate::wit::types
use crate::{ctx::UsageError, WasiNnCtx}; use crate::backend::Id;
use std::{error::Error, fmt, hash::Hash, str::FromStr}; use crate::{Backend, Registry};
use std::collections::HashMap;
use std::hash::Hash;
use std::{fmt, str::FromStr};
use wasmtime::component::{Resource, ResourceTable};
/// Capture the state necessary for calling into the backend ML libraries.
pub struct WasiNnCtx {
pub(crate) backends: HashMap<GraphEncoding, Backend>,
pub(crate) registry: Registry,
}
impl WasiNnCtx {
/// Make a new context from the default state.
pub fn new(backends: impl IntoIterator<Item = Backend>, registry: Registry) -> Self {
let backends = backends.into_iter().map(|b| (b.encoding(), b)).collect();
Self { backends, registry }
}
}
/// A wrapper capturing the needed internal wasi-nn state.
///
/// Unlike other WASI proposals (see `wasmtime-wasi`, `wasmtime-wasi-http`),
/// this wrapper is not a `trait` but rather holds the references directly. This
/// remove one layer of abstraction for simplicity only, and could be added back
/// in the future if embedders need more control here.
pub struct WasiNnView<'a> {
ctx: &'a mut WasiNnCtx,
table: &'a mut ResourceTable,
}
impl<'a> WasiNnView<'a> {
/// Create a new view into the wasi-nn state.
pub fn new(table: &'a mut ResourceTable, ctx: &'a mut WasiNnCtx) -> Self {
Self { ctx, table }
}
}
pub enum Error {
/// Caller module passed an invalid argument.
InvalidArgument,
/// Invalid encoding.
InvalidEncoding,
/// The operation timed out.
Timeout,
/// Runtime Error.
RuntimeError,
/// Unsupported operation.
UnsupportedOperation,
/// Graph is too large.
TooLarge,
/// Graph not found.
NotFound,
/// A runtime error occurred that we should trap on; see `StreamError`.
Trap(anyhow::Error),
}
impl From<wasmtime::component::ResourceTableError> for Error {
fn from(error: wasmtime::component::ResourceTableError) -> Self {
Self::Trap(error.into())
}
}
/// Generate the traits and types from the `wasi-nn` WIT specification. /// Generate the traits and types from the `wasi-nn` WIT specification.
mod gen_ { mod gen_ {
@ -24,126 +85,241 @@ mod gen_ {
world: "ml", world: "ml",
path: "wit/wasi-nn.wit", path: "wit/wasi-nn.wit",
trappable_imports: true, trappable_imports: true,
with: {
// Configure all WIT http resources to be defined types in this
// crate to use the `ResourceTable` helper methods.
"wasi:nn/graph/graph": crate::Graph,
"wasi:nn/tensor/tensor": crate::Tensor,
"wasi:nn/inference/graph-execution-context": crate::ExecutionContext,
},
trappable_error_type: {
"wasi:nn/errors/error" => super::Error,
},
}); });
} }
use gen_::wasi::nn as gen; // Shortcut to the module containing the types we need. use gen_::wasi::nn::{self as gen}; // Shortcut to the module containing the types we need.
// Export the `types` used in this crate as well as `ML::add_to_linker`. // Export the `types` used in this crate as well as `ML::add_to_linker`.
pub mod types { pub mod types {
use super::gen; use super::gen;
pub use gen::graph::{ExecutionTarget, Graph, GraphEncoding}; pub use gen::errors::Error;
pub use gen::graph::{ExecutionTarget, Graph, GraphBuilder, GraphEncoding};
pub use gen::inference::GraphExecutionContext; pub use gen::inference::GraphExecutionContext;
pub use gen::tensor::{Tensor, TensorType}; pub use gen::tensor::{Tensor, TensorType};
} }
pub use gen::graph::{ExecutionTarget, Graph, GraphBuilder, GraphEncoding};
pub use gen::inference::GraphExecutionContext;
pub use gen::tensor::{Tensor, TensorData, TensorDimensions, TensorType};
pub use gen_::Ml as ML; pub use gen_::Ml as ML;
impl gen::graph::Host for WasiNnCtx { /// Add the WIT-based version of the `wasi-nn` API to a
/// Load an opaque sequence of bytes to use for inference. /// [`wasmtime::component::Linker`].
pub fn add_to_linker<T>(
l: &mut wasmtime::component::Linker<T>,
f: impl Fn(&mut T) -> WasiNnView<'_> + Send + Sync + Copy + 'static,
) -> anyhow::Result<()> {
gen::graph::add_to_linker_get_host(l, f)?;
gen::tensor::add_to_linker_get_host(l, f)?;
gen::inference::add_to_linker_get_host(l, f)?;
gen::errors::add_to_linker_get_host(l, f)?;
Ok(())
}
impl gen::graph::Host for WasiNnView<'_> {
fn load( fn load(
&mut self, &mut self,
builders: Vec<gen::graph::GraphBuilder>, builders: Vec<GraphBuilder>,
encoding: gen::graph::GraphEncoding, encoding: GraphEncoding,
target: gen::graph::ExecutionTarget, target: ExecutionTarget,
) -> wasmtime::Result<Result<gen::graph::Graph, gen::errors::Error>> { ) -> Result<Resource<crate::Graph>, Error> {
let graph = if let Some(backend) = self.backends.get_mut(&encoding) { tracing::debug!("load {encoding:?} {target:?}");
if let Some(backend) = self.ctx.backends.get_mut(&encoding) {
let slices = builders.iter().map(|s| s.as_slice()).collect::<Vec<_>>(); let slices = builders.iter().map(|s| s.as_slice()).collect::<Vec<_>>();
backend.load(&slices, target.into())? match backend.load(&slices, target.into()) {
Ok(graph) => {
let graph = self.table.push(graph)?;
Ok(graph)
}
Err(error) => {
tracing::error!("failed to load graph: {error:?}");
Err(Error::RuntimeError)
}
}
} else { } else {
return Err(UsageError::InvalidEncoding(encoding.into()).into()); Err(Error::InvalidEncoding)
}; }
let graph_id = self.graphs.insert(graph);
Ok(Ok(graph_id))
} }
fn load_by_name( fn load_by_name(&mut self, name: String) -> Result<Resource<Graph>, Error> {
&mut self, use core::result::Result::*;
name: String, tracing::debug!("load by name {name:?}");
) -> wasmtime::Result<Result<gen::graph::Graph, gen::errors::Error>> { let registry = &self.ctx.registry;
if let Some(graph) = self.registry.get_mut(&name) { if let Some(graph) = registry.get(&name) {
let graph_id = self.graphs.insert(graph.clone().into()); let graph = graph.clone();
Ok(Ok(graph_id)) let graph = self.table.push(graph)?;
Ok(graph)
} else { } else {
return Err(UsageError::NotFound(name.to_string()).into()); tracing::error!("failed to find graph with name: {name}");
Err(Error::NotFound)
} }
} }
} }
impl gen::inference::Host for WasiNnCtx { impl gen::graph::HostGraph for WasiNnView<'_> {
/// Create an execution instance of a loaded graph.
///
/// TODO: remove completely?
fn init_execution_context( fn init_execution_context(
&mut self, &mut self,
graph_id: gen::graph::Graph, graph: Resource<Graph>,
) -> wasmtime::Result<Result<gen::inference::GraphExecutionContext, gen::errors::Error>> { ) -> Result<Resource<GraphExecutionContext>, Error> {
let exec_context = if let Some(graph) = self.graphs.get(graph_id) { use core::result::Result::*;
graph.init_execution_context()? tracing::debug!("initialize execution context");
} else { let graph = self.table.get(&graph)?;
return Err(UsageError::InvalidGraphHandle.into()); match graph.init_execution_context() {
}; Ok(exec_context) => {
let exec_context = self.table.push(exec_context)?;
Ok(exec_context)
}
Err(error) => {
tracing::error!("failed to initialize execution context: {error:?}");
Err(Error::RuntimeError)
}
}
}
let exec_context_id = self.executions.insert(exec_context); fn drop(&mut self, graph: Resource<Graph>) -> wasmtime::Result<()> {
Ok(Ok(exec_context_id)) self.table.delete(graph)?;
Ok(())
} }
}
/// Define the inputs to use for inference. impl gen::inference::HostGraphExecutionContext for WasiNnView<'_> {
fn set_input( fn set_input(
&mut self, &mut self,
exec_context_id: gen::inference::GraphExecutionContext, exec_context: Resource<GraphExecutionContext>,
index: u32, name: String,
tensor: gen::tensor::Tensor, tensor: Resource<Tensor>,
) -> wasmtime::Result<Result<(), gen::errors::Error>> { ) -> Result<(), Error> {
if let Some(exec_context) = self.executions.get_mut(exec_context_id) { let tensor = self.table.get(&tensor)?;
exec_context.set_input(index, &tensor)?; tracing::debug!("set input {name:?}: {tensor:?}");
Ok(Ok(())) let tensor = tensor.clone(); // TODO: avoid copying the tensor
let exec_context = self.table.get_mut(&exec_context)?;
if let Err(e) = exec_context.set_input(Id::Name(name), &tensor) {
tracing::error!("failed to set input: {e:?}");
Err(Error::InvalidArgument)
} else { } else {
Err(UsageError::InvalidGraphHandle.into()) Ok(())
} }
} }
/// Compute the inference on the given inputs. fn compute(&mut self, exec_context: Resource<GraphExecutionContext>) -> Result<(), Error> {
/// let exec_context = &mut self.table.get_mut(&exec_context)?;
/// TODO: refactor to compute(list<tensor>) -> result<list<tensor>, error> tracing::debug!("compute");
fn compute( match exec_context.compute() {
&mut self, Ok(()) => Ok(()),
exec_context_id: gen::inference::GraphExecutionContext, Err(error) => {
) -> wasmtime::Result<Result<(), gen::errors::Error>> { tracing::error!("failed to compute: {error:?}");
if let Some(exec_context) = self.executions.get_mut(exec_context_id) { Err(Error::RuntimeError)
exec_context.compute()?; }
Ok(Ok(()))
} else {
Err(UsageError::InvalidExecutionContextHandle.into())
} }
} }
/// Extract the outputs after inference. #[doc = r" Extract the outputs after inference."]
fn get_output( fn get_output(
&mut self, &mut self,
exec_context_id: gen::inference::GraphExecutionContext, exec_context: Resource<GraphExecutionContext>,
index: u32, name: String,
) -> wasmtime::Result<Result<gen::tensor::TensorData, gen::errors::Error>> { ) -> Result<Resource<Tensor>, Error> {
if let Some(exec_context) = self.executions.get_mut(exec_context_id) { let exec_context = self.table.get_mut(&exec_context)?;
// Read the output bytes. TODO: this involves a hard-coded upper tracing::debug!("get output {name:?}");
// limit on the tensor size that is necessary because there is no match exec_context.get_output(Id::Name(name)) {
// way to introspect the graph outputs Ok(tensor) => {
// (https://github.com/WebAssembly/wasi-nn/issues/37). let tensor = self.table.push(tensor)?;
let mut destination = vec![0; 1024 * 1024]; Ok(tensor)
let bytes_read = exec_context.get_output(index, &mut destination)?; }
destination.truncate(bytes_read as usize); Err(error) => {
Ok(Ok(destination)) tracing::error!("failed to get output: {error:?}");
} else { Err(Error::RuntimeError)
Err(UsageError::InvalidGraphHandle.into()) }
} }
} }
fn drop(&mut self, exec_context: Resource<GraphExecutionContext>) -> wasmtime::Result<()> {
self.table.delete(exec_context)?;
Ok(())
}
} }
impl gen::errors::Host for WasiNnCtx {} impl gen::tensor::HostTensor for WasiNnView<'_> {
fn new(
&mut self,
dimensions: TensorDimensions,
ty: TensorType,
data: TensorData,
) -> wasmtime::Result<Resource<Tensor>> {
let tensor = Tensor {
dimensions,
ty,
data,
};
let tensor = self.table.push(tensor)?;
Ok(tensor)
}
fn dimensions(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<TensorDimensions> {
let tensor = self.table.get(&tensor)?;
Ok(tensor.dimensions.clone())
}
fn ty(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<TensorType> {
let tensor = self.table.get(&tensor)?;
Ok(tensor.ty)
}
fn data(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<TensorData> {
let tensor = self.table.get(&tensor)?;
Ok(tensor.data.clone())
}
impl gen::tensor::Host for WasiNnCtx {} fn drop(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<()> {
self.table.delete(tensor)?;
Ok(())
}
}
impl gen::tensor::Host for WasiNnView<'_> {}
impl gen::errors::Host for WasiNnView<'_> {
fn convert_error(&mut self, err: Error) -> wasmtime::Result<gen::errors::Error> {
match err {
Error::InvalidArgument => Ok(gen::errors::Error::InvalidArgument),
Error::InvalidEncoding => Ok(gen::errors::Error::InvalidEncoding),
Error::Timeout => Ok(gen::errors::Error::Timeout),
Error::RuntimeError => Ok(gen::errors::Error::RuntimeError),
Error::UnsupportedOperation => Ok(gen::errors::Error::UnsupportedOperation),
Error::TooLarge => Ok(gen::errors::Error::TooLarge),
Error::NotFound => Ok(gen::errors::Error::NotFound),
Error::Trap(e) => Err(e),
}
}
}
impl gen::inference::Host for WasiNnView<'_> {}
impl Hash for gen::graph::GraphEncoding { impl Hash for gen::graph::GraphEncoding {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) { fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
core::mem::discriminant(self).hash(state); self.to_string().hash(state)
}
}
impl fmt::Display for gen::graph::GraphEncoding {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use gen::graph::GraphEncoding::*;
match self {
Openvino => write!(f, "openvino"),
Onnx => write!(f, "onnx"),
Pytorch => write!(f, "pytorch"),
Tensorflow => write!(f, "tensorflow"),
Tensorflowlite => write!(f, "tensorflowlite"),
Autodetect => write!(f, "autodetect"),
Ggml => write!(f, "ggml"),
}
} }
} }
@ -168,4 +344,4 @@ impl fmt::Display for GraphEncodingParseError {
write!(f, "unknown graph encoding: {}", self.0) write!(f, "unknown graph encoding: {}", self.0)
} }
} }
impl Error for GraphEncodingParseError {} impl std::error::Error for GraphEncodingParseError {}

122
crates/wasi-nn/src/witx.rs

@ -13,11 +13,83 @@
//! //!
//! [`types`]: crate::wit::types //! [`types`]: crate::wit::types
use crate::ctx::{UsageError, WasiNnCtx, WasiNnError, WasiNnResult as Result}; use crate::backend::BackendError;
use wiggle::{GuestMemory, GuestPtr}; use crate::backend::Id;
use crate::wit::GraphEncoding;
use crate::{Backend, ExecutionContext, Graph, Registry};
use std::collections::HashMap;
use std::hash::Hash;
use thiserror::Error;
use wiggle::{GuestError, GuestMemory, GuestPtr};
pub use gen::wasi_ephemeral_nn::add_to_linker; pub use gen::wasi_ephemeral_nn::add_to_linker;
pub(crate) type WasiNnResult<T> = std::result::Result<T, WasiNnError>;
type Result<T> = WasiNnResult<T>;
type GraphId = u32;
type GraphExecutionContextId = u32;
/// Capture the state necessary for calling into the backend ML libraries.
pub struct WasiNnCtx {
pub(crate) backends: HashMap<GraphEncoding, Backend>,
pub(crate) registry: Registry,
pub(crate) graphs: Table<GraphId, Graph>,
pub(crate) executions: Table<GraphExecutionContextId, ExecutionContext>,
}
impl WasiNnCtx {
/// Make a new context from the default state.
pub fn new(backends: impl IntoIterator<Item = Backend>, registry: Registry) -> Self {
let backends = backends.into_iter().map(|b| (b.encoding(), b)).collect();
Self {
backends,
registry,
graphs: Table::default(),
executions: Table::default(),
}
}
}
/// Record handle entries in a table.
pub struct Table<K, V> {
entries: HashMap<K, V>,
next_key: u32,
}
impl<K, V> Default for Table<K, V> {
fn default() -> Self {
Self {
entries: HashMap::new(),
next_key: 0,
}
}
}
impl<K, V> Table<K, V>
where
K: Eq + Hash + From<u32> + Copy,
{
pub fn insert(&mut self, value: V) -> K {
let key = self.use_next_key();
self.entries.insert(key, value);
key
}
pub fn get(&self, key: K) -> Option<&V> {
self.entries.get(&key)
}
pub fn get_mut(&mut self, key: K) -> Option<&mut V> {
self.entries.get_mut(&key)
}
fn use_next_key(&mut self) -> K {
let current = self.next_key;
self.next_key += 1;
K::from(current)
}
}
/// Generate the traits and types from the `wasi-nn` WITX specification. /// Generate the traits and types from the `wasi-nn` WITX specification.
mod gen { mod gen {
use super::*; use super::*;
@ -42,9 +114,10 @@ mod gen {
) -> anyhow::Result<types::NnErrno> { ) -> anyhow::Result<types::NnErrno> {
tracing::debug!("host error: {:?}", e); tracing::debug!("host error: {:?}", e);
match e { match e {
WasiNnError::BackendError(_) => unimplemented!(), WasiNnError::BackendError(_) => Ok(types::NnErrno::RuntimeError),
WasiNnError::GuestError(_) => unimplemented!(), WasiNnError::GuestError(_) => unimplemented!("guest error conversion"),
WasiNnError::UsageError(_) => unimplemented!(), WasiNnError::UsageError(_) => Ok(types::NnErrno::UnsupportedOperation),
WasiNnError::NotEnoughMemory(_) => Ok(types::NnErrno::TooLarge),
} }
} }
} }
@ -119,10 +192,10 @@ impl gen::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx {
if let Some(exec_context) = self.executions.get_mut(exec_context_id.into()) { if let Some(exec_context) = self.executions.get_mut(exec_context_id.into()) {
let tensor = crate::wit::types::Tensor { let tensor = crate::wit::types::Tensor {
dimensions: memory.to_vec(tensor.dimensions)?, dimensions: memory.to_vec(tensor.dimensions)?,
tensor_type: tensor.type_.into(), ty: tensor.type_.into(),
data: memory.to_vec(tensor.data)?, data: memory.to_vec(tensor.data)?,
}; };
Ok(exec_context.set_input(index, &tensor)?) Ok(exec_context.set_input(Id::Index(index), &tensor)?)
} else { } else {
Err(UsageError::InvalidGraphHandle.into()) Err(UsageError::InvalidGraphHandle.into())
} }
@ -149,13 +222,19 @@ impl gen::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx {
out_buffer_max_size: u32, out_buffer_max_size: u32,
) -> Result<u32> { ) -> Result<u32> {
if let Some(exec_context) = self.executions.get_mut(exec_context_id.into()) { if let Some(exec_context) = self.executions.get_mut(exec_context_id.into()) {
let mut destination = memory let tensor = exec_context.get_output(Id::Index(index))?;
let destination = memory
.as_slice_mut(out_buffer.as_array(out_buffer_max_size))? .as_slice_mut(out_buffer.as_array(out_buffer_max_size))?
.expect( .expect(
"cannot use with shared memories; \ "cannot use with shared memories; \
see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)", see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)",
); );
Ok(exec_context.get_output(index, &mut destination)?) if tensor.data.len() > destination.len() {
Err(WasiNnError::NotEnoughMemory(tensor.data.len()))
} else {
destination[..tensor.data.len()].copy_from_slice(&tensor.data);
Ok(tensor.data.len() as u32)
}
} else { } else {
Err(UsageError::InvalidGraphHandle.into()) Err(UsageError::InvalidGraphHandle.into())
} }
@ -199,3 +278,28 @@ impl From<gen::types::TensorType> for crate::wit::types::TensorType {
} }
} }
} }
/// Possible errors while interacting with [WasiNnCtx].
#[derive(Debug, Error)]
pub enum WasiNnError {
#[error("backend error")]
BackendError(#[from] BackendError),
#[error("guest error")]
GuestError(#[from] GuestError),
#[error("usage error")]
UsageError(#[from] UsageError),
#[error("not enough memory: requested {0} bytes")]
NotEnoughMemory(usize),
}
#[derive(Debug, Error)]
pub enum UsageError {
#[error("Only OpenVINO's IR is currently supported, passed encoding: {0:?}")]
InvalidEncoding(GraphEncoding),
#[error("Invalid graph handle; has it been loaded?")]
InvalidGraphHandle,
#[error("Invalid execution context handle; has it been initialized?")]
InvalidExecutionContextHandle,
#[error("No graph found with name: {0}")]
NotFound(String),
}

8
crates/wasi-nn/tests/check/mod.rs

@ -1,10 +1,8 @@
//! This is testing-specific code--it is public only so that it can be //! Check that the environment is set up correctly for running tests.
//! accessible both in unit and integration tests.
//! //!
//! This module checks: //! This module checks:
//! - that OpenVINO can be found in the environment //! - that various backends can be located on the system (see sub-modules)
//! - that WinML is available //! - that certain ML model artifacts can be downloaded and cached.
//! - that some ML model artifacts can be downloaded and cached.
#[allow(unused_imports)] #[allow(unused_imports)]
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};

54
crates/wasi-nn/tests/exec/mod.rs

@ -1,52 +1,6 @@
use crate::check::artifacts_dir; //! Provide a Wasmtime embedding for executing wasi-nn test programs.
use anyhow::Result;
use std::path::Path;
use wasi_common::sync::{Dir, WasiCtxBuilder};
use wasi_common::WasiCtx;
use wasmtime::{Config, Engine, Linker, Module, Store};
use wasmtime_wasi_nn::{Backend, InMemoryRegistry, WasiNnCtx};
const PREOPENED_DIR_NAME: &str = "fixture"; pub mod wit;
pub mod witx;
/// Run a wasi-nn test program. This is modeled after pub const PREOPENED_DIR_NAME: &str = "fixture";
/// `crates/wasi/tests/all/main.rs` but still uses the older preview1 API
/// for file reads.
pub fn run(path: &str, backend: Backend, preload_model: bool) -> Result<()> {
let path = Path::new(path);
let engine = Engine::new(&Config::new())?;
let mut linker = Linker::new(&engine);
wasmtime_wasi_nn::witx::add_to_linker(&mut linker, |s: &mut Ctx| &mut s.wasi_nn)?;
wasi_common::sync::add_to_linker(&mut linker, |s: &mut Ctx| &mut s.wasi)?;
let module = Module::from_file(&engine, path)?;
let mut store = Store::new(&engine, Ctx::new(&artifacts_dir(), preload_model, backend)?);
let instance = linker.instantiate(&mut store, &module)?;
let start = instance.get_typed_func::<(), ()>(&mut store, "_start")?;
start.call(&mut store, ())?;
Ok(())
}
/// The host state for running wasi-nn tests.
struct Ctx {
wasi: WasiCtx,
wasi_nn: WasiNnCtx,
}
impl Ctx {
fn new(preopen_dir: &Path, preload_model: bool, mut backend: Backend) -> Result<Self> {
let preopen_dir = Dir::open_ambient_dir(preopen_dir, cap_std::ambient_authority())?;
let mut builder = WasiCtxBuilder::new();
builder
.inherit_stdio()
.preopened_dir(preopen_dir, PREOPENED_DIR_NAME)?;
let wasi = builder.build();
let mut registry = InMemoryRegistry::new();
let mobilenet_dir = artifacts_dir();
if preload_model {
registry.load((backend).as_dir_loadable().unwrap(), &mobilenet_dir)?;
}
let wasi_nn = WasiNnCtx::new([backend.into()], registry.into());
Ok(Self { wasi, wasi_nn })
}
}

73
crates/wasi-nn/tests/exec/wit.rs

@ -0,0 +1,73 @@
use super::PREOPENED_DIR_NAME;
use crate::check::artifacts_dir;
use anyhow::{anyhow, Result};
use std::path::Path;
use wasmtime::component::{Component, Linker, ResourceTable};
use wasmtime::{Config, Engine, Store};
use wasmtime_wasi::bindings::sync::Command;
use wasmtime_wasi::{DirPerms, FilePerms, WasiCtx, WasiCtxBuilder};
use wasmtime_wasi_nn::wit::WasiNnView;
use wasmtime_wasi_nn::{wit::WasiNnCtx, Backend, InMemoryRegistry};
/// Run a wasi-nn test program. This is modeled after
/// `crates/wasi/tests/all/main.rs` but still uses the older preview1 API for
/// file reads.
pub fn run(path: &str, backend: Backend, preload_model: bool) -> Result<()> {
let path = Path::new(path);
let engine = Engine::new(&Config::new())?;
let mut linker = Linker::new(&engine);
wasmtime_wasi_nn::wit::add_to_linker(&mut linker, |c: &mut Ctx| {
WasiNnView::new(&mut c.table, &mut c.wasi_nn)
})?;
wasmtime_wasi::add_to_linker_sync(&mut linker)?;
let module = Component::from_file(&engine, path)?;
let mut store = Store::new(&engine, Ctx::new(&artifacts_dir(), preload_model, backend)?);
let command = Command::instantiate(&mut store, &module, &linker)?;
let result = command.wasi_cli_run().call_run(&mut store)?;
result.map_err(|_| anyhow!("failed to run command"))
}
/// The host state for running wasi-nn component tests.
struct Ctx {
wasi: WasiCtx,
wasi_nn: WasiNnCtx,
table: ResourceTable,
}
impl Ctx {
fn new(preopen_dir: &Path, preload_model: bool, mut backend: Backend) -> Result<Self> {
let mut builder = WasiCtxBuilder::new();
builder.inherit_stdio().preopened_dir(
preopen_dir,
PREOPENED_DIR_NAME,
DirPerms::READ,
FilePerms::READ,
)?;
let wasi = builder.build();
let mut registry = InMemoryRegistry::new();
let mobilenet_dir = artifacts_dir();
if preload_model {
registry.load((backend).as_dir_loadable().unwrap(), &mobilenet_dir)?;
}
let wasi_nn = WasiNnCtx::new([backend.into()], registry.into());
let table = ResourceTable::new();
Ok(Self {
wasi,
wasi_nn,
table,
})
}
}
impl wasmtime_wasi::WasiView for Ctx {
fn ctx(&mut self) -> &mut WasiCtx {
&mut self.wasi
}
fn table(&mut self) -> &mut ResourceTable {
&mut self.table
}
}

52
crates/wasi-nn/tests/exec/witx.rs

@ -0,0 +1,52 @@
use super::PREOPENED_DIR_NAME;
use crate::check::artifacts_dir;
use anyhow::Result;
use std::path::Path;
use wasmtime::{Config, Engine, Linker, Module, Store};
use wasmtime_wasi::{preview1::WasiP1Ctx, DirPerms, FilePerms, WasiCtxBuilder};
use wasmtime_wasi_nn::{witx::WasiNnCtx, Backend, InMemoryRegistry};
/// Run a wasi-nn test program. This is modeled after
/// `crates/wasi/tests/all/main.rs` but still uses the older preview1 API
/// for file reads.
pub fn run(path: &str, backend: Backend, preload_model: bool) -> Result<()> {
let path = Path::new(path);
let engine = Engine::new(&Config::new())?;
let mut linker = Linker::new(&engine);
wasmtime_wasi_nn::witx::add_to_linker(&mut linker, |s: &mut Ctx| &mut s.wasi_nn)?;
wasmtime_wasi::preview1::add_to_linker_sync(&mut linker, |s: &mut Ctx| &mut s.wasi)?;
let module = Module::from_file(&engine, path)?;
let mut store = Store::new(&engine, Ctx::new(&artifacts_dir(), preload_model, backend)?);
let instance = linker.instantiate(&mut store, &module)?;
let start = instance.get_typed_func::<(), ()>(&mut store, "_start")?;
start.call(&mut store, ())?;
Ok(())
}
/// The host state for running wasi-nn tests.
struct Ctx {
wasi: WasiP1Ctx,
wasi_nn: WasiNnCtx,
}
impl Ctx {
fn new(preopen_dir: &Path, preload_model: bool, mut backend: Backend) -> Result<Self> {
let mut builder = WasiCtxBuilder::new();
builder.inherit_stdio().preopened_dir(
preopen_dir,
PREOPENED_DIR_NAME,
DirPerms::READ,
FilePerms::READ,
)?;
let wasi = builder.build_p1();
let mut registry = InMemoryRegistry::new();
let mobilenet_dir = artifacts_dir();
if preload_model {
registry.load((backend).as_dir_loadable().unwrap(), &mobilenet_dir)?;
}
let wasi_nn = WasiNnCtx::new([backend.into()], registry.into());
Ok(Self { wasi, wasi_nn })
}
}

0
crates/wasi-nn/tests/fixtures/readme.md → crates/wasi-nn/tests/fixtures/README.md

200
crates/wasi-nn/tests/test-programs.rs

@ -23,6 +23,8 @@ use test_programs_artifacts::*;
use wasmtime_wasi_nn::{backend, Backend}; use wasmtime_wasi_nn::{backend, Backend};
fn main() -> Result<()> { fn main() -> Result<()> {
tracing_subscriber::fmt::init();
if cfg!(miri) { if cfg!(miri) {
return Ok(()); return Ok(());
} }
@ -45,7 +47,7 @@ fn main() -> Result<()> {
let mut trials = Vec::new(); let mut trials = Vec::new();
for program in programs { for program in programs {
// Either ignore the test if it cannot run (i.e., downgrade `Fail` to // Either ignore the test if it cannot run (i.e., downgrade `Fail` to
// `Ignore`) or pre-emptively fail it if `error_on_failed_check` is set. // `Ignore`) or preemptively fail it if `error_on_failed_check` is set.
let (run_test, mut check) = check_test_program(program); let (run_test, mut check) = check_test_program(program);
if !error_on_failed_check { if !error_on_failed_check {
check = check.downgrade_failure(); // Downgrade `Fail` to `Ignore`. check = check.downgrade_failure(); // Downgrade `Fail` to `Ignore`.
@ -68,103 +70,122 @@ fn main() -> Result<()> {
/// Return the test program to run and a check that must pass for the test to /// Return the test program to run and a check that must pass for the test to
/// run. /// run.
fn check_test_program(name: &str) -> (fn() -> Result<()>, IgnoreCheck) { fn check_test_program(name: &str) -> (fn() -> Result<()>, IgnoreCheck) {
use IgnoreCheck::*;
match name { match name {
"nn_image_classification" => ( // Legacy WITX-based tests:
nn_image_classification, "nn_witx_image_classification_openvino" => (
if !cfg!(target_arch = "x86_64") { nn_witx_image_classification_openvino,
Fail("requires x86_64".into()) IgnoreCheck::for_openvino(),
} else if !cfg!(target_os = "linux") && !cfg!(target_os = "windows") { ),
Fail("requires linux or windows".into()) "nn_witx_image_classification_openvino_named" => (
} else if let Err(e) = check::openvino::is_installed() { nn_witx_image_classification_openvino_named,
Fail(e.to_string().into()) IgnoreCheck::for_openvino(),
} else { ),
Run "nn_witx_image_classification_onnx" => {
}, (nn_witx_image_classification_onnx, IgnoreCheck::for_onnx())
}
"nn_witx_image_classification_winml_named" => (
nn_witx_image_classification_winml_named,
IgnoreCheck::for_winml(),
), ),
"nn_image_classification_named" => ( // WIT-based tests:
nn_image_classification_named, "nn_wit_image_classification_openvino" => (
if !cfg!(target_arch = "x86_64") { nn_wit_image_classification_openvino,
Fail("requires x86_64".into()) IgnoreCheck::for_openvino(),
} else if !cfg!(target_os = "linux") && !cfg!(target_os = "windows") {
Fail("requires linux or windows or macos".into())
} else if let Err(e) = check::openvino::is_installed() {
Fail(e.to_string().into())
} else {
Run
},
), ),
"nn_image_classification_onnx" => ( "nn_wit_image_classification_openvino_named" => (
nn_image_classification_onnx, nn_wit_image_classification_openvino_named,
#[cfg(feature = "onnx")] IgnoreCheck::for_openvino(),
if !cfg!(target_arch = "x86_64") && !cfg!(target_arch = "aarch64") {
Fail("requires x86_64 or aarch64".into())
} else if !cfg!(target_os = "linux")
&& !cfg!(target_os = "windows")
&& !cfg!(target_os = "macos")
{
Fail("requires linux, windows, or macos".into())
} else {
Run
},
#[cfg(not(feature = "onnx"))]
Ignore("requires the `onnx` feature".into()),
), ),
"nn_image_classification_winml" => ( "nn_wit_image_classification_onnx" => {
nn_image_classification_winml, (nn_wit_image_classification_onnx, IgnoreCheck::for_onnx())
#[cfg(all(feature = "winml", target_os = "windows"))] }
if !cfg!(target_arch = "x86_64") { "nn_wit_image_classification_winml_named" => (
Fail("requires x86_64".into()) nn_wit_image_classification_winml_named,
} else if cfg!(target_os = "windows") { IgnoreCheck::for_winml(),
Fail("requires windows".into())
} else if let Err(e) = check::winml::is_available() {
Fail(e.to_string().into())
} else {
Run
},
#[cfg(not(all(feature = "winml", target_os = "windows")))]
Ignore("requires the `winml` feature on windows".into()),
), ),
_ => panic!("unknown test program: {} (add to this `match`)", name), _ => panic!("unknown test program: {} (add to this `match`)", name),
} }
} }
fn nn_image_classification() -> Result<()> { fn nn_witx_image_classification_openvino() -> Result<()> {
check::openvino::is_installed()?; check::openvino::is_installed()?;
check::openvino::are_artifacts_available()?; check::openvino::are_artifacts_available()?;
let backend = Backend::from(backend::openvino::OpenvinoBackend::default()); let backend = Backend::from(backend::openvino::OpenvinoBackend::default());
exec::run(NN_IMAGE_CLASSIFICATION, backend, false) exec::witx::run(NN_WITX_IMAGE_CLASSIFICATION_OPENVINO, backend, false)
} }
fn nn_image_classification_named() -> Result<()> { fn nn_witx_image_classification_openvino_named() -> Result<()> {
check::openvino::is_installed()?; check::openvino::is_installed()?;
check::openvino::are_artifacts_available()?; check::openvino::are_artifacts_available()?;
let backend = Backend::from(backend::openvino::OpenvinoBackend::default()); let backend = Backend::from(backend::openvino::OpenvinoBackend::default());
exec::run(NN_IMAGE_CLASSIFICATION_NAMED, backend, true) exec::witx::run(NN_WITX_IMAGE_CLASSIFICATION_OPENVINO_NAMED, backend, true)
} }
#[cfg(feature = "onnx")] #[cfg(feature = "onnx")]
fn nn_image_classification_onnx() -> Result<()> { fn nn_witx_image_classification_onnx() -> Result<()> {
check::onnx::are_artifacts_available()?; check::onnx::are_artifacts_available()?;
let backend = Backend::from(backend::onnxruntime::OnnxBackend::default()); let backend = Backend::from(backend::onnx::OnnxBackend::default());
exec::run(NN_IMAGE_CLASSIFICATION_ONNX, backend, false) exec::witx::run(NN_WITX_IMAGE_CLASSIFICATION_ONNX, backend, false)
} }
#[cfg(not(feature = "onnx"))] #[cfg(not(feature = "onnx"))]
fn nn_image_classification_onnx() -> Result<()> { fn nn_witx_image_classification_onnx() -> Result<()> {
anyhow::bail!("this test requires the `onnx` feature") anyhow::bail!("this test requires the `onnx` feature")
} }
#[cfg(all(feature = "winml", target_os = "windows"))] #[cfg(all(feature = "winml", target_os = "windows"))]
fn nn_image_classification_winml() -> Result<()> { fn nn_witx_image_classification_winml_named() -> Result<()> {
check::winml::is_available()?; check::winml::is_available()?;
check::onnx::are_artifacts_available()?; check::onnx::are_artifacts_available()?;
let backend = Backend::from(backend::winml::WinMLBackend::default()); let backend = Backend::from(backend::winml::WinMLBackend::default());
exec::run(NN_IMAGE_CLASSIFICATION_ONNX, backend, false) exec::witx::run(NN_WITX_IMAGE_CLASSIFICATION_ONNX, backend, false)
}
#[cfg(not(all(feature = "winml", target_os = "windows")))]
fn nn_witx_image_classification_winml_named() -> Result<()> {
anyhow::bail!("this test requires the `winml` feature and only runs on windows")
} }
fn nn_wit_image_classification_openvino() -> Result<()> {
check::openvino::is_installed()?;
check::openvino::are_artifacts_available()?;
let backend = Backend::from(backend::openvino::OpenvinoBackend::default());
exec::wit::run(
NN_WIT_IMAGE_CLASSIFICATION_OPENVINO_COMPONENT,
backend,
false,
)
}
fn nn_wit_image_classification_openvino_named() -> Result<()> {
check::openvino::is_installed()?;
check::openvino::are_artifacts_available()?;
let backend = Backend::from(backend::openvino::OpenvinoBackend::default());
exec::wit::run(
NN_WIT_IMAGE_CLASSIFICATION_OPENVINO_NAMED_COMPONENT,
backend,
true,
)
}
#[cfg(feature = "onnx")]
fn nn_wit_image_classification_onnx() -> Result<()> {
check::onnx::are_artifacts_available()?;
let backend = Backend::from(backend::onnx::OnnxBackend::default());
exec::wit::run(NN_WIT_IMAGE_CLASSIFICATION_ONNX_COMPONENT, backend, false)
}
#[cfg(not(feature = "onnx"))]
fn nn_wit_image_classification_onnx() -> Result<()> {
anyhow::bail!("this test requires the `onnx` feature")
}
#[cfg(all(feature = "winml", target_os = "windows"))]
fn nn_wit_image_classification_winml_named() -> Result<()> {
check::winml::is_available()?;
check::onnx::are_artifacts_available()?;
let backend = Backend::from(backend::winml::WinMLBackend::default());
exec::wit::run(NN_WIT_IMAGE_CLASSIFICATION_ONNX_COMPONENT, backend, false)
}
#[cfg(not(all(feature = "winml", target_os = "windows")))] #[cfg(not(all(feature = "winml", target_os = "windows")))]
fn nn_image_classification_winml() -> Result<()> { fn nn_wit_image_classification_winml_named() -> Result<()> {
anyhow::bail!("this test requires the `winml` feature and only runs on windows") anyhow::bail!("this test requires the `winml` feature and only runs on windows")
} }
@ -197,3 +218,52 @@ impl IgnoreCheck {
matches!(self, IgnoreCheck::Ignore(_)) matches!(self, IgnoreCheck::Ignore(_))
} }
} }
/// Some pre-test checks for various backends.
impl IgnoreCheck {
fn for_openvino() -> IgnoreCheck {
use IgnoreCheck::*;
if !cfg!(target_arch = "x86_64") {
Fail("requires x86_64".into())
} else if !cfg!(target_os = "linux") && !cfg!(target_os = "windows") {
Fail("requires linux or windows or macos".into())
} else if let Err(e) = check::openvino::is_installed() {
Fail(e.to_string().into())
} else {
Run
}
}
fn for_onnx() -> Self {
use IgnoreCheck::*;
#[cfg(feature = "onnx")]
if !cfg!(target_arch = "x86_64") && !cfg!(target_arch = "aarch64") {
Fail("requires x86_64 or aarch64".into())
} else if !cfg!(target_os = "linux")
&& !cfg!(target_os = "windows")
&& !cfg!(target_os = "macos")
{
Fail("requires linux, windows, or macos".into())
} else {
Run
}
#[cfg(not(feature = "onnx"))]
Ignore("requires the `onnx` feature".into())
}
fn for_winml() -> IgnoreCheck {
use IgnoreCheck::*;
#[cfg(all(feature = "winml", target_os = "windows"))]
if !cfg!(target_arch = "x86_64") {
Fail("requires x86_64".into())
} else if !cfg!(target_os = "windows") {
Fail("requires windows".into())
} else if let Err(e) = check::winml::is_available() {
Fail(e.to_string().into())
} else {
Run
}
#[cfg(not(all(feature = "winml", target_os = "windows")))]
Ignore("requires the `winml` feature on windows".into())
}
}

57
crates/wasi-nn/wit/wasi-nn.wit

@ -43,16 +43,18 @@ interface tensor {
/// memory--e.g., using row-major ordering--and could perhaps be improved. /// memory--e.g., using row-major ordering--and could perhaps be improved.
type tensor-data = list<u8>; type tensor-data = list<u8>;
record tensor { resource tensor {
constructor(dimensions: tensor-dimensions, ty: tensor-type, data: tensor-data);
// Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To represent a tensor // Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To represent a tensor
// containing a single value, use `[1]` for the tensor dimensions. // containing a single value, use `[1]` for the tensor dimensions.
dimensions: tensor-dimensions, dimensions: func() -> tensor-dimensions;
// Describe the type of element in the tensor (e.g., `f32`). // Describe the type of element in the tensor (e.g., `f32`).
tensor-type: tensor-type, ty: func() -> tensor-type;
// Contains the tensor data. // Return the tensor data.
data: tensor-data, data: func() -> tensor-data;
} }
} }
@ -61,11 +63,12 @@ interface tensor {
interface graph { interface graph {
use errors.{error}; use errors.{error};
use tensor.{tensor}; use tensor.{tensor};
use inference.{graph-execution-context};
/// An execution graph for performing inference (i.e., a model). /// An execution graph for performing inference (i.e., a model).
/// resource graph {
/// TODO: replace with `resource` (https://github.com/WebAssembly/wasi-nn/issues/47). init-execution-context: func() -> result<graph-execution-context, error>;
type graph = u32; }
/// Describes the encoding of the graph. This allows the API to be implemented by various /// Describes the encoding of the graph. This allows the API to be implemented by various
/// backends that encode (i.e., serialize) their graph IR with different formats. /// backends that encode (i.e., serialize) their graph IR with different formats.
@ -75,6 +78,7 @@ interface graph {
tensorflow, tensorflow,
pytorch, pytorch,
tensorflowlite, tensorflowlite,
ggml,
autodetect, autodetect,
} }
@ -107,27 +111,25 @@ interface graph {
interface inference { interface inference {
use errors.{error}; use errors.{error};
use tensor.{tensor, tensor-data}; use tensor.{tensor, tensor-data};
use graph.{graph};
/// Bind a `graph` to the input and output tensors for an inference. /// Bind a `graph` to the input and output tensors for an inference.
/// ///
/// TODO: this is no longer necessary in WIT (https://github.com/WebAssembly/wasi-nn/issues/43) /// TODO: this may no longer be necessary in WIT
type graph-execution-context = u32; /// (https://github.com/WebAssembly/wasi-nn/issues/43)
resource graph-execution-context {
/// Create an execution instance of a loaded graph. /// Define the inputs to use for inference.
init-execution-context: func(graph: graph) -> result<graph-execution-context, error>; set-input: func(name: string, tensor: tensor) -> result<_, error>;
/// Define the inputs to use for inference. /// Compute the inference on the given inputs.
set-input: func(ctx: graph-execution-context, index: u32, tensor: tensor) -> result<_, error>; ///
/// Note the expected sequence of calls: `set-input`, `compute`, `get-output`. TODO: this
/// Compute the inference on the given inputs. /// expectation could be removed as a part of
/// /// https://github.com/WebAssembly/wasi-nn/issues/43.
/// Note the expected sequence of calls: `set-input`, `compute`, `get-output`. TODO: this compute: func() -> result<_, error>;
/// expectation could be removed as a part of https://github.com/WebAssembly/wasi-nn/issues/43.
compute: func(ctx: graph-execution-context) -> result<_, error>; /// Extract the outputs after inference.
get-output: func(name: string) -> result<tensor, error>;
/// Extract the outputs after inference. }
get-output: func(ctx: graph-execution-context, index: u32) -> result<tensor-data, error>;
} }
/// TODO: create function-specific errors (https://github.com/WebAssembly/wasi-nn/issues/42) /// TODO: create function-specific errors (https://github.com/WebAssembly/wasi-nn/issues/42)
@ -137,7 +139,8 @@ interface errors {
invalid-argument, invalid-argument,
// Invalid encoding. // Invalid encoding.
invalid-encoding, invalid-encoding,
busy, // The operation timed out.
timeout,
// Runtime Error. // Runtime Error.
runtime-error, runtime-error,
// Unsupported operation. // Unsupported operation.

63
src/commands/run.rs

@ -18,7 +18,7 @@ use wasmtime::{Engine, Func, Module, Store, StoreLimits, Val, ValType};
use wasmtime_wasi::WasiView; use wasmtime_wasi::WasiView;
#[cfg(feature = "wasi-nn")] #[cfg(feature = "wasi-nn")]
use wasmtime_wasi_nn::WasiNnCtx; use wasmtime_wasi_nn::wit::WasiNnView;
#[cfg(feature = "wasi-threads")] #[cfg(feature = "wasi-threads")]
use wasmtime_wasi_threads::WasiThreadsCtx; use wasmtime_wasi_threads::WasiThreadsCtx;
@ -624,40 +624,37 @@ impl RunCommand {
{ {
bail!("Cannot enable wasi-nn when the binary is not compiled with this feature."); bail!("Cannot enable wasi-nn when the binary is not compiled with this feature.");
} }
#[cfg(feature = "wasi-nn")] #[cfg(all(feature = "wasi-nn", feature = "component-model"))]
{ {
let (backends, registry) = self.collect_preloaded_nn_graphs()?;
match linker { match linker {
CliLinker::Core(linker) => { CliLinker::Core(linker) => {
wasmtime_wasi_nn::witx::add_to_linker(linker, |host| { wasmtime_wasi_nn::witx::add_to_linker(linker, |host| {
// This WASI proposal is currently not protected against Arc::get_mut(host.wasi_nn_witx.as_mut().unwrap())
// concurrent access--i.e., when wasi-threads is actively
// spawning new threads, we cannot (yet) safely allow access and
// fail if more than one thread has `Arc`-references to the
// context. Once this proposal is updated (as wasi-common has
// been) to allow concurrent access, this `Arc::get_mut`
// limitation can be removed.
Arc::get_mut(host.wasi_nn.as_mut().unwrap())
.expect("wasi-nn is not implemented with multi-threading support") .expect("wasi-nn is not implemented with multi-threading support")
})?; })?;
store.data_mut().wasi_nn_witx = Some(Arc::new(
wasmtime_wasi_nn::witx::WasiNnCtx::new(backends, registry),
));
} }
#[cfg(feature = "component-model")] #[cfg(feature = "component-model")]
CliLinker::Component(linker) => { CliLinker::Component(linker) => {
wasmtime_wasi_nn::wit::ML::add_to_linker(linker, |host| { wasmtime_wasi_nn::wit::add_to_linker(linker, |h: &mut Host| {
Arc::get_mut(host.wasi_nn.as_mut().unwrap()) let preview2_ctx =
.expect("wasi-nn is not implemented with multi-threading support") h.preview2_ctx.as_mut().expect("wasip2 is not configured");
let preview2_ctx = Arc::get_mut(preview2_ctx)
.expect("wasmtime_wasi is not compatible with threads")
.get_mut()
.unwrap();
let nn_ctx = Arc::get_mut(h.wasi_nn_wit.as_mut().unwrap())
.expect("wasi-nn is not implemented with multi-threading support");
WasiNnView::new(preview2_ctx.table(), nn_ctx)
})?; })?;
store.data_mut().wasi_nn_wit = Some(Arc::new(
wasmtime_wasi_nn::wit::WasiNnCtx::new(backends, registry),
));
} }
} }
let graphs = self
.run
.common
.wasi
.nn_graph
.iter()
.map(|g| (g.format.clone(), g.dir.clone()))
.collect::<Vec<_>>();
let (backends, registry) = wasmtime_wasi_nn::preload(&graphs)?;
store.data_mut().wasi_nn = Some(Arc::new(WasiNnCtx::new(backends, registry)));
} }
} }
@ -767,6 +764,21 @@ impl RunCommand {
store.data_mut().preview2_ctx = Some(Arc::new(Mutex::new(ctx))); store.data_mut().preview2_ctx = Some(Arc::new(Mutex::new(ctx)));
Ok(()) Ok(())
} }
#[cfg(feature = "wasi-nn")]
fn collect_preloaded_nn_graphs(
&self,
) -> Result<(Vec<wasmtime_wasi_nn::Backend>, wasmtime_wasi_nn::Registry)> {
let graphs = self
.run
.common
.wasi
.nn_graph
.iter()
.map(|g| (g.format.clone(), g.dir.clone()))
.collect::<Vec<_>>();
wasmtime_wasi_nn::preload(&graphs)
}
} }
#[derive(Default, Clone)] #[derive(Default, Clone)]
@ -779,7 +791,10 @@ struct Host {
preview2_ctx: Option<Arc<Mutex<wasmtime_wasi::preview1::WasiP1Ctx>>>, preview2_ctx: Option<Arc<Mutex<wasmtime_wasi::preview1::WasiP1Ctx>>>,
#[cfg(feature = "wasi-nn")] #[cfg(feature = "wasi-nn")]
wasi_nn: Option<Arc<WasiNnCtx>>, wasi_nn_wit: Option<Arc<wasmtime_wasi_nn::wit::WasiNnCtx>>,
#[cfg(feature = "wasi-nn")]
wasi_nn_witx: Option<Arc<wasmtime_wasi_nn::witx::WasiNnCtx>>,
#[cfg(feature = "wasi-threads")] #[cfg(feature = "wasi-threads")]
wasi_threads: Option<Arc<WasiThreadsCtx<Host>>>, wasi_threads: Option<Arc<WasiThreadsCtx<Host>>>,
#[cfg(feature = "wasi-http")] #[cfg(feature = "wasi-http")]

22
src/commands/serve.rs

@ -17,7 +17,7 @@ use wasmtime_wasi_http::io::TokioIo;
use wasmtime_wasi_http::{body::HyperOutgoingBody, WasiHttpCtx, WasiHttpView}; use wasmtime_wasi_http::{body::HyperOutgoingBody, WasiHttpCtx, WasiHttpView};
#[cfg(feature = "wasi-nn")] #[cfg(feature = "wasi-nn")]
use wasmtime_wasi_nn::WasiNnCtx; use wasmtime_wasi_nn::wit::WasiNnCtx;
struct Host { struct Host {
table: wasmtime::component::ResourceTable, table: wasmtime::component::ResourceTable,
@ -75,15 +75,8 @@ impl ServeCommand {
pub fn execute(mut self) -> Result<()> { pub fn execute(mut self) -> Result<()> {
self.run.common.init_logging()?; self.run.common.init_logging()?;
// We force cli errors before starting to listen for connections so then we don't // We force cli errors before starting to listen for connections so then
// accidentally delay them to the first request. // we don't accidentally delay them to the first request.
if self.run.common.wasi.nn == Some(true) {
#[cfg(not(feature = "wasi-nn"))]
{
bail!("Cannot enable wasi-nn when the binary is not compiled with this feature.");
}
}
if let Some(Profile::Guest { .. }) = &self.run.profile { if let Some(Profile::Guest { .. }) = &self.run.profile {
bail!("Cannot use the guest profiler with components"); bail!("Cannot use the guest profiler with components");
} }
@ -99,8 +92,8 @@ impl ServeCommand {
bail!("wasi-threads does not support components yet") bail!("wasi-threads does not support components yet")
} }
// The serve command requires both wasi-http and the component model, so we enable those by // The serve command requires both wasi-http and the component model, so
// default here. // we enable those by default here.
if self.run.common.wasi.http.replace(true) == Some(false) { if self.run.common.wasi.http.replace(true) == Some(false) {
bail!("wasi-http is required for the serve command, and must not be disabled"); bail!("wasi-http is required for the serve command, and must not be disabled");
} }
@ -227,7 +220,10 @@ impl ServeCommand {
} }
#[cfg(feature = "wasi-nn")] #[cfg(feature = "wasi-nn")]
{ {
wasmtime_wasi_nn::wit::ML::add_to_linker(linker, |host| host.nn.as_mut().unwrap())?; wasmtime_wasi_nn::wit::add_to_linker(linker, |h: &mut Host| {
let ctx = h.nn.as_mut().unwrap();
wasmtime_wasi_nn::wit::WasiNnView::new(&mut h.table, ctx)
})?;
} }
} }

18
supply-chain/audits.toml

@ -2028,6 +2028,12 @@ criteria = "safe-to-deploy"
version = "0.46.0" version = "0.46.0"
notes = "one use of unsafe to call windows specific api to get console handle." notes = "one use of unsafe to call windows specific api to get console handle."
[[audits.num-traits]]
who = "Andrew Brown <andrew.brown@intel.com>"
criteria = "safe-to-deploy"
version = "0.2.19"
notes = "As advertised: a numeric library. The only `unsafe` is from some float-to-int conversions, which seems expected."
[[audits.num_cpus]] [[audits.num_cpus]]
who = "Alex Crichton <alex@alexcrichton.com>" who = "Alex Crichton <alex@alexcrichton.com>"
criteria = "safe-to-deploy" criteria = "safe-to-deploy"
@ -2145,12 +2151,24 @@ criteria = "safe-to-deploy"
version = "2.0.0-rc.0" version = "2.0.0-rc.0"
notes = "As expected, this crate uses `unsafe` to access the `unsafe` `ort-sys` FFI functions; it also includes several `unsafe` implementations of `Send` for several structures. With the `load-dynamic` feature enabled, this crate will be `libloading` external libraries to call FFI functions. With the `fetch-models` feature enabled, this crate can also download arbitrary models to the local filesystem." notes = "As expected, this crate uses `unsafe` to access the `unsafe` `ort-sys` FFI functions; it also includes several `unsafe` implementations of `Send` for several structures. With the `load-dynamic` feature enabled, this crate will be `libloading` external libraries to call FFI functions. With the `fetch-models` feature enabled, this crate can also download arbitrary models to the local filesystem."
[[audits.ort]]
who = "Andrew Brown <andrew.brown@intel.com>"
criteria = "safe-to-deploy"
delta = "2.0.0-rc.0 -> 2.0.0-rc.2"
notes = "Same as previous audit: the crate inherently uses `unsafe` FFI calls for using ONNX through `ort-sys` (e.g., logging C error strings). The changes are relatively uninteresting: a lot of documentation, some `must_use`, and general refactoring due to changes in the underlying API."
[[audits.ort-sys]] [[audits.ort-sys]]
who = "Andrew Brown <andrew.brown@intel.com>" who = "Andrew Brown <andrew.brown@intel.com>"
criteria = "safe-to-deploy" criteria = "safe-to-deploy"
version = "2.0.0-rc.0" version = "2.0.0-rc.0"
notes = "As expected, this crate contains a significant number of `unsafe` definitions to expose the FFI surface of the ONNX libraries. Perhaps surprisingly, it also contains some `unsafe` system calls to locate the user's home directory. Another interesting bit is the `build.rs` script: with the `download-binaries` feature enabled, this script will retrieve and link various ONNX libraries from https://parcel.pyke.io. This seems par for the course with this kind of library, though; the alternative--attempting to find the library on an arbitrary system--can be quite complex." notes = "As expected, this crate contains a significant number of `unsafe` definitions to expose the FFI surface of the ONNX libraries. Perhaps surprisingly, it also contains some `unsafe` system calls to locate the user's home directory. Another interesting bit is the `build.rs` script: with the `download-binaries` feature enabled, this script will retrieve and link various ONNX libraries from https://parcel.pyke.io. This seems par for the course with this kind of library, though; the alternative--attempting to find the library on an arbitrary system--can be quite complex."
[[audits.ort-sys]]
who = "Andrew Brown <andrew.brown@intel.com>"
criteria = "safe-to-deploy"
delta = "2.0.0-rc.0 -> 2.0.0-rc.2"
notes = "This crate still downloads the ONNX libraries as a part of the `build.rs` script; now with more platform options for pre-built binaries stored in a `dist.txt` file. Otherwise largely unchanged since the previous audit."
[[audits.overload]] [[audits.overload]]
who = "Pat Hickey <phickey@fastly.com>" who = "Pat Hickey <phickey@fastly.com>"
criteria = "safe-to-deploy" criteria = "safe-to-deploy"

Loading…
Cancel
Save