mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 05:22:06 -07:00
added 'sort by vggface': sorting by face similarity using VGGFace model.
Requires 4GB+ VRAM and internet connection for the first run.
This commit is contained in:
parent
0d3b25812d
commit
734d97d729
8 changed files with 186 additions and 43 deletions
|
@ -162,10 +162,6 @@ class FUNIT(object):
|
|||
for w in weights_list:
|
||||
K.set_value( w, K.get_value(initer(K.int_shape(w))) )
|
||||
|
||||
#if not self.is_first_run():
|
||||
# self.load_weights_safe(self.get_model_filename_list())
|
||||
|
||||
|
||||
|
||||
if load_weights_locally:
|
||||
pass
|
||||
|
@ -188,9 +184,6 @@ class FUNIT(object):
|
|||
[self.D_opt, 'D_opt.h5'],
|
||||
]
|
||||
|
||||
#def save_weights(self):
|
||||
# self.model.save_weights (str(self.weights_path))
|
||||
|
||||
def train(self, xa,la,xb,lb):
|
||||
D_loss, = self.D_train ([xa,la,xb,lb])
|
||||
G_loss, = self.G_train ([xa,la,xb,lb])
|
||||
|
@ -209,17 +202,17 @@ class FUNIT(object):
|
|||
def ResBlock(dim):
|
||||
def func(input):
|
||||
x = input
|
||||
x = Conv2D(dim, 3, strides=1, padding='valid')(ZeroPadding2D(1)(x))
|
||||
x = Conv2D(dim, 3, strides=1, padding='same')(x)
|
||||
x = InstanceNormalization()(x)
|
||||
x = ReLU()(x)
|
||||
x = Conv2D(dim, 3, strides=1, padding='valid')(ZeroPadding2D(1)(x))
|
||||
x = Conv2D(dim, 3, strides=1, padding='same')(x)
|
||||
x = InstanceNormalization()(x)
|
||||
|
||||
return Add()([x,input])
|
||||
return func
|
||||
|
||||
def func(x):
|
||||
x = Conv2D (nf, kernel_size=7, strides=1, padding='valid')(ZeroPadding2D(3)(x))
|
||||
x = Conv2D (nf, kernel_size=7, strides=1, padding='same')(x)
|
||||
x = InstanceNormalization()(x)
|
||||
x = ReLU()(x)
|
||||
for i in range(downs):
|
||||
|
@ -237,11 +230,11 @@ class FUNIT(object):
|
|||
exec (nnlib.import_all(), locals(), globals())
|
||||
|
||||
def func(x):
|
||||
x = Conv2D (nf, kernel_size=7, strides=1, padding='valid', activation='relu')(ZeroPadding2D(3)(x))
|
||||
x = Conv2D (nf, kernel_size=7, strides=1, padding='same', activation='relu')(x)
|
||||
for i in range(downs):
|
||||
x = Conv2D (nf * min ( 4, 2**(i+1) ), kernel_size=4, strides=2, padding='valid', activation='relu')(ZeroPadding2D(1)(x))
|
||||
x = GlobalAveragePooling2D()(x)
|
||||
x = Dense(nf)(x)
|
||||
x = Dense(latent_dim)(x)
|
||||
return x
|
||||
|
||||
return func
|
||||
|
@ -250,16 +243,14 @@ class FUNIT(object):
|
|||
def DecoderFlow(ups, n_res_blks=2, mlp_blks=2, subpixel_decoder=False ):
|
||||
exec (nnlib.import_all(), locals(), globals())
|
||||
|
||||
|
||||
|
||||
def ResBlock(dim):
|
||||
def func(input):
|
||||
inp, mlp = input
|
||||
x = inp
|
||||
x = Conv2D(dim, 3, strides=1, padding='valid')(ZeroPadding2D(1)(x))
|
||||
x = Conv2D(dim, 3, strides=1, padding='same')(x)
|
||||
x = FUNITAdain(kernel_initializer='he_normal')([x,mlp])
|
||||
x = ReLU()(x)
|
||||
x = Conv2D(dim, 3, strides=1, padding='valid')(ZeroPadding2D(1)(x))
|
||||
x = Conv2D(dim, 3, strides=1, padding='same')(x)
|
||||
x = FUNITAdain(kernel_initializer='he_normal')([x,mlp])
|
||||
return Add()([x,inp])
|
||||
return func
|
||||
|
@ -280,16 +271,16 @@ class FUNIT(object):
|
|||
for i in range(ups):
|
||||
|
||||
if subpixel_decoder:
|
||||
x = Conv2D (4* (nf // 2**(i+1)), kernel_size=3, strides=1, padding='valid')(ZeroPadding2D(1)(x))
|
||||
x = Conv2D (4* (nf // 2**(i+1)), kernel_size=3, strides=1, padding='same')(x)
|
||||
x = SubpixelUpscaler()(x)
|
||||
else:
|
||||
x = UpSampling2D()(x)
|
||||
x = Conv2D (nf // 2**(i+1), kernel_size=5, strides=1, padding='valid')(ZeroPadding2D(2)(x))
|
||||
x = Conv2D (nf // 2**(i+1), kernel_size=5, strides=1, padding='same')(x)
|
||||
|
||||
x = InstanceNormalization()(x)
|
||||
x = ReLU()(x)
|
||||
|
||||
rgb = Conv2D (3, kernel_size=7, strides=1, padding='valid', activation='tanh')(ZeroPadding2D(3)(x))
|
||||
rgb = Conv2D (3, kernel_size=7, strides=1, padding='same', activation='tanh')(x)
|
||||
return rgb
|
||||
|
||||
return func
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue