fix fanseg

This commit is contained in:
iperov 2019-03-19 23:49:41 +04:00
parent 034ad3cce5
commit fa4e579b95
2 changed files with 12 additions and 4 deletions

View file

@ -3,6 +3,7 @@ import os
import cv2 import cv2
from pathlib import Path from pathlib import Path
from nnlib import nnlib from nnlib import nnlib
from interact import interact as io
class FANSegmentator(object): class FANSegmentator(object):
def __init__ (self, resolution, face_type_str, load_weights=True, weights_file_root=None): def __init__ (self, resolution, face_type_str, load_weights=True, weights_file_root=None):
@ -19,6 +20,14 @@ class FANSegmentator(object):
if load_weights: if load_weights:
self.model.load_weights (str(self.weights_path)) self.model.load_weights (str(self.weights_path))
else:
io.log_info ("Initializing CA weights...")
conv_weights_list = []
for layer in self.model.layers:
if type(layer) == Conv2D:
conv_weights_list += [layer.weights[0]] # Conv2D kernel_weights
CAInitializerMP(conv_weights_list)
self.model.compile(loss='mse', optimizer=Adam(tf_cpu_mode=2))
def __enter__(self): def __enter__(self):
return self return self
@ -43,7 +52,6 @@ class FANSegmentator(object):
x = FANSegmentator.EncFlow(ngf=ngf)(x) x = FANSegmentator.EncFlow(ngf=ngf)(x)
x = FANSegmentator.DecFlow(ngf=ngf)(x) x = FANSegmentator.DecFlow(ngf=ngf)(x)
model = Model(inp,x) model = Model(inp,x)
model.compile (loss='mse', optimizer=Adam(tf_cpu_mode=2) )
return model return model
@staticmethod @staticmethod

View file

@ -37,13 +37,13 @@ class Model(ModelBase):
self.set_training_data_generators ([ self.set_training_data_generators ([
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size, SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=True, normalize_tanh = True ), sample_process_options=SampleProcessor.Options(random_flip=True, normalize_tanh = True ),
output_sample_types=[ [f.TRANSFORMED | f_type | f.MODE_BGR, self.resolution], output_sample_types=[ [f.TRANSFORMED | f_type | f.MODE_BGR_SHUFFLE, self.resolution],
[f.TRANSFORMED | f_type | f.MODE_M | f.FACE_MASK_FULL, self.resolution] [f.TRANSFORMED | f_type | f.MODE_M | f.FACE_MASK_FULL, self.resolution]
]), ]),
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=True, normalize_tanh = True ), sample_process_options=SampleProcessor.Options(random_flip=True, normalize_tanh = True ),
output_sample_types=[ [f.TRANSFORMED | f_type | f.MODE_BGR, self.resolution] output_sample_types=[ [f.TRANSFORMED | f_type | f.MODE_BGR_SHUFFLE, self.resolution]
]) ])
]) ])