美文网首页
第四届中间件性能挑战赛冠军代码解析-复赛C代码

第四届中间件性能挑战赛冠军代码解析-复赛C代码

作者: CXYMichael | 来源:发表于2018-08-19 17:10 被阅读45次

1. 概述

本文主要分析Blink大神的参赛代码(Git源码地址:message-queue-cpp),难免会出现各种疏漏,希望读者即时指正,本人会在之后的更新中补充。

2. queue_store实现

put函数

void queue_store::put(const std::string &queue_name, const MemBlock &message) {
    //从队列名称中获取队列的ID
    auto queueID = static_cast<unsigned long>(getQueueID(queue_name));
    //队列ID模线程数获取所属的IO线程
    auto thread_id = static_cast<int>(queueID % IO_THREAD);
    //从IO线程数组中取出线程
    asyncfileio_thread_t *ioThread = asyncfileio->work_threads_object[thread_id];
    //计算这个队列在所属线程中的编号
    uint64_t which_queue_in_this_io_thread = queueID / IO_THREAD;
    //访问该队列的偏移量计数器并加一
    uint64_t queue_offset = ioThread->queue_counter[which_queue_in_this_io_thread]++;
    //计算CHUNK_ID,CHUNK_ID表示当前这条消息的索引应该写入的CHUNK的编号
    uint64_t chunk_id = ((queue_offset / CLUSTER_SIZE) * (CLUSTER_SIZE * QUEUE_NUM_PER_IO_THREAD) +
                         (which_queue_in_this_io_thread * CLUSTER_SIZE) +
                         queue_offset % CLUSTER_SIZE);
    //根据CHUNK_ID确定索引写入位置
    uint64_t idx_file_offset = INDEX_ENTRY_SIZE * chunk_id;
    //计算索引文件的序号
    int which_mapped_chunk = static_cast<int>(idx_file_offset / INDEX_MAPPED_BLOCK_RAW_SIZE);
    //计算索引在索引文件中的写入位置
    uint64_t offset_in_mapped_chunk = idx_file_offset % INDEX_MAPPED_BLOCK_RAW_SIZE;
    //如果内存中没有索引文件对应的缓冲区
    if (ioThread->index_file_memory_block[which_mapped_chunk] == nullptr) {
        //对缓冲区加锁(注意这里不是对mapped_block_mtx加锁)
        std::unique_lock<std::mutex> lock(ioThread->mapped_block_mtx[which_mapped_chunk]);
        //初始化缓冲区,并阻塞等待初始化完成
        ioThread->mapped_block_cond[which_mapped_chunk].wait(lock, [ioThread, which_mapped_chunk]() -> bool {
            return ioThread->index_file_memory_block[which_mapped_chunk] != nullptr;
        });

    }
    //计算缓冲区写入位置并存为指针
    char *buf = ioThread->index_file_memory_block[which_mapped_chunk] + offset_in_mapped_chunk;
    //如果消息是常规尺寸(58)
    if (message.size <= RAW_NORMAL_MESSAGE_SIZE) {
        //使用跳表快速编码并写入缓冲区
        serialize_base36_decoding_skip_index((uint8_t *) message.ptr, message.size,
                                             (uint8_t *) buf);
    } else {
        //创建长消息缓冲区
        unsigned char large_msg_buf[4096];
        //使用跳表快速编码并写入缓冲区,保存编码后的长度
        uint64_t length = (uint64_t) serialize_base36_decoding_skip_index((uint8_t *) message.ptr, message.size,
                                                                          (uint8_t *) large_msg_buf);
        //计算写入位置
        uint64_t offset = ioThread->data_file_current_size.fetch_add(length);
        //写入消息
        pwrite(ioThread->data_file_fd, large_msg_buf, length, offset);

        //写入标志位到索引
        buf[0] = LARGE_MESSAGE_MAGIC_CHAR;
        //写入偏移量到索引
        memcpy(buf + 4, &offset, 8);
        //写入消息长度到索引
        memcpy(buf + 12, &length, 8);
    }
    //回收消息内存
    delete[] ((char *) (message.ptr));
    //累加索引块写入次数计数器
    int write_times = ++(ioThread->index_mapped_block_write_counter[which_mapped_chunk]);
    //根据写入次数判断索引块是否写满
    if (write_times == INDEX_BLOCK_WRITE_TIMES_TO_FULL) {
        //新建异步保存索引任务
        asyncio_task_t *task = new asyncio_task_t(0);
        //执行异步任务
        ioThread->blockingQueue->put(task);
    }
}

get函数

vector<MemBlock> queue_store::get(const std::string &queue_name, long offset, long number) {
    //线程计数器
    static thread_local int tid = ++tidCounter;
    //如果不到10个线程,说明是IndexCheck评测阶段
    if (tid < CHECK_THREAD_NUM) {
        return doPhase2(tid, queue_name, offset, number);
    }
    //如果超过10个线程,说明是QueueConsume评测阶段
    return doPhase3(tid, queue_name, offset, number);
}

doPhase2函数

