mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-16 10:03:41 -07:00
_
This commit is contained in:
parent
47f9bad42b
commit
268b402513
2 changed files with 3 additions and 3 deletions
|
@ -50,7 +50,7 @@ class FANSegmentator(object):
|
||||||
|
|
||||||
if training:
|
if training:
|
||||||
#self.model.compile(loss='mse', optimizer=Adam(tf_cpu_mode=2))
|
#self.model.compile(loss='mse', optimizer=Adam(tf_cpu_mode=2))
|
||||||
self.model.compile(loss='binary_crossentropy', optimizer=Adam(tf_cpu_mode=2), metrics=['accuracy'])
|
self.model.compile(loss='binary_crossentropy', optimizer=Adam(tf_cpu_mode=2) )
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
|
@ -65,9 +65,9 @@ class Model(ModelBase):
|
||||||
def onTrainOneIter(self, generators_samples, generators_list):
|
def onTrainOneIter(self, generators_samples, generators_list):
|
||||||
target_src, target_src_mask = generators_samples[0]
|
target_src, target_src_mask = generators_samples[0]
|
||||||
|
|
||||||
loss,acc = self.fan_seg.train_on_batch( [target_src], [target_src_mask] )
|
loss = self.fan_seg.train_on_batch( [target_src], [target_src_mask] )
|
||||||
|
|
||||||
return ( ('loss', loss), ('acc',acc))
|
return ( ('loss', loss), )
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onGetPreview(self, sample):
|
def onGetPreview(self, sample):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue