aboutsummaryrefslogtreecommitdiff
path: root/discovery/mdns/mdns_trackers.cc
blob: b059761366100f076acd67ca4bf63d9606c74cc3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "discovery/mdns/mdns_trackers.h"

#include <array>

#include "discovery/mdns/mdns_random.h"
#include "discovery/mdns/mdns_record_changed_callback.h"
#include "discovery/mdns/mdns_sender.h"
#include "util/std_util.h"

namespace openscreen {
namespace discovery {

namespace {

// RFC 6762 Section 5.2
// https://tools.ietf.org/html/rfc6762#section-5.2

// Attempt to refresh a record should be performed at 80%, 85%, 90% and 95% TTL.
constexpr double kTtlFractions[] = {0.80, 0.85, 0.90, 0.95, 1.00};

// Intervals between successive queries must increase by at least a factor of 2.
constexpr int kIntervalIncreaseFactor = 2;

// The interval between the first two queries must be at least one second.
constexpr std::chrono::seconds kMinimumQueryInterval{1};

// The querier may cap the question refresh interval to a maximum of 60 minutes.
constexpr std::chrono::minutes kMaximumQueryInterval{60};

// RFC 6762 Section 10.1
// https://tools.ietf.org/html/rfc6762#section-10.1

// A goodbye record is a record with TTL of 0.
bool IsGoodbyeRecord(const MdnsRecord& record) {
  return record.ttl() == std::chrono::seconds{0};
}

// RFC 6762 Section 10.1
// https://tools.ietf.org/html/rfc6762#section-10.1
// In case of a goodbye record, the querier should set TTL to 1 second
constexpr std::chrono::seconds kGoodbyeRecordTtl{1};

}  // namespace

MdnsTracker::MdnsTracker(MdnsSender* sender,
                         TaskRunner* task_runner,
                         ClockNowFunctionPtr now_function,
                         MdnsRandom* random_delay)
    : sender_(sender),
      task_runner_(task_runner),
      now_function_(now_function),
      send_alarm_(now_function, task_runner),
      random_delay_(random_delay) {
  OSP_DCHECK(task_runner);
  OSP_DCHECK(now_function);
  OSP_DCHECK(random_delay);
  OSP_DCHECK(sender);
}

MdnsRecordTracker::MdnsRecordTracker(
    MdnsSender* sender,
    TaskRunner* task_runner,
    ClockNowFunctionPtr now_function,
    MdnsRandom* random_delay,
    std::function<void(const MdnsRecord&)> record_updated_callback,
    std::function<void(const MdnsRecord&)> record_expired_callback)
    : MdnsTracker(sender, task_runner, now_function, random_delay),
      record_updated_callback_(record_updated_callback),
      record_expired_callback_(record_expired_callback) {
  OSP_DCHECK(record_updated_callback);
  OSP_DCHECK(record_expired_callback);
}

Error MdnsRecordTracker::Start(MdnsRecord record) {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());

  if (record_.has_value()) {
    return Error::Code::kOperationInvalid;
  }

  record_ = std::move(record);
  start_time_ = now_function_();
  send_count_ = 0;
  send_alarm_.Schedule([this] { SendQuery(); }, GetNextSendTime());
  return Error::None();
}

Error MdnsRecordTracker::Stop() {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());

  if (!record_.has_value()) {
    return Error::Code::kOperationInvalid;
  }

  send_alarm_.Cancel();
  record_.reset();
  return Error::None();
}

Error MdnsRecordTracker::Update(const MdnsRecord& new_record) {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());

  if (!record_.has_value()) {
    return Error::Code::kOperationInvalid;
  }

  MdnsRecord& old_record = record_.value();
  if ((old_record.dns_type() != new_record.dns_type()) ||
      (old_record.dns_class() != new_record.dns_class()) ||
      (old_record.name() != new_record.name())) {
    // The new record has been passed to a wrong tracker
    return Error::Code::kParameterInvalid;
  }

  // Check if RDATA has changed before a call to Stop clears the old record
  const bool is_updated = (new_record.rdata() != old_record.rdata());

  Error error = Stop();
  if (!error.ok()) {
    return error;
  }

  if (IsGoodbyeRecord(new_record)) {
    error = Start(MdnsRecord(new_record.name(), new_record.dns_type(),
                             new_record.dns_class(), new_record.record_type(),
                             kGoodbyeRecordTtl, new_record.rdata()));
  } else {
    error = Start(new_record);
  }

  if (!error.ok()) {
    return error;
  }

  if (is_updated) {
    record_updated_callback_(record_.value());
  }

  return Error::None();
}

bool MdnsRecordTracker::IsStarted() {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());

  return record_.has_value();
}

void MdnsRecordTracker::SendQuery() {
  const MdnsRecord& record = record_.value();
  const Clock::time_point expiration_time = start_time_ + record.ttl();
  const bool is_expired = (now_function_() >= expiration_time);
  if (!is_expired) {
    MdnsQuestion question(record.name(), record.dns_type(), record.dns_class(),
                          ResponseType::kMulticast);
    MdnsMessage message(CreateMessageId(), MessageType::Query);
    message.AddQuestion(std::move(question));
    sender_->SendMulticast(message);
    send_alarm_.Schedule([this] { MdnsRecordTracker::SendQuery(); },
                         GetNextSendTime());
  } else {
    record_expired_callback_(record);
  }
}

