ZaynPei Lv6

shared_ptr 实现

手写一个 shared_ptr 的实现, 主要需要注意以下几点: - 成员变量是两个指针, 一个指向被管理对象, 一个指向引用计数器(其实是指向控制块, 但是这里简化为指向引用计数器) - 拷贝构造拷贝赋值运算符都需要增加引用计数器, 因为新对象共享同一个被管理对象 - 移动构造移动赋值运算符不增加引用计数器, 因为资源从源对象转移到目标对象, 源对象不再管理该资源 - 需要在析构函数中减少引用计数器, 当引用计数器为0时, 释放对象内存.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
template<class T>
class MysharedPtr{
T* ptr;
int* count;

void cleanup(){ // 非空判断由被调用者处理, 调用者不需要关心
if(count!=nullptr){
(*count)--;
if(*count==0){
delete ptr;
delete count;
}
ptr = nullptr;
count = nullptr;
}
}
public:
explicit MysharedPtr(T* node = nullptr): ptr(node){
if(ptr!=nullptr){
count = new int(1);
}else{
count = nullptr;
}
}
~MysharedPtr(){
cleanup();
}

MysharedPtr(const MysharedPtr<T>& other){
ptr = other.ptr;
count = other.count;

if(count!=nullptr) (*count)++;
}

MysharedPtr<T>& operator=(const MysharedPtr<T>& other){
if(&other==this) return *this; // 这里的比较必须是地址的比较, 因为对象之间没有重载比较运算符, 也就是不能 other==*this

cleanup();

ptr = other.ptr;
count = other.count;

if(count!=nullptr) (*count)++;

return *this;
}

MysharedPtr(MysharedPtr<T>&& other){
ptr = other.ptr;
count = other.count;

other.ptr = nullptr;
other.count = nullptr;
}

MysharedPtr<T>& operator=(MysharedPtr<T>&& other){
if(&other==this) return *this;

cleanup();

ptr = other.ptr;
count = other.count;

other.ptr = nullptr;
other.count = nullptr;

return *this;
}

T* operator->(){ return ptr; }

T& operator*(){ return *ptr;}

T* get(){
return ptr;
}

int use_count(){
return (count==nullptr)? 0: *count;
}

};

不过, 上述实现有个比较重要的缺陷: 它不是线程安全的. 在多线程环境下, 多个线程可能同时修改引用计数器, 导致竞态条件. 解决这个问题通常需要使用原子操作来保护引用计数器的修改.(但是依旧只保证了引用计数器的线程安全, 并不能保证被管理对象本身的线程安全, 如果需要的话还需要额外对对象本身加锁或者使用线程安全的数据结构)

