InferenceSession.h (3089B)
1 /* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ 2 /* vim:set ts=2 sw=2 sts=2 et cindent: */ 3 /* This Source Code Form is subject to the terms of the Mozilla Public 4 * License, v. 2.0. If a copy of the MPL was not distributed with this 5 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ 6 7 #ifndef DOM_INFERENCESESSION_H_ 8 #define DOM_INFERENCESESSION_H_ 9 10 #include "js/TypeDecls.h" 11 #include "mozilla/AlreadyAddRefed.h" 12 #include "mozilla/ErrorResult.h" 13 #include "mozilla/dom/BindingDeclarations.h" 14 #include "mozilla/dom/BindingUtils.h" 15 #include "mozilla/dom/IOUtilsBinding.h" 16 #include "mozilla/dom/ONNXBinding.h" 17 #include "mozilla/dom/Record.h" 18 #include "mozilla/dom/onnxruntime_c_api.h" 19 #include "nsCycleCollectionParticipant.h" 20 #include "nsIGlobalObject.h" 21 #include "nsISupports.h" 22 #include "nsWrapperCache.h" 23 24 namespace mozilla::dom { 25 OrtApi* GetOrtAPI(); 26 struct InferenceSessionRunOptions; 27 class Promise; 28 class Tensor; 29 30 class InferenceSession final : public nsISupports, public nsWrapperCache { 31 public: 32 explicit InferenceSession(GlobalObject& aGlobal) { 33 nsCOMPtr<nsIGlobalObject> global = 34 do_QueryInterface(aGlobal.GetAsSupports()); 35 mGlobal = global; 36 mCtx = aGlobal.Context(); 37 } 38 39 static bool InInferenceProcess(JSContext*, JSObject*); 40 41 protected: 42 virtual ~InferenceSession() { Destroy(); } 43 44 public: 45 NS_DECL_CYCLE_COLLECTING_ISUPPORTS; 46 NS_DECL_CYCLE_COLLECTION_WRAPPERCACHE_CLASS(InferenceSession); 47 48 static RefPtr<Promise> Create(GlobalObject& aGlobal, 49 const UTF8StringOrUint8Array& aUriOrBuffer, 50 const InferenceSessionSessionOptions& aOptions, 51 ErrorResult& aRv); 52 53 void Init(const RefPtr<Promise>& aPromise, 54 const UTF8StringOrUint8Array& aUriOrBuffer, 55 const InferenceSessionSessionOptions& aOptions); 56 57 nsIGlobalObject* GetParentObject() const { return mGlobal; }; 58 59 JSObject* WrapObject(JSContext* aCx, 60 JS::Handle<JSObject*> aGivenProto) override; 61 62 // Return a raw pointer here to avoid refcounting, but make sure it's safe 63 // (the object should be kept alive by the callee). 64 already_AddRefed<Promise> Run( 65 const Record<nsCString, OwningNonNull<Tensor>>& feeds, 66 const InferenceSessionRunOptions& options, ErrorResult& aRv); 67 68 void Destroy(); 69 70 // This implements "release()" in the JS API but needs to be renamed to 71 // avoid collliding with our AddRef/Release methods. 72 already_AddRefed<Promise> ReleaseSession(); 73 74 void StartProfiling(); 75 76 void EndProfiling(); 77 78 void GetInputNames(nsTArray<nsCString>& aRetVal) const; 79 80 void GetOutputNames(nsTArray<nsCString>& aRetVal) const; 81 82 protected: 83 enum class NameDirection { Input, Output }; 84 void GetNames(nsTArray<nsCString>& aRetVal, 85 NameDirection aDirectionInput) const; 86 nsCOMPtr<nsIGlobalObject> mGlobal; 87 JSContext* mCtx; 88 OrtSessionOptions* mOptions = nullptr; 89 OrtSession* mSession = nullptr; 90 }; 91 92 } // namespace mozilla::dom 93 94 #endif // DOM_INFERENCESESSION_H_