tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

mlbe.cpp (10730B)


      1 // © 2022 and later: Unicode, Inc. and others.
      2 // License & terms of use: http://www.unicode.org/copyright.html
      3 
      4 #include "unicode/utypes.h"
      5 
      6 #if !UCONFIG_NO_BREAK_ITERATION
      7 
      8 #include "cmemory.h"
      9 #include "mlbe.h"
     10 #include "uassert.h"
     11 #include "ubrkimpl.h"
     12 #include "unicode/resbund.h"
     13 #include "unicode/udata.h"
     14 #include "unicode/utf16.h"
     15 #include "uresimp.h"
     16 #include "util.h"
     17 #include "uvectr32.h"
     18 
     19 U_NAMESPACE_BEGIN
     20 
     21 enum class ModelIndex { kUWStart = 0, kBWStart = 6, kTWStart = 9 };
     22 
     23 MlBreakEngine::MlBreakEngine(const UnicodeSet &digitOrOpenPunctuationOrAlphabetSet,
     24                             const UnicodeSet &closePunctuationSet, UErrorCode &status)
     25    : fDigitOrOpenPunctuationOrAlphabetSet(digitOrOpenPunctuationOrAlphabetSet),
     26      fClosePunctuationSet(closePunctuationSet),
     27      fNegativeSum(0) {
     28    if (U_FAILURE(status)) {
     29        return;
     30    }
     31    loadMLModel(status);
     32 }
     33 
     34 MlBreakEngine::~MlBreakEngine() {}
     35 
     36 int32_t MlBreakEngine::divideUpRange(UText *inText, int32_t rangeStart, int32_t rangeEnd,
     37                                     UVector32 &foundBreaks, const UnicodeString &inString,
     38                                     const LocalPointer<UVector32> &inputMap,
     39                                     UErrorCode &status) const {
     40    if (U_FAILURE(status)) {
     41        return 0;
     42    }
     43    if (rangeStart >= rangeEnd) {
     44        status = U_ILLEGAL_ARGUMENT_ERROR;
     45        return 0;
     46    }
     47 
     48    UVector32 boundary(inString.countChar32() + 1, status);
     49    if (U_FAILURE(status)) {
     50        return 0;
     51    }
     52    int32_t numBreaks = 0;
     53    int32_t codePointLength = inString.countChar32();
     54    // The ML algorithm groups six char and evaluates whether the 4th char is a breakpoint.
     55    // In each iteration, it evaluates the 4th char and then moves forward one char like a sliding
     56    // window. Initially, the first six values in the indexList are [-1, -1, 0, 1, 2, 3]. After
     57    // moving forward, finally the last six values in the indexList are
     58    // [length-4, length-3, length-2, length-1, -1, -1]. The "+4" here means four extra "-1".
     59    int32_t indexSize = codePointLength + 4;
     60    LocalMemory<int32_t> indexList(static_cast<int32_t*>(uprv_malloc(indexSize * sizeof(int32_t))));
     61    if (indexList.isNull()) {
     62        status = U_MEMORY_ALLOCATION_ERROR;
     63        return 0;
     64    }
     65    int32_t numCodeUnits = initIndexList(inString, indexList.getAlias(), status);
     66 
     67    // Add a break for the start.
     68    boundary.addElement(0, status);
     69    numBreaks++;
     70    if (U_FAILURE(status)) return 0;
     71 
     72    for (int32_t idx = 0; idx + 1 < codePointLength && U_SUCCESS(status); idx++) {
     73        numBreaks =
     74            evaluateBreakpoint(inString, indexList.getAlias(), idx, numCodeUnits, numBreaks, boundary, status);
     75        if (idx + 4 < codePointLength) {
     76            indexList[idx + 6] = numCodeUnits;
     77            numCodeUnits += U16_LENGTH(inString.char32At(indexList[idx + 6]));
     78        }
     79    }
     80 
     81    if (U_FAILURE(status)) return 0;
     82 
     83    // Add a break for the end if there is not one there already.
     84    if (boundary.lastElementi() != inString.countChar32()) {
     85        boundary.addElement(inString.countChar32(), status);
     86        numBreaks++;
     87    }
     88 
     89    int32_t prevCPPos = -1;
     90    int32_t prevUTextPos = -1;
     91    int32_t correctedNumBreaks = 0;
     92    for (int32_t i = 0; i < numBreaks; i++) {
     93        int32_t cpPos = boundary.elementAti(i);
     94        int32_t utextPos = inputMap.isValid() ? inputMap->elementAti(cpPos) : cpPos + rangeStart;
     95        U_ASSERT(cpPos > prevCPPos);
     96        U_ASSERT(utextPos >= prevUTextPos);
     97 
     98        if (utextPos > prevUTextPos) {
     99            if (utextPos != rangeStart ||
    100                (utextPos > 0 &&
    101                 fClosePunctuationSet.contains(utext_char32At(inText, utextPos - 1)))) {
    102                foundBreaks.push(utextPos, status);
    103                correctedNumBreaks++;
    104            }
    105        } else {
    106            // Normalization expanded the input text, the dictionary found a boundary
    107            // within the expansion, giving two boundaries with the same index in the
    108            // original text. Ignore the second. See ticket #12918.
    109            --numBreaks;
    110        }
    111        prevCPPos = cpPos;
    112        prevUTextPos = utextPos;
    113    }
    114    (void)prevCPPos;  // suppress compiler warnings about unused variable
    115 
    116    UChar32 nextChar = utext_char32At(inText, rangeEnd);
    117    if (!foundBreaks.isEmpty() && foundBreaks.peeki() == rangeEnd) {
    118        // In phrase breaking, there has to be a breakpoint between Cj character and
    119        // the number/open punctuation.
    120        // E.g. る文字「そうだ、京都」->る▁文字▁「そうだ、▁京都」-> breakpoint between 字 and「
    121        // E.g. 乗車率90%程度だろうか -> 乗車▁率▁90%▁程度だろうか -> breakpoint between 率 and 9
    122        // E.g. しかもロゴがUnicode! -> しかも▁ロゴが▁Unicode!-> breakpoint between が and U
    123        if (!fDigitOrOpenPunctuationOrAlphabetSet.contains(nextChar)) {
    124            foundBreaks.popi();
    125            correctedNumBreaks--;
    126        }
    127    }
    128 
    129    return correctedNumBreaks;
    130 }
    131 
    132 int32_t MlBreakEngine::evaluateBreakpoint(const UnicodeString &inString, int32_t *indexList,
    133                                          int32_t startIdx, int32_t numCodeUnits, int32_t numBreaks,
    134                                          UVector32 &boundary, UErrorCode &status) const {
    135    if (U_FAILURE(status)) {
    136        return numBreaks;
    137    }
    138    int32_t start = 0, end = 0;
    139    int32_t score = fNegativeSum;
    140 
    141    for (int i = 0; i < 6; i++) {
    142        // UW1 ~ UW6
    143        start = startIdx + i;
    144        if (indexList[start] != -1) {
    145            end = (indexList[start + 1] != -1) ? indexList[start + 1] : numCodeUnits;
    146            score += fModel[static_cast<int32_t>(ModelIndex::kUWStart) + i].geti(
    147                inString.tempSubString(indexList[start], end - indexList[start]));
    148        }
    149    }
    150    for (int i = 0; i < 3; i++) {
    151        // BW1 ~ BW3
    152        start = startIdx + i + 1;
    153        if (indexList[start] != -1 && indexList[start + 1] != -1) {
    154            end = (indexList[start + 2] != -1) ? indexList[start + 2] : numCodeUnits;
    155            score += fModel[static_cast<int32_t>(ModelIndex::kBWStart) + i].geti(
    156                inString.tempSubString(indexList[start], end - indexList[start]));
    157        }
    158    }
    159    for (int i = 0; i < 4; i++) {
    160        // TW1 ~ TW4
    161        start = startIdx + i;
    162        if (indexList[start] != -1 && indexList[start + 1] != -1 && indexList[start + 2] != -1) {
    163            end = (indexList[start + 3] != -1) ? indexList[start + 3] : numCodeUnits;
    164            score += fModel[static_cast<int32_t>(ModelIndex::kTWStart) + i].geti(
    165                inString.tempSubString(indexList[start], end - indexList[start]));
    166        }
    167    }
    168 
    169    if (score > 0) {
    170        boundary.addElement(startIdx + 1, status);
    171        numBreaks++;
    172    }
    173    return numBreaks;
    174 }
    175 
    176 int32_t MlBreakEngine::initIndexList(const UnicodeString &inString, int32_t *indexList,
    177                                     UErrorCode &status) const {
    178    if (U_FAILURE(status)) {
    179        return 0;
    180    }
    181    int32_t index = 0;
    182    int32_t length = inString.countChar32();
    183    // Set all (lenght+4) items inside indexLength to -1 presuming -1 is 4 bytes of 0xff.
    184    uprv_memset(indexList, 0xff, (length + 4) * sizeof(int32_t));
    185    if (length > 0) {
    186        indexList[2] = 0;
    187        index = U16_LENGTH(inString.char32At(0));
    188        if (length > 1) {
    189            indexList[3] = index;
    190            index += U16_LENGTH(inString.char32At(index));
    191            if (length > 2) {
    192                indexList[4] = index;
    193                index += U16_LENGTH(inString.char32At(index));
    194                if (length > 3) {
    195                    indexList[5] = index;
    196                    index += U16_LENGTH(inString.char32At(index));
    197                }
    198            }
    199        }
    200    }
    201    return index;
    202 }
    203 
    204 void MlBreakEngine::loadMLModel(UErrorCode &error) {
    205    // BudouX's model consists of thirteen categories, each of which is make up of pairs of the
    206    // feature and its score. As integrating it into jaml.txt, we define thirteen kinds of key and
    207    // value to represent the feature and the corresponding score respectively.
    208 
    209    if (U_FAILURE(error)) return;
    210 
    211    UnicodeString key;
    212    StackUResourceBundle stackTempBundle;
    213    ResourceDataValue modelKey;
    214 
    215    LocalUResourceBundlePointer rbp(ures_openDirect(U_ICUDATA_BRKITR, "jaml", &error));
    216    UResourceBundle *rb = rbp.getAlias();
    217    if (U_FAILURE(error)) return;
    218 
    219    int32_t index = 0;
    220    initKeyValue(rb, "UW1Keys", "UW1Values", fModel[index++], error);
    221    initKeyValue(rb, "UW2Keys", "UW2Values", fModel[index++], error);
    222    initKeyValue(rb, "UW3Keys", "UW3Values", fModel[index++], error);
    223    initKeyValue(rb, "UW4Keys", "UW4Values", fModel[index++], error);
    224    initKeyValue(rb, "UW5Keys", "UW5Values", fModel[index++], error);
    225    initKeyValue(rb, "UW6Keys", "UW6Values", fModel[index++], error);
    226    initKeyValue(rb, "BW1Keys", "BW1Values", fModel[index++], error);
    227    initKeyValue(rb, "BW2Keys", "BW2Values", fModel[index++], error);
    228    initKeyValue(rb, "BW3Keys", "BW3Values", fModel[index++], error);
    229    initKeyValue(rb, "TW1Keys", "TW1Values", fModel[index++], error);
    230    initKeyValue(rb, "TW2Keys", "TW2Values", fModel[index++], error);
    231    initKeyValue(rb, "TW3Keys", "TW3Values", fModel[index++], error);
    232    initKeyValue(rb, "TW4Keys", "TW4Values", fModel[index++], error);
    233    fNegativeSum /= 2;
    234 }
    235 
    236 void MlBreakEngine::initKeyValue(UResourceBundle *rb, const char *keyName, const char *valueName,
    237                                 Hashtable &model, UErrorCode &error) {
    238    int32_t keySize = 0;
    239    int32_t valueSize = 0;
    240    int32_t stringLength = 0;
    241    UnicodeString key;
    242    StackUResourceBundle stackTempBundle;
    243    ResourceDataValue modelKey;
    244 
    245    // get modelValues
    246    LocalUResourceBundlePointer modelValue(ures_getByKey(rb, valueName, nullptr, &error));
    247    const int32_t *value = ures_getIntVector(modelValue.getAlias(), &valueSize, &error);
    248    if (U_FAILURE(error)) return;
    249 
    250    // get modelKeys
    251    ures_getValueWithFallback(rb, keyName, stackTempBundle.getAlias(), modelKey, error);
    252    ResourceArray stringArray = modelKey.getArray(error);
    253    keySize = stringArray.getSize();
    254    if (U_FAILURE(error)) return;
    255 
    256    for (int32_t idx = 0; idx < keySize; idx++) {
    257        stringArray.getValue(idx, modelKey);
    258        key = UnicodeString(modelKey.getString(stringLength, error));
    259        if (U_SUCCESS(error)) {
    260            U_ASSERT(idx < valueSize);
    261            fNegativeSum -= value[idx];
    262            model.puti(key, value[idx], error);
    263        }
    264    }
    265 }
    266 
    267 U_NAMESPACE_END
    268 
    269 #endif /* #if !UCONFIG_NO_BREAK_ITERATION */