下面是一个使用 C++11 原子操作实现的线程安全 shared_ptr, 同时使用了控制块来管理引用计数器和被管理对象:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
template<class T>
class MysharedPtr{
// -------------------------------
// 控制块(ControlBlock): 管理引用计数
// -------------------------------
struct ControlBlock{
std::atomic<int> strong; // 强引用计数(shared_ptr 使用)

ControlBlock(): strong(1){} // 初始计数为1
};

T* ptr; // 管理的对象指针
ControlBlock* cb; // 控制块指针,用于管理引用计数

// -------------------------------
// 清理函数:减少引用计数并释放资源
// 调用者不需要判断空指针,内部已处理
// -------------------------------
void cleanup(){
if(cb!=nullptr){ // 如果有控制块
// 原子减少引用计数,如果减少后为0,说明这是最后一个shared_ptr
if(cb->strong.fetch_sub(1)==1){
delete ptr; // 删除管理的对象
delete cb; // 删除控制块
}
ptr = nullptr; // 清空指针,避免悬挂引用
cb = nullptr;
}
}

public:
// -------------------------------
// 构造函数
// -------------------------------
explicit MysharedPtr(T* node = nullptr): ptr(node){
if(ptr!=nullptr){
cb = new ControlBlock(); // 新建控制块,计数为1
}else{
cb = nullptr; // 空指针不需要控制块
}
}

// -------------------------------
// 析构函数
// -------------------------------
~MysharedPtr(){
cleanup(); // 自动释放资源
}

// -------------------------------
// 拷贝构造函数
// -------------------------------
MysharedPtr(const MysharedPtr<T>& other){
ptr = other.ptr; // 指向相同对象
cb = other.cb; // 共用同一个控制块

if(cb!=nullptr)
cb->strong.fetch_add(1); // 原子增加引用计数
}

// -------------------------------
// 拷贝赋值
// -------------------------------
MysharedPtr<T>& operator=(const MysharedPtr<T>& other){
if(&other==this) return *this; // 自赋值检查,避免错误. 注意这里不能使用 other==*this 进行比较, 因为没有重载比较运算符

cleanup(); // 先释放当前资源

ptr = other.ptr; // 指向同一对象
cb = other.cb;

if(cb!=nullptr) // 增加之前要判断是否为空
cb->strong.fetch_add(1); // 增加引用计数

return *this;
}

// -------------------------------
// 移动构造
// -------------------------------
MysharedPtr(MysharedPtr<T>&& other) noexcept{
ptr = other.ptr; // 接管资源
cb = other.cb;

other.ptr = nullptr; // 清空原对象
other.cb = nullptr;
}

// -------------------------------
// 移动赋值
// -------------------------------
MysharedPtr<T>& operator=(MysharedPtr<T>&& other) noexcept{
if(&other==this) return *this;

cleanup(); // 释放当前资源

ptr = other.ptr; // 接管资源
cb = other.cb;

other.ptr = nullptr; // 清空原对象
other.cb = nullptr;

return *this;
}

// -------------------------------
// 访问管理对象
// -------------------------------
T* operator->() const{ return ptr; } // 访问成员
T& operator*() const{ return *ptr; } // 解引用
T* get() const{ return ptr; } // 返回裸指针

// -------------------------------
// 查询引用计数
// -------------------------------
int use_count() const {
return (cb==nullptr)? 0: cb->strong.load(); // 原子读取
}

// -------------------------------
// reset():释放当前资源
// -------------------------------
void reset(){
cleanup(); // 调用 cleanup 自动释放
}

// reset(T*): 换绑定对象
void reset(T* node){
cleanup(); // 先释放旧资源
*this = MysharedPtr(node); // 利用构造函数重新赋值
}

// -------------------------------
// swap(): 交换两个 shared_ptr 的资源
// -------------------------------
void swap(MysharedPtr<T>& other){
std::swap(ptr, other.ptr); // 交换对象指针
std::swap(cb, other.cb); // 交换控制块指针
}
};

memcpy 实现

手写一个简单的 memcpy 函数, 用于内存拷贝. 需要注意以下几点: - 函数参数包括: 目标地址 dest, 源地址 src, 拷贝字节数 n - 函数返回值是目标地址 dest - 除了实现简单的逐字节拷贝之外, 还需要解决两个问题: - 效率问题: 可以考虑按更大单位(如4字节或8字节)进行拷贝, 提高效率. 但是需要考虑地址对齐问题 - 重叠问题: 如果源地址和目标地址有重叠, 则需要考虑使用 memmove 的策略, 避免数据被覆盖

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#include <cstddef>  // for size_t
#include <cassert> // for assert

void* my_memcpy(void* dest, const void* src, size_t n) {
assert(dest != nullptr && src != nullptr);

// 使用 char* 来保证逐字节操作, 因为 void* 是“无类型”指针,不能直接解引用或做算术。我们必须先把它转换成具体类型的指针。
unsigned char* d = static_cast<unsigned char*>(dest);
const unsigned char* s = static_cast<const unsigned char*>(src);

// 假设不重叠,从前向后逐字节拷贝
for (size_t i = 0; i < n; ++i) {
d[i] = s[i];
}

return dest;
}

