机器人与人工智能爱好者论坛

 找回密码
 立即注册
查看: 20117|回复: 1
打印 上一主题 下一主题

y君用tensorflow跑MNIST的一个代码示例(带中文注释)

[复制链接]

285

主题

451

帖子

1万

积分

超级版主

Rank: 8Rank: 8

积分
13725
跳转到指定楼层
楼主
发表于 2015-12-9 19:50:27 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
y君用tensorflow跑MNIST的一个代码示例(带中文注释)

作者:y君   整理:morinson  来源:FGM学习组

y君同学在FGM学习研究组中比较早的搭建起环境,跑了示例。并在2015.12.4日的小组线上分享会议中进行了精彩演示和讲解。并分享了其做了详细中文注释的示例代码。获得大家一致好评。

最近有事情耽搁,一直没有将其给的带中文注释的代码发布出来,希望大家谅解。现在将代码及注释,发布出来。帖子中只是一部分,完整的请下载附件,阅读里面的readme.txt使用。


readme.txt
  1. 把yjun这个文件夹放到以下目录
  2. /root/tensorflow-master/tensorflow/g3doc/tutorials/mnist
  3. get_data.py 对应着原来的 input_data.py
  4. mnist_test.py 是我自己写的测试
  5. 还有部分没有完全理解 这周暂时就这么多东西了sss
复制代码

