lstmbe.cpp (26003B)
1 // © 2021 and later: Unicode, Inc. and others. 2 // License & terms of use: http://www.unicode.org/copyright.html 3 4 #include <complex> 5 #include <utility> 6 7 #include "unicode/utypes.h" 8 9 #if !UCONFIG_NO_BREAK_ITERATION 10 11 #include "brkeng.h" 12 #include "charstr.h" 13 #include "cmemory.h" 14 #include "lstmbe.h" 15 #include "putilimp.h" 16 #include "uassert.h" 17 #include "ubrkimpl.h" 18 #include "uresimp.h" 19 #include "uvectr32.h" 20 #include "uvector.h" 21 22 #include "unicode/brkiter.h" 23 #include "unicode/resbund.h" 24 #include "unicode/ubrk.h" 25 #include "unicode/uniset.h" 26 #include "unicode/ustring.h" 27 #include "unicode/utf.h" 28 29 U_NAMESPACE_BEGIN 30 31 // Uncomment the following #define to debug. 32 // #define LSTM_DEBUG 1 33 // #define LSTM_VECTORIZER_DEBUG 1 34 35 /** 36 * Interface for reading 1D array. 37 */ 38 class ReadArray1D { 39 public: 40 virtual ~ReadArray1D(); 41 virtual int32_t d1() const = 0; 42 virtual float get(int32_t i) const = 0; 43 44 #ifdef LSTM_DEBUG 45 void print() const { 46 printf("\n["); 47 for (int32_t i = 0; i < d1(); i++) { 48 printf("%0.8e ", get(i)); 49 if (i % 4 == 3) printf("\n"); 50 } 51 printf("]\n"); 52 } 53 #endif 54 }; 55 56 ReadArray1D::~ReadArray1D() 57 { 58 } 59 60 /** 61 * Interface for reading 2D array. 62 */ 63 class ReadArray2D { 64 public: 65 virtual ~ReadArray2D(); 66 virtual int32_t d1() const = 0; 67 virtual int32_t d2() const = 0; 68 virtual float get(int32_t i, int32_t j) const = 0; 69 }; 70 71 ReadArray2D::~ReadArray2D() 72 { 73 } 74 75 /** 76 * A class to index a float array as a 1D Array without owning the pointer or 77 * copy the data. 78 */ 79 class ConstArray1D : public ReadArray1D { 80 public: 81 ConstArray1D() : data_(nullptr), d1_(0) {} 82 83 ConstArray1D(const float* data, int32_t d1) : data_(data), d1_(d1) {} 84 85 virtual ~ConstArray1D(); 86 87 // Init the object, the object does not own the data nor copy. 88 // It is designed to directly use data from memory mapped resources. 89 void init(const int32_t* data, int32_t d1) { 90 U_ASSERT(IEEE_754 == 1); 91 data_ = reinterpret_cast<const float*>(data); 92 d1_ = d1; 93 } 94 95 // ReadArray1D methods. 96 virtual int32_t d1() const override { return d1_; } 97 virtual float get(int32_t i) const override { 98 U_ASSERT(i < d1_); 99 return data_[i]; 100 } 101 102 private: 103 const float* data_; 104 int32_t d1_; 105 }; 106 107 ConstArray1D::~ConstArray1D() 108 { 109 } 110 111 /** 112 * A class to index a float array as a 2D Array without owning the pointer or 113 * copy the data. 114 */ 115 class ConstArray2D : public ReadArray2D { 116 public: 117 ConstArray2D() : data_(nullptr), d1_(0), d2_(0) {} 118 119 ConstArray2D(const float* data, int32_t d1, int32_t d2) 120 : data_(data), d1_(d1), d2_(d2) {} 121 122 virtual ~ConstArray2D(); 123 124 // Init the object, the object does not own the data nor copy. 125 // It is designed to directly use data from memory mapped resources. 126 void init(const int32_t* data, int32_t d1, int32_t d2) { 127 U_ASSERT(IEEE_754 == 1); 128 data_ = reinterpret_cast<const float*>(data); 129 d1_ = d1; 130 d2_ = d2; 131 } 132 133 // ReadArray2D methods. 134 inline int32_t d1() const override { return d1_; } 135 inline int32_t d2() const override { return d2_; } 136 float get(int32_t i, int32_t j) const override { 137 U_ASSERT(i < d1_); 138 U_ASSERT(j < d2_); 139 return data_[i * d2_ + j]; 140 } 141 142 // Expose the ith row as a ConstArray1D 143 inline ConstArray1D row(int32_t i) const { 144 U_ASSERT(i < d1_); 145 return ConstArray1D(data_ + i * d2_, d2_); 146 } 147 148 private: 149 const float* data_; 150 int32_t d1_; 151 int32_t d2_; 152 }; 153 154 ConstArray2D::~ConstArray2D() 155 { 156 } 157 158 /** 159 * A class to allocate data as a writable 1D array. 160 * This is the main class implement matrix operation. 161 */ 162 class Array1D : public ReadArray1D { 163 public: 164 Array1D() : memory_(nullptr), data_(nullptr), d1_(0) {} 165 Array1D(int32_t d1, UErrorCode &status) 166 : memory_(uprv_malloc(d1 * sizeof(float))), 167 data_(static_cast<float*>(memory_)), d1_(d1) { 168 if (U_SUCCESS(status)) { 169 if (memory_ == nullptr) { 170 status = U_MEMORY_ALLOCATION_ERROR; 171 return; 172 } 173 clear(); 174 } 175 } 176 177 virtual ~Array1D(); 178 179 // A special constructor which does not own the memory but writeable 180 // as a slice of an array. 181 Array1D(float* data, int32_t d1) 182 : memory_(nullptr), data_(data), d1_(d1) {} 183 184 // ReadArray1D methods. 185 virtual int32_t d1() const override { return d1_; } 186 virtual float get(int32_t i) const override { 187 U_ASSERT(i < d1_); 188 return data_[i]; 189 } 190 191 // Return the index which point to the max data in the array. 192 inline int32_t maxIndex() const { 193 int32_t index = 0; 194 float max = data_[0]; 195 for (int32_t i = 1; i < d1_; i++) { 196 if (data_[i] > max) { 197 max = data_[i]; 198 index = i; 199 } 200 } 201 return index; 202 } 203 204 // Slice part of the array to a new one. 205 inline Array1D slice(int32_t from, int32_t size) const { 206 U_ASSERT(from >= 0); 207 U_ASSERT(from < d1_); 208 U_ASSERT(from + size <= d1_); 209 return Array1D(data_ + from, size); 210 } 211 212 // Add dot product of a 1D array and a 2D array into this one. 213 inline Array1D& addDotProduct(const ReadArray1D& a, const ReadArray2D& b) { 214 U_ASSERT(a.d1() == b.d1()); 215 U_ASSERT(b.d2() == d1()); 216 for (int32_t i = 0; i < d1(); i++) { 217 for (int32_t j = 0; j < a.d1(); j++) { 218 data_[i] += a.get(j) * b.get(j, i); 219 } 220 } 221 return *this; 222 } 223 224 // Hadamard Product the values of another array of the same size into this one. 225 inline Array1D& hadamardProduct(const ReadArray1D& a) { 226 U_ASSERT(a.d1() == d1()); 227 for (int32_t i = 0; i < d1(); i++) { 228 data_[i] *= a.get(i); 229 } 230 return *this; 231 } 232 233 // Add the Hadamard Product of two arrays of the same size into this one. 234 inline Array1D& addHadamardProduct(const ReadArray1D& a, const ReadArray1D& b) { 235 U_ASSERT(a.d1() == d1()); 236 U_ASSERT(b.d1() == d1()); 237 for (int32_t i = 0; i < d1(); i++) { 238 data_[i] += a.get(i) * b.get(i); 239 } 240 return *this; 241 } 242 243 // Add the values of another array of the same size into this one. 244 inline Array1D& add(const ReadArray1D& a) { 245 U_ASSERT(a.d1() == d1()); 246 for (int32_t i = 0; i < d1(); i++) { 247 data_[i] += a.get(i); 248 } 249 return *this; 250 } 251 252 // Assign the values of another array of the same size into this one. 253 inline Array1D& assign(const ReadArray1D& a) { 254 U_ASSERT(a.d1() == d1()); 255 for (int32_t i = 0; i < d1(); i++) { 256 data_[i] = a.get(i); 257 } 258 return *this; 259 } 260 261 // Apply tanh to all the elements in the array. 262 inline Array1D& tanh() { 263 return tanh(*this); 264 } 265 266 // Apply tanh of a and store into this array. 267 inline Array1D& tanh(const Array1D& a) { 268 U_ASSERT(a.d1() == d1()); 269 for (int32_t i = 0; i < d1_; i++) { 270 data_[i] = std::tanh(a.get(i)); 271 } 272 return *this; 273 } 274 275 // Apply sigmoid to all the elements in the array. 276 inline Array1D& sigmoid() { 277 for (int32_t i = 0; i < d1_; i++) { 278 data_[i] = 1.0f/(1.0f + expf(-data_[i])); 279 } 280 return *this; 281 } 282 283 inline Array1D& clear() { 284 uprv_memset(data_, 0, d1_ * sizeof(float)); 285 return *this; 286 } 287 288 private: 289 void* memory_; 290 float* data_; 291 int32_t d1_; 292 }; 293 294 Array1D::~Array1D() 295 { 296 uprv_free(memory_); 297 } 298 299 class Array2D : public ReadArray2D { 300 public: 301 Array2D() : memory_(nullptr), data_(nullptr), d1_(0), d2_(0) {} 302 Array2D(int32_t d1, int32_t d2, UErrorCode &status) 303 : memory_(uprv_malloc(d1 * d2 * sizeof(float))), 304 data_(static_cast<float*>(memory_)), d1_(d1), d2_(d2) { 305 if (U_SUCCESS(status)) { 306 if (memory_ == nullptr) { 307 status = U_MEMORY_ALLOCATION_ERROR; 308 return; 309 } 310 clear(); 311 } 312 } 313 virtual ~Array2D(); 314 315 // ReadArray2D methods. 316 virtual int32_t d1() const override { return d1_; } 317 virtual int32_t d2() const override { return d2_; } 318 virtual float get(int32_t i, int32_t j) const override { 319 U_ASSERT(i < d1_); 320 U_ASSERT(j < d2_); 321 return data_[i * d2_ + j]; 322 } 323 324 inline Array1D row(int32_t i) const { 325 U_ASSERT(i < d1_); 326 return Array1D(data_ + i * d2_, d2_); 327 } 328 329 inline Array2D& clear() { 330 uprv_memset(data_, 0, d1_ * d2_ * sizeof(float)); 331 return *this; 332 } 333 334 private: 335 void* memory_; 336 float* data_; 337 int32_t d1_; 338 int32_t d2_; 339 }; 340 341 Array2D::~Array2D() 342 { 343 uprv_free(memory_); 344 } 345 346 typedef enum { 347 BEGIN, 348 INSIDE, 349 END, 350 SINGLE 351 } LSTMClass; 352 353 typedef enum { 354 UNKNOWN, 355 CODE_POINTS, 356 GRAPHEME_CLUSTER, 357 } EmbeddingType; 358 359 struct LSTMData : public UMemory { 360 LSTMData(UResourceBundle* rb, UErrorCode &status); 361 ~LSTMData(); 362 UHashtable* fDict; 363 EmbeddingType fType; 364 const char16_t* fName; 365 ConstArray2D fEmbedding; 366 ConstArray2D fForwardW; 367 ConstArray2D fForwardU; 368 ConstArray1D fForwardB; 369 ConstArray2D fBackwardW; 370 ConstArray2D fBackwardU; 371 ConstArray1D fBackwardB; 372 ConstArray2D fOutputW; 373 ConstArray1D fOutputB; 374 375 private: 376 UResourceBundle* fBundle; 377 }; 378 379 LSTMData::LSTMData(UResourceBundle* rb, UErrorCode &status) 380 : fDict(nullptr), fType(UNKNOWN), fName(nullptr), 381 fBundle(rb) 382 { 383 if (U_FAILURE(status)) { 384 return; 385 } 386 if (IEEE_754 != 1) { 387 status = U_UNSUPPORTED_ERROR; 388 return; 389 } 390 LocalUResourceBundlePointer embeddings_res( 391 ures_getByKey(rb, "embeddings", nullptr, &status)); 392 int32_t embedding_size = ures_getInt(embeddings_res.getAlias(), &status); 393 LocalUResourceBundlePointer hunits_res( 394 ures_getByKey(rb, "hunits", nullptr, &status)); 395 if (U_FAILURE(status)) return; 396 int32_t hunits = ures_getInt(hunits_res.getAlias(), &status); 397 const char16_t* type = ures_getStringByKey(rb, "type", nullptr, &status); 398 if (U_FAILURE(status)) return; 399 if (u_strCompare(type, -1, u"codepoints", -1, false) == 0) { 400 fType = CODE_POINTS; 401 } else if (u_strCompare(type, -1, u"graphclust", -1, false) == 0) { 402 fType = GRAPHEME_CLUSTER; 403 } 404 fName = ures_getStringByKey(rb, "model", nullptr, &status); 405 LocalUResourceBundlePointer dataRes(ures_getByKey(rb, "data", nullptr, &status)); 406 if (U_FAILURE(status)) return; 407 int32_t data_len = 0; 408 const int32_t* data = ures_getIntVector(dataRes.getAlias(), &data_len, &status); 409 fDict = uhash_open(uhash_hashUChars, uhash_compareUChars, nullptr, &status); 410 411 StackUResourceBundle stackTempBundle; 412 ResourceDataValue value; 413 ures_getValueWithFallback(rb, "dict", stackTempBundle.getAlias(), value, status); 414 ResourceArray stringArray = value.getArray(status); 415 int32_t num_index = stringArray.getSize(); 416 if (U_FAILURE(status)) { return; } 417 418 // put dict into hash 419 int32_t stringLength; 420 for (int32_t idx = 0; idx < num_index; idx++) { 421 stringArray.getValue(idx, value); 422 const char16_t* str = value.getString(stringLength, status); 423 uhash_putiAllowZero(fDict, (void*)str, idx, &status); 424 if (U_FAILURE(status)) return; 425 #ifdef LSTM_VECTORIZER_DEBUG 426 printf("Assign ["); 427 while (*str != 0x0000) { 428 printf("U+%04x ", *str); 429 str++; 430 } 431 printf("] map to %d\n", idx-1); 432 #endif 433 } 434 int32_t mat1_size = (num_index + 1) * embedding_size; 435 int32_t mat2_size = embedding_size * 4 * hunits; 436 int32_t mat3_size = hunits * 4 * hunits; 437 int32_t mat4_size = 4 * hunits; 438 int32_t mat5_size = mat2_size; 439 int32_t mat6_size = mat3_size; 440 int32_t mat7_size = mat4_size; 441 int32_t mat8_size = 2 * hunits * 4; 442 #if U_DEBUG 443 int32_t mat9_size = 4; 444 U_ASSERT(data_len == mat1_size + mat2_size + mat3_size + mat4_size + mat5_size + 445 mat6_size + mat7_size + mat8_size + mat9_size); 446 #endif 447 448 fEmbedding.init(data, (num_index + 1), embedding_size); 449 data += mat1_size; 450 fForwardW.init(data, embedding_size, 4 * hunits); 451 data += mat2_size; 452 fForwardU.init(data, hunits, 4 * hunits); 453 data += mat3_size; 454 fForwardB.init(data, 4 * hunits); 455 data += mat4_size; 456 fBackwardW.init(data, embedding_size, 4 * hunits); 457 data += mat5_size; 458 fBackwardU.init(data, hunits, 4 * hunits); 459 data += mat6_size; 460 fBackwardB.init(data, 4 * hunits); 461 data += mat7_size; 462 fOutputW.init(data, 2 * hunits, 4); 463 data += mat8_size; 464 fOutputB.init(data, 4); 465 } 466 467 LSTMData::~LSTMData() { 468 uhash_close(fDict); 469 ures_close(fBundle); 470 } 471 472 class Vectorizer : public UMemory { 473 public: 474 Vectorizer(UHashtable* dict) : fDict(dict) {} 475 virtual ~Vectorizer(); 476 virtual void vectorize(UText *text, int32_t startPos, int32_t endPos, 477 UVector32 &offsets, UVector32 &indices, 478 UErrorCode &status) const = 0; 479 protected: 480 int32_t stringToIndex(const char16_t* str) const { 481 UBool found = false; 482 int32_t ret = uhash_getiAndFound(fDict, (const void*)str, &found); 483 if (!found) { 484 ret = fDict->count; 485 } 486 #ifdef LSTM_VECTORIZER_DEBUG 487 printf("["); 488 while (*str != 0x0000) { 489 printf("U+%04x ", *str); 490 str++; 491 } 492 printf("] map to %d\n", ret); 493 #endif 494 return ret; 495 } 496 497 private: 498 UHashtable* fDict; 499 }; 500 501 Vectorizer::~Vectorizer() 502 { 503 } 504 505 class CodePointsVectorizer : public Vectorizer { 506 public: 507 CodePointsVectorizer(UHashtable* dict) : Vectorizer(dict) {} 508 virtual ~CodePointsVectorizer(); 509 virtual void vectorize(UText *text, int32_t startPos, int32_t endPos, 510 UVector32 &offsets, UVector32 &indices, 511 UErrorCode &status) const override; 512 }; 513 514 CodePointsVectorizer::~CodePointsVectorizer() 515 { 516 } 517 518 void CodePointsVectorizer::vectorize( 519 UText *text, int32_t startPos, int32_t endPos, 520 UVector32 &offsets, UVector32 &indices, UErrorCode &status) const 521 { 522 if (offsets.ensureCapacity(endPos - startPos, status) && 523 indices.ensureCapacity(endPos - startPos, status)) { 524 if (U_FAILURE(status)) return; 525 utext_setNativeIndex(text, startPos); 526 int32_t current; 527 char16_t str[2] = {0, 0}; 528 while (U_SUCCESS(status) && 529 (current = static_cast<int32_t>(utext_getNativeIndex(text))) < endPos) { 530 // Since the LSTMBreakEngine is currently only accept chars in BMP, 531 // we can ignore the possibility of hitting supplementary code 532 // point. 533 str[0] = static_cast<char16_t>(utext_next32(text)); 534 U_ASSERT(!U_IS_SURROGATE(str[0])); 535 offsets.addElement(current, status); 536 indices.addElement(stringToIndex(str), status); 537 } 538 } 539 } 540 541 class GraphemeClusterVectorizer : public Vectorizer { 542 public: 543 GraphemeClusterVectorizer(UHashtable* dict) 544 : Vectorizer(dict) 545 { 546 } 547 virtual ~GraphemeClusterVectorizer(); 548 virtual void vectorize(UText *text, int32_t startPos, int32_t endPos, 549 UVector32 &offsets, UVector32 &indices, 550 UErrorCode &status) const override; 551 }; 552 553 GraphemeClusterVectorizer::~GraphemeClusterVectorizer() 554 { 555 } 556 557 constexpr int32_t MAX_GRAPHEME_CLSTER_LENGTH = 10; 558 559 void GraphemeClusterVectorizer::vectorize( 560 UText *text, int32_t startPos, int32_t endPos, 561 UVector32 &offsets, UVector32 &indices, UErrorCode &status) const 562 { 563 if (U_FAILURE(status)) return; 564 if (!offsets.ensureCapacity(endPos - startPos, status) || 565 !indices.ensureCapacity(endPos - startPos, status)) { 566 return; 567 } 568 if (U_FAILURE(status)) return; 569 LocalPointer<BreakIterator> graphemeIter(BreakIterator::createCharacterInstance(Locale(), status)); 570 if (U_FAILURE(status)) return; 571 graphemeIter->setText(text, status); 572 if (U_FAILURE(status)) return; 573 574 if (startPos != 0) { 575 graphemeIter->preceding(startPos); 576 } 577 int32_t last = startPos; 578 int32_t current = startPos; 579 char16_t str[MAX_GRAPHEME_CLSTER_LENGTH]; 580 while ((current = graphemeIter->next()) != BreakIterator::DONE) { 581 if (current >= endPos) { 582 break; 583 } 584 if (current > startPos) { 585 utext_extract(text, last, current, str, MAX_GRAPHEME_CLSTER_LENGTH, &status); 586 if (U_FAILURE(status)) return; 587 offsets.addElement(last, status); 588 indices.addElement(stringToIndex(str), status); 589 if (U_FAILURE(status)) return; 590 } 591 last = current; 592 } 593 if (U_FAILURE(status) || last >= endPos) { 594 return; 595 } 596 utext_extract(text, last, endPos, str, MAX_GRAPHEME_CLSTER_LENGTH, &status); 597 if (U_SUCCESS(status)) { 598 offsets.addElement(last, status); 599 indices.addElement(stringToIndex(str), status); 600 } 601 } 602 603 // Computing LSTM as stated in 604 // https://en.wikipedia.org/wiki/Long_short-term_memory#LSTM_with_a_forget_gate 605 // ifco is temp array allocate outside which does not need to be 606 // input/output value but could avoid unnecessary memory alloc/free if passing 607 // in. 608 void compute( 609 int32_t hunits, 610 const ReadArray2D& W, const ReadArray2D& U, const ReadArray1D& b, 611 const ReadArray1D& x, Array1D& h, Array1D& c, 612 Array1D& ifco) 613 { 614 // ifco = x * W + h * U + b 615 ifco.assign(b) 616 .addDotProduct(x, W) 617 .addDotProduct(h, U); 618 619 ifco.slice(0*hunits, hunits).sigmoid(); // i: sigmod 620 ifco.slice(1*hunits, hunits).sigmoid(); // f: sigmoid 621 ifco.slice(2*hunits, hunits).tanh(); // c_: tanh 622 ifco.slice(3*hunits, hunits).sigmoid(); // o: sigmod 623 624 c.hadamardProduct(ifco.slice(hunits, hunits)) 625 .addHadamardProduct(ifco.slice(0, hunits), ifco.slice(2*hunits, hunits)); 626 627 h.tanh(c) 628 .hadamardProduct(ifco.slice(3*hunits, hunits)); 629 } 630 631 // Minimum word size 632 static const int32_t MIN_WORD = 2; 633 634 // Minimum number of characters for two words 635 static const int32_t MIN_WORD_SPAN = MIN_WORD * 2; 636 637 int32_t 638 LSTMBreakEngine::divideUpDictionaryRange( UText *text, 639 int32_t startPos, 640 int32_t endPos, 641 UVector32 &foundBreaks, 642 UBool /* isPhraseBreaking */, 643 UErrorCode& status) const { 644 if (U_FAILURE(status)) return 0; 645 int32_t beginFoundBreakSize = foundBreaks.size(); 646 utext_setNativeIndex(text, startPos); 647 utext_moveIndex32(text, MIN_WORD_SPAN); 648 if (utext_getNativeIndex(text) >= endPos) { 649 return 0; // Not enough characters for two words 650 } 651 utext_setNativeIndex(text, startPos); 652 653 UVector32 offsets(status); 654 UVector32 indices(status); 655 if (U_FAILURE(status)) return 0; 656 fVectorizer->vectorize(text, startPos, endPos, offsets, indices, status); 657 if (U_FAILURE(status)) return 0; 658 int32_t* offsetsBuf = offsets.getBuffer(); 659 int32_t* indicesBuf = indices.getBuffer(); 660 661 int32_t input_seq_len = indices.size(); 662 int32_t hunits = fData->fForwardU.d1(); 663 664 // ----- Begin of all the Array memory allocation needed for this function 665 // Allocate temp array used inside compute() 666 Array1D ifco(4 * hunits, status); 667 668 Array1D c(hunits, status); 669 Array1D logp(4, status); 670 671 // TODO: limit size of hBackward. If input_seq_len is too big, we could 672 // run out of memory. 673 // Backward LSTM 674 Array2D hBackward(input_seq_len, hunits, status); 675 676 // Allocate fbRow and slice the internal array in two. 677 Array1D fbRow(2 * hunits, status); 678 679 // ----- End of all the Array memory allocation needed for this function 680 if (U_FAILURE(status)) return 0; 681 682 // To save the needed memory usage, the following is different from the 683 // Python or ICU4X implementation. We first perform the Backward LSTM 684 // and then merge the iteration of the forward LSTM and the output layer 685 // together because we only neetdto remember the h[t-1] for Forward LSTM. 686 for (int32_t i = input_seq_len - 1; i >= 0; i--) { 687 Array1D hRow = hBackward.row(i); 688 if (i != input_seq_len - 1) { 689 hRow.assign(hBackward.row(i+1)); 690 } 691 #ifdef LSTM_DEBUG 692 printf("hRow %d\n", i); 693 hRow.print(); 694 printf("indicesBuf[%d] = %d\n", i, indicesBuf[i]); 695 printf("fData->fEmbedding.row(indicesBuf[%d]):\n", i); 696 fData->fEmbedding.row(indicesBuf[i]).print(); 697 #endif // LSTM_DEBUG 698 compute(hunits, 699 fData->fBackwardW, fData->fBackwardU, fData->fBackwardB, 700 fData->fEmbedding.row(indicesBuf[i]), 701 hRow, c, ifco); 702 } 703 704 705 Array1D forwardRow = fbRow.slice(0, hunits); // point to first half of data in fbRow. 706 Array1D backwardRow = fbRow.slice(hunits, hunits); // point to second half of data n fbRow. 707 708 // The following iteration merge the forward LSTM and the output layer 709 // together. 710 c.clear(); // reuse c since it is the same size. 711 for (int32_t i = 0; i < input_seq_len; i++) { 712 #ifdef LSTM_DEBUG 713 printf("forwardRow %d\n", i); 714 forwardRow.print(); 715 #endif // LSTM_DEBUG 716 // Forward LSTM 717 // Calculate the result into forwardRow, which point to the data in the first half 718 // of fbRow. 719 compute(hunits, 720 fData->fForwardW, fData->fForwardU, fData->fForwardB, 721 fData->fEmbedding.row(indicesBuf[i]), 722 forwardRow, c, ifco); 723 724 // assign the data from hBackward.row(i) to second half of fbRowa. 725 backwardRow.assign(hBackward.row(i)); 726 727 logp.assign(fData->fOutputB).addDotProduct(fbRow, fData->fOutputW); 728 #ifdef LSTM_DEBUG 729 printf("backwardRow %d\n", i); 730 backwardRow.print(); 731 printf("logp %d\n", i); 732 logp.print(); 733 #endif // LSTM_DEBUG 734 735 // current = argmax(logp) 736 LSTMClass current = static_cast<LSTMClass>(logp.maxIndex()); 737 // BIES logic. 738 if (current == BEGIN || current == SINGLE) { 739 if (i != 0) { 740 foundBreaks.addElement(offsetsBuf[i], status); 741 if (U_FAILURE(status)) return 0; 742 } 743 } 744 } 745 return foundBreaks.size() - beginFoundBreakSize; 746 } 747 748 Vectorizer* createVectorizer(const LSTMData* data, UErrorCode &status) { 749 if (U_FAILURE(status)) { 750 return nullptr; 751 } 752 switch (data->fType) { 753 case CODE_POINTS: 754 return new CodePointsVectorizer(data->fDict); 755 break; 756 case GRAPHEME_CLUSTER: 757 return new GraphemeClusterVectorizer(data->fDict); 758 break; 759 default: 760 break; 761 } 762 UPRV_UNREACHABLE_EXIT; 763 } 764 765 LSTMBreakEngine::LSTMBreakEngine(const LSTMData* data, const UnicodeSet& set, UErrorCode &status) 766 : DictionaryBreakEngine(), fData(data), fVectorizer(createVectorizer(fData, status)) 767 { 768 if (U_FAILURE(status)) { 769 fData = nullptr; // If failure, we should not delete fData in destructor because the caller will do so. 770 return; 771 } 772 setCharacters(set); 773 } 774 775 LSTMBreakEngine::~LSTMBreakEngine() { 776 delete fData; 777 delete fVectorizer; 778 } 779 780 const char16_t* LSTMBreakEngine::name() const { 781 return fData->fName; 782 } 783 784 UnicodeString defaultLSTM(UScriptCode script, UErrorCode& status) { 785 // open root from brkitr tree. 786 UResourceBundle *b = ures_open(U_ICUDATA_BRKITR, "", &status); 787 b = ures_getByKeyWithFallback(b, "lstm", b, &status); 788 UnicodeString result = ures_getUnicodeStringByKey(b, uscript_getShortName(script), &status); 789 ures_close(b); 790 return result; 791 } 792 793 U_CAPI const LSTMData* U_EXPORT2 CreateLSTMDataForScript(UScriptCode script, UErrorCode& status) 794 { 795 if (script != USCRIPT_KHMER && script != USCRIPT_LAO && script != USCRIPT_MYANMAR && script != USCRIPT_THAI) { 796 return nullptr; 797 } 798 UnicodeString name = defaultLSTM(script, status); 799 if (U_FAILURE(status)) return nullptr; 800 CharString namebuf; 801 namebuf.appendInvariantChars(name, status).truncate(namebuf.lastIndexOf('.')); 802 803 LocalUResourceBundlePointer rb( 804 ures_openDirect(U_ICUDATA_BRKITR, namebuf.data(), &status)); 805 if (U_FAILURE(status)) return nullptr; 806 807 return CreateLSTMData(rb.orphan(), status); 808 } 809 810 U_CAPI const LSTMData* U_EXPORT2 CreateLSTMData(UResourceBundle* rb, UErrorCode& status) 811 { 812 if (U_FAILURE(status)) { 813 return nullptr; 814 } 815 const LSTMData* result = new LSTMData(rb, status); 816 if (U_FAILURE(status)) { 817 delete result; 818 return nullptr; 819 } 820 return result; 821 } 822 823 U_CAPI const LanguageBreakEngine* U_EXPORT2 824 CreateLSTMBreakEngine(UScriptCode script, const LSTMData* data, UErrorCode& status) 825 { 826 UnicodeString unicodeSetString; 827 switch(script) { 828 case USCRIPT_THAI: 829 unicodeSetString = UnicodeString(u"[[:Thai:]&[:LineBreak=SA:]]"); 830 break; 831 case USCRIPT_MYANMAR: 832 unicodeSetString = UnicodeString(u"[[:Mymr:]&[:LineBreak=SA:]]"); 833 break; 834 default: 835 delete data; 836 return nullptr; 837 } 838 UnicodeSet unicodeSet; 839 unicodeSet.applyPattern(unicodeSetString, status); 840 const LanguageBreakEngine* engine = new LSTMBreakEngine(data, unicodeSet, status); 841 if (U_FAILURE(status) || engine == nullptr) { 842 if (engine != nullptr) { 843 delete engine; 844 } else { 845 status = U_MEMORY_ALLOCATION_ERROR; 846 } 847 return nullptr; 848 } 849 return engine; 850 } 851 852 U_CAPI void U_EXPORT2 DeleteLSTMData(const LSTMData* data) 853 { 854 delete data; 855 } 856 857 U_CAPI const char16_t* U_EXPORT2 LSTMDataName(const LSTMData* data) 858 { 859 return data->fName; 860 } 861 862 U_NAMESPACE_END 863 864 #endif /* #if !UCONFIG_NO_BREAK_ITERATION */