mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 21:12:07 -07:00
enhanced SAE model. You should to restart training.
new default style power = 2.0 fix DF default batch sizes. upd readme
This commit is contained in:
parent
22401cecfc
commit
946688567d
5 changed files with 137 additions and 134 deletions
|
@ -28,21 +28,21 @@ class SAEModel(ModelBase):
|
|||
|
||||
if is_first_run:
|
||||
self.options['resolution'] = input_int("Resolution (64,128 ?:help skip:128) : ", default_resolution, [64,128], help_message="More resolution requires more VRAM.")
|
||||
self.options['archi'] = input_str ("AE architecture (df, liae, ?:help skip:liae) : ", default_archi, ['df','liae'], help_message="DF keeps faces more natural, while LIAE can fix overly different face shapes.").lower()
|
||||
self.options['archi'] = input_str ("AE architecture (df, liae, ?:help skip:%s) : " % (default_archi) , default_archi, ['df','liae'], help_message="DF keeps faces more natural, while LIAE can fix overly different face shapes.").lower()
|
||||
self.options['lighter_encoder'] = input_bool ("Use lightweight encoder? (y/n, ?:help skip:n) : ", False, help_message="Lightweight encoder is 35% faster, requires less VRAM, sacrificing overall quality.")
|
||||
else:
|
||||
self.options['resolution'] = self.options.get('resolution', default_resolution)
|
||||
self.options['archi'] = self.options.get('archi', default_archi)
|
||||
self.options['lighter_encoder'] = self.options.get('lighter_encoder', False)
|
||||
|
||||
default_face_style_power = 10.0
|
||||
default_face_style_power = 2.0
|
||||
if is_first_run or ask_override:
|
||||
default_face_style_power = default_face_style_power if is_first_run else self.options.get('face_style_power', default_face_style_power)
|
||||
self.options['face_style_power'] = np.clip ( input_number("Face style power ( 0.0 .. 100.0 ?:help skip:%.1f) : " % (default_face_style_power), default_face_style_power, help_message="How fast NN will learn dst face style during generalization of src and dst faces."), 0.0, 100.0 )
|
||||
else:
|
||||
self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power)
|
||||
|
||||
default_bg_style_power = 10.0
|
||||
default_bg_style_power = 2.0
|
||||
if is_first_run or ask_override:
|
||||
default_bg_style_power = default_bg_style_power if is_first_run else self.options.get('bg_style_power', default_bg_style_power)
|
||||
self.options['bg_style_power'] = np.clip ( input_number("Background style power ( 0.0 .. 100.0 ?:help skip:%.1f) : " % (default_bg_style_power), default_bg_style_power, help_message="How fast NN will learn dst background style during generalization of src and dst faces."), 0.0, 100.0 )
|
||||
|
@ -102,8 +102,7 @@ class SAEModel(ModelBase):
|
|||
|
||||
self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (inter_output_Inputs)
|
||||
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (inter_output_Inputs)
|
||||
|
||||
|
||||
|
||||
if not self.is_first_run():
|
||||
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
|
||||
self.inter_B.load_weights (self.get_strpath_storage_for_file(self.inter_BH5))
|
||||
|
@ -132,16 +131,14 @@ class SAEModel(ModelBase):
|
|||
pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
|
||||
else:
|
||||
self.encoder = modelify(SAEModel.DFEncFlow(resolution, adapt_k_size, self.options['lighter_encoder'], ae_dims=ae_dims, ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape))
|
||||
|
||||
|
||||
dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
||||
|
||||
self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (dec_Inputs)
|
||||
self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2)) (dec_Inputs)
|
||||
|
||||
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5))) (dec_Inputs)
|
||||
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5))) (dec_Inputs)
|
||||
|
||||
|
||||
|
||||
if not self.is_first_run():
|
||||
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
|
||||
self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5))
|
||||
|
@ -150,114 +147,106 @@ class SAEModel(ModelBase):
|
|||
self.decoder_dstm.load_weights (self.get_strpath_storage_for_file(self.decoder_dstmH5))
|
||||
|
||||
warped_src_code = self.encoder (warped_src)
|
||||
pred_src_src = self.decoder_src(warped_src_code)
|
||||
|
||||
pred_src_src = self.decoder_src(warped_src_code)
|
||||
pred_src_srcm = self.decoder_srcm(warped_src_code)
|
||||
|
||||
|
||||
warped_dst_code = self.encoder (warped_dst)
|
||||
pred_dst_dst = self.decoder_dst(warped_dst_code)
|
||||
|
||||
pred_dst_dst = self.decoder_dst(warped_dst_code)
|
||||
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
|
||||
|
||||
pred_src_dst = self.decoder_src(warped_dst_code)
|
||||
pred_src_dstm = self.decoder_srcm(warped_dst_code)
|
||||
|
||||
|
||||
target_srcm_blurred = tf_gaussian_blur(resolution // 32)(target_srcm)
|
||||
target_srcm_sigm = target_srcm_blurred / 2.0 + 0.5
|
||||
target_srcm_anti_sigm = 1.0 - target_srcm_sigm
|
||||
|
||||
target_dstm_blurred = tf_gaussian_blur(resolution // 32)(target_dstm)
|
||||
target_dstm_sigm = target_dstm_blurred / 2.0 + 0.5
|
||||
target_dstm_anti_sigm = 1.0 - target_dstm_sigm
|
||||
ms_count = len(pred_src_src)
|
||||
|
||||
target_src_sigm = target_src+1
|
||||
target_dst_sigm = target_dst+1
|
||||
|
||||
pred_src_src_sigm = pred_src_src+1
|
||||
pred_dst_dst_sigm = pred_dst_dst+1
|
||||
pred_src_dst_sigm = pred_src_dst+1
|
||||
target_src_ar = [ target_src if i == 0 else tf.image.resize_bicubic( target_src, (resolution // (2**i) ,)*2 ) for i in range(ms_count-1, -1, -1)]
|
||||
target_srcm_ar = [ target_srcm if i == 0 else tf.image.resize_bicubic( target_srcm, (resolution // (2**i) ,)*2 ) for i in range(ms_count-1, -1, -1)]
|
||||
target_dst_ar = [ target_dst if i == 0 else tf.image.resize_bicubic( target_dst, (resolution // (2**i) ,)*2 ) for i in range(ms_count-1, -1, -1)]
|
||||
target_dstm_ar = [ target_dstm if i == 0 else tf.image.resize_bicubic( target_dstm, (resolution // (2**i) ,)*2 ) for i in range(ms_count-1, -1, -1)]
|
||||
|
||||
target_src_masked = target_src_sigm*target_srcm_sigm
|
||||
target_srcm_blurred_ar = [ tf_gaussian_blur( max(1, x.get_shape().as_list()[1] // 32) )(x) for x in target_srcm_ar]
|
||||
target_srcm_sigm_ar = [ x / 2.0 + 0.5 for x in target_srcm_blurred_ar]
|
||||
target_srcm_anti_sigm_ar = [ 1.0 - x for x in target_srcm_sigm_ar]
|
||||
|
||||
target_dstm_blurred_ar = [ tf_gaussian_blur( max(1, x.get_shape().as_list()[1] // 32) )(x) for x in target_dstm_ar]
|
||||
target_dstm_sigm_ar = [ x / 2.0 + 0.5 for x in target_dstm_blurred_ar]
|
||||
target_dstm_anti_sigm_ar = [ 1.0 - x for x in target_dstm_sigm_ar]
|
||||
|
||||
target_dst_masked = target_dst_sigm * target_dstm_sigm
|
||||
target_dst_anti_masked = target_dst_sigm * target_dstm_anti_sigm
|
||||
|
||||
pred_src_src_masked = pred_src_src_sigm * target_srcm_sigm
|
||||
pred_dst_dst_masked = pred_dst_dst_sigm * target_dstm_sigm
|
||||
|
||||
psd_target_dst_masked = pred_src_dst_sigm * target_dstm_sigm
|
||||
psd_target_dst_anti_masked = pred_src_dst_sigm * target_dstm_anti_sigm
|
||||
target_src_sigm_ar = [ x + 1 for x in target_src_ar]
|
||||
target_dst_sigm_ar = [ x + 1 for x in target_dst_ar]
|
||||
|
||||
pred_src_src_sigm_ar = [ x + 1 for x in pred_src_src]
|
||||
pred_dst_dst_sigm_ar = [ x + 1 for x in pred_dst_dst]
|
||||
pred_src_dst_sigm_ar = [ x + 1 for x in pred_src_dst]
|
||||
|
||||
target_src_masked_ar = [ target_src_sigm_ar[i]*target_srcm_sigm_ar[i] for i in range(len(target_src_sigm_ar))]
|
||||
target_dst_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(target_dst_sigm_ar))]
|
||||
target_dst_anti_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(target_dst_sigm_ar))]
|
||||
|
||||
src_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked, pred_src_src_masked )) )
|
||||
|
||||
def optimizer():
|
||||
return Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)
|
||||
|
||||
if self.options['face_style_power'] != 0:
|
||||
face_style_power = self.options['face_style_power'] / 100.0
|
||||
src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*face_style_power)(psd_target_dst_masked, target_dst_masked)
|
||||
|
||||
if self.options['bg_style_power'] != 0:
|
||||
bg_style_power = self.options['bg_style_power'] / 100.0
|
||||
src_loss += K.mean( (100*bg_style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked, target_dst_anti_masked )))
|
||||
|
||||
if self.options['archi'] == 'liae':
|
||||
src_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
||||
else:
|
||||
src_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights
|
||||
self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss],
|
||||
optimizer().get_updates(src_loss, src_train_weights) )
|
||||
|
||||
dst_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked, pred_dst_dst_masked )) )
|
||||
|
||||
if self.options['archi'] == 'liae':
|
||||
dst_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
||||
else:
|
||||
dst_train_weights = self.encoder.trainable_weights + self.decoder_dst.trainable_weights
|
||||
self.dst_train = K.function ([warped_dst, target_dst, target_dstm],[dst_loss],
|
||||
optimizer().get_updates(dst_loss, dst_train_weights) )
|
||||
|
||||
src_mask_loss = K.mean(K.square(target_srcm-pred_src_srcm))
|
||||
|
||||
if self.options['archi'] == 'liae':
|
||||
src_mask_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
|
||||
else:
|
||||
src_mask_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights
|
||||
|
||||
self.src_mask_train = K.function ([warped_src, target_srcm],[src_mask_loss],
|
||||
optimizer().get_updates(src_mask_loss, src_mask_train_weights ) )
|
||||
|
||||
dst_mask_loss = K.mean(K.square(target_dstm-pred_dst_dstm))
|
||||
|
||||
if self.options['archi'] == 'liae':
|
||||
dst_mask_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
|
||||
else:
|
||||
dst_mask_train_weights = self.encoder.trainable_weights + self.decoder_dstm.trainable_weights
|
||||
|
||||
self.dst_mask_train = K.function ([warped_dst, target_dstm],[dst_mask_loss],
|
||||
optimizer().get_updates(dst_mask_loss, dst_mask_train_weights) )
|
||||
|
||||
self.AE_view = K.function ([warped_src, warped_dst],[pred_src_src, pred_src_srcm, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm])
|
||||
self.AE_convert = K.function ([warped_dst],[pred_src_dst, pred_src_dstm])
|
||||
psd_target_dst_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
||||
psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
||||
|
||||
if self.is_training_mode:
|
||||
f = SampleProcessor.TypeFlags
|
||||
def optimizer():
|
||||
return Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)
|
||||
|
||||
face_type = f.FACE_ALIGN_FULL if self.options['face_type'] == 'f' else f.FACE_ALIGN_HALF
|
||||
|
||||
if self.options['archi'] == 'liae':
|
||||
src_loss_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
||||
src_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
|
||||
dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
||||
dst_mask_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
|
||||
else:
|
||||
src_loss_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights
|
||||
src_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights
|
||||
dst_loss_train_weights = self.encoder.trainable_weights + self.decoder_dst.trainable_weights
|
||||
dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_dstm.trainable_weights
|
||||
|
||||
src_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked_ar[i], pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] ) )) for i in range(len(target_src_masked_ar)) ])
|
||||
|
||||
if self.options['face_style_power'] != 0:
|
||||
face_style_power = self.options['face_style_power'] / 100.0
|
||||
src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*face_style_power)( psd_target_dst_masked_ar[-1], target_dst_masked_ar[-1] )
|
||||
|
||||
if self.options['bg_style_power'] != 0:
|
||||
bg_style_power = self.options['bg_style_power'] / 100.0
|
||||
src_loss += K.mean( (100*bg_style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] )))
|
||||
|
||||
self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss], optimizer().get_updates(src_loss, src_loss_train_weights) )
|
||||
|
||||
src_mask_loss = sum([ K.mean(K.square(target_srcm_ar[i]-pred_src_srcm[i])) for i in range(len(target_srcm_ar)) ])
|
||||
self.src_mask_train = K.function ([warped_src, target_srcm],[src_mask_loss], optimizer().get_updates(src_mask_loss, src_mask_loss_train_weights) )
|
||||
|
||||
dst_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked_ar[i], pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ) )) for i in range(len(target_dst_masked_ar)) ])
|
||||
self.dst_train = K.function ([warped_dst, target_dst, target_dstm ],[dst_loss], optimizer().get_updates(dst_loss, dst_loss_train_weights) )
|
||||
|
||||
dst_mask_loss = sum([ K.mean(K.square(target_dstm_ar[i]-pred_dst_dstm[i])) for i in range(len(target_dstm_ar)) ])
|
||||
self.dst_mask_train = K.function ([warped_dst, target_dstm],[dst_mask_loss], optimizer().get_updates(dst_mask_loss, dst_mask_loss_train_weights) )
|
||||
|
||||
self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_src_srcm[-1], pred_dst_dst[-1], pred_dst_dstm[-1], pred_src_dst[-1], pred_src_dstm[-1]] )
|
||||
else:
|
||||
self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1], pred_src_dstm[-1] ])
|
||||
|
||||
if self.is_training_mode:
|
||||
f = SampleProcessor.TypeFlags
|
||||
face_type = f.FACE_ALIGN_FULL if self.options['face_type'] == 'f' else f.FACE_ALIGN_HALF
|
||||
self.set_training_data_generators ([
|
||||
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
|
||||
debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, normalize_tanh = True, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
|
||||
output_sample_types=[ [f.WARPED_TRANSFORMED | face_type | f.MODE_BGR, resolution],
|
||||
output_sample_types=[ [f.WARPED_TRANSFORMED | face_type | f.MODE_BGR, resolution],
|
||||
[f.TRANSFORMED | face_type | f.MODE_BGR, resolution],
|
||||
[f.TRANSFORMED | face_type | f.MODE_M | f.FACE_MASK_FULL, resolution]
|
||||
[f.TRANSFORMED | face_type | f.MODE_M | f.FACE_MASK_FULL, resolution]
|
||||
] ),
|
||||
|
||||
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, normalize_tanh = True),
|
||||
output_sample_types=[ [f.WARPED_TRANSFORMED | face_type | f.MODE_BGR, resolution],
|
||||
output_sample_types=[ [f.WARPED_TRANSFORMED | face_type | f.MODE_BGR, resolution],
|
||||
[f.TRANSFORMED | face_type | f.MODE_BGR, resolution],
|
||||
[f.TRANSFORMED | face_type | f.MODE_M | f.FACE_MASK_FULL, resolution] ] )
|
||||
[f.TRANSFORMED | face_type | f.MODE_M | f.FACE_MASK_FULL, resolution]
|
||||
] )
|
||||
])
|
||||
#override
|
||||
def onSave(self):
|
||||
|
@ -279,7 +268,7 @@ class SAEModel(ModelBase):
|
|||
#override
|
||||
def onTrainOneEpoch(self, sample):
|
||||
warped_src, target_src, target_src_mask = sample[0]
|
||||
warped_dst, target_dst, target_dst_mask = sample[1]
|
||||
warped_dst, target_dst, target_dst_mask = sample[1]
|
||||
|
||||
src_loss, = self.src_train ([warped_src, target_src, target_src_mask, warped_dst, target_dst, target_dst_mask])
|
||||
dst_loss, = self.dst_train ([warped_dst, target_dst, target_dst_mask])
|
||||
|
@ -296,24 +285,19 @@ class SAEModel(ModelBase):
|
|||
test_A_m = sample[0][2][0:4] #first 4 samples
|
||||
test_B = sample[1][1][0:4]
|
||||
test_B_m = sample[1][2][0:4]
|
||||
|
||||
S = test_A
|
||||
D = test_B
|
||||
|
||||
SS, SM, DD, DM, SD, SDM = self.AE_view ([test_A, test_B])
|
||||
S, D, SS, SM, DD, DM, SD, SDM = [ x / 2 + 0.5 for x in [S, D, SS, SM, DD, DM, SD, SDM] ]
|
||||
|
||||
SM, DM, SDM = [ np.repeat (x, (3,), -1) for x in [SM, DM, SDM] ]
|
||||
|
||||
st = []
|
||||
S, D, SS, SM, DD, DM, SD, SDM, = [ x / 2 + 0.5 for x in ([test_A,test_B] + self.AE_view ([test_A, test_B]) ) ]
|
||||
#SM, DM, SDM = [ np.repeat (x, (3,), -1) for x in [SM, DM, SDM] ]
|
||||
|
||||
st_x3 = []
|
||||
for i in range(0, len(test_A)):
|
||||
st.append ( np.concatenate ( (
|
||||
st_x3.append ( np.concatenate ( (
|
||||
S[i], SS[i], #SM[i],
|
||||
D[i], DD[i], #DM[i],
|
||||
SD[i], #SDM[i]
|
||||
), axis=1) )
|
||||
|
||||
return [ ('SAE', np.concatenate ( st, axis=0 ) ) ]
|
||||
|
||||
return [ ('SAE', np.concatenate (st_x3, axis=0 )), ]
|
||||
|
||||
def predictor_func (self, face):
|
||||
face = face * 2.0 - 1.0
|
||||
|
@ -371,7 +355,7 @@ class SAEModel(ModelBase):
|
|||
x = downscale(ed_dims*8)(x)
|
||||
else:
|
||||
x = downscale_sep(ed_dims*2)(x)
|
||||
x = downscale_sep(ed_dims*4)(x)
|
||||
x = downscale(ed_dims*4)(x)
|
||||
x = downscale_sep(ed_dims*8)(x)
|
||||
|
||||
x = Flatten()(x)
|
||||
|
@ -400,26 +384,31 @@ class SAEModel(ModelBase):
|
|||
@staticmethod
|
||||
def LIAEDecFlow(output_nc,ed_ch_dims=21,activation='tanh'):
|
||||
exec (nnlib.import_all(), locals(), globals())
|
||||
ed_dims = output_nc * ed_ch_dims
|
||||
|
||||
def upscale (dim):
|
||||
def func(x):
|
||||
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
|
||||
return func
|
||||
|
||||
def func(input):
|
||||
ed_dims = output_nc * ed_ch_dims
|
||||
return func
|
||||
|
||||
def to_bgr ():
|
||||
def func(x):
|
||||
return Conv2D(output_nc, kernel_size=5, padding='same', activation='tanh')(x)
|
||||
return func
|
||||
def func(input):
|
||||
x = input[0]
|
||||
x = upscale(ed_dims*8)(x)
|
||||
x = upscale(ed_dims*4)(x)
|
||||
x = upscale(ed_dims*2)(x)
|
||||
|
||||
x = Conv2D(output_nc, kernel_size=5, padding='same', activation=activation)(x)
|
||||
return x
|
||||
x1 = upscale(ed_dims*8)( x )
|
||||
x1_bgr = to_bgr() ( x1 )
|
||||
|
||||
return func
|
||||
|
||||
x2 = upscale(ed_dims*4)( x1 )
|
||||
x2_bgr = to_bgr() ( x2 )
|
||||
|
||||
x3 = upscale(ed_dims*2)( x2 )
|
||||
x3_bgr = to_bgr() ( x3 )
|
||||
|
||||
return [ x1_bgr, x2_bgr, x3_bgr ]
|
||||
return func
|
||||
|
||||
@staticmethod
|
||||
def DFEncFlow(resolution, adapt_k_size, light_enc, ae_dims=512, ed_ch_dims=42):
|
||||
exec (nnlib.import_all(), locals(), globals())
|
||||
|
@ -466,23 +455,31 @@ class SAEModel(ModelBase):
|
|||
return func
|
||||
|
||||
@staticmethod
|
||||
def DFDecFlow(output_nc, ed_ch_dims=21, activation='tanh'):
|
||||
def DFDecFlow(output_nc, ed_ch_dims=21):
|
||||
exec (nnlib.import_all(), locals(), globals())
|
||||
ed_dims = output_nc * ed_ch_dims
|
||||
|
||||
def upscale (dim):
|
||||
def func(x):
|
||||
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
|
||||
return func
|
||||
def func(input):
|
||||
x = input[0]
|
||||
x = upscale(ed_dims*8)(x)
|
||||
x = upscale(ed_dims*4)(x)
|
||||
x = upscale(ed_dims*2)(x)
|
||||
|
||||
x = Conv2D(output_nc, kernel_size=5, padding='same', activation=activation)(x)
|
||||
return x
|
||||
return func
|
||||
|
||||
def to_bgr ():
|
||||
def func(x):
|
||||
return Conv2D(output_nc, kernel_size=5, padding='same', activation='tanh')(x)
|
||||
return func
|
||||
def func(input):
|
||||
x = input[0]
|
||||
x1 = upscale(ed_dims*8)( x )
|
||||
x1_bgr = to_bgr() ( x1 )
|
||||
|
||||
x2 = upscale(ed_dims*4)( x1 )
|
||||
x2_bgr = to_bgr() ( x2 )
|
||||
|
||||
x3 = upscale(ed_dims*2)( x2 )
|
||||
x3_bgr = to_bgr() ( x3 )
|
||||
|
||||
return [ x1_bgr, x2_bgr, x3_bgr ]
|
||||
return func
|
||||
|
||||
Model = SAEModel
|
Loading…
Add table
Add a link
Reference in a new issue