From 46c69bb323de2eb3b41d04d0960083fa0c28c0cc Mon Sep 17 00:00:00 2001 From: Daniel Gordon <xkcd@cs.washington.edu> Date: Fri, 11 Aug 2017 11:51:02 -0700 Subject: [PATCH] fixed tiny bugs --- re3_utils/tensorflow_util/tf_queue.py | 2 +- tracker/network.py | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/re3_utils/tensorflow_util/tf_queue.py b/re3_utils/tensorflow_util/tf_queue.py index d475e3d..327fc71 100644 --- a/re3_utils/tensorflow_util/tf_queue.py +++ b/re3_utils/tensorflow_util/tf_queue.py @@ -83,7 +83,7 @@ class TFQueue(object): len(self.data_buffer), (len(self.data_buffer) - len(self.data_counts[self.data_counts > 0])), np.max(self.data_counts), - np.median(self.data_counts))) + np.median(self.data_counts[:len(self.data_buffer)]))) else: print('Buffer Full. Num unused: %d Max times used: %d Median times used: %d\n' % ( (len(self.data_buffer) - len(self.data_counts[self.data_counts > 0])), diff --git a/tracker/network.py b/tracker/network.py index 1735e2b..b1c3d18 100644 --- a/tracker/network.py +++ b/tracker/network.py @@ -141,11 +141,8 @@ def inference(inputs, num_unrolls, train, batch_size=None, prevLstmState=None, r lstmVars = [var for var in tf.trainable_variables() if 'lstm2' in var.name] for var in lstmVars: tf_util.variable_summaries(var, var.name[:-2]) - - with tf.variable_scope('lstm_output_concat'): - # BxTxC - outputs_concat = tf.concat(lstm2_outputs, 0) - outputs_reshape = tf_util.remove_axis(outputs_concat, 1) + # (BxT)xC + outputs_reshape = tf_util.remove_axis(lstm2_outputs, 1) # Final FC layer. with tf.variable_scope('fc_output'): -- GitLab