summaryrefslogtreecommitdiff
path: root/icu4j
diff options
context:
space:
mode:
authorallenwtsu <allenwtsu@google.com>2023-01-31 18:17:02 +0800
committerallenwtsu <allenwtsu@google.com>2023-02-04 05:54:01 +0800
commitd59fec558de2456791e2e659c8f3aad70dc59a38 (patch)
tree633ba36a9f890120649750dedd6b946f674c851f /icu4j
parent0556545bf9f357f9676d1a4b1fa12c9ded5f3ced (diff)
downloadicu-d59fec558de2456791e2e659c8f3aad70dc59a38.tar.gz
ICU-22100 Modify ML model to improve Japanese phrase breaking performance
cherry-pick from https://github.com/unicode-org/icu/pull/2297 Bug: 219529457 Test: atest CtsIcuTestCases Change-Id: I482d01f57848ef90ba3a64ea5880978a2535dd8d
Diffstat (limited to 'icu4j')
-rw-r--r--icu4j/main/classes/core/src/com/ibm/icu/impl/breakiter/MlBreakEngine.java283
-rw-r--r--icu4j/main/shared/data/icudata.jarbin12491010 -> 12490325 bytes
2 files changed, 143 insertions, 140 deletions
diff --git a/icu4j/main/classes/core/src/com/ibm/icu/impl/breakiter/MlBreakEngine.java b/icu4j/main/classes/core/src/com/ibm/icu/impl/breakiter/MlBreakEngine.java
index 196579d0a..e09c1763d 100644
--- a/icu4j/main/classes/core/src/com/ibm/icu/impl/breakiter/MlBreakEngine.java
+++ b/icu4j/main/classes/core/src/com/ibm/icu/impl/breakiter/MlBreakEngine.java
@@ -8,26 +8,36 @@ import static com.ibm.icu.impl.CharacterIteration.current32;
import static com.ibm.icu.impl.CharacterIteration.next32;
import static com.ibm.icu.impl.CharacterIteration.previous32;
-import com.ibm.icu.impl.Assert;
import com.ibm.icu.impl.ICUData;
-import com.ibm.icu.lang.UCharacter;
import com.ibm.icu.text.UnicodeSet;
import com.ibm.icu.util.UResourceBundle;
import com.ibm.icu.util.UResourceBundleIterator;
-import java.lang.System;
import java.text.CharacterIterator;
+import java.util.Arrays;
import java.util.ArrayList;
+import java.util.List;
import java.util.HashMap;
-public class MlBreakEngine {
+enum ModelIndex {
+ kUWStart(0), kBWStart(6), kTWStart(9);
+ private final int value;
+
+ private ModelIndex(int value) {
+ this.value = value;
+ }
+
+ public int getValue() {
+ return value;
+ }
+}
- private static final int INVALID = '|';
- private static final String INVALID_STRING = "|";
+public class MlBreakEngine {
+ // {UW1, UW2, ... UW6, BW1, ... BW3, TW1, TW2, ... TW4} 6+3+4= 13
private static final int MAX_FEATURE = 13;
private UnicodeSet fDigitOrOpenPunctuationOrAlphabetSet;
private UnicodeSet fClosePunctuationSet;
- private HashMap<String, Integer> fModel;
+ private List<HashMap<String, Integer>> fModel;
private int fNegativeSum;
/**
@@ -41,7 +51,10 @@ public class MlBreakEngine {
UnicodeSet closePunctuationSet) {
fDigitOrOpenPunctuationOrAlphabetSet = digitOrOpenPunctuationOrAlphabetSet;
fClosePunctuationSet = closePunctuationSet;
- fModel = new HashMap<String, Integer>();
+ fModel = new ArrayList<HashMap<String, Integer>>(MAX_FEATURE);
+ for (int i = 0; i < MAX_FEATURE; i++) {
+ fModel.add(new HashMap<String, Integer>());
+ }
fNegativeSum = 0;
loadMLModel();
}
@@ -49,42 +62,47 @@ public class MlBreakEngine {
/**
* Divide up a range of characters handled by this break engine.
*
- * @param inText A input text.
- * @param startPos The start index of the input text.
- * @param endPos The end index of the input text.
- * @param inString A input string normalized from inText from startPos to endPos
- * @param numCodePts The number of code points of inString
- * @param charPositions A map that transforms inString's code point index to code unit index.
- * @param foundBreaks A list to store the breakpoint.
+ * @param inText An input text.
+ * @param startPos The start index of the input text.
+ * @param endPos The end index of the input text.
+ * @param inString A input string normalized from inText from startPos to endPos
+ * @param codePointLength The number of code points of inString
+ * @param charPositions A map that transforms inString's code point index to code unit index.
+ * @param foundBreaks A list to store the breakpoint.
* @return The number of breakpoints
*/
public int divideUpRange(CharacterIterator inText, int startPos, int endPos,
- CharacterIterator inString, int numCodePts, int[] charPositions,
+ CharacterIterator inString, int codePointLength, int[] charPositions,
DictionaryBreakEngine.DequeI foundBreaks) {
if (startPos >= endPos) {
return 0;
}
- ArrayList<Integer> boundary = new ArrayList<Integer>(numCodePts);
- // The ML model groups six char to evaluate if the 4th char is a breakpoint.
- // Like a sliding window, the elementList removes the first char and appends the new char
- // from inString in each iteration so that its size always remains at six.
- int elementList[] = new int[6];
- initElementList(inString, elementList, numCodePts);
+ ArrayList<Integer> boundary = new ArrayList<Integer>(codePointLength);
+ String inputStr = transform(inString);
+ // The ML algorithm groups six char and evaluates whether the 4th char is a breakpoint.
+ // In each iteration, it evaluates the 4th char and then moves forward one char like
+ // sliding window. Initially, the first six values in the indexList are
+ // [-1, -1, 0, 1, 2, 3]. After moving forward, finally the last six values in the indexList
+ // are [length-4, length-3, length-2, length-1, -1, -1]. The "+4" here means four extra
+ // "-1".
+ int indexSize = codePointLength + 4;
+ int indexList[] = new int[indexSize];
+ int numCodeUnits = initIndexList(inString, indexList, codePointLength);
// Add a break for the start.
boundary.add(0, 0);
- for (int i = 1; i < numCodePts; i++) {
- evaluateBreakpoint(elementList, i, boundary);
- if (i + 1 > numCodePts) {
- break;
+
+ for (int idx = 0; idx + 1 < codePointLength; idx++) {
+ evaluateBreakpoint(inputStr, indexList, idx, numCodeUnits, boundary);
+ if (idx + 4 < codePointLength) {
+ indexList[idx + 6] = numCodeUnits;
+ numCodeUnits += Character.charCount(next32(inString));
}
- shiftLeftOne(elementList);
- elementList[5] = (i + 3) < numCodePts ? next32(inString) : INVALID;
}
// Add a break for the end if there is not one there already.
- if (boundary.get(boundary.size() - 1) != numCodePts) {
- boundary.add(numCodePts);
+ if (boundary.get(boundary.size() - 1) != codePointLength) {
+ boundary.add(codePointLength);
}
int correctedNumBreaks = 0;
@@ -127,137 +145,94 @@ public class MlBreakEngine {
return correctedNumBreaks;
}
- private void shiftLeftOne(int[] elementList) {
- int length = elementList.length;
- for (int i = 1; i < length; i++) {
- elementList[i - 1] = elementList[i];
+ /**
+ * Transform a CharacterIterator into a String.
+ */
+ private String transform(CharacterIterator inString) {
+ StringBuilder sb = new StringBuilder();
+ inString.setIndex(0);
+ for (char c = inString.first(); c != CharacterIterator.DONE; c = inString.next()) {
+ sb.append(c);
}
+ return sb.toString();
}
/**
- * Evaluate whether the index is a potential breakpoint.
+ * Evaluate whether the breakpointIdx is a potential breakpoint.
*
- * @param elementList A list including six elements for the breakpoint evaluation.
- * @param index The breakpoint index to be evaluated.
- * @param boundary An list including the index of the breakpoint.
+ * @param inputStr An input string to be segmented.
+ * @param indexList A code unit index list of the inputStr.
+ * @param startIdx The start index of the indexList.
+ * @param numCodeUnits The current code unit boundary of the indexList.
+ * @param boundary A list including the index of the breakpoint.
*/
- private void evaluateBreakpoint(int[] elementList, int index, ArrayList<Integer> boundary) {
- String[] featureList = new String[MAX_FEATURE];
- final int w1 = elementList[0];
- final int w2 = elementList[1];
- final int w3 = elementList[2];
- final int w4 = elementList[3];
- final int w5 = elementList[4];
- final int w6 = elementList[5];
+ private void evaluateBreakpoint(String inputStr, int[] indexList, int startIdx,
+ int numCodeUnits, ArrayList<Integer> boundary) {
+ int start = 0, end = 0;
+ int score = fNegativeSum;
- StringBuilder sb = new StringBuilder();
- int idx = 0;
- if (w1 != INVALID) {
- featureList[idx++] = sb.append("UW1:").appendCodePoint(w1).toString();
- }
- if (w2 != INVALID) {
- sb.setLength(0);
- featureList[idx++] = sb.append("UW2:").appendCodePoint(w2).toString();
- }
- if (w3 != INVALID) {
- sb.setLength(0);
- featureList[idx++] = sb.append("UW3:").appendCodePoint(w3).toString();
- }
- if (w4 != INVALID) {
- sb.setLength(0);
- featureList[idx++] = sb.append("UW4:").appendCodePoint(w4).toString();
- }
- if (w5 != INVALID) {
- sb.setLength(0);
- featureList[idx++] = sb.append("UW5:").appendCodePoint(w5).toString();
- }
- if (w6 != INVALID) {
- sb.setLength(0);
- featureList[idx++] = sb.append("UW6:").appendCodePoint(w6).toString();
- }
- if (w2 != INVALID && w3 != INVALID) {
- sb.setLength(0);
- featureList[idx++] = sb.append("BW1:").appendCodePoint(w2).appendCodePoint(
- w3).toString();
- }
- if (w3 != INVALID && w4 != INVALID) {
- sb.setLength(0);
- featureList[idx++] = sb.append("BW2:").appendCodePoint(w3).appendCodePoint(
- w4).toString();
- }
- if (w4 != INVALID && w5 != INVALID) {
- sb.setLength(0);
- featureList[idx++] = sb.append("BW3:").appendCodePoint(w4).appendCodePoint(
- w5).toString();
- }
- if (w1 != INVALID && w2 != INVALID && w3 != INVALID) {
- sb.setLength(0);
- featureList[idx++] = sb.append("TW1:").appendCodePoint(w1).appendCodePoint(
- w2).appendCodePoint(w3).toString();
- }
- if (w2 != INVALID && w3 != INVALID && w4 != INVALID) {
- sb.setLength(0);
- featureList[idx++] = sb.append("TW2:").appendCodePoint(w2).appendCodePoint(
- w3).appendCodePoint(w4).toString();
- }
- if (w3 != INVALID && w4 != INVALID && w5 != INVALID) {
- sb.setLength(0);
- featureList[idx++] = sb.append("TW3:").appendCodePoint(w3).appendCodePoint(
- w4).appendCodePoint(w5).toString();
+ for (int i = 0; i < 6; i++) {
+ // UW1 ~ UW6
+ start = startIdx + i;
+ if (indexList[start] != -1) {
+ end = (indexList[start + 1] != -1) ? indexList[start + 1] : numCodeUnits;
+ score += fModel.get(ModelIndex.kUWStart.getValue() + i).getOrDefault(
+ inputStr.substring(indexList[start], end), 0);
+ }
}
- if (w4 != INVALID && w5 != INVALID && w6 != INVALID) {
- sb.setLength(0);
- featureList[idx++] = sb.append("TW4:").appendCodePoint(w4).appendCodePoint(
- w5).appendCodePoint(w6).toString();
+ for (int i = 0; i < 3; i++) {
+ // BW1 ~ BW3
+ start = startIdx + i + 1;
+ if (indexList[start] != -1 && indexList[start + 1] != -1) {
+ end = (indexList[start + 2] != -1) ? indexList[start + 2] : numCodeUnits;
+ score += fModel.get(ModelIndex.kBWStart.getValue() + i).getOrDefault(
+ inputStr.substring(indexList[start], end), 0);
+ }
}
-
- int score = fNegativeSum;
- for (int j = 0; j < idx; j++) {
- if (fModel.containsKey(featureList[j])) {
- score += (2 * fModel.get(featureList[j]));
+ for (int i = 0; i < 4; i++) {
+ // TW1 ~ TW4
+ start = startIdx + i;
+ if (indexList[start] != -1
+ && indexList[start + 1] != -1
+ && indexList[start + 2] != -1) {
+ end = (indexList[start + 3] != -1) ? indexList[start + 3] : numCodeUnits;
+ score += fModel.get(ModelIndex.kTWStart.getValue() + i).getOrDefault(
+ inputStr.substring(indexList[start], end), 0);
}
}
if (score > 0) {
- boundary.add(index);
+ boundary.add(startIdx + 1);
}
}
/**
- * Initialize the element list from the input string.
+ * Initialize the index list from the input string.
*
- * @param inString A input string to be segmented.
- * @param elementList A list to store the first six characters.
- * @param numCodePts The number of code points of input string
+ * @param inString An input string to be segmented.
+ * @param indexList A code unit index list of the inString.
+ * @param codePointLength The number of code points of the input string
* @return The number of the code units of the first six characters in inString.
*/
- private int initElementList(CharacterIterator inString, int[] elementList, int numCodePts) {
+ private int initIndexList(CharacterIterator inString, int[] indexList, int codePointLength) {
int index = 0;
inString.setIndex(index);
- int w1, w2, w3, w4, w5, w6;
- w1 = w2 = w3 = w4 = w5 = w6 = INVALID;
- if (numCodePts > 0) {
- w3 = current32(inString);
- index += Character.charCount(w3);
- if (numCodePts > 1) {
- w4 = next32(inString);
- index += Character.charCount(w3);
- if (numCodePts > 2) {
- w5 = next32(inString);
- index += Character.charCount(w5);
- if (numCodePts > 3) {
- w6 = next32(inString);
- index += Character.charCount(w6);
+ Arrays.fill(indexList, -1);
+ if (codePointLength > 0) {
+ indexList[2] = 0;
+ index += Character.charCount(current32(inString));
+ if (codePointLength > 1) {
+ indexList[3] = index;
+ index += Character.charCount(next32(inString));
+ if (codePointLength > 2) {
+ indexList[4] = index;
+ index += Character.charCount(next32(inString));
+ if (codePointLength > 3) {
+ indexList[5] = index;
+ index += Character.charCount(next32(inString));
}
}
}
}
- elementList[0] = w1;
- elementList[1] = w2;
- elementList[2] = w3;
- elementList[3] = w4;
- elementList[4] = w5;
- elementList[5] = w6;
-
return index;
}
@@ -268,13 +243,41 @@ public class MlBreakEngine {
int index = 0;
UResourceBundle rb = UResourceBundle.getBundleInstance(ICUData.ICU_BRKITR_BASE_NAME,
"jaml");
- UResourceBundle keyBundle = rb.get("modelKeys");
- UResourceBundle valueBundle = rb.get("modelValues");
+ initKeyValue(rb, "UW1Keys", "UW1Values", fModel.get(index++));
+ initKeyValue(rb, "UW2Keys", "UW2Values", fModel.get(index++));
+ initKeyValue(rb, "UW3Keys", "UW3Values", fModel.get(index++));
+ initKeyValue(rb, "UW4Keys", "UW4Values", fModel.get(index++));
+ initKeyValue(rb, "UW5Keys", "UW5Values", fModel.get(index++));
+ initKeyValue(rb, "UW6Keys", "UW6Values", fModel.get(index++));
+ initKeyValue(rb, "BW1Keys", "BW1Values", fModel.get(index++));
+ initKeyValue(rb, "BW2Keys", "BW2Values", fModel.get(index++));
+ initKeyValue(rb, "BW3Keys", "BW3Values", fModel.get(index++));
+ initKeyValue(rb, "TW1Keys", "TW1Values", fModel.get(index++));
+ initKeyValue(rb, "TW2Keys", "TW2Values", fModel.get(index++));
+ initKeyValue(rb, "TW3Keys", "TW3Values", fModel.get(index++));
+ initKeyValue(rb, "TW4Keys", "TW4Values", fModel.get(index++));
+ fNegativeSum /= 2;
+ }
+
+ /**
+ * In the machine learning's model file, specify the name of the key and value to load the
+ * corresponding feature and its score.
+ *
+ * @param rb A RedouceBundle corresponding to the model file.
+ * @param keyName The kay name in the model file.
+ * @param valueName The value name in the model file.
+ * @param map A HashMap to store the pairs of the feature and its score.
+ */
+ private void initKeyValue(UResourceBundle rb, String keyName, String valueName,
+ HashMap<String, Integer> map) {
+ int idx = 0;
+ UResourceBundle keyBundle = rb.get(keyName);
+ UResourceBundle valueBundle = rb.get(valueName);
int[] value = valueBundle.getIntVector();
UResourceBundleIterator iterator = keyBundle.getIterator();
while (iterator.hasNext()) {
- fNegativeSum -= value[index];
- fModel.put(iterator.nextString(), value[index++]);
+ fNegativeSum -= value[idx];
+ map.put(iterator.nextString(), value[idx++]);
}
}
}
diff --git a/icu4j/main/shared/data/icudata.jar b/icu4j/main/shared/data/icudata.jar
index 0495ee693..f2f0b68e7 100644
--- a/icu4j/main/shared/data/icudata.jar
+++ b/icu4j/main/shared/data/icudata.jar
Binary files differ