openscreen::platform::Clock::time_point MdnsRecordTracker::GetNextSendTime() {
  OSP_DCHECK(send_count_ < openscreen::countof(kTtlFractions));

  double ttl_fraction = kTtlFractions[send_count_++];

  // Do not add random variation to the expiration time (last fraction of TTL)
  if (send_count_ != openscreen::countof(kTtlFractions)) {
    ttl_fraction += random_delay_->GetRecordTtlVariation();
  }

  const Clock::duration delay = std::chrono::duration_cast<Clock::duration>(
      record_.value().ttl() * ttl_fraction);
  return start_time_ + delay;
}

MdnsQuestionTracker::MdnsQuestionTracker(MdnsSender* sender,
                                         TaskRunner* task_runner,
                                         ClockNowFunctionPtr now_function,
                                         MdnsRandom* random_delay)
    : MdnsTracker(sender, task_runner, now_function, random_delay) {}

Error MdnsQuestionTracker::Start(MdnsQuestion question) {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());

  if (question_.has_value()) {
    return Error::Code::kOperationInvalid;
  }

  question_ = std::move(question);
  send_delay_ = kMinimumQueryInterval;
  // The initial query has to be sent after a random delay of 20-120
  // milliseconds.
  const Clock::duration delay = random_delay_->GetInitialQueryDelay();
  send_alarm_.Schedule([this] { MdnsQuestionTracker::SendQuery(); },
                       now_function_() + delay);
  return Error::None();
}

Error MdnsQuestionTracker::Stop() {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());

  if (!question_.has_value()) {
    return Error::Code::kOperationInvalid;
  }

  send_alarm_.Cancel();
  question_.reset();
  record_trackers_.clear();
  return Error::None();
}

bool MdnsQuestionTracker::IsStarted() {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());

  return question_.has_value();
}

void MdnsQuestionTracker::AddCallback(MdnsRecordChangedCallback* callback) {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());

  const auto find_result =
      std::find(callbacks_.begin(), callbacks_.end(), callback);
  if (find_result == callbacks_.end()) {
    callbacks_.push_back(callback);
    // TODO(yakimakha): Notify the new callback with all known answers
  }
}

void MdnsQuestionTracker::RemoveCallback(MdnsRecordChangedCallback* callback) {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());

  const auto find_result =
      std::find(callbacks_.begin(), callbacks_.end(), callback);
  if (find_result != callbacks_.end()) {
    callbacks_.erase(find_result);
  }
}

bool MdnsQuestionTracker::HasCallbacks() const {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());

  return !callbacks_.empty();
}

void MdnsQuestionTracker::OnRecordReceived(const MdnsRecord& record) {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());

  if (!question_.has_value()) {
    return;
  }

  const RecordKey key(record.name(), record.dns_type(), record.dns_class());

  const auto find_result = record_trackers_.find(key);
  if (find_result != record_trackers_.end()) {
    MdnsRecordTracker* record_tracker = find_result->second.get();
    record_tracker->Update(record);
    return;
  }

  std::unique_ptr<MdnsRecordTracker> record_tracker =
      std::make_unique<MdnsRecordTracker>(
          sender_, task_runner_, now_function_, random_delay_,
          [this](const MdnsRecord& record) {
            MdnsQuestionTracker::OnRecordUpdated(record);
          },
          [this](const MdnsRecord& record) {
            MdnsQuestionTracker::OnRecordExpired(record);
          });

  record_tracker->Start(record);
  record_trackers_.emplace(key, std::move(record_tracker));

  for (auto callback : callbacks_) {
    callback->OnRecordChanged(record, RecordChangedEvent::kCreated);
  }
}

void MdnsQuestionTracker::OnRecordExpired(const MdnsRecord& record) {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());

  for (auto callback : callbacks_) {
    callback->OnRecordChanged(record, RecordChangedEvent::kDeleted);
  }

  const RecordKey key(record.name(), record.dns_type(), record.dns_class());
  record_trackers_.erase(key);
}

void MdnsQuestionTracker::OnRecordUpdated(const MdnsRecord& record) {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());

  for (auto* callback : callbacks_) {
    callback->OnRecordChanged(record, RecordChangedEvent::kUpdated);
  }
}

void MdnsQuestionTracker::SendQuery() {
  MdnsMessage message(CreateMessageId(), MessageType::Query);
  message.AddQuestion(question_.value());
  // TODO(yakimakha): Implement known-answer suppression by adding known
  // answers to the question
  sender_->SendMulticast(message);
  send_alarm_.Schedule([this] { MdnsQuestionTracker::SendQuery(); },
                       now_function_() + send_delay_);
  send_delay_ = send_delay_ * kIntervalIncreaseFactor;
  if (send_delay_ > kMaximumQueryInterval) {
    send_delay_ = kMaximumQueryInterval;
  }
}

}  // namespace discovery
}  // namespace openscreen