added tf_cpu_mode option for funit models

This commit is contained in:
Colombo 2019-09-21 09:44:34 +04:00
parent d9d10f91c2
commit 2a3b3f0021
3 changed files with 20 additions and 5 deletions

View file

@ -30,7 +30,8 @@ class FUNIT(object):
load_weights_locally=False,
weights_file_root=None,
is_training=True
is_training=True,
tf_cpu_mode=0,
):
exec( nnlib.import_all(), locals(), globals() )
@ -47,8 +48,8 @@ class FUNIT(object):
self.dis = modelify ( FUNIT.DiscriminatorFlow(nf=dis_nf, n_res_blks=dis_res_blks, num_classes=num_classes) ) (Input(bgr_shape))
self.G_opt = RMSprop(lr=0.0001, decay=0.0001, tf_cpu_mode=2 if 'tensorflow' in nnlib.active_DeviceConfig.backend else 0)
self.D_opt = RMSprop(lr=0.0001, decay=0.0001, tf_cpu_mode=2 if 'tensorflow' in nnlib.active_DeviceConfig.backend else 0)
self.G_opt = RMSprop(lr=0.0001, decay=0.0001, tf_cpu_mode=tf_cpu_mode)
self.D_opt = RMSprop(lr=0.0001, decay=0.0001, tf_cpu_mode=tf_cpu_mode)
xa = Input(bgr_shape, name="xa")
la = Input(label_shape, dtype="int32", name="la")