mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
fix xseg training
This commit is contained in:
parent
35945b257c
commit
254a7cf5cf
2 changed files with 62 additions and 62 deletions
|
@ -29,7 +29,10 @@ class XSegNet(object):
|
|||
|
||||
nn.initialize(data_format=data_format)
|
||||
tf = nn.tf
|
||||
|
||||
|
||||
model_name = f'{name}_{resolution}'
|
||||
self.model_filename_list = []
|
||||
|
||||
with tf.device ('/CPU:0'):
|
||||
#Place holders on CPU
|
||||
self.input_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,3) )
|
||||
|
@ -39,18 +42,17 @@ class XSegNet(object):
|
|||
with tf.device ('/CPU:0' if place_model_on_cpu else '/GPU:0'):
|
||||
self.model = nn.XSeg(3, 32, 1, name=name)
|
||||
self.model_weights = self.model.get_weights()
|
||||
if training:
|
||||
if optimizer is None:
|
||||
raise ValueError("Optimizer should be provided for training mode.")
|
||||
self.opt = optimizer
|
||||
self.opt.initialize_variables (self.model_weights, vars_on_cpu=place_model_on_cpu)
|
||||
self.model_filename_list += [ [self.opt, f'{model_name}_opt.npy' ] ]
|
||||
|
||||
|
||||
self.model_filename_list += [ [self.model, f'{model_name}.npy'] ]
|
||||
|
||||
model_name = f'{name}_{resolution}'
|
||||
self.model_filename_list = [ [self.model, f'{model_name}.npy'] ]
|
||||
|
||||
if training:
|
||||
if optimizer is None:
|
||||
raise ValueError("Optimizer should be provided for training mode.")
|
||||
|
||||
self.opt = optimizer
|
||||
self.opt.initialize_variables (self.model_weights, vars_on_cpu=place_model_on_cpu)
|
||||
self.model_filename_list += [ [self.opt, f'{model_name}_opt.npy' ] ]
|
||||
else:
|
||||
if not training:
|
||||
with tf.device ('/CPU:0' if run_on_cpu else '/GPU:0'):
|
||||
_, pred = self.model(self.input_t)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue