summaryrefslogtreecommitdiff
path: root/native/lang_id/features/char-ngram-feature.cc
blob: 31faf2f3148c8f12030de5b08fa9ee7dbf831c3f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "lang_id/features/char-ngram-feature.h"

#include <string>
#include <utility>
#include <vector>

#include "lang_id/common/fel/feature-types.h"
#include "lang_id/common/fel/task-context.h"
#include "lang_id/common/lite_base/logging.h"
#include "lang_id/common/math/hash.h"
#include "lang_id/common/utf8.h"

namespace libtextclassifier3 {
namespace mobile {
namespace lang_id {

bool ContinuousBagOfNgramsFunction::Setup(TaskContext *context) {
  // Parameters in the feature function descriptor.
  bool include_terminators = GetBoolParameter("include_terminators", false);
  if (!include_terminators) {
    SAFTM_LOG(ERROR) << "No support for include_terminators=true";
    return false;
  }

  bool include_spaces = GetBoolParameter("include_spaces", false);
  if (include_spaces) {
    SAFTM_LOG(ERROR) << "No support for include_spaces=true";
    return false;
  }

  bool use_equal_ngram_weight = GetBoolParameter("use_equal_weight", false);
  if (use_equal_ngram_weight) {
    SAFTM_LOG(ERROR) << "No support for use_equal_weight=true";
    return false;
  }

  ngram_id_dimension_ = GetIntParameter("id_dim", 10000);
  ngram_size_ = GetIntParameter("size", 3);

  counts_.assign(ngram_id_dimension_, 0);
  return true;
}

bool ContinuousBagOfNgramsFunction::Init(TaskContext *context) {
  set_feature_type(new NumericFeatureType(name(), ngram_id_dimension_));
  return true;
}

int ContinuousBagOfNgramsFunction::ComputeNgramCounts(
    const LightSentence &sentence) const {
  SAFTM_CHECK_EQ(counts_.size(), ngram_id_dimension_);
  SAFTM_CHECK_EQ(non_zero_count_indices_.size(), 0);

  int total_count = 0;

  for (const std::string &word : sentence) {
    const char *const word_end = word.data() + word.size();

    // Set ngram_start at the start of the current token (word).
    const char *ngram_start = word.data();

    // Set ngram_end ngram_size UTF8 characters after ngram_start.  Note: each
    // UTF8 character contains between 1 and 4 bytes.
    const char *ngram_end = ngram_start;
    int num_utf8_chars = 0;
    do {
      ngram_end += utils::OneCharLen(ngram_end);
      num_utf8_chars++;
    } while ((num_utf8_chars < ngram_size_) && (ngram_end < word_end));

    if (num_utf8_chars < ngram_size_) {
      // Current token is so small, it does not contain a single ngram of
      // ngram_size UTF8 characters.  Not much we can do in this case ...
      continue;
    }

    // At this point, [ngram_start, ngram_end) is the first ngram of ngram_size
    // UTF8 characters from current token.
    while (true) {
      // Compute ngram id: hash(ngram) % ngram_id_dimension
      int ngram_id = (
          utils::Hash32WithDefaultSeed(ngram_start, ngram_end - ngram_start)
          % ngram_id_dimension_);

      // Use a reference to the actual count, such that we can both test whether
      // the count was 0 and increment it without perfoming two lookups.
      int &ref_to_count_for_ngram = counts_[ngram_id];
      if (ref_to_count_for_ngram == 0) {
        non_zero_count_indices_.push_back(ngram_id);
      }
      ref_to_count_for_ngram++;
      total_count++;
      if (ngram_end >= word_end) {
        break;
      }

      // Advance both ngram_start and ngram_end by one UTF8 character.  This
      // way, the number of UTF8 characters between them remains constant
      // (ngram_size).
      ngram_start += utils::OneCharLen(ngram_start);
      ngram_end += utils::OneCharLen(ngram_end);
    }
  }  // end of loop over tokens.

  return total_count;
}

void ContinuousBagOfNgramsFunction::Evaluate(const WorkspaceSet &workspaces,
                                             const LightSentence &sentence,
                                             FeatureVector *result) const {
  // NOTE: we use std::* constructs (instead of absl::Mutex & co) to simplify
  // porting to Android and to avoid pulling in absl (which increases our code
  // size).
  std::lock_guard<std::mutex> mlock(state_mutex_);

  // Find the char ngram counts.
  int total_count = ComputeNgramCounts(sentence);

  // Populate the feature vector.
  const float norm = static_cast<float>(total_count);

  // TODO(salcianu): explore treating dense vectors (i.e., many non-zero
  // elements) separately.
  for (int ngram_id : non_zero_count_indices_) {
    const float weight = counts_[ngram_id] / norm;
    FloatFeatureValue value(ngram_id, weight);
    result->add(feature_type(), value.discrete_value);

    // Clear up counts_, for the next invocation of Evaluate().
    counts_[ngram_id] = 0;
  }

  // Clear up non_zero_count_indices_, for the next invocation of Evaluate().
  non_zero_count_indices_.clear();
}

SAFTM_STATIC_REGISTRATION(ContinuousBagOfNgramsFunction);

}  // namespace lang_id
}  // namespace mobile
}  // namespace nlp_saft