SAE collapse fix (#245)

* test

* _

* _

* upd dev_poseest

* SAE: finally collapses are fixed

* fix batch size help
This commit is contained in:
iperov 2019-04-24 09:38:26 +04:00 committed by GitHub
parent d62785ca5a
commit e1da9c56b4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 87 additions and 70 deletions

View file

@ -12,6 +12,7 @@ from nnlib import nnlib
"""
PoseEstimator estimates pitch, yaw, roll, from FAN aligned face.
trained on https://www.umdfaces.io
based on https://arxiv.org/pdf/1901.06778.pdf HYBRID COARSE-FINE CLASSIFICATION FOR HEAD POSE ESTIMATION
"""
class PoseEstimator(object):
@ -19,9 +20,12 @@ class PoseEstimator(object):
def __init__ (self, resolution, face_type_str, load_weights=True, weights_file_root=None, training=False):
exec( nnlib.import_all(), locals(), globals() )
self.class_num = 91
self.angles = [90, 45, 30, 10, 2]
self.alpha_cat_losses = [0.07,0.05,0.03,0.01,0.01]
self.class_nums = [ angle+1 for angle in self.angles ]
self.model = PoseEstimator.BuildModel(resolution, class_nums=self.class_nums)
self.model = PoseEstimator.BuildModel(resolution, class_num=self.class_num)
if weights_file_root is not None:
weights_file_root = Path(weights_file_root)
@ -33,42 +37,44 @@ class PoseEstimator(object):
if load_weights:
self.model.load_weights (str(self.weights_path))
idx_tensor = np.array([idx for idx in range(self.class_num)], dtype=K.floatx() )
idx_tensor = K.constant(idx_tensor)
inp_t, = self.model.inputs
pitch_bins_t, yaw_bins_t, roll_bins_t = self.model.outputs
bins_t = self.model.outputs
pitch_t, yaw_t, roll_t = K.sum ( pitch_bins_t * idx_tensor, 1), K.sum ( yaw_bins_t * idx_tensor, 1), K.sum ( roll_bins_t * idx_tensor, 1)
inp_pitch_bins_t = Input ( (self.class_num,) )
inp_pitch_t = Input ( (1,) )
inp_yaw_bins_t = Input ( (self.class_num,) )
inp_yaw_t = Input ( (1,) )
inp_roll_bins_t = Input ( (self.class_num,) )
inp_roll_t = Input ( (1,) )
alpha = 0.001
pitch_loss = K.categorical_crossentropy(inp_pitch_bins_t, pitch_bins_t) \
+ alpha * K.mean(K.square( inp_pitch_t - pitch_t), -1)
inp_bins_t = []
for class_num in self.class_nums:
inp_bins_t += [ Input ((class_num,)), Input ((class_num,)), Input ((class_num,)) ]
yaw_loss = K.categorical_crossentropy(inp_yaw_bins_t, yaw_bins_t) \
+ alpha * K.mean(K.square( inp_yaw_t - yaw_t), -1)
roll_loss = K.categorical_crossentropy(inp_roll_bins_t, roll_bins_t) \
+ alpha * K.mean(K.square( inp_roll_t - roll_t), -1)
loss_pitch = []
loss_yaw = []
loss_roll = []
for i,class_num in enumerate(self.class_nums):
a = self.alpha_cat_losses[i]
loss_pitch += [ a*K.categorical_crossentropy( inp_bins_t[i*3+0], bins_t[i*3+0] ) ]
loss_yaw += [ a*K.categorical_crossentropy( inp_bins_t[i*3+1], bins_t[i*3+1] ) ]
loss_roll += [ a*K.categorical_crossentropy( inp_bins_t[i*3+2], bins_t[i*3+2] ) ]
idx_tensor = K.constant( np.array([idx for idx in range(self.class_nums[0])], dtype=K.floatx() ) )
pitch_t, yaw_t, roll_t = K.sum ( bins_t[0] * idx_tensor, 1), K.sum ( bins_t[1] * idx_tensor, 1), K.sum ( bins_t[2] * idx_tensor, 1)
loss = K.mean( pitch_loss + yaw_loss + roll_loss )
reg_alpha = 0.02
reg_pitch_loss = reg_alpha * K.mean(K.square( inp_pitch_t - pitch_t), -1)
reg_yaw_loss = reg_alpha * K.mean(K.square( inp_yaw_t - yaw_t), -1)
reg_roll_loss = reg_alpha * K.mean(K.square( inp_roll_t - roll_t), -1)
pitch_loss = reg_pitch_loss + sum(loss_pitch)
yaw_loss = reg_yaw_loss + sum(loss_yaw)
roll_loss = reg_roll_loss + sum(loss_roll)
opt = Adam(lr=0.001, tf_cpu_mode=2)
opt = Adam(lr=0.001, tf_cpu_mode=0)
if training:
self.train = K.function ([inp_t, inp_pitch_bins_t, inp_pitch_t, inp_yaw_bins_t, inp_yaw_t, inp_roll_bins_t, inp_roll_t],
[loss], opt.get_updates(loss, self.model.trainable_weights) )
self.train = K.function ([inp_t, inp_pitch_t, inp_yaw_t, inp_roll_t] + inp_bins_t,
[K.mean(pitch_loss),K.mean(yaw_loss),K.mean(roll_loss)], opt.get_updates( [pitch_loss,yaw_loss,roll_loss], self.model.trainable_weights) )
self.view = K.function ([inp_t], [pitch_t, yaw_t, roll_t] )
@ -82,19 +88,27 @@ class PoseEstimator(object):
self.model.save_weights (str(self.weights_path))
def train_on_batch(self, imgs, pitch_yaw_roll):
c = ( (pitch_yaw_roll+1) * 45.0 ).astype(np.int).astype(K.floatx())
pyr = pitch_yaw_roll+1
feed = [imgs]
inp_pitch = c[:,0:1]
inp_yaw = c[:,1:2]
inp_roll = c[:,2:3]
for i, (angle, class_num) in enumerate(zip(self.angles, self.class_nums)):
c = np.round(pyr * (angle / 2) ).astype(K.floatx())
inp_pitch = c[:,0:1]
inp_yaw = c[:,1:2]
inp_roll = c[:,2:3]
if i == 0:
feed += [inp_pitch, inp_yaw, inp_roll]
inp_pitch_bins = keras.utils.to_categorical(inp_pitch, class_num )
inp_yaw_bins = keras.utils.to_categorical(inp_yaw, class_num )
inp_roll_bins = keras.utils.to_categorical(inp_roll, class_num )
feed += [inp_pitch_bins, inp_yaw_bins, inp_roll_bins]
#import code
#code.interact(local=dict(globals(), **locals()))
inp_pitch_bins = keras.utils.to_categorical(inp_pitch, self.class_num )
inp_yaw_bins = keras.utils.to_categorical(inp_yaw, self.class_num )
inp_roll_bins = keras.utils.to_categorical(inp_roll, self.class_num )
loss, = self.train( [imgs, inp_pitch_bins, inp_pitch, inp_yaw_bins, inp_yaw, inp_roll_bins, inp_roll] )
return loss
pitch_loss,yaw_loss,roll_loss = self.train(feed)
return pitch_loss,yaw_loss,roll_loss
def extract (self, input_image, is_input_tanh=False):
if is_input_tanh:
@ -106,7 +120,7 @@ class PoseEstimator(object):
pitch, yaw, roll = self.view( [input_image] )
result = np.concatenate( (pitch[...,np.newaxis], yaw[...,np.newaxis], roll[...,np.newaxis]), -1 )
result = np.clip ( result / 45.0 - 1, -1.0, 1.0 )
result = np.clip ( result / (self.angles[0] / 2) - 1, -1.0, 1.0 )
if input_shape_len == 3:
result = result[0]
@ -114,28 +128,31 @@ class PoseEstimator(object):
return result
@staticmethod
def BuildModel ( resolution, class_num):
def BuildModel ( resolution, class_nums):
exec( nnlib.import_all(), locals(), globals() )
inp = Input ( (resolution,resolution,3) )
x = inp
x = PoseEstimator.Flow(class_num=class_num)(x)
x = PoseEstimator.Flow(class_nums=class_nums)(x)
model = Model(inp,x)
return model
@staticmethod
def Flow(class_num):
def Flow(class_nums):
exec( nnlib.import_all(), locals(), globals() )
def func(input):
x = input
# resnet50 = keras.applications.ResNet50(include_top=False, weights='imagenet', input_shape=K.int_shape(x)[1:], pooling='avg')
# resnet50 = keras.applications.ResNet50(include_top=False, weights=None, input_shape=K.int_shape(x)[1:], pooling='avg')
# x = resnet50(x)
# pitch = Dense(class_num, activation='softmax', name='pitch')(x)
# yaw = Dense(class_num, activation='softmax', name='yaw')(x)
# roll = Dense(class_num, activation='softmax', name='roll')(x)
# return [pitch, yaw, roll]
# output = []
# for class_num in class_nums:
# pitch = Dense(class_num, activation='softmax')(x)
# yaw = Dense(class_num, activation='softmax')(x)
# roll = Dense(class_num, activation='softmax')(x)
# output += [pitch,yaw,roll]
# return output
x = Conv2D(64, kernel_size=11, strides=4, padding='same', activation='relu')(x)
x = MaxPooling2D( (3,3), strides=2 )(x)
@ -153,10 +170,13 @@ class PoseEstimator(object):
x = Dropout(0.5)(x)
x = Dense(1024, activation='relu')(x)
pitch = Dense(class_num, activation='softmax', name='pitch')(x)
yaw = Dense(class_num, activation='softmax', name='yaw')(x)
roll = Dense(class_num, activation='softmax', name='roll')(x)
return [pitch, yaw, roll]
output = []
for class_num in class_nums:
pitch = Dense(class_num, activation='softmax')(x)
yaw = Dense(class_num, activation='softmax')(x)
roll = Dense(class_num, activation='softmax')(x)
output += [pitch,yaw,roll]
return output
return func