Merge pull request #42 from faceshiftlabs/feat/ping-pong-2

Feat/ping pong 2
This commit is contained in:
Jeremy Hummel 2019-08-23 19:10:31 -07:00 committed by GitHub
commit 1dff2283d1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 121 additions and 108 deletions

View file

@ -14,6 +14,7 @@ import imagelib
from interact import interact as io
from nnlib import nnlib
from samplelib import SampleGeneratorBase
from samplelib.SampleGeneratorPingPong import PingPongOptions, Paddle
from utils import Path_utils, std_utils
from utils.cv2_utils import *
@ -67,7 +68,7 @@ class ModelBase(object):
self.debug = debug
self.is_training_mode = (training_data_src_path is not None and training_data_dst_path is not None)
self.paddle = 'pong'
self.paddle = Paddle.PONG
self.iter = 0
self.options = {}
@ -210,15 +211,18 @@ class ModelBase(object):
self.onInitializeOptions(self.iter == 0, ask_override)
self.ping_pong_options = PingPongOptions(enabled=self.options['ping_pong'],
iterations=self.ping_pong_iter,
model_iter=self.iter,
paddle=self.paddle,
batch_cap=self.batch_size)
nnlib.import_all(self.device_config)
self.keras = nnlib.keras
self.K = nnlib.keras.backend
self.onInitialize()
self.options['batch_size'] = self.batch_size
self.paddle = self.options.get('paddle', 'ping')
if self.debug or self.batch_size == 0:
self.batch_size = 1
@ -413,7 +417,7 @@ class ModelBase(object):
def save(self):
self.options['batch_size'] = self.batch_size
self.options['paddle'] = self.paddle
self.options['paddle'] = self.ping_pong_options.paddle
summary_path = self.get_strpath_storage_for_file('summary.txt')
Path( summary_path ).write_text(self.model_summary_text)
self.onSave()
@ -542,12 +546,11 @@ class ModelBase(object):
return [ generator.generate_next() for generator in self.generator_list]
def train_one_iter(self):
if self.iter == 1 and self.options.get('ping_pong', False):
self.set_batch_size(1)
self.paddle = 'ping'
elif not self.options.get('ping_pong', False) and self.batch_cap != self.batch_size:
self.set_batch_size(self.batch_cap)
# if self.iter == 1 and self.options.get('ping_pong', False):
# self.set_batch_size(1)
# self.paddle = 'ping'
# elif not self.options.get('ping_pong', False) and self.batch_cap != self.batch_size:
# self.set_batch_size(self.batch_cap)
sample = self.generate_next_sample()
iter_time = time.time()
losses = self.onTrainOneIter(sample, self.generator_list)
@ -574,21 +577,6 @@ class ModelBase(object):
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
cv2_imwrite (filepath, img )
if self.iter % self.ping_pong_iter == 0 and self.iter != 0 and self.options.get('ping_pong', False):
if self.batch_size == self.batch_cap:
self.paddle = 'pong'
if self.batch_size > self.batch_cap:
self.set_batch_size(self.batch_cap)
self.paddle = 'pong'
if self.batch_size == 1:
self.paddle = 'ping'
if self.paddle == 'ping':
self.save()
self.set_batch_size(self.batch_size + 1)
else:
self.save()
self.set_batch_size(self.batch_size - 1)
self.iter += 1
return self.iter, iter_time, self.batch_size

View file

