Less than 100 lines of Tensorflow code!
# ONE
import tensorflow as tf
from tensorflow.contrib import layers, rnn
import os
import time
import math
import numpy as np
tf.setrandom_seed(0)
# model parameters
SEQLEN = 30
BATCHSIZE = 200
ALPHASIZE = 89
INTERNALSIZE = 512
NLAYERS = 3
learning_rate = 0.001
dropout_pkeep = 0.8
codetext, valitext, bookranges = load_data()
# the model
lr = tf.placeholder(tf.float32, name='lr') # learning rate
pkeep = tf.placeholder(tf.float32, name='pkeep') # dropout parameter
batchsize = tf.placeholder(tf.int32, name='batchsize')
# inputs
X = tf.placeholder(tf.uint8, [None, None], name='X')
Xo = tf.one_hot(X, ALPHASIZE, 1.0, 0.0)
# expected outputs
Y = tf.placeholder(tf.uint8, [None, None], name='Y')
Yo = tf.onehot(Y, ALPHASIZE, 1.0, 0.0)
# input state
Hin = tf.placeholder(tf.float32, [None, INTERNALSIZE*NLAYERS], name='Hin')
# hidden layers
cells = [rnn.GRUCell(INTERNALSIZE) for _ in range(NLAYERS)]
multicell = rnn.MultiRNNCell(cells, state_is_tuple=False)
# TWO
Yr, H = tf.nn.dynamicrnn(multicell, Xo, dtype=tf.float32, initial_state=Hin)
H = tf.identity(H, name='H')
# Softmax layer implementation
Yflat = tf.reshape(Yr, [1, INTERNALSIZE])
Ylogits = layers.linear(Yflat, ALPHASIZE)
Yflat = tf.reshape(Yo, [1, ALPHASIZE])
loss = tf.nn.softmax_cross_entropy_with_logits(logits=Ylogits, labels=Yflat)
loss = tf.reshape(loss, [batchsize, 1])
Yo = tf.nn.softmax(Ylogits, name='Yo')
Y = tf.argmax(Yo, 1)
Y = tf.reshape(Y, [batchsize, 1], name="Y")
trainstep = tf.train.AdamOptimizer(lr).minimize(loss)
# Init for saving models
if not os.path.exists("checkpoints"):
os.mkdir("checkpoints")
saver = tf.train.Saver(max_to_keep=1000)
# init
istate = np.zeros([BATCHSIZE, INTERNALSIZE*NLAYERS])
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
step = 0
# train on one minibatch at a time
for x, y, epoch in txt.rnnminibatch_sequencer(codetext, BATCHSIZE, SEQLEN, nb_ep
feed_dict = {X: x, Ye: ye, Hin: istate, lr: learning_rate, pkeep: dropout_pkeep, batc
, y, ostate = sess.run([trainstep, Y, H], feed_dict=feed_dict)
if step // 10 % _50_BATCHES == 0:
saved_file = saver.save(sess, 'checkpoints/rnn_train' + timestamp, global_st
print("Saved file: " + saved_file)
istate = ostate
step += BATCHSIZE * SEQLEN