diff --git a/nnlib/FUNIT.py b/nnlib/FUNIT.py index 37e7652..f6d72bc 100644 --- a/nnlib/FUNIT.py +++ b/nnlib/FUNIT.py @@ -51,12 +51,12 @@ class FUNIT(object): self.D_opt = RMSprop(lr=0.0001, decay=0.0001, tf_cpu_mode=2 if 'tensorflow' in nnlib.active_DeviceConfig.backend else 0) xa = Input(bgr_shape, name="xa") - la = Input(label_shape, dtype=np.int32, name="la") + la = Input(label_shape, dtype="int32", name="la") xb = Input(bgr_shape, name="xb") - lb = Input(label_shape, dtype=np.int32, name="lb") + lb = Input(label_shape, dtype="int32", name="lb") - s_xa_one = Input( (self.enc_class_model.outputs[0].shape[-1].value,), name="s_xa_input") + s_xa_one = Input( ( K.int_shape(self.enc_class_model.outputs[0])[-1],), name="s_xa_input") c_xa = self.enc_content(xa) @@ -268,7 +268,7 @@ class FUNIT(object): def func(inputs): x , class_code = inputs - nf = x.shape[-1].value + nf = K.int_shape(x)[-1] ### MLP block inside decoder mlp = class_code