aboutsummaryrefslogtreecommitdiff
path: root/src/unlock_notify.rs
blob: 5f9cdbfb6d27186d383126751ded0630c029052a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
//! [Unlock Notification](http://sqlite.org/unlock_notify.html)

use std::os::raw::c_int;
#[cfg(feature = "unlock_notify")]
use std::os::raw::c_void;
#[cfg(feature = "unlock_notify")]
use std::panic::catch_unwind;
#[cfg(feature = "unlock_notify")]
use std::sync::{Condvar, Mutex};

use crate::ffi;

#[cfg(feature = "unlock_notify")]
struct UnlockNotification {
    cond: Condvar,      // Condition variable to wait on
    mutex: Mutex<bool>, // Mutex to protect structure
}

#[cfg(feature = "unlock_notify")]
#[allow(clippy::mutex_atomic)]
impl UnlockNotification {
    fn new() -> UnlockNotification {
        UnlockNotification {
            cond: Condvar::new(),
            mutex: Mutex::new(false),
        }
    }

    fn fired(&self) {
        let mut flag = self.mutex.lock().unwrap();
        *flag = true;
        self.cond.notify_one();
    }

    fn wait(&self) {
        let mut fired = self.mutex.lock().unwrap();
        while !*fired {
            fired = self.cond.wait(fired).unwrap();
        }
    }
}

/// This function is an unlock-notify callback
#[cfg(feature = "unlock_notify")]
unsafe extern "C" fn unlock_notify_cb(ap_arg: *mut *mut c_void, n_arg: c_int) {
    use std::slice::from_raw_parts;
    let args = from_raw_parts(ap_arg as *const &UnlockNotification, n_arg as usize);
    for un in args {
        let _ = catch_unwind(std::panic::AssertUnwindSafe(|| un.fired()));
    }
}

#[cfg(feature = "unlock_notify")]
pub unsafe fn is_locked(db: *mut ffi::sqlite3, rc: c_int) -> bool {
    rc == ffi::SQLITE_LOCKED_SHAREDCACHE
        || (rc & 0xFF) == ffi::SQLITE_LOCKED
            && ffi::sqlite3_extended_errcode(db) == ffi::SQLITE_LOCKED_SHAREDCACHE
}

/// This function assumes that an SQLite API call (either `sqlite3_prepare_v2()`
/// or `sqlite3_step()`) has just returned `SQLITE_LOCKED`. The argument is the
/// associated database connection.
///
/// This function calls `sqlite3_unlock_notify()` to register for an
/// unlock-notify callback, then blocks until that callback is delivered
/// and returns `SQLITE_OK`. The caller should then retry the failed operation.
///
/// Or, if `sqlite3_unlock_notify()` indicates that to block would deadlock
/// the system, then this function returns `SQLITE_LOCKED` immediately. In
/// this case the caller should not retry the operation and should roll
/// back the current transaction (if any).
#[cfg(feature = "unlock_notify")]
pub unsafe fn wait_for_unlock_notify(db: *mut ffi::sqlite3) -> c_int {
    let un = UnlockNotification::new();
    /* Register for an unlock-notify callback. */
    let rc = ffi::sqlite3_unlock_notify(
        db,
        Some(unlock_notify_cb),
        &un as *const UnlockNotification as *mut c_void,
    );
    debug_assert!(
        rc == ffi::SQLITE_LOCKED || rc == ffi::SQLITE_LOCKED_SHAREDCACHE || rc == ffi::SQLITE_OK
    );
    if rc == ffi::SQLITE_OK {
        un.wait();
    }
    rc
}

#[cfg(not(feature = "unlock_notify"))]
pub unsafe fn is_locked(_db: *mut ffi::sqlite3, _rc: c_int) -> bool {
    unreachable!()
}

#[cfg(not(feature = "unlock_notify"))]
pub unsafe fn wait_for_unlock_notify(_db: *mut ffi::sqlite3) -> c_int {
    unreachable!()
}

#[cfg(feature = "unlock_notify")]
#[cfg(test)]
mod test {
    use crate::{Connection, OpenFlags, Result, Transaction, TransactionBehavior};
    use std::sync::mpsc::sync_channel;
    use std::thread;
    use std::time;

    #[test]
    fn test_unlock_notify() -> Result<()> {
        let url = "file::memory:?cache=shared";
        let flags = OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_URI;
        let db1 = Connection::open_with_flags(url, flags)?;
        db1.execute_batch("CREATE TABLE foo (x)")?;
        let (rx, tx) = sync_channel(0);
        let child = thread::spawn(move || {
            let mut db2 = Connection::open_with_flags(url, flags).unwrap();
            let tx2 = Transaction::new(&mut db2, TransactionBehavior::Immediate).unwrap();
            tx2.execute_batch("INSERT INTO foo VALUES (42)").unwrap();
            rx.send(1).unwrap();
            let ten_millis = time::Duration::from_millis(10);
            thread::sleep(ten_millis);
            tx2.commit().unwrap();
        });
        assert_eq!(tx.recv().unwrap(), 1);
        let the_answer: Result<i64> = db1.query_row("SELECT x FROM foo", [], |r| r.get(0));
        assert_eq!(42i64, the_answer?);
        child.join().unwrap();
        Ok(())
    }
}