mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-22 06:23:20 -07:00
_
This commit is contained in:
parent
d4a72b49c6
commit
2db46007cd
2 changed files with 5 additions and 4 deletions
|
@ -49,8 +49,9 @@ class FANSegmentator(object):
|
|||
io.log_err("Unable to load VGG11 pretrained weights from vgg11_enc_weights.npy")
|
||||
|
||||
if training:
|
||||
self.model.compile(loss='mse', optimizer=Adam(tf_cpu_mode=2))
|
||||
|
||||
#self.model.compile(loss='mse', optimizer=Adam(tf_cpu_mode=2))
|
||||
self.model.compile(loss='binary_crossentropy', optimizer=Adam(tf_cpu_mode=2), metrics=['accuracy'])
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
|
|
|
@ -56,9 +56,9 @@ class Model(ModelBase):
|
|||
def onTrainOneIter(self, generators_samples, generators_list):
|
||||
target_src, target_src_mask = generators_samples[0]
|
||||
|
||||
loss = self.fan_seg.train_on_batch( [target_src], [target_src_mask] )
|
||||
loss,acc = self.fan_seg.train_on_batch( [target_src], [target_src_mask] )
|
||||
|
||||
return ( ('loss', loss), )
|
||||
return ( ('loss', loss), ('acc',acc))
|
||||
|
||||
#override
|
||||
def onGetPreview(self, sample):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue