createTensorWithDataList static method

OrtValueTensor createTensorWithDataList(
  1. List data, [
  2. List<int>? shape
])

Implementation

static OrtValueTensor createTensorWithDataList(List data,
    [List<int>? shape]) {
  shape ??= data.shape;
  final element = data.element();
  var dataType = ONNXTensorElementDataType.undefined;
  ffi.Pointer<ffi.Void> dataPtr = ffi.nullptr;
  int dataSize = 0;
  int dataByteCount = 0;
  if (element is Uint8List) {
    final flattenData = data.flatten<int>();
    dataSize = flattenData.length;
    dataType = ONNXTensorElementDataType.uint8;
    dataPtr = (calloc<ffi.Uint8>(dataSize)
          ..asTypedList(dataSize).setRange(0, dataSize, flattenData))
        .cast();
    dataByteCount = dataSize;
  } else if (element is Int8List) {
    final flattenData = data.flatten<int>();
    dataSize = flattenData.length;
    dataType = ONNXTensorElementDataType.int8;
    dataPtr = (calloc<ffi.Int8>(dataSize)
          ..asTypedList(dataSize).setRange(0, dataSize, flattenData))
        .cast();
    dataByteCount = dataSize;
  } else if (element is Uint16List) {
    final flattenData = data.flatten<int>();
    dataSize = flattenData.length;
    dataType = ONNXTensorElementDataType.uint16;
    dataPtr = (calloc<ffi.Uint16>(dataSize)
          ..asTypedList(dataSize).setRange(0, dataSize, flattenData))
        .cast();
    dataByteCount = dataSize * 2;
  } else if (element is Int16List) {
    final flattenData = data.flatten<int>();
    dataSize = flattenData.length;
    dataType = ONNXTensorElementDataType.int16;
    dataPtr = (calloc<ffi.Int16>(dataSize)
          ..asTypedList(dataSize).setRange(0, dataSize, flattenData))
        .cast();
    dataByteCount = dataSize * 2;
  } else if (element is Uint32List) {
    final flattenData = data.flatten<int>();
    dataSize = flattenData.length;
    dataType = ONNXTensorElementDataType.uint32;
    dataPtr = (calloc<ffi.Uint32>(dataSize)
          ..asTypedList(dataSize).setRange(0, dataSize, flattenData))
        .cast();
    dataByteCount = dataSize * 4;
  } else if (element is Int32List) {
    final flattenData = data.flatten<int>();
    dataSize = flattenData.length;
    dataType = ONNXTensorElementDataType.int32;
    dataPtr = (calloc<ffi.Int32>(dataSize)
          ..asTypedList(dataSize).setRange(0, dataSize, flattenData))
        .cast();
    dataByteCount = dataSize * 4;
  } else if (element is Uint64List) {
    final flattenData = data.flatten<int>();
    dataSize = flattenData.length;
    dataType = ONNXTensorElementDataType.uint64;
    dataPtr = (calloc<ffi.Uint64>(dataSize)
          ..asTypedList(dataSize).setRange(0, dataSize, flattenData))
        .cast();
    dataByteCount = dataSize * 8;
  } else if (element is Int64List || element is int) {
    final flattenData = data.flatten<int>();
    dataSize = flattenData.length;
    dataType = ONNXTensorElementDataType.int64;
    dataPtr = (calloc<ffi.Int64>(dataSize)
          ..asTypedList(dataSize).setRange(0, dataSize, flattenData))
        .cast();
    dataByteCount = dataSize * 8;
  } else if (element is Float32List) {
    final flattenData = data.flatten<double>();
    dataSize = flattenData.length;
    dataType = ONNXTensorElementDataType.float;
    dataPtr = (calloc<ffi.Float>(dataSize)
          ..asTypedList(dataSize).setRange(0, dataSize, flattenData))
        .cast();
    dataByteCount = dataSize * 4;
  } else if (element is Float64List || element is double) {
    final flattenData = data.flatten<double>();
    dataSize = flattenData.length;
    dataType = ONNXTensorElementDataType.double;
    dataPtr = (calloc<ffi.Double>(dataSize)
          ..asTypedList(dataSize).setRange(0, dataSize, flattenData))
        .cast();
    dataByteCount = dataSize * 8;
  } else if (element is bool) {
    final flattenData = data.flatten<bool>();
    dataSize = flattenData.length;
    dataType = ONNXTensorElementDataType.bool;
    final ptr = calloc<ffi.Bool>(dataSize);
    for (int i = 0; i < dataSize; ++i) {
      ptr[i] = flattenData[i];
    }
    dataPtr = ptr.cast();
    dataByteCount = dataSize;
  } else if (element is String) {
    return _createTensorWithStringList(data.cast<String>(), shape);
  } else {
    throw Exception('Invalid inputTensor element type.');
  }

  final shapeSize = shape.length;
  final shapePtr = calloc<ffi.Int64>(shapeSize);
  shapePtr.asTypedList(shapeSize).setRange(0, shapeSize, shape);

  final ortMemoryInfoPtrPtr = calloc<ffi.Pointer<bg.OrtMemoryInfo>>();
  var statusPtr = OrtEnv.instance.ortApiPtr.ref.AllocatorGetInfo.asFunction<
          bg.OrtStatusPtr Function(ffi.Pointer<bg.OrtAllocator>,
              ffi.Pointer<ffi.Pointer<bg.OrtMemoryInfo>>)>()(
      OrtAllocator.instance.ptr, ortMemoryInfoPtrPtr);
  OrtStatus.checkOrtStatus(statusPtr);
  // or
  // OrtEnv.instance.ortApiPtr.ref.CreateCpuMemoryInfo.asFunction<
  //         bg.OrtStatusPtr Function(
  //             int, int, ffi.Pointer<ffi.Pointer<bg.OrtMemoryInfo>>)>()(
  //     bg.OrtAllocatorType.OrtDeviceAllocator,
  //     bg.OrtMemType.OrtMemTypeCPU,
  //     ortMemoryInfoPtrPtr);
  final ortMemoryInfoPtr = ortMemoryInfoPtrPtr.value;
  final ortValuePtrPtr = calloc<ffi.Pointer<bg.OrtValue>>();
  statusPtr = OrtEnv.instance.ortApiPtr.ref.CreateTensorWithDataAsOrtValue
          .asFunction<
              bg.OrtStatusPtr Function(
                  ffi.Pointer<bg.OrtMemoryInfo>,
                  ffi.Pointer<ffi.Void>,
                  int,
                  ffi.Pointer<ffi.Int64>,
                  int,
                  int,
                  ffi.Pointer<ffi.Pointer<bg.OrtValue>>)>()(
      ortMemoryInfoPtr,
      dataPtr,
      dataByteCount,
      shapePtr,
      shapeSize,
      dataType.value,
      ortValuePtrPtr);
  OrtStatus.checkOrtStatus(statusPtr);
  final ortValuePtr = ortValuePtrPtr.value;
  calloc.free(shapePtr);
  calloc.free(ortValuePtrPtr);
  calloc.free(ortMemoryInfoPtrPtr);
  return OrtValueTensor(ortValuePtr, dataPtr);
}