diff --git a/models/Model_AMP/Model.py b/models/Model_AMP/Model.py index 0fbcddb..e6a0c96 100644 --- a/models/Model_AMP/Model.py +++ b/models/Model_AMP/Model.py @@ -358,10 +358,11 @@ class AMPModel(ModelBase): gpu_dst_inter_src_code, gpu_dst_inter_dst_code = self.inter_src (gpu_dst_code), self.inter_dst (gpu_dst_code) inter_dims_bin = int(inter_dims*morph_factor) - inter_rnd_binomial = tf.stack([tf.concat([tf.tile(tf.constant([1], tf.float32), ( inter_dims_bin, )), - tf.tile(tf.constant([0], tf.float32), ( inter_dims-inter_dims_bin, ))], 0 ) for _ in range(bs_per_gpu)], 0) - - inter_rnd_binomial = tf.stop_gradient(inter_rnd_binomial[...,None,None]) + with tf.device(f'/CPU:0'): + inter_rnd_binomial = tf.stack([tf.random.shuffle(tf.concat([tf.tile(tf.constant([1], tf.float32), ( inter_dims_bin, )), + tf.tile(tf.constant([0], tf.float32), ( inter_dims-inter_dims_bin, ))], 0 )) for _ in range(bs_per_gpu)], 0) + + inter_rnd_binomial = tf.stop_gradient(inter_rnd_binomial[...,None,None]) gpu_src_code = gpu_src_inter_src_code * inter_rnd_binomial + gpu_src_inter_dst_code * (1-inter_rnd_binomial) gpu_dst_code = gpu_dst_inter_dst_code