fix xseg training

This commit is contained in:
iperov 2020-12-11 15:47:11 +04:00
parent 35945b257c
commit 254a7cf5cf
2 changed files with 62 additions and 62 deletions

View file

@ -29,7 +29,10 @@ class XSegNet(object):
nn.initialize(data_format=data_format) nn.initialize(data_format=data_format)
tf = nn.tf tf = nn.tf
model_name = f'{name}_{resolution}'
self.model_filename_list = []
with tf.device ('/CPU:0'): with tf.device ('/CPU:0'):
#Place holders on CPU #Place holders on CPU
self.input_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,3) ) self.input_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,3) )
@ -39,18 +42,17 @@ class XSegNet(object):
with tf.device ('/CPU:0' if place_model_on_cpu else '/GPU:0'): with tf.device ('/CPU:0' if place_model_on_cpu else '/GPU:0'):
self.model = nn.XSeg(3, 32, 1, name=name) self.model = nn.XSeg(3, 32, 1, name=name)
self.model_weights = self.model.get_weights() self.model_weights = self.model.get_weights()
if training:
if optimizer is None:
raise ValueError("Optimizer should be provided for training mode.")
self.opt = optimizer
self.opt.initialize_variables (self.model_weights, vars_on_cpu=place_model_on_cpu)
self.model_filename_list += [ [self.opt, f'{model_name}_opt.npy' ] ]
self.model_filename_list += [ [self.model, f'{model_name}.npy'] ]
model_name = f'{name}_{resolution}' if not training:
self.model_filename_list = [ [self.model, f'{model_name}.npy'] ]
if training:
if optimizer is None:
raise ValueError("Optimizer should be provided for training mode.")
self.opt = optimizer
self.opt.initialize_variables (self.model_weights, vars_on_cpu=place_model_on_cpu)
self.model_filename_list += [ [self.opt, f'{model_name}_opt.npy' ] ]
else:
with tf.device ('/CPU:0' if run_on_cpu else '/GPU:0'): with tf.device ('/CPU:0' if run_on_cpu else '/GPU:0'):
_, pred = self.model(self.input_t) _, pred = self.model(self.input_t)

View file

@ -15,23 +15,23 @@ class XSegModel(ModelBase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, force_model_class_name='XSeg', **kwargs) super().__init__(*args, force_model_class_name='XSeg', **kwargs)
#override #override
def on_initialize_options(self): def on_initialize_options(self):
ask_override = self.ask_override() ask_override = self.ask_override()
if not self.is_first_run() and ask_override: if not self.is_first_run() and ask_override:
if io.input_bool(f"Restart training?", False, help_message="Reset model weights and start training from scratch."): if io.input_bool(f"Restart training?", False, help_message="Reset model weights and start training from scratch."):
self.set_iter(0) self.set_iter(0)
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf') default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf')
if self.is_first_run(): if self.is_first_run():
self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Choose the same as your deepfake model.").lower() self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Choose the same as your deepfake model.").lower()
if self.is_first_run() or ask_override: if self.is_first_run() or ask_override:
self.ask_batch_size(4, range=[2,16]) self.ask_batch_size(4, range=[2,16])
#override #override
def on_initialize(self): def on_initialize(self):
device_config = nn.getCurrentDeviceConfig() device_config = nn.getCurrentDeviceConfig()
@ -44,29 +44,29 @@ class XSegModel(ModelBase):
self.resolution = resolution = 256 self.resolution = resolution = 256
self.face_type = {'h' : FaceType.HALF, self.face_type = {'h' : FaceType.HALF,
'mf' : FaceType.MID_FULL, 'mf' : FaceType.MID_FULL,
'f' : FaceType.FULL, 'f' : FaceType.FULL,
'wf' : FaceType.WHOLE_FACE, 'wf' : FaceType.WHOLE_FACE,
'head' : FaceType.HEAD}[ self.options['face_type'] ] 'head' : FaceType.HEAD}[ self.options['face_type'] ]
place_model_on_cpu = len(devices) == 0 place_model_on_cpu = len(devices) == 0
models_opt_device = '/CPU:0' if place_model_on_cpu else '/GPU:0' models_opt_device = '/CPU:0' if place_model_on_cpu else '/GPU:0'
bgr_shape = nn.get4Dshape(resolution,resolution,3) bgr_shape = nn.get4Dshape(resolution,resolution,3)
mask_shape = nn.get4Dshape(resolution,resolution,1) mask_shape = nn.get4Dshape(resolution,resolution,1)
# Initializing model classes # Initializing model classes
self.model = XSegNet(name='XSeg', self.model = XSegNet(name='XSeg',
resolution=resolution, resolution=resolution,
load_weights=not self.is_first_run(), load_weights=not self.is_first_run(),
weights_file_root=self.get_model_root_path(), weights_file_root=self.get_model_root_path(),
training=True, training=True,
place_model_on_cpu=place_model_on_cpu, place_model_on_cpu=place_model_on_cpu,
optimizer=nn.RMSprop(lr=0.0001, lr_dropout=0.3, name='opt'), optimizer=nn.RMSprop(lr=0.0001, lr_dropout=0.3, name='opt'),
data_format=nn.data_format) data_format=nn.data_format)
if self.is_training: if self.is_training:
# Adjust batch size for multiple GPU # Adjust batch size for multiple GPU
gpu_count = max(1, len(devices) ) gpu_count = max(1, len(devices) )
@ -79,20 +79,21 @@ class XSegModel(ModelBase):
gpu_losses = [] gpu_losses = []
gpu_loss_gvs = [] gpu_loss_gvs = []
for gpu_id in range(gpu_count):
with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
for gpu_id in range(gpu_count):
with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
with tf.device(f'/CPU:0'): with tf.device(f'/CPU:0'):
# slice on CPU, otherwise all batch data will be transfered to GPU first # slice on CPU, otherwise all batch data will be transfered to GPU first
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
gpu_input_t = self.model.input_t [batch_slice,:,:,:] gpu_input_t = self.model.input_t [batch_slice,:,:,:]
gpu_target_t = self.model.target_t [batch_slice,:,:,:] gpu_target_t = self.model.target_t [batch_slice,:,:,:]
# process model tensors # process model tensors
gpu_pred_logits_t, gpu_pred_t = self.model.flow(gpu_input_t) gpu_pred_logits_t, gpu_pred_t = self.model.flow(gpu_input_t)
gpu_pred_list.append(gpu_pred_t) gpu_pred_list.append(gpu_pred_t)
gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1,2,3]) gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1,2,3])
gpu_losses += [gpu_loss] gpu_losses += [gpu_loss]
@ -100,13 +101,13 @@ class XSegModel(ModelBase):
# Average losses and gradients, and create optimizer update ops # Average losses and gradients, and create optimizer update ops
with tf.device (models_opt_device): with tf.device(f'/CPU:0'):
pred = nn.concat(gpu_pred_list, 0) #with tf.device (models_opt_device):
loss = tf.reduce_mean(gpu_losses) pred = tf.concat(gpu_pred_list, 0)
loss = tf.concat(gpu_losses, 0)
loss_gv_op = self.model.opt.get_update_op (nn.average_gv_list (gpu_loss_gvs)) loss_gv_op = self.model.opt.get_update_op (nn.average_gv_list (gpu_loss_gvs))
# Initializing training and view functions # Initializing training and view functions
def train(input_np, target_np): def train(input_np, target_np):
l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np }) l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np })
@ -122,29 +123,29 @@ class XSegModel(ModelBase):
src_dst_generators_count = cpu_count // 2 src_dst_generators_count = cpu_count // 2
src_generators_count = cpu_count // 2 src_generators_count = cpu_count // 2
dst_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2
srcdst_generator = SampleGeneratorFaceXSeg([self.training_data_src_path, self.training_data_dst_path], srcdst_generator = SampleGeneratorFaceXSeg([self.training_data_src_path, self.training_data_dst_path],
debug=self.is_debug(), debug=self.is_debug(),
batch_size=self.get_batch_size(), batch_size=self.get_batch_size(),
resolution=resolution, resolution=resolution,
face_type=self.face_type, face_type=self.face_type,
generators_count=src_dst_generators_count, generators_count=src_dst_generators_count,
data_format=nn.data_format) data_format=nn.data_format)
src_generator = SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), src_generator = SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=False), sample_process_options=SampleProcessor.Options(random_flip=False),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
], ],
generators_count=src_generators_count, generators_count=src_generators_count,
raise_on_no_data=False ) raise_on_no_data=False )
dst_generator = SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), dst_generator = SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=False), sample_process_options=SampleProcessor.Options(random_flip=False),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
], ],
generators_count=dst_generators_count, generators_count=dst_generators_count,
raise_on_no_data=False ) raise_on_no_data=False )
self.set_training_data_generators ([srcdst_generator, src_generator, dst_generator]) self.set_training_data_generators ([srcdst_generator, src_generator, dst_generator])
#override #override
@ -154,21 +155,18 @@ class XSegModel(ModelBase):
#override #override
def onSave(self): def onSave(self):
self.model.save_weights() self.model.save_weights()
#override
def onTrainOneIter(self):
image_np, mask_np = self.generate_next_samples()[0]
loss = self.train (image_np, mask_np)
return ( ('loss', loss ), ) #override
def onTrainOneIter(self):
image_np, mask_np = self.generate_next_samples()[0]
loss = self.train (image_np, mask_np)
return ( ('loss', np.mean(loss) ), )
#override #override
def onGetPreview(self, samples): def onGetPreview(self, samples):
n_samples = min(4, self.get_batch_size(), 800 // self.resolution ) n_samples = min(4, self.get_batch_size(), 800 // self.resolution )
srcdst_samples, src_samples, dst_samples = samples srcdst_samples, src_samples, dst_samples = samples
image_np, mask_np = srcdst_samples image_np, mask_np = srcdst_samples
I, M, IM, = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([image_np,mask_np] + self.view (image_np) ) ] I, M, IM, = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([image_np,mask_np] + self.view (image_np) ) ]
@ -176,41 +174,41 @@ class XSegModel(ModelBase):
green_bg = np.tile( np.array([0,1,0], dtype=np.float32)[None,None,...], (self.resolution,self.resolution,1) ) green_bg = np.tile( np.array([0,1,0], dtype=np.float32)[None,None,...], (self.resolution,self.resolution,1) )
result = [] result = []
st = [] st = []
for i in range(n_samples): for i in range(n_samples):
ar = I[i]*M[i]+0.5*I[i]*(1-M[i])+0.5*green_bg*(1-M[i]), IM[i], I[i]*IM[i]+0.5*I[i]*(1-IM[i]) + 0.5*green_bg*(1-IM[i]) ar = I[i]*M[i]+0.5*I[i]*(1-M[i])+0.5*green_bg*(1-M[i]), IM[i], I[i]*IM[i]+0.5*I[i]*(1-IM[i]) + 0.5*green_bg*(1-IM[i])
st.append ( np.concatenate ( ar, axis=1) ) st.append ( np.concatenate ( ar, axis=1) )
result += [ ('XSeg training faces', np.concatenate (st, axis=0 )), ] result += [ ('XSeg training faces', np.concatenate (st, axis=0 )), ]
if len(src_samples) != 0: if len(src_samples) != 0:
src_np, = src_samples src_np, = src_samples
D, DM, = [ np.clip(nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([src_np] + self.view (src_np) ) ] D, DM, = [ np.clip(nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([src_np] + self.view (src_np) ) ]
DM, = [ np.repeat (x, (3,), -1) for x in [DM] ] DM, = [ np.repeat (x, (3,), -1) for x in [DM] ]
st = [] st = []
for i in range(n_samples): for i in range(n_samples):
ar = D[i], DM[i], D[i]*DM[i] + 0.5*D[i]*(1-DM[i]) + 0.5*green_bg*(1-DM[i]) ar = D[i], DM[i], D[i]*DM[i] + 0.5*D[i]*(1-DM[i]) + 0.5*green_bg*(1-DM[i])
st.append ( np.concatenate ( ar, axis=1) ) st.append ( np.concatenate ( ar, axis=1) )
result += [ ('XSeg src faces', np.concatenate (st, axis=0 )), ] result += [ ('XSeg src faces', np.concatenate (st, axis=0 )), ]
if len(dst_samples) != 0: if len(dst_samples) != 0:
dst_np, = dst_samples dst_np, = dst_samples
D, DM, = [ np.clip(nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([dst_np] + self.view (dst_np) ) ] D, DM, = [ np.clip(nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([dst_np] + self.view (dst_np) ) ]
DM, = [ np.repeat (x, (3,), -1) for x in [DM] ] DM, = [ np.repeat (x, (3,), -1) for x in [DM] ]
st = [] st = []
for i in range(n_samples): for i in range(n_samples):
ar = D[i], DM[i], D[i]*DM[i] + 0.5*D[i]*(1-DM[i]) + 0.5*green_bg*(1-DM[i]) ar = D[i], DM[i], D[i]*DM[i] + 0.5*D[i]*(1-DM[i]) + 0.5*green_bg*(1-DM[i])
st.append ( np.concatenate ( ar, axis=1) ) st.append ( np.concatenate ( ar, axis=1) )
result += [ ('XSeg dst faces', np.concatenate (st, axis=0 )), ] result += [ ('XSeg dst faces', np.concatenate (st, axis=0 )), ]
return result return result
Model = XSegModel Model = XSegModel