tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

InferenceSession.h (3089B)


      1 /* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
      2 /* vim:set ts=2 sw=2 sts=2 et cindent: */
      3 /* This Source Code Form is subject to the terms of the Mozilla Public
      4 * License, v. 2.0. If a copy of the MPL was not distributed with this
      5 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
      6 
      7 #ifndef DOM_INFERENCESESSION_H_
      8 #define DOM_INFERENCESESSION_H_
      9 
     10 #include "js/TypeDecls.h"
     11 #include "mozilla/AlreadyAddRefed.h"
     12 #include "mozilla/ErrorResult.h"
     13 #include "mozilla/dom/BindingDeclarations.h"
     14 #include "mozilla/dom/BindingUtils.h"
     15 #include "mozilla/dom/IOUtilsBinding.h"
     16 #include "mozilla/dom/ONNXBinding.h"
     17 #include "mozilla/dom/Record.h"
     18 #include "mozilla/dom/onnxruntime_c_api.h"
     19 #include "nsCycleCollectionParticipant.h"
     20 #include "nsIGlobalObject.h"
     21 #include "nsISupports.h"
     22 #include "nsWrapperCache.h"
     23 
     24 namespace mozilla::dom {
     25 OrtApi* GetOrtAPI();
     26 struct InferenceSessionRunOptions;
     27 class Promise;
     28 class Tensor;
     29 
     30 class InferenceSession final : public nsISupports, public nsWrapperCache {
     31 public:
     32  explicit InferenceSession(GlobalObject& aGlobal) {
     33    nsCOMPtr<nsIGlobalObject> global =
     34        do_QueryInterface(aGlobal.GetAsSupports());
     35    mGlobal = global;
     36    mCtx = aGlobal.Context();
     37  }
     38 
     39  static bool InInferenceProcess(JSContext*, JSObject*);
     40 
     41 protected:
     42  virtual ~InferenceSession() { Destroy(); }
     43 
     44 public:
     45  NS_DECL_CYCLE_COLLECTING_ISUPPORTS;
     46  NS_DECL_CYCLE_COLLECTION_WRAPPERCACHE_CLASS(InferenceSession);
     47 
     48  static RefPtr<Promise> Create(GlobalObject& aGlobal,
     49                                const UTF8StringOrUint8Array& aUriOrBuffer,
     50                                const InferenceSessionSessionOptions& aOptions,
     51                                ErrorResult& aRv);
     52 
     53  void Init(const RefPtr<Promise>& aPromise,
     54            const UTF8StringOrUint8Array& aUriOrBuffer,
     55            const InferenceSessionSessionOptions& aOptions);
     56 
     57  nsIGlobalObject* GetParentObject() const { return mGlobal; };
     58 
     59  JSObject* WrapObject(JSContext* aCx,
     60                       JS::Handle<JSObject*> aGivenProto) override;
     61 
     62  // Return a raw pointer here to avoid refcounting, but make sure it's safe
     63  // (the object should be kept alive by the callee).
     64  already_AddRefed<Promise> Run(
     65      const Record<nsCString, OwningNonNull<Tensor>>& feeds,
     66      const InferenceSessionRunOptions& options, ErrorResult& aRv);
     67 
     68  void Destroy();
     69 
     70  // This implements "release()" in the JS API but needs to be renamed to
     71  // avoid collliding with our AddRef/Release methods.
     72  already_AddRefed<Promise> ReleaseSession();
     73 
     74  void StartProfiling();
     75 
     76  void EndProfiling();
     77 
     78  void GetInputNames(nsTArray<nsCString>& aRetVal) const;
     79 
     80  void GetOutputNames(nsTArray<nsCString>& aRetVal) const;
     81 
     82 protected:
     83  enum class NameDirection { Input, Output };
     84  void GetNames(nsTArray<nsCString>& aRetVal,
     85                NameDirection aDirectionInput) const;
     86  nsCOMPtr<nsIGlobalObject> mGlobal;
     87  JSContext* mCtx;
     88  OrtSessionOptions* mOptions = nullptr;
     89  OrtSession* mSession = nullptr;
     90 };
     91 
     92 }  // namespace mozilla::dom
     93 
     94 #endif  // DOM_INFERENCESESSION_H_