std::vector<MemBlock> queue_store::doPhase2(int tid, const std::string &queue_name, long offset, long number) {
    //等待上次IO操作完成
    asyncfileio->waitFinishIO(tid);
    //声明查询结果变量(线程私有)
    static thread_local vector<MemBlock> result;
    //清空上次查询结果
    result.clear();
    //获取队列ID
    auto queueID = static_cast<unsigned long>(getQueueID(queue_name));
    //获取对应线程ID
    int threadID = queueID % IO_THREAD;
    //获取对应IO线程
    asyncfileio_thread_t *asyncfileio_thread = asyncfileio->work_threads_object[threadID];
    //获取队列在线程中的编号
    uint32_t which_queue_in_this_io_thread = queueID / IO_THREAD;
    //获取最大偏移量(待读取的最后一条消息的位置)
    auto max_offset = std::min(static_cast<uint32_t>(offset + number),
                               asyncfileio_thread->queue_counter[which_queue_in_this_io_thread]);
    //计算消息所在的CLUSTER
    uint64_t chunk_offset = (which_queue_in_this_io_thread * CLUSTER_SIZE);
    //分配一个地址是4k对齐的缓冲区
    static thread_local unsigned char *index_record = (unsigned char *) memalign(FILESYSTEM_BLOCK_SIZE,
                                                                                 (INDEX_ENTRY_SIZE * CLUSTER_SIZE) +
                                                                                 FILESYSTEM_BLOCK_SIZE);
    //循环并判断是否读取完所有数据
    for (auto queue_offset = static_cast<uint64_t>(offset); queue_offset < max_offset;) {
        //计算CHUNK_ID,与写入时的计算方式相同
        uint64_t chunk_id = ((queue_offset / CLUSTER_SIZE) * (CLUSTER_SIZE * QUEUE_NUM_PER_IO_THREAD) + chunk_offset +
                             queue_offset % CLUSTER_SIZE);
        //计算该CLUSTER中可读的记录数量
        auto remaining_num = static_cast<uint32_t>(CLUSTER_SIZE - queue_offset % CLUSTER_SIZE);     // >= 1
        //如果可读记录数量超过需要读取的数量
        if (max_offset - queue_offset < remaining_num) {
            //计算剩余需要读取数量
            remaining_num = static_cast<uint32_t >(max_offset - queue_offset);
        }
        //计算索引在索引文件中的偏移量
        uint64_t idx_file_offset = INDEX_ENTRY_SIZE * chunk_id;
        idx_file_offset = (idx_file_offset / INDEX_MAPPED_BLOCK_RAW_SIZE * INDEX_MAPPED_BLOCK_ALIGNED_SIZE) +
                          (idx_file_offset % INDEX_MAPPED_BLOCK_RAW_SIZE);
        //计算索引所在的4k文件块的偏移量
        uint64_t idx_file_offset_aligned_start = (idx_file_offset / FILESYSTEM_BLOCK_SIZE * FILESYSTEM_BLOCK_SIZE);

        size_t which_mapped_chunk = idx_file_offset_aligned_start / INDEX_MAPPED_BLOCK_ALIGNED_SIZE;
        if (which_mapped_chunk < asyncfileio_thread->index_mapped_flush_start_chunkID) {
            //异步读取
            pread(asyncfileio->index_fds[queueID % IO_THREAD], index_record,
                  ((INDEX_ENTRY_SIZE * remaining_num + (idx_file_offset - idx_file_offset_aligned_start)) /
                   FILESYSTEM_BLOCK_SIZE + 1) * FILESYSTEM_BLOCK_SIZE,
                  idx_file_offset_aligned_start);
        } else {
            //直接从CHUNK中读取
            memcpy(index_record,
                   asyncfileio_thread->index_file_memory_block[which_mapped_chunk] +
                   (idx_file_offset_aligned_start % INDEX_MAPPED_BLOCK_ALIGNED_SIZE),
                   (INDEX_ENTRY_SIZE * remaining_num) + (idx_file_offset - idx_file_offset_aligned_start));
        }
        //循环解析CLUSTER内未读取的剩余消息
        for (uint32_t i = 0; i < remaining_num; i++) {
            char *output_buf = nullptr;
            int output_length;
            unsigned char *serialized =
                    index_record + INDEX_ENTRY_SIZE * i + idx_file_offset - idx_file_offset_aligned_start;
            //根据标志位判断为常规消息
            if ((serialized[0] & 0xff) >> 2 != LARGE_MESSAGE_MAGIC_CHAR) {
                //直接解码索引即可得消息
                output_buf = (char *) deserialize_base36_encoding_add_index(serialized, INDEX_ENTRY_SIZE,
                                                                            output_length, queue_offset + i);
            } else {      //长消息(因为评测数据没有长消息,貌似这里的代码没有进行优化)
                log_info("big msg");
                //解析索引信息
                size_t large_msg_size;
                size_t large_msg_offset;
                memcpy(&large_msg_offset, serialized + 4, 8);
                memcpy(&large_msg_size, serialized + 12, 8);
                //分配缓冲区
                unsigned char large_msg_buf[4096];
                //阻塞原子性读取
                pread(asyncfileio->data_fds[queueID % IO_THREAD], large_msg_buf, large_msg_size,
                      large_msg_offset);
                //解码消息
                output_buf = (char *) deserialize_base36_encoding_add_index((uint8_t *) large_msg_buf, large_msg_size,
                                                                            output_length, queue_offset + i);
            }
            //放入结果集
            result.emplace_back(output_buf, (size_t) output_length);
        }
        //更新消息队列访问偏移量
        queue_offset += remaining_num;
    }

    return result;
}

doPhase3函数

volatile bool startedReaderThreadFlag = false;

std::vector<MemBlock> queue_store::doPhase3(int tid, const std::string &queue_name, long offset, long number) {
    //获取队列ID
    auto queueID = static_cast<unsigned long>(getQueueID(queue_name));
    //声明查询结果变量(线程私有)
    static thread_local vector<MemBlock> result;
    //清空上次查询结果
    result.clear();
    //获取对应线程ID
    int threadID = queueID % IO_THREAD;
    //获取对应IO线程
    asyncfileio_thread_t *asyncfileio_thread = asyncfileio->work_threads_object[threadID];
    //获取队列在线程中的编号
    uint32_t which_queue_in_this_io_thread = queueID / IO_THREAD;
    //获取最大偏移量(待读取的最后一条消息的位置)
    size_t max_queue_offset = asyncfileio_thread->queue_counter[which_queue_in_this_io_thread];
    //计算该CHUNK中的待读取消息数
    size_t max_result_num = 10 < (max_queue_offset - offset) ? 10 : (max_queue_offset - offset);
    //如果尚未启动异步读取线程
    if (offset == 0 && !startedReaderThreadFlag) {
        //阻塞等待线程初始化
        barrier1->Wait([this] {
            //循环初始化IO线程
            for (int i = 0; i < IO_THREAD; i++) {
//                for (size_t chunkID = asyncfileio->work_threads_object[i]->index_mapped_flush_start_chunkID;
//                     chunkID < asyncfileio->work_threads_object[i]->index_mapped_flush_end_chunkID; chunkID++) {
//                    //free(asyncfileio->work_threads_object[i]->index_file_memory_block[chunkID]);
//                    log_debug("free thread %d chunk id %ld", i, chunkID);
//                }
                //MallocExtension::instance()->ReleaseFreeMemory();
                //关闭之前的写入fd
                close(asyncfileio->index_fds[i]);
                //拼接索引文件路径
                string tmp_str = asyncfileio->file_prefix + "_" + std::to_string(i) + ".idx";
                //只读方式打开索引文件
                asyncfileio->index_fds[i] = open(tmp_str.c_str(), O_RDONLY | O_DIRECT, S_IRUSR | S_IWUSR);
//                posix_fadvise(asyncfileio->index_fds[i], 0,
//                              asyncfileio->mapped_index_files_length[i],
//                              POSIX_FADV_NORMAL);
            }
            printf("phase3 start\n");
        });
       //修改标志位
        startedReaderThreadFlag = true;
    }
    //如果读取结束,返回查询结果
    if (max_result_num <= 0) {
        return result;
    }
    //声明各线程的读缓冲指针数组
    static thread_local unsigned char **reader_hash_buffer = new unsigned char *[TOTAL_QUEUE_NUM]();
    //声明各线程的读缓冲偏移量
    static thread_local short *reader_hash_buffer_start_offset = new short[TOTAL_QUEUE_NUM]();
    //如果缓冲区未初始化
    if (reader_hash_buffer[queueID] == nullptr) {
        //为缓冲区分配4k对齐的内存
        reader_hash_buffer[queueID] = (unsigned char *) memalign(FILESYSTEM_BLOCK_SIZE,
                                                                 (INDEX_ENTRY_SIZE * CLUSTER_SIZE) +
                                                                 FILESYSTEM_BLOCK_SIZE);
    }
    //将待读取消息数转换为size_t类型
    size_t read_num_left = max_result_num;
    //循环读取每个CLUSTER中的消息
    for (size_t new_offset = offset; new_offset < offset + max_result_num;) {
        //计算该CLUSTER中可读的记录数量
        size_t this_max_read = std::min<size_t>(read_num_left, CLUSTER_SIZE - (new_offset % CLUSTER_SIZE));
        //如果CLUSTER中所有数据都需要读取
        if (new_offset % CLUSTER_SIZE == 0) {
            uint64_t chunk_offset = (which_queue_in_this_io_thread * CLUSTER_SIZE);
            //计算CHUNK_ID
            uint64_t chunk_id = ((new_offset / CLUSTER_SIZE) * (CLUSTER_SIZE * QUEUE_NUM_PER_IO_THREAD) + chunk_offset +
                                 new_offset % CLUSTER_SIZE);
           //计算索引在索引文件中的偏移量
            uint64_t idx_file_offset = INDEX_ENTRY_SIZE * chunk_id;
            idx_file_offset = (idx_file_offset / INDEX_MAPPED_BLOCK_RAW_SIZE * INDEX_MAPPED_BLOCK_ALIGNED_SIZE) +
                              (idx_file_offset % INDEX_MAPPED_BLOCK_RAW_SIZE);
            //计算索引所在的4k文件块的偏移量
            uint64_t idx_file_offset_aligned_start = (idx_file_offset / FILESYSTEM_BLOCK_SIZE * FILESYSTEM_BLOCK_SIZE);
            reader_hash_buffer_start_offset[queueID] = static_cast<short>(idx_file_offset -
                                                                          idx_file_offset_aligned_start);

            size_t which_mapped_chunk = idx_file_offset_aligned_start / INDEX_MAPPED_BLOCK_ALIGNED_SIZE;

            if (which_mapped_chunk < asyncfileio_thread->index_mapped_flush_start_chunkID) {
                //异步读取
                pread(asyncfileio->index_fds[threadID], reader_hash_buffer[queueID],
                      ((INDEX_ENTRY_SIZE * CLUSTER_SIZE + (idx_file_offset - idx_file_offset_aligned_start)) /
                       FILESYSTEM_BLOCK_SIZE + 1) * FILESYSTEM_BLOCK_SIZE,
                      idx_file_offset_aligned_start);
            } else {
                //直接从CHUNK中读取
                memcpy(reader_hash_buffer[queueID],
                       asyncfileio_thread->index_file_memory_block[which_mapped_chunk] +
                       (idx_file_offset_aligned_start % INDEX_MAPPED_BLOCK_ALIGNED_SIZE),
                       (INDEX_ENTRY_SIZE * CLUSTER_SIZE) + (idx_file_offset - idx_file_offset_aligned_start));
            }

        }
        //计算消息在CLUSTER内的偏移量
        long in_cluster_offset = new_offset % CLUSTER_SIZE;
        //循环解析CLUSTER内需要读取的消息
        for (uint32_t i = 0; i < this_max_read; i++) {
            char *output_buf = nullptr;
            int output_length;
            unsigned char *serialized = reader_hash_buffer[queueID] + INDEX_ENTRY_SIZE * (in_cluster_offset + i) +
                                        reader_hash_buffer_start_offset[queueID];
            //根据标志位判断为常规消息
            if ((serialized[0] & 0xff) >> 2 != LARGE_MESSAGE_MAGIC_CHAR) {
                output_buf = (char *) deserialize_base36_encoding_add_index(serialized, INDEX_ENTRY_SIZE,
                                                                            output_length, new_offset + i);
            } else {    //长消息
                //解析索引信息
                size_t large_msg_size;
                size_t large_msg_offset;
                memcpy(&large_msg_offset, serialized + 4, 8);
                memcpy(&large_msg_size, serialized + 12, 8);
                //分配缓冲区
                unsigned char large_msg_buf[4096];
                //阻塞原子性取
                pread(asyncfileio->data_fds[queueID % IO_THREAD], large_msg_buf, large_msg_size,
                      large_msg_offset);
                //解码消息
                output_buf = (char *) deserialize_base36_encoding_add_index((uint8_t *) large_msg_buf, large_msg_size,
                                                                            output_length, new_offset + i);
            }
            //放入结果集
            result.emplace_back(output_buf, (size_t) output_length);
        }
        //重新计算消息队列读取偏移量
        new_offset += this_max_read;
        //重新计算剩余消息数
        read_num_left -= this_max_read;
    }
    //返回查询结果
    return result;
}

