mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-30 11:40:36 -07:00
nothing interesting
This commit is contained in:
parent
0e088f6415
commit
f89200d236
2 changed files with 38 additions and 26 deletions
|
@ -76,7 +76,7 @@ class PoseEstimator(object):
|
|||
|
||||
CAInitializerMP ( gather_Conv2D_layers( [self.encoder, self.decoder] ) )
|
||||
|
||||
idx_tensor = K.constant( np.array([idx for idx in range(self.class_nums[0])], dtype=K.floatx() ) )
|
||||
idx_tensor = self.idx_tensor = K.constant( np.array([idx for idx in range(self.class_nums[0])], dtype=K.floatx() ) )
|
||||
pitch_t, yaw_t, roll_t = K.sum ( bins_t[0] * idx_tensor, 1), K.sum (bins_t[1] * idx_tensor, 1), K.sum ( bins_t[2] * idx_tensor, 1)
|
||||
|
||||
if training:
|
||||
|
@ -110,11 +110,17 @@ class PoseEstimator(object):
|
|||
[bgr_loss], Adam(lr=2e-4, beta_1=0.5, beta_2=0.999).get_updates( bgr_loss, self.encoder.trainable_weights+self.decoder.trainable_weights ) )
|
||||
|
||||
self.train_l = K.function ([inp_t, inp_pitch_t, inp_yaw_t, inp_roll_t] + inp_bins_t,
|
||||
[K.mean(pitch_loss),K.mean(yaw_loss),K.mean(roll_loss)], Adam(lr=0.000001).get_updates( [pitch_loss,yaw_loss,roll_loss], self.model_l.trainable_weights) )
|
||||
[K.mean(pitch_loss),K.mean(yaw_loss),K.mean(roll_loss)], Adam(lr=0.0001).get_updates( [pitch_loss,yaw_loss,roll_loss], self.model_l.trainable_weights) )
|
||||
|
||||
|
||||
self.view = K.function ([inp_t], [pitch_t, yaw_t, roll_t] )
|
||||
|
||||
|
||||
def flow(self, x):
|
||||
bins_t = self.model(x)
|
||||
return bins_t[0], bins_t[1], bins_t[2]
|
||||
pitch_t, yaw_t, roll_t = K.sum ( bins_t[0] * self.idx_tensor, 1), K.sum (bins_t[1] * self.idx_tensor, 1), K.sum ( bins_t[2] * self.idx_tensor, 1)
|
||||
return pitch_t, yaw_t, roll_t
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue