diff --git a/facelib/FANSegmentator.py b/facelib/FANSegmentator.py index 54ae673..d84f23d 100644 --- a/facelib/FANSegmentator.py +++ b/facelib/FANSegmentator.py @@ -3,6 +3,7 @@ import os import cv2 from pathlib import Path from nnlib import nnlib +from interact import interact as io class FANSegmentator(object): def __init__ (self, resolution, face_type_str, load_weights=True, weights_file_root=None): @@ -19,7 +20,15 @@ class FANSegmentator(object): if load_weights: self.model.load_weights (str(self.weights_path)) - + else: + io.log_info ("Initializing CA weights...") + conv_weights_list = [] + for layer in self.model.layers: + if type(layer) == Conv2D: + conv_weights_list += [layer.weights[0]] # Conv2D kernel_weights + CAInitializerMP(conv_weights_list) + self.model.compile(loss='mse', optimizer=Adam(tf_cpu_mode=2)) + def __enter__(self): return self @@ -43,7 +52,6 @@ class FANSegmentator(object): x = FANSegmentator.EncFlow(ngf=ngf)(x) x = FANSegmentator.DecFlow(ngf=ngf)(x) model = Model(inp,x) - model.compile (loss='mse', optimizer=Adam(tf_cpu_mode=2) ) return model @staticmethod diff --git a/models/Model_FANSegmentator/Model.py b/models/Model_FANSegmentator/Model.py index 5b3918a..a1baff1 100644 --- a/models/Model_FANSegmentator/Model.py +++ b/models/Model_FANSegmentator/Model.py @@ -37,13 +37,13 @@ class Model(ModelBase): self.set_training_data_generators ([ SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size, sample_process_options=SampleProcessor.Options(random_flip=True, normalize_tanh = True ), - output_sample_types=[ [f.TRANSFORMED | f_type | f.MODE_BGR, self.resolution], + output_sample_types=[ [f.TRANSFORMED | f_type | f.MODE_BGR_SHUFFLE, self.resolution], [f.TRANSFORMED | f_type | f.MODE_M | f.FACE_MASK_FULL, self.resolution] ]), SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, sample_process_options=SampleProcessor.Options(random_flip=True, normalize_tanh = True ), - output_sample_types=[ [f.TRANSFORMED | f_type | f.MODE_BGR, self.resolution] + output_sample_types=[ [f.TRANSFORMED | f_type | f.MODE_BGR_SHUFFLE, self.resolution] ]) ])