1.rnn_decoder
def _extract_argmax_and_embed(embedding, output_projection=None, update_embedding=True): """Get a loop_function that extracts the previous symbol and embeds it. Args: embedding: embedding tensor for symbols. output_projection: None or a pair (W, B). If provided, each fed previous output will first be multiplied by W and added B. update_embedding: Boolean; if False, the gradients will not propagate through the embeddings. Returns: A loop function. """ def loop_function(prev, _): if output_projection is not None: prev = tf.nn.xw_plus_b(prev, output_projection[0], output_projection[1]) prev_symbol = tf.argmax(prev, 1) # Note that gradients will not propagate through the second parameter of # embedding_lookup. emb_prev = tf.nn.embedding_lookup(embedding, prev_symbol) if not update_embedding: emb_prev = tf.stop_gradient(emb_prev) return emb_prev return loop_function tf.reset_default_graph() enc_inputs = tf.placeholder( tf.int32, shape=[None, enc_sentence_length], name='input_sentences') sequence_lengths = tf.placeholder( tf.int32, shape=[None], name='sentences_length') dec_inputs = tf.placeholder( tf.int32, shape=[None, dec_sentence_length+1], name='output_sentences') # batch_major => time_major enc_inputs_t = tf.transpose(enc_inputs, [1,0]) dec_inputs_t = tf.transpose(dec_inputs, [1,0]) with tf.device('/cpu:0'): dec_Wemb = tf.get_variable('dec_word_emb', initializer=tf.random_uniform([dec_vocab_size+2, dec_emb_size])) with tf.variable_scope('encoder'): enc_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) enc_cell = EmbeddingWrapper(enc_cell, enc_vocab_size+1, enc_emb_size) # enc_sent_len x batch_size x embedding_size enc_outputs, enc_last_state = tf.contrib.rnn.static_rnn( cell=enc_cell, inputs=tf.unstack(enc_inputs_t), sequence_length=sequence_lengths, dtype=tf.float32) dec_outputs = [] dec_predictions = [] with tf.variable_scope('decoder'): dec_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) dec_cell = OutputProjectionWrapper(dec_cell, dec_vocab_size+2) # EmbeddingWrapper & tf.unstack(dec_inputs_t) raises dimension error dec_emb_inputs = tf.nn.embedding_lookup(dec_Wemb, dec_inputs_t) # dec_outputs: [dec_sent_len+1 x batch_size x hidden_size] dec_outputs, dec_last_state = rnn_decoder( decoder_inputs=tf.unstack(dec_emb_inputs), initial_state=enc_last_state, cell=dec_cell, loop_function=_extract_argmax_and_embed(dec_Wemb)) # predictions: [batch_size x dec_sentence_lengths+1] predictions = tf.transpose(tf.argmax(tf.stack(dec_outputs), axis=-1), [1,0]) # labels & logits: [dec_sentence_length+1 x batch_size x dec_vocab_size+2] labels = tf.one_hot(dec_inputs_t, dec_vocab_size+2) logits = tf.stack(dec_outputs) loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( labels=labels, logits=logits)) # training_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss) training_op = tf.train.RMSPropOptimizer(learning_rate=0.0001).minimize(loss) 2.embedding_runn_decoder tf.reset_default_graph() enc_inputs = tf.placeholder( tf.int32, shape=[None, enc_sentence_length], name='input_sentences') sequence_lengths = tf.placeholder( tf.int32, shape=[None], name='sentences_length') dec_inputs = tf.placeholder( tf.int32, shape=[None, dec_sentence_length+1], name='output_sentences') # batch_major => time_major enc_inputs_t = tf.transpose(enc_inputs, [1,0]) dec_inputs_t = tf.transpose(dec_inputs, [1,0]) with tf.variable_scope('encoder'): enc_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) enc_cell = EmbeddingWrapper(enc_cell, enc_vocab_size+1, enc_emb_size) # enc_sent_len x batch_size x embedding_size enc_outputs, enc_last_state = tf.contrib.rnn.static_rnn( cell=enc_cell, inputs=tf.unstack(enc_inputs_t), sequence_length=sequence_lengths, dtype=tf.float32) dec_outputs = [] dec_predictions = [] with tf.variable_scope('decoder'): dec_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) dec_cell = OutputProjectionWrapper(dec_cell, dec_vocab_size+2) # dec_outputs: [dec_sent_len+1 x batch_size x hidden_size] dec_outputs, dec_last_state = embedding_rnn_decoder( decoder_inputs=tf.unstack(dec_inputs_t), initial_state=enc_last_state, cell=dec_cell, num_symbols=dec_vocab_size+2, embedding_size=dec_emb_size, feed_previous=True) # predictions: [batch_size x dec_sentence_lengths+1] predictions = tf.transpose(tf.argmax(tf.stack(dec_outputs), axis=-1), [1,0]) # labels & logits: [dec_sentence_length+1 x batch_size x dec_vocab_size+2] labels = tf.one_hot(dec_inputs_t, dec_vocab_size+2) logits = tf.stack(dec_outputs) loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( labels=labels, logits=logits)) # training_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss) training_op = tf.train.RMSPropOptimizer(learning_rate=0.0001).minimize(loss) 3.embedding_rnn_seq2seq tf.reset_default_graph() enc_inputs = tf.placeholder( tf.int32, shape=[None, enc_sentence_length], name='input_sentences') sequence_lengths = tf.placeholder( tf.int32, shape=[None], name='sentences_length') dec_inputs = tf.placeholder( tf.int32, shape=[None, dec_sentence_length+1], name='output_sentences') # batch_major => time_major enc_inputs_t = tf.transpose(enc_inputs, [1,0]) dec_inputs_t = tf.transpose(dec_inputs, [1,0]) rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) with tf.variable_scope("embedding_rnn_seq2seq"): # dec_outputs: [dec_sent_len+1 x batch_size x hidden_size] dec_outputs, dec_last_state = embedding_rnn_seq2seq( encoder_inputs=tf.unstack(enc_inputs_t), decoder_inputs=tf.unstack(dec_inputs_t), cell=rnn_cell, num_encoder_symbols=enc_vocab_size+1, num_decoder_symbols=dec_vocab_size+2, embedding_size=enc_emb_size, feed_previous=True) # predictions: [batch_size x dec_sentence_lengths+1] predictions = tf.transpose(tf.argmax(tf.stack(dec_outputs), axis=-1), [1,0]) # labels & logits: [dec_sentence_length+1 x batch_size x dec_vocab_size+2] labels = tf.one_hot(dec_inputs_t, dec_vocab_size+2) logits = tf.stack(dec_outputs) loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( labels=labels, logits=logits)) # training_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss) training_op = tf.train.RMSPropOptimizer(learning_rate=0.0001).minimize(loss)