InferenceSession.cpp (23698B)
1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ 2 /* vim: set ts=8 sts=2 et sw=2 tw=80: */ 3 /* This Source Code Form is subject to the terms of the Mozilla Public 4 * License, v. 2.0. If a copy of the MPL was not distributed with this 5 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ 6 #include "mozilla/dom/InferenceSession.h" 7 8 #include <prlink.h> 9 10 #include <thread> 11 12 #include "ErrorList.h" 13 #include "GeckoProfiler.h" 14 #include "fmt/format.h" 15 #include "mozilla/Attributes.h" 16 #include "mozilla/FileUtils.h" 17 #include "mozilla/Logging.h" 18 #include "mozilla/RefPtr.h" 19 #include "mozilla/ScopeExit.h" 20 #include "mozilla/dom/BindingDeclarations.h" 21 #include "mozilla/dom/ContentChild.h" 22 #include "mozilla/dom/ONNXBinding.h" 23 #include "mozilla/dom/Promise.h" 24 #include "mozilla/dom/ScriptSettings.h" 25 #include "mozilla/dom/Tensor.h" 26 #include "nsString.h" 27 #include "nsXPCOMPrivate.h" 28 mozilla::LazyLogModule gONNXLog("GeckoMLONNXNative"); 29 #define LOGV(fmt, ...) \ 30 MOZ_LOG_FMT(gONNXLog, LogLevel::Verbose, fmt, ##__VA_ARGS__) 31 #define LOGD(fmt, ...) \ 32 MOZ_LOG_FMT(gONNXLog, LogLevel::Debug, fmt, ##__VA_ARGS__) 33 #define LOGE(fmt, ...) \ 34 MOZ_LOG_FMT(gONNXLog, LogLevel::Error, fmt, ##__VA_ARGS__) 35 36 namespace mozilla::dom { 37 38 // Initialized when the first InferenceSession is initialized, 39 // valid until the shutdown of the inference process. 40 static OrtEnv* sEnv = nullptr; 41 static OrtApi* sAPI = nullptr; 42 43 class AutoOrtStatus { 44 public: 45 MOZ_IMPLICIT AutoOrtStatus(OrtStatus* aStatus = nullptr) : mStatus(aStatus) { 46 MOZ_ASSERT(sAPI); 47 } 48 ~AutoOrtStatus() { 49 if (mStatus) { 50 sAPI->ReleaseStatus(mStatus); 51 } 52 } 53 explicit operator bool() const { return !!mStatus; } 54 const char* Message() const { return sAPI->GetErrorMessage(mStatus); } 55 OrtStatus* mStatus; 56 }; 57 58 NS_IMPL_CYCLE_COLLECTION_WRAPPERCACHE(InferenceSession); 59 60 NS_IMPL_CYCLE_COLLECTING_ADDREF(InferenceSession) 61 NS_IMPL_CYCLE_COLLECTING_RELEASE(InferenceSession) 62 63 NS_INTERFACE_MAP_BEGIN_CYCLE_COLLECTION(InferenceSession) 64 NS_WRAPPERCACHE_INTERFACE_MAP_ENTRY 65 NS_INTERFACE_MAP_ENTRY(nsISupports) 66 NS_INTERFACE_MAP_END 67 68 #define DYLIB_PATH "onnxruntime" 69 70 OrtSessionOptions* ToOrtSessionOption( 71 const InferenceSessionSessionOptions& aOptions) { 72 OrtSessionOptions* sessionOptions = nullptr; 73 AutoOrtStatus status = sAPI->CreateSessionOptions(&sessionOptions); 74 if (status) { 75 LOGD("CreateSessionOptions error: {}", status.Message()); 76 return nullptr; 77 } 78 #define SET_BOOL_ON_SESSION(x) \ 79 do { \ 80 if (aOptions.mEnable##x) { \ 81 status = sAPI->Enable##x(sessionOptions); \ 82 } else { \ 83 status = sAPI->Disable##x(sessionOptions); \ 84 } \ 85 if (status) { \ 86 LOGE("Setter {} (val: {}) error: {}", #x, aOptions.mEnable##x, \ 87 status.Message()); \ 88 return nullptr; \ 89 } \ 90 } while (0) 91 92 LOGD("CpuMemArena: {}", aOptions.mEnableCpuMemArena); 93 SET_BOOL_ON_SESSION(CpuMemArena); 94 LOGD("MemPattern: {}", aOptions.mEnableMemPattern); 95 SET_BOOL_ON_SESSION(MemPattern); 96 97 #define CALL_API(x, ...) \ 98 do { \ 99 status = sAPI->x(sessionOptions, __VA_ARGS__); \ 100 if (status) { \ 101 LOGD("SetSessionExecutionMode error: {}", status.Message()); \ 102 return nullptr; \ 103 } \ 104 } while (0); 105 106 LOGD("Session execution mode: {}", aOptions.mExecutionMode); 107 CALL_API(SetSessionExecutionMode, 108 aOptions.mExecutionMode.EqualsASCII("parallel") 109 ? ExecutionMode::ORT_PARALLEL 110 : ExecutionMode::ORT_SEQUENTIAL); 111 112 LOGD("Inter op num threads: {}", aOptions.mInterOpNumThreads); 113 CALL_API(SetInterOpNumThreads, aOptions.mInterOpNumThreads); 114 LOGD("Inter op num threads: {}", aOptions.mIntraOpNumThreads); 115 CALL_API(SetInterOpNumThreads, aOptions.mIntraOpNumThreads); 116 CALL_API(SetSessionLogId, aOptions.mLogId.get()); 117 CALL_API(SetSessionLogSeverityLevel, aOptions.mLogSeverityLevel); 118 CALL_API(SetSessionLogVerbosityLevel, aOptions.mLogVerbosityLevel); 119 PathString path; 120 #ifdef XP_WIN 121 path = NS_ConvertUTF8toUTF16(aOptions.mOptimizedModelFilePath.get()); 122 123 #else 124 path = aOptions.mOptimizedModelFilePath.get(); 125 #endif 126 CALL_API(SetOptimizedModelFilePath, path.get()); 127 GraphOptimizationLevel level = ORT_ENABLE_BASIC; 128 LOGD("Graph optimization level: {}", aOptions.mGraphOptimizationLevel); 129 if (aOptions.mGraphOptimizationLevel.EqualsASCII("all")) { 130 level = ORT_ENABLE_ALL; 131 } else if (aOptions.mGraphOptimizationLevel.EqualsASCII("basic")) { 132 level = ORT_ENABLE_BASIC; 133 } else if (aOptions.mGraphOptimizationLevel.EqualsASCII("extended")) { 134 level = ORT_ENABLE_EXTENDED; 135 } else if (aOptions.mGraphOptimizationLevel.EqualsASCII("all")) { 136 level = ORT_ENABLE_ALL; 137 } 138 CALL_API(SetSessionGraphOptimizationLevel, level); 139 140 if (aOptions.mFreeDimensionOverrides.WasPassed()) { 141 for (const auto& rec : aOptions.mFreeDimensionOverrides.Value().Entries()) { 142 LOGD("Adding free dimension override for key: {}, value: {}", rec.mKey, 143 rec.mValue); 144 CALL_API(AddFreeDimensionOverride, rec.mKey.get(), rec.mValue); 145 } 146 } 147 148 return sessionOptions; 149 } // namespace mozilla::dom 150 151 OrtApi* GetOrtAPI() { 152 #ifdef XP_WIN 153 PathString path = GetLibraryFilePathname(LXUL_DLL, (PRFuncPtr)&GetOrtAPI); 154 #else 155 PathString path = GetLibraryFilePathname(XUL_DLL, (PRFuncPtr)&GetOrtAPI); 156 #endif 157 if (path.IsEmpty()) { 158 LOGE("Could not locate XUL library when loading onnxruntime"); 159 return nullptr; 160 } 161 nsCOMPtr<nsIFile> libFile; 162 if (NS_FAILED(NS_NewPathStringLocalFile(path, getter_AddRefs(libFile)))) { 163 LOGE("Could not get path string for local file when loading onnxruntime"); 164 return nullptr; 165 } 166 167 if (NS_FAILED(libFile->SetNativeLeafName( 168 MOZ_DLL_PREFIX "onnxruntime" MOZ_DLL_SUFFIX ""_ns))) { 169 LOGE("SetNativeLeavName error when loading onnxruntime"); 170 return nullptr; 171 } 172 PRLibSpec lspec; 173 PathString nativePath = libFile->NativePath(); 174 #ifdef XP_WIN 175 lspec.type = PR_LibSpec_PathnameU; 176 lspec.value.pathname_u = nativePath.get(); 177 #else 178 lspec.type = PR_LibSpec_Pathname; 179 lspec.value.pathname = nativePath.get(); 180 #endif 181 #ifdef MOZ_WIDGET_ANDROID 182 PRLibrary* handle = PR_LoadLibraryWithFlags(lspec, PR_LD_NOW | PR_LD_GLOBAL); 183 #else 184 PRLibrary* handle = PR_LoadLibraryWithFlags(lspec, PR_LD_NOW | PR_LD_LOCAL); 185 #endif 186 if (!handle) { 187 PRErrorCode code = PR_GetError(); 188 const char* msg = PR_ErrorToString(code, PR_LANGUAGE_I_DEFAULT); 189 LOGE("Couldn't load onnxruntime shared library ({:x}: {})", PR_GetOSError(), 190 msg); 191 return nullptr; 192 } 193 194 using OrtApiBaseFn = const OrtApiBase* (*)(); 195 auto ortGetApiBaseFnPtr = 196 reinterpret_cast<OrtApiBaseFn>(PR_FindSymbol(handle, "OrtGetApiBase")); 197 if (!ortGetApiBaseFnPtr) { 198 LOGE("Couldn't fetch symbol OrgGetApiBase"); 199 PR_UnloadLibrary(handle); 200 return nullptr; 201 } 202 const OrtApiBase* apiBase = ortGetApiBaseFnPtr(); 203 OrtApi* ortAPI = const_cast<OrtApi*>(apiBase->GetApi(ORT_API_VERSION)); 204 if (!ortAPI) { 205 LOGE("Couldn't get ahold of the OrtApi pointer"); 206 PR_UnloadLibrary(handle); 207 return nullptr; 208 } 209 210 return ortAPI; 211 } 212 213 bool InferenceSession::InInferenceProcess(JSContext*, JSObject*) { 214 if (!ContentChild::GetSingleton()) { 215 return false; 216 } 217 return ContentChild::GetSingleton()->GetRemoteType().Equals( 218 INFERENCE_REMOTE_TYPE); 219 } 220 221 nsCString InferenceSessionSessionOptionsToString( 222 const InferenceSessionSessionOptions& aOptions) { 223 return nsFmtCString( 224 FMT_STRING("EnableCpuMemArena: {}, " 225 "EnableGraphCapture: {}, " 226 "EnableMemPattern: {}, " 227 "EnableProfiling: {}, " 228 "ExecutionMode: {}, " 229 "ExecutionProviders: {}, " 230 "Extra: {}, " 231 "FreeDimensionOverrides: {}, " 232 "GraphOptimizationLevel: {}, " 233 "InterOpNumThreads: {}, " 234 "IntraOpNumThreads: {}, " 235 "LogId: {}, " 236 "LogSeverityLevel: {}, " 237 "LogVerbosityLevel: {}, " 238 "OptimizedModelFilePath: {}, " 239 "PreferredOutputLocation: {}, " 240 "ProfileFilePrefix: {}"), 241 aOptions.mEnableCpuMemArena, aOptions.mEnableGraphCapture, 242 aOptions.mEnableMemPattern, aOptions.mEnableProfiling, 243 aOptions.mExecutionMode, 244 aOptions.mExecutionProviders.WasPassed() ? "<passed>" : "<not passed>", 245 aOptions.mExtra.WasPassed() ? "<passed>" : "<not passed>", 246 aOptions.mFreeDimensionOverrides.WasPassed() ? "<passed>" 247 : "<not passed>", 248 aOptions.mGraphOptimizationLevel, aOptions.mInterOpNumThreads, 249 aOptions.mIntraOpNumThreads, aOptions.mLogId, aOptions.mLogSeverityLevel, 250 aOptions.mLogVerbosityLevel, aOptions.mOptimizedModelFilePath, 251 aOptions.mPreferredOutputLocation.WasPassed() ? "<passed>" 252 : "<not passed>", 253 aOptions.mProfileFilePrefix); 254 } 255 256 OrtCustomThreadHandle WrapProfilerRegister(void* options, void (*func)(void*), 257 void* param) { 258 // We don't use options for now 259 MOZ_ASSERT(!options); 260 auto wrapperFunc = [func](void* param) { 261 PROFILER_REGISTER_THREAD("onnx_worker"); 262 LOGD("Starting thread"); 263 (static_cast<OrtThreadWorkerFn>(func))(param); 264 }; 265 266 auto* t = new std::thread(wrapperFunc, param); 267 268 return reinterpret_cast<OrtCustomThreadHandle>(t); 269 } 270 271 void WrapProfilerUnregister(OrtCustomThreadHandle thread) { 272 LOGD("Joining thread"); 273 std::thread* t = (std::thread*)thread; 274 t->join(); 275 delete t; 276 } 277 278 RefPtr<Promise> InferenceSession::Create( 279 GlobalObject& aGlobal, const UTF8StringOrUint8Array& aUriOrBuffer, 280 const InferenceSessionSessionOptions& aOptions, ErrorResult& aRv) { 281 LOGD("{}", __PRETTY_FUNCTION__); 282 nsCOMPtr<nsIGlobalObject> global = do_QueryInterface(aGlobal.GetAsSupports()); 283 RefPtr<Promise> p = Promise::Create(global, aRv); 284 RefPtr<InferenceSession> session = new InferenceSession(aGlobal); 285 session->Init(p, aUriOrBuffer, aOptions); 286 return p; 287 } 288 289 void InferenceSession::Init(const RefPtr<Promise>& aPromise, 290 const UTF8StringOrUint8Array& aUriOrBuffer, 291 const InferenceSessionSessionOptions& aOptions) { 292 LOGD("InferenceSession::Init called with a {}", 293 aUriOrBuffer.IsUTF8String() ? "string" : "buffer"); 294 295 if (!sEnv) { 296 sAPI = GetOrtAPI(); 297 if (!sAPI) { 298 LOGD("Couldn't get ahold of ORT API"); 299 aPromise->MaybeReject(NS_ERROR_FAILURE); 300 return; 301 } 302 OrtThreadingOptions* threadingOptions; 303 304 AutoOrtStatus status = sAPI->CreateThreadingOptions(&threadingOptions); 305 if (status) { 306 LOGD("CreateThreadingOptions error"); 307 aPromise->MaybeRejectWithUndefined(); 308 return; 309 } 310 status = sAPI->SetGlobalCustomCreateThreadFn(threadingOptions, 311 WrapProfilerRegister); 312 if (status) { 313 LOGD("SetGlobalCustomCreateThreadFn error"); 314 aPromise->MaybeRejectWithUndefined(); 315 return; 316 } 317 318 status = sAPI->SetGlobalCustomJoinThreadFn(threadingOptions, 319 WrapProfilerUnregister); 320 if (status) { 321 LOGD("SetGlobalCustomJoinThreadFn error"); 322 aPromise->MaybeRejectWithUndefined(); 323 return; 324 } 325 326 status = sAPI->SetGlobalInterOpNumThreads( 327 threadingOptions, AssertedCast<int>(aOptions.mInterOpNumThreads)); 328 if (status) { 329 LOGD("SetGlobalInterOpNumThreads error"); 330 aPromise->MaybeRejectWithUndefined(); 331 return; 332 } 333 334 status = sAPI->SetGlobalIntraOpNumThreads( 335 threadingOptions, AssertedCast<int>(aOptions.mIntraOpNumThreads)); 336 if (status) { 337 LOGD("SetGlobalIntraOpNumThreads error"); 338 aPromise->MaybeRejectWithUndefined(); 339 return; 340 } 341 342 status = sAPI->SetGlobalDenormalAsZero(threadingOptions); 343 if (status) { 344 LOGD("SetGlobalDenormalsAreZero error"); 345 aPromise->MaybeRejectWithUndefined(); 346 return; 347 } 348 349 status = sAPI->SetGlobalSpinControl(threadingOptions, 0); 350 if (status) { 351 LOGD("SetGlobalSpinControl error"); 352 aPromise->MaybeRejectWithUndefined(); 353 return; 354 } 355 356 status = sAPI->CreateEnvWithGlobalThreadPools( 357 ORT_LOGGING_LEVEL_FATAL, "my_env", threadingOptions, &sEnv); 358 if (status) { 359 LOGD("CreateEnv error: {}", status.Message()); 360 MOZ_CRASH("Init CreateEnv"); 361 } 362 LOGD("CreateEnv OK"); 363 } 364 365 mOptions = ToOrtSessionOption(aOptions); 366 AutoOrtStatus status = sAPI->DisablePerSessionThreads(mOptions); 367 if (status) { 368 LOGD("DisablePerSessionThreads error: {}", status.Message()); 369 } 370 371 OrtSession* session = nullptr; 372 if (aUriOrBuffer.IsUTF8String()) { 373 LOGE("Passing a URI to a model isn't implemented, pass the bytes directly"); 374 aPromise->MaybeRejectWithNotSupportedError("Not implemented"); 375 return; 376 } 377 aUriOrBuffer.GetAsUint8Array().ProcessFixedData( 378 [&](const Span<uint8_t>& aData) { 379 AUTO_PROFILER_MARKER_UNTYPED("CreateSessionFromArray", ML_SETUP, {}); 380 status = sAPI->CreateSessionFromArray( 381 sEnv, aData.data(), aData.Length(), mOptions, &session); 382 }); 383 if (status) { 384 LOGD("CreateSession error: {}", status.Message()); 385 MOZ_CRASH("CreateSession error"); 386 } 387 LOGD("Successfully created ONNX Runtime session."); 388 mSession = session; 389 aPromise->MaybeResolve(this); 390 } 391 392 nsCString FeedsToString( 393 const Record<nsCString, OwningNonNull<Tensor>>& aFeeds) { 394 nsCString rv; 395 for (const auto& input : aFeeds.Entries()) { 396 rv.AppendFmt("[{}: {}],", input.mKey, input.mValue->ToString().get()); 397 } 398 return rv; 399 } 400 401 already_AddRefed<Promise> InferenceSession::Run( 402 const Record<nsCString, OwningNonNull<Tensor>>& feeds, 403 const InferenceSessionRunOptions& options, ErrorResult& aRv) { 404 LOGD("{} {}", __PRETTY_FUNCTION__, fmt::ptr(this)); 405 RefPtr<Promise> p = Promise::Create(GetParentObject(), aRv); 406 407 if (!mSession) { 408 LOGD("runInference: session pointer is null."); 409 } 410 if (!sAPI || !sEnv) { 411 LOGD("Need API {} and Env {} here", fmt::ptr(sAPI), fmt::ptr(sEnv)); 412 MOZ_CRASH("In run"); 413 p->MaybeReject(NS_ERROR_UNEXPECTED); 414 return p.forget(); 415 } 416 417 OrtMemoryInfo* memoryInfo = nullptr; 418 auto guard = MakeScopeExit([&] { sAPI->ReleaseMemoryInfo(memoryInfo); }); 419 AutoOrtStatus status = sAPI->CreateCpuMemoryInfo( 420 OrtArenaAllocator, OrtMemTypeDefault, &memoryInfo); 421 if (status) { 422 LOGD("CreateCpuMemoryInfo failed: {}", status.Message()); 423 p->MaybeReject(NS_ERROR_UNEXPECTED); 424 return p.forget(); 425 } 426 427 LOGD("Inputs:"); 428 nsTArray<OrtValue*> inputValues; 429 auto scope = MakeScopeExit([&] { 430 for (auto& v : inputValues) { 431 sAPI->ReleaseValue(v); 432 } 433 }); 434 for (const auto& input : feeds.Entries()) { 435 OrtValue* inputOrt = nullptr; 436 const auto& val = input.mValue; 437 AutoTArray<int64_t, 16> dims64; 438 for (uint32_t i = 0; i < val->DimsSize(); i++) { 439 dims64.AppendElement(val->Dims()[i]); 440 } 441 LOGD("{}: {}", input.mKey.get(), val->ToString().get()); 442 AUTO_PROFILER_MARKER_FMT("CreateTensorWithDataAsOrtValue", ML_INFERENCE, {}, 443 "{}", input.mKey.get()); 444 status = sAPI->CreateTensorWithDataAsOrtValue( 445 memoryInfo, val->Data(), val->Size(), dims64.Elements(), 446 val->DimsSize(), val->Type(), &inputOrt); 447 if (status) { 448 LOGD("CreateTensorWithDataAsOrtValue for input_ids {} failed: {}", 449 input.mKey, status.Message()); 450 p->MaybeReject(NS_ERROR_UNEXPECTED); 451 return p.forget(); 452 } 453 454 inputValues.AppendElement(inputOrt); 455 } 456 457 nsTArray<nsCString> inputNames; 458 nsTArray<const char*> inputNamesPtrs; 459 GetNames(inputNames, NameDirection::Input); 460 for (const auto& name : inputNames) { 461 inputNamesPtrs.AppendElement(name.get()); 462 } 463 nsTArray<nsCString> outputNames; 464 nsTArray<const char*> outputNamesPtrs; 465 GetNames(outputNames, NameDirection::Output); 466 LOGD("Outputs names:"); 467 for (const auto& name : outputNames) { 468 LOGD("- {}", name.get()); 469 outputNamesPtrs.AppendElement(name.get()); 470 } 471 nsTArray<OrtValue*> outputs; 472 outputs.SetLength(outputNames.Length()); 473 for (uint32_t i = 0; i < outputNames.Length(); i++) { 474 outputs[i] = nullptr; 475 } 476 OrtValue** ptr = outputs.Elements(); 477 478 { 479 AUTO_PROFILER_MARKER_UNTYPED("Ort::Run", ML_INFERENCE, {}); 480 status = sAPI->Run(mSession, 481 nullptr, // Run options 482 inputNamesPtrs.Elements(), inputValues.Elements(), 483 inputNamesPtrs.Length(), outputNamesPtrs.Elements(), 484 outputNamesPtrs.Length(), ptr); 485 } 486 if (status) { 487 LOGD("Session Run failed: {}", status.Message()); 488 p->MaybeReject(NS_ERROR_UNEXPECTED); 489 return p.forget(); 490 } 491 492 Record<nsCString, OwningNonNull<Tensor>> rv; 493 for (size_t i = 0; i < outputs.Length(); i++) { 494 TimeStamp start = TimeStamp::Now(); 495 // outputData has the same lifetime as output[i]. For now, the actual data 496 // is copied into the Tensor object below. This copy will be removed in the 497 // future. 498 uint8_t* outputData = nullptr; 499 status = sAPI->GetTensorMutableData(outputs[i], (void**)&outputData); 500 if (status) { 501 LOGD("GetTensorMutableData failed: {}", status.Message()); 502 p->MaybeReject(NS_ERROR_UNEXPECTED); 503 return p.forget(); 504 } 505 506 OrtTypeInfo* typeInfo; 507 status = sAPI->SessionGetOutputTypeInfo(mSession, i, &typeInfo); 508 if (status) { 509 LOGD("GetOutputTypeInfo failed: {}", status.Message()); 510 p->MaybeReject(NS_ERROR_UNEXPECTED); 511 return p.forget(); 512 } 513 514 OrtTensorTypeAndShapeInfo* typeAndShapeInfo; 515 status = sAPI->GetTensorTypeAndShape(outputs[i], &typeAndShapeInfo); 516 if (status) { 517 LOGD("GetTensorTypeAndShape failed: {}", status.Message()); 518 p->MaybeReject(NS_ERROR_UNEXPECTED); 519 return p.forget(); 520 } 521 522 ONNXType type; 523 status = sAPI->GetOnnxTypeFromTypeInfo(typeInfo, &type); 524 if (status) { 525 LOGD("GetOnnxTypeFromTypeInfo failed: {}", status.Message()); 526 p->MaybeReject(NS_ERROR_UNEXPECTED); 527 return p.forget(); 528 } 529 MOZ_ASSERT(type == ONNX_TYPE_TENSOR); 530 531 ONNXTensorElementDataType outputTensorType; 532 status = sAPI->GetTensorElementType(typeAndShapeInfo, &outputTensorType); 533 if (status) { 534 LOGD("GetTensorElementType failed: {}", status.Message()); 535 p->MaybeReject(NS_ERROR_UNEXPECTED); 536 return p.forget(); 537 } 538 539 size_t dimCount; 540 status = sAPI->GetDimensionsCount(typeAndShapeInfo, &dimCount); 541 if (status) { 542 LOGD("GetDimensionsCount failed: {}", status.Message()); 543 p->MaybeReject(NS_ERROR_UNEXPECTED); 544 return p.forget(); 545 } 546 547 AutoTArray<int64_t, 16> dims; 548 dims.SetLength(dimCount); 549 status = sAPI->GetDimensions(typeAndShapeInfo, dims.Elements(), dimCount); 550 551 size_t outputSize = 1; 552 for (size_t d = 0; d < dimCount; ++d) { 553 outputSize *= dims[d]; 554 } 555 556 // TODO skip this copy by using CreateTensorWithDataAsOrtValue 557 nsTArray<uint8_t> output; 558 output.AppendElements(outputData, 559 outputSize * Tensor::DataTypeSize(outputTensorType)); 560 GlobalObject global(mCtx, GetParentObject()->GetGlobalJSObject()); 561 auto outputTensor = MakeRefPtr<Tensor>(global, outputTensorType, 562 std::move(output), std::move(dims)); 563 AUTO_PROFILER_MARKER_FMT( 564 "Output tensor", ML_INFERENCE, 565 MarkerOptions(MarkerTiming::IntervalUntilNowFrom(start)), "{}: {}", 566 outputNames[i], outputTensor->ToString().get()); 567 568 sAPI->ReleaseTensorTypeAndShapeInfo(typeAndShapeInfo); 569 570 auto elem = rv.Entries().AppendElement(); 571 elem->mKey = outputNames[i]; 572 elem->mValue = outputTensor; 573 } 574 575 p->MaybeResolve(rv); 576 577 return p.forget(); 578 } 579 580 void InferenceSession::Destroy() { 581 LOGD("{} {}", __PRETTY_FUNCTION__, fmt::ptr(this)); 582 if (mSession) { 583 sAPI->ReleaseSession(mSession); 584 } 585 if (mOptions) { 586 sAPI->ReleaseSessionOptions(mOptions); 587 } 588 } 589 590 already_AddRefed<Promise> InferenceSession::ReleaseSession() { 591 LOGD("{} {}", __PRETTY_FUNCTION__, fmt::ptr(this)); 592 593 Destroy(); 594 RefPtr<Promise> p = Promise::CreateInfallible(mGlobal); 595 p->MaybeResolveWithUndefined(); 596 return p.forget(); 597 } 598 599 void InferenceSession::StartProfiling() { 600 LOGD("{} {}", __PRETTY_FUNCTION__, fmt::ptr(this)); 601 } 602 603 void InferenceSession::EndProfiling() { 604 LOGD("{} {}", __PRETTY_FUNCTION__, fmt::ptr(this)); 605 } 606 607 void InferenceSession::GetNames(nsTArray<nsCString>& aRetVal, 608 NameDirection aDirection) const { 609 const char* NameDirection2String[2] = {"Input", "Output"}; 610 611 if (!mSession) { 612 return; 613 } 614 size_t nameCount = 0; 615 AutoOrtStatus status; 616 if (aDirection == NameDirection::Input) { 617 status = sAPI->SessionGetInputCount(mSession, &nameCount); 618 } else { 619 status = sAPI->SessionGetOutputCount(mSession, &nameCount); 620 } 621 if (status) { 622 LOGD("SessionGet{}Count failed: ", 623 NameDirection2String[static_cast<int>(aDirection)], status.Message()); 624 return; 625 } 626 627 OrtAllocator* allocator = nullptr; 628 status = sAPI->GetAllocatorWithDefaultOptions(&allocator); 629 if (status) { 630 LOGD("GetAllocatorWithDefaultOptions failed: {}", status.Message()); 631 return; 632 } 633 aRetVal.SetCapacity(nameCount); 634 for (size_t i = 0; i < nameCount; i++) { 635 // Allocated by onnxruntiem, must be freed by AllocatorFree 636 char* name = nullptr; 637 638 if (aDirection == NameDirection::Input) { 639 status = sAPI->SessionGetInputName(mSession, i, allocator, &name); 640 } else { 641 status = sAPI->SessionGetOutputName(mSession, i, allocator, &name); 642 } 643 if (status) { 644 LOGD("SessionGet{}Name failed: ", 645 NameDirection2String[static_cast<int>(aDirection)], 646 status.Message()); 647 continue; 648 } 649 aRetVal.AppendElement(name); 650 status = sAPI->AllocatorFree(allocator, name); 651 if (status) { 652 LOGD("AllocatorFree failed: ", status.Message()); 653 continue; 654 } 655 } 656 } 657 658 void InferenceSession::GetInputNames(nsTArray<nsCString>& aRetVal) const { 659 LOGD("{} {}", __PRETTY_FUNCTION__, fmt::ptr(this)); 660 GetNames(aRetVal, NameDirection::Input); 661 if (MOZ_LOG_TEST(gONNXLog, LogLevel::Debug)) { 662 for (auto& name : aRetVal) { 663 LOGD("- {}", name); 664 } 665 } 666 } 667 668 void InferenceSession::GetOutputNames(nsTArray<nsCString>& aRetVal) const { 669 LOGD("{} {}", __PRETTY_FUNCTION__, fmt::ptr(this)); 670 GetNames(aRetVal, NameDirection::Output); 671 if (MOZ_LOG_TEST(gONNXLog, LogLevel::Debug)) { 672 for (auto& name : aRetVal) { 673 LOGD("- {}", name); 674 } 675 } 676 } 677 678 JSObject* InferenceSession::WrapObject(JSContext* aCx, 679 JS::Handle<JSObject*> aGivenProto) { 680 return InferenceSession_Binding::Wrap(aCx, this, aGivenProto); 681 } 682 683 } // namespace mozilla::dom