C++标准规定:unsigned char唯一保证可以无歧义地访问任意类型原始内存字节的类型。它是“字节的同义词”,适合做内存拷贝、内存操作等底层工作。 上面的实现是一个简单版本的 memcpy, 它假设源地址和目标地址不重叠, 并且逐字节进行拷贝.

下面实现一个更高效的版本, 它按照8字节对齐进行拷贝. 这个实现需要注意的点有: - 先处理前导字节, 直到目标地址对齐到8字节边界 - 然后按8字节块进行拷贝 - 最后处理剩余的字节

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#include <cstddef>  // for size_t
#include <cassert> // for assert
void* my_memcpy(void* dest, const void* src, size_t n) {
assert(dest != nullptr && src != nullptr);

unsigned char* d = static_cast<unsigned char*>(dest);
const unsigned char* s = static_cast<const unsigned char*>(src);

// 处理前导字节,直到目标地址对齐到8字节边界
// 先把指针d转换为uintptr_t类型, 再对8取模判断是否对齐
while (n > 0 && (reinterpret_cast<uintptr_t>(d) % 8 != 0)) {
*d++ = *s++;
--n;
}

// 按照8字节块进行拷贝
size_t num_blocks = n / 8;
for (size_t i = 0; i < num_blocks; ++i) {
// 每次拷贝8字节, 具体方法是将源地址和目标地址都转换为 uint64_t* 指针, 这样指针指向的大小就是8字节, 然后解引用赋值, 目标地址 d 开始的8个字节就被赋值为源地址 s 开始的8个字节
*reinterpret_cast<uint64_t*>(d) = *reinterpret_cast<const uint64_t*>(s);
d += 8;
s += 8;
}

// 处理剩余的字节
n %= 8;
while (n > 0) {
*d++ = *s++;
--n;
}

return dest;
}

对于重叠的情况, 需要根据源地址和目标地址的相对位置来决定拷贝的方向: - 如果目标地址在源地址之后, 则从后向前拷贝 - 例如: 假设 dest = 0x1004, src = 0x1000, n = 8 - 则拷贝顺序应该是: 拷贝 src[7] 到 dest[7], 然后拷贝 src[6] 到 dest[6], 依此类推 - 如果目标地址在源地址之前, 则从前向后拷贝, 因为此时不会覆盖未拷贝的数据

这也就是标准库中的 memmove 函数的实现思路.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#include <cstddef>  // for size_t
#include <cassert> // for assert

// 需要注意这里的 const 只是保证函数不会主动通过 src 写数据, 但并不保证 src 所指向的内存区域不会被修改(比如通过 dest 修改)
void* my_memmove(void* dest, const void* src, size_t n) {
assert(dest != nullptr && src != nullptr);

unsigned char* d = static_cast<unsigned char*>(dest);
const unsigned char* s = static_cast<const unsigned char*>(src);

if (d < s) {
// 从前向后拷贝
for (size_t i = 0; i < n; ++i) {
d[i] = s[i];
}
} else if (d > s) { // 如果 src 在 dest 之前,可能会有重叠
// 从后向前拷贝
for (size_t i = n; i > 0; --i) {
d[i - 1] = s[i - 1];
}
}
// 如果 d == s, 不需要拷贝

return dest;
}

多线程交替打印

实现三个线程交替打印 ABCABC , 需要使用互斥锁条件变量进行线程同步. 主要思路是: - 使用一个互斥锁保护共享变量, 该变量表示当前轮到哪个线程打印 - 使用条件变量让线程等待和通知, 当轮到某个线程打印时, 该线程被唤醒进行打印, 打印完成后修改共享变量并通知所有线程, 每个线程根据共享变量判断是否轮到自己打印

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include <iostream>
#include <thread> // 包含 std::thread
#include <mutex> // 包含 std::mutex 和 std::unique_lock
#include <condition_variable> // 包含 std::condition_variable

