aboutsummaryrefslogtreecommitdiff
path: root/src/js.rs
blob: 574c4dc326030d0ce882bcddd44de2d2db2f47e3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
// Copyright 2018 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
use crate::Error;

extern crate std;
use std::thread_local;

use js_sys::{global, Function, Uint8Array};
use wasm_bindgen::{prelude::wasm_bindgen, JsCast, JsValue};

// Size of our temporary Uint8Array buffer used with WebCrypto methods
// Maximum is 65536 bytes see https://developer.mozilla.org/en-US/docs/Web/API/Crypto/getRandomValues
const WEB_CRYPTO_BUFFER_SIZE: usize = 256;

enum RngSource {
    Node(NodeCrypto),
    Web(WebCrypto, Uint8Array),
}

// JsValues are always per-thread, so we initialize RngSource for each thread.
//   See: https://github.com/rustwasm/wasm-bindgen/pull/955
thread_local!(
    static RNG_SOURCE: Result<RngSource, Error> = getrandom_init();
);

pub(crate) fn getrandom_inner(dest: &mut [u8]) -> Result<(), Error> {
    RNG_SOURCE.with(|result| {
        let source = result.as_ref().map_err(|&e| e)?;

        match source {
            RngSource::Node(n) => {
                if n.random_fill_sync(dest).is_err() {
                    return Err(Error::NODE_RANDOM_FILL_SYNC);
                }
            }
            RngSource::Web(crypto, buf) => {
                // getRandomValues does not work with all types of WASM memory,
                // so we initially write to browser memory to avoid exceptions.
                for chunk in dest.chunks_mut(WEB_CRYPTO_BUFFER_SIZE) {
                    // The chunk can be smaller than buf's length, so we call to
                    // JS to create a smaller view of buf without allocation.
                    let sub_buf = buf.subarray(0, chunk.len() as u32);

                    if crypto.get_random_values(&sub_buf).is_err() {
                        return Err(Error::WEB_GET_RANDOM_VALUES);
                    }
                    sub_buf.copy_to(chunk);
                }
            }
        };
        Ok(())
    })
}

fn getrandom_init() -> Result<RngSource, Error> {
    let global: Global = global().unchecked_into();

    // Get the Web Crypto interface if we are in a browser, Web Worker, Deno,
    // or another environment that supports the Web Cryptography API. This
    // also allows for user-provided polyfills in unsupported environments.
    let crypto = match global.crypto() {
        // Standard Web Crypto interface
        c if c.is_object() => c,
        // Node.js CommonJS Crypto module
        _ if is_node(&global) => {
            // If module.require isn't a valid function, we are in an ES module.
            match Module::require_fn().and_then(JsCast::dyn_into::<Function>) {
                Ok(require_fn) => match require_fn.call1(&global, &JsValue::from_str("crypto")) {
                    Ok(n) => return Ok(RngSource::Node(n.unchecked_into())),
                    Err(_) => return Err(Error::NODE_CRYPTO),
                },
                Err(_) => return Err(Error::NODE_ES_MODULE),
            }
        }
        // IE 11 Workaround
        _ => match global.ms_crypto() {
            c if c.is_object() => c,
            _ => return Err(Error::WEB_CRYPTO),
        },
    };

    let buf = Uint8Array::new_with_length(WEB_CRYPTO_BUFFER_SIZE as u32);
    Ok(RngSource::Web(crypto, buf))
}

// Taken from https://www.npmjs.com/package/browser-or-node
fn is_node(global: &Global) -> bool {
    let process = global.process();
    if process.is_object() {
        let versions = process.versions();
        if versions.is_object() {
            return versions.node().is_string();
        }
    }
    false
}

#[wasm_bindgen]
extern "C" {
    // Return type of js_sys::global()
    type Global;

    // Web Crypto API: Crypto interface (https://www.w3.org/TR/WebCryptoAPI/)
    type WebCrypto;
    // Getters for the WebCrypto API
    #[wasm_bindgen(method, getter)]
    fn crypto(this: &Global) -> WebCrypto;
    #[wasm_bindgen(method, getter, js_name = msCrypto)]
    fn ms_crypto(this: &Global) -> WebCrypto;
    // Crypto.getRandomValues()
    #[wasm_bindgen(method, js_name = getRandomValues, catch)]
    fn get_random_values(this: &WebCrypto, buf: &Uint8Array) -> Result<(), JsValue>;

    // Node JS crypto module (https://nodejs.org/api/crypto.html)
    type NodeCrypto;
    // crypto.randomFillSync()
    #[wasm_bindgen(method, js_name = randomFillSync, catch)]
    fn random_fill_sync(this: &NodeCrypto, buf: &mut [u8]) -> Result<(), JsValue>;

    // Ideally, we would just use `fn require(s: &str)` here. However, doing
    // this causes a Webpack warning. So we instead return the function itself
    // and manually invoke it using call1. This also lets us to check that the
    // function actually exists, allowing for better error messages. See:
    //   https://github.com/rust-random/getrandom/issues/224
    //   https://github.com/rust-random/getrandom/issues/256
    type Module;
    #[wasm_bindgen(getter, static_method_of = Module, js_class = module, js_name = require, catch)]
    fn require_fn() -> Result<JsValue, JsValue>;

    // Node JS process Object (https://nodejs.org/api/process.html)
    #[wasm_bindgen(method, getter)]
    fn process(this: &Global) -> Process;
    type Process;
    #[wasm_bindgen(method, getter)]
    fn versions(this: &Process) -> Versions;
    type Versions;
    #[wasm_bindgen(method, getter)]
    fn node(this: &Versions) -> JsValue;
}