get_data.py
  1. # -*- coding: utf-8 -*-
  2. """Functions for downloading and reading MNIST data."""
  3. from __future__ import absolute_import
  4. from __future__ import division
  5. from __future__ import print_function
  6. import gzip
  7. import os
  8. import numpy
  9. from six.moves import urllib
  10. from six.moves import xrange  # pylint: disable=redefined-builtin

  11. SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'

  12. #DataSet对象
  13. '''
  14. self     类似 C++ this
  15. '''
  16. class DataSet(object):
  17.   def __init__(self, images, labels, fake_data=False):
  18.     if fake_data:
  19.       self._num_examples = 10000
  20.     else:
  21.       assert images.shape[0] == labels.shape[0], (
  22.           "images.shape: %s labels.shape: %s" % (images.shape,
  23.                                                  labels.shape))
  24.       #保存数据长度到num_examples
  25.       self._num_examples = images.shape[0]
  26.       
  27.       print ("self._num_examples : ", (self._num_examples))
  28.       
  29.       # Convert shape from [num examples, rows, columns, depth]
  30.       # to [num examples, rows*columns] (assuming depth == 1)
  31.       
  32.       print ("num ",images.shape[0])
  33.       print ("rows",images.shape[1])
  34.       print ("cols",images.shape[2])
  35.       
  36.       assert images.shape[3] == 1
  37.       images = images.reshape(images.shape[0],
  38.                               images.shape[1] * images.shape[2])
  39.       
  40.       # Convert from [0, 255] -> [0.0, 1.0].
  41.       images = images.astype(numpy.float32)
  42.       images = numpy.multiply(images, 1.0 / 255.0)
  43.       
  44.     print ("-------")
  45.     self._images = images
  46.     self._labels = labels
  47.     self._epochs_completed = 0
  48.     self._index_in_epoch = 0
  49.     print ("__init__ end")

  50.   
  51.   @property
  52.   def images(self):
  53.     return self._images

  54.   @property
  55.   def labels(self):
  56.     return self._labels

  57.   @property
  58.   def num_examples(self):
  59.     return self._num_examples

  60.   @property
  61.   def epochs_completed(self):
  62.     return self._epochs_completed
  63.   #作用为切换样本
  64.   def next_batch(self, batch_size, fake_data=False):
  65.     """Return the next `batch_size` examples from this data set."""
  66.     if fake_data:
  67.       fake_image = [1.0 for _ in xrange(784)]
  68.       fake_label = 0
  69.       return [fake_image for _ in xrange(batch_size)], [
  70.           fake_label for _ in xrange(batch_size)]
  71.    
  72.     start = self._index_in_epoch
  73.     self._index_in_epoch += batch_size
  74.     #判断index是否超过最大样本数量 超过的话就是训练完了
  75.     if self._index_in_epoch > self._num_examples:
  76.       # Finished epoch
  77.       self._epochs_completed += 1
  78.       # Shuffle the data
  79.       perm = numpy.arange(self._num_examples)
  80.       numpy.random.shuffle(perm)
  81.       self._images = self._images[perm]
  82.       self._labels = self._labels[perm]
  83.       # Start next epoch
  84.       start = 0
  85.       self._index_in_epoch = batch_size
  86.       assert batch_size <= self._num_examples
  87.     #计算切换样本后的index
  88.     end = self._index_in_epoch
  89.     #根据计算后的index  在iamges 和label 中读取数据
  90.     return self._images[start:end], self._labels[start:end]

  91. #读取一个int32类型的数据
  92. def _read32(bytestream):
  93.   dt = numpy.dtype(numpy.uint32).newbyteorder('>')
  94.   result = numpy.frombuffer(bytestream.read(4), dtype=dt)
  95.   return result

  96. def dense_to_one_hot(labels_dense, num_classes=10):
  97.   """Convert class labels from scalars to one-hot vectors."""
  98.   #获取标签数量
  99.   num_labels = labels_dense.shape[0]
  100.   index_offset = numpy.arange(num_labels) * num_classes
  101.   #清零
  102.   labels_one_hot = numpy.zeROS((num_labels, num_classes))
  103.   #末尾填1 ?
  104.   labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
  105.   return labels_one_hot

  106. #函数功能是下载文件 如果有文件就不下载了
  107. def maybe_download(filename, work_directory):
  108.   
  109.   if not os.path.exists(work_directory):
  110.     #如果不存在目录 就创建一个
  111.     os.mkdir(work_directory)
  112.   #把文件目录和文件名合成
  113.   filepath = os.path.join(work_directory, filename)
  114.   print('filename : ', filename)
  115.   if not os.path.exists(filepath):
  116.     #如果不存在文件就开始下载
  117.     print('SOURCE_URL + filename : ', SOURCE_URL + filename)
  118.     filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
  119.     #获取文件状态
  120.     statinfo = os.stat(filepath)
  121.     #输出下载成功
  122.     print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
  123.   return filepath

  124. def getfiledata(filepath):
  125.   print('mapfile : ', filepath)
  126.   #打开目标文件
  127.   with gzip.open(filepath) as bytestream:
  128.     magic = _read32(bytestream)
  129.     if magic != 2051:
  130.       raise ValueError(
  131.           'Invalid magic number %d in MNIST image file: %s' %
  132.           (magic, filepath))
  133.     num_images = _read32(bytestream)
  134.     rows = _read32(bytestream)
  135.     cols = _read32(bytestream)
  136.     buf = bytestream.read(rows * cols * num_images)
  137.     data = numpy.frombuffer(buf, dtype=numpy.uint8)
  138.     #扩展为 data[num_images][rows][cols][1]
  139.     data = data.reshape(num_images, rows, cols,1)
  140.    
  141.     return data
  142.   
  143. def getlabeldata(filepath, one_hot=False):
  144.   """Extract the labels into a 1D uint8 numpy array [index]."""
  145.   print('mapfile : ', filepath)
  146.   with gzip.open(filepath) as bytestream:
  147.     magic = _read32(bytestream)
  148.     if magic != 2049:
  149.       raise ValueError(
  150.           'Invalid magic number %d in MNIST label file: %s' %
  151.           (magic, filepath))
  152.     num_items = _read32(bytestream)
  153.     buf = bytestream.read(num_items)
  154.     labels = numpy.frombuffer(buf, dtype=numpy.uint8)
  155.     if one_hot:
  156.       return dense_to_one_hot(labels)
  157.     return labels
  158.   
  159. def downloadimagefile(train_dir,fake_data=False, one_hot=False):
  160.   class DataSets(object):
  161.     pass
  162.   data_sets = DataSets()
  163.   #55000训练图像,5000验证图片训练图像,5000验证图片
  164.   TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
  165.   #训练集标签
  166.   TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
  167.   #测试集图像 - 10000图片
  168.   TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
  169.   #测试集标签
  170.   TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
  171.   VALIDATION_SIZE = 5000
  172.   print ("load file")
  173.   #下载训练集图像  local_file 存放的是下载下来的文件路径
  174.   file_path = maybe_download(TRAIN_IMAGES, train_dir)
  175.   train_images = getfiledata(file_path)
  176.   #训练集标签
  177.   file_path = maybe_download(TRAIN_LABELS, train_dir)
  178.   train_labels = getlabeldata(file_path, one_hot=one_hot)
  179.   #测试集图像
  180.   file_path = maybe_download(TEST_IMAGES, train_dir)
  181.   test_images = getfiledata(file_path)
  182.   #测试集标签
  183.   file_path = maybe_download(TEST_LABELS, train_dir)
  184.   test_labels = getlabeldata(file_path, one_hot=one_hot)
  185.   
  186.   validation_images = train_images[:VALIDATION_SIZE]
  187.   validation_labels = train_labels[:VALIDATION_SIZE]
  188.   
  189.   train_images = train_images[VALIDATION_SIZE:]
  190.   train_labels = train_labels[VALIDATION_SIZE:]
  191.   
  192.   #作为主要训练集。
  193.   data_sets.train = DataSet(train_images, train_labels)
  194.   #用于迭代验证训练准确度。
  195.   data_sets.validation = DataSet(validation_images, validation_labels)
  196.   #用于最终测试训练准确度
  197.   data_sets.test = DataSet(test_images, test_labels)
  198.   
  199.   return data_sets
  200.   
复制代码

mnist_test.py
  1. # -*- coding: utf-8 -*-
  2. import tensorflow as tf
  3. import get_data
  4. import numpy

  5. #下载数据
  6. import math
  7. import numpy as np
  8. from array import array

  9. '''
  10. x = np.arange(1, 1001).reshape(20, 50)
  11. print x
  12. print x.flat[3]
  13. a = numpy.array([[1,2,3],[4,5,6],[7,8,9]])
  14. print a.shape
  15. '''

  16. mnist = get_data.downloadimagefile("MNIST_data/",one_hot=True)


  17. #创建一个session 方便在shell中用
  18. sess = tf.InteractiveSession()

  19. #开始构建回归模型
  20. #先创建占位符 占位符并不能直接输出 因为这个格式只有tf自己能够识别 这是一个二维数组 一维任意长度
  21. #一维为batch大小 二维是图片所有像素点 28*28=784
  22. x = tf.placeholder("float", [None, 784])
  23. y_ = tf.placeholder("float", shape=[None, 10])

  24. #定义权重W和偏置b 全部初始化为0
  25. #w是一个28*28*10的矩阵 因为有28*28个像素点 和10个输出值
  26. W = tf.Variable(tf.zeros([784,10]))

  27. #因为有10个分类 所以b 长度为10
  28. b = tf.Variable(tf.zeros([10]))

  29. #Variable 定义的变量需要在session之前调用initialize_all_variables完成初始化 才能载session中使用
  30. sess.run(tf.initialize_all_variables())

  31. #计算每个分类的softmax概率值
  32. y = tf.nn.softmax(tf.matmul(x,W) + b)

  33. #计算交叉熵
  34. #tf.reduce_sum把minibatch里的每张图片的交叉熵值都加起来了。我们计算的交叉熵是指整个minibatch的。
  35. cross_entropy = -tf.reduce_sum(y_*tf.log(y))

  36. #我们已经定义好了模型和训练的时候用的损失函数,接下来使用TensorFlow来训练。
  37. #因为TensorFlow知道整个计算图,它会用自动微分法来找到损失函数对于各个变量的梯度。
  38. #TensorFlow有大量内置的优化算法. 这个例子中,我们用最速下降法让交叉熵下降,步长为0.01.
  39. train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
  40. #这一行代码实际上是用来往计算图上添加一个新操作,其中包括计算梯度,计算每个参数的步长变化,并且计算出新的参数值。

  41. #train_step这个操作,用梯度下降来更新权值。因此,整个模型的训练可以通过反复地运行train_step来完成。
  42. #每次迭代加载50个样本
  43. rows = 28
  44. cols = 28
  45. num_images = 1

  46. for i in range(1000):
  47.   #next_batch会自动把数据转换成tf能认的
  48.   batch = mnist.train.next_batch(50)
  49.   train_step.run(feed_dict={x: batch[0], y_: batch[1]})

  50. correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
  51. accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

  52. print accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels})






  53. #下面构建一个多层卷积网络 上面的代码只有%91准确率
  54. #在创建模型之前,我们先来创建权重和偏置。一般来说,初始化时应加入轻微噪声,来打破对称性,防止零梯度的问题。
  55. #因为我们用的是ReLU,所以用稍大于0的值来初始化偏置能够避免节点输出恒为0的问题(dead neurons)。
  56. #为了不在建立模型的时候反复做初始化操作,我们定义两个函数用于初始化。
  57. def weight_variable(shape):
  58.   initial = tf.truncated_normal(shape, stddev=0.1)
  59.   return tf.Variable(initial)

  60. def bias_variable(shape):
  61.   initial = tf.constant(0.1, shape=shape)
  62.   return tf.Variable(initial)
  63. #卷积和池化
  64. #TensorFlow在卷积和池化上有很强的灵活性。
  65. #我们怎么处理边界?步长应该设多大?在这个实例里,我们会一直使用vanilla版本。
  66. #我们的卷积使用1步长(stride size),0边距(padding size)的模板,保证输出和输入是同一个大小。
  67. #我们的池化用简单传统的2x2大小的模板做max pooling。为了代码更简洁,我们把这部分抽象成一个函数。
  68. def conv2d(x, W):
  69.   return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

  70. def max_pool_2x2(x):
  71.   return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
  72.                         strides=[1, 2, 2, 1], padding='SAME')
  73. #第一层卷积
  74. #现在我们可以开始实现第一层了。它由一个卷积接一个max pooling完成。
  75. #卷积在每个5x5的patch中算出32个特征。权重是一个[5, 5, 1, 32]的张量,
  76. #前两个维度是patch的大小,接着是输入的通道数目,最后是输出的通道数目。输出对应一个同样大小的偏置向量。
  77. W_conv1 = weight_variable([5, 5, 1, 32])
  78. b_conv1 = bias_variable([32])
  79. #为了用这一层,我们把x变成一个4d向量,第2、3维对应图片的宽高,最后一维代表颜色通道。
  80. x_image = tf.reshape(x, [-1,28,28,1])

  81. #我们把x_image和权值向量进行卷积相乘,加上偏置,使用ReLU激活函数,最后max pooling。
  82. h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
  83. h_pool1 = max_pool_2x2(h_conv1)

  84. #第二层卷积
  85. #为了构建一个更深的网络,我们会把几个类似的层堆叠起来。第二层中,每个5x5的patch会得到64个特征。
  86. W_conv2 = weight_variable([5, 5, 32, 64])
  87. b_conv2 = bias_variable([64])

  88. h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
  89. h_pool2 = max_pool_2x2(h_conv2)

  90. #密集连接层
  91. #现在,图片降维到7x7,我们加入一个有1024个神经元的全连接层,用于处理整个图片。
  92. #我们把池化层输出的张量reshape成一些向量,乘上权重矩阵,加上偏置,使用ReLU激活。
  93. W_fc1 = weight_variable([7 * 7 * 64, 1024])
  94. b_fc1 = bias_variable([1024])

  95. h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
  96. h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

  97. #Dropout
  98. #为了减少过拟合,我们在输出层之前加入dropout。我们用一个placeholder来代表一个神经元在dropout中被保留的概率。
  99. #这样我们可以在训练过程中启用dropout,在测试过程中关闭dropout。
  100. #TensorFlow的tf.nn.dropout操作会自动处理神经元输出值的scale。所以用dropout的时候可以不用考虑scale。
  101. keep_prob = tf.placeholder("float")
  102. h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

  103. #输出层
  104. #最后,我们添加一个softmax层,就像前面的单层softmax regression一样。
  105. W_fc2 = weight_variable([1024, 10])
  106. b_fc2 = bias_variable([10])

  107. y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)

  108. #训练和评估模型
  109. #这次效果又有多好呢?我们用前面几乎一样的代码来测测看。
  110. #只是我们会用更加复杂的ADAM优化器来做梯度最速下降,在feed_dict中加入额外的参数keep_prob来控制dropout比例。
  111. #然后每100次迭代输出一次日志。
  112. cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
  113. train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
  114. correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
  115. accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
  116. sess.run(tf.initialize_all_variables())
  117. for i in range(20000):
  118.   batch = mnist.train.next_batch(50)
  119.   if i%100 == 0:
  120.     train_accuracy = accuracy.eval(feed_dict={
  121.         x:batch[0], y_: batch[1], keep_prob: 1.0})
  122.     print "step %d, training accuracy %g"%(i, train_accuracy)
  123.   train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})

  124. print "test accuracy %g"%accuracy.eval(feed_dict={
  125.     x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})






复制代码


yjun12-04.zip (18 KB, 下载次数: 100)


我是笨鸟,我先飞!
回复

使用道具 举报

0

主题

5

帖子

15

积分

注册会员

Rank: 2

积分
15
沙发
发表于 2015-12-10 09:37:17 | 只看该作者
已经下载,等俺安装完了再来评论,谢谢
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

关闭

站长推荐上一条 /1 下一条

QQ|Archiver|手机版|小黑屋|陕ICP备15012670号-1    

GMT+8, 2024-5-7 04:31 , Processed in 0.061634 second(s), 26 queries .

Powered by Discuz! X3.2

© 2001-2013 Comsenz Inc.

快速回复 返回顶部 返回列表