mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 13:02:15 -07:00
SAEHD: speed up for nvidia, duplicate code clean up
This commit is contained in:
parent
627df082d7
commit
3f23135982
2 changed files with 187 additions and 171 deletions
|
@ -112,7 +112,31 @@ class SAEv2Model(ModelBase):
|
|||
self.true_face_training = self.options.get('true_face_training', False)
|
||||
masked_training = True
|
||||
|
||||
class SAEDFModel(object):
|
||||
class CommonModel(object):
|
||||
def downscale (self, dim, kernel_size=5, dilation_rate=1, use_activator=True):
|
||||
def func(x):
|
||||
if not use_activator:
|
||||
return SubpixelDownscaler()(Conv2D(dim // 4, kernel_size=kernel_size, strides=1, dilation_rate=dilation_rate, padding='same')(x))
|
||||
else:
|
||||
return SubpixelDownscaler()(LeakyReLU(0.1)(Conv2D(dim // 4, kernel_size=kernel_size, strides=1, dilation_rate=dilation_rate, padding='same')(x)))
|
||||
return func
|
||||
|
||||
def upscale (self, dim, size=(2,2)):
|
||||
def func(x):
|
||||
return SubpixelUpscaler(size=size)(LeakyReLU(0.1)(Conv2D(dim * np.prod(size) , kernel_size=3, strides=1, padding='same')(x)))
|
||||
return func
|
||||
|
||||
def ResidualBlock(self, dim):
|
||||
def func(inp):
|
||||
x = Conv2D(dim, kernel_size=3, padding='same')(inp)
|
||||
x = LeakyReLU(0.2)(x)
|
||||
x = Conv2D(dim, kernel_size=3, padding='same')(x)
|
||||
x = Add()([x, inp])
|
||||
x = LeakyReLU(0.2)(x)
|
||||
return x
|
||||
return func
|
||||
|
||||
class SAEDFModel(CommonModel):
|
||||
def __init__(self, resolution, ae_dims, e_ch_dims, d_ch_dims, learn_mask):
|
||||
super().__init__()
|
||||
self.learn_mask = learn_mask
|
||||
|
@ -123,51 +147,40 @@ class SAEv2Model(ModelBase):
|
|||
lowest_dense_res = resolution // 16
|
||||
e_dims = output_nc*e_ch_dims
|
||||
|
||||
def downscale (dim, kernel_size=5, dilation_rate=1, use_activator=True):
|
||||
def func(x):
|
||||
if not use_activator:
|
||||
return SubpixelDownscaler()(Conv2D(dim // 4, kernel_size=kernel_size, strides=1, dilation_rate=dilation_rate, padding='same')(x))
|
||||
else:
|
||||
return SubpixelDownscaler()(LeakyReLU(0.1)(Conv2D(dim // 4, kernel_size=kernel_size, strides=1, dilation_rate=dilation_rate, padding='same')(x)))
|
||||
return func
|
||||
|
||||
def upscale (dim, size=(2,2)):
|
||||
def func(x):
|
||||
return SubpixelUpscaler(size=size)(LeakyReLU(0.1)(Conv2D(dim * np.prod(size) , kernel_size=3, strides=1, padding='same')(x)))
|
||||
return func
|
||||
|
||||
|
||||
def enc_flow(e_ch_dims, ae_dims, lowest_dense_res):
|
||||
dims = output_nc * e_ch_dims
|
||||
if dims % 2 != 0:
|
||||
dims += 1
|
||||
|
||||
def func(inp):
|
||||
x = downscale(dims , 3, 1 )(inp)
|
||||
x = downscale(dims*2, 3, 1 )(x)
|
||||
x = downscale(dims*4, 3, 1 )(x)
|
||||
x0 = downscale(dims*8, 3, 1 )(x)
|
||||
|
||||
x = downscale(dims , 5, 1 )(inp)
|
||||
x = downscale(dims*2, 5, 1 )(x)
|
||||
x = downscale(dims*4, 5, 1 )(x)
|
||||
x1 = downscale(dims*8, 5, 1 )(x)
|
||||
|
||||
x = downscale(dims , 5, 2 )(inp)
|
||||
x = downscale(dims*2, 5, 2 )(x)
|
||||
x = downscale(dims*4, 5, 2 )(x)
|
||||
x2 = downscale(dims*8, 5, 2 )(x)
|
||||
|
||||
x = downscale(dims , 7, 2 )(inp)
|
||||
x = downscale(dims*2, 7, 2 )(x)
|
||||
x = downscale(dims*4, 7, 2 )(x)
|
||||
x3 = downscale(dims*8, 7, 2 )(x)
|
||||
|
||||
|
||||
def func(inp):
|
||||
x = self.downscale(dims , 3, 1 )(inp)
|
||||
x = self.downscale(dims*2, 3, 1 )(x)
|
||||
x = self.downscale(dims*4, 3, 1 )(x)
|
||||
x0 = self.downscale(dims*8, 3, 1 )(x)
|
||||
|
||||
x = self.downscale(dims , 5, 1 )(inp)
|
||||
x = self.downscale(dims*2, 5, 1 )(x)
|
||||
x = self.downscale(dims*4, 5, 1 )(x)
|
||||
x1 = self.downscale(dims*8, 5, 1 )(x)
|
||||
|
||||
x = self.downscale(dims , 5, 2 )(inp)
|
||||
x = self.downscale(dims*2, 5, 2 )(x)
|
||||
x = self.downscale(dims*4, 5, 2 )(x)
|
||||
x2 = self.downscale(dims*8, 5, 2 )(x)
|
||||
|
||||
x = self.downscale(dims , 7, 2 )(inp)
|
||||
x = self.downscale(dims*2, 7, 2 )(x)
|
||||
x = self.downscale(dims*4, 7, 2 )(x)
|
||||
x3 = self.downscale(dims*8, 7, 2 )(x)
|
||||
|
||||
x = Concatenate()([x0,x1,x2,x3])
|
||||
|
||||
|
||||
x = Dense(ae_dims)(Flatten()(x))
|
||||
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims)(x)
|
||||
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims))(x)
|
||||
x = upscale(ae_dims)(x)
|
||||
x = self.upscale(ae_dims)(x)
|
||||
return x
|
||||
return func
|
||||
|
||||
|
@ -175,32 +188,24 @@ class SAEv2Model(ModelBase):
|
|||
dims = output_nc * d_ch_dims
|
||||
if dims % 2 != 0:
|
||||
dims += 1
|
||||
|
||||
def ResidualBlock(dim):
|
||||
def func(inp):
|
||||
x = Conv2D(dim, kernel_size=3, padding='same')(inp)
|
||||
x = LeakyReLU(0.2)(x)
|
||||
x = Conv2D(dim, kernel_size=3, padding='same')(x)
|
||||
x = Add()([x, inp])
|
||||
x = LeakyReLU(0.2)(x)
|
||||
return x
|
||||
return func
|
||||
|
||||
|
||||
|
||||
|
||||
def func(x):
|
||||
|
||||
|
||||
for i in [8,4,2]:
|
||||
x = upscale(dims*i)(x)
|
||||
|
||||
x = self.upscale(dims*i)(x)
|
||||
|
||||
if not is_mask:
|
||||
x0 = x
|
||||
x = upscale( (dims*i)//2 )(x)
|
||||
x = ResidualBlock( (dims*i)//2 )(x)
|
||||
x = downscale( dims*i, use_activator=False ) (x)
|
||||
x = self.upscale( (dims*i)//2 )(x)
|
||||
x = self.ResidualBlock( (dims*i)//2 )(x)
|
||||
x = self.downscale( dims*i, use_activator=False ) (x)
|
||||
x = Add()([x, x0])
|
||||
x = LeakyReLU(0.2)(x)
|
||||
|
||||
|
||||
return Conv2D(output_nc, kernel_size=5, padding='same', activation='sigmoid')(x)
|
||||
|
||||
|
||||
return func
|
||||
|
||||
self.encoder = modelify(enc_flow(e_ch_dims, ae_dims, lowest_dense_res)) ( Input(bgr_shape) )
|
||||
|
@ -243,7 +248,7 @@ class SAEv2Model(ModelBase):
|
|||
[self.decoder_dstm, 'decoder_dstm.h5'] ]
|
||||
return ar
|
||||
|
||||
class SAELIAEModel(object):
|
||||
class SAELIAEModel(CommonModel):
|
||||
def __init__(self, resolution, ae_dims, e_ch_dims, d_ch_dims, learn_mask):
|
||||
super().__init__()
|
||||
self.learn_mask = learn_mask
|
||||
|
@ -251,50 +256,37 @@ class SAEv2Model(ModelBase):
|
|||
output_nc = 3
|
||||
bgr_shape = (resolution, resolution, output_nc)
|
||||
mask_shape = (resolution, resolution, 1)
|
||||
|
||||
|
||||
lowest_dense_res = resolution // 16
|
||||
|
||||
def downscale (dim, kernel_size=5, dilation_rate=1, use_activator=True):
|
||||
def func(x):
|
||||
if not use_activator:
|
||||
return SubpixelDownscaler()(Conv2D(dim // 4, kernel_size=kernel_size, strides=1, dilation_rate=dilation_rate, padding='same')(x))
|
||||
else:
|
||||
return SubpixelDownscaler()(LeakyReLU(0.1)(Conv2D(dim // 4, kernel_size=kernel_size, strides=1, dilation_rate=dilation_rate, padding='same')(x)))
|
||||
return func
|
||||
|
||||
def upscale (dim):
|
||||
def func(x):
|
||||
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, kernel_size=3, strides=1, padding='same')(x)))
|
||||
return func
|
||||
|
||||
def enc_flow(e_ch_dims):
|
||||
dims = output_nc*e_ch_dims
|
||||
if dims % 2 != 0:
|
||||
dims += 1
|
||||
|
||||
def func(inp):
|
||||
x = downscale(dims , 3, 1 )(inp)
|
||||
x = downscale(dims*2, 3, 1 )(x)
|
||||
x = downscale(dims*4, 3, 1 )(x)
|
||||
x0 = downscale(dims*8, 3, 1 )(x)
|
||||
|
||||
x = downscale(dims , 5, 1 )(inp)
|
||||
x = downscale(dims*2, 5, 1 )(x)
|
||||
x = downscale(dims*4, 5, 1 )(x)
|
||||
x1 = downscale(dims*8, 5, 1 )(x)
|
||||
|
||||
x = downscale(dims , 5, 2 )(inp)
|
||||
x = downscale(dims*2, 5, 2 )(x)
|
||||
x = downscale(dims*4, 5, 2 )(x)
|
||||
x2 = downscale(dims*8, 5, 2 )(x)
|
||||
|
||||
x = downscale(dims , 7, 2 )(inp)
|
||||
x = downscale(dims*2, 7, 2 )(x)
|
||||
x = downscale(dims*4, 7, 2 )(x)
|
||||
x3 = downscale(dims*8, 7, 2 )(x)
|
||||
|
||||
|
||||
def func(inp):
|
||||
x = self.downscale(dims , 3, 1 )(inp)
|
||||
x = self.downscale(dims*2, 3, 1 )(x)
|
||||
x = self.downscale(dims*4, 3, 1 )(x)
|
||||
x0 = self.downscale(dims*8, 3, 1 )(x)
|
||||
|
||||
x = self.downscale(dims , 5, 1 )(inp)
|
||||
x = self.downscale(dims*2, 5, 1 )(x)
|
||||
x = self.downscale(dims*4, 5, 1 )(x)
|
||||
x1 = self.downscale(dims*8, 5, 1 )(x)
|
||||
|
||||
x = self.downscale(dims , 5, 2 )(inp)
|
||||
x = self.downscale(dims*2, 5, 2 )(x)
|
||||
x = self.downscale(dims*4, 5, 2 )(x)
|
||||
x2 = self.downscale(dims*8, 5, 2 )(x)
|
||||
|
||||
x = self.downscale(dims , 7, 2 )(inp)
|
||||
x = self.downscale(dims*2, 7, 2 )(x)
|
||||
x = self.downscale(dims*4, 7, 2 )(x)
|
||||
x3 = self.downscale(dims*8, 7, 2 )(x)
|
||||
|
||||
x = Concatenate()([x0,x1,x2,x3])
|
||||
|
||||
|
||||
x = Flatten()(x)
|
||||
return x
|
||||
return func
|
||||
|
@ -304,7 +296,7 @@ class SAEv2Model(ModelBase):
|
|||
x = Dense(ae_dims)(x)
|
||||
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims*2)(x)
|
||||
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims*2))(x)
|
||||
x = upscale(ae_dims*2)(x)
|
||||
x = self.upscale(ae_dims*2)(x)
|
||||
return x
|
||||
return func
|
||||
|
||||
|
@ -312,30 +304,20 @@ class SAEv2Model(ModelBase):
|
|||
dims = output_nc * d_ch_dims
|
||||
if dims % 2 != 0:
|
||||
dims += 1
|
||||
|
||||
def ResidualBlock(dim):
|
||||
def func(inp):
|
||||
x = Conv2D(dim, kernel_size=3, padding='same')(inp)
|
||||
x = LeakyReLU(0.2)(x)
|
||||
x = Conv2D(dim, kernel_size=3, padding='same')(x)
|
||||
x = Add()([x, inp])
|
||||
x = LeakyReLU(0.2)(x)
|
||||
return x
|
||||
return func
|
||||
|
||||
|
||||
def func(x):
|
||||
|
||||
|
||||
for i in [8,4,2]:
|
||||
x = upscale(dims*i)(x)
|
||||
|
||||
x = self.upscale(dims*i)(x)
|
||||
|
||||
if not is_mask:
|
||||
x0 = x
|
||||
x = upscale( (dims*i)//2 )(x)
|
||||
x = ResidualBlock( (dims*i)//2 )(x)
|
||||
x = downscale( dims*i, use_activator=False ) (x)
|
||||
x = self.upscale( (dims*i)//2 )(x)
|
||||
x = self.ResidualBlock( (dims*i)//2 )(x)
|
||||
x = self.downscale( dims*i, use_activator=False ) (x)
|
||||
x = Add()([x, x0])
|
||||
x = LeakyReLU(0.2)(x)
|
||||
|
||||
|
||||
return Conv2D(output_nc, kernel_size=5, padding='same', activation='sigmoid')(x)
|
||||
|
||||
return func
|
||||
|
@ -360,7 +342,7 @@ class SAEv2Model(ModelBase):
|
|||
self.warped_src, self.warped_dst = Input(bgr_shape), Input(bgr_shape)
|
||||
self.target_src, self.target_dst = Input(bgr_shape), Input(bgr_shape)
|
||||
self.target_srcm, self.target_dstm = Input(mask_shape), Input(mask_shape)
|
||||
|
||||
|
||||
warped_src_code = self.encoder (self.warped_src)
|
||||
warped_src_inter_AB_code = self.inter_AB (warped_src_code)
|
||||
self.src_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code])
|
||||
|
@ -399,14 +381,14 @@ class SAEv2Model(ModelBase):
|
|||
self.model = SAEDFModel (resolution, ae_dims, ed_ch_dims, ed_ch_dims, learn_mask)
|
||||
elif 'liae' in self.options['archi']:
|
||||
self.model = SAELIAEModel (resolution, ae_dims, ed_ch_dims, ed_ch_dims, learn_mask)
|
||||
|
||||
|
||||
self.opt_dis_model = []
|
||||
|
||||
|
||||
if self.true_face_training:
|
||||
def dis_flow(ndf=256):
|
||||
def func(x):
|
||||
x, = x
|
||||
|
||||
|
||||
code_res = K.int_shape(x)[1]
|
||||
|
||||
x = Conv2D( ndf, 4, strides=2, padding='valid')( ZeroPadding2D(1)(x) )
|
||||
|
@ -414,25 +396,25 @@ class SAEv2Model(ModelBase):
|
|||
|
||||
x = Conv2D( ndf*2, 3, strides=2, padding='valid')( ZeroPadding2D(1)(x) )
|
||||
x = LeakyReLU(0.1)(x)
|
||||
|
||||
|
||||
if code_res > 8:
|
||||
x = Conv2D( ndf*4, 3, strides=2, padding='valid')( ZeroPadding2D(1)(x) )
|
||||
x = LeakyReLU(0.1)(x)
|
||||
|
||||
|
||||
if code_res > 16:
|
||||
x = Conv2D( ndf*8, 3, strides=2, padding='valid')( ZeroPadding2D(1)(x) )
|
||||
x = LeakyReLU(0.1)(x)
|
||||
|
||||
|
||||
if code_res > 32:
|
||||
x = Conv2D( ndf*8, 3, strides=2, padding='valid')( ZeroPadding2D(1)(x) )
|
||||
x = LeakyReLU(0.1)(x)
|
||||
|
||||
return Conv2D( 1, 1, strides=1, padding='valid', activation='sigmoid')(x)
|
||||
return func
|
||||
|
||||
|
||||
sh = [ Input( K.int_shape(self.model.src_code)[1:] ) ]
|
||||
self.dis = modelify(dis_flow()) (sh)
|
||||
|
||||
|
||||
self.opt_dis_model = [ (self.dis, 'dis.h5') ]
|
||||
|
||||
loaded, not_loaded = [], self.model.get_model_filename_list()+self.opt_dis_model
|
||||
|
@ -467,7 +449,7 @@ class SAEv2Model(ModelBase):
|
|||
self.src_dst_opt = RMSprop(lr=5e-5, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||
self.src_dst_mask_opt = RMSprop(lr=5e-5, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||
self.D_opt = RMSprop(lr=5e-5, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||
|
||||
|
||||
src_loss = K.mean ( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_opt, pred_src_src_masked_opt) )
|
||||
src_loss += K.mean ( 10*K.square( target_src_masked_opt - pred_src_src_masked_opt ) )
|
||||
|
||||
|
@ -484,26 +466,26 @@ class SAEv2Model(ModelBase):
|
|||
dst_loss += K.mean( 10*K.square( target_dst_masked_opt - pred_dst_dst_masked_opt ) )
|
||||
|
||||
G_loss = src_loss+dst_loss
|
||||
|
||||
if self.true_face_training:
|
||||
|
||||
if self.true_face_training:
|
||||
def DLoss(labels,logits):
|
||||
return K.mean(K.binary_crossentropy(labels,logits))
|
||||
|
||||
|
||||
src_code_d = self.dis( self.model.src_code )
|
||||
src_code_d_ones = K.ones_like(src_code_d)
|
||||
src_code_d_zeros = K.zeros_like(src_code_d)
|
||||
dst_code_d = self.dis( self.model.dst_code )
|
||||
dst_code_d_ones = K.ones_like(dst_code_d)
|
||||
G_loss += 0.01*DLoss(src_code_d_ones, src_code_d)
|
||||
|
||||
|
||||
loss_D = (DLoss(dst_code_d_ones , dst_code_d) + \
|
||||
DLoss(src_code_d_zeros, src_code_d) ) * 0.5
|
||||
|
||||
self.D_train = K.function ([self.model.warped_src, self.model.warped_dst],[loss_D], self.D_opt.get_updates(loss_D, self.dis.trainable_weights) )
|
||||
|
||||
|
||||
self.src_dst_train = K.function ([self.model.warped_src, self.model.warped_dst, self.model.target_src, self.model.target_srcm, self.model.target_dst, self.model.target_dstm],
|
||||
[src_loss,dst_loss],
|
||||
self.src_dst_opt.get_updates( G_loss, self.model.src_dst_trainable_weights)
|
||||
[src_loss,dst_loss],
|
||||
self.src_dst_opt.get_updates( G_loss, self.model.src_dst_trainable_weights)
|
||||
)
|
||||
|
||||
if self.options['learn_mask']:
|
||||
|
@ -525,14 +507,14 @@ class SAEv2Model(ModelBase):
|
|||
|
||||
if self.is_training_mode:
|
||||
t = SampleProcessor.Types
|
||||
|
||||
|
||||
if self.options['face_type'] == 'h':
|
||||
face_type = t.FACE_TYPE_HALF
|
||||
elif self.options['face_type'] == 'mf':
|
||||
face_type = t.FACE_TYPE_MID_FULL
|
||||
elif self.options['face_type'] == 'f':
|
||||
face_type = t.FACE_TYPE_FULL
|
||||
|
||||
|
||||
t_mode_bgr = t.MODE_BGR if not self.pretrain else t.MODE_BGR_SHUFFLE
|
||||
|
||||
training_data_src_path = self.training_data_src_path
|
||||
|
@ -568,7 +550,7 @@ class SAEv2Model(ModelBase):
|
|||
#override
|
||||
def onSave(self):
|
||||
self.save_weights_safe( self.get_model_filename_list()+self.opt_dis_model )
|
||||
|
||||
|
||||
#override
|
||||
def on_success_train_one_iter(self):
|
||||
if len(self.CA_conv_weights_list) != 0:
|
||||
|
@ -584,9 +566,9 @@ class SAEv2Model(ModelBase):
|
|||
feed = [warped_src, warped_dst, target_src, target_srcm, target_dst, target_dstm]
|
||||
|
||||
src_loss, dst_loss, = self.src_dst_train (feed)
|
||||
|
||||
|
||||
if self.true_face_training:
|
||||
self.D_train([warped_src, warped_dst])
|
||||
self.D_train([warped_src, warped_dst])
|
||||
|
||||
if self.options['learn_mask']:
|
||||
feed = [ warped_src, warped_dst, target_srcm, target_dstm ]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue