将所有图片生成一个二进制数据集文件的过程
示例代码
#可以将图片和标签制作成二进制文件,读取二进制文件进行数据读取,会提高内存利用率。
#训练数据的特征用键值对的形式表示
def write_tfRecord(tfRecordName,image_path,label_path):
#创建写入
writer=tf.python_io.TFRecordWriter(tfRecordName)
num_pic=0
f=open(label_path,'r')
contents=f.readlines()
f.close()
#遍历每张图和标签
for content in contents:
value=content.split()
img_path=image_path+value[0]
img=Image.open(img_path)
img_raw=img.tobytes()
labels=[0]*10
lables[int(value[1])]=1
example=tf.train.Example(features=tf.train.features(feature={
'img_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
'label':tf.train.Feature(int64_list=tf.train.Int64List(value=labels))
}))
writer.write(example.SerializeToString())
num_pic+=1
#序列化
print("the number of picture:",num_pic)
writer.close()
def generate_tfRecord():
isExists=os.path.exists(data_path)
if not isExists:
os.makedirs(data_path)
print("Created")
else:
print("Already Exists")
write_tfRecord(tfRecord_train,image_train_path,label_train_path)
write_tfRecord(tfRecord_test,image_test_path,label_test_path)
#解析文件
def read_tfRecord(tfRecord_path):
#生成一个先入先出的队列
filename_queue=tf.train.string_input_producer([tfRecord_path])
reader=tf.TFRecordReader()
_,serialized_example=reader.read(filename_queue)
features=tf.parse_single_example(serialized_example,features={
'label':tf.FixedLenFeature([10],tf.int64),
'img_raw':tf.FixedLenFeature([],tf.string)
})
img=tf.decode_raw(features['img_raw'],tf.uint8)
img.set_shape([784])
img=tf.cast(img,tf.float32)*(1./255)
label=tf.cast(features['label'],tf.float32)
return img,lable
def get_tfrecord(num,isTrain=True):
if isTrain:
tfRecord_path=tfRecord_path
else:
tfRecord_path=tfRecord_test
img,label=read_tfRecord(tfRecord_path)
img_batch,label_batch=tf.train.shuffle_batch([img,label],batch_size=num,num_threads=2,capacity=1000,min_after_dequeue=700)
return img_batch,label_batch
def main():
generate_tfRecord()
if __name__=='__main__':
main()