diff options
Diffstat (limited to 'src/functions.rs')
-rw-r--r-- | src/functions.rs | 379 |
1 files changed, 207 insertions, 172 deletions
diff --git a/src/functions.rs b/src/functions.rs index 3531391..e613182 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -1,17 +1,17 @@ -//! `feature = "functions"` Create or redefine SQL functions. +//! Create or redefine SQL functions. //! //! # Example //! //! Adding a `regexp` function to a connection in which compiled regular //! expressions are cached in a `HashMap`. For an alternative implementation -//! that uses SQLite's [Function Auxilliary Data](https://www.sqlite.org/c3ref/get_auxdata.html) interface +//! that uses SQLite's [Function Auxiliary Data](https://www.sqlite.org/c3ref/get_auxdata.html) interface //! to avoid recompiling regular expressions, see the unit tests for this //! module. //! //! ```rust //! use regex::Regex; //! use rusqlite::functions::FunctionFlags; -//! use rusqlite::{Connection, Error, Result, NO_PARAMS}; +//! use rusqlite::{Connection, Error, Result}; //! use std::sync::Arc; //! type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>; //! @@ -22,10 +22,9 @@ //! FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, //! move |ctx| { //! assert_eq!(ctx.len(), 2, "called with unexpected number of arguments"); -//! let regexp: Arc<Regex> = ctx -//! .get_or_create_aux(0, |vr| -> Result<_, BoxError> { -//! Ok(Regex::new(vr.as_str()?)?) -//! })?; +//! let regexp: Arc<Regex> = ctx.get_or_create_aux(0, |vr| -> Result<_, BoxError> { +//! Ok(Regex::new(vr.as_str()?)?) +//! })?; //! let is_match = { //! let text = ctx //! .get_raw(1) @@ -44,17 +43,18 @@ //! let db = Connection::open_in_memory()?; //! add_regexp_function(&db)?; //! -//! let is_match: bool = db.query_row( -//! "SELECT regexp('[aeiou]*', 'aaaaeeeiii')", -//! NO_PARAMS, -//! |row| row.get(0), -//! )?; +//! let is_match: bool = +//! db.query_row("SELECT regexp('[aeiou]*', 'aaaaeeeiii')", [], |row| { +//! row.get(0) +//! })?; //! //! assert!(is_match); //! Ok(()) //! } //! ``` use std::any::Any; +use std::marker::PhantomData; +use std::ops::Deref; use std::os::raw::{c_int, c_void}; use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe}; use std::ptr; @@ -84,27 +84,24 @@ unsafe fn report_error(ctx: *mut sqlite3_context, err: &Error) { ffi::SQLITE_CONSTRAINT } - match *err { - Error::SqliteFailure(ref err, ref s) => { - ffi::sqlite3_result_error_code(ctx, err.extended_code); - if let Some(Ok(cstr)) = s.as_ref().map(|s| str_to_cstring(s)) { - ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); - } + if let Error::SqliteFailure(ref err, ref s) = *err { + ffi::sqlite3_result_error_code(ctx, err.extended_code); + if let Some(Ok(cstr)) = s.as_ref().map(|s| str_to_cstring(s)) { + ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); } - _ => { - ffi::sqlite3_result_error_code(ctx, constraint_error_code()); - if let Ok(cstr) = str_to_cstring(&err.to_string()) { - ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); - } + } else { + ffi::sqlite3_result_error_code(ctx, constraint_error_code()); + if let Ok(cstr) = str_to_cstring(&err.to_string()) { + ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); } } } unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) { - drop(Box::from_raw(p as *mut T)); + drop(Box::from_raw(p.cast::<T>())); } -/// `feature = "functions"` Context is a wrapper for the SQLite function +/// Context is a wrapper for the SQLite function /// evaluation context. pub struct Context<'a> { ctx: *mut sqlite3_context, @@ -113,11 +110,15 @@ pub struct Context<'a> { impl Context<'_> { /// Returns the number of arguments to the function. + #[inline] + #[must_use] pub fn len(&self) -> usize { self.args.len() } /// Returns `true` when there is no argument. + #[inline] + #[must_use] pub fn is_empty(&self) -> bool { self.args.is_empty() } @@ -126,7 +127,8 @@ impl Context<'_> { /// /// # Failure /// - /// Will panic if `idx` is greater than or equal to `self.len()`. + /// Will panic if `idx` is greater than or equal to + /// [`self.len()`](Context::len). /// /// Will return Err if the underlying SQLite type cannot be converted to a /// `T`. @@ -141,12 +143,7 @@ impl Context<'_> { FromSqlError::Other(err) => { Error::FromSqlConversionFailure(idx, value.data_type(), err) } - #[cfg(feature = "i128_blob")] - FromSqlError::InvalidI128Size(_) => { - Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err)) - } - #[cfg(feature = "uuid")] - FromSqlError::InvalidUuidSize(_) => { + FromSqlError::InvalidBlobSize { .. } => { Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err)) } }) @@ -156,17 +153,21 @@ impl Context<'_> { /// /// # Failure /// - /// Will panic if `idx` is greater than or equal to `self.len()`. + /// Will panic if `idx` is greater than or equal to + /// [`self.len()`](Context::len). + #[inline] + #[must_use] pub fn get_raw(&self, idx: usize) -> ValueRef<'_> { let arg = self.args[idx]; unsafe { ValueRef::from_value(arg) } } - /// Fetch or insert the the auxilliary data associated with a particular + /// Fetch or insert the auxiliary data associated with a particular /// parameter. This is intended to be an easier-to-use way of fetching it - /// compared to calling `get_aux` and `set_aux` separately. + /// compared to calling [`get_aux`](Context::get_aux) and + /// [`set_aux`](Context::set_aux) separately. /// - /// See https://www.sqlite.org/c3ref/get_auxdata.html for a discussion of + /// See `https://www.sqlite.org/c3ref/get_auxdata.html` for a discussion of /// this feature, or the unit tests of this module for an example. pub fn get_or_create_aux<T, E, F>(&self, arg: c_int, func: F) -> Result<Arc<T>> where @@ -185,8 +186,8 @@ impl Context<'_> { } } - /// Sets the auxilliary data associated with a particular parameter. See - /// https://www.sqlite.org/c3ref/get_auxdata.html for a discussion of + /// Sets the auxiliary data associated with a particular parameter. See + /// `https://www.sqlite.org/c3ref/get_auxdata.html` for a discussion of /// this feature, or the unit tests of this module for an example. pub fn set_aux<T: Send + Sync + 'static>(&self, arg: c_int, value: T) -> Result<Arc<T>> { let orig: Arc<T> = Arc::new(value); @@ -197,17 +198,17 @@ impl Context<'_> { ffi::sqlite3_set_auxdata( self.ctx, arg, - raw as *mut _, + raw.cast(), Some(free_boxed_value::<AuxInner>), - ) + ); }; Ok(orig) } - /// Gets the auxilliary data that was associated with a given parameter via - /// `set_aux`. Returns `Ok(None)` if no data has been associated, and - /// Ok(Some(v)) if it has. Returns an error if the requested type does not - /// match. + /// Gets the auxiliary data that was associated with a given parameter via + /// [`set_aux`](Context::set_aux). Returns `Ok(None)` if no data has been + /// associated, and Ok(Some(v)) if it has. Returns an error if the + /// requested type does not match. pub fn get_aux<T: Send + Sync + 'static>(&self, arg: c_int) -> Result<Option<Arc<T>>> { let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *const AuxInner }; if p.is_null() { @@ -219,11 +220,42 @@ impl Context<'_> { .map_err(|_| Error::GetAuxWrongType) } } + + /// Get the db connection handle via [sqlite3_context_db_handle](https://www.sqlite.org/c3ref/context_db_handle.html) + /// + /// # Safety + /// + /// This function is marked unsafe because there is a potential for other + /// references to the connection to be sent across threads, [see this comment](https://github.com/rusqlite/rusqlite/issues/643#issuecomment-640181213). + pub unsafe fn get_connection(&self) -> Result<ConnectionRef<'_>> { + let handle = ffi::sqlite3_context_db_handle(self.ctx); + Ok(ConnectionRef { + conn: Connection::from_handle(handle)?, + phantom: PhantomData, + }) + } +} + +/// A reference to a connection handle with a lifetime bound to something. +pub struct ConnectionRef<'ctx> { + // comes from Connection::from_handle(sqlite3_context_db_handle(...)) + // and is non-owning + conn: Connection, + phantom: PhantomData<&'ctx Context<'ctx>>, +} + +impl Deref for ConnectionRef<'_> { + type Target = Connection; + + #[inline] + fn deref(&self) -> &Connection { + &self.conn + } } type AuxInner = Arc<dyn Any + Send + Sync + 'static>; -/// `feature = "functions"` Aggregate is the callback interface for user-defined +/// Aggregate is the callback interface for user-defined /// aggregate function. /// /// `A` is the type of the aggregation context and `T` is the type of the final @@ -234,25 +266,31 @@ where T: ToSql, { /// Initializes the aggregation context. Will be called prior to the first - /// call to `step()` to set up the context for an invocation of the - /// function. (Note: `init()` will not be called if there are no rows.) - fn init(&self) -> A; + /// call to [`step()`](Aggregate::step) to set up the context for an + /// invocation of the function. (Note: `init()` will not be called if + /// there are no rows.) + fn init(&self, _: &mut Context<'_>) -> Result<A>; /// "step" function called once for each row in an aggregate group. May be /// called 0 times if there are no rows. fn step(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>; /// Computes and returns the final result. Will be called exactly once for - /// each invocation of the function. If `step()` was called at least - /// once, will be given `Some(A)` (the same `A` as was created by - /// `init` and given to `step`); if `step()` was not called (because - /// the function is running against 0 rows), will be given `None`. - fn finalize(&self, _: Option<A>) -> Result<T>; + /// each invocation of the function. If [`step()`](Aggregate::step) was + /// called at least once, will be given `Some(A)` (the same `A` as was + /// created by [`init`](Aggregate::init) and given to + /// [`step`](Aggregate::step)); if [`step()`](Aggregate::step) was not + /// called (because the function is running against 0 rows), will be + /// given `None`. + /// + /// The passed context will have no arguments. + fn finalize(&self, _: &mut Context<'_>, _: Option<A>) -> Result<T>; } -/// `feature = "window"` WindowAggregate is the callback interface for +/// `WindowAggregate` is the callback interface for /// user-defined aggregate window function. #[cfg(feature = "window")] +#[cfg_attr(docsrs, doc(cfg(feature = "window")))] pub trait WindowAggregate<A, T>: Aggregate<A, T> where A: RefUnwindSafe + UnwindSafe, @@ -292,13 +330,14 @@ bitflags::bitflags! { } impl Default for FunctionFlags { + #[inline] fn default() -> FunctionFlags { FunctionFlags::SQLITE_UTF8 } } impl Connection { - /// `feature = "functions"` Attach a user-defined scalar function to + /// Attach a user-defined scalar function to /// this database connection. /// /// `fn_name` is the name the function will be accessible from SQL. @@ -307,12 +346,13 @@ impl Connection { /// given the same input, `deterministic` should be `true`. /// /// The function will remain available until the connection is closed or - /// until it is explicitly removed via `remove_function`. + /// until it is explicitly removed via + /// [`remove_function`](Connection::remove_function). /// /// # Example /// /// ```rust - /// # use rusqlite::{Connection, Result, NO_PARAMS}; + /// # use rusqlite::{Connection, Result}; /// # use rusqlite::functions::FunctionFlags; /// fn scalar_function_example(db: Connection) -> Result<()> { /// db.create_scalar_function( @@ -325,7 +365,7 @@ impl Connection { /// }, /// )?; /// - /// let six_halved: f64 = db.query_row("SELECT halve(6)", NO_PARAMS, |r| r.get(0))?; + /// let six_halved: f64 = db.query_row("SELECT halve(6)", [], |r| r.get(0))?; /// assert_eq!(six_halved, 3f64); /// Ok(()) /// } @@ -334,6 +374,7 @@ impl Connection { /// # Failure /// /// Will return Err if the function could not be attached to the connection. + #[inline] pub fn create_scalar_function<F, T>( &self, fn_name: &str, @@ -350,12 +391,13 @@ impl Connection { .create_scalar_function(fn_name, n_arg, flags, x_func) } - /// `feature = "functions"` Attach a user-defined aggregate function to this + /// Attach a user-defined aggregate function to this /// database connection. /// /// # Failure /// /// Will return Err if the function could not be attached to the connection. + #[inline] pub fn create_aggregate_function<A, D, T>( &self, fn_name: &str, @@ -365,7 +407,7 @@ impl Connection { ) -> Result<()> where A: RefUnwindSafe + UnwindSafe, - D: Aggregate<A, T>, + D: Aggregate<A, T> + 'static, T: ToSql, { self.db @@ -373,12 +415,14 @@ impl Connection { .create_aggregate_function(fn_name, n_arg, flags, aggr) } - /// `feature = "window"` Attach a user-defined aggregate window function to + /// Attach a user-defined aggregate window function to /// this database connection. /// - /// See https://sqlite.org/windowfunctions.html#udfwinfunc for more + /// See `https://sqlite.org/windowfunctions.html#udfwinfunc` for more /// information. #[cfg(feature = "window")] + #[cfg_attr(docsrs, doc(cfg(feature = "window")))] + #[inline] pub fn create_window_function<A, W, T>( &self, fn_name: &str, @@ -388,7 +432,7 @@ impl Connection { ) -> Result<()> where A: RefUnwindSafe + UnwindSafe, - W: WindowAggregate<A, T>, + W: WindowAggregate<A, T> + 'static, T: ToSql, { self.db @@ -396,15 +440,17 @@ impl Connection { .create_window_function(fn_name, n_arg, flags, aggr) } - /// `feature = "functions"` Removes a user-defined function from this + /// Removes a user-defined function from this /// database connection. /// /// `fn_name` and `n_arg` should match the name and number of arguments - /// given to `create_scalar_function` or `create_aggregate_function`. + /// given to [`create_scalar_function`](Connection::create_scalar_function) + /// or [`create_aggregate_function`](Connection::create_aggregate_function). /// /// # Failure /// /// Will return Err if the function could not be removed. + #[inline] pub fn remove_function(&self, fn_name: &str, n_arg: c_int) -> Result<()> { self.db.borrow_mut().remove_function(fn_name, n_arg) } @@ -431,7 +477,7 @@ impl InnerConnection { T: ToSql, { let r = catch_unwind(|| { - let boxed_f: *mut F = ffi::sqlite3_user_data(ctx) as *mut F; + let boxed_f: *mut F = ffi::sqlite3_user_data(ctx).cast::<F>(); assert!(!boxed_f.is_null(), "Internal error - null function pointer"); let ctx = Context { ctx, @@ -463,7 +509,7 @@ impl InnerConnection { c_name.as_ptr(), n_arg, flags.bits(), - boxed_f as *mut c_void, + boxed_f.cast::<c_void>(), Some(call_boxed_closure::<F, T>), None, None, @@ -482,7 +528,7 @@ impl InnerConnection { ) -> Result<()> where A: RefUnwindSafe + UnwindSafe, - D: Aggregate<A, T>, + D: Aggregate<A, T> + 'static, T: ToSql, { let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr)); @@ -493,7 +539,7 @@ impl InnerConnection { c_name.as_ptr(), n_arg, flags.bits(), - boxed_aggr as *mut c_void, + boxed_aggr.cast::<c_void>(), None, Some(call_boxed_step::<A, D, T>), Some(call_boxed_final::<A, D, T>), @@ -513,7 +559,7 @@ impl InnerConnection { ) -> Result<()> where A: RefUnwindSafe + UnwindSafe, - W: WindowAggregate<A, T>, + W: WindowAggregate<A, T> + 'static, T: ToSql, { let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr)); @@ -524,7 +570,7 @@ impl InnerConnection { c_name.as_ptr(), n_arg, flags.bits(), - boxed_aggr as *mut c_void, + boxed_aggr.cast::<c_void>(), Some(call_boxed_step::<A, W, T>), Some(call_boxed_final::<A, W, T>), Some(call_boxed_value::<A, W, T>), @@ -571,27 +617,28 @@ unsafe extern "C" fn call_boxed_step<A, D, T>( D: Aggregate<A, T>, T: ToSql, { - let pac = match aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) { - Some(pac) => pac, - None => { - ffi::sqlite3_result_error_nomem(ctx); - return; - } + let pac = if let Some(pac) = aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) { + pac + } else { + ffi::sqlite3_result_error_nomem(ctx); + return; }; let r = catch_unwind(|| { - let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D; + let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx).cast::<D>(); assert!( !boxed_aggr.is_null(), "Internal error - null aggregate pointer" ); - if (*pac as *mut A).is_null() { - *pac = Box::into_raw(Box::new((*boxed_aggr).init())); - } let mut ctx = Context { ctx, args: slice::from_raw_parts(argv, argc as usize), }; + + if (*pac as *mut A).is_null() { + *pac = Box::into_raw(Box::new((*boxed_aggr).init(&mut ctx)?)); + } + (*boxed_aggr).step(&mut ctx, &mut **pac) }); let r = match r { @@ -617,16 +664,15 @@ unsafe extern "C" fn call_boxed_inverse<A, W, T>( W: WindowAggregate<A, T>, T: ToSql, { - let pac = match aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) { - Some(pac) => pac, - None => { - ffi::sqlite3_result_error_nomem(ctx); - return; - } + let pac = if let Some(pac) = aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) { + pac + } else { + ffi::sqlite3_result_error_nomem(ctx); + return; }; let r = catch_unwind(|| { - let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx) as *mut W; + let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx).cast::<W>(); assert!( !boxed_aggr.is_null(), "Internal error - null aggregate pointer" @@ -671,12 +717,13 @@ where }; let r = catch_unwind(|| { - let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D; + let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx).cast::<D>(); assert!( !boxed_aggr.is_null(), "Internal error - null aggregate pointer" ); - (*boxed_aggr).finalize(a) + let mut ctx = Context { ctx, args: &mut [] }; + (*boxed_aggr).finalize(&mut ctx, a) }); let t = match r { Err(_) => { @@ -715,7 +762,7 @@ where }; let r = catch_unwind(|| { - let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx) as *mut W; + let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx).cast::<W>(); assert!( !boxed_aggr.is_null(), "Internal error - null aggregate pointer" @@ -746,7 +793,7 @@ mod test { #[cfg(feature = "window")] use crate::functions::WindowAggregate; use crate::functions::{Aggregate, Context, FunctionFlags}; - use crate::{Connection, Error, Result, NO_PARAMS}; + use crate::{Connection, Error, Result}; fn half(ctx: &Context<'_>) -> Result<c_double> { assert_eq!(ctx.len(), 1, "called with unexpected number of arguments"); @@ -755,39 +802,39 @@ mod test { } #[test] - fn test_function_half() { - let db = Connection::open_in_memory().unwrap(); + fn test_function_half() -> Result<()> { + let db = Connection::open_in_memory()?; db.create_scalar_function( "half", 1, FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, half, - ) - .unwrap(); - let result: Result<f64> = db.query_row("SELECT half(6)", NO_PARAMS, |r| r.get(0)); + )?; + let result: Result<f64> = db.query_row("SELECT half(6)", [], |r| r.get(0)); - assert!((3f64 - result.unwrap()).abs() < EPSILON); + assert!((3f64 - result?).abs() < EPSILON); + Ok(()) } #[test] - fn test_remove_function() { - let db = Connection::open_in_memory().unwrap(); + fn test_remove_function() -> Result<()> { + let db = Connection::open_in_memory()?; db.create_scalar_function( "half", 1, FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, half, - ) - .unwrap(); - let result: Result<f64> = db.query_row("SELECT half(6)", NO_PARAMS, |r| r.get(0)); - assert!((3f64 - result.unwrap()).abs() < EPSILON); + )?; + let result: Result<f64> = db.query_row("SELECT half(6)", [], |r| r.get(0)); + assert!((3f64 - result?).abs() < EPSILON); - db.remove_function("half", 1).unwrap(); - let result: Result<f64> = db.query_row("SELECT half(6)", NO_PARAMS, |r| r.get(0)); + db.remove_function("half", 1)?; + let result: Result<f64> = db.query_row("SELECT half(6)", [], |r| r.get(0)); assert!(result.is_err()); + Ok(()) } - // This implementation of a regexp scalar function uses SQLite's auxilliary data + // This implementation of a regexp scalar function uses SQLite's auxiliary data // (https://www.sqlite.org/c3ref/get_auxdata.html) to avoid recompiling the regular // expression multiple times within one query. fn regexp_with_auxilliary(ctx: &Context<'_>) -> Result<bool> { @@ -811,8 +858,8 @@ mod test { } #[test] - fn test_function_regexp_with_auxilliary() { - let db = Connection::open_in_memory().unwrap(); + fn test_function_regexp_with_auxilliary() -> Result<()> { + let db = Connection::open_in_memory()?; db.execute_batch( "BEGIN; CREATE TABLE foo (x string); @@ -820,35 +867,32 @@ mod test { INSERT INTO foo VALUES ('lXsi'); INSERT INTO foo VALUES ('lisX'); END;", - ) - .unwrap(); + )?; db.create_scalar_function( "regexp", 2, FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, regexp_with_auxilliary, - ) - .unwrap(); + )?; let result: Result<bool> = - db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", NO_PARAMS, |r| { - r.get(0) - }); + db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", [], |r| r.get(0)); - assert_eq!(true, result.unwrap()); + assert!(result?); let result: Result<i64> = db.query_row( "SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1", - NO_PARAMS, + [], |r| r.get(0), ); - assert_eq!(2, result.unwrap()); + assert_eq!(2, result?); + Ok(()) } #[test] - fn test_varargs_function() { - let db = Connection::open_in_memory().unwrap(); + fn test_varargs_function() -> Result<()> { + let db = Connection::open_in_memory()?; db.create_scalar_function( "my_concat", -1, @@ -863,50 +907,48 @@ mod test { Ok(ret) }, - ) - .unwrap(); + )?; for &(expected, query) in &[ ("", "SELECT my_concat()"), ("onetwo", "SELECT my_concat('one', 'two')"), ("abc", "SELECT my_concat('a', 'b', 'c')"), ] { - let result: String = db.query_row(query, NO_PARAMS, |r| r.get(0)).unwrap(); + let result: String = db.query_row(query, [], |r| r.get(0))?; assert_eq!(expected, result); } + Ok(()) } #[test] - fn test_get_aux_type_checking() { - let db = Connection::open_in_memory().unwrap(); + fn test_get_aux_type_checking() -> Result<()> { + let db = Connection::open_in_memory()?; db.create_scalar_function("example", 2, FunctionFlags::default(), |ctx| { if !ctx.get::<bool>(1)? { ctx.set_aux::<i64>(0, 100)?; } else { assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType)); - assert_eq!(*ctx.get_aux::<i64>(0).unwrap().unwrap(), 100); + assert_eq!(*ctx.get_aux::<i64>(0)?.unwrap(), 100); } Ok(true) - }) - .unwrap(); + })?; - let res: bool = db - .query_row( - "SELECT example(0, i) FROM (SELECT 0 as i UNION SELECT 1)", - NO_PARAMS, - |r| r.get(0), - ) - .unwrap(); + let res: bool = db.query_row( + "SELECT example(0, i) FROM (SELECT 0 as i UNION SELECT 1)", + [], + |r| r.get(0), + )?; // Doesn't actually matter, we'll assert in the function if there's a problem. assert!(res); + Ok(()) } struct Sum; struct Count; impl Aggregate<i64, Option<i64>> for Sum { - fn init(&self) -> i64 { - 0 + fn init(&self, _: &mut Context<'_>) -> Result<i64> { + Ok(0) } fn step(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> { @@ -914,14 +956,14 @@ mod test { Ok(()) } - fn finalize(&self, sum: Option<i64>) -> Result<Option<i64>> { + fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<Option<i64>> { Ok(sum) } } impl Aggregate<i64, i64> for Count { - fn init(&self) -> i64 { - 0 + fn init(&self, _: &mut Context<'_>) -> Result<i64> { + Ok(0) } fn step(&self, _ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> { @@ -929,58 +971,56 @@ mod test { Ok(()) } - fn finalize(&self, sum: Option<i64>) -> Result<i64> { + fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<i64> { Ok(sum.unwrap_or(0)) } } #[test] - fn test_sum() { - let db = Connection::open_in_memory().unwrap(); + fn test_sum() -> Result<()> { + let db = Connection::open_in_memory()?; db.create_aggregate_function( "my_sum", 1, FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, Sum, - ) - .unwrap(); + )?; // sum should return NULL when given no columns (contrast with count below) let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)"; - let result: Option<i64> = db.query_row(no_result, NO_PARAMS, |r| r.get(0)).unwrap(); + let result: Option<i64> = db.query_row(no_result, [], |r| r.get(0))?; assert!(result.is_none()); let single_sum = "SELECT my_sum(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)"; - let result: i64 = db.query_row(single_sum, NO_PARAMS, |r| r.get(0)).unwrap(); + let result: i64 = db.query_row(single_sum, [], |r| r.get(0))?; assert_eq!(4, result); let dual_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL SELECT \ 2, 1)"; - let result: (i64, i64) = db - .query_row(dual_sum, NO_PARAMS, |r| Ok((r.get(0)?, r.get(1)?))) - .unwrap(); + let result: (i64, i64) = db.query_row(dual_sum, [], |r| Ok((r.get(0)?, r.get(1)?)))?; assert_eq!((4, 2), result); + Ok(()) } #[test] - fn test_count() { - let db = Connection::open_in_memory().unwrap(); + fn test_count() -> Result<()> { + let db = Connection::open_in_memory()?; db.create_aggregate_function( "my_count", -1, FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, Count, - ) - .unwrap(); + )?; // count should return 0 when given no columns (contrast with sum above) let no_result = "SELECT my_count(i) FROM (SELECT 2 AS i WHERE 1 <> 1)"; - let result: i64 = db.query_row(no_result, NO_PARAMS, |r| r.get(0)).unwrap(); + let result: i64 = db.query_row(no_result, [], |r| r.get(0))?; assert_eq!(result, 0); let single_sum = "SELECT my_count(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)"; - let result: i64 = db.query_row(single_sum, NO_PARAMS, |r| r.get(0)).unwrap(); + let result: i64 = db.query_row(single_sum, [], |r| r.get(0))?; assert_eq!(2, result); + Ok(()) } #[cfg(feature = "window")] @@ -997,17 +1037,16 @@ mod test { #[test] #[cfg(feature = "window")] - fn test_window() { + fn test_window() -> Result<()> { use fallible_iterator::FallibleIterator; - let db = Connection::open_in_memory().unwrap(); + let db = Connection::open_in_memory()?; db.create_window_function( "sumint", 1, FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, Sum, - ) - .unwrap(); + )?; db.execute_batch( "CREATE TABLE t3(x, y); INSERT INTO t3 VALUES('a', 4), @@ -1015,24 +1054,19 @@ mod test { ('c', 3), ('d', 8), ('e', 1);", - ) - .unwrap(); + )?; - let mut stmt = db - .prepare( - "SELECT x, sumint(y) OVER ( + let mut stmt = db.prepare( + "SELECT x, sumint(y) OVER ( ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING ) AS sum_y FROM t3 ORDER BY x;", - ) - .unwrap(); + )?; let results: Vec<(String, i64)> = stmt - .query(NO_PARAMS) - .unwrap() + .query([])? .map(|row| Ok((row.get("x")?, row.get("sum_y")?))) - .collect() - .unwrap(); + .collect()?; let expected = vec![ ("a".to_owned(), 9), ("b".to_owned(), 12), @@ -1041,5 +1075,6 @@ mod test { ("e".to_owned(), 9), ]; assert_eq!(expected, results); + Ok(()) } } |