This commit is contained in:
Colombo 2020-03-09 00:18:01 +04:00
parent eda6433936
commit d731930537
2 changed files with 6 additions and 7 deletions

View file

@ -21,7 +21,7 @@ def scantree(path):
else: else:
yield entry yield entry
def get_image_paths(dir_path, image_extensions=image_extensions, subdirs=False): def get_image_paths(dir_path, image_extensions=image_extensions, subdirs=False, return_Path_class=False):
dir_path = Path (dir_path) dir_path = Path (dir_path)
result = [] result = []
@ -34,7 +34,7 @@ def get_image_paths(dir_path, image_extensions=image_extensions, subdirs=False):
for x in list(gen): for x in list(gen):
if any([x.name.lower().endswith(ext) for ext in image_extensions]): if any([x.name.lower().endswith(ext) for ext in image_extensions]):
result.append(x.path) result.append( x.path if not return_Path_class else Path(x.path) )
return sorted(result) return sorted(result)
def get_image_unique_filestem_paths(dir_path, verbose_print_func=None): def get_image_unique_filestem_paths(dir_path, verbose_print_func=None):

View file

@ -25,7 +25,7 @@ class FANSegModel(ModelBase):
if self.is_first_run() or ask_override: if self.is_first_run() or ask_override:
self.ask_autobackup_hour() self.ask_autobackup_hour()
self.ask_target_iter() self.ask_target_iter()
self.ask_batch_size(4) self.ask_batch_size(24)
#if self.is_first_run(): #if self.is_first_run():
#resolution = io.input_int("Resolution", default_resolution, add_info="64-512") #resolution = io.input_int("Resolution", default_resolution, add_info="64-512")
@ -56,7 +56,7 @@ class FANSegModel(ModelBase):
mask_shape = nn.get4Dshape(resolution,resolution,1) mask_shape = nn.get4Dshape(resolution,resolution,1)
# Initializing model classes # Initializing model classes
self.model = TernausNet('FANSeg', self.model = TernausNet(f'{self.model_name}_FANSeg',
resolution, resolution,
FaceType.toString(self.face_type), FaceType.toString(self.face_type),
load_weights=not self.is_first_run(), load_weights=not self.is_first_run(),
@ -104,7 +104,6 @@ class FANSegModel(ModelBase):
loss_gv_op = self.model.opt.get_update_op (nn.tf_average_gv_list (gpu_loss_gvs)) loss_gv_op = self.model.opt.get_update_op (nn.tf_average_gv_list (gpu_loss_gvs))
# Initializing training and view functions # Initializing training and view functions
def train(input_np, target_np): def train(input_np, target_np):
l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np }) l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np })
@ -126,7 +125,7 @@ class FANSegModel(ModelBase):
src_generator = SampleGeneratorFace(training_data_src_path, random_ct_samples_path=training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), src_generator = SampleGeneratorFace(training_data_src_path, random_ct_samples_path=training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=True), sample_process_options=SampleProcessor.Options(random_flip=True),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'ct_mode':'idt', 'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'motion_blur':(25, 5), 'gaussian_blur':(25,5), 'data_format':nn.data_format, 'resolution': resolution}, output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'ct_mode':'lct', 'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'motion_blur':(25, 5), 'gaussian_blur':(25,5), 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
], ],
generators_count=src_generators_count ) generators_count=src_generators_count )