mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
fix xseg training
This commit is contained in:
parent
35945b257c
commit
254a7cf5cf
2 changed files with 62 additions and 62 deletions
|
@ -29,7 +29,10 @@ class XSegNet(object):
|
||||||
|
|
||||||
nn.initialize(data_format=data_format)
|
nn.initialize(data_format=data_format)
|
||||||
tf = nn.tf
|
tf = nn.tf
|
||||||
|
|
||||||
|
model_name = f'{name}_{resolution}'
|
||||||
|
self.model_filename_list = []
|
||||||
|
|
||||||
with tf.device ('/CPU:0'):
|
with tf.device ('/CPU:0'):
|
||||||
#Place holders on CPU
|
#Place holders on CPU
|
||||||
self.input_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,3) )
|
self.input_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,3) )
|
||||||
|
@ -39,18 +42,17 @@ class XSegNet(object):
|
||||||
with tf.device ('/CPU:0' if place_model_on_cpu else '/GPU:0'):
|
with tf.device ('/CPU:0' if place_model_on_cpu else '/GPU:0'):
|
||||||
self.model = nn.XSeg(3, 32, 1, name=name)
|
self.model = nn.XSeg(3, 32, 1, name=name)
|
||||||
self.model_weights = self.model.get_weights()
|
self.model_weights = self.model.get_weights()
|
||||||
|
if training:
|
||||||
|
if optimizer is None:
|
||||||
|
raise ValueError("Optimizer should be provided for training mode.")
|
||||||
|
self.opt = optimizer
|
||||||
|
self.opt.initialize_variables (self.model_weights, vars_on_cpu=place_model_on_cpu)
|
||||||
|
self.model_filename_list += [ [self.opt, f'{model_name}_opt.npy' ] ]
|
||||||
|
|
||||||
|
|
||||||
|
self.model_filename_list += [ [self.model, f'{model_name}.npy'] ]
|
||||||
|
|
||||||
model_name = f'{name}_{resolution}'
|
if not training:
|
||||||
self.model_filename_list = [ [self.model, f'{model_name}.npy'] ]
|
|
||||||
|
|
||||||
if training:
|
|
||||||
if optimizer is None:
|
|
||||||
raise ValueError("Optimizer should be provided for training mode.")
|
|
||||||
|
|
||||||
self.opt = optimizer
|
|
||||||
self.opt.initialize_variables (self.model_weights, vars_on_cpu=place_model_on_cpu)
|
|
||||||
self.model_filename_list += [ [self.opt, f'{model_name}_opt.npy' ] ]
|
|
||||||
else:
|
|
||||||
with tf.device ('/CPU:0' if run_on_cpu else '/GPU:0'):
|
with tf.device ('/CPU:0' if run_on_cpu else '/GPU:0'):
|
||||||
_, pred = self.model(self.input_t)
|
_, pred = self.model(self.input_t)
|
||||||
|
|
||||||
|
|
|
@ -15,23 +15,23 @@ class XSegModel(ModelBase):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, force_model_class_name='XSeg', **kwargs)
|
super().__init__(*args, force_model_class_name='XSeg', **kwargs)
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def on_initialize_options(self):
|
def on_initialize_options(self):
|
||||||
ask_override = self.ask_override()
|
ask_override = self.ask_override()
|
||||||
|
|
||||||
if not self.is_first_run() and ask_override:
|
if not self.is_first_run() and ask_override:
|
||||||
if io.input_bool(f"Restart training?", False, help_message="Reset model weights and start training from scratch."):
|
if io.input_bool(f"Restart training?", False, help_message="Reset model weights and start training from scratch."):
|
||||||
self.set_iter(0)
|
self.set_iter(0)
|
||||||
|
|
||||||
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf')
|
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf')
|
||||||
|
|
||||||
if self.is_first_run():
|
if self.is_first_run():
|
||||||
self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Choose the same as your deepfake model.").lower()
|
self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Choose the same as your deepfake model.").lower()
|
||||||
|
|
||||||
if self.is_first_run() or ask_override:
|
if self.is_first_run() or ask_override:
|
||||||
self.ask_batch_size(4, range=[2,16])
|
self.ask_batch_size(4, range=[2,16])
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def on_initialize(self):
|
def on_initialize(self):
|
||||||
device_config = nn.getCurrentDeviceConfig()
|
device_config = nn.getCurrentDeviceConfig()
|
||||||
|
@ -44,29 +44,29 @@ class XSegModel(ModelBase):
|
||||||
|
|
||||||
self.resolution = resolution = 256
|
self.resolution = resolution = 256
|
||||||
|
|
||||||
|
|
||||||
self.face_type = {'h' : FaceType.HALF,
|
self.face_type = {'h' : FaceType.HALF,
|
||||||
'mf' : FaceType.MID_FULL,
|
'mf' : FaceType.MID_FULL,
|
||||||
'f' : FaceType.FULL,
|
'f' : FaceType.FULL,
|
||||||
'wf' : FaceType.WHOLE_FACE,
|
'wf' : FaceType.WHOLE_FACE,
|
||||||
'head' : FaceType.HEAD}[ self.options['face_type'] ]
|
'head' : FaceType.HEAD}[ self.options['face_type'] ]
|
||||||
|
|
||||||
place_model_on_cpu = len(devices) == 0
|
place_model_on_cpu = len(devices) == 0
|
||||||
models_opt_device = '/CPU:0' if place_model_on_cpu else '/GPU:0'
|
models_opt_device = '/CPU:0' if place_model_on_cpu else '/GPU:0'
|
||||||
|
|
||||||
bgr_shape = nn.get4Dshape(resolution,resolution,3)
|
bgr_shape = nn.get4Dshape(resolution,resolution,3)
|
||||||
mask_shape = nn.get4Dshape(resolution,resolution,1)
|
mask_shape = nn.get4Dshape(resolution,resolution,1)
|
||||||
|
|
||||||
# Initializing model classes
|
# Initializing model classes
|
||||||
self.model = XSegNet(name='XSeg',
|
self.model = XSegNet(name='XSeg',
|
||||||
resolution=resolution,
|
resolution=resolution,
|
||||||
load_weights=not self.is_first_run(),
|
load_weights=not self.is_first_run(),
|
||||||
weights_file_root=self.get_model_root_path(),
|
weights_file_root=self.get_model_root_path(),
|
||||||
training=True,
|
training=True,
|
||||||
place_model_on_cpu=place_model_on_cpu,
|
place_model_on_cpu=place_model_on_cpu,
|
||||||
optimizer=nn.RMSprop(lr=0.0001, lr_dropout=0.3, name='opt'),
|
optimizer=nn.RMSprop(lr=0.0001, lr_dropout=0.3, name='opt'),
|
||||||
data_format=nn.data_format)
|
data_format=nn.data_format)
|
||||||
|
|
||||||
if self.is_training:
|
if self.is_training:
|
||||||
# Adjust batch size for multiple GPU
|
# Adjust batch size for multiple GPU
|
||||||
gpu_count = max(1, len(devices) )
|
gpu_count = max(1, len(devices) )
|
||||||
|
@ -79,20 +79,21 @@ class XSegModel(ModelBase):
|
||||||
|
|
||||||
gpu_losses = []
|
gpu_losses = []
|
||||||
gpu_loss_gvs = []
|
gpu_loss_gvs = []
|
||||||
|
|
||||||
for gpu_id in range(gpu_count):
|
|
||||||
with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
|
|
||||||
|
|
||||||
|
for gpu_id in range(gpu_count):
|
||||||
|
|
||||||
|
|
||||||
|
with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
|
||||||
with tf.device(f'/CPU:0'):
|
with tf.device(f'/CPU:0'):
|
||||||
# slice on CPU, otherwise all batch data will be transfered to GPU first
|
# slice on CPU, otherwise all batch data will be transfered to GPU first
|
||||||
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
|
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
|
||||||
gpu_input_t = self.model.input_t [batch_slice,:,:,:]
|
gpu_input_t = self.model.input_t [batch_slice,:,:,:]
|
||||||
gpu_target_t = self.model.target_t [batch_slice,:,:,:]
|
gpu_target_t = self.model.target_t [batch_slice,:,:,:]
|
||||||
|
|
||||||
# process model tensors
|
# process model tensors
|
||||||
gpu_pred_logits_t, gpu_pred_t = self.model.flow(gpu_input_t)
|
gpu_pred_logits_t, gpu_pred_t = self.model.flow(gpu_input_t)
|
||||||
gpu_pred_list.append(gpu_pred_t)
|
gpu_pred_list.append(gpu_pred_t)
|
||||||
|
|
||||||
gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1,2,3])
|
gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1,2,3])
|
||||||
gpu_losses += [gpu_loss]
|
gpu_losses += [gpu_loss]
|
||||||
|
|
||||||
|
@ -100,13 +101,13 @@ class XSegModel(ModelBase):
|
||||||
|
|
||||||
|
|
||||||
# Average losses and gradients, and create optimizer update ops
|
# Average losses and gradients, and create optimizer update ops
|
||||||
with tf.device (models_opt_device):
|
with tf.device(f'/CPU:0'):
|
||||||
pred = nn.concat(gpu_pred_list, 0)
|
#with tf.device (models_opt_device):
|
||||||
loss = tf.reduce_mean(gpu_losses)
|
pred = tf.concat(gpu_pred_list, 0)
|
||||||
|
loss = tf.concat(gpu_losses, 0)
|
||||||
loss_gv_op = self.model.opt.get_update_op (nn.average_gv_list (gpu_loss_gvs))
|
loss_gv_op = self.model.opt.get_update_op (nn.average_gv_list (gpu_loss_gvs))
|
||||||
|
|
||||||
|
|
||||||
# Initializing training and view functions
|
# Initializing training and view functions
|
||||||
def train(input_np, target_np):
|
def train(input_np, target_np):
|
||||||
l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np })
|
l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np })
|
||||||
|
@ -122,29 +123,29 @@ class XSegModel(ModelBase):
|
||||||
src_dst_generators_count = cpu_count // 2
|
src_dst_generators_count = cpu_count // 2
|
||||||
src_generators_count = cpu_count // 2
|
src_generators_count = cpu_count // 2
|
||||||
dst_generators_count = cpu_count // 2
|
dst_generators_count = cpu_count // 2
|
||||||
|
|
||||||
|
|
||||||
srcdst_generator = SampleGeneratorFaceXSeg([self.training_data_src_path, self.training_data_dst_path],
|
srcdst_generator = SampleGeneratorFaceXSeg([self.training_data_src_path, self.training_data_dst_path],
|
||||||
debug=self.is_debug(),
|
debug=self.is_debug(),
|
||||||
batch_size=self.get_batch_size(),
|
batch_size=self.get_batch_size(),
|
||||||
resolution=resolution,
|
resolution=resolution,
|
||||||
face_type=self.face_type,
|
face_type=self.face_type,
|
||||||
generators_count=src_dst_generators_count,
|
generators_count=src_dst_generators_count,
|
||||||
data_format=nn.data_format)
|
data_format=nn.data_format)
|
||||||
|
|
||||||
src_generator = SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
|
src_generator = SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=False),
|
sample_process_options=SampleProcessor.Options(random_flip=False),
|
||||||
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||||
],
|
],
|
||||||
generators_count=src_generators_count,
|
generators_count=src_generators_count,
|
||||||
raise_on_no_data=False )
|
raise_on_no_data=False )
|
||||||
dst_generator = SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
|
dst_generator = SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=False),
|
sample_process_options=SampleProcessor.Options(random_flip=False),
|
||||||
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||||
],
|
],
|
||||||
generators_count=dst_generators_count,
|
generators_count=dst_generators_count,
|
||||||
raise_on_no_data=False )
|
raise_on_no_data=False )
|
||||||
|
|
||||||
self.set_training_data_generators ([srcdst_generator, src_generator, dst_generator])
|
self.set_training_data_generators ([srcdst_generator, src_generator, dst_generator])
|
||||||
|
|
||||||
#override
|
#override
|
||||||
|
@ -154,21 +155,18 @@ class XSegModel(ModelBase):
|
||||||
#override
|
#override
|
||||||
def onSave(self):
|
def onSave(self):
|
||||||
self.model.save_weights()
|
self.model.save_weights()
|
||||||
|
|
||||||
#override
|
|
||||||
def onTrainOneIter(self):
|
|
||||||
|
|
||||||
|
|
||||||
image_np, mask_np = self.generate_next_samples()[0]
|
|
||||||
loss = self.train (image_np, mask_np)
|
|
||||||
|
|
||||||
return ( ('loss', loss ), )
|
#override
|
||||||
|
def onTrainOneIter(self):
|
||||||
|
image_np, mask_np = self.generate_next_samples()[0]
|
||||||
|
loss = self.train (image_np, mask_np)
|
||||||
|
return ( ('loss', np.mean(loss) ), )
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onGetPreview(self, samples):
|
def onGetPreview(self, samples):
|
||||||
n_samples = min(4, self.get_batch_size(), 800 // self.resolution )
|
n_samples = min(4, self.get_batch_size(), 800 // self.resolution )
|
||||||
|
|
||||||
srcdst_samples, src_samples, dst_samples = samples
|
srcdst_samples, src_samples, dst_samples = samples
|
||||||
image_np, mask_np = srcdst_samples
|
image_np, mask_np = srcdst_samples
|
||||||
|
|
||||||
I, M, IM, = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([image_np,mask_np] + self.view (image_np) ) ]
|
I, M, IM, = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([image_np,mask_np] + self.view (image_np) ) ]
|
||||||
|
@ -176,41 +174,41 @@ class XSegModel(ModelBase):
|
||||||
|
|
||||||
green_bg = np.tile( np.array([0,1,0], dtype=np.float32)[None,None,...], (self.resolution,self.resolution,1) )
|
green_bg = np.tile( np.array([0,1,0], dtype=np.float32)[None,None,...], (self.resolution,self.resolution,1) )
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
st = []
|
st = []
|
||||||
for i in range(n_samples):
|
for i in range(n_samples):
|
||||||
ar = I[i]*M[i]+0.5*I[i]*(1-M[i])+0.5*green_bg*(1-M[i]), IM[i], I[i]*IM[i]+0.5*I[i]*(1-IM[i]) + 0.5*green_bg*(1-IM[i])
|
ar = I[i]*M[i]+0.5*I[i]*(1-M[i])+0.5*green_bg*(1-M[i]), IM[i], I[i]*IM[i]+0.5*I[i]*(1-IM[i]) + 0.5*green_bg*(1-IM[i])
|
||||||
st.append ( np.concatenate ( ar, axis=1) )
|
st.append ( np.concatenate ( ar, axis=1) )
|
||||||
result += [ ('XSeg training faces', np.concatenate (st, axis=0 )), ]
|
result += [ ('XSeg training faces', np.concatenate (st, axis=0 )), ]
|
||||||
|
|
||||||
if len(src_samples) != 0:
|
if len(src_samples) != 0:
|
||||||
src_np, = src_samples
|
src_np, = src_samples
|
||||||
|
|
||||||
|
|
||||||
D, DM, = [ np.clip(nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([src_np] + self.view (src_np) ) ]
|
D, DM, = [ np.clip(nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([src_np] + self.view (src_np) ) ]
|
||||||
DM, = [ np.repeat (x, (3,), -1) for x in [DM] ]
|
DM, = [ np.repeat (x, (3,), -1) for x in [DM] ]
|
||||||
|
|
||||||
st = []
|
st = []
|
||||||
for i in range(n_samples):
|
for i in range(n_samples):
|
||||||
ar = D[i], DM[i], D[i]*DM[i] + 0.5*D[i]*(1-DM[i]) + 0.5*green_bg*(1-DM[i])
|
ar = D[i], DM[i], D[i]*DM[i] + 0.5*D[i]*(1-DM[i]) + 0.5*green_bg*(1-DM[i])
|
||||||
st.append ( np.concatenate ( ar, axis=1) )
|
st.append ( np.concatenate ( ar, axis=1) )
|
||||||
|
|
||||||
result += [ ('XSeg src faces', np.concatenate (st, axis=0 )), ]
|
result += [ ('XSeg src faces', np.concatenate (st, axis=0 )), ]
|
||||||
|
|
||||||
if len(dst_samples) != 0:
|
if len(dst_samples) != 0:
|
||||||
dst_np, = dst_samples
|
dst_np, = dst_samples
|
||||||
|
|
||||||
|
|
||||||
D, DM, = [ np.clip(nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([dst_np] + self.view (dst_np) ) ]
|
D, DM, = [ np.clip(nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([dst_np] + self.view (dst_np) ) ]
|
||||||
DM, = [ np.repeat (x, (3,), -1) for x in [DM] ]
|
DM, = [ np.repeat (x, (3,), -1) for x in [DM] ]
|
||||||
|
|
||||||
st = []
|
st = []
|
||||||
for i in range(n_samples):
|
for i in range(n_samples):
|
||||||
ar = D[i], DM[i], D[i]*DM[i] + 0.5*D[i]*(1-DM[i]) + 0.5*green_bg*(1-DM[i])
|
ar = D[i], DM[i], D[i]*DM[i] + 0.5*D[i]*(1-DM[i]) + 0.5*green_bg*(1-DM[i])
|
||||||
st.append ( np.concatenate ( ar, axis=1) )
|
st.append ( np.concatenate ( ar, axis=1) )
|
||||||
|
|
||||||
result += [ ('XSeg dst faces', np.concatenate (st, axis=0 )), ]
|
result += [ ('XSeg dst faces', np.concatenate (st, axis=0 )), ]
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
Model = XSegModel
|
Model = XSegModel
|
Loading…
Add table
Add a link
Reference in a new issue