• Tutorials >
  • Customize Process Group Backends Using Cpp Extensions

Customize Process Group Backends Using Cpp Extensions

Author: Feng Tian, Shen Li, Min Si


edit View and edit this tutorial in github.


This tutorial demonstrates how to implement a custom ProcessGroup backend and plug that into PyTorch distributed package using cpp extensions. This is helpful when you need a specialized software stack for your hardware, or when you would like to experiment with new collective communication algorithms.


PyTorch collective communications power several widely adopted distributed training features, including DistributedDataParallel, ZeroRedundancyOptimizer, FullyShardedDataParallel. In order to make the same collective communication API work with different communication backends, the distributed package abstracts collective communication operations into a ProcessGroup class. Different backends can then be implemented as subclasses of ProcessGroup using preferred third-party libraries. PyTorch distributed comes with three default backends, ProcessGroupNCCL, ProcessGroupGloo, and ProcessGroupMPI. However, beyond these three backends, there are also other communication libraries (e.g., UCC, OneCCL), different types of hardware (e.g., TPU, Trainum), and emerging communication algorithms (e.g., Herring, Reduction Server). Therefore, the distributed package exposes extension APIs to allow customizing collective communication backends.

The 4 steps below show how to implement a dummy ProcessGroup backend and use that in Python application code. Please note that this tutorial focuses on demonstrating the extension APIs, instead of developing a functioning communication backend. Hence, the dummy backend just covers a subset of the APIs (all_reduce and all_gather), and simply sets the values of tensors to 0.

Step 1: Implement a Subclass of ProcessGroup

This first step is to implement a ProcessGroup subclass that overrides target collective communication APIs and runs the custom communication algorithm. The extension also needs to implement a Work subclass, which serves as a future of communication results and allows asynchronous execution in application code. If the extension uses third-party libraries, it can include the headers and call into the library APIs from the ProcessGroupDummy subclass. The two code snippets below present the implementation of dummy.h and dummy.cpp. See the dummy collectives repository for the full implementation.

// file name: dummy.hpp
#include <torch/python.h>

#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/Work.hpp>
#include <torch/csrc/distributed/c10d/Store.hpp>
#include <torch/csrc/distributed/c10d/Types.hpp>
#include <torch/csrc/distributed/c10d/Utils.hpp>

#include <pybind11/chrono.h>

namespace c10d {

class ProcessGroupDummy : public ProcessGroup {
    ProcessGroupDummy(int rank, int size);

    c10::intrusive_ptr<Work> allgather(
        std::vector<std::vector<at::Tensor>>& outputTensors,
        std::vector<at::Tensor>& inputTensors,
        const AllgatherOptions& opts = AllgatherOptions()) override;

    c10::intrusive_ptr<Work> allreduce(
        std::vector<at::Tensor>& tensors,
        const AllreduceOptions& opts = AllreduceOptions()) override;

    // The collective communication APIs without a custom implementation
    // will error out if invoked by application code.

class WorkDummy : public Work {
      OpType opType,
      c10::intrusive_ptr<c10::ivalue::Future> future) // future of the output
      : Work(
          -1, // rank, only used by recvAnySource, irrelevant in this demo
      future_(std::move(future)) {}
    // There are several additional helper functions that need to be
    // implemented. Please refer to https://github.com/mrshenli/dummy_collectives
    // for the full implementation.

    c10::intrusive_ptr<c10::ivalue::Future> future_;
} // namespace c10d
// file name: dummy.cpp
#include "dummy.hpp"

namespace c10d {

// This is a dummy allgather that sets all output tensors to zero
// Modify the implementation to conduct real communication asynchronously
c10::intrusive_ptr<Work> ProcessGroupDummy::allgather(
        std::vector<std::vector<at::Tensor>>& outputTensors,
        std::vector<at::Tensor>& inputTensors,
        const AllgatherOptions& /* unused */) {
    for (auto& outputTensorVec : outputTensors) {
        for (auto& outputTensor : outputTensorVec) {

    auto future = c10::make_intrusive<c10::ivalue::Future>(
    return c10::make_intrusive<WorkDummy>(OpType::ALLGATHER, std::move(future));

// This is a dummy allreduce that sets all output tensors to zero
// Modify the implementation to conduct real communication asynchronously
c10::intrusive_ptr<Work> ProcessGroupDummy::allreduce(
        std::vector<at::Tensor>& tensors,
        const AllreduceOptions& opts) {
    for (auto& tensor : tensors) {

    auto future = c10::make_intrusive<c10::ivalue::Future>(
    return c10::make_intrusive<WorkDummy>(OpType::ALLGATHER, std::move(future));
} // namespace c10d

Step 2: Expose The Extension Python APIs

The backend constructors are called from Python side, so the extension also needs to expose the constructor APIs to Python. This can be done by adding the following methods. In this example, store and timeout are ignored by the ProcessGroupDummy instantiation method, as those are not used in this dummy implementation. However, real-world extensions should consider using the store to perform rendezvous and supporting the timeout argument.

class ProcessGroupDummy : public ProcessGroup {
    static c10::intrusive_ptr<ProcessGroup> createProcessGroupDummy(
        const c10::intrusive_ptr<::c10d::Store>& store,
        int rank,
        int size,
        const std::chrono::duration<float>& timeout);

    static void ProcessGroupDummyConstructor() __attribute__((constructor)) {
        py::object module = py::module::import("torch.distributed");
        py::object register_backend =
        // torch.distributed.Backend.register_backend will add `dummy` as a
        // new valid backend.
        register_backend("dummy", py::cpp_function(createProcessGroupDummy));
c10::intrusive_ptr<ProcessGroup> ProcessGroupDummy::createProcessGroupDummy(
        const c10::intrusive_ptr<::c10d::Store>& /* unused */,
        int rank,
        int size,
        const std::chrono::duration<float>& /* unused */) {
    return c10::make_intrusive<ProcessGroupDummy>(rank, size);

    m.def("createProcessGroupDummy", &ProcessGroupDummy::createProcessGroupDummy);

Step 3: Build The Custom Extension

Now, the extension source code files are ready. We can then use cpp extensions to build it. To do that, create a setup.py file that prepares the paths and commands. Then call python setup.py install to install the extension.

If the extension depends on third-party libraries, you can also specify libraries_dirs and libraries to the cpp extension APIs. See the torch ucc project as a real-world example.

# file name: setup.py
import os
import sys
import torch
from setuptools import setup
from torch.utils import cpp_extension

sources = ["src/dummy.cpp"]
include_dirs = [f"{os.path.dirname(os.path.abspath(__file__))}/include/"]

if torch.cuda.is_available():
    module = cpp_extension.CUDAExtension(
        name = "dummy_collectives",
        sources = sources,
        include_dirs = include_dirs,
    module = cpp_extension.CppExtension(
        name = "dummy_collectives",
        sources = sources,
        include_dirs = include_dirs,

    name = "Dummy-Collectives",
    version = "0.0.1",
    ext_modules = [module],
    cmdclass={'build_ext': cpp_extension.BuildExtension}

Step 4: Use The Extension in Application

After installation, you can conveniently use the dummy backend when calling init_process_group as if it is an builtin backend.

import os

import torch
# importing dummy_collectives makes torch.distributed recognize `dummy`
# as a valid backend.
import dummy_collectives

import torch.distributed as dist

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'

dist.init_process_group("dummy", rank=0, world_size=1)

x = torch.ones(6)
print(f"cpu allreduce: {x}")
if torch.cuda.is_available():
    y = x.cuda()
    print(f"cuda allreduce: {y}")

    dist.broadcast(x, 0)
except RuntimeError:
    print("got RuntimeError as broadcast is not implemented in Dummy ProcessGroup")

더 궁금하시거나 개선할 내용이 있으신가요? 커뮤니티에 참여해보세요!

이 튜토리얼이 어떠셨나요? 평가해주시면 이후 개선에 참고하겠습니다! :)

© Copyright 2018-2023, PyTorch & 파이토치 한국 사용자 모임(PyTorch Korea User Group).

Built with Sphinx using a theme provided by Read the Docs.

PyTorchKorea @ GitHub

파이토치 한국 사용자 모임을 GitHub에서 만나보세요.

GitHub로 이동

한국어 튜토리얼

한국어로 번역 중인 PyTorch 튜토리얼입니다.

튜토리얼로 이동


다른 사용자들과 의견을 나누고, 도와주세요!

커뮤니티로 이동