From e6aa9968140436589c2d5233433f8cc3fc133cbf Mon Sep 17 00:00:00 2001 From: Colombo Date: Wed, 1 Jul 2020 22:17:35 +0400 Subject: [PATCH] leras: add ability to save sub layers in a dict --- core/leras/models/ModelBase.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/core/leras/models/ModelBase.py b/core/leras/models/ModelBase.py index d96e03f..77ac284 100644 --- a/core/leras/models/ModelBase.py +++ b/core/leras/models/ModelBase.py @@ -18,6 +18,10 @@ class ModelBase(nn.Saveable): if isinstance (layer, list): for i,sublayer in enumerate(layer): self._build_sub(sublayer, f"{name}_{i}") + elif isinstance (layer, dict): + for subname in layer.keys(): + sublayer = layer[subname] + self._build_sub(sublayer, f"{name}_{subname}") elif isinstance (layer, nn.LayerBase) or \ isinstance (layer, ModelBase): @@ -32,7 +36,7 @@ class ModelBase(nn.Saveable): self.layers.append (layer) self.layers_by_name[layer.name] = layer - + def xor_list(self, lst1, lst2): return [value for value in lst1+lst2 if (value not in lst1) or (value not in lst2) ] @@ -79,7 +83,7 @@ class ModelBase(nn.Saveable): def get_layer_by_name(self, name): return self.layers_by_name.get(name, None) - + def get_layers(self): if not self.built: self.build()