How to write a thread pool

2 minute read

Clarification

a. We can specify number of threads in the pool
b. We can enqueue tasks (For simplicity, task takes no argument and returns an integer) into the thread pool
c. We can retrieve the result of task from the thread pool

API definition

// Start the thread pool
void start(int num_of_threads);
// Stop the thread pool
void stop();
// Enqueue a task into thread pool and with a future object return
std::future<int> enqueue(packaged_task<int()> task);

Implementation

We need a mutex, a condition_variable, a boolean & a vector of threads to create the most basic form of a thread pool

class ThreadPool {
public:
    explicit ThreadPool(int num_of_threads) {
        stopping = false;
        Start(num_of_threads);
    }

    virtual ~ThreadPool() {
        Stop();
    }
private:
    std::mutex m;
    std::condition_variable cv;
    bool stopping;
    std::vector<std::thread> threads;
    std::queue<std::packaged_task<void()>> q;

    void Start(int num_of_threads) {
        for (int i = 0; i < num_of_threads; i++) {
            threads.push_back(std::thread([=]() {
                while (true) {
                    std::unique_lock<std::mutex> lock(m);
                    // The wait of condition_variable will block the current thread, will unlock the lock, waiting for the stopping to be true.
                    // If other thread notify_one() or notify_all() and stopping is false, it will wait again.
                    cv.wait(lock, [=]() {
                        return stopping;
                    });
                    if (stopping) {
                        break;
                    }
                }
            }));
        }
    }

    void Stop() {
        {
            std::unique_lock<std::mutex> lock(m);
            stopping = true;
        }
        cv.notify_all();
        for (auto &th : threads) {
            th.join();
        }
    }
};

We need a queue in order to support adding tasks

class ThreadPool {
public:
    explicit ThreadPool(int num_of_threads) {
        stopping = false;
        Start(num_of_threads);
    }

    virtual ~ThreadPool() {
        Stop();
    }

    void enqueue(std::packaged_task<void()> task) {
        {
            std::unique_lock<std::mutex> lock(m);
            q.push(std::move(task));
        }
        cv.notify_one();
    }
private:
    std::mutex m;
    std::condition_variable cv;
    bool stopping;
    std::vector<std::thread> threads;
    std::queue<std::packaged_task<void()>> q;

    void Start(int num_of_threads) {
        for (int i = 0; i < num_of_threads; i++) {
            threads.push_back(std::thread([=]() {
                while (true) {
                    std::packaged_task<void()> task;
                    {
                        std::unique_lock<std::mutex> lock(m);
                        cv.wait(lock, [=]() {
                            return stopping || !q.empty();
                        });
                        if (stopping && q.empty()) {
                            break;
                        }
                        task = std::move(q.front());
                        q.pop();
                    }
                    if (task.valid()) {
                        task();
                    }
                }
            }));
        }
    }

    void Stop() {
        {
            std::unique_lock<std::mutex> lock(m);
            stopping = true;
        }
        cv.notify_all();
        for (auto &th : threads) {
            th.join();
        }
    }
};

We can finalise the implementation by adding a std::future in the return of enqueue function and change void to int

class ThreadPool {
public:
    explicit ThreadPool(int num_of_threads) {
        stopping = false;
        Start(num_of_threads);
    }

    virtual ~ThreadPool() {
        Stop();
    }

    std::future<int> enqueue(std::packaged_task<int()> task) {
        auto future = task.get_future();
        {
            std::unique_lock<std::mutex> lock(m);
            q.push(std::move(task));
        }
        cv.notify_one();
        return future;
    }
private:
    std::mutex m;
    std::condition_variable cv;
    bool stopping;
    std::vector<std::thread> threads;
    std::queue<std::packaged_task<int()>> q;

    void Start(int num_of_threads) {
        for (int i = 0; i < num_of_threads; i++) {
            threads.push_back(std::thread([=]() {
                while (true) {
                    std::packaged_task<int()> task;
                    {
                        std::unique_lock<std::mutex> lock(m);
                        cv.wait(lock, [=]() {
                            return stopping || !q.empty();
                        });
                        if (stopping && q.empty()) {
                            break;
                        }
                        task = std::move(q.front());
                        q.pop();
                    }
                    if (task.valid()) {
                        task();
                    }
                }
            }));
        }
    }

    void Stop() {
        {
            std::unique_lock<std::mutex> lock(m);
            stopping = true;
        }
        cv.notify_all();
        for (auto &th : threads) {
            th.join();
        }
    }
};