当keras在CPU上进行图片数据的训练时,只支持NHWC格式的图片输入数据,不支持NCHW的格式。输入NCHW格式的数据会报错:
1 | tensorflow.python.framework.errors_impl.InvalidArgumentError: Conv2DCustomBackpropFilterOp only supports NHWC. |
什么是NHWC和NCHW,很简单,NCHW 的意义是:
- N:一个batch内图片的数量。
- C:Chanel,图片的通道数。灰度图是1,RGB图是3。
- H:Height,垂直高度方向的像素个数。
- W: Weight,水平宽度方向的像素个数。
两种格式的区别就是Chanel维度在张量的末尾还是第二个。
像shape是[ 600, 1, 28, 28 ] 图片数据是NCHW格式的,表示600张图片,通道数是1,高和宽都是28像素。
从库中加载的 mnist 数据集图片格式是 NCHW 格式的,但是keras用CPU计算时不支持,keras添加一个卷积层的代码是:
1 | # add Conv 1 , shape(32,28,28) |
最后一个参数意思是输入数据chanel在前,即NCHW格式,这样运行代码时会报错。需要将数据数据改成NHWC格式,date_format参数改成 ‘channels_last’ 。
改变数组(或tensor)的形状(维度位置)可以用tf.transpose()
函数,
如 NCHW 格式的 X 的 shape 是 [ 6000, 1, 28, 28 ] 改成 [ 6000, 28, 28, 1 ] :
X = tf.transpose(X, [0, 2, 3, 1])
[0, 2, 3, 1] 中的数字表示维度的编号,原本 X 维度编号为 [0, 1, 2, 3],将第二维的Chanel移到末尾就是[0, 2, 3, 1] 。
这样修改后 X 就变成 NHWC 格式的数据了。
参考链接: