tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

InferenceSession.cpp (23698B)


      1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
      2 /* vim: set ts=8 sts=2 et sw=2 tw=80: */
      3 /* This Source Code Form is subject to the terms of the Mozilla Public
      4 * License, v. 2.0. If a copy of the MPL was not distributed with this
      5 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
      6 #include "mozilla/dom/InferenceSession.h"
      7 
      8 #include <prlink.h>
      9 
     10 #include <thread>
     11 
     12 #include "ErrorList.h"
     13 #include "GeckoProfiler.h"
     14 #include "fmt/format.h"
     15 #include "mozilla/Attributes.h"
     16 #include "mozilla/FileUtils.h"
     17 #include "mozilla/Logging.h"
     18 #include "mozilla/RefPtr.h"
     19 #include "mozilla/ScopeExit.h"
     20 #include "mozilla/dom/BindingDeclarations.h"
     21 #include "mozilla/dom/ContentChild.h"
     22 #include "mozilla/dom/ONNXBinding.h"
     23 #include "mozilla/dom/Promise.h"
     24 #include "mozilla/dom/ScriptSettings.h"
     25 #include "mozilla/dom/Tensor.h"
     26 #include "nsString.h"
     27 #include "nsXPCOMPrivate.h"
     28 mozilla::LazyLogModule gONNXLog("GeckoMLONNXNative");
     29 #define LOGV(fmt, ...) \
     30  MOZ_LOG_FMT(gONNXLog, LogLevel::Verbose, fmt, ##__VA_ARGS__)
     31 #define LOGD(fmt, ...) \
     32  MOZ_LOG_FMT(gONNXLog, LogLevel::Debug, fmt, ##__VA_ARGS__)
     33 #define LOGE(fmt, ...) \
     34  MOZ_LOG_FMT(gONNXLog, LogLevel::Error, fmt, ##__VA_ARGS__)
     35 
     36 namespace mozilla::dom {
     37 
     38 // Initialized when the first InferenceSession is initialized,
     39 // valid until the shutdown of the inference process.
     40 static OrtEnv* sEnv = nullptr;
     41 static OrtApi* sAPI = nullptr;
     42 
     43 class AutoOrtStatus {
     44 public:
     45  MOZ_IMPLICIT AutoOrtStatus(OrtStatus* aStatus = nullptr) : mStatus(aStatus) {
     46    MOZ_ASSERT(sAPI);
     47  }
     48  ~AutoOrtStatus() {
     49    if (mStatus) {
     50      sAPI->ReleaseStatus(mStatus);
     51    }
     52  }
     53  explicit operator bool() const { return !!mStatus; }
     54  const char* Message() const { return sAPI->GetErrorMessage(mStatus); }
     55  OrtStatus* mStatus;
     56 };
     57 
     58 NS_IMPL_CYCLE_COLLECTION_WRAPPERCACHE(InferenceSession);
     59 
     60 NS_IMPL_CYCLE_COLLECTING_ADDREF(InferenceSession)
     61 NS_IMPL_CYCLE_COLLECTING_RELEASE(InferenceSession)
     62 
     63 NS_INTERFACE_MAP_BEGIN_CYCLE_COLLECTION(InferenceSession)
     64  NS_WRAPPERCACHE_INTERFACE_MAP_ENTRY
     65  NS_INTERFACE_MAP_ENTRY(nsISupports)
     66 NS_INTERFACE_MAP_END
     67 
     68 #define DYLIB_PATH "onnxruntime"
     69 
     70 OrtSessionOptions* ToOrtSessionOption(
     71    const InferenceSessionSessionOptions& aOptions) {
     72  OrtSessionOptions* sessionOptions = nullptr;
     73  AutoOrtStatus status = sAPI->CreateSessionOptions(&sessionOptions);
     74  if (status) {
     75    LOGD("CreateSessionOptions error: {}", status.Message());
     76    return nullptr;
     77  }
     78 #define SET_BOOL_ON_SESSION(x)                                       \
     79  do {                                                               \
     80    if (aOptions.mEnable##x) {                                       \
     81      status = sAPI->Enable##x(sessionOptions);                      \
     82    } else {                                                         \
     83      status = sAPI->Disable##x(sessionOptions);                     \
     84    }                                                                \
     85    if (status) {                                                    \
     86      LOGE("Setter {} (val: {}) error: {}", #x, aOptions.mEnable##x, \
     87           status.Message());                                        \
     88      return nullptr;                                                \
     89    }                                                                \
     90  } while (0)
     91 
     92  LOGD("CpuMemArena: {}", aOptions.mEnableCpuMemArena);
     93  SET_BOOL_ON_SESSION(CpuMemArena);
     94  LOGD("MemPattern: {}", aOptions.mEnableMemPattern);
     95  SET_BOOL_ON_SESSION(MemPattern);
     96 
     97 #define CALL_API(x, ...)                                           \
     98  do {                                                             \
     99    status = sAPI->x(sessionOptions, __VA_ARGS__);                 \
    100    if (status) {                                                  \
    101      LOGD("SetSessionExecutionMode error: {}", status.Message()); \
    102      return nullptr;                                              \
    103    }                                                              \
    104  } while (0);
    105 
    106  LOGD("Session execution mode: {}", aOptions.mExecutionMode);
    107  CALL_API(SetSessionExecutionMode,
    108           aOptions.mExecutionMode.EqualsASCII("parallel")
    109               ? ExecutionMode::ORT_PARALLEL
    110               : ExecutionMode::ORT_SEQUENTIAL);
    111 
    112  LOGD("Inter op num threads: {}", aOptions.mInterOpNumThreads);
    113  CALL_API(SetInterOpNumThreads, aOptions.mInterOpNumThreads);
    114  LOGD("Inter op num threads: {}", aOptions.mIntraOpNumThreads);
    115  CALL_API(SetInterOpNumThreads, aOptions.mIntraOpNumThreads);
    116  CALL_API(SetSessionLogId, aOptions.mLogId.get());
    117  CALL_API(SetSessionLogSeverityLevel, aOptions.mLogSeverityLevel);
    118  CALL_API(SetSessionLogVerbosityLevel, aOptions.mLogVerbosityLevel);
    119  PathString path;
    120 #ifdef XP_WIN
    121  path = NS_ConvertUTF8toUTF16(aOptions.mOptimizedModelFilePath.get());
    122 
    123 #else
    124  path = aOptions.mOptimizedModelFilePath.get();
    125 #endif
    126  CALL_API(SetOptimizedModelFilePath, path.get());
    127  GraphOptimizationLevel level = ORT_ENABLE_BASIC;
    128  LOGD("Graph optimization level: {}", aOptions.mGraphOptimizationLevel);
    129  if (aOptions.mGraphOptimizationLevel.EqualsASCII("all")) {
    130    level = ORT_ENABLE_ALL;
    131  } else if (aOptions.mGraphOptimizationLevel.EqualsASCII("basic")) {
    132    level = ORT_ENABLE_BASIC;
    133  } else if (aOptions.mGraphOptimizationLevel.EqualsASCII("extended")) {
    134    level = ORT_ENABLE_EXTENDED;
    135  } else if (aOptions.mGraphOptimizationLevel.EqualsASCII("all")) {
    136    level = ORT_ENABLE_ALL;
    137  }
    138  CALL_API(SetSessionGraphOptimizationLevel, level);
    139 
    140  if (aOptions.mFreeDimensionOverrides.WasPassed()) {
    141    for (const auto& rec : aOptions.mFreeDimensionOverrides.Value().Entries()) {
    142      LOGD("Adding free dimension override for key: {}, value: {}", rec.mKey,
    143           rec.mValue);
    144      CALL_API(AddFreeDimensionOverride, rec.mKey.get(), rec.mValue);
    145    }
    146  }
    147 
    148  return sessionOptions;
    149 }  // namespace mozilla::dom
    150 
    151 OrtApi* GetOrtAPI() {
    152 #ifdef XP_WIN
    153  PathString path = GetLibraryFilePathname(LXUL_DLL, (PRFuncPtr)&GetOrtAPI);
    154 #else
    155  PathString path = GetLibraryFilePathname(XUL_DLL, (PRFuncPtr)&GetOrtAPI);
    156 #endif
    157  if (path.IsEmpty()) {
    158    LOGE("Could not locate XUL library when loading onnxruntime");
    159    return nullptr;
    160  }
    161  nsCOMPtr<nsIFile> libFile;
    162  if (NS_FAILED(NS_NewPathStringLocalFile(path, getter_AddRefs(libFile)))) {
    163    LOGE("Could not get path string for local file when loading onnxruntime");
    164    return nullptr;
    165  }
    166 
    167  if (NS_FAILED(libFile->SetNativeLeafName(
    168          MOZ_DLL_PREFIX "onnxruntime" MOZ_DLL_SUFFIX ""_ns))) {
    169    LOGE("SetNativeLeavName error when loading onnxruntime");
    170    return nullptr;
    171  }
    172  PRLibSpec lspec;
    173  PathString nativePath = libFile->NativePath();
    174 #ifdef XP_WIN
    175  lspec.type = PR_LibSpec_PathnameU;
    176  lspec.value.pathname_u = nativePath.get();
    177 #else
    178  lspec.type = PR_LibSpec_Pathname;
    179  lspec.value.pathname = nativePath.get();
    180 #endif
    181 #ifdef MOZ_WIDGET_ANDROID
    182  PRLibrary* handle = PR_LoadLibraryWithFlags(lspec, PR_LD_NOW | PR_LD_GLOBAL);
    183 #else
    184  PRLibrary* handle = PR_LoadLibraryWithFlags(lspec, PR_LD_NOW | PR_LD_LOCAL);
    185 #endif
    186  if (!handle) {
    187    PRErrorCode code = PR_GetError();
    188    const char* msg = PR_ErrorToString(code, PR_LANGUAGE_I_DEFAULT);
    189    LOGE("Couldn't load onnxruntime shared library ({:x}: {})", PR_GetOSError(),
    190         msg);
    191    return nullptr;
    192  }
    193 
    194  using OrtApiBaseFn = const OrtApiBase* (*)();
    195  auto ortGetApiBaseFnPtr =
    196      reinterpret_cast<OrtApiBaseFn>(PR_FindSymbol(handle, "OrtGetApiBase"));
    197  if (!ortGetApiBaseFnPtr) {
    198    LOGE("Couldn't fetch symbol OrgGetApiBase");
    199    PR_UnloadLibrary(handle);
    200    return nullptr;
    201  }
    202  const OrtApiBase* apiBase = ortGetApiBaseFnPtr();
    203  OrtApi* ortAPI = const_cast<OrtApi*>(apiBase->GetApi(ORT_API_VERSION));
    204  if (!ortAPI) {
    205    LOGE("Couldn't get ahold of the OrtApi pointer");
    206    PR_UnloadLibrary(handle);
    207    return nullptr;
    208  }
    209 
    210  return ortAPI;
    211 }
    212 
    213 bool InferenceSession::InInferenceProcess(JSContext*, JSObject*) {
    214  if (!ContentChild::GetSingleton()) {
    215    return false;
    216  }
    217  return ContentChild::GetSingleton()->GetRemoteType().Equals(
    218      INFERENCE_REMOTE_TYPE);
    219 }
    220 
    221 nsCString InferenceSessionSessionOptionsToString(
    222    const InferenceSessionSessionOptions& aOptions) {
    223  return nsFmtCString(
    224      FMT_STRING("EnableCpuMemArena: {}, "
    225                 "EnableGraphCapture: {}, "
    226                 "EnableMemPattern: {}, "
    227                 "EnableProfiling: {}, "
    228                 "ExecutionMode: {}, "
    229                 "ExecutionProviders: {}, "
    230                 "Extra: {}, "
    231                 "FreeDimensionOverrides: {}, "
    232                 "GraphOptimizationLevel: {}, "
    233                 "InterOpNumThreads: {}, "
    234                 "IntraOpNumThreads: {}, "
    235                 "LogId: {}, "
    236                 "LogSeverityLevel: {}, "
    237                 "LogVerbosityLevel: {}, "
    238                 "OptimizedModelFilePath: {}, "
    239                 "PreferredOutputLocation: {}, "
    240                 "ProfileFilePrefix: {}"),
    241      aOptions.mEnableCpuMemArena, aOptions.mEnableGraphCapture,
    242      aOptions.mEnableMemPattern, aOptions.mEnableProfiling,
    243      aOptions.mExecutionMode,
    244      aOptions.mExecutionProviders.WasPassed() ? "<passed>" : "<not passed>",
    245      aOptions.mExtra.WasPassed() ? "<passed>" : "<not passed>",
    246      aOptions.mFreeDimensionOverrides.WasPassed() ? "<passed>"
    247                                                   : "<not passed>",
    248      aOptions.mGraphOptimizationLevel, aOptions.mInterOpNumThreads,
    249      aOptions.mIntraOpNumThreads, aOptions.mLogId, aOptions.mLogSeverityLevel,
    250      aOptions.mLogVerbosityLevel, aOptions.mOptimizedModelFilePath,
    251      aOptions.mPreferredOutputLocation.WasPassed() ? "<passed>"
    252                                                    : "<not passed>",
    253      aOptions.mProfileFilePrefix);
    254 }
    255 
    256 OrtCustomThreadHandle WrapProfilerRegister(void* options, void (*func)(void*),
    257                                           void* param) {
    258  // We don't use options for now
    259  MOZ_ASSERT(!options);
    260  auto wrapperFunc = [func](void* param) {
    261    PROFILER_REGISTER_THREAD("onnx_worker");
    262    LOGD("Starting thread");
    263    (static_cast<OrtThreadWorkerFn>(func))(param);
    264  };
    265 
    266  auto* t = new std::thread(wrapperFunc, param);
    267 
    268  return reinterpret_cast<OrtCustomThreadHandle>(t);
    269 }
    270 
    271 void WrapProfilerUnregister(OrtCustomThreadHandle thread) {
    272  LOGD("Joining thread");
    273  std::thread* t = (std::thread*)thread;
    274  t->join();
    275  delete t;
    276 }
    277 
    278 RefPtr<Promise> InferenceSession::Create(
    279    GlobalObject& aGlobal, const UTF8StringOrUint8Array& aUriOrBuffer,
    280    const InferenceSessionSessionOptions& aOptions, ErrorResult& aRv) {
    281  LOGD("{}", __PRETTY_FUNCTION__);
    282  nsCOMPtr<nsIGlobalObject> global = do_QueryInterface(aGlobal.GetAsSupports());
    283  RefPtr<Promise> p = Promise::Create(global, aRv);
    284  RefPtr<InferenceSession> session = new InferenceSession(aGlobal);
    285  session->Init(p, aUriOrBuffer, aOptions);
    286  return p;
    287 }
    288 
    289 void InferenceSession::Init(const RefPtr<Promise>& aPromise,
    290                            const UTF8StringOrUint8Array& aUriOrBuffer,
    291                            const InferenceSessionSessionOptions& aOptions) {
    292  LOGD("InferenceSession::Init called with a {}",
    293       aUriOrBuffer.IsUTF8String() ? "string" : "buffer");
    294 
    295  if (!sEnv) {
    296    sAPI = GetOrtAPI();
    297    if (!sAPI) {
    298      LOGD("Couldn't get ahold of ORT API");
    299      aPromise->MaybeReject(NS_ERROR_FAILURE);
    300      return;
    301    }
    302    OrtThreadingOptions* threadingOptions;
    303 
    304    AutoOrtStatus status = sAPI->CreateThreadingOptions(&threadingOptions);
    305    if (status) {
    306      LOGD("CreateThreadingOptions error");
    307      aPromise->MaybeRejectWithUndefined();
    308      return;
    309    }
    310    status = sAPI->SetGlobalCustomCreateThreadFn(threadingOptions,
    311                                                 WrapProfilerRegister);
    312    if (status) {
    313      LOGD("SetGlobalCustomCreateThreadFn error");
    314      aPromise->MaybeRejectWithUndefined();
    315      return;
    316    }
    317 
    318    status = sAPI->SetGlobalCustomJoinThreadFn(threadingOptions,
    319                                               WrapProfilerUnregister);
    320    if (status) {
    321      LOGD("SetGlobalCustomJoinThreadFn error");
    322      aPromise->MaybeRejectWithUndefined();
    323      return;
    324    }
    325 
    326    status = sAPI->SetGlobalInterOpNumThreads(
    327        threadingOptions, AssertedCast<int>(aOptions.mInterOpNumThreads));
    328    if (status) {
    329      LOGD("SetGlobalInterOpNumThreads error");
    330      aPromise->MaybeRejectWithUndefined();
    331      return;
    332    }
    333 
    334    status = sAPI->SetGlobalIntraOpNumThreads(
    335        threadingOptions, AssertedCast<int>(aOptions.mIntraOpNumThreads));
    336    if (status) {
    337      LOGD("SetGlobalIntraOpNumThreads error");
    338      aPromise->MaybeRejectWithUndefined();
    339      return;
    340    }
    341 
    342    status = sAPI->SetGlobalDenormalAsZero(threadingOptions);
    343    if (status) {
    344      LOGD("SetGlobalDenormalsAreZero error");
    345      aPromise->MaybeRejectWithUndefined();
    346      return;
    347    }
    348 
    349    status = sAPI->SetGlobalSpinControl(threadingOptions, 0);
    350    if (status) {
    351      LOGD("SetGlobalSpinControl error");
    352      aPromise->MaybeRejectWithUndefined();
    353      return;
    354    }
    355 
    356    status = sAPI->CreateEnvWithGlobalThreadPools(
    357        ORT_LOGGING_LEVEL_FATAL, "my_env", threadingOptions, &sEnv);
    358    if (status) {
    359      LOGD("CreateEnv error: {}", status.Message());
    360      MOZ_CRASH("Init CreateEnv");
    361    }
    362    LOGD("CreateEnv OK");
    363  }
    364 
    365  mOptions = ToOrtSessionOption(aOptions);
    366  AutoOrtStatus status = sAPI->DisablePerSessionThreads(mOptions);
    367  if (status) {
    368    LOGD("DisablePerSessionThreads error: {}", status.Message());
    369  }
    370 
    371  OrtSession* session = nullptr;
    372  if (aUriOrBuffer.IsUTF8String()) {
    373    LOGE("Passing a URI to a model isn't implemented, pass the bytes directly");
    374    aPromise->MaybeRejectWithNotSupportedError("Not implemented");
    375    return;
    376  }
    377  aUriOrBuffer.GetAsUint8Array().ProcessFixedData(
    378      [&](const Span<uint8_t>& aData) {
    379        AUTO_PROFILER_MARKER_UNTYPED("CreateSessionFromArray", ML_SETUP, {});
    380        status = sAPI->CreateSessionFromArray(
    381            sEnv, aData.data(), aData.Length(), mOptions, &session);
    382      });
    383  if (status) {
    384    LOGD("CreateSession error: {}", status.Message());
    385    MOZ_CRASH("CreateSession error");
    386  }
    387  LOGD("Successfully created ONNX Runtime session.");
    388  mSession = session;
    389  aPromise->MaybeResolve(this);
    390 }
    391 
    392 nsCString FeedsToString(
    393    const Record<nsCString, OwningNonNull<Tensor>>& aFeeds) {
    394  nsCString rv;
    395  for (const auto& input : aFeeds.Entries()) {
    396    rv.AppendFmt("[{}: {}],", input.mKey, input.mValue->ToString().get());
    397  }
    398  return rv;
    399 }
    400 
    401 already_AddRefed<Promise> InferenceSession::Run(
    402    const Record<nsCString, OwningNonNull<Tensor>>& feeds,
    403    const InferenceSessionRunOptions& options, ErrorResult& aRv) {
    404  LOGD("{} {}", __PRETTY_FUNCTION__, fmt::ptr(this));
    405  RefPtr<Promise> p = Promise::Create(GetParentObject(), aRv);
    406 
    407  if (!mSession) {
    408    LOGD("runInference: session pointer is null.");
    409  }
    410  if (!sAPI || !sEnv) {
    411    LOGD("Need API {} and Env {} here", fmt::ptr(sAPI), fmt::ptr(sEnv));
    412    MOZ_CRASH("In run");
    413    p->MaybeReject(NS_ERROR_UNEXPECTED);
    414    return p.forget();
    415  }
    416 
    417  OrtMemoryInfo* memoryInfo = nullptr;
    418  auto guard = MakeScopeExit([&] { sAPI->ReleaseMemoryInfo(memoryInfo); });
    419  AutoOrtStatus status = sAPI->CreateCpuMemoryInfo(
    420      OrtArenaAllocator, OrtMemTypeDefault, &memoryInfo);
    421  if (status) {
    422    LOGD("CreateCpuMemoryInfo failed: {}", status.Message());
    423    p->MaybeReject(NS_ERROR_UNEXPECTED);
    424    return p.forget();
    425  }
    426 
    427  LOGD("Inputs:");
    428  nsTArray<OrtValue*> inputValues;
    429  auto scope = MakeScopeExit([&] {
    430    for (auto& v : inputValues) {
    431      sAPI->ReleaseValue(v);
    432    }
    433  });
    434  for (const auto& input : feeds.Entries()) {
    435    OrtValue* inputOrt = nullptr;
    436    const auto& val = input.mValue;
    437    AutoTArray<int64_t, 16> dims64;
    438    for (uint32_t i = 0; i < val->DimsSize(); i++) {
    439      dims64.AppendElement(val->Dims()[i]);
    440    }
    441    LOGD("{}: {}", input.mKey.get(), val->ToString().get());
    442    AUTO_PROFILER_MARKER_FMT("CreateTensorWithDataAsOrtValue", ML_INFERENCE, {},
    443                             "{}", input.mKey.get());
    444    status = sAPI->CreateTensorWithDataAsOrtValue(
    445        memoryInfo, val->Data(), val->Size(), dims64.Elements(),
    446        val->DimsSize(), val->Type(), &inputOrt);
    447    if (status) {
    448      LOGD("CreateTensorWithDataAsOrtValue for input_ids {} failed: {}",
    449           input.mKey, status.Message());
    450      p->MaybeReject(NS_ERROR_UNEXPECTED);
    451      return p.forget();
    452    }
    453 
    454    inputValues.AppendElement(inputOrt);
    455  }
    456 
    457  nsTArray<nsCString> inputNames;
    458  nsTArray<const char*> inputNamesPtrs;
    459  GetNames(inputNames, NameDirection::Input);
    460  for (const auto& name : inputNames) {
    461    inputNamesPtrs.AppendElement(name.get());
    462  }
    463  nsTArray<nsCString> outputNames;
    464  nsTArray<const char*> outputNamesPtrs;
    465  GetNames(outputNames, NameDirection::Output);
    466  LOGD("Outputs names:");
    467  for (const auto& name : outputNames) {
    468    LOGD("- {}", name.get());
    469    outputNamesPtrs.AppendElement(name.get());
    470  }
    471  nsTArray<OrtValue*> outputs;
    472  outputs.SetLength(outputNames.Length());
    473  for (uint32_t i = 0; i < outputNames.Length(); i++) {
    474    outputs[i] = nullptr;
    475  }
    476  OrtValue** ptr = outputs.Elements();
    477 
    478  {
    479    AUTO_PROFILER_MARKER_UNTYPED("Ort::Run", ML_INFERENCE, {});
    480    status = sAPI->Run(mSession,
    481                       nullptr,  // Run options
    482                       inputNamesPtrs.Elements(), inputValues.Elements(),
    483                       inputNamesPtrs.Length(), outputNamesPtrs.Elements(),
    484                       outputNamesPtrs.Length(), ptr);
    485  }
    486  if (status) {
    487    LOGD("Session Run failed: {}", status.Message());
    488    p->MaybeReject(NS_ERROR_UNEXPECTED);
    489    return p.forget();
    490  }
    491 
    492  Record<nsCString, OwningNonNull<Tensor>> rv;
    493  for (size_t i = 0; i < outputs.Length(); i++) {
    494    TimeStamp start = TimeStamp::Now();
    495    // outputData has the same lifetime as output[i]. For now, the actual data
    496    // is copied into the Tensor object below. This copy will be removed in the
    497    // future.
    498    uint8_t* outputData = nullptr;
    499    status = sAPI->GetTensorMutableData(outputs[i], (void**)&outputData);
    500    if (status) {
    501      LOGD("GetTensorMutableData failed: {}", status.Message());
    502      p->MaybeReject(NS_ERROR_UNEXPECTED);
    503      return p.forget();
    504    }
    505 
    506    OrtTypeInfo* typeInfo;
    507    status = sAPI->SessionGetOutputTypeInfo(mSession, i, &typeInfo);
    508    if (status) {
    509      LOGD("GetOutputTypeInfo failed: {}", status.Message());
    510      p->MaybeReject(NS_ERROR_UNEXPECTED);
    511      return p.forget();
    512    }
    513 
    514    OrtTensorTypeAndShapeInfo* typeAndShapeInfo;
    515    status = sAPI->GetTensorTypeAndShape(outputs[i], &typeAndShapeInfo);
    516    if (status) {
    517      LOGD("GetTensorTypeAndShape failed: {}", status.Message());
    518      p->MaybeReject(NS_ERROR_UNEXPECTED);
    519      return p.forget();
    520    }
    521 
    522    ONNXType type;
    523    status = sAPI->GetOnnxTypeFromTypeInfo(typeInfo, &type);
    524    if (status) {
    525      LOGD("GetOnnxTypeFromTypeInfo failed: {}", status.Message());
    526      p->MaybeReject(NS_ERROR_UNEXPECTED);
    527      return p.forget();
    528    }
    529    MOZ_ASSERT(type == ONNX_TYPE_TENSOR);
    530 
    531    ONNXTensorElementDataType outputTensorType;
    532    status = sAPI->GetTensorElementType(typeAndShapeInfo, &outputTensorType);
    533    if (status) {
    534      LOGD("GetTensorElementType failed: {}", status.Message());
    535      p->MaybeReject(NS_ERROR_UNEXPECTED);
    536      return p.forget();
    537    }
    538 
    539    size_t dimCount;
    540    status = sAPI->GetDimensionsCount(typeAndShapeInfo, &dimCount);
    541    if (status) {
    542      LOGD("GetDimensionsCount failed: {}", status.Message());
    543      p->MaybeReject(NS_ERROR_UNEXPECTED);
    544      return p.forget();
    545    }
    546 
    547    AutoTArray<int64_t, 16> dims;
    548    dims.SetLength(dimCount);
    549    status = sAPI->GetDimensions(typeAndShapeInfo, dims.Elements(), dimCount);
    550 
    551    size_t outputSize = 1;
    552    for (size_t d = 0; d < dimCount; ++d) {
    553      outputSize *= dims[d];
    554    }
    555 
    556    // TODO skip this copy by using CreateTensorWithDataAsOrtValue
    557    nsTArray<uint8_t> output;
    558    output.AppendElements(outputData,
    559                          outputSize * Tensor::DataTypeSize(outputTensorType));
    560    GlobalObject global(mCtx, GetParentObject()->GetGlobalJSObject());
    561    auto outputTensor = MakeRefPtr<Tensor>(global, outputTensorType,
    562                                           std::move(output), std::move(dims));
    563    AUTO_PROFILER_MARKER_FMT(
    564        "Output tensor", ML_INFERENCE,
    565        MarkerOptions(MarkerTiming::IntervalUntilNowFrom(start)), "{}: {}",
    566        outputNames[i], outputTensor->ToString().get());
    567 
    568    sAPI->ReleaseTensorTypeAndShapeInfo(typeAndShapeInfo);
    569 
    570    auto elem = rv.Entries().AppendElement();
    571    elem->mKey = outputNames[i];
    572    elem->mValue = outputTensor;
    573  }
    574 
    575  p->MaybeResolve(rv);
    576 
    577  return p.forget();
    578 }
    579 
    580 void InferenceSession::Destroy() {
    581  LOGD("{} {}", __PRETTY_FUNCTION__, fmt::ptr(this));
    582  if (mSession) {
    583    sAPI->ReleaseSession(mSession);
    584  }
    585  if (mOptions) {
    586    sAPI->ReleaseSessionOptions(mOptions);
    587  }
    588 }
    589 
    590 already_AddRefed<Promise> InferenceSession::ReleaseSession() {
    591  LOGD("{} {}", __PRETTY_FUNCTION__, fmt::ptr(this));
    592 
    593  Destroy();
    594  RefPtr<Promise> p = Promise::CreateInfallible(mGlobal);
    595  p->MaybeResolveWithUndefined();
    596  return p.forget();
    597 }
    598 
    599 void InferenceSession::StartProfiling() {
    600  LOGD("{} {}", __PRETTY_FUNCTION__, fmt::ptr(this));
    601 }
    602 
    603 void InferenceSession::EndProfiling() {
    604  LOGD("{} {}", __PRETTY_FUNCTION__, fmt::ptr(this));
    605 }
    606 
    607 void InferenceSession::GetNames(nsTArray<nsCString>& aRetVal,
    608                                NameDirection aDirection) const {
    609  const char* NameDirection2String[2] = {"Input", "Output"};
    610 
    611  if (!mSession) {
    612    return;
    613  }
    614  size_t nameCount = 0;
    615  AutoOrtStatus status;
    616  if (aDirection == NameDirection::Input) {
    617    status = sAPI->SessionGetInputCount(mSession, &nameCount);
    618  } else {
    619    status = sAPI->SessionGetOutputCount(mSession, &nameCount);
    620  }
    621  if (status) {
    622    LOGD("SessionGet{}Count failed: ",
    623         NameDirection2String[static_cast<int>(aDirection)], status.Message());
    624    return;
    625  }
    626 
    627  OrtAllocator* allocator = nullptr;
    628  status = sAPI->GetAllocatorWithDefaultOptions(&allocator);
    629  if (status) {
    630    LOGD("GetAllocatorWithDefaultOptions failed: {}", status.Message());
    631    return;
    632  }
    633  aRetVal.SetCapacity(nameCount);
    634  for (size_t i = 0; i < nameCount; i++) {
    635    // Allocated by onnxruntiem, must be freed by AllocatorFree
    636    char* name = nullptr;
    637 
    638    if (aDirection == NameDirection::Input) {
    639      status = sAPI->SessionGetInputName(mSession, i, allocator, &name);
    640    } else {
    641      status = sAPI->SessionGetOutputName(mSession, i, allocator, &name);
    642    }
    643    if (status) {
    644      LOGD("SessionGet{}Name failed: ",
    645           NameDirection2String[static_cast<int>(aDirection)],
    646           status.Message());
    647      continue;
    648    }
    649    aRetVal.AppendElement(name);
    650    status = sAPI->AllocatorFree(allocator, name);
    651    if (status) {
    652      LOGD("AllocatorFree failed: ", status.Message());
    653      continue;
    654    }
    655  }
    656 }
    657 
    658 void InferenceSession::GetInputNames(nsTArray<nsCString>& aRetVal) const {
    659  LOGD("{} {}", __PRETTY_FUNCTION__, fmt::ptr(this));
    660  GetNames(aRetVal, NameDirection::Input);
    661  if (MOZ_LOG_TEST(gONNXLog, LogLevel::Debug)) {
    662    for (auto& name : aRetVal) {
    663      LOGD("- {}", name);
    664    }
    665  }
    666 }
    667 
    668 void InferenceSession::GetOutputNames(nsTArray<nsCString>& aRetVal) const {
    669  LOGD("{} {}", __PRETTY_FUNCTION__, fmt::ptr(this));
    670  GetNames(aRetVal, NameDirection::Output);
    671  if (MOZ_LOG_TEST(gONNXLog, LogLevel::Debug)) {
    672    for (auto& name : aRetVal) {
    673      LOGD("- {}", name);
    674    }
    675  }
    676 }
    677 
    678 JSObject* InferenceSession::WrapObject(JSContext* aCx,
    679                                       JS::Handle<JSObject*> aGivenProto) {
    680  return InferenceSession_Binding::Wrap(aCx, this, aGivenProto);
    681 }
    682 
    683 }  // namespace mozilla::dom