aboutsummaryrefslogtreecommitdiff
path: root/src/functions.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/functions.rs')
-rw-r--r--src/functions.rs379
1 files changed, 207 insertions, 172 deletions
diff --git a/src/functions.rs b/src/functions.rs
index 3531391..e613182 100644
--- a/src/functions.rs
+++ b/src/functions.rs
@@ -1,17 +1,17 @@
-//! `feature = "functions"` Create or redefine SQL functions.
+//! Create or redefine SQL functions.
//!
//! # Example
//!
//! Adding a `regexp` function to a connection in which compiled regular
//! expressions are cached in a `HashMap`. For an alternative implementation
-//! that uses SQLite's [Function Auxilliary Data](https://www.sqlite.org/c3ref/get_auxdata.html) interface
+//! that uses SQLite's [Function Auxiliary Data](https://www.sqlite.org/c3ref/get_auxdata.html) interface
//! to avoid recompiling regular expressions, see the unit tests for this
//! module.
//!
//! ```rust
//! use regex::Regex;
//! use rusqlite::functions::FunctionFlags;
-//! use rusqlite::{Connection, Error, Result, NO_PARAMS};
+//! use rusqlite::{Connection, Error, Result};
//! use std::sync::Arc;
//! type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
//!
@@ -22,10 +22,9 @@
//! FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
//! move |ctx| {
//! assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
-//! let regexp: Arc<Regex> = ctx
-//! .get_or_create_aux(0, |vr| -> Result<_, BoxError> {
-//! Ok(Regex::new(vr.as_str()?)?)
-//! })?;
+//! let regexp: Arc<Regex> = ctx.get_or_create_aux(0, |vr| -> Result<_, BoxError> {
+//! Ok(Regex::new(vr.as_str()?)?)
+//! })?;
//! let is_match = {
//! let text = ctx
//! .get_raw(1)
@@ -44,17 +43,18 @@
//! let db = Connection::open_in_memory()?;
//! add_regexp_function(&db)?;
//!
-//! let is_match: bool = db.query_row(
-//! "SELECT regexp('[aeiou]*', 'aaaaeeeiii')",
-//! NO_PARAMS,
-//! |row| row.get(0),
-//! )?;
+//! let is_match: bool =
+//! db.query_row("SELECT regexp('[aeiou]*', 'aaaaeeeiii')", [], |row| {
+//! row.get(0)
+//! })?;
//!
//! assert!(is_match);
//! Ok(())
//! }
//! ```
use std::any::Any;
+use std::marker::PhantomData;
+use std::ops::Deref;
use std::os::raw::{c_int, c_void};
use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe};
use std::ptr;
@@ -84,27 +84,24 @@ unsafe fn report_error(ctx: *mut sqlite3_context, err: &Error) {
ffi::SQLITE_CONSTRAINT
}
- match *err {
- Error::SqliteFailure(ref err, ref s) => {
- ffi::sqlite3_result_error_code(ctx, err.extended_code);
- if let Some(Ok(cstr)) = s.as_ref().map(|s| str_to_cstring(s)) {
- ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
- }
+ if let Error::SqliteFailure(ref err, ref s) = *err {
+ ffi::sqlite3_result_error_code(ctx, err.extended_code);
+ if let Some(Ok(cstr)) = s.as_ref().map(|s| str_to_cstring(s)) {
+ ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
}
- _ => {
- ffi::sqlite3_result_error_code(ctx, constraint_error_code());
- if let Ok(cstr) = str_to_cstring(&err.to_string()) {
- ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
- }
+ } else {
+ ffi::sqlite3_result_error_code(ctx, constraint_error_code());
+ if let Ok(cstr) = str_to_cstring(&err.to_string()) {
+ ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
}
}
}
unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) {
- drop(Box::from_raw(p as *mut T));
+ drop(Box::from_raw(p.cast::<T>()));
}
-/// `feature = "functions"` Context is a wrapper for the SQLite function
+/// Context is a wrapper for the SQLite function
/// evaluation context.
pub struct Context<'a> {
ctx: *mut sqlite3_context,
@@ -113,11 +110,15 @@ pub struct Context<'a> {
impl Context<'_> {
/// Returns the number of arguments to the function.
+ #[inline]
+ #[must_use]
pub fn len(&self) -> usize {
self.args.len()
}
/// Returns `true` when there is no argument.
+ #[inline]
+ #[must_use]
pub fn is_empty(&self) -> bool {
self.args.is_empty()
}
@@ -126,7 +127,8 @@ impl Context<'_> {
///
/// # Failure
///
- /// Will panic if `idx` is greater than or equal to `self.len()`.
+ /// Will panic if `idx` is greater than or equal to
+ /// [`self.len()`](Context::len).
///
/// Will return Err if the underlying SQLite type cannot be converted to a
/// `T`.
@@ -141,12 +143,7 @@ impl Context<'_> {
FromSqlError::Other(err) => {
Error::FromSqlConversionFailure(idx, value.data_type(), err)
}
- #[cfg(feature = "i128_blob")]
- FromSqlError::InvalidI128Size(_) => {
- Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err))
- }
- #[cfg(feature = "uuid")]
- FromSqlError::InvalidUuidSize(_) => {
+ FromSqlError::InvalidBlobSize { .. } => {
Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err))
}
})
@@ -156,17 +153,21 @@ impl Context<'_> {
///
/// # Failure
///
- /// Will panic if `idx` is greater than or equal to `self.len()`.
+ /// Will panic if `idx` is greater than or equal to
+ /// [`self.len()`](Context::len).
+ #[inline]
+ #[must_use]
pub fn get_raw(&self, idx: usize) -> ValueRef<'_> {
let arg = self.args[idx];
unsafe { ValueRef::from_value(arg) }
}
- /// Fetch or insert the the auxilliary data associated with a particular
+ /// Fetch or insert the auxiliary data associated with a particular
/// parameter. This is intended to be an easier-to-use way of fetching it
- /// compared to calling `get_aux` and `set_aux` separately.
+ /// compared to calling [`get_aux`](Context::get_aux) and
+ /// [`set_aux`](Context::set_aux) separately.
///
- /// See https://www.sqlite.org/c3ref/get_auxdata.html for a discussion of
+ /// See `https://www.sqlite.org/c3ref/get_auxdata.html` for a discussion of
/// this feature, or the unit tests of this module for an example.
pub fn get_or_create_aux<T, E, F>(&self, arg: c_int, func: F) -> Result<Arc<T>>
where
@@ -185,8 +186,8 @@ impl Context<'_> {
}
}
- /// Sets the auxilliary data associated with a particular parameter. See
- /// https://www.sqlite.org/c3ref/get_auxdata.html for a discussion of
+ /// Sets the auxiliary data associated with a particular parameter. See
+ /// `https://www.sqlite.org/c3ref/get_auxdata.html` for a discussion of
/// this feature, or the unit tests of this module for an example.
pub fn set_aux<T: Send + Sync + 'static>(&self, arg: c_int, value: T) -> Result<Arc<T>> {
let orig: Arc<T> = Arc::new(value);
@@ -197,17 +198,17 @@ impl Context<'_> {
ffi::sqlite3_set_auxdata(
self.ctx,
arg,
- raw as *mut _,
+ raw.cast(),
Some(free_boxed_value::<AuxInner>),
- )
+ );
};
Ok(orig)
}
- /// Gets the auxilliary data that was associated with a given parameter via
- /// `set_aux`. Returns `Ok(None)` if no data has been associated, and
- /// Ok(Some(v)) if it has. Returns an error if the requested type does not
- /// match.
+ /// Gets the auxiliary data that was associated with a given parameter via
+ /// [`set_aux`](Context::set_aux). Returns `Ok(None)` if no data has been
+ /// associated, and Ok(Some(v)) if it has. Returns an error if the
+ /// requested type does not match.
pub fn get_aux<T: Send + Sync + 'static>(&self, arg: c_int) -> Result<Option<Arc<T>>> {
let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *const AuxInner };
if p.is_null() {
@@ -219,11 +220,42 @@ impl Context<'_> {
.map_err(|_| Error::GetAuxWrongType)
}
}
+
+ /// Get the db connection handle via [sqlite3_context_db_handle](https://www.sqlite.org/c3ref/context_db_handle.html)
+ ///
+ /// # Safety
+ ///
+ /// This function is marked unsafe because there is a potential for other
+ /// references to the connection to be sent across threads, [see this comment](https://github.com/rusqlite/rusqlite/issues/643#issuecomment-640181213).
+ pub unsafe fn get_connection(&self) -> Result<ConnectionRef<'_>> {
+ let handle = ffi::sqlite3_context_db_handle(self.ctx);
+ Ok(ConnectionRef {
+ conn: Connection::from_handle(handle)?,
+ phantom: PhantomData,
+ })
+ }
+}
+
+/// A reference to a connection handle with a lifetime bound to something.
+pub struct ConnectionRef<'ctx> {
+ // comes from Connection::from_handle(sqlite3_context_db_handle(...))
+ // and is non-owning
+ conn: Connection,
+ phantom: PhantomData<&'ctx Context<'ctx>>,
+}
+
+impl Deref for ConnectionRef<'_> {
+ type Target = Connection;
+
+ #[inline]
+ fn deref(&self) -> &Connection {
+ &self.conn
+ }
}
type AuxInner = Arc<dyn Any + Send + Sync + 'static>;
-/// `feature = "functions"` Aggregate is the callback interface for user-defined
+/// Aggregate is the callback interface for user-defined
/// aggregate function.
///
/// `A` is the type of the aggregation context and `T` is the type of the final
@@ -234,25 +266,31 @@ where
T: ToSql,
{
/// Initializes the aggregation context. Will be called prior to the first
- /// call to `step()` to set up the context for an invocation of the
- /// function. (Note: `init()` will not be called if there are no rows.)
- fn init(&self) -> A;
+ /// call to [`step()`](Aggregate::step) to set up the context for an
+ /// invocation of the function. (Note: `init()` will not be called if
+ /// there are no rows.)
+ fn init(&self, _: &mut Context<'_>) -> Result<A>;
/// "step" function called once for each row in an aggregate group. May be
/// called 0 times if there are no rows.
fn step(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>;
/// Computes and returns the final result. Will be called exactly once for
- /// each invocation of the function. If `step()` was called at least
- /// once, will be given `Some(A)` (the same `A` as was created by
- /// `init` and given to `step`); if `step()` was not called (because
- /// the function is running against 0 rows), will be given `None`.
- fn finalize(&self, _: Option<A>) -> Result<T>;
+ /// each invocation of the function. If [`step()`](Aggregate::step) was
+ /// called at least once, will be given `Some(A)` (the same `A` as was
+ /// created by [`init`](Aggregate::init) and given to
+ /// [`step`](Aggregate::step)); if [`step()`](Aggregate::step) was not
+ /// called (because the function is running against 0 rows), will be
+ /// given `None`.
+ ///
+ /// The passed context will have no arguments.
+ fn finalize(&self, _: &mut Context<'_>, _: Option<A>) -> Result<T>;
}
-/// `feature = "window"` WindowAggregate is the callback interface for
+/// `WindowAggregate` is the callback interface for
/// user-defined aggregate window function.
#[cfg(feature = "window")]
+#[cfg_attr(docsrs, doc(cfg(feature = "window")))]
pub trait WindowAggregate<A, T>: Aggregate<A, T>
where
A: RefUnwindSafe + UnwindSafe,
@@ -292,13 +330,14 @@ bitflags::bitflags! {
}
impl Default for FunctionFlags {
+ #[inline]
fn default() -> FunctionFlags {
FunctionFlags::SQLITE_UTF8
}
}
impl Connection {
- /// `feature = "functions"` Attach a user-defined scalar function to
+ /// Attach a user-defined scalar function to
/// this database connection.
///
/// `fn_name` is the name the function will be accessible from SQL.
@@ -307,12 +346,13 @@ impl Connection {
/// given the same input, `deterministic` should be `true`.
///
/// The function will remain available until the connection is closed or
- /// until it is explicitly removed via `remove_function`.
+ /// until it is explicitly removed via
+ /// [`remove_function`](Connection::remove_function).
///
/// # Example
///
/// ```rust
- /// # use rusqlite::{Connection, Result, NO_PARAMS};
+ /// # use rusqlite::{Connection, Result};
/// # use rusqlite::functions::FunctionFlags;
/// fn scalar_function_example(db: Connection) -> Result<()> {
/// db.create_scalar_function(
@@ -325,7 +365,7 @@ impl Connection {
/// },
/// )?;
///
- /// let six_halved: f64 = db.query_row("SELECT halve(6)", NO_PARAMS, |r| r.get(0))?;
+ /// let six_halved: f64 = db.query_row("SELECT halve(6)", [], |r| r.get(0))?;
/// assert_eq!(six_halved, 3f64);
/// Ok(())
/// }
@@ -334,6 +374,7 @@ impl Connection {
/// # Failure
///
/// Will return Err if the function could not be attached to the connection.
+ #[inline]
pub fn create_scalar_function<F, T>(
&self,
fn_name: &str,
@@ -350,12 +391,13 @@ impl Connection {
.create_scalar_function(fn_name, n_arg, flags, x_func)
}
- /// `feature = "functions"` Attach a user-defined aggregate function to this
+ /// Attach a user-defined aggregate function to this
/// database connection.
///
/// # Failure
///
/// Will return Err if the function could not be attached to the connection.
+ #[inline]
pub fn create_aggregate_function<A, D, T>(
&self,
fn_name: &str,
@@ -365,7 +407,7 @@ impl Connection {
) -> Result<()>
where
A: RefUnwindSafe + UnwindSafe,
- D: Aggregate<A, T>,
+ D: Aggregate<A, T> + 'static,
T: ToSql,
{
self.db
@@ -373,12 +415,14 @@ impl Connection {
.create_aggregate_function(fn_name, n_arg, flags, aggr)
}
- /// `feature = "window"` Attach a user-defined aggregate window function to
+ /// Attach a user-defined aggregate window function to
/// this database connection.
///
- /// See https://sqlite.org/windowfunctions.html#udfwinfunc for more
+ /// See `https://sqlite.org/windowfunctions.html#udfwinfunc` for more
/// information.
#[cfg(feature = "window")]
+ #[cfg_attr(docsrs, doc(cfg(feature = "window")))]
+ #[inline]
pub fn create_window_function<A, W, T>(
&self,
fn_name: &str,
@@ -388,7 +432,7 @@ impl Connection {
) -> Result<()>
where
A: RefUnwindSafe + UnwindSafe,
- W: WindowAggregate<A, T>,
+ W: WindowAggregate<A, T> + 'static,
T: ToSql,
{
self.db
@@ -396,15 +440,17 @@ impl Connection {
.create_window_function(fn_name, n_arg, flags, aggr)
}
- /// `feature = "functions"` Removes a user-defined function from this
+ /// Removes a user-defined function from this
/// database connection.
///
/// `fn_name` and `n_arg` should match the name and number of arguments
- /// given to `create_scalar_function` or `create_aggregate_function`.
+ /// given to [`create_scalar_function`](Connection::create_scalar_function)
+ /// or [`create_aggregate_function`](Connection::create_aggregate_function).
///
/// # Failure
///
/// Will return Err if the function could not be removed.
+ #[inline]
pub fn remove_function(&self, fn_name: &str, n_arg: c_int) -> Result<()> {
self.db.borrow_mut().remove_function(fn_name, n_arg)
}
@@ -431,7 +477,7 @@ impl InnerConnection {
T: ToSql,
{
let r = catch_unwind(|| {
- let boxed_f: *mut F = ffi::sqlite3_user_data(ctx) as *mut F;
+ let boxed_f: *mut F = ffi::sqlite3_user_data(ctx).cast::<F>();
assert!(!boxed_f.is_null(), "Internal error - null function pointer");
let ctx = Context {
ctx,
@@ -463,7 +509,7 @@ impl InnerConnection {
c_name.as_ptr(),
n_arg,
flags.bits(),
- boxed_f as *mut c_void,
+ boxed_f.cast::<c_void>(),
Some(call_boxed_closure::<F, T>),
None,
None,
@@ -482,7 +528,7 @@ impl InnerConnection {
) -> Result<()>
where
A: RefUnwindSafe + UnwindSafe,
- D: Aggregate<A, T>,
+ D: Aggregate<A, T> + 'static,
T: ToSql,
{
let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr));
@@ -493,7 +539,7 @@ impl InnerConnection {
c_name.as_ptr(),
n_arg,
flags.bits(),
- boxed_aggr as *mut c_void,
+ boxed_aggr.cast::<c_void>(),
None,
Some(call_boxed_step::<A, D, T>),
Some(call_boxed_final::<A, D, T>),
@@ -513,7 +559,7 @@ impl InnerConnection {
) -> Result<()>
where
A: RefUnwindSafe + UnwindSafe,
- W: WindowAggregate<A, T>,
+ W: WindowAggregate<A, T> + 'static,
T: ToSql,
{
let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr));
@@ -524,7 +570,7 @@ impl InnerConnection {
c_name.as_ptr(),
n_arg,
flags.bits(),
- boxed_aggr as *mut c_void,
+ boxed_aggr.cast::<c_void>(),
Some(call_boxed_step::<A, W, T>),
Some(call_boxed_final::<A, W, T>),
Some(call_boxed_value::<A, W, T>),
@@ -571,27 +617,28 @@ unsafe extern "C" fn call_boxed_step<A, D, T>(
D: Aggregate<A, T>,
T: ToSql,
{
- let pac = match aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) {
- Some(pac) => pac,
- None => {
- ffi::sqlite3_result_error_nomem(ctx);
- return;
- }
+ let pac = if let Some(pac) = aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) {
+ pac
+ } else {
+ ffi::sqlite3_result_error_nomem(ctx);
+ return;
};
let r = catch_unwind(|| {
- let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D;
+ let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx).cast::<D>();
assert!(
!boxed_aggr.is_null(),
"Internal error - null aggregate pointer"
);
- if (*pac as *mut A).is_null() {
- *pac = Box::into_raw(Box::new((*boxed_aggr).init()));
- }
let mut ctx = Context {
ctx,
args: slice::from_raw_parts(argv, argc as usize),
};
+
+ if (*pac as *mut A).is_null() {
+ *pac = Box::into_raw(Box::new((*boxed_aggr).init(&mut ctx)?));
+ }
+
(*boxed_aggr).step(&mut ctx, &mut **pac)
});
let r = match r {
@@ -617,16 +664,15 @@ unsafe extern "C" fn call_boxed_inverse<A, W, T>(
W: WindowAggregate<A, T>,
T: ToSql,
{
- let pac = match aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) {
- Some(pac) => pac,
- None => {
- ffi::sqlite3_result_error_nomem(ctx);
- return;
- }
+ let pac = if let Some(pac) = aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) {
+ pac
+ } else {
+ ffi::sqlite3_result_error_nomem(ctx);
+ return;
};
let r = catch_unwind(|| {
- let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx) as *mut W;
+ let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx).cast::<W>();
assert!(
!boxed_aggr.is_null(),
"Internal error - null aggregate pointer"
@@ -671,12 +717,13 @@ where
};
let r = catch_unwind(|| {
- let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D;
+ let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx).cast::<D>();
assert!(
!boxed_aggr.is_null(),
"Internal error - null aggregate pointer"
);
- (*boxed_aggr).finalize(a)
+ let mut ctx = Context { ctx, args: &mut [] };
+ (*boxed_aggr).finalize(&mut ctx, a)
});
let t = match r {
Err(_) => {
@@ -715,7 +762,7 @@ where
};
let r = catch_unwind(|| {
- let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx) as *mut W;
+ let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx).cast::<W>();
assert!(
!boxed_aggr.is_null(),
"Internal error - null aggregate pointer"
@@ -746,7 +793,7 @@ mod test {
#[cfg(feature = "window")]
use crate::functions::WindowAggregate;
use crate::functions::{Aggregate, Context, FunctionFlags};
- use crate::{Connection, Error, Result, NO_PARAMS};
+ use crate::{Connection, Error, Result};
fn half(ctx: &Context<'_>) -> Result<c_double> {
assert_eq!(ctx.len(), 1, "called with unexpected number of arguments");
@@ -755,39 +802,39 @@ mod test {
}
#[test]
- fn test_function_half() {
- let db = Connection::open_in_memory().unwrap();
+ fn test_function_half() -> Result<()> {
+ let db = Connection::open_in_memory()?;
db.create_scalar_function(
"half",
1,
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
half,
- )
- .unwrap();
- let result: Result<f64> = db.query_row("SELECT half(6)", NO_PARAMS, |r| r.get(0));
+ )?;
+ let result: Result<f64> = db.query_row("SELECT half(6)", [], |r| r.get(0));
- assert!((3f64 - result.unwrap()).abs() < EPSILON);
+ assert!((3f64 - result?).abs() < EPSILON);
+ Ok(())
}
#[test]
- fn test_remove_function() {
- let db = Connection::open_in_memory().unwrap();
+ fn test_remove_function() -> Result<()> {
+ let db = Connection::open_in_memory()?;
db.create_scalar_function(
"half",
1,
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
half,
- )
- .unwrap();
- let result: Result<f64> = db.query_row("SELECT half(6)", NO_PARAMS, |r| r.get(0));
- assert!((3f64 - result.unwrap()).abs() < EPSILON);
+ )?;
+ let result: Result<f64> = db.query_row("SELECT half(6)", [], |r| r.get(0));
+ assert!((3f64 - result?).abs() < EPSILON);
- db.remove_function("half", 1).unwrap();
- let result: Result<f64> = db.query_row("SELECT half(6)", NO_PARAMS, |r| r.get(0));
+ db.remove_function("half", 1)?;
+ let result: Result<f64> = db.query_row("SELECT half(6)", [], |r| r.get(0));
assert!(result.is_err());
+ Ok(())
}
- // This implementation of a regexp scalar function uses SQLite's auxilliary data
+ // This implementation of a regexp scalar function uses SQLite's auxiliary data
// (https://www.sqlite.org/c3ref/get_auxdata.html) to avoid recompiling the regular
// expression multiple times within one query.
fn regexp_with_auxilliary(ctx: &Context<'_>) -> Result<bool> {
@@ -811,8 +858,8 @@ mod test {
}
#[test]
- fn test_function_regexp_with_auxilliary() {
- let db = Connection::open_in_memory().unwrap();
+ fn test_function_regexp_with_auxilliary() -> Result<()> {
+ let db = Connection::open_in_memory()?;
db.execute_batch(
"BEGIN;
CREATE TABLE foo (x string);
@@ -820,35 +867,32 @@ mod test {
INSERT INTO foo VALUES ('lXsi');
INSERT INTO foo VALUES ('lisX');
END;",
- )
- .unwrap();
+ )?;
db.create_scalar_function(
"regexp",
2,
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
regexp_with_auxilliary,
- )
- .unwrap();
+ )?;
let result: Result<bool> =
- db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", NO_PARAMS, |r| {
- r.get(0)
- });
+ db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", [], |r| r.get(0));
- assert_eq!(true, result.unwrap());
+ assert!(result?);
let result: Result<i64> = db.query_row(
"SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1",
- NO_PARAMS,
+ [],
|r| r.get(0),
);
- assert_eq!(2, result.unwrap());
+ assert_eq!(2, result?);
+ Ok(())
}
#[test]
- fn test_varargs_function() {
- let db = Connection::open_in_memory().unwrap();
+ fn test_varargs_function() -> Result<()> {
+ let db = Connection::open_in_memory()?;
db.create_scalar_function(
"my_concat",
-1,
@@ -863,50 +907,48 @@ mod test {
Ok(ret)
},
- )
- .unwrap();
+ )?;
for &(expected, query) in &[
("", "SELECT my_concat()"),
("onetwo", "SELECT my_concat('one', 'two')"),
("abc", "SELECT my_concat('a', 'b', 'c')"),
] {
- let result: String = db.query_row(query, NO_PARAMS, |r| r.get(0)).unwrap();
+ let result: String = db.query_row(query, [], |r| r.get(0))?;
assert_eq!(expected, result);
}
+ Ok(())
}
#[test]
- fn test_get_aux_type_checking() {
- let db = Connection::open_in_memory().unwrap();
+ fn test_get_aux_type_checking() -> Result<()> {
+ let db = Connection::open_in_memory()?;
db.create_scalar_function("example", 2, FunctionFlags::default(), |ctx| {
if !ctx.get::<bool>(1)? {
ctx.set_aux::<i64>(0, 100)?;
} else {
assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType));
- assert_eq!(*ctx.get_aux::<i64>(0).unwrap().unwrap(), 100);
+ assert_eq!(*ctx.get_aux::<i64>(0)?.unwrap(), 100);
}
Ok(true)
- })
- .unwrap();
+ })?;
- let res: bool = db
- .query_row(
- "SELECT example(0, i) FROM (SELECT 0 as i UNION SELECT 1)",
- NO_PARAMS,
- |r| r.get(0),
- )
- .unwrap();
+ let res: bool = db.query_row(
+ "SELECT example(0, i) FROM (SELECT 0 as i UNION SELECT 1)",
+ [],
+ |r| r.get(0),
+ )?;
// Doesn't actually matter, we'll assert in the function if there's a problem.
assert!(res);
+ Ok(())
}
struct Sum;
struct Count;
impl Aggregate<i64, Option<i64>> for Sum {
- fn init(&self) -> i64 {
- 0
+ fn init(&self, _: &mut Context<'_>) -> Result<i64> {
+ Ok(0)
}
fn step(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
@@ -914,14 +956,14 @@ mod test {
Ok(())
}
- fn finalize(&self, sum: Option<i64>) -> Result<Option<i64>> {
+ fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<Option<i64>> {
Ok(sum)
}
}
impl Aggregate<i64, i64> for Count {
- fn init(&self) -> i64 {
- 0
+ fn init(&self, _: &mut Context<'_>) -> Result<i64> {
+ Ok(0)
}
fn step(&self, _ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
@@ -929,58 +971,56 @@ mod test {
Ok(())
}
- fn finalize(&self, sum: Option<i64>) -> Result<i64> {
+ fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<i64> {
Ok(sum.unwrap_or(0))
}
}
#[test]
- fn test_sum() {
- let db = Connection::open_in_memory().unwrap();
+ fn test_sum() -> Result<()> {
+ let db = Connection::open_in_memory()?;
db.create_aggregate_function(
"my_sum",
1,
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
Sum,
- )
- .unwrap();
+ )?;
// sum should return NULL when given no columns (contrast with count below)
let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
- let result: Option<i64> = db.query_row(no_result, NO_PARAMS, |r| r.get(0)).unwrap();
+ let result: Option<i64> = db.query_row(no_result, [], |r| r.get(0))?;
assert!(result.is_none());
let single_sum = "SELECT my_sum(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
- let result: i64 = db.query_row(single_sum, NO_PARAMS, |r| r.get(0)).unwrap();
+ let result: i64 = db.query_row(single_sum, [], |r| r.get(0))?;
assert_eq!(4, result);
let dual_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL SELECT \
2, 1)";
- let result: (i64, i64) = db
- .query_row(dual_sum, NO_PARAMS, |r| Ok((r.get(0)?, r.get(1)?)))
- .unwrap();
+ let result: (i64, i64) = db.query_row(dual_sum, [], |r| Ok((r.get(0)?, r.get(1)?)))?;
assert_eq!((4, 2), result);
+ Ok(())
}
#[test]
- fn test_count() {
- let db = Connection::open_in_memory().unwrap();
+ fn test_count() -> Result<()> {
+ let db = Connection::open_in_memory()?;
db.create_aggregate_function(
"my_count",
-1,
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
Count,
- )
- .unwrap();
+ )?;
// count should return 0 when given no columns (contrast with sum above)
let no_result = "SELECT my_count(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
- let result: i64 = db.query_row(no_result, NO_PARAMS, |r| r.get(0)).unwrap();
+ let result: i64 = db.query_row(no_result, [], |r| r.get(0))?;
assert_eq!(result, 0);
let single_sum = "SELECT my_count(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
- let result: i64 = db.query_row(single_sum, NO_PARAMS, |r| r.get(0)).unwrap();
+ let result: i64 = db.query_row(single_sum, [], |r| r.get(0))?;
assert_eq!(2, result);
+ Ok(())
}
#[cfg(feature = "window")]
@@ -997,17 +1037,16 @@ mod test {
#[test]
#[cfg(feature = "window")]
- fn test_window() {
+ fn test_window() -> Result<()> {
use fallible_iterator::FallibleIterator;
- let db = Connection::open_in_memory().unwrap();
+ let db = Connection::open_in_memory()?;
db.create_window_function(
"sumint",
1,
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
Sum,
- )
- .unwrap();
+ )?;
db.execute_batch(
"CREATE TABLE t3(x, y);
INSERT INTO t3 VALUES('a', 4),
@@ -1015,24 +1054,19 @@ mod test {
('c', 3),
('d', 8),
('e', 1);",
- )
- .unwrap();
+ )?;
- let mut stmt = db
- .prepare(
- "SELECT x, sumint(y) OVER (
+ let mut stmt = db.prepare(
+ "SELECT x, sumint(y) OVER (
ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING
) AS sum_y
FROM t3 ORDER BY x;",
- )
- .unwrap();
+ )?;
let results: Vec<(String, i64)> = stmt
- .query(NO_PARAMS)
- .unwrap()
+ .query([])?
.map(|row| Ok((row.get("x")?, row.get("sum_y")?)))
- .collect()
- .unwrap();
+ .collect()?;
let expected = vec![
("a".to_owned(), 9),
("b".to_owned(), 12),
@@ -1041,5 +1075,6 @@ mod test {
("e".to_owned(), 9),
];
assert_eq!(expected, results);
+ Ok(())
}
}