Tensor.h (2958B)
1 /* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ 2 /* vim:set ts=2 sw=2 sts=2 et cindent: */ 3 /* This Source Code Form is subject to the terms of the Mozilla Public 4 * License, v. 2.0. If a copy of the MPL was not distributed with this 5 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ 6 7 #ifndef DOM_TENSOR_H_ 8 #define DOM_TENSOR_H_ 9 10 #include "js/TypeDecls.h" 11 #include "mozilla/ErrorResult.h" 12 #include "mozilla/dom/BindingDeclarations.h" 13 #include "mozilla/dom/ONNXBinding.h" 14 #include "mozilla/dom/onnxruntime_c_api.h" 15 #include "nsCycleCollectionParticipant.h" 16 #include "nsIGlobalObject.h" 17 #include "nsWrapperCache.h" 18 19 namespace mozilla::dom { 20 class Promise; 21 22 class Tensor final : public nsISupports, public nsWrapperCache { 23 public: 24 NS_DECL_CYCLE_COLLECTING_ISUPPORTS 25 NS_DECL_CYCLE_COLLECTION_WRAPPERCACHE_CLASS(Tensor) 26 27 public: 28 // Used when created from js using a regular js array, containing numbers. 29 Tensor(const GlobalObject& global, const nsACString& type, 30 const nsTArray<uint8_t>& data, const Sequence<int32_t>& dims); 31 // Used when created from JS, e.g. input tensor, with a type array (it can be 32 // of any type) 33 Tensor(const GlobalObject& global, const nsACString& type, 34 const ArrayBufferView& data, const Sequence<int32_t>& dims); 35 // Used when created from C++, e.g. output tensor 36 Tensor(const GlobalObject& aGlobal, ONNXTensorElementDataType aType, 37 nsTArray<uint8_t> aData, nsTArray<int64_t> aDims); 38 static already_AddRefed<Tensor> Constructor( 39 const GlobalObject& global, const nsACString& type, 40 const ArrayBufferViewOrAnySequence& data, const Sequence<int32_t>& dims, 41 ErrorResult& aRv); 42 43 protected: 44 ~Tensor() = default; 45 46 public: 47 nsIGlobalObject* GetParentObject() const { return mGlobal; }; 48 JSObject* WrapObject(JSContext* aCx, 49 JS::Handle<JSObject*> aGivenProto) override; 50 void GetDims(nsTArray<int32_t>& aRetVal); 51 void SetDims(const nsTArray<int32_t>& aVal); 52 void GetType(nsCString& aRetVal) const; 53 void GetData(JSContext* cx, JS::MutableHandle<JSObject*> aRetVal) const; 54 TensorDataLocation Location() const; 55 already_AddRefed<Promise> GetData(const Optional<bool>& releaseData); 56 57 void Dispose(); 58 uint8_t* Data() { return mData.Elements(); } 59 size_t Size() { return mData.Length(); } 60 int32_t* Dims() { return mDims.Elements(); } 61 size_t DimsSize() { return mDims.Length(); } 62 63 ONNXTensorElementDataType Type() const; 64 nsCString TypeString() const; 65 nsLiteralCString ONNXTypeToString(ONNXTensorElementDataType aType) const; 66 nsCString ToString() const; 67 static ONNXTensorElementDataType StringToONNXDataType( 68 const nsACString& aString); 69 static size_t DataTypeSize(ONNXTensorElementDataType aType); 70 71 private: 72 nsCOMPtr<nsIGlobalObject> mGlobal; 73 nsCString mType; 74 nsTArray<uint8_t> mData; 75 nsTArray<int32_t> mDims; 76 }; 77 78 } // namespace mozilla::dom 79 80 #endif // DOM_TENSOR_H_