mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 13:02:15 -07:00
SAEHD: speed up for nvidia, duplicate code clean up
This commit is contained in:
parent
627df082d7
commit
3f23135982
2 changed files with 187 additions and 171 deletions
|
@ -112,7 +112,31 @@ class SAEv2Model(ModelBase):
|
||||||
self.true_face_training = self.options.get('true_face_training', False)
|
self.true_face_training = self.options.get('true_face_training', False)
|
||||||
masked_training = True
|
masked_training = True
|
||||||
|
|
||||||
class SAEDFModel(object):
|
class CommonModel(object):
|
||||||
|
def downscale (self, dim, kernel_size=5, dilation_rate=1, use_activator=True):
|
||||||
|
def func(x):
|
||||||
|
if not use_activator:
|
||||||
|
return SubpixelDownscaler()(Conv2D(dim // 4, kernel_size=kernel_size, strides=1, dilation_rate=dilation_rate, padding='same')(x))
|
||||||
|
else:
|
||||||
|
return SubpixelDownscaler()(LeakyReLU(0.1)(Conv2D(dim // 4, kernel_size=kernel_size, strides=1, dilation_rate=dilation_rate, padding='same')(x)))
|
||||||
|
return func
|
||||||
|
|
||||||
|
def upscale (self, dim, size=(2,2)):
|
||||||
|
def func(x):
|
||||||
|
return SubpixelUpscaler(size=size)(LeakyReLU(0.1)(Conv2D(dim * np.prod(size) , kernel_size=3, strides=1, padding='same')(x)))
|
||||||
|
return func
|
||||||
|
|
||||||
|
def ResidualBlock(self, dim):
|
||||||
|
def func(inp):
|
||||||
|
x = Conv2D(dim, kernel_size=3, padding='same')(inp)
|
||||||
|
x = LeakyReLU(0.2)(x)
|
||||||
|
x = Conv2D(dim, kernel_size=3, padding='same')(x)
|
||||||
|
x = Add()([x, inp])
|
||||||
|
x = LeakyReLU(0.2)(x)
|
||||||
|
return x
|
||||||
|
return func
|
||||||
|
|
||||||
|
class SAEDFModel(CommonModel):
|
||||||
def __init__(self, resolution, ae_dims, e_ch_dims, d_ch_dims, learn_mask):
|
def __init__(self, resolution, ae_dims, e_ch_dims, d_ch_dims, learn_mask):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.learn_mask = learn_mask
|
self.learn_mask = learn_mask
|
||||||
|
@ -123,51 +147,40 @@ class SAEv2Model(ModelBase):
|
||||||
lowest_dense_res = resolution // 16
|
lowest_dense_res = resolution // 16
|
||||||
e_dims = output_nc*e_ch_dims
|
e_dims = output_nc*e_ch_dims
|
||||||
|
|
||||||
def downscale (dim, kernel_size=5, dilation_rate=1, use_activator=True):
|
|
||||||
def func(x):
|
|
||||||
if not use_activator:
|
|
||||||
return SubpixelDownscaler()(Conv2D(dim // 4, kernel_size=kernel_size, strides=1, dilation_rate=dilation_rate, padding='same')(x))
|
|
||||||
else:
|
|
||||||
return SubpixelDownscaler()(LeakyReLU(0.1)(Conv2D(dim // 4, kernel_size=kernel_size, strides=1, dilation_rate=dilation_rate, padding='same')(x)))
|
|
||||||
return func
|
|
||||||
|
|
||||||
def upscale (dim, size=(2,2)):
|
|
||||||
def func(x):
|
|
||||||
return SubpixelUpscaler(size=size)(LeakyReLU(0.1)(Conv2D(dim * np.prod(size) , kernel_size=3, strides=1, padding='same')(x)))
|
|
||||||
return func
|
|
||||||
|
|
||||||
def enc_flow(e_ch_dims, ae_dims, lowest_dense_res):
|
def enc_flow(e_ch_dims, ae_dims, lowest_dense_res):
|
||||||
dims = output_nc * e_ch_dims
|
dims = output_nc * e_ch_dims
|
||||||
if dims % 2 != 0:
|
if dims % 2 != 0:
|
||||||
dims += 1
|
dims += 1
|
||||||
|
|
||||||
def func(inp):
|
def func(inp):
|
||||||
x = downscale(dims , 3, 1 )(inp)
|
x = self.downscale(dims , 3, 1 )(inp)
|
||||||
x = downscale(dims*2, 3, 1 )(x)
|
x = self.downscale(dims*2, 3, 1 )(x)
|
||||||
x = downscale(dims*4, 3, 1 )(x)
|
x = self.downscale(dims*4, 3, 1 )(x)
|
||||||
x0 = downscale(dims*8, 3, 1 )(x)
|
x0 = self.downscale(dims*8, 3, 1 )(x)
|
||||||
|
|
||||||
x = downscale(dims , 5, 1 )(inp)
|
x = self.downscale(dims , 5, 1 )(inp)
|
||||||
x = downscale(dims*2, 5, 1 )(x)
|
x = self.downscale(dims*2, 5, 1 )(x)
|
||||||
x = downscale(dims*4, 5, 1 )(x)
|
x = self.downscale(dims*4, 5, 1 )(x)
|
||||||
x1 = downscale(dims*8, 5, 1 )(x)
|
x1 = self.downscale(dims*8, 5, 1 )(x)
|
||||||
|
|
||||||
x = downscale(dims , 5, 2 )(inp)
|
x = self.downscale(dims , 5, 2 )(inp)
|
||||||
x = downscale(dims*2, 5, 2 )(x)
|
x = self.downscale(dims*2, 5, 2 )(x)
|
||||||
x = downscale(dims*4, 5, 2 )(x)
|
x = self.downscale(dims*4, 5, 2 )(x)
|
||||||
x2 = downscale(dims*8, 5, 2 )(x)
|
x2 = self.downscale(dims*8, 5, 2 )(x)
|
||||||
|
|
||||||
x = downscale(dims , 7, 2 )(inp)
|
x = self.downscale(dims , 7, 2 )(inp)
|
||||||
x = downscale(dims*2, 7, 2 )(x)
|
x = self.downscale(dims*2, 7, 2 )(x)
|
||||||
x = downscale(dims*4, 7, 2 )(x)
|
x = self.downscale(dims*4, 7, 2 )(x)
|
||||||
x3 = downscale(dims*8, 7, 2 )(x)
|
x3 = self.downscale(dims*8, 7, 2 )(x)
|
||||||
|
|
||||||
x = Concatenate()([x0,x1,x2,x3])
|
x = Concatenate()([x0,x1,x2,x3])
|
||||||
|
|
||||||
x = Dense(ae_dims)(Flatten()(x))
|
x = Dense(ae_dims)(Flatten()(x))
|
||||||
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims)(x)
|
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims)(x)
|
||||||
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims))(x)
|
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims))(x)
|
||||||
x = upscale(ae_dims)(x)
|
x = self.upscale(ae_dims)(x)
|
||||||
return x
|
return x
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
@ -175,32 +188,24 @@ class SAEv2Model(ModelBase):
|
||||||
dims = output_nc * d_ch_dims
|
dims = output_nc * d_ch_dims
|
||||||
if dims % 2 != 0:
|
if dims % 2 != 0:
|
||||||
dims += 1
|
dims += 1
|
||||||
|
|
||||||
def ResidualBlock(dim):
|
|
||||||
def func(inp):
|
|
||||||
x = Conv2D(dim, kernel_size=3, padding='same')(inp)
|
|
||||||
x = LeakyReLU(0.2)(x)
|
|
||||||
x = Conv2D(dim, kernel_size=3, padding='same')(x)
|
|
||||||
x = Add()([x, inp])
|
|
||||||
x = LeakyReLU(0.2)(x)
|
|
||||||
return x
|
|
||||||
return func
|
|
||||||
|
|
||||||
def func(x):
|
def func(x):
|
||||||
|
|
||||||
for i in [8,4,2]:
|
for i in [8,4,2]:
|
||||||
x = upscale(dims*i)(x)
|
x = self.upscale(dims*i)(x)
|
||||||
|
|
||||||
if not is_mask:
|
if not is_mask:
|
||||||
x0 = x
|
x0 = x
|
||||||
x = upscale( (dims*i)//2 )(x)
|
x = self.upscale( (dims*i)//2 )(x)
|
||||||
x = ResidualBlock( (dims*i)//2 )(x)
|
x = self.ResidualBlock( (dims*i)//2 )(x)
|
||||||
x = downscale( dims*i, use_activator=False ) (x)
|
x = self.downscale( dims*i, use_activator=False ) (x)
|
||||||
x = Add()([x, x0])
|
x = Add()([x, x0])
|
||||||
x = LeakyReLU(0.2)(x)
|
x = LeakyReLU(0.2)(x)
|
||||||
|
|
||||||
return Conv2D(output_nc, kernel_size=5, padding='same', activation='sigmoid')(x)
|
return Conv2D(output_nc, kernel_size=5, padding='same', activation='sigmoid')(x)
|
||||||
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
self.encoder = modelify(enc_flow(e_ch_dims, ae_dims, lowest_dense_res)) ( Input(bgr_shape) )
|
self.encoder = modelify(enc_flow(e_ch_dims, ae_dims, lowest_dense_res)) ( Input(bgr_shape) )
|
||||||
|
@ -243,7 +248,7 @@ class SAEv2Model(ModelBase):
|
||||||
[self.decoder_dstm, 'decoder_dstm.h5'] ]
|
[self.decoder_dstm, 'decoder_dstm.h5'] ]
|
||||||
return ar
|
return ar
|
||||||
|
|
||||||
class SAELIAEModel(object):
|
class SAELIAEModel(CommonModel):
|
||||||
def __init__(self, resolution, ae_dims, e_ch_dims, d_ch_dims, learn_mask):
|
def __init__(self, resolution, ae_dims, e_ch_dims, d_ch_dims, learn_mask):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.learn_mask = learn_mask
|
self.learn_mask = learn_mask
|
||||||
|
@ -251,50 +256,37 @@ class SAEv2Model(ModelBase):
|
||||||
output_nc = 3
|
output_nc = 3
|
||||||
bgr_shape = (resolution, resolution, output_nc)
|
bgr_shape = (resolution, resolution, output_nc)
|
||||||
mask_shape = (resolution, resolution, 1)
|
mask_shape = (resolution, resolution, 1)
|
||||||
|
|
||||||
lowest_dense_res = resolution // 16
|
lowest_dense_res = resolution // 16
|
||||||
|
|
||||||
def downscale (dim, kernel_size=5, dilation_rate=1, use_activator=True):
|
|
||||||
def func(x):
|
|
||||||
if not use_activator:
|
|
||||||
return SubpixelDownscaler()(Conv2D(dim // 4, kernel_size=kernel_size, strides=1, dilation_rate=dilation_rate, padding='same')(x))
|
|
||||||
else:
|
|
||||||
return SubpixelDownscaler()(LeakyReLU(0.1)(Conv2D(dim // 4, kernel_size=kernel_size, strides=1, dilation_rate=dilation_rate, padding='same')(x)))
|
|
||||||
return func
|
|
||||||
|
|
||||||
def upscale (dim):
|
|
||||||
def func(x):
|
|
||||||
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, kernel_size=3, strides=1, padding='same')(x)))
|
|
||||||
return func
|
|
||||||
|
|
||||||
def enc_flow(e_ch_dims):
|
def enc_flow(e_ch_dims):
|
||||||
dims = output_nc*e_ch_dims
|
dims = output_nc*e_ch_dims
|
||||||
if dims % 2 != 0:
|
if dims % 2 != 0:
|
||||||
dims += 1
|
dims += 1
|
||||||
|
|
||||||
def func(inp):
|
def func(inp):
|
||||||
x = downscale(dims , 3, 1 )(inp)
|
x = self.downscale(dims , 3, 1 )(inp)
|
||||||
x = downscale(dims*2, 3, 1 )(x)
|
x = self.downscale(dims*2, 3, 1 )(x)
|
||||||
x = downscale(dims*4, 3, 1 )(x)
|
x = self.downscale(dims*4, 3, 1 )(x)
|
||||||
x0 = downscale(dims*8, 3, 1 )(x)
|
x0 = self.downscale(dims*8, 3, 1 )(x)
|
||||||
|
|
||||||
x = downscale(dims , 5, 1 )(inp)
|
x = self.downscale(dims , 5, 1 )(inp)
|
||||||
x = downscale(dims*2, 5, 1 )(x)
|
x = self.downscale(dims*2, 5, 1 )(x)
|
||||||
x = downscale(dims*4, 5, 1 )(x)
|
x = self.downscale(dims*4, 5, 1 )(x)
|
||||||
x1 = downscale(dims*8, 5, 1 )(x)
|
x1 = self.downscale(dims*8, 5, 1 )(x)
|
||||||
|
|
||||||
x = downscale(dims , 5, 2 )(inp)
|
x = self.downscale(dims , 5, 2 )(inp)
|
||||||
x = downscale(dims*2, 5, 2 )(x)
|
x = self.downscale(dims*2, 5, 2 )(x)
|
||||||
x = downscale(dims*4, 5, 2 )(x)
|
x = self.downscale(dims*4, 5, 2 )(x)
|
||||||
x2 = downscale(dims*8, 5, 2 )(x)
|
x2 = self.downscale(dims*8, 5, 2 )(x)
|
||||||
|
|
||||||
x = downscale(dims , 7, 2 )(inp)
|
x = self.downscale(dims , 7, 2 )(inp)
|
||||||
x = downscale(dims*2, 7, 2 )(x)
|
x = self.downscale(dims*2, 7, 2 )(x)
|
||||||
x = downscale(dims*4, 7, 2 )(x)
|
x = self.downscale(dims*4, 7, 2 )(x)
|
||||||
x3 = downscale(dims*8, 7, 2 )(x)
|
x3 = self.downscale(dims*8, 7, 2 )(x)
|
||||||
|
|
||||||
x = Concatenate()([x0,x1,x2,x3])
|
x = Concatenate()([x0,x1,x2,x3])
|
||||||
|
|
||||||
x = Flatten()(x)
|
x = Flatten()(x)
|
||||||
return x
|
return x
|
||||||
return func
|
return func
|
||||||
|
@ -304,7 +296,7 @@ class SAEv2Model(ModelBase):
|
||||||
x = Dense(ae_dims)(x)
|
x = Dense(ae_dims)(x)
|
||||||
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims*2)(x)
|
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims*2)(x)
|
||||||
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims*2))(x)
|
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims*2))(x)
|
||||||
x = upscale(ae_dims*2)(x)
|
x = self.upscale(ae_dims*2)(x)
|
||||||
return x
|
return x
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
@ -312,30 +304,20 @@ class SAEv2Model(ModelBase):
|
||||||
dims = output_nc * d_ch_dims
|
dims = output_nc * d_ch_dims
|
||||||
if dims % 2 != 0:
|
if dims % 2 != 0:
|
||||||
dims += 1
|
dims += 1
|
||||||
|
|
||||||
def ResidualBlock(dim):
|
|
||||||
def func(inp):
|
|
||||||
x = Conv2D(dim, kernel_size=3, padding='same')(inp)
|
|
||||||
x = LeakyReLU(0.2)(x)
|
|
||||||
x = Conv2D(dim, kernel_size=3, padding='same')(x)
|
|
||||||
x = Add()([x, inp])
|
|
||||||
x = LeakyReLU(0.2)(x)
|
|
||||||
return x
|
|
||||||
return func
|
|
||||||
|
|
||||||
def func(x):
|
def func(x):
|
||||||
|
|
||||||
for i in [8,4,2]:
|
for i in [8,4,2]:
|
||||||
x = upscale(dims*i)(x)
|
x = self.upscale(dims*i)(x)
|
||||||
|
|
||||||
if not is_mask:
|
if not is_mask:
|
||||||
x0 = x
|
x0 = x
|
||||||
x = upscale( (dims*i)//2 )(x)
|
x = self.upscale( (dims*i)//2 )(x)
|
||||||
x = ResidualBlock( (dims*i)//2 )(x)
|
x = self.ResidualBlock( (dims*i)//2 )(x)
|
||||||
x = downscale( dims*i, use_activator=False ) (x)
|
x = self.downscale( dims*i, use_activator=False ) (x)
|
||||||
x = Add()([x, x0])
|
x = Add()([x, x0])
|
||||||
x = LeakyReLU(0.2)(x)
|
x = LeakyReLU(0.2)(x)
|
||||||
|
|
||||||
return Conv2D(output_nc, kernel_size=5, padding='same', activation='sigmoid')(x)
|
return Conv2D(output_nc, kernel_size=5, padding='same', activation='sigmoid')(x)
|
||||||
|
|
||||||
return func
|
return func
|
||||||
|
@ -360,7 +342,7 @@ class SAEv2Model(ModelBase):
|
||||||
self.warped_src, self.warped_dst = Input(bgr_shape), Input(bgr_shape)
|
self.warped_src, self.warped_dst = Input(bgr_shape), Input(bgr_shape)
|
||||||
self.target_src, self.target_dst = Input(bgr_shape), Input(bgr_shape)
|
self.target_src, self.target_dst = Input(bgr_shape), Input(bgr_shape)
|
||||||
self.target_srcm, self.target_dstm = Input(mask_shape), Input(mask_shape)
|
self.target_srcm, self.target_dstm = Input(mask_shape), Input(mask_shape)
|
||||||
|
|
||||||
warped_src_code = self.encoder (self.warped_src)
|
warped_src_code = self.encoder (self.warped_src)
|
||||||
warped_src_inter_AB_code = self.inter_AB (warped_src_code)
|
warped_src_inter_AB_code = self.inter_AB (warped_src_code)
|
||||||
self.src_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code])
|
self.src_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code])
|
||||||
|
@ -399,14 +381,14 @@ class SAEv2Model(ModelBase):
|
||||||
self.model = SAEDFModel (resolution, ae_dims, ed_ch_dims, ed_ch_dims, learn_mask)
|
self.model = SAEDFModel (resolution, ae_dims, ed_ch_dims, ed_ch_dims, learn_mask)
|
||||||
elif 'liae' in self.options['archi']:
|
elif 'liae' in self.options['archi']:
|
||||||
self.model = SAELIAEModel (resolution, ae_dims, ed_ch_dims, ed_ch_dims, learn_mask)
|
self.model = SAELIAEModel (resolution, ae_dims, ed_ch_dims, ed_ch_dims, learn_mask)
|
||||||
|
|
||||||
self.opt_dis_model = []
|
self.opt_dis_model = []
|
||||||
|
|
||||||
if self.true_face_training:
|
if self.true_face_training:
|
||||||
def dis_flow(ndf=256):
|
def dis_flow(ndf=256):
|
||||||
def func(x):
|
def func(x):
|
||||||
x, = x
|
x, = x
|
||||||
|
|
||||||
code_res = K.int_shape(x)[1]
|
code_res = K.int_shape(x)[1]
|
||||||
|
|
||||||
x = Conv2D( ndf, 4, strides=2, padding='valid')( ZeroPadding2D(1)(x) )
|
x = Conv2D( ndf, 4, strides=2, padding='valid')( ZeroPadding2D(1)(x) )
|
||||||
|
@ -414,25 +396,25 @@ class SAEv2Model(ModelBase):
|
||||||
|
|
||||||
x = Conv2D( ndf*2, 3, strides=2, padding='valid')( ZeroPadding2D(1)(x) )
|
x = Conv2D( ndf*2, 3, strides=2, padding='valid')( ZeroPadding2D(1)(x) )
|
||||||
x = LeakyReLU(0.1)(x)
|
x = LeakyReLU(0.1)(x)
|
||||||
|
|
||||||
if code_res > 8:
|
if code_res > 8:
|
||||||
x = Conv2D( ndf*4, 3, strides=2, padding='valid')( ZeroPadding2D(1)(x) )
|
x = Conv2D( ndf*4, 3, strides=2, padding='valid')( ZeroPadding2D(1)(x) )
|
||||||
x = LeakyReLU(0.1)(x)
|
x = LeakyReLU(0.1)(x)
|
||||||
|
|
||||||
if code_res > 16:
|
if code_res > 16:
|
||||||
x = Conv2D( ndf*8, 3, strides=2, padding='valid')( ZeroPadding2D(1)(x) )
|
x = Conv2D( ndf*8, 3, strides=2, padding='valid')( ZeroPadding2D(1)(x) )
|
||||||
x = LeakyReLU(0.1)(x)
|
x = LeakyReLU(0.1)(x)
|
||||||
|
|
||||||
if code_res > 32:
|
if code_res > 32:
|
||||||
x = Conv2D( ndf*8, 3, strides=2, padding='valid')( ZeroPadding2D(1)(x) )
|
x = Conv2D( ndf*8, 3, strides=2, padding='valid')( ZeroPadding2D(1)(x) )
|
||||||
x = LeakyReLU(0.1)(x)
|
x = LeakyReLU(0.1)(x)
|
||||||
|
|
||||||
return Conv2D( 1, 1, strides=1, padding='valid', activation='sigmoid')(x)
|
return Conv2D( 1, 1, strides=1, padding='valid', activation='sigmoid')(x)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
sh = [ Input( K.int_shape(self.model.src_code)[1:] ) ]
|
sh = [ Input( K.int_shape(self.model.src_code)[1:] ) ]
|
||||||
self.dis = modelify(dis_flow()) (sh)
|
self.dis = modelify(dis_flow()) (sh)
|
||||||
|
|
||||||
self.opt_dis_model = [ (self.dis, 'dis.h5') ]
|
self.opt_dis_model = [ (self.dis, 'dis.h5') ]
|
||||||
|
|
||||||
loaded, not_loaded = [], self.model.get_model_filename_list()+self.opt_dis_model
|
loaded, not_loaded = [], self.model.get_model_filename_list()+self.opt_dis_model
|
||||||
|
@ -467,7 +449,7 @@ class SAEv2Model(ModelBase):
|
||||||
self.src_dst_opt = RMSprop(lr=5e-5, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
self.src_dst_opt = RMSprop(lr=5e-5, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||||
self.src_dst_mask_opt = RMSprop(lr=5e-5, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
self.src_dst_mask_opt = RMSprop(lr=5e-5, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||||
self.D_opt = RMSprop(lr=5e-5, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
self.D_opt = RMSprop(lr=5e-5, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||||
|
|
||||||
src_loss = K.mean ( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_opt, pred_src_src_masked_opt) )
|
src_loss = K.mean ( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_opt, pred_src_src_masked_opt) )
|
||||||
src_loss += K.mean ( 10*K.square( target_src_masked_opt - pred_src_src_masked_opt ) )
|
src_loss += K.mean ( 10*K.square( target_src_masked_opt - pred_src_src_masked_opt ) )
|
||||||
|
|
||||||
|
@ -484,26 +466,26 @@ class SAEv2Model(ModelBase):
|
||||||
dst_loss += K.mean( 10*K.square( target_dst_masked_opt - pred_dst_dst_masked_opt ) )
|
dst_loss += K.mean( 10*K.square( target_dst_masked_opt - pred_dst_dst_masked_opt ) )
|
||||||
|
|
||||||
G_loss = src_loss+dst_loss
|
G_loss = src_loss+dst_loss
|
||||||
|
|
||||||
if self.true_face_training:
|
if self.true_face_training:
|
||||||
def DLoss(labels,logits):
|
def DLoss(labels,logits):
|
||||||
return K.mean(K.binary_crossentropy(labels,logits))
|
return K.mean(K.binary_crossentropy(labels,logits))
|
||||||
|
|
||||||
src_code_d = self.dis( self.model.src_code )
|
src_code_d = self.dis( self.model.src_code )
|
||||||
src_code_d_ones = K.ones_like(src_code_d)
|
src_code_d_ones = K.ones_like(src_code_d)
|
||||||
src_code_d_zeros = K.zeros_like(src_code_d)
|
src_code_d_zeros = K.zeros_like(src_code_d)
|
||||||
dst_code_d = self.dis( self.model.dst_code )
|
dst_code_d = self.dis( self.model.dst_code )
|
||||||
dst_code_d_ones = K.ones_like(dst_code_d)
|
dst_code_d_ones = K.ones_like(dst_code_d)
|
||||||
G_loss += 0.01*DLoss(src_code_d_ones, src_code_d)
|
G_loss += 0.01*DLoss(src_code_d_ones, src_code_d)
|
||||||
|
|
||||||
loss_D = (DLoss(dst_code_d_ones , dst_code_d) + \
|
loss_D = (DLoss(dst_code_d_ones , dst_code_d) + \
|
||||||
DLoss(src_code_d_zeros, src_code_d) ) * 0.5
|
DLoss(src_code_d_zeros, src_code_d) ) * 0.5
|
||||||
|
|
||||||
self.D_train = K.function ([self.model.warped_src, self.model.warped_dst],[loss_D], self.D_opt.get_updates(loss_D, self.dis.trainable_weights) )
|
self.D_train = K.function ([self.model.warped_src, self.model.warped_dst],[loss_D], self.D_opt.get_updates(loss_D, self.dis.trainable_weights) )
|
||||||
|
|
||||||
self.src_dst_train = K.function ([self.model.warped_src, self.model.warped_dst, self.model.target_src, self.model.target_srcm, self.model.target_dst, self.model.target_dstm],
|
self.src_dst_train = K.function ([self.model.warped_src, self.model.warped_dst, self.model.target_src, self.model.target_srcm, self.model.target_dst, self.model.target_dstm],
|
||||||
[src_loss,dst_loss],
|
[src_loss,dst_loss],
|
||||||
self.src_dst_opt.get_updates( G_loss, self.model.src_dst_trainable_weights)
|
self.src_dst_opt.get_updates( G_loss, self.model.src_dst_trainable_weights)
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
|
@ -525,14 +507,14 @@ class SAEv2Model(ModelBase):
|
||||||
|
|
||||||
if self.is_training_mode:
|
if self.is_training_mode:
|
||||||
t = SampleProcessor.Types
|
t = SampleProcessor.Types
|
||||||
|
|
||||||
if self.options['face_type'] == 'h':
|
if self.options['face_type'] == 'h':
|
||||||
face_type = t.FACE_TYPE_HALF
|
face_type = t.FACE_TYPE_HALF
|
||||||
elif self.options['face_type'] == 'mf':
|
elif self.options['face_type'] == 'mf':
|
||||||
face_type = t.FACE_TYPE_MID_FULL
|
face_type = t.FACE_TYPE_MID_FULL
|
||||||
elif self.options['face_type'] == 'f':
|
elif self.options['face_type'] == 'f':
|
||||||
face_type = t.FACE_TYPE_FULL
|
face_type = t.FACE_TYPE_FULL
|
||||||
|
|
||||||
t_mode_bgr = t.MODE_BGR if not self.pretrain else t.MODE_BGR_SHUFFLE
|
t_mode_bgr = t.MODE_BGR if not self.pretrain else t.MODE_BGR_SHUFFLE
|
||||||
|
|
||||||
training_data_src_path = self.training_data_src_path
|
training_data_src_path = self.training_data_src_path
|
||||||
|
@ -568,7 +550,7 @@ class SAEv2Model(ModelBase):
|
||||||
#override
|
#override
|
||||||
def onSave(self):
|
def onSave(self):
|
||||||
self.save_weights_safe( self.get_model_filename_list()+self.opt_dis_model )
|
self.save_weights_safe( self.get_model_filename_list()+self.opt_dis_model )
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def on_success_train_one_iter(self):
|
def on_success_train_one_iter(self):
|
||||||
if len(self.CA_conv_weights_list) != 0:
|
if len(self.CA_conv_weights_list) != 0:
|
||||||
|
@ -584,9 +566,9 @@ class SAEv2Model(ModelBase):
|
||||||
feed = [warped_src, warped_dst, target_src, target_srcm, target_dst, target_dstm]
|
feed = [warped_src, warped_dst, target_src, target_srcm, target_dst, target_dstm]
|
||||||
|
|
||||||
src_loss, dst_loss, = self.src_dst_train (feed)
|
src_loss, dst_loss, = self.src_dst_train (feed)
|
||||||
|
|
||||||
if self.true_face_training:
|
if self.true_face_training:
|
||||||
self.D_train([warped_src, warped_dst])
|
self.D_train([warped_src, warped_dst])
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
feed = [ warped_src, warped_dst, target_srcm, target_dstm ]
|
feed = [ warped_src, warped_dst, target_srcm, target_dstm ]
|
||||||
|
|
106
nnlib/nnlib.py
106
nnlib/nnlib.py
|
@ -140,7 +140,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
||||||
|
|
||||||
if 'CUDA_VISIBLE_DEVICES' in os.environ.keys():
|
if 'CUDA_VISIBLE_DEVICES' in os.environ.keys():
|
||||||
os.environ.pop('CUDA_VISIBLE_DEVICES')
|
os.environ.pop('CUDA_VISIBLE_DEVICES')
|
||||||
|
|
||||||
os.environ['CUDA_CACHE_MAXSIZE'] = '536870912' #512Mb (32mb default)
|
os.environ['CUDA_CACHE_MAXSIZE'] = '536870912' #512Mb (32mb default)
|
||||||
|
|
||||||
os.environ['TF_MIN_GPU_MULTIPROCESSOR_COUNT'] = '2'
|
os.environ['TF_MIN_GPU_MULTIPROCESSOR_COUNT'] = '2'
|
||||||
|
@ -151,7 +151,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
nnlib.tf = tf
|
nnlib.tf = tf
|
||||||
|
|
||||||
if device_config.cpu_only:
|
if device_config.cpu_only:
|
||||||
config = tf.ConfigProto(device_count={'GPU': 0})
|
config = tf.ConfigProto(device_count={'GPU': 0})
|
||||||
else:
|
else:
|
||||||
|
@ -473,50 +473,84 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
||||||
|
|
||||||
nnlib.PixelShuffler = PixelShuffler
|
nnlib.PixelShuffler = PixelShuffler
|
||||||
nnlib.SubpixelUpscaler = PixelShuffler
|
nnlib.SubpixelUpscaler = PixelShuffler
|
||||||
|
|
||||||
class SubpixelDownscaler(KL.Layer):
|
|
||||||
def __init__(self, size=(2, 2), data_format='channels_last', **kwargs):
|
|
||||||
super(SubpixelDownscaler, self).__init__(**kwargs)
|
|
||||||
self.data_format = data_format
|
|
||||||
self.size = size
|
|
||||||
|
|
||||||
def call(self, inputs):
|
if 'tensorflow' in backend:
|
||||||
|
class SubpixelDownscaler(KL.Layer):
|
||||||
|
def __init__(self, size=(2, 2), data_format='channels_last', **kwargs):
|
||||||
|
super(SubpixelDownscaler, self).__init__(**kwargs)
|
||||||
|
self.data_format = data_format
|
||||||
|
self.size = size
|
||||||
|
|
||||||
input_shape = K.shape(inputs)
|
def call(self, inputs):
|
||||||
if K.int_shape(input_shape)[0] != 4:
|
|
||||||
raise ValueError('Inputs should have rank 4; Received input shape:', str(K.int_shape(inputs)))
|
|
||||||
|
|
||||||
batch_size, h, w, c = input_shape[0], input_shape[1], input_shape[2], K.int_shape(inputs)[-1]
|
input_shape = K.shape(inputs)
|
||||||
rh, rw = self.size
|
if K.int_shape(input_shape)[0] != 4:
|
||||||
oh, ow = h // rh, w // rw
|
raise ValueError('Inputs should have rank 4; Received input shape:', str(K.int_shape(inputs)))
|
||||||
oc = c * (rh * rw)
|
|
||||||
|
|
||||||
out = K.reshape(inputs, (batch_size, oh, rh, ow, rw, c))
|
return K.tf.space_to_depth(inputs, self.size[0], 'NHWC')
|
||||||
out = K.permute_dimensions(out, (0, 1, 3, 2, 4, 5))
|
|
||||||
out = K.reshape(out, (batch_size, oh, ow, oc))
|
|
||||||
return out
|
|
||||||
|
|
||||||
def compute_output_shape(self, input_shape):
|
def compute_output_shape(self, input_shape):
|
||||||
if len(input_shape) != 4:
|
if len(input_shape) != 4:
|
||||||
raise ValueError('Inputs should have rank ' +
|
raise ValueError('Inputs should have rank ' +
|
||||||
str(4) +
|
str(4) +
|
||||||
'; Received input shape:', str(input_shape))
|
'; Received input shape:', str(input_shape))
|
||||||
|
|
||||||
height = input_shape[1] // self.size[0] if input_shape[1] is not None else None
|
height = input_shape[1] // self.size[0] if input_shape[1] is not None else None
|
||||||
width = input_shape[2] // self.size[1] if input_shape[2] is not None else None
|
width = input_shape[2] // self.size[1] if input_shape[2] is not None else None
|
||||||
channels = input_shape[3] * self.size[0] * self.size[1]
|
channels = input_shape[3] * self.size[0] * self.size[1]
|
||||||
|
|
||||||
return (input_shape[0], height, width, channels)
|
return (input_shape[0], height, width, channels)
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
config = {'size': self.size,
|
config = {'size': self.size,
|
||||||
'data_format': self.data_format}
|
'data_format': self.data_format}
|
||||||
base_config = super(SubpixelDownscaler, self).get_config()
|
base_config = super(SubpixelDownscaler, self).get_config()
|
||||||
|
|
||||||
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
else:
|
||||||
|
class SubpixelDownscaler(KL.Layer):
|
||||||
|
def __init__(self, size=(2, 2), data_format='channels_last', **kwargs):
|
||||||
|
super(SubpixelDownscaler, self).__init__(**kwargs)
|
||||||
|
self.data_format = data_format
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
|
||||||
|
input_shape = K.shape(inputs)
|
||||||
|
if K.int_shape(input_shape)[0] != 4:
|
||||||
|
raise ValueError('Inputs should have rank 4; Received input shape:', str(K.int_shape(inputs)))
|
||||||
|
|
||||||
|
batch_size, h, w, c = input_shape[0], input_shape[1], input_shape[2], K.int_shape(inputs)[-1]
|
||||||
|
rh, rw = self.size
|
||||||
|
oh, ow = h // rh, w // rw
|
||||||
|
oc = c * (rh * rw)
|
||||||
|
|
||||||
|
out = K.reshape(inputs, (batch_size, oh, rh, ow, rw, c))
|
||||||
|
out = K.permute_dimensions(out, (0, 1, 3, 2, 4, 5))
|
||||||
|
out = K.reshape(out, (batch_size, oh, ow, oc))
|
||||||
|
return out
|
||||||
|
|
||||||
|
def compute_output_shape(self, input_shape):
|
||||||
|
if len(input_shape) != 4:
|
||||||
|
raise ValueError('Inputs should have rank ' +
|
||||||
|
str(4) +
|
||||||
|
'; Received input shape:', str(input_shape))
|
||||||
|
|
||||||
|
height = input_shape[1] // self.size[0] if input_shape[1] is not None else None
|
||||||
|
width = input_shape[2] // self.size[1] if input_shape[2] is not None else None
|
||||||
|
channels = input_shape[3] * self.size[0] * self.size[1]
|
||||||
|
|
||||||
|
return (input_shape[0], height, width, channels)
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
config = {'size': self.size,
|
||||||
|
'data_format': self.data_format}
|
||||||
|
base_config = super(SubpixelDownscaler, self).get_config()
|
||||||
|
|
||||||
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
|
||||||
return dict(list(base_config.items()) + list(config.items()))
|
|
||||||
|
|
||||||
nnlib.SubpixelDownscaler = SubpixelDownscaler
|
nnlib.SubpixelDownscaler = SubpixelDownscaler
|
||||||
|
|
||||||
class BlurPool(KL.Layer):
|
class BlurPool(KL.Layer):
|
||||||
"""
|
"""
|
||||||
https://arxiv.org/abs/1904.11486 https://github.com/adobe/antialiased-cnns
|
https://arxiv.org/abs/1904.11486 https://github.com/adobe/antialiased-cnns
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue