mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-14 00:53:48 -07:00
fix fanseg
This commit is contained in:
parent
034ad3cce5
commit
fa4e579b95
2 changed files with 12 additions and 4 deletions
|
@ -3,6 +3,7 @@ import os
|
||||||
import cv2
|
import cv2
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from nnlib import nnlib
|
from nnlib import nnlib
|
||||||
|
from interact import interact as io
|
||||||
|
|
||||||
class FANSegmentator(object):
|
class FANSegmentator(object):
|
||||||
def __init__ (self, resolution, face_type_str, load_weights=True, weights_file_root=None):
|
def __init__ (self, resolution, face_type_str, load_weights=True, weights_file_root=None):
|
||||||
|
@ -19,7 +20,15 @@ class FANSegmentator(object):
|
||||||
|
|
||||||
if load_weights:
|
if load_weights:
|
||||||
self.model.load_weights (str(self.weights_path))
|
self.model.load_weights (str(self.weights_path))
|
||||||
|
else:
|
||||||
|
io.log_info ("Initializing CA weights...")
|
||||||
|
conv_weights_list = []
|
||||||
|
for layer in self.model.layers:
|
||||||
|
if type(layer) == Conv2D:
|
||||||
|
conv_weights_list += [layer.weights[0]] # Conv2D kernel_weights
|
||||||
|
CAInitializerMP(conv_weights_list)
|
||||||
|
self.model.compile(loss='mse', optimizer=Adam(tf_cpu_mode=2))
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -43,7 +52,6 @@ class FANSegmentator(object):
|
||||||
x = FANSegmentator.EncFlow(ngf=ngf)(x)
|
x = FANSegmentator.EncFlow(ngf=ngf)(x)
|
||||||
x = FANSegmentator.DecFlow(ngf=ngf)(x)
|
x = FANSegmentator.DecFlow(ngf=ngf)(x)
|
||||||
model = Model(inp,x)
|
model = Model(inp,x)
|
||||||
model.compile (loss='mse', optimizer=Adam(tf_cpu_mode=2) )
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -37,13 +37,13 @@ class Model(ModelBase):
|
||||||
self.set_training_data_generators ([
|
self.set_training_data_generators ([
|
||||||
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
|
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=True, normalize_tanh = True ),
|
sample_process_options=SampleProcessor.Options(random_flip=True, normalize_tanh = True ),
|
||||||
output_sample_types=[ [f.TRANSFORMED | f_type | f.MODE_BGR, self.resolution],
|
output_sample_types=[ [f.TRANSFORMED | f_type | f.MODE_BGR_SHUFFLE, self.resolution],
|
||||||
[f.TRANSFORMED | f_type | f.MODE_M | f.FACE_MASK_FULL, self.resolution]
|
[f.TRANSFORMED | f_type | f.MODE_M | f.FACE_MASK_FULL, self.resolution]
|
||||||
]),
|
]),
|
||||||
|
|
||||||
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=True, normalize_tanh = True ),
|
sample_process_options=SampleProcessor.Options(random_flip=True, normalize_tanh = True ),
|
||||||
output_sample_types=[ [f.TRANSFORMED | f_type | f.MODE_BGR, self.resolution]
|
output_sample_types=[ [f.TRANSFORMED | f_type | f.MODE_BGR_SHUFFLE, self.resolution]
|
||||||
])
|
])
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue