.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        Click :ref:`here <sphx_glr_download_auto_examples_training_constrained_net_on_mnist.py>`     to download the full example code
    .. rst-class:: sphx-glr-example-title

    .. _sphx_glr_auto_examples_training_constrained_net_on_mnist.py:


Constrained neural network training.
======================================
Trains a LeNet5 model on MNIST using constraints on the weights.


.. code-block:: default

    from tqdm import tqdm

    import numpy as np
    import torch
    from torch import nn

    from easydict import EasyDict

    from advertorch.test_utils import LeNet5
    from advertorch_examples.utils import get_mnist_train_loader
    from advertorch_examples.utils import get_mnist_test_loader

    import chop

    # Setup
    torch.manual_seed(0)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # Data Loaders
    train_loader = get_mnist_train_loader(batch_size=50, shuffle=True)
    test_loader = get_mnist_test_loader(batch_size=512, shuffle=True)

    # Model setup
    model = LeNet5()
    model.to(device)

    criterion = nn.CrossEntropyLoss()

    # Outer optimization parameters
    nb_epochs = 20
    momentum = .9
    lr = 0.3

    # Choose optimizer here
    # optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Make constraints

    alpha = 1.
    constraint = chop.constraints.LinfBall(alpha)

    # Project model parameters in the constraint set.
    constraint.make_feasible(model)

    optimizer = chop.stochastic.FrankWolfe(model.parameters(), constraint, lr=lr, momentum=momentum)

    # Training loop
    for epoch in range(nb_epochs):
        model.train()
        train_loss = 0.
        for data, target in tqdm(train_loader, desc=f'Training epoch {epoch}/{nb_epochs - 1}'):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            loss = criterion(model(data), target)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
        train_loss /= len(train_loader)
        print(f'Training loss: {train_loss:.3f}')
        # TODO: get accuracy

        # Evaluate on clean and adversarial test data

        model.eval()
        report = EasyDict(nb_test=0, correct=0, correct_adv_pgd=0,
                          correct_adv_pgd_madry=0,
                          correct_adv_fw=0, correct_adv_mfw=0)

        for data, target in tqdm(test_loader, desc=f'Val epoch {epoch}/{nb_epochs - 1}'):
            data, target = data.to(device), target.to(device)

            # Compute corresponding predictions        
            _, pred = model(data).max(1)

            # Get clean accuracies
            report.nb_test += data.size(0)
            report.correct += pred.eq(target).sum().item()

        print(f'Val acc on clean examples (%): {report.correct / report.nb_test * 100.:.3f}')


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 0 minutes  0.000 seconds)

**Estimated memory usage:**  0 MB


.. _sphx_glr_download_auto_examples_training_constrained_net_on_mnist.py:


.. only :: html

 .. container:: sphx-glr-footer
    :class: sphx-glr-footer-example



  .. container:: sphx-glr-download sphx-glr-download-python

     :download:`Download Python source code: training_constrained_net_on_mnist.py <training_constrained_net_on_mnist.py>`



  .. container:: sphx-glr-download sphx-glr-download-jupyter

     :download:`Download Jupyter notebook: training_constrained_net_on_mnist.ipynb <training_constrained_net_on_mnist.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
