diff --git a/tracker/re3_multi_tracker.py b/tracker/re3_multi_tracker.py
index 993d7107746a1d3f53b6c7f5fa63daa38be700d4..6f5c001f0d4a1c16c085b941c862762ce7493ed8 100644
--- a/tracker/re3_multi_tracker.py
+++ b/tracker/re3_multi_tracker.py
@@ -35,7 +35,7 @@ class Re3TrackerFactory(object):
         self.is_initialized = False
 
     def create_tracker(self, gpu_id=0):
-        tracker = Re3Tracker(reuse=self.is_initialized, gpu_id)
+        tracker = Re3Tracker(self.sess, reuse=self.is_initialized, gpu_id)
         if not self.is_initialized:
             basedir = os.path.dirname(__file__)
             ckpt = tf.train.get_checkpoint_state(os.path.join(basedir, '..', LOG_DIR, 'checkpoints'))
@@ -45,12 +45,14 @@ class Re3TrackerFactory(object):
 
 
 class Re3Tracker(object):
-    def __init__(self, reuse=False, gpu_id=0):
+    def __init__(self, sess, reuse=False, gpu_id=0):
         if gpu_id is not None:
             os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
         else:
             os.environ['CUDA_VISIBLE_DEVICES'] = str(GPU_ID)
 
+        self.sess = sess
+
         self.imagePlaceholder = tf.placeholder(tf.uint8, shape=(None, CROP_SIZE, CROP_SIZE, 3))
         self.prevLstmState = tuple([tf.placeholder(tf.float32, shape=(None, LSTM_SIZE)) for _ in xrange(4)])
         self.batch_size = tf.placeholder(tf.int32, shape=())