.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        Click :ref:`here <sphx_glr_download_auto_examples_adversarial_robustness_plot_train_robust_cifar10.py>`     to download the full example code
    .. rst-class:: sphx-glr-example-title

    .. _sphx_glr_auto_examples_adversarial_robustness_plot_train_robust_cifar10.py:


Example of robust training on CIFAR10.
=========================================



.. image:: /auto_examples/adversarial_robustness/images/sphx_glr_plot_train_robust_cifar10_001.png
    :alt: Clean data accuracies, Adversarial data accuracies
    :class: sphx-glr-single-img


.. rst-class:: sphx-glr-script-out

 Out:

 .. code-block:: none

    Files already downloaded and verified
    Files already downloaded and verified
    Training on Linf ball(0.03137254901960784).
    Train Accuracy: 32.6%
    Train Adv Accuracy: 22.9%
    Test Accuracy: 30.7%
    Test Adv Accuracy: 21.2%
    Train Accuracy: 46.3%
    Train Adv Accuracy: 27.6%
    Test Accuracy: 44.9%
    Test Adv Accuracy: 27.0%
    Train Accuracy: 54.6%
    Train Adv Accuracy: 29.2%
    Test Accuracy: 38.1%
    Test Adv Accuracy: 27.2%
    Train Accuracy: 59.5%
    Train Adv Accuracy: 30.7%
    Test Accuracy: 53.8%
    Test Adv Accuracy: 27.7%
    Train Accuracy: 63.4%
    Train Adv Accuracy: 31.9%
    Test Accuracy: 52.7%
    Test Adv Accuracy: 26.3%
    Train Accuracy: 65.8%
    Train Adv Accuracy: 31.9%
    Test Accuracy: 44.2%
    Test Adv Accuracy: 28.0%
    Train Accuracy: 67.4%
    Train Adv Accuracy: 32.4%
    Test Accuracy: 57.8%
    Test Adv Accuracy: 32.5%
    Train Accuracy: 69.1%
    Train Adv Accuracy: 32.1%
    Test Accuracy: 59.3%
    Test Adv Accuracy: 26.5%
    Train Accuracy: 70.2%
    Train Adv Accuracy: 33.0%
    Test Accuracy: 58.7%
    Test Adv Accuracy: 29.1%
    Train Accuracy: 71.2%
    Train Adv Accuracy: 33.3%
    Test Accuracy: 62.5%
    Test Adv Accuracy: 31.9%
    Train Accuracy: 71.6%
    Train Adv Accuracy: 33.1%
    Test Accuracy: 60.6%
    Test Adv Accuracy: 28.7%
    Train Accuracy: 72.2%
    Train Adv Accuracy: 33.3%
    Test Accuracy: 60.7%
    Test Adv Accuracy: 30.0%
    Train Accuracy: 72.6%
    Train Adv Accuracy: 33.1%
    Test Accuracy: 66.1%
    Test Adv Accuracy: 23.6%
    Train Accuracy: 73.2%
    Train Adv Accuracy: 33.6%
    Test Accuracy: 60.9%
    Test Adv Accuracy: 28.9%
    Train Accuracy: 73.6%
    Train Adv Accuracy: 34.0%
    Test Accuracy: 60.7%
    Test Adv Accuracy: 31.1%
    Train Accuracy: 74.0%
    Train Adv Accuracy: 34.0%
    Test Accuracy: 63.8%
    Test Adv Accuracy: 27.6%
    Train Accuracy: 74.4%
    Train Adv Accuracy: 34.1%
    Test Accuracy: 63.6%
    Test Adv Accuracy: 28.6%
    Train Accuracy: 74.7%
    Train Adv Accuracy: 33.8%
    Test Accuracy: 62.9%
    Test Adv Accuracy: 26.4%
    Train Accuracy: 74.9%
    Train Adv Accuracy: 33.9%
    Test Accuracy: 61.2%
    Test Adv Accuracy: 28.0%
    Train Accuracy: 75.1%
    Train Adv Accuracy: 34.2%
    Test Accuracy: 61.8%
    Test Adv Accuracy: 28.1%
    Train Accuracy: 75.0%
    Train Adv Accuracy: 34.1%
    Test Accuracy: 63.5%
    Test Adv Accuracy: 32.7%
    Train Accuracy: 75.4%
    Train Adv Accuracy: 34.2%
    Test Accuracy: 62.5%
    Test Adv Accuracy: 28.4%
    Train Accuracy: 75.3%
    Train Adv Accuracy: 34.1%
    Test Accuracy: 62.2%
    Test Adv Accuracy: 30.5%
    Train Accuracy: 75.8%
    Train Adv Accuracy: 34.2%
    Test Accuracy: 56.4%
    Test Adv Accuracy: 29.4%
    Train Accuracy: 76.0%
    Train Adv Accuracy: 33.9%
    Test Accuracy: 62.2%
    Test Adv Accuracy: 27.7%
    Train Accuracy: 76.0%
    Train Adv Accuracy: 34.6%
    Test Accuracy: 63.8%
    Test Adv Accuracy: 28.6%
    Train Accuracy: 76.1%
    Train Adv Accuracy: 34.3%
    Test Accuracy: 58.4%
    Test Adv Accuracy: 30.3%
    Train Accuracy: 76.1%
    Train Adv Accuracy: 34.6%
    Test Accuracy: 66.1%
    Test Adv Accuracy: 30.0%
    Train Accuracy: 75.9%
    Train Adv Accuracy: 34.6%
    Test Accuracy: 64.3%
    Test Adv Accuracy: 31.4%
    Train Accuracy: 76.3%
    Train Adv Accuracy: 34.3%
    Test Accuracy: 60.1%
    Test Adv Accuracy: 31.6%
    Train Accuracy: 76.4%
    Train Adv Accuracy: 34.4%
    Test Accuracy: 51.9%
    Test Adv Accuracy: 31.9%
    Train Accuracy: 76.6%
    Train Adv Accuracy: 34.8%
    Test Accuracy: 65.5%
    Test Adv Accuracy: 33.0%
    Train Accuracy: 76.5%
    Train Adv Accuracy: 34.5%
    Test Accuracy: 56.4%
    Test Adv Accuracy: 31.8%
    Train Accuracy: 76.5%
    Train Adv Accuracy: 34.2%
    Test Accuracy: 65.2%
    Test Adv Accuracy: 29.7%
    Train Accuracy: 76.6%
    Train Adv Accuracy: 34.6%
    Test Accuracy: 62.9%
    Test Adv Accuracy: 31.8%
    Train Accuracy: 76.5%
    Train Adv Accuracy: 34.5%
    Test Accuracy: 58.2%
    Test Adv Accuracy: 25.9%
    Train Accuracy: 76.4%
    Train Adv Accuracy: 34.7%
    Test Accuracy: 63.8%
    Test Adv Accuracy: 29.7%
    Train Accuracy: 77.0%
    Train Adv Accuracy: 34.7%
    Test Accuracy: 57.9%
    Test Adv Accuracy: 30.2%
    Train Accuracy: 77.0%
    Train Adv Accuracy: 34.0%
    Test Accuracy: 64.6%
    Test Adv Accuracy: 29.4%
    Train Accuracy: 76.7%
    Train Adv Accuracy: 34.5%
    Test Accuracy: 62.4%
    Test Adv Accuracy: 28.2%
    Train Accuracy: 76.8%
    Train Adv Accuracy: 34.7%
    Test Accuracy: 59.6%
    Test Adv Accuracy: 26.4%
    Train Accuracy: 77.3%
    Train Adv Accuracy: 34.7%
    Test Accuracy: 64.6%
    Test Adv Accuracy: 27.7%
    Train Accuracy: 77.0%
    Train Adv Accuracy: 35.0%
    Test Accuracy: 57.5%
    Test Adv Accuracy: 30.7%
    Train Accuracy: 76.8%
    Train Adv Accuracy: 34.8%
    Test Accuracy: 65.8%
    Test Adv Accuracy: 32.9%
    Train Accuracy: 77.3%
    Train Adv Accuracy: 35.1%
    Test Accuracy: 61.3%
    Test Adv Accuracy: 27.5%
    Train Accuracy: 77.4%
    Train Adv Accuracy: 34.6%
    Test Accuracy: 62.4%
    Test Adv Accuracy: 31.8%
    Train Accuracy: 77.0%
    Train Adv Accuracy: 34.6%
    Test Accuracy: 57.0%
    Test Adv Accuracy: 31.7%
    Train Accuracy: 77.2%
    Train Adv Accuracy: 35.2%
    Test Accuracy: 57.7%
    Test Adv Accuracy: 31.7%
    Train Accuracy: 77.2%
    Train Adv Accuracy: 35.0%
    Test Accuracy: 65.6%
    Test Adv Accuracy: 24.2%
    Train Accuracy: 77.6%
    Train Adv Accuracy: 34.8%
    Test Accuracy: 63.5%
    Test Adv Accuracy: 30.3%
    Train Accuracy: 77.7%
    Train Adv Accuracy: 35.2%
    Test Accuracy: 63.7%
    Test Adv Accuracy: 30.8%
    Train Accuracy: 77.3%
    Train Adv Accuracy: 35.0%
    Test Accuracy: 63.5%
    Test Adv Accuracy: 30.8%
    Train Accuracy: 77.4%
    Train Adv Accuracy: 35.1%
    Test Accuracy: 59.3%
    Test Adv Accuracy: 32.5%
    Train Accuracy: 77.4%
    Train Adv Accuracy: 34.8%
    Test Accuracy: 56.7%
    Test Adv Accuracy: 30.5%
    Train Accuracy: 77.3%
    Train Adv Accuracy: 35.3%
    Test Accuracy: 67.7%
    Test Adv Accuracy: 32.0%
    Train Accuracy: 77.7%
    Train Adv Accuracy: 35.2%
    Test Accuracy: 61.4%
    Test Adv Accuracy: 32.5%
    Train Accuracy: 77.6%
    Train Adv Accuracy: 34.9%
    Test Accuracy: 63.1%
    Test Adv Accuracy: 31.6%
    Train Accuracy: 77.8%
    Train Adv Accuracy: 35.0%
    Test Accuracy: 63.5%
    Test Adv Accuracy: 29.2%
    Train Accuracy: 77.2%
    Train Adv Accuracy: 35.5%
    Test Accuracy: 61.5%
    Test Adv Accuracy: 33.4%
    Train Accuracy: 77.6%
    Train Adv Accuracy: 35.2%
    Test Accuracy: 66.1%
    Test Adv Accuracy: 27.1%
    Train Accuracy: 77.6%
    Train Adv Accuracy: 34.9%
    Test Accuracy: 63.4%
    Test Adv Accuracy: 31.2%
    Train Accuracy: 77.7%
    Train Adv Accuracy: 35.2%
    Test Accuracy: 62.4%
    Test Adv Accuracy: 31.3%
    Train Accuracy: 77.3%
    Train Adv Accuracy: 35.0%
    Test Accuracy: 65.2%
    Test Adv Accuracy: 28.5%
    Train Accuracy: 77.9%
    Train Adv Accuracy: 34.8%
    Test Accuracy: 63.9%
    Test Adv Accuracy: 27.9%
    Train Accuracy: 77.8%
    Train Adv Accuracy: 35.1%
    Test Accuracy: 63.7%
    Test Adv Accuracy: 28.0%
    Train Accuracy: 77.7%
    Train Adv Accuracy: 35.5%
    Test Accuracy: 58.6%
    Test Adv Accuracy: 31.7%
    Train Accuracy: 77.8%
    Train Adv Accuracy: 35.1%
    Test Accuracy: 63.4%
    Test Adv Accuracy: 31.5%
    Train Accuracy: 77.9%
    Train Adv Accuracy: 34.8%
    Test Accuracy: 57.1%
    Test Adv Accuracy: 31.2%
    Train Accuracy: 77.7%
    Train Adv Accuracy: 35.2%
    Test Accuracy: 62.8%
    Test Adv Accuracy: 31.0%
    Train Accuracy: 77.1%
    Train Adv Accuracy: 34.7%
    Test Accuracy: 57.0%
    Test Adv Accuracy: 32.3%
    Train Accuracy: 77.9%
    Train Adv Accuracy: 35.1%
    Test Accuracy: 64.8%
    Test Adv Accuracy: 29.7%
    Train Accuracy: 77.9%
    Train Adv Accuracy: 35.3%
    Test Accuracy: 65.1%
    Test Adv Accuracy: 29.5%
    Train Accuracy: 78.0%
    Train Adv Accuracy: 34.9%
    Test Accuracy: 64.1%
    Test Adv Accuracy: 31.1%
    Train Accuracy: 77.8%
    Train Adv Accuracy: 35.7%
    Test Accuracy: 58.1%
    Test Adv Accuracy: 33.8%
    Train Accuracy: 77.8%
    Train Adv Accuracy: 35.2%
    Test Accuracy: 66.6%
    Test Adv Accuracy: 30.8%
    Train Accuracy: 78.1%
    Train Adv Accuracy: 35.4%
    Test Accuracy: 57.5%
    Test Adv Accuracy: 27.9%
    Train Accuracy: 78.1%
    Train Adv Accuracy: 35.1%
    Test Accuracy: 60.9%
    Test Adv Accuracy: 29.9%
    Train Accuracy: 77.9%
    Train Adv Accuracy: 35.1%
    Test Accuracy: 64.3%
    Test Adv Accuracy: 26.8%
    Train Accuracy: 77.8%
    Train Adv Accuracy: 34.9%
    Test Accuracy: 54.2%
    Test Adv Accuracy: 30.5%
    Train Accuracy: 78.2%
    Train Adv Accuracy: 35.0%
    Test Accuracy: 54.6%
    Test Adv Accuracy: 31.0%
    Train Accuracy: 78.2%
    Train Adv Accuracy: 35.2%
    Test Accuracy: 62.4%
    Test Adv Accuracy: 31.7%
    Train Accuracy: 77.9%
    Train Adv Accuracy: 35.4%
    Test Accuracy: 61.1%
    Test Adv Accuracy: 34.1%
    Train Accuracy: 78.1%
    Train Adv Accuracy: 34.8%
    Test Accuracy: 51.0%
    Test Adv Accuracy: 31.5%
    Train Accuracy: 77.9%
    Train Adv Accuracy: 35.0%
    Test Accuracy: 64.0%
    Test Adv Accuracy: 26.5%
    Train Accuracy: 78.2%
    Train Adv Accuracy: 35.1%
    Test Accuracy: 54.2%
    Test Adv Accuracy: 33.4%
    Train Accuracy: 77.6%
    Train Adv Accuracy: 35.3%
    Test Accuracy: 63.6%
    Test Adv Accuracy: 33.1%
    Train Accuracy: 77.9%
    Train Adv Accuracy: 35.4%
    Test Accuracy: 65.6%
    Test Adv Accuracy: 29.6%
    Train Accuracy: 78.0%
    Train Adv Accuracy: 35.1%
    Test Accuracy: 65.7%
    Test Adv Accuracy: 27.6%
    Train Accuracy: 77.8%
    Train Adv Accuracy: 35.5%
    Test Accuracy: 57.6%
    Test Adv Accuracy: 32.7%
    Train Accuracy: 78.0%
    Train Adv Accuracy: 35.0%
    Test Accuracy: 50.4%
    Test Adv Accuracy: 30.7%
    Train Accuracy: 78.3%
    Train Adv Accuracy: 35.0%
    Test Accuracy: 59.2%
    Test Adv Accuracy: 27.7%
    Train Accuracy: 78.0%
    Train Adv Accuracy: 34.5%
    Test Accuracy: 59.7%
    Test Adv Accuracy: 32.5%
    Train Accuracy: 78.2%
    Train Adv Accuracy: 35.4%
    Test Accuracy: 57.2%
    Test Adv Accuracy: 29.3%
    Train Accuracy: 78.1%
    Train Adv Accuracy: 35.3%
    Test Accuracy: 53.1%
    Test Adv Accuracy: 30.8%
    Train Accuracy: 77.7%
    Train Adv Accuracy: 34.9%
    Test Accuracy: 59.7%
    Test Adv Accuracy: 29.9%
    Train Accuracy: 78.2%
    Train Adv Accuracy: 35.3%
    Test Accuracy: 56.8%
    Test Adv Accuracy: 28.0%
    Train Accuracy: 77.9%
    Train Adv Accuracy: 35.0%
    Test Accuracy: 57.8%
    Test Adv Accuracy: 31.8%
    Train Accuracy: 78.1%
    Train Adv Accuracy: 35.0%
    Test Accuracy: 55.6%
    Test Adv Accuracy: 30.0%
    Train Accuracy: 78.2%
    Train Adv Accuracy: 35.1%
    Test Accuracy: 63.7%
    Test Adv Accuracy: 30.3%
    Train Accuracy: 78.2%
    Train Adv Accuracy: 35.6%
    Test Accuracy: 63.1%
    Test Adv Accuracy: 35.0%






