美文网首页
基于视觉的多分辨率地图构建与定位程序说明

基于视觉的多分辨率地图构建与定位程序说明

作者: Omar_4321 | 来源:发表于2019-01-02 23:09 被阅读0次

    一.安装

    库:

    安装numpy、matplotlib、sklearn、scipy、PIL、opencv、pickle、pytorch(高于等于0.4)
    代码在CycleGAN and pix2pix in PyTorch基础上编写。
    Python版本为3.6(使用3.5和3.7也能运行),在Windows和Ubuntu下都能运行,windows下可能会报Lambda表达式打包的错误。

    文件目录结构:

    C:\CODE_P
    ├─alexnet 中间层特征可视化
    ├─Checkpoints 保存训练模型
    ├─Cluster 聚类相关程序
    │ │ cluster_img.py 对图像聚类
    │ │ gen_cell.py 生成Cell,主要为Cell后处理程序
    ├─Data 数据准备模块
    │ │ base_dataset.py 基类,在送入网络前进行处理
    │ │ base_data_loader.py 基类
    │ │ Dataset_gather.py 根据不同参数调用不同的数据的具体实现
    │ │ data_loader.py 数据下载类,根据不同参数调用不同的数据
    │ │ Data_manage.py 数据管理, 读取、生成路径等
    ├─figure 保存图表
    ├─label 保存所有label
    ├─Models 网络模型相关
    │ │ base_model.py 基类
    │ │ double_threads.py 双线程示例
    │ │ layers_trans.py 替换字符串string中指定位置p的字符为c,用于批量转换模型各层的名字
    │ │ models.py 根据参数进行模型选择
    │ │ model_set.py 网络调用、优化函数定义、前向和反向传播及损失值计算
    │ │ networks.py 网络模型定义
    │ │ resnet_layer_trans.txt

    ├─Options
    │ │ options_set.py 参数定义
    ├─pre_model_state_dict
    │ resnet18-5c106cde.pth 预训练模型
    ├─Result 结果
    ├─runs
    │ │ 无视这个文件夹 │
    ├─util 计算图像中值
    │ compute_image_mean.py│
    ├─Visualization 可视化相关程序
    │ ├─test
    └─

    二.总体流程

    下图是总体流程,三个部分分别为 特征提取、构建Cell和训练定位网络并定位三个部分

    程序流程图

    主程序

    opt = BaseOptions().parse()#导入配置参数
    clu = clusterdata()#实例化clusterdata类
    datareader = dataread(opt)#实例化数据读取类
    [gps_x,gps_y] = datareader.get_gps()#读取数据的GPS信息
    c = dataset.num_img#各个车道的图像数量,左中右对应c[0]、c[1]、c[2]
    ll2 = clu.cluster_sequence(length,200)#根据图像序列平均划分200个cell
    ll3 = clu.cluster_sequence(length,600)
    ll4 = clu.cluster_sequence(length,900)
    
    #标注三个车道的图像为0、1、2
    three_cla = numpy.zeros(length,dtype=int)
    three_cla[0:c[0]] = 0
    three_cla[c[0]:c[0]+c[1]] = 1
    three_cla[c[0]+c[1]:c[0]+c[1]+c[2]] = 2
    three_l = numpy.array(three_cla)
    #numpy.savetxt('3.txt',three_l,fmt='%d')
    
    #平均划分的CELL标号写入txt
    f =open('label/seq200.txt','w')  
    for j in range(len(img_dir)):
        text = str(img_dir[j][37:]) + ' ' + str(int(ll2[j]))                                
        f.write(text)
        f.write('\n')
    f.close()
    f =open('label/seq600.txt','w')  
    for j in range(len(img_dir)):
        text = str(img_dir[j][37:]) + ' ' + str(int(ll3[j]))                                
        f.write(text)
        f.write('\n')
    f.close()
    f =open('label/seq900.txt','w')  
    for j in range(len(img_dir)):
        text = str(img_dir[j][37:]) + ' ' + str(int(ll4[j]))                                
        f.write(text)
        f.write('\n')
    f.close()
    #将数据随机划分为无序的训练集和测试集
    split(opt,'seq200')
    split(opt,'seq600')
    split(opt,'seq900')
    
    train_extract_features(opt.num_outputs)#训练特提取网络
    extract_features(opt.num_outputs)#使用训练好的网络提取图像特征
    label = clu.clu_features(900)#根据新特征聚类
    numpy.save('clu_900.npy',label)#保存聚类结果#label = numpy.load('clu_900.npy')#下载聚类结果
    #显示聚类结果的柱形图
    plt.figure(2)
    plt.bar(numpy.arange(len(label)),label,width = 1)
    plt.show()
    #实例化生成Cell的类
    cell_gen = CELL(label,900,100,0.16,0.5,opt)  #the third para should < 0.25,else all cells will be 0
    [cells_num,lane_cells_count] = cell_gen.gen_cell()#生成Cell,即对聚类结果进行后处理
    
    #cells_num = 631
    
    train_clustered_cell(cells_num)#训练定位网络
    single_localization(cells_num)#单张图片定位
    
    train_cells(cells_num)#分层定位网络训练
    localization(cells_num)#分层定位网络单张图片定位
    

    Data模块

    CreateDataLoader函数准备数据

        train_str = 'train'
        test_str = 'test'
        dataset_train = CreateDataLoader(opt,train_str,isTrain = True)
        dataset_test = CreateDataLoader(opt,test_str,isTrain = False)
    

    CreateDataLoader,实例化CustomDatasetDataLoader类并进行初始化

    def CreateDataLoader(opt,phase,isTrain):
        data_loader = CustomDatasetDataLoader()
        print(data_loader.name())
        data_loader.initialize(opt,phase,isTrain)
        return data_loader
    

    CustomDatasetDataLoader类实现如下:重写BaseDataLoader类,并在初始化时通过调用CreateDataset函数选择数据集,再定义torch.utils.data.DataLoader中的参数,如batch大小,是否给顺序等

    class CustomDatasetDataLoader(BaseDataLoader):
        def name(self):
            return 'CustomDatasetDataLoader'
    
        def initialize(self, opt,phase,isTrain):
            BaseDataLoader.initialize(self, opt)
            self.dataset = CreateDataset(opt,phase,isTrain)
            self.dataloader = torch.utils.data.DataLoader(
                self.dataset,
                batch_size=opt.batchSize,
                #shuffle = opt.shuffle if isTrain else not opt.shuffle,
                #shuffle = False, 
                shuffle= isTrain,
                num_workers=int(opt.nThreads))
            print('-----------------dataloader------------------')
            #print(self.dataset)
        def load_data(self):
            return self
    
        def __len__(self):
            return min(len(self.dataset), self.opt.max_dataset_size)
    
        def __iter__(self):
            for i, data in enumerate(self.dataloader):
                if i >= self.opt.max_dataset_size:
                    break
                yield data
    

    CreateDataset函数实现如下:通过opt.dataset_mode参数选择数据集

    def CreateDataset(opt,phase,isTrain):
        dataset = None
        if opt.dataset_mode == 'c3_Dataset':
            from Data.Dataset_gather import c3_Dataset
            dataset = c3_Dataset()
        elif opt.dataset_mode == 'seq_Dataset':
            from Data.Dataset_gather import seq_Dataset
            dataset = seq_Dataset()
        elif opt.dataset_mode == 'cells_Dataset':
            from Data.Dataset_gather import cells_Dataset
            dataset = cells_Dataset()
        elif opt.dataset_mode == 'clustered_cells_Dataset':
            from Data.Dataset_gather import clustered_cells_Dataset
            dataset = clustered_cells_Dataset()
        elif opt.dataset_mode == 'other_test_dataset':
            from Data.Dataset_gather import other_test_dataset
            dataset = other_test_dataset()
        else:
            raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)
    
        print("dataset [%s] was created" % (dataset.name()))
        dataset.initialize(opt,phase,isTrain)
        return dataset
    

    以cells_Dataset数据为例,先下载对应的train.txt和test.txt以及单张图片定位的txt。调用get_transform_函数,定义数据处理过程,这个函数中的处理过程在自动调用__getitem__函数时会自动进行,如对图片进行剪裁、缩放等。__getitem__函数在

    for i, data in enumerate(dataset_train):
    

    循环中会在每一次迭代时自动调用,返回的data即为return的数据
    for i, data in enumerate(dataset_train):

    class cells_Dataset(BaseDataset):
        def initialize(self, opt ,phase ,isTrain):
            self.opt = opt
            self.root = opt.coderoot
            self.transform_flag = True
            str_train = '/label/cell_'+ str(opt.num_outputs) +'_train.txt'
            str_test = '/label/cell_'+ str(opt.num_outputs) +'_test.txt'
            str_localiza = '/label/cell_'+ str(opt.num_outputs) +'.txt'
            if(phase == 'train'):
                split_file = self.root + str_train
    #            split_file.replace(''\'',''/'')
                isTrain = True
            elif(phase == 'test'):
                split_file = self.root + str_test
                isTrain = False
            else:
                #isTrain = True
                self.transform_flag = False
                split_file = self.root + str_localiza
            self.path = numpy.loadtxt(split_file, dtype=str, delimiter=' ', skiprows=0, usecols=(0))
            #self.path = [os.path.join(self.opt.dataroot, path) for path in self.path]
            self.path = [(self.opt.dataroot + path) for path in self.path]
            self.lane= numpy.loadtxt(split_file, dtype=float, delimiter=' ', skiprows=0, usecols=(1))
            self.cell= numpy.loadtxt(split_file, dtype=float, delimiter=' ', skiprows=0, usecols=(2))
            self.mean_image = numpy.load(os.path.join(self.opt.dataroot , 'mean_image.npy'))#下载中值文件
            self.size = len(self.path)
            print('len(self.path):{:}'.format(self.size))
            self.transform = get_transform_(opt,self.mean_image,self.transform_flag)#定义数据处理过程
            #self.num_outputs = opt.num_outputs
        def __getitem__(self, index):
            path = self.path[index % self.size]
            A_img = Image.open(path).convert('RGB')
            #A_img.save('pic/'+path[-9:])
            #print('************')
            cell = self.cell[index % self.size]
            lane = self.lane[index % self.size]
            img = self.transform(A_img)
            return {'img': img, 'cell': cell,
                    'path': path,'lane':lane}
    
        def __len__(self):
            return self.size
    
        def name(self):
            return 'cells_Dataset'
    
    

    get_transform_函数的定义如下:使用lambda表达式将函数打包到transforms,在每次执行上面的__getitem__函数时,这些lambda表达式封装的函数都会对每张图片进行处理。

    def get_transform_(opt, mean_image,isTrain = True):
        transform_list = []
        transform_list.append(transforms.Resize(opt.loadSize, Image.BICUBIC))
        transform_list.append(transforms.Lambda(lambda img: __subtract_mean(img, mean_image)))
        transform_list.append(transforms.Lambda(lambda img: __crop_image(img, opt.fineSize, isTrain)))
        transform_list.append(transforms.Lambda(lambda img: __to_tensor(img)))
        return transforms.Compose(transform_list)
    
    def __scale_width(img, target_width):
        ow, oh = img.size
        if (ow == target_width):
            return img
        w = target_width
        h = int(target_width * oh / ow)
        return img.resize((w, h), Image.BICUBIC)
    
    def __subtract_mean(img, mean_image):
        if mean_image is None:
            return numpy.array(img).astype('float')
        return numpy.array(img).astype('float') - mean_image.astype('float')
    
    def __crop_image(img, size, isTrain):
        h, w = img.shape[0:2]
        # w, h = img.size
        if isTrain:
            if w == size and h == size:
                return img
            x = numpy.random.randint(0, w - size)
            y = numpy.random.randint(0, h - size)
        else:
            x = int(round((w - size) / 2.))
            y = int(round((h - size) / 2.))
        return img[y:y+size, x:x+size, :]
        # return img.crop((x, y, x + size, y + size))
    
    def __to_tensor(img):
        return torch.from_numpy(img.transpose((2, 0, 1)))
    

    Model模块

    model = create_model(opt)
    

    create_model创建model,根据 opt.model参数创建用于特征训练、定位训练和车道分类的网络

    def create_model(opt,istest = False):
        model = None
        print(opt.model)
        if opt.model == 'RESNET18'
            from .model_set import RESNET18Model
            model = RESNET18Model():  #训练特征提取网络、定位网络
        elif opt.model == 'RESNET18_CELL':   
            from .model_set import RESNET18Model_CELL
            model = RESNET18Model_CELL() :#训练分层定位网络
        elif opt.model == 'RESNET18_3':   
            from .model_set import RESNET18Model_3
            model = RESNET18Model_3() #训练车道分类网络
        else:
            raise ValueError("Model [%s] not recognized." % opt.model)
        model.initialize(opt, istest)
        #print("model [%s] was created" % (model.name()))
        return model
    
        def save_network(self, network, network_label, epoch):
            save_filename = '%s_net_%s.pth' % (network_label, epoch)
            save_path = os.path.join(self.save_dir, '%s_%s'%(self.opt.dataset_mode,self.opt.num_outputs))
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            save_path = os.path.join(save_path,save_filename)
            torch.save(network.state_dict(), save_path)
    

    以RESNET18Model类为例,讲解Model类的功能。RESNET18Model类重写了BaseModel类,BaseModel类中有个重要的函数实现,即save_network函数,在base_model文件中。

    class RESNET18Model(BaseModel):
        def name(self):
            return 'RESNET18'
        def initialize(self, opt,isTest = False):#调用resnet网络结构;定义优化方法为SGD;定义训练策略lr_scheduler.StepLR
            self.opt = opt
            BaseModel.initialize(self, opt)
            self.isTrain = not isTest
            self.net = networks.RESNET18(opt.num_outputs,isTest)#调用net为networks模块下的RESNET18网络
            if self.isTrain:
                self.old_lr = opt.lr
                self.criterion = torch.nn.CrossEntropyLoss()    #定义损失函数为交叉熵函数       
                self.optimizers = []
                self.optimizer_A = torch.optim.SGD(self.net.parameters() , lr = opt.lr , momentum = 0.9)#定义优化方法为SGD
                self.optimizers.append(self.optimizer_A)
                self.schedulers = lr_scheduler.StepLR(self.optimizer_A, step_size=10, gamma=0.9)#定义训练策略,每10个epoch学习率×0.9
        def set_input(self, input):#设置输出图像
            self.input_img = input['img']
            self.cell = input['cell']
            self.image_paths = input['path']        
        def forward(self):#推理函数
            self.input_img = Variable(self.input_img.float().cuda())
            [self.features,self.pred] = self.net(self.input_img)
            Z = F.softmax(self.pred,dim=1)#获得softmax输出
            _ , self.preds_= torch.max(Z, 1)#获得softmax输出中概率最大的类
        def extract_features(self):
            f = deepcopy(self.features.data.cpu().numpy())#提取特征
            return f
        def testnet(self):#测试,只推理不backward
            self.forward()
            
        def trainnet(self):#训练
            self.optimize_parameters()
        def get_pred_result(self):
            return self.preds_
        def get_image_paths(self):
            return self.image_paths
        def backward(self):#反向传播
            self.loss = self.criterion(self.pred,self.cell.long().cuda())
            self.loss.backward()
        def optimize_parameters(self):#训练优化
            self.forward()
            self.optimizer_A.zero_grad()
            self.backward()
            self.optimizer_A.step()
        def get_current_acc(self,opt):#得到每个batch的正确率
            self.cell = self.cell.long().cuda()
            self.running_corrects = int(torch.sum(self.preds_ == self.cell.data))
            return self.running_corrects    
        def get_current_loss(self,opt):#得到损失值
            self.loss = self.criterion(self.pred,self.cell.long().cuda())
            return float(self.loss)
        def save(self, epoch):
            self.save_network(self.net, 'RESNET18', epoch)
        def forward_singlepic(self):#单张图片推理
            self.input_img = Variable(self.input_img.float().cuda())
            [self.features,self.pred] = self.net(self.input_img)
            Z = F.softmax(self.pred,dim=1)
            _ , self.preds_= torch.max(Z, 1)
            return self.preds_
    

    networks模块定义了各个网络的结构,其中class RESNET18(torch.nn.Module):是继承了torch.nn.Module类,重写了初始化函数 init和前向传播函数forward,在网络喂入图片数据后自动调用forward函数。 torch.load 返回的是一个 OrderedDict。关于模型和权重下载以及权重保存格式等,可以阅读这个博客。self.model.eval()将网络调到测试模式,测试模式时对ropout和batch normalization层的操作在训练和测试的时候是不一样的,具体讲解看这个博客

    class RESNET18(torch.nn.Module):
        """Constructs a ResNet-18 model.
        """
        def __init__(self, num_output, isTest=False,  gpu_ids=[]):
            super(RESNET18, self).__init__()        
            self.model_name = 'resnet18'
            self.gpu_ids = gpu_ids
            state_dict = (torch.load('C:/code_p/Checkpoints/RESNET18/clustered_cells_Dataset_631/RESNET18_net_068.pth'))#预训练权重,其数据结构是每个键对应一个层
            self.model = ResNet(BasicBlock, [2, 2, 2, 2], num_output)#定义ResNet的具体网络结构
            pretrained = True
            if pretrained:
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[6:] # remove `module.`
                    new_state_dict[name] = v
                    #print(v.size())
                self.model.load_state_dict(new_state_dict,strict = True) #将权重下载到模型中,以模型各层的名字为准,名字不对应则报错,如果strict = False,名字不对应则直接略过。
            if isTest:
                self.model.eval()#在测试模式下
            self.model.eval()
            self.model = self.model.cuda()
            print(self.model)
        def forward(self, x):#前向传播函数
            out = self.model(x)
            return out 
    

    还有网络结构的具体实现,这部分为官方对resnet18的实现源码。这里不讲解,可以去网上搜一下资料。

    训练函数

    以训练特征提取网络为例,先创建数据,然后创建模型,在每次epoch中进行一次训练和一次测试。

    def train_extract_features(num_outputs):
        opt = BaseOptions().parse()
        train_str = 'train'
        test_str = 'test'
        dataset_train = CreateDataLoader(opt,train_str,isTrain = True)#创建训练数据
        dataset_test = CreateDataLoader(opt,test_str,isTrain = False)#创建测试数据
        dataset_size_train = len(dataset_train)
        dataset_size_test = len(dataset_test)
        model = create_model(opt)
        Loss_list_train = []
        Loss_list_test = []
        Accuracy_list_train = []
        Accuracy_list_test = []
        for epoch in range(opt.num_epochs):#epoch 
            epoch_acc_train = 0
            epoch_acc_test = 0
            epoch_loss_train = 0
            epoch_loss_test = 0
            print('Training...')
            for i, data in enumerate(dataset_train): #iter
                #print('[%04d/%04d] ' % (i, len(dataset_train)/opt.batchSize), end='\r')
                model.set_input(data)#输入数据
                model.trainnet()#训练网络
                running_corrects = model.get_current_acc(opt)
                running_loss = model.get_current_loss(opt)
                epoch_acc_train = running_corrects + epoch_acc_train
                epoch_loss_train = running_loss + epoch_loss_train
                Loss_list_train.append(running_loss)
                data_batch_size = len(data['cell'])
                Accuracy_list_train.append(running_corrects/data_batch_size)
                #print(running_loss)
                #print(running_corrects)
                print('[%04d/%04d] ------------------  corrects: %04f-------------------' % (i, len(dataset_train)/opt.batchSize,epoch_acc_train/(i+1)/data_batch_size), end='\r')
            epoch_loss_train = epoch_loss_train/(i+1)
            epoch_acc_train = epoch_acc_train*100/dataset_size_train
            print(' Train epoch {:}:---- lr:{:} ----Acc: {:.4f}%  loss:{:.4f}' .format(epoch,opt.lr,epoch_acc_train,epoch_loss_train))
            print('Test...')
            for i, data in enumerate(dataset_test):
                model.set_input(data)
                istest = True
                model.testnet()
                running_corrects = model.get_current_acc(opt)
                running_loss = model.get_current_loss(opt)
                epoch_acc_test = running_corrects + epoch_acc_test
                epoch_loss_test = running_loss + epoch_loss_test
                Loss_list_test.append(running_loss)
                data_batch_size = len(data['cell'])
                Accuracy_list_test.append(running_corrects/data_batch_size)
                print('[%04d/%04d] ------------------  corrects: %04f-------------------' % (i, len(dataset_test)/opt.batchSize,epoch_acc_test/(i+1)/data_batch_size), end='\r')
            epoch_loss_test = epoch_loss_test/(i+1)
            epoch_acc_test = epoch_acc_test*100/dataset_size_test
            print(' Test epoch {:}:---- lr:{:} ----Acc: {:.4f}%  loss:{:.4f}' .format(epoch,opt.lr,epoch_acc_test,epoch_loss_test))
            model.save(epoch)
            if(epoch%1 == 0):
                numpy.save('D:/figure/Loss_list_train_600_.npy',Loss_list_train)#保存loss
                numpy.save('D:/figure/Loss_list_test_600_.npy',Loss_list_test)
                numpy.save('D:/figure/Accuracy_list_train_600_.npy',Accuracy_list_train)
                numpy.save('D:/figure/Accuracy_list_test_600_.npy',Accuracy_list_test)
            if((epoch_acc_train>99.9)&(epoch_acc_test>99.9)):
                break
    

    GEN_CELL模块

    GEN_CELL模块是对聚类后的cell进行后处理生成最终cell的模块

    #!/usr/bin/env python3
    # -*- coding: utf-8 -*-
    """
    Created on Sat Jul 28 10:19:06 2018
    
    @author: zs
    """
    import numpy 
    import os
    from matplotlib import pyplot as plt
    class CELL():
        def __init__(self,label,k,outline_range,outline_range_threshold,kkk,opt):
            ''''
            outline_range、outline_range_threshold分别是统计范围和小Cell的阈值,小于这个阈值则合并,kkk*outline_range是需要处理的范围,kk是个系数。
            ''''
            self.root = opt.coderoot
            self.outline_range = int(outline_range)
            self.outline_range_threshold = outline_range_threshold
            self.num = k
            self.kkk = kkk
            self.new_index = numpy.zeros(k)
            self.index = label
            self.dataset_size = len(label)
            self.new_label = numpy.zeros(label.shape)
            str_path = 'label/seq'+ str(opt.num_outputs) +'.txt'
            split_file = os.path.join(self.root , str_path)
            self.path = numpy.loadtxt(split_file, dtype=str, delimiter=' ', skiprows=0, usecols=(0))
            str_path = 'label/3.txt'
            split_file = os.path.join(self.root , str_path)
            self.lines = numpy.loadtxt(split_file, dtype=int, delimiter=' ', skiprows=0, usecols=(1))
            print(self.new_label.shape)
        def idx_transformation(self):#将聚类形成无序label转换成以图像序列为准的有序label
            n = 0
            for i in range(0,self.dataset_size):
                if(self.new_index[self.index[i]] == 0):
                    self.new_index[self.index[i]] = n
                    n+=1
            for i in range(0,self.dataset_size):
                self.new_label[i] = int(self.new_index[self.index[i]])-1
                #print(self.new_label[i])
            #self.lll = self.new_label.copy()
            
            save_dir = 'C:/code_p/label/new_label%d.txt'%(self.num)
            numpy.savetxt(save_dir,self.new_label,fmt='%d')
            self.label_removed = self.new_label.copy()
        def check_same_cell_differentlane(self):#检查小的cell和包含不同车道图像的cell
            lane_cells_count = numpy.zeros(3,dtype=numpy.int)
            for j in range(len(self.lines)):
                lane_cells_count[self.lines[j]] += 1
            lane_cells_count[1] = lane_cells_count[1]+lane_cells_count[0]
            print(lane_cells_count)
            save_dir = 'C:/code_p/label/ttt.txt'
            f =open(save_dir,'w')  
            for j in range(len(self.new_label)):
                text =str(self.new_label[j])                                 
                f.write(text)
                f.write('\n')
            f.close
            for i in range(len(lane_cells_count)-1):
                if (self.new_label[lane_cells_count[i]] == self.new_label[lane_cells_count[i]+1]):
                    for j in range(lane_cells_count[i]+1,len(self.new_label)):
                        self.new_label[j] += 1        
            
            
        def remove_smallcell(self):#移除小的cell
            print('remove small cell...')
            kkk = self.kkk
            for i in range(0,len(self.new_label)-self.outline_range):
                #print(i)
                global point_sample
                point_sample = numpy.zeros(int(2000),dtype=numpy.int)#保存一定范围内各类cell的数量,因为我们的cell数量这里不超过2000,因此长度设为2000,保证不会超出
                #ll_ = numpy.zeros(int(2000),dtype=numpy.int)
                #print(int(self.outline_range/2))
                #last = self.label_removed[i]
                #nn = 0
                #ll_[0] = self.label_removed[i]
                for j in range(0,self.outline_range):
    #                if (self.label_removed[i+j] != last):
    #                    nn += 1
    #                    ll_[nn] = self.label_removed[i+j]
                        
                    point_sample[int(self.label_removed[i+j])] += 1#在[i,i+outline_range]范围内统计每类标签的数量
                    #last = self.label_removed[i+j]
                for ii in range(0,2000):
                    flag_remove_once = 0
                    if(point_sample[ii]<=self.outline_range*self.outline_range_threshold)and(point_sample[ii]>0):   #如果此类cell数量不是0并且小于阈值                 
                        for jj in range(int(self.outline_range/2-self.outline_range*kkk/2),int(self.outline_range/2+self.outline_range*kkk/2)):#对统计处理范围内的小cell进行合并
                            #print(ll.shape())
                            #print(ll_)
                            if (self.label_removed[i+jj] == ii):
                                print('remove%d'%(self.new_label[i+jj]))
                                self.label_removed[i+jj] = self.label_removed[i+jj-1]
                                flag_remove_once = 1
                                print(point_sample)
                                for iii in range(len(point_sample)):
                                    if(point_sample[iii] >0):
                                        print(point_sample[iii])
                                #print(ll_)
    #                        plt.bar(range(i+0,i+self.outline_range),self.label_removed[i:i+self.outline_range],width = 1)
    #                        plt.show()
                    if flag_remove_once:
                        i -= 1#保证可以处理交叉的CELL
                        break
            #for i in range(0,len(self.new_label)):
                #print(self.new_label[i],'---',self.lll[i]) 
                #if(abs(self.new_label[i]-self.lll[i])>0.1):
                    #print('remove label: %d'%(self.lll[i]))
            save_dir = 'C:/code_p/label/label_removed%d.txt'%(self.num)
            numpy.savetxt(save_dir,self.label_removed,fmt='%d')                  
        def checkandsort_cell(self):#对重复出现的大cell赋予新的标号
    #        for i in range(0,len(self.new_label)):
    #            #print(self.new_label[i],'---',self.lll[i]) 
    #            if(abs(self.new_label[i]-self.lll[i])>0.1):
    #                print('remove label')#: %d'%(self.lll[i]))
            print('check and sort cell...')
            cell= numpy.zeros(len(self.label_removed),dtype=numpy.int)
            last = 0
            class_plus = 0
            reco = numpy.zeros(1000,dtype=numpy.int)
            self.count = 0
            for i in range(0,len(self.label_removed)):
                n = int(self.label_removed[i])
                if(reco[n]> 0)and(abs(n - last)>0.1):
                    #print('check repeated cell: %d:'%(n))
                    self.count += 1
                if(abs(n-last)>0.1):
                    class_plus += 1
                #print(class_plus)
                cell[i] = class_plus
                last = n
                reco[n] += 1
            print('check repeated  %d cells'%(self.count))
            return cell
        def cou_cell(self):#统计处理前后的cell数量变化
            nn = int(self.new_label[-1]+1)
            print(nn)
            cm = numpy.zeros(nn,dtype=numpy.int)
            for i in range(0,len(self.new_label)):
                cm[int(self.new_label[i])-1] += 1
            print('before remove max: %d'%(max(cm)))
            print('before remove min: %d'%(min(cm)))
            #print(cm)
            cm = numpy.zeros(nn,dtype=numpy.int)
            for i in range(0,len(self.label_removed)):
                cm[int(self.label_removed[i])-1] += 1
            print('after remove max: %d'%(max(cm)))
            print('after remove min: %d'%(min(cm)))
            #print(cm)
        def gen_cell(self):
            self.idx_transformation()
            #self.check_same_cell_differentlane()
            self.remove_smallcell()
            
            self.cou_cell()
            cell = self.checkandsort_cell()
            save_dir = 'C:/code_p/label/cell_%d.txt'%(cell[-1]+1)
            self.path
            lane_cells_count = numpy.zeros(3,dtype=numpy.int)
            f =open(save_dir,'w')  
            for j in range(len(cell)):
                lane_cells_count[self.lines[j]] += 1
                text = self.path[j]+' '+str(self.lines[j])+' ' +str(cell[j])                                 
                f.write(text)
                f.write('\n')
            f.close
            print(lane_cells_count)
            lane_branch_start_cell = [0,cell[lane_cells_count[0]],cell[lane_cells_count[0]+lane_cells_count[1]]]
            print(lane_branch_start_cell)
            lane_cells_cla = [cell[lane_cells_count[0]-1]+1,cell[lane_cells_count[1]+lane_cells_count[0]-1]-cell[lane_cells_count[0]]+1,cell[lane_cells_count[2]+lane_cells_count[1]+lane_cells_count[0]-1]-cell[lane_cells_count[1]+lane_cells_count[0]]+1]
            #lane_cells_cla = [lane_cells_cla[0],lane_cells_cla[1]-]
            
            numpy.save('C:/code_p/label/lane_cells_cla_900.npy',lane_cells_cla)
            numpy.save('C:/code_p/label/lane_branch_start_cell_900.npy',lane_branch_start_cell)
    
            return [cell[-1]+1,lane_cells_cla]
    

    论文第四章是分层定位网络相关的内容,主要根据浅层的特征对数据进行预分类,这应该不会是你的下一步重点,如果你想了解,相关程序主要在以下两个类:
    model_set 文件内的

    class RESNET18Model_CELL(BaseModel):
    

    networks文件内的

    class ResNet_redefined_cell(nn.Module):
    

    相关文章

      网友评论

          本文标题:基于视觉的多分辨率地图构建与定位程序说明

          本文链接:https://www.haomeiwen.com/subject/hkcxrqtx.html