mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-19 13:09:56 -07:00
SAE: revert back CA weights option
This commit is contained in:
parent
0cd8dd7296
commit
9535a657d2
6 changed files with 8 additions and 5 deletions
|
@ -23,7 +23,6 @@ class PoseEstimator(object):
|
|||
self.angles = [90, 45, 30, 10, 2]
|
||||
self.alpha_cat_losses = [7,5,3,1,1]
|
||||
self.class_nums = [ angle+1 for angle in self.angles ]
|
||||
|
||||
self.model = PoseEstimator.BuildModel(resolution, class_nums=self.class_nums)
|
||||
|
||||
|
||||
|
@ -57,7 +56,7 @@ class PoseEstimator(object):
|
|||
loss_pitch = []
|
||||
loss_yaw = []
|
||||
loss_roll = []
|
||||
|
||||
|
||||
for i,class_num in enumerate(self.class_nums):
|
||||
a = self.alpha_cat_losses[i]
|
||||
loss_pitch += [ a*K.categorical_crossentropy( inp_bins_t[i*3+0], bins_t[i*3+0] ) ]
|
||||
|
@ -65,7 +64,12 @@ class PoseEstimator(object):
|
|||
loss_roll += [ a*K.categorical_crossentropy( inp_bins_t[i*3+2], bins_t[i*3+2] ) ]
|
||||
|
||||
idx_tensor = K.constant( np.array([idx for idx in range(self.class_nums[0])], dtype=K.floatx() ) )
|
||||
pitch_t, yaw_t, roll_t = K.sum ( bins_t[0] * idx_tensor, 1), K.sum ( bins_t[1] * idx_tensor, 1), K.sum ( bins_t[2] * idx_tensor, 1)
|
||||
#pitch_t, yaw_t, roll_t = K.sum ( bins_t[0] * idx_tensor, 1), K.sum ( bins_t[1] * idx_tensor, 1), K.sum ( bins_t[2] * idx_tensor, 1)
|
||||
|
||||
pitch_t, yaw_t, roll_t = nnlib.tf.reduce_sum ( bins_t[0] * idx_tensor, 1), nnlib.tf.reduce_sum ( bins_t[1] * idx_tensor, 1), nnlib.tf.reduce_sum ( bins_t[2] * idx_tensor, 1)
|
||||
|
||||
|
||||
|
||||
|
||||
reg_alpha = 2
|
||||
reg_pitch_loss = reg_alpha * K.mean(K.square( inp_pitch_t - pitch_t), -1)
|
||||
|
@ -75,7 +79,6 @@ class PoseEstimator(object):
|
|||
pitch_loss = reg_pitch_loss + sum(loss_pitch)
|
||||
yaw_loss = reg_yaw_loss + sum(loss_yaw)
|
||||
roll_loss = reg_roll_loss + sum(loss_roll)
|
||||
|
||||
opt = Adam(lr=0.000001)
|
||||
|
||||
if training:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue