aboutsummaryrefslogtreecommitdiff
path: root/net/dcsctp/fuzzers/dcsctp_fuzzers.h
diff options
context:
space:
mode:
Diffstat (limited to 'net/dcsctp/fuzzers/dcsctp_fuzzers.h')
-rw-r--r--net/dcsctp/fuzzers/dcsctp_fuzzers.h116
1 files changed, 116 insertions, 0 deletions
diff --git a/net/dcsctp/fuzzers/dcsctp_fuzzers.h b/net/dcsctp/fuzzers/dcsctp_fuzzers.h
new file mode 100644
index 0000000000..f3de0722f4
--- /dev/null
+++ b/net/dcsctp/fuzzers/dcsctp_fuzzers.h
@@ -0,0 +1,116 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#ifndef NET_DCSCTP_FUZZERS_DCSCTP_FUZZERS_H_
+#define NET_DCSCTP_FUZZERS_DCSCTP_FUZZERS_H_
+
+#include <deque>
+#include <memory>
+#include <set>
+#include <vector>
+
+#include "api/array_view.h"
+#include "net/dcsctp/public/dcsctp_socket.h"
+
+namespace dcsctp {
+namespace dcsctp_fuzzers {
+
+// A fake timeout used during fuzzing.
+class FuzzerTimeout : public Timeout {
+ public:
+ explicit FuzzerTimeout(std::set<TimeoutID>& active_timeouts)
+ : active_timeouts_(active_timeouts) {}
+
+ void Start(DurationMs duration_ms, TimeoutID timeout_id) override {
+ // Start is only allowed to be called on stopped or expired timeouts.
+ if (timeout_id_.has_value()) {
+ // It has been started before, but maybe it expired. Ensure that it's not
+ // running at least.
+ RTC_DCHECK(active_timeouts_.find(*timeout_id_) == active_timeouts_.end());
+ }
+ timeout_id_ = timeout_id;
+ RTC_DCHECK(active_timeouts_.insert(timeout_id).second);
+ }
+
+ void Stop() override {
+ // Stop is only allowed to be called on active timeouts. Not stopped or
+ // expired.
+ RTC_DCHECK(timeout_id_.has_value());
+ RTC_DCHECK(active_timeouts_.erase(*timeout_id_) == 1);
+ timeout_id_ = absl::nullopt;
+ }
+
+ // A set of all active timeouts, managed by `FuzzerCallbacks`.
+ std::set<TimeoutID>& active_timeouts_;
+ // If present, the timout is active and will expire reported as `timeout_id`.
+ absl::optional<TimeoutID> timeout_id_;
+};
+
+class FuzzerCallbacks : public DcSctpSocketCallbacks {
+ public:
+ static constexpr int kRandomValue = 42;
+ void SendPacket(rtc::ArrayView<const uint8_t> data) override {
+ sent_packets_.emplace_back(std::vector<uint8_t>(data.begin(), data.end()));
+ }
+ std::unique_ptr<Timeout> CreateTimeout() override {
+ return std::make_unique<FuzzerTimeout>(active_timeouts_);
+ }
+ TimeMs TimeMillis() override { return TimeMs(42); }
+ uint32_t GetRandomInt(uint32_t low, uint32_t high) override {
+ return kRandomValue;
+ }
+ void OnMessageReceived(DcSctpMessage message) override {}
+ void OnError(ErrorKind error, absl::string_view message) override {}
+ void OnAborted(ErrorKind error, absl::string_view message) override {}
+ void OnConnected() override {}
+ void OnClosed() override {}
+ void OnConnectionRestarted() override {}
+ void OnStreamsResetFailed(rtc::ArrayView<const StreamID> outgoing_streams,
+ absl::string_view reason) override {}
+ void OnStreamsResetPerformed(
+ rtc::ArrayView<const StreamID> outgoing_streams) override {}
+ void OnIncomingStreamsReset(
+ rtc::ArrayView<const StreamID> incoming_streams) override {}
+
+ std::vector<uint8_t> ConsumeSentPacket() {
+ if (sent_packets_.empty()) {
+ return {};
+ }
+ std::vector<uint8_t> ret = sent_packets_.front();
+ sent_packets_.pop_front();
+ return ret;
+ }
+
+ // Given an index among the active timeouts, will expire that one.
+ absl::optional<TimeoutID> ExpireTimeout(size_t index) {
+ if (index < active_timeouts_.size()) {
+ auto it = active_timeouts_.begin();
+ std::advance(it, index);
+ TimeoutID timeout_id = *it;
+ active_timeouts_.erase(it);
+ return timeout_id;
+ }
+ return absl::nullopt;
+ }
+
+ private:
+ // Needs to be ordered, to allow fuzzers to expire timers.
+ std::set<TimeoutID> active_timeouts_;
+ std::deque<std::vector<uint8_t>> sent_packets_;
+};
+
+// Given some fuzzing `data` will send packets to the socket as well as calling
+// API methods.
+void FuzzSocket(DcSctpSocketInterface& socket,
+ FuzzerCallbacks& cb,
+ rtc::ArrayView<const uint8_t> data);
+
+} // namespace dcsctp_fuzzers
+} // namespace dcsctp
+#endif // NET_DCSCTP_FUZZERS_DCSCTP_FUZZERS_H_