mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-19 21:13:20 -07:00
fix fanseg
This commit is contained in:
parent
a3df04999c
commit
6169e6ba8a
2 changed files with 59 additions and 56 deletions
|
@ -6,7 +6,7 @@ from nnlib import nnlib
|
||||||
from interact import interact as io
|
from interact import interact as io
|
||||||
|
|
||||||
class FANSegmentator(object):
|
class FANSegmentator(object):
|
||||||
def __init__ (self, resolution, face_type_str, load_weights=True, weights_file_root=None):
|
def __init__ (self, resolution, face_type_str, load_weights=True, weights_file_root=None, training=False):
|
||||||
exec( nnlib.import_all(), locals(), globals() )
|
exec( nnlib.import_all(), locals(), globals() )
|
||||||
|
|
||||||
self.model = FANSegmentator.BuildModel(resolution, ngf=32)
|
self.model = FANSegmentator.BuildModel(resolution, ngf=32)
|
||||||
|
@ -21,12 +21,14 @@ class FANSegmentator(object):
|
||||||
if load_weights:
|
if load_weights:
|
||||||
self.model.load_weights (str(self.weights_path))
|
self.model.load_weights (str(self.weights_path))
|
||||||
else:
|
else:
|
||||||
io.log_info ("Initializing CA weights...")
|
if training:
|
||||||
conv_weights_list = []
|
io.log_info ("Initializing CA weights...")
|
||||||
for layer in self.model.layers:
|
conv_weights_list = []
|
||||||
if type(layer) == Conv2D:
|
for layer in self.model.layers:
|
||||||
conv_weights_list += [layer.weights[0]] # Conv2D kernel_weights
|
if type(layer) == Conv2D:
|
||||||
CAInitializerMP(conv_weights_list)
|
conv_weights_list += [layer.weights[0]] # Conv2D kernel_weights
|
||||||
|
CAInitializerMP(conv_weights_list)
|
||||||
|
if training:
|
||||||
self.model.compile(loss='mse', optimizer=Adam(tf_cpu_mode=2))
|
self.model.compile(loss='mse', optimizer=Adam(tf_cpu_mode=2))
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
|
|
|
@ -28,7 +28,8 @@ class Model(ModelBase):
|
||||||
self.fan_seg = FANSegmentator(self.resolution,
|
self.fan_seg = FANSegmentator(self.resolution,
|
||||||
FaceType.toString(self.face_type),
|
FaceType.toString(self.face_type),
|
||||||
load_weights=not self.is_first_run(),
|
load_weights=not self.is_first_run(),
|
||||||
weights_file_root=self.get_model_root_path() )
|
weights_file_root=self.get_model_root_path(),
|
||||||
|
training=True)
|
||||||
|
|
||||||
if self.is_training_mode:
|
if self.is_training_mode:
|
||||||
f = SampleProcessor.TypeFlags
|
f = SampleProcessor.TypeFlags
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue