在TensorFlow框架下实现DBN网络

xiaoxiao2021-02-28  104

前言 在上一篇博客‘Windows下安装Tensorflow’的基础上,实现深度学习网络——DBN网络,以经典的手写字体识别为例子。

一、下载手写字体数据集,官方网址为:http://yann.lecun.com/exdb/mnist/ 下载: train-images-idx3-ubyte.gz: training set images (9912422 bytes) train-labels-idx1-ubyte.gz: training set labels (28881 bytes) t10k-images-idx3-ubyte.gz: test set images (1648877 bytes) t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)

二、参考github上的DBN实现源码为:https://github.com/myme5261314/dbn_tf

原作者实现的过程:https://gist.github.com/myme5261314/005ceac0483fc5a581cc

注:博主选择的python编辑器为pycharm编辑器,在windows下安装tensorflow后,需要在pycharm中配置路径,具体参考:http://blog.csdn.net/wx7788250/article/details/60877166

三,将github上下载的代码导入到pycharm中,会有几处错误,下面一一修改: 问题1:官方PIL(python image library)目前只支持python2.7及以下版本,所以当导入代码后,会提示ImportError: No module named 'Image'也就是没有import到image这个库。 解决办法:下载非官方库从http://www.lfd.uci.edu/~gohlke/pythonlibs/ 该网址上提供了支持64位系统的PIL文件,网站上叫做Pillow,下载下来,是个 .whl 结尾的文件,这个其实就是python使用的一种压缩文件,后缀名改成zip,可以打开。 这个需要用 pip 安装,在cmd下,进入该.whl下载的文件夹,输入

pip install Pillow‑4.1.1‑cp35‑cp35m‑win_amd64.whl

即可安装。 注意,这里有一段 Pillow is a replacement for PIL, the Python Image Library, which provides image processing functionality and supports many file formats. Use from PIL import Image instead of import Image. 意思就是说,要用 ‘ from PIL import Image’ 代替 ‘import Image’ 也就是将导入的程序代码中import Image 改为 from PIL import Image 即可。

from PIL import Image

问题2:代码中会有几处print的错误,由于版本问题在python3.X中,print的输出要加‘()’,所以修改后的样子为:

print (sess.run( err_sum, feed_dict={X: trX, rbm_w: n_w, rbm_vb: n_vb, rbm_hb: n_hb}))

问题3: 修改上述问题后,看似没有错误了,运行代码rbm_MNIST_test.py发现依旧存在问题,问题描述如下:

C:\Users\Administrator\AppData\Local\Programs\Python\Python35\python.exe D:/hlDL/tensorflow-DBN/dbn_tf-master/rbm_MNIST_test.py Traceback (most recent call last): File "D:/hlDL/tensorflow-DBN/dbn_tf-master/rbm_MNIST_test.py", line 16, in <module> mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) Extracting MNIST_data/train-images-idx3-ubyte.gz File "D:\hlDL\tensorflow-DBN\dbn_tf-master\input_data.py", line 150, in read_data_sets train_images = extract_images(local_file) File "D:\hlDL\tensorflow-DBN\dbn_tf-master\input_data.py", line 39, in extract_images buf = bytestream.read(rows * cols * num_images) File "C:\Users\Administrator\AppData\Local\Programs\Python\Python35\lib\gzip.py", line 274, in read return self._buffer.read(size) TypeError: only integer scalar arrays can be converted to a scalar index

修改方式:在input_data.py中找到代码

def _read32(bytestream): dt = numpy.dtype(numpy.uint32).newbyteorder('>') return numpy.frombuffer(bytestream.read(4), dtype=dt)

修改为:

def _read32(bytestream): dt = numpy.dtype(numpy.uint32).newbyteorder('>') return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]

在末尾出加上[0],再次运行,问题得到解决。这个问题似乎是新版本的Numpy存在的,最近的更新中,将单个元素数组作为一个标量来进行索引。

四、运行结果 注:根据具体代码,输出结果每10000次输出一次精度,共迭代60000次,同时每10000次输出一个识别结果的image,命名为:rbm_x.png。这里贴出第一幅识别结果和最后一次识别结果图如下: rbm_0 rbm_5

转载请注明原文地址: https://www.6miu.com/read-55224.html

最新回复(0)