mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 05:22:06 -07:00
temporary revert last fixes
This commit is contained in:
parent
dd1d5e8909
commit
068c7d0d55
5 changed files with 57 additions and 69 deletions
|
@ -57,16 +57,16 @@ class TernausNet(object):
|
|||
real_t = Input ( (resolution, resolution, 1) )
|
||||
out_t = self.model(inp_t)
|
||||
|
||||
loss = K.mean(10*K.binary_crossentropy(real_t,out_t) )
|
||||
loss = K.mean(10*K.binary_crossentropy(real_t,out_t), axis=[1,2,3] )
|
||||
|
||||
out_t_diff1 = out_t[:, 1:, :, :] - out_t[:, :-1, :, :]
|
||||
out_t_diff2 = out_t[:, :, 1:, :] - out_t[:, :, :-1, :]
|
||||
|
||||
total_var_loss = K.mean( 0.1*K.abs(out_t_diff1), axis=[1, 2, 3] ) + K.mean( 0.1*K.abs(out_t_diff2), axis=[1, 2, 3] )
|
||||
|
||||
opt = Adam(lr=0.0001, beta_1=0.5, beta_2=0.999, tf_cpu_mode=2)
|
||||
opt = RMSprop(lr=0.0001, lr_dropout=0.3, tf_cpu_mode=2)
|
||||
|
||||
self.train_func = K.function ( [inp_t, real_t], [K.mean(loss)], opt.get_updates( [loss], self.model.trainable_weights) )
|
||||
self.train_func = K.function ( [inp_t, real_t], [K.mean(loss)], opt.get_updates( loss, self.model.trainable_weights) )
|
||||
|
||||
|
||||
def __enter__(self):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue