aboutsummaryrefslogtreecommitdiff
path: root/fcp/base/match.h
diff options
context:
space:
mode:
Diffstat (limited to 'fcp/base/match.h')
-rw-r--r--fcp/base/match.h292
1 files changed, 292 insertions, 0 deletions
diff --git a/fcp/base/match.h b/fcp/base/match.h
new file mode 100644
index 0000000..253b59b
--- /dev/null
+++ b/fcp/base/match.h
@@ -0,0 +1,292 @@
+/*
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// 'Match' expressions for {std, absl}::variant.
+//
+// {std, absl}::variant is an algebraic sum type. However, the standard library
+// does not provide a convenient way to destructure or match on them - unlike in
+// Haskell, Rust, etc.
+//
+// This file provides a way to match on :variant in a way akin to a switch
+// statement.
+//
+// Example:
+//
+// using V = std::variant<X, Y, Z>;
+// V v = ...;
+// ...
+// int i = Match(v,
+// [](X const& x) { return 1; },
+// [](Y const& y) { return 2; },
+// [](Z const& z) { return 3; });
+//
+// It is a compile-time error if the match is not exhaustive. A 'Default' case
+// can be provided:
+//
+// int i = Match(v,
+// [](X const& x) { return 1; },
+// // Called with the otherwise-unhandled alternative (see decltype(alt)).
+// [](Default, auto const& alt) { ...; });
+//
+// int i = Match(v,
+// [](X const& x) { return 1; },
+// // Called with the variant itself.
+// [](Default, V const& v) { ...; });
+//
+// If constructing the matcher lambdas is non-trivial, it might be worthwhile to
+// create a re-usable matcher object. See 'MakeMatcher'.
+
+#ifndef FCP_BASE_MATCH_H_
+#define FCP_BASE_MATCH_H_
+
+#include <optional>
+#include <type_traits>
+#include <variant>
+
+#include "fcp/base/meta.h"
+
+namespace fcp {
+
+// Marker type for default match cases.
+struct Default {};
+
+namespace match_internal {
+
+template <typename... CaseFns>
+struct MatchCasesCallable : public CaseFns... {
+ // Each CaseFn provides operator(). We want to pick one by overload
+ // resolution.
+ using CaseFns::operator()...;
+};
+
+template <typename ToType, typename... CaseFns>
+class MatchCases {
+ public:
+ explicit constexpr MatchCases(MatchCasesCallable<CaseFns...> c)
+ : callable_(std::move(c)) {}
+
+ // False by default
+ template <typename Enable, typename... T>
+ struct IsCaseHandledImpl : public std::false_type {};
+
+ // True when m.MatchCases(args...) is well-formed, for a
+ // MatchCases<CaseFns...> m and T arg.
+ template <typename... T>
+ struct IsCaseHandledImpl<
+ std::void_t<decltype(std::declval<MatchCasesCallable<CaseFns...>>()(
+ std::declval<T>()...))>,
+ T...> : public std::true_type {};
+
+ template <typename... T>
+ static constexpr bool IsCaseHandled() {
+ return IsCaseHandledImpl<void, T...>::value;
+ }
+
+ template <typename ToType_ = ToType, typename... Args>
+ constexpr auto operator()(Args&&... args) const {
+ if constexpr (std::is_void_v<ToType_>) {
+ return callable_(std::forward<Args>(args)...);
+ } else {
+ return ToType_(callable_(std::forward<Args>(args)...));
+ }
+ }
+
+ private:
+ MatchCasesCallable<CaseFns...> callable_;
+};
+
+template <typename ToType, typename... CaseFns>
+constexpr MatchCases<ToType, CaseFns...> MakeMatchCases(CaseFns... case_fns) {
+ return MatchCases<ToType, CaseFns...>(
+ MatchCasesCallable<CaseFns...>{case_fns...});
+}
+
+template <typename CasesType, typename VariantType, typename ArgType>
+constexpr auto ApplyCase(CasesType const& cases, VariantType&& v,
+ ArgType&& arg) {
+ if constexpr (CasesType::template IsCaseHandled<ArgType>()) {
+ return cases(std::forward<ArgType>(arg));
+ } else if constexpr (CasesType::template IsCaseHandled<Default, ArgType>()) {
+ return cases(Default{}, std::forward<ArgType>(arg));
+ } else if constexpr (CasesType::template IsCaseHandled<Default,
+ VariantType>()) {
+ return cases(Default{}, std::forward<VariantType>(v));
+ } else if constexpr (CasesType::template IsCaseHandled<Default>()) {
+ return cases(Default{});
+ } else {
+ static_assert(
+ FailIfReached<ArgType>(),
+ "Provide a case for all variant alternatives, or a 'Default' case");
+ }
+}
+
+template <typename Traits, typename CasesType>
+class VariantMatcherImpl {
+ public:
+ using ValueType = typename Traits::ValueType;
+
+ explicit constexpr VariantMatcherImpl(CasesType cases)
+ : cases_(std::move(cases)) {}
+
+ constexpr auto Match(ValueType* v) const { return MatchInternal(v); }
+
+ constexpr auto Match(ValueType const& v) const { return MatchInternal(v); }
+
+ constexpr auto Match(ValueType&& v) const {
+ return MatchInternal(std::move(v));
+ }
+
+ private:
+ template <typename FromType>
+ constexpr auto MatchInternal(FromType&& v) const {
+ return Traits::Visit(std::forward<FromType>(v), [this, &v](auto&& alt) {
+ return ApplyCase(cases_, std::forward<FromType>(v),
+ std::forward<decltype(alt)>(alt));
+ });
+ }
+
+ CasesType cases_;
+};
+
+template <typename T, typename Enable = void>
+struct MatchTraits {
+ static_assert(FailIfReached<T>(),
+ "Only variant-like (e.g. std::variant<...> types can be "
+ "matched. See MatchTraits.");
+};
+
+template <typename... AltTypes>
+struct MatchTraits<std::variant<AltTypes...>> {
+ using ValueType = std::variant<AltTypes...>;
+
+ template <typename VisitFn>
+ static constexpr auto Visit(ValueType const& v, VisitFn&& fn) {
+ return absl::visit(std::forward<VisitFn>(fn), v);
+ }
+
+ template <typename VisitFn>
+ static constexpr auto Visit(ValueType&& v, VisitFn&& fn) {
+ return absl::visit(std::forward<VisitFn>(fn), std::move(v));
+ }
+
+ template <typename VisitFn>
+ static constexpr auto Visit(ValueType* v, VisitFn&& fn) {
+ return absl::visit([fn = std::forward<VisitFn>(fn)](
+ auto& alt) mutable { return fn(&alt); },
+ *v);
+ }
+};
+
+template <typename T>
+struct MatchTraits<std::optional<T>> {
+ using ValueType = std::optional<T>;
+
+ static constexpr auto Wrap(std::optional<T>* o)
+ -> std::variant<T*, std::nullopt_t> {
+ if (o->has_value()) {
+ return &**o;
+ } else {
+ return std::nullopt;
+ }
+ }
+
+ static constexpr auto Wrap(std::optional<T> const& o)
+ -> std::variant<std::reference_wrapper<T const>, std::nullopt_t> {
+ if (o.has_value()) {
+ return std::ref(*o);
+ } else {
+ return std::nullopt;
+ }
+ }
+
+ static constexpr auto Wrap(std::optional<T>&& o)
+ -> std::variant<T, std::nullopt_t> {
+ if (o.has_value()) {
+ return *std::move(o);
+ } else {
+ return std::nullopt;
+ }
+ }
+
+ template <typename V, typename VisitFn>
+ static constexpr auto Visit(V&& v, VisitFn&& fn) {
+ return absl::visit(std::forward<VisitFn>(fn), Wrap(std::forward<V>(v)));
+ }
+};
+
+template <typename T>
+struct MatchTraits<T, std::void_t<typename T::VariantType>> {
+ using ValueType = T;
+
+ template <typename VisitFn>
+ static constexpr auto Visit(ValueType const& v, VisitFn&& fn) {
+ return MatchTraits<typename T::VariantType>::Visit(
+ v.variant(), std::forward<VisitFn>(fn));
+ }
+
+ template <typename VisitFn>
+ static constexpr auto Visit(ValueType&& v, VisitFn&& fn) {
+ return MatchTraits<typename T::VariantType>::Visit(
+ std::move(v).variant(), std::forward<VisitFn>(fn));
+ }
+
+ template <typename VisitFn>
+ static constexpr auto Visit(ValueType* v, VisitFn&& fn) {
+ return MatchTraits<typename T::VariantType>::Visit(
+ &v->variant(), std::forward<VisitFn>(fn));
+ }
+};
+
+template <typename VariantType, typename CasesType>
+constexpr auto CreateMatcherImpl(CasesType cases) {
+ return VariantMatcherImpl<MatchTraits<VariantType>, CasesType>(
+ std::move(cases));
+}
+
+} // namespace match_internal
+
+// See file remarks.
+template <typename From, typename To = void, typename... CaseFnTypes>
+constexpr auto MakeMatcher(CaseFnTypes... fns) {
+ return match_internal::CreateMatcherImpl<From>(
+ match_internal::MakeMatchCases<To>(fns...));
+}
+
+// See file remarks.
+//
+// Note that the order of template arguments differs from MakeMatcher; it is
+// expected that 'From' is always deduced (but it can be useful to specify 'To'
+// explicitly).
+template <typename To = void, typename From, typename... CaseFnTypes>
+constexpr auto Match(From&& v, CaseFnTypes... fns) {
+ // 'From' is intended to be deduced. For MakeMatcher, we want V (not e.g. V
+ // const&).
+ auto m = MakeMatcher<std::decay_t<From>, To>(fns...);
+ // The full type is still relevant for forwarding.
+ return m.Match(std::forward<From>(v));
+}
+
+template <typename To = void, typename From, typename... CaseFnTypes>
+constexpr auto Match(From* v, CaseFnTypes... fns) {
+ // 'From' is intended to be deduced. For MakeMatcher, we want V (not e.g. V
+ // const*).
+ auto m = MakeMatcher<std::decay_t<From>, To>(fns...);
+ return m.Match(v);
+}
+
+} // namespace fcp
+
+#endif // FCP_BASE_MATCH_H_