fix reshape

This commit is contained in:
jh 2021-03-16 19:20:24 -07:00
commit a89cf63e9f

View file

@ -514,7 +514,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
else: else:
label = np.random.uniform(1-smoothing/2, 1.0) label = np.random.uniform(1-smoothing/2, 1.0)
labels.append(label) labels.append(label)
return tf.reshape(labels, tensor.shape) return tf.reshape(labels, (self.batch_size,) + tensor.shape[1:])
gpu_pred_src_src_d_ones = get_smooth_noisy_labels(1, gpu_pred_src_src_d, smoothing=0.2, noise=0.05) gpu_pred_src_src_d_ones = get_smooth_noisy_labels(1, gpu_pred_src_src_d, smoothing=0.2, noise=0.05)
gpu_pred_src_src_d_zeros = get_smooth_noisy_labels(0, gpu_pred_src_src_d, smoothing=0.2, noise=0.05) gpu_pred_src_src_d_zeros = get_smooth_noisy_labels(0, gpu_pred_src_src_d, smoothing=0.2, noise=0.05)