|


.. code-block:: default

    import matplotlib.pyplot as plt
    from chop.adversary import Adversary
    import torch
    from tqdm import tqdm
    from easydict import EasyDict

    import chop

    from torch.optim import SGD

    from torchvision import models

    device = torch.device('cuda' if torch.cuda.is_available()
                          else 'cpu')

    n_epochs = 100
    batch_size = 128
    batch_size_test = 100

    loaders = chop.data.load_cifar10(train_batch_size=batch_size,
                                     test_batch_size=batch_size_test,
                                     data_dir='~/datasets',
                                     augment_train=True)

    trainloader, testloader = loaders.train, loaders.test
    n_train = len(trainloader.dataset)
    n_test = len(testloader.dataset)

    model = models.resnet18(pretrained=False)
    model.to(device)

    criterion = torch.nn.CrossEntropyLoss()

    optimizer = SGD(model.parameters(), lr=.1, momentum=.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

    # Define the perturbation constraint set
    max_iter_train = 7
    max_iter_test = 20
    alpha = 8. / 255
    constraint = chop.constraints.LinfBall(alpha)
    criterion_adv = torch.nn.CrossEntropyLoss(reduction='none')

    print(f"Training on L{constraint.p} ball({alpha}).")


    adversary = Adversary(chop.optim.minimize_pgd_madry)

    results = EasyDict(train_acc=[], test_acc=[],
                       train_acc_adv=[], test_acc_adv=[],
                       train_adv_loss=[],
                       test_adv_loss=[])

    for _ in range(n_epochs):

        # Train
        n_correct = 0
        n_correct_adv = 0

        model.train()

        for k, (data, target) in enumerate(trainloader):
            data = data.to(device)
            target = target.to(device)

            @torch.no_grad()
            def image_constraint_prox(delta, step_size=None):
                """Projects perturbation delta
                so that 0. <= data + delta <= 1."""

                adv_img = torch.clamp(data + delta, 0, 1)
                delta = adv_img - data
                return delta

            @torch.no_grad()
            def prox(delta, step_size=None):
                delta = constraint.prox(delta, step_size)
                delta = image_constraint_prox(delta, step_size)
                return delta

            _, delta = adversary.perturb(data, target, model,
                                         criterion_adv,
                                         prox=prox,
                                         lmo=constraint.lmo,
                                         step=2. / max_iter_train,
                                         max_iter=max_iter_train)

            optimizer.zero_grad()
        
            output = model(data)
            output_adv = model(data + delta)
            loss = criterion(output, target)
            loss.backward()

            optimizer.step()

            pred = torch.argmax(output, dim=-1)
            pred_adv = torch.argmax(output_adv, dim=-1)

            n_correct += (pred == target).sum().item()
            n_correct_adv += (pred_adv == target).sum().item()

        results.train_acc.append(100. * n_correct / n_train)
        results.train_acc_adv.append(100. * n_correct_adv / n_train)
        print(f"Train Accuracy: {results.train_acc[-1] :.1f}%")
        print(f"Train Adv Accuracy: {results.train_acc_adv[-1]:.1f}%")

        # Test
        n_correct = 0
        n_correct_adv = 0

        model.eval()

        for k, (data, target) in enumerate(testloader):
            data = data.to(device)
            target = target.to(device)

            @torch.no_grad()
            def image_constraint_prox(delta, step_size=None):
                """Projects perturbation delta
                so that 0. <= data + delta <= 1."""

                adv_img = torch.clamp(data + delta, 0, 1)
                delta = adv_img - data
                return delta

            @torch.no_grad()
            def prox(delta, step_size=None):
                delta = constraint.prox(delta, step_size)
                delta = image_constraint_prox(delta, step_size)
                return delta

            _, delta = adversary.perturb(data, target, model,
                                            criterion_adv,
                                            prox=prox,
                                            lmo=constraint.lmo,
                                            step=2. / max_iter_test,
                                            max_iter=max_iter_test)

            with torch.no_grad():
                output = model(data)
                output_adv = model(data + delta)

                pred = torch.argmax(output, dim=-1)
                pred_adv = torch.argmax(output_adv, dim=-1)

            n_correct += (pred == target).sum().item()
            n_correct_adv += (pred_adv == target).sum().item()

        results.test_acc.append(100. * n_correct / n_test)
        results.test_acc_adv.append(100. * n_correct_adv / n_test)

        print(f"Test Accuracy: {results.test_acc[-1]:.1f}%")
        print(f"Test Adv Accuracy: {results.test_acc_adv[-1]:.1f}%")


    fig, ax = plt.subplots(nrows=2, sharex=True)

    ax[0].set_title("Clean data accuracies")
    ax[0].plot(results.train_acc, label='Train Acc')
    ax[0].plot(results.test_acc, label='Test Acc')
    ax[1].set_title("Adversarial data accuracies")
    ax[1].plot(results.train_acc_adv, label='Train Acc Adv')
    ax[1].plot(results.test_acc_adv, label='Test Acc Adv')
    plt.legend()
    plt.show()


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 624 minutes  28.935 seconds)

**Estimated memory usage:**  2425 MB


.. _sphx_glr_download_auto_examples_adversarial_robustness_plot_train_robust_cifar10.py:


.. only :: html

 .. container:: sphx-glr-footer
    :class: sphx-glr-footer-example



  .. container:: sphx-glr-download sphx-glr-download-python

     :download:`Download Python source code: plot_train_robust_cifar10.py <plot_train_robust_cifar10.py>`



  .. container:: sphx-glr-download sphx-glr-download-jupyter

     :download:`Download Jupyter notebook: plot_train_robust_cifar10.ipynb <plot_train_robust_cifar10.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
