tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

Tensor.h (2958B)


      1 /* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
      2 /* vim:set ts=2 sw=2 sts=2 et cindent: */
      3 /* This Source Code Form is subject to the terms of the Mozilla Public
      4 * License, v. 2.0. If a copy of the MPL was not distributed with this
      5 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
      6 
      7 #ifndef DOM_TENSOR_H_
      8 #define DOM_TENSOR_H_
      9 
     10 #include "js/TypeDecls.h"
     11 #include "mozilla/ErrorResult.h"
     12 #include "mozilla/dom/BindingDeclarations.h"
     13 #include "mozilla/dom/ONNXBinding.h"
     14 #include "mozilla/dom/onnxruntime_c_api.h"
     15 #include "nsCycleCollectionParticipant.h"
     16 #include "nsIGlobalObject.h"
     17 #include "nsWrapperCache.h"
     18 
     19 namespace mozilla::dom {
     20 class Promise;
     21 
     22 class Tensor final : public nsISupports, public nsWrapperCache {
     23 public:
     24  NS_DECL_CYCLE_COLLECTING_ISUPPORTS
     25  NS_DECL_CYCLE_COLLECTION_WRAPPERCACHE_CLASS(Tensor)
     26 
     27 public:
     28  // Used when created from js using a regular js array, containing numbers.
     29  Tensor(const GlobalObject& global, const nsACString& type,
     30         const nsTArray<uint8_t>& data, const Sequence<int32_t>& dims);
     31  // Used when created from JS, e.g. input tensor, with a type array (it can be
     32  // of any type)
     33  Tensor(const GlobalObject& global, const nsACString& type,
     34         const ArrayBufferView& data, const Sequence<int32_t>& dims);
     35  // Used when created from C++, e.g. output tensor
     36  Tensor(const GlobalObject& aGlobal, ONNXTensorElementDataType aType,
     37         nsTArray<uint8_t> aData, nsTArray<int64_t> aDims);
     38  static already_AddRefed<Tensor> Constructor(
     39      const GlobalObject& global, const nsACString& type,
     40      const ArrayBufferViewOrAnySequence& data, const Sequence<int32_t>& dims,
     41      ErrorResult& aRv);
     42 
     43 protected:
     44  ~Tensor() = default;
     45 
     46 public:
     47  nsIGlobalObject* GetParentObject() const { return mGlobal; };
     48  JSObject* WrapObject(JSContext* aCx,
     49                       JS::Handle<JSObject*> aGivenProto) override;
     50  void GetDims(nsTArray<int32_t>& aRetVal);
     51  void SetDims(const nsTArray<int32_t>& aVal);
     52  void GetType(nsCString& aRetVal) const;
     53  void GetData(JSContext* cx, JS::MutableHandle<JSObject*> aRetVal) const;
     54  TensorDataLocation Location() const;
     55  already_AddRefed<Promise> GetData(const Optional<bool>& releaseData);
     56 
     57  void Dispose();
     58  uint8_t* Data() { return mData.Elements(); }
     59  size_t Size() { return mData.Length(); }
     60  int32_t* Dims() { return mDims.Elements(); }
     61  size_t DimsSize() { return mDims.Length(); }
     62 
     63  ONNXTensorElementDataType Type() const;
     64  nsCString TypeString() const;
     65  nsLiteralCString ONNXTypeToString(ONNXTensorElementDataType aType) const;
     66  nsCString ToString() const;
     67  static ONNXTensorElementDataType StringToONNXDataType(
     68      const nsACString& aString);
     69  static size_t DataTypeSize(ONNXTensorElementDataType aType);
     70 
     71 private:
     72  nsCOMPtr<nsIGlobalObject> mGlobal;
     73  nsCString mType;
     74  nsTArray<uint8_t> mData;
     75  nsTArray<int32_t> mDims;
     76 };
     77 
     78 }  // namespace mozilla::dom
     79 
     80 #endif  // DOM_TENSOR_H_