构造函数

queue_store::queue_store() {
    //打印Git版本
    print_version();
    //创建一个内存栅栏(阈值为校验线程的数量)
    barrier1 = new Barrier(CHECK_THREAD_NUM);
    //初始化文件异步写入线程
    asyncfileio = new asyncfileio_t(DATA_FILE_PATH);
    //启动异步文件写入线程
    asyncfileio->startIOThread();
}

3. asyncfileio实现

asyncfileio_thread_t类

class asyncfileio_thread_t {
public:
    //线程ID
    const int thread_id;
    //消息数据文件大小
    atomic<long> data_file_current_size;
    //数据文件描述符
    int data_file_fd;
    //消息索引文件大小
    size_t index_file_size;
    //索引文件描述符
    int index_file_fd;
    //索引的各种偏移量
    size_t current_index_mapped_start_offset;
    size_t current_index_mapped_end_offset;
    size_t current_index_mapped_start_chunk;
    size_t index_mapped_flush_start_chunkID;
    size_t index_mapped_flush_end_chunkID;
    atomic<int> *index_mapped_block_write_counter;
    std::mutex *mapped_block_mtx;
    std::condition_variable *mapped_block_cond;
    //索引文件缓冲区指针
    char **index_file_memory_block;

    uint32_t *queue_counter;
    //IO任务队列
    BlockingQueue<asyncio_task_t *> *blockingQueue;
    //线程状态
    enum asyncfileio_thread_status status;

    asyncfileio_thread_t(int tid, std::string file_prefix) : thread_id(tid) {
        //初始化文件大小变量
        this->data_file_current_size.store(0);
        this->index_file_size = 0;
        //初始化索引相关变量
        index_mapped_block_write_counter = new atomic<int>[MAX_MAPED_CHUNK_NUM];
        index_file_memory_block = new char *[MAX_MAPED_CHUNK_NUM];
        for (int i = 0; i < MAX_MAPED_CHUNK_NUM; i++) {
            index_mapped_block_write_counter[i].store(0);
            index_file_memory_block[i] = nullptr;
        }
        //分配缓冲区内存
        for (int i = 0; i < MAX_CONCURRENT_INDEX_MAPPED_BLOCK_NUM; i++) {
            index_file_memory_block[i] = (char *) memalign(FILESYSTEM_BLOCK_SIZE, INDEX_MAPPED_BLOCK_ALIGNED_SIZE);
            //index_file_memory_block[i] = ;new char[INDEX_MAPPED_BLOCK_SIZE];
        }
        //初始化队列和锁
        queue_counter = new uint32_t[QUEUE_NUM_PER_IO_THREAD];
        memset(queue_counter, 0, sizeof(uint32_t) * QUEUE_NUM_PER_IO_THREAD);
        mapped_block_mtx = new std::mutex[MAX_MAPED_CHUNK_NUM];
        mapped_block_cond = new std::condition_variable[MAX_MAPED_CHUNK_NUM];

        this->blockingQueue = new BlockingQueue<asyncio_task_t *>;
        //初始化数据文件
        string tmp_str = file_prefix + "_" + std::to_string(tid) + ".data";
        this->data_file_fd = open(tmp_str.c_str(), O_RDWR | O_CREAT, S_IRUSR | S_IWUSR);
        ftruncate(data_file_fd, 0);
        ftruncate(data_file_fd, DATA_FILE_MAX_SIZE);
        //初始化索引文件
        tmp_str = file_prefix + "_" + std::to_string(tid) + ".idx";
        this->index_file_fd = open(tmp_str.c_str(), O_WRONLY | O_CREAT | O_DIRECT | O_SYNC, S_IRUSR | S_IWUSR);
        ftruncate(index_file_fd, 0);
        //初始化索引偏移量
        this->current_index_mapped_start_chunk = 0;
        this->current_index_mapped_start_offset = 0;
        this->current_index_mapped_end_offset =
                ((size_t) INDEX_MAPPED_BLOCK_ALIGNED_SIZE) * MAX_CONCURRENT_INDEX_MAPPED_BLOCK_NUM;
        this->index_file_size = 0;
    }

