aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/call/server.rs20
-rw-r--r--src/channel.rs2
-rw-r--r--src/lib.rs4
-rw-r--r--src/server.rs39
4 files changed, 61 insertions, 4 deletions
diff --git a/src/call/server.rs b/src/call/server.rs
index add9874..875555e 100644
--- a/src/call/server.rs
+++ b/src/call/server.rs
@@ -25,8 +25,10 @@ use crate::codec::{DeserializeFn, SerializeFn};
use crate::cq::CompletionQueue;
use crate::error::{Error, Result};
use crate::metadata::Metadata;
+use crate::server::ServerChecker;
use crate::server::{BoxHandler, RequestCallContext};
use crate::task::{BatchFuture, CallTag, Executor, Kicker};
+use crate::CheckResult;
pub struct Deadline {
spec: gpr_timespec,
@@ -74,12 +76,13 @@ impl RequestContext {
cq: &CompletionQueue,
rc: &mut RequestCallContext,
) -> result::Result<(), Self> {
+ let checker = rc.get_checker();
let handler = unsafe { rc.get_handler(self.method()) };
match handler {
Some(handler) => match handler.method_type() {
MethodType::Unary | MethodType::ServerStreaming => Err(self),
_ => {
- execute(self, cq, None, handler);
+ execute(self, cq, None, handler, checker);
Ok(())
}
},
@@ -225,9 +228,10 @@ impl UnaryRequestContext {
cq: &CompletionQueue,
reader: Option<MessageReader>,
) {
+ let checker = rc.get_checker();
let handler = unsafe { rc.get_handler(self.request.method()).unwrap() };
if reader.is_some() {
- return execute(self.request, cq, reader, handler);
+ return execute(self.request, cq, reader, handler, checker);
}
let status = RpcStatus::new(RpcStatusCode::INTERNAL, Some("No payload".to_owned()));
@@ -775,7 +779,19 @@ fn execute(
cq: &CompletionQueue,
payload: Option<MessageReader>,
f: &mut BoxHandler,
+ mut checkers: Vec<Box<dyn ServerChecker>>,
) {
let rpc_ctx = RpcContext::new(ctx, cq);
+
+ for handler in checkers.iter_mut() {
+ match handler.check(&rpc_ctx) {
+ CheckResult::Continue => {}
+ CheckResult::Abort(status) => {
+ rpc_ctx.call().abort(&status);
+ return;
+ }
+ }
+ }
+
f.handle(rpc_ctx, payload)
}
diff --git a/src/channel.rs b/src/channel.rs
index bdf95ce..a33a4be 100644
--- a/src/channel.rs
+++ b/src/channel.rs
@@ -28,7 +28,7 @@ pub use crate::grpc_sys::{
/// Ref: http://www.grpc.io/docs/guides/wire.html#user-agents
fn format_user_agent_string(agent: &str) -> CString {
- let version = "0.7.0";
+ let version = "0.7.1";
let trimed_agent = agent.trim();
let val = if trimed_agent.is_empty() {
format!("grpc-rust/{}", version)
diff --git a/src/lib.rs b/src/lib.rs
index 2bac988..2bb0c11 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -77,4 +77,6 @@ pub use crate::security::{
CertificateRequestType, ChannelCredentials, ChannelCredentialsBuilder, ServerCredentials,
ServerCredentialsBuilder, ServerCredentialsFetcher,
};
-pub use crate::server::{Server, ServerBuilder, Service, ServiceBuilder, ShutdownFuture};
+pub use crate::server::{
+ CheckResult, Server, ServerBuilder, ServerChecker, Service, ServiceBuilder, ShutdownFuture,
+};
diff --git a/src/server.rs b/src/server.rs
index 8cb6a87..0f01690 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -21,6 +21,7 @@ use crate::env::Environment;
use crate::error::{Error, Result};
use crate::task::{CallTag, CqFuture};
use crate::RpcContext;
+use crate::RpcStatus;
const DEFAULT_REQUEST_SLOTS_PER_CQ: usize = 1024;
@@ -266,6 +267,24 @@ impl ServiceBuilder {
}
}
+/// Used to indicate the result of the check. If it returns `Abort`,
+/// skip the subsequent checkers and abort the grpc call.
+pub enum CheckResult {
+ Continue,
+ Abort(RpcStatus),
+}
+
+pub trait ServerChecker: Send {
+ fn check(&mut self, ctx: &RpcContext) -> CheckResult;
+ fn box_clone(&self) -> Box<dyn ServerChecker>;
+}
+
+impl Clone for Box<dyn ServerChecker> {
+ fn clone(&self) -> Self {
+ self.box_clone()
+ }
+}
+
/// A gRPC service.
///
/// Use [`ServiceBuilder`] to build a [`Service`].
@@ -280,6 +299,7 @@ pub struct ServerBuilder {
args: Option<ChannelArgs>,
slots_per_cq: usize,
handlers: HashMap<&'static [u8], BoxHandler>,
+ checkers: Vec<Box<dyn ServerChecker>>,
}
impl ServerBuilder {
@@ -291,6 +311,7 @@ impl ServerBuilder {
args: None,
slots_per_cq: DEFAULT_REQUEST_SLOTS_PER_CQ,
handlers: HashMap::new(),
+ checkers: Vec::new(),
}
}
@@ -320,6 +341,16 @@ impl ServerBuilder {
self
}
+ /// Add a custom checker to handle some tasks before the grpc call handler starts.
+ /// This allows users to operate grpc call based on the context. Users can add
+ /// multiple checkers and they will be executed in the order added.
+ ///
+ /// TODO: Extend this interface to intercepte each payload like grpc-c++.
+ pub fn add_checker<C: ServerChecker + 'static>(mut self, checker: C) -> ServerBuilder {
+ self.checkers.push(Box::new(checker));
+ self
+ }
+
/// Finalize the [`ServerBuilder`] and build the [`Server`].
pub fn build(mut self) -> Result<Server> {
let args = self
@@ -355,6 +386,7 @@ impl ServerBuilder {
slots_per_cq: self.slots_per_cq,
}),
handlers: self.handlers,
+ checkers: self.checkers,
})
}
}
@@ -439,6 +471,7 @@ pub type BoxHandler = Box<dyn CloneableHandler>;
pub struct RequestCallContext {
server: Arc<ServerCore>,
registry: Arc<UnsafeCell<HashMap<&'static [u8], BoxHandler>>>,
+ checkers: Vec<Box<dyn ServerChecker>>,
}
impl RequestCallContext {
@@ -449,6 +482,10 @@ impl RequestCallContext {
let registry = &mut *self.registry.get();
registry.get_mut(path)
}
+
+ pub(crate) fn get_checker(&self) -> Vec<Box<dyn ServerChecker>> {
+ self.checkers.clone()
+ }
}
// Apparently, its life time is guaranteed by the ref count, hence is safe to be sent
@@ -506,6 +543,7 @@ pub struct Server {
env: Arc<Environment>,
core: Arc<ServerCore>,
handlers: HashMap<&'static [u8], BoxHandler>,
+ checkers: Vec<Box<dyn ServerChecker>>,
}
impl Server {
@@ -549,6 +587,7 @@ impl Server {
let rc = RequestCallContext {
server: self.core.clone(),
registry: Arc::new(UnsafeCell::new(registry)),
+ checkers: self.checkers.clone(),
};
for _ in 0..self.core.slots_per_cq {
request_call(rc.clone(), cq);