1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
|
/*
* Copyright (C) 2021 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
//! Provides a backing task to implement a Dispatcher
use crate::boot_time::{BootTime, Duration};
use anyhow::{bail, Result};
use log::{debug, trace, warn};
use std::collections::HashMap;
use tokio::sync::{mpsc, oneshot};
use super::{Command, QueryError, Response};
use crate::network::{Network, ServerInfo, SocketTagger, ValidationReporter};
use crate::{config, network};
pub struct Driver {
command_rx: mpsc::Receiver<Command>,
networks: HashMap<u32, Network>,
validation: ValidationReporter,
tagger: SocketTagger,
config_cache: config::Cache,
}
fn debug_err(r: Result<()>) {
if let Err(e) = r {
debug!("Dispatcher loop got {:?}", e);
}
}
impl Driver {
pub fn new(
command_rx: mpsc::Receiver<Command>,
validation: ValidationReporter,
tagger: SocketTagger,
) -> Self {
Self {
command_rx,
networks: HashMap::new(),
validation,
tagger,
config_cache: config::Cache::new(),
}
}
pub async fn drive(mut self) -> Result<()> {
loop {
self.drive_once().await?
}
}
async fn drive_once(&mut self) -> Result<()> {
if let Some(command) = self.command_rx.recv().await {
trace!("dispatch command: {:?}", command);
match command {
Command::Probe { info, timeout } => debug_err(self.probe(info, timeout).await),
Command::Query { net_id, base64_query, expired_time, resp } => {
debug_err(self.query(net_id, base64_query, expired_time, resp).await)
}
Command::Clear { net_id } => {
self.networks.remove(&net_id);
self.config_cache.garbage_collect();
}
Command::Exit => {
bail!("Death due to Exit")
}
}
Ok(())
} else {
bail!("Death due to command_tx dying")
}
}
async fn query(
&mut self,
net_id: u32,
query: String,
expiry: BootTime,
response: oneshot::Sender<Response>,
) -> Result<()> {
if let Some(network) = self.networks.get_mut(&net_id) {
network.query(network::Query { query, response, expiry }).await?;
} else {
warn!("Tried to send a query to non-existent network net_id={}", net_id);
response.send(Response::Error { error: QueryError::Unexpected }).unwrap_or_else(|_| {
warn!("Unable to send reply for non-existent network net_id={}", net_id);
})
}
Ok(())
}
async fn probe(&mut self, info: ServerInfo, timeout: Duration) -> Result<()> {
use std::collections::hash_map::Entry;
if !self.networks.get(&info.net_id).map_or(true, |net| net.get_info() == &info) {
// If we have a network registered to the provided net_id, but the server info doesn't
// match, our API has been used incorrectly. Attempt to recover by deleting the old
// network and recreating it according to the probe request.
warn!("Probing net_id={} with mismatched server info", info.net_id);
self.networks.remove(&info.net_id);
}
// Can't use or_insert_with because creating a network may fail
let net = match self.networks.entry(info.net_id) {
Entry::Occupied(network) => network.into_mut(),
Entry::Vacant(vacant) => {
let config = self.config_cache.from_cert_path(&info.cert_path)?;
vacant.insert(
Network::new(info, config, self.validation.clone(), self.tagger.clone())
.await?,
)
}
};
net.probe(timeout).await?;
Ok(())
}
}
|