    void doIO(asyncio_task_t *asyncio_task) {
        //遍历所有CHUNK
        for (; current_index_mapped_start_chunk < MAX_MAPED_CHUNK_NUM; current_index_mapped_start_chunk++) {
            //如果CHUNK写满
            if (index_mapped_block_write_counter[current_index_mapped_start_chunk].load() >=
                INDEX_BLOCK_WRITE_TIMES_TO_FULL) {
                //更新索引文件大小
                index_file_size += INDEX_MAPPED_BLOCK_ALIGNED_SIZE;
                ftruncate(index_file_fd, index_file_size);
                //CHUNK写入索引文件
                pwrite(index_file_fd, index_file_memory_block[current_index_mapped_start_chunk],
                       INDEX_MAPPED_BLOCK_ALIGNED_SIZE,
                       INDEX_MAPPED_BLOCK_ALIGNED_SIZE * current_index_mapped_start_chunk);
                //将使用完的CHUNK放到缓冲区队列的队尾
                int next_chunk = current_index_mapped_start_chunk + MAX_CONCURRENT_INDEX_MAPPED_BLOCK_NUM;
                current_index_mapped_start_offset += INDEX_MAPPED_BLOCK_ALIGNED_SIZE;
                current_index_mapped_end_offset += INDEX_MAPPED_BLOCK_ALIGNED_SIZE;
                {
                    std::unique_lock<std::mutex> lock(mapped_block_mtx[next_chunk]);
                    index_file_memory_block[next_chunk] = index_file_memory_block[current_index_mapped_start_chunk];
                }
                //通知等待使用next_chunk的线程
                mapped_block_cond[next_chunk].notify_all();
                //释放缓冲区队列头部的原引用
                index_file_memory_block[current_index_mapped_start_chunk] = nullptr;

                log_info("io thread %d advanced to %d", this->thread_id, next_chunk);

            } else {
                break;
            }
        }
    }
};

ioThreadFunction函数

感觉独立声明一个函数不太好,应该封装到类里面,同时避免使用全局变量。IO线程类只要实现初始化、FLUSH、缓冲区回收、分配、写入、读取6个操作,IO流程控制实现IO线程同步启动写、结束写、启动读、结束读4个操作即可。

bool allFlushFlag = false;
bool ioFinished = false;
Barrier *barrier;
atomic<int> *finish_thread_counter;

void ioThreadFunction(asyncfileio_thread_t *args) {
    //保存传入的asyncfileio_thread_t实例
    asyncfileio_thread_t *work_thread = args;
    //线程状态设为运行中
    work_thread->status = AT_RUNNING;
    //线程循环
    for (;;) {
        //从队列中取Task
        asyncio_task_t *task = work_thread->blockingQueue->take();
        //如果线程已经关闭或者收到关闭命令
        if (work_thread->status == AT_CLOSING || task->global_offset == -1) {
            size_t force_flush_chunk_num = 0;
            //TID为0的线程flush位置要向后偏移1
            if (work_thread->thread_id < 1) {
                force_flush_chunk_num = 1;
            }
            //初始化索引flush操作偏移量
            work_thread->index_mapped_flush_start_chunkID =
                    work_thread->current_index_mapped_start_chunk + force_flush_chunk_num;
            work_thread->index_mapped_flush_end_chunkID = work_thread->index_mapped_flush_start_chunkID;
            //索引缓冲区循环落盘
            for (size_t i = work_thread->current_index_mapped_start_chunk;
                 i < work_thread->index_mapped_flush_start_chunkID &&
                 work_thread->index_file_memory_block[i] != nullptr; i++) {
                work_thread->index_file_size += INDEX_MAPPED_BLOCK_ALIGNED_SIZE;
                ftruncate(work_thread->index_file_fd, work_thread->index_file_size);
                //写入索引文件
                pwrite(work_thread->index_file_fd, work_thread->index_file_memory_block[i],
                       INDEX_MAPPED_BLOCK_ALIGNED_SIZE, INDEX_MAPPED_BLOCK_ALIGNED_SIZE * i);
                //回收缓冲区
                free(work_thread->index_file_memory_block[i]);
            }
            //等待内核完成写操作
            fsync(work_thread->index_file_fd);
            //写入完成线程数计数器加1
            finish_thread_counter->fetch_add(1);
            //写入消息数据文件
            ftruncate(work_thread->data_file_fd, work_thread->data_file_current_size.load());
            //终止循环
            break;
        }
        //执行IO操作
        work_thread->doIO(task);
        //回收task实例
        delete task;
    }
}

asyncfileio_t类

class asyncfileio_t {

public:
    //构造函数
    asyncfileio_t(std::string file_prefix) {
        //打印内存分配信息
        malloc_stats();
        //调用TCMalloc回收内存
        MallocExtension::instance()->ReleaseFreeMemory();
        malloc_stats();
        //保存文件前缀名,这里是消息数据文件路径:"/alidata1/race2018/data/data"
        this->file_prefix = file_prefix;
        //统计完成IO工作的线程的计数器
        finish_thread_counter = new atomic<int>(0);
        //等待IO线程完成工作的栅栏
        barrier = new Barrier(SEND_THREAD_NUM);
        //创建线程数组
        for (int i = 0; i < IO_THREAD; i++) {
            work_threads_object[i] = new asyncfileio_thread_t(i, file_prefix);
        }
    }

    void startIOThread() {
        for (int i = 0; i < IO_THREAD; i++) {
            work_threads_handle[i] = std::thread(ioThreadFunction, work_threads_object[i]);
            //分离新创建的线程
            work_threads_handle[i].detach();
        }
    }

    //等待写入操作完成
    void waitFinishIO(int tid) {
        //避免各个线程重复执行该方法
        if (!ioFinished) {
            if (tid == 0) {
                printf("in wait_flush function %ld\n", getCurrentTimeInMS());
            }
            //等待十个线程全部首次进入读取阶段
            barrier->Wait([this] {
                printf("start send flush cmd %ld\n", getCurrentTimeInMS());
                //在工作线程的队尾加入结束写入操作的Task
                for (int i = 0; i < IO_THREAD; i++) {
                    asyncio_task_t *task = new asyncio_task_t(-1);
                    work_threads_object[i]->blockingQueue->put(task);
                }
                printf("after send flush cmd %ld\n", getCurrentTimeInMS());
            });
            if (tid == 0) {
                printf("before wait flush finish %ld\n", getCurrentTimeInMS());
            }
            //自旋等待所有工作线程完成IO操作
            while (finish_thread_counter->load() < IO_THREAD) {};
            if (tid == 0) {
                printf("after wait flush finish %ld\n", getCurrentTimeInMS());
            }
            //等待十个工作线程全部完成写入操作
            barrier->Wait([this] {
                 //malloc_stats();
                //调用TCMalloc回收内存
                MallocExtension::instance()->ReleaseFreeMemory();
                //malloc_stats();
                //循环初始化读文件FD
                for (int i = 0; i < IO_THREAD; i++) {
                    //索引文件名
                    string tmp_str = file_prefix + "_" + std::to_string(i) + ".idx";
                    index_fds[i] = open(tmp_str.c_str(), O_RDONLY, S_IRUSR | S_IWUSR);
                    data_fds[i] = work_threads_object[i]->data_file_fd;
                    mapped_index_files_length[i] = work_threads_object[i]->index_file_size;
                    //执行内核预读取优化
                    int ret = posix_fadvise(index_fds[i], 0,
                                            work_threads_object[i]->index_file_size,
                                            POSIX_FADV_RANDOM);
                    printf("ret %d\n", ret);
                }

            });
            if (tid == 0) {
                printf("finish wait_flush function %ld\n", getCurrentTimeInMS());
            }
            //修改标志位
            ioFinished = true;
        }
    }
    //析构函数
    ~asyncfileio_t() {
        printf("f\n");
        //创建flush_all_func回收线程
        std::thread flush_thread(flush_all_func, this);
        flush_thread.detach();
    }

    string file_prefix;
    //IO线程实例
    asyncfileio_thread_t *work_threads_object[IO_THREAD];
    //IO线程句柄
    std::thread work_threads_handle[IO_THREAD];
    //各线程索引文件长度
    size_t mapped_index_files_length[IO_THREAD];
    //索引文件描述符数组
    int index_fds[IO_THREAD];
    //数据文件描述符数组
    int data_fds[IO_THREAD];

};

void flush_all_func(void *args) {
    //转换this指针类型
    asyncfileio_t *asyncfileio = (asyncfileio_t *) args;
    //循环回收IO线程
    for (int tid = 0; tid < IO_THREAD; tid++) {
        asyncfileio_thread_t *work_thread = asyncfileio->work_threads_object[tid];
        //逐个CHUNK进行回收
        for (size_t i = work_thread->index_mapped_flush_start_chunkID;
             i < MAX_MAPED_CHUNK_NUM && work_thread->index_file_memory_block[i] != nullptr; i++) {
            //更新回收队列队尾位置
            work_thread->index_mapped_flush_end_chunkID++;
            //计算块地址对齐后的文件大小
            work_thread->index_file_size += INDEX_MAPPED_BLOCK_ALIGNED_SIZE;
            //修改文件大小
            ftruncate(work_thread->index_file_fd, work_thread->index_file_size);
            //写入硬盘
            pwrite(work_thread->index_file_fd, work_thread->index_file_memory_block[i],
                   INDEX_MAPPED_BLOCK_ALIGNED_SIZE, INDEX_MAPPED_BLOCK_ALIGNED_SIZE * i);
            //回收内存缓冲区
            free(work_thread->index_file_memory_block[i]);
        }
        //阻塞等待系统完成落盘操作
        fsync(work_thread->index_file_fd);
    }
}

4. Barrier实现

class Barrier {
public:
    explicit Barrier(std::size_t iCount) :
            mThreshold(iCount),
            mCount(iCount),
            mGeneration(0) {
    }

    void Wait(std::function<void()> func) {
        std::unique_lock<std::mutex> lLock{mMutex};
        auto lGen = mGeneration;
        //如果计数器为0
        if (!--mCount) {
            //进入下一轮
            mGeneration++;
            mCount = mThreshold;
            func();
            //通知所有阻塞线程
            mCond.notify_all();
        } else {
            //阻塞等待通知,直到进入下一轮
            mCond.wait(lLock, [this, lGen] { return lGen != mGeneration; });
        }
    }

private:
    //同步锁
    std::mutex mMutex;
    //条件变量
    std::condition_variable mCond;
    //下一轮的计数器阈值
    std::size_t mThreshold;
    //每轮的等待计数器
    std::size_t mCount;
    //第几轮栅栏
    std::size_t mGeneration;
};

5. fast_base64实现

serialization.cpp中的方法

该文件包含了对消息进行序列化的函数,这些函数传入消息及其长度信息,首先计算消息压缩长度和偏移量,然后进行内存拷贝填充并根据当前系统环境进行编码解码操作,最后返回一个序列化后的数据的指针。

Base64序列化

int serialize_base64_decoding(uint8_t *message, uint16_t len, uint8_t *serialized) {
     //计算消息需要压缩的部分的长度,(实际运行时这里是58-10=48,正好不需要填充)
    auto serialize_len = len - FIXED_PART_LEN;
    //计算需要填充的字节数
    int padding_chars = (4 - serialize_len % 4) % 4;
    uint8_t *buf = message;
    //计算压缩后压缩部分的长度
    size_t estimated_length = 3 * (serialize_len / 4 + (serialize_len % 4 == 0 ? 0 : 1));
    //先拷贝不压缩的部分
    memcpy(serialized + estimated_length, message + serialize_len, FIXED_PART_LEN);
    //然后用“BLINK”补齐需要填充的部分
    // attention: add padding chars, assume following chars enough >= 3
    memcpy(message + serialize_len, "BLINK", padding_chars);
//如果支持AVX2指令集
#ifdef __AVX2__
    fast_avx2_base64_decode(reinterpret_cast<char *>(serialized),
                                            reinterpret_cast<const char *>(buf),
                                            serialize_len + padding_chars);
#else
    //只压缩消息和填充字符
    chromium_base64_decode(reinterpret_cast<char *>(serialized),
                           reinterpret_cast<const char *>(buf),
                           serialize_len + padding_chars);
#endif
    serialized[estimated_length + FIXED_PART_LEN] = padding_chars;
    return estimated_length + FIXED_PART_LEN + 1;
}

Base64反序列化

序列化的逆操作,这里不再赘述。

uint8_t *deserialize_base64_encoding(const uint8_t *serialized, uint16_t total_serialized_len, int &len) {
    auto serialize_len = total_serialized_len - FIXED_PART_LEN - 1;
    auto *deserialized = new uint8_t[total_serialized_len / 3 * 4 + 16];

#ifdef __AVX2__
    size_t length = fast_avx2_base64_encode(reinterpret_cast<char *>(deserialized),
                                            reinterpret_cast<const char *>(serialized), serialize_len);
#else
    size_t length = chromium_base64_encode(reinterpret_cast<char *>(deserialized),
                                           reinterpret_cast<const char *>(serialized), serialize_len);

#endif
    memcpy(deserialized + length - serialized[total_serialized_len - 1], serialized + serialize_len, FIXED_PART_LEN);
    len = length - serialized[total_serialized_len - 1] + FIXED_PART_LEN;
    return deserialized;
}

Base64序列化(省略索引)

serialization.hpp中的宏定义如下:

#define BASE64_INFO_LEN (2u)
#define INDEX_LEN (4u)
#define VARYING_VERIFY_LEN (4u)
#define FIXED_PART_LEN (10u)

很明显:

FIXED_PART_LEN = BASE64_INFO_LEN + INDEX_LEN + VARYING_VERIFY_LEN 

而其中INDEX_LEN可以通过计算得出,所以serialize_base64_decoding_skip_index方法省略了INDEX_LEN部分的存储。

// Skip index =================================================================================================
int serialize_base64_decoding_skip_index(uint8_t *message, uint16_t len, uint8_t *serialized) {
    auto serialize_len = len - FIXED_PART_LEN;
    int padding_chars = (4 - serialize_len % 4) % 4;
    uint8_t *buf = message;

    size_t estimated_length = 3 * (serialize_len / 4 + (serialize_len % 4 == 0 ? 0 : 1));
    memcpy(serialized + estimated_length, message + serialize_len, BASE64_INFO_LEN);
    memcpy(serialized + estimated_length + BASE64_INFO_LEN, message + serialize_len + BASE64_INFO_LEN + INDEX_LEN,
           VARYING_VERIFY_LEN);
    // attention: add padding chars, assume following chars enough >= 3
    memcpy(message + serialize_len, "BLINK", padding_chars);

#ifdef __AVX2__
    fast_avx2_base64_decode(reinterpret_cast<char *>(serialized),
                                            reinterpret_cast<const char *>(buf),
                                            serialize_len + padding_chars);
#else

    chromium_base64_decode(reinterpret_cast<char *>(serialized),
                           reinterpret_cast<const char *>(buf),
                           serialize_len + padding_chars);
#endif
    serialized[estimated_length + FIXED_PART_LEN - INDEX_LEN] = padding_chars;
    return estimated_length + FIXED_PART_LEN - INDEX_LEN + 1;
}

Base64反序列化(填充索引)

