mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-22 06:23:20 -07:00
Move ping pong logic into generator class
This commit is contained in:
parent
f2ed44f3c4
commit
1ecc6f1b62
5 changed files with 122 additions and 78 deletions
|
@ -14,6 +14,7 @@ import imagelib
|
|||
from interact import interact as io
|
||||
from nnlib import nnlib
|
||||
from samplelib import SampleGeneratorBase
|
||||
from samplelib.SampleGeneratorPingPong import PingPongOptions, Paddle
|
||||
from utils import Path_utils, std_utils
|
||||
from utils.cv2_utils import *
|
||||
|
||||
|
@ -67,7 +68,7 @@ class ModelBase(object):
|
|||
self.debug = debug
|
||||
self.is_training_mode = (training_data_src_path is not None and training_data_dst_path is not None)
|
||||
|
||||
self.paddle = 'pong'
|
||||
self.paddle = Paddle.PONG
|
||||
|
||||
self.iter = 0
|
||||
self.options = {}
|
||||
|
@ -210,15 +211,18 @@ class ModelBase(object):
|
|||
|
||||
self.onInitializeOptions(self.iter == 0, ask_override)
|
||||
|
||||
self.ping_pong_options = PingPongOptions(enabled=self.options['ping_pong'],
|
||||
iterations=self.ping_pong_iter,
|
||||
model_iter=self.iter,
|
||||
paddle=self.paddle,
|
||||
batch_cap=self.batch_size)
|
||||
|
||||
nnlib.import_all(self.device_config)
|
||||
self.keras = nnlib.keras
|
||||
self.K = nnlib.keras.backend
|
||||
|
||||
self.onInitialize()
|
||||
|
||||
self.options['batch_size'] = self.batch_size
|
||||
self.paddle = self.options.get('paddle', 'ping')
|
||||
|
||||
if self.debug or self.batch_size == 0:
|
||||
self.batch_size = 1
|
||||
|
||||
|
@ -413,7 +417,7 @@ class ModelBase(object):
|
|||
|
||||
def save(self):
|
||||
self.options['batch_size'] = self.batch_size
|
||||
self.options['paddle'] = self.paddle
|
||||
self.options['paddle'] = self.ping_pong_options.paddle
|
||||
summary_path = self.get_strpath_storage_for_file('summary.txt')
|
||||
Path( summary_path ).write_text(self.model_summary_text)
|
||||
self.onSave()
|
||||
|
@ -542,12 +546,11 @@ class ModelBase(object):
|
|||
return [ generator.generate_next() for generator in self.generator_list]
|
||||
|
||||
def train_one_iter(self):
|
||||
|
||||
if self.iter == 1 and self.options.get('ping_pong', False):
|
||||
self.set_batch_size(1)
|
||||
self.paddle = 'ping'
|
||||
elif not self.options.get('ping_pong', False) and self.batch_cap != self.batch_size:
|
||||
self.set_batch_size(self.batch_cap)
|
||||
# if self.iter == 1 and self.options.get('ping_pong', False):
|
||||
# self.set_batch_size(1)
|
||||
# self.paddle = 'ping'
|
||||
# elif not self.options.get('ping_pong', False) and self.batch_cap != self.batch_size:
|
||||
# self.set_batch_size(self.batch_cap)
|
||||
sample = self.generate_next_sample()
|
||||
iter_time = time.time()
|
||||
losses = self.onTrainOneIter(sample, self.generator_list)
|
||||
|
@ -574,21 +577,6 @@ class ModelBase(object):
|
|||
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
|
||||
cv2_imwrite (filepath, img )
|
||||
|
||||
if self.iter % self.ping_pong_iter == 0 and self.iter != 0 and self.options.get('ping_pong', False):
|
||||
if self.batch_size == self.batch_cap:
|
||||
self.paddle = 'pong'
|
||||
if self.batch_size > self.batch_cap:
|
||||
self.set_batch_size(self.batch_cap)
|
||||
self.paddle = 'pong'
|
||||
if self.batch_size == 1:
|
||||
self.paddle = 'ping'
|
||||
if self.paddle == 'ping':
|
||||
self.save()
|
||||
self.set_batch_size(self.batch_size + 1)
|
||||
else:
|
||||
self.save()
|
||||
self.set_batch_size(self.batch_size - 1)
|
||||
|
||||
self.iter += 1
|
||||
|
||||
return self.iter, iter_time, self.batch_size
|
||||
|
|
|
@ -487,7 +487,8 @@ class SAEModel(ModelBase):
|
|||
'apply_ct': apply_random_ct} for i in range(ms_count)] + \
|
||||
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
|
||||
'resolution': resolution // (2 ** i)} for i in
|
||||
range(ms_count)]
|
||||
range(ms_count)],
|
||||
ping_pong=self.ping_pong_options,
|
||||
),
|
||||
|
||||
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
|
@ -500,7 +501,8 @@ class SAEModel(ModelBase):
|
|||
range(ms_count)] + \
|
||||
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
|
||||
'resolution': resolution // (2 ** i)} for i in
|
||||
range(ms_count)])
|
||||
range(ms_count)],
|
||||
ping_pong=self.ping_pong_options,)
|
||||
])
|
||||
|
||||
# override
|
||||
|
@ -557,7 +559,8 @@ class SAEModel(ModelBase):
|
|||
'apply_ct': apply_random_ct} for i in range(ms_count)] + \
|
||||
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
|
||||
'resolution': resolution // (2 ** i)} for i in
|
||||
range(ms_count)]
|
||||
range(ms_count)],
|
||||
ping_pong=self.ping_pong_options,
|
||||
),
|
||||
|
||||
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
|
@ -570,7 +573,8 @@ class SAEModel(ModelBase):
|
|||
range(ms_count)] + \
|
||||
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
|
||||
'resolution': resolution // (2 ** i)} for i in
|
||||
range(ms_count)])
|
||||
range(ms_count)],
|
||||
ping_pong=self.ping_pong_options,)
|
||||
])
|
||||
|
||||
# override
|
||||
|
|
|
@ -5,8 +5,9 @@ import cv2
|
|||
import numpy as np
|
||||
|
||||
from facelib import LandmarksProcessor
|
||||
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
|
||||
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, SampleGeneratorPingPong,
|
||||
SampleType)
|
||||
from samplelib.SampleGeneratorPingPong import PingPongOptions
|
||||
from utils import iter_utils
|
||||
|
||||
|
||||
|
@ -19,12 +20,12 @@ output_sample_types = [
|
|||
'''
|
||||
|
||||
|
||||
class SampleGeneratorFace(SampleGeneratorBase):
|
||||
class SampleGeneratorFace(SampleGeneratorPingPong):
|
||||
def __init__(self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None,
|
||||
random_ct_samples_path=None, sample_process_options=SampleProcessor.Options(),
|
||||
output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None,
|
||||
**kwargs):
|
||||
super().__init__(samples_path, debug, batch_size)
|
||||
ping_pong=PingPongOptions(), **kwargs):
|
||||
super().__init__(samples_path, debug, batch_size, ping_pong)
|
||||
self.sample_process_options = sample_process_options
|
||||
self.output_sample_types = output_sample_types
|
||||
self.add_sample_idx = add_sample_idx
|
||||
|
@ -64,6 +65,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
def __next__(self):
|
||||
self.generator_counter += 1
|
||||
generator = self.generators[self.generator_counter % len(self.generators)]
|
||||
super().__next__()
|
||||
return next(generator)
|
||||
|
||||
def batch_func(self, param):
|
||||
|
|
50
samplelib/SampleGeneratorPingPong.py
Normal file
50
samplelib/SampleGeneratorPingPong.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
from enum import Enum
|
||||
from samplelib import SampleGeneratorBase
|
||||
|
||||
|
||||
class Paddle(Enum):
|
||||
PING = 'ping' # Ascending
|
||||
PONG = 'pong' # Descending
|
||||
|
||||
|
||||
class PingPongOptions:
|
||||
def __init__(self, enabled=False, iterations=1000, model_iter=1, paddle=Paddle.PING, batch_cap=1):
|
||||
self.enabled = enabled
|
||||
self.iterations = iterations
|
||||
self.model_iter = model_iter
|
||||
self.paddle = paddle
|
||||
self.batch_cap = batch_cap
|
||||
|
||||
|
||||
class SampleGeneratorPingPong(SampleGeneratorBase):
|
||||
def __init__(self, *args, batch_size, ping_pong=PingPongOptions()):
|
||||
self.ping_pong = ping_pong
|
||||
super().__init__(*args, batch_size)
|
||||
|
||||
def __next__(self):
|
||||
if self.ping_pong.enabled and self.ping_pong.model_iter % self.ping_pong.iterations == 0 \
|
||||
and self.ping_pong.model_iter != 0:
|
||||
|
||||
# If batch size is greater then batch cap, set it to batch cap
|
||||
if self.batch_size > self.ping_pong.batch_cap:
|
||||
self.batch_size = self.ping_pong.batch_cap
|
||||
|
||||
# If we are at the batch cap, switch to PONG (descend)
|
||||
if self.batch_size == self.ping_pong.batch_cap:
|
||||
self.paddle = Paddle.PONG
|
||||
# Else if we are at 1, switch to PING (ascend)
|
||||
elif self.batch_size == 1:
|
||||
self.paddle = Paddle.PING
|
||||
|
||||
# If PING (ascending) increase the batch size
|
||||
if self.paddle is Paddle.PING:
|
||||
self.batch_size += 1
|
||||
# Else decrease the batch size
|
||||
else:
|
||||
self.batch_size -= 1
|
||||
|
||||
self.ping_pong.model_iter += 1
|
||||
super().__next__()
|
||||
|
||||
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue