Browse Source

Add FP16 and I64 support for wasi-nn WinML backend. (#8964)

* Add FP16 and I64 support for wasi-nn WinML backend.

Some devices may not support FP32.

prtest:full

* Remove unnecessary features.

* Address comments.

* Check alignment before from_raw_parts.

* Implement PartialEq for Tensor.

* Remove duplicated shape info from set_input.

* Update alignment checker.

* Add comments about creating TensorFloat16Bit from f32 array.

* Use PartialEq attribute.

* Audit new WinML dependencies

---------

Co-authored-by: Andrew Brown <andrew.brown@intel.com>
pull/9100/head
jianjunz 3 months ago
committed by GitHub
parent
commit
6907868078
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 76
      Cargo.lock
  2. 8
      crates/wasi-nn/Cargo.toml
  3. 240
      crates/wasi-nn/src/backend/winml.rs
  4. 2
      crates/wasi-nn/src/lib.rs
  5. 35
      supply-chain/audits.toml
  6. 36
      supply-chain/imports.lock

76
Cargo.lock

@ -161,6 +161,12 @@ version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4a4ddaa51a5bc52a6948f74c06d20aaaddb71924eab79b8c97a8c556e942d6a"
[[package]]
name = "base64"
version = "0.22.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]]
name = "bit-set"
version = "0.5.2"
@ -1187,9 +1193,9 @@ checksum = "cda653ca797810c02f7ca4b804b40b8b95ae046eb989d356bce17919a8c25499"
[[package]]
name = "flate2"
version = "1.0.28"
version = "1.0.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46303f565772937ffe1d394a4fac6f411c6013172fadde9dcdb1e147a086940e"
checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae"
dependencies = [
"crc32fast",
"miniz_oxide",
@ -2387,6 +2393,21 @@ dependencies = [
"zeroize",
]
[[package]]
name = "rustls"
version = "0.23.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebbbdb961df0ad3f2652da8f3fdc4b36122f568f968f45ad3316f26c025c677b"
dependencies = [
"log",
"once_cell",
"ring",
"rustls-pki-types",
"rustls-webpki",
"subtle",
"zeroize",
]
[[package]]
name = "rustls-pki-types"
version = "1.3.1"
@ -2682,9 +2703,9 @@ dependencies = [
[[package]]
name = "tar"
version = "0.4.40"
version = "0.4.41"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b16afcea1f22891c49a00c751c7b63b2233284064f11a200fc624137c51e2ddb"
checksum = "cb797dad5fb5b76fcf519e702f4a589483b5ef06567f160c392832c1f5e44909"
dependencies = [
"filetime",
"libc",
@ -2745,7 +2766,7 @@ name = "test-programs"
version = "0.0.0"
dependencies = [
"anyhow",
"base64",
"base64 0.21.0",
"futures",
"getrandom",
"libc",
@ -2864,7 +2885,7 @@ version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "775e0c0f0adb3a2f22a00c4745d728b479985fc15ee7ca6a2608388c5569860f"
dependencies = [
"rustls",
"rustls 0.22.4",
"rustls-pki-types",
"tokio",
]
@ -3032,16 +3053,15 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
[[package]]
name = "ureq"
version = "2.9.6"
version = "2.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11f214ce18d8b2cbe84ed3aa6486ed3f5b285cf8d8fbdbce9f3f767a724adc35"
checksum = "72139d247e5f97a3eff96229a7ae85ead5328a39efe76f8bf5a06313d505b6ea"
dependencies = [
"base64",
"base64 0.22.1",
"log",
"once_cell",
"rustls",
"rustls 0.23.7",
"rustls-pki-types",
"rustls-webpki",
"url",
"webpki-roots",
]
@ -3492,7 +3512,7 @@ name = "wasmtime-cache"
version = "25.0.0"
dependencies = [
"anyhow",
"base64",
"base64 0.21.0",
"directories-next",
"filetime",
"log",
@ -3859,14 +3879,14 @@ version = "25.0.0"
dependencies = [
"anyhow",
"async-trait",
"base64",
"base64 0.21.0",
"bytes",
"futures",
"http",
"http-body",
"http-body-util",
"hyper",
"rustls",
"rustls 0.22.4",
"sha2",
"test-log",
"test-programs-artifacts",
@ -4141,6 +4161,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be"
dependencies = [
"windows-core",
"windows-implement",
"windows-interface",
"windows-targets 0.52.0",
]
@ -4153,6 +4175,28 @@ dependencies = [
"windows-targets 0.52.0",
]
[[package]]
name = "windows-implement"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "12168c33176773b86799be25e2a2ba07c7aab9968b37541f1094dbd7a60c8946"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.60",
]
[[package]]
name = "windows-interface"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d8dc32e0095a7eeccebd0e3f09e9509365ecb3fc6ac4d6f5f14a3f6392942d1"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.60",
]
[[package]]
name = "windows-sys"
version = "0.48.0"
@ -4416,9 +4460,9 @@ dependencies = [
[[package]]
name = "xattr"
version = "1.2.0"
version = "1.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "914566e6413e7fa959cc394fb30e563ba80f3541fbd40816d4c05a0fc3f2a0f1"
checksum = "8da84f1a25939b27f6820d92aed108f83ff920fdf11a7b19366c27c4cda81d4f"
dependencies = [
"libc",
"linux-raw-sys",

8
crates/wasi-nn/Cargo.toml

@ -40,7 +40,13 @@ ort = { version = "2.0.0-rc.2", default-features = false, features = [
[target.'cfg(windows)'.dependencies.windows]
version = "0.52"
features = ["AI_MachineLearning", "Storage_Streams", "Foundation_Collections"]
features = [
"AI_MachineLearning",
"Storage_Streams",
"Foundation_Collections",
# For getting IVectorView<i64> from tensor.dimensions.
"implement",
]
optional = true
[build-dependencies]

240
crates/wasi-nn/src/backend/winml.rs

@ -13,7 +13,7 @@ use crate::backend::{
use crate::wit::{ExecutionTarget, GraphEncoding, Tensor, TensorType};
use crate::{ExecutionContext, Graph};
use std::{fs::File, io::Read, mem::size_of, path::Path};
use windows::core::{ComInterface, HSTRING};
use windows::core::{ComInterface, Error, IInspectable, HSTRING};
use windows::Foundation::Collections::IVectorView;
use windows::Storage::Streams::{
DataWriter, InMemoryRandomAccessStream, RandomAccessStreamReference,
@ -21,7 +21,7 @@ use windows::Storage::Streams::{
use windows::AI::MachineLearning::{
ILearningModelFeatureDescriptor, LearningModel, LearningModelBinding, LearningModelDevice,
LearningModelDeviceKind, LearningModelEvaluationResult, LearningModelSession,
TensorFeatureDescriptor, TensorFloat,
TensorFeatureDescriptor, TensorFloat, TensorFloat16Bit, TensorInt64Bit, TensorKind,
};
#[derive(Default)]
@ -136,32 +136,14 @@ impl WinMLExecutionContext {
impl BackendExecutionContext for WinMLExecutionContext {
fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> {
// TODO: Clear previous bindings when needed.
let input_features = self.session.Model()?.InputFeatures()?;
let index = self.find(id, &input_features)?;
let input = input_features.GetAt(index)?;
// TODO: Support other tensor types. Only FP32 is supported right now.
match tensor.ty {
crate::wit::types::TensorType::Fp32 => {}
_ => unimplemented!(),
}
// TODO: this is quite unsafe and probably incorrect--will the slice
// still be around by the time the binding is used?!
let data = unsafe {
std::slice::from_raw_parts(
tensor.data.as_ptr() as *const f32,
tensor.data.len() / size_of::<f32>(),
)
};
self.binding.Bind(
&input.Name()?,
&TensorFloat::CreateFromArray(
&input.cast::<TensorFeatureDescriptor>()?.Shape()?,
data,
)?,
)?;
let inspectable = to_inspectable(tensor)?;
self.binding.Bind(&input.Name()?, &inspectable)?;
Ok(())
}
@ -175,23 +157,21 @@ impl BackendExecutionContext for WinMLExecutionContext {
if let Some(result) = &self.result {
let output_features = self.session.Model()?.OutputFeatures()?;
let index = self.find(id, &output_features)?;
let output = output_features.GetAt(index)?;
// TODO: this only handles FP32!
let tensor = result
.Outputs()?
.Lookup(&output.Name()?)?
.cast::<TensorFloat>()?;
let dimensions = dimensions_as_u32(&tensor.Shape()?)?;
let view = tensor.GetAsVectorView()?;
let mut data = Vec::with_capacity(view.Size()? as usize * size_of::<f32>());
for f in view.into_iter() {
data.extend(f.to_le_bytes());
}
Ok(Tensor {
ty: TensorType::Fp32,
dimensions,
data,
})
let output_feature = output_features.GetAt(index)?;
let tensor_kind = match output_feature.Kind()? {
windows::AI::MachineLearning::LearningModelFeatureKind::Tensor => output_feature
.cast::<TensorFeatureDescriptor>()?
.TensorKind()?,
_ => unimplemented!(
"the WinML backend only supports tensors, found: {:?}",
output_feature.Kind()
),
};
let tensor = to_tensor(
result.Outputs()?.Lookup(&output_feature.Name()?)?,
tensor_kind,
);
tensor
} else {
return Err(BackendError::BackendAccess(anyhow::Error::msg(
"Output is not ready.",
@ -226,3 +206,181 @@ fn convert_i64(i: i64) -> Result<u32, BackendError> {
anyhow::anyhow!("unable to convert dimension to u32: {d}").into()
})
}
// Convert from wasi-nn tensor to WinML tensor.
fn to_inspectable(tensor: &Tensor) -> Result<IInspectable, Error> {
let shape = IVectorView::<i64>::try_from(
tensor
.dimensions
.iter()
.map(|&x| x as i64)
.collect::<Vec<i64>>(),
)?;
match tensor.ty {
// f16 is not official supported by stable version of Rust. https://github.com/rust-lang/rust/issues/116909
// Therefore we create TensorFloat16Bit from f32 array. https://microsoft.github.io/windows-docs-rs/doc/windows/AI/MachineLearning/struct.TensorFloat16Bit.html#method.CreateFromArray
TensorType::Fp16 => unsafe {
let data = std::slice::from_raw_parts(
tensor.data.as_ptr().cast::<f32>(),
tensor.data.len() / size_of::<f32>(),
);
check_alignment::<f32>(data);
TensorFloat16Bit::CreateFromArray(&shape, data)?.cast::<IInspectable>()
},
TensorType::Fp32 => unsafe {
let data = std::slice::from_raw_parts(
tensor.data.as_ptr().cast::<f32>(),
tensor.data.len() / size_of::<f32>(),
);
check_alignment::<f32>(data);
TensorFloat::CreateFromArray(&shape, data)?.cast::<IInspectable>()
},
TensorType::I64 => unsafe {
let data = std::slice::from_raw_parts(
tensor.data.as_ptr().cast::<i64>(),
tensor.data.len() / size_of::<i64>(),
);
check_alignment::<i64>(data);
TensorInt64Bit::CreateFromArray(&shape, data)?.cast::<IInspectable>()
},
_ => unimplemented!(),
}
}
// Convert from WinML tensor to wasi-nn tensor.
fn to_tensor(inspectable: IInspectable, tensor_kind: TensorKind) -> Result<Tensor, BackendError> {
let tensor = match tensor_kind {
TensorKind::Float16 => {
let output_tensor = inspectable.cast::<TensorFloat16Bit>()?;
let dimensions = dimensions_as_u32(&output_tensor.Shape()?)?;
let view = output_tensor.GetAsVectorView()?;
// TODO: Move to f16 when it's available in stable.
let data = view.into_iter().flat_map(f32::to_le_bytes).collect();
Tensor {
ty: TensorType::Fp16,
dimensions,
data,
}
}
TensorKind::Float => {
let output_tensor = inspectable.cast::<TensorFloat>()?;
let dimensions = dimensions_as_u32(&output_tensor.Shape()?)?;
let view = output_tensor.GetAsVectorView()?;
let data = view.into_iter().flat_map(f32::to_le_bytes).collect();
Tensor {
ty: TensorType::Fp32,
dimensions,
data,
}
}
TensorKind::Int64 => {
let output_tensor = inspectable.cast::<TensorInt64Bit>()?;
let dimensions = dimensions_as_u32(&output_tensor.Shape()?)?;
let view = output_tensor.GetAsVectorView()?;
let data = view.into_iter().flat_map(i64::to_le_bytes).collect();
Tensor {
ty: TensorType::I64,
dimensions,
data,
}
}
_ => unimplemented!(),
};
Ok(tensor)
}
fn check_alignment<T>(data: &[T]) {
let (prefix, _slice, suffix) = unsafe { data.align_to::<T>() };
assert!(
prefix.is_empty() && suffix.is_empty(),
"Data is not aligned to {:?}'s alignment",
std::any::type_name::<T>()
);
}
#[cfg(test)]
mod tests {
use super::*;
// Unit tests for different data types. Convert from wasi-nn tensor to WinML tensor and back.
#[test]
fn fp16() {
let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let buffer = data
.iter()
.map(|f| f.to_ne_bytes())
.flatten()
.collect::<Vec<u8>>();
let buffer_copy = buffer.clone();
let tensor = Tensor {
ty: TensorType::Fp16,
dimensions: vec![2, 3],
data: buffer_copy,
};
let inspectable = to_inspectable(&tensor);
assert!(inspectable.is_ok());
let winml_tensor = inspectable
.as_ref()
.unwrap()
.cast::<TensorFloat16Bit>()
.unwrap();
let view = winml_tensor.GetAsVectorView().unwrap();
assert_eq!(view.into_iter().collect::<Vec<f32>>(), data);
// Convert back.
let t = to_tensor(inspectable.unwrap(), TensorKind::Float16);
assert!(t.as_ref().is_ok());
assert_eq!(t.unwrap(), tensor);
}
#[test]
fn fp32() {
let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let mut buffer = Vec::with_capacity(data.len() * size_of::<f32>());
for f in &data {
buffer.extend(f.to_ne_bytes());
}
let buffer_copy = buffer.clone();
let tensor = Tensor {
ty: TensorType::Fp32,
dimensions: vec![2, 3],
data: buffer_copy,
};
let inspectable = to_inspectable(&tensor);
assert!(inspectable.is_ok());
let winml_tensor = inspectable.as_ref().unwrap().cast::<TensorFloat>().unwrap();
let view = winml_tensor.GetAsVectorView().unwrap();
assert_eq!(view.into_iter().collect::<Vec<f32>>(), data);
// Convert back.
let t = to_tensor(inspectable.unwrap(), TensorKind::Float);
assert!(t.as_ref().is_ok());
assert_eq!(t.unwrap(), tensor);
}
#[test]
fn i64() {
let data = vec![6i64, 5, 4, 3, 2, 1];
let mut buffer = Vec::with_capacity(data.len() * size_of::<i64>());
for f in &data {
buffer.extend(f.to_ne_bytes());
}
let buffer_copy = buffer.clone();
let tensor = Tensor {
ty: TensorType::I64,
dimensions: vec![1, 6],
data: buffer_copy,
};
let inspectable = to_inspectable(&tensor);
assert!(inspectable.is_ok());
let winml_tensor = inspectable
.as_ref()
.unwrap()
.cast::<TensorInt64Bit>()
.unwrap();
let view = winml_tensor.GetAsVectorView().unwrap();
assert_eq!(view.into_iter().collect::<Vec<i64>>(), data);
// Convert back.
let t = to_tensor(inspectable.unwrap(), TensorKind::Int64);
assert!(t.as_ref().is_ok());
assert_eq!(t.unwrap(), tensor);
}
}

2
crates/wasi-nn/src/lib.rs

@ -68,7 +68,7 @@ impl std::ops::Deref for Graph {
/// Eventually, this may be defined in each backend as they gain the ability to
/// hold tensors on various devices (TODO:
/// https://github.com/WebAssembly/wasi-nn/pull/70).
#[derive(Clone)]
#[derive(Clone, PartialEq)]
pub struct Tensor {
dimensions: Vec<u32>,
ty: wit::TensorType,

35
supply-chain/audits.toml

@ -995,6 +995,11 @@ criteria = "safe-to-run"
version = "0.21.0"
notes = "This crate has no dependencies, no build.rs, and contains no unsafe code."
[[audits.base64]]
who = "Andrew Brown <andrew.brown@intel.com>"
criteria = "safe-to-deploy"
delta = "0.21.3 -> 0.22.1"
[[audits.bitflags]]
who = "Jamey Sharp <jsharp@fastly.com>"
criteria = "safe-to-deploy"
@ -2408,6 +2413,12 @@ who = "Pat Hickey <phickey@fastly.com>"
criteria = "safe-to-deploy"
delta = "0.21.0 -> 0.21.6"
[[audits.rustls]]
who = "Andrew Brown <andrew.brown@intel.com>"
criteria = "safe-to-deploy"
delta = "0.22.4 -> 0.23.7"
notes = "No new unsafe code."
[[audits.rustls-webpki]]
who = "Pat Hickey <phickey@fastly.com>"
criteria = "safe-to-deploy"
@ -2727,6 +2738,12 @@ criteria = "safe-to-deploy"
version = "2.9.1"
notes = "As advertised, the crate is a blocking HTTP client library; it uses no `unsafe`. Security-conscious users might want to audit its dependencies for crypto-related functionality (e.g., TLS)."
[[audits.ureq]]
who = "Andrew Brown <andrew.brown@intel.com>"
criteria = "safe-to-deploy"
delta = "2.9.6 -> 2.10.0"
notes = "No `unsafe` changes; this audit observed mainly license and documentation changes."
[[audits.url]]
who = "Alex Crichton <alex@alexcrichton.com>"
criteria = "safe-to-deploy"
@ -3419,6 +3436,12 @@ criteria = "safe-to-run"
delta = "4.4.0 -> 5.0.0"
notes = "Only one `unsafe` block, it's what a `which` crate is expected to be."
[[audits.windows-implement]]
who = "Andrew Brown <andrew.brown@intel.com>"
criteria = "safe-to-deploy"
version = "0.52.0"
notes = "Procedural macros for accessing COM interfaces; necessarily `unsafe` but safety rationale is clearly documented."
[[audits.winx]]
who = "Dan Gohman <dev@sunfishcode.online>"
criteria = "safe-to-deploy"
@ -3491,6 +3514,12 @@ criteria = "safe-to-deploy"
version = "1.2.0"
notes = "This crate contains `unsafe` calls to libc `extattr_*` functions as one would expect from the crate's purpose."
[[audits.xattr]]
who = "Andrew Brown <andrew.brown@intel.com>"
criteria = "safe-to-deploy"
delta = "1.2.0 -> 1.3.1"
notes = "Minor changes to MacOS-specific code."
[[audits.zstd]]
who = "Alex Crichton <alex@alexcrichton.com>"
criteria = "safe-to-deploy"
@ -4002,6 +4031,12 @@ user-id = 64539 # Kenny Kerr (kennykerr)
start = "2021-11-15"
end = "2025-01-02"
[[trusted.windows-interface]]
criteria = "safe-to-deploy"
user-id = 64539 # Kenny Kerr (kennykerr)
start = "2022-02-18"
end = "2025-08-07"
[[trusted.windows-sys]]
criteria = "safe-to-deploy"
user-id = 64539 # Kenny Kerr (kennykerr)

36
supply-chain/imports.lock

@ -949,6 +949,13 @@ user-id = 1
user-login = "alexcrichton"
user-name = "Alex Crichton"
[[publisher.tar]]
version = "0.4.41"
when = "2024-06-04"
user-id = 1
user-login = "alexcrichton"
user-name = "Alex Crichton"
[[publisher.target-lexicon]]
version = "0.12.16"
when = "2024-07-30"
@ -1322,6 +1329,13 @@ user-id = 64539
user-login = "kennykerr"
user-name = "Kenny Kerr"
[[publisher.windows-interface]]
version = "0.52.0"
when = "2023-11-15"
user-id = 64539
user-login = "kennykerr"
user-name = "Kenny Kerr"
[[publisher.windows-sys]]
version = "0.48.0"
when = "2023-03-31"
@ -1676,6 +1690,21 @@ criteria = "safe-to-deploy"
version = "0.9.4"
aggregated-from = "https://chromium.googlesource.com/chromiumos/third_party/rust_crates/+/main/cargo-vet/audits.toml?format=TEXT"
[[audits.isrg.audits.base64]]
who = "Tim Geoghegan <timg@letsencrypt.org>"
criteria = "safe-to-deploy"
delta = "0.21.0 -> 0.21.1"
[[audits.isrg.audits.base64]]
who = "Brandon Pitman <bran@bran.land>"
criteria = "safe-to-deploy"
delta = "0.21.1 -> 0.21.2"
[[audits.isrg.audits.base64]]
who = "David Cook <dcook@divviup.org>"
criteria = "safe-to-deploy"
delta = "0.21.2 -> 0.21.3"
[[audits.isrg.audits.block-buffer]]
who = "David Cook <dcook@divviup.org>"
criteria = "safe-to-deploy"
@ -1886,6 +1915,13 @@ criteria = "safe-to-deploy"
delta = "1.9.0 -> 2.0.0"
aggregated-from = "https://hg.mozilla.org/mozilla-central/raw-file/tip/supply-chain/audits.toml"
[[audits.mozilla.audits.flate2]]
who = "Alex Franchuk <afranchuk@mozilla.com>"
criteria = "safe-to-deploy"
delta = "1.0.28 -> 1.0.30"
notes = "Some new unsafe code, however it has been verified and there are unit tests as well."
aggregated-from = "https://hg.mozilla.org/mozilla-central/raw-file/tip/supply-chain/audits.toml"
[[audits.mozilla.audits.fnv]]
who = "Bobby Holley <bobbyholley@gmail.com>"
criteria = "safe-to-deploy"

Loading…
Cancel
Save