This commit is contained in:
Colombo 2019-09-19 17:29:12 +04:00
parent a325d0353b
commit ef90316f27

View file

@ -51,12 +51,12 @@ class FUNIT(object):
self.D_opt = RMSprop(lr=0.0001, decay=0.0001, tf_cpu_mode=2 if 'tensorflow' in nnlib.active_DeviceConfig.backend else 0)
xa = Input(bgr_shape, name="xa")
la = Input(label_shape, dtype=np.int32, name="la")
la = Input(label_shape, dtype="int32", name="la")
xb = Input(bgr_shape, name="xb")
lb = Input(label_shape, dtype=np.int32, name="lb")
lb = Input(label_shape, dtype="int32", name="lb")
s_xa_one = Input( (self.enc_class_model.outputs[0].shape[-1].value,), name="s_xa_input")
s_xa_one = Input( ( K.int_shape(self.enc_class_model.outputs[0])[-1],), name="s_xa_input")
c_xa = self.enc_content(xa)
@ -268,7 +268,7 @@ class FUNIT(object):
def func(inputs):
x , class_code = inputs
nf = x.shape[-1].value
nf = K.int_shape(x)[-1]
### MLP block inside decoder
mlp = class_code