.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        Click :ref:`here <sphx_glr_download_auto_examples_adversarial_robustness_train_robust_cifar10.py>`     to download the full example code
    .. rst-class:: sphx-glr-example-title

    .. _sphx_glr_auto_examples_adversarial_robustness_train_robust_cifar10.py:

Example of robust training on CIFAR10.
=========================================


.. code-block:: default

    from chop.adversary import Adversary
    import torch
    from tqdm import tqdm

    import chop

    from torch.optim import SGD

    from torchvision import models

    device = torch.device('cuda' if torch.cuda.is_available()
                          else 'cpu')

    n_epochs = 128
    batch_size = 100

    loaders = chop.data.load_cifar10(train_batch_size=batch_size,
                                     test_batch_size=batch_size,
                                     data_dir='~/datasets')
    trainloader, testloader = loaders.train, loaders.test

    model = models.resnet18(pretrained=False)
    model.to(device)

    criterion = torch.nn.CrossEntropyLoss()

    # TODO: use learning rate schedulers
    optimizer = SGD(model.parameters(), lr=.1, momentum=.9, weight_decay=5e-4)

    # Define the perturbation constraint set
    max_iter_train = 7
    max_iter_test = 20
    alpha = 8. / 255
    constraint = chop.constraints.LinfBall(alpha)
    criterion_adv = torch.nn.CrossEntropyLoss(reduction='none')

    print(f"Training on L{constraint.p} ball({alpha}).")


    adversary = Adversary(chop.optim.minimize_pgd_madry)

    for _ in range(n_epochs):

        # Train
        n_correct = 0
        n_correct_adv = 0

        model.train()

        for k, (data, target) in tqdm(enumerate(trainloader), total=len(trainloader)):
            data = data.to(device)
            target = target.to(device)

            def image_constraint_prox(delta, step_size=None):
                """Projects perturbation delta
                so that 0. <= data + delta <= 1."""

                adv_img = torch.clamp(data + delta, 0, 1)
                delta = adv_img - data
                return delta

            def prox(delta, step_size=None):
                delta = constraint.prox(delta, step_size)
                delta = image_constraint_prox(delta, step_size)
                return delta

            _, delta = adversary.perturb(data, target, model,
                                         criterion_adv,
                                         prox=prox,
                                         lmo=constraint.lmo,
                                         step=2. / max_iter_train,
                                         max_iter=max_iter_train)

            output = model(data)
            output_adv = model(data + delta)
            loss = criterion(output_adv, target)
            loss.backward()

            pred = torch.argmax(output, dim=-1)
            pred_adv = torch.argmax(output_adv, dim=-1)

            n_correct += (pred == target).sum().item()
            n_correct_adv += (pred_adv == target).sum().item()

        print(f"Train Accuracy: {n_correct / 50000.}")
        print(f"Train Adv Accuracy: {n_correct_adv / 50000.}")

        # Test
        n_correct = 0
        n_correct_adv = 0

        model.eval()

        for k, (data, target) in tqdm(enumerate(testloader), total=len(testloader)):
            data = data.to(device)
            target = target.to(device)

            def image_constraint_prox(delta, step_size=None):
                """Projects perturbation delta
                so that 0. <= data + delta <= 1."""

                adv_img = torch.clamp(data + delta, 0, 1)
                delta = adv_img - data
                return delta

            def prox(delta, step_size=None):
                delta = constraint.prox(delta, step_size)
                delta = image_constraint_prox(delta, step_size)
                return delta

            _, delta = adversary.perturb(data, target, model,
                                         criterion_adv,
                                         prox=prox,
                                         lmo=constraint.lmo,
                                         step=2. / max_iter_test,
                                         max_iter=max_iter_test)

            with torch.no_grad():
                output = model(data)
                output_adv = model(data + delta)

                pred = torch.argmax(output, dim=-1)
                pred_adv = torch.argmax(output_adv, dim=-1)

                n_correct += (pred == target).sum().item()
                n_correct_adv += (pred_adv == target).sum().item()

        print(f"Test Accuracy: {n_correct / 10000.}")
        print(f"Test Adv Accuracy: {n_correct_adv / 10000.}")

.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 0 minutes  0.000 seconds)

**Estimated memory usage:**  0 MB


.. _sphx_glr_download_auto_examples_adversarial_robustness_train_robust_cifar10.py:


.. only :: html

 .. container:: sphx-glr-footer
    :class: sphx-glr-footer-example



  .. container:: sphx-glr-download sphx-glr-download-python

     :download:`Download Python source code: train_robust_cifar10.py <train_robust_cifar10.py>`



  .. container:: sphx-glr-download sphx-glr-download-jupyter

     :download:`Download Jupyter notebook: train_robust_cifar10.ipynb <train_robust_cifar10.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
