mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 13:02:15 -07:00
fix xseg training
This commit is contained in:
parent
35945b257c
commit
254a7cf5cf
2 changed files with 62 additions and 62 deletions
|
@ -30,6 +30,9 @@ class XSegNet(object):
|
||||||
nn.initialize(data_format=data_format)
|
nn.initialize(data_format=data_format)
|
||||||
tf = nn.tf
|
tf = nn.tf
|
||||||
|
|
||||||
|
model_name = f'{name}_{resolution}'
|
||||||
|
self.model_filename_list = []
|
||||||
|
|
||||||
with tf.device ('/CPU:0'):
|
with tf.device ('/CPU:0'):
|
||||||
#Place holders on CPU
|
#Place holders on CPU
|
||||||
self.input_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,3) )
|
self.input_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,3) )
|
||||||
|
@ -39,18 +42,17 @@ class XSegNet(object):
|
||||||
with tf.device ('/CPU:0' if place_model_on_cpu else '/GPU:0'):
|
with tf.device ('/CPU:0' if place_model_on_cpu else '/GPU:0'):
|
||||||
self.model = nn.XSeg(3, 32, 1, name=name)
|
self.model = nn.XSeg(3, 32, 1, name=name)
|
||||||
self.model_weights = self.model.get_weights()
|
self.model_weights = self.model.get_weights()
|
||||||
|
|
||||||
model_name = f'{name}_{resolution}'
|
|
||||||
self.model_filename_list = [ [self.model, f'{model_name}.npy'] ]
|
|
||||||
|
|
||||||
if training:
|
if training:
|
||||||
if optimizer is None:
|
if optimizer is None:
|
||||||
raise ValueError("Optimizer should be provided for training mode.")
|
raise ValueError("Optimizer should be provided for training mode.")
|
||||||
|
|
||||||
self.opt = optimizer
|
self.opt = optimizer
|
||||||
self.opt.initialize_variables (self.model_weights, vars_on_cpu=place_model_on_cpu)
|
self.opt.initialize_variables (self.model_weights, vars_on_cpu=place_model_on_cpu)
|
||||||
self.model_filename_list += [ [self.opt, f'{model_name}_opt.npy' ] ]
|
self.model_filename_list += [ [self.opt, f'{model_name}_opt.npy' ] ]
|
||||||
else:
|
|
||||||
|
|
||||||
|
self.model_filename_list += [ [self.model, f'{model_name}.npy'] ]
|
||||||
|
|
||||||
|
if not training:
|
||||||
with tf.device ('/CPU:0' if run_on_cpu else '/GPU:0'):
|
with tf.device ('/CPU:0' if run_on_cpu else '/GPU:0'):
|
||||||
_, pred = self.model(self.input_t)
|
_, pred = self.model(self.input_t)
|
||||||
|
|
||||||
|
|
|
@ -81,8 +81,9 @@ class XSegModel(ModelBase):
|
||||||
gpu_loss_gvs = []
|
gpu_loss_gvs = []
|
||||||
|
|
||||||
for gpu_id in range(gpu_count):
|
for gpu_id in range(gpu_count):
|
||||||
with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
|
|
||||||
|
|
||||||
|
|
||||||
|
with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
|
||||||
with tf.device(f'/CPU:0'):
|
with tf.device(f'/CPU:0'):
|
||||||
# slice on CPU, otherwise all batch data will be transfered to GPU first
|
# slice on CPU, otherwise all batch data will be transfered to GPU first
|
||||||
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
|
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
|
||||||
|
@ -100,10 +101,10 @@ class XSegModel(ModelBase):
|
||||||
|
|
||||||
|
|
||||||
# Average losses and gradients, and create optimizer update ops
|
# Average losses and gradients, and create optimizer update ops
|
||||||
with tf.device (models_opt_device):
|
with tf.device(f'/CPU:0'):
|
||||||
pred = nn.concat(gpu_pred_list, 0)
|
#with tf.device (models_opt_device):
|
||||||
loss = tf.reduce_mean(gpu_losses)
|
pred = tf.concat(gpu_pred_list, 0)
|
||||||
|
loss = tf.concat(gpu_losses, 0)
|
||||||
loss_gv_op = self.model.opt.get_update_op (nn.average_gv_list (gpu_loss_gvs))
|
loss_gv_op = self.model.opt.get_update_op (nn.average_gv_list (gpu_loss_gvs))
|
||||||
|
|
||||||
|
|
||||||
|
@ -157,12 +158,9 @@ class XSegModel(ModelBase):
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onTrainOneIter(self):
|
def onTrainOneIter(self):
|
||||||
|
|
||||||
|
|
||||||
image_np, mask_np = self.generate_next_samples()[0]
|
image_np, mask_np = self.generate_next_samples()[0]
|
||||||
loss = self.train (image_np, mask_np)
|
loss = self.train (image_np, mask_np)
|
||||||
|
return ( ('loss', np.mean(loss) ), )
|
||||||
return ( ('loss', loss ), )
|
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onGetPreview(self, samples):
|
def onGetPreview(self, samples):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue