diff --git a/crates/wasi-nn/src/backend/winml.rs b/crates/wasi-nn/src/backend/winml.rs index 5940b901e6..e11761f867 100644 --- a/crates/wasi-nn/src/backend/winml.rs +++ b/crates/wasi-nn/src/backend/winml.rs @@ -3,7 +3,7 @@ use super::{BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner}; use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor}; use crate::{ExecutionContext, Graph}; -use std::{fs::File, io::Read, path::Path}; +use std::{fs::File, io::Read, mem::size_of, path::Path}; use windows::core::{ComInterface, HSTRING}; use windows::Storage::Streams::{ DataWriter, InMemoryRandomAccessStream, RandomAccessStreamReference, @@ -134,7 +134,6 @@ impl BackendExecutionContext for WinMLExecutionContext { let output_name = self.session.Model()?.OutputFeatures()?.GetAt(index)?; let output_name_hstring = output_name.Name()?; - // Print results. let vector_view = self .result .as_ref() @@ -144,14 +143,15 @@ impl BackendExecutionContext for WinMLExecutionContext { .cast::()? .GetAsVectorView()?; let output: Vec = vector_view.into_iter().collect(); + let len_to_copy = output.len() * size_of::(); unsafe { - destination.copy_from_slice(std::slice::from_raw_parts( + destination[..len_to_copy].copy_from_slice(std::slice::from_raw_parts( output.as_ptr() as *const u8, - output.len() * 4, + len_to_copy, )); } - Ok(0) + Ok(len_to_copy as u32) } }