This commit is contained in:
iperov 2019-04-15 10:03:44 +04:00
commit 2db46007cd
2 changed files with 5 additions and 4 deletions

View file

@ -49,8 +49,9 @@ class FANSegmentator(object):
io.log_err("Unable to load VGG11 pretrained weights from vgg11_enc_weights.npy")
if training:
self.model.compile(loss='mse', optimizer=Adam(tf_cpu_mode=2))
#self.model.compile(loss='mse', optimizer=Adam(tf_cpu_mode=2))
self.model.compile(loss='binary_crossentropy', optimizer=Adam(tf_cpu_mode=2), metrics=['accuracy'])
def __enter__(self):
return self

View file

@ -56,9 +56,9 @@ class Model(ModelBase):
def onTrainOneIter(self, generators_samples, generators_list):
target_src, target_src_mask = generators_samples[0]
loss = self.fan_seg.train_on_batch( [target_src], [target_src_mask] )
loss,acc = self.fan_seg.train_on_batch( [target_src], [target_src_mask] )
return ( ('loss', loss), )
return ( ('loss', loss), ('acc',acc))
#override
def onGetPreview(self, sample):