Skip to content
Snippets Groups Projects
Commit 38408352 authored by Daniel Gordon's avatar Daniel Gordon
Browse files

added missing session

parent 878a90f6
No related branches found
No related tags found
No related merge requests found
......@@ -35,7 +35,7 @@ class Re3TrackerFactory(object):
self.is_initialized = False
def create_tracker(self, gpu_id=0):
tracker = Re3Tracker(reuse=self.is_initialized, gpu_id)
tracker = Re3Tracker(self.sess, reuse=self.is_initialized, gpu_id)
if not self.is_initialized:
basedir = os.path.dirname(__file__)
ckpt = tf.train.get_checkpoint_state(os.path.join(basedir, '..', LOG_DIR, 'checkpoints'))
......@@ -45,12 +45,14 @@ class Re3TrackerFactory(object):
class Re3Tracker(object):
def __init__(self, reuse=False, gpu_id=0):
def __init__(self, sess, reuse=False, gpu_id=0):
if gpu_id is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
else:
os.environ['CUDA_VISIBLE_DEVICES'] = str(GPU_ID)
self.sess = sess
self.imagePlaceholder = tf.placeholder(tf.uint8, shape=(None, CROP_SIZE, CROP_SIZE, 3))
self.prevLstmState = tuple([tf.placeholder(tf.float32, shape=(None, LSTM_SIZE)) for _ in xrange(4)])
self.batch_size = tf.placeholder(tf.int32, shape=())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment