mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-16 10:03:41 -07:00
SAE: added new archi 'vg'
This commit is contained in:
parent
d66829aae4
commit
f0a20b46d3
5 changed files with 378 additions and 119 deletions
|
@ -12,7 +12,7 @@ from utils import image_utils
|
|||
import cv2
|
||||
import models
|
||||
|
||||
def trainerThread (input_queue, output_queue, training_data_src_dir, training_data_dst_dir, model_path, model_name, save_interval_min=10, debug=False, **in_options):
|
||||
def trainerThread (input_queue, output_queue, training_data_src_dir, training_data_dst_dir, model_path, model_name, save_interval_min=15, debug=False, **in_options):
|
||||
|
||||
while True:
|
||||
try:
|
||||
|
@ -39,10 +39,11 @@ def trainerThread (input_queue, output_queue, training_data_src_dir, training_da
|
|||
**in_options)
|
||||
|
||||
is_reached_goal = model.is_reached_epoch_goal()
|
||||
|
||||
is_upd_save_time_after_train = False
|
||||
def model_save():
|
||||
if not debug and not is_reached_goal:
|
||||
model.save()
|
||||
is_upd_save_time_after_train = True
|
||||
|
||||
def send_preview():
|
||||
if not debug:
|
||||
|
@ -65,11 +66,15 @@ def trainerThread (input_queue, output_queue, training_data_src_dir, training_da
|
|||
print('Starting. Press "Enter" to stop training and save model.')
|
||||
|
||||
last_save_time = time.time()
|
||||
|
||||
for i in itertools.count(0,1):
|
||||
if not debug:
|
||||
if not is_reached_goal:
|
||||
loss_string = model.train_one_epoch()
|
||||
|
||||
if is_upd_save_time_after_train:
|
||||
#save resets plaidML programs, so upd last_save_time only after plaidML rebuild them
|
||||
last_save_time = time.time()
|
||||
|
||||
print (loss_string, end='\r')
|
||||
if model.get_target_epoch() != 0 and model.is_reached_epoch_goal():
|
||||
print ('Reached target epoch.')
|
||||
|
@ -78,7 +83,7 @@ def trainerThread (input_queue, output_queue, training_data_src_dir, training_da
|
|||
print ('You can use preview now.')
|
||||
|
||||
if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min*60:
|
||||
last_save_time = time.time()
|
||||
last_save_time = time.time()
|
||||
model_save()
|
||||
send_preview()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue