// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_SQLiTE_H_
#define DLIB_SQLiTE_H_
#include "sqlite_abstract.h"
#include <iostream>
#include <limits>
#include <memory>
#include <vector>
#include "../algs.h"
#include <sqlite3.h>
#include "../serialize.h"
// --------------------------------------------------------------------------------------------
namespace dlib
{
// --------------------------------------------------------------------------------------------
struct sqlite_error : public error
{
sqlite_error(const std::string& message): error(message) {}
};
// --------------------------------------------------------------------------------------------
namespace impl
{
struct db_deleter
{
void operator()(
sqlite3* db
)const
{
sqlite3_close(db);
}
};
}
// --------------------------------------------------------------------------------------------
class database : noncopyable
{
public:
database(
)
{
}
database (
const std::string& file
)
{
open(file);
}
bool is_open (
) const
{
return db.get() != 0;
}
void open (
const std::string& file
)
{
filename = file;
sqlite3* ptr = 0;
int status = sqlite3_open(file.c_str(), &ptr);
db.reset(ptr, impl::db_deleter());
if (status != SQLITE_OK)
{
throw sqlite_error(sqlite3_errmsg(db.get()));
}
}
const std::string& get_database_filename (
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(is_open() == true,
"\t std::string database::get_database_filename()"
<< "\n\t The database must be opened before calling this routine."
<< "\n\t this: " << this
);
return filename;
}
inline void exec (
const std::string& sql_statement
);
int64 last_insert_rowid (
) const
{
return sqlite3_last_insert_rowid(db.get());
}
private:
friend class statement;
std::string filename;
std::shared_ptr<sqlite3> db;
};
// --------------------------------------------------------------------------------------------
class statement : noncopyable
{
public:
statement (
database& db_,
const std::string sql_statement
) :
needs_reset(false),
step_status(SQLITE_DONE),
at_first_step(true),
db(db_.db),
stmt(0),
sql_string(sql_statement)
{
// make sure requires clause is not broken
DLIB_ASSERT(db_.is_open() == true,
"\t statement::statement()"
<< "\n\t The database must be opened before calling this routine."
<< "\n\t this: " << this
);
int status = sqlite3_prepare_v2(db.get(),
sql_string.c_str(),
sql_string.size()+1,
&stmt,
NULL);
if (status != SQLITE_OK)
{
sqlite3_finalize(stmt);
throw sqlite_error(sqlite3_errmsg(db.get()));
}
if (stmt == 0)
{
throw sqlite_error("Invalid SQL statement");
}
}
~statement(
)
{
sqlite3_finalize(stmt);
}
void exec(
)
{
reset();
step_status = sqlite3_step(stmt);
needs_reset = true;
if (step_status != SQLITE_DONE && step_status != SQLITE_ROW)
{
if (step_status == SQLITE_ERROR)
throw sqlite_error(sqlite3_errmsg(db.get()));
else if (step_status == SQLITE_BUSY)
throw sqlite_error("statement::exec() failed. SQLITE_BUSY returned");
else
throw sqlite_error("statement::exec() failed.");
}
}
bool move_next (
)
{
if (step_status == SQLITE_ROW)
{
if (at_first_step)
{
at_first_step = false;
return true;
}
else
{
step_status = sqlite3_step(stmt);
if (step_status == SQLITE_DONE)
{
return false;
}
else if (step_status == SQLITE_ROW)
{
return true;
}
else
{
throw sqlite_error(sqlite3_errmsg(db.get()));
}
}
}
else
{
return false;
}
}
unsigned long get_num_columns(
) const
{
if( (at_first_step==false) && (step_status==SQLITE_ROW))
{
return sqlite3_column_count(stmt);
}
else
{
return 0;
}
}
const std::string& get_sql_string (
) const
{
return sql_string;
}
template <typename T>
typename enable_if_c<std::numeric_limits<T>::is_integer>::type get_column (
unsigned long idx,
T& item
) const
{
// unsigned ints won't fit into int all the time so put those into 64bit ints.
if (sizeof(T) < sizeof(int) || (sizeof(T)==sizeof(int) && is_signed_type<T>::value))
item = get_column_as_int(idx);
else
item = get_column_as_int64(idx);
}
void get_column(unsigned long idx, std::string& item) const { item = get_column_as_text(idx); }
void get_column(unsigned long idx, float& item ) const { item = get_column_as_double(idx); }
void get_column(unsigned long idx, double& item ) const { item = get_column_as_double(idx); }
void get_column(unsigned long idx, long double& item) const { item = get_column_as_double(idx); }
template <typename T>
typename disable_if_c<std::numeric_limits<T>::is_integer>::type get_column (
unsigned long idx,
T& item
) const
{
get_column_as_object(idx, item);
}
const std::vector<char> get_column_as_blob (
unsigned long idx
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(idx < get_num_columns(),
"\t std::vector<char> statement::get_column_as_blob()"
<< "\n\t Invalid column index."
<< "\n\t idx: " << idx
<< "\n\t this: " << this
);
const char* data = static_cast<const char*>(sqlite3_column_blob(stmt, idx));
const int size = sqlite3_column_bytes(stmt, idx);
return std::vector<char>(data, data+size);
}
template <typename T>
void get_column_as_object (
unsigned long idx,
T& item
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(idx < get_num_columns(),
"\t void statement::get_column_as_object()"
<< "\n\t Invalid column index."
<< "\n\t idx: " << idx
<< "\n\t this: " << this
);
const char* data = static_cast<const char*>(sqlite3_column_blob(stmt, idx));
const int size = sqlite3_column_bytes(stmt, idx);
std::istringstream sin(std::string(data,size));
deserialize(item, sin);
}
const std::string get_column_as_text (
unsigned long idx
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(idx < get_num_columns(),
"\t std::string statement::get_column_as_text()"
<< "\n\t Invalid column index."
<< "\n\t idx: " << idx
<< "\n\t this: " << this
);
const char* data = reinterpret_cast<const char*>(sqlite3_column_text(stmt, idx));
if (data != 0)
return std::string(data);
else
return std::string();
}
double get_column_as_double (
unsigned long idx
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(idx < get_num_columns(),
"\t double statement::get_column_as_double()"
<< "\n\t Invalid column index."
<< "\n\t idx: " << idx
<< "\n\t this: " << this
);
return sqlite3_column_double(stmt, idx);
}
int get_column_as_int (
unsigned long idx
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(idx < get_num_columns(),
"\t int statement::get_column_as_int()"
<< "\n\t Invalid column index."
<< "\n\t idx: " << idx
<< "\n\t this: " << this
);
return sqlite3_column_int(stmt, idx);
}
int64 get_column_as_int64 (
unsigned long idx
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(idx < get_num_columns(),
"\t int64 statement::get_column_as_int64()"
<< "\n\t Invalid column index."
<< "\n\t idx: " << idx
<< "\n\t this: " << this
);
return sqlite3_column_int64(stmt, idx);
}
const std::string get_column_name (
unsigned long idx
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(idx < get_num_columns(),
"\t std::string statement::get_column_name()"
<< "\n\t Invalid column index."
<< "\n\t idx: " << idx
<< "\n\t this: " << this
);
return std::string(sqlite3_column_name(stmt,idx));
}
unsigned long get_max_parameter_id (
) const
{
return sqlite3_limit(db.get(), SQLITE_LIMIT_VARIABLE_NUMBER, -1);
}
unsigned long get_parameter_id (
const std::string& name
) const
{
return sqlite3_bind_parameter_index(stmt, name.c_str());
}
template <typename T>
typename enable_if_c<std::numeric_limits<T>::is_integer>::type bind (
unsigned long idx,
const T& item
)
{
// unsigned ints won't fit into int all the time so put those into 64bit ints.
if (sizeof(T) < sizeof(int) || (sizeof(T)==sizeof(int) && is_signed_type<T>::value))
bind_int(idx, item);
else
bind_int64(idx, item);
}
void bind(unsigned long idx, const std::string& item) { bind_text(idx, item); }
void bind(unsigned long idx, const float& item ) { bind_double(idx, item); }
void bind(unsigned long idx, const double& item ) { bind_double(idx, item); }
void bind(unsigned long idx, const long double& item) { bind_double(idx, item); }
template <typename T>
typename disable_if_c<std::numeric_limits<T>::is_integer>::type bind (
unsigned long idx,
const T& item
)
{
bind_object(idx, item);
}
void bind_blob (
unsigned long parameter_id,
const std::vector<char>& item
)
{
// make sure requires clause is not broken
DLIB_ASSERT(1 <= parameter_id && parameter_id <= get_max_parameter_id(),
"\t void statement::bind_blob()"
<< "\n\t Invalid parameter id."
<< "\n\t parameter_id: " << parameter_id
<< "\n\t get_max_parameter_id(): " << get_max_parameter_id()
<< "\n\t this: " << this
);
reset();
int status = sqlite3_bind_blob(stmt, parameter_id, &item[0], item.size(), SQLITE_TRANSIENT);
if (status != SQLITE_OK)
{
throw sqlite_error(sqlite3_errmsg(db.get()));
}
}
template <typename T>
void bind_object (
unsigned long parameter_id,
const T& item
)
{
// make sure requires clause is not broken
DLIB_ASSERT(1 <= parameter_id && parameter_id <= get_max_parameter_id(),
"\t void statement::bind_object()"
<< "\n\t Invalid parameter id."
<< "\n\t parameter_id: " << parameter_id
<< "\n\t get_max_parameter_id(): " << get_max_parameter_id()
<< "\n\t this: " << this
);
reset();
std::ostringstream sout;
serialize(item, sout);
const std::string& str = sout.str();
int status = sqlite3_bind_blob(stmt, parameter_id, str.data(), str.size(), SQLITE_TRANSIENT);
if (status != SQLITE_OK)
{
throw sqlite_error(sqlite3_errmsg(db.get()));
}
}
void bind_double (
unsigned long parameter_id,
const double& item
)
{
// make sure requires clause is not broken
DLIB_ASSERT(1 <= parameter_id && parameter_id <= get_max_parameter_id(),
"\t void statement::bind_double()"
<< "\n\t Invalid parameter id."
<< "\n\t parameter_id: " << parameter_id
<< "\n\t get_max_parameter_id(): " << get_max_parameter_id()
<< "\n\t this: " << this
);
reset();
int status = sqlite3_bind_double(stmt, parameter_id, item);
if (status != SQLITE_OK)
{
throw sqlite_error(sqlite3_errmsg(db.get()));
}
}
void bind_int (
unsigned long parameter_id,
const int& item
)
{
// make sure requires clause is not broken
DLIB_ASSERT(1 <= parameter_id && parameter_id <= get_max_parameter_id(),
"\t void statement::bind_int()"
<< "\n\t Invalid parameter id."
<< "\n\t parameter_id: " << parameter_id
<< "\n\t get_max_parameter_id(): " << get_max_parameter_id()
<< "\n\t this: " << this
);
reset();
int status = sqlite3_bind_int(stmt, parameter_id, item);
if (status != SQLITE_OK)
{
throw sqlite_error(sqlite3_errmsg(db.get()));
}
}
void bind_int64 (
unsigned long parameter_id,
const int64& item
)
{
// make sure requires clause is not broken
DLIB_ASSERT(1 <= parameter_id && parameter_id <= get_max_parameter_id(),
"\t void statement::bind_int64()"
<< "\n\t Invalid parameter id."
<< "\n\t parameter_id: " << parameter_id
<< "\n\t get_max_parameter_id(): " << get_max_parameter_id()
<< "\n\t this: " << this
);
reset();
int status = sqlite3_bind_int64(stmt, parameter_id, item);
if (status != SQLITE_OK)
{
throw sqlite_error(sqlite3_errmsg(db.get()));
}
}
void bind_null (
unsigned long parameter_id
)
{
// make sure requires clause is not broken
DLIB_ASSERT(1 <= parameter_id && parameter_id <= get_max_parameter_id(),
"\t void statement::bind_null()"
<< "\n\t Invalid parameter id."
<< "\n\t parameter_id: " << parameter_id
<< "\n\t get_max_parameter_id(): " << get_max_parameter_id()
<< "\n\t this: " << this
);
reset();
int status = sqlite3_bind_null(stmt, parameter_id);
if (status != SQLITE_OK)
{
throw sqlite_error(sqlite3_errmsg(db.get()));
}
}
void bind_text (
unsigned long parameter_id,
const std::string& item
)
{
// make sure requires clause is not broken
DLIB_ASSERT(1 <= parameter_id && parameter_id <= get_max_parameter_id(),
"\t void statement::bind_text()"
<< "\n\t Invalid parameter id."
<< "\n\t parameter_id: " << parameter_id
<< "\n\t get_max_parameter_id(): " << get_max_parameter_id()
<< "\n\t this: " << this
);
reset();
int status = sqlite3_bind_text(stmt, parameter_id, item.c_str(), -1, SQLITE_TRANSIENT);
if (status != SQLITE_OK)
{
throw sqlite_error(sqlite3_errmsg(db.get()));
}
}
private:
void reset()
{
if (needs_reset)
{
if (sqlite3_reset(stmt) != SQLITE_OK)
{
step_status = SQLITE_DONE;
throw sqlite_error(sqlite3_errmsg(db.get()));
}
needs_reset = false;
step_status = SQLITE_DONE;
at_first_step = true;
}
}
bool needs_reset; // true if sqlite3_step() has been called more recently than sqlite3_reset()
int step_status;
bool at_first_step;
std::shared_ptr<sqlite3> db;
sqlite3_stmt* stmt;
std::string sql_string;
};
// --------------------------------------------------------------------------------------------
void database::
exec (
const std::string& sql_statement
)
{
// make sure requires clause is not broken
DLIB_ASSERT(is_open() == true,
"\t void database::exec()"
<< "\n\t The database must be opened before calling this routine."
<< "\n\t this: " << this
);
statement(*this, sql_statement).exec();
}
// --------------------------------------------------------------------------------------------
}
#endif // DLIB_SQLiTE_H_