From 2db46007cd0ba1b88e24567281d90f98f36d2df1 Mon Sep 17 00:00:00 2001 From: iperov Date: Mon, 15 Apr 2019 10:03:44 +0400 Subject: [PATCH] _ --- facelib/FANSegmentator.py | 5 +++-- models/Model_FANSEG/Model.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/facelib/FANSegmentator.py b/facelib/FANSegmentator.py index 4bcfe3c..47a4fa1 100644 --- a/facelib/FANSegmentator.py +++ b/facelib/FANSegmentator.py @@ -49,8 +49,9 @@ class FANSegmentator(object): io.log_err("Unable to load VGG11 pretrained weights from vgg11_enc_weights.npy") if training: - self.model.compile(loss='mse', optimizer=Adam(tf_cpu_mode=2)) - + #self.model.compile(loss='mse', optimizer=Adam(tf_cpu_mode=2)) + self.model.compile(loss='binary_crossentropy', optimizer=Adam(tf_cpu_mode=2), metrics=['accuracy']) + def __enter__(self): return self diff --git a/models/Model_FANSEG/Model.py b/models/Model_FANSEG/Model.py index 0153bab..7f92161 100644 --- a/models/Model_FANSEG/Model.py +++ b/models/Model_FANSEG/Model.py @@ -56,9 +56,9 @@ class Model(ModelBase): def onTrainOneIter(self, generators_samples, generators_list): target_src, target_src_mask = generators_samples[0] - loss = self.fan_seg.train_on_batch( [target_src], [target_src_mask] ) + loss,acc = self.fan_seg.train_on_batch( [target_src], [target_src_mask] ) - return ( ('loss', loss), ) + return ( ('loss', loss), ('acc',acc)) #override def onGetPreview(self, sample):