uint8_t *deserialize_base64_encoding_add_index(const uint8_t *serialized, uint16_t total_serialized_len,
                                               int &deserialized_len, int32_t idx) {
    auto serialize_len = total_serialized_len - (FIXED_PART_LEN - INDEX_LEN) - 1;
    auto *deserialized = new uint8_t[total_serialized_len / 3 * 4 + 16];

#ifdef __AVX2__
    size_t length = fast_avx2_base64_encode(reinterpret_cast<char *>(deserialized),
                                            reinterpret_cast<const char *>(serialized), serialize_len);
#else
    size_t length = chromium_base64_encode(reinterpret_cast<char *>(deserialized),
                                           reinterpret_cast<const char *>(serialized), serialize_len);

#endif
    size_t offset = length - serialized[total_serialized_len - 1];
    memcpy(deserialized + offset, serialized + serialize_len, BASE64_INFO_LEN);
    offset += BASE64_INFO_LEN;
    memcpy(deserialized + offset, &idx, sizeof(int32_t));
    offset += INDEX_LEN;
    memcpy(deserialized + offset, serialized + serialize_len + BASE64_INFO_LEN, VARYING_VERIFY_LEN);

    deserialized_len = length - serialized[total_serialized_len - 1] + FIXED_PART_LEN;
    return deserialized;
}

Base64反序列化指定位置的消息(填充索引)

void deserialize_base64_encoding_add_index_in_place(const uint8_t *serialized, uint16_t total_serialized_len,
                                                    uint8_t *deserialized, int &deserialized_len, int32_t idx) {
    auto serialize_len = total_serialized_len - (FIXED_PART_LEN - INDEX_LEN) - 1;

#ifdef __AVX2__
    size_t length = fast_avx2_base64_encode(reinterpret_cast<char *>(deserialized),
                                            reinterpret_cast<const char *>(serialized), serialize_len);
#else
    size_t length = chromium_base64_encode(reinterpret_cast<char *>(deserialized),
                                           reinterpret_cast<const char *>(serialized), serialize_len);

#endif
    size_t offset = length - serialized[total_serialized_len - 1];
    memcpy(deserialized + offset, serialized + serialize_len, BASE64_INFO_LEN);
    offset += BASE64_INFO_LEN;
    memcpy(deserialized + offset, &idx, sizeof(int32_t));
    offset += INDEX_LEN;
    memcpy(deserialized + offset, serialized + serialize_len + BASE64_INFO_LEN, VARYING_VERIFY_LEN);

    deserialized_len = length - serialized[total_serialized_len - 1] + FIXED_PART_LEN;
}
// End of Skip index ========================================================================================

Base36序列化(省略索引)

// ------------------------------- Begin of Base36 -------------------------------------------------------------
int serialize_base36_decoding_skip_index(uint8_t *message, uint16_t len, uint8_t *serialized) {
    auto serialize_len = len - FIXED_PART_LEN;
    int padding_chars = (4 - serialize_len % 4) % 4;
    uint8_t *buf = message;

    size_t estimated_length = 3 * (serialize_len / 4 + (serialize_len % 4 == 0 ? 0 : 1));
    memcpy(serialized + estimated_length, message + serialize_len, BASE64_INFO_LEN);
    memcpy(serialized + estimated_length + BASE64_INFO_LEN, message + serialize_len + BASE64_INFO_LEN + INDEX_LEN,
           VARYING_VERIFY_LEN);
    // attention: add padding chars, assume following chars enough >= 3
    memcpy(message + serialize_len, "BLINK", padding_chars);

#ifdef __AVX2__
    fast_avx2_base64_decode(reinterpret_cast<char *>(serialized),
                                            reinterpret_cast<const char *>(buf),
                                            serialize_len + padding_chars);
#else

    chromium_base64_decode(reinterpret_cast<char *>(serialized),
                           reinterpret_cast<const char *>(buf),
                           serialize_len + padding_chars);
#endif
    return estimated_length + FIXED_PART_LEN - INDEX_LEN;
}

Base36反序列化(填充索引)

uint8_t *deserialize_base36_encoding_add_index(const uint8_t *serialized, uint16_t total_serialized_len,
                                               int &deserialized_len, int32_t idx) {
    auto serialize_len = total_serialized_len - (FIXED_PART_LEN - INDEX_LEN);
    auto *deserialized = new uint8_t[total_serialized_len / 3 * 4 + 16];
    // 1st: deserialize preparation: base64 encoding
#ifdef __AVX2__
    size_t length = fast_avx2_base64_encode(reinterpret_cast<char *>(deserialized),
                                            reinterpret_cast<const char *>(serialized), serialize_len);
#else
    size_t length = chromium_base64_encode(reinterpret_cast<char *>(deserialized),
                                           reinterpret_cast<const char *>(serialized), serialize_len);

#endif
    // 2nd: skip padding (padding could be 'A'-'Z', '+', '/', '=')
    for (; deserialized[length - 1] >= 'A' && deserialized[length - 1] <= 'Z' && length >= 0; length--) {}

    // 3rd: append other info
    size_t offset = length;
    memcpy(deserialized + offset, serialized + serialize_len, BASE64_INFO_LEN);
    offset += BASE64_INFO_LEN;
    memcpy(deserialized + offset, &idx, sizeof(int32_t));
    offset += INDEX_LEN;
    memcpy(deserialized + offset, serialized + serialize_len + BASE64_INFO_LEN, VARYING_VERIFY_LEN);

    // 4th: assign the correct length
    deserialized_len = length + FIXED_PART_LEN;
    return deserialized;
}
// ------------------------------ End of Base36, do not support A-Z yet --------------------------------------------

通用序列化

int serialize(uint8_t *message, uint16_t len, uint8_t *serialized) {
    // add the header to indicate raw message varying-length part size
    int serialize_len = len - FIXED_PART_LEN;
    if (len < 128) {
        serialized[0] = static_cast<uint8_t>(len - FIXED_PART_LEN);
        serialized += 1;
    } else {
        uint16_t tmp = (len - FIXED_PART_LEN);
        serialized[0] = static_cast<uint8_t>((tmp >> 7u) | 0x80); // assume < 32767
        serialized[1] = static_cast<uint8_t>(tmp & (uint8_t) 0x7f);
        serialized += 2;
    }
    uint32_t next_extra_3bits_idx = 5u * serialize_len;
    uint32_t next_5bits_idx = 0;

    // attention: message is not usable later
    for (int i = 0; i < serialize_len; i++) {
        message[i] = message[i] >= 'a' ? message[i] - 'a' : message[i] - '0' + (uint8_t) 26;
    }
    // attention: must clear to be correct
    memset(serialized, 0, (len - FIXED_PART_LEN));
    // 1) construct the compressed part
    for (int i = 0; i < serialize_len; i++) {
        uint16_t cur_uchar = message[i];
        uint16_t expand_uchar = cur_uchar < MAX_FIVE_BITS_INT ? (cur_uchar << 11u) : (MAX_FIVE_BITS_INT << 11u);

        int shift_bits = (next_5bits_idx & 0x7u);
        expand_uchar >>= shift_bits;
        int idx = (next_5bits_idx >> 3u);
        serialized[idx] |= (expand_uchar >> 8u);
        serialized[idx + 1] |= (expand_uchar & 0xffu);
        next_5bits_idx += 5;

        if (cur_uchar >= MAX_FIVE_BITS_INT) {
            // do extra bits operations
            expand_uchar = ((cur_uchar - MAX_FIVE_BITS_INT) << 13u);
            shift_bits = (next_extra_3bits_idx & 0x7u);
            expand_uchar >>= shift_bits;
            // assume little-endian
            idx = (next_extra_3bits_idx >> 3u);
            serialized[idx] |= (expand_uchar >> 8u);
            serialized[idx + 1] |= (expand_uchar & 0xffu);
            next_extra_3bits_idx += 3;
        }
    }

    // 2) left FIXED_PART_LEN, should use memcpy
    int start_copy_byte_idx = (next_extra_3bits_idx >> 3u) + ((next_extra_3bits_idx & 0x7u) != 0);
    memcpy(serialized + start_copy_byte_idx, message + serialize_len, FIXED_PART_LEN);
    return start_copy_byte_idx + FIXED_PART_LEN + (len < 128 ? 1 : 2);
}

通用反序列化

uint8_t *deserialize(const uint8_t *serialized, int &len) {
    // get the length of varying part
    uint16_t varying_byte_len;
    if ((serialized[0] & 0x80u) == 0) {
        varying_byte_len = serialized[0];
        serialized += 1;
    } else {
        varying_byte_len = static_cast<uint16_t>(((serialized[0] & 0x7fu) << 7u) + serialized[1]);
        serialized += 2;
    }
    uint32_t next_extra_3bits_idx = 5u * varying_byte_len;
    uint32_t next_5bits_idx = 0;

    auto *deserialized = new uint8_t[varying_byte_len + 8];
    len = varying_byte_len + FIXED_PART_LEN;
    // deserialize
    for (int i = 0; i < varying_byte_len; i++) {
        int idx = (next_5bits_idx >> 3u);
        uint16_t value = (serialized[idx] << 8u) + serialized[idx + 1];
        value = (value >> (11u - (next_5bits_idx & 07u))) & MAX_FIVE_BITS_INT;
        if (value != MAX_FIVE_BITS_INT) {
            deserialized[i] = static_cast<uint8_t>(value < 26 ? 'a' + value : value - 26 + '0');
        } else {
            idx = (next_extra_3bits_idx >> 3u);
            value = (serialized[idx] << 8u) + serialized[idx + 1];
            value = (value >> (13u - (next_extra_3bits_idx & 07u))) & (uint8_t) 0x7;
            deserialized[i] = value + '5';
            next_extra_3bits_idx += 3;
        }
        next_5bits_idx += 5;
    }

    // 2) copy the fixed part
    memcpy(deserialized + varying_byte_len,
           serialized + (next_extra_3bits_idx >> 3u) + ((next_extra_3bits_idx & 0x7u) != 0), FIXED_PART_LEN);
    return deserialized;
}

chromiumbase64.c中的方法

这里应该是参考了google的chromium浏览器内核中的base64加密算法实现,采用的是查表法实现,速度很快。

#define BADCHAR 0x01FFFFFF

/**
 * you can control if we use padding by commenting out this
 * next line.  However, I highly recommend you use padding and not
 * using it should only be for compatability with a 3rd party.
 * Also, 'no padding' is not tested!
 */
#define DOPAD 1

/*
 * if we aren't doing padding
 * set the pad character to NULL
 */
#ifndef DOPAD
#undef CHARPAD
#define CHARPAD '\0'
#endif

size_t chromium_base64_encode(char* dest, const char* str, size_t len)
{
    size_t i = 0;
    uint8_t* p = (uint8_t*) dest;

    /* unsigned here is important! */
    uint8_t t1, t2, t3;

    if (len > 2) {
        for (; i < len - 2; i += 3) {
            t1 = str[i]; t2 = str[i+1]; t3 = str[i+2];
            *p++ = e0[t1];
            *p++ = e1[((t1 & 0x03) << 4) | ((t2 >> 4) & 0x0F)];
            *p++ = e1[((t2 & 0x0F) << 2) | ((t3 >> 6) & 0x03)];
            *p++ = e2[t3];
        }
    }

    switch (len - i) {
    case 0:
        break;
    case 1:
        t1 = str[i];
        *p++ = e0[t1];
        *p++ = e1[(t1 & 0x03) << 4];
        *p++ = CHARPAD;
        *p++ = CHARPAD;
        break;
    default: /* case 2 */
        t1 = str[i]; t2 = str[i+1];
        *p++ = e0[t1];
        *p++ = e1[((t1 & 0x03) << 4) | ((t2 >> 4) & 0x0F)];
        *p++ = e2[(t2 & 0x0F) << 2];
        *p++ = CHARPAD;
    }

    *p = '\0';
    return p - (uint8_t*)dest;
}


size_t chromium_base64_decode(char* dest, const char* src, size_t len)
{
    if (len == 0) return 0;

#ifdef DOPAD
    /*
     * if padding is used, then the message must be at least
     * 4 chars and be a multiple of 4
     */
    if (len < 4 || (len % 4 != 0)) {
      return MODP_B64_ERROR; /* error */
    }
    /* there can be at most 2 pad chars at the end */
    if (src[len-1] == CHARPAD) {
        len--;
        if (src[len -1] == CHARPAD) {
            len--;
        }
    }
#endif

    size_t i;
    int leftover = len % 4;
    size_t chunks = (leftover == 0) ? len / 4 - 1 : len /4;

    uint8_t* p = (uint8_t*)dest;
    uint32_t x = 0;
    const uint8_t* y = (uint8_t*)src;
    for (i = 0; i < chunks; ++i, y += 4) {
        x = d0[y[0]] | d1[y[1]] | d2[y[2]] | d3[y[3]];
        if (x >= BADCHAR) return MODP_B64_ERROR;
        *p++ =  ((uint8_t*)(&x))[0];
        *p++ =  ((uint8_t*)(&x))[1];
        *p++ =  ((uint8_t*)(&x))[2];
    }

    switch (leftover) {
    case 0:
        x = d0[y[0]] | d1[y[1]] | d2[y[2]] | d3[y[3]];

        if (x >= BADCHAR) return MODP_B64_ERROR;
        *p++ =  ((uint8_t*)(&x))[0];
        *p++ =  ((uint8_t*)(&x))[1];
        *p =    ((uint8_t*)(&x))[2];
        return (chunks+1)*3;
        break;
    case 1:  /* with padding this is an impossible case */
        x = d0[y[0]];
        *p = *((uint8_t*)(&x)); // i.e. first char/byte in int
        break;
    case 2: // * case 2, 1  output byte */
        x = d0[y[0]] | d1[y[1]];
        *p = *((uint8_t*)(&x)); // i.e. first char
        break;
    default: /* case 3, 2 output bytes */
        x = d0[y[0]] | d1[y[1]] | d2[y[2]];  /* 0x3c */
        *p++ =  ((uint8_t*)(&x))[0];
        *p =  ((uint8_t*)(&x))[1];
        break;
    }

    if (x >= BADCHAR) return MODP_B64_ERROR;

    return 3*chunks + (6*leftover)/8;
}

fastavxbase64.c

代码中的最大亮点,使用AVX2指令集实现base64加解密。具体指令说明可以参考:https://software.intel.com/en-us/node/524017?language=ru

#include "fastavxbase64.h"

#include <x86intrin.h>
#include <stdbool.h>

/**
* This code borrows from Wojciech Mula's library at
* https://github.com/WojciechMula/base64simd (published under BSD)
* as well as code from Alfred Klomp's library https://github.com/aklomp/base64 (published under BSD)
*
*/

/**
* Note : Hardware such as Knights Landing might do poorly with this AVX2 code since it relies on shuffles. Alternatives might be faster.
*/


static inline __m256i enc_reshuffle(const __m256i input) {
    //_mm256_shuffle_epi8可以对32个8bit整数进行查表操作,其中参数一是被查表,参数二是查找掩码
    // translation from SSE into AVX2 of procedure
    // https://github.com/WojciechMula/base64simd/blob/master/encode/unpack_bigendian.cpp
    const __m256i in = _mm256_shuffle_epi8(input, _mm256_set_epi8(
        10, 11,  9, 10,
         7,  8,  6,  7,
         4,  5,  3,  4,
         1,  2,  0,  1,

        14, 15, 13, 14,
        11, 12, 10, 11,
         8,  9,  7,  8,
         5,  6,  4,  5
    ));
    //_mm256_set1_epi32使用参数一初始化8*32向量集
    //_mm256_and_si256对向量集进行按位与操作
    //_mm256_mulhi_epu16进行16位无符号向量集的乘法操作,并保留高bit位
    //_mm256_mullo_epi16进行16位无符号向量集的乘法操作,并保留低bit位
    const __m256i t0 = _mm256_and_si256(in, _mm256_set1_epi32(0x0fc0fc00));
    const __m256i t1 = _mm256_mulhi_epu16(t0, _mm256_set1_epi32(0x04000040));

    const __m256i t2 = _mm256_and_si256(in, _mm256_set1_epi32(0x003f03f0));
    const __m256i t3 = _mm256_mullo_epi16(t2, _mm256_set1_epi32(0x01000010));
    //_mm256_or_si256对向量集进行按位或操作
    return _mm256_or_si256(t1, t3);
}

