mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 21:12:07 -07:00
enhanced SAE model. You should to restart training.
new default style power = 2.0 fix DF default batch sizes. upd readme
This commit is contained in:
parent
22401cecfc
commit
946688567d
5 changed files with 137 additions and 134 deletions
11
README.md
11
README.md
|
@ -82,7 +82,9 @@ LIAEF128 Cage video:
|
|||
|
||||
[](https://www.youtube.com/watch?v=mRsexePEVco)
|
||||
|
||||
- **SAE (2GB+)** - Styled AutoEncoder - new superior model based on style loss. Morphing/stylizing done directly by neural network. Face obstructions also reconstructed without any masks. Converter mode 'overlay' should be used. Model has several options on start for fine tuning to fit your GPU.
|
||||
- **SAE ( minimum 2GB+, recommended 11GB+ )** - Styled AutoEncoder - new superior model based on style loss. Morphing/stylizing done directly by neural network. Face obstructions also reconstructed without any masks. Converter mode 'overlay' should be used. Model has several options on start for fine tuning to fit your GPU.
|
||||
|
||||
SAE actually contains all other models. Just set style powers to 0.0 to get default models.
|
||||
|
||||

|
||||
|
||||
|
@ -96,7 +98,6 @@ SAE model Putin-Navalny video: https://www.youtube.com/watch?v=Jj7b3mqx-Mw
|
|||
|
||||

|
||||
|
||||
|
||||
### **Tips and tricks**:
|
||||
|
||||
unfortunately deepfaking is time/eletricpower consuming topic and has a lot of nuances.
|
||||
|
@ -119,13 +120,17 @@ Narrow src face is better fakeable than wide. This is why Cage is so popular in
|
|||
|
||||
SAE tips:
|
||||
|
||||
- SAE actually contains all other models. Just set style power options to 0.0 to get default models.
|
||||
|
||||
- if src faceset has number of faces more than dst faceset, model can be not converged. In this case try 'Feed faces to network sorted by yaw' option.
|
||||
|
||||
- if src face wider than dst, model can be not converged. In this case try to decrease 'Src face scale modifier' to -5.
|
||||
|
||||
- architecture 'df' make predicted face looking more like src, but if model not converges try default 'liae'.
|
||||
|
||||
- most scenes converge fine with batch size = 8. In this case better to increase 'encoder/decoder dims per channel' to get more sharp result.
|
||||
- if you have a lot of VRAM, you can choose between batch size that affects quality of generalization and enc/dec dims that affects image quality.
|
||||
|
||||
- if style speed is too fast, you will get the artifacts before the face becomes sharp.
|
||||
|
||||
### **Sort tool**:
|
||||
|
||||
|
|
|
@ -163,10 +163,10 @@ def previewThread (input_queue, output_queue):
|
|||
|
||||
if update_preview:
|
||||
update_preview = False
|
||||
(h,w,c) = previews[0][1].shape
|
||||
|
||||
selected_preview_name = previews[selected_preview][0]
|
||||
selected_preview_rgb = previews[selected_preview][1]
|
||||
(h,w,c) = selected_preview_rgb.shape
|
||||
|
||||
# HEAD
|
||||
head_text_color = [0.8]*c
|
||||
|
|
|
@ -14,7 +14,7 @@ class Model(ModelBase):
|
|||
#override
|
||||
def onInitialize(self, **in_options):
|
||||
exec(nnlib.import_all(), locals(), globals())
|
||||
self.set_vram_batch_requirements( {4.5:16,5:16,6:16,7:16,8:24,9:24,10:32,11:32,12:32,13:48} )
|
||||
self.set_vram_batch_requirements( {4.5:4,5:6,6:8,7:16,8:24,9:24,10:32,11:32,12:32,13:48} )
|
||||
|
||||
ae_input_layer = Input(shape=(128, 128, 3))
|
||||
mask_layer = Input(shape=(128, 128, 1)) #same as output
|
||||
|
|
|
@ -28,21 +28,21 @@ class SAEModel(ModelBase):
|
|||
|
||||
if is_first_run:
|
||||
self.options['resolution'] = input_int("Resolution (64,128 ?:help skip:128) : ", default_resolution, [64,128], help_message="More resolution requires more VRAM.")
|
||||
self.options['archi'] = input_str ("AE architecture (df, liae, ?:help skip:liae) : ", default_archi, ['df','liae'], help_message="DF keeps faces more natural, while LIAE can fix overly different face shapes.").lower()
|
||||
self.options['archi'] = input_str ("AE architecture (df, liae, ?:help skip:%s) : " % (default_archi) , default_archi, ['df','liae'], help_message="DF keeps faces more natural, while LIAE can fix overly different face shapes.").lower()
|
||||
self.options['lighter_encoder'] = input_bool ("Use lightweight encoder? (y/n, ?:help skip:n) : ", False, help_message="Lightweight encoder is 35% faster, requires less VRAM, sacrificing overall quality.")
|
||||
else:
|
||||
self.options['resolution'] = self.options.get('resolution', default_resolution)
|
||||
self.options['archi'] = self.options.get('archi', default_archi)
|
||||
self.options['lighter_encoder'] = self.options.get('lighter_encoder', False)
|
||||
|
||||
default_face_style_power = 10.0
|
||||
default_face_style_power = 2.0
|
||||
if is_first_run or ask_override:
|
||||
default_face_style_power = default_face_style_power if is_first_run else self.options.get('face_style_power', default_face_style_power)
|
||||
self.options['face_style_power'] = np.clip ( input_number("Face style power ( 0.0 .. 100.0 ?:help skip:%.1f) : " % (default_face_style_power), default_face_style_power, help_message="How fast NN will learn dst face style during generalization of src and dst faces."), 0.0, 100.0 )
|
||||
else:
|
||||
self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power)
|
||||
|
||||
default_bg_style_power = 10.0
|
||||
default_bg_style_power = 2.0
|
||||
if is_first_run or ask_override:
|
||||
default_bg_style_power = default_bg_style_power if is_first_run else self.options.get('bg_style_power', default_bg_style_power)
|
||||
self.options['bg_style_power'] = np.clip ( input_number("Background style power ( 0.0 .. 100.0 ?:help skip:%.1f) : " % (default_bg_style_power), default_bg_style_power, help_message="How fast NN will learn dst background style during generalization of src and dst faces."), 0.0, 100.0 )
|
||||
|
@ -103,7 +103,6 @@ class SAEModel(ModelBase):
|
|||
self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (inter_output_Inputs)
|
||||
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (inter_output_Inputs)
|
||||
|
||||
|
||||
if not self.is_first_run():
|
||||
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
|
||||
self.inter_B.load_weights (self.get_strpath_storage_for_file(self.inter_BH5))
|
||||
|
@ -137,11 +136,9 @@ class SAEModel(ModelBase):
|
|||
|
||||
self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (dec_Inputs)
|
||||
self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (dec_Inputs)
|
||||
|
||||
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5))) (dec_Inputs)
|
||||
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5))) (dec_Inputs)
|
||||
|
||||
|
||||
if not self.is_first_run():
|
||||
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
|
||||
self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5))
|
||||
|
@ -150,100 +147,91 @@ class SAEModel(ModelBase):
|
|||
self.decoder_dstm.load_weights (self.get_strpath_storage_for_file(self.decoder_dstmH5))
|
||||
|
||||
warped_src_code = self.encoder (warped_src)
|
||||
|
||||
pred_src_src = self.decoder_src(warped_src_code)
|
||||
pred_src_srcm = self.decoder_srcm(warped_src_code)
|
||||
|
||||
warped_dst_code = self.encoder (warped_dst)
|
||||
|
||||
pred_dst_dst = self.decoder_dst(warped_dst_code)
|
||||
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
|
||||
|
||||
pred_src_dst = self.decoder_src(warped_dst_code)
|
||||
pred_src_dstm = self.decoder_srcm(warped_dst_code)
|
||||
|
||||
ms_count = len(pred_src_src)
|
||||
|
||||
target_srcm_blurred = tf_gaussian_blur(resolution // 32)(target_srcm)
|
||||
target_srcm_sigm = target_srcm_blurred / 2.0 + 0.5
|
||||
target_srcm_anti_sigm = 1.0 - target_srcm_sigm
|
||||
target_src_ar = [ target_src if i == 0 else tf.image.resize_bicubic( target_src, (resolution // (2**i) ,)*2 ) for i in range(ms_count-1, -1, -1)]
|
||||
target_srcm_ar = [ target_srcm if i == 0 else tf.image.resize_bicubic( target_srcm, (resolution // (2**i) ,)*2 ) for i in range(ms_count-1, -1, -1)]
|
||||
target_dst_ar = [ target_dst if i == 0 else tf.image.resize_bicubic( target_dst, (resolution // (2**i) ,)*2 ) for i in range(ms_count-1, -1, -1)]
|
||||
target_dstm_ar = [ target_dstm if i == 0 else tf.image.resize_bicubic( target_dstm, (resolution // (2**i) ,)*2 ) for i in range(ms_count-1, -1, -1)]
|
||||
|
||||
target_dstm_blurred = tf_gaussian_blur(resolution // 32)(target_dstm)
|
||||
target_dstm_sigm = target_dstm_blurred / 2.0 + 0.5
|
||||
target_dstm_anti_sigm = 1.0 - target_dstm_sigm
|
||||
target_srcm_blurred_ar = [ tf_gaussian_blur( max(1, x.get_shape().as_list()[1] // 32) )(x) for x in target_srcm_ar]
|
||||
target_srcm_sigm_ar = [ x / 2.0 + 0.5 for x in target_srcm_blurred_ar]
|
||||
target_srcm_anti_sigm_ar = [ 1.0 - x for x in target_srcm_sigm_ar]
|
||||
|
||||
target_src_sigm = target_src+1
|
||||
target_dst_sigm = target_dst+1
|
||||
target_dstm_blurred_ar = [ tf_gaussian_blur( max(1, x.get_shape().as_list()[1] // 32) )(x) for x in target_dstm_ar]
|
||||
target_dstm_sigm_ar = [ x / 2.0 + 0.5 for x in target_dstm_blurred_ar]
|
||||
target_dstm_anti_sigm_ar = [ 1.0 - x for x in target_dstm_sigm_ar]
|
||||
|
||||
pred_src_src_sigm = pred_src_src+1
|
||||
pred_dst_dst_sigm = pred_dst_dst+1
|
||||
pred_src_dst_sigm = pred_src_dst+1
|
||||
target_src_sigm_ar = [ x + 1 for x in target_src_ar]
|
||||
target_dst_sigm_ar = [ x + 1 for x in target_dst_ar]
|
||||
|
||||
target_src_masked = target_src_sigm*target_srcm_sigm
|
||||
pred_src_src_sigm_ar = [ x + 1 for x in pred_src_src]
|
||||
pred_dst_dst_sigm_ar = [ x + 1 for x in pred_dst_dst]
|
||||
pred_src_dst_sigm_ar = [ x + 1 for x in pred_src_dst]
|
||||
|
||||
target_dst_masked = target_dst_sigm * target_dstm_sigm
|
||||
target_dst_anti_masked = target_dst_sigm * target_dstm_anti_sigm
|
||||
target_src_masked_ar = [ target_src_sigm_ar[i]*target_srcm_sigm_ar[i] for i in range(len(target_src_sigm_ar))]
|
||||
target_dst_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(target_dst_sigm_ar))]
|
||||
target_dst_anti_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(target_dst_sigm_ar))]
|
||||
|
||||
pred_src_src_masked = pred_src_src_sigm * target_srcm_sigm
|
||||
pred_dst_dst_masked = pred_dst_dst_sigm * target_dstm_sigm
|
||||
|
||||
psd_target_dst_masked = pred_src_dst_sigm * target_dstm_sigm
|
||||
psd_target_dst_anti_masked = pred_src_dst_sigm * target_dstm_anti_sigm
|
||||
|
||||
src_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked, pred_src_src_masked )) )
|
||||
psd_target_dst_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
||||
psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
||||
|
||||
if self.is_training_mode:
|
||||
def optimizer():
|
||||
return Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)
|
||||
|
||||
|
||||
if self.options['archi'] == 'liae':
|
||||
src_loss_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
||||
src_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
|
||||
dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
||||
dst_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
|
||||
else:
|
||||
src_loss_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights
|
||||
src_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights
|
||||
dst_loss_train_weights = self.encoder.trainable_weights + self.decoder_dst.trainable_weights
|
||||
dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_dstm.trainable_weights
|
||||
|
||||
src_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked_ar[i], pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] ) )) for i in range(len(target_src_masked_ar)) ])
|
||||
|
||||
if self.options['face_style_power'] != 0:
|
||||
face_style_power = self.options['face_style_power'] / 100.0
|
||||
src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*face_style_power)(psd_target_dst_masked, target_dst_masked)
|
||||
src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*face_style_power)( psd_target_dst_masked_ar[-1], target_dst_masked_ar[-1] )
|
||||
|
||||
if self.options['bg_style_power'] != 0:
|
||||
bg_style_power = self.options['bg_style_power'] / 100.0
|
||||
src_loss += K.mean( (100*bg_style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked, target_dst_anti_masked )))
|
||||
src_loss += K.mean( (100*bg_style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] )))
|
||||
|
||||
if self.options['archi'] == 'liae':
|
||||
src_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
||||
self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss], optimizer().get_updates(src_loss, src_loss_train_weights) )
|
||||
|
||||
src_mask_loss = sum([ K.mean(K.square(target_srcm_ar[i]-pred_src_srcm[i])) for i in range(len(target_srcm_ar)) ])
|
||||
self.src_mask_train = K.function ([warped_src, target_srcm],[src_mask_loss], optimizer().get_updates(src_mask_loss, src_mask_loss_train_weights) )
|
||||
|
||||
dst_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked_ar[i], pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ) )) for i in range(len(target_dst_masked_ar)) ])
|
||||
self.dst_train = K.function ([warped_dst, target_dst, target_dstm ],[dst_loss], optimizer().get_updates(dst_loss, dst_loss_train_weights) )
|
||||
|
||||
dst_mask_loss = sum([ K.mean(K.square(target_dstm_ar[i]-pred_dst_dstm[i])) for i in range(len(target_dstm_ar)) ])
|
||||
self.dst_mask_train = K.function ([warped_dst, target_dstm],[dst_mask_loss], optimizer().get_updates(dst_mask_loss, dst_mask_loss_train_weights) )
|
||||
|
||||
self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_src_srcm[-1], pred_dst_dst[-1], pred_dst_dstm[-1], pred_src_dst[-1], pred_src_dstm[-1]] )
|
||||
else:
|
||||
src_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights
|
||||
self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss],
|
||||
optimizer().get_updates(src_loss, src_train_weights) )
|
||||
|
||||
dst_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked, pred_dst_dst_masked )) )
|
||||
|
||||
if self.options['archi'] == 'liae':
|
||||
dst_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
||||
else:
|
||||
dst_train_weights = self.encoder.trainable_weights + self.decoder_dst.trainable_weights
|
||||
self.dst_train = K.function ([warped_dst, target_dst, target_dstm],[dst_loss],
|
||||
optimizer().get_updates(dst_loss, dst_train_weights) )
|
||||
|
||||
src_mask_loss = K.mean(K.square(target_srcm-pred_src_srcm))
|
||||
|
||||
if self.options['archi'] == 'liae':
|
||||
src_mask_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
|
||||
else:
|
||||
src_mask_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights
|
||||
|
||||
self.src_mask_train = K.function ([warped_src, target_srcm],[src_mask_loss],
|
||||
optimizer().get_updates(src_mask_loss, src_mask_train_weights ) )
|
||||
|
||||
dst_mask_loss = K.mean(K.square(target_dstm-pred_dst_dstm))
|
||||
|
||||
if self.options['archi'] == 'liae':
|
||||
dst_mask_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
|
||||
else:
|
||||
dst_mask_train_weights = self.encoder.trainable_weights + self.decoder_dstm.trainable_weights
|
||||
|
||||
self.dst_mask_train = K.function ([warped_dst, target_dstm],[dst_mask_loss],
|
||||
optimizer().get_updates(dst_mask_loss, dst_mask_train_weights) )
|
||||
|
||||
self.AE_view = K.function ([warped_src, warped_dst],[pred_src_src, pred_src_srcm, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm])
|
||||
self.AE_convert = K.function ([warped_dst],[pred_src_dst, pred_src_dstm])
|
||||
self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1], pred_src_dstm[-1] ])
|
||||
|
||||
if self.is_training_mode:
|
||||
f = SampleProcessor.TypeFlags
|
||||
|
||||
face_type = f.FACE_ALIGN_FULL if self.options['face_type'] == 'f' else f.FACE_ALIGN_HALF
|
||||
|
||||
self.set_training_data_generators ([
|
||||
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
|
||||
debug=self.is_debug(), batch_size=self.batch_size,
|
||||
|
@ -257,7 +245,8 @@ class SAEModel(ModelBase):
|
|||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, normalize_tanh = True),
|
||||
output_sample_types=[ [f.WARPED_TRANSFORMED | face_type | f.MODE_BGR, resolution],
|
||||
[f.TRANSFORMED | face_type | f.MODE_BGR, resolution],
|
||||
[f.TRANSFORMED | face_type | f.MODE_M | f.FACE_MASK_FULL, resolution] ] )
|
||||
[f.TRANSFORMED | face_type | f.MODE_M | f.FACE_MASK_FULL, resolution]
|
||||
] )
|
||||
])
|
||||
#override
|
||||
def onSave(self):
|
||||
|
@ -297,23 +286,18 @@ class SAEModel(ModelBase):
|
|||
test_B = sample[1][1][0:4]
|
||||
test_B_m = sample[1][2][0:4]
|
||||
|
||||
S = test_A
|
||||
D = test_B
|
||||
S, D, SS, SM, DD, DM, SD, SDM, = [ x / 2 + 0.5 for x in ([test_A,test_B] + self.AE_view ([test_A, test_B]) ) ]
|
||||
#SM, DM, SDM = [ np.repeat (x, (3,), -1) for x in [SM, DM, SDM] ]
|
||||
|
||||
SS, SM, DD, DM, SD, SDM = self.AE_view ([test_A, test_B])
|
||||
S, D, SS, SM, DD, DM, SD, SDM = [ x / 2 + 0.5 for x in [S, D, SS, SM, DD, DM, SD, SDM] ]
|
||||
|
||||
SM, DM, SDM = [ np.repeat (x, (3,), -1) for x in [SM, DM, SDM] ]
|
||||
|
||||
st = []
|
||||
st_x3 = []
|
||||
for i in range(0, len(test_A)):
|
||||
st.append ( np.concatenate ( (
|
||||
st_x3.append ( np.concatenate ( (
|
||||
S[i], SS[i], #SM[i],
|
||||
D[i], DD[i], #DM[i],
|
||||
SD[i], #SDM[i]
|
||||
), axis=1) )
|
||||
|
||||
return [ ('SAE', np.concatenate ( st, axis=0 ) ) ]
|
||||
return [ ('SAE', np.concatenate (st_x3, axis=0 )), ]
|
||||
|
||||
def predictor_func (self, face):
|
||||
face = face * 2.0 - 1.0
|
||||
|
@ -371,7 +355,7 @@ class SAEModel(ModelBase):
|
|||
x = downscale(ed_dims*8)(x)
|
||||
else:
|
||||
x = downscale_sep(ed_dims*2)(x)
|
||||
x = downscale_sep(ed_dims*4)(x)
|
||||
x = downscale(ed_dims*4)(x)
|
||||
x = downscale_sep(ed_dims*8)(x)
|
||||
|
||||
x = Flatten()(x)
|
||||
|
@ -400,25 +384,30 @@ class SAEModel(ModelBase):
|
|||
@staticmethod
|
||||
def LIAEDecFlow(output_nc,ed_ch_dims=21,activation='tanh'):
|
||||
exec (nnlib.import_all(), locals(), globals())
|
||||
ed_dims = output_nc * ed_ch_dims
|
||||
|
||||
def upscale (dim):
|
||||
def func(x):
|
||||
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
|
||||
return func
|
||||
|
||||
def func(input):
|
||||
ed_dims = output_nc * ed_ch_dims
|
||||
|
||||
x = input[0]
|
||||
x = upscale(ed_dims*8)(x)
|
||||
x = upscale(ed_dims*4)(x)
|
||||
x = upscale(ed_dims*2)(x)
|
||||
|
||||
x = Conv2D(output_nc, kernel_size=5, padding='same', activation=activation)(x)
|
||||
return x
|
||||
|
||||
def to_bgr ():
|
||||
def func(x):
|
||||
return Conv2D(output_nc, kernel_size=5, padding='same', activation='tanh')(x)
|
||||
return func
|
||||
def func(input):
|
||||
x = input[0]
|
||||
x1 = upscale(ed_dims*8)( x )
|
||||
x1_bgr = to_bgr() ( x1 )
|
||||
|
||||
x2 = upscale(ed_dims*4)( x1 )
|
||||
x2_bgr = to_bgr() ( x2 )
|
||||
|
||||
x3 = upscale(ed_dims*2)( x2 )
|
||||
x3_bgr = to_bgr() ( x3 )
|
||||
|
||||
return [ x1_bgr, x2_bgr, x3_bgr ]
|
||||
return func
|
||||
|
||||
@staticmethod
|
||||
def DFEncFlow(resolution, adapt_k_size, light_enc, ae_dims=512, ed_ch_dims=42):
|
||||
|
@ -466,7 +455,7 @@ class SAEModel(ModelBase):
|
|||
return func
|
||||
|
||||
@staticmethod
|
||||
def DFDecFlow(output_nc, ed_ch_dims=21, activation='tanh'):
|
||||
def DFDecFlow(output_nc, ed_ch_dims=21):
|
||||
exec (nnlib.import_all(), locals(), globals())
|
||||
ed_dims = output_nc * ed_ch_dims
|
||||
|
||||
|
@ -474,15 +463,23 @@ class SAEModel(ModelBase):
|
|||
def func(x):
|
||||
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
|
||||
return func
|
||||
|
||||
def to_bgr ():
|
||||
def func(x):
|
||||
return Conv2D(output_nc, kernel_size=5, padding='same', activation='tanh')(x)
|
||||
return func
|
||||
def func(input):
|
||||
x = input[0]
|
||||
x = upscale(ed_dims*8)(x)
|
||||
x = upscale(ed_dims*4)(x)
|
||||
x = upscale(ed_dims*2)(x)
|
||||
x1 = upscale(ed_dims*8)( x )
|
||||
x1_bgr = to_bgr() ( x1 )
|
||||
|
||||
x = Conv2D(output_nc, kernel_size=5, padding='same', activation=activation)(x)
|
||||
return x
|
||||
x2 = upscale(ed_dims*4)( x1 )
|
||||
x2_bgr = to_bgr() ( x2 )
|
||||
|
||||
x3 = upscale(ed_dims*2)( x2 )
|
||||
x3_bgr = to_bgr() ( x3 )
|
||||
|
||||
return [ x1_bgr, x2_bgr, x3_bgr ]
|
||||
return func
|
||||
|
||||
Model = SAEModel
|
|
@ -75,6 +75,7 @@ Conv2D = keras.layers.Conv2D
|
|||
Conv2DTranspose = keras.layers.Conv2DTranspose
|
||||
SeparableConv2D = keras.layers.SeparableConv2D
|
||||
MaxPooling2D = keras.layers.MaxPooling2D
|
||||
UpSampling2D = keras.layers.UpSampling2D
|
||||
BatchNormalization = keras.layers.BatchNormalization
|
||||
|
||||
LeakyReLU = keras.layers.LeakyReLU
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue