diff --git a/facelib/FANSegmentator.py b/facelib/FANSegmentator.py index 4bcfe3c..47a4fa1 100644 --- a/facelib/FANSegmentator.py +++ b/facelib/FANSegmentator.py @@ -49,8 +49,9 @@ class FANSegmentator(object): io.log_err("Unable to load VGG11 pretrained weights from vgg11_enc_weights.npy") if training: - self.model.compile(loss='mse', optimizer=Adam(tf_cpu_mode=2)) - + #self.model.compile(loss='mse', optimizer=Adam(tf_cpu_mode=2)) + self.model.compile(loss='binary_crossentropy', optimizer=Adam(tf_cpu_mode=2), metrics=['accuracy']) + def __enter__(self): return self diff --git a/models/Model_FANSEG/Model.py b/models/Model_FANSEG/Model.py index 0153bab..7f92161 100644 --- a/models/Model_FANSEG/Model.py +++ b/models/Model_FANSEG/Model.py @@ -56,9 +56,9 @@ class Model(ModelBase): def onTrainOneIter(self, generators_samples, generators_list): target_src, target_src_mask = generators_samples[0] - loss = self.fan_seg.train_on_batch( [target_src], [target_src_mask] ) + loss,acc = self.fan_seg.train_on_batch( [target_src], [target_src_mask] ) - return ( ('loss', loss), ) + return ( ('loss', loss), ('acc',acc)) #override def onGetPreview(self, sample):