V2EX = way to explore
V2EX 是一个关于分享和探索的地方
现在注册
已注册用户请  登录
LittleUqeer
V2EX  ›  TensorFlow

Tensorflow 笔记 使用 scan 构建 GRUcell

  •  
  •   LittleUqeer · 2017-02-20 14:54:27 +08:00 · 6132 次点击
    这是一个创建于 1745 天前的主题,其中的信息可能已经有所发展或是发生改变。

    看 RNN 的 paper 大多数集中在 RNNcell 内部构建,少数涉及 units 之间交互,

    Tensorflow 提供了几种最流行的 RNN 变种类,但没有 CNN 编写方便,这里分享一段使用 tf.scan 构建 GRUcell 代码,可以作为自定义 RNNcell 的参考。

    import numpy as np
    import pandas as pd
    import tensorflow as tf
    import pylab as pl
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
    %matplotlib inline
    
    
    class GRUcell(object):
        
        def __init__(self):
            self.in_length= 28
            self.in_width= 28
            self.hidden_layer_size = 2000
            self.out_classes = 10
            
            self.Wr = tf.Variable(tf.zeros([self.in_width, self.hidden_layer_size]))
            self.Wz = tf.Variable(tf.zeros([self.in_width, self.hidden_layer_size]))
            self.W_ = tf.Variable(tf.zeros([self.in_width, self.hidden_layer_size]))   
            self.Ur = tf.Variable(tf.truncated_normal([self.hidden_layer_size, self.hidden_layer_size]))
            self.Uz = tf.Variable(tf.truncated_normal([self.hidden_layer_size, self.hidden_layer_size]))
            self.U_ = tf.Variable(tf.truncated_normal([self.hidden_layer_size, self.hidden_layer_size]))
            
            self.Wout = tf.Variable(tf.truncated_normal([self.hidden_layer_size, self.out_classes], mean=0., stddev=.1))
            self.bout = tf.Variable(tf.truncated_normal([self.out_classes], mean=0., stddev=.1))
            
            self.inX = tf.placeholder(shape=[None, self.in_length, self.in_width], dtype=tf.float32)
            self.initial_hidden = tf.matmul(self.inX[:,0,:], tf.zeros([self.in_width, self.hidden_layer_size]))
            self.X = tf.transpose(self.inX, perm=[1,0,2])
        
        def GRU(self, hidden_states_previous, current_input_X):
            """
            GRU topology unit
            Note that the input order above is for the fn function
            The two tensors are entered for the fn function, 
            the first tensor is the output calculated in the previous step, 
            and the second tensor is the input value at this time
            """
            hp = hidden_states_previous
            x = current_input_X
            
            r = tf.sigmoid(tf.matmul(x, self.Wr) + tf.matmul(hp, self.Ur))
            z = tf.sigmoid(tf.matmul(x, self.Wz) + tf.matmul(hp, self.Uz)) 
            h_ = tf.tanh(tf.matmul(x, self.W_) + tf.matmul(r*hp ,self.U_))
            h = tf.multiply(hp,z) + tf.multiply((1-z),h_)
            return h      
        
        def PRO_TS(self):
            """
            Perform recursive operations in time series
            Iterates through time/ sequence to get all hidden state
            Input format : [in_length, batch_size, in_width]
            Output format : [in_length, batch_size, hidden_layer_size]
            """
            return tf.scan(fn= self.GRU, elems=self.X, initializer=self.initial_hidden)
        
        def Full_Connection_Layer(self, batch_hidden_layer_states):
            """
            The hidden layer state input is converted to 
            output through the full connection layer
            Input format : [batch_size, hidden_layer_size]
            Output format : [batch_size, out_classes]
            """
            return tf.nn.relu(tf.nn.bias_add(tf.matmul(batch_hidden_layer_states, self.Wout), self.bout))
            
        
        def deal_hidden_layer(self):
            """
            Handle all state output of hidden layer
            Input format : [in_length, batch_size, hidden_layer_size]
            Output format : [in_length, batch_size, out_classes]
            """
            #all_hidden_states = self.PRO_TS()
            #return tf.map_fn(self.Full_Connection_Layer, all_hidden_states)
            return tf.map_fn(self.Full_Connection_Layer, self.PRO_TS())  
        
        def last_output(self):
            tp = tf.reverse(self.deal_hidden_layer(), axis=[0])[0,:,:]
            return tf.nn.softmax(tp)        
    
    
    y = tf.placeholder(tf.float32, shape=[None, 10],name='inputs')
    rnn = GRUcell()
    output = rnn.last_output()
    cross_entropy = -tf.reduce_sum(y * tf.log(output))
    train_step = tf.train.AdamOptimizer(learning_rate=0.001).minimize(cross_entropy)
    correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(output,1))
    accuracy = (tf.reduce_mean(tf.cast(correct_prediction, tf.float32)))
    sess=tf.InteractiveSession()
    sess.run(tf.global_variables_initializer())
    
    
    batch_size = 32
    ss = []
    for i in range(5000):
        batch_x, batch_y = mnist.train.next_batch(batch_size)
        batch_x = batch_x.reshape((batch_size, 28, 28))
        sess.run(train_step, feed_dict={rnn.inX:batch_x, y:batch_y})
        t = sess.run(accuracy, feed_dict={rnn.inX:batch_x, y:batch_y})
        ss.append(t)
    
    ttt = pd.Series(ss)
    ttt.plot()
    
    

    使用 Tensorflow version 1.0 python 3.6

    源代码地址: https://uqer.io/community/share/58a9332bf1973300597ae209

    2 条回复    2017-02-21 10:28:13 +08:00
    aosp
        1
    aosp   2017-02-20 20:08:12 +08:00
    深度学习……好高端啊
    snnn
        2
    snnn   2017-02-21 10:28:13 +08:00 via Android
    厉害!
    关于   ·   帮助文档   ·   API   ·   FAQ   ·   我们的愿景   ·   广告投放   ·   感谢   ·   实用小工具   ·   3003 人在线   最高记录 5497   ·     Select Language
    创意工作者们的社区
    World is powered by solitude
    VERSION: 3.9.8.5 · 24ms · UTC 10:29 · PVG 18:29 · LAX 02:29 · JFK 05:29
    ♥ Do have faith in what you're doing.