SAEHD: speed up for nvidia, duplicate code clean up

This commit is contained in:
Colombo 2019-10-08 21:02:20 +04:00
parent 627df082d7
commit 3f23135982
2 changed files with 187 additions and 171 deletions

View file

@ -112,7 +112,31 @@ class SAEv2Model(ModelBase):
self.true_face_training = self.options.get('true_face_training', False)
masked_training = True
class SAEDFModel(object):
class CommonModel(object):
def downscale (self, dim, kernel_size=5, dilation_rate=1, use_activator=True):
def func(x):
if not use_activator:
return SubpixelDownscaler()(Conv2D(dim // 4, kernel_size=kernel_size, strides=1, dilation_rate=dilation_rate, padding='same')(x))
else:
return SubpixelDownscaler()(LeakyReLU(0.1)(Conv2D(dim // 4, kernel_size=kernel_size, strides=1, dilation_rate=dilation_rate, padding='same')(x)))
return func
def upscale (self, dim, size=(2,2)):
def func(x):
return SubpixelUpscaler(size=size)(LeakyReLU(0.1)(Conv2D(dim * np.prod(size) , kernel_size=3, strides=1, padding='same')(x)))
return func
def ResidualBlock(self, dim):
def func(inp):
x = Conv2D(dim, kernel_size=3, padding='same')(inp)
x = LeakyReLU(0.2)(x)
x = Conv2D(dim, kernel_size=3, padding='same')(x)
x = Add()([x, inp])
x = LeakyReLU(0.2)(x)
return x
return func
class SAEDFModel(CommonModel):
def __init__(self, resolution, ae_dims, e_ch_dims, d_ch_dims, learn_mask):
super().__init__()
self.learn_mask = learn_mask
@ -123,51 +147,40 @@ class SAEv2Model(ModelBase):
lowest_dense_res = resolution // 16
e_dims = output_nc*e_ch_dims
def downscale (dim, kernel_size=5, dilation_rate=1, use_activator=True):
def func(x):
if not use_activator:
return SubpixelDownscaler()(Conv2D(dim // 4, kernel_size=kernel_size, strides=1, dilation_rate=dilation_rate, padding='same')(x))
else:
return SubpixelDownscaler()(LeakyReLU(0.1)(Conv2D(dim // 4, kernel_size=kernel_size, strides=1, dilation_rate=dilation_rate, padding='same')(x)))
return func
def upscale (dim, size=(2,2)):
def func(x):
return SubpixelUpscaler(size=size)(LeakyReLU(0.1)(Conv2D(dim * np.prod(size) , kernel_size=3, strides=1, padding='same')(x)))
return func
def enc_flow(e_ch_dims, ae_dims, lowest_dense_res):
dims = output_nc * e_ch_dims
if dims % 2 != 0:
dims += 1
def func(inp):
x = downscale(dims , 3, 1 )(inp)
x = downscale(dims*2, 3, 1 )(x)
x = downscale(dims*4, 3, 1 )(x)
x0 = downscale(dims*8, 3, 1 )(x)
x = downscale(dims , 5, 1 )(inp)
x = downscale(dims*2, 5, 1 )(x)
x = downscale(dims*4, 5, 1 )(x)
x1 = downscale(dims*8, 5, 1 )(x)
x = downscale(dims , 5, 2 )(inp)
x = downscale(dims*2, 5, 2 )(x)
x = downscale(dims*4, 5, 2 )(x)
x2 = downscale(dims*8, 5, 2 )(x)
x = downscale(dims , 7, 2 )(inp)
x = downscale(dims*2, 7, 2 )(x)
x = downscale(dims*4, 7, 2 )(x)
x3 = downscale(dims*8, 7, 2 )(x)
def func(inp):
x = self.downscale(dims , 3, 1 )(inp)
x = self.downscale(dims*2, 3, 1 )(x)
x = self.downscale(dims*4, 3, 1 )(x)
x0 = self.downscale(dims*8, 3, 1 )(x)
x = self.downscale(dims , 5, 1 )(inp)
x = self.downscale(dims*2, 5, 1 )(x)
x = self.downscale(dims*4, 5, 1 )(x)
x1 = self.downscale(dims*8, 5, 1 )(x)
x = self.downscale(dims , 5, 2 )(inp)
x = self.downscale(dims*2, 5, 2 )(x)
x = self.downscale(dims*4, 5, 2 )(x)
x2 = self.downscale(dims*8, 5, 2 )(x)
x = self.downscale(dims , 7, 2 )(inp)
x = self.downscale(dims*2, 7, 2 )(x)
x = self.downscale(dims*4, 7, 2 )(x)
x3 = self.downscale(dims*8, 7, 2 )(x)
x = Concatenate()([x0,x1,x2,x3])
x = Dense(ae_dims)(Flatten()(x))
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims)(x)
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims))(x)
x = upscale(ae_dims)(x)
x = self.upscale(ae_dims)(x)
return x
return func
@ -175,32 +188,24 @@ class SAEv2Model(ModelBase):
dims = output_nc * d_ch_dims
if dims % 2 != 0:
dims += 1
def ResidualBlock(dim):
def func(inp):
x = Conv2D(dim, kernel_size=3, padding='same')(inp)
x = LeakyReLU(0.2)(x)
x = Conv2D(dim, kernel_size=3, padding='same')(x)
x = Add()([x, inp])
x = LeakyReLU(0.2)(x)
return x
return func
def func(x):
for i in [8,4,2]:
x = upscale(dims*i)(x)
x = self.upscale(dims*i)(x)
if not is_mask:
x0 = x
x = upscale( (dims*i)//2 )(x)
x = ResidualBlock( (dims*i)//2 )(x)
x = downscale( dims*i, use_activator=False ) (x)
x = self.upscale( (dims*i)//2 )(x)
x = self.ResidualBlock( (dims*i)//2 )(x)
x = self.downscale( dims*i, use_activator=False ) (x)
x = Add()([x, x0])
x = LeakyReLU(0.2)(x)
return Conv2D(output_nc, kernel_size=5, padding='same', activation='sigmoid')(x)
return func
self.encoder = modelify(enc_flow(e_ch_dims, ae_dims, lowest_dense_res)) ( Input(bgr_shape) )
@ -243,7 +248,7 @@ class SAEv2Model(ModelBase):
[self.decoder_dstm, 'decoder_dstm.h5'] ]
return ar
class SAELIAEModel(object):
class SAELIAEModel(CommonModel):
def __init__(self, resolution, ae_dims, e_ch_dims, d_ch_dims, learn_mask):
super().__init__()
self.learn_mask = learn_mask
@ -251,50 +256,37 @@ class SAEv2Model(ModelBase):
output_nc = 3
bgr_shape = (resolution, resolution, output_nc)
mask_shape = (resolution, resolution, 1)
lowest_dense_res = resolution // 16
def downscale (dim, kernel_size=5, dilation_rate=1, use_activator=True):
def func(x):
if not use_activator:
return SubpixelDownscaler()(Conv2D(dim // 4, kernel_size=kernel_size, strides=1, dilation_rate=dilation_rate, padding='same')(x))
else:
return SubpixelDownscaler()(LeakyReLU(0.1)(Conv2D(dim // 4, kernel_size=kernel_size, strides=1, dilation_rate=dilation_rate, padding='same')(x)))
return func
def upscale (dim):
def func(x):
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, kernel_size=3, strides=1, padding='same')(x)))
return func
def enc_flow(e_ch_dims):
dims = output_nc*e_ch_dims
if dims % 2 != 0:
dims += 1
def func(inp):
x = downscale(dims , 3, 1 )(inp)
x = downscale(dims*2, 3, 1 )(x)
x = downscale(dims*4, 3, 1 )(x)
x0 = downscale(dims*8, 3, 1 )(x)
x = downscale(dims , 5, 1 )(inp)
x = downscale(dims*2, 5, 1 )(x)
x = downscale(dims*4, 5, 1 )(x)
x1 = downscale(dims*8, 5, 1 )(x)
x = downscale(dims , 5, 2 )(inp)
x = downscale(dims*2, 5, 2 )(x)
x = downscale(dims*4, 5, 2 )(x)
x2 = downscale(dims*8, 5, 2 )(x)
x = downscale(dims , 7, 2 )(inp)
x = downscale(dims*2, 7, 2 )(x)
x = downscale(dims*4, 7, 2 )(x)
x3 = downscale(dims*8, 7, 2 )(x)
def func(inp):
x = self.downscale(dims , 3, 1 )(inp)
x = self.downscale(dims*2, 3, 1 )(x)
x = self.downscale(dims*4, 3, 1 )(x)
x0 = self.downscale(dims*8, 3, 1 )(x)
x = self.downscale(dims , 5, 1 )(inp)
x = self.downscale(dims*2, 5, 1 )(x)
x = self.downscale(dims*4, 5, 1 )(x)
x1 = self.downscale(dims*8, 5, 1 )(x)
x = self.downscale(dims , 5, 2 )(inp)
x = self.downscale(dims*2, 5, 2 )(x)
x = self.downscale(dims*4, 5, 2 )(x)
x2 = self.downscale(dims*8, 5, 2 )(x)
x = self.downscale(dims , 7, 2 )(inp)
x = self.downscale(dims*2, 7, 2 )(x)
x = self.downscale(dims*4, 7, 2 )(x)
x3 = self.downscale(dims*8, 7, 2 )(x)
x = Concatenate()([x0,x1,x2,x3])
x = Flatten()(x)
return x
return func
@ -304,7 +296,7 @@ class SAEv2Model(ModelBase):
x = Dense(ae_dims)(x)
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims*2)(x)
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims*2))(x)
x = upscale(ae_dims*2)(x)
x = self.upscale(ae_dims*2)(x)
return x
return func
@ -312,30 +304,20 @@ class SAEv2Model(ModelBase):
dims = output_nc * d_ch_dims
if dims % 2 != 0:
dims += 1
def ResidualBlock(dim):
def func(inp):
x = Conv2D(dim, kernel_size=3, padding='same')(inp)
x = LeakyReLU(0.2)(x)
x = Conv2D(dim, kernel_size=3, padding='same')(x)
x = Add()([x, inp])
x = LeakyReLU(0.2)(x)
return x
return func
def func(x):
for i in [8,4,2]:
x = upscale(dims*i)(x)
x = self.upscale(dims*i)(x)
if not is_mask:
x0 = x
x = upscale( (dims*i)//2 )(x)
x = ResidualBlock( (dims*i)//2 )(x)
x = downscale( dims*i, use_activator=False ) (x)
x = self.upscale( (dims*i)//2 )(x)
x = self.ResidualBlock( (dims*i)//2 )(x)
x = self.downscale( dims*i, use_activator=False ) (x)
x = Add()([x, x0])
x = LeakyReLU(0.2)(x)
return Conv2D(output_nc, kernel_size=5, padding='same', activation='sigmoid')(x)
return func
@ -360,7 +342,7 @@ class SAEv2Model(ModelBase):
self.warped_src, self.warped_dst = Input(bgr_shape), Input(bgr_shape)
self.target_src, self.target_dst = Input(bgr_shape), Input(bgr_shape)
self.target_srcm, self.target_dstm = Input(mask_shape), Input(mask_shape)
warped_src_code = self.encoder (self.warped_src)
warped_src_inter_AB_code = self.inter_AB (warped_src_code)
self.src_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code])
@ -399,14 +381,14 @@ class SAEv2Model(ModelBase):
self.model = SAEDFModel (resolution, ae_dims, ed_ch_dims, ed_ch_dims, learn_mask)
elif 'liae' in self.options['archi']:
self.model = SAELIAEModel (resolution, ae_dims, ed_ch_dims, ed_ch_dims, learn_mask)
self.opt_dis_model = []
if self.true_face_training:
def dis_flow(ndf=256):
def func(x):
x, = x
code_res = K.int_shape(x)[1]
x = Conv2D( ndf, 4, strides=2, padding='valid')( ZeroPadding2D(1)(x) )
@ -414,25 +396,25 @@ class SAEv2Model(ModelBase):
x = Conv2D( ndf*2, 3, strides=2, padding='valid')( ZeroPadding2D(1)(x) )
x = LeakyReLU(0.1)(x)
if code_res > 8:
x = Conv2D( ndf*4, 3, strides=2, padding='valid')( ZeroPadding2D(1)(x) )
x = LeakyReLU(0.1)(x)
if code_res > 16:
x = Conv2D( ndf*8, 3, strides=2, padding='valid')( ZeroPadding2D(1)(x) )
x = LeakyReLU(0.1)(x)
if code_res > 32:
x = Conv2D( ndf*8, 3, strides=2, padding='valid')( ZeroPadding2D(1)(x) )
x = LeakyReLU(0.1)(x)
return Conv2D( 1, 1, strides=1, padding='valid', activation='sigmoid')(x)
return func
sh = [ Input( K.int_shape(self.model.src_code)[1:] ) ]
self.dis = modelify(dis_flow()) (sh)
self.opt_dis_model = [ (self.dis, 'dis.h5') ]
loaded, not_loaded = [], self.model.get_model_filename_list()+self.opt_dis_model
@ -467,7 +449,7 @@ class SAEv2Model(ModelBase):
self.src_dst_opt = RMSprop(lr=5e-5, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
self.src_dst_mask_opt = RMSprop(lr=5e-5, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
self.D_opt = RMSprop(lr=5e-5, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
src_loss = K.mean ( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_opt, pred_src_src_masked_opt) )
src_loss += K.mean ( 10*K.square( target_src_masked_opt - pred_src_src_masked_opt ) )
@ -484,26 +466,26 @@ class SAEv2Model(ModelBase):
dst_loss += K.mean( 10*K.square( target_dst_masked_opt - pred_dst_dst_masked_opt ) )
G_loss = src_loss+dst_loss
if self.true_face_training:
if self.true_face_training:
def DLoss(labels,logits):
return K.mean(K.binary_crossentropy(labels,logits))
src_code_d = self.dis( self.model.src_code )
src_code_d_ones = K.ones_like(src_code_d)
src_code_d_zeros = K.zeros_like(src_code_d)
dst_code_d = self.dis( self.model.dst_code )
dst_code_d_ones = K.ones_like(dst_code_d)
G_loss += 0.01*DLoss(src_code_d_ones, src_code_d)
loss_D = (DLoss(dst_code_d_ones , dst_code_d) + \
DLoss(src_code_d_zeros, src_code_d) ) * 0.5
self.D_train = K.function ([self.model.warped_src, self.model.warped_dst],[loss_D], self.D_opt.get_updates(loss_D, self.dis.trainable_weights) )
self.src_dst_train = K.function ([self.model.warped_src, self.model.warped_dst, self.model.target_src, self.model.target_srcm, self.model.target_dst, self.model.target_dstm],
[src_loss,dst_loss],
self.src_dst_opt.get_updates( G_loss, self.model.src_dst_trainable_weights)
[src_loss,dst_loss],
self.src_dst_opt.get_updates( G_loss, self.model.src_dst_trainable_weights)
)
if self.options['learn_mask']:
@ -525,14 +507,14 @@ class SAEv2Model(ModelBase):
if self.is_training_mode:
t = SampleProcessor.Types
if self.options['face_type'] == 'h':
face_type = t.FACE_TYPE_HALF
elif self.options['face_type'] == 'mf':
face_type = t.FACE_TYPE_MID_FULL
elif self.options['face_type'] == 'f':
face_type = t.FACE_TYPE_FULL
t_mode_bgr = t.MODE_BGR if not self.pretrain else t.MODE_BGR_SHUFFLE
training_data_src_path = self.training_data_src_path
@ -568,7 +550,7 @@ class SAEv2Model(ModelBase):
#override
def onSave(self):
self.save_weights_safe( self.get_model_filename_list()+self.opt_dis_model )
#override
def on_success_train_one_iter(self):
if len(self.CA_conv_weights_list) != 0:
@ -584,9 +566,9 @@ class SAEv2Model(ModelBase):
feed = [warped_src, warped_dst, target_src, target_srcm, target_dst, target_dstm]
src_loss, dst_loss, = self.src_dst_train (feed)
if self.true_face_training:
self.D_train([warped_src, warped_dst])
self.D_train([warped_src, warped_dst])
if self.options['learn_mask']:
feed = [ warped_src, warped_dst, target_srcm, target_dstm ]