SAE: added new archi 'vg'

This commit is contained in:
iperov 2019-02-21 17:53:59 +04:00
parent d66829aae4
commit f0a20b46d3
5 changed files with 378 additions and 119 deletions

View file

@ -12,7 +12,7 @@ from utils import image_utils
import cv2
import models
def trainerThread (input_queue, output_queue, training_data_src_dir, training_data_dst_dir, model_path, model_name, save_interval_min=10, debug=False, **in_options):
def trainerThread (input_queue, output_queue, training_data_src_dir, training_data_dst_dir, model_path, model_name, save_interval_min=15, debug=False, **in_options):
while True:
try:
@ -39,10 +39,11 @@ def trainerThread (input_queue, output_queue, training_data_src_dir, training_da
**in_options)
is_reached_goal = model.is_reached_epoch_goal()
is_upd_save_time_after_train = False
def model_save():
if not debug and not is_reached_goal:
model.save()
is_upd_save_time_after_train = True
def send_preview():
if not debug:
@ -65,11 +66,15 @@ def trainerThread (input_queue, output_queue, training_data_src_dir, training_da
print('Starting. Press "Enter" to stop training and save model.')
last_save_time = time.time()
for i in itertools.count(0,1):
if not debug:
if not is_reached_goal:
loss_string = model.train_one_epoch()
if is_upd_save_time_after_train:
#save resets plaidML programs, so upd last_save_time only after plaidML rebuild them
last_save_time = time.time()
print (loss_string, end='\r')
if model.get_target_epoch() != 0 and model.is_reached_epoch_goal():
print ('Reached target epoch.')
@ -78,7 +83,7 @@ def trainerThread (input_queue, output_queue, training_data_src_dir, training_da
print ('You can use preview now.')
if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min*60:
last_save_time = time.time()
last_save_time = time.time()
model_save()
send_preview()