Introduction to TorchRec


To get the most of this tutorial, we suggest using this Colab Version. This will allow you to experiment with the information presented below.

Follow along with the video below or on youtube.

Frequently, when building recommendation systems, we want to represent entities like products or pages with embeddings. For example, see Meta AI’s Deep learning recommendation model, or DLRM. As the number of entities grow, the size of the embedding tables can exceed a single GPU’s memory. A common practice is to shard the embedding table across devices, a type of model parallelism. To that end, TorchRec introduces its primary API called DistributedModelParallel, or DMP. Like PyTorch’s DistributedDataParallel, DMP wraps a model to enable distributed training.


Requirements: python >= 3.7

We highly recommend CUDA when using TorchRec. If using CUDA: cuda >= 11.0

# install pytorch with cudatoolkit 11.3
conda install pytorch cudatoolkit=11.3 -c pytorch-nightly -y
# install TorchTec
pip3 install torchrec-nightly


This tutorial will cover three pieces of TorchRec: the nn.module EmbeddingBagCollection, the DistributedModelParallel API, and the datastructure KeyedJaggedTensor.

Distributed Setup

We setup our environment with torch.distributed. For more info on distributed, see this tutorial.

Here, we use one rank (the colab process) corresponding to our 1 colab GPU.

import os
import torch
import torchrec
import torch.distributed as dist

os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"

# Note - you will need a V100 or A100 to run tutorial as as!
# If using an older GPU (such as colab free K80),
# you will need to compile fbgemm with the appripriate CUDA architecture
# or run with "gloo" on CPUs

From EmbeddingBag to EmbeddingBagCollection

PyTorch represents embeddings through torch.nn.Embedding and torch.nn.EmbeddingBag. EmbeddingBag is a pooled version of Embedding.

TorchRec extends these modules by creating collections of embeddings. We will use EmbeddingBagCollection to represent a group of EmbeddingBags.

Here, we create an EmbeddingBagCollection (EBC) with two embedding bags. Each table, product_table and user_table, is represented by a 64 dimension embedding of size 4096. Note how we initially allocate the EBC on device “meta”. This will tell EBC to not allocate memory yet.

ebc = torchrec.EmbeddingBagCollection(


Now, we’re ready to wrap our model with DistributedModelParallel (DMP). Instantiating DMP will:

  1. Decide how to shard the model. DMP will collect the available ‘sharders’ and come up with a ‘plan’ of the optimal way to shard the embedding table(s) (i.e., the EmbeddingBagCollection).

  2. Actually shard the model. This includes allocating memory for each embedding table on the appropriate device(s).

In this toy example, since we have two EmbeddingTables and one GPU, TorchRec will place both on the single GPU.

model = torchrec.distributed.DistributedModelParallel(ebc, device=torch.device("cuda"))

Query vanilla nn.EmbeddingBag with input and offsets

We query nn.Embedding and nn.EmbeddingBag with input and offsets. Input is a 1-D tensor containing the lookup values. Offsets is a 1-D tensor where the sequence is a cumulative sum of the number of values to pool per example.

Let’s look at an example, recreating the product EmbeddingBag above:

| product ID |
| [101, 202] |
| []         |
| [303]      |
product_eb = torch.nn.EmbeddingBag(4096, 64)
product_eb(input=torch.tensor([101, 202, 303]), offsets=torch.tensor([0, 2, 2]))

Representing minibatches with KeyedJaggedTensor

We need an efficient representation of multiple examples of an arbitrary number of entity IDs per feature per example. In order to enable this “jagged” representation, we use the TorchRec datastructure KeyedJaggedTensor (KJT).

Let’s take a look at how to lookup a collection of two embedding bags, “product” and “user”. Assume the minibatch is made up of three examples for three users. The first of which has two product IDs, the second with none, and the third with one product ID.

| product ID | user ID    |
| [101, 202] | [404]      |
| []         | [505]      |
| [303]      | [606]      |

The query should be:

mb = torchrec.KeyedJaggedTensor(
    keys = ["product", "user"],
    values = torch.tensor([101, 202, 303, 404, 505, 606]).cuda(),
    lengths = torch.tensor([2, 0, 1, 1, 1, 1], dtype=torch.int64).cuda(),


Note that the KJT batch size is batch_size = len(lengths)//len(keys). In the above example, batch_size is 3.

Putting it all together, querying our distributed model with a KJT minibatch

Finally, we can query our model using our minibatch of products and users.

The resulting lookup will contain a KeyedTensor, where each key (or feature) contains a 2D tensor of size 3x64 (batch_size x embedding_dim).

pooled_embeddings = model(mb)

More resources

For more information, please see our dlrm example, which includes multinode training on the criteo terabyte dataset, using Meta’s DLRM.

이 튜토리얼이 어떠셨나요?

© Copyright 2022, PyTorch & 파이토치 한국 사용자 모임(PyTorch Korea User Group).

Built with Sphinx using a theme provided by Read the Docs.

PyTorchKorea @ GitHub

파이토치 한국 사용자 모임을 GitHub에서 만나보세요.

GitHub로 이동

한국어 튜토리얼

한국어로 번역 중인 PyTorch 튜토리얼입니다.

튜토리얼로 이동


다른 사용자들과 의견을 나누고, 도와주세요!

커뮤니티로 이동