TRTorch: A compiler for TorchScript Targeting NVIDIA GPUs

Writing a compiler for PyTorch that optmizes neural nets for NVIDIA GPUs

March 20, 2020 - 18 minute read -
cv ai

In the practice of developing machine learning models, there are few tools as approachable as PyTorch for developing and experimenting in designing machine learning models. The power of PyTorch comes from its deep integration into Python, its flexibility and its approach to automatic differentiation and execution (eager execution). However, when moving from research into production, the requirements change and we may no longer want that deep Python integration and we want optimization to get the best performance we can on our deployment platform. In PyTorch 1.0, TorchScript was introduced as a method to separate your PyTorch model from Python, make it portable and optimizable. TorchScript uses PyTorch’s JIT compiler to transform your normal PyTorch code which gets interpreted by the Python interpreter to an intermediate representation (IR) which can have optimizations run on it and at runtime can get interpreted by the PyTorch JIT interpreter. For PyTorch this has opened up a whole new world of possibilities, including deployment in other languages like C++. It also introduces a structured graph based format that we can use to do down to the kernel level optimization of models for inference.

When deploying on NVIDIA GPUs TensorRT, NVIDIA’s Deep Learning Optimization SDK and Runtime is able to take models from any major framework and specifically tune them to perform better on specific target hardware in the NVIDIA family be it an A100, TITAN V, Jetson Xavier or NVIDIA’s Deep Learning Accelerator. TensorRT performs a couple sets of optimizations to achieve this. TensorRT fuses layers and tensors in the model graph, it then uses a large kernel library to select implementations that perform best on the target GPU. TensorRT also has strong support for reduced operating precision execution which allows users to leverage the Tensor Cores on Volta and newer GPUs as well as reducing memory and computation footprints on device.

TRTorch is a compiler that uses TensorRT to optimize TorchScript code, compiling standard TorchScript modules into ones that internally run with TensorRT optimizations. This enables you to continue to remain in the PyTorch ecosystem, using all the great features PyTorch has such as module composability, its flexible tensor implementation, data loaders and more. TRTorch is available to use with both PyTorch and LibTorch. Here is a quick example on how to use it. Creating a TorchScript Module

When developing models in PyTorch, modules are the unit of composition, you create or use subclasses of torch.nn.module which are composed into a super module representing your model. Here is a simple LeNet implementation:

class LeNetFeatExtractor(nn.Module):
    def __init__(self):
        super(LeNetFeatExtractor, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        return x

class LeNetClassifier(nn.Module):
    def __init__(self):
        super(LeNetClassifier, self).__init__()
        self.fc1 = nn.Linear(16 * 6 * 6, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = torch.flatten(x,1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.feat = LeNetFeatExtractor()
        self.classifer = LeNetClassifier()

    def forward(self, x):
        x = self.feat(x)
        x = self.classifer(x)
        return x

Here we create two submodules for a feature extractor and a classifier and stitch them together in a single LeNet module. In this case this is overkill but modules give us granular control over our program including where we decide to optimize and where we don’t. It is also the unit that the TorchScript compiler operates on. So you can decide to only convert/optimize the feature extractor and leave the classifier in standard PyTorch or you can convert the whole thing. When compiling your module to TorchScript, there are two paths: Tracing and Scripting.


Tracing follows the path of execution when the module is called and records what happens. This recording is what the TorchScript IR will describe. To trace an instance of our LeNet module, we can call torch.jit.trace with an example input.

model = LeNet()
traced_model = torch.jit.trace(model, torch.empty([1,1,32,32]))


Scripting actually inspects your code with a compiler and generates an equivalent TorchScript program. The difference is that since tracing simply follows the execution of your module, it cannot pick up control flow for instance, it will only follow the code path that a particular input triggers. By working from the Python code, the compiler can include these components. We can run the script compiler on our LeNet module by calling torch.jit.script.

model = LeNet()
script_model = torch.jit.script(model)

There are reasons to use one path or another, the PyTorch documentation has information on how to choose and more information on TorchScript From a TRTorch perspective, there is better support (i.e your module is more likely to compile) for traced modules because it doesn’t include all the complexities of a complete programming language, though both paths are supported.

After scripting or tracing your module, you are given back a TorchScript Module. This contains the code and parameters used to run the module stored in an intermediate representation. Here is what the LeNet traced module IR looks like:

graph(%self.1 : __torch__.___torch_mangle_10.LeNet,
    %input.1 : Float(1, 1, 32, 32)):
    %129 : __torch__.___torch_mangle_9.LeNetClassifier = prim::GetAttr[name="classifer"](%self.1)
    %119 : __torch__.___torch_mangle_5.LeNetFeatExtractor = prim::GetAttr[name="feat"](%self.1)
    %137 : Tensor = prim::CallMethod[name="forward"](%119, %input.1)
    %138 : Tensor = prim::CallMethod[name="forward"](%129, %137)
    return (%138)

and the LeNet scripted module IR:

graph(%self : __torch__.LeNet,
    %x.1 : Tensor):
    %2 : __torch__.LeNetFeatExtractor = prim::GetAttr[name="feat"](%self)
    %x.3 : Tensor = prim::CallMethod[name="forward"](%2, %x.1) #
    %5 : __torch__.LeNetClassifier = prim::GetAttr[name="classifer"](%self)
    %x.5 : Tensor = prim::CallMethod[name="forward"](%5, %x.3) #
    return (%x.5)

You can see that the IR preserves the module structure we have in our python code.

From the user perspective it behaves like a normal PyTorch module. You can run the forward pass using the forward method or just calling the module torch_script_module(in_tensor). The JIT interpreter will run the module and return the results. From here we are in a position to compile our module further to run more efficiently on GPUs with TRTorch.

Compiling with TRTorch in Python

To compile your TorchScript module with TRTorch, all you need to do is provide the module and some compiler settings to TRTorch and you will be returned an optimized TorchScript module to run or add into another PyTorch module. The only required setting is the input size for applications which will have a static input size which is defined as a list of either list types like lists, tuples or PyTorch size objects. If your application needs to support a range of sizes, entries in the list should be dictionaries of minimum, optimal and maximum sizes which describe the range of dimensions that the auto-tuner should use for optimization. The auto-tuner will select kernels which have the lowest runtime for the optimal size but are valid through the range of dimensions specified. You can also specify settings such as operating precision for the engine or target device.

import trtorch

compile_settings = {
    "input_shapes": [
            "min": [1, 3, 224, 224],
            "opt": [1, 3, 512, 512],
            "max": [1, 3, 1024, 1024]
        }, # For static size [1, 3, 224, 224]
    "op_precision": torch.half # Run with FP16

trt_ts_module = trtorch.compile(torch_script_module, compile_settings)

input_data = input_data.half()
result = trt_ts_module(input_data)

Working with TorchScript in C++

After converting a PyTorch module to TorchScript our dependency on Python ends. We can just save our TorchScript module using which will serialize the TorchScript code, weights and other information into a package and load it in a C++ application.

#include <torch/script.h> // One-stop header.

#include <iostream>
#include <memory>

int main(int argc, const char* argv[]) {
    torch::jit::Module module;
    try {
        // Deserialize the ScriptModule from a file using torch::jit::load().
        module = torch::jit::load("<PATH TO SAVED TS MOD>");
    catch (const c10::Error& e) {
        std::cerr << "error loading the model\n";
        return -1;

    std::cout << "ok\n";

You can do full training and inference in C++ with LibTorch if you would like, you can even define your modules in C++ and have access to the same powerful tensor library that backs PyTorch. (For more information: For instance we can run inference with our module like this:

torch::Tensor in = torch::randn({1, 1, 32, 32});
auto out = mod.forward(in);

and to run on the GPU:

torch::Tensor in = torch::randn({1, 1, 32, 32}, torch::kCUDA);
auto out = mod.forward(in);

As you can see it is pretty similar to the Python API. When you call the forward method, you again invoke the JIT interpreter in the same way running a PyTorch module in Python would to run the TorchScript code and return back the results.

Compiling with TRTorch in C++

We are also at the point where we can compile and optimize our module with TRTorch in C++. With our module loaded, we can feed it into the TRTorch compiler. When we do so, we must provide a similar set of additional information on the expected input size and also configure any settings when we do.

#include "torch/script.h"
#include "trtorch/trtorch.h"

    auto in = torch::randn({1, 1, 32, 32}, {torch::kCUDA});
    auto trt_mod = trtorch::CompileGraph(mod, std::vector<trtorch::ExtraInfo::InputRange>);
    auto out = trt_mod.forward({in});

That’s it! Now the graph runs a TensorRT optimized version of the module through the JIT interpreter.

We can also set settings like operating precision to run in FP16.

#include "torch/script.h"
#include "trtorch/trtorch.h"

    auto in = torch::randn({1, 1, 32, 32}, {torch::kCUDA}).to(torch::kHALF);
    auto input_sizes = std::vector<trtorch::ExtraInfo::InputRange>({in.sizes()});
    trtorch::ExtraInfo info(input_sizes);
    info.op_precision = torch::kHALF;
    auto trt_mod = trtorch::CompileGraph(mod, info);
    auto out = trt_mod.forward({in});

And now we are running the module in FP16 precision.

If you want to save the engine produced by TRTorch to use in a TensorRT application, you can use the ConvertGraphToTRTEngine API (this API is available in Python as well and is compatible with the TensorRT Python API to deploy).

#include "torch/script.h"
#include "trtorch/trtorch.h"

    auto in = torch::randn({1, 1, 32, 32}, {torch::kCUDA}).to(torch::kHALF);
    auto input_sizes = std::vector<trtorch::ExtraInfo::InputRange>({in.sizes()});
    trtorch::ExtraInfo info(input_sizes);
    info.op_precision = torch::kHALF;
    auto trt_mod = trtorch::ConvertGraphToTRTEngine(mod, "forward", info);
    std::ofstream out("/tmp/engine_converted_from_jit.trt");
    out << engine;

Compiler Internals

When a module is provided to TRTorch, it goes through three main phases, lowering, compilation, and execution.

Lowering Phase

The lowering is made up of a set of passes (some from PyTorch and some specific to TRTorch) run over the graph IR to map the large PyTorch opset to a reduced opset that is easier to convert to TensorRT.

graph(%input.2 : Tensor):
    %2 : Float(84, 10) = prim::Constant[value=<Tensor>]()
    %3 : Float(120, 84) = prim::Constant[value=<Tensor>]()
    %4 : Float(576, 120) = prim::Constant[value=<Tensor>]()
    %5 : int = prim::Constant[value=-1]() #
    %6 : int[] = prim::Constant[value=annotate(List[int], [])]()
    %7 : int[] = prim::Constant[value=[2, 2]]()
    %8 : int[] = prim::Constant[value=[0, 0]]()
    %9 : int[] = prim::Constant[value=[1, 1]]()
    %10 : bool = prim::Constant[value=1]() # ~/.local/lib/python3.6/site-packages/torch/nn/modules/
    %11 : int = prim::Constant[value=1]() # ~/.local/lib/python3.6/site-packages/torch/nn/
    %12 : bool = prim::Constant[value=0]() # ~/.local/lib/python3.6/site-packages/torch/nn/
    %self.classifer.fc3.bias : Float(10) = prim::Constant[value= 0.0464  0.0383  0.0678  0.0932  0.1045 -0.0805 -0.0435 -0.0818  0.0208 -0.0358 [ CUDAFloatType{10} ]]()
    %self.classifer.fc2.bias : Float(84) = prim::Constant[value=<Tensor>]()
    %self.classifer.fc1.bias : Float(120) = prim::Constant[value=<Tensor>]()
    %self.feat.conv2.weight : Float(16, 6, 3, 3) = prim::Constant[value=<Tensor>]()
    %self.feat.conv2.bias : Float(16) = prim::Constant[value=<Tensor>]()
    %self.feat.conv1.weight : Float(6, 1, 3, 3) = prim::Constant[value=<Tensor>]()
    %self.feat.conv1.bias : Float(6) = prim::Constant[value= 0.0530 -0.1691  0.2802  0.1502  0.1056 -0.1549 [ CUDAFloatType{6} ]]()
    %input0.4 : Tensor = aten::_convolution(%input.2, %self.feat.conv1.weight, %self.feat.conv1.bias, %9, %8, %9, %12, %8, %11, %12, %12, %10) # ~/.local/lib/python3.6/site-packages/torch/nn/modules/
    %input0.5 : Tensor = aten::relu(%input0.4) # ~/.local/lib/python3.6/site-packages/torch/nn/
    %input1.2 : Tensor = aten::max_pool2d(%input0.5, %7, %6, %8, %9, %12) # ~/.local/lib/python3.6/site-packages/torch/nn/
    %input0.6 : Tensor = aten::_convolution(%input1.2, %self.feat.conv2.weight, %self.feat.conv2.bias, %9, %8, %9, %12, %8, %11, %12, %12, %10) # ~/.local/lib/python3.6/site-packages/torch/nn/modules/
    %input2.1 : Tensor = aten::relu(%input0.6) # ~/.local/lib/python3.6/site-packages/torch/nn/
    %x.1 : Tensor = aten::max_pool2d(%input2.1, %7, %6, %8, %9, %12) # ~/.local/lib/python3.6/site-packages/torch/nn/
    %input.1 : Tensor = aten::flatten(%x.1, %11, %5) #
    %27 : Tensor = aten::matmul(%input.1, %4)
    %28 : Tensor = trt::const(%self.classifer.fc1.bias)
    %29 : Tensor = aten::add_(%28, %27, %11)
    %input0.2 : Tensor = aten::relu(%29) # ~/.local/lib/python3.6/site-packages/torch/nn/
    %31 : Tensor = aten::matmul(%input0.2, %3)
    %32 : Tensor = trt::const(%self.classifer.fc2.bias)
    %33 : Tensor = aten::add_(%32, %31, %11)
    %input1.1 : Tensor = aten::relu(%33) # ~/.local/lib/python3.6/site-packages/torch/nn/
    %35 : Tensor = aten::matmul(%input1.1, %2)
    %36 : Tensor = trt::const(%self.classifer.fc3.bias)
    %37 : Tensor = aten::add_(%36, %35, %11)
    return (%37)

The graph has now been transformed from a collection of modules, each managing their own parameters into a single graph with the parameters inlined into the graph and all of the operations laid out. TRTorch has also executed a number of optimizations and mappings to make the graph easier to translate to TensorRT and produces more efficient TensorRT engines. From here the compiler can assemble the TensorRT engine by following the dataflow through the graph.

Conversion Phase

In the conversion phase we traverse the lowered graph and construct an equivalent TensorRT graph. The conversion phase is made up of three main components, a context to manage compile time data, a evaluator library which will execute operations that can be resolved at compile time and a converter library which maps an op from JIT to TensorRT. Operations are mapped to TensorRT through the use of modular converters, functions that take a node from a the JIT graph and produce an equivalent layer or subgraph in TensorRT. TRTorch ships with a library of these converters stored in a registry, that will be executed depending on the node being analyzed. For instance an aten::relu(%input0.4) instruction will trigger the ReLU converter to be run on it, producing an activation layer in the TensorRT graph.

For example, if we try to compile a graph with a build of TRTorch that doesn’t support the flatten operation (aten::flatten) you may see this error:

terminate called after throwing an instance of 'trtorch::Error'
what():  [enforce fail at core/conversion/conversion.cpp:109] Expected converter to be true but got false
Unable to convert node: %input.1 : Tensor = aten::flatten(%x.1, %11, %5) # (conversion.AddLayer)
Schema: aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> (Tensor)
Converter for aten::flatten requested, but no such converter was found.
If you need a converter for this operator, you can try implementing one yourself
or request a converter:

We register converters in the global converter registry, associating a function schema like aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> (Tensor) with a lambda that will take the state of the conversion, the node/operation in question to convert and all of the inputs to the node and produces as a side effect a new layer in the TensorRT network. Arguments are passed as a vector of inspectable unions of TensorRT ITensors and Torch IValues (PyTorch data such as Tensors, ints, lists, etc.) in the order arguments are listed in the schema.

Below is an implementation of an aten::flatten converter.

#include "torch/script.h"
#include "trtorch/trtorch.h"
#include "trtorch/core/conversion/converters/converters.h"

static auto flatten_converter = trtorch::core::conversion::converters::RegisterNodeConversionPatterns()
        "aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> (Tensor)",
        [](trtorch::core::conversion::ConversionCtx* ctx,
            const torch::jit::Node* n,
            trtorch::core::conversion::converters::args& args) -> bool {
            auto in = args[0].ITensor();
            auto start_dim = args[1].unwrapToInt();
            auto end_dim = args[2].unwrapToInt();
            auto in_shape = trtorch::core::util::toVec(in->getDimensions());
            auto out_shape = torch::flatten(torch::rand(in_shape), start_dim, end_dim).sizes();

            auto shuffle = ctx->net->addShuffle(*in);

            auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
            return true;

Writing Converters

Converter Contract

What is guaranteed to converters

  • In the args there will be an entry for each node input value, either a ITensor or IValue
  • Inputs will be provided in order according to the function schema

Responsibilities of a converter

  • Args must be guaranteed to be a type to unwrap the Arg union without checking, typically input tensor arguments can be expected to be ITensors
  • Any weights or static values must guaranteed to be valid until the end of conversion time
  • Converters are expected to produce an IValue or ITensor for each output of a node. The compiler will check this and produce warnings if there are Values that don’t have associated ITensors or IValues.
  • Outputs must be annotated. There must be an association between a JIT nodes output values and the new TRT layers output tensors in the value_tensor_map in the conversion context
Conversion Context

The conversion context maintains the state of conversion, it manages the Network Definition, two maps one that stores associations between Values and IValues (the evaluated_value_map) and one that stores associations between Values and ITensors, and any sort of memory that needs to live until the end of conversion. The main apis that you will interface with in converters is directly accessing the network definition to add layers ctx->net and data association functions ctx->AssociateValueAndTensor() and ctx->AssociateValueAndIValue(), which you will use to add layers to the TRT layers and log pairs of node outputs and static values or TensorRT layer outputs.


Arguments provided to the converter are inspectable unions of nvinfer1::ITensors and torch::jit::IValues (i.e. abstract dataflow in the TensorRT graph and static values). You are guaranteed that you will have some argument for each input value for the node. They are provided in the order of the function schema. It can be expected that inputs (meaning the parameters that would be passed into the forward function of a module in PyTorch) will be ITensors but the Arg class also has mechanisms to inspect arguments safely before unwrapping if you are unsure. Args also have deep unwrap methods that let you get straight to the underlying data in an IValue if you know it’s safe. You can also pass in a fallback value if there is a chance the IValue is None. IValues have been extended to be able to hold a wrapper around ITensors only in the case of TensorLists. You can get an ITensor from an IValue by a pattern similar to this: ivalue.toCustomClass()->tensor() . You can tell if an IValue contains a Tensor or an ITensor by using ivalue.isTensor() or ivalue.isCustomClass() .


Weights are used during build time, so any weights need to be guaranteed to live until the end of the conversion phase. TensorRT also uses its own weights structure to hold the weights. There is a wrapper around this class available to converts which abstracts a lot of this.

The weights wrapper class can accept either at::Tensors or singular values (right now). You also need to pass the conversion context when constructing these weights because internally the weights class will allocate memory managed by the conversion context to store a copy of the tensor data. This data gets freed when the conversion context destructor gets destroyed so converters don’t really need to think about it.

There is metadata generated from the shape of the input data which becomes useful in interfacing with TensorRT, such as number of input maps, number of output maps and kernel shape.

The complier will topologically traverse the graph from the lowering phase and then iteratively construct the TensorRT graph. Once it is complete, the graph is optimized for the target hardware using TensorRT’s library of tuned kernel implementations. Here TRTorch creates a JIT Module to execute the TensorRT engine which will be instantiated and managed by the TRTorch runtime.

Here is the graph that you get back after compilation is complete:

graph(%self.1 : __torch__.___torch_mangle_10.LeNet_trt,
    %2 : Tensor):
    %1 : int = prim::Constant[value=94106001690080]()
    %3 : Tensor = trt::execute_engine(%1, %2)
    return (%3)

You can see the call where the engine is executed, based on a constant which is the ID of the engine, telling JIT how to find the engine and the input tensor which will be fed to TensorRT.

Runtime Phase

After the graph is provided to the user, it is now a fully compliant TorchScript module. You can seralize and deserialize it the same way and all that is required is to have the TRTorch runtime loaded in memory.

Getting TRTorch

TRTorch 0.1.0 precompiled binaries are available on GitHub You can get a tarball distribution of the C++ API (either pre-CXX11 ABI or CXX11 ABI builds) and wheel files for the Python package, which is installable via pip. To get the latest and greatest, you can compile the library and/or the Python package from source following these instructions: Compilation on embedded platforms like Jetson is possible via minor modifications to the build system.

More Information

I gave two talks on TRTorch at the GPU Technology Conference:

Github: Project Page: