mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-22 14:24:40 -07:00
_
This commit is contained in:
parent
d4a72b49c6
commit
2db46007cd
2 changed files with 5 additions and 4 deletions
|
@ -49,7 +49,8 @@ class FANSegmentator(object):
|
||||||
io.log_err("Unable to load VGG11 pretrained weights from vgg11_enc_weights.npy")
|
io.log_err("Unable to load VGG11 pretrained weights from vgg11_enc_weights.npy")
|
||||||
|
|
||||||
if training:
|
if training:
|
||||||
self.model.compile(loss='mse', optimizer=Adam(tf_cpu_mode=2))
|
#self.model.compile(loss='mse', optimizer=Adam(tf_cpu_mode=2))
|
||||||
|
self.model.compile(loss='binary_crossentropy', optimizer=Adam(tf_cpu_mode=2), metrics=['accuracy'])
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
|
@ -56,9 +56,9 @@ class Model(ModelBase):
|
||||||
def onTrainOneIter(self, generators_samples, generators_list):
|
def onTrainOneIter(self, generators_samples, generators_list):
|
||||||
target_src, target_src_mask = generators_samples[0]
|
target_src, target_src_mask = generators_samples[0]
|
||||||
|
|
||||||
loss = self.fan_seg.train_on_batch( [target_src], [target_src_mask] )
|
loss,acc = self.fan_seg.train_on_batch( [target_src], [target_src_mask] )
|
||||||
|
|
||||||
return ( ('loss', loss), )
|
return ( ('loss', loss), ('acc',acc))
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onGetPreview(self, sample):
|
def onGetPreview(self, sample):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue