enhanced SAE model. You should to restart training.

new default style power = 2.0
fix DF default batch sizes.
upd readme
This commit is contained in:
iperov 2019-01-17 21:41:40 +04:00
parent 22401cecfc
commit 946688567d
5 changed files with 137 additions and 134 deletions

View file

@ -82,7 +82,9 @@ LIAEF128 Cage video:
[![Watch the video](https://img.youtube.com/vi/mRsexePEVco/0.jpg)](https://www.youtube.com/watch?v=mRsexePEVco) [![Watch the video](https://img.youtube.com/vi/mRsexePEVco/0.jpg)](https://www.youtube.com/watch?v=mRsexePEVco)
- **SAE (2GB+)** - Styled AutoEncoder - new superior model based on style loss. Morphing/stylizing done directly by neural network. Face obstructions also reconstructed without any masks. Converter mode 'overlay' should be used. Model has several options on start for fine tuning to fit your GPU. - **SAE ( minimum 2GB+, recommended 11GB+ )** - Styled AutoEncoder - new superior model based on style loss. Morphing/stylizing done directly by neural network. Face obstructions also reconstructed without any masks. Converter mode 'overlay' should be used. Model has several options on start for fine tuning to fit your GPU.
SAE actually contains all other models. Just set style powers to 0.0 to get default models.
![](https://github.com/iperov/DeepFaceLab/blob/master/doc/SAE_Cage_0.jpg) ![](https://github.com/iperov/DeepFaceLab/blob/master/doc/SAE_Cage_0.jpg)
@ -96,7 +98,6 @@ SAE model Putin-Navalny video: https://www.youtube.com/watch?v=Jj7b3mqx-Mw
![](https://github.com/iperov/DeepFaceLab/blob/master/doc/DeepFaceLab_convertor_overview.png) ![](https://github.com/iperov/DeepFaceLab/blob/master/doc/DeepFaceLab_convertor_overview.png)
### **Tips and tricks**: ### **Tips and tricks**:
unfortunately deepfaking is time/eletricpower consuming topic and has a lot of nuances. unfortunately deepfaking is time/eletricpower consuming topic and has a lot of nuances.
@ -119,13 +120,17 @@ Narrow src face is better fakeable than wide. This is why Cage is so popular in
SAE tips: SAE tips:
- SAE actually contains all other models. Just set style power options to 0.0 to get default models.
- if src faceset has number of faces more than dst faceset, model can be not converged. In this case try 'Feed faces to network sorted by yaw' option. - if src faceset has number of faces more than dst faceset, model can be not converged. In this case try 'Feed faces to network sorted by yaw' option.
- if src face wider than dst, model can be not converged. In this case try to decrease 'Src face scale modifier' to -5. - if src face wider than dst, model can be not converged. In this case try to decrease 'Src face scale modifier' to -5.
- architecture 'df' make predicted face looking more like src, but if model not converges try default 'liae'. - architecture 'df' make predicted face looking more like src, but if model not converges try default 'liae'.
- most scenes converge fine with batch size = 8. In this case better to increase 'encoder/decoder dims per channel' to get more sharp result. - if you have a lot of VRAM, you can choose between batch size that affects quality of generalization and enc/dec dims that affects image quality.
- if style speed is too fast, you will get the artifacts before the face becomes sharp.
### **Sort tool**: ### **Sort tool**:

View file

@ -163,10 +163,10 @@ def previewThread (input_queue, output_queue):
if update_preview: if update_preview:
update_preview = False update_preview = False
(h,w,c) = previews[0][1].shape
selected_preview_name = previews[selected_preview][0] selected_preview_name = previews[selected_preview][0]
selected_preview_rgb = previews[selected_preview][1] selected_preview_rgb = previews[selected_preview][1]
(h,w,c) = selected_preview_rgb.shape
# HEAD # HEAD
head_text_color = [0.8]*c head_text_color = [0.8]*c

View file

@ -14,7 +14,7 @@ class Model(ModelBase):
#override #override
def onInitialize(self, **in_options): def onInitialize(self, **in_options):
exec(nnlib.import_all(), locals(), globals()) exec(nnlib.import_all(), locals(), globals())
self.set_vram_batch_requirements( {4.5:16,5:16,6:16,7:16,8:24,9:24,10:32,11:32,12:32,13:48} ) self.set_vram_batch_requirements( {4.5:4,5:6,6:8,7:16,8:24,9:24,10:32,11:32,12:32,13:48} )
ae_input_layer = Input(shape=(128, 128, 3)) ae_input_layer = Input(shape=(128, 128, 3))
mask_layer = Input(shape=(128, 128, 1)) #same as output mask_layer = Input(shape=(128, 128, 1)) #same as output

View file

@ -28,21 +28,21 @@ class SAEModel(ModelBase):
if is_first_run: if is_first_run:
self.options['resolution'] = input_int("Resolution (64,128 ?:help skip:128) : ", default_resolution, [64,128], help_message="More resolution requires more VRAM.") self.options['resolution'] = input_int("Resolution (64,128 ?:help skip:128) : ", default_resolution, [64,128], help_message="More resolution requires more VRAM.")
self.options['archi'] = input_str ("AE architecture (df, liae, ?:help skip:liae) : ", default_archi, ['df','liae'], help_message="DF keeps faces more natural, while LIAE can fix overly different face shapes.").lower() self.options['archi'] = input_str ("AE architecture (df, liae, ?:help skip:%s) : " % (default_archi) , default_archi, ['df','liae'], help_message="DF keeps faces more natural, while LIAE can fix overly different face shapes.").lower()
self.options['lighter_encoder'] = input_bool ("Use lightweight encoder? (y/n, ?:help skip:n) : ", False, help_message="Lightweight encoder is 35% faster, requires less VRAM, sacrificing overall quality.") self.options['lighter_encoder'] = input_bool ("Use lightweight encoder? (y/n, ?:help skip:n) : ", False, help_message="Lightweight encoder is 35% faster, requires less VRAM, sacrificing overall quality.")
else: else:
self.options['resolution'] = self.options.get('resolution', default_resolution) self.options['resolution'] = self.options.get('resolution', default_resolution)
self.options['archi'] = self.options.get('archi', default_archi) self.options['archi'] = self.options.get('archi', default_archi)
self.options['lighter_encoder'] = self.options.get('lighter_encoder', False) self.options['lighter_encoder'] = self.options.get('lighter_encoder', False)
default_face_style_power = 10.0 default_face_style_power = 2.0
if is_first_run or ask_override: if is_first_run or ask_override:
default_face_style_power = default_face_style_power if is_first_run else self.options.get('face_style_power', default_face_style_power) default_face_style_power = default_face_style_power if is_first_run else self.options.get('face_style_power', default_face_style_power)
self.options['face_style_power'] = np.clip ( input_number("Face style power ( 0.0 .. 100.0 ?:help skip:%.1f) : " % (default_face_style_power), default_face_style_power, help_message="How fast NN will learn dst face style during generalization of src and dst faces."), 0.0, 100.0 ) self.options['face_style_power'] = np.clip ( input_number("Face style power ( 0.0 .. 100.0 ?:help skip:%.1f) : " % (default_face_style_power), default_face_style_power, help_message="How fast NN will learn dst face style during generalization of src and dst faces."), 0.0, 100.0 )
else: else:
self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power) self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power)
default_bg_style_power = 10.0 default_bg_style_power = 2.0
if is_first_run or ask_override: if is_first_run or ask_override:
default_bg_style_power = default_bg_style_power if is_first_run else self.options.get('bg_style_power', default_bg_style_power) default_bg_style_power = default_bg_style_power if is_first_run else self.options.get('bg_style_power', default_bg_style_power)
self.options['bg_style_power'] = np.clip ( input_number("Background style power ( 0.0 .. 100.0 ?:help skip:%.1f) : " % (default_bg_style_power), default_bg_style_power, help_message="How fast NN will learn dst background style during generalization of src and dst faces."), 0.0, 100.0 ) self.options['bg_style_power'] = np.clip ( input_number("Background style power ( 0.0 .. 100.0 ?:help skip:%.1f) : " % (default_bg_style_power), default_bg_style_power, help_message="How fast NN will learn dst background style during generalization of src and dst faces."), 0.0, 100.0 )
@ -103,7 +103,6 @@ class SAEModel(ModelBase):
self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (inter_output_Inputs) self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (inter_output_Inputs)
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (inter_output_Inputs) self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (inter_output_Inputs)
if not self.is_first_run(): if not self.is_first_run():
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5)) self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
self.inter_B.load_weights (self.get_strpath_storage_for_file(self.inter_BH5)) self.inter_B.load_weights (self.get_strpath_storage_for_file(self.inter_BH5))
@ -137,11 +136,9 @@ class SAEModel(ModelBase):
self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (dec_Inputs) self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (dec_Inputs)
self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (dec_Inputs) self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (dec_Inputs)
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5))) (dec_Inputs) self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5))) (dec_Inputs)
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5))) (dec_Inputs) self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5))) (dec_Inputs)
if not self.is_first_run(): if not self.is_first_run():
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5)) self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5)) self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5))
@ -150,100 +147,91 @@ class SAEModel(ModelBase):
self.decoder_dstm.load_weights (self.get_strpath_storage_for_file(self.decoder_dstmH5)) self.decoder_dstm.load_weights (self.get_strpath_storage_for_file(self.decoder_dstmH5))
warped_src_code = self.encoder (warped_src) warped_src_code = self.encoder (warped_src)
pred_src_src = self.decoder_src(warped_src_code) pred_src_src = self.decoder_src(warped_src_code)
pred_src_srcm = self.decoder_srcm(warped_src_code) pred_src_srcm = self.decoder_srcm(warped_src_code)
warped_dst_code = self.encoder (warped_dst) warped_dst_code = self.encoder (warped_dst)
pred_dst_dst = self.decoder_dst(warped_dst_code) pred_dst_dst = self.decoder_dst(warped_dst_code)
pred_dst_dstm = self.decoder_dstm(warped_dst_code) pred_dst_dstm = self.decoder_dstm(warped_dst_code)
pred_src_dst = self.decoder_src(warped_dst_code) pred_src_dst = self.decoder_src(warped_dst_code)
pred_src_dstm = self.decoder_srcm(warped_dst_code) pred_src_dstm = self.decoder_srcm(warped_dst_code)
ms_count = len(pred_src_src)
target_srcm_blurred = tf_gaussian_blur(resolution // 32)(target_srcm) target_src_ar = [ target_src if i == 0 else tf.image.resize_bicubic( target_src, (resolution // (2**i) ,)*2 ) for i in range(ms_count-1, -1, -1)]
target_srcm_sigm = target_srcm_blurred / 2.0 + 0.5 target_srcm_ar = [ target_srcm if i == 0 else tf.image.resize_bicubic( target_srcm, (resolution // (2**i) ,)*2 ) for i in range(ms_count-1, -1, -1)]
target_srcm_anti_sigm = 1.0 - target_srcm_sigm target_dst_ar = [ target_dst if i == 0 else tf.image.resize_bicubic( target_dst, (resolution // (2**i) ,)*2 ) for i in range(ms_count-1, -1, -1)]
target_dstm_ar = [ target_dstm if i == 0 else tf.image.resize_bicubic( target_dstm, (resolution // (2**i) ,)*2 ) for i in range(ms_count-1, -1, -1)]
target_dstm_blurred = tf_gaussian_blur(resolution // 32)(target_dstm) target_srcm_blurred_ar = [ tf_gaussian_blur( max(1, x.get_shape().as_list()[1] // 32) )(x) for x in target_srcm_ar]
target_dstm_sigm = target_dstm_blurred / 2.0 + 0.5 target_srcm_sigm_ar = [ x / 2.0 + 0.5 for x in target_srcm_blurred_ar]
target_dstm_anti_sigm = 1.0 - target_dstm_sigm target_srcm_anti_sigm_ar = [ 1.0 - x for x in target_srcm_sigm_ar]
target_src_sigm = target_src+1 target_dstm_blurred_ar = [ tf_gaussian_blur( max(1, x.get_shape().as_list()[1] // 32) )(x) for x in target_dstm_ar]
target_dst_sigm = target_dst+1 target_dstm_sigm_ar = [ x / 2.0 + 0.5 for x in target_dstm_blurred_ar]
target_dstm_anti_sigm_ar = [ 1.0 - x for x in target_dstm_sigm_ar]
pred_src_src_sigm = pred_src_src+1 target_src_sigm_ar = [ x + 1 for x in target_src_ar]
pred_dst_dst_sigm = pred_dst_dst+1 target_dst_sigm_ar = [ x + 1 for x in target_dst_ar]
pred_src_dst_sigm = pred_src_dst+1
target_src_masked = target_src_sigm*target_srcm_sigm pred_src_src_sigm_ar = [ x + 1 for x in pred_src_src]
pred_dst_dst_sigm_ar = [ x + 1 for x in pred_dst_dst]
pred_src_dst_sigm_ar = [ x + 1 for x in pred_src_dst]
target_dst_masked = target_dst_sigm * target_dstm_sigm target_src_masked_ar = [ target_src_sigm_ar[i]*target_srcm_sigm_ar[i] for i in range(len(target_src_sigm_ar))]
target_dst_anti_masked = target_dst_sigm * target_dstm_anti_sigm target_dst_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(target_dst_sigm_ar))]
target_dst_anti_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(target_dst_sigm_ar))]
pred_src_src_masked = pred_src_src_sigm * target_srcm_sigm psd_target_dst_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
pred_dst_dst_masked = pred_dst_dst_sigm * target_dstm_sigm psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
psd_target_dst_masked = pred_src_dst_sigm * target_dstm_sigm if self.is_training_mode:
psd_target_dst_anti_masked = pred_src_dst_sigm * target_dstm_anti_sigm def optimizer():
return Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)
src_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked, pred_src_src_masked )) )
def optimizer(): if self.options['archi'] == 'liae':
return Adam(lr=5e-5, beta_1=0.5, beta_2=0.999) src_loss_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
src_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
dst_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
else:
src_loss_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights
src_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights
dst_loss_train_weights = self.encoder.trainable_weights + self.decoder_dst.trainable_weights
dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_dstm.trainable_weights
if self.options['face_style_power'] != 0: src_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked_ar[i], pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] ) )) for i in range(len(target_src_masked_ar)) ])
face_style_power = self.options['face_style_power'] / 100.0
src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*face_style_power)(psd_target_dst_masked, target_dst_masked)
if self.options['bg_style_power'] != 0: if self.options['face_style_power'] != 0:
bg_style_power = self.options['bg_style_power'] / 100.0 face_style_power = self.options['face_style_power'] / 100.0
src_loss += K.mean( (100*bg_style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked, target_dst_anti_masked ))) src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*face_style_power)( psd_target_dst_masked_ar[-1], target_dst_masked_ar[-1] )
if self.options['archi'] == 'liae': if self.options['bg_style_power'] != 0:
src_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights bg_style_power = self.options['bg_style_power'] / 100.0
src_loss += K.mean( (100*bg_style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] )))
self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss], optimizer().get_updates(src_loss, src_loss_train_weights) )
src_mask_loss = sum([ K.mean(K.square(target_srcm_ar[i]-pred_src_srcm[i])) for i in range(len(target_srcm_ar)) ])
self.src_mask_train = K.function ([warped_src, target_srcm],[src_mask_loss], optimizer().get_updates(src_mask_loss, src_mask_loss_train_weights) )
dst_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked_ar[i], pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ) )) for i in range(len(target_dst_masked_ar)) ])
self.dst_train = K.function ([warped_dst, target_dst, target_dstm ],[dst_loss], optimizer().get_updates(dst_loss, dst_loss_train_weights) )
dst_mask_loss = sum([ K.mean(K.square(target_dstm_ar[i]-pred_dst_dstm[i])) for i in range(len(target_dstm_ar)) ])
self.dst_mask_train = K.function ([warped_dst, target_dstm],[dst_mask_loss], optimizer().get_updates(dst_mask_loss, dst_mask_loss_train_weights) )
self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_src_srcm[-1], pred_dst_dst[-1], pred_dst_dstm[-1], pred_src_dst[-1], pred_src_dstm[-1]] )
else: else:
src_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1], pred_src_dstm[-1] ])
self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss],
optimizer().get_updates(src_loss, src_train_weights) )
dst_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked, pred_dst_dst_masked )) )
if self.options['archi'] == 'liae':
dst_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
else:
dst_train_weights = self.encoder.trainable_weights + self.decoder_dst.trainable_weights
self.dst_train = K.function ([warped_dst, target_dst, target_dstm],[dst_loss],
optimizer().get_updates(dst_loss, dst_train_weights) )
src_mask_loss = K.mean(K.square(target_srcm-pred_src_srcm))
if self.options['archi'] == 'liae':
src_mask_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
else:
src_mask_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights
self.src_mask_train = K.function ([warped_src, target_srcm],[src_mask_loss],
optimizer().get_updates(src_mask_loss, src_mask_train_weights ) )
dst_mask_loss = K.mean(K.square(target_dstm-pred_dst_dstm))
if self.options['archi'] == 'liae':
dst_mask_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
else:
dst_mask_train_weights = self.encoder.trainable_weights + self.decoder_dstm.trainable_weights
self.dst_mask_train = K.function ([warped_dst, target_dstm],[dst_mask_loss],
optimizer().get_updates(dst_mask_loss, dst_mask_train_weights) )
self.AE_view = K.function ([warped_src, warped_dst],[pred_src_src, pred_src_srcm, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm])
self.AE_convert = K.function ([warped_dst],[pred_src_dst, pred_src_dstm])
if self.is_training_mode: if self.is_training_mode:
f = SampleProcessor.TypeFlags f = SampleProcessor.TypeFlags
face_type = f.FACE_ALIGN_FULL if self.options['face_type'] == 'f' else f.FACE_ALIGN_HALF face_type = f.FACE_ALIGN_FULL if self.options['face_type'] == 'f' else f.FACE_ALIGN_HALF
self.set_training_data_generators ([ self.set_training_data_generators ([
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None, SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
debug=self.is_debug(), batch_size=self.batch_size, debug=self.is_debug(), batch_size=self.batch_size,
@ -257,7 +245,8 @@ class SAEModel(ModelBase):
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, normalize_tanh = True), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, normalize_tanh = True),
output_sample_types=[ [f.WARPED_TRANSFORMED | face_type | f.MODE_BGR, resolution], output_sample_types=[ [f.WARPED_TRANSFORMED | face_type | f.MODE_BGR, resolution],
[f.TRANSFORMED | face_type | f.MODE_BGR, resolution], [f.TRANSFORMED | face_type | f.MODE_BGR, resolution],
[f.TRANSFORMED | face_type | f.MODE_M | f.FACE_MASK_FULL, resolution] ] ) [f.TRANSFORMED | face_type | f.MODE_M | f.FACE_MASK_FULL, resolution]
] )
]) ])
#override #override
def onSave(self): def onSave(self):
@ -297,23 +286,18 @@ class SAEModel(ModelBase):
test_B = sample[1][1][0:4] test_B = sample[1][1][0:4]
test_B_m = sample[1][2][0:4] test_B_m = sample[1][2][0:4]
S = test_A S, D, SS, SM, DD, DM, SD, SDM, = [ x / 2 + 0.5 for x in ([test_A,test_B] + self.AE_view ([test_A, test_B]) ) ]
D = test_B #SM, DM, SDM = [ np.repeat (x, (3,), -1) for x in [SM, DM, SDM] ]
SS, SM, DD, DM, SD, SDM = self.AE_view ([test_A, test_B]) st_x3 = []
S, D, SS, SM, DD, DM, SD, SDM = [ x / 2 + 0.5 for x in [S, D, SS, SM, DD, DM, SD, SDM] ]
SM, DM, SDM = [ np.repeat (x, (3,), -1) for x in [SM, DM, SDM] ]
st = []
for i in range(0, len(test_A)): for i in range(0, len(test_A)):
st.append ( np.concatenate ( ( st_x3.append ( np.concatenate ( (
S[i], SS[i], #SM[i], S[i], SS[i], #SM[i],
D[i], DD[i], #DM[i], D[i], DD[i], #DM[i],
SD[i], #SDM[i] SD[i], #SDM[i]
), axis=1) ) ), axis=1) )
return [ ('SAE', np.concatenate ( st, axis=0 ) ) ] return [ ('SAE', np.concatenate (st_x3, axis=0 )), ]
def predictor_func (self, face): def predictor_func (self, face):
face = face * 2.0 - 1.0 face = face * 2.0 - 1.0
@ -371,7 +355,7 @@ class SAEModel(ModelBase):
x = downscale(ed_dims*8)(x) x = downscale(ed_dims*8)(x)
else: else:
x = downscale_sep(ed_dims*2)(x) x = downscale_sep(ed_dims*2)(x)
x = downscale_sep(ed_dims*4)(x) x = downscale(ed_dims*4)(x)
x = downscale_sep(ed_dims*8)(x) x = downscale_sep(ed_dims*8)(x)
x = Flatten()(x) x = Flatten()(x)
@ -400,26 +384,31 @@ class SAEModel(ModelBase):
@staticmethod @staticmethod
def LIAEDecFlow(output_nc,ed_ch_dims=21,activation='tanh'): def LIAEDecFlow(output_nc,ed_ch_dims=21,activation='tanh'):
exec (nnlib.import_all(), locals(), globals()) exec (nnlib.import_all(), locals(), globals())
ed_dims = output_nc * ed_ch_dims
def upscale (dim): def upscale (dim):
def func(x): def func(x):
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x))) return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
return func return func
def to_bgr ():
def func(x):
return Conv2D(output_nc, kernel_size=5, padding='same', activation='tanh')(x)
return func
def func(input): def func(input):
ed_dims = output_nc * ed_ch_dims
x = input[0] x = input[0]
x = upscale(ed_dims*8)(x) x1 = upscale(ed_dims*8)( x )
x = upscale(ed_dims*4)(x) x1_bgr = to_bgr() ( x1 )
x = upscale(ed_dims*2)(x)
x = Conv2D(output_nc, kernel_size=5, padding='same', activation=activation)(x) x2 = upscale(ed_dims*4)( x1 )
return x x2_bgr = to_bgr() ( x2 )
x3 = upscale(ed_dims*2)( x2 )
x3_bgr = to_bgr() ( x3 )
return [ x1_bgr, x2_bgr, x3_bgr ]
return func return func
@staticmethod @staticmethod
def DFEncFlow(resolution, adapt_k_size, light_enc, ae_dims=512, ed_ch_dims=42): def DFEncFlow(resolution, adapt_k_size, light_enc, ae_dims=512, ed_ch_dims=42):
exec (nnlib.import_all(), locals(), globals()) exec (nnlib.import_all(), locals(), globals())
@ -466,7 +455,7 @@ class SAEModel(ModelBase):
return func return func
@staticmethod @staticmethod
def DFDecFlow(output_nc, ed_ch_dims=21, activation='tanh'): def DFDecFlow(output_nc, ed_ch_dims=21):
exec (nnlib.import_all(), locals(), globals()) exec (nnlib.import_all(), locals(), globals())
ed_dims = output_nc * ed_ch_dims ed_dims = output_nc * ed_ch_dims
@ -474,15 +463,23 @@ class SAEModel(ModelBase):
def func(x): def func(x):
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x))) return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
return func return func
def to_bgr ():
def func(x):
return Conv2D(output_nc, kernel_size=5, padding='same', activation='tanh')(x)
return func
def func(input): def func(input):
x = input[0] x = input[0]
x = upscale(ed_dims*8)(x) x1 = upscale(ed_dims*8)( x )
x = upscale(ed_dims*4)(x) x1_bgr = to_bgr() ( x1 )
x = upscale(ed_dims*2)(x)
x = Conv2D(output_nc, kernel_size=5, padding='same', activation=activation)(x) x2 = upscale(ed_dims*4)( x1 )
return x x2_bgr = to_bgr() ( x2 )
x3 = upscale(ed_dims*2)( x2 )
x3_bgr = to_bgr() ( x3 )
return [ x1_bgr, x2_bgr, x3_bgr ]
return func return func
Model = SAEModel Model = SAEModel

View file

@ -75,6 +75,7 @@ Conv2D = keras.layers.Conv2D
Conv2DTranspose = keras.layers.Conv2DTranspose Conv2DTranspose = keras.layers.Conv2DTranspose
SeparableConv2D = keras.layers.SeparableConv2D SeparableConv2D = keras.layers.SeparableConv2D
MaxPooling2D = keras.layers.MaxPooling2D MaxPooling2D = keras.layers.MaxPooling2D
UpSampling2D = keras.layers.UpSampling2D
BatchNormalization = keras.layers.BatchNormalization BatchNormalization = keras.layers.BatchNormalization
LeakyReLU = keras.layers.LeakyReLU LeakyReLU = keras.layers.LeakyReLU