From 42398ebebc865e241fcff3075f4187079e66947d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 4 May 2020 19:07:34 -0400 Subject: [PATCH 1/3] Add a trainer based on Pytorch Lightning --- docs/source/conf.py | 12 +- setup.py | 43 +-- tests/torchgan/test_layers.py | 4 +- tests/torchgan/test_losses.py | 15 +- tests/torchgan/test_metrics.py | 4 +- tests/torchgan/test_trainer.py | 24 +- torchgan/layers/denseblock.py | 54 +++- torchgan/layers/minibatchdiscrimination.py | 4 +- torchgan/layers/residual.py | 12 +- torchgan/layers/spectralnorm.py | 12 +- torchgan/logging/logger.py | 12 +- torchgan/logging/visualize.py | 52 +++- torchgan/losses/auxclassifier.py | 9 +- torchgan/losses/boundaryequilibrium.py | 13 +- torchgan/losses/draganpenalty.py | 17 +- torchgan/losses/energybased.py | 20 +- torchgan/losses/featurematching.py | 12 +- torchgan/losses/functional.py | 28 +- torchgan/losses/historical.py | 19 +- torchgan/losses/leastsquares.py | 17 +- torchgan/losses/loss.py | 8 +- torchgan/losses/minimax.py | 21 +- torchgan/losses/mutualinfo.py | 8 +- torchgan/losses/wasserstein.py | 16 +- torchgan/metrics/classifierscore.py | 8 +- torchgan/models/acgan.py | 10 +- torchgan/models/autoencoding.py | 62 ++-- torchgan/models/conditional.py | 8 +- torchgan/models/dcgan.py | 35 ++- torchgan/models/infogan.py | 8 +- torchgan/trainer/base_trainer.py | 41 ++- torchgan/trainer/lightning_trainer.py | 319 +++++++++++++++++++++ torchgan/trainer/parallel_trainer.py | 12 +- torchgan/trainer/trainer.py | 12 +- 34 files changed, 794 insertions(+), 157 deletions(-) create mode 100644 torchgan/trainer/lightning_trainer.py diff --git a/docs/source/conf.py b/docs/source/conf.py index 82786cf..6ae896e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -23,7 +23,9 @@ import sphinx_rtd_theme -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) +sys.path.insert( + 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +) # -- General configuration ------------------------------------------------ @@ -32,7 +34,13 @@ # needs_sphinx = '1.0' # Mock Imports -autodoc_mock_imports = ["torch", "pillow", "torchvision", "tensorboardX", "visdom"] +autodoc_mock_imports = [ + "torch", + "pillow", + "torchvision", + "tensorboardX", + "visdom", +] # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom diff --git a/setup.py b/setup.py index 372c922..c7540fb 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,9 @@ def get_dist(pkgname): def find_version(*file_paths): version_file = read(*file_paths) - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M) + version_match = re.search( + r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M + ) if version_match: return version_match.group(1) raise RuntimeError("Unable to find version string.") @@ -36,18 +38,19 @@ def find_version(*file_paths): VERSION = find_version("torchgan", "__init__.py") -def load_requirements(path_dir=PATH_ROOT, comment_char='#'): - with open(os.path.join(path_dir, 'requirements.txt'), 'r') as file: +def load_requirements(path_dir=PATH_ROOT, comment_char="#"): + with open(os.path.join(path_dir, "requirements.txt"), "r") as file: lines = [ln.strip() for ln in file.readlines()] reqs = [] for ln in lines: # filer all comments if comment_char in ln: - ln = ln[:ln.index(comment_char)] + ln = ln[: ln.index(comment_char)] if ln: # if requirement is not empty reqs.append(ln) return reqs + setup( # Metadata name="torchgan", @@ -61,29 +64,29 @@ def load_requirements(path_dir=PATH_ROOT, comment_char='#'): packages=find_packages(exclude=("test",)), zip_safe=True, install_requires=load_requirements(PATH_ROOT), - long_description=open('README.md', encoding='utf-8').read(), - long_description_content_type='text/markdown', + long_description=open("README.md", encoding="utf-8").read(), + long_description_content_type="text/markdown", include_package_data=True, - keywords=['deep learning', 'pytorch', 'GAN', 'AI'], - python_requires='>=3.6', + keywords=["deep learning", "pytorch", "GAN", "AI"], + python_requires=">=3.6", classifiers=[ - 'Environment :: Console', - 'Natural Language :: English', + "Environment :: Console", + "Natural Language :: English", # How mature is this project? Common values are # 3 - Alpha, 4 - Beta, 5 - Production/Stable - 'Development Status :: 4 - Beta', + "Development Status :: 4 - Beta", # Indicate who your project is intended for - 'Intended Audience :: Developers', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - 'Topic :: Scientific/Engineering :: Image Recognition', - 'Topic :: Scientific/Engineering :: Deep Learning', + "Intended Audience :: Developers", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Image Recognition", + "Topic :: Scientific/Engineering :: Deep Learning", # Pick your license as you wish - 'License :: OSI Approved :: MIT License', - 'Operating System :: OS Independent', + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", # Specify the Python versions you support here. In particular, ensure # that you indicate whether you support Python 2, Python 3 or both. - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", ], ) diff --git a/tests/torchgan/test_layers.py b/tests/torchgan/test_layers.py index a9c7b33..c19769f 100644 --- a/tests/torchgan/test_layers.py +++ b/tests/torchgan/test_layers.py @@ -37,7 +37,9 @@ def test_residual_block2d(self): def test_transposed_residula_block2d(self): input = torch.rand(16, 3, 10, 10) - layer = ResidualBlockTranspose2d([3, 16, 32, 3], [3, 3, 1], paddings=[1, 1, 0]) + layer = ResidualBlockTranspose2d( + [3, 16, 32, 3], [3, 3, 1], paddings=[1, 1, 0] + ) self.match_layer_outputs(layer, input, (16, 3, 10, 10)) diff --git a/tests/torchgan/test_losses.py b/tests/torchgan/test_losses.py index c440425..5c174c2 100644 --- a/tests/torchgan/test_losses.py +++ b/tests/torchgan/test_losses.py @@ -32,7 +32,9 @@ def match_losses( l_d.reduction = "none" loss_none = l_d(D_X, D_GZ).view(-1, 1) for i in range(4): - self.assertAlmostEqual(d_loss_none[i], loss_none[i].item(), places=5) + self.assertAlmostEqual( + d_loss_none[i], loss_none[i].item(), places=5 + ) self.assertAlmostEqual(gen_loss_mean, l_g(D_GZ).item(), places=5) l_g.reduction = "sum" @@ -40,7 +42,9 @@ def match_losses( l_g.reduction = "none" loss_none = l_g(D_GZ).view(-1, 1) for i in range(4): - self.assertAlmostEqual(gen_loss_none[i], loss_none[i].item(), places=5) + self.assertAlmostEqual( + gen_loss_none[i], loss_none[i].item(), places=5 + ) def test_wasserstein_loss(self): dx = [1.3, 2.9, 8.4, 6.3] @@ -136,7 +140,12 @@ def test_minimax_nonsaturating_loss(self): gen_loss_mean = 0.9509911 gen_loss_sum = 3.8039644 - gen_loss_none = [8.1960661e-03, 2.6328245e-01, 3.5297503e00, 2.7356991e-03] + gen_loss_none = [ + 8.1960661e-03, + 2.6328245e-01, + 3.5297503e00, + 2.7356991e-03, + ] d_loss_mean = 3.1251488 d_loss_sum = 12.500595 diff --git a/tests/torchgan/test_metrics.py b/tests/torchgan/test_metrics.py index 62b1a5a..602aaa5 100644 --- a/tests/torchgan/test_metrics.py +++ b/tests/torchgan/test_metrics.py @@ -12,4 +12,6 @@ class TestMetrics(unittest.TestCase): def test_inception_score(self): inception_score = ClassifierScore() x = torch.Tensor([[1.0, 2.0, 3.0], [-1.0, 5.0, 3.1]]) - self.assertAlmostEqual(inception_score.calculate_score(x).item(), 1.24357, 4) + self.assertAlmostEqual( + inception_score.calculate_score(x).item(), 1.24357, 4 + ) diff --git a/tests/torchgan/test_trainer.py b/tests/torchgan/test_trainer.py index b4bbcfb..b54fc00 100644 --- a/tests/torchgan/test_trainer.py +++ b/tests/torchgan/test_trainer.py @@ -68,7 +68,11 @@ def test_trainer_cgan(self): network_params = { "generator": { "name": ConditionalGANGenerator, - "args": {"num_classes": 10, "out_channels": 1, "step_channels": 4}, + "args": { + "num_classes": 10, + "out_channels": 1, + "step_channels": 4, + }, "optimizer": { "name": Adam, "args": {"lr": 0.0002, "betas": (0.5, 0.999)}, @@ -76,7 +80,11 @@ def test_trainer_cgan(self): }, "discriminator": { "name": ConditionalGANDiscriminator, - "args": {"num_classes": 10, "in_channels": 1, "step_channels": 4}, + "args": { + "num_classes": 10, + "in_channels": 1, + "step_channels": 4, + }, "optimizer": { "name": Adam, "args": {"lr": 0.0002, "betas": (0.5, 0.999)}, @@ -97,7 +105,11 @@ def test_trainer_acgan(self): network_params = { "generator": { "name": ACGANGenerator, - "args": {"num_classes": 10, "out_channels": 1, "step_channels": 4}, + "args": { + "num_classes": 10, + "out_channels": 1, + "step_channels": 4, + }, "optimizer": { "name": Adam, "args": {"lr": 0.0002, "betas": (0.5, 0.999)}, @@ -105,7 +117,11 @@ def test_trainer_acgan(self): }, "discriminator": { "name": ACGANDiscriminator, - "args": {"num_classes": 10, "in_channels": 1, "step_channels": 4}, + "args": { + "num_classes": 10, + "in_channels": 1, + "step_channels": 4, + }, "optimizer": { "name": Adam, "args": {"lr": 0.0002, "betas": (0.5, 0.999)}, diff --git a/torchgan/layers/denseblock.py b/torchgan/layers/denseblock.py index fafc319..5001d9b 100644 --- a/torchgan/layers/denseblock.py +++ b/torchgan/layers/denseblock.py @@ -48,14 +48,24 @@ def __init__( nn.BatchNorm2d(in_channels), nl, nn.Conv2d( - in_channels, out_channels, kernel, stride, padding, bias=False + in_channels, + out_channels, + kernel, + stride, + padding, + bias=False, ), ) else: self.model = nn.Sequential( nl, nn.Conv2d( - in_channels, out_channels, kernel, stride, padding, bias=True + in_channels, + out_channels, + kernel, + stride, + padding, + bias=True, ), ) @@ -110,14 +120,18 @@ def __init__( ): super(BottleneckBlock2d, self).__init__() bottleneck_channels = ( - 4 * in_channels if bottleneck_channels is None else bottleneck_channels + 4 * in_channels + if bottleneck_channels is None + else bottleneck_channels ) nl = nn.LeakyReLU(0.2) if nonlinearity is None else nonlinearity if batchnorm is True: self.model = nn.Sequential( nn.BatchNorm2d(in_channels), nl, - nn.Conv2d(in_channels, bottleneck_channels, 1, 1, 0, bias=False), + nn.Conv2d( + in_channels, bottleneck_channels, 1, 1, 0, bias=False + ), nn.BatchNorm2d(bottleneck_channels), nl, nn.Conv2d( @@ -132,7 +146,9 @@ def __init__( else: self.model = nn.Sequential( nl, - nn.Conv2d(in_channels, bottleneck_channels, 1, 1, 0, bias=True), + nn.Conv2d( + in_channels, bottleneck_channels, 1, 1, 0, bias=True + ), nl, nn.Conv2d( bottleneck_channels, @@ -191,14 +207,24 @@ def __init__( nn.BatchNorm2d(in_channels), nl, nn.Conv2d( - in_channels, out_channels, kernel, stride, padding, bias=False + in_channels, + out_channels, + kernel, + stride, + padding, + bias=False, ), ) else: self.model = nn.Sequential( nl, nn.Conv2d( - in_channels, out_channels, kernel, stride, padding, bias=True + in_channels, + out_channels, + kernel, + stride, + padding, + bias=True, ), ) @@ -246,14 +272,24 @@ def __init__( nn.BatchNorm2d(in_channels), nl, nn.ConvTranspose2d( - in_channels, out_channels, kernel, stride, padding, bias=False + in_channels, + out_channels, + kernel, + stride, + padding, + bias=False, ), ) else: self.model = nn.Sequential( nl, nn.ConvTranspose2d( - in_channels, out_channels, kernel, stride, padding, bias=True + in_channels, + out_channels, + kernel, + stride, + padding, + bias=True, ), ) diff --git a/torchgan/layers/minibatchdiscrimination.py b/torchgan/layers/minibatchdiscrimination.py index 0c24afc..d11c868 100644 --- a/torchgan/layers/minibatchdiscrimination.py +++ b/torchgan/layers/minibatchdiscrimination.py @@ -56,7 +56,9 @@ def forward(self, x): 3D Torch Tensor of size :math: `(N,infeatures + outfeatures)` after applying Minibatch Discrimination """ M = torch.mm(x, self.T.view(self.in_features, -1)) - M = M.view(-1, self.out_features, self.intermediate_features).unsqueeze(0) + M = M.view( + -1, self.out_features, self.intermediate_features + ).unsqueeze(0) M_t = M.permute(1, 0, 2, 3) # Broadcasting reduces the matrix subtraction to the form desired in the paper out = torch.sum(torch.exp(-(torch.abs(M - M_t).sum(3))), dim=0) - 1 diff --git a/torchgan/layers/residual.py b/torchgan/layers/residual.py index bd71fde..52999fc 100644 --- a/torchgan/layers/residual.py +++ b/torchgan/layers/residual.py @@ -93,7 +93,11 @@ def forward(self, x): out += self.shortcut(x) else: out += x - return out if self.last_nonlinearity is None else self.last_nonlinearity(out) + return ( + out + if self.last_nonlinearity is None + else self.last_nonlinearity(out) + ) class ResidualBlockTranspose2d(nn.Module): @@ -185,4 +189,8 @@ def forward(self, x): out += self.shortcut(x) else: out += x - return out if self.last_nonlinearity is None else self.last_nonlinearity(out) + return ( + out + if self.last_nonlinearity is None + else self.last_nonlinearity(out) + ) diff --git a/torchgan/layers/spectralnorm.py b/torchgan/layers/spectralnorm.py index fb5a724..2f4f6c2 100644 --- a/torchgan/layers/spectralnorm.py +++ b/torchgan/layers/spectralnorm.py @@ -43,8 +43,12 @@ def __init__(self, module, name="weight", power_iterations=1): w = getattr(self.module, self.name) height = w.data.shape[0] width = w.view(height, -1).data.shape[1] - self.u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) - self.v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) + self.u = Parameter( + w.data.new(height).normal_(0, 1), requires_grad=False + ) + self.v = Parameter( + w.data.new(width).normal_(0, 1), requires_grad=False + ) self.u.data = self._l2normalize(self.u.data) self.v.data = self._l2normalize(self.v.data) self.w_bar = Parameter(w.data) @@ -78,5 +82,7 @@ def forward(self, *args): torch.mv(self.w_bar.view(height, -1), self.v) ) sigma = self.u.dot(self.w_bar.view(height, -1).mv(self.v)) - setattr(self.module, self.name, self.w_bar / sigma.expand_as(self.w_bar)) + setattr( + self.module, self.name, self.w_bar / sigma.expand_as(self.w_bar) + ) return self.module.forward(*args) diff --git a/torchgan/logging/logger.py b/torchgan/logging/logger.py index 2cad638..6e717be 100644 --- a/torchgan/logging/logger.py +++ b/torchgan/logging/logger.py @@ -71,7 +71,9 @@ def __init__( self.logger_end_epoch.append( MetricVisualize(metrics_list, writer=self.writer) ) - self.logger_mid_epoch.append(LossVisualize(losses_list, writer=self.writer)) + self.logger_mid_epoch.append( + LossVisualize(losses_list, writer=self.writer) + ) def get_loss_viz(self): r"""Get the LossVisualize object. @@ -97,9 +99,13 @@ def register(self, visualize, *args, mid_epoch=True, **kwargs): over. Otherwise it is executed after every call to the ``train_iter``. """ if mid_epoch: - self.logger_mid_epoch.append(visualize(*args, writer=self.writer, **kwargs)) + self.logger_mid_epoch.append( + visualize(*args, writer=self.writer, **kwargs) + ) else: - self.logger_end_epoch.append(visualize(*args, writer=self.writer, **kwargs)) + self.logger_end_epoch.append( + visualize(*args, writer=self.writer, **kwargs) + ) def close(self): r"""Turns off the tensorboard ``SummaryWriter`` if it were created. diff --git a/torchgan/logging/visualize.py b/torchgan/logging/visualize.py index d51c588..0e92244 100644 --- a/torchgan/logging/visualize.py +++ b/torchgan/logging/visualize.py @@ -32,7 +32,9 @@ class Visualize(object): don't want to start a new SummaryWriter. """ - def __init__(self, visualize_list, visdom_port=8097, log_dir=None, writer=None): + def __init__( + self, visualize_list, visdom_port=8097, log_dir=None, writer=None + ): self.logs = {} for item in visualize_list: name = type(item).__name__ @@ -153,7 +155,9 @@ def log_tensorboard(self, running_losses): "Losses/{}-Discriminator".format(name), val[1], self.step ) else: - self.writer.add_scalar("Losses/{}".format(name), val, self.step) + self.writer.add_scalar( + "Losses/{}".format(name), val, self.step + ) def log_console(self, running_losses): r"""Console logging function. This function logs the mean ``generator`` and ``discriminator`` @@ -227,14 +231,18 @@ def log_visdom(self, running_losses): [self.step], win=name1, update="append", - opts=dict(title=name1, xlabel="Time Step", ylabel="Loss Value"), + opts=dict( + title=name1, xlabel="Time Step", ylabel="Loss Value" + ), ) self.vis.line( [val[1]], [self.step], win=name2, update="append", - opts=dict(title=name2, xlabel="Time Step", ylabel="Loss Value"), + opts=dict( + title=name2, xlabel="Time Step", ylabel="Loss Value" + ), ) else: self.vis.line( @@ -242,7 +250,9 @@ def log_visdom(self, running_losses): [self.step], win=name, update="append", - opts=dict(title=name, xlabel="Time Step", ylabel="Loss Value"), + opts=dict( + title=name, xlabel="Time Step", ylabel="Loss Value" + ), ) def __call__(self, trainer, **kwargs): @@ -279,7 +289,9 @@ def log_tensorboard(self): r"""Tensorboard logging function. This function logs the values of the individual metrics. """ for name, value in self.logs.items(): - self.writer.add_scalar("Metrics/{}".format(name), value[-1], self.step) + self.writer.add_scalar( + "Metrics/{}".format(name), value[-1], self.step + ) def log_console(self): r"""Console logging function. This function logs the mean metrics. @@ -296,7 +308,9 @@ def log_visdom(self): [self.step], win=name, update="append", - opts=dict(title=name, xlabel="Time Step", ylabel="Metric Value"), + opts=dict( + title=name, xlabel="Time Step", ylabel="Metric Value" + ), ) @@ -314,7 +328,9 @@ class GradientVisualize(Visualize): don't want to start a new SummaryWriter. """ - def __init__(self, visualize_list, visdom_port=8097, log_dir=None, writer=None): + def __init__( + self, visualize_list, visdom_port=8097, log_dir=None, writer=None + ): if visualize_list is None or len(visualize_list) == 0: raise Exception("Gradient Visualizer requires list of model names") self.logs = {} @@ -345,7 +361,9 @@ def log_console(self, name): name (str): Name of the model whose gradients are to be logged. """ print( - "{} Gradients : {}".format(name, self.logs[name][len(self.logs[name]) - 1]) + "{} Gradients : {}".format( + name, self.logs[name][len(self.logs[name]) - 1] + ) ) def log_visdom(self, name): @@ -384,7 +402,9 @@ def report_end_epoch(self): """ if CONSOLE_LOGGING == 1: for key, val in self.logs.items(): - print("{} Mean Gradients : {}".format(key, sum(val) / len(val))) + print( + "{} Mean Gradients : {}".format(key, sum(val) / len(val)) + ) def __call__(self, trainer, **kwargs): for name in trainer.model_names: @@ -425,7 +445,9 @@ def __init__( for model in trainer.model_names: if isinstance(getattr(trainer, model), Generator): self.test_noise.append( - getattr(trainer, model).sampler(trainer.sample_size, trainer.device) + getattr(trainer, model).sampler( + trainer.sample_size, trainer.device + ) if test_noise is None else test_noise ) @@ -440,7 +462,9 @@ def log_tensorboard(self, trainer, image, model): image (Image): The generated image. model (str): The name of the model which generated the ``image``. """ - self.writer.add_image("Generated Samples/{}".format(model), image, self.step) + self.writer.add_image( + "Generated Samples/{}".format(model), image, self.step + ) def log_console(self, trainer, image, model): r"""Saves a generated image at the end of an epoch. The path where the image is @@ -463,7 +487,9 @@ def log_visdom(self, trainer, image, model): image (Image): The generated image. model (str): The name of the model which generated the ``image``. """ - self.vis.image(image, opts=dict(caption="Generated Samples/{}".format(model))) + self.vis.image( + image, opts=dict(caption="Generated Samples/{}".format(model)) + ) def __call__(self, trainer, **kwargs): pos = 0 diff --git a/torchgan/losses/auxclassifier.py b/torchgan/losses/auxclassifier.py index a904036..7970a29 100644 --- a/torchgan/losses/auxclassifier.py +++ b/torchgan/losses/auxclassifier.py @@ -3,7 +3,10 @@ from .functional import auxiliary_classification_loss from .loss import DiscriminatorLoss, GeneratorLoss -__all__ = ["AuxiliaryClassifierGeneratorLoss", "AuxiliaryClassifierDiscriminatorLoss"] +__all__ = [ + "AuxiliaryClassifierGeneratorLoss", + "AuxiliaryClassifierDiscriminatorLoss", +] class AuxiliaryClassifierGeneratorLoss(GeneratorLoss): @@ -72,7 +75,9 @@ def train_ops( noise = torch.randn(batch_size, generator.encoding_dims, device=device) optimizer_generator.zero_grad() if generator.label_type == "none": - raise Exception("Incorrect Model: ACGAN generator must require labels") + raise Exception( + "Incorrect Model: ACGAN generator must require labels" + ) if generator.label_type == "required": fake = generator(noise, labels) elif generator.label_type == "generated": diff --git a/torchgan/losses/boundaryequilibrium.py b/torchgan/losses/boundaryequilibrium.py index 32c29a7..5f3aeee 100644 --- a/torchgan/losses/boundaryequilibrium.py +++ b/torchgan/losses/boundaryequilibrium.py @@ -6,7 +6,10 @@ ) from .loss import DiscriminatorLoss, GeneratorLoss -__all__ = ["BoundaryEquilibriumGeneratorLoss", "BoundaryEquilibriumDiscriminatorLoss"] +__all__ = [ + "BoundaryEquilibriumGeneratorLoss", + "BoundaryEquilibriumDiscriminatorLoss", +] class BoundaryEquilibriumGeneratorLoss(GeneratorLoss): @@ -107,7 +110,9 @@ def forward(self, dx, dgz): A tuple of 3 loss values, namely the ``total loss``, ``loss due to real data`` and ``loss due to fake data``. """ - return boundary_equilibrium_discriminator_loss(dx, dgz, self.k, self.reduction) + return boundary_equilibrium_discriminator_loss( + dx, dgz, self.k, self.reduction + ) def set_k(self, k=0.0): r"""Change the default value of k @@ -186,7 +191,9 @@ def train_ops( ): raise Exception("GAN model requires labels for training") batch_size = real_inputs.size(0) - noise = torch.randn(batch_size, generator.encoding_dims, device=device) + noise = torch.randn( + batch_size, generator.encoding_dims, device=device + ) if generator.label_type == "generated": label_gen = torch.randint( 0, generator.num_classes, (batch_size,), device=device diff --git a/torchgan/losses/draganpenalty.py b/torchgan/losses/draganpenalty.py index ed080b5..2663594 100644 --- a/torchgan/losses/draganpenalty.py +++ b/torchgan/losses/draganpenalty.py @@ -33,7 +33,9 @@ class DraganGradientPenalty(DiscriminatorLoss): override_train_ops (function, optional): Function to be used in place of the default ``train_ops`` """ - def __init__(self, reduction="mean", lambd=10.0, k=1.0, override_train_ops=None): + def __init__( + self, reduction="mean", lambd=10.0, k=1.0, override_train_ops=None + ): super(DraganGradientPenalty, self).__init__(reduction) self.lambd = lambd self.override_train_ops = override_train_ops @@ -104,12 +106,19 @@ def train_ops( alpha = torch.rand( size=real_inputs.shape, device=device, requires_grad=True ) - beta = torch.rand(size=real_inputs.shape, device=device, requires_grad=True) + beta = torch.rand( + size=real_inputs.shape, device=device, requires_grad=True + ) optimizer_discriminator.zero_grad() - interpolate = real_inputs + (1 - alpha) * 0.5 * real_inputs.std() * beta + interpolate = ( + real_inputs + (1 - alpha) * 0.5 * real_inputs.std() * beta + ) if generator.label_type == "generated": label_gen = torch.randint( - 0, generator.num_classes, (real_inputs.size(0),), device=device + 0, + generator.num_classes, + (real_inputs.size(0),), + device=device, ) if discriminator.label_type == "none": d_interpolate = discriminator(interpolate) diff --git a/torchgan/losses/energybased.py b/torchgan/losses/energybased.py index 2b24470..78586fd 100644 --- a/torchgan/losses/energybased.py +++ b/torchgan/losses/energybased.py @@ -121,7 +121,9 @@ class EnergyBasedPullingAwayTerm(GeneratorLoss): """ def __init__(self, pt_ratio=0.1, override_train_ops=None): - super(EnergyBasedPullingAwayTerm, self).__init__("mean", override_train_ops) + super(EnergyBasedPullingAwayTerm, self).__init__( + "mean", override_train_ops + ) self.pt_ratio = pt_ratio def forward(self, dgz, d_hid): @@ -184,10 +186,16 @@ def train_ops( "EBGAN PT requires the Discriminator to be a AutoEncoder" ) if not generator.label_type == "none": - raise Exception("EBGAN PT supports models which donot require labels") + raise Exception( + "EBGAN PT supports models which donot require labels" + ) if not discriminator.embeddings: - raise Exception("EBGAN PT requires the embeddings for loss computation") - noise = torch.randn(batch_size, generator.encoding_dims, device=device) + raise Exception( + "EBGAN PT requires the embeddings for loss computation" + ) + noise = torch.randn( + batch_size, generator.encoding_dims, device=device + ) optimizer_generator.zero_grad() fake = generator(noise) d_hid, dgz = discriminator(fake) @@ -258,7 +266,9 @@ def forward(self, dx, dgz): Returns: scalar if reduction is applied else Tensor with dimensions (N, \*). """ - return energy_based_discriminator_loss(dx, dgz, self.margin, self.reduction) + return energy_based_discriminator_loss( + dx, dgz, self.margin, self.reduction + ) def train_ops( self, diff --git a/torchgan/losses/featurematching.py b/torchgan/losses/featurematching.py index 7c2d2cd..bd8c7d6 100644 --- a/torchgan/losses/featurematching.py +++ b/torchgan/losses/featurematching.py @@ -88,7 +88,9 @@ def train_ops( if labels is None and generator.label_type == "required": raise Exception("GAN model requires labels for training") batch_size = real_inputs.size(0) - noise = torch.randn(batch_size, generator.encoding_dims, device=device) + noise = torch.randn( + batch_size, generator.encoding_dims, device=device + ) optimizer_generator.zero_grad() if generator.label_type == "generated": label_gen = torch.randint( @@ -106,9 +108,13 @@ def train_ops( fgz = discriminator(fake, feature_matching=True) else: if discriminator.label_type == "generated": - fx = discriminator(real_inputs, label_gen, feature_matching=True) + fx = discriminator( + real_inputs, label_gen, feature_matching=True + ) else: - fx = discriminator(real_inputs, labels, feature_matching=True) + fx = discriminator( + real_inputs, labels, feature_matching=True + ) if generator.label_type == "generated": fgz = discriminator(fake, label_gen, feature_matching=True) else: diff --git a/torchgan/losses/functional.py b/torchgan/losses/functional.py index 7373d11..3711b77 100644 --- a/torchgan/losses/functional.py +++ b/torchgan/losses/functional.py @@ -28,7 +28,9 @@ def minimax_generator_loss(dgz, nonsaturating=True, reduction="mean"): if nonsaturating: target = torch.ones_like(dgz) - return F.binary_cross_entropy_with_logits(dgz, target, reduction=reduction) + return F.binary_cross_entropy_with_logits( + dgz, target, reduction=reduction + ) else: target = torch.zeros_like(dgz) return -1.0 * F.binary_cross_entropy_with_logits( @@ -39,8 +41,12 @@ def minimax_generator_loss(dgz, nonsaturating=True, reduction="mean"): def minimax_discriminator_loss(dx, dgz, label_smoothing=0.0, reduction="mean"): target_ones = torch.ones_like(dgz) * (1.0 - label_smoothing) target_zeros = torch.zeros_like(dx) - loss = F.binary_cross_entropy_with_logits(dx, target_ones, reduction=reduction) - loss += F.binary_cross_entropy_with_logits(dgz, target_zeros, reduction=reduction) + loss = F.binary_cross_entropy_with_logits( + dx, target_ones, reduction=reduction + ) + loss += F.binary_cross_entropy_with_logits( + dgz, target_zeros, reduction=reduction + ) return loss @@ -52,13 +58,17 @@ def least_squares_generator_loss(dgz, c=1.0, reduction="mean"): def least_squares_discriminator_loss(dx, dgz, a=0.0, b=1.0, reduction="mean"): - return 0.5 * (reduce((dx - b) ** 2, reduction) + reduce((dgz - a) ** 2, reduction)) + return 0.5 * ( + reduce((dx - b) ** 2, reduction) + reduce((dgz - a) ** 2, reduction) + ) # Mutual Information Penalty -def mutual_information_penalty(c_dis, c_cont, dist_dis, dist_cont, reduction="mean"): +def mutual_information_penalty( + c_dis, c_cont, dist_dis, dist_cont, reduction="mean" +): log_probs = torch.Tensor( [ torch.mean(dist.log_prob(c)) @@ -97,7 +107,9 @@ def wasserstein_gradient_penalty(interpolate, d_interpolate, reduction="mean"): # Dragan Penalty -def dragan_gradient_penalty(interpolate, d_interpolate, k=1.0, reduction="mean"): +def dragan_gradient_penalty( + interpolate, d_interpolate, k=1.0, reduction="mean" +): grad_outputs = torch.ones_like(d_interpolate) gradients = autograd.grad( outputs=d_interpolate, @@ -135,7 +147,9 @@ def energy_based_pulling_away_term(d_hid): d_hid_normalized = F.normalize(d_hid, p=2, dim=0) n = d_hid_normalized.size(0) d_hid_normalized = d_hid_normalized.view(n, -1) - similarity = torch.matmul(d_hid_normalized, d_hid_normalized.transpose(1, 0)) + similarity = torch.matmul( + d_hid_normalized, d_hid_normalized.transpose(1, 0) + ) loss_pt = torch.sum(similarity ** 2) / (n * (n - 1)) return loss_pt diff --git a/torchgan/losses/historical.py b/torchgan/losses/historical.py index 3737964..24e2b64 100644 --- a/torchgan/losses/historical.py +++ b/torchgan/losses/historical.py @@ -3,7 +3,10 @@ from ..utils import reduce from .loss import DiscriminatorLoss, GeneratorLoss -__all__ = ["HistoricalAverageGeneratorLoss", "HistoricalAverageDiscriminatorLoss"] +__all__ = [ + "HistoricalAverageGeneratorLoss", + "HistoricalAverageDiscriminatorLoss", +] class HistoricalAverageGeneratorLoss(GeneratorLoss): @@ -58,7 +61,9 @@ def __init__( def train_ops(self, generator, optimizer_generator): if self.override_train_ops is not None: - return self.override_train_ops(self, generator, optimizer_generator) + return self.override_train_ops( + self, generator, optimizer_generator + ) else: if self.timesteps == 0: for p in generator.parameters(): @@ -71,7 +76,8 @@ def train_ops(self, generator, optimizer_generator): loss = 0.0 for i, p in enumerate(generator.parameters()): loss += torch.sum( - (p - (self.sum_parameters[i].data / self.timesteps)) ** 2 + (p - (self.sum_parameters[i].data / self.timesteps)) + ** 2 ) self.sum_parameters[i] += p.data.clone() self.timesteps += 1 @@ -133,7 +139,9 @@ def __init__( def train_ops(self, discriminator, optimizer_discriminator): if self.override_train_ops is not None: - return self.override_train_ops(self, discriminator, optimizer_discriminator) + return self.override_train_ops( + self, discriminator, optimizer_discriminator + ) else: if self.timesteps == 0: for p in discriminator.parameters(): @@ -146,7 +154,8 @@ def train_ops(self, discriminator, optimizer_discriminator): loss = 0.0 for i, p in enumerate(discriminator.parameters()): loss += torch.sum( - (p - (self.sum_parameters[i].data / self.timesteps)) ** 2 + (p - (self.sum_parameters[i].data / self.timesteps)) + ** 2 ) self.sum_parameters[i] += p.data.clone() self.timesteps += 1 diff --git a/torchgan/losses/leastsquares.py b/torchgan/losses/leastsquares.py index a687574..8333149 100644 --- a/torchgan/losses/leastsquares.py +++ b/torchgan/losses/leastsquares.py @@ -1,6 +1,9 @@ import torch -from .functional import least_squares_discriminator_loss, least_squares_generator_loss +from .functional import ( + least_squares_discriminator_loss, + least_squares_generator_loss, +) from .loss import DiscriminatorLoss, GeneratorLoss __all__ = ["LeastSquaresGeneratorLoss", "LeastSquaresDiscriminatorLoss"] @@ -30,7 +33,9 @@ class LeastSquaresGeneratorLoss(GeneratorLoss): """ def __init__(self, reduction="mean", c=1.0, override_train_ops=None): - super(LeastSquaresGeneratorLoss, self).__init__(reduction, override_train_ops) + super(LeastSquaresGeneratorLoss, self).__init__( + reduction, override_train_ops + ) self.c = c def forward(self, dgz): @@ -71,7 +76,9 @@ class LeastSquaresDiscriminatorLoss(DiscriminatorLoss): override_train_ops (function, optional): Function to be used in place of the default ``train_ops`` """ - def __init__(self, reduction="mean", a=0.0, b=1.0, override_train_ops=None): + def __init__( + self, reduction="mean", a=0.0, b=1.0, override_train_ops=None + ): super(LeastSquaresDiscriminatorLoss, self).__init__( reduction, override_train_ops ) @@ -92,4 +99,6 @@ def forward(self, dx, dgz): Returns: scalar if reduction is applied else Tensor with dimensions (N, \*). """ - return least_squares_discriminator_loss(dx, dgz, self.a, self.b, self.reduction) + return least_squares_discriminator_loss( + dx, dgz, self.a, self.b, self.reduction + ) diff --git a/torchgan/losses/loss.py b/torchgan/losses/loss.py index a274ea4..7a039bc 100644 --- a/torchgan/losses/loss.py +++ b/torchgan/losses/loss.py @@ -85,7 +85,9 @@ def train_ops( else: if labels is None and generator.label_type == "required": raise Exception("GAN model requires labels for training") - noise = torch.randn(batch_size, generator.encoding_dims, device=device) + noise = torch.randn( + batch_size, generator.encoding_dims, device=device + ) optimizer_generator.zero_grad() if generator.label_type == "generated": label_gen = torch.randint( @@ -199,7 +201,9 @@ def train_ops( ): raise Exception("GAN model requires labels for training") batch_size = real_inputs.size(0) - noise = torch.randn(batch_size, generator.encoding_dims, device=device) + noise = torch.randn( + batch_size, generator.encoding_dims, device=device + ) if generator.label_type == "generated": label_gen = torch.randint( 0, generator.num_classes, (batch_size,), device=device diff --git a/torchgan/losses/minimax.py b/torchgan/losses/minimax.py index 0915d77..df9a1ce 100644 --- a/torchgan/losses/minimax.py +++ b/torchgan/losses/minimax.py @@ -33,8 +33,12 @@ class MinimaxGeneratorLoss(GeneratorLoss): loss for the generator. """ - def __init__(self, reduction="mean", nonsaturating=True, override_train_ops=None): - super(MinimaxGeneratorLoss, self).__init__(reduction, override_train_ops) + def __init__( + self, reduction="mean", nonsaturating=True, override_train_ops=None + ): + super(MinimaxGeneratorLoss, self).__init__( + reduction, override_train_ops + ) self.nonsaturating = nonsaturating def forward(self, dgz): @@ -77,8 +81,12 @@ class MinimaxDiscriminatorLoss(DiscriminatorLoss): if the default ``train_ops`` is not to be used. """ - def __init__(self, label_smoothing=0.0, reduction="mean", override_train_ops=None): - super(MinimaxDiscriminatorLoss, self).__init__(reduction, override_train_ops) + def __init__( + self, label_smoothing=0.0, reduction="mean", override_train_ops=None + ): + super(MinimaxDiscriminatorLoss, self).__init__( + reduction, override_train_ops + ) self.label_smoothing = label_smoothing def forward(self, dx, dgz): @@ -96,5 +104,8 @@ def forward(self, dx, dgz): scalar if reduction is applied else Tensor with dimensions (N, \*). """ return minimax_discriminator_loss( - dx, dgz, label_smoothing=self.label_smoothing, reduction=self.reduction + dx, + dgz, + label_smoothing=self.label_smoothing, + reduction=self.reduction, ) diff --git a/torchgan/losses/mutualinfo.py b/torchgan/losses/mutualinfo.py index b67dd1e..5809e98 100644 --- a/torchgan/losses/mutualinfo.py +++ b/torchgan/losses/mutualinfo.py @@ -31,7 +31,9 @@ class MutualInformationPenalty(GeneratorLoss, DiscriminatorLoss): """ def __init__(self, lambd=1.0, reduction="mean", override_train_ops=None): - super(MutualInformationPenalty, self).__init__(reduction, override_train_ops) + super(MutualInformationPenalty, self).__init__( + reduction, override_train_ops + ) self.lambd = lambd def forward(self, c_dis, c_cont, dist_dis, dist_cont): @@ -75,7 +77,9 @@ def train_ops( batch_size, ) else: - noise = torch.randn(batch_size, generator.encoding_dims, device=device) + noise = torch.randn( + batch_size, generator.encoding_dims, device=device + ) optimizer_discriminator.zero_grad() optimizer_generator.zero_grad() fake = generator(noise, dis_code, cont_code) diff --git a/torchgan/losses/wasserstein.py b/torchgan/losses/wasserstein.py index 79c92f6..ecbe45a 100644 --- a/torchgan/losses/wasserstein.py +++ b/torchgan/losses/wasserstein.py @@ -80,7 +80,9 @@ def __init__(self, reduction="mean", clip=None, override_train_ops=None): super(WassersteinDiscriminatorLoss, self).__init__( reduction, override_train_ops ) - if (isinstance(clip, tuple) or isinstance(clip, list)) and len(clip) > 1: + if (isinstance(clip, tuple) or isinstance(clip, list)) and len( + clip + ) > 1: self.clip = clip else: self.clip = None @@ -187,7 +189,9 @@ class WassersteinGradientPenalty(DiscriminatorLoss): """ def __init__(self, reduction="mean", lambd=10.0, override_train_ops=None): - super(WassersteinGradientPenalty, self).__init__(reduction, override_train_ops) + super(WassersteinGradientPenalty, self).__init__( + reduction, override_train_ops + ) self.lambd = lambd self.override_train_ops = override_train_ops @@ -207,7 +211,9 @@ def forward(self, interpolate, d_interpolate): # TODO(Aniket1998): Check for performance bottlenecks # If found, write the backprop yourself instead of # relying on autograd - return wasserstein_gradient_penalty(interpolate, d_interpolate, self.reduction) + return wasserstein_gradient_penalty( + interpolate, d_interpolate, self.reduction + ) def train_ops( self, @@ -260,7 +266,9 @@ def train_ops( ): raise Exception("GAN model requires labels for training") batch_size = real_inputs.size(0) - noise = torch.randn(batch_size, generator.encoding_dims, device=device) + noise = torch.randn( + batch_size, generator.encoding_dims, device=device + ) if generator.label_type == "generated": label_gen = torch.randint( 0, generator.num_classes, (batch_size,), device=device diff --git a/torchgan/metrics/classifierscore.py b/torchgan/metrics/classifierscore.py index 1f55e07..6b61154 100644 --- a/torchgan/metrics/classifierscore.py +++ b/torchgan/metrics/classifierscore.py @@ -29,7 +29,9 @@ class ClassifierScore(EvaluationMetric): def __init__(self, classifier=None, transform=None, sample_size=1): super(ClassifierScore, self).__init__() self.classifier = ( - torchvision.models.inception_v3(True) if classifier is None else classifier + torchvision.models.inception_v3(True) + if classifier is None + else classifier ) self.classifier.eval() self.transform = transform @@ -74,7 +76,9 @@ def metric_ops(self, generator, device): Returns: The Classifier Score (scalar quantity) """ - noise = torch.randn(self.sample_size, generator.encoding_dims, device=device) + noise = torch.randn( + self.sample_size, generator.encoding_dims, device=device + ) img = generator(noise).detach() score = self.__call__(img) return score diff --git a/torchgan/models/acgan.py b/torchgan/models/acgan.py index 91328dd..9720ee8 100644 --- a/torchgan/models/acgan.py +++ b/torchgan/models/acgan.py @@ -52,7 +52,9 @@ def __init__( ) self.encoding_dims = encoding_dims self.num_classes = num_classes - self.label_embeddings = nn.Embedding(self.num_classes, self.encoding_dims) + self.label_embeddings = nn.Embedding( + self.num_classes, self.encoding_dims + ) def forward(self, z, y): r"""Calculates the output tensor on passing the encoding ``z`` through the Generator. @@ -115,7 +117,11 @@ def __init__( last_nonlinearity, label_type="none", ) - last_nl = nn.LeakyReLU(0.2) if last_nonlinearity is None else last_nonlinearity + last_nl = ( + nn.LeakyReLU(0.2) + if last_nonlinearity is None + else last_nonlinearity + ) self.input_dims = in_channels self.num_classes = num_classes d = self.n * 2 ** (in_size.bit_length() - 4) diff --git a/torchgan/models/autoencoding.py b/torchgan/models/autoencoding.py index ef5ec9f..2d6f102 100644 --- a/torchgan/models/autoencoding.py +++ b/torchgan/models/autoencoding.py @@ -48,9 +48,9 @@ def __init__( label_type="none", ): super(AutoEncodingGenerator, self).__init__(encoding_dims, label_type) - if out_size < (scale_factor ** 4) or ceil(log(out_size, scale_factor)) != log( - out_size, scale_factor - ): + if out_size < (scale_factor ** 4) or ceil( + log(out_size, scale_factor) + ) != log(out_size, scale_factor): raise Exception( "Target image size must be at least {} and a perfect power of {}".format( scale_factor ** 4, scale_factor @@ -84,10 +84,14 @@ def __init__( nl, ) initial_unit = nn.Sequential( - nn.Conv2d(self.n, self.n, same_filters, 1, same_pad, bias=use_bias), + nn.Conv2d( + self.n, self.n, same_filters, 1, same_pad, bias=use_bias + ), nn.BatchNorm2d(self.n), nl, - nn.Conv2d(self.n, self.n, same_filters, 1, same_pad, bias=use_bias), + nn.Conv2d( + self.n, self.n, same_filters, 1, same_pad, bias=use_bias + ), nn.BatchNorm2d(self.n), nl, ) @@ -103,7 +107,9 @@ def __init__( ), nn.BatchNorm2d(self.n), nl, - nn.Conv2d(self.n, self.n, same_filters, 1, same_pad, bias=use_bias), + nn.Conv2d( + self.n, self.n, same_filters, 1, same_pad, bias=use_bias + ), nn.BatchNorm2d(self.n), nl, ) @@ -112,9 +118,13 @@ def __init__( nn.Linear(self.encoding_dims, (init_dim ** 2) * self.n), nl ) initial_unit = nn.Sequential( - nn.Conv2d(self.n, self.n, same_filters, 1, same_pad, bias=use_bias), + nn.Conv2d( + self.n, self.n, same_filters, 1, same_pad, bias=use_bias + ), nl, - nn.Conv2d(self.n, self.n, same_filters, 1, same_pad, bias=use_bias), + nn.Conv2d( + self.n, self.n, same_filters, 1, same_pad, bias=use_bias + ), nl, ) upsample_unit = nn.Sequential( @@ -128,12 +138,15 @@ def __init__( bias=use_bias, ), nl, - nn.Conv2d(self.n, self.n, same_filters, 1, same_pad, bias=use_bias), + nn.Conv2d( + self.n, self.n, same_filters, 1, same_pad, bias=use_bias + ), nl, ) last_unit = nn.Sequential( - nn.Conv2d(self.n, self.ch, same_filters, 1, same_pad, bias=True), last_nl + nn.Conv2d(self.n, self.ch, same_filters, 1, same_pad, bias=True), + last_nl, ) model = [initial_unit] for i in range(num_repeats): @@ -199,10 +212,12 @@ def __init__( embeddings=False, label_type="none", ): - super(AutoEncodingDiscriminator, self).__init__(in_channels, label_type) - if in_size < (scale_factor ** 4) or ceil(log(in_size, scale_factor)) != log( - in_size, scale_factor - ): + super(AutoEncodingDiscriminator, self).__init__( + in_channels, label_type + ) + if in_size < (scale_factor ** 4) or ceil( + log(in_size, scale_factor) + ) != log(in_size, scale_factor): raise Exception( "Input image size must be at least {} and a perfect power of {}".format( scale_factor ** 4, scale_factor @@ -229,7 +244,12 @@ def __init__( model.append( nn.Sequential( nn.Conv2d( - self.input_dims, self.n, same_filters, 1, same_pad, bias=True + self.input_dims, + self.n, + same_filters, + 1, + same_pad, + bias=True, ), nl, ) @@ -285,7 +305,9 @@ def __init__( ) ) self.fc = nn.Sequential( - nn.Linear((init_dim ** 2) * (num_repeats + 1) * self.n, encoding_dims), + nn.Linear( + (init_dim ** 2) * (num_repeats + 1) * self.n, encoding_dims + ), nn.BatchNorm1d(encoding_dims), last_nl, ) @@ -293,7 +315,9 @@ def __init__( for i in range(1, num_repeats + 1): model.append( nn.Sequential( - nn.Conv2d(self.n * i, self.n * i, 3, 1, 1, bias=use_bias), + nn.Conv2d( + self.n * i, self.n * i, 3, 1, 1, bias=use_bias + ), nl, nn.Conv2d( self.n * i, @@ -329,7 +353,9 @@ def __init__( ) ) self.fc = nn.Sequential( - nn.Linear((init_dim ** 2) * (num_repeats + 1) * self.n, encoding_dims), + nn.Linear( + (init_dim ** 2) * (num_repeats + 1) * self.n, encoding_dims + ), last_nl, ) self.encoder = nn.Sequential(*model) diff --git a/torchgan/models/conditional.py b/torchgan/models/conditional.py index 5342790..a27d8f8 100644 --- a/torchgan/models/conditional.py +++ b/torchgan/models/conditional.py @@ -52,7 +52,9 @@ def __init__( ) self.encoding_dims = encoding_dims self.num_classes = num_classes - self.label_embeddings = nn.Embedding(self.num_classes, self.num_classes) + self.label_embeddings = nn.Embedding( + self.num_classes, self.num_classes + ) def forward(self, z, y): r"""Calculates the output tensor on passing the encoding ``z`` through the Generator. @@ -119,7 +121,9 @@ def __init__( ) self.input_dims = in_channels self.num_classes = num_classes - self.label_embeddings = nn.Embedding(self.num_classes, self.num_classes) + self.label_embeddings = nn.Embedding( + self.num_classes, self.num_classes + ) def forward(self, x, y, feature_matching=False): r"""Calculates the output tensor on passing the image ``x`` through the Discriminator. diff --git a/torchgan/models/dcgan.py b/torchgan/models/dcgan.py index 7f94b4e..4713c69 100644 --- a/torchgan/models/dcgan.py +++ b/torchgan/models/dcgan.py @@ -59,7 +59,9 @@ def __init__( if batchnorm is True: model.append( nn.Sequential( - nn.ConvTranspose2d(self.encoding_dims, d, 4, 1, 0, bias=use_bias), + nn.ConvTranspose2d( + self.encoding_dims, d, 4, 1, 0, bias=use_bias + ), nn.BatchNorm2d(d), nl, ) @@ -76,20 +78,25 @@ def __init__( else: model.append( nn.Sequential( - nn.ConvTranspose2d(self.encoding_dims, d, 4, 1, 0, bias=use_bias), + nn.ConvTranspose2d( + self.encoding_dims, d, 4, 1, 0, bias=use_bias + ), nl, ) ) for i in range(num_repeats): model.append( nn.Sequential( - nn.ConvTranspose2d(d, d // 2, 4, 2, 1, bias=use_bias), nl + nn.ConvTranspose2d(d, d // 2, 4, 2, 1, bias=use_bias), + nl, ) ) d = d // 2 model.append( - nn.Sequential(nn.ConvTranspose2d(d, self.ch, 4, 2, 1, bias=True), last_nl) + nn.Sequential( + nn.ConvTranspose2d(d, self.ch, 4, 2, 1, bias=True), last_nl + ) ) self.model = nn.Sequential(*model) self._weight_initializer() @@ -152,9 +159,17 @@ def __init__( self.n = step_channels use_bias = not batchnorm nl = nn.LeakyReLU(0.2) if nonlinearity is None else nonlinearity - last_nl = nn.LeakyReLU(0.2) if last_nonlinearity is None else last_nonlinearity + last_nl = ( + nn.LeakyReLU(0.2) + if last_nonlinearity is None + else last_nonlinearity + ) d = self.n - model = [nn.Sequential(nn.Conv2d(self.input_dims, d, 4, 2, 1, bias=True), nl)] + model = [ + nn.Sequential( + nn.Conv2d(self.input_dims, d, 4, 2, 1, bias=True), nl + ) + ] if batchnorm is True: for i in range(num_repeats): model.append( @@ -168,10 +183,14 @@ def __init__( else: for i in range(num_repeats): model.append( - nn.Sequential(nn.Conv2d(d, d * 2, 4, 2, 1, bias=use_bias), nl) + nn.Sequential( + nn.Conv2d(d, d * 2, 4, 2, 1, bias=use_bias), nl + ) ) d *= 2 - self.disc = nn.Sequential(nn.Conv2d(d, 1, 4, 1, 0, bias=use_bias), last_nl) + self.disc = nn.Sequential( + nn.Conv2d(d, 1, 4, 1, 0, bias=use_bias), last_nl + ) self.model = nn.Sequential(*model) self._weight_initializer() diff --git a/torchgan/models/infogan.py b/torchgan/models/infogan.py index 43c287e..8274ee2 100644 --- a/torchgan/models/infogan.py +++ b/torchgan/models/infogan.py @@ -130,7 +130,9 @@ def __init__( ) self.latent_nl = ( - nn.LeakyReLU(0.2) if latent_nonlinearity is None else latent_nonlinearity + nn.LeakyReLU(0.2) + if latent_nonlinearity is None + else latent_nonlinearity ) d = self.n * 2 ** (in_size.bit_length() - 4) if batchnorm is True: @@ -155,7 +157,9 @@ def forward(self, x, return_latents=False, feature_matching=False): return x critic_score = self.disc(x) x = self.dist_conv(x).view(-1, x.size(1)) - dist_dis = distributions.OneHotCategorical(logits=self.dis_categorical(x)) + dist_dis = distributions.OneHotCategorical( + logits=self.dis_categorical(x) + ) dist_cont = distributions.Normal( loc=self.cont_mean(x), scale=torch.exp(0.5 * self.cont_logvar(x)) ) diff --git a/torchgan/trainer/base_trainer.py b/torchgan/trainer/base_trainer.py index 2b40f64..7b9d7f8 100644 --- a/torchgan/trainer/base_trainer.py +++ b/torchgan/trainer/base_trainer.py @@ -151,7 +151,9 @@ def save_model(self, epoch, save_items=None): """ if self.last_retained_checkpoint == self.retain_checkpoints: self.last_retained_checkpoint = 0 - save_path = self.checkpoints + str(self.last_retained_checkpoint) + ".model" + save_path = ( + self.checkpoints + str(self.last_retained_checkpoint) + ".model" + ) self.last_retained_checkpoint += 1 print("Saving Model at '{}'".format(save_path)) model = { @@ -195,7 +197,11 @@ def load_model(self, load_path="", load_items=None): from scratch. So make sure that item was saved. """ if load_path == "": - load_path = self.checkpoints + str(self.last_retained_checkpoint) + ".model" + load_path = ( + self.checkpoints + + str(self.last_retained_checkpoint) + + ".model" + ) print("Loading Model From '{}'".format(load_path)) try: checkpoint = torch.load(load_path) @@ -213,7 +219,9 @@ def load_model(self, load_path="", load_items=None): else: setattr(self, load_items, checkpoint["load_items"]) except: - raise Exception("Model could not be loaded from {}.".format(load_path)) + raise Exception( + "Model could not be loaded from {}.".format(load_path) + ) def _get_argument_maps(self, default_map, func): r"""Extracts the signature of the `func`. Then it returns the list of arguments that @@ -239,8 +247,14 @@ def _get_argument_maps(self, default_map, func): if arg_name in self.__dict__: arg_map.update({arg: arg_name}) else: - if arg_name not in self.__dict__ and arg != "kwargs" and arg != "args": - raise Exception("Argument : {} not present.".format(arg_name)) + if ( + arg_name not in self.__dict__ + and arg != "kwargs" + and arg != "args" + ): + raise Exception( + "Argument : {} not present.".format(arg_name) + ) else: arg_map.update({arg: arg_name}) return arg_map @@ -306,7 +320,9 @@ def train_iter(self): loss_logs = self.logger.get_loss_viz() grad_logs = self.logger.get_grad_viz() for name, loss in self.losses.items(): - if isinstance(loss, GeneratorLoss) and isinstance(loss, DiscriminatorLoss): + if isinstance(loss, GeneratorLoss) and isinstance( + loss, DiscriminatorLoss + ): # NOTE(avik-pal): In most cases this loss is meant to optimize the Discriminator # but we might need to think of a better solution if self.loss_information["generator_iters"] % self.ngen == 0: @@ -325,9 +341,14 @@ def train_iter(self): # NOTE(avik-pal): We assume that it is a Discriminator Loss by default. ldis, dis_iter = ldis + cur_loss, dis_iter + 1 for model_name in self.model_names: - grad_logs.update_grads(model_name, getattr(self, model_name)) + grad_logs.update_grads( + model_name, getattr(self, model_name) + ) elif isinstance(loss, GeneratorLoss): - if self.loss_information["discriminator_iters"] % self.ncritic == 0: + if ( + self.loss_information["discriminator_iters"] % self.ncritic + == 0 + ): cur_loss = loss.train_ops( **self._get_arguments(self.loss_arg_maps[name]) ) @@ -358,7 +379,9 @@ def eval_ops(self, **kwargs): for name, metric in self.metrics.items(): metric_logs = self.logger.get_metric_viz() metric_logs.logs[name].append( - metric.metric_ops(**self._get_arguments(self.metric_arg_maps[name])) + metric.metric_ops( + **self._get_arguments(self.metric_arg_maps[name]) + ) ) def optim_ops(self): diff --git a/torchgan/trainer/lightning_trainer.py b/torchgan/trainer/lightning_trainer.py new file mode 100644 index 0000000..0fd8897 --- /dev/null +++ b/torchgan/trainer/lightning_trainer.py @@ -0,0 +1,319 @@ +import os +import time +from types import * +import logging as lg +import warnings +from inspect import _empty, signature +from warnings import warn + +import torch +import torchvision + +from torchgan.losses import DiscriminatorLoss, GeneratorLoss +from torchgan.models import Discriminator, Generator + +import pytorch_lightning as pl + + +class LightningTrainer(pl.LightningModule): + r"""Trainer for TorchGAN built on top of Pytorch Lightning. This shall be the + default trainer post 0.1 release and all other trainers shall be deprecated. + """ + + def __init__( + self, + models_config, + losses_list, + train_dataloader, + val_dataloader=None, + metrics_list=None, + ncritic=1, + sample_size=8, + test_noise=None, + batch_size=1, + **kwargs, + ): + self.model_names = [] + self.optimizer_names = [] + self.schedulers = [] + + for key, model_config in models_config.items(): + # Model creation and storage + self.model_names.append(key) + if "model" in model_config: + setattr(self, key, model_config["model"]) + elif "args" in model_config or "name" in model_config: + warnings.warn( + "This is the old TorchGAN API. It is deprecated" + + "and shall be removed post v0.1. Please update to" + + "instantiating the model and pass it using the key" + + "`model`", + FutureWarning, + ) + args = model_config.get("args", {}) + # Instantiate a GAN model + setattr(self, key, (model_config["name"])(**args)) + else: + raise Exception( + f"Couldn't find/instantiate the model corresponding to" + + f"{key}" + ) + model = getattr(self, key) + + # Dealing with the optimizers + opt = model_config.get("optimizer", {}) + if type(opt) is dict: + if "optimizer" not in model_config: + warnings.warn( + "This is the old TorchGAN API. It is deprecated" + + "and shall be removed post v0.1. Please update to" + + "creating a lambda function taking as input" + + "the model parameters", + FutureWarning, + ) + opt_name = opt.get("var", f"optimizer_{key}") + self.optimizer_names.append(opt_name) + setattr( + self, + opt_name, + opt.get("name", torch.optim.Adam)( + model.parameters(), **opt.get("args", {}) + ), + ) + elif type(opt) is FunctionType: + opt_name = opt.get("optimizer_name", f"optimizer_{key}") + self.optimizer_names.append(opt_name) + setattr(self, opt_name, opt(model.parameters())) + else: + raise Exception( + f"Couldn't find/instantiate the optimizer corresponding to" + + f"{key}" + ) + + # TODO: Deal with schedulers + + self.losses = {} + for loss in losses_list: + self.losses[type(loss).__name__] = loss + + if metrics_list is None: + self.metrics = None + else: + self.metrics = {} + for metric in metrics_list: + self.metrics[type(metric).__name__] = metric + + self.sample_size = sample_size + self.batch_size = batch_size + + # Not needed but we need to store this to avoid errors. + # Also makes life simpler + self.noise = None + self.real_inputs = None + self.labels = None + self.batch_size = 1 + + self.generator_steps = 0 + self.discriminator_steps = 0 + + assert ncritic != 0 + if ncritic > 0: + self.ncritic = ncritic + self.ngen = 1 + else: + self.ncritic = 1 + self.ngen = abs(ncritic) + + for key, val in kwargs.items(): + if key in self.__dict__: + warn( + "Overiding the default value of {} from {} to {}".format( + key, getattr(self, key), val + ) + ) + setattr(self, key, val) + + # This is only temporarily stored and is deleted + self.train_dataloader_cached = train_dataloader + self.val_dataloader_cached = val_dataloader + + def configure_optimizers(self): + optimizers = [getattr(self, name) for name in self.optimizer_names] + return optimizers # , self.schedulers + + @pl.data_loader + def train_dataloader(self): + # TODO: Test if this actually works + train_dataloader_cached = self.train_dataloader_cached + del self.train_dataloader_cached + return train_dataloader_cached + + @pl.data_loader + def val_dataloader(self): + # TODO: Test if this actually works + val_dataloader_cached = self.val_dataloader_cached + del self.val_dataloader_cached + return val_dataloader_cached + + def _get_argument_maps(self, default_map, func): + r"""Extracts the signature of the `func`. Then it returns the list of + arguments that are present in the object and need to be mapped and + passed to the `func` when calling it. + + Args: + default_map (dict): The keys of this dictionary override the + function signature. + func (function): Function whose argument map is to be generated. + + Returns: + List of arguments that need to be fed into the function. It contains + all the positional arguments and keyword arguments that are stored + in the object. If any of the required arguments are not present an + error is thrown. + """ + sig = signature(func) + arg_map = {} + for sig_param in sig.parameters.values(): + arg = sig_param.name + arg_name = arg + if arg in default_map: + arg_name = default_map[arg] + if sig_param.default is not _empty: + if arg_name in self.__dict__: + arg_map.update({arg: arg_name}) + else: + if ( + arg_name not in self.__dict__ + and arg != "kwargs" + and arg != "args" + ): + raise Exception( + "Argument : {} not present.".format(arg_name) + ) + else: + arg_map.update({arg: arg_name}) + return arg_map + + def _store_metric_maps(self): + r"""Creates a mapping between the metrics and the arguments from the object that need to be + passed to it. + """ + if self.metrics is not None: + self.metric_arg_maps = {} + for name, metric in self.metrics.items(): + self.metric_arg_maps[name] = self._get_argument_maps( + metric.arg_map, metric.metric_ops + ) + + def _store_loss_maps(self): + r"""Creates a mapping between the losses and the arguments from the object that need to be + passed to it. + """ + self.loss_arg_maps = {} + for name, loss in self.losses.items(): + self.loss_arg_maps[name] = self._get_argument_maps( + loss.arg_map, loss.train_ops + ) + + def _get_arguments(self, arg_map): + r"""Get the argument values from the object and create a dictionary. + + Args: + arg_map (dict): A dict of arguments that is generated by `_get_argument_maps`. + + Returns: + A dictionary mapping the argument name to the value of the argument. + """ + args = {} + for key, val in arg_map.items(): + args[key] = self.__dict__[val] + return args + + def optimizer_step( + self, + current_epoch, + batch_idx, + optimizer, + optimizer_idx, + second_order_closure=None, + ): + # We handle the optimizer step and zero_grad in the train_ops + # itself, so override Pytorch Lightning's default function + return + + def handle_data_batch(self, batch): + if type(batch) in (tuple, list): + self.real_inputs = data[0].to(self.device) + self.labels = data[1].to(self.device) + elif type(data) is torch.Tensor: + self.real_inputs = data.to(self.device) + else: + self.real_inputs = data + + def training_step(self, batch, batch_idx, opt_idx): + self.handle_data_batch(batch) + gen_loss = 0.0 + dis_loss = 0.0 + + train_gen = self.discriminator_steps % self.ncritic == 0 + train_dis = self.generator_steps % self.ngen == 0 + + for name, loss in self.losses.items(): + lgen = isinstance(loss, GeneratorLoss) + ldis = isinstance(loss, DiscriminatorLoss) + + if lgen and ldis: + if train_dis: + cur_loss = loss.train_ops( + **self._get_arguments(self.loss_arg_maps[name]) + ) + + if type(cur_loss) in (tuple, list): + gen_loss += cur_loss[0] + self.generator_steps += 1 + dis_loss += cur_loss[1] + self.discriminator_steps += 1 + else: + dis_loss += cur_loss + self.discriminator_steps += 1 + elif lgen: + if train_gen: + cur_loss = loss.train_ops( + **self._get_arguments(self.loss_arg_maps[name]) + ) + + gen_loss += cur_loss + self.generator_steps += 1 + elif ldis: + if train_dis: + cur_loss = loss.train_ops( + **self._get_arguments(self.loss_arg_maps[name]) + ) + + dis_loss += cur_loss + self.discriminator_steps += 1 + else: + raise Exception( + f"type({loss}) is {type(loss)} which is not a subclass" + + f"of GeneratorLoss / DiscriminatorLoss" + ) + # Bypass Lightning by passing a zero loss + loss = torch.zeros(1) + loss.requires_grad = True + return { + "loss": loss, + "log": {"Generator Loss": gen_loss, "DiscriminatorLoss": dis_loss}, + } + + # + # def validation_step(self, batch, batch_idx): + # # OPTIONAL + # x, y = batch + # y_hat = self.forward(x) + # return {'val_loss': F.cross_entropy(y_hat, y)} + # + # def validation_end(self, outputs): + # # OPTIONAL + # avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() + # tensorboard_logs = {'val_loss': avg_loss} + # return {'avg_val_loss': avg_loss, 'log': tensorboard_logs} diff --git a/torchgan/trainer/parallel_trainer.py b/torchgan/trainer/parallel_trainer.py index 547cecd..7697d9d 100644 --- a/torchgan/trainer/parallel_trainer.py +++ b/torchgan/trainer/parallel_trainer.py @@ -99,7 +99,9 @@ def __init__( for key, model in models.items(): self.model_names.append(key) if "args" in model: - setattr(self, key, (model["name"](**model["args"])).to(self.device)) + setattr( + self, key, (model["name"](**model["args"])).to(self.device) + ) else: setattr(self, key, (model["name"]()).to(self.device)) for m in getattr(self, key)._modules: @@ -113,7 +115,9 @@ def __init__( self.optimizer_names.append(opt_name) model_params = getattr(self, key).parameters() if "args" in opt: - setattr(self, opt_name, (opt["name"](model_params, **opt["args"]))) + setattr( + self, opt_name, (opt["name"](model_params, **opt["args"])) + ) else: setattr(self, opt_name, (opt["name"](model_params))) if "scheduler" in opt: @@ -123,7 +127,9 @@ def __init__( sched["name"](getattr(self, opt_name), **sched["args"]) ) else: - self.schedulers.append(sched["name"](getattr(self, opt_name))) + self.schedulers.append( + sched["name"](getattr(self, opt_name)) + ) self.logger = Logger( self, diff --git a/torchgan/trainer/trainer.py b/torchgan/trainer/trainer.py index 61ce6a1..a6bab6b 100755 --- a/torchgan/trainer/trainer.py +++ b/torchgan/trainer/trainer.py @@ -100,7 +100,9 @@ def __init__( for key, model in models.items(): self.model_names.append(key) if "args" in model: - setattr(self, key, (model["name"](**model["args"])).to(self.device)) + setattr( + self, key, (model["name"](**model["args"])).to(self.device) + ) else: setattr(self, key, (model["name"]()).to(self.device)) opt = model["optimizer"] @@ -110,7 +112,9 @@ def __init__( self.optimizer_names.append(opt_name) model_params = getattr(self, key).parameters() if "args" in opt: - setattr(self, opt_name, (opt["name"](model_params, **opt["args"]))) + setattr( + self, opt_name, (opt["name"](model_params, **opt["args"])) + ) else: setattr(self, opt_name, (opt["name"](model_params))) if "scheduler" in opt: @@ -120,7 +124,9 @@ def __init__( sched["name"](getattr(self, opt_name), **sched["args"]) ) else: - self.schedulers.append(sched["name"](getattr(self, opt_name))) + self.schedulers.append( + sched["name"](getattr(self, opt_name)) + ) self.logger = Logger( self, From fb3db2ba4ec716145edc4718925e4f72aa827a49 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 4 May 2020 21:25:12 -0400 Subject: [PATCH 2/3] Fix import path --- torchgan/trainer/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchgan/trainer/__init__.py b/torchgan/trainer/__init__.py index 1f9a0cf..f27e558 100644 --- a/torchgan/trainer/__init__.py +++ b/torchgan/trainer/__init__.py @@ -1,3 +1,4 @@ from .base_trainer import * +from .lightning_trainer import * from .parallel_trainer import * from .trainer import * From c51f8aee33f51103db3dc58797c00cdcac974abd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 4 May 2020 22:56:17 -0400 Subject: [PATCH 3/3] Fix attributes specific to lightning --- torchgan/trainer/lightning_trainer.py | 80 +++++++++++++++------------ 1 file changed, 46 insertions(+), 34 deletions(-) diff --git a/torchgan/trainer/lightning_trainer.py b/torchgan/trainer/lightning_trainer.py index 0fd8897..ab9c9b0 100644 --- a/torchgan/trainer/lightning_trainer.py +++ b/torchgan/trainer/lightning_trainer.py @@ -15,7 +15,7 @@ import pytorch_lightning as pl -class LightningTrainer(pl.LightningModule): +class LightningGANModule(pl.LightningModule): r"""Trainer for TorchGAN built on top of Pytorch Lightning. This shall be the default trainer post 0.1 release and all other trainers shall be deprecated. """ @@ -30,9 +30,10 @@ def __init__( ncritic=1, sample_size=8, test_noise=None, - batch_size=1, **kwargs, ): + super().__init__() + self.model_names = [] self.optimizer_names = [] self.schedulers = [] @@ -104,14 +105,12 @@ def __init__( self.metrics[type(metric).__name__] = metric self.sample_size = sample_size - self.batch_size = batch_size # Not needed but we need to store this to avoid errors. # Also makes life simpler self.noise = None self.real_inputs = None self.labels = None - self.batch_size = 1 self.generator_steps = 0 self.discriminator_steps = 0 @@ -124,6 +123,10 @@ def __init__( self.ncritic = 1 self.ngen = abs(ncritic) + # This exists for convenience. We will handle the device from data in + # the `training_step` function + self.device = torch.device("cpu") + for key, val in kwargs.items(): if key in self.__dict__: warn( @@ -137,24 +140,19 @@ def __init__( self.train_dataloader_cached = train_dataloader self.val_dataloader_cached = val_dataloader + self._store_loss_maps() + self._store_metric_maps() + + def configure_optimizers(self): optimizers = [getattr(self, name) for name in self.optimizer_names] return optimizers # , self.schedulers @pl.data_loader def train_dataloader(self): - # TODO: Test if this actually works train_dataloader_cached = self.train_dataloader_cached - del self.train_dataloader_cached return train_dataloader_cached - @pl.data_loader - def val_dataloader(self): - # TODO: Test if this actually works - val_dataloader_cached = self.val_dataloader_cached - del self.val_dataloader_cached - return val_dataloader_cached - def _get_argument_maps(self, default_map, func): r"""Extracts the signature of the `func`. Then it returns the list of arguments that are present in the object and need to be mapped and @@ -179,11 +177,12 @@ def _get_argument_maps(self, default_map, func): if arg in default_map: arg_name = default_map[arg] if sig_param.default is not _empty: - if arg_name in self.__dict__: + if arg_name in self.__dict__ or arg_name in self.__dict__["_modules"]: arg_map.update({arg: arg_name}) else: if ( arg_name not in self.__dict__ + and arg_name not in self.__dict__["_modules"] and arg != "kwargs" and arg != "args" ): @@ -226,7 +225,13 @@ def _get_arguments(self, arg_map): """ args = {} for key, val in arg_map.items(): - args[key] = self.__dict__[val] + if val == "device": + args[key] = self._get_device_from_tensor(self.real_inputs) + continue + if val in self.__dict__: + args[key] = self.__dict__[val] + else: + args[key] = self.__dict__["_modules"][val] return args def optimizer_step( @@ -243,15 +248,31 @@ def optimizer_step( def handle_data_batch(self, batch): if type(batch) in (tuple, list): - self.real_inputs = data[0].to(self.device) - self.labels = data[1].to(self.device) - elif type(data) is torch.Tensor: - self.real_inputs = data.to(self.device) + self.real_inputs = batch[0] + self.labels = batch[1] + elif type(batch) is torch.Tensor: + self.real_inputs = batch else: - self.real_inputs = data + self.real_inputs = batch + + def _unfreeze_parameters(self): + for name in self.model_names: + model = getattr(self, name) + for param in model.parameters(): + param.requires_grad = True + + def _get_device_from_tensor(self, x: torch.Tensor): + if self.on_gpu: + device = torch.device(f"cuda:{x.device.index}") + return device + return torch.device("cpu") - def training_step(self, batch, batch_idx, opt_idx): + def training_step(self, batch, batch_idx, optimizer_idx): self.handle_data_batch(batch) + # FIXME: PyLightning seems to convert all the parameters to + # require no grad. + self._unfreeze_parameters() + gen_loss = 0.0 dis_loss = 0.0 @@ -302,18 +323,9 @@ def training_step(self, batch, batch_idx, opt_idx): loss.requires_grad = True return { "loss": loss, - "log": {"Generator Loss": gen_loss, "DiscriminatorLoss": dis_loss}, + "progress_bar": {"Generator Loss": gen_loss, "DiscriminatorLoss": dis_loss}, } - # - # def validation_step(self, batch, batch_idx): - # # OPTIONAL - # x, y = batch - # y_hat = self.forward(x) - # return {'val_loss': F.cross_entropy(y_hat, y)} - # - # def validation_end(self, outputs): - # # OPTIONAL - # avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() - # tensorboard_logs = {'val_loss': avg_loss} - # return {'avg_val_loss': avg_loss, 'log': tensorboard_logs} + def forward(x): + pass +