diff options
author | Ta-Wei Tu <tu.da.wei@gmail.com> | 2021-09-10 13:00:44 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-09-10 13:00:44 +0800 |
commit | fa2d21716b31a875a3334eab74a885a18faedca6 (patch) | |
tree | a508f1bd674c6518f08bad6749af7c9ffb5f0059 | |
parent | 72171a33269073a4c09940e948e82b93bf0fcf97 (diff) | |
download | grpc-grpc-fa2d21716b31a875a3334eab74a885a18faedca6.tar.gz |
[binder] Fix server-side recv_trailing_metadata (#27184)
According to the [transport explainer](https://grpc.github.io/grpc/core/md_doc_core_transport_explainer.html), the server-side `recv_trailing_metadata` should not be completed before sending trailing metadata to the client.
6 files changed, 129 insertions, 63 deletions
diff --git a/src/core/ext/transport/binder/transport/binder_stream.h b/src/core/ext/transport/binder/transport/binder_stream.h index 0d27333fd2..176d1f59b3 100644 --- a/src/core/ext/transport/binder/transport/binder_stream.h +++ b/src/core/ext/transport/binder/transport/binder_stream.h @@ -108,6 +108,9 @@ struct grpc_binder_stream { bool* call_failed_before_recv_message = nullptr; grpc_metadata_batch* recv_trailing_metadata; grpc_closure* recv_trailing_metadata_finished = nullptr; + + bool trailing_metadata_sent = false; + bool need_to_call_trailing_metadata_callback = false; }; #endif // GRPC_CORE_EXT_TRANSPORT_BINDER_TRANSPORT_BINDER_STREAM_H diff --git a/src/core/ext/transport/binder/transport/binder_transport.cc b/src/core/ext/transport/binder/transport/binder_transport.cc index 75cad8c968..3327f42e16 100644 --- a/src/core/ext/transport/binder/transport/binder_transport.cc +++ b/src/core/ext/transport/binder/transport/binder_transport.cc @@ -157,6 +157,7 @@ static void cancel_stream_locked(grpc_binder_transport* gbt, grpc_core::ExecCtx::Run(DEBUG_LOCATION, gbs->recv_message_ready, GRPC_ERROR_REF(error)); gbs->recv_message_ready = nullptr; + gbs->recv_message->reset(); gbs->recv_message = nullptr; gbs->call_failed_before_recv_message = nullptr; } @@ -173,11 +174,13 @@ static void cancel_stream_locked(grpc_binder_transport* gbt, static void recv_initial_metadata_locked(void* arg, grpc_error_handle /*error*/) { - gpr_log(GPR_INFO, "recv_initial_metadata_locked"); RecvInitialMetadataArgs* args = static_cast<RecvInitialMetadataArgs*>(arg); - grpc_binder_stream* gbs = args->gbs; + gpr_log(GPR_INFO, + "recv_initial_metadata_locked is_client = %d is_closed = %d", + gbs->is_client, gbs->is_closed); + if (!gbs->is_closed) { grpc_error_handle error = [&] { GPR_ASSERT(gbs->recv_initial_metadata); @@ -200,11 +203,12 @@ static void recv_initial_metadata_locked(void* arg, } static void recv_message_locked(void* arg, grpc_error_handle /*error*/) { - gpr_log(GPR_INFO, "recv_message_locked"); RecvMessageArgs* args = static_cast<RecvMessageArgs*>(arg); - grpc_binder_stream* gbs = args->gbs; + gpr_log(GPR_INFO, "recv_message_locked is_client = %d is_closed = %d", + gbs->is_client, gbs->is_closed); + if (!gbs->is_closed) { grpc_error_handle error = [&] { GPR_ASSERT(gbs->recv_message); @@ -246,11 +250,13 @@ static void recv_message_locked(void* arg, grpc_error_handle /*error*/) { static void recv_trailing_metadata_locked(void* arg, grpc_error_handle /*error*/) { - gpr_log(GPR_INFO, "recv_trailing_metadata_locked"); RecvTrailingMetadataArgs* args = static_cast<RecvTrailingMetadataArgs*>(arg); - grpc_binder_stream* gbs = args->gbs; + gpr_log(GPR_INFO, + "recv_trailing_metadata_locked is_client = %d is_closed = %d", + gbs->is_client, gbs->is_closed); + if (!gbs->is_closed) { grpc_error_handle error = [&] { GPR_ASSERT(gbs->recv_trailing_metadata); @@ -284,10 +290,20 @@ static void recv_trailing_metadata_locked(void* arg, return GRPC_ERROR_NONE; }(); - grpc_closure* cb = gbs->recv_trailing_metadata_finished; - gbs->recv_trailing_metadata_finished = nullptr; - gbs->recv_trailing_metadata = nullptr; - grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, error); + if (gbs->is_client || gbs->trailing_metadata_sent) { + grpc_closure* cb = gbs->recv_trailing_metadata_finished; + gbs->recv_trailing_metadata_finished = nullptr; + gbs->recv_trailing_metadata = nullptr; + grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, error); + } else { + // According to transport explaineer - "Server extra: This op shouldn't + // actually be considered complete until the server has also sent trailing + // metadata to provide the other side with final status" + // + // We haven't sent trailing metadata yet, so we have to delay completing + // the recv_trailing_metadata callback. + gbs->need_to_call_trailing_metadata_callback = true; + } } GRPC_BINDER_STREAM_UNREF(gbs, "recv_trailing_metadata"); } @@ -304,23 +320,29 @@ static void perform_stream_op_locked(void* stream_op, GPR_ASSERT(!op->send_initial_metadata && !op->send_message && !op->send_trailing_metadata && !op->recv_initial_metadata && !op->recv_message && !op->recv_trailing_metadata); - gpr_log(GPR_INFO, "cancel_stream"); - // Send trailing metadata to inform the other end about the cancellation, - // regardless if we'd already done that or not. - grpc_binder::Transaction cancel_tx(gbs->GetTxCode(), gbs->GetThenIncSeq(), - gbt->is_client); - cancel_tx.SetSuffix(grpc_binder::Metadata{}); - absl::Status status = gbt->wire_writer->RpcCall(cancel_tx); + gpr_log(GPR_INFO, "cancel_stream is_client = %d", gbs->is_client); + if (!gbs->is_client) { + // Send trailing metadata to inform the other end about the cancellation, + // regardless if we'd already done that or not. + grpc_binder::Transaction cancel_tx(gbs->GetTxCode(), gbs->GetThenIncSeq(), + gbt->is_client); + cancel_tx.SetSuffix(grpc_binder::Metadata{}); + cancel_tx.SetStatus(1); + absl::Status status = gbt->wire_writer->RpcCall(cancel_tx); + } cancel_stream_locked(gbt, gbs, op->payload->cancel_stream.cancel_error); if (op->on_complete != nullptr) { - grpc_core::ExecCtx::Run(DEBUG_LOCATION, op->on_complete, - absl_status_to_grpc_error(status)); + grpc_core::ExecCtx::Run(DEBUG_LOCATION, op->on_complete, GRPC_ERROR_NONE); } GRPC_BINDER_STREAM_UNREF(gbs, "perform_stream_op"); return; } if (gbs->is_closed) { + if (op->send_message) { + // Reset the send_message payload to prevent memory leaks. + op->payload->send_message.send_message.reset(); + } if (op->recv_initial_metadata) { grpc_core::ExecCtx::Run( DEBUG_LOCATION, @@ -520,6 +542,21 @@ static void perform_stream_op_locked(void* stream_op, absl::Status status = absl::OkStatus(); if (tx) { status = gbt->wire_writer->RpcCall(*tx); + if (!gbs->is_client && op->send_trailing_metadata) { + gbs->trailing_metadata_sent = true; + // According to transport explaineer - "Server extra: This op shouldn't + // actually be considered complete until the server has also sent trailing + // metadata to provide the other side with final status" + // + // Because we've done sending trailing metadata here, we can safely + // complete the recv_trailing_metadata callback here. + if (gbs->need_to_call_trailing_metadata_callback) { + grpc_closure* cb = gbs->recv_trailing_metadata_finished; + gbs->recv_trailing_metadata_finished = nullptr; + grpc_core::ExecCtx::Run(DEBUG_LOCATION, cb, GRPC_ERROR_NONE); + gbs->need_to_call_trailing_metadata_callback = false; + } + } } // Note that this should only be scheduled when all non-recv ops are // completed @@ -534,9 +571,10 @@ static void perform_stream_op_locked(void* stream_op, static void perform_stream_op(grpc_transport* gt, grpc_stream* gs, grpc_transport_stream_op_batch* op) { GPR_TIMER_SCOPE("perform_stream_op", 0); - gpr_log(GPR_INFO, "%s = %p %p %p", __func__, gt, gs, op); grpc_binder_transport* gbt = reinterpret_cast<grpc_binder_transport*>(gt); grpc_binder_stream* gbs = reinterpret_cast<grpc_binder_stream*>(gs); + gpr_log(GPR_INFO, "%s = %p %p %p is_client = %d", __func__, gt, gs, op, + gbs->is_client); GRPC_BINDER_STREAM_REF(gbs, "perform_stream_op"); op->handler_private.extra_arg = gbs; gbt->combiner->Run(GRPC_CLOSURE_INIT(&op->handler_private.closure, diff --git a/src/core/ext/transport/binder/utils/transport_stream_receiver.h b/src/core/ext/transport/binder/utils/transport_stream_receiver.h index fa1d4774ff..1b306b1378 100644 --- a/src/core/ext/transport/binder/utils/transport_stream_receiver.h +++ b/src/core/ext/transport/binder/utils/transport_stream_receiver.h @@ -60,15 +60,6 @@ class TransportStreamReceiver { virtual void NotifyRecvTrailingMetadata( StreamIdentifier id, absl::StatusOr<Metadata> trailing_metadata, int status) = 0; - - // Trailing metadata marks the end of one-side of the stream. Thus, after - // receiving trailing metadata from the other-end, we know that there will - // never be in-coming message data anymore, and all recv_message callbacks - // registered will never be satisfied. This function cancels all such - // callbacks gracefully (with GRPC_ERROR_NONE) to avoid being blocked waiting - // for them. - virtual void CancelRecvMessageCallbacksDueToTrailingMetadata( - StreamIdentifier id) = 0; // Remove all entries associated with stream number `id`. virtual void CancelStream(StreamIdentifier id) = 0; diff --git a/src/core/ext/transport/binder/utils/transport_stream_receiver_impl.cc b/src/core/ext/transport/binder/utils/transport_stream_receiver_impl.cc index 36241423e6..70652db7a3 100644 --- a/src/core/ext/transport/binder/utils/transport_stream_receiver_impl.cc +++ b/src/core/ext/transport/binder/utils/transport_stream_receiver_impl.cc @@ -37,7 +37,11 @@ void TransportStreamReceiverImpl::RegisterRecvInitialMetadata( grpc_core::MutexLock l(&m_); auto iter = pending_initial_metadata_.find(id); if (iter == pending_initial_metadata_.end()) { - initial_metadata_cbs_[id] = std::move(cb); + if (trailing_metadata_recvd_.count(id)) { + cb(absl::CancelledError("")); + } else { + initial_metadata_cbs_[id] = std::move(cb); + } cb = nullptr; } else { initial_metadata = std::move(iter->second.front()); @@ -63,7 +67,7 @@ void TransportStreamReceiverImpl::RegisterRecvMessage( if (iter == pending_message_.end()) { // If we'd already received trailing-metadata and there's no pending // messages, cancel the callback. - if (recv_message_cancelled_.count(id)) { + if (trailing_metadata_recvd_.count(id)) { cb(absl::CancelledError( TransportStreamReceiver::kGrpcBinderTransportCancelledGracefully)); } else { @@ -157,7 +161,7 @@ void TransportStreamReceiverImpl::NotifyRecvTrailingMetadata( // parsed after message data, we can safely cancel all upcoming callbacks of // recv_message. gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_); - CancelRecvMessageCallbacksDueToTrailingMetadata(id); + OnRecvTrailingMetadata(id); TrailingMetadataCallbackType cb; { grpc_core::MutexLock l(&m_); @@ -174,51 +178,73 @@ void TransportStreamReceiverImpl::NotifyRecvTrailingMetadata( cb(std::move(trailing_metadata), status); } -void TransportStreamReceiverImpl:: - CancelRecvMessageCallbacksDueToTrailingMetadata(StreamIdentifier id) { - gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_); - MessageDataCallbackType cb = nullptr; +void TransportStreamReceiverImpl::CancelInitialMetadataCallback( + StreamIdentifier id, absl::Status error) { + InitialMetadataCallbackType callback = nullptr; { grpc_core::MutexLock l(&m_); - auto iter = message_cbs_.find(id); - if (iter != message_cbs_.end()) { - cb = std::move(iter->second); - message_cbs_.erase(iter); - } - recv_message_cancelled_.insert(id); - } - if (cb != nullptr) { - // The registered callback will never be satisfied. Cancel it. - cb(absl::CancelledError( - TransportStreamReceiver::kGrpcBinderTransportCancelledGracefully)); - } -} - -void TransportStreamReceiverImpl::CancelStream(StreamIdentifier id) { - gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_); - grpc_core::MutexLock l(&m_); - { auto iter = initial_metadata_cbs_.find(id); if (iter != initial_metadata_cbs_.end()) { - iter->second(absl::CancelledError("Stream cancelled")); + callback = std::move(iter->second); initial_metadata_cbs_.erase(iter); } } + if (callback != nullptr) { + std::move(callback)(error); + } +} + +void TransportStreamReceiverImpl::CancelMessageCallback(StreamIdentifier id, + absl::Status error) { + MessageDataCallbackType callback = nullptr; { + grpc_core::MutexLock l(&m_); auto iter = message_cbs_.find(id); if (iter != message_cbs_.end()) { - iter->second(absl::CancelledError("Stream cancelled")); + callback = std::move(iter->second); message_cbs_.erase(iter); } } + if (callback != nullptr) { + std::move(callback)(error); + } +} + +void TransportStreamReceiverImpl::CancelTrailingMetadataCallback( + StreamIdentifier id, absl::Status error) { + TrailingMetadataCallbackType callback = nullptr; { + grpc_core::MutexLock l(&m_); auto iter = trailing_metadata_cbs_.find(id); if (iter != trailing_metadata_cbs_.end()) { - iter->second(absl::CancelledError("Stream cancelled"), 0); + callback = std::move(iter->second); trailing_metadata_cbs_.erase(iter); } } - recv_message_cancelled_.erase(id); + if (callback != nullptr) { + std::move(callback)(error, 0); + } +} + +void TransportStreamReceiverImpl::OnRecvTrailingMetadata(StreamIdentifier id) { + gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_); + m_.Lock(); + trailing_metadata_recvd_.insert(id); + m_.Unlock(); + CancelInitialMetadataCallback(id, absl::CancelledError("")); + CancelMessageCallback( + id, + absl::CancelledError( + TransportStreamReceiver::kGrpcBinderTransportCancelledGracefully)); +} + +void TransportStreamReceiverImpl::CancelStream(StreamIdentifier id) { + gpr_log(GPR_INFO, "%s id = %d is_client = %d", __func__, id, is_client_); + CancelInitialMetadataCallback(id, absl::CancelledError("Stream cancelled")); + CancelMessageCallback(id, absl::CancelledError("Stream cancelled")); + CancelTrailingMetadataCallback(id, absl::CancelledError("Stream cancelled")); + grpc_core::MutexLock l(&m_); + trailing_metadata_recvd_.erase(id); pending_initial_metadata_.erase(id); pending_message_.erase(id); pending_trailing_metadata_.erase(id); diff --git a/src/core/ext/transport/binder/utils/transport_stream_receiver_impl.h b/src/core/ext/transport/binder/utils/transport_stream_receiver_impl.h index 154a0db04d..2cf0355f36 100644 --- a/src/core/ext/transport/binder/utils/transport_stream_receiver_impl.h +++ b/src/core/ext/transport/binder/utils/transport_stream_receiver_impl.h @@ -50,11 +50,21 @@ class TransportStreamReceiverImpl : public TransportStreamReceiver { absl::StatusOr<Metadata> trailing_metadata, int status) override; - void CancelRecvMessageCallbacksDueToTrailingMetadata( - StreamIdentifier id) override; void CancelStream(StreamIdentifier id) override; private: + // Trailing metadata marks the end of one-side of the stream. Thus, after + // receiving trailing metadata from the other-end, we know that there will + // never be in-coming message data anymore, and all recv_message callbacks + // (as well as recv_initial_metadata callback, if there's any) registered will + // never be satisfied. This function cancels all such callbacks gracefully + // (with GRPC_ERROR_NONE) to avoid being blocked waiting for them. + void OnRecvTrailingMetadata(StreamIdentifier id); + + void CancelInitialMetadataCallback(StreamIdentifier id, absl::Status error); + void CancelMessageCallback(StreamIdentifier id, absl::Status error); + void CancelTrailingMetadataCallback(StreamIdentifier id, absl::Status error); + std::map<StreamIdentifier, InitialMetadataCallbackType> initial_metadata_cbs_; std::map<StreamIdentifier, MessageDataCallbackType> message_cbs_; std::map<StreamIdentifier, TrailingMetadataCallbackType> @@ -90,7 +100,7 @@ class TransportStreamReceiverImpl : public TransportStreamReceiver { // when RegisterRecvMessage() gets called, we should check whether // recv_message_cancelled_ contains the corresponding stream ID, and if so, // directly cancel the callback gracefully without pending it. - std::set<StreamIdentifier> recv_message_cancelled_ ABSL_GUARDED_BY(m_); + std::set<StreamIdentifier> trailing_metadata_recvd_ ABSL_GUARDED_BY(m_); bool is_client_; // Called when receiving initial metadata to inform the server about a new diff --git a/test/core/transport/binder/mock_objects.h b/test/core/transport/binder/mock_objects.h index 5a584c8eaa..d749257954 100644 --- a/test/core/transport/binder/mock_objects.h +++ b/test/core/transport/binder/mock_objects.h @@ -105,8 +105,6 @@ class MockTransportStreamReceiver : public TransportStreamReceiver { (StreamIdentifier, absl::StatusOr<std::string>), (override)); MOCK_METHOD(void, NotifyRecvTrailingMetadata, (StreamIdentifier, absl::StatusOr<Metadata>, int), (override)); - MOCK_METHOD(void, CancelRecvMessageCallbacksDueToTrailingMetadata, - (StreamIdentifier), (override)); MOCK_METHOD(void, CancelStream, (StreamIdentifier), (override)); }; |