aihot  2017-12-02 08:57:09  机器学习 |   查看评论   

 线性回归

 

  以下代码来自GitHub - aymericdamien/TensorFlow-Examples: TensorFlow Tutorial and Examples for beginners,仅作学习用

  1.  import tensorflow as tf
  2.  import numpy
  3.  import matplotlib.pyplot as plt
  4.  rng = numpy.random

  5.  # Parameters
  6.  learning_rate = 0.01
  7.  training_epochs = 2000
  8.  display_step = 50

  9.  # Training Data
  10.  train_X = numpy.asarray([3.3,4.4,5.5,6.71,6.93,4.168,9.779,6.182,7.59,2.167,7.042,10.791,5.313,7.997,5.654,9.27,3.1])
  11.  train_Y = numpy.asarray([1.7,2.76,2.09,3.19,1.694,1.573,3.366,2.596,2.53,1.221,2.827,3.465,1.65,2.904,2.42,2.94,1.3])
  12.  n_samples = train_X.shape[0]

  13.  # tf Graph Input
  14.  X = tf.placeholder("float")
  15.  Y = tf.placeholder("float")

  16.  # Create Model

  17.  # Set model weights
  18.  W = tf.Variable(rng.randn(), name="weight")
  19.  b = tf.Variable(rng.randn(), name="bias")

  20.  # Construct a linear model
  21.  activation = tf.add(tf.mul(X, W), b)

  22.  # Minimize the squared errors
  23.  cost = tf.reduce_sum(tf.pow(activation-Y, 2))/(2*n_samples) #L2 loss
  24.  optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) #Gradient descent

  25.  # Initializing the variables
  26.  init = tf.initialize_all_variables()

  27.  # Launch the graph
  28.  with tf.Session() as sess:
  29.      sess.run(init)

  30.      # Fit all training data
  31.      for epoch in range(training_epochs):
  32.          for (x, y) in zip(train_X, train_Y):
  33.              sess.run(optimizer, feed_dict={X: x, Y: y})

  34.          #Display logs per epoch step
  35.          if epoch % display_step == 0:
  36.              print "Epoch:", '%04d' % (epoch+1), "cost=", \
  37.                  "{:.9f}".format(sess.run(cost, feed_dict={X: train_X, Y:train_Y})), \
  38.                  "W=", sess.run(W), "b=", sess.run(b)

  39.      print "Optimization Finished!"
  40.      print "cost=", sess.run(cost, feed_dict={X: train_X, Y: train_Y}), \
  41.            "W=", sess.run(W), "b=", sess.run(b)

  42.      #Graphic display
  43.      plt.plot(train_X, train_Y, 'ro', label='Original data')
  44.      plt.plot(train_X, sess.run(W) * train_X + sess.run(b), label='Fitted line')
  45.      plt.legend()
  46.      plt.show()

 

 

除特别注明外,本站所有文章均为 赢咖4注册 原创,转载请注明出处来自机器学习进阶笔记之一 | TensorFlow安装与入门

留言与评论(共有 0 条评论)
   
验证码:
[lianlun]1[/lianlun]