// --- 共享数据和同步原语 ---
std::mutex mtx; // 互斥锁,保护 'turn'
std::condition_variable cv; // 条件变量,用于等待和通知
int turn = 1; // 共享的 "轮次" 标志, 初始值为 1,表示轮到线程 1 打印
// -----------------------------

void PrintA() {
for (int i = 0; i < 10; ++i) { // 假设我们希望打印 10 次
unique_lock<mutex> lock(mtx); // 获取锁
cv.wait(lock, [] { return turn == 1; }); // 等待轮到线程 1 打印

cout << "A"; // 打印 A

turn = 2; // 修改轮次,轮到线程 2 打印
cv.notify_all(); // 通知所有等待的线程
}
}

void PrintB() {
for (int i = 0; i < 10; ++i) { // 假设我们希望打印 10 次
unique_lock<mutex> lock(mtx); // 获取锁
cv.wait(lock, [] { return turn == 2; }); // 等待轮到线程 2 打印

cout << "B"; // 打印 B

turn = 3; // 修改轮次,轮到线程 3 打印
cv.notify_all(); // 通知所有等待的线程
}
}

void PrintC() {
for (int i = 0; i < 10; ++i) { // 假设我们希望打印 10 次
unique_lock<mutex> lock(mtx); // 获取锁
cv.wait(lock, [] { return turn == 3; }); // 等待轮到线程 3 打印

cout << "C"; // 打印 C

turn = 1; // 修改轮次,轮到线程 1 打印
cv.notify_all(); // 通知所有等待的线程
}
}

int main() {
// 设置我们希望交替打印的次数
const int iterations = 10;

// 创建三个线程
std::thread t1(print_numbers, iterations);
std::thread t2(print_chars, iterations);
std::thread t3(print_symbols, iterations);

// 等待三个线程都执行完毕
t1.join();
t2.join();
t3.join();

return 0;
}

简单HashMap 实现

手写一个简单的 HashMap 实现, 主要功能包括插入、查找和删除键值对. 需要注意以下几点: - 使用链地址法解决哈希冲突 - 实现基本的哈希函数, 即根据键计算哈希值 - 提供插入、查找和删除操作的接口

下面是一个简单的 HashMap 实现, 使用字符串作为键, 整数作为值, 暂时不支持扩容机制

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#include <iostream>
#include <string>
#include <vector>
#include <list> // 用于在每个桶中实现“链表”(拉链法)
#include <utility> // 用于 std::pair
#include <functional> // 用于 std::hash

class MyHashMap {
private:
// 1. 定义桶数组
size_t m_bucket_size; // 桶的数量
std::vector<std::list<std::pair<std::string, int>>> m_buckets; // 每个桶是一个链表,链表中存储 (key, value) 对

// 内部哈希函数
size_t _hash(const std::string& key) {
// 使用 C++ 内置的 std::hash 来获取哈希值
// std::hash<std::string>{}() 是一个可调用对象,用于计算字符串的哈希
size_t hash_value = std::hash<std::string>{}(key); // 临时创建一个 std::hash 对象来计算哈希值

// 步骤 2: 取模运算,将哈希值映射到桶索引
return hash_value % m_bucket_size;
}

public:
MyHashMap(size_t size = 100) : m_bucket_size(size) {
// 初始化桶数组, 将 vector 的大小调整为 m_bucket_size,
// 每个桶默认都是一个空的 std::list
m_buckets.resize(m_bucket_size);
}

// 插入或更新一个 (key, value) 对
void put(const std::string& key, int value) {
// 步骤 1: 计算 key 对应的桶索引
size_t index = _hash(key);

// 步骤 2: 获取该索引处的链表(桶)
// 注意:我们使用引用(&),以便能直接修改链表
std::list<std::pair<std::string, int>>& bucket_list = m_buckets[index];

// 步骤 3: 检查 key 是否已存在于链表中
for (auto& pair : bucket_list) {
if (pair.first == key) {
// 目标:Key 已存在
// 说明:我们只需要更新其 value
std::cout << "更新 Key: " << key << " 的值为 " << value << std::endl;
pair.second = value;
return; // 完成操作,退出函数
}
}

// 步骤 4: Key 不存在,将其插入链表尾部
// 说明:如果循环结束仍未返回,表示这是一个新的 key
std::cout << "插入 Key: " << key << ", Value: " << value << std::endl;
bucket_list.push_back({key, value});
}

// 根据 key 获取 value
int get(const std::string& key) {
// 步骤 1: 计算 key 对应的桶索引
size_t index = _hash(key);

// 步骤 2: 获取该索引处的链表(桶)
// 注意:这里使用 const 引用,因为 get 不应修改数据
const std::list<std::pair<std::string, int>>& bucket_list = m_buckets[index];

// 步骤 3: 遍历链表查找 key
for (const auto& pair : bucket_list) {
if (pair.first == key) {
// 目标:Key 已找到
// 说明:返回对应的 value
return pair.second;
}
}

// 步骤 4: Key 未找到
// 说明:如果循环结束仍未返回,表示 key 不存在
return -1;
}

// 删除一个 (key, value) 对
void remove(const std::string& key) {
// 步骤 1: 计算 key 对应的桶索引
size_t index = _hash(key);

// 步骤 2: 获取该索引处的链表(桶)
std::list<std::pair<std::string, int>>& bucket_list = m_buckets[index];

// 步骤 3: 遍历链表查找并删除 key
for (auto it = bucket_list.begin(); it != bucket_list.end(); ++it) {
if (it->first == key) {
// 目标:Key 已找到
// 说明:删除该节点
std::cout << "删除 Key: " << key << std::endl;
bucket_list.erase(it);
return; // 完成操作,退出函数
}
}

// 步骤 4: Key 未找到
// 说明:如果循环结束仍未返回,表示 key 不存在
std::cout << "Key: " << key << " 不存在,无法删除" << std::endl;
}
};

生产者-消费者模型

实现一个简单的生产者-消费者模型, 使用条件变量和互斥锁进行同步. 主要功能包括: - 生产者线程不断生成数据并放入缓冲区 - 消费者线程不断从缓冲区取出数据进行处理 - 使用条件变量通知生产者和消费者线程, 注意需要两个条件变量, 一个用于通知生产者缓冲区有空位, 另一个用于通知消费者缓冲区有数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#include<iostream>
#include<queue>
#include<mutex>
#include<condition_variable>

using namespace std;
class SPSC_Buffer{
private:
queue<int> queue_;
size_t capacity_;
mutex mtx_;
condition_variable not_full_;
condition_variable not_empty_;
public:
explicit SPSC_Buffer(size_t capacity): capacity_(capacity){}

void put(int item){
unique_lock<mutex> lk(mtx_);
not_full_.wait(lk, [this](){
return queue_.size()<capacity_;
});

queue_.push(item);
not_empty_.notify_one();
}

int get(){
unique_lock<mutex> lk(mtx_);
not_empty_.wait(lk, [this](){
return !queue_.empty();
});
int ret = queue_.front();
queue_.pop();
not_full_.notify_one();
return ret;
}

};

int main(){
SPSC_Buffer bf(5);
thread producer([&](){
for(int i=0;i<20;i++){
bf.put(i);
cout << "Produced: " << i << endl;
}
});

thread Consumer([&](){
for(int i=0;i<20;i++){
int ret = bf.get();
cout << "Consumer: " << i <<endl;
}
});

producer.join();
Consumer.join();
return 0;
}

自旋锁实现

手写一个简单的自旋锁实现, 主要功能包括: - 提供加锁和解锁的接口 - 使用原子变量实现自旋锁的状态标志

1
2
3
4
5
6
7
8
9
10
11
12
13
class SpinLock{
atomic_flag flag = ATOMIC_FLAG_INIT;
public:
void lock(){
while(flag.test_and_set(memory_order_acquire)){

}
}

void unlock(){
flag.clear(memory_order_release);
}
};

