tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

lstmbe.cpp (26003B)


      1 // © 2021 and later: Unicode, Inc. and others.
      2 // License & terms of use: http://www.unicode.org/copyright.html
      3 
      4 #include <complex>
      5 #include <utility>
      6 
      7 #include "unicode/utypes.h"
      8 
      9 #if !UCONFIG_NO_BREAK_ITERATION
     10 
     11 #include "brkeng.h"
     12 #include "charstr.h"
     13 #include "cmemory.h"
     14 #include "lstmbe.h"
     15 #include "putilimp.h"
     16 #include "uassert.h"
     17 #include "ubrkimpl.h"
     18 #include "uresimp.h"
     19 #include "uvectr32.h"
     20 #include "uvector.h"
     21 
     22 #include "unicode/brkiter.h"
     23 #include "unicode/resbund.h"
     24 #include "unicode/ubrk.h"
     25 #include "unicode/uniset.h"
     26 #include "unicode/ustring.h"
     27 #include "unicode/utf.h"
     28 
     29 U_NAMESPACE_BEGIN
     30 
     31 // Uncomment the following #define to debug.
     32 // #define LSTM_DEBUG 1
     33 // #define LSTM_VECTORIZER_DEBUG 1
     34 
     35 /**
     36 * Interface for reading 1D array.
     37 */
     38 class ReadArray1D {
     39 public:
     40    virtual ~ReadArray1D();
     41    virtual int32_t d1() const = 0;
     42    virtual float get(int32_t i) const = 0;
     43 
     44 #ifdef LSTM_DEBUG
     45    void print() const {
     46        printf("\n[");
     47        for (int32_t i = 0; i < d1(); i++) {
     48           printf("%0.8e ", get(i));
     49           if (i % 4 == 3) printf("\n");
     50        }
     51        printf("]\n");
     52    }
     53 #endif
     54 };
     55 
     56 ReadArray1D::~ReadArray1D()
     57 {
     58 }
     59 
     60 /**
     61 * Interface for reading 2D array.
     62 */
     63 class ReadArray2D {
     64 public:
     65    virtual ~ReadArray2D();
     66    virtual int32_t d1() const = 0;
     67    virtual int32_t d2() const = 0;
     68    virtual float get(int32_t i, int32_t j) const = 0;
     69 };
     70 
     71 ReadArray2D::~ReadArray2D()
     72 {
     73 }
     74 
     75 /**
     76 * A class to index a float array as a 1D Array without owning the pointer or
     77 * copy the data.
     78 */
     79 class ConstArray1D : public ReadArray1D {
     80 public:
     81    ConstArray1D() : data_(nullptr), d1_(0) {}
     82 
     83    ConstArray1D(const float* data, int32_t d1) : data_(data), d1_(d1) {}
     84 
     85    virtual ~ConstArray1D();
     86 
     87    // Init the object, the object does not own the data nor copy.
     88    // It is designed to directly use data from memory mapped resources.
     89    void init(const int32_t* data, int32_t d1) {
     90        U_ASSERT(IEEE_754 == 1);
     91        data_ = reinterpret_cast<const float*>(data);
     92        d1_ = d1;
     93    }
     94 
     95    // ReadArray1D methods.
     96    virtual int32_t d1() const override { return d1_; }
     97    virtual float get(int32_t i) const override {
     98        U_ASSERT(i < d1_);
     99        return data_[i];
    100    }
    101 
    102 private:
    103    const float* data_;
    104    int32_t d1_;
    105 };
    106 
    107 ConstArray1D::~ConstArray1D()
    108 {
    109 }
    110 
    111 /**
    112 * A class to index a float array as a 2D Array without owning the pointer or
    113 * copy the data.
    114 */
    115 class ConstArray2D : public ReadArray2D {
    116 public:
    117    ConstArray2D() : data_(nullptr), d1_(0), d2_(0) {}
    118 
    119    ConstArray2D(const float* data, int32_t d1, int32_t d2)
    120        : data_(data), d1_(d1), d2_(d2) {}
    121 
    122    virtual ~ConstArray2D();
    123 
    124    // Init the object, the object does not own the data nor copy.
    125    // It is designed to directly use data from memory mapped resources.
    126    void init(const int32_t* data, int32_t d1, int32_t d2) {
    127        U_ASSERT(IEEE_754 == 1);
    128        data_ = reinterpret_cast<const float*>(data);
    129        d1_ = d1;
    130        d2_ = d2;
    131    }
    132 
    133    // ReadArray2D methods.
    134    inline int32_t d1() const override { return d1_; }
    135    inline int32_t d2() const override { return d2_; }
    136    float get(int32_t i, int32_t j) const override {
    137        U_ASSERT(i < d1_);
    138        U_ASSERT(j < d2_);
    139        return data_[i * d2_ + j];
    140    }
    141 
    142    // Expose the ith row as a ConstArray1D
    143    inline ConstArray1D row(int32_t i) const {
    144        U_ASSERT(i < d1_);
    145        return ConstArray1D(data_ + i * d2_, d2_);
    146    }
    147 
    148 private:
    149    const float* data_;
    150    int32_t d1_;
    151    int32_t d2_;
    152 };
    153 
    154 ConstArray2D::~ConstArray2D()
    155 {
    156 }
    157 
    158 /**
    159 * A class to allocate data as a writable 1D array.
    160 * This is the main class implement matrix operation.
    161 */
    162 class Array1D : public ReadArray1D {
    163 public:
    164    Array1D() : memory_(nullptr), data_(nullptr), d1_(0) {}
    165    Array1D(int32_t d1, UErrorCode &status)
    166        : memory_(uprv_malloc(d1 * sizeof(float))),
    167          data_(static_cast<float*>(memory_)), d1_(d1) {
    168        if (U_SUCCESS(status)) {
    169            if (memory_ == nullptr) {
    170                status = U_MEMORY_ALLOCATION_ERROR;
    171                return;
    172            }
    173            clear();
    174        }
    175    }
    176 
    177    virtual ~Array1D();
    178 
    179    // A special constructor which does not own the memory but writeable
    180    // as a slice of an array.
    181    Array1D(float* data, int32_t d1)
    182        : memory_(nullptr), data_(data), d1_(d1) {}
    183 
    184    // ReadArray1D methods.
    185    virtual int32_t d1() const override { return d1_; }
    186    virtual float get(int32_t i) const override {
    187        U_ASSERT(i < d1_);
    188        return data_[i];
    189    }
    190 
    191    // Return the index which point to the max data in the array.
    192    inline int32_t maxIndex() const {
    193        int32_t index = 0;
    194        float max = data_[0];
    195        for (int32_t i = 1; i < d1_; i++) {
    196            if (data_[i] > max) {
    197                max = data_[i];
    198                index = i;
    199            }
    200        }
    201        return index;
    202    }
    203 
    204    // Slice part of the array to a new one.
    205    inline Array1D slice(int32_t from, int32_t size) const {
    206        U_ASSERT(from >= 0);
    207        U_ASSERT(from < d1_);
    208        U_ASSERT(from + size <= d1_);
    209        return Array1D(data_ + from, size);
    210    }
    211 
    212    // Add dot product of a 1D array and a 2D array into this one.
    213    inline Array1D& addDotProduct(const ReadArray1D& a, const ReadArray2D& b) {
    214        U_ASSERT(a.d1() == b.d1());
    215        U_ASSERT(b.d2() == d1());
    216        for (int32_t i = 0; i < d1(); i++) {
    217            for (int32_t j = 0; j < a.d1(); j++) {
    218                data_[i] += a.get(j) * b.get(j, i);
    219            }
    220        }
    221        return *this;
    222    }
    223 
    224    // Hadamard Product the values of another array of the same size into this one.
    225    inline Array1D& hadamardProduct(const ReadArray1D& a) {
    226        U_ASSERT(a.d1() == d1());
    227        for (int32_t i = 0; i < d1(); i++) {
    228            data_[i] *= a.get(i);
    229        }
    230        return *this;
    231    }
    232 
    233    // Add the Hadamard Product of two arrays of the same size into this one.
    234    inline Array1D& addHadamardProduct(const ReadArray1D& a, const ReadArray1D& b) {
    235        U_ASSERT(a.d1() == d1());
    236        U_ASSERT(b.d1() == d1());
    237        for (int32_t i = 0; i < d1(); i++) {
    238            data_[i] += a.get(i) * b.get(i);
    239        }
    240        return *this;
    241    }
    242 
    243    // Add the values of another array of the same size into this one.
    244    inline Array1D& add(const ReadArray1D& a) {
    245        U_ASSERT(a.d1() == d1());
    246        for (int32_t i = 0; i < d1(); i++) {
    247            data_[i] += a.get(i);
    248        }
    249        return *this;
    250    }
    251 
    252    // Assign the values of another array of the same size into this one.
    253    inline Array1D& assign(const ReadArray1D& a) {
    254        U_ASSERT(a.d1() == d1());
    255        for (int32_t i = 0; i < d1(); i++) {
    256            data_[i] = a.get(i);
    257        }
    258        return *this;
    259    }
    260 
    261    // Apply tanh to all the elements in the array.
    262    inline Array1D& tanh() {
    263        return tanh(*this);
    264    }
    265 
    266    // Apply tanh of a and store into this array.
    267    inline Array1D& tanh(const Array1D& a) {
    268        U_ASSERT(a.d1() == d1());
    269        for (int32_t i = 0; i < d1_; i++) {
    270            data_[i] = std::tanh(a.get(i));
    271        }
    272        return *this;
    273    }
    274 
    275    // Apply sigmoid to all the elements in the array.
    276    inline Array1D& sigmoid() {
    277        for (int32_t i = 0; i < d1_; i++) {
    278            data_[i] = 1.0f/(1.0f + expf(-data_[i]));
    279        }
    280        return *this;
    281    }
    282 
    283    inline Array1D& clear() {
    284        uprv_memset(data_, 0, d1_ * sizeof(float));
    285        return *this;
    286    }
    287 
    288 private:
    289    void* memory_;
    290    float* data_;
    291    int32_t d1_;
    292 };
    293 
    294 Array1D::~Array1D()
    295 {
    296    uprv_free(memory_);
    297 }
    298 
    299 class Array2D : public ReadArray2D {
    300 public:
    301    Array2D() : memory_(nullptr), data_(nullptr), d1_(0), d2_(0) {}
    302    Array2D(int32_t d1, int32_t d2, UErrorCode &status)
    303        : memory_(uprv_malloc(d1 * d2 * sizeof(float))),
    304          data_(static_cast<float*>(memory_)), d1_(d1), d2_(d2) {
    305        if (U_SUCCESS(status)) {
    306            if (memory_ == nullptr) {
    307                status = U_MEMORY_ALLOCATION_ERROR;
    308                return;
    309            }
    310            clear();
    311        }
    312    }
    313    virtual ~Array2D();
    314 
    315    // ReadArray2D methods.
    316    virtual int32_t d1() const override { return d1_; }
    317    virtual int32_t d2() const override { return d2_; }
    318    virtual float get(int32_t i, int32_t j) const override {
    319        U_ASSERT(i < d1_);
    320        U_ASSERT(j < d2_);
    321        return data_[i * d2_ + j];
    322    }
    323 
    324    inline Array1D row(int32_t i) const {
    325        U_ASSERT(i < d1_);
    326        return Array1D(data_ + i * d2_, d2_);
    327    }
    328 
    329    inline Array2D& clear() {
    330        uprv_memset(data_, 0, d1_ * d2_ * sizeof(float));
    331        return *this;
    332    }
    333 
    334 private:
    335    void* memory_;
    336    float* data_;
    337    int32_t d1_;
    338    int32_t d2_;
    339 };
    340 
    341 Array2D::~Array2D()
    342 {
    343    uprv_free(memory_);
    344 }
    345 
    346 typedef enum {
    347    BEGIN,
    348    INSIDE,
    349    END,
    350    SINGLE
    351 } LSTMClass;
    352 
    353 typedef enum {
    354    UNKNOWN,
    355    CODE_POINTS,
    356    GRAPHEME_CLUSTER,
    357 } EmbeddingType;
    358 
    359 struct LSTMData : public UMemory {
    360    LSTMData(UResourceBundle* rb, UErrorCode &status);
    361    ~LSTMData();
    362    UHashtable* fDict;
    363    EmbeddingType fType;
    364    const char16_t* fName;
    365    ConstArray2D fEmbedding;
    366    ConstArray2D fForwardW;
    367    ConstArray2D fForwardU;
    368    ConstArray1D fForwardB;
    369    ConstArray2D fBackwardW;
    370    ConstArray2D fBackwardU;
    371    ConstArray1D fBackwardB;
    372    ConstArray2D fOutputW;
    373    ConstArray1D fOutputB;
    374 
    375 private:
    376    UResourceBundle* fBundle;
    377 };
    378 
    379 LSTMData::LSTMData(UResourceBundle* rb, UErrorCode &status)
    380    : fDict(nullptr), fType(UNKNOWN), fName(nullptr),
    381      fBundle(rb)
    382 {
    383    if (U_FAILURE(status)) {
    384        return;
    385    }
    386    if (IEEE_754 != 1) {
    387        status = U_UNSUPPORTED_ERROR;
    388        return;
    389    }
    390    LocalUResourceBundlePointer embeddings_res(
    391        ures_getByKey(rb, "embeddings", nullptr, &status));
    392    int32_t embedding_size = ures_getInt(embeddings_res.getAlias(), &status);
    393    LocalUResourceBundlePointer hunits_res(
    394        ures_getByKey(rb, "hunits", nullptr, &status));
    395    if (U_FAILURE(status)) return;
    396    int32_t hunits = ures_getInt(hunits_res.getAlias(), &status);
    397    const char16_t* type = ures_getStringByKey(rb, "type", nullptr, &status);
    398    if (U_FAILURE(status)) return;
    399    if (u_strCompare(type, -1, u"codepoints", -1, false) == 0) {
    400        fType = CODE_POINTS;
    401    } else if (u_strCompare(type, -1, u"graphclust", -1, false) == 0) {
    402        fType = GRAPHEME_CLUSTER;
    403    }
    404    fName = ures_getStringByKey(rb, "model", nullptr, &status);
    405    LocalUResourceBundlePointer dataRes(ures_getByKey(rb, "data", nullptr, &status));
    406    if (U_FAILURE(status)) return;
    407    int32_t data_len = 0;
    408    const int32_t* data = ures_getIntVector(dataRes.getAlias(), &data_len, &status);
    409    fDict = uhash_open(uhash_hashUChars, uhash_compareUChars, nullptr, &status);
    410 
    411    StackUResourceBundle stackTempBundle;
    412    ResourceDataValue value;
    413    ures_getValueWithFallback(rb, "dict", stackTempBundle.getAlias(), value, status);
    414    ResourceArray stringArray = value.getArray(status);
    415    int32_t num_index = stringArray.getSize();
    416    if (U_FAILURE(status)) { return; }
    417 
    418    // put dict into hash
    419    int32_t stringLength;
    420    for (int32_t idx = 0; idx < num_index; idx++) {
    421        stringArray.getValue(idx, value);
    422        const char16_t* str = value.getString(stringLength, status);
    423        uhash_putiAllowZero(fDict, (void*)str, idx, &status);
    424        if (U_FAILURE(status)) return;
    425 #ifdef LSTM_VECTORIZER_DEBUG
    426        printf("Assign [");
    427        while (*str != 0x0000) {
    428            printf("U+%04x ", *str);
    429            str++;
    430        }
    431        printf("] map to %d\n", idx-1);
    432 #endif
    433    }
    434    int32_t mat1_size = (num_index + 1) * embedding_size;
    435    int32_t mat2_size = embedding_size * 4 * hunits;
    436    int32_t mat3_size = hunits * 4 * hunits;
    437    int32_t mat4_size = 4 * hunits;
    438    int32_t mat5_size = mat2_size;
    439    int32_t mat6_size = mat3_size;
    440    int32_t mat7_size = mat4_size;
    441    int32_t mat8_size = 2 * hunits * 4;
    442 #if U_DEBUG
    443    int32_t mat9_size = 4;
    444    U_ASSERT(data_len == mat1_size + mat2_size + mat3_size + mat4_size + mat5_size +
    445        mat6_size + mat7_size + mat8_size + mat9_size);
    446 #endif
    447 
    448    fEmbedding.init(data, (num_index + 1), embedding_size);
    449    data += mat1_size;
    450    fForwardW.init(data, embedding_size, 4 * hunits);
    451    data += mat2_size;
    452    fForwardU.init(data, hunits, 4 * hunits);
    453    data += mat3_size;
    454    fForwardB.init(data, 4 * hunits);
    455    data += mat4_size;
    456    fBackwardW.init(data, embedding_size, 4 * hunits);
    457    data += mat5_size;
    458    fBackwardU.init(data, hunits, 4 * hunits);
    459    data += mat6_size;
    460    fBackwardB.init(data, 4 * hunits);
    461    data += mat7_size;
    462    fOutputW.init(data, 2 * hunits, 4);
    463    data += mat8_size;
    464    fOutputB.init(data, 4);
    465 }
    466 
    467 LSTMData::~LSTMData() {
    468    uhash_close(fDict);
    469    ures_close(fBundle);
    470 }
    471 
    472 class Vectorizer : public UMemory {
    473 public:
    474    Vectorizer(UHashtable* dict) : fDict(dict) {}
    475    virtual ~Vectorizer();
    476    virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
    477                           UVector32 &offsets, UVector32 &indices,
    478                           UErrorCode &status) const = 0;
    479 protected:
    480    int32_t stringToIndex(const char16_t* str) const {
    481        UBool found = false;
    482        int32_t ret = uhash_getiAndFound(fDict, (const void*)str, &found);
    483        if (!found) {
    484            ret = fDict->count;
    485        }
    486 #ifdef LSTM_VECTORIZER_DEBUG
    487        printf("[");
    488        while (*str != 0x0000) {
    489            printf("U+%04x ", *str);
    490            str++;
    491        }
    492        printf("] map to %d\n", ret);
    493 #endif
    494        return ret;
    495    }
    496 
    497 private:
    498    UHashtable* fDict;
    499 };
    500 
    501 Vectorizer::~Vectorizer()
    502 {
    503 }
    504 
    505 class CodePointsVectorizer : public Vectorizer {
    506 public:
    507    CodePointsVectorizer(UHashtable* dict) : Vectorizer(dict) {}
    508    virtual ~CodePointsVectorizer();
    509    virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
    510                           UVector32 &offsets, UVector32 &indices,
    511                           UErrorCode &status) const override;
    512 };
    513 
    514 CodePointsVectorizer::~CodePointsVectorizer()
    515 {
    516 }
    517 
    518 void CodePointsVectorizer::vectorize(
    519    UText *text, int32_t startPos, int32_t endPos,
    520    UVector32 &offsets, UVector32 &indices, UErrorCode &status) const
    521 {
    522    if (offsets.ensureCapacity(endPos - startPos, status) &&
    523            indices.ensureCapacity(endPos - startPos, status)) {
    524        if (U_FAILURE(status)) return;
    525        utext_setNativeIndex(text, startPos);
    526        int32_t current;
    527        char16_t str[2] = {0, 0};
    528        while (U_SUCCESS(status) &&
    529               (current = static_cast<int32_t>(utext_getNativeIndex(text))) < endPos) {
    530            // Since the LSTMBreakEngine is currently only accept chars in BMP,
    531            // we can ignore the possibility of hitting supplementary code
    532            // point.
    533            str[0] = static_cast<char16_t>(utext_next32(text));
    534            U_ASSERT(!U_IS_SURROGATE(str[0]));
    535            offsets.addElement(current, status);
    536            indices.addElement(stringToIndex(str), status);
    537        }
    538    }
    539 }
    540 
    541 class GraphemeClusterVectorizer : public Vectorizer {
    542 public:
    543    GraphemeClusterVectorizer(UHashtable* dict)
    544        : Vectorizer(dict)
    545    {
    546    }
    547    virtual ~GraphemeClusterVectorizer();
    548    virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
    549                           UVector32 &offsets, UVector32 &indices,
    550                           UErrorCode &status) const override;
    551 };
    552 
    553 GraphemeClusterVectorizer::~GraphemeClusterVectorizer()
    554 {
    555 }
    556 
    557 constexpr int32_t MAX_GRAPHEME_CLSTER_LENGTH = 10;
    558 
    559 void GraphemeClusterVectorizer::vectorize(
    560    UText *text, int32_t startPos, int32_t endPos,
    561    UVector32 &offsets, UVector32 &indices, UErrorCode &status) const
    562 {
    563    if (U_FAILURE(status)) return;
    564    if (!offsets.ensureCapacity(endPos - startPos, status) ||
    565            !indices.ensureCapacity(endPos - startPos, status)) {
    566        return;
    567    }
    568    if (U_FAILURE(status)) return;
    569    LocalPointer<BreakIterator> graphemeIter(BreakIterator::createCharacterInstance(Locale(), status));
    570    if (U_FAILURE(status)) return;
    571    graphemeIter->setText(text, status);
    572    if (U_FAILURE(status)) return;
    573 
    574    if (startPos != 0) {
    575        graphemeIter->preceding(startPos);
    576    }
    577    int32_t last = startPos;
    578    int32_t current = startPos;
    579    char16_t str[MAX_GRAPHEME_CLSTER_LENGTH];
    580    while ((current = graphemeIter->next()) != BreakIterator::DONE) {
    581        if (current >= endPos) {
    582            break;
    583        }
    584        if (current > startPos) {
    585            utext_extract(text, last, current, str, MAX_GRAPHEME_CLSTER_LENGTH, &status);
    586            if (U_FAILURE(status)) return;
    587            offsets.addElement(last, status);
    588            indices.addElement(stringToIndex(str), status);
    589            if (U_FAILURE(status)) return;
    590        }
    591        last = current;
    592    }
    593    if (U_FAILURE(status) || last >= endPos) {
    594        return;
    595    }
    596    utext_extract(text, last, endPos, str, MAX_GRAPHEME_CLSTER_LENGTH, &status);
    597    if (U_SUCCESS(status)) {
    598        offsets.addElement(last, status);
    599        indices.addElement(stringToIndex(str), status);
    600    }
    601 }
    602 
    603 // Computing LSTM as stated in
    604 // https://en.wikipedia.org/wiki/Long_short-term_memory#LSTM_with_a_forget_gate
    605 // ifco is temp array allocate outside which does not need to be
    606 // input/output value but could avoid unnecessary memory alloc/free if passing
    607 // in.
    608 void compute(
    609    int32_t hunits,
    610    const ReadArray2D& W, const ReadArray2D& U, const ReadArray1D& b,
    611    const ReadArray1D& x, Array1D& h, Array1D& c,
    612    Array1D& ifco)
    613 {
    614    // ifco = x * W + h * U + b
    615    ifco.assign(b)
    616        .addDotProduct(x, W)
    617        .addDotProduct(h, U);
    618 
    619    ifco.slice(0*hunits, hunits).sigmoid();  // i: sigmod
    620    ifco.slice(1*hunits, hunits).sigmoid(); // f: sigmoid
    621    ifco.slice(2*hunits, hunits).tanh(); // c_: tanh
    622    ifco.slice(3*hunits, hunits).sigmoid(); // o: sigmod
    623 
    624    c.hadamardProduct(ifco.slice(hunits, hunits))
    625        .addHadamardProduct(ifco.slice(0, hunits), ifco.slice(2*hunits, hunits));
    626 
    627    h.tanh(c)
    628        .hadamardProduct(ifco.slice(3*hunits, hunits));
    629 }
    630 
    631 // Minimum word size
    632 static const int32_t MIN_WORD = 2;
    633 
    634 // Minimum number of characters for two words
    635 static const int32_t MIN_WORD_SPAN = MIN_WORD * 2;
    636 
    637 int32_t
    638 LSTMBreakEngine::divideUpDictionaryRange( UText *text,
    639                                                int32_t startPos,
    640                                                int32_t endPos,
    641                                                UVector32 &foundBreaks,
    642                                                UBool /* isPhraseBreaking */,
    643                                                UErrorCode& status) const {
    644    if (U_FAILURE(status)) return 0;
    645    int32_t beginFoundBreakSize = foundBreaks.size();
    646    utext_setNativeIndex(text, startPos);
    647    utext_moveIndex32(text, MIN_WORD_SPAN);
    648    if (utext_getNativeIndex(text) >= endPos) {
    649        return 0;       // Not enough characters for two words
    650    }
    651    utext_setNativeIndex(text, startPos);
    652 
    653    UVector32 offsets(status);
    654    UVector32 indices(status);
    655    if (U_FAILURE(status)) return 0;
    656    fVectorizer->vectorize(text, startPos, endPos, offsets, indices, status);
    657    if (U_FAILURE(status)) return 0;
    658    int32_t* offsetsBuf = offsets.getBuffer();
    659    int32_t* indicesBuf = indices.getBuffer();
    660 
    661    int32_t input_seq_len = indices.size();
    662    int32_t hunits = fData->fForwardU.d1();
    663 
    664    // ----- Begin of all the Array memory allocation needed for this function
    665    // Allocate temp array used inside compute()
    666    Array1D ifco(4 * hunits, status);
    667 
    668    Array1D c(hunits, status);
    669    Array1D logp(4, status);
    670 
    671    // TODO: limit size of hBackward. If input_seq_len is too big, we could
    672    // run out of memory.
    673    // Backward LSTM
    674    Array2D hBackward(input_seq_len, hunits, status);
    675 
    676    // Allocate fbRow and slice the internal array in two.
    677    Array1D fbRow(2 * hunits, status);
    678 
    679    // ----- End of all the Array memory allocation needed for this function
    680    if (U_FAILURE(status)) return 0;
    681 
    682    // To save the needed memory usage, the following is different from the
    683    // Python or ICU4X implementation. We first perform the Backward LSTM
    684    // and then merge the iteration of the forward LSTM and the output layer
    685    // together because we only neetdto remember the h[t-1] for Forward LSTM.
    686    for (int32_t i = input_seq_len - 1; i >= 0; i--) {
    687        Array1D hRow = hBackward.row(i);
    688        if (i != input_seq_len - 1) {
    689            hRow.assign(hBackward.row(i+1));
    690        }
    691 #ifdef LSTM_DEBUG
    692        printf("hRow %d\n", i);
    693        hRow.print();
    694        printf("indicesBuf[%d] = %d\n", i, indicesBuf[i]);
    695        printf("fData->fEmbedding.row(indicesBuf[%d]):\n", i);
    696        fData->fEmbedding.row(indicesBuf[i]).print();
    697 #endif  // LSTM_DEBUG
    698        compute(hunits,
    699                fData->fBackwardW, fData->fBackwardU, fData->fBackwardB,
    700                fData->fEmbedding.row(indicesBuf[i]),
    701                hRow, c, ifco);
    702    }
    703 
    704 
    705    Array1D forwardRow = fbRow.slice(0, hunits);  // point to first half of data in fbRow.
    706    Array1D backwardRow = fbRow.slice(hunits, hunits);  // point to second half of data n fbRow.
    707 
    708    // The following iteration merge the forward LSTM and the output layer
    709    // together.
    710    c.clear();  // reuse c since it is the same size.
    711    for (int32_t i = 0; i < input_seq_len; i++) {
    712 #ifdef LSTM_DEBUG
    713        printf("forwardRow %d\n", i);
    714        forwardRow.print();
    715 #endif  // LSTM_DEBUG
    716        // Forward LSTM
    717        // Calculate the result into forwardRow, which point to the data in the first half
    718        // of fbRow.
    719        compute(hunits,
    720                fData->fForwardW, fData->fForwardU, fData->fForwardB,
    721                fData->fEmbedding.row(indicesBuf[i]),
    722                forwardRow, c, ifco);
    723 
    724        // assign the data from hBackward.row(i) to second half of fbRowa.
    725        backwardRow.assign(hBackward.row(i));
    726 
    727        logp.assign(fData->fOutputB).addDotProduct(fbRow, fData->fOutputW);
    728 #ifdef LSTM_DEBUG
    729        printf("backwardRow %d\n", i);
    730        backwardRow.print();
    731        printf("logp %d\n", i);
    732        logp.print();
    733 #endif  // LSTM_DEBUG
    734 
    735        // current = argmax(logp)
    736        LSTMClass current = static_cast<LSTMClass>(logp.maxIndex());
    737        // BIES logic.
    738        if (current == BEGIN || current == SINGLE) {
    739            if (i != 0) {
    740                foundBreaks.addElement(offsetsBuf[i], status);
    741                if (U_FAILURE(status)) return 0;
    742            }
    743        }
    744    }
    745    return foundBreaks.size() - beginFoundBreakSize;
    746 }
    747 
    748 Vectorizer* createVectorizer(const LSTMData* data, UErrorCode &status) {
    749    if (U_FAILURE(status)) {
    750        return nullptr;
    751    }
    752    switch (data->fType) {
    753        case CODE_POINTS:
    754            return new CodePointsVectorizer(data->fDict);
    755            break;
    756        case GRAPHEME_CLUSTER:
    757            return new GraphemeClusterVectorizer(data->fDict);
    758            break;
    759        default:
    760            break;
    761    }
    762    UPRV_UNREACHABLE_EXIT;
    763 }
    764 
    765 LSTMBreakEngine::LSTMBreakEngine(const LSTMData* data, const UnicodeSet& set, UErrorCode &status)
    766    : DictionaryBreakEngine(), fData(data), fVectorizer(createVectorizer(fData, status))
    767 {
    768    if (U_FAILURE(status)) {
    769      fData = nullptr;  // If failure, we should not delete fData in destructor because the caller will do so.
    770      return;
    771    }
    772    setCharacters(set);
    773 }
    774 
    775 LSTMBreakEngine::~LSTMBreakEngine() {
    776    delete fData;
    777    delete fVectorizer;
    778 }
    779 
    780 const char16_t* LSTMBreakEngine::name() const {
    781    return fData->fName;
    782 }
    783 
    784 UnicodeString defaultLSTM(UScriptCode script, UErrorCode& status) {
    785    // open root from brkitr tree.
    786    UResourceBundle *b = ures_open(U_ICUDATA_BRKITR, "", &status);
    787    b = ures_getByKeyWithFallback(b, "lstm", b, &status);
    788    UnicodeString result = ures_getUnicodeStringByKey(b, uscript_getShortName(script), &status);
    789    ures_close(b);
    790    return result;
    791 }
    792 
    793 U_CAPI const LSTMData* U_EXPORT2 CreateLSTMDataForScript(UScriptCode script, UErrorCode& status)
    794 {
    795    if (script != USCRIPT_KHMER && script != USCRIPT_LAO && script != USCRIPT_MYANMAR && script != USCRIPT_THAI) {
    796        return nullptr;
    797    }
    798    UnicodeString name = defaultLSTM(script, status);
    799    if (U_FAILURE(status)) return nullptr;
    800    CharString namebuf;
    801    namebuf.appendInvariantChars(name, status).truncate(namebuf.lastIndexOf('.'));
    802 
    803    LocalUResourceBundlePointer rb(
    804        ures_openDirect(U_ICUDATA_BRKITR, namebuf.data(), &status));
    805    if (U_FAILURE(status)) return nullptr;
    806 
    807    return CreateLSTMData(rb.orphan(), status);
    808 }
    809 
    810 U_CAPI const LSTMData* U_EXPORT2 CreateLSTMData(UResourceBundle* rb, UErrorCode& status)
    811 {
    812    if (U_FAILURE(status)) {
    813        return nullptr;
    814    }
    815    const LSTMData* result = new LSTMData(rb, status);
    816    if (U_FAILURE(status)) {
    817        delete result;
    818        return nullptr;
    819    }
    820    return result;
    821 }
    822 
    823 U_CAPI const LanguageBreakEngine* U_EXPORT2
    824 CreateLSTMBreakEngine(UScriptCode script, const LSTMData* data, UErrorCode& status)
    825 {
    826    UnicodeString unicodeSetString;
    827    switch(script) {
    828        case USCRIPT_THAI:
    829            unicodeSetString = UnicodeString(u"[[:Thai:]&[:LineBreak=SA:]]");
    830            break;
    831        case USCRIPT_MYANMAR:
    832            unicodeSetString = UnicodeString(u"[[:Mymr:]&[:LineBreak=SA:]]");
    833            break;
    834        default:
    835            delete data;
    836            return nullptr;
    837    }
    838    UnicodeSet unicodeSet;
    839    unicodeSet.applyPattern(unicodeSetString, status);
    840    const LanguageBreakEngine* engine = new LSTMBreakEngine(data, unicodeSet, status);
    841    if (U_FAILURE(status) || engine == nullptr) {
    842        if (engine != nullptr) {
    843            delete engine;
    844        } else {
    845            status = U_MEMORY_ALLOCATION_ERROR;
    846        }
    847        return nullptr;
    848    }
    849    return engine;
    850 }
    851 
    852 U_CAPI void U_EXPORT2 DeleteLSTMData(const LSTMData* data)
    853 {
    854    delete data;
    855 }
    856 
    857 U_CAPI const char16_t* U_EXPORT2 LSTMDataName(const LSTMData* data)
    858 {
    859    return data->fName;
    860 }
    861 
    862 U_NAMESPACE_END
    863 
    864 #endif /* #if !UCONFIG_NO_BREAK_ITERATION */