mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 21:12:07 -07:00
SAE: added new archi 'vg'
This commit is contained in:
parent
d66829aae4
commit
f0a20b46d3
5 changed files with 378 additions and 119 deletions
|
@ -30,7 +30,7 @@ class SAEModel(ModelBase):
|
|||
self.options['resolution'] = input_int("Resolution (64,128 ?:help skip:128) : ", default_resolution, [64,128], help_message="More resolution requires more VRAM.")
|
||||
self.options['face_type'] = input_str ("Half or Full face? (h/f, ?:help skip:f) : ", default_face_type, ['h','f'], help_message="Half face has better resolution, but covers less area of cheeks.").lower()
|
||||
self.options['learn_mask'] = input_bool ("Learn mask? (y/n, ?:help skip:y) : ", True, help_message="Learning mask can help model to recognize face directions. Learn without mask can reduce model size, in this case converter forced to use 'not predicted mask' that is not smooth as predicted. Model with style values can be learned without mask and produce same quality result.")
|
||||
self.options['archi'] = input_str ("AE architecture (df, liae, ?:help skip:%s) : " % (default_archi) , default_archi, ['df','liae'], help_message="DF keeps faces more natural, while LIAE can fix overly different face shapes.").lower()
|
||||
self.options['archi'] = input_str ("AE architecture (df, liae, vg ?:help skip:%s) : " % (default_archi) , default_archi, ['df','liae','vg'], help_message="'df' keeps faces more natural. 'liae' can fix overly different face shapes. 'vg' - currently testing.").lower()
|
||||
else:
|
||||
self.options['resolution'] = self.options.get('resolution', default_resolution)
|
||||
self.options['face_type'] = self.options.get('face_type', default_face_type)
|
||||
|
@ -48,10 +48,14 @@ class SAEModel(ModelBase):
|
|||
|
||||
if is_first_run:
|
||||
self.options['lighter_encoder'] = input_bool ("Use lightweight encoder? (y/n, ?:help skip:n) : ", False, help_message="Lightweight encoder is 35% faster, requires less VRAM, but sacrificing overall quality.")
|
||||
self.options['multiscale_decoder'] = input_bool ("Use multiscale decoder? (y/n, ?:help skip:y) : ", True, help_message="Multiscale decoder helps to get better details.")
|
||||
|
||||
if self.options['archi'] != 'vg':
|
||||
self.options['multiscale_decoder'] = input_bool ("Use multiscale decoder? (y/n, ?:help skip:y) : ", True, help_message="Multiscale decoder helps to get better details.")
|
||||
else:
|
||||
self.options['lighter_encoder'] = self.options.get('lighter_encoder', False)
|
||||
self.options['multiscale_decoder'] = self.options.get('multiscale_decoder', True)
|
||||
|
||||
if self.options['archi'] != 'vg':
|
||||
self.options['multiscale_decoder'] = self.options.get('multiscale_decoder', True)
|
||||
|
||||
default_face_style_power = 0.0
|
||||
default_bg_style_power = 0.0
|
||||
|
@ -74,17 +78,19 @@ class SAEModel(ModelBase):
|
|||
#override
|
||||
def onInitialize(self, **in_options):
|
||||
exec(nnlib.import_all(), locals(), globals())
|
||||
|
||||
SAEModel.initialize_nn_functions()
|
||||
|
||||
self.set_vram_batch_requirements({1.5:4})
|
||||
|
||||
resolution = self.options['resolution']
|
||||
ae_dims = self.options['ae_dims']
|
||||
ed_ch_dims = self.options['ed_ch_dims']
|
||||
adapt_k_size = False
|
||||
bgr_shape = (resolution, resolution, 3)
|
||||
mask_shape = (resolution, resolution, 1)
|
||||
|
||||
self.ms_count = ms_count = 3 if (self.options['archi'] != 'vg' and self.options['multiscale_decoder']) else 1
|
||||
|
||||
self.ms_count = ms_count = 3 if self.options['multiscale_decoder'] else 1
|
||||
masked_training = True
|
||||
|
||||
epoch_alpha = Input( (1,) )
|
||||
warped_src = Input(bgr_shape)
|
||||
|
@ -101,7 +107,7 @@ class SAEModel(ModelBase):
|
|||
target_dstm_ar = [ Input ( ( mask_shape[0] // (2**i) ,)*2 + (mask_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
|
||||
|
||||
if self.options['archi'] == 'liae':
|
||||
self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, adapt_k_size, self.options['lighter_encoder'], ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape))
|
||||
self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, self.options['lighter_encoder'], ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape))
|
||||
|
||||
enc_output_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
||||
|
||||
|
@ -143,8 +149,8 @@ class SAEModel(ModelBase):
|
|||
pred_dst_dstm = self.decoderm(warped_dst_inter_code)
|
||||
pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
|
||||
|
||||
else:
|
||||
self.encoder = modelify(SAEModel.DFEncFlow(resolution, adapt_k_size, self.options['lighter_encoder'], ae_dims=ae_dims, ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape))
|
||||
elif self.options['archi'] == 'df':
|
||||
self.encoder = modelify(SAEModel.DFEncFlow(resolution, self.options['lighter_encoder'], ae_dims=ae_dims, ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape))
|
||||
|
||||
dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
||||
|
||||
|
@ -173,7 +179,39 @@ class SAEModel(ModelBase):
|
|||
pred_src_srcm = self.decoder_srcm(warped_src_code)
|
||||
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
|
||||
pred_src_dstm = self.decoder_srcm(warped_dst_code)
|
||||
|
||||
elif self.options['archi'] == 'vg':
|
||||
self.encoder = modelify(SAEModel.VGEncFlow(resolution, self.options['lighter_encoder'], ae_dims=ae_dims, ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape))
|
||||
|
||||
dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
||||
|
||||
self.decoder_src = modelify(SAEModel.VGDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2 )) (dec_Inputs)
|
||||
self.decoder_dst = modelify(SAEModel.VGDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2 )) (dec_Inputs)
|
||||
|
||||
if self.options['learn_mask']:
|
||||
self.decoder_srcm = modelify(SAEModel.VGDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (dec_Inputs)
|
||||
self.decoder_dstm = modelify(SAEModel.VGDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (dec_Inputs)
|
||||
|
||||
if not self.is_first_run():
|
||||
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
|
||||
self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5))
|
||||
self.decoder_dst.load_weights (self.get_strpath_storage_for_file(self.decoder_dstH5))
|
||||
if self.options['learn_mask']:
|
||||
self.decoder_srcm.load_weights (self.get_strpath_storage_for_file(self.decoder_srcmH5))
|
||||
self.decoder_dstm.load_weights (self.get_strpath_storage_for_file(self.decoder_dstmH5))
|
||||
|
||||
warped_src_code = self.encoder (warped_src)
|
||||
warped_dst_code = self.encoder (warped_dst)
|
||||
pred_src_src = self.decoder_src(warped_src_code)
|
||||
pred_dst_dst = self.decoder_dst(warped_dst_code)
|
||||
pred_src_dst = self.decoder_src(warped_dst_code)
|
||||
|
||||
|
||||
if self.options['learn_mask']:
|
||||
pred_src_srcm = self.decoder_srcm(warped_src_code)
|
||||
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
|
||||
pred_src_dstm = self.decoder_srcm(warped_dst_code)
|
||||
|
||||
pred_src_src, pred_dst_dst, pred_src_dst, = [ [x] if type(x) != list else x for x in [pred_src_src, pred_dst_dst, pred_src_dst, ] ]
|
||||
|
||||
if self.options['learn_mask']:
|
||||
|
@ -193,11 +231,20 @@ class SAEModel(ModelBase):
|
|||
pred_src_src_sigm_ar = [ x + 1 for x in pred_src_src]
|
||||
pred_dst_dst_sigm_ar = [ x + 1 for x in pred_dst_dst]
|
||||
pred_src_dst_sigm_ar = [ x + 1 for x in pred_src_dst]
|
||||
|
||||
|
||||
target_src_masked_ar = [ target_src_sigm_ar[i]*target_srcm_sigm_ar[i] for i in range(len(target_src_sigm_ar))]
|
||||
target_dst_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(target_dst_sigm_ar))]
|
||||
target_dst_anti_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(target_dst_sigm_ar))]
|
||||
|
||||
|
||||
pred_src_src_masked_ar = [ pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] for i in range(len(pred_src_src_sigm_ar))]
|
||||
pred_dst_dst_masked_ar = [ pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] for i in range(len(pred_dst_dst_sigm_ar))]
|
||||
|
||||
target_src_masked_ar_opt = target_src_masked_ar if masked_training else target_src_sigm_ar
|
||||
target_dst_masked_ar_opt = target_dst_masked_ar if masked_training else target_dst_sigm_ar
|
||||
|
||||
pred_src_src_masked_ar_opt = pred_src_src_masked_ar if masked_training else pred_src_src_sigm_ar
|
||||
pred_dst_dst_masked_ar_opt = pred_dst_dst_masked_ar if masked_training else pred_dst_dst_sigm_ar
|
||||
|
||||
psd_target_dst_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
||||
psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
||||
|
||||
|
@ -215,9 +262,9 @@ class SAEModel(ModelBase):
|
|||
src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights + self.decoder_dstm.trainable_weights
|
||||
|
||||
if not self.options['pixel_loss']:
|
||||
src_loss_batch = sum([ ( 100*K.square( dssim(max_value=2.0)( target_src_masked_ar[i], pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] ) )) for i in range(len(target_src_masked_ar)) ])
|
||||
src_loss_batch = sum([ ( 100*K.square( dssim(max_value=2.0)( target_src_masked_ar_opt[i], pred_src_src_masked_ar_opt[i] ) )) for i in range(len(target_src_masked_ar_opt)) ])
|
||||
else:
|
||||
src_loss_batch = sum([ K.mean ( 100*K.square( target_src_masked_ar[i] - pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] ), axis=[1,2,3]) for i in range(len(target_src_masked_ar)) ])
|
||||
src_loss_batch = sum([ K.mean ( 100*K.square( target_src_masked_ar_opt[i] - pred_src_src_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_src_masked_ar_opt)) ])
|
||||
|
||||
src_loss = K.mean(src_loss_batch)
|
||||
|
||||
|
@ -235,9 +282,9 @@ class SAEModel(ModelBase):
|
|||
src_loss += bg_loss
|
||||
|
||||
if not self.options['pixel_loss']:
|
||||
dst_loss_batch = sum([ ( 100*K.square(dssim(max_value=2.0)( target_dst_masked_ar[i], pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ) )) for i in range(len(target_dst_masked_ar)) ])
|
||||
dst_loss_batch = sum([ ( 100*K.square(dssim(max_value=2.0)( target_dst_masked_ar_opt[i], pred_dst_dst_masked_ar_opt[i] ) )) for i in range(len(target_dst_masked_ar_opt)) ])
|
||||
else:
|
||||
dst_loss_batch = sum([ K.mean ( 100*K.square( target_dst_masked_ar[i] - pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ), axis=[1,2,3]) for i in range(len(target_dst_masked_ar)) ])
|
||||
dst_loss_batch = sum([ K.mean ( 100*K.square( target_dst_masked_ar_opt[i] - pred_dst_dst_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_dst_masked_ar_opt)) ])
|
||||
|
||||
dst_loss = K.mean(dst_loss_batch)
|
||||
|
||||
|
@ -390,28 +437,68 @@ class SAEModel(ModelBase):
|
|||
default_blur_mask_modifier=default_blur_mask_modifier,
|
||||
clip_hborder_mask_per=0.0625 if self.options['face_type'] == 'f' else 0,
|
||||
**in_options)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def LIAEEncFlow(resolution, adapt_k_size, light_enc, ed_ch_dims=42):
|
||||
def initialize_nn_functions():
|
||||
exec (nnlib.import_all(), locals(), globals())
|
||||
|
||||
k_size = resolution // 16 + 1 if adapt_k_size else 5
|
||||
strides = resolution // 32 if adapt_k_size else 2
|
||||
|
||||
class ResidualBlock(object):
|
||||
def __init__(self, filters, kernel_size=3, padding='same', use_reflection_padding=False):
|
||||
self.filters = filters
|
||||
self.kernel_size = kernel_size
|
||||
self.padding = padding #if not use_reflection_padding else 'valid'
|
||||
self.use_reflection_padding = use_reflection_padding
|
||||
|
||||
def __call__(self, inp):
|
||||
var_x = LeakyReLU(alpha=0.2)(inp)
|
||||
|
||||
#if self.use_reflection_padding:
|
||||
# #var_x = ReflectionPadding2D(stride=1, kernel_size=kernel_size)(var_x)
|
||||
|
||||
var_x = Conv2D(self.filters, kernel_size=self.kernel_size, padding=self.padding, kernel_initializer=RandomNormal(0, 0.02) )(var_x)
|
||||
var_x = LeakyReLU(alpha=0.2)(var_x)
|
||||
|
||||
#if self.use_reflection_padding:
|
||||
# #var_x = ReflectionPadding2D(stride=1, kernel_size=kernel_size)(var_x)
|
||||
|
||||
var_x = Conv2D(self.filters, kernel_size=self.kernel_size, padding=self.padding, kernel_initializer=RandomNormal(0, 0.02) )(var_x)
|
||||
var_x = Scale(gamma_init=keras.initializers.Constant(value=0.1))(var_x)
|
||||
var_x = Add()([var_x, inp])
|
||||
var_x = LeakyReLU(alpha=0.2)(var_x)
|
||||
return var_x
|
||||
SAEModel.ResidualBlock = ResidualBlock
|
||||
|
||||
def downscale (dim):
|
||||
def func(x):
|
||||
return LeakyReLU(0.1)(Conv2D(dim, k_size, strides=strides, padding='same')(x))
|
||||
return LeakyReLU(0.1)(Conv2D(dim, kernel_size=5, strides=2, padding='same', kernel_initializer=RandomNormal(0, 0.02))(x))
|
||||
return func
|
||||
|
||||
SAEModel.downscale = downscale
|
||||
|
||||
def downscale_sep (dim):
|
||||
def func(x):
|
||||
return LeakyReLU(0.1)(SeparableConv2D(dim, k_size, strides=strides, padding='same')(x))
|
||||
return LeakyReLU(0.1)(SeparableConv2D(dim, kernel_size=5, strides=2, padding='same', depthwise_initializer=RandomNormal(0, 0.02), pointwise_initializer=RandomNormal(0, 0.02) )(x))
|
||||
return func
|
||||
|
||||
SAEModel.downscale_sep = downscale_sep
|
||||
|
||||
def upscale (dim):
|
||||
def func(x):
|
||||
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
|
||||
return func
|
||||
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, kernel_size=3, strides=1, padding='same', kernel_initializer=RandomNormal(0, 0.02) )(x)))
|
||||
return func
|
||||
SAEModel.upscale = upscale
|
||||
|
||||
def to_bgr (output_nc):
|
||||
def func(x):
|
||||
return Conv2D(output_nc, kernel_size=5, padding='same', activation='tanh', kernel_initializer=RandomNormal(0, 0.02))(x)
|
||||
return func
|
||||
SAEModel.to_bgr = to_bgr
|
||||
|
||||
|
||||
@staticmethod
|
||||
def LIAEEncFlow(resolution, light_enc, ed_ch_dims=42):
|
||||
exec (nnlib.import_all(), locals(), globals())
|
||||
upscale = SAEModel.upscale
|
||||
downscale = SAEModel.downscale
|
||||
downscale_sep = SAEModel.downscale_sep
|
||||
|
||||
def func(input):
|
||||
ed_dims = K.int_shape(input)[-1]*ed_ch_dims
|
||||
|
@ -434,13 +521,9 @@ class SAEModel(ModelBase):
|
|||
@staticmethod
|
||||
def LIAEInterFlow(resolution, ae_dims=256):
|
||||
exec (nnlib.import_all(), locals(), globals())
|
||||
upscale = SAEModel.upscale
|
||||
lowest_dense_res=resolution // 16
|
||||
|
||||
def upscale (dim):
|
||||
def func(x):
|
||||
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
|
||||
return func
|
||||
|
||||
def func(input):
|
||||
x = input[0]
|
||||
x = Dense(ae_dims)(x)
|
||||
|
@ -453,17 +536,10 @@ class SAEModel(ModelBase):
|
|||
@staticmethod
|
||||
def LIAEDecFlow(output_nc,ed_ch_dims=21, multiscale_count=1):
|
||||
exec (nnlib.import_all(), locals(), globals())
|
||||
upscale = SAEModel.upscale
|
||||
to_bgr = SAEModel.to_bgr
|
||||
ed_dims = output_nc * ed_ch_dims
|
||||
|
||||
def upscale (dim):
|
||||
def func(x):
|
||||
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
|
||||
return func
|
||||
|
||||
def to_bgr ():
|
||||
def func(x):
|
||||
return Conv2D(output_nc, kernel_size=5, padding='same', activation='tanh')(x)
|
||||
return func
|
||||
def func(input):
|
||||
x = input[0]
|
||||
|
||||
|
@ -471,45 +547,28 @@ class SAEModel(ModelBase):
|
|||
x1 = upscale(ed_dims*8)( x )
|
||||
|
||||
if multiscale_count >= 3:
|
||||
outputs += [ to_bgr() ( x1 ) ]
|
||||
outputs += [ to_bgr(output_nc) ( x1 ) ]
|
||||
|
||||
x2 = upscale(ed_dims*4)( x1 )
|
||||
|
||||
if multiscale_count >= 2:
|
||||
outputs += [ to_bgr() ( x2 ) ]
|
||||
outputs += [ to_bgr(output_nc) ( x2 ) ]
|
||||
|
||||
x3 = upscale(ed_dims*2)( x2 )
|
||||
|
||||
outputs += [ to_bgr() ( x3 ) ]
|
||||
outputs += [ to_bgr(output_nc) ( x3 ) ]
|
||||
|
||||
return outputs
|
||||
return func
|
||||
|
||||
@staticmethod
|
||||
def DFEncFlow(resolution, adapt_k_size, light_enc, ae_dims=512, ed_ch_dims=42):
|
||||
def DFEncFlow(resolution, light_enc, ae_dims=512, ed_ch_dims=42):
|
||||
exec (nnlib.import_all(), locals(), globals())
|
||||
k_size = resolution // 16 + 1 if adapt_k_size else 5
|
||||
strides = resolution // 32 if adapt_k_size else 2
|
||||
upscale = SAEModel.upscale
|
||||
downscale = SAEModel.downscale
|
||||
downscale_sep = SAEModel.downscale_sep
|
||||
lowest_dense_res = resolution // 16
|
||||
|
||||
def Conv2D (filters, kernel_size, strides=(1, 1), padding='valid', data_format=None, dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer=RandomNormal(0, 0.02), bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None):
|
||||
return keras.layers.Conv2D( filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, data_format=data_format, dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint )
|
||||
|
||||
def downscale (dim):
|
||||
def func(x):
|
||||
return LeakyReLU(0.1)(Conv2D(dim, k_size, strides=strides, padding='same')(x))
|
||||
return func
|
||||
|
||||
def downscale_sep (dim):
|
||||
def func(x):
|
||||
return LeakyReLU(0.1)(SeparableConv2D(dim, k_size, strides=strides, padding='same')(x))
|
||||
return func
|
||||
|
||||
def upscale (dim):
|
||||
def func(x):
|
||||
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
|
||||
return func
|
||||
|
||||
def func(input):
|
||||
x = input
|
||||
|
||||
|
@ -536,20 +595,10 @@ class SAEModel(ModelBase):
|
|||
@staticmethod
|
||||
def DFDecFlow(output_nc, ed_ch_dims=21, multiscale_count=1):
|
||||
exec (nnlib.import_all(), locals(), globals())
|
||||
upscale = SAEModel.upscale
|
||||
to_bgr = SAEModel.to_bgr
|
||||
ed_dims = output_nc * ed_ch_dims
|
||||
|
||||
def Conv2D (filters, kernel_size, strides=(1, 1), padding='valid', data_format=None, dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer=RandomNormal(0, 0.02), bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None):
|
||||
return keras.layers.Conv2D( filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, data_format=data_format, dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint )
|
||||
|
||||
def upscale (dim):
|
||||
def func(x):
|
||||
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
|
||||
return func
|
||||
|
||||
def to_bgr ():
|
||||
def func(x):
|
||||
return Conv2D(output_nc, kernel_size=5, padding='same', activation='tanh')(x)
|
||||
return func
|
||||
def func(input):
|
||||
x = input[0]
|
||||
|
||||
|
@ -557,18 +606,95 @@ class SAEModel(ModelBase):
|
|||
x1 = upscale(ed_dims*8)( x )
|
||||
|
||||
if multiscale_count >= 3:
|
||||
outputs += [ to_bgr() ( x1 ) ]
|
||||
outputs += [ to_bgr(output_nc) ( x1 ) ]
|
||||
|
||||
x2 = upscale(ed_dims*4)( x1 )
|
||||
|
||||
if multiscale_count >= 2:
|
||||
outputs += [ to_bgr() ( x2 ) ]
|
||||
outputs += [ to_bgr(output_nc) ( x2 ) ]
|
||||
|
||||
x3 = upscale(ed_dims*2)( x2 )
|
||||
|
||||
outputs += [ to_bgr() ( x3 ) ]
|
||||
outputs += [ to_bgr(output_nc) ( x3 ) ]
|
||||
|
||||
return outputs
|
||||
return outputs
|
||||
return func
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def VGEncFlow(resolution, light_enc, ae_dims=512, ed_ch_dims=42):
|
||||
exec (nnlib.import_all(), locals(), globals())
|
||||
upscale = SAEModel.upscale
|
||||
downscale = SAEModel.downscale
|
||||
downscale_sep = SAEModel.downscale_sep
|
||||
ResidualBlock = SAEModel.ResidualBlock
|
||||
lowest_dense_res = resolution // 16
|
||||
|
||||
def func(input):
|
||||
x = input
|
||||
ed_dims = K.int_shape(input)[-1]*ed_ch_dims
|
||||
while np.modf(ed_dims / 4)[0] != 0.0:
|
||||
ed_dims -= 1
|
||||
|
||||
in_conv_filters = ed_dims if resolution <= 128 else ed_dims + (resolution//128)*ed_ch_dims
|
||||
|
||||
x = tmp_x = Conv2D (in_conv_filters, kernel_size=5, strides=2, padding='same') (x)
|
||||
|
||||
for _ in range ( 8 if light_enc else 16 ):
|
||||
x = ResidualBlock(ed_dims)(x)
|
||||
|
||||
x = Add()([x, tmp_x])
|
||||
|
||||
x = downscale(ed_dims)(x)
|
||||
x = SubpixelUpscaler()(x)
|
||||
|
||||
x = downscale(ed_dims)(x)
|
||||
x = SubpixelUpscaler()(x)
|
||||
|
||||
x = downscale(ed_dims)(x)
|
||||
if light_enc:
|
||||
x = downscale_sep (ed_dims*2)(x)
|
||||
else:
|
||||
x = downscale (ed_dims*2)(x)
|
||||
|
||||
x = downscale(ed_dims*4)(x)
|
||||
|
||||
if light_enc:
|
||||
x = downscale_sep (ed_dims*8)(x)
|
||||
else:
|
||||
x = downscale (ed_dims*8)(x)
|
||||
|
||||
x = Dense(ae_dims)(Flatten()(x))
|
||||
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims)(x)
|
||||
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims))(x)
|
||||
x = upscale(ae_dims)(x)
|
||||
return x
|
||||
|
||||
return func
|
||||
|
||||
@staticmethod
|
||||
def VGDecFlow(output_nc, ed_ch_dims=21, multiscale_count=1):
|
||||
exec (nnlib.import_all(), locals(), globals())
|
||||
upscale = SAEModel.upscale
|
||||
to_bgr = SAEModel.to_bgr
|
||||
ResidualBlock = SAEModel.ResidualBlock
|
||||
ed_dims = output_nc * ed_ch_dims
|
||||
|
||||
def func(input):
|
||||
x = input[0]
|
||||
|
||||
x = upscale( ed_dims*8 )(x)
|
||||
x = ResidualBlock( ed_dims*8 )(x)
|
||||
|
||||
x = upscale( ed_dims*4 )(x)
|
||||
x = ResidualBlock( ed_dims*4 )(x)
|
||||
|
||||
x = upscale( ed_dims*2 )(x)
|
||||
x = ResidualBlock( ed_dims*2 )(x)
|
||||
|
||||
x = to_bgr(output_nc) (x)
|
||||
return x
|
||||
|
||||
return func
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue