mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-30 03:29:59 -07:00
fix fanseg
This commit is contained in:
parent
a3df04999c
commit
6169e6ba8a
2 changed files with 59 additions and 56 deletions
|
@ -6,44 +6,46 @@ from nnlib import nnlib
|
|||
from interact import interact as io
|
||||
|
||||
class FANSegmentator(object):
|
||||
def __init__ (self, resolution, face_type_str, load_weights=True, weights_file_root=None):
|
||||
def __init__ (self, resolution, face_type_str, load_weights=True, weights_file_root=None, training=False):
|
||||
exec( nnlib.import_all(), locals(), globals() )
|
||||
|
||||
|
||||
self.model = FANSegmentator.BuildModel(resolution, ngf=32)
|
||||
|
||||
|
||||
if weights_file_root:
|
||||
weights_file_root = Path(weights_file_root)
|
||||
else:
|
||||
weights_file_root = Path(__file__).parent
|
||||
|
||||
|
||||
self.weights_path = weights_file_root / ('FANSeg_%d_%s.h5' % (resolution, face_type_str) )
|
||||
|
||||
|
||||
if load_weights:
|
||||
self.model.load_weights (str(self.weights_path))
|
||||
else:
|
||||
io.log_info ("Initializing CA weights...")
|
||||
conv_weights_list = []
|
||||
for layer in self.model.layers:
|
||||
if type(layer) == Conv2D:
|
||||
conv_weights_list += [layer.weights[0]] # Conv2D kernel_weights
|
||||
CAInitializerMP(conv_weights_list)
|
||||
if training:
|
||||
io.log_info ("Initializing CA weights...")
|
||||
conv_weights_list = []
|
||||
for layer in self.model.layers:
|
||||
if type(layer) == Conv2D:
|
||||
conv_weights_list += [layer.weights[0]] # Conv2D kernel_weights
|
||||
CAInitializerMP(conv_weights_list)
|
||||
if training:
|
||||
self.model.compile(loss='mse', optimizer=Adam(tf_cpu_mode=2))
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
|
||||
def __exit__(self, exc_type=None, exc_value=None, traceback=None):
|
||||
return False #pass exception between __enter__ and __exit__ to outter level
|
||||
|
||||
|
||||
def save_weights(self):
|
||||
self.model.save_weights (str(self.weights_path))
|
||||
|
||||
|
||||
def train_on_batch(self, inp, outp):
|
||||
return self.model.train_on_batch(inp, outp)
|
||||
|
||||
|
||||
def extract_from_bgr (self, input_image):
|
||||
return np.clip ( (self.model.predict(input_image) + 1) / 2.0, 0, 1.0 )
|
||||
|
||||
|
||||
@staticmethod
|
||||
def BuildModel ( resolution, ngf=64):
|
||||
exec( nnlib.import_all(), locals(), globals() )
|
||||
|
@ -53,7 +55,7 @@ class FANSegmentator(object):
|
|||
x = FANSegmentator.DecFlow(ngf=ngf)(x)
|
||||
model = Model(inp,x)
|
||||
return model
|
||||
|
||||
|
||||
@staticmethod
|
||||
def EncFlow(ngf=64, num_downs=4):
|
||||
exec( nnlib.import_all(), locals(), globals() )
|
||||
|
@ -65,19 +67,19 @@ class FANSegmentator(object):
|
|||
def downscale (dim):
|
||||
def func(x):
|
||||
return LeakyReLU(0.1)(XNormalization(Conv2D(dim, kernel_size=5, strides=2, padding='same', kernel_initializer=RandomNormal(0, 0.02))(x)))
|
||||
return func
|
||||
|
||||
def func(input):
|
||||
return func
|
||||
|
||||
def func(input):
|
||||
x = input
|
||||
|
||||
|
||||
result = []
|
||||
for i in range(num_downs):
|
||||
x = downscale ( min(ngf*(2**i), ngf*8) )(x)
|
||||
result += [x]
|
||||
|
||||
result += [x]
|
||||
|
||||
return result
|
||||
return func
|
||||
|
||||
|
||||
@staticmethod
|
||||
def DecFlow(output_nc=1, ngf=64, activation='tanh'):
|
||||
exec (nnlib.import_all(), locals(), globals())
|
||||
|
@ -85,23 +87,23 @@ class FANSegmentator(object):
|
|||
use_bias = True
|
||||
def XNormalization(x):
|
||||
return InstanceNormalization (axis=3, gamma_initializer=RandomNormal(1., 0.02))(x)
|
||||
|
||||
|
||||
def Conv2D (filters, kernel_size, strides=(1, 1), padding='valid', data_format=None, dilation_rate=(1, 1), activation=None, use_bias=use_bias, kernel_initializer=RandomNormal(0, 0.02), bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None):
|
||||
return keras.layers.Conv2D( filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, data_format=data_format, dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint )
|
||||
|
||||
|
||||
def upscale (dim):
|
||||
def func(x):
|
||||
return SubpixelUpscaler()( LeakyReLU(0.1)(XNormalization(Conv2D(dim, kernel_size=3, strides=1, padding='same', kernel_initializer=RandomNormal(0, 0.02))(x))))
|
||||
return func
|
||||
|
||||
return func
|
||||
|
||||
def func(input):
|
||||
input_len = len(input)
|
||||
|
||||
|
||||
x = input[input_len-1]
|
||||
for i in range(input_len-1, -1, -1):
|
||||
for i in range(input_len-1, -1, -1):
|
||||
x = upscale( min(ngf* (2**i) *4, ngf*8 *4 ) )(x)
|
||||
if i != 0:
|
||||
x = Concatenate(axis=3)([ input[i-1] , x])
|
||||
|
||||
|
||||
return Conv2D(output_nc, 3, 1, 'same', activation=activation)(x)
|
||||
return func
|
||||
return func
|
|
@ -10,13 +10,13 @@ from interact import interact as io
|
|||
class Model(ModelBase):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs,
|
||||
ask_write_preview_history=False,
|
||||
super().__init__(*args, **kwargs,
|
||||
ask_write_preview_history=False,
|
||||
ask_target_iter=False,
|
||||
ask_sort_by_yaw=False,
|
||||
ask_random_flip=False,
|
||||
ask_src_scale_mod=False)
|
||||
|
||||
|
||||
#override
|
||||
def onInitialize(self):
|
||||
exec(nnlib.import_all(), locals(), globals())
|
||||
|
@ -24,33 +24,34 @@ class Model(ModelBase):
|
|||
|
||||
self.resolution = 256
|
||||
self.face_type = FaceType.FULL
|
||||
|
||||
self.fan_seg = FANSegmentator(self.resolution,
|
||||
FaceType.toString(self.face_type),
|
||||
|
||||
self.fan_seg = FANSegmentator(self.resolution,
|
||||
FaceType.toString(self.face_type),
|
||||
load_weights=not self.is_first_run(),
|
||||
weights_file_root=self.get_model_root_path() )
|
||||
weights_file_root=self.get_model_root_path(),
|
||||
training=True)
|
||||
|
||||
if self.is_training_mode:
|
||||
f = SampleProcessor.TypeFlags
|
||||
f_type = f.FACE_ALIGN_FULL
|
||||
|
||||
self.set_training_data_generators ([
|
||||
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=True, normalize_tanh = True ),
|
||||
|
||||
self.set_training_data_generators ([
|
||||
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=True, normalize_tanh = True ),
|
||||
output_sample_types=[ [f.TRANSFORMED | f_type | f.MODE_BGR_SHUFFLE, self.resolution],
|
||||
[f.TRANSFORMED | f_type | f.MODE_M | f.FACE_MASK_FULL, self.resolution]
|
||||
]),
|
||||
|
||||
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=True, normalize_tanh = True ),
|
||||
|
||||
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=True, normalize_tanh = True ),
|
||||
output_sample_types=[ [f.TRANSFORMED | f_type | f.MODE_BGR_SHUFFLE, self.resolution]
|
||||
])
|
||||
])
|
||||
|
||||
|
||||
#override
|
||||
def onSave(self):
|
||||
def onSave(self):
|
||||
self.fan_seg.save_weights()
|
||||
|
||||
|
||||
#override
|
||||
def onTrainOneIter(self, generators_samples, generators_list):
|
||||
target_src, target_src_mask = generators_samples[0]
|
||||
|
@ -58,20 +59,20 @@ class Model(ModelBase):
|
|||
loss = self.fan_seg.train_on_batch( [target_src], [target_src_mask] )
|
||||
|
||||
return ( ('loss', loss), )
|
||||
|
||||
|
||||
#override
|
||||
def onGetPreview(self, sample):
|
||||
test_A = sample[0][0][0:4] #first 4 samples
|
||||
test_B = sample[1][0][0:4] #first 4 samples
|
||||
|
||||
|
||||
mAA = self.fan_seg.extract_from_bgr([test_A])
|
||||
mBB = self.fan_seg.extract_from_bgr([test_B])
|
||||
|
||||
|
||||
test_A, test_B, = [ np.clip( (x + 1.0)/2.0, 0.0, 1.0) for x in [test_A, test_B] ]
|
||||
|
||||
|
||||
mAA = np.repeat ( mAA, (3,), -1)
|
||||
mBB = np.repeat ( mBB, (3,), -1)
|
||||
|
||||
|
||||
st = []
|
||||
for i in range(0, len(test_A)):
|
||||
st.append ( np.concatenate ( (
|
||||
|
@ -79,7 +80,7 @@ class Model(ModelBase):
|
|||
mAA[i],
|
||||
test_A[i,:,:,0:3]*mAA[i],
|
||||
), axis=1) )
|
||||
|
||||
|
||||
st2 = []
|
||||
for i in range(0, len(test_B)):
|
||||
st2.append ( np.concatenate ( (
|
||||
|
@ -87,7 +88,7 @@ class Model(ModelBase):
|
|||
mBB[i],
|
||||
test_B[i,:,:,0:3]*mBB[i],
|
||||
), axis=1) )
|
||||
|
||||
|
||||
return [ ('FANSegmentator', np.concatenate ( st, axis=0 ) ),
|
||||
('never seen', np.concatenate ( st2, axis=0 ) ),
|
||||
]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue