XSeg: added pretrain option.

This commit is contained in:
iperov 2021-07-30 17:24:21 +04:00
parent 83b1412da7
commit 55b947eab5
3 changed files with 111 additions and 46 deletions

View file

@ -90,7 +90,7 @@ class XSeg(nn.ModelBase):
self.out_conv = nn.Conv2D (base_ch, out_ch, kernel_size=3, padding='SAME') self.out_conv = nn.Conv2D (base_ch, out_ch, kernel_size=3, padding='SAME')
def forward(self, inp): def forward(self, inp, pretrain=False):
x = inp x = inp
x = self.conv01(x) x = self.conv01(x)
@ -126,29 +126,41 @@ class XSeg(nn.ModelBase):
x = nn.reshape_4D (x, 4, 4, self.base_ch*8 ) x = nn.reshape_4D (x, 4, 4, self.base_ch*8 )
x = self.up5(x) x = self.up5(x)
if pretrain:
x5 = tf.zeros_like(x5)
x = self.uconv53(tf.concat([x,x5],axis=nn.conv2d_ch_axis)) x = self.uconv53(tf.concat([x,x5],axis=nn.conv2d_ch_axis))
x = self.uconv52(x) x = self.uconv52(x)
x = self.uconv51(x) x = self.uconv51(x)
x = self.up4(x) x = self.up4(x)
if pretrain:
x4 = tf.zeros_like(x4)
x = self.uconv43(tf.concat([x,x4],axis=nn.conv2d_ch_axis)) x = self.uconv43(tf.concat([x,x4],axis=nn.conv2d_ch_axis))
x = self.uconv42(x) x = self.uconv42(x)
x = self.uconv41(x) x = self.uconv41(x)
x = self.up3(x) x = self.up3(x)
if pretrain:
x3 = tf.zeros_like(x3)
x = self.uconv33(tf.concat([x,x3],axis=nn.conv2d_ch_axis)) x = self.uconv33(tf.concat([x,x3],axis=nn.conv2d_ch_axis))
x = self.uconv32(x) x = self.uconv32(x)
x = self.uconv31(x) x = self.uconv31(x)
x = self.up2(x) x = self.up2(x)
if pretrain:
x2 = tf.zeros_like(x2)
x = self.uconv22(tf.concat([x,x2],axis=nn.conv2d_ch_axis)) x = self.uconv22(tf.concat([x,x2],axis=nn.conv2d_ch_axis))
x = self.uconv21(x) x = self.uconv21(x)
x = self.up1(x) x = self.up1(x)
if pretrain:
x1 = tf.zeros_like(x1)
x = self.uconv12(tf.concat([x,x1],axis=nn.conv2d_ch_axis)) x = self.uconv12(tf.concat([x,x1],axis=nn.conv2d_ch_axis))
x = self.uconv11(x) x = self.uconv11(x)
x = self.up0(x) x = self.up0(x)
if pretrain:
x0 = tf.zeros_like(x0)
x = self.uconv02(tf.concat([x,x0],axis=nn.conv2d_ch_axis)) x = self.uconv02(tf.concat([x,x0],axis=nn.conv2d_ch_axis))
x = self.uconv01(x) x = self.uconv01(x)

View file

@ -81,8 +81,8 @@ class XSegNet(object):
def get_resolution(self): def get_resolution(self):
return self.resolution return self.resolution
def flow(self, x): def flow(self, x, pretrain=False):
return self.model(x) return self.model(x, pretrain=pretrain)
def get_weights(self): def get_weights(self):
return self.model_weights return self.model_weights

View file

