// Copyright 2019 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "discovery/mdns/mdns_records.h" #include #include #include #include #include #include "absl/strings/ascii.h" #include "absl/strings/match.h" #include "absl/strings/str_join.h" #include "discovery/mdns/mdns_writer.h" namespace openscreen { namespace discovery { namespace { constexpr size_t kMaxRawRecordSize = std::numeric_limits::max(); constexpr size_t kMaxMessageFieldEntryCount = std::numeric_limits::max(); inline int CompareIgnoreCase(const std::string& x, const std::string& y) { size_t i = 0; for (; i < x.size(); i++) { if (i == y.size()) { return 1; } const char& x_char = std::tolower(x[i]); const char& y_char = std::tolower(y[i]); if (x_char < y_char) { return -1; } else if (y_char < x_char) { return 1; } } return i == y.size() ? 0 : -1; } template bool IsGreaterThan(const Rdata& lhs, const Rdata& rhs) { const RDataType& lhs_cast = absl::get(lhs); const RDataType& rhs_cast = absl::get(rhs); // The Extra 2 in length is from the record size that Write() prepends to the // result. const size_t lhs_size = lhs_cast.MaxWireSize() + 2; const size_t rhs_size = rhs_cast.MaxWireSize() + 2; uint8_t lhs_bytes[lhs_size]; uint8_t rhs_bytes[rhs_size]; MdnsWriter lhs_writer(lhs_bytes, lhs_size); MdnsWriter rhs_writer(rhs_bytes, rhs_size); const bool lhs_write = lhs_writer.Write(lhs_cast); const bool rhs_write = rhs_writer.Write(rhs_cast); OSP_DCHECK(lhs_write); OSP_DCHECK(rhs_write); // Skip the size bits. const size_t min_size = std::min(lhs_writer.offset(), rhs_writer.offset()); for (size_t i = 2; i < min_size; i++) { if (lhs_bytes[i] != rhs_bytes[i]) { return lhs_bytes[i] > rhs_bytes[i]; } } return lhs_size > rhs_size; } bool IsGreaterThan(DnsType type, const Rdata& lhs, const Rdata& rhs) { switch (type) { case DnsType::kA: return IsGreaterThan(lhs, rhs); case DnsType::kPTR: return IsGreaterThan(lhs, rhs); case DnsType::kTXT: return IsGreaterThan(lhs, rhs); case DnsType::kAAAA: return IsGreaterThan(lhs, rhs); case DnsType::kSRV: return IsGreaterThan(lhs, rhs); case DnsType::kNSEC: return IsGreaterThan(lhs, rhs); default: return IsGreaterThan(lhs, rhs); } } } // namespace bool IsValidDomainLabel(absl::string_view label) { const size_t label_size = label.size(); return label_size > 0 && label_size <= kMaxLabelLength; } DomainName::DomainName() = default; DomainName::DomainName(std::vector labels) : DomainName(labels.begin(), labels.end()) {} DomainName::DomainName(const std::vector& labels) : DomainName(labels.begin(), labels.end()) {} DomainName::DomainName(std::initializer_list labels) : DomainName(labels.begin(), labels.end()) {} DomainName::DomainName(std::vector labels, size_t max_wire_size) : max_wire_size_(max_wire_size), labels_(std::move(labels)) {} DomainName::DomainName(const DomainName& other) = default; DomainName::DomainName(DomainName&& other) noexcept = default; DomainName& DomainName::operator=(const DomainName& rhs) = default; DomainName& DomainName::operator=(DomainName&& rhs) = default; std::string DomainName::ToString() const { return absl::StrJoin(labels_, "."); } bool DomainName::operator<(const DomainName& rhs) const { size_t i = 0; for (; i < labels_.size(); i++) { if (i == rhs.labels_.size()) { return false; } else { int result = CompareIgnoreCase(labels_[i], rhs.labels_[i]); if (result < 0) { return true; } else if (result > 0) { return false; } } } return i < rhs.labels_.size(); } bool DomainName::operator<=(const DomainName& rhs) const { return (*this < rhs) || (*this == rhs); } bool DomainName::operator>(const DomainName& rhs) const { return !(*this < rhs) && !(*this == rhs); } bool DomainName::operator>=(const DomainName& rhs) const { return !(*this < rhs); } bool DomainName::operator==(const DomainName& rhs) const { if (labels_.size() != rhs.labels_.size()) { return false; } for (size_t i = 0; i < labels_.size(); i++) { if (CompareIgnoreCase(labels_[i], rhs.labels_[i]) != 0) { return false; } } return true; } bool DomainName::operator!=(const DomainName& rhs) const { return !(*this == rhs); } size_t DomainName::MaxWireSize() const { return max_wire_size_; } // static ErrorOr RawRecordRdata::TryCreate(std::vector rdata) { if (rdata.size() > kMaxRawRecordSize) { return Error::Code::kIndexOutOfBounds; } else { return RawRecordRdata(std::move(rdata)); } } RawRecordRdata::RawRecordRdata() = default; RawRecordRdata::RawRecordRdata(std::vector rdata) : rdata_(std::move(rdata)) { // Ensure RDATA length does not exceed the maximum allowed. OSP_DCHECK(rdata_.size() <= kMaxRawRecordSize); } RawRecordRdata::RawRecordRdata(const uint8_t* begin, size_t size) : RawRecordRdata(std::vector(begin, begin + size)) {} RawRecordRdata::RawRecordRdata(const RawRecordRdata& other) = default; RawRecordRdata::RawRecordRdata(RawRecordRdata&& other) noexcept = default; RawRecordRdata& RawRecordRdata::operator=(const RawRecordRdata& rhs) = default; RawRecordRdata& RawRecordRdata::operator=(RawRecordRdata&& rhs) = default; bool RawRecordRdata::operator==(const RawRecordRdata& rhs) const { return rdata_ == rhs.rdata_; } bool RawRecordRdata::operator!=(const RawRecordRdata& rhs) const { return !(*this == rhs); } size_t RawRecordRdata::MaxWireSize() const { // max_wire_size includes uint16_t record length field. return sizeof(uint16_t) + rdata_.size(); } SrvRecordRdata::SrvRecordRdata() = default; SrvRecordRdata::SrvRecordRdata(uint16_t priority, uint16_t weight, uint16_t port, DomainName target) : priority_(priority), weight_(weight), port_(port), target_(std::move(target)) {} SrvRecordRdata::SrvRecordRdata(const SrvRecordRdata& other) = default; SrvRecordRdata::SrvRecordRdata(SrvRecordRdata&& other) noexcept = default; SrvRecordRdata& SrvRecordRdata::operator=(const SrvRecordRdata& rhs) = default; SrvRecordRdata& SrvRecordRdata::operator=(SrvRecordRdata&& rhs) = default; bool SrvRecordRdata::operator==(const SrvRecordRdata& rhs) const { return priority_ == rhs.priority_ && weight_ == rhs.weight_ && port_ == rhs.port_ && target_ == rhs.target_; } bool SrvRecordRdata::operator!=(const SrvRecordRdata& rhs) const { return !(*this == rhs); } size_t SrvRecordRdata::MaxWireSize() const { // max_wire_size includes uint16_t record length field. return sizeof(uint16_t) + sizeof(priority_) + sizeof(weight_) + sizeof(port_) + target_.MaxWireSize(); } ARecordRdata::ARecordRdata() = default; ARecordRdata::ARecordRdata(IPAddress ipv4_address, NetworkInterfaceIndex interface_index) : ipv4_address_(std::move(ipv4_address)), interface_index_(interface_index) { OSP_CHECK(ipv4_address_.IsV4()); } ARecordRdata::ARecordRdata(const ARecordRdata& other) = default; ARecordRdata::ARecordRdata(ARecordRdata&& other) noexcept = default; ARecordRdata& ARecordRdata::operator=(const ARecordRdata& rhs) = default; ARecordRdata& ARecordRdata::operator=(ARecordRdata&& rhs) = default; bool ARecordRdata::operator==(const ARecordRdata& rhs) const { return ipv4_address_ == rhs.ipv4_address_ && interface_index_ == rhs.interface_index_; } bool ARecordRdata::operator!=(const ARecordRdata& rhs) const { return !(*this == rhs); } size_t ARecordRdata::MaxWireSize() const { // max_wire_size includes uint16_t record length field. return sizeof(uint16_t) + IPAddress::kV4Size; } AAAARecordRdata::AAAARecordRdata() = default; AAAARecordRdata::AAAARecordRdata(IPAddress ipv6_address, NetworkInterfaceIndex interface_index) : ipv6_address_(std::move(ipv6_address)), interface_index_(interface_index) { OSP_CHECK(ipv6_address_.IsV6()); } AAAARecordRdata::AAAARecordRdata(const AAAARecordRdata& other) = default; AAAARecordRdata::AAAARecordRdata(AAAARecordRdata&& other) noexcept = default; AAAARecordRdata& AAAARecordRdata::operator=(const AAAARecordRdata& rhs) = default; AAAARecordRdata& AAAARecordRdata::operator=(AAAARecordRdata&& rhs) = default; bool AAAARecordRdata::operator==(const AAAARecordRdata& rhs) const { return ipv6_address_ == rhs.ipv6_address_ && interface_index_ == rhs.interface_index_; } bool AAAARecordRdata::operator!=(const AAAARecordRdata& rhs) const { return !(*this == rhs); } size_t AAAARecordRdata::MaxWireSize() const { // max_wire_size includes uint16_t record length field. return sizeof(uint16_t) + IPAddress::kV6Size; } PtrRecordRdata::PtrRecordRdata() = default; PtrRecordRdata::PtrRecordRdata(DomainName ptr_domain) : ptr_domain_(ptr_domain) {} PtrRecordRdata::PtrRecordRdata(const PtrRecordRdata& other) = default; PtrRecordRdata::PtrRecordRdata(PtrRecordRdata&& other) noexcept = default; PtrRecordRdata& PtrRecordRdata::operator=(const PtrRecordRdata& rhs) = default; PtrRecordRdata& PtrRecordRdata::operator=(PtrRecordRdata&& rhs) = default; bool PtrRecordRdata::operator==(const PtrRecordRdata& rhs) const { return ptr_domain_ == rhs.ptr_domain_; } bool PtrRecordRdata::operator!=(const PtrRecordRdata& rhs) const { return !(*this == rhs); } size_t PtrRecordRdata::MaxWireSize() const { // max_wire_size includes uint16_t record length field. return sizeof(uint16_t) + ptr_domain_.MaxWireSize(); } // static ErrorOr TxtRecordRdata::TryCreate(std::vector texts) { std::vector str_texts; size_t max_wire_size = 3; if (texts.size() > 0) { str_texts.reserve(texts.size()); // max_wire_size includes uint16_t record length field. max_wire_size = sizeof(uint16_t); for (const auto& text : texts) { if (text.empty()) { return Error::Code::kParameterInvalid; } str_texts.push_back( std::string(reinterpret_cast(text.data()), text.size())); // Include the length byte in the size calculation. max_wire_size += text.size() + 1; } } return TxtRecordRdata(std::move(str_texts), max_wire_size); } TxtRecordRdata::TxtRecordRdata() = default; TxtRecordRdata::TxtRecordRdata(std::vector texts) { ErrorOr rdata = TxtRecordRdata::TryCreate(std::move(texts)); *this = std::move(rdata.value()); } TxtRecordRdata::TxtRecordRdata(std::vector texts, size_t max_wire_size) : max_wire_size_(max_wire_size), texts_(std::move(texts)) {} TxtRecordRdata::TxtRecordRdata(const TxtRecordRdata& other) = default; TxtRecordRdata::TxtRecordRdata(TxtRecordRdata&& other) noexcept = default; TxtRecordRdata& TxtRecordRdata::operator=(const TxtRecordRdata& rhs) = default; TxtRecordRdata& TxtRecordRdata::operator=(TxtRecordRdata&& rhs) = default; bool TxtRecordRdata::operator==(const TxtRecordRdata& rhs) const { return texts_ == rhs.texts_; } bool TxtRecordRdata::operator!=(const TxtRecordRdata& rhs) const { return !(*this == rhs); } size_t TxtRecordRdata::MaxWireSize() const { return max_wire_size_; } NsecRecordRdata::NsecRecordRdata() = default; NsecRecordRdata::NsecRecordRdata(DomainName next_domain_name, std::vector types) : types_(std::move(types)), next_domain_name_(std::move(next_domain_name)) { // Sort the types_ array for easier comparison later. std::sort(types_.begin(), types_.end()); // Calculate the bitmaps as described in RFC 4034 Section 4.1.2. std::vector block_contents; uint8_t current_block = 0; for (auto type : types_) { const uint16_t type_int = static_cast(type); const uint8_t block = static_cast(type_int >> 8); const uint8_t block_position = static_cast(type_int & 0xFF); const uint8_t byte_bit_is_at = block_position >> 3; // First 5 bits. const uint8_t byte_mask = 0x80 >> (block_position & 0x07); // Last 3 bits. // If the block has changed, write the previous block's info and all of its // contents to the |encoded_types_| vector. if (block > current_block) { if (!block_contents.empty()) { encoded_types_.push_back(current_block); encoded_types_.push_back(static_cast(block_contents.size())); encoded_types_.insert(encoded_types_.end(), block_contents.begin(), block_contents.end()); } block_contents = std::vector(); current_block = block; } // Make sure |block_contents| is large enough to hold the bit representing // the new type , then set it. if (block_contents.size() <= byte_bit_is_at) { block_contents.insert(block_contents.end(), byte_bit_is_at - block_contents.size() + 1, 0x00); } block_contents[byte_bit_is_at] |= byte_mask; } if (!block_contents.empty()) { encoded_types_.push_back(current_block); encoded_types_.push_back(static_cast(block_contents.size())); encoded_types_.insert(encoded_types_.end(), block_contents.begin(), block_contents.end()); } } NsecRecordRdata::NsecRecordRdata(const NsecRecordRdata& other) = default; NsecRecordRdata::NsecRecordRdata(NsecRecordRdata&& other) noexcept = default; NsecRecordRdata& NsecRecordRdata::operator=(const NsecRecordRdata& rhs) = default; NsecRecordRdata& NsecRecordRdata::operator=(NsecRecordRdata&& rhs) = default; bool NsecRecordRdata::operator==(const NsecRecordRdata& rhs) const { return types_ == rhs.types_ && next_domain_name_ == rhs.next_domain_name_; } bool NsecRecordRdata::operator!=(const NsecRecordRdata& rhs) const { return !(*this == rhs); } size_t NsecRecordRdata::MaxWireSize() const { return next_domain_name_.MaxWireSize() + encoded_types_.size(); } size_t OptRecordRdata::Option::MaxWireSize() const { // One uint16_t for each of OPTION-LENGTH and OPTION-CODE as defined in RFC // 6891 section 6.1.2. constexpr size_t kOptionLengthAndCodeSize = 2 * sizeof(uint16_t); return data.size() + kOptionLengthAndCodeSize; } bool OptRecordRdata::Option::operator>( const OptRecordRdata::Option& rhs) const { if (code != rhs.code) { return code > rhs.code; } else if (length != rhs.length) { return length > rhs.length; } else if (data.size() != rhs.data.size()) { return data.size() > rhs.data.size(); } for (int i = 0; i < static_cast(data.size()); i++) { if (data[i] != rhs.data[i]) { return data[i] > rhs.data[i]; } } return false; } bool OptRecordRdata::Option::operator<( const OptRecordRdata::Option& rhs) const { return rhs > *this; } bool OptRecordRdata::Option::operator>=( const OptRecordRdata::Option& rhs) const { return !(*this < rhs); } bool OptRecordRdata::Option::operator<=( const OptRecordRdata::Option& rhs) const { return !(*this > rhs); } bool OptRecordRdata::Option::operator==( const OptRecordRdata::Option& rhs) const { return *this >= rhs && *this <= rhs; } bool OptRecordRdata::Option::operator!=( const OptRecordRdata::Option& rhs) const { return !(*this == rhs); } OptRecordRdata::OptRecordRdata() = default; OptRecordRdata::OptRecordRdata(std::vector