diff options
author | allenwtsu <allenwtsu@google.com> | 2023-01-31 18:17:02 +0800 |
---|---|---|
committer | allenwtsu <allenwtsu@google.com> | 2023-02-04 05:54:01 +0800 |
commit | d59fec558de2456791e2e659c8f3aad70dc59a38 (patch) | |
tree | 633ba36a9f890120649750dedd6b946f674c851f /icu4j | |
parent | 0556545bf9f357f9676d1a4b1fa12c9ded5f3ced (diff) | |
download | icu-d59fec558de2456791e2e659c8f3aad70dc59a38.tar.gz |
ICU-22100 Modify ML model to improve Japanese phrase breaking performance
cherry-pick from https://github.com/unicode-org/icu/pull/2297
Bug: 219529457
Test: atest CtsIcuTestCases
Change-Id: I482d01f57848ef90ba3a64ea5880978a2535dd8d
Diffstat (limited to 'icu4j')
-rw-r--r-- | icu4j/main/classes/core/src/com/ibm/icu/impl/breakiter/MlBreakEngine.java | 283 | ||||
-rw-r--r-- | icu4j/main/shared/data/icudata.jar | bin | 12491010 -> 12490325 bytes |
2 files changed, 143 insertions, 140 deletions
diff --git a/icu4j/main/classes/core/src/com/ibm/icu/impl/breakiter/MlBreakEngine.java b/icu4j/main/classes/core/src/com/ibm/icu/impl/breakiter/MlBreakEngine.java index 196579d0a..e09c1763d 100644 --- a/icu4j/main/classes/core/src/com/ibm/icu/impl/breakiter/MlBreakEngine.java +++ b/icu4j/main/classes/core/src/com/ibm/icu/impl/breakiter/MlBreakEngine.java @@ -8,26 +8,36 @@ import static com.ibm.icu.impl.CharacterIteration.current32; import static com.ibm.icu.impl.CharacterIteration.next32; import static com.ibm.icu.impl.CharacterIteration.previous32; -import com.ibm.icu.impl.Assert; import com.ibm.icu.impl.ICUData; -import com.ibm.icu.lang.UCharacter; import com.ibm.icu.text.UnicodeSet; import com.ibm.icu.util.UResourceBundle; import com.ibm.icu.util.UResourceBundleIterator; -import java.lang.System; import java.text.CharacterIterator; +import java.util.Arrays; import java.util.ArrayList; +import java.util.List; import java.util.HashMap; -public class MlBreakEngine { +enum ModelIndex { + kUWStart(0), kBWStart(6), kTWStart(9); + private final int value; + + private ModelIndex(int value) { + this.value = value; + } + + public int getValue() { + return value; + } +} - private static final int INVALID = '|'; - private static final String INVALID_STRING = "|"; +public class MlBreakEngine { + // {UW1, UW2, ... UW6, BW1, ... BW3, TW1, TW2, ... TW4} 6+3+4= 13 private static final int MAX_FEATURE = 13; private UnicodeSet fDigitOrOpenPunctuationOrAlphabetSet; private UnicodeSet fClosePunctuationSet; - private HashMap<String, Integer> fModel; + private List<HashMap<String, Integer>> fModel; private int fNegativeSum; /** @@ -41,7 +51,10 @@ public class MlBreakEngine { UnicodeSet closePunctuationSet) { fDigitOrOpenPunctuationOrAlphabetSet = digitOrOpenPunctuationOrAlphabetSet; fClosePunctuationSet = closePunctuationSet; - fModel = new HashMap<String, Integer>(); + fModel = new ArrayList<HashMap<String, Integer>>(MAX_FEATURE); + for (int i = 0; i < MAX_FEATURE; i++) { + fModel.add(new HashMap<String, Integer>()); + } fNegativeSum = 0; loadMLModel(); } @@ -49,42 +62,47 @@ public class MlBreakEngine { /** * Divide up a range of characters handled by this break engine. * - * @param inText A input text. - * @param startPos The start index of the input text. - * @param endPos The end index of the input text. - * @param inString A input string normalized from inText from startPos to endPos - * @param numCodePts The number of code points of inString - * @param charPositions A map that transforms inString's code point index to code unit index. - * @param foundBreaks A list to store the breakpoint. + * @param inText An input text. + * @param startPos The start index of the input text. + * @param endPos The end index of the input text. + * @param inString A input string normalized from inText from startPos to endPos + * @param codePointLength The number of code points of inString + * @param charPositions A map that transforms inString's code point index to code unit index. + * @param foundBreaks A list to store the breakpoint. * @return The number of breakpoints */ public int divideUpRange(CharacterIterator inText, int startPos, int endPos, - CharacterIterator inString, int numCodePts, int[] charPositions, + CharacterIterator inString, int codePointLength, int[] charPositions, DictionaryBreakEngine.DequeI foundBreaks) { if (startPos >= endPos) { return 0; } - ArrayList<Integer> boundary = new ArrayList<Integer>(numCodePts); - // The ML model groups six char to evaluate if the 4th char is a breakpoint. - // Like a sliding window, the elementList removes the first char and appends the new char - // from inString in each iteration so that its size always remains at six. - int elementList[] = new int[6]; - initElementList(inString, elementList, numCodePts); + ArrayList<Integer> boundary = new ArrayList<Integer>(codePointLength); + String inputStr = transform(inString); + // The ML algorithm groups six char and evaluates whether the 4th char is a breakpoint. + // In each iteration, it evaluates the 4th char and then moves forward one char like + // sliding window. Initially, the first six values in the indexList are + // [-1, -1, 0, 1, 2, 3]. After moving forward, finally the last six values in the indexList + // are [length-4, length-3, length-2, length-1, -1, -1]. The "+4" here means four extra + // "-1". + int indexSize = codePointLength + 4; + int indexList[] = new int[indexSize]; + int numCodeUnits = initIndexList(inString, indexList, codePointLength); // Add a break for the start. boundary.add(0, 0); - for (int i = 1; i < numCodePts; i++) { - evaluateBreakpoint(elementList, i, boundary); - if (i + 1 > numCodePts) { - break; + + for (int idx = 0; idx + 1 < codePointLength; idx++) { + evaluateBreakpoint(inputStr, indexList, idx, numCodeUnits, boundary); + if (idx + 4 < codePointLength) { + indexList[idx + 6] = numCodeUnits; + numCodeUnits += Character.charCount(next32(inString)); } - shiftLeftOne(elementList); - elementList[5] = (i + 3) < numCodePts ? next32(inString) : INVALID; } // Add a break for the end if there is not one there already. - if (boundary.get(boundary.size() - 1) != numCodePts) { - boundary.add(numCodePts); + if (boundary.get(boundary.size() - 1) != codePointLength) { + boundary.add(codePointLength); } int correctedNumBreaks = 0; @@ -127,137 +145,94 @@ public class MlBreakEngine { return correctedNumBreaks; } - private void shiftLeftOne(int[] elementList) { - int length = elementList.length; - for (int i = 1; i < length; i++) { - elementList[i - 1] = elementList[i]; + /** + * Transform a CharacterIterator into a String. + */ + private String transform(CharacterIterator inString) { + StringBuilder sb = new StringBuilder(); + inString.setIndex(0); + for (char c = inString.first(); c != CharacterIterator.DONE; c = inString.next()) { + sb.append(c); } + return sb.toString(); } /** - * Evaluate whether the index is a potential breakpoint. + * Evaluate whether the breakpointIdx is a potential breakpoint. * - * @param elementList A list including six elements for the breakpoint evaluation. - * @param index The breakpoint index to be evaluated. - * @param boundary An list including the index of the breakpoint. + * @param inputStr An input string to be segmented. + * @param indexList A code unit index list of the inputStr. + * @param startIdx The start index of the indexList. + * @param numCodeUnits The current code unit boundary of the indexList. + * @param boundary A list including the index of the breakpoint. */ - private void evaluateBreakpoint(int[] elementList, int index, ArrayList<Integer> boundary) { - String[] featureList = new String[MAX_FEATURE]; - final int w1 = elementList[0]; - final int w2 = elementList[1]; - final int w3 = elementList[2]; - final int w4 = elementList[3]; - final int w5 = elementList[4]; - final int w6 = elementList[5]; + private void evaluateBreakpoint(String inputStr, int[] indexList, int startIdx, + int numCodeUnits, ArrayList<Integer> boundary) { + int start = 0, end = 0; + int score = fNegativeSum; - StringBuilder sb = new StringBuilder(); - int idx = 0; - if (w1 != INVALID) { - featureList[idx++] = sb.append("UW1:").appendCodePoint(w1).toString(); - } - if (w2 != INVALID) { - sb.setLength(0); - featureList[idx++] = sb.append("UW2:").appendCodePoint(w2).toString(); - } - if (w3 != INVALID) { - sb.setLength(0); - featureList[idx++] = sb.append("UW3:").appendCodePoint(w3).toString(); - } - if (w4 != INVALID) { - sb.setLength(0); - featureList[idx++] = sb.append("UW4:").appendCodePoint(w4).toString(); - } - if (w5 != INVALID) { - sb.setLength(0); - featureList[idx++] = sb.append("UW5:").appendCodePoint(w5).toString(); - } - if (w6 != INVALID) { - sb.setLength(0); - featureList[idx++] = sb.append("UW6:").appendCodePoint(w6).toString(); - } - if (w2 != INVALID && w3 != INVALID) { - sb.setLength(0); - featureList[idx++] = sb.append("BW1:").appendCodePoint(w2).appendCodePoint( - w3).toString(); - } - if (w3 != INVALID && w4 != INVALID) { - sb.setLength(0); - featureList[idx++] = sb.append("BW2:").appendCodePoint(w3).appendCodePoint( - w4).toString(); - } - if (w4 != INVALID && w5 != INVALID) { - sb.setLength(0); - featureList[idx++] = sb.append("BW3:").appendCodePoint(w4).appendCodePoint( - w5).toString(); - } - if (w1 != INVALID && w2 != INVALID && w3 != INVALID) { - sb.setLength(0); - featureList[idx++] = sb.append("TW1:").appendCodePoint(w1).appendCodePoint( - w2).appendCodePoint(w3).toString(); - } - if (w2 != INVALID && w3 != INVALID && w4 != INVALID) { - sb.setLength(0); - featureList[idx++] = sb.append("TW2:").appendCodePoint(w2).appendCodePoint( - w3).appendCodePoint(w4).toString(); - } - if (w3 != INVALID && w4 != INVALID && w5 != INVALID) { - sb.setLength(0); - featureList[idx++] = sb.append("TW3:").appendCodePoint(w3).appendCodePoint( - w4).appendCodePoint(w5).toString(); + for (int i = 0; i < 6; i++) { + // UW1 ~ UW6 + start = startIdx + i; + if (indexList[start] != -1) { + end = (indexList[start + 1] != -1) ? indexList[start + 1] : numCodeUnits; + score += fModel.get(ModelIndex.kUWStart.getValue() + i).getOrDefault( + inputStr.substring(indexList[start], end), 0); + } } - if (w4 != INVALID && w5 != INVALID && w6 != INVALID) { - sb.setLength(0); - featureList[idx++] = sb.append("TW4:").appendCodePoint(w4).appendCodePoint( - w5).appendCodePoint(w6).toString(); + for (int i = 0; i < 3; i++) { + // BW1 ~ BW3 + start = startIdx + i + 1; + if (indexList[start] != -1 && indexList[start + 1] != -1) { + end = (indexList[start + 2] != -1) ? indexList[start + 2] : numCodeUnits; + score += fModel.get(ModelIndex.kBWStart.getValue() + i).getOrDefault( + inputStr.substring(indexList[start], end), 0); + } } - - int score = fNegativeSum; - for (int j = 0; j < idx; j++) { - if (fModel.containsKey(featureList[j])) { - score += (2 * fModel.get(featureList[j])); + for (int i = 0; i < 4; i++) { + // TW1 ~ TW4 + start = startIdx + i; + if (indexList[start] != -1 + && indexList[start + 1] != -1 + && indexList[start + 2] != -1) { + end = (indexList[start + 3] != -1) ? indexList[start + 3] : numCodeUnits; + score += fModel.get(ModelIndex.kTWStart.getValue() + i).getOrDefault( + inputStr.substring(indexList[start], end), 0); } } if (score > 0) { - boundary.add(index); + boundary.add(startIdx + 1); } } /** - * Initialize the element list from the input string. + * Initialize the index list from the input string. * - * @param inString A input string to be segmented. - * @param elementList A list to store the first six characters. - * @param numCodePts The number of code points of input string + * @param inString An input string to be segmented. + * @param indexList A code unit index list of the inString. + * @param codePointLength The number of code points of the input string * @return The number of the code units of the first six characters in inString. */ - private int initElementList(CharacterIterator inString, int[] elementList, int numCodePts) { + private int initIndexList(CharacterIterator inString, int[] indexList, int codePointLength) { int index = 0; inString.setIndex(index); - int w1, w2, w3, w4, w5, w6; - w1 = w2 = w3 = w4 = w5 = w6 = INVALID; - if (numCodePts > 0) { - w3 = current32(inString); - index += Character.charCount(w3); - if (numCodePts > 1) { - w4 = next32(inString); - index += Character.charCount(w3); - if (numCodePts > 2) { - w5 = next32(inString); - index += Character.charCount(w5); - if (numCodePts > 3) { - w6 = next32(inString); - index += Character.charCount(w6); + Arrays.fill(indexList, -1); + if (codePointLength > 0) { + indexList[2] = 0; + index += Character.charCount(current32(inString)); + if (codePointLength > 1) { + indexList[3] = index; + index += Character.charCount(next32(inString)); + if (codePointLength > 2) { + indexList[4] = index; + index += Character.charCount(next32(inString)); + if (codePointLength > 3) { + indexList[5] = index; + index += Character.charCount(next32(inString)); } } } } - elementList[0] = w1; - elementList[1] = w2; - elementList[2] = w3; - elementList[3] = w4; - elementList[4] = w5; - elementList[5] = w6; - return index; } @@ -268,13 +243,41 @@ public class MlBreakEngine { int index = 0; UResourceBundle rb = UResourceBundle.getBundleInstance(ICUData.ICU_BRKITR_BASE_NAME, "jaml"); - UResourceBundle keyBundle = rb.get("modelKeys"); - UResourceBundle valueBundle = rb.get("modelValues"); + initKeyValue(rb, "UW1Keys", "UW1Values", fModel.get(index++)); + initKeyValue(rb, "UW2Keys", "UW2Values", fModel.get(index++)); + initKeyValue(rb, "UW3Keys", "UW3Values", fModel.get(index++)); + initKeyValue(rb, "UW4Keys", "UW4Values", fModel.get(index++)); + initKeyValue(rb, "UW5Keys", "UW5Values", fModel.get(index++)); + initKeyValue(rb, "UW6Keys", "UW6Values", fModel.get(index++)); + initKeyValue(rb, "BW1Keys", "BW1Values", fModel.get(index++)); + initKeyValue(rb, "BW2Keys", "BW2Values", fModel.get(index++)); + initKeyValue(rb, "BW3Keys", "BW3Values", fModel.get(index++)); + initKeyValue(rb, "TW1Keys", "TW1Values", fModel.get(index++)); + initKeyValue(rb, "TW2Keys", "TW2Values", fModel.get(index++)); + initKeyValue(rb, "TW3Keys", "TW3Values", fModel.get(index++)); + initKeyValue(rb, "TW4Keys", "TW4Values", fModel.get(index++)); + fNegativeSum /= 2; + } + + /** + * In the machine learning's model file, specify the name of the key and value to load the + * corresponding feature and its score. + * + * @param rb A RedouceBundle corresponding to the model file. + * @param keyName The kay name in the model file. + * @param valueName The value name in the model file. + * @param map A HashMap to store the pairs of the feature and its score. + */ + private void initKeyValue(UResourceBundle rb, String keyName, String valueName, + HashMap<String, Integer> map) { + int idx = 0; + UResourceBundle keyBundle = rb.get(keyName); + UResourceBundle valueBundle = rb.get(valueName); int[] value = valueBundle.getIntVector(); UResourceBundleIterator iterator = keyBundle.getIterator(); while (iterator.hasNext()) { - fNegativeSum -= value[index]; - fModel.put(iterator.nextString(), value[index++]); + fNegativeSum -= value[idx]; + map.put(iterator.nextString(), value[idx++]); } } } diff --git a/icu4j/main/shared/data/icudata.jar b/icu4j/main/shared/data/icudata.jar Binary files differindex 0495ee693..f2f0b68e7 100644 --- a/icu4j/main/shared/data/icudata.jar +++ b/icu4j/main/shared/data/icudata.jar |