upd fan segmentator

This commit is contained in:
iperov 2019-03-19 19:44:14 +04:00
commit 034ad3cce5
3 changed files with 27 additions and 37 deletions

View file

@ -33,7 +33,6 @@ class FANSegmentator(object):
return self.model.train_on_batch(inp, outp)
def extract_from_bgr (self, input_image):
#return np.clip ( self.model.predict(input_image), 0, 1.0 )
return np.clip ( (self.model.predict(input_image) + 1) / 2.0, 0, 1.0 )
@staticmethod
@ -44,8 +43,7 @@ class FANSegmentator(object):
x = FANSegmentator.EncFlow(ngf=ngf)(x)
x = FANSegmentator.DecFlow(ngf=ngf)(x)
model = Model(inp,x)
model.compile (loss='mse', optimizer=Padam(tf_cpu_mode=2) )
#model.compile (loss='mse', optimizer=Adam(tf_cpu_mode=2) )
model.compile (loss='mse', optimizer=Adam(tf_cpu_mode=2) )
return model
@staticmethod