Skip to content

Latest commit

 

History

History
302 lines (206 loc) · 11.1 KB

File metadata and controls

302 lines (206 loc) · 11.1 KB

Examples

These examples will help you get started using Intel® Extension for PyTorch* with Intel GPUs.

Prerequisites: Before running these examples, install the torchvision and transformers Python packages.

Python

Training

Single-Instance Training

To use Intel® Extension for PyTorch* on training, you need to make the following changes in your code:

  1. Import intel_extension_for_pytorch as ipex.
  2. Use the ipex.optimize function for additional performance boost, which applies optimizations against the model object, as well as an optimizer object.
  3. Use Auto Mixed Precision (AMP) with BFloat16 data type.
  4. Convert input tensors, loss criterion and model to XPU, as shown below:
...
import torch
import intel_extension_for_pytorch as ipex
...
model = Model()
criterion = ...
optimizer = ...
model.train()
# Move model and loss criterion to xpu before calling ipex.optimize()
model = model.to("xpu")
criterion = criterion.to("xpu")

# For Float32
model, optimizer = ipex.optimize(model, optimizer=optimizer)
# For BFloat16
model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=torch.bfloat16)
...
dataloader = ...
for (input, target) in dataloader:
    input = input.to("xpu")
    target = target.to("xpu")
    optimizer.zero_grad()
    # For Float32
    output = model(input)

    # For BFloat16
    with torch.xpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
        output = model(input)

    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
...

Below you can find complete code examples demonstrating how to use the extension on training for different data types:

Float32
BFloat16

Inference

Get additional performance boosts for your computer vision and NLP workloads by applying the Intel® Extension for PyTorch* optimize function against your model object.

Float32

Imperative Mode
Resnet50
BERT
TorchScript Mode

We recommend using Intel® Extension for PyTorch* with TorchScript for further optimizations.

Resnet50
BERT

BFloat16

The optimize function works for both Float32 and BFloat16 data type. For BFloat16 data type, set the dtype parameter to torch.bfloat16. We recommend using Auto Mixed Precision (AMP) with BFloat16 data type.

Imperative Mode
Resnet50
BERT
TorchScript Mode

We recommend using Intel® Extension for PyTorch* with TorchScript for further optimizations.

Resnet50
BERT

Float16

The optimize function works for both Float32 and Float16 data type. For Float16 data type, set the dtype parameter to torch.float16. We recommend using Auto Mixed Precision (AMP) with Float16 data type.

Imperative Mode
Resnet50
BERT
TorchScript Mode

We recommend using Intel® Extension for PyTorch* with TorchScript for further optimizations.

Resnet50
BERT

INT8

We recommend using TorchScript for INT8 model because it has wider support for models. TorchScript mode also auto-enables our optimizations. For TorchScript INT8 model, inserting observer and model quantization is achieved through prepare_jit and convert_jit separately. Calibration process is required for collecting statistics from real data. After conversion, optimizations such as operator fusion would be auto-enabled.

torch.xpu.optimize

The torch.xpu.optimize function is an alternative to ipex.optimize in Intel® Extension for PyTorch*, and provides identical usage for XPU devices only. The motivation for adding this alias is to unify the coding style in user scripts base on torch.xpu modular. Refer to the example below for usage.

C++

To work with libtorch, the PyTorch C++ library, Intel® Extension for PyTorch* provides its own C++ dynamic library. The C++ library only handles inference workloads, such as service deployment. For regular development, use the Python interface. Unlike using libtorch, no specific code changes are required. Compilation follows the recommended methodology with CMake. Detailed instructions can be found in the PyTorch tutorial.

During compilation, Intel optimizations will be activated automatically after the C++ dynamic library of Intel® Extension for PyTorch* is linked.

The example code below works for all data types.

Basic Usage

Download and Install cppsdk

Ensure you have download and install cppsdk in the installation page before compiling the cpp code.

  1. Go to installation page
  2. Select the desired Platform & Version & OS
  3. In the package part, select cppsdk
  4. Follow the instructions in the cppsdk installation page to download and install cppsdk into libtorch.

example-app.cpp

CMakeLists.txt

Command for compilation

$ cd examples/gpu/inference/cpp/example-app
$ mkdir build
$ cd build
$ CC=icx CXX=icpx cmake -DCMAKE_PREFIX_PATH=<LIBPYTORCH_PATH> ..
$ make

The <LIBPYTORCH_PATH> is the absolute path of libtorch we install at the first step.

If Found IPEX is shown as dynamic library paths, the extension was linked into the binary. This can be verified with the Linux command ldd.

The value of x, y, z in the following log will change depending on the version you choose.

$ CC=icx CXX=icpx cmake -DCMAKE_PREFIX_PATH=/workspace/libtorch ..
-- The C compiler identification is IntelLLVM 202x.y.z
-- The CXX compiler identification is IntelLLVM 202x.y.z
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Check for working C compiler: /workspace/intel/oneapi/compiler/202x.y.z/linux/bin/icx - skipped
-- Detecting C compile features
-- Detecting C compile features - done
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Check for working CXX compiler: /workspace/intel/oneapi/compiler/202x.y.z/linux/bin/icpx - skipped
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Looking for pthread.h
-- Looking for pthread.h - found
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Success
-- Found Threads: TRUE
-- Found Torch: /workspace/libtorch/lib/libtorch.so
-- Found IPEX: /workspace/libtorch/lib/libintel-ext-pt-cpu.so;/workspace/libtorch/lib/libintel-ext-pt-gpu.so
-- Configuring done
-- Generating done
-- Build files have been written to: examples/gpu/inference/cpp/example-app/build

$ ldd example-app
        ...
        libtorch.so => /workspace/libtorch/lib/libtorch.so (0x00007fd5bb927000)
        libc10.so => /workspace/libtorch/lib/libc10.so (0x00007fd5bb895000)
        libtorch_cpu.so => /workspace/libtorch/lib/libtorch_cpu.so (0x00007fd5a44d8000)
        libintel-ext-pt-cpu.so => /workspace/libtorch/lib/libintel-ext-pt-cpu.so (0x00007fd5a1a1b000)
        libintel-ext-pt-gpu.so => /workspace/libtorch/lib/libintel-ext-pt-gpu.so (0x00007fd5862b0000)
        ...
        libmkl_intel_lp64.so.2 => /workspace/intel/oneapi/mkl/202x.y.z/lib/intel64/libmkl_intel_lp64.so.2 (0x00007fd584ab0000)
        libmkl_core.so.2 => /workspace/intel/oneapi/mkl/202x.y.z/lib/intel64/libmkl_core.so.2 (0x00007fd5806cc000)
        libmkl_gnu_thread.so.2 => /workspace/intel/oneapi/mkl/202x.y.z/lib/intel64/libmkl_gnu_thread.so.2 (0x00007fd57eb1d000)
        libmkl_sycl.so.3 => /workspace/intel/oneapi/mkl/202x.y.z/lib/intel64/libmkl_sycl.so.3 (0x00007fd55512c000)
        libOpenCL.so.1 => /workspace/intel/oneapi/compiler/202x.y.z/linux/lib/libOpenCL.so.1 (0x00007fd55511d000)
        libsvml.so => /workspace/intel/oneapi/compiler/202x.y.z/linux/compiler/lib/intel64_lin/libsvml.so (0x00007fd553b11000)
        libirng.so => /workspace/intel/oneapi/compiler/202x.y.z/linux/compiler/lib/intel64_lin/libirng.so (0x00007fd553600000)
        libimf.so => /workspace/intel/oneapi/compiler/202x.y.z/linux/compiler/lib/intel64_lin/libimf.so (0x00007fd55321b000)
        libintlc.so.5 => /workspace/intel/oneapi/compiler/202x.y.z/linux/compiler/lib/intel64_lin/libintlc.so.5 (0x00007fd553a9c000)
        libsycl.so.6 => /workspace/intel/oneapi/compiler/202x.y.z/linux/lib/libsycl.so.6 (0x00007fd552f36000)
        ...

Use SYCL code

Using SYCL code in an C++ application is also possible. The example below shows how to invoke SYCL codes. You need to explicitly pass -fsycl into CMAKE_CXX_FLAGS.

example-usm.cpp

CMakeLists.txt

Customize DPC++ kernels

Intel® Extension for PyTorch* provides its C++ dynamic library to allow users to implement custom DPC++ kernels to run on the XPU device. Refer to the DPC++ extension for details.

Intel® AI Reference Models

Use cases that have already been optimized by Intel engineers are available at Intel® AI Reference Models (former Model Zoo). A number of PyTorch use cases for benchmarking are also available in the Use Cases section. Models verified on Intel GPUs are marked in the Model Documentation column. You can get performance benefits out-of-the-box by simply running scripts in the Intel® AI Reference Models.