summaryrefslogtreecommitdiff
path: root/keystore2/src/rkpd_client.rs
blob: 931782477a365151f8732800a8bcd6f4043e241f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
// Copyright 2022, The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! Helper wrapper around RKPD interface.

use crate::error::{map_binder_status_code, Error, ResponseCode};
use crate::watchdog_helper::watchdog as wd;
use android_security_rkp_aidl::aidl::android::security::rkp::{
    IGetKeyCallback::BnGetKeyCallback, IGetKeyCallback::ErrorCode::ErrorCode as GetKeyErrorCode,
    IGetKeyCallback::IGetKeyCallback, IGetRegistrationCallback::BnGetRegistrationCallback,
    IGetRegistrationCallback::IGetRegistrationCallback, IRegistration::IRegistration,
    IRemoteProvisioning::IRemoteProvisioning,
    IStoreUpgradedKeyCallback::BnStoreUpgradedKeyCallback,
    IStoreUpgradedKeyCallback::IStoreUpgradedKeyCallback,
    RemotelyProvisionedKey::RemotelyProvisionedKey,
};
use android_security_rkp_aidl::binder::{BinderFeatures, Interface, Strong};
use anyhow::{Context, Result};
use message_macro::source_location_msg;
use std::sync::Mutex;
use std::time::Duration;
use tokio::sync::oneshot;
use tokio::time::timeout;

// Normally, we block indefinitely when making calls outside of keystore and rely on watchdog to
// report deadlocks. However, RKPD is mainline updatable. Also, calls to RKPD may wait on network
// for certificates. So, we err on the side of caution and timeout instead.
static RKPD_TIMEOUT: Duration = Duration::from_secs(10);

fn tokio_rt() -> tokio::runtime::Runtime {
    tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap()
}

/// Thread-safe channel for sending a value once and only once. If a value has
/// already been send, subsequent calls to send will noop.
struct SafeSender<T> {
    inner: Mutex<Option<oneshot::Sender<T>>>,
}

impl<T> SafeSender<T> {
    fn new(sender: oneshot::Sender<T>) -> Self {
        Self { inner: Mutex::new(Some(sender)) }
    }

    fn send(&self, value: T) {
        if let Some(inner) = self.inner.lock().unwrap().take() {
            // It's possible for the corresponding receiver to time out and be dropped. In this
            // case send() will fail. This error is not actionable though, so only log the error.
            if inner.send(value).is_err() {
                log::error!("SafeSender::send() failed");
            }
        }
    }
}

struct GetRegistrationCallback {
    registration_tx: SafeSender<Result<binder::Strong<dyn IRegistration>>>,
}

impl GetRegistrationCallback {
    pub fn new_native_binder(
        registration_tx: oneshot::Sender<Result<binder::Strong<dyn IRegistration>>>,
    ) -> Strong<dyn IGetRegistrationCallback> {
        let result: Self =
            GetRegistrationCallback { registration_tx: SafeSender::new(registration_tx) };
        BnGetRegistrationCallback::new_binder(result, BinderFeatures::default())
    }
}

impl Interface for GetRegistrationCallback {}

impl IGetRegistrationCallback for GetRegistrationCallback {
    fn onSuccess(&self, registration: &Strong<dyn IRegistration>) -> binder::Result<()> {
        let _wp = wd::watch_millis("IGetRegistrationCallback::onSuccess", 500);
        self.registration_tx.send(Ok(registration.clone()));
        Ok(())
    }
    fn onCancel(&self) -> binder::Result<()> {
        let _wp = wd::watch_millis("IGetRegistrationCallback::onCancel", 500);
        log::warn!("IGetRegistrationCallback cancelled");
        self.registration_tx.send(
            Err(Error::Rc(ResponseCode::OUT_OF_KEYS_TRANSIENT_ERROR))
                .context(source_location_msg!("GetRegistrationCallback cancelled.")),
        );
        Ok(())
    }
    fn onError(&self, description: &str) -> binder::Result<()> {
        let _wp = wd::watch_millis("IGetRegistrationCallback::onError", 500);
        log::error!("IGetRegistrationCallback failed: '{description}'");
        self.registration_tx
            .send(Err(Error::Rc(ResponseCode::OUT_OF_KEYS_TRANSIENT_ERROR)).context(
                source_location_msg!("GetRegistrationCallback failed: {:?}", description),
            ));
        Ok(())
    }
}

