|
|
@ -3,7 +3,7 @@ |
|
|
|
use super::{BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner}; |
|
|
|
use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor}; |
|
|
|
use crate::{ExecutionContext, Graph}; |
|
|
|
use std::{fs::File, io::Read, path::Path}; |
|
|
|
use std::{fs::File, io::Read, mem::size_of, path::Path}; |
|
|
|
use windows::core::{ComInterface, HSTRING}; |
|
|
|
use windows::Storage::Streams::{ |
|
|
|
DataWriter, InMemoryRandomAccessStream, RandomAccessStreamReference, |
|
|
@ -134,7 +134,6 @@ impl BackendExecutionContext for WinMLExecutionContext { |
|
|
|
let output_name = self.session.Model()?.OutputFeatures()?.GetAt(index)?; |
|
|
|
let output_name_hstring = output_name.Name()?; |
|
|
|
|
|
|
|
// Print results.
|
|
|
|
let vector_view = self |
|
|
|
.result |
|
|
|
.as_ref() |
|
|
@ -144,14 +143,15 @@ impl BackendExecutionContext for WinMLExecutionContext { |
|
|
|
.cast::<TensorFloat>()? |
|
|
|
.GetAsVectorView()?; |
|
|
|
let output: Vec<f32> = vector_view.into_iter().collect(); |
|
|
|
let len_to_copy = output.len() * size_of::<f32>(); |
|
|
|
unsafe { |
|
|
|
destination.copy_from_slice(std::slice::from_raw_parts( |
|
|
|
destination[..len_to_copy].copy_from_slice(std::slice::from_raw_parts( |
|
|
|
output.as_ptr() as *const u8, |
|
|
|
output.len() * 4, |
|
|
|
len_to_copy, |
|
|
|
)); |
|
|
|
} |
|
|
|
|
|
|
|
Ok(0) |
|
|
|
Ok(len_to_copy as u32) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|