aboutsummaryrefslogtreecommitdiff
path: root/test/test_blocking_counter.cc
blob: 34d963db4f28bc478489eeb81de475060c984c6c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <atomic>  // NOLINT
#include <vector>
#include <iostream>
#include <cstdlib>

#include "../internal/multi_thread_gemm.h"
#include "../profiling/pthread_everywhere.h"
#include "test.h"

namespace gemmlowp {

class Thread {
 public:
  Thread(BlockingCounter* blocking_counter, int number_of_times_to_decrement)
      : blocking_counter_(blocking_counter),
        number_of_times_to_decrement_(number_of_times_to_decrement),
        made_the_last_decrement_(false),
        finished_(false) {
#if defined GEMMLOWP_USE_PTHREAD
    // Limit the stack size so as not to deplete memory when creating
    // many threads.
    pthread_attr_t attr;
    int err = pthread_attr_init(&attr);
    if (!err) {
      size_t stack_size;
      err = pthread_attr_getstacksize(&attr, &stack_size);
      if (!err && stack_size > max_stack_size_) {
        err = pthread_attr_setstacksize(&attr, max_stack_size_);
      }
      if (!err) {
        err = pthread_create(&thread_, &attr, ThreadFunc, this);
      }
    }
    if (err) {
      std::cerr << "Failed to create a thread.\n";
      std::abort();
    }
#else
    pthread_create(&thread_, nullptr, ThreadFunc, this);
#endif
  }

  ~Thread() { Join(); }

  bool Join() {
    while (!finished_.load()) {
    }
    return made_the_last_decrement_;
  }

 private:
  Thread(const Thread& other) = delete;

  void ThreadFunc() {
    for (int i = 0; i < number_of_times_to_decrement_; i++) {
      Check(!made_the_last_decrement_);
      made_the_last_decrement_ = blocking_counter_->DecrementCount();
    }
    finished_.store(true);
  }

  static void* ThreadFunc(void* ptr) {
    static_cast<Thread*>(ptr)->ThreadFunc();
    return nullptr;
  }

  static constexpr size_t max_stack_size_ = 256 * 1024;
  BlockingCounter* const blocking_counter_;
  const int number_of_times_to_decrement_;
  pthread_t thread_;
  bool made_the_last_decrement_;
  // finished_ is used to manually implement Join() by busy-waiting.
  // I wanted to use pthread_join / std::thread::join, but the behavior
  // observed on Android was that pthread_join aborts when the thread has
  // already joined before calling pthread_join, making that hard to use.
  // It appeared simplest to just implement this simple spinlock, and that
  // is good enough as this is just a test.
  std::atomic<bool> finished_;
};

void test_blocking_counter(BlockingCounter* blocking_counter, int num_threads,
                           int num_decrements_per_thread,
                           int num_decrements_to_wait_for) {
  std::vector<Thread*> threads;
  blocking_counter->Reset(num_decrements_to_wait_for);
  for (int i = 0; i < num_threads; i++) {
    threads.push_back(new Thread(blocking_counter, num_decrements_per_thread));
  }
  blocking_counter->Wait();

  int num_threads_that_made_the_last_decrement = 0;
  for (int i = 0; i < num_threads; i++) {
    if (threads[i]->Join()) {
      num_threads_that_made_the_last_decrement++;
    }
    delete threads[i];
  }
  Check(num_threads_that_made_the_last_decrement == 1);
}

void test_blocking_counter() {
  BlockingCounter* blocking_counter = new BlockingCounter;

  // repeating the entire test sequence ensures that we test
  // non-monotonic changes.
  for (int repeat = 1; repeat <= 2; repeat++) {
    for (int num_threads = 1; num_threads <= 5; num_threads++) {
      for (int num_decrements_per_thread = 1;
           num_decrements_per_thread <= 4 * 1024;
           num_decrements_per_thread *= 16) {
        test_blocking_counter(blocking_counter, num_threads,
                              num_decrements_per_thread,
                              num_threads * num_decrements_per_thread);
      }
    }
  }
  delete blocking_counter;
}

}  // end namespace gemmlowp

int main() { gemmlowp::test_blocking_counter(); }