diff --git a/facelib/TernausNet.py b/facelib/TernausNet.py index 00a807e..1efab9e 100644 --- a/facelib/TernausNet.py +++ b/facelib/TernausNet.py @@ -19,7 +19,7 @@ TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentat class TernausNet(object): VERSION = 1 - def __init__ (self, name, resolution, face_type_str, load_weights=True, weights_file_root=None, training=False, place_model_on_cpu=False): + def __init__ (self, name, resolution, face_type_str=None, load_weights=True, weights_file_root=None, training=False, place_model_on_cpu=False): nn.initialize(data_format="NHWC") tf = nn.tf @@ -125,12 +125,16 @@ class TernausNet(object): self.net = Ternaus(3, 64, name='Ternaus') self.net_weights = self.net.get_weights() - self.model_filename_list = [ [self.net, '%s_%d_%s.npy' % (name, resolution, face_type_str) ] ] + model_name = f'{name}_{resolution}' + if face_type_str is not None: + model_name += f'_{face_type_str}' + + self.model_filename_list = [ [self.net, f'{model_name}.npy'] ] if training: self.opt = nn.TFRMSpropOptimizer(lr=0.0001, name='opt') self.opt.initialize_variables (self.net_weights, vars_on_cpu=place_model_on_cpu) - self.model_filename_list += [ [self.opt, '%s_%d_%s_opt.npy' % (name, resolution, face_type_str) ] ] + self.model_filename_list += [ [self.opt, f'{model_name}_opt.npy' ] ] else: _, pred = self.net([self.input_t]) def net_run(input_np):