/// Make a new connection to a IRegistration service.
async fn get_rkpd_registration(rpc_name: &str) -> Result<binder::Strong<dyn IRegistration>> {
    let remote_provisioning: Strong<dyn IRemoteProvisioning> =
        map_binder_status_code(binder::get_interface("remote_provisioning"))
            .context(source_location_msg!("Trying to connect to IRemoteProvisioning service."))?;

    let (tx, rx) = oneshot::channel();
    let cb = GetRegistrationCallback::new_native_binder(tx);

    remote_provisioning
        .getRegistration(rpc_name, &cb)
        .context(source_location_msg!("Trying to get registration."))?;

    match timeout(RKPD_TIMEOUT, rx).await {
        Err(e) => Err(Error::Rc(ResponseCode::SYSTEM_ERROR))
            .context(source_location_msg!("Waiting for RKPD: {:?}", e)),
        Ok(v) => v.unwrap(),
    }
}

struct GetKeyCallback {
    key_tx: SafeSender<Result<RemotelyProvisionedKey>>,
}

impl GetKeyCallback {
    pub fn new_native_binder(
        key_tx: oneshot::Sender<Result<RemotelyProvisionedKey>>,
    ) -> Strong<dyn IGetKeyCallback> {
        let result: Self = GetKeyCallback { key_tx: SafeSender::new(key_tx) };
        BnGetKeyCallback::new_binder(result, BinderFeatures::default())
    }
}

impl Interface for GetKeyCallback {}

impl IGetKeyCallback for GetKeyCallback {
    fn onSuccess(&self, key: &RemotelyProvisionedKey) -> binder::Result<()> {
        let _wp = wd::watch_millis("IGetKeyCallback::onSuccess", 500);
        self.key_tx.send(Ok(RemotelyProvisionedKey {
            keyBlob: key.keyBlob.clone(),
            encodedCertChain: key.encodedCertChain.clone(),
        }));
        Ok(())
    }
    fn onCancel(&self) -> binder::Result<()> {
        let _wp = wd::watch_millis("IGetKeyCallback::onCancel", 500);
        log::warn!("IGetKeyCallback cancelled");
        self.key_tx.send(
            Err(Error::Rc(ResponseCode::OUT_OF_KEYS_TRANSIENT_ERROR))
                .context(source_location_msg!("GetKeyCallback cancelled.")),
        );
        Ok(())
    }
    fn onError(&self, error: GetKeyErrorCode, description: &str) -> binder::Result<()> {
        let _wp = wd::watch_millis("IGetKeyCallback::onError", 500);
        log::error!("IGetKeyCallback failed: {description}");
        let rc = match error {
            GetKeyErrorCode::ERROR_UNKNOWN => ResponseCode::OUT_OF_KEYS_TRANSIENT_ERROR,
            GetKeyErrorCode::ERROR_PERMANENT => ResponseCode::OUT_OF_KEYS_PERMANENT_ERROR,
            GetKeyErrorCode::ERROR_PENDING_INTERNET_CONNECTIVITY => {
                ResponseCode::OUT_OF_KEYS_PENDING_INTERNET_CONNECTIVITY
            }
            GetKeyErrorCode::ERROR_REQUIRES_SECURITY_PATCH => {
                ResponseCode::OUT_OF_KEYS_REQUIRES_SYSTEM_UPGRADE
            }
            _ => {
                log::error!("Unexpected error from rkpd: {error:?}");
                ResponseCode::OUT_OF_KEYS_TRANSIENT_ERROR
            }
        };
        self.key_tx.send(Err(Error::Rc(rc)).context(source_location_msg!(
            "GetKeyCallback failed: {:?} {:?}",
            error,
            description
        )));
        Ok(())
    }
}

async fn get_rkpd_attestation_key_from_registration_async(
    registration: &Strong<dyn IRegistration>,
    caller_uid: u32,
) -> Result<RemotelyProvisionedKey> {
    let (tx, rx) = oneshot::channel();
    let cb = GetKeyCallback::new_native_binder(tx);

    registration
        .getKey(caller_uid.try_into().unwrap(), &cb)
        .context(source_location_msg!("Trying to get key."))?;

    match timeout(RKPD_TIMEOUT, rx).await {
        Err(e) => {
            // Make a best effort attempt to cancel the timed out request.
            if let Err(e) = registration.cancelGetKey(&cb) {
                log::error!("IRegistration::cancelGetKey failed: {:?}", e);
            }
            Err(Error::Rc(ResponseCode::OUT_OF_KEYS_TRANSIENT_ERROR))
                .context(source_location_msg!("Waiting for RKPD key timed out: {:?}", e))
        }
        Ok(v) => v.unwrap(),
    }
}

