diff --git a/ThreadPool.h b/ThreadPool.h index a3a6690..40e3a26 100755 --- a/ThreadPool.h +++ b/ThreadPool.h @@ -41,21 +41,49 @@ namespace progschj { +class would_block + : public std::runtime_error +{ + using std::runtime_error::runtime_error; +}; + + class ThreadPool { public: + template + using return_type = +#if defined(__cpp_lib_is_invocable) && __cpp_lib_is_invocable >= 201703L + typename std::invoke_result::type; +#else + typename std::result_of::type; +#endif + explicit ThreadPool(std::size_t threads - = (std::max)(2u, std::thread::hardware_concurrency() * 2)); - template - auto enqueue(F&& f, Args&&... args) - -> std::future::type>; + = (std::max)(2u, std::thread::hardware_concurrency())); + template + auto enqueue_block(F&& f, Args&&... args) -> std::future>; + template + auto enqueue(F&& f, Args&&... args) -> std::future>; void wait_until_empty(); void wait_until_nothing_in_flight(); void set_queue_size_limit(std::size_t limit); + void set_pool_size(std::size_t limit); ~ThreadPool(); private: + void start_worker(std::size_t worker_number, + std::unique_lock const &lock); + + template + auto enqueue_worker(bool, F&& f, Args&&... args) -> std::future>; + + template + static std::future make_exception_future (std::exception_ptr ex_ptr); + // need to keep track of threads so we can join them std::vector< std::thread > workers; + // target pool size + std::size_t pool_size; // the task queue std::queue< std::function > tasks; // queue length limit @@ -97,65 +125,58 @@ class ThreadPool { // the constructor just launches some amount of workers inline ThreadPool::ThreadPool(std::size_t threads) - : in_flight(0) + : pool_size(threads) + , in_flight(0) { - for(size_t i = 0;i task; - bool notify; - - { - std::unique_lock lock(this->queue_mutex); - this->condition_consumers.wait(lock, - [this]{ return this->stop || !this->tasks.empty(); }); - if(this->stop && this->tasks.empty()) - return; - task = std::move(this->tasks.front()); - this->tasks.pop(); - notify = this->tasks.size() + 1 == max_queue_size - || this->tasks.empty(); - } - - handle_in_flight_decrement guard(*this); - - if (notify) - { - std::unique_lock lock(this->queue_mutex); - condition_producers.notify_all(); - } + std::unique_lock lock(this->queue_mutex); + for (std::size_t i = 0; i != threads; ++i) + start_worker(i, lock); +} - task(); - } - } - ); +// add new work item to the pool and block if the queue is full +template +auto ThreadPool::enqueue_block(F&& f, Args&&... args) -> std::future> +{ + return enqueue_worker (true, std::forward (f), std::forward (args)...); } -// add new work item to the pool +// add new work item to the pool and return future with would_block exception if it is full template -auto ThreadPool::enqueue(F&& f, Args&&... args) - -> std::future::type> +auto ThreadPool::enqueue(F&& f, Args&&... args) -> std::future> { - using return_type = typename std::result_of::type; + return enqueue_worker (false, std::forward (f), std::forward (args)...); +} - auto task = std::make_shared< std::packaged_task >( +template +auto ThreadPool::enqueue_worker(bool block, F&& f, Args&&... args) -> std::future> +{ + auto task = std::make_shared< std::packaged_task()> >( std::bind(std::forward(f), std::forward(args)...) ); - std::future res = task->get_future(); + std::future> res = task->get_future(); std::unique_lock lock(queue_mutex); + if (tasks.size () >= max_queue_size) - // wait for the queue to empty or be stopped - condition_producers.wait(lock, - [this] - { - return tasks.size () < max_queue_size - || stop; - }); + { + if (block) + { + // wait for the queue to empty or be stopped + condition_producers.wait(lock, + [this] + { + return tasks.size () < max_queue_size + || stop; + }); + } + else + { + return ThreadPool::make_exception_future> ( + std::make_exception_ptr (would_block("queue full"))); + } + } + // don't allow enqueueing after stopping the pool if (stop) @@ -170,20 +191,15 @@ auto ThreadPool::enqueue(F&& f, Args&&... args) return res; } - // the destructor joins all threads inline ThreadPool::~ThreadPool() { - { - std::unique_lock lock(queue_mutex); - stop = true; - condition_consumers.notify_all(); - condition_producers.notify_all(); - } - - for(std::thread &worker: workers) - worker.join(); - + std::unique_lock lock(queue_mutex); + stop = true; + pool_size = 0; + condition_consumers.notify_all(); + condition_producers.notify_all(); + condition_consumers.wait(lock, [this]{ return this->workers.empty(); }); assert(in_flight == 0); } @@ -204,12 +220,121 @@ inline void ThreadPool::wait_until_nothing_in_flight() inline void ThreadPool::set_queue_size_limit(std::size_t limit) { std::unique_lock lock(this->queue_mutex); + + if (stop) + return; + std::size_t const old_limit = max_queue_size; max_queue_size = (std::max)(limit, std::size_t(1)); if (old_limit < max_queue_size) condition_producers.notify_all(); } +inline void ThreadPool::set_pool_size(std::size_t limit) +{ + if (limit < 1) + limit = 1; + + std::unique_lock lock(this->queue_mutex); + + if (stop) + return; + + std::size_t const old_size = pool_size; + assert(this->workers.size() >= old_size); + + pool_size = limit; + if (pool_size > old_size) + { + // create new worker threads + // it is possible that some of these are still running because + // they have not stopped yet after a pool size reduction, such + // workers will just keep running + for (std::size_t i = old_size; i != pool_size; ++i) + start_worker(i, lock); + } + else if (pool_size < old_size) + // notify all worker threads to start downsizing + this->condition_consumers.notify_all(); +} + +inline void ThreadPool::start_worker( + std::size_t worker_number, std::unique_lock const &lock) +{ + assert(lock.owns_lock() && lock.mutex() == &this->queue_mutex); + assert(worker_number <= this->workers.size()); + + auto worker_func = + [this, worker_number] + { + for(;;) + { + std::function task; + bool notify; + + { + std::unique_lock lock(this->queue_mutex); + this->condition_consumers.wait(lock, + [this, worker_number]{ + return this->stop || !this->tasks.empty() + || pool_size < worker_number + 1; }); + + // deal with downsizing of thread pool or shutdown + if ((this->stop && this->tasks.empty()) + || (!this->stop && pool_size < worker_number + 1)) + { + // detach this worker, effectively marking it stopped + this->workers[worker_number].detach(); + // downsize the workers vector as much as possible + while (this->workers.size() > pool_size + && !this->workers.back().joinable()) + this->workers.pop_back(); + // if this is was last worker, notify the destructor + if (this->workers.empty()) + this->condition_consumers.notify_all(); + return; + } + else if (!this->tasks.empty()) + { + task = std::move(this->tasks.front()); + this->tasks.pop(); + notify = this->tasks.size() + 1 == max_queue_size + || this->tasks.empty(); + } + else + continue; + } + + handle_in_flight_decrement guard(*this); + + if (notify) + { + std::unique_lock lock(this->queue_mutex); + condition_producers.notify_all(); + } + + task(); + } + }; + + if (worker_number < this->workers.size()) { + std::thread & worker = this->workers[worker_number]; + // start only if not already running + if (!worker.joinable()) { + worker = std::thread(worker_func); + } + } else + this->workers.push_back(std::thread(worker_func)); +} + +template +inline std::future ThreadPool::make_exception_future (std::exception_ptr ex_ptr) +{ + std::promise p; + p.set_exception (ex_ptr); + return p.get_future (); +} + } // namespace progschj #endif // THREAD_POOL_H_7ea1ee6b_4f17_4c09_b76b_3d44e102400c