static inline __m256i enc_translate(const __m256i in) {
  //_mm256_setr_epi8使用传入的参数反序初始化32*8向量集(a0赋值给b31,a1赋值给b30)
  //_mm256_set1_epi8使用传入的参数顺序初始化32*8向量集
  //_mm256_subs_epu8用参数一逐字节去减参数二
  //_mm256_cmpgt_epi8比较参数一和参数二各字节的大小,参数一的大则返回向量1,否则返回向量0
  const __m256i lut = _mm256_setr_epi8(
      65, 71, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -19, -16, 0, 0, 65, 71,
      -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -19, -16, 0, 0);
  __m256i indices = _mm256_subs_epu8(in, _mm256_set1_epi8(51));
  __m256i mask = _mm256_cmpgt_epi8((in), _mm256_set1_epi8(25));
  indices = _mm256_sub_epi8(indices, mask);
  __m256i out = _mm256_add_epi8(in, _mm256_shuffle_epi8(lut, indices));
  return out;
}

static inline __m256i dec_reshuffle(__m256i in) {

  // inlined procedure pack_madd from https://github.com/WojciechMula/base64simd/blob/master/decode/pack.avx2.cpp
  // The only difference is that elements are reversed,
  // only the multiplication constants were changed.

  const __m256i merge_ab_and_bc = _mm256_maddubs_epi16(in, _mm256_set1_epi32(0x01400140)); //_mm256_maddubs_epi16 is likely expensive
  __m256i out = _mm256_madd_epi16(merge_ab_and_bc, _mm256_set1_epi32(0x00011000));
  // end of inlined

  // Pack bytes together within 32-bit words, discarding words 3 and 7:
  out = _mm256_shuffle_epi8(out, _mm256_setr_epi8(
        2, 1, 0, 6, 5, 4, 10, 9, 8, 14, 13, 12, -1, -1, -1, -1,
        2, 1, 0, 6, 5, 4, 10, 9, 8, 14, 13, 12, -1, -1, -1, -1
  ));
  // the call to _mm256_permutevar8x32_epi32 could be replaced by a call to _mm256_storeu2_m128i but it is doubtful that it would help
  return _mm256_permutevar8x32_epi32(
      out, _mm256_setr_epi32(0, 1, 2, 4, 5, 6, -1, -1));
}


size_t fast_avx2_base64_encode(char* dest, const char* str, size_t len) {
      const char* const dest_orig = dest;
      if(len >= 32 - 4) {
        // first load is masked
        __m256i inputvector = _mm256_maskload_epi32((int const*)(str - 4),  _mm256_set_epi32(
            0x80000000,
            0x80000000,
            0x80000000,
            0x80000000,

            0x80000000,
            0x80000000,
            0x80000000,
            0x00000000 // we do not load the first 4 bytes
        ));
        //////////
        // Intel docs: Faults occur only due to mask-bit required memory accesses that caused the faults.
        // Faults will not occur due to referencing any memory location if the corresponding mask bit for
        //that memory location is 0. For example, no faults will be detected if the mask bits are all zero.
        ////////////
        while(true) {
          inputvector = enc_reshuffle(inputvector);
          inputvector = enc_translate(inputvector);
          _mm256_storeu_si256((__m256i *)dest, inputvector);
          str += 24;
          dest += 32;
          len -= 24;
          if(len >= 32) {
            inputvector = _mm256_loadu_si256((__m256i *)(str - 4)); // no need for a mask here
            // we could do a mask load as long as len >= 24
          } else {
            break;
          }
        }
      }
      size_t scalarret = chromium_base64_encode(dest, str, len);
      if(scalarret == MODP_B64_ERROR) return MODP_B64_ERROR;
      return (dest - dest_orig) + scalarret;
}

size_t fast_avx2_base64_decode(char *out, const char *src, size_t srclen) {
      char* out_orig = out;
      while (srclen >= 45) {

        // The input consists of six character sets in the Base64 alphabet,
        // which we need to map back to the 6-bit values they represent.
        // There are three ranges, two singles, and then there's the rest.
        //
        //  #  From       To        Add  Characters
        //  1  [43]       [62]      +19  +
        //  2  [47]       [63]      +16  /
        //  3  [48..57]   [52..61]   +4  0..9
        //  4  [65..90]   [0..25]   -65  A..Z
        //  5  [97..122]  [26..51]  -71  a..z
        // (6) Everything else => invalid input

        __m256i str = _mm256_loadu_si256((__m256i *)src);

        // code by @aqrit from
        // https://github.com/WojciechMula/base64simd/issues/3#issuecomment-271137490
        // transated into AVX2
        const __m256i lut_lo = _mm256_setr_epi8(
            0x15, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
            0x11, 0x11, 0x13, 0x1A, 0x1B, 0x1B, 0x1B, 0x1A,
            0x15, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
            0x11, 0x11, 0x13, 0x1A, 0x1B, 0x1B, 0x1B, 0x1A
        );
        const __m256i lut_hi = _mm256_setr_epi8(
            0x10, 0x10, 0x01, 0x02, 0x04, 0x08, 0x04, 0x08,
            0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10,
            0x10, 0x10, 0x01, 0x02, 0x04, 0x08, 0x04, 0x08,
            0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10
        );
        const __m256i lut_roll = _mm256_setr_epi8(
            0,   16,  19,   4, -65, -65, -71, -71,
            0,   0,   0,   0,   0,   0,   0,   0,
            0,   16,  19,   4, -65, -65, -71, -71,
            0,   0,   0,   0,   0,   0,   0,   0
        );

        const __m256i mask_2F = _mm256_set1_epi8(0x2f);

        // lookup
        __m256i hi_nibbles  = _mm256_srli_epi32(str, 4);
        __m256i lo_nibbles  = _mm256_and_si256(str, mask_2F);

        const __m256i lo    = _mm256_shuffle_epi8(lut_lo, lo_nibbles);
        const __m256i eq_2F = _mm256_cmpeq_epi8(str, mask_2F);

        hi_nibbles = _mm256_and_si256(hi_nibbles, mask_2F);
        const __m256i hi    = _mm256_shuffle_epi8(lut_hi, hi_nibbles);
        const __m256i roll  = _mm256_shuffle_epi8(lut_roll, _mm256_add_epi8(eq_2F, hi_nibbles));

        if (!_mm256_testz_si256(lo, hi)) {
            break;
        }

        str = _mm256_add_epi8(str, roll);
        // end of copied function

        srclen -= 32;
        src += 32;

        // end of inlined function

        // Reshuffle the input to packed 12-byte output format:
        str = dec_reshuffle(str);
        _mm256_storeu_si256((__m256i *)out, str);
        out += 24;
      }
      size_t scalarret = chromium_base64_decode(out, src, srclen);
      if(scalarret == MODP_B64_ERROR) return MODP_B64_ERROR;
      return (out - out_orig) + scalarret;
}

相关文章

网友评论

      本文标题:第四届中间件性能挑战赛冠军代码解析-复赛C代码

      本文链接:https://www.haomeiwen.com/subject/ybbciftx.html