Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
R
re3-tensorflow
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Daniel Gordon
re3-tensorflow
Commits
e4f8b246
Commit
e4f8b246
authored
7 years ago
by
Daniel Gordon
Browse files
Options
Downloads
Patches
Plain Diff
comments
parent
00df2250
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
training/unrolled_solver.py
+36
-17
36 additions, 17 deletions
training/unrolled_solver.py
with
36 additions
and
17 deletions
training/unrolled_solver.py
+
36
−
17
View file @
e4f8b246
...
...
@@ -50,7 +50,7 @@ SIMULATION_WIDTH = simulater.IMAGE_WIDTH
SIMULATION_HEIGHT
=
simulater
.
IMAGE_HEIGHT
simulater
.
NUM_DISTRACTORS
=
20
USE_IM
AGENET_PROB
=
0.5
USE_
S
IM
ULATER
=
0.5
USE_NETWORK_PROB
=
0.8
REAL_MOTION_PROB
=
1.0
/
8
AREA_CUTOFF
=
0.25
...
...
@@ -144,7 +144,7 @@ def main(FLAGS):
np
.
set_printoptions
(
suppress
=
True
)
np
.
set_printoptions
(
precision
=
4
)
# Read in and format GT
# Read in and format GT
.
# dict from (dataset_ind, video_id, track_id, image_id) to line in labels array
key_lookup
=
dict
()
datasets
=
[]
...
...
@@ -159,7 +159,7 @@ def main(FLAGS):
add_dataset
(
'
imagenet_video
'
)
# Tensorflow stu
ff
# Tensorflow s
e
tu
p
if
not
os
.
path
.
exists
(
LOG_DIR
):
os
.
makedirs
(
LOG_DIR
)
if
not
os
.
path
.
exists
(
LOG_DIR
+
'
/checkpoints
'
):
...
...
@@ -176,6 +176,7 @@ def main(FLAGS):
labelPlaceholder
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
(
ENQUEUE_BATCH_SIZE
,
delta
,
4
))
learningRate
=
tf
.
placeholder
(
tf
.
float32
)
# Set up the data queue for holding images and retrieving them from RAM rather than disk.
queue
=
tf_queue
.
TFQueue
(
sess
,
placeholders
=
[
imagePlaceholder
,
labelPlaceholder
],
max_queue_size
=
REPLAY_BUFFER_SIZE
,
...
...
@@ -205,6 +206,8 @@ def main(FLAGS):
saver
=
tf
.
train
.
Saver
()
longSaver
=
tf
.
train
.
Saver
()
# Create the nodes for single image forward passes for learning to fix mistakes.
# Parameters here are shared with the learned network.
if
'
,
'
in
FLAGS
.
cuda_visible_devices
:
with
tf
.
device
(
'
/gpu:1
'
):
forwardNetworkImagePlaceholder
=
tf
.
placeholder
(
tf
.
uint8
,
shape
=
(
2
,
CROP_SIZE
,
CROP_SIZE
,
3
))
...
...
@@ -221,6 +224,7 @@ def main(FLAGS):
forwardNetworkImagePlaceholder
,
num_unrolls
=
1
,
train
=
False
,
prevLstmState
=
prevLstmState
,
reuse
=
True
)
# Initialize the network and load saved parameters.
sess
.
run
(
init
)
startIter
=
0
if
FLAGS
.
restore
:
...
...
@@ -243,6 +247,7 @@ def main(FLAGS):
tf_util
.
conv_variable_summaries
(
var
,
scope
=
var
.
name
.
replace
(
'
/
'
,
'
_
'
)[:
-
2
])
summary_with_images
=
tf
.
summary
.
merge_all
()
# Logging stuff
robustness_ph
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
[])
lost_targets_ph
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
[])
mean_iou_ph
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
[])
...
...
@@ -261,19 +266,22 @@ def main(FLAGS):
sess
.
graph
.
finalize
()
# Read a sequence from the batch cache or simulate one and get the ground truth and crops for it.
def
get_data_sequence
():
# Preallocate the space for the images and labels.
tImage
=
np
.
zeros
((
delta
,
2
,
CROP_SIZE
,
CROP_SIZE
,
3
),
dtype
=
np
.
uint8
)
xywhLabels
=
np
.
zeros
((
delta
,
4
))
mirrored
Inds
=
random
.
random
()
<
0.5
use
ImagenetInds
=
random
.
random
()
<
USE_IM
AGENET_PROB
mirrored
=
random
.
random
()
<
0.5
use
Simulater
=
random
.
random
()
<
USE_
S
IM
ULATER
gtType
=
random
.
random
()
realMotion
=
random
.
random
()
<
REAL_MOTION_PROB
# Initialize first frame (give the network context).
if
useImagenetInds
:
if
useSimulater
:
# Initialize the simulation and run through a few frames.
trackingObj
,
trackedObjects
,
background
=
simulater
.
create_new_track
()
for
_
in
xrange
(
random
.
randint
(
0
,
200
)):
simulater
.
step
(
trackedObjects
)
...
...
@@ -294,6 +302,7 @@ def main(FLAGS):
images
=
[
np
.
zeros
((
SIMULATION_HEIGHT
,
SIMULATION_WIDTH
))]
else
:
# Read a new data sequence from batch cache and get the ground truth.
(
batchKey
,
images
)
=
getData
()
gtKey
=
batchKey
imageIndex
=
key_lookup
[
gtKey
]
...
...
@@ -302,11 +311,14 @@ def main(FLAGS):
bboxes
=
[]
cropBBoxes
=
[]
# bboxPrev starts at the initial box and is the best guess (or gt) for the image0 location.
# noisyBox holds the bboxPrev estimate plus some noise.
bboxPrev
=
initBox
lstmState
=
None
for
dd
in
xrange
(
delta
):
if
useImagenetInds
:
# bboxOn is the gt location in image1
if
useSimulater
:
bboxOn
=
trackingObj
.
get_object_box
()
else
:
newKey
=
list
(
gtKey
)
...
...
@@ -316,12 +328,12 @@ def main(FLAGS):
bboxOn
=
datasets
[
newKey
[
0
]][
imageIndex
,
:
4
].
copy
()
if
dd
==
0
:
noisyBox
=
bboxOn
.
copy
()
elif
not
realMotion
and
not
use
ImagenetInds
and
gtType
>=
USE_NETWORK_PROB
:
elif
not
realMotion
and
not
use
Simulater
and
gtType
>=
USE_NETWORK_PROB
:
noisyBox
=
add_noise
(
bboxOn
,
bboxOn
,
images
[
0
].
shape
[
1
],
images
[
0
].
shape
[
0
])
else
:
noisyBox
=
fix_bbox_intersection
(
bboxPrev
,
bboxOn
,
images
[
0
].
shape
[
1
],
images
[
0
].
shape
[
0
])
if
use
ImagenetInds
:
if
use
Simulater
:
patch
=
simulater
.
render_patch
(
bboxPrev
,
background
,
trackedObjects
)
tImage
[
dd
,
0
,...]
=
patch
if
dd
>
0
:
...
...
@@ -332,7 +344,7 @@ def main(FLAGS):
tImage
[
dd
,
0
,...]
=
im_util
.
get_cropped_input
(
images
[
max
(
dd
-
1
,
0
)],
bboxPrev
,
CROP_PAD
,
CROP_SIZE
)[
0
]
if
use
ImagenetInds
:
if
use
Simulater
:
patch
=
simulater
.
render_patch
(
noisyBox
,
background
,
trackedObjects
)
tImage
[
dd
,
1
,...]
=
patch
if
debug
:
...
...
@@ -346,12 +358,11 @@ def main(FLAGS):
xywhLabels
[
dd
,:]
=
shiftedBBoxXYWH
# Get next box.
if
gtType
<
USE_NETWORK_PROB
:
# Run through a single forward pass to get the next box estimate.
if
dd
<
delta
-
1
:
# Get next predicted box.
if
dd
==
0
:
lstmState
=
initialLstmState
,
lstmState
=
initialLstmState
feed_dict
=
{
forwardNetworkImagePlaceholder
:
tImage
[
dd
,...],
...
...
@@ -371,6 +382,7 @@ def main(FLAGS):
bboxPrev
=
bboxOn
if
FLAGS
.
debug
:
# Look at the inputs to make sure they are correct.
image0
=
tImage
[
dd
,
0
,...].
copy
()
image1
=
tImage
[
dd
,
1
,...].
copy
()
...
...
@@ -394,7 +406,7 @@ def main(FLAGS):
cv2
.
imshow
(
'
debug
'
,
subplot
[:,:,::
-
1
])
cv2
.
waitKey
(
0
)
if
mirrored
Inds
:
if
mirrored
:
tImage
=
np
.
fliplr
(
tImage
.
transpose
(
2
,
3
,
4
,
0
,
1
)).
transpose
(
3
,
4
,
0
,
1
,
2
)
xywhLabels
[...,
0
]
=
1
-
xywhLabels
[...,
0
]
...
...
@@ -414,6 +426,7 @@ def main(FLAGS):
for
_
in
xrange
(
10
):
queue
.
enqueue
(
new_data
)
else
:
# Start some data loading threads.
for
i
in
range
(
PARALLEL_SIZE
):
load_data_thread
=
threading
.
Thread
(
target
=
load_data
)
load_data_thread
.
daemon
=
True
...
...
@@ -424,6 +437,7 @@ def main(FLAGS):
timeTotal
=
0.000001
numIters
=
0
iteration
=
startIter
# Run training iterations in the main thread.
while
iteration
<
FLAGS
.
max_steps
:
if
(
iteration
-
1
)
%
10
==
0
:
currentTimeStart
=
time
.
time
()
...
...
@@ -462,6 +476,7 @@ def main(FLAGS):
print
'
Current Average: %.3f
'
%
((
time
.
time
()
-
currentTimeStart
)
/
10
)
print
''
# Save a checkpoint and remove old ones.
if
iteration
%
500
==
0
or
iteration
==
FLAGS
.
max_steps
:
checkpoint_file
=
os
.
path
.
join
(
LOG_DIR
,
'
checkpoints
'
,
'
model.ckpt
'
)
saver
.
save
(
sess
,
checkpoint_file
,
global_step
=
iteration
)
...
...
@@ -471,6 +486,7 @@ def main(FLAGS):
basename
=
os
.
path
.
basename
(
file
)
if
os
.
path
.
isfile
(
file
)
and
str
(
iteration
)
not
in
file
and
'
checkpoint
'
not
in
basename
:
os
.
remove
(
file
)
# Every once in a while save a checkpoint that isn't ever removed except by hand.
if
iteration
%
10000
==
0
or
iteration
==
FLAGS
.
max_steps
:
if
not
os
.
path
.
exists
(
LOG_DIR
+
'
/checkpoints/long_checkpoints
'
):
os
.
makedirs
(
LOG_DIR
+
'
/checkpoints/long_checkpoints
'
)
...
...
@@ -480,6 +496,7 @@ def main(FLAGS):
if
(
numIters
==
1
or
iteration
%
100
==
0
or
iteration
==
FLAGS
.
max_steps
):
# Write out the full graph sometimes.
if
(
numIters
==
1
or
iteration
==
FLAGS
.
max_steps
):
print
'
Running detailed summary
'
...
...
@@ -501,7 +518,7 @@ def main(FLAGS):
summary_writer
.
add_summary
(
summary_str
,
iteration
)
summary_writer
.
flush
()
if
(
FLAGS
.
run_val
and
(
numIters
==
1
or
iteration
%
500
==
0
)):
# Run a validation
test
in a separate process.
# Run a validation
set eval
in a separate process.
def
test_func
():
test_iter_on
=
iteration
print
'
Staring test iter
'
,
test_iter_on
...
...
@@ -523,6 +540,7 @@ def main(FLAGS):
test_thread
.
daemon
=
True
test_thread
.
start
()
if
FLAGS
.
output
:
# Look at some of the outputs.
print
'
new batch
'
queue
.
lock
.
acquire
()
images
=
debug_feed_dict
[
imagePlaceholder
].
astype
(
np
.
uint8
).
reshape
(
...
...
@@ -552,11 +570,13 @@ def main(FLAGS):
cv2
.
imshow
(
'
debug
'
,
subplot
[:,:,::
-
1
])
cv2
.
waitKey
(
0
)
queue
.
lock
.
release
()
except
KeyboardInterrupt
:
except
:
# Save if error or killed by ctrl-c.
if
not
debug
:
print
'
Saving...
'
checkpoint_file
=
os
.
path
.
join
(
LOG_DIR
,
'
checkpoints
'
,
'
model.ckpt
'
)
saver
.
save
(
sess
,
checkpoint_file
,
global_step
=
iteration
)
raise
if
__name__
==
'
__main__
'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'
Training for Re3.
'
)
...
...
@@ -573,6 +593,5 @@ if __name__ == '__main__':
parser
.
add_argument
(
'
--val_device
'
,
type
=
str
,
default
=
'
0
'
,
help
=
'
Device number or string for val process to use.
'
)
parser
.
add_argument
(
'
-m
'
,
'
--max_steps
'
,
type
=
int
,
default
=
NUM_ITERATIONS
,
help
=
'
Number of steps to run trainer.
'
)
FLAGS
=
parser
.
parse_args
()
#tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
main
(
FLAGS
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment