mlbe.cpp (10730B)
1 // © 2022 and later: Unicode, Inc. and others. 2 // License & terms of use: http://www.unicode.org/copyright.html 3 4 #include "unicode/utypes.h" 5 6 #if !UCONFIG_NO_BREAK_ITERATION 7 8 #include "cmemory.h" 9 #include "mlbe.h" 10 #include "uassert.h" 11 #include "ubrkimpl.h" 12 #include "unicode/resbund.h" 13 #include "unicode/udata.h" 14 #include "unicode/utf16.h" 15 #include "uresimp.h" 16 #include "util.h" 17 #include "uvectr32.h" 18 19 U_NAMESPACE_BEGIN 20 21 enum class ModelIndex { kUWStart = 0, kBWStart = 6, kTWStart = 9 }; 22 23 MlBreakEngine::MlBreakEngine(const UnicodeSet &digitOrOpenPunctuationOrAlphabetSet, 24 const UnicodeSet &closePunctuationSet, UErrorCode &status) 25 : fDigitOrOpenPunctuationOrAlphabetSet(digitOrOpenPunctuationOrAlphabetSet), 26 fClosePunctuationSet(closePunctuationSet), 27 fNegativeSum(0) { 28 if (U_FAILURE(status)) { 29 return; 30 } 31 loadMLModel(status); 32 } 33 34 MlBreakEngine::~MlBreakEngine() {} 35 36 int32_t MlBreakEngine::divideUpRange(UText *inText, int32_t rangeStart, int32_t rangeEnd, 37 UVector32 &foundBreaks, const UnicodeString &inString, 38 const LocalPointer<UVector32> &inputMap, 39 UErrorCode &status) const { 40 if (U_FAILURE(status)) { 41 return 0; 42 } 43 if (rangeStart >= rangeEnd) { 44 status = U_ILLEGAL_ARGUMENT_ERROR; 45 return 0; 46 } 47 48 UVector32 boundary(inString.countChar32() + 1, status); 49 if (U_FAILURE(status)) { 50 return 0; 51 } 52 int32_t numBreaks = 0; 53 int32_t codePointLength = inString.countChar32(); 54 // The ML algorithm groups six char and evaluates whether the 4th char is a breakpoint. 55 // In each iteration, it evaluates the 4th char and then moves forward one char like a sliding 56 // window. Initially, the first six values in the indexList are [-1, -1, 0, 1, 2, 3]. After 57 // moving forward, finally the last six values in the indexList are 58 // [length-4, length-3, length-2, length-1, -1, -1]. The "+4" here means four extra "-1". 59 int32_t indexSize = codePointLength + 4; 60 LocalMemory<int32_t> indexList(static_cast<int32_t*>(uprv_malloc(indexSize * sizeof(int32_t)))); 61 if (indexList.isNull()) { 62 status = U_MEMORY_ALLOCATION_ERROR; 63 return 0; 64 } 65 int32_t numCodeUnits = initIndexList(inString, indexList.getAlias(), status); 66 67 // Add a break for the start. 68 boundary.addElement(0, status); 69 numBreaks++; 70 if (U_FAILURE(status)) return 0; 71 72 for (int32_t idx = 0; idx + 1 < codePointLength && U_SUCCESS(status); idx++) { 73 numBreaks = 74 evaluateBreakpoint(inString, indexList.getAlias(), idx, numCodeUnits, numBreaks, boundary, status); 75 if (idx + 4 < codePointLength) { 76 indexList[idx + 6] = numCodeUnits; 77 numCodeUnits += U16_LENGTH(inString.char32At(indexList[idx + 6])); 78 } 79 } 80 81 if (U_FAILURE(status)) return 0; 82 83 // Add a break for the end if there is not one there already. 84 if (boundary.lastElementi() != inString.countChar32()) { 85 boundary.addElement(inString.countChar32(), status); 86 numBreaks++; 87 } 88 89 int32_t prevCPPos = -1; 90 int32_t prevUTextPos = -1; 91 int32_t correctedNumBreaks = 0; 92 for (int32_t i = 0; i < numBreaks; i++) { 93 int32_t cpPos = boundary.elementAti(i); 94 int32_t utextPos = inputMap.isValid() ? inputMap->elementAti(cpPos) : cpPos + rangeStart; 95 U_ASSERT(cpPos > prevCPPos); 96 U_ASSERT(utextPos >= prevUTextPos); 97 98 if (utextPos > prevUTextPos) { 99 if (utextPos != rangeStart || 100 (utextPos > 0 && 101 fClosePunctuationSet.contains(utext_char32At(inText, utextPos - 1)))) { 102 foundBreaks.push(utextPos, status); 103 correctedNumBreaks++; 104 } 105 } else { 106 // Normalization expanded the input text, the dictionary found a boundary 107 // within the expansion, giving two boundaries with the same index in the 108 // original text. Ignore the second. See ticket #12918. 109 --numBreaks; 110 } 111 prevCPPos = cpPos; 112 prevUTextPos = utextPos; 113 } 114 (void)prevCPPos; // suppress compiler warnings about unused variable 115 116 UChar32 nextChar = utext_char32At(inText, rangeEnd); 117 if (!foundBreaks.isEmpty() && foundBreaks.peeki() == rangeEnd) { 118 // In phrase breaking, there has to be a breakpoint between Cj character and 119 // the number/open punctuation. 120 // E.g. る文字「そうだ、京都」->る▁文字▁「そうだ、▁京都」-> breakpoint between 字 and「 121 // E.g. 乗車率90%程度だろうか -> 乗車▁率▁90%▁程度だろうか -> breakpoint between 率 and 9 122 // E.g. しかもロゴがUnicode! -> しかも▁ロゴが▁Unicode!-> breakpoint between が and U 123 if (!fDigitOrOpenPunctuationOrAlphabetSet.contains(nextChar)) { 124 foundBreaks.popi(); 125 correctedNumBreaks--; 126 } 127 } 128 129 return correctedNumBreaks; 130 } 131 132 int32_t MlBreakEngine::evaluateBreakpoint(const UnicodeString &inString, int32_t *indexList, 133 int32_t startIdx, int32_t numCodeUnits, int32_t numBreaks, 134 UVector32 &boundary, UErrorCode &status) const { 135 if (U_FAILURE(status)) { 136 return numBreaks; 137 } 138 int32_t start = 0, end = 0; 139 int32_t score = fNegativeSum; 140 141 for (int i = 0; i < 6; i++) { 142 // UW1 ~ UW6 143 start = startIdx + i; 144 if (indexList[start] != -1) { 145 end = (indexList[start + 1] != -1) ? indexList[start + 1] : numCodeUnits; 146 score += fModel[static_cast<int32_t>(ModelIndex::kUWStart) + i].geti( 147 inString.tempSubString(indexList[start], end - indexList[start])); 148 } 149 } 150 for (int i = 0; i < 3; i++) { 151 // BW1 ~ BW3 152 start = startIdx + i + 1; 153 if (indexList[start] != -1 && indexList[start + 1] != -1) { 154 end = (indexList[start + 2] != -1) ? indexList[start + 2] : numCodeUnits; 155 score += fModel[static_cast<int32_t>(ModelIndex::kBWStart) + i].geti( 156 inString.tempSubString(indexList[start], end - indexList[start])); 157 } 158 } 159 for (int i = 0; i < 4; i++) { 160 // TW1 ~ TW4 161 start = startIdx + i; 162 if (indexList[start] != -1 && indexList[start + 1] != -1 && indexList[start + 2] != -1) { 163 end = (indexList[start + 3] != -1) ? indexList[start + 3] : numCodeUnits; 164 score += fModel[static_cast<int32_t>(ModelIndex::kTWStart) + i].geti( 165 inString.tempSubString(indexList[start], end - indexList[start])); 166 } 167 } 168 169 if (score > 0) { 170 boundary.addElement(startIdx + 1, status); 171 numBreaks++; 172 } 173 return numBreaks; 174 } 175 176 int32_t MlBreakEngine::initIndexList(const UnicodeString &inString, int32_t *indexList, 177 UErrorCode &status) const { 178 if (U_FAILURE(status)) { 179 return 0; 180 } 181 int32_t index = 0; 182 int32_t length = inString.countChar32(); 183 // Set all (lenght+4) items inside indexLength to -1 presuming -1 is 4 bytes of 0xff. 184 uprv_memset(indexList, 0xff, (length + 4) * sizeof(int32_t)); 185 if (length > 0) { 186 indexList[2] = 0; 187 index = U16_LENGTH(inString.char32At(0)); 188 if (length > 1) { 189 indexList[3] = index; 190 index += U16_LENGTH(inString.char32At(index)); 191 if (length > 2) { 192 indexList[4] = index; 193 index += U16_LENGTH(inString.char32At(index)); 194 if (length > 3) { 195 indexList[5] = index; 196 index += U16_LENGTH(inString.char32At(index)); 197 } 198 } 199 } 200 } 201 return index; 202 } 203 204 void MlBreakEngine::loadMLModel(UErrorCode &error) { 205 // BudouX's model consists of thirteen categories, each of which is make up of pairs of the 206 // feature and its score. As integrating it into jaml.txt, we define thirteen kinds of key and 207 // value to represent the feature and the corresponding score respectively. 208 209 if (U_FAILURE(error)) return; 210 211 UnicodeString key; 212 StackUResourceBundle stackTempBundle; 213 ResourceDataValue modelKey; 214 215 LocalUResourceBundlePointer rbp(ures_openDirect(U_ICUDATA_BRKITR, "jaml", &error)); 216 UResourceBundle *rb = rbp.getAlias(); 217 if (U_FAILURE(error)) return; 218 219 int32_t index = 0; 220 initKeyValue(rb, "UW1Keys", "UW1Values", fModel[index++], error); 221 initKeyValue(rb, "UW2Keys", "UW2Values", fModel[index++], error); 222 initKeyValue(rb, "UW3Keys", "UW3Values", fModel[index++], error); 223 initKeyValue(rb, "UW4Keys", "UW4Values", fModel[index++], error); 224 initKeyValue(rb, "UW5Keys", "UW5Values", fModel[index++], error); 225 initKeyValue(rb, "UW6Keys", "UW6Values", fModel[index++], error); 226 initKeyValue(rb, "BW1Keys", "BW1Values", fModel[index++], error); 227 initKeyValue(rb, "BW2Keys", "BW2Values", fModel[index++], error); 228 initKeyValue(rb, "BW3Keys", "BW3Values", fModel[index++], error); 229 initKeyValue(rb, "TW1Keys", "TW1Values", fModel[index++], error); 230 initKeyValue(rb, "TW2Keys", "TW2Values", fModel[index++], error); 231 initKeyValue(rb, "TW3Keys", "TW3Values", fModel[index++], error); 232 initKeyValue(rb, "TW4Keys", "TW4Values", fModel[index++], error); 233 fNegativeSum /= 2; 234 } 235 236 void MlBreakEngine::initKeyValue(UResourceBundle *rb, const char *keyName, const char *valueName, 237 Hashtable &model, UErrorCode &error) { 238 int32_t keySize = 0; 239 int32_t valueSize = 0; 240 int32_t stringLength = 0; 241 UnicodeString key; 242 StackUResourceBundle stackTempBundle; 243 ResourceDataValue modelKey; 244 245 // get modelValues 246 LocalUResourceBundlePointer modelValue(ures_getByKey(rb, valueName, nullptr, &error)); 247 const int32_t *value = ures_getIntVector(modelValue.getAlias(), &valueSize, &error); 248 if (U_FAILURE(error)) return; 249 250 // get modelKeys 251 ures_getValueWithFallback(rb, keyName, stackTempBundle.getAlias(), modelKey, error); 252 ResourceArray stringArray = modelKey.getArray(error); 253 keySize = stringArray.getSize(); 254 if (U_FAILURE(error)) return; 255 256 for (int32_t idx = 0; idx < keySize; idx++) { 257 stringArray.getValue(idx, modelKey); 258 key = UnicodeString(modelKey.getString(stringLength, error)); 259 if (U_SUCCESS(error)) { 260 U_ASSERT(idx < valueSize); 261 fNegativeSum -= value[idx]; 262 model.puti(key, value[idx], error); 263 } 264 } 265 } 266 267 U_NAMESPACE_END 268 269 #endif /* #if !UCONFIG_NO_BREAK_ITERATION */