单例模式实现

手写一个线程安全的懒汉式单例模式实现: - 成员变量是一个静态指针, 指向唯一实例; 还需要一个互斥锁保护实例创建过程 - 构造函数私有化, 防止外部实例化 - 确保类只有一个实例 - 提供一个全局访问点获取该实例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
template<typename T>
class Singleton {
private:
inline static T* instance = nullptr;
inline static std::mutex mtx;

Singleton(){} = default;
~Singleton(){} = default;

Singleton(const Singleton&) = delete;
Singleton& operator=(const Singleton&) = delete;
public:
static T& getInstance(){
if(instance==nullptr){
std::unique_lock<std::mutex> lk(mtx);
if(instance==nullptr){
instance = new T();
}
}
return *instance;
}

void destroyInstance(){
std::lock_guard<std::mutex> lk(mtx);
if(instance!=nullptr){
delete instance;
instance = nullptr;
}
}
};
这里还实现了双重检查锁定(Double-Checked Locking)来减少锁的开销, 只有在实例未创建时才加锁.双重检查锁定 (Double-Checked Locking Pattern, DCLP) 是一种用于延迟初始化单例对象的设计模式。它旨在减少获取单例实例时的锁开销,同时确保线程安全。这种模式的核心思想是,在获取单例实例时,先进行一次非锁定的检查,如果实例已经存在,则直接返回实例;如果实例不存在,则进入锁定区域,再次检查实例是否存在,如果仍然不存在,则创建实例。

下面是一个C++11使用静态局部变量实现的线程安全单例模式示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
template<typename T>
class Singleton {
private:
// 没有成员变量

Singleton(){} = default;
~Singleton(){} = default;

Singleton(const Singleton&) = delete;
Singleton& operator=(const Singleton&) = delete;
public:
static T& getInstance(){
// 直接在函数中定义静态局部变量, 也不需要在析构函数中销毁实例, 因为静态局部变量会在程序结束时自动销毁
static T instance; // C++11 保证静态局部变量的初始化是线程安全的
return instance;
}
};

线程池实现

手写一个简单的线程池实现, 主要功能包括: - 初始化一定数量的工作线程 - 提供提交任务的接口 - 工作线程从任务队列中取出任务并执行 - 使用互斥锁和条件变量进行线程同步

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class ThreadPool{
vector<thread> workers;
queue<function<void()>> tasks;

mutex mtx;
condition_variable cv;
atomic<bool> stop;
public:
ThreadPool(int nums): stop(false){
for(int i=0;i<nums;i++){
workers.emplace_back([this]{ // this 捕获避免拷贝整个对象
while(1){
function<void()> task;
{
unique_lock<mutex> lock(mtx);
cv.wait(lock, [this](){
return !tasks.empty() || stop.load();
});
if(stop.load() && tasks.empty()) return;
task = move(tasks.front()); // move 避免多余拷贝
tasks.pop();
}
task();
}
});
}
}

void enqueue(function<void()> task){
{
unique_lock<mutex> lock(mtx);
tasks.emplace(move(task)); // move 避免多余拷贝
}
cv.notify_one();
}

~ThreadPool(){
stop.store(true); // 原子变量设置为true
cv.notify_all();
for(auto& worker: workers){ // 必须是引用, 因为线程对象不可复制
worker.join();
}
}
};
上面的线程池实现了基本的功能, 不过它只支持提交 function<void()> 类型的任务, 我们可以使用模板来实现一个更通用的 enqueue 方法, 支持任意可调用对象和参数:

1
2
3
4
5
6
7
8
9
template<class F, class... Args>
void enqueue(F&& f, Args&&... args){
auto task = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
{
std::unique_lock<std::mutex> lock(mtx);
tasks.emplace(std::move(task));
}
cv.notify_one();
}

不过上面的实现还是无法获取任务的返回值, 我们可以使用 std::packaged_taskstd::future 来实现支持返回值的任务提交:

1