1. 概述
本文主要分析Blink大神的参赛代码(Git源码地址:message-queue-cpp),难免会出现各种疏漏,希望读者即时指正,本人会在之后的更新中补充。
2. queue_store实现
put函数
void queue_store::put(const std::string &queue_name, const MemBlock &message) {
//从队列名称中获取队列的ID
auto queueID = static_cast<unsigned long>(getQueueID(queue_name));
//队列ID模线程数获取所属的IO线程
auto thread_id = static_cast<int>(queueID % IO_THREAD);
//从IO线程数组中取出线程
asyncfileio_thread_t *ioThread = asyncfileio->work_threads_object[thread_id];
//计算这个队列在所属线程中的编号
uint64_t which_queue_in_this_io_thread = queueID / IO_THREAD;
//访问该队列的偏移量计数器并加一
uint64_t queue_offset = ioThread->queue_counter[which_queue_in_this_io_thread]++;
//计算CHUNK_ID,CHUNK_ID表示当前这条消息的索引应该写入的CHUNK的编号
uint64_t chunk_id = ((queue_offset / CLUSTER_SIZE) * (CLUSTER_SIZE * QUEUE_NUM_PER_IO_THREAD) +
(which_queue_in_this_io_thread * CLUSTER_SIZE) +
queue_offset % CLUSTER_SIZE);
//根据CHUNK_ID确定索引写入位置
uint64_t idx_file_offset = INDEX_ENTRY_SIZE * chunk_id;
//计算索引文件的序号
int which_mapped_chunk = static_cast<int>(idx_file_offset / INDEX_MAPPED_BLOCK_RAW_SIZE);
//计算索引在索引文件中的写入位置
uint64_t offset_in_mapped_chunk = idx_file_offset % INDEX_MAPPED_BLOCK_RAW_SIZE;
//如果内存中没有索引文件对应的缓冲区
if (ioThread->index_file_memory_block[which_mapped_chunk] == nullptr) {
//对缓冲区加锁(注意这里不是对mapped_block_mtx加锁)
std::unique_lock<std::mutex> lock(ioThread->mapped_block_mtx[which_mapped_chunk]);
//初始化缓冲区,并阻塞等待初始化完成
ioThread->mapped_block_cond[which_mapped_chunk].wait(lock, [ioThread, which_mapped_chunk]() -> bool {
return ioThread->index_file_memory_block[which_mapped_chunk] != nullptr;
});
}
//计算缓冲区写入位置并存为指针
char *buf = ioThread->index_file_memory_block[which_mapped_chunk] + offset_in_mapped_chunk;
//如果消息是常规尺寸(58)
if (message.size <= RAW_NORMAL_MESSAGE_SIZE) {
//使用跳表快速编码并写入缓冲区
serialize_base36_decoding_skip_index((uint8_t *) message.ptr, message.size,
(uint8_t *) buf);
} else {
//创建长消息缓冲区
unsigned char large_msg_buf[4096];
//使用跳表快速编码并写入缓冲区,保存编码后的长度
uint64_t length = (uint64_t) serialize_base36_decoding_skip_index((uint8_t *) message.ptr, message.size,
(uint8_t *) large_msg_buf);
//计算写入位置
uint64_t offset = ioThread->data_file_current_size.fetch_add(length);
//写入消息
pwrite(ioThread->data_file_fd, large_msg_buf, length, offset);
//写入标志位到索引
buf[0] = LARGE_MESSAGE_MAGIC_CHAR;
//写入偏移量到索引
memcpy(buf + 4, &offset, 8);
//写入消息长度到索引
memcpy(buf + 12, &length, 8);
}
//回收消息内存
delete[] ((char *) (message.ptr));
//累加索引块写入次数计数器
int write_times = ++(ioThread->index_mapped_block_write_counter[which_mapped_chunk]);
//根据写入次数判断索引块是否写满
if (write_times == INDEX_BLOCK_WRITE_TIMES_TO_FULL) {
//新建异步保存索引任务
asyncio_task_t *task = new asyncio_task_t(0);
//执行异步任务
ioThread->blockingQueue->put(task);
}
}
get函数
vector<MemBlock> queue_store::get(const std::string &queue_name, long offset, long number) {
//线程计数器
static thread_local int tid = ++tidCounter;
//如果不到10个线程,说明是IndexCheck评测阶段
if (tid < CHECK_THREAD_NUM) {
return doPhase2(tid, queue_name, offset, number);
}
//如果超过10个线程,说明是QueueConsume评测阶段
return doPhase3(tid, queue_name, offset, number);
}
doPhase2函数
std::vector<MemBlock> queue_store::doPhase2(int tid, const std::string &queue_name, long offset, long number) {
//等待上次IO操作完成
asyncfileio->waitFinishIO(tid);
//声明查询结果变量(线程私有)
static thread_local vector<MemBlock> result;
//清空上次查询结果
result.clear();
//获取队列ID
auto queueID = static_cast<unsigned long>(getQueueID(queue_name));
//获取对应线程ID
int threadID = queueID % IO_THREAD;
//获取对应IO线程
asyncfileio_thread_t *asyncfileio_thread = asyncfileio->work_threads_object[threadID];
//获取队列在线程中的编号
uint32_t which_queue_in_this_io_thread = queueID / IO_THREAD;
//获取最大偏移量(待读取的最后一条消息的位置)
auto max_offset = std::min(static_cast<uint32_t>(offset + number),
asyncfileio_thread->queue_counter[which_queue_in_this_io_thread]);
//计算消息所在的CLUSTER
uint64_t chunk_offset = (which_queue_in_this_io_thread * CLUSTER_SIZE);
//分配一个地址是4k对齐的缓冲区
static thread_local unsigned char *index_record = (unsigned char *) memalign(FILESYSTEM_BLOCK_SIZE,
(INDEX_ENTRY_SIZE * CLUSTER_SIZE) +
FILESYSTEM_BLOCK_SIZE);
//循环并判断是否读取完所有数据
for (auto queue_offset = static_cast<uint64_t>(offset); queue_offset < max_offset;) {
//计算CHUNK_ID,与写入时的计算方式相同
uint64_t chunk_id = ((queue_offset / CLUSTER_SIZE) * (CLUSTER_SIZE * QUEUE_NUM_PER_IO_THREAD) + chunk_offset +
queue_offset % CLUSTER_SIZE);
//计算该CLUSTER中可读的记录数量
auto remaining_num = static_cast<uint32_t>(CLUSTER_SIZE - queue_offset % CLUSTER_SIZE); // >= 1
//如果可读记录数量超过需要读取的数量
if (max_offset - queue_offset < remaining_num) {
//计算剩余需要读取数量
remaining_num = static_cast<uint32_t >(max_offset - queue_offset);
}
//计算索引在索引文件中的偏移量
uint64_t idx_file_offset = INDEX_ENTRY_SIZE * chunk_id;
idx_file_offset = (idx_file_offset / INDEX_MAPPED_BLOCK_RAW_SIZE * INDEX_MAPPED_BLOCK_ALIGNED_SIZE) +
(idx_file_offset % INDEX_MAPPED_BLOCK_RAW_SIZE);
//计算索引所在的4k文件块的偏移量
uint64_t idx_file_offset_aligned_start = (idx_file_offset / FILESYSTEM_BLOCK_SIZE * FILESYSTEM_BLOCK_SIZE);
size_t which_mapped_chunk = idx_file_offset_aligned_start / INDEX_MAPPED_BLOCK_ALIGNED_SIZE;
if (which_mapped_chunk < asyncfileio_thread->index_mapped_flush_start_chunkID) {
//异步读取
pread(asyncfileio->index_fds[queueID % IO_THREAD], index_record,
((INDEX_ENTRY_SIZE * remaining_num + (idx_file_offset - idx_file_offset_aligned_start)) /
FILESYSTEM_BLOCK_SIZE + 1) * FILESYSTEM_BLOCK_SIZE,
idx_file_offset_aligned_start);
} else {
//直接从CHUNK中读取
memcpy(index_record,
asyncfileio_thread->index_file_memory_block[which_mapped_chunk] +
(idx_file_offset_aligned_start % INDEX_MAPPED_BLOCK_ALIGNED_SIZE),
(INDEX_ENTRY_SIZE * remaining_num) + (idx_file_offset - idx_file_offset_aligned_start));
}
//循环解析CLUSTER内未读取的剩余消息
for (uint32_t i = 0; i < remaining_num; i++) {
char *output_buf = nullptr;
int output_length;
unsigned char *serialized =
index_record + INDEX_ENTRY_SIZE * i + idx_file_offset - idx_file_offset_aligned_start;
//根据标志位判断为常规消息
if ((serialized[0] & 0xff) >> 2 != LARGE_MESSAGE_MAGIC_CHAR) {
//直接解码索引即可得消息
output_buf = (char *) deserialize_base36_encoding_add_index(serialized, INDEX_ENTRY_SIZE,
output_length, queue_offset + i);
} else { //长消息(因为评测数据没有长消息,貌似这里的代码没有进行优化)
log_info("big msg");
//解析索引信息
size_t large_msg_size;
size_t large_msg_offset;
memcpy(&large_msg_offset, serialized + 4, 8);
memcpy(&large_msg_size, serialized + 12, 8);
//分配缓冲区
unsigned char large_msg_buf[4096];
//阻塞原子性读取
pread(asyncfileio->data_fds[queueID % IO_THREAD], large_msg_buf, large_msg_size,
large_msg_offset);
//解码消息
output_buf = (char *) deserialize_base36_encoding_add_index((uint8_t *) large_msg_buf, large_msg_size,
output_length, queue_offset + i);
}
//放入结果集
result.emplace_back(output_buf, (size_t) output_length);
}
//更新消息队列访问偏移量
queue_offset += remaining_num;
}
return result;
}
doPhase3函数
volatile bool startedReaderThreadFlag = false;
std::vector<MemBlock> queue_store::doPhase3(int tid, const std::string &queue_name, long offset, long number) {
//获取队列ID
auto queueID = static_cast<unsigned long>(getQueueID(queue_name));
//声明查询结果变量(线程私有)
static thread_local vector<MemBlock> result;
//清空上次查询结果
result.clear();
//获取对应线程ID
int threadID = queueID % IO_THREAD;
//获取对应IO线程
asyncfileio_thread_t *asyncfileio_thread = asyncfileio->work_threads_object[threadID];
//获取队列在线程中的编号
uint32_t which_queue_in_this_io_thread = queueID / IO_THREAD;
//获取最大偏移量(待读取的最后一条消息的位置)
size_t max_queue_offset = asyncfileio_thread->queue_counter[which_queue_in_this_io_thread];
//计算该CHUNK中的待读取消息数
size_t max_result_num = 10 < (max_queue_offset - offset) ? 10 : (max_queue_offset - offset);
//如果尚未启动异步读取线程
if (offset == 0 && !startedReaderThreadFlag) {
//阻塞等待线程初始化
barrier1->Wait([this] {
//循环初始化IO线程
for (int i = 0; i < IO_THREAD; i++) {
// for (size_t chunkID = asyncfileio->work_threads_object[i]->index_mapped_flush_start_chunkID;
// chunkID < asyncfileio->work_threads_object[i]->index_mapped_flush_end_chunkID; chunkID++) {
// //free(asyncfileio->work_threads_object[i]->index_file_memory_block[chunkID]);
// log_debug("free thread %d chunk id %ld", i, chunkID);
// }
//MallocExtension::instance()->ReleaseFreeMemory();
//关闭之前的写入fd
close(asyncfileio->index_fds[i]);
//拼接索引文件路径
string tmp_str = asyncfileio->file_prefix + "_" + std::to_string(i) + ".idx";
//只读方式打开索引文件
asyncfileio->index_fds[i] = open(tmp_str.c_str(), O_RDONLY | O_DIRECT, S_IRUSR | S_IWUSR);
// posix_fadvise(asyncfileio->index_fds[i], 0,
// asyncfileio->mapped_index_files_length[i],
// POSIX_FADV_NORMAL);
}
printf("phase3 start\n");
});
//修改标志位
startedReaderThreadFlag = true;
}
//如果读取结束,返回查询结果
if (max_result_num <= 0) {
return result;
}
//声明各线程的读缓冲指针数组
static thread_local unsigned char **reader_hash_buffer = new unsigned char *[TOTAL_QUEUE_NUM]();
//声明各线程的读缓冲偏移量
static thread_local short *reader_hash_buffer_start_offset = new short[TOTAL_QUEUE_NUM]();
//如果缓冲区未初始化
if (reader_hash_buffer[queueID] == nullptr) {
//为缓冲区分配4k对齐的内存
reader_hash_buffer[queueID] = (unsigned char *) memalign(FILESYSTEM_BLOCK_SIZE,
(INDEX_ENTRY_SIZE * CLUSTER_SIZE) +
FILESYSTEM_BLOCK_SIZE);
}
//将待读取消息数转换为size_t类型
size_t read_num_left = max_result_num;
//循环读取每个CLUSTER中的消息
for (size_t new_offset = offset; new_offset < offset + max_result_num;) {
//计算该CLUSTER中可读的记录数量
size_t this_max_read = std::min<size_t>(read_num_left, CLUSTER_SIZE - (new_offset % CLUSTER_SIZE));
//如果CLUSTER中所有数据都需要读取
if (new_offset % CLUSTER_SIZE == 0) {
uint64_t chunk_offset = (which_queue_in_this_io_thread * CLUSTER_SIZE);
//计算CHUNK_ID
uint64_t chunk_id = ((new_offset / CLUSTER_SIZE) * (CLUSTER_SIZE * QUEUE_NUM_PER_IO_THREAD) + chunk_offset +
new_offset % CLUSTER_SIZE);
//计算索引在索引文件中的偏移量
uint64_t idx_file_offset = INDEX_ENTRY_SIZE * chunk_id;
idx_file_offset = (idx_file_offset / INDEX_MAPPED_BLOCK_RAW_SIZE * INDEX_MAPPED_BLOCK_ALIGNED_SIZE) +
(idx_file_offset % INDEX_MAPPED_BLOCK_RAW_SIZE);
//计算索引所在的4k文件块的偏移量
uint64_t idx_file_offset_aligned_start = (idx_file_offset / FILESYSTEM_BLOCK_SIZE * FILESYSTEM_BLOCK_SIZE);
reader_hash_buffer_start_offset[queueID] = static_cast<short>(idx_file_offset -
idx_file_offset_aligned_start);
size_t which_mapped_chunk = idx_file_offset_aligned_start / INDEX_MAPPED_BLOCK_ALIGNED_SIZE;
if (which_mapped_chunk < asyncfileio_thread->index_mapped_flush_start_chunkID) {
//异步读取
pread(asyncfileio->index_fds[threadID], reader_hash_buffer[queueID],
((INDEX_ENTRY_SIZE * CLUSTER_SIZE + (idx_file_offset - idx_file_offset_aligned_start)) /
FILESYSTEM_BLOCK_SIZE + 1) * FILESYSTEM_BLOCK_SIZE,
idx_file_offset_aligned_start);
} else {
//直接从CHUNK中读取
memcpy(reader_hash_buffer[queueID],
asyncfileio_thread->index_file_memory_block[which_mapped_chunk] +
(idx_file_offset_aligned_start % INDEX_MAPPED_BLOCK_ALIGNED_SIZE),
(INDEX_ENTRY_SIZE * CLUSTER_SIZE) + (idx_file_offset - idx_file_offset_aligned_start));
}
}
//计算消息在CLUSTER内的偏移量
long in_cluster_offset = new_offset % CLUSTER_SIZE;
//循环解析CLUSTER内需要读取的消息
for (uint32_t i = 0; i < this_max_read; i++) {
char *output_buf = nullptr;
int output_length;
unsigned char *serialized = reader_hash_buffer[queueID] + INDEX_ENTRY_SIZE * (in_cluster_offset + i) +
reader_hash_buffer_start_offset[queueID];
//根据标志位判断为常规消息
if ((serialized[0] & 0xff) >> 2 != LARGE_MESSAGE_MAGIC_CHAR) {
output_buf = (char *) deserialize_base36_encoding_add_index(serialized, INDEX_ENTRY_SIZE,
output_length, new_offset + i);
} else { //长消息
//解析索引信息
size_t large_msg_size;
size_t large_msg_offset;
memcpy(&large_msg_offset, serialized + 4, 8);
memcpy(&large_msg_size, serialized + 12, 8);
//分配缓冲区
unsigned char large_msg_buf[4096];
//阻塞原子性取
pread(asyncfileio->data_fds[queueID % IO_THREAD], large_msg_buf, large_msg_size,
large_msg_offset);
//解码消息
output_buf = (char *) deserialize_base36_encoding_add_index((uint8_t *) large_msg_buf, large_msg_size,
output_length, new_offset + i);
}
//放入结果集
result.emplace_back(output_buf, (size_t) output_length);
}
//重新计算消息队列读取偏移量
new_offset += this_max_read;
//重新计算剩余消息数
read_num_left -= this_max_read;
}
//返回查询结果
return result;
}
构造函数
queue_store::queue_store() {
//打印Git版本
print_version();
//创建一个内存栅栏(阈值为校验线程的数量)
barrier1 = new Barrier(CHECK_THREAD_NUM);
//初始化文件异步写入线程
asyncfileio = new asyncfileio_t(DATA_FILE_PATH);
//启动异步文件写入线程
asyncfileio->startIOThread();
}
3. asyncfileio实现
asyncfileio_thread_t类
class asyncfileio_thread_t {
public:
//线程ID
const int thread_id;
//消息数据文件大小
atomic<long> data_file_current_size;
//数据文件描述符
int data_file_fd;
//消息索引文件大小
size_t index_file_size;
//索引文件描述符
int index_file_fd;
//索引的各种偏移量
size_t current_index_mapped_start_offset;
size_t current_index_mapped_end_offset;
size_t current_index_mapped_start_chunk;
size_t index_mapped_flush_start_chunkID;
size_t index_mapped_flush_end_chunkID;
atomic<int> *index_mapped_block_write_counter;
std::mutex *mapped_block_mtx;
std::condition_variable *mapped_block_cond;
//索引文件缓冲区指针
char **index_file_memory_block;
uint32_t *queue_counter;
//IO任务队列
BlockingQueue<asyncio_task_t *> *blockingQueue;
//线程状态
enum asyncfileio_thread_status status;
asyncfileio_thread_t(int tid, std::string file_prefix) : thread_id(tid) {
//初始化文件大小变量
this->data_file_current_size.store(0);
this->index_file_size = 0;
//初始化索引相关变量
index_mapped_block_write_counter = new atomic<int>[MAX_MAPED_CHUNK_NUM];
index_file_memory_block = new char *[MAX_MAPED_CHUNK_NUM];
for (int i = 0; i < MAX_MAPED_CHUNK_NUM; i++) {
index_mapped_block_write_counter[i].store(0);
index_file_memory_block[i] = nullptr;
}
//分配缓冲区内存
for (int i = 0; i < MAX_CONCURRENT_INDEX_MAPPED_BLOCK_NUM; i++) {
index_file_memory_block[i] = (char *) memalign(FILESYSTEM_BLOCK_SIZE, INDEX_MAPPED_BLOCK_ALIGNED_SIZE);
//index_file_memory_block[i] = ;new char[INDEX_MAPPED_BLOCK_SIZE];
}
//初始化队列和锁
queue_counter = new uint32_t[QUEUE_NUM_PER_IO_THREAD];
memset(queue_counter, 0, sizeof(uint32_t) * QUEUE_NUM_PER_IO_THREAD);
mapped_block_mtx = new std::mutex[MAX_MAPED_CHUNK_NUM];
mapped_block_cond = new std::condition_variable[MAX_MAPED_CHUNK_NUM];
this->blockingQueue = new BlockingQueue<asyncio_task_t *>;
//初始化数据文件
string tmp_str = file_prefix + "_" + std::to_string(tid) + ".data";
this->data_file_fd = open(tmp_str.c_str(), O_RDWR | O_CREAT, S_IRUSR | S_IWUSR);
ftruncate(data_file_fd, 0);
ftruncate(data_file_fd, DATA_FILE_MAX_SIZE);
//初始化索引文件
tmp_str = file_prefix + "_" + std::to_string(tid) + ".idx";
this->index_file_fd = open(tmp_str.c_str(), O_WRONLY | O_CREAT | O_DIRECT | O_SYNC, S_IRUSR | S_IWUSR);
ftruncate(index_file_fd, 0);
//初始化索引偏移量
this->current_index_mapped_start_chunk = 0;
this->current_index_mapped_start_offset = 0;
this->current_index_mapped_end_offset =
((size_t) INDEX_MAPPED_BLOCK_ALIGNED_SIZE) * MAX_CONCURRENT_INDEX_MAPPED_BLOCK_NUM;
this->index_file_size = 0;
}
void doIO(asyncio_task_t *asyncio_task) {
//遍历所有CHUNK
for (; current_index_mapped_start_chunk < MAX_MAPED_CHUNK_NUM; current_index_mapped_start_chunk++) {
//如果CHUNK写满
if (index_mapped_block_write_counter[current_index_mapped_start_chunk].load() >=
INDEX_BLOCK_WRITE_TIMES_TO_FULL) {
//更新索引文件大小
index_file_size += INDEX_MAPPED_BLOCK_ALIGNED_SIZE;
ftruncate(index_file_fd, index_file_size);
//CHUNK写入索引文件
pwrite(index_file_fd, index_file_memory_block[current_index_mapped_start_chunk],
INDEX_MAPPED_BLOCK_ALIGNED_SIZE,
INDEX_MAPPED_BLOCK_ALIGNED_SIZE * current_index_mapped_start_chunk);
//将使用完的CHUNK放到缓冲区队列的队尾
int next_chunk = current_index_mapped_start_chunk + MAX_CONCURRENT_INDEX_MAPPED_BLOCK_NUM;
current_index_mapped_start_offset += INDEX_MAPPED_BLOCK_ALIGNED_SIZE;
current_index_mapped_end_offset += INDEX_MAPPED_BLOCK_ALIGNED_SIZE;
{
std::unique_lock<std::mutex> lock(mapped_block_mtx[next_chunk]);
index_file_memory_block[next_chunk] = index_file_memory_block[current_index_mapped_start_chunk];
}
//通知等待使用next_chunk的线程
mapped_block_cond[next_chunk].notify_all();
//释放缓冲区队列头部的原引用
index_file_memory_block[current_index_mapped_start_chunk] = nullptr;
log_info("io thread %d advanced to %d", this->thread_id, next_chunk);
} else {
break;
}
}
}
};
ioThreadFunction函数
感觉独立声明一个函数不太好,应该封装到类里面,同时避免使用全局变量。IO线程类只要实现初始化、FLUSH、缓冲区回收、分配、写入、读取6个操作,IO流程控制实现IO线程同步启动写、结束写、启动读、结束读4个操作即可。
bool allFlushFlag = false;
bool ioFinished = false;
Barrier *barrier;
atomic<int> *finish_thread_counter;
void ioThreadFunction(asyncfileio_thread_t *args) {
//保存传入的asyncfileio_thread_t实例
asyncfileio_thread_t *work_thread = args;
//线程状态设为运行中
work_thread->status = AT_RUNNING;
//线程循环
for (;;) {
//从队列中取Task
asyncio_task_t *task = work_thread->blockingQueue->take();
//如果线程已经关闭或者收到关闭命令
if (work_thread->status == AT_CLOSING || task->global_offset == -1) {
size_t force_flush_chunk_num = 0;
//TID为0的线程flush位置要向后偏移1
if (work_thread->thread_id < 1) {
force_flush_chunk_num = 1;
}
//初始化索引flush操作偏移量
work_thread->index_mapped_flush_start_chunkID =
work_thread->current_index_mapped_start_chunk + force_flush_chunk_num;
work_thread->index_mapped_flush_end_chunkID = work_thread->index_mapped_flush_start_chunkID;
//索引缓冲区循环落盘
for (size_t i = work_thread->current_index_mapped_start_chunk;
i < work_thread->index_mapped_flush_start_chunkID &&
work_thread->index_file_memory_block[i] != nullptr; i++) {
work_thread->index_file_size += INDEX_MAPPED_BLOCK_ALIGNED_SIZE;
ftruncate(work_thread->index_file_fd, work_thread->index_file_size);
//写入索引文件
pwrite(work_thread->index_file_fd, work_thread->index_file_memory_block[i],
INDEX_MAPPED_BLOCK_ALIGNED_SIZE, INDEX_MAPPED_BLOCK_ALIGNED_SIZE * i);
//回收缓冲区
free(work_thread->index_file_memory_block[i]);
}
//等待内核完成写操作
fsync(work_thread->index_file_fd);
//写入完成线程数计数器加1
finish_thread_counter->fetch_add(1);
//写入消息数据文件
ftruncate(work_thread->data_file_fd, work_thread->data_file_current_size.load());
//终止循环
break;
}
//执行IO操作
work_thread->doIO(task);
//回收task实例
delete task;
}
}
asyncfileio_t类
class asyncfileio_t {
public:
//构造函数
asyncfileio_t(std::string file_prefix) {
//打印内存分配信息
malloc_stats();
//调用TCMalloc回收内存
MallocExtension::instance()->ReleaseFreeMemory();
malloc_stats();
//保存文件前缀名,这里是消息数据文件路径:"/alidata1/race2018/data/data"
this->file_prefix = file_prefix;
//统计完成IO工作的线程的计数器
finish_thread_counter = new atomic<int>(0);
//等待IO线程完成工作的栅栏
barrier = new Barrier(SEND_THREAD_NUM);
//创建线程数组
for (int i = 0; i < IO_THREAD; i++) {
work_threads_object[i] = new asyncfileio_thread_t(i, file_prefix);
}
}
void startIOThread() {
for (int i = 0; i < IO_THREAD; i++) {
work_threads_handle[i] = std::thread(ioThreadFunction, work_threads_object[i]);
//分离新创建的线程
work_threads_handle[i].detach();
}
}
//等待写入操作完成
void waitFinishIO(int tid) {
//避免各个线程重复执行该方法
if (!ioFinished) {
if (tid == 0) {
printf("in wait_flush function %ld\n", getCurrentTimeInMS());
}
//等待十个线程全部首次进入读取阶段
barrier->Wait([this] {
printf("start send flush cmd %ld\n", getCurrentTimeInMS());
//在工作线程的队尾加入结束写入操作的Task
for (int i = 0; i < IO_THREAD; i++) {
asyncio_task_t *task = new asyncio_task_t(-1);
work_threads_object[i]->blockingQueue->put(task);
}
printf("after send flush cmd %ld\n", getCurrentTimeInMS());
});
if (tid == 0) {
printf("before wait flush finish %ld\n", getCurrentTimeInMS());
}
//自旋等待所有工作线程完成IO操作
while (finish_thread_counter->load() < IO_THREAD) {};
if (tid == 0) {
printf("after wait flush finish %ld\n", getCurrentTimeInMS());
}
//等待十个工作线程全部完成写入操作
barrier->Wait([this] {
//malloc_stats();
//调用TCMalloc回收内存
MallocExtension::instance()->ReleaseFreeMemory();
//malloc_stats();
//循环初始化读文件FD
for (int i = 0; i < IO_THREAD; i++) {
//索引文件名
string tmp_str = file_prefix + "_" + std::to_string(i) + ".idx";
index_fds[i] = open(tmp_str.c_str(), O_RDONLY, S_IRUSR | S_IWUSR);
data_fds[i] = work_threads_object[i]->data_file_fd;
mapped_index_files_length[i] = work_threads_object[i]->index_file_size;
//执行内核预读取优化
int ret = posix_fadvise(index_fds[i], 0,
work_threads_object[i]->index_file_size,
POSIX_FADV_RANDOM);
printf("ret %d\n", ret);
}
});
if (tid == 0) {
printf("finish wait_flush function %ld\n", getCurrentTimeInMS());
}
//修改标志位
ioFinished = true;
}
}
//析构函数
~asyncfileio_t() {
printf("f\n");
//创建flush_all_func回收线程
std::thread flush_thread(flush_all_func, this);
flush_thread.detach();
}
string file_prefix;
//IO线程实例
asyncfileio_thread_t *work_threads_object[IO_THREAD];
//IO线程句柄
std::thread work_threads_handle[IO_THREAD];
//各线程索引文件长度
size_t mapped_index_files_length[IO_THREAD];
//索引文件描述符数组
int index_fds[IO_THREAD];
//数据文件描述符数组
int data_fds[IO_THREAD];
};
void flush_all_func(void *args) {
//转换this指针类型
asyncfileio_t *asyncfileio = (asyncfileio_t *) args;
//循环回收IO线程
for (int tid = 0; tid < IO_THREAD; tid++) {
asyncfileio_thread_t *work_thread = asyncfileio->work_threads_object[tid];
//逐个CHUNK进行回收
for (size_t i = work_thread->index_mapped_flush_start_chunkID;
i < MAX_MAPED_CHUNK_NUM && work_thread->index_file_memory_block[i] != nullptr; i++) {
//更新回收队列队尾位置
work_thread->index_mapped_flush_end_chunkID++;
//计算块地址对齐后的文件大小
work_thread->index_file_size += INDEX_MAPPED_BLOCK_ALIGNED_SIZE;
//修改文件大小
ftruncate(work_thread->index_file_fd, work_thread->index_file_size);
//写入硬盘
pwrite(work_thread->index_file_fd, work_thread->index_file_memory_block[i],
INDEX_MAPPED_BLOCK_ALIGNED_SIZE, INDEX_MAPPED_BLOCK_ALIGNED_SIZE * i);
//回收内存缓冲区
free(work_thread->index_file_memory_block[i]);
}
//阻塞等待系统完成落盘操作
fsync(work_thread->index_file_fd);
}
}
4. Barrier实现
class Barrier {
public:
explicit Barrier(std::size_t iCount) :
mThreshold(iCount),
mCount(iCount),
mGeneration(0) {
}
void Wait(std::function<void()> func) {
std::unique_lock<std::mutex> lLock{mMutex};
auto lGen = mGeneration;
//如果计数器为0
if (!--mCount) {
//进入下一轮
mGeneration++;
mCount = mThreshold;
func();
//通知所有阻塞线程
mCond.notify_all();
} else {
//阻塞等待通知,直到进入下一轮
mCond.wait(lLock, [this, lGen] { return lGen != mGeneration; });
}
}
private:
//同步锁
std::mutex mMutex;
//条件变量
std::condition_variable mCond;
//下一轮的计数器阈值
std::size_t mThreshold;
//每轮的等待计数器
std::size_t mCount;
//第几轮栅栏
std::size_t mGeneration;
};
5. fast_base64实现
serialization.cpp中的方法
该文件包含了对消息进行序列化的函数,这些函数传入消息及其长度信息,首先计算消息压缩长度和偏移量,然后进行内存拷贝填充并根据当前系统环境进行编码解码操作,最后返回一个序列化后的数据的指针。
Base64序列化
int serialize_base64_decoding(uint8_t *message, uint16_t len, uint8_t *serialized) {
//计算消息需要压缩的部分的长度,(实际运行时这里是58-10=48,正好不需要填充)
auto serialize_len = len - FIXED_PART_LEN;
//计算需要填充的字节数
int padding_chars = (4 - serialize_len % 4) % 4;
uint8_t *buf = message;
//计算压缩后压缩部分的长度
size_t estimated_length = 3 * (serialize_len / 4 + (serialize_len % 4 == 0 ? 0 : 1));
//先拷贝不压缩的部分
memcpy(serialized + estimated_length, message + serialize_len, FIXED_PART_LEN);
//然后用“BLINK”补齐需要填充的部分
// attention: add padding chars, assume following chars enough >= 3
memcpy(message + serialize_len, "BLINK", padding_chars);
//如果支持AVX2指令集
#ifdef __AVX2__
fast_avx2_base64_decode(reinterpret_cast<char *>(serialized),
reinterpret_cast<const char *>(buf),
serialize_len + padding_chars);
#else
//只压缩消息和填充字符
chromium_base64_decode(reinterpret_cast<char *>(serialized),
reinterpret_cast<const char *>(buf),
serialize_len + padding_chars);
#endif
serialized[estimated_length + FIXED_PART_LEN] = padding_chars;
return estimated_length + FIXED_PART_LEN + 1;
}
Base64反序列化
序列化的逆操作,这里不再赘述。
uint8_t *deserialize_base64_encoding(const uint8_t *serialized, uint16_t total_serialized_len, int &len) {
auto serialize_len = total_serialized_len - FIXED_PART_LEN - 1;
auto *deserialized = new uint8_t[total_serialized_len / 3 * 4 + 16];
#ifdef __AVX2__
size_t length = fast_avx2_base64_encode(reinterpret_cast<char *>(deserialized),
reinterpret_cast<const char *>(serialized), serialize_len);
#else
size_t length = chromium_base64_encode(reinterpret_cast<char *>(deserialized),
reinterpret_cast<const char *>(serialized), serialize_len);
#endif
memcpy(deserialized + length - serialized[total_serialized_len - 1], serialized + serialize_len, FIXED_PART_LEN);
len = length - serialized[total_serialized_len - 1] + FIXED_PART_LEN;
return deserialized;
}
Base64序列化(省略索引)
serialization.hpp中的宏定义如下:
#define BASE64_INFO_LEN (2u)
#define INDEX_LEN (4u)
#define VARYING_VERIFY_LEN (4u)
#define FIXED_PART_LEN (10u)
很明显:
FIXED_PART_LEN = BASE64_INFO_LEN + INDEX_LEN + VARYING_VERIFY_LEN
而其中INDEX_LEN可以通过计算得出,所以serialize_base64_decoding_skip_index方法省略了INDEX_LEN部分的存储。
// Skip index =================================================================================================
int serialize_base64_decoding_skip_index(uint8_t *message, uint16_t len, uint8_t *serialized) {
auto serialize_len = len - FIXED_PART_LEN;
int padding_chars = (4 - serialize_len % 4) % 4;
uint8_t *buf = message;
size_t estimated_length = 3 * (serialize_len / 4 + (serialize_len % 4 == 0 ? 0 : 1));
memcpy(serialized + estimated_length, message + serialize_len, BASE64_INFO_LEN);
memcpy(serialized + estimated_length + BASE64_INFO_LEN, message + serialize_len + BASE64_INFO_LEN + INDEX_LEN,
VARYING_VERIFY_LEN);
// attention: add padding chars, assume following chars enough >= 3
memcpy(message + serialize_len, "BLINK", padding_chars);
#ifdef __AVX2__
fast_avx2_base64_decode(reinterpret_cast<char *>(serialized),
reinterpret_cast<const char *>(buf),
serialize_len + padding_chars);
#else
chromium_base64_decode(reinterpret_cast<char *>(serialized),
reinterpret_cast<const char *>(buf),
serialize_len + padding_chars);
#endif
serialized[estimated_length + FIXED_PART_LEN - INDEX_LEN] = padding_chars;
return estimated_length + FIXED_PART_LEN - INDEX_LEN + 1;
}
Base64反序列化(填充索引)
uint8_t *deserialize_base64_encoding_add_index(const uint8_t *serialized, uint16_t total_serialized_len,
int &deserialized_len, int32_t idx) {
auto serialize_len = total_serialized_len - (FIXED_PART_LEN - INDEX_LEN) - 1;
auto *deserialized = new uint8_t[total_serialized_len / 3 * 4 + 16];
#ifdef __AVX2__
size_t length = fast_avx2_base64_encode(reinterpret_cast<char *>(deserialized),
reinterpret_cast<const char *>(serialized), serialize_len);
#else
size_t length = chromium_base64_encode(reinterpret_cast<char *>(deserialized),
reinterpret_cast<const char *>(serialized), serialize_len);
#endif
size_t offset = length - serialized[total_serialized_len - 1];
memcpy(deserialized + offset, serialized + serialize_len, BASE64_INFO_LEN);
offset += BASE64_INFO_LEN;
memcpy(deserialized + offset, &idx, sizeof(int32_t));
offset += INDEX_LEN;
memcpy(deserialized + offset, serialized + serialize_len + BASE64_INFO_LEN, VARYING_VERIFY_LEN);
deserialized_len = length - serialized[total_serialized_len - 1] + FIXED_PART_LEN;
return deserialized;
}
Base64反序列化指定位置的消息(填充索引)
void deserialize_base64_encoding_add_index_in_place(const uint8_t *serialized, uint16_t total_serialized_len,
uint8_t *deserialized, int &deserialized_len, int32_t idx) {
auto serialize_len = total_serialized_len - (FIXED_PART_LEN - INDEX_LEN) - 1;
#ifdef __AVX2__
size_t length = fast_avx2_base64_encode(reinterpret_cast<char *>(deserialized),
reinterpret_cast<const char *>(serialized), serialize_len);
#else
size_t length = chromium_base64_encode(reinterpret_cast<char *>(deserialized),
reinterpret_cast<const char *>(serialized), serialize_len);
#endif
size_t offset = length - serialized[total_serialized_len - 1];
memcpy(deserialized + offset, serialized + serialize_len, BASE64_INFO_LEN);
offset += BASE64_INFO_LEN;
memcpy(deserialized + offset, &idx, sizeof(int32_t));
offset += INDEX_LEN;
memcpy(deserialized + offset, serialized + serialize_len + BASE64_INFO_LEN, VARYING_VERIFY_LEN);
deserialized_len = length - serialized[total_serialized_len - 1] + FIXED_PART_LEN;
}
// End of Skip index ========================================================================================
Base36序列化(省略索引)
// ------------------------------- Begin of Base36 -------------------------------------------------------------
int serialize_base36_decoding_skip_index(uint8_t *message, uint16_t len, uint8_t *serialized) {
auto serialize_len = len - FIXED_PART_LEN;
int padding_chars = (4 - serialize_len % 4) % 4;
uint8_t *buf = message;
size_t estimated_length = 3 * (serialize_len / 4 + (serialize_len % 4 == 0 ? 0 : 1));
memcpy(serialized + estimated_length, message + serialize_len, BASE64_INFO_LEN);
memcpy(serialized + estimated_length + BASE64_INFO_LEN, message + serialize_len + BASE64_INFO_LEN + INDEX_LEN,
VARYING_VERIFY_LEN);
// attention: add padding chars, assume following chars enough >= 3
memcpy(message + serialize_len, "BLINK", padding_chars);
#ifdef __AVX2__
fast_avx2_base64_decode(reinterpret_cast<char *>(serialized),
reinterpret_cast<const char *>(buf),
serialize_len + padding_chars);
#else
chromium_base64_decode(reinterpret_cast<char *>(serialized),
reinterpret_cast<const char *>(buf),
serialize_len + padding_chars);
#endif
return estimated_length + FIXED_PART_LEN - INDEX_LEN;
}
Base36反序列化(填充索引)
uint8_t *deserialize_base36_encoding_add_index(const uint8_t *serialized, uint16_t total_serialized_len,
int &deserialized_len, int32_t idx) {
auto serialize_len = total_serialized_len - (FIXED_PART_LEN - INDEX_LEN);
auto *deserialized = new uint8_t[total_serialized_len / 3 * 4 + 16];
// 1st: deserialize preparation: base64 encoding
#ifdef __AVX2__
size_t length = fast_avx2_base64_encode(reinterpret_cast<char *>(deserialized),
reinterpret_cast<const char *>(serialized), serialize_len);
#else
size_t length = chromium_base64_encode(reinterpret_cast<char *>(deserialized),
reinterpret_cast<const char *>(serialized), serialize_len);
#endif
// 2nd: skip padding (padding could be 'A'-'Z', '+', '/', '=')
for (; deserialized[length - 1] >= 'A' && deserialized[length - 1] <= 'Z' && length >= 0; length--) {}
// 3rd: append other info
size_t offset = length;
memcpy(deserialized + offset, serialized + serialize_len, BASE64_INFO_LEN);
offset += BASE64_INFO_LEN;
memcpy(deserialized + offset, &idx, sizeof(int32_t));
offset += INDEX_LEN;
memcpy(deserialized + offset, serialized + serialize_len + BASE64_INFO_LEN, VARYING_VERIFY_LEN);
// 4th: assign the correct length
deserialized_len = length + FIXED_PART_LEN;
return deserialized;
}
// ------------------------------ End of Base36, do not support A-Z yet --------------------------------------------
通用序列化
int serialize(uint8_t *message, uint16_t len, uint8_t *serialized) {
// add the header to indicate raw message varying-length part size
int serialize_len = len - FIXED_PART_LEN;
if (len < 128) {
serialized[0] = static_cast<uint8_t>(len - FIXED_PART_LEN);
serialized += 1;
} else {
uint16_t tmp = (len - FIXED_PART_LEN);
serialized[0] = static_cast<uint8_t>((tmp >> 7u) | 0x80); // assume < 32767
serialized[1] = static_cast<uint8_t>(tmp & (uint8_t) 0x7f);
serialized += 2;
}
uint32_t next_extra_3bits_idx = 5u * serialize_len;
uint32_t next_5bits_idx = 0;
// attention: message is not usable later
for (int i = 0; i < serialize_len; i++) {
message[i] = message[i] >= 'a' ? message[i] - 'a' : message[i] - '0' + (uint8_t) 26;
}
// attention: must clear to be correct
memset(serialized, 0, (len - FIXED_PART_LEN));
// 1) construct the compressed part
for (int i = 0; i < serialize_len; i++) {
uint16_t cur_uchar = message[i];
uint16_t expand_uchar = cur_uchar < MAX_FIVE_BITS_INT ? (cur_uchar << 11u) : (MAX_FIVE_BITS_INT << 11u);
int shift_bits = (next_5bits_idx & 0x7u);
expand_uchar >>= shift_bits;
int idx = (next_5bits_idx >> 3u);
serialized[idx] |= (expand_uchar >> 8u);
serialized[idx + 1] |= (expand_uchar & 0xffu);
next_5bits_idx += 5;
if (cur_uchar >= MAX_FIVE_BITS_INT) {
// do extra bits operations
expand_uchar = ((cur_uchar - MAX_FIVE_BITS_INT) << 13u);
shift_bits = (next_extra_3bits_idx & 0x7u);
expand_uchar >>= shift_bits;
// assume little-endian
idx = (next_extra_3bits_idx >> 3u);
serialized[idx] |= (expand_uchar >> 8u);
serialized[idx + 1] |= (expand_uchar & 0xffu);
next_extra_3bits_idx += 3;
}
}
// 2) left FIXED_PART_LEN, should use memcpy
int start_copy_byte_idx = (next_extra_3bits_idx >> 3u) + ((next_extra_3bits_idx & 0x7u) != 0);
memcpy(serialized + start_copy_byte_idx, message + serialize_len, FIXED_PART_LEN);
return start_copy_byte_idx + FIXED_PART_LEN + (len < 128 ? 1 : 2);
}
通用反序列化
uint8_t *deserialize(const uint8_t *serialized, int &len) {
// get the length of varying part
uint16_t varying_byte_len;
if ((serialized[0] & 0x80u) == 0) {
varying_byte_len = serialized[0];
serialized += 1;
} else {
varying_byte_len = static_cast<uint16_t>(((serialized[0] & 0x7fu) << 7u) + serialized[1]);
serialized += 2;
}
uint32_t next_extra_3bits_idx = 5u * varying_byte_len;
uint32_t next_5bits_idx = 0;
auto *deserialized = new uint8_t[varying_byte_len + 8];
len = varying_byte_len + FIXED_PART_LEN;
// deserialize
for (int i = 0; i < varying_byte_len; i++) {
int idx = (next_5bits_idx >> 3u);
uint16_t value = (serialized[idx] << 8u) + serialized[idx + 1];
value = (value >> (11u - (next_5bits_idx & 07u))) & MAX_FIVE_BITS_INT;
if (value != MAX_FIVE_BITS_INT) {
deserialized[i] = static_cast<uint8_t>(value < 26 ? 'a' + value : value - 26 + '0');
} else {
idx = (next_extra_3bits_idx >> 3u);
value = (serialized[idx] << 8u) + serialized[idx + 1];
value = (value >> (13u - (next_extra_3bits_idx & 07u))) & (uint8_t) 0x7;
deserialized[i] = value + '5';
next_extra_3bits_idx += 3;
}
next_5bits_idx += 5;
}
// 2) copy the fixed part
memcpy(deserialized + varying_byte_len,
serialized + (next_extra_3bits_idx >> 3u) + ((next_extra_3bits_idx & 0x7u) != 0), FIXED_PART_LEN);
return deserialized;
}
chromiumbase64.c中的方法
这里应该是参考了google的chromium浏览器内核中的base64加密算法实现,采用的是查表法实现,速度很快。
#define BADCHAR 0x01FFFFFF
/**
* you can control if we use padding by commenting out this
* next line. However, I highly recommend you use padding and not
* using it should only be for compatability with a 3rd party.
* Also, 'no padding' is not tested!
*/
#define DOPAD 1
/*
* if we aren't doing padding
* set the pad character to NULL
*/
#ifndef DOPAD
#undef CHARPAD
#define CHARPAD '\0'
#endif
size_t chromium_base64_encode(char* dest, const char* str, size_t len)
{
size_t i = 0;
uint8_t* p = (uint8_t*) dest;
/* unsigned here is important! */
uint8_t t1, t2, t3;
if (len > 2) {
for (; i < len - 2; i += 3) {
t1 = str[i]; t2 = str[i+1]; t3 = str[i+2];
*p++ = e0[t1];
*p++ = e1[((t1 & 0x03) << 4) | ((t2 >> 4) & 0x0F)];
*p++ = e1[((t2 & 0x0F) << 2) | ((t3 >> 6) & 0x03)];
*p++ = e2[t3];
}
}
switch (len - i) {
case 0:
break;
case 1:
t1 = str[i];
*p++ = e0[t1];
*p++ = e1[(t1 & 0x03) << 4];
*p++ = CHARPAD;
*p++ = CHARPAD;
break;
default: /* case 2 */
t1 = str[i]; t2 = str[i+1];
*p++ = e0[t1];
*p++ = e1[((t1 & 0x03) << 4) | ((t2 >> 4) & 0x0F)];
*p++ = e2[(t2 & 0x0F) << 2];
*p++ = CHARPAD;
}
*p = '\0';
return p - (uint8_t*)dest;
}
size_t chromium_base64_decode(char* dest, const char* src, size_t len)
{
if (len == 0) return 0;
#ifdef DOPAD
/*
* if padding is used, then the message must be at least
* 4 chars and be a multiple of 4
*/
if (len < 4 || (len % 4 != 0)) {
return MODP_B64_ERROR; /* error */
}
/* there can be at most 2 pad chars at the end */
if (src[len-1] == CHARPAD) {
len--;
if (src[len -1] == CHARPAD) {
len--;
}
}
#endif
size_t i;
int leftover = len % 4;
size_t chunks = (leftover == 0) ? len / 4 - 1 : len /4;
uint8_t* p = (uint8_t*)dest;
uint32_t x = 0;
const uint8_t* y = (uint8_t*)src;
for (i = 0; i < chunks; ++i, y += 4) {
x = d0[y[0]] | d1[y[1]] | d2[y[2]] | d3[y[3]];
if (x >= BADCHAR) return MODP_B64_ERROR;
*p++ = ((uint8_t*)(&x))[0];
*p++ = ((uint8_t*)(&x))[1];
*p++ = ((uint8_t*)(&x))[2];
}
switch (leftover) {
case 0:
x = d0[y[0]] | d1[y[1]] | d2[y[2]] | d3[y[3]];
if (x >= BADCHAR) return MODP_B64_ERROR;
*p++ = ((uint8_t*)(&x))[0];
*p++ = ((uint8_t*)(&x))[1];
*p = ((uint8_t*)(&x))[2];
return (chunks+1)*3;
break;
case 1: /* with padding this is an impossible case */
x = d0[y[0]];
*p = *((uint8_t*)(&x)); // i.e. first char/byte in int
break;
case 2: // * case 2, 1 output byte */
x = d0[y[0]] | d1[y[1]];
*p = *((uint8_t*)(&x)); // i.e. first char
break;
default: /* case 3, 2 output bytes */
x = d0[y[0]] | d1[y[1]] | d2[y[2]]; /* 0x3c */
*p++ = ((uint8_t*)(&x))[0];
*p = ((uint8_t*)(&x))[1];
break;
}
if (x >= BADCHAR) return MODP_B64_ERROR;
return 3*chunks + (6*leftover)/8;
}
fastavxbase64.c
代码中的最大亮点,使用AVX2指令集实现base64加解密。具体指令说明可以参考:https://software.intel.com/en-us/node/524017?language=ru
#include "fastavxbase64.h"
#include <x86intrin.h>
#include <stdbool.h>
/**
* This code borrows from Wojciech Mula's library at
* https://github.com/WojciechMula/base64simd (published under BSD)
* as well as code from Alfred Klomp's library https://github.com/aklomp/base64 (published under BSD)
*
*/
/**
* Note : Hardware such as Knights Landing might do poorly with this AVX2 code since it relies on shuffles. Alternatives might be faster.
*/
static inline __m256i enc_reshuffle(const __m256i input) {
//_mm256_shuffle_epi8可以对32个8bit整数进行查表操作,其中参数一是被查表,参数二是查找掩码
// translation from SSE into AVX2 of procedure
// https://github.com/WojciechMula/base64simd/blob/master/encode/unpack_bigendian.cpp
const __m256i in = _mm256_shuffle_epi8(input, _mm256_set_epi8(
10, 11, 9, 10,
7, 8, 6, 7,
4, 5, 3, 4,
1, 2, 0, 1,
14, 15, 13, 14,
11, 12, 10, 11,
8, 9, 7, 8,
5, 6, 4, 5
));
//_mm256_set1_epi32使用参数一初始化8*32向量集
//_mm256_and_si256对向量集进行按位与操作
//_mm256_mulhi_epu16进行16位无符号向量集的乘法操作,并保留高bit位
//_mm256_mullo_epi16进行16位无符号向量集的乘法操作,并保留低bit位
const __m256i t0 = _mm256_and_si256(in, _mm256_set1_epi32(0x0fc0fc00));
const __m256i t1 = _mm256_mulhi_epu16(t0, _mm256_set1_epi32(0x04000040));
const __m256i t2 = _mm256_and_si256(in, _mm256_set1_epi32(0x003f03f0));
const __m256i t3 = _mm256_mullo_epi16(t2, _mm256_set1_epi32(0x01000010));
//_mm256_or_si256对向量集进行按位或操作
return _mm256_or_si256(t1, t3);
}
static inline __m256i enc_translate(const __m256i in) {
//_mm256_setr_epi8使用传入的参数反序初始化32*8向量集(a0赋值给b31,a1赋值给b30)
//_mm256_set1_epi8使用传入的参数顺序初始化32*8向量集
//_mm256_subs_epu8用参数一逐字节去减参数二
//_mm256_cmpgt_epi8比较参数一和参数二各字节的大小,参数一的大则返回向量1,否则返回向量0
const __m256i lut = _mm256_setr_epi8(
65, 71, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -19, -16, 0, 0, 65, 71,
-4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -19, -16, 0, 0);
__m256i indices = _mm256_subs_epu8(in, _mm256_set1_epi8(51));
__m256i mask = _mm256_cmpgt_epi8((in), _mm256_set1_epi8(25));
indices = _mm256_sub_epi8(indices, mask);
__m256i out = _mm256_add_epi8(in, _mm256_shuffle_epi8(lut, indices));
return out;
}
static inline __m256i dec_reshuffle(__m256i in) {
// inlined procedure pack_madd from https://github.com/WojciechMula/base64simd/blob/master/decode/pack.avx2.cpp
// The only difference is that elements are reversed,
// only the multiplication constants were changed.
const __m256i merge_ab_and_bc = _mm256_maddubs_epi16(in, _mm256_set1_epi32(0x01400140)); //_mm256_maddubs_epi16 is likely expensive
__m256i out = _mm256_madd_epi16(merge_ab_and_bc, _mm256_set1_epi32(0x00011000));
// end of inlined
// Pack bytes together within 32-bit words, discarding words 3 and 7:
out = _mm256_shuffle_epi8(out, _mm256_setr_epi8(
2, 1, 0, 6, 5, 4, 10, 9, 8, 14, 13, 12, -1, -1, -1, -1,
2, 1, 0, 6, 5, 4, 10, 9, 8, 14, 13, 12, -1, -1, -1, -1
));
// the call to _mm256_permutevar8x32_epi32 could be replaced by a call to _mm256_storeu2_m128i but it is doubtful that it would help
return _mm256_permutevar8x32_epi32(
out, _mm256_setr_epi32(0, 1, 2, 4, 5, 6, -1, -1));
}
size_t fast_avx2_base64_encode(char* dest, const char* str, size_t len) {
const char* const dest_orig = dest;
if(len >= 32 - 4) {
// first load is masked
__m256i inputvector = _mm256_maskload_epi32((int const*)(str - 4), _mm256_set_epi32(
0x80000000,
0x80000000,
0x80000000,
0x80000000,
0x80000000,
0x80000000,
0x80000000,
0x00000000 // we do not load the first 4 bytes
));
//////////
// Intel docs: Faults occur only due to mask-bit required memory accesses that caused the faults.
// Faults will not occur due to referencing any memory location if the corresponding mask bit for
//that memory location is 0. For example, no faults will be detected if the mask bits are all zero.
////////////
while(true) {
inputvector = enc_reshuffle(inputvector);
inputvector = enc_translate(inputvector);
_mm256_storeu_si256((__m256i *)dest, inputvector);
str += 24;
dest += 32;
len -= 24;
if(len >= 32) {
inputvector = _mm256_loadu_si256((__m256i *)(str - 4)); // no need for a mask here
// we could do a mask load as long as len >= 24
} else {
break;
}
}
}
size_t scalarret = chromium_base64_encode(dest, str, len);
if(scalarret == MODP_B64_ERROR) return MODP_B64_ERROR;
return (dest - dest_orig) + scalarret;
}
size_t fast_avx2_base64_decode(char *out, const char *src, size_t srclen) {
char* out_orig = out;
while (srclen >= 45) {
// The input consists of six character sets in the Base64 alphabet,
// which we need to map back to the 6-bit values they represent.
// There are three ranges, two singles, and then there's the rest.
//
// # From To Add Characters
// 1 [43] [62] +19 +
// 2 [47] [63] +16 /
// 3 [48..57] [52..61] +4 0..9
// 4 [65..90] [0..25] -65 A..Z
// 5 [97..122] [26..51] -71 a..z
// (6) Everything else => invalid input
__m256i str = _mm256_loadu_si256((__m256i *)src);
// code by @aqrit from
// https://github.com/WojciechMula/base64simd/issues/3#issuecomment-271137490
// transated into AVX2
const __m256i lut_lo = _mm256_setr_epi8(
0x15, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
0x11, 0x11, 0x13, 0x1A, 0x1B, 0x1B, 0x1B, 0x1A,
0x15, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
0x11, 0x11, 0x13, 0x1A, 0x1B, 0x1B, 0x1B, 0x1A
);
const __m256i lut_hi = _mm256_setr_epi8(
0x10, 0x10, 0x01, 0x02, 0x04, 0x08, 0x04, 0x08,
0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10,
0x10, 0x10, 0x01, 0x02, 0x04, 0x08, 0x04, 0x08,
0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10
);
const __m256i lut_roll = _mm256_setr_epi8(
0, 16, 19, 4, -65, -65, -71, -71,
0, 0, 0, 0, 0, 0, 0, 0,
0, 16, 19, 4, -65, -65, -71, -71,
0, 0, 0, 0, 0, 0, 0, 0
);
const __m256i mask_2F = _mm256_set1_epi8(0x2f);
// lookup
__m256i hi_nibbles = _mm256_srli_epi32(str, 4);
__m256i lo_nibbles = _mm256_and_si256(str, mask_2F);
const __m256i lo = _mm256_shuffle_epi8(lut_lo, lo_nibbles);
const __m256i eq_2F = _mm256_cmpeq_epi8(str, mask_2F);
hi_nibbles = _mm256_and_si256(hi_nibbles, mask_2F);
const __m256i hi = _mm256_shuffle_epi8(lut_hi, hi_nibbles);
const __m256i roll = _mm256_shuffle_epi8(lut_roll, _mm256_add_epi8(eq_2F, hi_nibbles));
if (!_mm256_testz_si256(lo, hi)) {
break;
}
str = _mm256_add_epi8(str, roll);
// end of copied function
srclen -= 32;
src += 32;
// end of inlined function
// Reshuffle the input to packed 12-byte output format:
str = dec_reshuffle(str);
_mm256_storeu_si256((__m256i *)out, str);
out += 24;
}
size_t scalarret = chromium_base64_decode(out, src, srclen);
if(scalarret == MODP_B64_ERROR) return MODP_B64_ERROR;
return (out - out_orig) + scalarret;
}
网友评论