一.安装
库:
安装numpy、matplotlib、sklearn、scipy、PIL、opencv、pickle、pytorch(高于等于0.4)
代码在CycleGAN and pix2pix in PyTorch基础上编写。
Python版本为3.6(使用3.5和3.7也能运行),在Windows和Ubuntu下都能运行,windows下可能会报Lambda表达式打包的错误。
文件目录结构:
C:\CODE_P
├─alexnet 中间层特征可视化
├─Checkpoints 保存训练模型
├─Cluster 聚类相关程序
│ │ cluster_img.py 对图像聚类
│ │ gen_cell.py 生成Cell,主要为Cell后处理程序
├─Data 数据准备模块
│ │ base_dataset.py 基类,在送入网络前进行处理
│ │ base_data_loader.py 基类
│ │ Dataset_gather.py 根据不同参数调用不同的数据的具体实现
│ │ data_loader.py 数据下载类,根据不同参数调用不同的数据
│ │ Data_manage.py 数据管理, 读取、生成路径等
├─figure 保存图表
├─label 保存所有label
├─Models 网络模型相关
│ │ base_model.py 基类
│ │ double_threads.py 双线程示例
│ │ layers_trans.py 替换字符串string中指定位置p的字符为c,用于批量转换模型各层的名字
│ │ models.py 根据参数进行模型选择
│ │ model_set.py 网络调用、优化函数定义、前向和反向传播及损失值计算
│ │ networks.py 网络模型定义
│ │ resnet_layer_trans.txt
│
├─Options
│ │ options_set.py 参数定义
├─pre_model_state_dict
│ resnet18-5c106cde.pth 预训练模型
├─Result 结果
├─runs
│ │ 无视这个文件夹 │
├─util 计算图像中值
│ compute_image_mean.py│
├─Visualization 可视化相关程序
│ ├─test
└─
二.总体流程
下图是总体流程,三个部分分别为 特征提取、构建Cell和训练定位网络并定位三个部分
主程序
opt = BaseOptions().parse()#导入配置参数
clu = clusterdata()#实例化clusterdata类
datareader = dataread(opt)#实例化数据读取类
[gps_x,gps_y] = datareader.get_gps()#读取数据的GPS信息
c = dataset.num_img#各个车道的图像数量,左中右对应c[0]、c[1]、c[2]
ll2 = clu.cluster_sequence(length,200)#根据图像序列平均划分200个cell
ll3 = clu.cluster_sequence(length,600)
ll4 = clu.cluster_sequence(length,900)
#标注三个车道的图像为0、1、2
three_cla = numpy.zeros(length,dtype=int)
three_cla[0:c[0]] = 0
three_cla[c[0]:c[0]+c[1]] = 1
three_cla[c[0]+c[1]:c[0]+c[1]+c[2]] = 2
three_l = numpy.array(three_cla)
#numpy.savetxt('3.txt',three_l,fmt='%d')
#平均划分的CELL标号写入txt
f =open('label/seq200.txt','w')
for j in range(len(img_dir)):
text = str(img_dir[j][37:]) + ' ' + str(int(ll2[j]))
f.write(text)
f.write('\n')
f.close()
f =open('label/seq600.txt','w')
for j in range(len(img_dir)):
text = str(img_dir[j][37:]) + ' ' + str(int(ll3[j]))
f.write(text)
f.write('\n')
f.close()
f =open('label/seq900.txt','w')
for j in range(len(img_dir)):
text = str(img_dir[j][37:]) + ' ' + str(int(ll4[j]))
f.write(text)
f.write('\n')
f.close()
#将数据随机划分为无序的训练集和测试集
split(opt,'seq200')
split(opt,'seq600')
split(opt,'seq900')
train_extract_features(opt.num_outputs)#训练特提取网络
extract_features(opt.num_outputs)#使用训练好的网络提取图像特征
label = clu.clu_features(900)#根据新特征聚类
numpy.save('clu_900.npy',label)#保存聚类结果#label = numpy.load('clu_900.npy')#下载聚类结果
#显示聚类结果的柱形图
plt.figure(2)
plt.bar(numpy.arange(len(label)),label,width = 1)
plt.show()
#实例化生成Cell的类
cell_gen = CELL(label,900,100,0.16,0.5,opt) #the third para should < 0.25,else all cells will be 0
[cells_num,lane_cells_count] = cell_gen.gen_cell()#生成Cell,即对聚类结果进行后处理
#cells_num = 631
train_clustered_cell(cells_num)#训练定位网络
single_localization(cells_num)#单张图片定位
train_cells(cells_num)#分层定位网络训练
localization(cells_num)#分层定位网络单张图片定位
Data模块
CreateDataLoader函数准备数据
train_str = 'train'
test_str = 'test'
dataset_train = CreateDataLoader(opt,train_str,isTrain = True)
dataset_test = CreateDataLoader(opt,test_str,isTrain = False)
CreateDataLoader,实例化CustomDatasetDataLoader类并进行初始化
def CreateDataLoader(opt,phase,isTrain):
data_loader = CustomDatasetDataLoader()
print(data_loader.name())
data_loader.initialize(opt,phase,isTrain)
return data_loader
CustomDatasetDataLoader类实现如下:重写BaseDataLoader类,并在初始化时通过调用CreateDataset函数选择数据集,再定义torch.utils.data.DataLoader中的参数,如batch大小,是否给顺序等
class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'
def initialize(self, opt,phase,isTrain):
BaseDataLoader.initialize(self, opt)
self.dataset = CreateDataset(opt,phase,isTrain)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batchSize,
#shuffle = opt.shuffle if isTrain else not opt.shuffle,
#shuffle = False,
shuffle= isTrain,
num_workers=int(opt.nThreads))
print('-----------------dataloader------------------')
#print(self.dataset)
def load_data(self):
return self
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)
def __iter__(self):
for i, data in enumerate(self.dataloader):
if i >= self.opt.max_dataset_size:
break
yield data
CreateDataset函数实现如下:通过opt.dataset_mode参数选择数据集
def CreateDataset(opt,phase,isTrain):
dataset = None
if opt.dataset_mode == 'c3_Dataset':
from Data.Dataset_gather import c3_Dataset
dataset = c3_Dataset()
elif opt.dataset_mode == 'seq_Dataset':
from Data.Dataset_gather import seq_Dataset
dataset = seq_Dataset()
elif opt.dataset_mode == 'cells_Dataset':
from Data.Dataset_gather import cells_Dataset
dataset = cells_Dataset()
elif opt.dataset_mode == 'clustered_cells_Dataset':
from Data.Dataset_gather import clustered_cells_Dataset
dataset = clustered_cells_Dataset()
elif opt.dataset_mode == 'other_test_dataset':
from Data.Dataset_gather import other_test_dataset
dataset = other_test_dataset()
else:
raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)
print("dataset [%s] was created" % (dataset.name()))
dataset.initialize(opt,phase,isTrain)
return dataset
以cells_Dataset数据为例,先下载对应的train.txt和test.txt以及单张图片定位的txt。调用get_transform_函数,定义数据处理过程,这个函数中的处理过程在自动调用__getitem__函数时会自动进行,如对图片进行剪裁、缩放等。__getitem__函数在
for i, data in enumerate(dataset_train):
循环中会在每一次迭代时自动调用,返回的data即为return的数据
for i, data in enumerate(dataset_train):
class cells_Dataset(BaseDataset):
def initialize(self, opt ,phase ,isTrain):
self.opt = opt
self.root = opt.coderoot
self.transform_flag = True
str_train = '/label/cell_'+ str(opt.num_outputs) +'_train.txt'
str_test = '/label/cell_'+ str(opt.num_outputs) +'_test.txt'
str_localiza = '/label/cell_'+ str(opt.num_outputs) +'.txt'
if(phase == 'train'):
split_file = self.root + str_train
# split_file.replace(''\'',''/'')
isTrain = True
elif(phase == 'test'):
split_file = self.root + str_test
isTrain = False
else:
#isTrain = True
self.transform_flag = False
split_file = self.root + str_localiza
self.path = numpy.loadtxt(split_file, dtype=str, delimiter=' ', skiprows=0, usecols=(0))
#self.path = [os.path.join(self.opt.dataroot, path) for path in self.path]
self.path = [(self.opt.dataroot + path) for path in self.path]
self.lane= numpy.loadtxt(split_file, dtype=float, delimiter=' ', skiprows=0, usecols=(1))
self.cell= numpy.loadtxt(split_file, dtype=float, delimiter=' ', skiprows=0, usecols=(2))
self.mean_image = numpy.load(os.path.join(self.opt.dataroot , 'mean_image.npy'))#下载中值文件
self.size = len(self.path)
print('len(self.path):{:}'.format(self.size))
self.transform = get_transform_(opt,self.mean_image,self.transform_flag)#定义数据处理过程
#self.num_outputs = opt.num_outputs
def __getitem__(self, index):
path = self.path[index % self.size]
A_img = Image.open(path).convert('RGB')
#A_img.save('pic/'+path[-9:])
#print('************')
cell = self.cell[index % self.size]
lane = self.lane[index % self.size]
img = self.transform(A_img)
return {'img': img, 'cell': cell,
'path': path,'lane':lane}
def __len__(self):
return self.size
def name(self):
return 'cells_Dataset'
get_transform_函数的定义如下:使用lambda表达式将函数打包到transforms,在每次执行上面的__getitem__函数时,这些lambda表达式封装的函数都会对每张图片进行处理。
def get_transform_(opt, mean_image,isTrain = True):
transform_list = []
transform_list.append(transforms.Resize(opt.loadSize, Image.BICUBIC))
transform_list.append(transforms.Lambda(lambda img: __subtract_mean(img, mean_image)))
transform_list.append(transforms.Lambda(lambda img: __crop_image(img, opt.fineSize, isTrain)))
transform_list.append(transforms.Lambda(lambda img: __to_tensor(img)))
return transforms.Compose(transform_list)
def __scale_width(img, target_width):
ow, oh = img.size
if (ow == target_width):
return img
w = target_width
h = int(target_width * oh / ow)
return img.resize((w, h), Image.BICUBIC)
def __subtract_mean(img, mean_image):
if mean_image is None:
return numpy.array(img).astype('float')
return numpy.array(img).astype('float') - mean_image.astype('float')
def __crop_image(img, size, isTrain):
h, w = img.shape[0:2]
# w, h = img.size
if isTrain:
if w == size and h == size:
return img
x = numpy.random.randint(0, w - size)
y = numpy.random.randint(0, h - size)
else:
x = int(round((w - size) / 2.))
y = int(round((h - size) / 2.))
return img[y:y+size, x:x+size, :]
# return img.crop((x, y, x + size, y + size))
def __to_tensor(img):
return torch.from_numpy(img.transpose((2, 0, 1)))
Model模块
model = create_model(opt)
create_model创建model,根据 opt.model参数创建用于特征训练、定位训练和车道分类的网络
def create_model(opt,istest = False):
model = None
print(opt.model)
if opt.model == 'RESNET18'
from .model_set import RESNET18Model
model = RESNET18Model(): #训练特征提取网络、定位网络
elif opt.model == 'RESNET18_CELL':
from .model_set import RESNET18Model_CELL
model = RESNET18Model_CELL() :#训练分层定位网络
elif opt.model == 'RESNET18_3':
from .model_set import RESNET18Model_3
model = RESNET18Model_3() #训练车道分类网络
else:
raise ValueError("Model [%s] not recognized." % opt.model)
model.initialize(opt, istest)
#print("model [%s] was created" % (model.name()))
return model
def save_network(self, network, network_label, epoch):
save_filename = '%s_net_%s.pth' % (network_label, epoch)
save_path = os.path.join(self.save_dir, '%s_%s'%(self.opt.dataset_mode,self.opt.num_outputs))
if not os.path.exists(save_path):
os.makedirs(save_path)
save_path = os.path.join(save_path,save_filename)
torch.save(network.state_dict(), save_path)
以RESNET18Model类为例,讲解Model类的功能。RESNET18Model类重写了BaseModel类,BaseModel类中有个重要的函数实现,即save_network函数,在base_model文件中。
class RESNET18Model(BaseModel):
def name(self):
return 'RESNET18'
def initialize(self, opt,isTest = False):#调用resnet网络结构;定义优化方法为SGD;定义训练策略lr_scheduler.StepLR
self.opt = opt
BaseModel.initialize(self, opt)
self.isTrain = not isTest
self.net = networks.RESNET18(opt.num_outputs,isTest)#调用net为networks模块下的RESNET18网络
if self.isTrain:
self.old_lr = opt.lr
self.criterion = torch.nn.CrossEntropyLoss() #定义损失函数为交叉熵函数
self.optimizers = []
self.optimizer_A = torch.optim.SGD(self.net.parameters() , lr = opt.lr , momentum = 0.9)#定义优化方法为SGD
self.optimizers.append(self.optimizer_A)
self.schedulers = lr_scheduler.StepLR(self.optimizer_A, step_size=10, gamma=0.9)#定义训练策略,每10个epoch学习率×0.9
def set_input(self, input):#设置输出图像
self.input_img = input['img']
self.cell = input['cell']
self.image_paths = input['path']
def forward(self):#推理函数
self.input_img = Variable(self.input_img.float().cuda())
[self.features,self.pred] = self.net(self.input_img)
Z = F.softmax(self.pred,dim=1)#获得softmax输出
_ , self.preds_= torch.max(Z, 1)#获得softmax输出中概率最大的类
def extract_features(self):
f = deepcopy(self.features.data.cpu().numpy())#提取特征
return f
def testnet(self):#测试,只推理不backward
self.forward()
def trainnet(self):#训练
self.optimize_parameters()
def get_pred_result(self):
return self.preds_
def get_image_paths(self):
return self.image_paths
def backward(self):#反向传播
self.loss = self.criterion(self.pred,self.cell.long().cuda())
self.loss.backward()
def optimize_parameters(self):#训练优化
self.forward()
self.optimizer_A.zero_grad()
self.backward()
self.optimizer_A.step()
def get_current_acc(self,opt):#得到每个batch的正确率
self.cell = self.cell.long().cuda()
self.running_corrects = int(torch.sum(self.preds_ == self.cell.data))
return self.running_corrects
def get_current_loss(self,opt):#得到损失值
self.loss = self.criterion(self.pred,self.cell.long().cuda())
return float(self.loss)
def save(self, epoch):
self.save_network(self.net, 'RESNET18', epoch)
def forward_singlepic(self):#单张图片推理
self.input_img = Variable(self.input_img.float().cuda())
[self.features,self.pred] = self.net(self.input_img)
Z = F.softmax(self.pred,dim=1)
_ , self.preds_= torch.max(Z, 1)
return self.preds_
networks模块定义了各个网络的结构,其中class RESNET18(torch.nn.Module):是继承了torch.nn.Module类,重写了初始化函数 init和前向传播函数forward,在网络喂入图片数据后自动调用forward函数。 torch.load 返回的是一个 OrderedDict。关于模型和权重下载以及权重保存格式等,可以阅读这个博客。self.model.eval()将网络调到测试模式,测试模式时对ropout和batch normalization层的操作在训练和测试的时候是不一样的,具体讲解看这个博客。
class RESNET18(torch.nn.Module):
"""Constructs a ResNet-18 model.
"""
def __init__(self, num_output, isTest=False, gpu_ids=[]):
super(RESNET18, self).__init__()
self.model_name = 'resnet18'
self.gpu_ids = gpu_ids
state_dict = (torch.load('C:/code_p/Checkpoints/RESNET18/clustered_cells_Dataset_631/RESNET18_net_068.pth'))#预训练权重,其数据结构是每个键对应一个层
self.model = ResNet(BasicBlock, [2, 2, 2, 2], num_output)#定义ResNet的具体网络结构
pretrained = True
if pretrained:
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[6:] # remove `module.`
new_state_dict[name] = v
#print(v.size())
self.model.load_state_dict(new_state_dict,strict = True) #将权重下载到模型中,以模型各层的名字为准,名字不对应则报错,如果strict = False,名字不对应则直接略过。
if isTest:
self.model.eval()#在测试模式下
self.model.eval()
self.model = self.model.cuda()
print(self.model)
def forward(self, x):#前向传播函数
out = self.model(x)
return out
还有网络结构的具体实现,这部分为官方对resnet18的实现源码。这里不讲解,可以去网上搜一下资料。
训练函数
以训练特征提取网络为例,先创建数据,然后创建模型,在每次epoch中进行一次训练和一次测试。
def train_extract_features(num_outputs):
opt = BaseOptions().parse()
train_str = 'train'
test_str = 'test'
dataset_train = CreateDataLoader(opt,train_str,isTrain = True)#创建训练数据
dataset_test = CreateDataLoader(opt,test_str,isTrain = False)#创建测试数据
dataset_size_train = len(dataset_train)
dataset_size_test = len(dataset_test)
model = create_model(opt)
Loss_list_train = []
Loss_list_test = []
Accuracy_list_train = []
Accuracy_list_test = []
for epoch in range(opt.num_epochs):#epoch
epoch_acc_train = 0
epoch_acc_test = 0
epoch_loss_train = 0
epoch_loss_test = 0
print('Training...')
for i, data in enumerate(dataset_train): #iter
#print('[%04d/%04d] ' % (i, len(dataset_train)/opt.batchSize), end='\r')
model.set_input(data)#输入数据
model.trainnet()#训练网络
running_corrects = model.get_current_acc(opt)
running_loss = model.get_current_loss(opt)
epoch_acc_train = running_corrects + epoch_acc_train
epoch_loss_train = running_loss + epoch_loss_train
Loss_list_train.append(running_loss)
data_batch_size = len(data['cell'])
Accuracy_list_train.append(running_corrects/data_batch_size)
#print(running_loss)
#print(running_corrects)
print('[%04d/%04d] ------------------ corrects: %04f-------------------' % (i, len(dataset_train)/opt.batchSize,epoch_acc_train/(i+1)/data_batch_size), end='\r')
epoch_loss_train = epoch_loss_train/(i+1)
epoch_acc_train = epoch_acc_train*100/dataset_size_train
print(' Train epoch {:}:---- lr:{:} ----Acc: {:.4f}% loss:{:.4f}' .format(epoch,opt.lr,epoch_acc_train,epoch_loss_train))
print('Test...')
for i, data in enumerate(dataset_test):
model.set_input(data)
istest = True
model.testnet()
running_corrects = model.get_current_acc(opt)
running_loss = model.get_current_loss(opt)
epoch_acc_test = running_corrects + epoch_acc_test
epoch_loss_test = running_loss + epoch_loss_test
Loss_list_test.append(running_loss)
data_batch_size = len(data['cell'])
Accuracy_list_test.append(running_corrects/data_batch_size)
print('[%04d/%04d] ------------------ corrects: %04f-------------------' % (i, len(dataset_test)/opt.batchSize,epoch_acc_test/(i+1)/data_batch_size), end='\r')
epoch_loss_test = epoch_loss_test/(i+1)
epoch_acc_test = epoch_acc_test*100/dataset_size_test
print(' Test epoch {:}:---- lr:{:} ----Acc: {:.4f}% loss:{:.4f}' .format(epoch,opt.lr,epoch_acc_test,epoch_loss_test))
model.save(epoch)
if(epoch%1 == 0):
numpy.save('D:/figure/Loss_list_train_600_.npy',Loss_list_train)#保存loss
numpy.save('D:/figure/Loss_list_test_600_.npy',Loss_list_test)
numpy.save('D:/figure/Accuracy_list_train_600_.npy',Accuracy_list_train)
numpy.save('D:/figure/Accuracy_list_test_600_.npy',Accuracy_list_test)
if((epoch_acc_train>99.9)&(epoch_acc_test>99.9)):
break
GEN_CELL模块
GEN_CELL模块是对聚类后的cell进行后处理生成最终cell的模块
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Jul 28 10:19:06 2018
@author: zs
"""
import numpy
import os
from matplotlib import pyplot as plt
class CELL():
def __init__(self,label,k,outline_range,outline_range_threshold,kkk,opt):
''''
outline_range、outline_range_threshold分别是统计范围和小Cell的阈值,小于这个阈值则合并,kkk*outline_range是需要处理的范围,kk是个系数。
''''
self.root = opt.coderoot
self.outline_range = int(outline_range)
self.outline_range_threshold = outline_range_threshold
self.num = k
self.kkk = kkk
self.new_index = numpy.zeros(k)
self.index = label
self.dataset_size = len(label)
self.new_label = numpy.zeros(label.shape)
str_path = 'label/seq'+ str(opt.num_outputs) +'.txt'
split_file = os.path.join(self.root , str_path)
self.path = numpy.loadtxt(split_file, dtype=str, delimiter=' ', skiprows=0, usecols=(0))
str_path = 'label/3.txt'
split_file = os.path.join(self.root , str_path)
self.lines = numpy.loadtxt(split_file, dtype=int, delimiter=' ', skiprows=0, usecols=(1))
print(self.new_label.shape)
def idx_transformation(self):#将聚类形成无序label转换成以图像序列为准的有序label
n = 0
for i in range(0,self.dataset_size):
if(self.new_index[self.index[i]] == 0):
self.new_index[self.index[i]] = n
n+=1
for i in range(0,self.dataset_size):
self.new_label[i] = int(self.new_index[self.index[i]])-1
#print(self.new_label[i])
#self.lll = self.new_label.copy()
save_dir = 'C:/code_p/label/new_label%d.txt'%(self.num)
numpy.savetxt(save_dir,self.new_label,fmt='%d')
self.label_removed = self.new_label.copy()
def check_same_cell_differentlane(self):#检查小的cell和包含不同车道图像的cell
lane_cells_count = numpy.zeros(3,dtype=numpy.int)
for j in range(len(self.lines)):
lane_cells_count[self.lines[j]] += 1
lane_cells_count[1] = lane_cells_count[1]+lane_cells_count[0]
print(lane_cells_count)
save_dir = 'C:/code_p/label/ttt.txt'
f =open(save_dir,'w')
for j in range(len(self.new_label)):
text =str(self.new_label[j])
f.write(text)
f.write('\n')
f.close
for i in range(len(lane_cells_count)-1):
if (self.new_label[lane_cells_count[i]] == self.new_label[lane_cells_count[i]+1]):
for j in range(lane_cells_count[i]+1,len(self.new_label)):
self.new_label[j] += 1
def remove_smallcell(self):#移除小的cell
print('remove small cell...')
kkk = self.kkk
for i in range(0,len(self.new_label)-self.outline_range):
#print(i)
global point_sample
point_sample = numpy.zeros(int(2000),dtype=numpy.int)#保存一定范围内各类cell的数量,因为我们的cell数量这里不超过2000,因此长度设为2000,保证不会超出
#ll_ = numpy.zeros(int(2000),dtype=numpy.int)
#print(int(self.outline_range/2))
#last = self.label_removed[i]
#nn = 0
#ll_[0] = self.label_removed[i]
for j in range(0,self.outline_range):
# if (self.label_removed[i+j] != last):
# nn += 1
# ll_[nn] = self.label_removed[i+j]
point_sample[int(self.label_removed[i+j])] += 1#在[i,i+outline_range]范围内统计每类标签的数量
#last = self.label_removed[i+j]
for ii in range(0,2000):
flag_remove_once = 0
if(point_sample[ii]<=self.outline_range*self.outline_range_threshold)and(point_sample[ii]>0): #如果此类cell数量不是0并且小于阈值
for jj in range(int(self.outline_range/2-self.outline_range*kkk/2),int(self.outline_range/2+self.outline_range*kkk/2)):#对统计处理范围内的小cell进行合并
#print(ll.shape())
#print(ll_)
if (self.label_removed[i+jj] == ii):
print('remove%d'%(self.new_label[i+jj]))
self.label_removed[i+jj] = self.label_removed[i+jj-1]
flag_remove_once = 1
print(point_sample)
for iii in range(len(point_sample)):
if(point_sample[iii] >0):
print(point_sample[iii])
#print(ll_)
# plt.bar(range(i+0,i+self.outline_range),self.label_removed[i:i+self.outline_range],width = 1)
# plt.show()
if flag_remove_once:
i -= 1#保证可以处理交叉的CELL
break
#for i in range(0,len(self.new_label)):
#print(self.new_label[i],'---',self.lll[i])
#if(abs(self.new_label[i]-self.lll[i])>0.1):
#print('remove label: %d'%(self.lll[i]))
save_dir = 'C:/code_p/label/label_removed%d.txt'%(self.num)
numpy.savetxt(save_dir,self.label_removed,fmt='%d')
def checkandsort_cell(self):#对重复出现的大cell赋予新的标号
# for i in range(0,len(self.new_label)):
# #print(self.new_label[i],'---',self.lll[i])
# if(abs(self.new_label[i]-self.lll[i])>0.1):
# print('remove label')#: %d'%(self.lll[i]))
print('check and sort cell...')
cell= numpy.zeros(len(self.label_removed),dtype=numpy.int)
last = 0
class_plus = 0
reco = numpy.zeros(1000,dtype=numpy.int)
self.count = 0
for i in range(0,len(self.label_removed)):
n = int(self.label_removed[i])
if(reco[n]> 0)and(abs(n - last)>0.1):
#print('check repeated cell: %d:'%(n))
self.count += 1
if(abs(n-last)>0.1):
class_plus += 1
#print(class_plus)
cell[i] = class_plus
last = n
reco[n] += 1
print('check repeated %d cells'%(self.count))
return cell
def cou_cell(self):#统计处理前后的cell数量变化
nn = int(self.new_label[-1]+1)
print(nn)
cm = numpy.zeros(nn,dtype=numpy.int)
for i in range(0,len(self.new_label)):
cm[int(self.new_label[i])-1] += 1
print('before remove max: %d'%(max(cm)))
print('before remove min: %d'%(min(cm)))
#print(cm)
cm = numpy.zeros(nn,dtype=numpy.int)
for i in range(0,len(self.label_removed)):
cm[int(self.label_removed[i])-1] += 1
print('after remove max: %d'%(max(cm)))
print('after remove min: %d'%(min(cm)))
#print(cm)
def gen_cell(self):
self.idx_transformation()
#self.check_same_cell_differentlane()
self.remove_smallcell()
self.cou_cell()
cell = self.checkandsort_cell()
save_dir = 'C:/code_p/label/cell_%d.txt'%(cell[-1]+1)
self.path
lane_cells_count = numpy.zeros(3,dtype=numpy.int)
f =open(save_dir,'w')
for j in range(len(cell)):
lane_cells_count[self.lines[j]] += 1
text = self.path[j]+' '+str(self.lines[j])+' ' +str(cell[j])
f.write(text)
f.write('\n')
f.close
print(lane_cells_count)
lane_branch_start_cell = [0,cell[lane_cells_count[0]],cell[lane_cells_count[0]+lane_cells_count[1]]]
print(lane_branch_start_cell)
lane_cells_cla = [cell[lane_cells_count[0]-1]+1,cell[lane_cells_count[1]+lane_cells_count[0]-1]-cell[lane_cells_count[0]]+1,cell[lane_cells_count[2]+lane_cells_count[1]+lane_cells_count[0]-1]-cell[lane_cells_count[1]+lane_cells_count[0]]+1]
#lane_cells_cla = [lane_cells_cla[0],lane_cells_cla[1]-]
numpy.save('C:/code_p/label/lane_cells_cla_900.npy',lane_cells_cla)
numpy.save('C:/code_p/label/lane_branch_start_cell_900.npy',lane_branch_start_cell)
return [cell[-1]+1,lane_cells_cla]
论文第四章是分层定位网络相关的内容,主要根据浅层的特征对数据进行预分类,这应该不会是你的下一步重点,如果你想了解,相关程序主要在以下两个类:
model_set 文件内的
class RESNET18Model_CELL(BaseModel):
networks文件内的
class ResNet_redefined_cell(nn.Module):