async fn get_rkpd_attestation_key_async(
    rpc_name: &str,
    caller_uid: u32,
) -> Result<RemotelyProvisionedKey> {
    let registration = get_rkpd_registration(rpc_name)
        .await
        .context(source_location_msg!("Trying to get to IRegistration service."))?;
    get_rkpd_attestation_key_from_registration_async(&registration, caller_uid).await
}

struct StoreUpgradedKeyCallback {
    completer: SafeSender<Result<()>>,
}

impl StoreUpgradedKeyCallback {
    pub fn new_native_binder(
        completer: oneshot::Sender<Result<()>>,
    ) -> Strong<dyn IStoreUpgradedKeyCallback> {
        let result: Self = StoreUpgradedKeyCallback { completer: SafeSender::new(completer) };
        BnStoreUpgradedKeyCallback::new_binder(result, BinderFeatures::default())
    }
}

impl Interface for StoreUpgradedKeyCallback {}

impl IStoreUpgradedKeyCallback for StoreUpgradedKeyCallback {
    fn onSuccess(&self) -> binder::Result<()> {
        let _wp = wd::watch_millis("IGetRegistrationCallback::onSuccess", 500);
        self.completer.send(Ok(()));
        Ok(())
    }

    fn onError(&self, error: &str) -> binder::Result<()> {
        let _wp = wd::watch_millis("IGetRegistrationCallback::onError", 500);
        log::error!("IGetRegistrationCallback failed: {error}");
        self.completer.send(
            Err(Error::Rc(ResponseCode::SYSTEM_ERROR))
                .context(source_location_msg!("Failed to store upgraded key: {:?}", error)),
        );
        Ok(())
    }
}

async fn store_rkpd_attestation_key_with_registration_async(
    registration: &Strong<dyn IRegistration>,
    key_blob: &[u8],
    upgraded_blob: &[u8],
) -> Result<()> {
    let (tx, rx) = oneshot::channel();
    let cb = StoreUpgradedKeyCallback::new_native_binder(tx);

    registration
        .storeUpgradedKeyAsync(key_blob, upgraded_blob, &cb)
        .context(source_location_msg!("Failed to store upgraded blob with RKPD."))?;

    match timeout(RKPD_TIMEOUT, rx).await {
        Err(e) => Err(Error::Rc(ResponseCode::SYSTEM_ERROR))
            .context(source_location_msg!("Waiting for RKPD to complete storing key: {:?}", e)),
        Ok(v) => v.unwrap(),
    }
}

async fn store_rkpd_attestation_key_async(
    rpc_name: &str,
    key_blob: &[u8],
    upgraded_blob: &[u8],
) -> Result<()> {
    let registration = get_rkpd_registration(rpc_name)
        .await
        .context(source_location_msg!("Trying to get to IRegistration service."))?;
    store_rkpd_attestation_key_with_registration_async(&registration, key_blob, upgraded_blob).await
}

/// Get attestation key from RKPD.
pub fn get_rkpd_attestation_key(rpc_name: &str, caller_uid: u32) -> Result<RemotelyProvisionedKey> {
    let _wp = wd::watch_millis("Calling get_rkpd_attestation_key()", 500);
    tokio_rt().block_on(get_rkpd_attestation_key_async(rpc_name, caller_uid))
}

/// Store attestation key in RKPD.
pub fn store_rkpd_attestation_key(
    rpc_name: &str,
    key_blob: &[u8],
    upgraded_blob: &[u8],
) -> Result<()> {
    let _wp = wd::watch_millis("Calling store_rkpd_attestation_key()", 500);
    tokio_rt().block_on(store_rkpd_attestation_key_async(rpc_name, key_blob, upgraded_blob))
}

#[cfg(test)]
mod tests {
    use super::*;
    use android_security_rkp_aidl::aidl::android::security::rkp::IRegistration::BnRegistration;
    use std::collections::HashMap;
    use std::sync::atomic::{AtomicU32, Ordering};
    use std::sync::{Arc, Mutex};

    const DEFAULT_RPC_SERVICE_NAME: &str =
        "android.hardware.security.keymint.IRemotelyProvisionedComponent/default";

    struct MockRegistrationValues {
        key: RemotelyProvisionedKey,
        latency: Option<Duration>,
        thread_join_handles: Vec<Option<std::thread::JoinHandle<()>>>,
    }

    struct MockRegistration(Arc<Mutex<MockRegistrationValues>>);

    impl MockRegistration {
        pub fn new_native_binder(
            key: &RemotelyProvisionedKey,
            latency: Option<Duration>,
        ) -> Strong<dyn IRegistration> {
            let result = Self(Arc::new(Mutex::new(MockRegistrationValues {
                key: RemotelyProvisionedKey {
                    keyBlob: key.keyBlob.clone(),
                    encodedCertChain: key.encodedCertChain.clone(),
                },
                latency,
                thread_join_handles: Vec::new(),
            })));
            BnRegistration::new_binder(result, BinderFeatures::default())
        }
    }

    impl Drop for MockRegistration {
        fn drop(&mut self) {
            let mut values = self.0.lock().unwrap();
            for handle in values.thread_join_handles.iter_mut() {
                // These are test threads. So, no need to worry too much about error handling.
                handle.take().unwrap().join().unwrap();
            }
        }
    }

    impl Interface for MockRegistration {}

    impl IRegistration for MockRegistration {
        fn getKey(&self, _: i32, cb: &Strong<dyn IGetKeyCallback>) -> binder::Result<()> {
            let mut values = self.0.lock().unwrap();
            let key = RemotelyProvisionedKey {
                keyBlob: values.key.keyBlob.clone(),
                encodedCertChain: values.key.encodedCertChain.clone(),
            };
            let latency = values.latency;
            let get_key_cb = cb.clone();

            // Need a separate thread to trigger timeout in the caller.
            let join_handle = std::thread::spawn(move || {
                if let Some(duration) = latency {
                    std::thread::sleep(duration);
                }
                get_key_cb.onSuccess(&key).unwrap();
            });
            values.thread_join_handles.push(Some(join_handle));
            Ok(())
        }

        fn cancelGetKey(&self, _: &Strong<dyn IGetKeyCallback>) -> binder::Result<()> {
            Ok(())
        }

        fn storeUpgradedKeyAsync(
            &self,
            _: &[u8],
            _: &[u8],
            cb: &Strong<dyn IStoreUpgradedKeyCallback>,
        ) -> binder::Result<()> {
            // We are primarily concerned with timing out correctly. Storing the key in this mock
            // registration isn't particularly interesting, so skip that part.
            let values = self.0.lock().unwrap();
            let store_cb = cb.clone();
            let latency = values.latency;

            std::thread::spawn(move || {
                if let Some(duration) = latency {
                    std::thread::sleep(duration);
                }
                store_cb.onSuccess().unwrap();
            });
            Ok(())
        }
    }

    fn get_mock_registration(
        key: &RemotelyProvisionedKey,
        latency: Option<Duration>,
    ) -> Result<binder::Strong<dyn IRegistration>> {
        let (tx, rx) = oneshot::channel();
        let cb = GetRegistrationCallback::new_native_binder(tx);
        let mock_registration = MockRegistration::new_native_binder(key, latency);

        assert!(cb.onSuccess(&mock_registration).is_ok());
        tokio_rt().block_on(rx).unwrap()
    }

    // Using the same key ID makes test cases race with each other. So, we use separate key IDs for
    // different test cases.
    fn get_next_key_id() -> u32 {
        static ID: AtomicU32 = AtomicU32::new(0);
        ID.fetch_add(1, Ordering::Relaxed)
    }

    #[test]
    fn test_get_registration_cb_success() {
        let key: RemotelyProvisionedKey = Default::default();
        let registration = get_mock_registration(&key, /*latency=*/ None);
        assert!(registration.is_ok());
    }

    #[test]
    fn test_get_registration_cb_cancel() {
        let (tx, rx) = oneshot::channel();
        let cb = GetRegistrationCallback::new_native_binder(tx);
        assert!(cb.onCancel().is_ok());

        let result = tokio_rt().block_on(rx).unwrap();
        assert_eq!(
            result.unwrap_err().downcast::<Error>().unwrap(),
            Error::Rc(ResponseCode::OUT_OF_KEYS_TRANSIENT_ERROR)
        );
    }

