temporary revert last fixes

This commit is contained in:
Colombo 2019-12-20 10:21:59 +04:00
parent dd1d5e8909
commit 068c7d0d55
5 changed files with 57 additions and 69 deletions

View file

@ -57,16 +57,16 @@ class TernausNet(object):
real_t = Input ( (resolution, resolution, 1) )
out_t = self.model(inp_t)
loss = K.mean(10*K.binary_crossentropy(real_t,out_t) )
loss = K.mean(10*K.binary_crossentropy(real_t,out_t), axis=[1,2,3] )
out_t_diff1 = out_t[:, 1:, :, :] - out_t[:, :-1, :, :]
out_t_diff2 = out_t[:, :, 1:, :] - out_t[:, :, :-1, :]
total_var_loss = K.mean( 0.1*K.abs(out_t_diff1), axis=[1, 2, 3] ) + K.mean( 0.1*K.abs(out_t_diff2), axis=[1, 2, 3] )
opt = Adam(lr=0.0001, beta_1=0.5, beta_2=0.999, tf_cpu_mode=2)
opt = RMSprop(lr=0.0001, lr_dropout=0.3, tf_cpu_mode=2)
self.train_func = K.function ( [inp_t, real_t], [K.mean(loss)], opt.get_updates( [loss], self.model.trainable_weights) )
self.train_func = K.function ( [inp_t, real_t], [K.mean(loss)], opt.get_updates( loss, self.model.trainable_weights) )
def __enter__(self):