diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/call/server.rs | 20 | ||||
-rw-r--r-- | src/channel.rs | 2 | ||||
-rw-r--r-- | src/lib.rs | 4 | ||||
-rw-r--r-- | src/server.rs | 39 |
4 files changed, 61 insertions, 4 deletions
diff --git a/src/call/server.rs b/src/call/server.rs index add9874..875555e 100644 --- a/src/call/server.rs +++ b/src/call/server.rs @@ -25,8 +25,10 @@ use crate::codec::{DeserializeFn, SerializeFn}; use crate::cq::CompletionQueue; use crate::error::{Error, Result}; use crate::metadata::Metadata; +use crate::server::ServerChecker; use crate::server::{BoxHandler, RequestCallContext}; use crate::task::{BatchFuture, CallTag, Executor, Kicker}; +use crate::CheckResult; pub struct Deadline { spec: gpr_timespec, @@ -74,12 +76,13 @@ impl RequestContext { cq: &CompletionQueue, rc: &mut RequestCallContext, ) -> result::Result<(), Self> { + let checker = rc.get_checker(); let handler = unsafe { rc.get_handler(self.method()) }; match handler { Some(handler) => match handler.method_type() { MethodType::Unary | MethodType::ServerStreaming => Err(self), _ => { - execute(self, cq, None, handler); + execute(self, cq, None, handler, checker); Ok(()) } }, @@ -225,9 +228,10 @@ impl UnaryRequestContext { cq: &CompletionQueue, reader: Option<MessageReader>, ) { + let checker = rc.get_checker(); let handler = unsafe { rc.get_handler(self.request.method()).unwrap() }; if reader.is_some() { - return execute(self.request, cq, reader, handler); + return execute(self.request, cq, reader, handler, checker); } let status = RpcStatus::new(RpcStatusCode::INTERNAL, Some("No payload".to_owned())); @@ -775,7 +779,19 @@ fn execute( cq: &CompletionQueue, payload: Option<MessageReader>, f: &mut BoxHandler, + mut checkers: Vec<Box<dyn ServerChecker>>, ) { let rpc_ctx = RpcContext::new(ctx, cq); + + for handler in checkers.iter_mut() { + match handler.check(&rpc_ctx) { + CheckResult::Continue => {} + CheckResult::Abort(status) => { + rpc_ctx.call().abort(&status); + return; + } + } + } + f.handle(rpc_ctx, payload) } diff --git a/src/channel.rs b/src/channel.rs index bdf95ce..a33a4be 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -28,7 +28,7 @@ pub use crate::grpc_sys::{ /// Ref: http://www.grpc.io/docs/guides/wire.html#user-agents fn format_user_agent_string(agent: &str) -> CString { - let version = "0.7.0"; + let version = "0.7.1"; let trimed_agent = agent.trim(); let val = if trimed_agent.is_empty() { format!("grpc-rust/{}", version) @@ -77,4 +77,6 @@ pub use crate::security::{ CertificateRequestType, ChannelCredentials, ChannelCredentialsBuilder, ServerCredentials, ServerCredentialsBuilder, ServerCredentialsFetcher, }; -pub use crate::server::{Server, ServerBuilder, Service, ServiceBuilder, ShutdownFuture}; +pub use crate::server::{ + CheckResult, Server, ServerBuilder, ServerChecker, Service, ServiceBuilder, ShutdownFuture, +}; diff --git a/src/server.rs b/src/server.rs index 8cb6a87..0f01690 100644 --- a/src/server.rs +++ b/src/server.rs @@ -21,6 +21,7 @@ use crate::env::Environment; use crate::error::{Error, Result}; use crate::task::{CallTag, CqFuture}; use crate::RpcContext; +use crate::RpcStatus; const DEFAULT_REQUEST_SLOTS_PER_CQ: usize = 1024; @@ -266,6 +267,24 @@ impl ServiceBuilder { } } +/// Used to indicate the result of the check. If it returns `Abort`, +/// skip the subsequent checkers and abort the grpc call. +pub enum CheckResult { + Continue, + Abort(RpcStatus), +} + +pub trait ServerChecker: Send { + fn check(&mut self, ctx: &RpcContext) -> CheckResult; + fn box_clone(&self) -> Box<dyn ServerChecker>; +} + +impl Clone for Box<dyn ServerChecker> { + fn clone(&self) -> Self { + self.box_clone() + } +} + /// A gRPC service. /// /// Use [`ServiceBuilder`] to build a [`Service`]. @@ -280,6 +299,7 @@ pub struct ServerBuilder { args: Option<ChannelArgs>, slots_per_cq: usize, handlers: HashMap<&'static [u8], BoxHandler>, + checkers: Vec<Box<dyn ServerChecker>>, } impl ServerBuilder { @@ -291,6 +311,7 @@ impl ServerBuilder { args: None, slots_per_cq: DEFAULT_REQUEST_SLOTS_PER_CQ, handlers: HashMap::new(), + checkers: Vec::new(), } } @@ -320,6 +341,16 @@ impl ServerBuilder { self } + /// Add a custom checker to handle some tasks before the grpc call handler starts. + /// This allows users to operate grpc call based on the context. Users can add + /// multiple checkers and they will be executed in the order added. + /// + /// TODO: Extend this interface to intercepte each payload like grpc-c++. + pub fn add_checker<C: ServerChecker + 'static>(mut self, checker: C) -> ServerBuilder { + self.checkers.push(Box::new(checker)); + self + } + /// Finalize the [`ServerBuilder`] and build the [`Server`]. pub fn build(mut self) -> Result<Server> { let args = self @@ -355,6 +386,7 @@ impl ServerBuilder { slots_per_cq: self.slots_per_cq, }), handlers: self.handlers, + checkers: self.checkers, }) } } @@ -439,6 +471,7 @@ pub type BoxHandler = Box<dyn CloneableHandler>; pub struct RequestCallContext { server: Arc<ServerCore>, registry: Arc<UnsafeCell<HashMap<&'static [u8], BoxHandler>>>, + checkers: Vec<Box<dyn ServerChecker>>, } impl RequestCallContext { @@ -449,6 +482,10 @@ impl RequestCallContext { let registry = &mut *self.registry.get(); registry.get_mut(path) } + + pub(crate) fn get_checker(&self) -> Vec<Box<dyn ServerChecker>> { + self.checkers.clone() + } } // Apparently, its life time is guaranteed by the ref count, hence is safe to be sent @@ -506,6 +543,7 @@ pub struct Server { env: Arc<Environment>, core: Arc<ServerCore>, handlers: HashMap<&'static [u8], BoxHandler>, + checkers: Vec<Box<dyn ServerChecker>>, } impl Server { @@ -549,6 +587,7 @@ impl Server { let rc = RequestCallContext { server: self.core.clone(), registry: Arc::new(UnsafeCell::new(registry)), + checkers: self.checkers.clone(), }; for _ in 0..self.core.slots_per_cq { request_call(rc.clone(), cq); |