    #[test]
    fn test_get_registration_cb_error() {
        let (tx, rx) = oneshot::channel();
        let cb = GetRegistrationCallback::new_native_binder(tx);
        assert!(cb.onError("error").is_ok());

        let result = tokio_rt().block_on(rx).unwrap();
        assert_eq!(
            result.unwrap_err().downcast::<Error>().unwrap(),
            Error::Rc(ResponseCode::OUT_OF_KEYS_TRANSIENT_ERROR)
        );
    }

    #[test]
    fn test_get_key_cb_success() {
        let mock_key =
            RemotelyProvisionedKey { keyBlob: vec![1, 2, 3], encodedCertChain: vec![4, 5, 6] };
        let (tx, rx) = oneshot::channel();
        let cb = GetKeyCallback::new_native_binder(tx);
        assert!(cb.onSuccess(&mock_key).is_ok());

        let key = tokio_rt().block_on(rx).unwrap().unwrap();
        assert_eq!(key, mock_key);
    }

    #[test]
    fn test_get_key_cb_cancel() {
        let (tx, rx) = oneshot::channel();
        let cb = GetKeyCallback::new_native_binder(tx);
        assert!(cb.onCancel().is_ok());

        let result = tokio_rt().block_on(rx).unwrap();
        assert_eq!(
            result.unwrap_err().downcast::<Error>().unwrap(),
            Error::Rc(ResponseCode::OUT_OF_KEYS_TRANSIENT_ERROR)
        );
    }

    #[test]
    fn test_get_key_cb_error() {
        let error_mapping = HashMap::from([
            (GetKeyErrorCode::ERROR_UNKNOWN, ResponseCode::OUT_OF_KEYS_TRANSIENT_ERROR),
            (GetKeyErrorCode::ERROR_PERMANENT, ResponseCode::OUT_OF_KEYS_PERMANENT_ERROR),
            (
                GetKeyErrorCode::ERROR_PENDING_INTERNET_CONNECTIVITY,
                ResponseCode::OUT_OF_KEYS_PENDING_INTERNET_CONNECTIVITY,
            ),
            (
                GetKeyErrorCode::ERROR_REQUIRES_SECURITY_PATCH,
                ResponseCode::OUT_OF_KEYS_REQUIRES_SYSTEM_UPGRADE,
            ),
        ]);

        // Loop over the generated list of enum values to better ensure this test stays in
        // sync with the AIDL.
        for get_key_error in GetKeyErrorCode::enum_values() {
            let (tx, rx) = oneshot::channel();
            let cb = GetKeyCallback::new_native_binder(tx);
            assert!(cb.onError(get_key_error, "error").is_ok());

            let result = tokio_rt().block_on(rx).unwrap();
            assert_eq!(
                result.unwrap_err().downcast::<Error>().unwrap(),
                Error::Rc(error_mapping[&get_key_error]),
            );
        }
    }

    #[test]
    fn test_store_upgraded_cb_success() {
        let (tx, rx) = oneshot::channel();
        let cb = StoreUpgradedKeyCallback::new_native_binder(tx);
        assert!(cb.onSuccess().is_ok());

        tokio_rt().block_on(rx).unwrap().unwrap();
    }

    #[test]
    fn test_store_upgraded_key_cb_error() {
        let (tx, rx) = oneshot::channel();
        let cb = StoreUpgradedKeyCallback::new_native_binder(tx);
        assert!(cb.onError("oh no! it failed").is_ok());

        let result = tokio_rt().block_on(rx).unwrap();
        assert_eq!(
            result.unwrap_err().downcast::<Error>().unwrap(),
            Error::Rc(ResponseCode::SYSTEM_ERROR)
        );
    }

    #[test]
    fn test_get_mock_key_success() {
        let mock_key =
            RemotelyProvisionedKey { keyBlob: vec![1, 2, 3], encodedCertChain: vec![4, 5, 6] };
        let registration = get_mock_registration(&mock_key, /*latency=*/ None).unwrap();

        let key = tokio_rt()
            .block_on(get_rkpd_attestation_key_from_registration_async(&registration, 0))
            .unwrap();
        assert_eq!(key, mock_key);
    }

