aboutsummaryrefslogtreecommitdiff
path: root/discovery/mdns/mdns_trackers.cc
blob: f45c2148168837123601c7357bba56bce5d61765 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "discovery/mdns/mdns_trackers.h"

#include <array>

#include "discovery/mdns/mdns_random.h"
#include "discovery/mdns/mdns_record_changed_callback.h"
#include "discovery/mdns/mdns_sender.h"
#include "util/std_util.h"

namespace openscreen {
namespace discovery {

namespace {

// RFC 6762 Section 5.2
// https://tools.ietf.org/html/rfc6762#section-5.2

// Attempt to refresh a record should be performed at 80%, 85%, 90% and 95% TTL.
constexpr double kTtlFractions[] = {0.80, 0.85, 0.90, 0.95, 1.00};

// Intervals between successive queries must increase by at least a factor of 2.
constexpr int kIntervalIncreaseFactor = 2;

// The interval between the first two queries must be at least one second.
constexpr std::chrono::seconds kMinimumQueryInterval{1};

// The querier may cap the question refresh interval to a maximum of 60 minutes.
constexpr std::chrono::minutes kMaximumQueryInterval{60};

// RFC 6762 Section 10.1
// https://tools.ietf.org/html/rfc6762#section-10.1

// A goodbye record is a record with TTL of 0.
bool IsGoodbyeRecord(const MdnsRecord& record) {
  return record.ttl() == std::chrono::seconds{0};
}

// RFC 6762 Section 10.1
// https://tools.ietf.org/html/rfc6762#section-10.1
// In case of a goodbye record, the querier should set TTL to 1 second
constexpr std::chrono::seconds kGoodbyeRecordTtl{1};

}  // namespace

MdnsTracker::MdnsTracker(MdnsSender* sender,
                         TaskRunner* task_runner,
                         ClockNowFunctionPtr now_function,
                         MdnsRandom* random_delay)
    : sender_(sender),
      task_runner_(task_runner),
      now_function_(now_function),
      send_alarm_(now_function, task_runner),
      random_delay_(random_delay) {
  OSP_DCHECK(task_runner);
  OSP_DCHECK(now_function);
  OSP_DCHECK(random_delay);
  OSP_DCHECK(sender);
}

MdnsRecordTracker::MdnsRecordTracker(
    MdnsRecord record,
    MdnsSender* sender,
    TaskRunner* task_runner,
    ClockNowFunctionPtr now_function,
    MdnsRandom* random_delay,
    std::function<void(const MdnsRecord&)> record_expired_callback)
    : MdnsTracker(sender, task_runner, now_function, random_delay),
      record_(std::move(record)),
      start_time_(now_function_()),
      record_expired_callback_(record_expired_callback) {
  OSP_DCHECK(record_expired_callback);

  send_alarm_.Schedule([this] { SendQuery(); }, GetNextSendTime());
}

ErrorOr<MdnsRecordTracker::UpdateType> MdnsRecordTracker::Update(
    const MdnsRecord& new_record) {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
  bool has_same_rdata = record_.rdata() == new_record.rdata();

  // Goodbye records must have the same RDATA but TTL of 0.
  // RFC 6762 Section 10.1
  // https://tools.ietf.org/html/rfc6762#section-10.1
  if ((record_.dns_type() != new_record.dns_type()) ||
      (record_.dns_class() != new_record.dns_class()) ||
      (record_.name() != new_record.name()) ||
      (IsGoodbyeRecord(new_record) && !has_same_rdata)) {
    // The new record has been passed to a wrong tracker.
    return Error::Code::kParameterInvalid;
  }

  UpdateType result = UpdateType::kGoodbye;
  if (IsGoodbyeRecord(new_record)) {
    record_ = MdnsRecord(new_record.name(), new_record.dns_type(),
                         new_record.dns_class(), new_record.record_type(),
                         kGoodbyeRecordTtl, new_record.rdata());

    // Goodbye records do not need to be requeried, set the attempt count to the
    // last item, which is 100% of TTL, i.e. record expiration.
    attempt_count_ = countof(kTtlFractions) - 1;
  } else {
    record_ = new_record;
    attempt_count_ = 0;
    result = has_same_rdata ? UpdateType::kTTLOnly : UpdateType::kRdata;
  }

  start_time_ = now_function_();
  send_alarm_.Schedule([this] { SendQuery(); }, GetNextSendTime());

  return result;
}

void MdnsRecordTracker::ExpireSoon() {
  OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());

  record_ =
      MdnsRecord(record_.name(), record_.dns_type(), record_.dns_class(),
                 record_.record_type(), kGoodbyeRecordTtl, record_.rdata());

  // Set the attempt count to the last item, which is 100% of TTL, i.e. record
  // expiration, to prevent any requeries
  attempt_count_ = countof(kTtlFractions) - 1;
  start_time_ = now_function_();
  send_alarm_.Schedule([this] { SendQuery(); }, GetNextSendTime());
}

void MdnsRecordTracker::SendQuery() {
  const Clock::time_point expiration_time = start_time_ + record_.ttl();
  const bool is_expired = (now_function_() >= expiration_time);
  if (!is_expired) {
    MdnsQuestion question(record_.name(), record_.dns_type(),
                          record_.dns_class(), ResponseType::kMulticast);
    MdnsMessage message(CreateMessageId(), MessageType::Query);
    message.AddQuestion(std::move(question));
    sender_->SendMulticast(message);
    send_alarm_.Schedule([this] { MdnsRecordTracker::SendQuery(); },
                         GetNextSendTime());
  } else {
    record_expired_callback_(record_);
  }
}

Clock::time_point MdnsRecordTracker::GetNextSendTime() {
  OSP_DCHECK(attempt_count_ < countof(kTtlFractions));

  double ttl_fraction = kTtlFractions[attempt_count_++];

  // Do not add random variation to the expiration time (last fraction of TTL)
  if (attempt_count_ != countof(kTtlFractions)) {
    ttl_fraction += random_delay_->GetRecordTtlVariation();
  }

  const Clock::duration delay =
      std::chrono::duration_cast<Clock::duration>(record_.ttl() * ttl_fraction);
  return start_time_ + delay;
}

MdnsQuestionTracker::MdnsQuestionTracker(MdnsQuestion question,
                                         MdnsSender* sender,
                                         TaskRunner* task_runner,
                                         ClockNowFunctionPtr now_function,
                                         MdnsRandom* random_delay)
    : MdnsTracker(sender, task_runner, now_function, random_delay),
      question_(std::move(question)),
      send_delay_(kMinimumQueryInterval) {
  // The initial query has to be sent after a random delay of 20-120
  // milliseconds.
  const Clock::duration delay = random_delay_->GetInitialQueryDelay();
  send_alarm_.Schedule([this] { MdnsQuestionTracker::SendQuery(); },
                       now_function_() + delay);
}

void MdnsQuestionTracker::SendQuery() {
  MdnsMessage message(CreateMessageId(), MessageType::Query);
  message.AddQuestion(question_);
  // TODO(yakimakha): Implement known-answer suppression by adding known
  // answers to the question
  sender_->SendMulticast(message);
  send_alarm_.Schedule([this] { MdnsQuestionTracker::SendQuery(); },
                       now_function_() + send_delay_);
  send_delay_ = send_delay_ * kIntervalIncreaseFactor;
  if (send_delay_ > kMaximumQueryInterval) {
    send_delay_ = kMaximumQueryInterval;
  }
}

}  // namespace discovery
}  // namespace openscreen