@ -487,7 +487,8 @@ class SAEModel(ModelBase):
'apply_ct': apply_random_ct} for i in range(ms_count)] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
'resolution': resolution // (2 ** i)} for i in
range(ms_count)]
range(ms_count)],
ping_pong=self.ping_pong_options,
),
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
@ -500,7 +501,8 @@ class SAEModel(ModelBase):
range(ms_count)] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
'resolution': resolution // (2 ** i)} for i in
range(ms_count)])
range(ms_count)],
ping_pong=self.ping_pong_options,)
])
# override
@ -540,38 +542,8 @@ class SAEModel(ModelBase):
# override
def set_batch_size(self, batch_size):
self.batch_size = batch_size
self.set_training_data_generators(None)
self.set_training_data_generators([
SampleGeneratorFace(training_data_src_path,
sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
random_ct_samples_path=training_data_dst_path if apply_random_ct != ColorTransferMode.NONE else None,
debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip,
scale_range=np.array([-0.05,
0.05]) + self.src_scale_mod / 100.0),
output_sample_types=[{'types': (
t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),
'resolution': resolution, 'apply_ct': apply_random_ct}] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr),
'resolution': resolution // (2 ** i),
'apply_ct': apply_random_ct} for i in range(ms_count)] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
'resolution': resolution // (2 ** i)} for i in
range(ms_count)]
),
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ),
output_sample_types=[{'types': (
t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),
'resolution': resolution}] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr),
'resolution': resolution // (2 ** i)} for i in
range(ms_count)] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
'resolution': resolution // (2 ** i)} for i in
range(ms_count)])
])
for generators in self.get_training_data_generators():
generators.update_batch(batch_size)
# override
def onTrainOneIter(self, generators_samples, generators_list):

View file

@ -5,8 +5,9 @@ import cv2
import numpy as np
from facelib import LandmarksProcessor
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, SampleGeneratorPingPong,
SampleType)
from samplelib.SampleGeneratorPingPong import PingPongOptions, SampleGeneratorPingPong
from utils import iter_utils
@ -19,12 +20,12 @@ output_sample_types = [
'''
class SampleGeneratorFace(SampleGeneratorBase):
class SampleGeneratorFace(SampleGeneratorPingPong):
def __init__(self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None,
random_ct_samples_path=None, sample_process_options=SampleProcessor.Options(),
output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None,
**kwargs):
super().__init__(samples_path, debug, batch_size)
ping_pong=PingPongOptions(), **kwargs):
super().__init__(samples_path, debug, batch_size=batch_size, ping_pong=ping_pong)
self.sample_process_options = sample_process_options
self.output_sample_types = output_sample_types
self.add_sample_idx = add_sample_idx
@ -64,6 +65,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
def __next__(self):
self.generator_counter += 1
generator = self.generators[self.generator_counter % len(self.generators)]
super().__next__()
return next(generator)
def batch_func(self, param):

View file

@ -0,0 +1,50 @@
from enum import Enum
from samplelib import SampleGeneratorBase
class Paddle(Enum):
PING = 'ping' # Ascending
PONG = 'pong' # Descending
class PingPongOptions:
def __init__(self, enabled=False, iterations=1000, model_iter=1, paddle=Paddle.PING, batch_cap=1):
self.enabled = enabled
self.iterations = iterations
self.model_iter = model_iter
self.paddle = paddle
self.batch_cap = batch_cap
class SampleGeneratorPingPong(SampleGeneratorBase):
def __init__(self, *args, batch_size, ping_pong=PingPongOptions()):
self.ping_pong = ping_pong
super().__init__(*args, batch_size)
def __next__(self):
if self.ping_pong.enabled and self.ping_pong.model_iter % self.ping_pong.iterations == 0 \
and self.ping_pong.model_iter != 0:
# If batch size is greater then batch cap, set it to batch cap
if self.batch_size > self.ping_pong.batch_cap:
self.batch_size = self.ping_pong.batch_cap
# If we are at the batch cap, switch to PONG (descend)
if self.batch_size == self.ping_pong.batch_cap:
self.paddle = Paddle.PONG
# Else if we are at 1, switch to PING (ascend)
elif self.batch_size == 1:
self.paddle = Paddle.PING
# If PING (ascending) increase the batch size
if self.paddle is Paddle.PING:
self.batch_size += 1
# Else decrease the batch size
else:
self.batch_size -= 1
self.ping_pong.model_iter += 1
super().__next__()

View file

@ -4,5 +4,6 @@ from .SampleLoader import SampleLoader
from .SampleProcessor import SampleProcessor
from .SampleGeneratorBase import SampleGeneratorBase
from .SampleGeneratorFace import SampleGeneratorFace
from .SampleGeneratorPingPong import SampleGeneratorPingPong
from .SampleGeneratorFaceTemporal import SampleGeneratorFaceTemporal
from .SampleGeneratorImageTemporal import SampleGeneratorImageTemporal