Skip to content
Snippets Groups Projects
Commit e4f8b246 authored by Daniel Gordon's avatar Daniel Gordon
Browse files

comments

parent 00df2250
No related branches found
No related tags found
No related merge requests found
......@@ -50,7 +50,7 @@ SIMULATION_WIDTH = simulater.IMAGE_WIDTH
SIMULATION_HEIGHT = simulater.IMAGE_HEIGHT
simulater.NUM_DISTRACTORS = 20
USE_IMAGENET_PROB = 0.5
USE_SIMULATER = 0.5
USE_NETWORK_PROB = 0.8
REAL_MOTION_PROB = 1.0 / 8
AREA_CUTOFF = 0.25
......@@ -144,7 +144,7 @@ def main(FLAGS):
np.set_printoptions(suppress=True)
np.set_printoptions(precision=4)
# Read in and format GT
# Read in and format GT.
# dict from (dataset_ind, video_id, track_id, image_id) to line in labels array
key_lookup = dict()
datasets = []
......@@ -159,7 +159,7 @@ def main(FLAGS):
add_dataset('imagenet_video')
# Tensorflow stuff
# Tensorflow setup
if not os.path.exists(LOG_DIR):
os.makedirs(LOG_DIR)
if not os.path.exists(LOG_DIR + '/checkpoints'):
......@@ -176,6 +176,7 @@ def main(FLAGS):
labelPlaceholder = tf.placeholder(tf.float32, shape=(ENQUEUE_BATCH_SIZE, delta, 4))
learningRate = tf.placeholder(tf.float32)
# Set up the data queue for holding images and retrieving them from RAM rather than disk.
queue = tf_queue.TFQueue(sess,
placeholders=[imagePlaceholder, labelPlaceholder],
max_queue_size=REPLAY_BUFFER_SIZE,
......@@ -205,6 +206,8 @@ def main(FLAGS):
saver = tf.train.Saver()
longSaver = tf.train.Saver()
# Create the nodes for single image forward passes for learning to fix mistakes.
# Parameters here are shared with the learned network.
if ',' in FLAGS.cuda_visible_devices:
with tf.device('/gpu:1'):
forwardNetworkImagePlaceholder = tf.placeholder(tf.uint8, shape=(2, CROP_SIZE, CROP_SIZE, 3))
......@@ -221,6 +224,7 @@ def main(FLAGS):
forwardNetworkImagePlaceholder, num_unrolls=1, train=False,
prevLstmState=prevLstmState, reuse=True)
# Initialize the network and load saved parameters.
sess.run(init)
startIter = 0
if FLAGS.restore:
......@@ -243,6 +247,7 @@ def main(FLAGS):
tf_util.conv_variable_summaries(var, scope=var.name.replace('/', '_')[:-2])
summary_with_images = tf.summary.merge_all()
# Logging stuff
robustness_ph = tf.placeholder(tf.float32, shape=[])
lost_targets_ph = tf.placeholder(tf.float32, shape=[])
mean_iou_ph = tf.placeholder(tf.float32, shape=[])
......@@ -261,19 +266,22 @@ def main(FLAGS):
sess.graph.finalize()
# Read a sequence from the batch cache or simulate one and get the ground truth and crops for it.
def get_data_sequence():
# Preallocate the space for the images and labels.
tImage = np.zeros((delta, 2, CROP_SIZE, CROP_SIZE, 3),
dtype=np.uint8)
xywhLabels = np.zeros((delta, 4))
mirroredInds = random.random() < 0.5
useImagenetInds = random.random() < USE_IMAGENET_PROB
mirrored = random.random() < 0.5
useSimulater = random.random() < USE_SIMULATER
gtType = random.random()
realMotion = random.random() < REAL_MOTION_PROB
# Initialize first frame (give the network context).
if useImagenetInds:
if useSimulater:
# Initialize the simulation and run through a few frames.
trackingObj, trackedObjects, background = simulater.create_new_track()
for _ in xrange(random.randint(0,200)):
simulater.step(trackedObjects)
......@@ -294,6 +302,7 @@ def main(FLAGS):
images = [np.zeros((SIMULATION_HEIGHT, SIMULATION_WIDTH))]
else:
# Read a new data sequence from batch cache and get the ground truth.
(batchKey, images) = getData()
gtKey = batchKey
imageIndex = key_lookup[gtKey]
......@@ -302,11 +311,14 @@ def main(FLAGS):
bboxes = []
cropBBoxes = []
# bboxPrev starts at the initial box and is the best guess (or gt) for the image0 location.
# noisyBox holds the bboxPrev estimate plus some noise.
bboxPrev = initBox
lstmState = None
for dd in xrange(delta):
if useImagenetInds:
# bboxOn is the gt location in image1
if useSimulater:
bboxOn = trackingObj.get_object_box()
else:
newKey = list(gtKey)
......@@ -316,12 +328,12 @@ def main(FLAGS):
bboxOn = datasets[newKey[0]][imageIndex, :4].copy()
if dd == 0:
noisyBox = bboxOn.copy()
elif not realMotion and not useImagenetInds and gtType >= USE_NETWORK_PROB:
elif not realMotion and not useSimulater and gtType >= USE_NETWORK_PROB:
noisyBox = add_noise(bboxOn, bboxOn, images[0].shape[1], images[0].shape[0])
else:
noisyBox = fix_bbox_intersection(bboxPrev, bboxOn, images[0].shape[1], images[0].shape[0])
if useImagenetInds:
if useSimulater:
patch = simulater.render_patch(bboxPrev, background, trackedObjects)
tImage[dd,0,...] = patch
if dd > 0:
......@@ -332,7 +344,7 @@ def main(FLAGS):
tImage[dd,0,...] = im_util.get_cropped_input(
images[max(dd-1, 0)], bboxPrev, CROP_PAD, CROP_SIZE)[0]
if useImagenetInds:
if useSimulater:
patch = simulater.render_patch(noisyBox, background, trackedObjects)
tImage[dd,1,...] = patch
if debug:
......@@ -346,12 +358,11 @@ def main(FLAGS):
xywhLabels[dd,:] = shiftedBBoxXYWH
# Get next box.
if gtType < USE_NETWORK_PROB:
# Run through a single forward pass to get the next box estimate.
if dd < delta - 1:
# Get next predicted box.
if dd == 0:
lstmState = initialLstmState,
lstmState = initialLstmState
feed_dict = {
forwardNetworkImagePlaceholder : tImage[dd,...],
......@@ -371,6 +382,7 @@ def main(FLAGS):
bboxPrev = bboxOn
if FLAGS.debug:
# Look at the inputs to make sure they are correct.
image0 = tImage[dd,0,...].copy()
image1 = tImage[dd,1,...].copy()
......@@ -394,7 +406,7 @@ def main(FLAGS):
cv2.imshow('debug', subplot[:,:,::-1])
cv2.waitKey(0)
if mirroredInds:
if mirrored:
tImage = np.fliplr(
tImage.transpose(2,3,4,0,1)).transpose(3,4,0,1,2)
xywhLabels[...,0] = 1 - xywhLabels[...,0]
......@@ -414,6 +426,7 @@ def main(FLAGS):
for _ in xrange(10):
queue.enqueue(new_data)
else:
# Start some data loading threads.
for i in range(PARALLEL_SIZE):
load_data_thread = threading.Thread(target=load_data)
load_data_thread.daemon = True
......@@ -424,6 +437,7 @@ def main(FLAGS):
timeTotal = 0.000001
numIters = 0
iteration = startIter
# Run training iterations in the main thread.
while iteration < FLAGS.max_steps:
if (iteration - 1) % 10 == 0:
currentTimeStart = time.time()
......@@ -462,6 +476,7 @@ def main(FLAGS):
print 'Current Average: %.3f' % ((time.time() - currentTimeStart) / 10)
print ''
# Save a checkpoint and remove old ones.
if iteration % 500 == 0 or iteration == FLAGS.max_steps:
checkpoint_file = os.path.join(LOG_DIR, 'checkpoints', 'model.ckpt')
saver.save(sess, checkpoint_file, global_step=iteration)
......@@ -471,6 +486,7 @@ def main(FLAGS):
basename = os.path.basename(file)
if os.path.isfile(file) and str(iteration) not in file and 'checkpoint' not in basename:
os.remove(file)
# Every once in a while save a checkpoint that isn't ever removed except by hand.
if iteration % 10000 == 0 or iteration == FLAGS.max_steps:
if not os.path.exists(LOG_DIR + '/checkpoints/long_checkpoints'):
os.makedirs(LOG_DIR + '/checkpoints/long_checkpoints')
......@@ -480,6 +496,7 @@ def main(FLAGS):
if (numIters == 1 or
iteration % 100 == 0 or
iteration == FLAGS.max_steps):
# Write out the full graph sometimes.
if (numIters == 1 or
iteration == FLAGS.max_steps):
print 'Running detailed summary'
......@@ -501,7 +518,7 @@ def main(FLAGS):
summary_writer.add_summary(summary_str, iteration)
summary_writer.flush()
if (FLAGS.run_val and (numIters == 1 or iteration % 500 == 0)):
# Run a validation test in a separate process.
# Run a validation set eval in a separate process.
def test_func():
test_iter_on = iteration
print 'Staring test iter', test_iter_on
......@@ -523,6 +540,7 @@ def main(FLAGS):
test_thread.daemon = True
test_thread.start()
if FLAGS.output:
# Look at some of the outputs.
print 'new batch'
queue.lock.acquire()
images = debug_feed_dict[imagePlaceholder].astype(np.uint8).reshape(
......@@ -552,11 +570,13 @@ def main(FLAGS):
cv2.imshow('debug', subplot[:,:,::-1])
cv2.waitKey(0)
queue.lock.release()
except KeyboardInterrupt:
except:
# Save if error or killed by ctrl-c.
if not debug:
print 'Saving...'
checkpoint_file = os.path.join(LOG_DIR, 'checkpoints', 'model.ckpt')
saver.save(sess, checkpoint_file, global_step=iteration)
raise
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Training for Re3.')
......@@ -573,6 +593,5 @@ if __name__ == '__main__':
parser.add_argument('--val_device', type=str, default='0', help='Device number or string for val process to use.')
parser.add_argument('-m', '--max_steps', type=int, default=NUM_ITERATIONS, help='Number of steps to run trainer.')
FLAGS = parser.parse_args()
#tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
main(FLAGS)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment