Browse Source

Wasi-nn WinML backend returns output size for get_output. (#8745)

This change fixes an issue that get_output always returns 0.
pull/8693/head
jianjunz 5 months ago
committed by GitHub
parent
commit
ca405bb023
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 10
      crates/wasi-nn/src/backend/winml.rs

10
crates/wasi-nn/src/backend/winml.rs

@ -3,7 +3,7 @@
use super::{BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner};
use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor};
use crate::{ExecutionContext, Graph};
use std::{fs::File, io::Read, path::Path};
use std::{fs::File, io::Read, mem::size_of, path::Path};
use windows::core::{ComInterface, HSTRING};
use windows::Storage::Streams::{
DataWriter, InMemoryRandomAccessStream, RandomAccessStreamReference,
@ -134,7 +134,6 @@ impl BackendExecutionContext for WinMLExecutionContext {
let output_name = self.session.Model()?.OutputFeatures()?.GetAt(index)?;
let output_name_hstring = output_name.Name()?;
// Print results.
let vector_view = self
.result
.as_ref()
@ -144,14 +143,15 @@ impl BackendExecutionContext for WinMLExecutionContext {
.cast::<TensorFloat>()?
.GetAsVectorView()?;
let output: Vec<f32> = vector_view.into_iter().collect();
let len_to_copy = output.len() * size_of::<f32>();
unsafe {
destination.copy_from_slice(std::slice::from_raw_parts(
destination[..len_to_copy].copy_from_slice(std::slice::from_raw_parts(
output.as_ptr() as *const u8,
output.len() * 4,
len_to_copy,
));
}
Ok(0)
Ok(len_to_copy as u32)
}
}

Loading…
Cancel
Save