fix fanseg

This commit is contained in:
iperov 2019-03-20 09:08:42 +04:00
commit 6169e6ba8a
2 changed files with 59 additions and 56 deletions

View file

@ -6,7 +6,7 @@ from nnlib import nnlib
from interact import interact as io from interact import interact as io
class FANSegmentator(object): class FANSegmentator(object):
def __init__ (self, resolution, face_type_str, load_weights=True, weights_file_root=None): def __init__ (self, resolution, face_type_str, load_weights=True, weights_file_root=None, training=False):
exec( nnlib.import_all(), locals(), globals() ) exec( nnlib.import_all(), locals(), globals() )
self.model = FANSegmentator.BuildModel(resolution, ngf=32) self.model = FANSegmentator.BuildModel(resolution, ngf=32)
@ -21,12 +21,14 @@ class FANSegmentator(object):
if load_weights: if load_weights:
self.model.load_weights (str(self.weights_path)) self.model.load_weights (str(self.weights_path))
else: else:
io.log_info ("Initializing CA weights...") if training:
conv_weights_list = [] io.log_info ("Initializing CA weights...")
for layer in self.model.layers: conv_weights_list = []
if type(layer) == Conv2D: for layer in self.model.layers:
conv_weights_list += [layer.weights[0]] # Conv2D kernel_weights if type(layer) == Conv2D:
CAInitializerMP(conv_weights_list) conv_weights_list += [layer.weights[0]] # Conv2D kernel_weights
CAInitializerMP(conv_weights_list)
if training:
self.model.compile(loss='mse', optimizer=Adam(tf_cpu_mode=2)) self.model.compile(loss='mse', optimizer=Adam(tf_cpu_mode=2))
def __enter__(self): def __enter__(self):

View file

@ -28,7 +28,8 @@ class Model(ModelBase):
self.fan_seg = FANSegmentator(self.resolution, self.fan_seg = FANSegmentator(self.resolution,
FaceType.toString(self.face_type), FaceType.toString(self.face_type),
load_weights=not self.is_first_run(), load_weights=not self.is_first_run(),
weights_file_root=self.get_model_root_path() ) weights_file_root=self.get_model_root_path(),
training=True)
if self.is_training_mode: if self.is_training_mode:
f = SampleProcessor.TypeFlags f = SampleProcessor.TypeFlags