将keras/TensorFlow的变量格式从NCHW转换为NHWC

当keras在CPU上进行图片数据的训练时,只支持NHWC格式的图片输入数据,不支持NCHW的格式。输入NCHW格式的数据会报错:

1
tensorflow.python.framework.errors_impl.InvalidArgumentError:  Conv2DCustomBackpropFilterOp only supports NHWC.

什么是NHWC和NCHW,很简单,NCHW 的意义是:

  • N:一个batch内图片的数量。
  • C:Chanel,图片的通道数。灰度图是1,RGB图是3。
  • H:Height,垂直高度方向的像素个数。
  • W: Weight,水平宽度方向的像素个数。

两种格式的区别就是Chanel维度在张量的末尾还是第二个。

像shape是[ 600, 1, 28, 28 ] 图片数据是NCHW格式的,表示600张图片,通道数是1,高和宽都是28像素。

从库中加载的 mnist 数据集图片格式是 NCHW 格式的,但是keras用CPU计算时不支持,keras添加一个卷积层的代码是:

1
2
3
4
5
6
7
8
9
# add Conv 1 , shape(32,28,28)
model.add(Convolution2D(
batch_input_shape=(None, 28, 28, 1),
filters=32,
kernel_size=5,
strides=1,
padding='same',
data_format='channels_first'
)

最后一个参数意思是输入数据chanel在前,即NCHW格式,这样运行代码时会报错。需要将数据数据改成NHWC格式,date_format参数改成 ‘channels_last’ 。

改变数组(或tensor)的形状(维度位置)可以用tf.transpose() 函数,

如 NCHW 格式的 X 的 shape 是 [ 6000, 1, 28, 28 ] 改成 [ 6000, 28, 28, 1 ] :

X = tf.transpose(X, [0, 2, 3, 1])

[0, 2, 3, 1] 中的数字表示维度的编号,原本 X 维度编号为 [0, 1, 2, 3],将第二维的Chanel移到末尾就是[0, 2, 3, 1] 。

这样修改后 X 就变成 NHWC 格式的数据了。


参考链接:

http://www.cppcns.com/jiaoben/python/324161.html

https://yinguobing.com/convert-nchw-to-nhwc-in-tensorflow/