    #[test]
    fn test_get_mock_key_timeout() {
        let mock_key =
            RemotelyProvisionedKey { keyBlob: vec![1, 2, 3], encodedCertChain: vec![4, 5, 6] };
        let latency = RKPD_TIMEOUT + Duration::from_secs(1);
        let registration = get_mock_registration(&mock_key, Some(latency)).unwrap();

        let result =
            tokio_rt().block_on(get_rkpd_attestation_key_from_registration_async(&registration, 0));
        assert_eq!(
            result.unwrap_err().downcast::<Error>().unwrap(),
            Error::Rc(ResponseCode::OUT_OF_KEYS_TRANSIENT_ERROR)
        );
    }

    #[test]
    fn test_store_mock_key_success() {
        let mock_key =
            RemotelyProvisionedKey { keyBlob: vec![1, 2, 3], encodedCertChain: vec![4, 5, 6] };
        let registration = get_mock_registration(&mock_key, /*latency=*/ None).unwrap();
        tokio_rt()
            .block_on(store_rkpd_attestation_key_with_registration_async(&registration, &[], &[]))
            .unwrap();
    }

    #[test]
    fn test_store_mock_key_timeout() {
        let mock_key =
            RemotelyProvisionedKey { keyBlob: vec![1, 2, 3], encodedCertChain: vec![4, 5, 6] };
        let latency = RKPD_TIMEOUT + Duration::from_secs(1);
        let registration = get_mock_registration(&mock_key, Some(latency)).unwrap();

        let result = tokio_rt().block_on(store_rkpd_attestation_key_with_registration_async(
            &registration,
            &[],
            &[],
        ));
        assert_eq!(
            result.unwrap_err().downcast::<Error>().unwrap(),
            Error::Rc(ResponseCode::SYSTEM_ERROR)
        );
    }

    #[test]
    fn test_get_rkpd_attestation_key() {
        binder::ProcessState::start_thread_pool();
        let key_id = get_next_key_id();
        let key = get_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, key_id).unwrap();
        assert!(!key.keyBlob.is_empty());
        assert!(!key.encodedCertChain.is_empty());
    }

    #[test]
    fn test_get_rkpd_attestation_key_same_caller() {
        binder::ProcessState::start_thread_pool();
        let key_id = get_next_key_id();

        // Multiple calls should return the same key.
        let first_key = get_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, key_id).unwrap();
        let second_key = get_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, key_id).unwrap();

        assert_eq!(first_key.keyBlob, second_key.keyBlob);
        assert_eq!(first_key.encodedCertChain, second_key.encodedCertChain);
    }

    #[test]
    fn test_get_rkpd_attestation_key_different_caller() {
        binder::ProcessState::start_thread_pool();
        let first_key_id = get_next_key_id();
        let second_key_id = get_next_key_id();

        // Different callers should be getting different keys.
        let first_key = get_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, first_key_id).unwrap();
        let second_key = get_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, second_key_id).unwrap();

        assert_ne!(first_key.keyBlob, second_key.keyBlob);
        assert_ne!(first_key.encodedCertChain, second_key.encodedCertChain);
    }

    #[test]
    // Couple of things to note:
    // 1. This test must never run with UID of keystore. Otherwise, it can mess up keys stored by
    //    keystore.
    // 2. Storing and reading the stored key is prone to race condition. So, we only do this in one
    //    test case.
    fn test_store_rkpd_attestation_key() {
        binder::ProcessState::start_thread_pool();
        let key_id = get_next_key_id();
        let key = get_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, key_id).unwrap();
        let new_blob: [u8; 8] = rand::random();

        assert!(
            store_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, &key.keyBlob, &new_blob).is_ok()
        );

        let new_key = get_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, key_id).unwrap();

        // Restore original key so that we don't leave RKPD with invalid blobs.
        assert!(
            store_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, &new_blob, &key.keyBlob).is_ok()
        );
        assert_eq!(new_key.keyBlob, new_blob);
    }

    #[test]
    fn test_stress_get_rkpd_attestation_key() {
        binder::ProcessState::start_thread_pool();
        let key_id = get_next_key_id();
        let mut threads = vec![];
        const NTHREADS: u32 = 10;
        const NCALLS: u32 = 1000;

        for _ in 0..NTHREADS {
            threads.push(std::thread::spawn(move || {
                for _ in 0..NCALLS {
                    let key = get_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, key_id).unwrap();
                    assert!(!key.keyBlob.is_empty());
                    assert!(!key.encodedCertChain.is_empty());
                }
            }));
        }

        for t in threads {
            assert!(t.join().is_ok());
        }
    }
}