940 lines
29 KiB
C++
940 lines
29 KiB
C++
// Copyright 2019 The Chromium Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style license that can be
|
|
// found in the LICENSE file.
|
|
|
|
#include "discovery/mdns/mdns_records.h"
|
|
|
|
#include <algorithm>
|
|
#include <cctype>
|
|
#include <limits>
|
|
#include <sstream>
|
|
#include <vector>
|
|
|
|
#include "absl/strings/ascii.h"
|
|
#include "absl/strings/match.h"
|
|
#include "absl/strings/str_join.h"
|
|
#include "discovery/mdns/mdns_writer.h"
|
|
|
|
namespace openscreen {
|
|
namespace discovery {
|
|
|
|
namespace {
|
|
|
|
constexpr size_t kMaxRawRecordSize = std::numeric_limits<uint16_t>::max();
|
|
|
|
constexpr size_t kMaxMessageFieldEntryCount =
|
|
std::numeric_limits<uint16_t>::max();
|
|
|
|
inline int CompareIgnoreCase(const std::string& x, const std::string& y) {
|
|
size_t i = 0;
|
|
for (; i < x.size(); i++) {
|
|
if (i == y.size()) {
|
|
return 1;
|
|
}
|
|
const char& x_char = std::tolower(x[i]);
|
|
const char& y_char = std::tolower(y[i]);
|
|
if (x_char < y_char) {
|
|
return -1;
|
|
} else if (y_char < x_char) {
|
|
return 1;
|
|
}
|
|
}
|
|
return i == y.size() ? 0 : -1;
|
|
}
|
|
|
|
template <typename RDataType>
|
|
bool IsGreaterThan(const Rdata& lhs, const Rdata& rhs) {
|
|
const RDataType& lhs_cast = absl::get<RDataType>(lhs);
|
|
const RDataType& rhs_cast = absl::get<RDataType>(rhs);
|
|
|
|
// The Extra 2 in length is from the record size that Write() prepends to the
|
|
// result.
|
|
const size_t lhs_size = lhs_cast.MaxWireSize() + 2;
|
|
const size_t rhs_size = rhs_cast.MaxWireSize() + 2;
|
|
|
|
uint8_t lhs_bytes[lhs_size];
|
|
uint8_t rhs_bytes[rhs_size];
|
|
MdnsWriter lhs_writer(lhs_bytes, lhs_size);
|
|
MdnsWriter rhs_writer(rhs_bytes, rhs_size);
|
|
|
|
const bool lhs_write = lhs_writer.Write(lhs_cast);
|
|
const bool rhs_write = rhs_writer.Write(rhs_cast);
|
|
OSP_DCHECK(lhs_write);
|
|
OSP_DCHECK(rhs_write);
|
|
|
|
// Skip the size bits.
|
|
const size_t min_size = std::min(lhs_writer.offset(), rhs_writer.offset());
|
|
for (size_t i = 2; i < min_size; i++) {
|
|
if (lhs_bytes[i] != rhs_bytes[i]) {
|
|
return lhs_bytes[i] > rhs_bytes[i];
|
|
}
|
|
}
|
|
|
|
return lhs_size > rhs_size;
|
|
}
|
|
|
|
bool IsGreaterThan(DnsType type, const Rdata& lhs, const Rdata& rhs) {
|
|
switch (type) {
|
|
case DnsType::kA:
|
|
return IsGreaterThan<ARecordRdata>(lhs, rhs);
|
|
case DnsType::kPTR:
|
|
return IsGreaterThan<PtrRecordRdata>(lhs, rhs);
|
|
case DnsType::kTXT:
|
|
return IsGreaterThan<TxtRecordRdata>(lhs, rhs);
|
|
case DnsType::kAAAA:
|
|
return IsGreaterThan<AAAARecordRdata>(lhs, rhs);
|
|
case DnsType::kSRV:
|
|
return IsGreaterThan<SrvRecordRdata>(lhs, rhs);
|
|
case DnsType::kNSEC:
|
|
return IsGreaterThan<NsecRecordRdata>(lhs, rhs);
|
|
default:
|
|
return IsGreaterThan<RawRecordRdata>(lhs, rhs);
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
bool IsValidDomainLabel(absl::string_view label) {
|
|
const size_t label_size = label.size();
|
|
return label_size > 0 && label_size <= kMaxLabelLength;
|
|
}
|
|
|
|
DomainName::DomainName() = default;
|
|
|
|
DomainName::DomainName(std::vector<std::string> labels)
|
|
: DomainName(labels.begin(), labels.end()) {}
|
|
|
|
DomainName::DomainName(const std::vector<absl::string_view>& labels)
|
|
: DomainName(labels.begin(), labels.end()) {}
|
|
|
|
DomainName::DomainName(std::initializer_list<absl::string_view> labels)
|
|
: DomainName(labels.begin(), labels.end()) {}
|
|
|
|
DomainName::DomainName(std::vector<std::string> labels, size_t max_wire_size)
|
|
: max_wire_size_(max_wire_size), labels_(std::move(labels)) {}
|
|
|
|
DomainName::DomainName(const DomainName& other) = default;
|
|
|
|
DomainName::DomainName(DomainName&& other) noexcept = default;
|
|
|
|
DomainName& DomainName::operator=(const DomainName& rhs) = default;
|
|
|
|
DomainName& DomainName::operator=(DomainName&& rhs) = default;
|
|
|
|
std::string DomainName::ToString() const {
|
|
return absl::StrJoin(labels_, ".");
|
|
}
|
|
|
|
bool DomainName::operator<(const DomainName& rhs) const {
|
|
size_t i = 0;
|
|
for (; i < labels_.size(); i++) {
|
|
if (i == rhs.labels_.size()) {
|
|
return false;
|
|
} else {
|
|
int result = CompareIgnoreCase(labels_[i], rhs.labels_[i]);
|
|
if (result < 0) {
|
|
return true;
|
|
} else if (result > 0) {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
return i < rhs.labels_.size();
|
|
}
|
|
|
|
bool DomainName::operator<=(const DomainName& rhs) const {
|
|
return (*this < rhs) || (*this == rhs);
|
|
}
|
|
|
|
bool DomainName::operator>(const DomainName& rhs) const {
|
|
return !(*this < rhs) && !(*this == rhs);
|
|
}
|
|
|
|
bool DomainName::operator>=(const DomainName& rhs) const {
|
|
return !(*this < rhs);
|
|
}
|
|
|
|
bool DomainName::operator==(const DomainName& rhs) const {
|
|
if (labels_.size() != rhs.labels_.size()) {
|
|
return false;
|
|
}
|
|
for (size_t i = 0; i < labels_.size(); i++) {
|
|
if (CompareIgnoreCase(labels_[i], rhs.labels_[i]) != 0) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool DomainName::operator!=(const DomainName& rhs) const {
|
|
return !(*this == rhs);
|
|
}
|
|
|
|
size_t DomainName::MaxWireSize() const {
|
|
return max_wire_size_;
|
|
}
|
|
|
|
// static
|
|
ErrorOr<RawRecordRdata> RawRecordRdata::TryCreate(std::vector<uint8_t> rdata) {
|
|
if (rdata.size() > kMaxRawRecordSize) {
|
|
return Error::Code::kIndexOutOfBounds;
|
|
} else {
|
|
return RawRecordRdata(std::move(rdata));
|
|
}
|
|
}
|
|
|
|
RawRecordRdata::RawRecordRdata() = default;
|
|
|
|
RawRecordRdata::RawRecordRdata(std::vector<uint8_t> rdata)
|
|
: rdata_(std::move(rdata)) {
|
|
// Ensure RDATA length does not exceed the maximum allowed.
|
|
OSP_DCHECK(rdata_.size() <= kMaxRawRecordSize);
|
|
}
|
|
|
|
RawRecordRdata::RawRecordRdata(const uint8_t* begin, size_t size)
|
|
: RawRecordRdata(std::vector<uint8_t>(begin, begin + size)) {}
|
|
|
|
RawRecordRdata::RawRecordRdata(const RawRecordRdata& other) = default;
|
|
|
|
RawRecordRdata::RawRecordRdata(RawRecordRdata&& other) noexcept = default;
|
|
|
|
RawRecordRdata& RawRecordRdata::operator=(const RawRecordRdata& rhs) = default;
|
|
|
|
RawRecordRdata& RawRecordRdata::operator=(RawRecordRdata&& rhs) = default;
|
|
|
|
bool RawRecordRdata::operator==(const RawRecordRdata& rhs) const {
|
|
return rdata_ == rhs.rdata_;
|
|
}
|
|
|
|
bool RawRecordRdata::operator!=(const RawRecordRdata& rhs) const {
|
|
return !(*this == rhs);
|
|
}
|
|
|
|
size_t RawRecordRdata::MaxWireSize() const {
|
|
// max_wire_size includes uint16_t record length field.
|
|
return sizeof(uint16_t) + rdata_.size();
|
|
}
|
|
|
|
SrvRecordRdata::SrvRecordRdata() = default;
|
|
|
|
SrvRecordRdata::SrvRecordRdata(uint16_t priority,
|
|
uint16_t weight,
|
|
uint16_t port,
|
|
DomainName target)
|
|
: priority_(priority),
|
|
weight_(weight),
|
|
port_(port),
|
|
target_(std::move(target)) {}
|
|
|
|
SrvRecordRdata::SrvRecordRdata(const SrvRecordRdata& other) = default;
|
|
|
|
SrvRecordRdata::SrvRecordRdata(SrvRecordRdata&& other) noexcept = default;
|
|
|
|
SrvRecordRdata& SrvRecordRdata::operator=(const SrvRecordRdata& rhs) = default;
|
|
|
|
SrvRecordRdata& SrvRecordRdata::operator=(SrvRecordRdata&& rhs) = default;
|
|
|
|
bool SrvRecordRdata::operator==(const SrvRecordRdata& rhs) const {
|
|
return priority_ == rhs.priority_ && weight_ == rhs.weight_ &&
|
|
port_ == rhs.port_ && target_ == rhs.target_;
|
|
}
|
|
|
|
bool SrvRecordRdata::operator!=(const SrvRecordRdata& rhs) const {
|
|
return !(*this == rhs);
|
|
}
|
|
|
|
size_t SrvRecordRdata::MaxWireSize() const {
|
|
// max_wire_size includes uint16_t record length field.
|
|
return sizeof(uint16_t) + sizeof(priority_) + sizeof(weight_) +
|
|
sizeof(port_) + target_.MaxWireSize();
|
|
}
|
|
|
|
ARecordRdata::ARecordRdata() = default;
|
|
|
|
ARecordRdata::ARecordRdata(IPAddress ipv4_address,
|
|
NetworkInterfaceIndex interface_index)
|
|
: ipv4_address_(std::move(ipv4_address)),
|
|
interface_index_(interface_index) {
|
|
OSP_CHECK(ipv4_address_.IsV4());
|
|
}
|
|
|
|
ARecordRdata::ARecordRdata(const ARecordRdata& other) = default;
|
|
|
|
ARecordRdata::ARecordRdata(ARecordRdata&& other) noexcept = default;
|
|
|
|
ARecordRdata& ARecordRdata::operator=(const ARecordRdata& rhs) = default;
|
|
|
|
ARecordRdata& ARecordRdata::operator=(ARecordRdata&& rhs) = default;
|
|
|
|
bool ARecordRdata::operator==(const ARecordRdata& rhs) const {
|
|
return ipv4_address_ == rhs.ipv4_address_ &&
|
|
interface_index_ == rhs.interface_index_;
|
|
}
|
|
|
|
bool ARecordRdata::operator!=(const ARecordRdata& rhs) const {
|
|
return !(*this == rhs);
|
|
}
|
|
|
|
size_t ARecordRdata::MaxWireSize() const {
|
|
// max_wire_size includes uint16_t record length field.
|
|
return sizeof(uint16_t) + IPAddress::kV4Size;
|
|
}
|
|
|
|
AAAARecordRdata::AAAARecordRdata() = default;
|
|
|
|
AAAARecordRdata::AAAARecordRdata(IPAddress ipv6_address,
|
|
NetworkInterfaceIndex interface_index)
|
|
: ipv6_address_(std::move(ipv6_address)),
|
|
interface_index_(interface_index) {
|
|
OSP_CHECK(ipv6_address_.IsV6());
|
|
}
|
|
|
|
AAAARecordRdata::AAAARecordRdata(const AAAARecordRdata& other) = default;
|
|
|
|
AAAARecordRdata::AAAARecordRdata(AAAARecordRdata&& other) noexcept = default;
|
|
|
|
AAAARecordRdata& AAAARecordRdata::operator=(const AAAARecordRdata& rhs) =
|
|
default;
|
|
|
|
AAAARecordRdata& AAAARecordRdata::operator=(AAAARecordRdata&& rhs) = default;
|
|
|
|
bool AAAARecordRdata::operator==(const AAAARecordRdata& rhs) const {
|
|
return ipv6_address_ == rhs.ipv6_address_ &&
|
|
interface_index_ == rhs.interface_index_;
|
|
}
|
|
|
|
bool AAAARecordRdata::operator!=(const AAAARecordRdata& rhs) const {
|
|
return !(*this == rhs);
|
|
}
|
|
|
|
size_t AAAARecordRdata::MaxWireSize() const {
|
|
// max_wire_size includes uint16_t record length field.
|
|
return sizeof(uint16_t) + IPAddress::kV6Size;
|
|
}
|
|
|
|
PtrRecordRdata::PtrRecordRdata() = default;
|
|
|
|
PtrRecordRdata::PtrRecordRdata(DomainName ptr_domain)
|
|
: ptr_domain_(ptr_domain) {}
|
|
|
|
PtrRecordRdata::PtrRecordRdata(const PtrRecordRdata& other) = default;
|
|
|
|
PtrRecordRdata::PtrRecordRdata(PtrRecordRdata&& other) noexcept = default;
|
|
|
|
PtrRecordRdata& PtrRecordRdata::operator=(const PtrRecordRdata& rhs) = default;
|
|
|
|
PtrRecordRdata& PtrRecordRdata::operator=(PtrRecordRdata&& rhs) = default;
|
|
|
|
bool PtrRecordRdata::operator==(const PtrRecordRdata& rhs) const {
|
|
return ptr_domain_ == rhs.ptr_domain_;
|
|
}
|
|
|
|
bool PtrRecordRdata::operator!=(const PtrRecordRdata& rhs) const {
|
|
return !(*this == rhs);
|
|
}
|
|
|
|
size_t PtrRecordRdata::MaxWireSize() const {
|
|
// max_wire_size includes uint16_t record length field.
|
|
return sizeof(uint16_t) + ptr_domain_.MaxWireSize();
|
|
}
|
|
|
|
// static
|
|
ErrorOr<TxtRecordRdata> TxtRecordRdata::TryCreate(std::vector<Entry> texts) {
|
|
std::vector<std::string> str_texts;
|
|
size_t max_wire_size = 3;
|
|
if (texts.size() > 0) {
|
|
str_texts.reserve(texts.size());
|
|
// max_wire_size includes uint16_t record length field.
|
|
max_wire_size = sizeof(uint16_t);
|
|
for (const auto& text : texts) {
|
|
if (text.empty()) {
|
|
return Error::Code::kParameterInvalid;
|
|
}
|
|
str_texts.push_back(
|
|
std::string(reinterpret_cast<const char*>(text.data()), text.size()));
|
|
// Include the length byte in the size calculation.
|
|
max_wire_size += text.size() + 1;
|
|
}
|
|
}
|
|
return TxtRecordRdata(std::move(str_texts), max_wire_size);
|
|
}
|
|
|
|
TxtRecordRdata::TxtRecordRdata() = default;
|
|
|
|
TxtRecordRdata::TxtRecordRdata(std::vector<Entry> texts) {
|
|
ErrorOr<TxtRecordRdata> rdata = TxtRecordRdata::TryCreate(std::move(texts));
|
|
*this = std::move(rdata.value());
|
|
}
|
|
|
|
TxtRecordRdata::TxtRecordRdata(std::vector<std::string> texts,
|
|
size_t max_wire_size)
|
|
: max_wire_size_(max_wire_size), texts_(std::move(texts)) {}
|
|
|
|
TxtRecordRdata::TxtRecordRdata(const TxtRecordRdata& other) = default;
|
|
|
|
TxtRecordRdata::TxtRecordRdata(TxtRecordRdata&& other) noexcept = default;
|
|
|
|
TxtRecordRdata& TxtRecordRdata::operator=(const TxtRecordRdata& rhs) = default;
|
|
|
|
TxtRecordRdata& TxtRecordRdata::operator=(TxtRecordRdata&& rhs) = default;
|
|
|
|
bool TxtRecordRdata::operator==(const TxtRecordRdata& rhs) const {
|
|
return texts_ == rhs.texts_;
|
|
}
|
|
|
|
bool TxtRecordRdata::operator!=(const TxtRecordRdata& rhs) const {
|
|
return !(*this == rhs);
|
|
}
|
|
|
|
size_t TxtRecordRdata::MaxWireSize() const {
|
|
return max_wire_size_;
|
|
}
|
|
|
|
NsecRecordRdata::NsecRecordRdata() = default;
|
|
|
|
NsecRecordRdata::NsecRecordRdata(DomainName next_domain_name,
|
|
std::vector<DnsType> types)
|
|
: types_(std::move(types)), next_domain_name_(std::move(next_domain_name)) {
|
|
// Sort the types_ array for easier comparison later.
|
|
std::sort(types_.begin(), types_.end());
|
|
|
|
// Calculate the bitmaps as described in RFC 4034 Section 4.1.2.
|
|
std::vector<uint8_t> block_contents;
|
|
uint8_t current_block = 0;
|
|
for (auto type : types_) {
|
|
const uint16_t type_int = static_cast<uint16_t>(type);
|
|
const uint8_t block = static_cast<uint8_t>(type_int >> 8);
|
|
const uint8_t block_position = static_cast<uint8_t>(type_int & 0xFF);
|
|
const uint8_t byte_bit_is_at = block_position >> 3; // First 5 bits.
|
|
const uint8_t byte_mask = 0x80 >> (block_position & 0x07); // Last 3 bits.
|
|
|
|
// If the block has changed, write the previous block's info and all of its
|
|
// contents to the |encoded_types_| vector.
|
|
if (block > current_block) {
|
|
if (!block_contents.empty()) {
|
|
encoded_types_.push_back(current_block);
|
|
encoded_types_.push_back(static_cast<uint8_t>(block_contents.size()));
|
|
encoded_types_.insert(encoded_types_.end(), block_contents.begin(),
|
|
block_contents.end());
|
|
}
|
|
block_contents = std::vector<uint8_t>();
|
|
current_block = block;
|
|
}
|
|
|
|
// Make sure |block_contents| is large enough to hold the bit representing
|
|
// the new type , then set it.
|
|
if (block_contents.size() <= byte_bit_is_at) {
|
|
block_contents.insert(block_contents.end(),
|
|
byte_bit_is_at - block_contents.size() + 1, 0x00);
|
|
}
|
|
|
|
block_contents[byte_bit_is_at] |= byte_mask;
|
|
}
|
|
|
|
if (!block_contents.empty()) {
|
|
encoded_types_.push_back(current_block);
|
|
encoded_types_.push_back(static_cast<uint8_t>(block_contents.size()));
|
|
encoded_types_.insert(encoded_types_.end(), block_contents.begin(),
|
|
block_contents.end());
|
|
}
|
|
}
|
|
|
|
NsecRecordRdata::NsecRecordRdata(const NsecRecordRdata& other) = default;
|
|
|
|
NsecRecordRdata::NsecRecordRdata(NsecRecordRdata&& other) noexcept = default;
|
|
|
|
NsecRecordRdata& NsecRecordRdata::operator=(const NsecRecordRdata& rhs) =
|
|
default;
|
|
|
|
NsecRecordRdata& NsecRecordRdata::operator=(NsecRecordRdata&& rhs) = default;
|
|
|
|
bool NsecRecordRdata::operator==(const NsecRecordRdata& rhs) const {
|
|
return types_ == rhs.types_ && next_domain_name_ == rhs.next_domain_name_;
|
|
}
|
|
|
|
bool NsecRecordRdata::operator!=(const NsecRecordRdata& rhs) const {
|
|
return !(*this == rhs);
|
|
}
|
|
|
|
size_t NsecRecordRdata::MaxWireSize() const {
|
|
return next_domain_name_.MaxWireSize() + encoded_types_.size();
|
|
}
|
|
|
|
size_t OptRecordRdata::Option::MaxWireSize() const {
|
|
// One uint16_t for each of OPTION-LENGTH and OPTION-CODE as defined in RFC
|
|
// 6891 section 6.1.2.
|
|
constexpr size_t kOptionLengthAndCodeSize = 2 * sizeof(uint16_t);
|
|
return data.size() + kOptionLengthAndCodeSize;
|
|
}
|
|
|
|
bool OptRecordRdata::Option::operator>(
|
|
const OptRecordRdata::Option& rhs) const {
|
|
if (code != rhs.code) {
|
|
return code > rhs.code;
|
|
} else if (length != rhs.length) {
|
|
return length > rhs.length;
|
|
} else if (data.size() != rhs.data.size()) {
|
|
return data.size() > rhs.data.size();
|
|
}
|
|
|
|
for (int i = 0; i < static_cast<int>(data.size()); i++) {
|
|
if (data[i] != rhs.data[i]) {
|
|
return data[i] > rhs.data[i];
|
|
}
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
bool OptRecordRdata::Option::operator<(
|
|
const OptRecordRdata::Option& rhs) const {
|
|
return rhs > *this;
|
|
}
|
|
|
|
bool OptRecordRdata::Option::operator>=(
|
|
const OptRecordRdata::Option& rhs) const {
|
|
return !(*this < rhs);
|
|
}
|
|
|
|
bool OptRecordRdata::Option::operator<=(
|
|
const OptRecordRdata::Option& rhs) const {
|
|
return !(*this > rhs);
|
|
}
|
|
|
|
bool OptRecordRdata::Option::operator==(
|
|
const OptRecordRdata::Option& rhs) const {
|
|
return *this >= rhs && *this <= rhs;
|
|
}
|
|
|
|
bool OptRecordRdata::Option::operator!=(
|
|
const OptRecordRdata::Option& rhs) const {
|
|
return !(*this == rhs);
|
|
}
|
|
|
|
OptRecordRdata::OptRecordRdata() = default;
|
|
|
|
OptRecordRdata::OptRecordRdata(std::vector<Option> options)
|
|
: options_(std::move(options)) {
|
|
for (const auto& option : options_) {
|
|
max_wire_size_ += option.MaxWireSize();
|
|
}
|
|
std::sort(options_.begin(), options_.end());
|
|
}
|
|
|
|
OptRecordRdata::OptRecordRdata(const OptRecordRdata& other) = default;
|
|
|
|
OptRecordRdata::OptRecordRdata(OptRecordRdata&& other) noexcept = default;
|
|
|
|
OptRecordRdata& OptRecordRdata::operator=(const OptRecordRdata& rhs) = default;
|
|
|
|
OptRecordRdata& OptRecordRdata::operator=(OptRecordRdata&& rhs) = default;
|
|
|
|
bool OptRecordRdata::operator==(const OptRecordRdata& rhs) const {
|
|
return options_ == rhs.options_;
|
|
}
|
|
|
|
bool OptRecordRdata::operator!=(const OptRecordRdata& rhs) const {
|
|
return !(*this == rhs);
|
|
}
|
|
|
|
// static
|
|
ErrorOr<MdnsRecord> MdnsRecord::TryCreate(DomainName name,
|
|
DnsType dns_type,
|
|
DnsClass dns_class,
|
|
RecordType record_type,
|
|
std::chrono::seconds ttl,
|
|
Rdata rdata) {
|
|
if (!IsValidConfig(name, dns_type, ttl, rdata)) {
|
|
return Error::Code::kParameterInvalid;
|
|
} else {
|
|
return MdnsRecord(std::move(name), dns_type, dns_class, record_type, ttl,
|
|
std::move(rdata));
|
|
}
|
|
}
|
|
|
|
MdnsRecord::MdnsRecord() = default;
|
|
|
|
MdnsRecord::MdnsRecord(DomainName name,
|
|
DnsType dns_type,
|
|
DnsClass dns_class,
|
|
RecordType record_type,
|
|
std::chrono::seconds ttl,
|
|
Rdata rdata)
|
|
: name_(std::move(name)),
|
|
dns_type_(dns_type),
|
|
dns_class_(dns_class),
|
|
record_type_(record_type),
|
|
ttl_(ttl),
|
|
rdata_(std::move(rdata)) {
|
|
OSP_DCHECK(IsValidConfig(name_, dns_type, ttl_, rdata_));
|
|
}
|
|
|
|
MdnsRecord::MdnsRecord(const MdnsRecord& other) = default;
|
|
|
|
MdnsRecord::MdnsRecord(MdnsRecord&& other) noexcept = default;
|
|
|
|
MdnsRecord& MdnsRecord::operator=(const MdnsRecord& rhs) = default;
|
|
|
|
MdnsRecord& MdnsRecord::operator=(MdnsRecord&& rhs) = default;
|
|
|
|
// static
|
|
bool MdnsRecord::IsValidConfig(const DomainName& name,
|
|
DnsType dns_type,
|
|
std::chrono::seconds ttl,
|
|
const Rdata& rdata) {
|
|
// NOTE: Although the name_ field was initially expected to be non-empty, this
|
|
// validation is no longer accurate for some record types (such as OPT
|
|
// records). To ensure that future record types correctly parse into
|
|
// RawRecordData types and do not invalidate the received message, this check
|
|
// has been removed.
|
|
return ttl.count() <= std::numeric_limits<uint32_t>::max() &&
|
|
((dns_type == DnsType::kSRV &&
|
|
absl::holds_alternative<SrvRecordRdata>(rdata)) ||
|
|
(dns_type == DnsType::kA &&
|
|
absl::holds_alternative<ARecordRdata>(rdata)) ||
|
|
(dns_type == DnsType::kAAAA &&
|
|
absl::holds_alternative<AAAARecordRdata>(rdata)) ||
|
|
(dns_type == DnsType::kPTR &&
|
|
absl::holds_alternative<PtrRecordRdata>(rdata)) ||
|
|
(dns_type == DnsType::kTXT &&
|
|
absl::holds_alternative<TxtRecordRdata>(rdata)) ||
|
|
(dns_type == DnsType::kNSEC &&
|
|
absl::holds_alternative<NsecRecordRdata>(rdata)) ||
|
|
(dns_type == DnsType::kOPT &&
|
|
absl::holds_alternative<OptRecordRdata>(rdata)) ||
|
|
absl::holds_alternative<RawRecordRdata>(rdata));
|
|
}
|
|
|
|
bool MdnsRecord::operator==(const MdnsRecord& rhs) const {
|
|
return IsReannouncementOf(rhs) && ttl_ == rhs.ttl_;
|
|
}
|
|
|
|
bool MdnsRecord::operator!=(const MdnsRecord& rhs) const {
|
|
return !(*this == rhs);
|
|
}
|
|
|
|
bool MdnsRecord::operator>(const MdnsRecord& rhs) const {
|
|
// Returns the record which is lexicographically later. The determination of
|
|
// "lexicographically later" is performed by first comparing the record class,
|
|
// then the record type, then raw comparison of the binary content of the
|
|
// rdata without regard for meaning or structure.
|
|
// NOTE: Per RFC, the TTL is not included in this comparison.
|
|
if (name() != rhs.name()) {
|
|
return name() > rhs.name();
|
|
}
|
|
|
|
if (record_type() != rhs.record_type()) {
|
|
return record_type() == RecordType::kUnique;
|
|
}
|
|
|
|
if (dns_class() != rhs.dns_class()) {
|
|
return dns_class() > rhs.dns_class();
|
|
}
|
|
|
|
uint16_t this_type = static_cast<uint16_t>(dns_type()) & kClassMask;
|
|
uint16_t other_type = static_cast<uint16_t>(rhs.dns_type()) & kClassMask;
|
|
if (this_type != other_type) {
|
|
return this_type > other_type;
|
|
}
|
|
|
|
return IsGreaterThan(dns_type(), rdata(), rhs.rdata());
|
|
}
|
|
|
|
bool MdnsRecord::operator<(const MdnsRecord& rhs) const {
|
|
return rhs > *this;
|
|
}
|
|
|
|
bool MdnsRecord::operator<=(const MdnsRecord& rhs) const {
|
|
return !(*this > rhs);
|
|
}
|
|
|
|
bool MdnsRecord::operator>=(const MdnsRecord& rhs) const {
|
|
return !(*this < rhs);
|
|
}
|
|
|
|
bool MdnsRecord::IsReannouncementOf(const MdnsRecord& rhs) const {
|
|
return dns_type_ == rhs.dns_type_ && dns_class_ == rhs.dns_class_ &&
|
|
record_type_ == rhs.record_type_ && name_ == rhs.name_ &&
|
|
rdata_ == rhs.rdata_;
|
|
}
|
|
|
|
size_t MdnsRecord::MaxWireSize() const {
|
|
auto wire_size_visitor = [](auto&& arg) { return arg.MaxWireSize(); };
|
|
// NAME size, 2-byte TYPE, 2-byte CLASS, 4-byte TTL, RDATA size
|
|
return name_.MaxWireSize() + absl::visit(wire_size_visitor, rdata_) + 8;
|
|
}
|
|
|
|
std::string MdnsRecord::ToString() const {
|
|
std::stringstream ss;
|
|
ss << "name: '" << name_.ToString() << "'";
|
|
ss << ", type: " << dns_type_;
|
|
|
|
if (dns_type_ == DnsType::kPTR) {
|
|
const DomainName& target = absl::get<PtrRecordRdata>(rdata_).ptr_domain();
|
|
ss << ", target: '" << target.ToString() << "'";
|
|
} else if (dns_type_ == DnsType::kSRV) {
|
|
const DomainName& target = absl::get<SrvRecordRdata>(rdata_).target();
|
|
ss << ", target: '" << target.ToString() << "'";
|
|
} else if (dns_type_ == DnsType::kNSEC) {
|
|
const auto& nsec_rdata = absl::get<NsecRecordRdata>(rdata_);
|
|
std::vector<DnsType> types = nsec_rdata.types();
|
|
ss << ", representing [";
|
|
if (!types.empty()) {
|
|
auto it = types.begin();
|
|
ss << *it++;
|
|
while (it != types.end()) {
|
|
ss << ", " << *it++;
|
|
}
|
|
ss << "]";
|
|
}
|
|
}
|
|
|
|
return ss.str();
|
|
}
|
|
|
|
MdnsRecord CreateAddressRecord(DomainName name, const IPAddress& address) {
|
|
Rdata rdata;
|
|
DnsType type;
|
|
std::chrono::seconds ttl;
|
|
if (address.IsV4()) {
|
|
type = DnsType::kA;
|
|
rdata = ARecordRdata(address);
|
|
ttl = kARecordTtl;
|
|
} else {
|
|
type = DnsType::kAAAA;
|
|
rdata = AAAARecordRdata(address);
|
|
ttl = kAAAARecordTtl;
|
|
}
|
|
|
|
return MdnsRecord(std::move(name), type, DnsClass::kIN, RecordType::kUnique,
|
|
ttl, std::move(rdata));
|
|
}
|
|
|
|
// static
|
|
ErrorOr<MdnsQuestion> MdnsQuestion::TryCreate(DomainName name,
|
|
DnsType dns_type,
|
|
DnsClass dns_class,
|
|
ResponseType response_type) {
|
|
if (name.empty()) {
|
|
return Error::Code::kParameterInvalid;
|
|
}
|
|
|
|
return MdnsQuestion(std::move(name), dns_type, dns_class, response_type);
|
|
}
|
|
|
|
MdnsQuestion::MdnsQuestion(DomainName name,
|
|
DnsType dns_type,
|
|
DnsClass dns_class,
|
|
ResponseType response_type)
|
|
: name_(std::move(name)),
|
|
dns_type_(dns_type),
|
|
dns_class_(dns_class),
|
|
response_type_(response_type) {
|
|
OSP_CHECK(!name_.empty());
|
|
}
|
|
|
|
bool MdnsQuestion::operator==(const MdnsQuestion& rhs) const {
|
|
return dns_type_ == rhs.dns_type_ && dns_class_ == rhs.dns_class_ &&
|
|
response_type_ == rhs.response_type_ && name_ == rhs.name_;
|
|
}
|
|
|
|
bool MdnsQuestion::operator!=(const MdnsQuestion& rhs) const {
|
|
return !(*this == rhs);
|
|
}
|
|
|
|
size_t MdnsQuestion::MaxWireSize() const {
|
|
// NAME size, 2-byte TYPE, 2-byte CLASS
|
|
return name_.MaxWireSize() + 4;
|
|
}
|
|
|
|
// static
|
|
ErrorOr<MdnsMessage> MdnsMessage::TryCreate(
|
|
uint16_t id,
|
|
MessageType type,
|
|
std::vector<MdnsQuestion> questions,
|
|
std::vector<MdnsRecord> answers,
|
|
std::vector<MdnsRecord> authority_records,
|
|
std::vector<MdnsRecord> additional_records) {
|
|
if (questions.size() >= kMaxMessageFieldEntryCount ||
|
|
answers.size() >= kMaxMessageFieldEntryCount ||
|
|
authority_records.size() >= kMaxMessageFieldEntryCount ||
|
|
additional_records.size() >= kMaxMessageFieldEntryCount) {
|
|
return Error::Code::kParameterInvalid;
|
|
}
|
|
|
|
return MdnsMessage(id, type, std::move(questions), std::move(answers),
|
|
std::move(authority_records),
|
|
std::move(additional_records));
|
|
}
|
|
|
|
MdnsMessage::MdnsMessage(uint16_t id, MessageType type)
|
|
: id_(id), type_(type) {}
|
|
|
|
MdnsMessage::MdnsMessage(uint16_t id,
|
|
MessageType type,
|
|
std::vector<MdnsQuestion> questions,
|
|
std::vector<MdnsRecord> answers,
|
|
std::vector<MdnsRecord> authority_records,
|
|
std::vector<MdnsRecord> additional_records)
|
|
: id_(id),
|
|
type_(type),
|
|
questions_(std::move(questions)),
|
|
answers_(std::move(answers)),
|
|
authority_records_(std::move(authority_records)),
|
|
additional_records_(std::move(additional_records)) {
|
|
OSP_DCHECK(questions_.size() < kMaxMessageFieldEntryCount);
|
|
OSP_DCHECK(answers_.size() < kMaxMessageFieldEntryCount);
|
|
OSP_DCHECK(authority_records_.size() < kMaxMessageFieldEntryCount);
|
|
OSP_DCHECK(additional_records_.size() < kMaxMessageFieldEntryCount);
|
|
|
|
for (const MdnsQuestion& question : questions_) {
|
|
max_wire_size_ += question.MaxWireSize();
|
|
}
|
|
for (const MdnsRecord& record : answers_) {
|
|
max_wire_size_ += record.MaxWireSize();
|
|
}
|
|
for (const MdnsRecord& record : authority_records_) {
|
|
max_wire_size_ += record.MaxWireSize();
|
|
}
|
|
for (const MdnsRecord& record : additional_records_) {
|
|
max_wire_size_ += record.MaxWireSize();
|
|
}
|
|
}
|
|
|
|
bool MdnsMessage::operator==(const MdnsMessage& rhs) const {
|
|
return id_ == rhs.id_ && type_ == rhs.type_ && questions_ == rhs.questions_ &&
|
|
answers_ == rhs.answers_ &&
|
|
authority_records_ == rhs.authority_records_ &&
|
|
additional_records_ == rhs.additional_records_;
|
|
}
|
|
|
|
bool MdnsMessage::operator!=(const MdnsMessage& rhs) const {
|
|
return !(*this == rhs);
|
|
}
|
|
|
|
bool MdnsMessage::IsProbeQuery() const {
|
|
// A message is a probe query if it contains records in the authority section
|
|
// which answer the question being asked.
|
|
if (questions().empty() || authority_records().empty()) {
|
|
return false;
|
|
}
|
|
|
|
for (const MdnsQuestion& question : questions_) {
|
|
for (const MdnsRecord& record : authority_records_) {
|
|
if (question.name() == record.name() &&
|
|
((question.dns_type() == record.dns_type()) ||
|
|
(question.dns_type() == DnsType::kANY)) &&
|
|
((question.dns_class() == record.dns_class()) ||
|
|
(question.dns_class() == DnsClass::kANY))) {
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
size_t MdnsMessage::MaxWireSize() const {
|
|
return max_wire_size_;
|
|
}
|
|
|
|
void MdnsMessage::AddQuestion(MdnsQuestion question) {
|
|
OSP_DCHECK(questions_.size() < kMaxMessageFieldEntryCount);
|
|
max_wire_size_ += question.MaxWireSize();
|
|
questions_.emplace_back(std::move(question));
|
|
}
|
|
|
|
void MdnsMessage::AddAnswer(MdnsRecord record) {
|
|
OSP_DCHECK(answers_.size() < kMaxMessageFieldEntryCount);
|
|
max_wire_size_ += record.MaxWireSize();
|
|
answers_.emplace_back(std::move(record));
|
|
}
|
|
|
|
void MdnsMessage::AddAuthorityRecord(MdnsRecord record) {
|
|
OSP_DCHECK(authority_records_.size() < kMaxMessageFieldEntryCount);
|
|
max_wire_size_ += record.MaxWireSize();
|
|
authority_records_.emplace_back(std::move(record));
|
|
}
|
|
|
|
void MdnsMessage::AddAdditionalRecord(MdnsRecord record) {
|
|
OSP_DCHECK(additional_records_.size() < kMaxMessageFieldEntryCount);
|
|
max_wire_size_ += record.MaxWireSize();
|
|
additional_records_.emplace_back(std::move(record));
|
|
}
|
|
|
|
bool MdnsMessage::CanAddRecord(const MdnsRecord& record) {
|
|
return (max_wire_size_ + record.MaxWireSize()) < kMaxMulticastMessageSize;
|
|
}
|
|
|
|
uint16_t CreateMessageId() {
|
|
static uint16_t id(0);
|
|
return id++;
|
|
}
|
|
|
|
bool CanBePublished(DnsType type) {
|
|
// NOTE: A 'default' switch statement has intentionally been avoided below to
|
|
// enforce that new DnsTypes added must be added below through a compile-time
|
|
// check.
|
|
switch (type) {
|
|
case DnsType::kA:
|
|
case DnsType::kAAAA:
|
|
case DnsType::kPTR:
|
|
case DnsType::kTXT:
|
|
case DnsType::kSRV:
|
|
return true;
|
|
case DnsType::kOPT:
|
|
case DnsType::kNSEC:
|
|
case DnsType::kANY:
|
|
break;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
bool CanBePublished(const MdnsRecord& record) {
|
|
return CanBePublished(record.dns_type());
|
|
}
|
|
|
|
bool CanBeQueried(DnsType type) {
|
|
// NOTE: A 'default' switch statement has intentionally been avoided below to
|
|
// enforce that new DnsTypes added must be added below through a compile-time
|
|
// check.
|
|
switch (type) {
|
|
case DnsType::kA:
|
|
case DnsType::kAAAA:
|
|
case DnsType::kPTR:
|
|
case DnsType::kTXT:
|
|
case DnsType::kSRV:
|
|
case DnsType::kANY:
|
|
return true;
|
|
case DnsType::kOPT:
|
|
case DnsType::kNSEC:
|
|
break;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
bool CanBeProcessed(DnsType type) {
|
|
// NOTE: A 'default' switch statement has intentionally been avoided below to
|
|
// enforce that new DnsTypes added must be added below through a compile-time
|
|
// check.
|
|
switch (type) {
|
|
case DnsType::kA:
|
|
case DnsType::kAAAA:
|
|
case DnsType::kPTR:
|
|
case DnsType::kTXT:
|
|
case DnsType::kSRV:
|
|
case DnsType::kNSEC:
|
|
return true;
|
|
case DnsType::kOPT:
|
|
case DnsType::kANY:
|
|
break;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
} // namespace discovery
|
|
} // namespace openscreen
|