@ -25,12 +25,19 @@ class XSegModel(ModelBase):
self.set_iter(0) self.set_iter(0)
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf') default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf')
default_pretrain = self.options['pretrain'] = self.load_or_def_option('pretrain', False)
if self.is_first_run(): if self.is_first_run():
self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Choose the same as your deepfake model.").lower() self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Choose the same as your deepfake model.").lower()
if self.is_first_run() or ask_override: if self.is_first_run() or ask_override:
self.ask_batch_size(4, range=[2,16]) self.ask_batch_size(4, range=[2,16])
self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain)
if self.options['pretrain'] and self.get_pretraining_data_path() is None:
raise Exception("pretraining_data_path is not defined")
self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False)
#override #override
def on_initialize(self): def on_initialize(self):
@ -51,6 +58,7 @@ class XSegModel(ModelBase):
'wf' : FaceType.WHOLE_FACE, 'wf' : FaceType.WHOLE_FACE,
'head' : FaceType.HEAD}[ self.options['face_type'] ] 'head' : FaceType.HEAD}[ self.options['face_type'] ]
place_model_on_cpu = len(devices) == 0 place_model_on_cpu = len(devices) == 0
models_opt_device = '/CPU:0' if place_model_on_cpu else nn.tf_default_device_name models_opt_device = '/CPU:0' if place_model_on_cpu else nn.tf_default_device_name
@ -67,12 +75,17 @@ class XSegModel(ModelBase):
optimizer=nn.RMSprop(lr=0.0001, lr_dropout=0.3, name='opt'), optimizer=nn.RMSprop(lr=0.0001, lr_dropout=0.3, name='opt'),
data_format=nn.data_format) data_format=nn.data_format)
self.pretrain = self.options['pretrain']
if self.pretrain_just_disabled:
self.set_iter(0)
if self.is_training: if self.is_training:
# Adjust batch size for multiple GPU # Adjust batch size for multiple GPU
gpu_count = max(1, len(devices) ) gpu_count = max(1, len(devices) )
bs_per_gpu = max(1, self.get_batch_size() // gpu_count) bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
self.set_batch_size( gpu_count*bs_per_gpu) self.set_batch_size( gpu_count*bs_per_gpu)
targetm_t = tf.placeholder (nn.floatx, mask_shape)
# Compute losses per GPU # Compute losses per GPU
gpu_pred_list = [] gpu_pred_list = []
@ -81,19 +94,32 @@ class XSegModel(ModelBase):
gpu_loss_gvs = [] gpu_loss_gvs = []
for gpu_id in range(gpu_count): for gpu_id in range(gpu_count):
with tf.device(f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): with tf.device(f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
with tf.device(f'/CPU:0'): with tf.device(f'/CPU:0'):
# slice on CPU, otherwise all batch data will be transfered to GPU first # slice on CPU, otherwise all batch data will be transfered to GPU first
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
gpu_input_t = self.model.input_t [batch_slice,:,:,:] gpu_input_t = self.model.input_t [batch_slice,:,:,:]
gpu_target_t = self.model.target_t [batch_slice,:,:,:] gpu_target_t = self.model.target_t [batch_slice,:,:,:]
gpu_targetm_t = targetm_t [batch_slice,:,:,:]
# process model tensors # process model tensors
gpu_pred_logits_t, gpu_pred_t = self.model.flow(gpu_input_t) gpu_pred_logits_t, gpu_pred_t = self.model.flow(gpu_input_t, pretrain=self.pretrain)
gpu_pred_list.append(gpu_pred_t) gpu_pred_list.append(gpu_pred_t)
if self.pretrain:
gpu_targetm_blur = nn.gaussian_blur(gpu_targetm_t, max(1, resolution // 32) )
gpu_targetm_blur = tf.clip_by_value(gpu_targetm_blur, 0, 0.5) * 2
gpu_target_t_blur = gpu_target_t*gpu_targetm_blur
gpu_pred_t_blur = gpu_pred_t*gpu_targetm_t
# Structural loss
gpu_loss = tf.reduce_mean (5*nn.dssim(gpu_target_t_blur, gpu_pred_t_blur, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
gpu_loss += tf.reduce_mean (5*nn.dssim(gpu_target_t_blur, gpu_pred_t_blur, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1])
# Pixel loss
gpu_loss += tf.reduce_mean (10*tf.square(gpu_target_t_blur-gpu_pred_t_blur), axis=[1,2,3])
else:
gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1,2,3]) gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1,2,3])
gpu_losses += [gpu_loss] gpu_losses += [gpu_loss]
@ -110,6 +136,11 @@ class XSegModel(ModelBase):
# Initializing training and view functions # Initializing training and view functions
if self.pretrain:
def train(input_np, target_np, targetm_np):
l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np, targetm_t :targetm_np })
return l
else:
def train(input_np, target_np): def train(input_np, target_np):
l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np }) l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np })
return l return l
@ -125,7 +156,17 @@ class XSegModel(ModelBase):
src_generators_count = cpu_count // 2 src_generators_count = cpu_count // 2
dst_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2
if self.pretrain:
pretrain_gen = SampleGeneratorFace(self.get_pretraining_data_path(), debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=True),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
],
uniform_yaw_distribution=False,
generators_count=cpu_count )
self.set_training_data_generators ([pretrain_gen])
else:
srcdst_generator = SampleGeneratorFaceXSeg([self.training_data_src_path, self.training_data_dst_path], srcdst_generator = SampleGeneratorFaceXSeg([self.training_data_src_path, self.training_data_dst_path],
debug=self.is_debug(), debug=self.is_debug(),
batch_size=self.get_batch_size(), batch_size=self.get_batch_size(),
@ -159,14 +200,23 @@ class XSegModel(ModelBase):
#override #override
def onTrainOneIter(self): def onTrainOneIter(self):
if self.pretrain:
image_np, target_np, targetm_np = self.generate_next_samples()[0]
loss = self.train (image_np, target_np, targetm_np)
else:
image_np, mask_np = self.generate_next_samples()[0] image_np, mask_np = self.generate_next_samples()[0]
loss = self.train (image_np, mask_np) loss = self.train (image_np, mask_np)
return ( ('loss', np.mean(loss) ), ) return ( ('loss', np.mean(loss) ), )
#override #override
def onGetPreview(self, samples, for_history=False): def onGetPreview(self, samples, for_history=False):
n_samples = min(4, self.get_batch_size(), 800 // self.resolution ) n_samples = min(4, self.get_batch_size(), 800 // self.resolution )
if self.pretrain:
srcdst_samples, = samples
image_np, mask_np, _ = srcdst_samples
else:
srcdst_samples, src_samples, dst_samples = samples srcdst_samples, src_samples, dst_samples = samples
image_np, mask_np = srcdst_samples image_np, mask_np = srcdst_samples
@ -178,11 +228,14 @@ class XSegModel(ModelBase):
result = [] result = []
st = [] st = []
for i in range(n_samples): for i in range(n_samples):
if self.pretrain:
ar = I[i], IM[i]
else:
ar = I[i]*M[i]+0.5*I[i]*(1-M[i])+0.5*green_bg*(1-M[i]), IM[i], I[i]*IM[i]+0.5*I[i]*(1-IM[i]) + 0.5*green_bg*(1-IM[i]) ar = I[i]*M[i]+0.5*I[i]*(1-M[i])+0.5*green_bg*(1-M[i]), IM[i], I[i]*IM[i]+0.5*I[i]*(1-IM[i]) + 0.5*green_bg*(1-IM[i])
st.append ( np.concatenate ( ar, axis=1) ) st.append ( np.concatenate ( ar, axis=1) )
result += [ ('XSeg training faces', np.concatenate (st, axis=0 )), ] result += [ ('XSeg training faces', np.concatenate (st, axis=0 )), ]
if len(src_samples) != 0: if not self.pretrain and len(src_samples) != 0:
src_np, = src_samples src_np, = src_samples
@ -196,7 +249,7 @@ class XSegModel(ModelBase):
result += [ ('XSeg src faces', np.concatenate (st, axis=0 )), ] result += [ ('XSeg src faces', np.concatenate (st, axis=0 )), ]
if len(dst_samples) != 0: if not self.pretrain and len(dst_samples) != 0:
dst_np, = dst_samples dst_np, = dst_samples