From 38408352d814d270846b9f6720a51a6b47943f6d Mon Sep 17 00:00:00 2001 From: Daniel Gordon <xkcd@cs.washington.edu> Date: Sun, 19 Nov 2017 14:47:49 -0800 Subject: [PATCH] added missing session --- tracker/re3_multi_tracker.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tracker/re3_multi_tracker.py b/tracker/re3_multi_tracker.py index 993d710..6f5c001 100644 --- a/tracker/re3_multi_tracker.py +++ b/tracker/re3_multi_tracker.py @@ -35,7 +35,7 @@ class Re3TrackerFactory(object): self.is_initialized = False def create_tracker(self, gpu_id=0): - tracker = Re3Tracker(reuse=self.is_initialized, gpu_id) + tracker = Re3Tracker(self.sess, reuse=self.is_initialized, gpu_id) if not self.is_initialized: basedir = os.path.dirname(__file__) ckpt = tf.train.get_checkpoint_state(os.path.join(basedir, '..', LOG_DIR, 'checkpoints')) @@ -45,12 +45,14 @@ class Re3TrackerFactory(object): class Re3Tracker(object): - def __init__(self, reuse=False, gpu_id=0): + def __init__(self, sess, reuse=False, gpu_id=0): if gpu_id is not None: os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id) else: os.environ['CUDA_VISIBLE_DEVICES'] = str(GPU_ID) + self.sess = sess + self.imagePlaceholder = tf.placeholder(tf.uint8, shape=(None, CROP_SIZE, CROP_SIZE, 3)) self.prevLstmState = tuple([tf.placeholder(tf.float32, shape=(None, LSTM_SIZE)) for _ in xrange(4)]) self.batch_size = tf.placeholder(tf.int32, shape=()) -- GitLab