最近想在tensorflow上写个3dcnn,看了下tensoflow提供了conv3d、tf.nn.avg_pool3d和tf.nn.max_pool3d函数。官方教程关于这3个函数的介绍太简短。
Conv3d
tf.nn.conv3d(input, filter, strides, padding, name=None)
Computes a 3-D convolution given 5-D input and filter tensors.
In signal processing, cross-correlation is a measure of similarity of two waveforms as a function of a time-lag applied to one of them. This is also known as a sliding dot product or sliding inner-product.
Our Conv3D implements a form of cross-correlation.
Args:
input: A Tensor. Must be one of the following types: float32, float64, int64, int32, uint8, uint16, int16, int8, complex64, complex128, qint8, quint8, qint32, half. Shape [batch, in_depth, in_height, in_width, in_channels].
filter: A Tensor. Must have the same type as input. Shape [filter_depth, filter_height, filter_width, in_channels, out_channels]. in_channels must match between input and filter.
strides: A list of ints that has length >= 5. 1-D tensor of length 5. The stride of the sliding window for each dimension of input. Must have strides[0] = strides[4] = 1.
padding: A string from: "SAME", "VALID". The type of padding algorithm to use.
name: A name for the operation (optional).
Returns:
A Tensor. Has the same type as input.
Pooling
tf.nn.avg_pool3d(input, ksize, strides, padding, name=None)
Performs 3D average pooling on the input.
Args:
input: A Tensor. Must be one of the following types: float32, float64, int64, int32, uint8, uint16, int16, int8, complex64, complex128, qint8, quint8, qint32, half. Shape [batch, depth, rows, cols, channels] tensor to pool over.
ksize: A list of ints that has length >= 5. 1-D tensor of length 5. The size of the window for each dimension of the input tensor. Must have ksize[0] = ksize[4] = 1.
strides: A list of ints that has length >= 5. 1-D tensor of length 5. The stride of the sliding window for each dimension of input. Must have strides[0] = strides[4] = 1.
padding: A string from: "SAME", "VALID". The type of padding algorithm to use.
name: A name for the operation (optional).
Returns:
A Tensor. Has the same type as input. The average pooled output tensor
tf.nn.max_pool3d(input, ksize, strides, padding, name=None)
Performs 3D max pooling on the input.
Args:
input: A Tensor. Must be one of the following types: float32, float64, int64, int32, uint8, uint16, int16, int8, complex64, complex128, qint8, quint8, qint32, half. Shape [batch, depth, rows, cols, channels] tensor to pool over.
ksize: A list of ints that has length >= 5. 1-D tensor of length 5. The size of the window for each dimension of the input tensor. Must have ksize[0] = ksize[4] = 1.
strides: A list of ints that has length >= 5. 1-D tensor of length 5. The stride of the sliding window for each dimension of input. Must have strides[0] = strides[4] = 1.
padding: A string from: "SAME", "VALID". The type of padding algorithm to use.
name: A name for the operation (optional).
Returns:
A Tensor. Has the same type as input. The max pooled output tensor.
其实比较简单,关键还是tensor的index,另外input的rank必须是5.
贴一下测试的代码。主要还是体会一下tensor中的index,pooling和conv3d差不多,我就拿tf.nn.max_pool3d测试了。
import tensorflow as tf
a=tf.constant([
[
[[1.0,2.0,3.0,4.0],
[5.0,6.0,7.0,8.0],
[8.0,7.0,6.0,5.0],
[4.0,3.0,2.0,1.0]],
[[4.0,3.0,2.0,1.0],
[8.0,7.0,6.0,5.0],
[1.0,2.0,3.0,4.0],
[5.0,6.0,7.0,8.0]],
[[10.,9.,8.,7.],
[6.,7.,8.,9.],
[5.,6.,7.,8.],
[4.,5.,6.,7.]]
],
[
[[1.0,2.0,3.0,4.0],
[5.0,6.0,7.0,8.0],
[8.0,7.0,6.0,5.0],
[4.0,3.0,2.0,1.0]],
[[4.0,3.0,2.0,1.0],
[8.0,7.0,6.0,5.0],
[1.0,2.0,3.0,4.0],
[5.0,6.0,7.0,8.0]],
[[10.,9.,8.,7.],
[6.,7.,8.,9.],
[5.,6.,7.,8.],
[4.,5.,6.,7.]]
],
[
[[1.0,2.0,3.0,4.0],
[5.0,6.0,7.0,8.0],
[8.0,7.0,6.0,5.0],
[4.0,3.0,2.0,1.0]],
[[4.0,3.0,2.0,1.0],
[8.0,7.0,6.0,5.0],
[1.0,2.0,3.0,4.0],
[5.0,6.0,7.0,8.0]],
[[10.,9.,8.,7.],
[6.,7.,8.,9.],
[5.,6.,7.,8.],
[4.,5.,6.,7.]]
]
]
)
a=tf.reshape(a,[1,3,4,6,2])
pooling_3d=tf.nn.max_pool3d(a,[1,2,2,1,1],[1,1,2,2,1],padding='VALID')
with tf.Session() as sess:
image=sess.run(a)
print("image:",image)
for j in range(0,2):
for i in range(0,6):
print("image",i+1,j+1,":",image[:,:,:,i,j])
result3=sess.run(pooling_3d)
for j in range(0,2):
for i in range(0,3):
print ('result3',i+1,j+1,result3[:,:,:,i,j])
一点体会:无论是卷积核还是pooling,就当成滑动窗口,参数ksize固定了滑窗的size,stride中每个值代表了每次在每个方向上的滑动距离,2d就只有从上往下和从左往右,3d多了一个从前往后。
本来想用图说的,太麻烦,算了。