summaryrefslogtreecommitdiff
path: root/mojo/public/cpp/bindings/lib/message_header_validator.cc
blob: 9f8c6278c07d0c520ad3ef11e3191d65307e4982 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "mojo/public/cpp/bindings/message_header_validator.h"

#include "mojo/public/cpp/bindings/lib/array_internal.h"
#include "mojo/public/cpp/bindings/lib/validate_params.h"
#include "mojo/public/cpp/bindings/lib/validation_context.h"
#include "mojo/public/cpp/bindings/lib/validation_errors.h"
#include "mojo/public/cpp/bindings/lib/validation_util.h"

namespace mojo {
namespace {

// TODO(yzshen): Define a mojom struct for message header and use the generated
// validation and data view code.
bool IsValidMessageHeader(const internal::MessageHeader* header,
                          internal::ValidationContext* validation_context) {
  // NOTE: Our goal is to preserve support for future extension of the message
  // header. If we encounter fields we do not understand, we must ignore them.

  // Extra validation of the struct header:
  do {
    if (header->version == 0) {
      if (header->num_bytes == sizeof(internal::MessageHeader))
        break;
    } else if (header->version == 1) {
      if (header->num_bytes == sizeof(internal::MessageHeaderV1))
        break;
    } else if (header->version == 2) {
      if (header->num_bytes == sizeof(internal::MessageHeaderV2))
        break;
    } else if (header->version > 2) {
      if (header->num_bytes >= sizeof(internal::MessageHeaderV2))
        break;
    }
    internal::ReportValidationError(
        validation_context,
        internal::VALIDATION_ERROR_UNEXPECTED_STRUCT_HEADER);
    return false;
  } while (false);

  // Validate flags (allow unknown bits):

  // These flags require a RequestID.
  constexpr uint32_t kRequestIdFlags =
      Message::kFlagExpectsResponse | Message::kFlagIsResponse;
  if (header->version == 0 && (header->flags & kRequestIdFlags)) {
    internal::ReportValidationError(
        validation_context,
        internal::VALIDATION_ERROR_MESSAGE_HEADER_MISSING_REQUEST_ID);
    return false;
  }

  // These flags are mutually exclusive.
  if ((header->flags & kRequestIdFlags) == kRequestIdFlags) {
    internal::ReportValidationError(
        validation_context,
        internal::VALIDATION_ERROR_MESSAGE_HEADER_INVALID_FLAGS);
    return false;
  }

  if (header->version < 2)
    return true;

  auto* header_v2 = static_cast<const internal::MessageHeaderV2*>(header);
  // For the payload pointer:
  // - Check that the pointer can be safely decoded.
  // - Claim one byte that the pointer points to. It makes sure not only the
  //   address is within the message, but also the address precedes the array
  //   storing interface IDs (which is important for safely calculating the
  //   payload size).
  // - Validation of the payload contents will be done separately based on the
  //   payload type.
  if (!header_v2->payload.is_null() &&
      (!internal::ValidatePointer(header_v2->payload, validation_context) ||
       !validation_context->ClaimMemory(header_v2->payload.Get(), 1))) {
    return false;
  }

  const internal::ContainerValidateParams validate_params(0, false, nullptr);
  if (!internal::ValidateContainer(header_v2->payload_interface_ids,
                                   validation_context, &validate_params)) {
    return false;
  }

  if (!header_v2->payload_interface_ids.is_null()) {
    size_t num_ids = header_v2->payload_interface_ids.Get()->size();
    const uint32_t* ids = header_v2->payload_interface_ids.Get()->storage();
    for (size_t i = 0; i < num_ids; ++i) {
      if (!IsValidInterfaceId(ids[i]) || IsMasterInterfaceId(ids[i])) {
        internal::ReportValidationError(
            validation_context,
            internal::VALIDATION_ERROR_ILLEGAL_INTERFACE_ID);
        return false;
      }
    }
  }

  return true;
}

}  // namespace

MessageHeaderValidator::MessageHeaderValidator()
    : MessageHeaderValidator("MessageHeaderValidator") {}

MessageHeaderValidator::MessageHeaderValidator(const std::string& description)
    : description_(description) {
}

void MessageHeaderValidator::SetDescription(const std::string& description) {
  description_ = description;
}

bool MessageHeaderValidator::Accept(Message* message) {
  // Pass 0 as number of handles and associated endpoint handles because we
  // don't expect any in the header, even if |message| contains handles.
  internal::ValidationContext validation_context(
      message->data(), message->data_num_bytes(), 0, 0, message, description_);

  if (!internal::ValidateStructHeaderAndClaimMemory(message->data(),
                                                    &validation_context))
    return false;

  if (!IsValidMessageHeader(message->header(), &validation_context))
    return false;

  return true;
}

}  // namespace mojo