summaryrefslogtreecommitdiff
path: root/native/lang_id/features/relevant-script-feature.cc
blob: f24c23c6af9086f1c705c5709cadd6f25f84ef2b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "lang_id/features/relevant-script-feature.h"

#include <string>
#include <vector>

#include "lang_id/common/fel/feature-types.h"
#include "lang_id/common/fel/task-context.h"
#include "lang_id/common/fel/workspace.h"
#include "lang_id/common/lite_base/logging.h"
#include "lang_id/common/utf8.h"
#include "lang_id/script/script-detector.h"

namespace libtextclassifier3 {
namespace mobile {
namespace lang_id {

bool RelevantScriptFeature::Setup(TaskContext *context) {
  std::string script_detector_name = GetParameter(
      "script_detector_name", /* default_value = */ "tiny-script-detector");

  // We don't use absl::WrapUnique, nor the rest of absl, see http://b/71873194
  script_detector_.reset(ScriptDetector::Create(script_detector_name));
  if (script_detector_ == nullptr) {
    // This means ScriptDetector::Create() could not find the requested
    // script_detector_name.  In that case, Create() already logged an error
    // message.
    return false;
  }

  // We use default value 172 because this is the number of scripts supported by
  // the first model we trained with this feature.  See http://b/70617713.
  // Newer models may support more scripts.
  num_supported_scripts_ = GetIntParameter("num_supported_scripts", 172);
  return true;
}

bool RelevantScriptFeature::Init(TaskContext *context) {
  set_feature_type(new NumericFeatureType(name(), num_supported_scripts_));
  return true;
}

void RelevantScriptFeature::Evaluate(
    const WorkspaceSet &workspaces, const LightSentence &sentence,
    FeatureVector *result) const {
  // counts[s] is the number of characters with script s.
  std::vector<int> counts(num_supported_scripts_);
  int total_count = 0;
  for (const std::string &word : sentence) {
    const char *const word_end = word.data() + word.size();
    const char *curr = word.data();

    // Skip over token start '^'.
    SAFTM_DCHECK_EQ(*curr, '^');
    curr += utils::OneCharLen(curr);
    while (true) {
      const int num_bytes = utils::OneCharLen(curr);

      int script = script_detector_->GetScript(curr, num_bytes);

      // We do this update and the if (...) break below *before* incrementing
      // counts[script] in order to skip the token end '$'.
      curr += num_bytes;
      if (curr >= word_end) {
        SAFTM_DCHECK_EQ(*(curr - num_bytes), '$');
        break;
      }
      SAFTM_DCHECK_GE(script, 0);

      if (script < num_supported_scripts_) {
        counts[script]++;
        total_count++;
      } else {
        // Unsupported script: this usually indicates a script that is
        // recognized by newer versions of the code, after the model was
        // trained.  E.g., new code running with old model.
      }
    }
  }

  for (int script_id = 0; script_id < num_supported_scripts_; ++script_id) {
    int count = counts[script_id];
    if (count > 0) {
      const float weight = static_cast<float>(count) / total_count;
      FloatFeatureValue value(script_id, weight);
      result->add(feature_type(), value.discrete_value);
    }
  }
}

SAFTM_STATIC_REGISTRATION(RelevantScriptFeature);

}  // namespace lang_id
}  // namespace mobile
}  // namespace nlp_saft