Browse Source
* wasi-nn: use resources Recent discussion in the wasi-nn proposal (see [wasi-nn#59], e.g.) has concluded that the right approach for representing wasi-nn "things" (tensors, graph, etc.) is with a component model _resource_. This sweeping change brings Wasmtime's implementation in line with that decision. Initially I had structured this PR to remove all of the WITX-based implementation (#8530). But, after consulting in a Zulip [thread] on what other WASI proposals aim to do, this PR pivoted to support _both_` the WITX-based and WIT-based ABIs (e.g., preview1 era versus preview2, component model era). What is clear is that the WITX-based specification will remain "frozen in time" while the WIT-based implementation moves forward. What that means for this PR is a "split world" paradigm. In many places, we have to distinguish between the `wit` and `witx` versions of the same thing. This change isn't the end state yet: it's a big step forward towards bringing Wasmtime back in line with the WIT spec but, despite my best efforts, doesn't fully fix all the TODOs left behind over several years of development. I have, however, taken the liberty to refactor and fix various parts as I came across them (e.g., the ONNX backend). I plan to continue working on this in future PRs to figure out a good error paradigm (the current one is too wordy) and device residence. [wasi-nn#59]: https://github.com/WebAssembly/wasi-nn/pull/59 [thread]: https://bytecodealliance.zulipchat.com/#narrow/stream/219900-wasi/topic/wasi-nn's.20preview1.20vs.20preview2.20timeline prtest:full * vet: audit `ort`-related crate updates * Simplify `WasiNnView` With @alexcrichton's help, this change removes the `trait WasiNnView` and `struct WasiNnImpl` wrapping that the WIT-based implementation used for accessing the host context. Instead, `WasiNnView` is now a `struct` containing the mutable references it needs to make things work. This unwraps one complex layer of abstraction, though it does have the downside that it complicates CLI code to split borrows of `Host`. * Temporarily disable WIT check * Refactor errors to use `trappable_error_type` This change simplifies the return types of the host implementations of the WIT-based wasi-nn. There is more work to be done with errors, e.g., to catch up with the upstream decision to return errors as resources. But this is better than the previous mess.pull/8886/head
Andrew Brown
4 months ago
committed by
GitHub
36 changed files with 1571 additions and 743 deletions
@ -1,16 +0,0 @@ |
|||||
use anyhow::{Context, Result}; |
|
||||
use std::fs; |
|
||||
use test_programs::nn::{classify, sort_results}; |
|
||||
use wasi_nn::{ExecutionTarget, GraphBuilder, GraphEncoding}; |
|
||||
|
|
||||
pub fn main() -> Result<()> { |
|
||||
let graph = GraphBuilder::new(GraphEncoding::Onnx, ExecutionTarget::CPU) |
|
||||
.build_from_cache("mobilenet")?; |
|
||||
let tensor = fs::read("fixture/kitten.rgb") |
|
||||
.context("the tensor file to be mapped to the fixture directory")?; |
|
||||
let results = classify(graph, tensor)?; |
|
||||
let top_five = &sort_results(&results)[..5]; |
|
||||
println!("found results, sorted top 5: {:?}", top_five); |
|
||||
assert_eq!(top_five[0].class_id(), 284); |
|
||||
Ok(()) |
|
||||
} |
|
@ -1,18 +1,20 @@ |
|||||
use anyhow::{Context, Result}; |
use anyhow::{Context, Result}; |
||||
use std::fs; |
use std::fs; |
||||
use test_programs::nn::{classify, sort_results}; |
use test_programs::nn::{sort_results, wit}; |
||||
use wasi_nn::{ExecutionTarget, GraphBuilder, GraphEncoding}; |
|
||||
|
|
||||
pub fn main() -> Result<()> { |
pub fn main() -> Result<()> { |
||||
let model = fs::read("fixture/model.onnx") |
let model = fs::read("fixture/model.onnx") |
||||
.context("the model file to be mapped to the fixture directory")?; |
.context("the model file to be mapped to the fixture directory")?; |
||||
let graph = |
let graph = wit::load( |
||||
GraphBuilder::new(GraphEncoding::Onnx, ExecutionTarget::CPU).build_from_bytes([&model])?; |
&[model], |
||||
|
wit::GraphEncoding::Onnx, |
||||
|
wit::ExecutionTarget::Cpu, |
||||
|
)?; |
||||
let tensor = fs::read("fixture/000000062808.rgb") |
let tensor = fs::read("fixture/000000062808.rgb") |
||||
.context("the tensor file to be mapped to the fixture directory")?; |
.context("the tensor file to be mapped to the fixture directory")?; |
||||
let results = classify(graph, tensor)?; |
let results = wit::classify(graph, ("input", tensor), "output")?; |
||||
let top_five = &sort_results(&results)[..5]; |
let top_five = &sort_results(&results)[..5]; |
||||
// 963 is meat loaf, meatloaf.
|
// 963 is "meat loaf, meatloaf."
|
||||
// https://github.com/onnx/models/blob/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/synset.txt#L963
|
// https://github.com/onnx/models/blob/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/synset.txt#L963
|
||||
assert_eq!(top_five[0].class_id(), 963); |
assert_eq!(top_five[0].class_id(), 963); |
||||
println!("found results, sorted top 5: {:?}", top_five); |
println!("found results, sorted top 5: {:?}", top_five); |
@ -0,0 +1,25 @@ |
|||||
|
use anyhow::{Context, Result}; |
||||
|
use std::fs; |
||||
|
use test_programs::nn::{sort_results, wit}; |
||||
|
|
||||
|
pub fn main() -> Result<()> { |
||||
|
let xml = fs::read("fixture/model.xml") |
||||
|
.context("the model file to be mapped to the fixture directory")?; |
||||
|
let weights = fs::read("fixture/model.bin") |
||||
|
.context("the weights file to be mapped to the fixture directory")?; |
||||
|
let graph = wit::load( |
||||
|
&[xml, weights], |
||||
|
wit::GraphEncoding::Openvino, |
||||
|
wit::ExecutionTarget::Cpu, |
||||
|
)?; |
||||
|
let tensor = fs::read("fixture/tensor.bgr") |
||||
|
.context("the tensor file to be mapped to the fixture directory")?; |
||||
|
let results = wit::classify( |
||||
|
graph, |
||||
|
("input", tensor), |
||||
|
"MobilenetV2/Predictions/Reshape_1", |
||||
|
)?; |
||||
|
let top_five = &sort_results(&results)[..5]; |
||||
|
println!("found results, sorted top 5: {:?}", top_five); |
||||
|
Ok(()) |
||||
|
} |
@ -0,0 +1,17 @@ |
|||||
|
use anyhow::{Context, Result}; |
||||
|
use std::fs; |
||||
|
use test_programs::nn::{sort_results, wit}; |
||||
|
|
||||
|
pub fn main() -> Result<()> { |
||||
|
let graph = wit::load_by_name("fixtures")?; |
||||
|
let tensor: Vec<u8> = fs::read("fixture/tensor.bgr") |
||||
|
.context("the tensor file to be mapped to the fixture directory")?; |
||||
|
let results = wit::classify( |
||||
|
graph, |
||||
|
("input", tensor), |
||||
|
"MobilenetV2/Predictions/Reshape_1", |
||||
|
)?; |
||||
|
let top_five = &sort_results(&results)[..5]; |
||||
|
println!("found results, sorted top 5: {:?}", top_five); |
||||
|
Ok(()) |
||||
|
} |
@ -1,15 +1,14 @@ |
|||||
use anyhow::{Context, Result}; |
use anyhow::{Context, Result}; |
||||
use std::fs; |
use std::fs; |
||||
use test_programs::nn::{classify, sort_results}; |
use test_programs::nn::{sort_results, wit}; |
||||
use wasi_nn::{ExecutionTarget, GraphBuilder, GraphEncoding}; |
|
||||
|
|
||||
pub fn main() -> Result<()> { |
pub fn main() -> Result<()> { |
||||
let graph = GraphBuilder::new(GraphEncoding::Openvino, ExecutionTarget::CPU) |
let graph = wit::load_by_name("mobilenet")?; |
||||
.build_from_cache("fixtures")?; |
|
||||
let tensor = fs::read("fixture/tensor.bgr") |
let tensor = fs::read("fixture/tensor.bgr") |
||||
.context("the tensor file to be mapped to the fixture directory")?; |
.context("the tensor file to be mapped to the fixture directory")?; |
||||
let results = classify(graph, tensor)?; |
let results = wit::classify(graph, ("input", tensor), "output")?; |
||||
let top_five = &sort_results(&results)[..5]; |
let top_five = &sort_results(&results)[..5]; |
||||
println!("found results, sorted top 5: {:?}", top_five); |
println!("found results, sorted top 5: {:?}", top_five); |
||||
|
assert_eq!(top_five[0].class_id(), 284); |
||||
Ok(()) |
Ok(()) |
||||
} |
} |
@ -0,0 +1,22 @@ |
|||||
|
use anyhow::{Context, Result}; |
||||
|
use std::fs; |
||||
|
use test_programs::nn::{sort_results, witx}; |
||||
|
|
||||
|
pub fn main() -> Result<()> { |
||||
|
let model = fs::read("fixture/model.onnx") |
||||
|
.context("the model file to be mapped to the fixture directory")?; |
||||
|
let graph = witx::load( |
||||
|
&[&model], |
||||
|
witx::GraphEncoding::Onnx, |
||||
|
witx::ExecutionTarget::CPU, |
||||
|
)?; |
||||
|
let tensor = fs::read("fixture/000000062808.rgb") |
||||
|
.context("the tensor file to be mapped to the fixture directory")?; |
||||
|
let results = witx::classify(graph, tensor)?; |
||||
|
let top_five = &sort_results(&results)[..5]; |
||||
|
// 963 is "meat loaf, meatloaf."
|
||||
|
// https://github.com/onnx/models/blob/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/synset.txt#L963
|
||||
|
assert_eq!(top_five[0].class_id(), 963); |
||||
|
println!("found results, sorted top 5: {:?}", top_five); |
||||
|
Ok(()) |
||||
|
} |
@ -1,18 +1,20 @@ |
|||||
use anyhow::{Context, Result}; |
use anyhow::{Context, Result}; |
||||
use std::fs; |
use std::fs; |
||||
use test_programs::nn::{classify, sort_results}; |
use test_programs::nn::{sort_results, witx}; |
||||
use wasi_nn::{ExecutionTarget, GraphBuilder, GraphEncoding}; |
|
||||
|
|
||||
pub fn main() -> Result<()> { |
pub fn main() -> Result<()> { |
||||
let xml = fs::read("fixture/model.xml") |
let xml = fs::read("fixture/model.xml") |
||||
.context("the model file to be mapped to the fixture directory")?; |
.context("the model file to be mapped to the fixture directory")?; |
||||
let weights = fs::read("fixture/model.bin") |
let weights = fs::read("fixture/model.bin") |
||||
.context("the weights file to be mapped to the fixture directory")?; |
.context("the weights file to be mapped to the fixture directory")?; |
||||
let graph = GraphBuilder::new(GraphEncoding::Openvino, ExecutionTarget::CPU) |
let graph = witx::load( |
||||
.build_from_bytes([&xml, &weights])?; |
&[&xml, &weights], |
||||
|
witx::GraphEncoding::Openvino, |
||||
|
witx::ExecutionTarget::CPU, |
||||
|
)?; |
||||
let tensor = fs::read("fixture/tensor.bgr") |
let tensor = fs::read("fixture/tensor.bgr") |
||||
.context("the tensor file to be mapped to the fixture directory")?; |
.context("the tensor file to be mapped to the fixture directory")?; |
||||
let results = classify(graph, tensor)?; |
let results = witx::classify(graph, tensor)?; |
||||
let top_five = &sort_results(&results)[..5]; |
let top_five = &sort_results(&results)[..5]; |
||||
println!("found results, sorted top 5: {:?}", top_five); |
println!("found results, sorted top 5: {:?}", top_five); |
||||
Ok(()) |
Ok(()) |
@ -0,0 +1,17 @@ |
|||||
|
use anyhow::{Context, Result}; |
||||
|
use std::fs; |
||||
|
use test_programs::nn::{sort_results, witx}; |
||||
|
|
||||
|
pub fn main() -> Result<()> { |
||||
|
let graph = witx::load_by_name( |
||||
|
"fixtures", |
||||
|
witx::GraphEncoding::Openvino, |
||||
|
witx::ExecutionTarget::CPU, |
||||
|
)?; |
||||
|
let tensor: Vec<u8> = fs::read("fixture/tensor.bgr") |
||||
|
.context("the tensor file to be mapped to the fixture directory")?; |
||||
|
let results = witx::classify(graph, tensor)?; |
||||
|
let top_five = &sort_results(&results)[..5]; |
||||
|
println!("found results, sorted top 5: {:?}", top_five); |
||||
|
Ok(()) |
||||
|
} |
@ -0,0 +1,18 @@ |
|||||
|
use anyhow::{Context, Result}; |
||||
|
use std::fs; |
||||
|
use test_programs::nn::{sort_results, witx}; |
||||
|
|
||||
|
pub fn main() -> Result<()> { |
||||
|
let graph = witx::load_by_name( |
||||
|
"mobilenet", |
||||
|
witx::GraphEncoding::Onnx, |
||||
|
witx::ExecutionTarget::CPU, |
||||
|
)?; |
||||
|
let tensor = fs::read("fixture/tensor.bgr") |
||||
|
.context("the tensor file to be mapped to the fixture directory")?; |
||||
|
let results = witx::classify(graph, tensor)?; |
||||
|
let top_five = &sort_results(&results)[..5]; |
||||
|
println!("found results, sorted top 5: {:?}", top_five); |
||||
|
assert_eq!(top_five[0].class_id(), 284); |
||||
|
Ok(()) |
||||
|
} |
@ -0,0 +1,338 @@ |
|||||
|
//! Implements a `wasi-nn` [`BackendInner`] using ONNX via the `ort` crate.
|
||||
|
|
||||
|
use super::{BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner}; |
||||
|
use crate::backend::{read, Id}; |
||||
|
use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor, TensorType}; |
||||
|
use crate::{ExecutionContext, Graph}; |
||||
|
use anyhow::Context; |
||||
|
use ort::{inputs, GraphOptimizationLevel, Session}; |
||||
|
use std::path::Path; |
||||
|
use std::sync::{Arc, Mutex}; |
||||
|
|
||||
|
#[derive(Default)] |
||||
|
pub struct OnnxBackend(); |
||||
|
unsafe impl Send for OnnxBackend {} |
||||
|
unsafe impl Sync for OnnxBackend {} |
||||
|
|
||||
|
impl BackendInner for OnnxBackend { |
||||
|
fn encoding(&self) -> GraphEncoding { |
||||
|
GraphEncoding::Onnx |
||||
|
} |
||||
|
|
||||
|
fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError> { |
||||
|
if builders.len() != 1 { |
||||
|
return Err(BackendError::InvalidNumberOfBuilders(1, builders.len()).into()); |
||||
|
} |
||||
|
|
||||
|
let session = Session::builder()? |
||||
|
.with_optimization_level(GraphOptimizationLevel::Level3)? |
||||
|
.commit_from_memory(builders[0])?; |
||||
|
|
||||
|
let box_: Box<dyn BackendGraph> = |
||||
|
Box::new(OnnxGraph(Arc::new(Mutex::new(session)), target)); |
||||
|
Ok(box_.into()) |
||||
|
} |
||||
|
|
||||
|
fn as_dir_loadable<'a>(&'a mut self) -> Option<&'a mut dyn BackendFromDir> { |
||||
|
Some(self) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
impl BackendFromDir for OnnxBackend { |
||||
|
fn load_from_dir( |
||||
|
&mut self, |
||||
|
path: &Path, |
||||
|
target: ExecutionTarget, |
||||
|
) -> Result<Graph, BackendError> { |
||||
|
let model = read(&path.join("model.onnx"))?; |
||||
|
self.load(&[&model], target) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
struct OnnxGraph(Arc<Mutex<Session>>, #[allow(dead_code)] ExecutionTarget); |
||||
|
unsafe impl Send for OnnxGraph {} |
||||
|
unsafe impl Sync for OnnxGraph {} |
||||
|
|
||||
|
impl BackendGraph for OnnxGraph { |
||||
|
fn init_execution_context(&self) -> Result<ExecutionContext, BackendError> { |
||||
|
let session = self.0.lock().unwrap(); |
||||
|
// We need to hold on to the names of the inputs in order for
|
||||
|
// `set_input` to work with both indexes and names. Having the
|
||||
|
// dimensions and type around is useful for validation but could be
|
||||
|
// retrieved from the session.
|
||||
|
let mut inputs = vec![]; |
||||
|
for input in &session.inputs { |
||||
|
let shape = Shape::from_onnx_input(input)?; |
||||
|
inputs.push(TensorSlot { |
||||
|
shape, |
||||
|
tensor: None, |
||||
|
}); |
||||
|
} |
||||
|
// We need to keep track of the output shapes since they are used for
|
||||
|
// creating the output tensor.
|
||||
|
let mut outputs = vec![]; |
||||
|
for output in &session.outputs { |
||||
|
let shape = Shape::from_onnx_output(output)?; |
||||
|
outputs.push(TensorSlot { |
||||
|
shape, |
||||
|
tensor: None, |
||||
|
}); |
||||
|
} |
||||
|
let box_: Box<dyn BackendExecutionContext> = Box::new(OnnxExecutionContext { |
||||
|
session: self.0.clone(), |
||||
|
inputs, |
||||
|
outputs, |
||||
|
}); |
||||
|
Ok(box_.into()) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
struct OnnxExecutionContext { |
||||
|
session: Arc<Mutex<Session>>, |
||||
|
inputs: Vec<TensorSlot>, |
||||
|
outputs: Vec<TensorSlot>, |
||||
|
} |
||||
|
|
||||
|
unsafe impl Send for OnnxExecutionContext {} |
||||
|
unsafe impl Sync for OnnxExecutionContext {} |
||||
|
|
||||
|
impl OnnxExecutionContext { |
||||
|
/// Helper function for finding the internal index of a tensor by [`Id`].
|
||||
|
fn find(&self, id: Id, list: &[TensorSlot]) -> Result<usize, BackendError> { |
||||
|
let index = match id { |
||||
|
Id::Index(i) => { |
||||
|
let i = i as usize; |
||||
|
if i < list.len() { |
||||
|
i |
||||
|
} else { |
||||
|
return Err(BackendError::BackendAccess(anyhow::anyhow!( |
||||
|
"incorrect tensor index: {i} >= {}", |
||||
|
list.len() |
||||
|
))); |
||||
|
} |
||||
|
} |
||||
|
Id::Name(n) => list.iter().position(|s| s.shape.name == n).ok_or_else(|| { |
||||
|
BackendError::BackendAccess(anyhow::anyhow!("unknown tensor name: {n}")) |
||||
|
})?, |
||||
|
}; |
||||
|
Ok(index) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
impl BackendExecutionContext for OnnxExecutionContext { |
||||
|
fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> { |
||||
|
let index = self.find(id, &self.inputs)?; |
||||
|
let input = &mut self.inputs[index]; |
||||
|
if let Err(e) = input.shape.matches(tensor) { |
||||
|
return Err(e.into()); |
||||
|
} |
||||
|
// Hold the tensor data on the context until `compute` is called.
|
||||
|
input.tensor.replace(tensor.clone()); |
||||
|
Ok(()) |
||||
|
} |
||||
|
|
||||
|
fn compute(&mut self) -> Result<(), BackendError> { |
||||
|
let mut session_inputs: Vec<ort::SessionInputValue<'_>> = vec![]; |
||||
|
for i in &self.inputs { |
||||
|
session_inputs.extend(to_input_value(i)?); |
||||
|
} |
||||
|
let session = self.session.lock().unwrap(); |
||||
|
let session_outputs = session.run(session_inputs.as_slice())?; |
||||
|
for i in 0..self.outputs.len() { |
||||
|
// TODO: fix preexisting gap--this only handles f32 tensors.
|
||||
|
let raw: (Vec<i64>, &[f32]) = session_outputs[i].try_extract_raw_tensor()?; |
||||
|
let f32s = raw.1.to_vec(); |
||||
|
let output = &mut self.outputs[i]; |
||||
|
output.tensor.replace(Tensor { |
||||
|
dimensions: output.shape.dimensions_as_u32()?, |
||||
|
ty: output.shape.ty, |
||||
|
data: f32_vec_to_bytes(f32s), |
||||
|
}); |
||||
|
} |
||||
|
Ok(()) |
||||
|
} |
||||
|
|
||||
|
fn get_output(&mut self, id: Id) -> Result<Tensor, BackendError> { |
||||
|
let index = self.find(id, &self.outputs)?; |
||||
|
let output = &self.outputs[index]; |
||||
|
if let Some(tensor) = &output.tensor { |
||||
|
Ok(tensor.clone()) |
||||
|
} else { |
||||
|
Err(BackendError::BackendAccess(anyhow::anyhow!( |
||||
|
"missing output tensor: {}; has `compute` been called?", |
||||
|
output.shape.name |
||||
|
))) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
impl From<ort::Error> for BackendError { |
||||
|
fn from(e: ort::Error) -> Self { |
||||
|
BackendError::BackendAccess(e.into()) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
/// Holds a slot for ONNX session inputs and outputs.
|
||||
|
///
|
||||
|
/// TODO: it seems unfortunate that we have to "hold" some extra data per
|
||||
|
/// session but in the input case, this is necessary for name-based indexing.
|
||||
|
struct TensorSlot { |
||||
|
shape: Shape, |
||||
|
tensor: Option<Tensor>, |
||||
|
} |
||||
|
|
||||
|
/// Describes a tensor in ONNX terms.
|
||||
|
struct Shape { |
||||
|
name: String, |
||||
|
dimensions: Vec<i64>, |
||||
|
ty: TensorType, |
||||
|
} |
||||
|
|
||||
|
impl Shape { |
||||
|
fn from_onnx_input(input: &ort::Input) -> Result<Self, BackendError> { |
||||
|
let name = input.name.clone(); |
||||
|
let (dimensions, ty) = convert_value_type(&input.input_type)?; |
||||
|
Ok(Self { |
||||
|
name, |
||||
|
dimensions, |
||||
|
ty, |
||||
|
}) |
||||
|
} |
||||
|
|
||||
|
fn from_onnx_output(output: &ort::Output) -> Result<Self, BackendError> { |
||||
|
let name = output.name.clone(); |
||||
|
let (dimensions, ty) = convert_value_type(&output.output_type)?; |
||||
|
Ok(Self { |
||||
|
name, |
||||
|
dimensions, |
||||
|
ty, |
||||
|
}) |
||||
|
} |
||||
|
|
||||
|
fn dimensions_as_u32(&self) -> Result<Vec<u32>, BackendError> { |
||||
|
self.dimensions |
||||
|
.iter() |
||||
|
.map(|d| if *d == -1 { Ok(1) } else { convert_i64(d) }) |
||||
|
.collect() |
||||
|
} |
||||
|
|
||||
|
fn matches(&self, tensor: &Tensor) -> anyhow::Result<()> { |
||||
|
if self.dimensions.len() != tensor.dimensions.len() { |
||||
|
return Err(anyhow::anyhow!( |
||||
|
"input tensor cardinality does not match model: {:?} != {:?}", |
||||
|
self.dimensions, |
||||
|
tensor.dimensions |
||||
|
)); |
||||
|
} else { |
||||
|
for (&shape_dim, &tensor_dim) in self.dimensions.iter().zip(tensor.dimensions.iter()) { |
||||
|
let tensor_dim = tensor_dim as i64; |
||||
|
if !is_dynamic_dimension(shape_dim) && shape_dim != tensor_dim { |
||||
|
return Err(anyhow::anyhow!( |
||||
|
"input tensor dimensions do not match model: {:?} != {:?}", |
||||
|
self.dimensions, |
||||
|
tensor.dimensions |
||||
|
)); |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
if self.ty != tensor.ty { |
||||
|
return Err(anyhow::anyhow!( |
||||
|
"input tensor type does not match model: {:?} != {:?}", |
||||
|
self.ty, |
||||
|
tensor.ty |
||||
|
)); |
||||
|
} |
||||
|
Ok(()) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
fn convert_value_type(vt: &ort::ValueType) -> Result<(Vec<i64>, TensorType), BackendError> { |
||||
|
match vt { |
||||
|
ort::ValueType::Tensor { ty, dimensions } => { |
||||
|
let dims = dimensions.clone(); |
||||
|
let ty = (*ty).try_into()?; |
||||
|
Ok((dims, ty)) |
||||
|
} |
||||
|
_ => Err(BackendError::BackendAccess(anyhow::anyhow!( |
||||
|
"unsupported input type: {vt:?}" |
||||
|
))), |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
fn convert_i64(i: &i64) -> Result<u32, BackendError> { |
||||
|
u32::try_from(*i).map_err(|d| -> BackendError { |
||||
|
anyhow::anyhow!("unable to convert dimension to u32: {d}").into() |
||||
|
}) |
||||
|
} |
||||
|
|
||||
|
impl TryFrom<ort::TensorElementType> for TensorType { |
||||
|
type Error = BackendError; |
||||
|
fn try_from(ty: ort::TensorElementType) -> Result<Self, Self::Error> { |
||||
|
match ty { |
||||
|
ort::TensorElementType::Float32 => Ok(TensorType::Fp32), |
||||
|
ort::TensorElementType::Float64 => Ok(TensorType::Fp64), |
||||
|
ort::TensorElementType::Uint8 => Ok(TensorType::U8), |
||||
|
ort::TensorElementType::Int32 => Ok(TensorType::I32), |
||||
|
ort::TensorElementType::Int64 => Ok(TensorType::I64), |
||||
|
_ => Err(BackendError::BackendAccess(anyhow::anyhow!( |
||||
|
"unsupported tensor type: {ty:?}" |
||||
|
))), |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
fn to_input_value(slot: &TensorSlot) -> Result<[ort::SessionInputValue<'_>; 1], BackendError> { |
||||
|
match &slot.tensor { |
||||
|
Some(tensor) => match tensor.ty { |
||||
|
TensorType::Fp32 => { |
||||
|
let data = bytes_to_f32_vec(tensor.data.to_vec()); |
||||
|
let dimensions = tensor |
||||
|
.dimensions |
||||
|
.iter() |
||||
|
.map(|d| *d as i64) // TODO: fewer conversions
|
||||
|
.collect::<Vec<i64>>(); |
||||
|
Ok(inputs![(dimensions, Arc::new(data.into_boxed_slice()))] |
||||
|
.context("failed to create ONNX session input")?) |
||||
|
} |
||||
|
_ => { |
||||
|
unimplemented!("{:?} not supported by ONNX", tensor.ty); |
||||
|
} |
||||
|
}, |
||||
|
None => { |
||||
|
return Err(BackendError::BackendAccess(anyhow::anyhow!( |
||||
|
"missing input tensor: {}", |
||||
|
slot.shape.name |
||||
|
))); |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
pub fn f32_vec_to_bytes(data: Vec<f32>) -> Vec<u8> { |
||||
|
let chunks: Vec<[u8; 4]> = data.into_iter().map(|f| f.to_le_bytes()).collect(); |
||||
|
let result: Vec<u8> = chunks.iter().flatten().copied().collect(); |
||||
|
result |
||||
|
} |
||||
|
|
||||
|
pub fn bytes_to_f32_vec(data: Vec<u8>) -> Vec<f32> { |
||||
|
let chunks: Vec<&[u8]> = data.chunks(4).collect(); |
||||
|
let v: Vec<f32> = chunks |
||||
|
.into_iter() |
||||
|
.map(|c| f32::from_le_bytes(c.try_into().unwrap())) |
||||
|
.collect(); |
||||
|
|
||||
|
v.into_iter().collect() |
||||
|
} |
||||
|
|
||||
|
/// Returns whether the dimension is dynamic.
|
||||
|
///
|
||||
|
/// ONNX uses [dimensional variables] (i.e., name strings) to indicate that the
|
||||
|
/// value of a tensor dimension is user-defined, not fixed by the model. This is
|
||||
|
/// useful for batching up several inference requests, e.g. When `ort` returns a
|
||||
|
/// dimension of this kind, though, it uses `-1` to indicate that the dimension
|
||||
|
/// is dynamic.
|
||||
|
///
|
||||
|
/// [dimensional variables]:
|
||||
|
/// https://onnx.ai/onnx/repo-docs/IR.html#static-tensor-shapes
|
||||
|
fn is_dynamic_dimension(d: i64) -> bool { |
||||
|
d == -1 |
||||
|
} |
@ -1,149 +0,0 @@ |
|||||
//! Implements a `wasi-nn` [`BackendInner`] using ONNX via ort.
|
|
||||
|
|
||||
use super::{BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner}; |
|
||||
use crate::backend::read; |
|
||||
use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor, TensorType}; |
|
||||
use crate::{ExecutionContext, Graph}; |
|
||||
use ort::{inputs, GraphOptimizationLevel, Session}; |
|
||||
use std::path::Path; |
|
||||
use std::sync::{Arc, Mutex}; |
|
||||
|
|
||||
#[derive(Default)] |
|
||||
pub struct OnnxBackend(); |
|
||||
unsafe impl Send for OnnxBackend {} |
|
||||
unsafe impl Sync for OnnxBackend {} |
|
||||
|
|
||||
impl BackendInner for OnnxBackend { |
|
||||
fn encoding(&self) -> GraphEncoding { |
|
||||
GraphEncoding::Onnx |
|
||||
} |
|
||||
|
|
||||
fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError> { |
|
||||
if builders.len() != 1 { |
|
||||
return Err(BackendError::InvalidNumberOfBuilders(1, builders.len()).into()); |
|
||||
} |
|
||||
|
|
||||
let session = Session::builder()? |
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)? |
|
||||
.with_model_from_memory(builders[0])?; |
|
||||
|
|
||||
let box_: Box<dyn BackendGraph> = |
|
||||
Box::new(ONNXGraph(Arc::new(Mutex::new(session)), target)); |
|
||||
Ok(box_.into()) |
|
||||
} |
|
||||
|
|
||||
fn as_dir_loadable<'a>(&'a mut self) -> Option<&'a mut dyn BackendFromDir> { |
|
||||
Some(self) |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
impl BackendFromDir for OnnxBackend { |
|
||||
fn load_from_dir( |
|
||||
&mut self, |
|
||||
path: &Path, |
|
||||
target: ExecutionTarget, |
|
||||
) -> Result<Graph, BackendError> { |
|
||||
let model = read(&path.join("model.onnx"))?; |
|
||||
self.load(&[&model], target) |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
struct ONNXGraph(Arc<Mutex<Session>>, #[allow(dead_code)] ExecutionTarget); |
|
||||
|
|
||||
unsafe impl Send for ONNXGraph {} |
|
||||
unsafe impl Sync for ONNXGraph {} |
|
||||
|
|
||||
impl BackendGraph for ONNXGraph { |
|
||||
fn init_execution_context(&self) -> Result<ExecutionContext, BackendError> { |
|
||||
let session = self.0.lock().unwrap(); |
|
||||
let inputs = session.inputs.iter().map(|_| None).collect::<Vec<_>>(); |
|
||||
let outputs = session.outputs.iter().map(|_| None).collect::<Vec<_>>(); |
|
||||
let box_: Box<dyn BackendExecutionContext> = Box::new(ONNXExecutionContext { |
|
||||
session: self.0.clone(), |
|
||||
inputs, |
|
||||
outputs, |
|
||||
}); |
|
||||
Ok(box_.into()) |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
struct ONNXExecutionContext { |
|
||||
session: Arc<Mutex<Session>>, |
|
||||
inputs: Vec<Option<Tensor>>, |
|
||||
outputs: Vec<Option<Vec<u8>>>, |
|
||||
} |
|
||||
|
|
||||
unsafe impl Send for ONNXExecutionContext {} |
|
||||
unsafe impl Sync for ONNXExecutionContext {} |
|
||||
|
|
||||
impl BackendExecutionContext for ONNXExecutionContext { |
|
||||
fn set_input(&mut self, index: u32, tensor: &Tensor) -> Result<(), BackendError> { |
|
||||
self.inputs[index as usize].replace(tensor.clone()); |
|
||||
Ok(()) |
|
||||
} |
|
||||
|
|
||||
fn compute(&mut self) -> Result<(), BackendError> { |
|
||||
let shaped_inputs: Vec<_> = self |
|
||||
.inputs |
|
||||
.iter() |
|
||||
.enumerate() |
|
||||
.map(|(i, _o)| { |
|
||||
let input = self.inputs[i].as_ref().unwrap(); |
|
||||
let dims = input |
|
||||
.dimensions |
|
||||
.as_slice() |
|
||||
.iter() |
|
||||
.map(|d| *d as i64) |
|
||||
.collect::<Vec<_>>(); |
|
||||
match input.tensor_type { |
|
||||
TensorType::Fp32 => { |
|
||||
let data = bytes_to_f32_vec(input.data.to_vec()); |
|
||||
inputs![(dims, Arc::new(data.into_boxed_slice()))].unwrap() |
|
||||
} |
|
||||
_ => { |
|
||||
unimplemented!("{:?} not supported by ONNX", input.tensor_type); |
|
||||
} |
|
||||
} |
|
||||
}) |
|
||||
.flatten() |
|
||||
.collect(); |
|
||||
|
|
||||
let session = self.session.lock().unwrap(); |
|
||||
let res = session.run(shaped_inputs.as_slice())?; |
|
||||
|
|
||||
for i in 0..self.outputs.len() { |
|
||||
let raw: (Vec<i64>, &[f32]) = res[i].extract_raw_tensor()?; |
|
||||
let f32s = raw.1.to_vec(); |
|
||||
self.outputs[i].replace(f32_vec_to_bytes(f32s)); |
|
||||
} |
|
||||
Ok(()) |
|
||||
} |
|
||||
|
|
||||
fn get_output(&mut self, index: u32, destination: &mut [u8]) -> Result<u32, BackendError> { |
|
||||
let output = self.outputs[index as usize].as_ref().unwrap(); |
|
||||
destination[..output.len()].copy_from_slice(output); |
|
||||
Ok(output.len() as u32) |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
impl From<ort::Error> for BackendError { |
|
||||
fn from(e: ort::Error) -> Self { |
|
||||
BackendError::BackendAccess(e.into()) |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
pub fn f32_vec_to_bytes(data: Vec<f32>) -> Vec<u8> { |
|
||||
let chunks: Vec<[u8; 4]> = data.into_iter().map(|f| f.to_le_bytes()).collect(); |
|
||||
let result: Vec<u8> = chunks.iter().flatten().copied().collect(); |
|
||||
result |
|
||||
} |
|
||||
|
|
||||
pub fn bytes_to_f32_vec(data: Vec<u8>) -> Vec<f32> { |
|
||||
let chunks: Vec<&[u8]> = data.chunks(4).collect(); |
|
||||
let v: Vec<f32> = chunks |
|
||||
.into_iter() |
|
||||
.map(|c| f32::from_le_bytes(c.try_into().unwrap())) |
|
||||
.collect(); |
|
||||
|
|
||||
v.into_iter().collect() |
|
||||
} |
|
@ -1,146 +0,0 @@ |
|||||
//! Implements the host state for the `wasi-nn` API: [WasiNnCtx].
|
|
||||
|
|
||||
use crate::backend::{self, BackendError}; |
|
||||
use crate::wit::types::GraphEncoding; |
|
||||
use crate::{Backend, ExecutionContext, Graph, InMemoryRegistry, Registry}; |
|
||||
use anyhow::anyhow; |
|
||||
use std::{collections::HashMap, hash::Hash, path::Path}; |
|
||||
use thiserror::Error; |
|
||||
use wiggle::GuestError; |
|
||||
|
|
||||
type GraphId = u32; |
|
||||
type GraphExecutionContextId = u32; |
|
||||
type BackendName = String; |
|
||||
type GraphDirectory = String; |
|
||||
|
|
||||
/// Construct an in-memory registry from the available backends and a list of
|
|
||||
/// `(<backend name>, <graph directory>)`. This assumes graphs can be loaded
|
|
||||
/// from a local directory, which is a safe assumption currently for the current
|
|
||||
/// model types.
|
|
||||
pub fn preload( |
|
||||
preload_graphs: &[(BackendName, GraphDirectory)], |
|
||||
) -> anyhow::Result<(impl IntoIterator<Item = Backend>, Registry)> { |
|
||||
let mut backends = backend::list(); |
|
||||
let mut registry = InMemoryRegistry::new(); |
|
||||
for (kind, path) in preload_graphs { |
|
||||
let kind_ = kind.parse()?; |
|
||||
let backend = backends |
|
||||
.iter_mut() |
|
||||
.find(|b| b.encoding() == kind_) |
|
||||
.ok_or(anyhow!("unsupported backend: {}", kind))? |
|
||||
.as_dir_loadable() |
|
||||
.ok_or(anyhow!("{} does not support directory loading", kind))?; |
|
||||
registry.load(backend, Path::new(path))?; |
|
||||
} |
|
||||
Ok((backends, Registry::from(registry))) |
|
||||
} |
|
||||
|
|
||||
/// Capture the state necessary for calling into the backend ML libraries.
|
|
||||
pub struct WasiNnCtx { |
|
||||
pub(crate) backends: HashMap<GraphEncoding, Backend>, |
|
||||
pub(crate) registry: Registry, |
|
||||
pub(crate) graphs: Table<GraphId, Graph>, |
|
||||
pub(crate) executions: Table<GraphExecutionContextId, ExecutionContext>, |
|
||||
} |
|
||||
|
|
||||
impl WasiNnCtx { |
|
||||
/// Make a new context from the default state.
|
|
||||
pub fn new(backends: impl IntoIterator<Item = Backend>, registry: Registry) -> Self { |
|
||||
let backends = backends.into_iter().map(|b| (b.encoding(), b)).collect(); |
|
||||
Self { |
|
||||
backends, |
|
||||
registry, |
|
||||
graphs: Table::default(), |
|
||||
executions: Table::default(), |
|
||||
} |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
/// Possible errors while interacting with [WasiNnCtx].
|
|
||||
#[derive(Debug, Error)] |
|
||||
pub enum WasiNnError { |
|
||||
#[error("backend error")] |
|
||||
BackendError(#[from] BackendError), |
|
||||
#[error("guest error")] |
|
||||
GuestError(#[from] GuestError), |
|
||||
#[error("usage error")] |
|
||||
UsageError(#[from] UsageError), |
|
||||
} |
|
||||
|
|
||||
#[derive(Debug, Error)] |
|
||||
pub enum UsageError { |
|
||||
#[error("Invalid context; has the load function been called?")] |
|
||||
InvalidContext, |
|
||||
#[error("Only OpenVINO's IR is currently supported, passed encoding: {0:?}")] |
|
||||
InvalidEncoding(GraphEncoding), |
|
||||
#[error("OpenVINO expects only two buffers (i.e. [ir, weights]), passed: {0}")] |
|
||||
InvalidNumberOfBuilders(u32), |
|
||||
#[error("Invalid graph handle; has it been loaded?")] |
|
||||
InvalidGraphHandle, |
|
||||
#[error("Invalid execution context handle; has it been initialized?")] |
|
||||
InvalidExecutionContextHandle, |
|
||||
#[error("Not enough memory to copy tensor data of size: {0}")] |
|
||||
NotEnoughMemory(u32), |
|
||||
#[error("No graph found with name: {0}")] |
|
||||
NotFound(String), |
|
||||
} |
|
||||
|
|
||||
pub(crate) type WasiNnResult<T> = std::result::Result<T, WasiNnError>; |
|
||||
|
|
||||
/// Record handle entries in a table.
|
|
||||
pub struct Table<K, V> { |
|
||||
entries: HashMap<K, V>, |
|
||||
next_key: u32, |
|
||||
} |
|
||||
|
|
||||
impl<K, V> Default for Table<K, V> { |
|
||||
fn default() -> Self { |
|
||||
Self { |
|
||||
entries: HashMap::new(), |
|
||||
next_key: 0, |
|
||||
} |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
impl<K, V> Table<K, V> |
|
||||
where |
|
||||
K: Eq + Hash + From<u32> + Copy, |
|
||||
{ |
|
||||
pub fn insert(&mut self, value: V) -> K { |
|
||||
let key = self.use_next_key(); |
|
||||
self.entries.insert(key, value); |
|
||||
key |
|
||||
} |
|
||||
|
|
||||
pub fn get(&self, key: K) -> Option<&V> { |
|
||||
self.entries.get(&key) |
|
||||
} |
|
||||
|
|
||||
pub fn get_mut(&mut self, key: K) -> Option<&mut V> { |
|
||||
self.entries.get_mut(&key) |
|
||||
} |
|
||||
|
|
||||
fn use_next_key(&mut self) -> K { |
|
||||
let current = self.next_key; |
|
||||
self.next_key += 1; |
|
||||
K::from(current) |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
#[cfg(test)] |
|
||||
mod test { |
|
||||
use super::*; |
|
||||
use crate::registry::GraphRegistry; |
|
||||
|
|
||||
#[test] |
|
||||
fn example() { |
|
||||
struct FakeRegistry; |
|
||||
impl GraphRegistry for FakeRegistry { |
|
||||
fn get_mut(&mut self, _: &str) -> Option<&mut Graph> { |
|
||||
None |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
let _ctx = WasiNnCtx::new([], Registry::from(FakeRegistry)); |
|
||||
} |
|
||||
} |
|
@ -1,52 +1,6 @@ |
|||||
use crate::check::artifacts_dir; |
//! Provide a Wasmtime embedding for executing wasi-nn test programs.
|
||||
use anyhow::Result; |
|
||||
use std::path::Path; |
|
||||
use wasi_common::sync::{Dir, WasiCtxBuilder}; |
|
||||
use wasi_common::WasiCtx; |
|
||||
use wasmtime::{Config, Engine, Linker, Module, Store}; |
|
||||
use wasmtime_wasi_nn::{Backend, InMemoryRegistry, WasiNnCtx}; |
|
||||
|
|
||||
const PREOPENED_DIR_NAME: &str = "fixture"; |
pub mod wit; |
||||
|
pub mod witx; |
||||
|
|
||||
/// Run a wasi-nn test program. This is modeled after
|
pub const PREOPENED_DIR_NAME: &str = "fixture"; |
||||
/// `crates/wasi/tests/all/main.rs` but still uses the older preview1 API
|
|
||||
/// for file reads.
|
|
||||
pub fn run(path: &str, backend: Backend, preload_model: bool) -> Result<()> { |
|
||||
let path = Path::new(path); |
|
||||
let engine = Engine::new(&Config::new())?; |
|
||||
let mut linker = Linker::new(&engine); |
|
||||
wasmtime_wasi_nn::witx::add_to_linker(&mut linker, |s: &mut Ctx| &mut s.wasi_nn)?; |
|
||||
wasi_common::sync::add_to_linker(&mut linker, |s: &mut Ctx| &mut s.wasi)?; |
|
||||
let module = Module::from_file(&engine, path)?; |
|
||||
let mut store = Store::new(&engine, Ctx::new(&artifacts_dir(), preload_model, backend)?); |
|
||||
let instance = linker.instantiate(&mut store, &module)?; |
|
||||
let start = instance.get_typed_func::<(), ()>(&mut store, "_start")?; |
|
||||
start.call(&mut store, ())?; |
|
||||
Ok(()) |
|
||||
} |
|
||||
|
|
||||
/// The host state for running wasi-nn tests.
|
|
||||
struct Ctx { |
|
||||
wasi: WasiCtx, |
|
||||
wasi_nn: WasiNnCtx, |
|
||||
} |
|
||||
|
|
||||
impl Ctx { |
|
||||
fn new(preopen_dir: &Path, preload_model: bool, mut backend: Backend) -> Result<Self> { |
|
||||
let preopen_dir = Dir::open_ambient_dir(preopen_dir, cap_std::ambient_authority())?; |
|
||||
let mut builder = WasiCtxBuilder::new(); |
|
||||
builder |
|
||||
.inherit_stdio() |
|
||||
.preopened_dir(preopen_dir, PREOPENED_DIR_NAME)?; |
|
||||
let wasi = builder.build(); |
|
||||
|
|
||||
let mut registry = InMemoryRegistry::new(); |
|
||||
let mobilenet_dir = artifacts_dir(); |
|
||||
if preload_model { |
|
||||
registry.load((backend).as_dir_loadable().unwrap(), &mobilenet_dir)?; |
|
||||
} |
|
||||
let wasi_nn = WasiNnCtx::new([backend.into()], registry.into()); |
|
||||
|
|
||||
Ok(Self { wasi, wasi_nn }) |
|
||||
} |
|
||||
} |
|
||||
|
@ -0,0 +1,73 @@ |
|||||
|
use super::PREOPENED_DIR_NAME; |
||||
|
use crate::check::artifacts_dir; |
||||
|
use anyhow::{anyhow, Result}; |
||||
|
use std::path::Path; |
||||
|
use wasmtime::component::{Component, Linker, ResourceTable}; |
||||
|
use wasmtime::{Config, Engine, Store}; |
||||
|
use wasmtime_wasi::bindings::sync::Command; |
||||
|
use wasmtime_wasi::{DirPerms, FilePerms, WasiCtx, WasiCtxBuilder}; |
||||
|
use wasmtime_wasi_nn::wit::WasiNnView; |
||||
|
use wasmtime_wasi_nn::{wit::WasiNnCtx, Backend, InMemoryRegistry}; |
||||
|
|
||||
|
/// Run a wasi-nn test program. This is modeled after
|
||||
|
/// `crates/wasi/tests/all/main.rs` but still uses the older preview1 API for
|
||||
|
/// file reads.
|
||||
|
pub fn run(path: &str, backend: Backend, preload_model: bool) -> Result<()> { |
||||
|
let path = Path::new(path); |
||||
|
let engine = Engine::new(&Config::new())?; |
||||
|
let mut linker = Linker::new(&engine); |
||||
|
wasmtime_wasi_nn::wit::add_to_linker(&mut linker, |c: &mut Ctx| { |
||||
|
WasiNnView::new(&mut c.table, &mut c.wasi_nn) |
||||
|
})?; |
||||
|
wasmtime_wasi::add_to_linker_sync(&mut linker)?; |
||||
|
let module = Component::from_file(&engine, path)?; |
||||
|
let mut store = Store::new(&engine, Ctx::new(&artifacts_dir(), preload_model, backend)?); |
||||
|
let command = Command::instantiate(&mut store, &module, &linker)?; |
||||
|
let result = command.wasi_cli_run().call_run(&mut store)?; |
||||
|
result.map_err(|_| anyhow!("failed to run command")) |
||||
|
} |
||||
|
|
||||
|
/// The host state for running wasi-nn component tests.
|
||||
|
struct Ctx { |
||||
|
wasi: WasiCtx, |
||||
|
wasi_nn: WasiNnCtx, |
||||
|
table: ResourceTable, |
||||
|
} |
||||
|
|
||||
|
impl Ctx { |
||||
|
fn new(preopen_dir: &Path, preload_model: bool, mut backend: Backend) -> Result<Self> { |
||||
|
let mut builder = WasiCtxBuilder::new(); |
||||
|
builder.inherit_stdio().preopened_dir( |
||||
|
preopen_dir, |
||||
|
PREOPENED_DIR_NAME, |
||||
|
DirPerms::READ, |
||||
|
FilePerms::READ, |
||||
|
)?; |
||||
|
let wasi = builder.build(); |
||||
|
|
||||
|
let mut registry = InMemoryRegistry::new(); |
||||
|
let mobilenet_dir = artifacts_dir(); |
||||
|
if preload_model { |
||||
|
registry.load((backend).as_dir_loadable().unwrap(), &mobilenet_dir)?; |
||||
|
} |
||||
|
let wasi_nn = WasiNnCtx::new([backend.into()], registry.into()); |
||||
|
|
||||
|
let table = ResourceTable::new(); |
||||
|
|
||||
|
Ok(Self { |
||||
|
wasi, |
||||
|
wasi_nn, |
||||
|
table, |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
impl wasmtime_wasi::WasiView for Ctx { |
||||
|
fn ctx(&mut self) -> &mut WasiCtx { |
||||
|
&mut self.wasi |
||||
|
} |
||||
|
|
||||
|
fn table(&mut self) -> &mut ResourceTable { |
||||
|
&mut self.table |
||||
|
} |
||||
|
} |
@ -0,0 +1,52 @@ |
|||||
|
use super::PREOPENED_DIR_NAME; |
||||
|
use crate::check::artifacts_dir; |
||||
|
use anyhow::Result; |
||||
|
use std::path::Path; |
||||
|
use wasmtime::{Config, Engine, Linker, Module, Store}; |
||||
|
use wasmtime_wasi::{preview1::WasiP1Ctx, DirPerms, FilePerms, WasiCtxBuilder}; |
||||
|
use wasmtime_wasi_nn::{witx::WasiNnCtx, Backend, InMemoryRegistry}; |
||||
|
|
||||
|
/// Run a wasi-nn test program. This is modeled after
|
||||
|
/// `crates/wasi/tests/all/main.rs` but still uses the older preview1 API
|
||||
|
/// for file reads.
|
||||
|
pub fn run(path: &str, backend: Backend, preload_model: bool) -> Result<()> { |
||||
|
let path = Path::new(path); |
||||
|
let engine = Engine::new(&Config::new())?; |
||||
|
let mut linker = Linker::new(&engine); |
||||
|
wasmtime_wasi_nn::witx::add_to_linker(&mut linker, |s: &mut Ctx| &mut s.wasi_nn)?; |
||||
|
wasmtime_wasi::preview1::add_to_linker_sync(&mut linker, |s: &mut Ctx| &mut s.wasi)?; |
||||
|
let module = Module::from_file(&engine, path)?; |
||||
|
let mut store = Store::new(&engine, Ctx::new(&artifacts_dir(), preload_model, backend)?); |
||||
|
let instance = linker.instantiate(&mut store, &module)?; |
||||
|
let start = instance.get_typed_func::<(), ()>(&mut store, "_start")?; |
||||
|
start.call(&mut store, ())?; |
||||
|
Ok(()) |
||||
|
} |
||||
|
|
||||
|
/// The host state for running wasi-nn tests.
|
||||
|
struct Ctx { |
||||
|
wasi: WasiP1Ctx, |
||||
|
wasi_nn: WasiNnCtx, |
||||
|
} |
||||
|
|
||||
|
impl Ctx { |
||||
|
fn new(preopen_dir: &Path, preload_model: bool, mut backend: Backend) -> Result<Self> { |
||||
|
let mut builder = WasiCtxBuilder::new(); |
||||
|
builder.inherit_stdio().preopened_dir( |
||||
|
preopen_dir, |
||||
|
PREOPENED_DIR_NAME, |
||||
|
DirPerms::READ, |
||||
|
FilePerms::READ, |
||||
|
)?; |
||||
|
let wasi = builder.build_p1(); |
||||
|
|
||||
|
let mut registry = InMemoryRegistry::new(); |
||||
|
let mobilenet_dir = artifacts_dir(); |
||||
|
if preload_model { |
||||
|
registry.load((backend).as_dir_loadable().unwrap(), &mobilenet_dir)?; |
||||
|
} |
||||
|
let wasi_nn = WasiNnCtx::new([backend.into()], registry.into()); |
||||
|
|
||||
|
Ok(Self { wasi, wasi_nn }) |
||||
|
} |
||||
|
} |
Loading…
Reference in new issue