// Copyright (C) 2008 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include <sstream>
#include <string>
#include <cstdlib>
#include <ctime>
#include <dlib/misc_api.h>
#include <dlib/threads.h>
#include <dlib/any.h>
#include "tester.h"
namespace
{
using namespace test;
using namespace dlib;
using namespace std;
logger dlog("test.thread_pool");
struct some_struct : noncopyable
{
float val;
};
int global_var = 0;
struct add_functor
{
add_functor() { var = 1;}
add_functor(int v):var(v) {}
template <typename T, typename U, typename V>
void operator()(T a, U b, V& res)
{
dlib::sleep(20);
res = a + b;
}
void set_global_var() { global_var = 9; }
void set_global_var_const() const { global_var = 9; }
void set_global_var_arg1(int val) { global_var = val; }
void set_global_var_const_arg1(int val) const { global_var = val; }
void set_global_var_arg2(int val, int val2) { global_var = val+val2; }
void set_global_var_const_arg2(int val, int val2) const { global_var = val+val2; }
void operator()()
{
global_var = 9;
}
// use an any just so that if this object goes out of scope
// then var will get all messed up.
any var;
void operator()(int& a) { dlib::sleep(100); a = var.get<int>(); }
void operator()(int& a, int& b) { dlib::sleep(100); a = var.get<int>(); b = 2; }
void operator()(int& a, int& b, int& c) { dlib::sleep(100); a = var.get<int>(); b = 2; c = 3; }
void operator()(int& a, int& b, int& c, int& d) { dlib::sleep(100); a = var.get<int>(); b = 2; c = 3; d = 4; }
};
void set_global_var() { global_var = 9; }
void gset_struct_to_zero (some_struct& a) { a.val = 0; }
void gset_to_zero (int& a) { a = 0; }
void gincrement (int& a) { ++a; }
void gadd (int a, const int& b, int& res) { dlib::sleep(20); res = a + b; }
void gadd1(int& a, int& res) { res += a; }
void gadd2 (int c, int a, const int& b, int& res) { dlib::sleep(20); res = a + b + c; }
class thread_pool_tester : public tester
{
public:
thread_pool_tester (
) :
tester ("test_thread_pool",
"Runs tests on the thread_pool component.")
{}
void perform_test (
)
{
add_functor f;
for (int num_threads= 0; num_threads < 4; ++num_threads)
{
dlib::future<int> a, b, c, res, d;
thread_pool tp(num_threads);
print_spinner();
dlib::future<some_struct> obj;
for (int i = 0; i < 4; ++i)
{
a = 1;
b = 2;
c = 3;
res = 4;
DLIB_TEST(a==a);
DLIB_TEST(a!=b);
DLIB_TEST(a==1);
tp.add_task(gset_to_zero, a);
tp.add_task(gset_to_zero, b);
tp.add_task(*this, &thread_pool_tester::set_to_zero, c);
tp.add_task(gset_to_zero, res);
DLIB_TEST(a == 0);
DLIB_TEST(b == 0);
DLIB_TEST(c == 0);
DLIB_TEST(res == 0);
tp.add_task(gincrement, a);
tp.add_task(*this, &thread_pool_tester::increment, b);
tp.add_task(*this, &thread_pool_tester::increment, c);
tp.add_task(gincrement, res);
DLIB_TEST(a == 1);
DLIB_TEST(b == 1);
DLIB_TEST(c == 1);
DLIB_TEST(res == 1);
tp.add_task(&gincrement, a);
tp.add_task(*this, &thread_pool_tester::increment, b);
tp.add_task(*this, &thread_pool_tester::increment, c);
tp.add_task(&gincrement, res);
tp.add_task(gincrement, a);
tp.add_task(*this, &thread_pool_tester::increment, b);
tp.add_task(*this, &thread_pool_tester::increment, c);
tp.add_task(gincrement, res);
DLIB_TEST(a == 3);
DLIB_TEST(b == 3);
DLIB_TEST(c == 3);
DLIB_TEST(res == 3);
tp.add_task(*this, &thread_pool_tester::increment, c);
tp.add_task(gincrement, res);
DLIB_TEST(c == 4);
DLIB_TEST(res == 4);
tp.add_task(gadd, a, b, res);
DLIB_TEST(res == a+b);
DLIB_TEST(res == 6);
a = 3;
b = 4;
res = 99;
DLIB_TEST(res == 99);
tp.add_task(*this, &thread_pool_tester::add, a, b, res);
DLIB_TEST(res == a+b);
DLIB_TEST(res == 7);
a = 1;
b = 2;
c = 3;
res = 88;
DLIB_TEST(res == 88);
DLIB_TEST(a == 1);
DLIB_TEST(b == 2);
DLIB_TEST(c == 3);
tp.add_task(gadd2, a, b, c, res);
DLIB_TEST(res == 6);
DLIB_TEST(a == 1);
DLIB_TEST(b == 2);
DLIB_TEST(c == 3);
a = 1;
b = 2;
c = 3;
res = 88;
DLIB_TEST(res == 88);
DLIB_TEST(a == 1);
DLIB_TEST(b == 2);
DLIB_TEST(c == 3);
tp.add_task(*this, &thread_pool_tester::add2, a, b, c, res);
DLIB_TEST(res == 6);
DLIB_TEST(a == 1);
DLIB_TEST(b == 2);
DLIB_TEST(c == 3);
a = 1;
b = 2;
c = 3;
res = 88;
tp.add_task(gadd1, a, b);
DLIB_TEST(a == 1);
DLIB_TEST(b == 3);
a = 2;
tp.add_task(*this, &thread_pool_tester::add1, a, b);
DLIB_TEST(a == 2);
DLIB_TEST(b == 5);
val = 4;
uint64 id = tp.add_task(*this, &thread_pool_tester::zero_val);
tp.wait_for_task(id);
DLIB_TEST(val == 0);
id = tp.add_task(*this, &thread_pool_tester::accum2, 1,2);
tp.wait_for_all_tasks();
DLIB_TEST(val == 3);
id = tp.add_task(*this, &thread_pool_tester::accum1, 3);
tp.wait_for_task(id);
DLIB_TEST(val == 6);
obj.get().val = 8;
DLIB_TEST(obj.get().val == 8);
tp.add_task(gset_struct_to_zero, obj);
DLIB_TEST(obj.get().val == 0);
obj.get().val = 8;
DLIB_TEST(obj.get().val == 8);
tp.add_task(*this,&thread_pool_tester::set_struct_to_zero, obj);
DLIB_TEST(obj.get().val == 0);
a = 1;
b = 2;
res = 0;
tp.add_task(f, a, b, res);
DLIB_TEST(a == 1);
DLIB_TEST(b == 2);
DLIB_TEST(res == 3);
global_var = 0;
DLIB_TEST(global_var == 0);
id = tp.add_task(&set_global_var);
tp.wait_for_task(id);
DLIB_TEST(global_var == 9);
global_var = 0;
DLIB_TEST(global_var == 0);
id = tp.add_task(f);
tp.wait_for_task(id);
DLIB_TEST(global_var == 9);
global_var = 0;
DLIB_TEST(global_var == 0);
id = tp.add_task(f, &add_functor::set_global_var);
tp.wait_for_task(id);
DLIB_TEST(global_var == 9);
global_var = 0;
a = 4;
DLIB_TEST(global_var == 0);
id = tp.add_task(f, &add_functor::set_global_var_arg1, a);
tp.wait_for_task(id);
DLIB_TEST(global_var == 4);
global_var = 0;
a = 4;
DLIB_TEST(global_var == 0);
id = tp.add_task_by_value(f, &add_functor::set_global_var_arg1, a);
tp.wait_for_task(id);
DLIB_TEST(global_var == 4);
global_var = 0;
a = 4;
b = 3;
DLIB_TEST(global_var == 0);
id = tp.add_task(f, &add_functor::set_global_var_arg2, a, b);
tp.wait_for_task(id);
DLIB_TEST(global_var == 7);
global_var = 0;
a = 4;
b = 3;
DLIB_TEST(global_var == 0);
id = tp.add_task_by_value(f, &add_functor::set_global_var_arg2, a, b);
tp.wait_for_task(id);
DLIB_TEST(global_var == 7);
global_var = 0;
a = 4;
b = 3;
DLIB_TEST(global_var == 0);
id = tp.add_task(f, &add_functor::set_global_var_const_arg2, a, b);
tp.wait_for_task(id);
DLIB_TEST(global_var == 7);
global_var = 0;
a = 4;
b = 3;
DLIB_TEST(global_var == 0);
id = tp.add_task_by_value(f, &add_functor::set_global_var_const_arg2, a, b);
tp.wait_for_task(id);
DLIB_TEST(global_var == 7);
global_var = 0;
a = 4;
DLIB_TEST(global_var == 0);
id = tp.add_task(f, &add_functor::set_global_var_const_arg1, a);
tp.wait_for_task(id);
DLIB_TEST(global_var == 4);
global_var = 0;
a = 4;
DLIB_TEST(global_var == 0);
id = tp.add_task_by_value(f, &add_functor::set_global_var_const_arg1, a);
tp.wait_for_task(id);
DLIB_TEST(global_var == 4);
global_var = 0;
DLIB_TEST(global_var == 0);
id = tp.add_task_by_value(f, &add_functor::set_global_var);
tp.wait_for_task(id);
DLIB_TEST(global_var == 9);
global_var = 0;
DLIB_TEST(global_var == 0);
id = tp.add_task(f, &add_functor::set_global_var_const);
tp.wait_for_task(id);
DLIB_TEST(global_var == 9);
global_var = 0;
DLIB_TEST(global_var == 0);
id = tp.add_task_by_value(f, &add_functor::set_global_var_const);
tp.wait_for_task(id);
DLIB_TEST(global_var == 9);
}
// add this task just to to perterb the thread pool before it goes out of scope
tp.add_task(f, a, b, res);
for (int k = 0; k < 3; ++k)
{
print_spinner();
global_var = 0;
tp.add_task_by_value(add_functor());
tp.wait_for_all_tasks();
DLIB_TEST(global_var == 9);
a = 0; b = 0; c = 0; d = 0;
tp.add_task_by_value(add_functor(), a);
DLIB_TEST(a == 1);
a = 0; b = 0; c = 0; d = 0;
tp.add_task_by_value(add_functor(8), a, b);
DLIB_TEST(a == 8);
DLIB_TEST(b == 2);
a = 0; b = 0; c = 0; d = 0;
tp.add_task_by_value(add_functor(), a, b, c);
DLIB_TEST(a == 1);
DLIB_TEST(b == 2);
DLIB_TEST(c == 3);
a = 0; b = 0; c = 0; d = 0;
tp.add_task_by_value(add_functor(5), a, b, c, d);
DLIB_TEST(a == 5);
DLIB_TEST(b == 2);
DLIB_TEST(c == 3);
DLIB_TEST(d == 4);
}
tp.wait_for_all_tasks();
// make sure exception propagation from tasks works correctly.
auto f_throws = []() { throw dlib::error("test exception");};
bool got_exception = false;
try
{
tp.add_task_by_value(f_throws);
tp.wait_for_all_tasks();
}
catch(dlib::error& e)
{
DLIB_TEST(e.info == "test exception");
got_exception = true;
}
DLIB_TEST(got_exception);
dlib::future<int> aa;
auto f_throws2 = [](int& a) { a = 1; throw dlib::error("test exception");};
got_exception = false;
try
{
tp.add_task(f_throws2, aa);
aa.get();
}
catch(dlib::error& e)
{
DLIB_TEST(e.info == "test exception");
got_exception = true;
}
DLIB_TEST(got_exception);
}
}
long val;
void accum1(long a) { val += a; }
void accum2(long a, long b) { val += a + b; }
void zero_val() { dlib::sleep(20); val = 0; }
void set_struct_to_zero (some_struct& a) { a.val = 0; }
void set_to_zero (int& a) { dlib::sleep(20); a = 0; }
void increment (int& a) const { dlib::sleep(20); ++a; }
void add (int a, const int& b, int& res) { dlib::sleep(20); res = a + b; }
void add1(int& a, int& res) const { res += a; }
void add2 (int c, int a, const int& b, int& res) { res = a + b + c; }
} a;
}