This commit is contained in:
Colombo 2020-01-26 12:56:21 +04:00
parent 76ca79216e
commit c485e1718a
4 changed files with 10 additions and 22 deletions

View file

@ -450,9 +450,10 @@ class SAEHDModel(ModelBase):
for gpu_id in range(gpu_count):
with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
with tf.device(f'/CPU:0'):
# slice on CPU, otherwise all batch data will be transfered to GPU first
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
gpu_warped_src = self.warped_src [batch_slice,:,:,:]
gpu_warped_dst = self.warped_dst [batch_slice,:,:,:]
gpu_target_src = self.target_src [batch_slice,:,:,:]
@ -646,7 +647,6 @@ class SAEHDModel(ModelBase):
model.init_weights()
# initializing sample generators
if self.is_training:
t = SampleProcessor.Types
if self.options['face_type'] == 'h':
@ -710,12 +710,8 @@ class SAEHDModel(ModelBase):
#override
def onGetPreview(self, samples):
n_samples = min(4, self.get_batch_size() )
( (warped_src, target_src, target_srcm),
(warped_dst, target_dst, target_dstm) ) = \
[ [sample[0:n_samples] for sample in sample_list ]
for sample_list in samples ]
(warped_dst, target_dst, target_dstm) ) = samples
if self.options['learn_mask']:
S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ]
@ -725,6 +721,7 @@ class SAEHDModel(ModelBase):
target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )]
n_samples = min(4, self.get_batch_size() )
result = []
st = []
for i in range(n_samples):