aihot  2017-12-09 23:01:43  机器学习 |   查看评论   

VGG-16 tflearn实现

 

      tflearn 官方github上有给出基于tflearn下的VGG-16的实现

      from future import division, print_function, absolute_import

  1.  import tflearn
  2.  from tflearn.layers.core import input_data, dropout, fully_connected
  3.  from tflearn.layers.conv import conv_2d, max_pool_2d
  4.  from tflearn.layers.estimator import regression

  5.  # Data loading and preprocessing
  6.  import tflearn.datasets.oxflower17 as oxflower17
  7.  X, Y = oxflower17.load_data(one_hot=True)

  8.  # Building 'VGG Network'
  9.  network = input_data(shape=[None, 224, 224, 3])

  10.  network = conv_2d(network, 64, 3, activation='relu')
  11.  network = conv_2d(network, 64, 3, activation='relu')
  12.  network = max_pool_2d(network, 2, strides=2)

  13.  network = conv_2d(network, 128, 3, activation='relu')
  14.  network = conv_2d(network, 128, 3, activation='relu')
  15.  network = max_pool_2d(network, 2, strides=2)

  16.  network = conv_2d(network, 256, 3, activation='relu')
  17.  network = conv_2d(network, 256, 3, activation='relu')
  18.  network = conv_2d(network, 256, 3, activation='relu')
  19.  network = max_pool_2d(network, 2, strides=2)

  20.  network = conv_2d(network, 512, 3, activation='relu')
  21.  network = conv_2d(network, 512, 3, activation='relu')
  22.  network = conv_2d(network, 512, 3, activation='relu')
  23.  network = max_pool_2d(network, 2, strides=2)

  24.  network = conv_2d(network, 512, 3, activation='relu')
  25.  network = conv_2d(network, 512, 3, activation='relu')
  26.  network = conv_2d(network, 512, 3, activation='relu')
  27.  network = max_pool_2d(network, 2, strides=2)

  28.  network = fully_connected(network, 4096, activation='relu')
  29.  network = dropout(network, 0.5)
  30.  network = fully_connected(network, 4096, activation='relu')
  31.  network = dropout(network, 0.5)
  32.  network = fully_connected(network, 17, activation='softmax')

  33.  network = regression(network, optimizer='rmsprop',
  34.                       loss='categorical_crossentropy',
  35.                       learning_rate=0.001)

  36.  # Training
  37.  model = tflearn.DNN(network, checkpoint_path='model_vgg',
  38.                      max_checkpoints=1, tensorboard_verbose=0)
  39.  model.fit(X, Y, n_epoch=500, shuffle=True,
  40.            show_metric=True, batch_size=32, snapshot_step=500,
  41.            snapshot_epoch=False, run_id='vgg_oxflowers17')

       VGG-16 graph如下:

VGG-16 graph

对VGG,我个人觉得他的亮点不多,pre-trained的model我们可以很好的使用,但是不如GoogLeNet那样让我有眼前一亮的感觉。

 

除特别注明外,本站所有文章均为 赢咖4注册 原创,转载请注明出处来自机器学习进阶笔记之五 | 深入理解VGG\Residual Network

留言与评论(共有 0 条评论)
   
验证码:
[lianlun]1[/lianlun]